|
1 # coding: utf-8 |
|
2 |
|
3 from datetime import datetime, timedelta |
|
4 from flask import Flask |
|
5 from flask import session, request |
|
6 from flask import render_template, redirect, jsonify |
|
7 from flask_sqlalchemy import SQLAlchemy |
|
8 from werkzeug.security import gen_salt |
|
9 from flask_oauthlib.provider import OAuth2Provider |
|
10 from settings.oauth_settings import OAuthSettings |
|
11 |
|
12 app = Flask(__name__, template_folder='templates') |
|
13 app.debug = True |
|
14 app.secret_key = 'secret' |
|
15 app.config.from_object(OAuthSettings) |
|
16 app.config.update({ |
|
17 'SQLALCHEMY_DATABASE_URI': 'sqlite:///db.sqlite', |
|
18 }) |
|
19 db = SQLAlchemy(app) |
|
20 oauth = OAuth2Provider(app) |
|
21 |
|
22 |
|
23 class User(db.Model): |
|
24 id = db.Column(db.Integer, primary_key=True) |
|
25 username = db.Column(db.String(40), unique=True) |
|
26 |
|
27 |
|
28 class Client(db.Model): |
|
29 client_id = db.Column(db.String(40), primary_key=True) |
|
30 client_secret = db.Column(db.String(55), nullable=False) |
|
31 |
|
32 _redirect_uris = db.Column(db.Text) |
|
33 _default_scopes = db.Column(db.Text) |
|
34 |
|
35 @property |
|
36 def client_type(self): |
|
37 return 'public' |
|
38 |
|
39 @property |
|
40 def redirect_uris(self): |
|
41 if self._redirect_uris: |
|
42 return self._redirect_uris.split() |
|
43 return [] |
|
44 |
|
45 @property |
|
46 def default_redirect_uri(self): |
|
47 return self.redirect_uris[0] |
|
48 |
|
49 @property |
|
50 def default_scopes(self): |
|
51 if self._default_scopes: |
|
52 return self._default_scopes.split() |
|
53 return [] |
|
54 |
|
55 |
|
56 class Grant(db.Model): |
|
57 id = db.Column(db.Integer, primary_key=True) |
|
58 |
|
59 user_id = db.Column( |
|
60 db.Integer, db.ForeignKey('user.id', ondelete='CASCADE') |
|
61 ) |
|
62 user = db.relationship('User') |
|
63 |
|
64 client_id = db.Column( |
|
65 db.String(40), db.ForeignKey('client.client_id'), |
|
66 nullable=False, |
|
67 ) |
|
68 client = db.relationship('Client') |
|
69 |
|
70 code = db.Column(db.String(255), index=True, nullable=False) |
|
71 |
|
72 redirect_uri = db.Column(db.String(255)) |
|
73 expires = db.Column(db.DateTime) |
|
74 |
|
75 _scopes = db.Column(db.Text) |
|
76 |
|
77 def delete(self): |
|
78 db.session.delete(self) |
|
79 db.session.commit() |
|
80 return self |
|
81 |
|
82 @property |
|
83 def scopes(self): |
|
84 if self._scopes: |
|
85 return self._scopes.split() |
|
86 return [] |
|
87 |
|
88 |
|
89 class Token(db.Model): |
|
90 id = db.Column(db.Integer, primary_key=True) |
|
91 client_id = db.Column( |
|
92 db.String(40), db.ForeignKey('client.client_id'), |
|
93 nullable=False, |
|
94 ) |
|
95 client = db.relationship('Client') |
|
96 |
|
97 user_id = db.Column( |
|
98 db.Integer, db.ForeignKey('user.id') |
|
99 ) |
|
100 user = db.relationship('User') |
|
101 |
|
102 # currently only bearer is supported |
|
103 token_type = db.Column(db.String(40)) |
|
104 |
|
105 access_token = db.Column(db.String(255), unique=True) |
|
106 refresh_token = db.Column(db.String(255), unique=True) |
|
107 expires = db.Column(db.DateTime) |
|
108 _scopes = db.Column(db.Text) |
|
109 |
|
110 @property |
|
111 def scopes(self): |
|
112 if self._scopes: |
|
113 return self._scopes.split() |
|
114 return [] |
|
115 |
|
116 |
|
117 def current_user(): |
|
118 if 'id' in session: |
|
119 uid = session['id'] |
|
120 return User.query.get(uid) |
|
121 return None |
|
122 |
|
123 |
|
124 @app.route('/', methods=('GET', 'POST')) |
|
125 def home(): |
|
126 if request.method == 'POST': |
|
127 username = request.form.get('username') |
|
128 user = User.query.filter_by(username=username).first() |
|
129 if not user: |
|
130 user = User(username=username) |
|
131 db.session.add(user) |
|
132 db.session.commit() |
|
133 session['id'] = user.id |
|
134 return redirect('/') |
|
135 user = current_user() |
|
136 return render_template('oauth/home.html', user=user) |
|
137 |
|
138 def generate_credentials(redirect_uris): |
|
139 item = Client( |
|
140 client_id=gen_salt(40), |
|
141 client_secret=gen_salt(50), |
|
142 _redirect_uris=' '.join(redirect_uris), |
|
143 _default_scopes='basic', |
|
144 ) |
|
145 db.session.add(item) |
|
146 db.session.commit() |
|
147 return jsonify( |
|
148 client_id=item.client_id, |
|
149 client_secret=item.client_secret, |
|
150 ) |
|
151 |
|
152 @app.route('/get-client-credentials') |
|
153 def make_client_credentials(): |
|
154 return generate_credentials(app.config.get("CLIENT_REDIRECT_URIS", [])) |
|
155 |
|
156 @app.route('/get-renkan-credentials') |
|
157 def make_renkan_credentials(): |
|
158 return generate_credentials(app.config.get("RENKAN_REDIRECT_URIS", [])) |
|
159 |
|
160 @oauth.clientgetter |
|
161 def load_client(client_id): |
|
162 return Client.query.filter_by(client_id=client_id).first() |
|
163 |
|
164 |
|
165 @oauth.grantgetter |
|
166 def load_grant(client_id, code): |
|
167 return Grant.query.filter_by(client_id=client_id, code=code).first() |
|
168 |
|
169 |
|
170 @oauth.grantsetter |
|
171 def save_grant(client_id, code, request, *args, **kwargs): |
|
172 # decide the expires time yourself |
|
173 expires = datetime.utcnow() + timedelta(seconds=100) |
|
174 grant = Grant( |
|
175 client_id=client_id, |
|
176 code=code['code'], |
|
177 redirect_uri=request.redirect_uri, |
|
178 _scopes=' '.join(request.scopes), |
|
179 user=current_user(), |
|
180 expires=expires |
|
181 ) |
|
182 db.session.add(grant) |
|
183 db.session.commit() |
|
184 return grant |
|
185 |
|
186 |
|
187 @oauth.tokengetter |
|
188 def load_token(access_token=None, refresh_token=None): |
|
189 if access_token: |
|
190 return Token.query.filter_by(access_token=access_token).first() |
|
191 elif refresh_token: |
|
192 return Token.query.filter_by(refresh_token=refresh_token).first() |
|
193 |
|
194 |
|
195 @oauth.tokensetter |
|
196 def save_token(token, request, *args, **kwargs): |
|
197 toks = Token.query.filter_by( |
|
198 client_id=request.client.client_id, |
|
199 user_id=request.user.id |
|
200 ) |
|
201 # make sure that every client has only one token connected to a user |
|
202 for t in toks: |
|
203 db.session.delete(t) |
|
204 |
|
205 expires_in = token.pop('expires_in') |
|
206 expires = datetime.utcnow() + timedelta(seconds=expires_in) |
|
207 |
|
208 tok = Token( |
|
209 access_token=token['access_token'], |
|
210 refresh_token=token['refresh_token'], |
|
211 token_type=token['token_type'], |
|
212 _scopes=token['scope'], |
|
213 expires=expires, |
|
214 client_id=request.client.client_id, |
|
215 user_id=request.user.id, |
|
216 ) |
|
217 db.session.add(tok) |
|
218 db.session.commit() |
|
219 return tok |
|
220 |
|
221 |
|
222 @app.route('/oauth/token', methods=['GET', 'POST']) |
|
223 @oauth.token_handler |
|
224 def access_token(): |
|
225 return None |
|
226 |
|
227 |
|
228 @app.route('/oauth/authorize', methods=['GET', 'POST']) |
|
229 @oauth.authorize_handler |
|
230 def authorize(*args, **kwargs): |
|
231 user = current_user() |
|
232 if not user: |
|
233 return redirect('/') |
|
234 if request.method == 'GET': |
|
235 client_id = kwargs.get('client_id') |
|
236 client = Client.query.filter_by(client_id=client_id).first() |
|
237 kwargs['client'] = client |
|
238 kwargs['user'] = user |
|
239 return render_template('oauth/authorize.html', **kwargs) |
|
240 |
|
241 confirm = request.form.get('confirm', 'no') |
|
242 return confirm == 'yes' |
|
243 |
|
244 |
|
245 @app.route('/api/me') |
|
246 @oauth.require_oauth() |
|
247 def me(): |
|
248 user = request.oauth.user |
|
249 return jsonify(id=user.id, username=user.username) |
|
250 |
|
251 |
|
252 if __name__ == '__main__': |
|
253 db.create_all() |
|
254 app.run() |