this repo has no description
1use super::super::{OAuthError, RefreshTokenState, TokenData};
2use super::helpers::{from_json, to_json};
3use chrono::{DateTime, Utc};
4use sqlx::PgPool;
5
6pub enum RefreshTokenLookup {
7 Valid { db_id: i32, token_data: TokenData },
8 InGracePeriod { db_id: i32, token_data: TokenData, rotated_at: DateTime<Utc> },
9 Used { original_token_id: i32 },
10 Expired { db_id: i32 },
11 NotFound,
12}
13
14impl RefreshTokenLookup {
15 pub fn state(&self) -> RefreshTokenState {
16 match self {
17 RefreshTokenLookup::Valid { .. } => RefreshTokenState::Valid,
18 RefreshTokenLookup::InGracePeriod { rotated_at, .. } => {
19 RefreshTokenState::InGracePeriod { rotated_at: *rotated_at }
20 }
21 RefreshTokenLookup::Used { .. } => RefreshTokenState::Used { at: Utc::now() },
22 RefreshTokenLookup::Expired { .. } => RefreshTokenState::Expired,
23 RefreshTokenLookup::NotFound => RefreshTokenState::Revoked,
24 }
25 }
26}
27
28pub async fn lookup_refresh_token(
29 pool: &PgPool,
30 refresh_token: &str,
31) -> Result<RefreshTokenLookup, OAuthError> {
32 if let Some(token_id) = check_refresh_token_used(pool, refresh_token).await? {
33 if let Some((db_id, token_data)) = get_token_by_previous_refresh_token(pool, refresh_token).await? {
34 let rotated_at = token_data.updated_at;
35 return Ok(RefreshTokenLookup::InGracePeriod { db_id, token_data, rotated_at });
36 }
37 return Ok(RefreshTokenLookup::Used { original_token_id: token_id });
38 }
39
40 match get_token_by_refresh_token(pool, refresh_token).await? {
41 Some((db_id, token_data)) => {
42 if token_data.expires_at < Utc::now() {
43 Ok(RefreshTokenLookup::Expired { db_id })
44 } else {
45 Ok(RefreshTokenLookup::Valid { db_id, token_data })
46 }
47 }
48 None => Ok(RefreshTokenLookup::NotFound),
49 }
50}
51
52pub async fn create_token(pool: &PgPool, data: &TokenData) -> Result<i32, OAuthError> {
53 let client_auth_json = to_json(&data.client_auth)?;
54 let parameters_json = to_json(&data.parameters)?;
55 let row = sqlx::query!(
56 r#"
57 INSERT INTO oauth_token
58 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
59 device_id, parameters, details, code, current_refresh_token, scope, controller_did)
60 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
61 RETURNING id
62 "#,
63 data.did,
64 data.token_id,
65 data.created_at,
66 data.updated_at,
67 data.expires_at,
68 data.client_id,
69 client_auth_json,
70 data.device_id,
71 parameters_json,
72 data.details,
73 data.code,
74 data.current_refresh_token,
75 data.scope,
76 data.controller_did,
77 )
78 .fetch_one(pool)
79 .await?;
80 Ok(row.id)
81}
82
83pub async fn get_token_by_id(
84 pool: &PgPool,
85 token_id: &str,
86) -> Result<Option<TokenData>, OAuthError> {
87 let row = sqlx::query!(
88 r#"
89 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
90 device_id, parameters, details, code, current_refresh_token, scope, controller_did
91 FROM oauth_token
92 WHERE token_id = $1
93 "#,
94 token_id
95 )
96 .fetch_optional(pool)
97 .await?;
98 match row {
99 Some(r) => Ok(Some(TokenData {
100 did: r.did,
101 token_id: r.token_id,
102 created_at: r.created_at,
103 updated_at: r.updated_at,
104 expires_at: r.expires_at,
105 client_id: r.client_id,
106 client_auth: from_json(r.client_auth)?,
107 device_id: r.device_id,
108 parameters: from_json(r.parameters)?,
109 details: r.details,
110 code: r.code,
111 current_refresh_token: r.current_refresh_token,
112 scope: r.scope,
113 controller_did: r.controller_did,
114 })),
115 None => Ok(None),
116 }
117}
118
119pub async fn get_token_by_refresh_token(
120 pool: &PgPool,
121 refresh_token: &str,
122) -> Result<Option<(i32, TokenData)>, OAuthError> {
123 let row = sqlx::query!(
124 r#"
125 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
126 device_id, parameters, details, code, current_refresh_token, scope, controller_did
127 FROM oauth_token
128 WHERE current_refresh_token = $1
129 "#,
130 refresh_token
131 )
132 .fetch_optional(pool)
133 .await?;
134 match row {
135 Some(r) => Ok(Some((
136 r.id,
137 TokenData {
138 did: r.did,
139 token_id: r.token_id,
140 created_at: r.created_at,
141 updated_at: r.updated_at,
142 expires_at: r.expires_at,
143 client_id: r.client_id,
144 client_auth: from_json(r.client_auth)?,
145 device_id: r.device_id,
146 parameters: from_json(r.parameters)?,
147 details: r.details,
148 code: r.code,
149 current_refresh_token: r.current_refresh_token,
150 scope: r.scope,
151 controller_did: r.controller_did,
152 },
153 ))),
154 None => Ok(None),
155 }
156}
157
158pub async fn rotate_token(
159 pool: &PgPool,
160 old_db_id: i32,
161 new_token_id: &str,
162 new_refresh_token: &str,
163 new_expires_at: DateTime<Utc>,
164) -> Result<(), OAuthError> {
165 let mut tx = pool.begin().await?;
166 let old_refresh = sqlx::query_scalar!(
167 r#"
168 SELECT current_refresh_token FROM oauth_token WHERE id = $1
169 "#,
170 old_db_id
171 )
172 .fetch_one(&mut *tx)
173 .await?;
174 if let Some(ref old_rt) = old_refresh {
175 sqlx::query!(
176 r#"
177 INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
178 VALUES ($1, $2)
179 "#,
180 old_rt,
181 old_db_id
182 )
183 .execute(&mut *tx)
184 .await?;
185 }
186 sqlx::query!(
187 r#"
188 UPDATE oauth_token
189 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW(),
190 previous_refresh_token = $5, rotated_at = NOW()
191 WHERE id = $1
192 "#,
193 old_db_id,
194 new_token_id,
195 new_refresh_token,
196 new_expires_at,
197 old_refresh
198 )
199 .execute(&mut *tx)
200 .await?;
201 tx.commit().await?;
202 Ok(())
203}
204
205pub async fn check_refresh_token_used(
206 pool: &PgPool,
207 refresh_token: &str,
208) -> Result<Option<i32>, OAuthError> {
209 let row = sqlx::query_scalar!(
210 r#"
211 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
212 "#,
213 refresh_token
214 )
215 .fetch_optional(pool)
216 .await?;
217 Ok(row)
218}
219
220const REFRESH_GRACE_PERIOD_SECS: i64 = 60;
221
222pub async fn get_token_by_previous_refresh_token(
223 pool: &PgPool,
224 refresh_token: &str,
225) -> Result<Option<(i32, TokenData)>, OAuthError> {
226 let grace_cutoff = Utc::now() - chrono::Duration::seconds(REFRESH_GRACE_PERIOD_SECS);
227 let row = sqlx::query!(
228 r#"
229 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
230 device_id, parameters, details, code, current_refresh_token, scope, controller_did
231 FROM oauth_token
232 WHERE previous_refresh_token = $1 AND rotated_at > $2
233 "#,
234 refresh_token,
235 grace_cutoff
236 )
237 .fetch_optional(pool)
238 .await?;
239 match row {
240 Some(r) => Ok(Some((
241 r.id,
242 TokenData {
243 did: r.did,
244 token_id: r.token_id,
245 created_at: r.created_at,
246 updated_at: r.updated_at,
247 expires_at: r.expires_at,
248 client_id: r.client_id,
249 client_auth: from_json(r.client_auth)?,
250 device_id: r.device_id,
251 parameters: from_json(r.parameters)?,
252 details: r.details,
253 code: r.code,
254 current_refresh_token: r.current_refresh_token,
255 scope: r.scope,
256 controller_did: r.controller_did,
257 },
258 ))),
259 None => Ok(None),
260 }
261}
262
263pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
264 sqlx::query!(
265 r#"
266 DELETE FROM oauth_token WHERE token_id = $1
267 "#,
268 token_id
269 )
270 .execute(pool)
271 .await?;
272 Ok(())
273}
274
275pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
276 sqlx::query!(
277 r#"
278 DELETE FROM oauth_token WHERE id = $1
279 "#,
280 db_id
281 )
282 .execute(pool)
283 .await?;
284 Ok(())
285}
286
287pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenData>, OAuthError> {
288 let rows = sqlx::query!(
289 r#"
290 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
291 device_id, parameters, details, code, current_refresh_token, scope, controller_did
292 FROM oauth_token
293 WHERE did = $1
294 "#,
295 did
296 )
297 .fetch_all(pool)
298 .await?;
299 let mut tokens = Vec::with_capacity(rows.len());
300 for r in rows {
301 tokens.push(TokenData {
302 did: r.did,
303 token_id: r.token_id,
304 created_at: r.created_at,
305 updated_at: r.updated_at,
306 expires_at: r.expires_at,
307 client_id: r.client_id,
308 client_auth: from_json(r.client_auth)?,
309 device_id: r.device_id,
310 parameters: from_json(r.parameters)?,
311 details: r.details,
312 code: r.code,
313 current_refresh_token: r.current_refresh_token,
314 scope: r.scope,
315 controller_did: r.controller_did,
316 });
317 }
318 Ok(tokens)
319}
320
321pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
322 let count = sqlx::query_scalar!(
323 r#"
324 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
325 "#,
326 did
327 )
328 .fetch_one(pool)
329 .await?;
330 Ok(count)
331}
332
333pub async fn delete_oldest_tokens_for_user(
334 pool: &PgPool,
335 did: &str,
336 keep_count: i64,
337) -> Result<u64, OAuthError> {
338 let result = sqlx::query!(
339 r#"
340 DELETE FROM oauth_token
341 WHERE id IN (
342 SELECT id FROM oauth_token
343 WHERE did = $1
344 ORDER BY updated_at ASC
345 OFFSET $2
346 )
347 "#,
348 did,
349 keep_count
350 )
351 .execute(pool)
352 .await?;
353 Ok(result.rows_affected())
354}
355
356const MAX_TOKENS_PER_USER: i64 = 100;
357
358pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
359 let count = count_tokens_for_user(pool, did).await?;
360 if count > MAX_TOKENS_PER_USER {
361 let to_keep = MAX_TOKENS_PER_USER - 1;
362 delete_oldest_tokens_for_user(pool, did, to_keep).await?;
363 }
364 Ok(())
365}
366
367pub async fn revoke_tokens_for_client(
368 pool: &PgPool,
369 did: &str,
370 client_id: &str,
371) -> Result<u64, OAuthError> {
372 let result = sqlx::query!(
373 "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2",
374 did,
375 client_id
376 )
377 .execute(pool)
378 .await?;
379 Ok(result.rows_affected())
380}
381
382pub async fn revoke_tokens_for_controller(
383 pool: &PgPool,
384 delegated_did: &str,
385 controller_did: &str,
386) -> Result<u64, OAuthError> {
387 let result = sqlx::query!(
388 "DELETE FROM oauth_token WHERE did = $1 AND controller_did = $2",
389 delegated_did,
390 controller_did
391 )
392 .execute(pool)
393 .await?;
394 Ok(result.rows_affected())
395}