this repo has no description
1use super::super::{OAuthError, RefreshTokenState, TokenData};
2use super::helpers::{from_json, to_json};
3use chrono::{DateTime, Utc};
4use sqlx::PgPool;
5
6pub enum RefreshTokenLookup {
7 Valid {
8 db_id: i32,
9 token_data: TokenData,
10 },
11 InGracePeriod {
12 db_id: i32,
13 token_data: TokenData,
14 rotated_at: DateTime<Utc>,
15 },
16 Used {
17 original_token_id: i32,
18 },
19 Expired {
20 db_id: i32,
21 },
22 NotFound,
23}
24
25impl RefreshTokenLookup {
26 pub fn state(&self) -> RefreshTokenState {
27 match self {
28 RefreshTokenLookup::Valid { .. } => RefreshTokenState::Valid,
29 RefreshTokenLookup::InGracePeriod { rotated_at, .. } => {
30 RefreshTokenState::InGracePeriod {
31 rotated_at: *rotated_at,
32 }
33 }
34 RefreshTokenLookup::Used { .. } => RefreshTokenState::Used { at: Utc::now() },
35 RefreshTokenLookup::Expired { .. } => RefreshTokenState::Expired,
36 RefreshTokenLookup::NotFound => RefreshTokenState::Revoked,
37 }
38 }
39}
40
41pub async fn lookup_refresh_token(
42 pool: &PgPool,
43 refresh_token: &str,
44) -> Result<RefreshTokenLookup, OAuthError> {
45 if let Some(token_id) = check_refresh_token_used(pool, refresh_token).await? {
46 if let Some((db_id, token_data)) =
47 get_token_by_previous_refresh_token(pool, refresh_token).await?
48 {
49 let rotated_at = token_data.updated_at;
50 return Ok(RefreshTokenLookup::InGracePeriod {
51 db_id,
52 token_data,
53 rotated_at,
54 });
55 }
56 return Ok(RefreshTokenLookup::Used {
57 original_token_id: token_id,
58 });
59 }
60
61 match get_token_by_refresh_token(pool, refresh_token).await? {
62 Some((db_id, token_data)) => {
63 if token_data.expires_at < Utc::now() {
64 Ok(RefreshTokenLookup::Expired { db_id })
65 } else {
66 Ok(RefreshTokenLookup::Valid { db_id, token_data })
67 }
68 }
69 None => Ok(RefreshTokenLookup::NotFound),
70 }
71}
72
73pub async fn create_token(pool: &PgPool, data: &TokenData) -> Result<i32, OAuthError> {
74 let client_auth_json = to_json(&data.client_auth)?;
75 let parameters_json = to_json(&data.parameters)?;
76 let row = sqlx::query!(
77 r#"
78 INSERT INTO oauth_token
79 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
80 device_id, parameters, details, code, current_refresh_token, scope, controller_did)
81 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
82 RETURNING id
83 "#,
84 data.did,
85 data.token_id,
86 data.created_at,
87 data.updated_at,
88 data.expires_at,
89 data.client_id,
90 client_auth_json,
91 data.device_id,
92 parameters_json,
93 data.details,
94 data.code,
95 data.current_refresh_token,
96 data.scope,
97 data.controller_did,
98 )
99 .fetch_one(pool)
100 .await?;
101 Ok(row.id)
102}
103
104pub async fn get_token_by_id(
105 pool: &PgPool,
106 token_id: &str,
107) -> Result<Option<TokenData>, OAuthError> {
108 let row = sqlx::query!(
109 r#"
110 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
111 device_id, parameters, details, code, current_refresh_token, scope, controller_did
112 FROM oauth_token
113 WHERE token_id = $1
114 "#,
115 token_id
116 )
117 .fetch_optional(pool)
118 .await?;
119 match row {
120 Some(r) => Ok(Some(TokenData {
121 did: r.did,
122 token_id: r.token_id,
123 created_at: r.created_at,
124 updated_at: r.updated_at,
125 expires_at: r.expires_at,
126 client_id: r.client_id,
127 client_auth: from_json(r.client_auth)?,
128 device_id: r.device_id,
129 parameters: from_json(r.parameters)?,
130 details: r.details,
131 code: r.code,
132 current_refresh_token: r.current_refresh_token,
133 scope: r.scope,
134 controller_did: r.controller_did,
135 })),
136 None => Ok(None),
137 }
138}
139
140pub async fn get_token_by_refresh_token(
141 pool: &PgPool,
142 refresh_token: &str,
143) -> Result<Option<(i32, TokenData)>, OAuthError> {
144 let row = sqlx::query!(
145 r#"
146 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
147 device_id, parameters, details, code, current_refresh_token, scope, controller_did
148 FROM oauth_token
149 WHERE current_refresh_token = $1
150 "#,
151 refresh_token
152 )
153 .fetch_optional(pool)
154 .await?;
155 match row {
156 Some(r) => Ok(Some((
157 r.id,
158 TokenData {
159 did: r.did,
160 token_id: r.token_id,
161 created_at: r.created_at,
162 updated_at: r.updated_at,
163 expires_at: r.expires_at,
164 client_id: r.client_id,
165 client_auth: from_json(r.client_auth)?,
166 device_id: r.device_id,
167 parameters: from_json(r.parameters)?,
168 details: r.details,
169 code: r.code,
170 current_refresh_token: r.current_refresh_token,
171 scope: r.scope,
172 controller_did: r.controller_did,
173 },
174 ))),
175 None => Ok(None),
176 }
177}
178
179pub async fn rotate_token(
180 pool: &PgPool,
181 old_db_id: i32,
182 new_refresh_token: &str,
183 new_expires_at: DateTime<Utc>,
184) -> Result<(), OAuthError> {
185 let mut tx = pool.begin().await?;
186 let old_refresh = sqlx::query_scalar!(
187 r#"
188 SELECT current_refresh_token FROM oauth_token WHERE id = $1
189 "#,
190 old_db_id
191 )
192 .fetch_one(&mut *tx)
193 .await?;
194 if let Some(ref old_rt) = old_refresh {
195 sqlx::query!(
196 r#"
197 INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
198 VALUES ($1, $2)
199 "#,
200 old_rt,
201 old_db_id
202 )
203 .execute(&mut *tx)
204 .await?;
205 }
206 sqlx::query!(
207 r#"
208 UPDATE oauth_token
209 SET current_refresh_token = $2, expires_at = $3, updated_at = NOW(),
210 previous_refresh_token = $4, rotated_at = NOW()
211 WHERE id = $1
212 "#,
213 old_db_id,
214 new_refresh_token,
215 new_expires_at,
216 old_refresh
217 )
218 .execute(&mut *tx)
219 .await?;
220 tx.commit().await?;
221 Ok(())
222}
223
224pub async fn check_refresh_token_used(
225 pool: &PgPool,
226 refresh_token: &str,
227) -> Result<Option<i32>, OAuthError> {
228 let row = sqlx::query_scalar!(
229 r#"
230 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
231 "#,
232 refresh_token
233 )
234 .fetch_optional(pool)
235 .await?;
236 Ok(row)
237}
238
239const REFRESH_GRACE_PERIOD_SECS: i64 = 60;
240
241pub async fn get_token_by_previous_refresh_token(
242 pool: &PgPool,
243 refresh_token: &str,
244) -> Result<Option<(i32, TokenData)>, OAuthError> {
245 let grace_cutoff = Utc::now() - chrono::Duration::seconds(REFRESH_GRACE_PERIOD_SECS);
246 let row = sqlx::query!(
247 r#"
248 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
249 device_id, parameters, details, code, current_refresh_token, scope, controller_did
250 FROM oauth_token
251 WHERE previous_refresh_token = $1 AND rotated_at > $2
252 "#,
253 refresh_token,
254 grace_cutoff
255 )
256 .fetch_optional(pool)
257 .await?;
258 match row {
259 Some(r) => Ok(Some((
260 r.id,
261 TokenData {
262 did: r.did,
263 token_id: r.token_id,
264 created_at: r.created_at,
265 updated_at: r.updated_at,
266 expires_at: r.expires_at,
267 client_id: r.client_id,
268 client_auth: from_json(r.client_auth)?,
269 device_id: r.device_id,
270 parameters: from_json(r.parameters)?,
271 details: r.details,
272 code: r.code,
273 current_refresh_token: r.current_refresh_token,
274 scope: r.scope,
275 controller_did: r.controller_did,
276 },
277 ))),
278 None => Ok(None),
279 }
280}
281
282pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
283 sqlx::query!(
284 r#"
285 DELETE FROM oauth_token WHERE token_id = $1
286 "#,
287 token_id
288 )
289 .execute(pool)
290 .await?;
291 Ok(())
292}
293
294pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
295 sqlx::query!(
296 r#"
297 DELETE FROM oauth_token WHERE id = $1
298 "#,
299 db_id
300 )
301 .execute(pool)
302 .await?;
303 Ok(())
304}
305
306pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenData>, OAuthError> {
307 let rows = sqlx::query!(
308 r#"
309 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
310 device_id, parameters, details, code, current_refresh_token, scope, controller_did
311 FROM oauth_token
312 WHERE did = $1
313 "#,
314 did
315 )
316 .fetch_all(pool)
317 .await?;
318 rows.into_iter()
319 .map(|r| {
320 Ok(TokenData {
321 did: r.did,
322 token_id: r.token_id,
323 created_at: r.created_at,
324 updated_at: r.updated_at,
325 expires_at: r.expires_at,
326 client_id: r.client_id,
327 client_auth: from_json(r.client_auth)?,
328 device_id: r.device_id,
329 parameters: from_json(r.parameters)?,
330 details: r.details,
331 code: r.code,
332 current_refresh_token: r.current_refresh_token,
333 scope: r.scope,
334 controller_did: r.controller_did,
335 })
336 })
337 .collect()
338}
339
340pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
341 let count = sqlx::query_scalar!(
342 r#"
343 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
344 "#,
345 did
346 )
347 .fetch_one(pool)
348 .await?;
349 Ok(count)
350}
351
352pub async fn delete_oldest_tokens_for_user(
353 pool: &PgPool,
354 did: &str,
355 keep_count: i64,
356) -> Result<u64, OAuthError> {
357 let result = sqlx::query!(
358 r#"
359 DELETE FROM oauth_token
360 WHERE id IN (
361 SELECT id FROM oauth_token
362 WHERE did = $1
363 ORDER BY updated_at ASC
364 OFFSET $2
365 )
366 "#,
367 did,
368 keep_count
369 )
370 .execute(pool)
371 .await?;
372 Ok(result.rows_affected())
373}
374
375const MAX_TOKENS_PER_USER: i64 = 100;
376
377pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
378 let count = count_tokens_for_user(pool, did).await?;
379 if count > MAX_TOKENS_PER_USER {
380 let to_keep = MAX_TOKENS_PER_USER - 1;
381 delete_oldest_tokens_for_user(pool, did, to_keep).await?;
382 }
383 Ok(())
384}
385
386pub async fn revoke_tokens_for_client(
387 pool: &PgPool,
388 did: &str,
389 client_id: &str,
390) -> Result<u64, OAuthError> {
391 let result = sqlx::query!(
392 "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2",
393 did,
394 client_id
395 )
396 .execute(pool)
397 .await?;
398 Ok(result.rows_affected())
399}
400
401pub async fn revoke_tokens_for_controller(
402 pool: &PgPool,
403 delegated_did: &str,
404 controller_did: &str,
405) -> Result<u64, OAuthError> {
406 let result = sqlx::query!(
407 "DELETE FROM oauth_token WHERE did = $1 AND controller_did = $2",
408 delegated_did,
409 controller_did
410 )
411 .execute(pool)
412 .await?;
413 Ok(result.rows_affected())
414}