this repo has no description
1use super::super::{OAuthError, RefreshTokenState, TokenData};
2use super::helpers::{from_json, to_json};
3use chrono::{DateTime, Utc};
4use sqlx::PgPool;
5
6pub enum RefreshTokenLookup {
7 Valid {
8 db_id: i32,
9 token_data: TokenData,
10 },
11 InGracePeriod {
12 db_id: i32,
13 token_data: TokenData,
14 rotated_at: DateTime<Utc>,
15 },
16 Used {
17 original_token_id: i32,
18 },
19 Expired {
20 db_id: i32,
21 },
22 NotFound,
23}
24
25impl RefreshTokenLookup {
26 pub fn state(&self) -> RefreshTokenState {
27 match self {
28 RefreshTokenLookup::Valid { .. } => RefreshTokenState::Valid,
29 RefreshTokenLookup::InGracePeriod { rotated_at, .. } => {
30 RefreshTokenState::InGracePeriod {
31 rotated_at: *rotated_at,
32 }
33 }
34 RefreshTokenLookup::Used { .. } => RefreshTokenState::Used { at: Utc::now() },
35 RefreshTokenLookup::Expired { .. } => RefreshTokenState::Expired,
36 RefreshTokenLookup::NotFound => RefreshTokenState::Revoked,
37 }
38 }
39}
40
41pub async fn lookup_refresh_token(
42 pool: &PgPool,
43 refresh_token: &str,
44) -> Result<RefreshTokenLookup, OAuthError> {
45 if let Some(token_id) = check_refresh_token_used(pool, refresh_token).await? {
46 if let Some((db_id, token_data)) =
47 get_token_by_previous_refresh_token(pool, refresh_token).await?
48 {
49 let rotated_at = token_data.updated_at;
50 return Ok(RefreshTokenLookup::InGracePeriod {
51 db_id,
52 token_data,
53 rotated_at,
54 });
55 }
56 return Ok(RefreshTokenLookup::Used {
57 original_token_id: token_id,
58 });
59 }
60
61 match get_token_by_refresh_token(pool, refresh_token).await? {
62 Some((db_id, token_data)) => {
63 if token_data.expires_at < Utc::now() {
64 Ok(RefreshTokenLookup::Expired { db_id })
65 } else {
66 Ok(RefreshTokenLookup::Valid { db_id, token_data })
67 }
68 }
69 None => Ok(RefreshTokenLookup::NotFound),
70 }
71}
72
73pub async fn create_token(pool: &PgPool, data: &TokenData) -> Result<i32, OAuthError> {
74 let client_auth_json = to_json(&data.client_auth)?;
75 let parameters_json = to_json(&data.parameters)?;
76 let row = sqlx::query!(
77 r#"
78 INSERT INTO oauth_token
79 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
80 device_id, parameters, details, code, current_refresh_token, scope, controller_did)
81 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
82 RETURNING id
83 "#,
84 data.did,
85 data.token_id,
86 data.created_at,
87 data.updated_at,
88 data.expires_at,
89 data.client_id,
90 client_auth_json,
91 data.device_id,
92 parameters_json,
93 data.details,
94 data.code,
95 data.current_refresh_token,
96 data.scope,
97 data.controller_did,
98 )
99 .fetch_one(pool)
100 .await?;
101 Ok(row.id)
102}
103
104pub async fn get_token_by_id(
105 pool: &PgPool,
106 token_id: &str,
107) -> Result<Option<TokenData>, OAuthError> {
108 let row = sqlx::query!(
109 r#"
110 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
111 device_id, parameters, details, code, current_refresh_token, scope, controller_did
112 FROM oauth_token
113 WHERE token_id = $1
114 "#,
115 token_id
116 )
117 .fetch_optional(pool)
118 .await?;
119 match row {
120 Some(r) => Ok(Some(TokenData {
121 did: r.did,
122 token_id: r.token_id,
123 created_at: r.created_at,
124 updated_at: r.updated_at,
125 expires_at: r.expires_at,
126 client_id: r.client_id,
127 client_auth: from_json(r.client_auth)?,
128 device_id: r.device_id,
129 parameters: from_json(r.parameters)?,
130 details: r.details,
131 code: r.code,
132 current_refresh_token: r.current_refresh_token,
133 scope: r.scope,
134 controller_did: r.controller_did,
135 })),
136 None => Ok(None),
137 }
138}
139
140pub async fn get_token_by_refresh_token(
141 pool: &PgPool,
142 refresh_token: &str,
143) -> Result<Option<(i32, TokenData)>, OAuthError> {
144 let row = sqlx::query!(
145 r#"
146 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
147 device_id, parameters, details, code, current_refresh_token, scope, controller_did
148 FROM oauth_token
149 WHERE current_refresh_token = $1
150 "#,
151 refresh_token
152 )
153 .fetch_optional(pool)
154 .await?;
155 match row {
156 Some(r) => Ok(Some((
157 r.id,
158 TokenData {
159 did: r.did,
160 token_id: r.token_id,
161 created_at: r.created_at,
162 updated_at: r.updated_at,
163 expires_at: r.expires_at,
164 client_id: r.client_id,
165 client_auth: from_json(r.client_auth)?,
166 device_id: r.device_id,
167 parameters: from_json(r.parameters)?,
168 details: r.details,
169 code: r.code,
170 current_refresh_token: r.current_refresh_token,
171 scope: r.scope,
172 controller_did: r.controller_did,
173 },
174 ))),
175 None => Ok(None),
176 }
177}
178
179pub async fn rotate_token(
180 pool: &PgPool,
181 old_db_id: i32,
182 new_token_id: &str,
183 new_refresh_token: &str,
184 new_expires_at: DateTime<Utc>,
185) -> Result<(), OAuthError> {
186 let mut tx = pool.begin().await?;
187 let old_refresh = sqlx::query_scalar!(
188 r#"
189 SELECT current_refresh_token FROM oauth_token WHERE id = $1
190 "#,
191 old_db_id
192 )
193 .fetch_one(&mut *tx)
194 .await?;
195 if let Some(ref old_rt) = old_refresh {
196 sqlx::query!(
197 r#"
198 INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
199 VALUES ($1, $2)
200 "#,
201 old_rt,
202 old_db_id
203 )
204 .execute(&mut *tx)
205 .await?;
206 }
207 sqlx::query!(
208 r#"
209 UPDATE oauth_token
210 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW(),
211 previous_refresh_token = $5, rotated_at = NOW()
212 WHERE id = $1
213 "#,
214 old_db_id,
215 new_token_id,
216 new_refresh_token,
217 new_expires_at,
218 old_refresh
219 )
220 .execute(&mut *tx)
221 .await?;
222 tx.commit().await?;
223 Ok(())
224}
225
226pub async fn check_refresh_token_used(
227 pool: &PgPool,
228 refresh_token: &str,
229) -> Result<Option<i32>, OAuthError> {
230 let row = sqlx::query_scalar!(
231 r#"
232 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
233 "#,
234 refresh_token
235 )
236 .fetch_optional(pool)
237 .await?;
238 Ok(row)
239}
240
241const REFRESH_GRACE_PERIOD_SECS: i64 = 60;
242
243pub async fn get_token_by_previous_refresh_token(
244 pool: &PgPool,
245 refresh_token: &str,
246) -> Result<Option<(i32, TokenData)>, OAuthError> {
247 let grace_cutoff = Utc::now() - chrono::Duration::seconds(REFRESH_GRACE_PERIOD_SECS);
248 let row = sqlx::query!(
249 r#"
250 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
251 device_id, parameters, details, code, current_refresh_token, scope, controller_did
252 FROM oauth_token
253 WHERE previous_refresh_token = $1 AND rotated_at > $2
254 "#,
255 refresh_token,
256 grace_cutoff
257 )
258 .fetch_optional(pool)
259 .await?;
260 match row {
261 Some(r) => Ok(Some((
262 r.id,
263 TokenData {
264 did: r.did,
265 token_id: r.token_id,
266 created_at: r.created_at,
267 updated_at: r.updated_at,
268 expires_at: r.expires_at,
269 client_id: r.client_id,
270 client_auth: from_json(r.client_auth)?,
271 device_id: r.device_id,
272 parameters: from_json(r.parameters)?,
273 details: r.details,
274 code: r.code,
275 current_refresh_token: r.current_refresh_token,
276 scope: r.scope,
277 controller_did: r.controller_did,
278 },
279 ))),
280 None => Ok(None),
281 }
282}
283
284pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
285 sqlx::query!(
286 r#"
287 DELETE FROM oauth_token WHERE token_id = $1
288 "#,
289 token_id
290 )
291 .execute(pool)
292 .await?;
293 Ok(())
294}
295
296pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
297 sqlx::query!(
298 r#"
299 DELETE FROM oauth_token WHERE id = $1
300 "#,
301 db_id
302 )
303 .execute(pool)
304 .await?;
305 Ok(())
306}
307
308pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenData>, OAuthError> {
309 let rows = sqlx::query!(
310 r#"
311 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
312 device_id, parameters, details, code, current_refresh_token, scope, controller_did
313 FROM oauth_token
314 WHERE did = $1
315 "#,
316 did
317 )
318 .fetch_all(pool)
319 .await?;
320 let mut tokens = Vec::with_capacity(rows.len());
321 for r in rows {
322 tokens.push(TokenData {
323 did: r.did,
324 token_id: r.token_id,
325 created_at: r.created_at,
326 updated_at: r.updated_at,
327 expires_at: r.expires_at,
328 client_id: r.client_id,
329 client_auth: from_json(r.client_auth)?,
330 device_id: r.device_id,
331 parameters: from_json(r.parameters)?,
332 details: r.details,
333 code: r.code,
334 current_refresh_token: r.current_refresh_token,
335 scope: r.scope,
336 controller_did: r.controller_did,
337 });
338 }
339 Ok(tokens)
340}
341
342pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
343 let count = sqlx::query_scalar!(
344 r#"
345 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
346 "#,
347 did
348 )
349 .fetch_one(pool)
350 .await?;
351 Ok(count)
352}
353
354pub async fn delete_oldest_tokens_for_user(
355 pool: &PgPool,
356 did: &str,
357 keep_count: i64,
358) -> Result<u64, OAuthError> {
359 let result = sqlx::query!(
360 r#"
361 DELETE FROM oauth_token
362 WHERE id IN (
363 SELECT id FROM oauth_token
364 WHERE did = $1
365 ORDER BY updated_at ASC
366 OFFSET $2
367 )
368 "#,
369 did,
370 keep_count
371 )
372 .execute(pool)
373 .await?;
374 Ok(result.rows_affected())
375}
376
377const MAX_TOKENS_PER_USER: i64 = 100;
378
379pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
380 let count = count_tokens_for_user(pool, did).await?;
381 if count > MAX_TOKENS_PER_USER {
382 let to_keep = MAX_TOKENS_PER_USER - 1;
383 delete_oldest_tokens_for_user(pool, did, to_keep).await?;
384 }
385 Ok(())
386}
387
388pub async fn revoke_tokens_for_client(
389 pool: &PgPool,
390 did: &str,
391 client_id: &str,
392) -> Result<u64, OAuthError> {
393 let result = sqlx::query!(
394 "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2",
395 did,
396 client_id
397 )
398 .execute(pool)
399 .await?;
400 Ok(result.rows_affected())
401}
402
403pub async fn revoke_tokens_for_controller(
404 pool: &PgPool,
405 delegated_did: &str,
406 controller_did: &str,
407) -> Result<u64, OAuthError> {
408 let result = sqlx::query!(
409 "DELETE FROM oauth_token WHERE did = $1 AND controller_did = $2",
410 delegated_did,
411 controller_did
412 )
413 .execute(pool)
414 .await?;
415 Ok(result.rows_affected())
416}