this repo has no description
1use chrono::{DateTime, Utc};
2use serde::{de::DeserializeOwned, Serialize};
3use sqlx::PgPool;
4
5use super::{
6 AuthorizationRequestParameters, ClientAuth, DeviceData, OAuthError, RequestData, TokenData,
7 AuthorizedClientData,
8};
9
10fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> {
11 serde_json::to_value(value).map_err(|e| {
12 tracing::error!("JSON serialization error: {}", e);
13 OAuthError::ServerError("Internal serialization error".to_string())
14 })
15}
16
17fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> {
18 serde_json::from_value(value).map_err(|e| {
19 tracing::error!("JSON deserialization error: {}", e);
20 OAuthError::ServerError("Internal data corruption".to_string())
21 })
22}
23
24pub async fn create_device(
25 pool: &PgPool,
26 device_id: &str,
27 data: &DeviceData,
28) -> Result<(), OAuthError> {
29 sqlx::query!(
30 r#"
31 INSERT INTO oauth_device (id, session_id, user_agent, ip_address, last_seen_at)
32 VALUES ($1, $2, $3, $4, $5)
33 "#,
34 device_id,
35 data.session_id,
36 data.user_agent,
37 data.ip_address,
38 data.last_seen_at,
39 )
40 .execute(pool)
41 .await?;
42
43 Ok(())
44}
45
46pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> {
47 let row = sqlx::query!(
48 r#"
49 SELECT session_id, user_agent, ip_address, last_seen_at
50 FROM oauth_device
51 WHERE id = $1
52 "#,
53 device_id
54 )
55 .fetch_optional(pool)
56 .await?;
57
58 Ok(row.map(|r| DeviceData {
59 session_id: r.session_id,
60 user_agent: r.user_agent,
61 ip_address: r.ip_address,
62 last_seen_at: r.last_seen_at,
63 }))
64}
65
66pub async fn update_device_last_seen(
67 pool: &PgPool,
68 device_id: &str,
69) -> Result<(), OAuthError> {
70 sqlx::query!(
71 r#"
72 UPDATE oauth_device
73 SET last_seen_at = NOW()
74 WHERE id = $1
75 "#,
76 device_id
77 )
78 .execute(pool)
79 .await?;
80
81 Ok(())
82}
83
84pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> {
85 sqlx::query!(
86 r#"
87 DELETE FROM oauth_device WHERE id = $1
88 "#,
89 device_id
90 )
91 .execute(pool)
92 .await?;
93
94 Ok(())
95}
96
97pub async fn create_authorization_request(
98 pool: &PgPool,
99 request_id: &str,
100 data: &RequestData,
101) -> Result<(), OAuthError> {
102 let client_auth_json = match &data.client_auth {
103 Some(ca) => Some(to_json(ca)?),
104 None => None,
105 };
106 let parameters_json = to_json(&data.parameters)?;
107
108 sqlx::query!(
109 r#"
110 INSERT INTO oauth_authorization_request
111 (id, did, device_id, client_id, client_auth, parameters, expires_at, code)
112 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
113 "#,
114 request_id,
115 data.did,
116 data.device_id,
117 data.client_id,
118 client_auth_json,
119 parameters_json,
120 data.expires_at,
121 data.code,
122 )
123 .execute(pool)
124 .await?;
125
126 Ok(())
127}
128
129pub async fn get_authorization_request(
130 pool: &PgPool,
131 request_id: &str,
132) -> Result<Option<RequestData>, OAuthError> {
133 let row = sqlx::query!(
134 r#"
135 SELECT did, device_id, client_id, client_auth, parameters, expires_at, code
136 FROM oauth_authorization_request
137 WHERE id = $1
138 "#,
139 request_id
140 )
141 .fetch_optional(pool)
142 .await?;
143
144 match row {
145 Some(r) => {
146 let client_auth: Option<ClientAuth> = match r.client_auth {
147 Some(v) => Some(from_json(v)?),
148 None => None,
149 };
150 let parameters: AuthorizationRequestParameters = from_json(r.parameters)?;
151
152 Ok(Some(RequestData {
153 client_id: r.client_id,
154 client_auth,
155 parameters,
156 expires_at: r.expires_at,
157 did: r.did,
158 device_id: r.device_id,
159 code: r.code,
160 }))
161 }
162 None => Ok(None),
163 }
164}
165
166pub async fn update_authorization_request(
167 pool: &PgPool,
168 request_id: &str,
169 did: &str,
170 device_id: Option<&str>,
171 code: &str,
172) -> Result<(), OAuthError> {
173 sqlx::query!(
174 r#"
175 UPDATE oauth_authorization_request
176 SET did = $2, device_id = $3, code = $4
177 WHERE id = $1
178 "#,
179 request_id,
180 did,
181 device_id,
182 code
183 )
184 .execute(pool)
185 .await?;
186
187 Ok(())
188}
189
190pub async fn consume_authorization_request_by_code(
191 pool: &PgPool,
192 code: &str,
193) -> Result<Option<RequestData>, OAuthError> {
194 let row = sqlx::query!(
195 r#"
196 DELETE FROM oauth_authorization_request
197 WHERE code = $1
198 RETURNING did, device_id, client_id, client_auth, parameters, expires_at, code
199 "#,
200 code
201 )
202 .fetch_optional(pool)
203 .await?;
204
205 match row {
206 Some(r) => {
207 let client_auth: Option<ClientAuth> = match r.client_auth {
208 Some(v) => Some(from_json(v)?),
209 None => None,
210 };
211 let parameters: AuthorizationRequestParameters = from_json(r.parameters)?;
212
213 Ok(Some(RequestData {
214 client_id: r.client_id,
215 client_auth,
216 parameters,
217 expires_at: r.expires_at,
218 did: r.did,
219 device_id: r.device_id,
220 code: r.code,
221 }))
222 }
223 None => Ok(None),
224 }
225}
226
227pub async fn delete_authorization_request(
228 pool: &PgPool,
229 request_id: &str,
230) -> Result<(), OAuthError> {
231 sqlx::query!(
232 r#"
233 DELETE FROM oauth_authorization_request WHERE id = $1
234 "#,
235 request_id
236 )
237 .execute(pool)
238 .await?;
239
240 Ok(())
241}
242
243pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> {
244 let result = sqlx::query!(
245 r#"
246 DELETE FROM oauth_authorization_request
247 WHERE expires_at < NOW()
248 "#
249 )
250 .execute(pool)
251 .await?;
252
253 Ok(result.rows_affected())
254}
255
256pub async fn create_token(
257 pool: &PgPool,
258 data: &TokenData,
259) -> Result<i32, OAuthError> {
260 let client_auth_json = to_json(&data.client_auth)?;
261 let parameters_json = to_json(&data.parameters)?;
262
263 let row = sqlx::query!(
264 r#"
265 INSERT INTO oauth_token
266 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
267 device_id, parameters, details, code, current_refresh_token, scope)
268 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
269 RETURNING id
270 "#,
271 data.did,
272 data.token_id,
273 data.created_at,
274 data.updated_at,
275 data.expires_at,
276 data.client_id,
277 client_auth_json,
278 data.device_id,
279 parameters_json,
280 data.details,
281 data.code,
282 data.current_refresh_token,
283 data.scope,
284 )
285 .fetch_one(pool)
286 .await?;
287
288 Ok(row.id)
289}
290
291pub async fn get_token_by_id(
292 pool: &PgPool,
293 token_id: &str,
294) -> Result<Option<TokenData>, OAuthError> {
295 let row = sqlx::query!(
296 r#"
297 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
298 device_id, parameters, details, code, current_refresh_token, scope
299 FROM oauth_token
300 WHERE token_id = $1
301 "#,
302 token_id
303 )
304 .fetch_optional(pool)
305 .await?;
306
307 match row {
308 Some(r) => Ok(Some(TokenData {
309 did: r.did,
310 token_id: r.token_id,
311 created_at: r.created_at,
312 updated_at: r.updated_at,
313 expires_at: r.expires_at,
314 client_id: r.client_id,
315 client_auth: from_json(r.client_auth)?,
316 device_id: r.device_id,
317 parameters: from_json(r.parameters)?,
318 details: r.details,
319 code: r.code,
320 current_refresh_token: r.current_refresh_token,
321 scope: r.scope,
322 })),
323 None => Ok(None),
324 }
325}
326
327pub async fn get_token_by_refresh_token(
328 pool: &PgPool,
329 refresh_token: &str,
330) -> Result<Option<(i32, TokenData)>, OAuthError> {
331 let row = sqlx::query!(
332 r#"
333 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
334 device_id, parameters, details, code, current_refresh_token, scope
335 FROM oauth_token
336 WHERE current_refresh_token = $1
337 "#,
338 refresh_token
339 )
340 .fetch_optional(pool)
341 .await?;
342
343 match row {
344 Some(r) => Ok(Some((
345 r.id,
346 TokenData {
347 did: r.did,
348 token_id: r.token_id,
349 created_at: r.created_at,
350 updated_at: r.updated_at,
351 expires_at: r.expires_at,
352 client_id: r.client_id,
353 client_auth: from_json(r.client_auth)?,
354 device_id: r.device_id,
355 parameters: from_json(r.parameters)?,
356 details: r.details,
357 code: r.code,
358 current_refresh_token: r.current_refresh_token,
359 scope: r.scope,
360 },
361 ))),
362 None => Ok(None),
363 }
364}
365
366pub async fn rotate_token(
367 pool: &PgPool,
368 old_db_id: i32,
369 new_token_id: &str,
370 new_refresh_token: &str,
371 new_expires_at: DateTime<Utc>,
372) -> Result<(), OAuthError> {
373 let mut tx = pool.begin().await?;
374
375 let old_refresh = sqlx::query_scalar!(
376 r#"
377 SELECT current_refresh_token FROM oauth_token WHERE id = $1
378 "#,
379 old_db_id
380 )
381 .fetch_one(&mut *tx)
382 .await?;
383
384 if let Some(old_rt) = old_refresh {
385 sqlx::query!(
386 r#"
387 INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
388 VALUES ($1, $2)
389 "#,
390 old_rt,
391 old_db_id
392 )
393 .execute(&mut *tx)
394 .await?;
395 }
396
397 sqlx::query!(
398 r#"
399 UPDATE oauth_token
400 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW()
401 WHERE id = $1
402 "#,
403 old_db_id,
404 new_token_id,
405 new_refresh_token,
406 new_expires_at
407 )
408 .execute(&mut *tx)
409 .await?;
410
411 tx.commit().await?;
412 Ok(())
413}
414
415pub async fn check_refresh_token_used(
416 pool: &PgPool,
417 refresh_token: &str,
418) -> Result<Option<i32>, OAuthError> {
419 let row = sqlx::query_scalar!(
420 r#"
421 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
422 "#,
423 refresh_token
424 )
425 .fetch_optional(pool)
426 .await?;
427
428 Ok(row)
429}
430
431pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
432 sqlx::query!(
433 r#"
434 DELETE FROM oauth_token WHERE token_id = $1
435 "#,
436 token_id
437 )
438 .execute(pool)
439 .await?;
440
441 Ok(())
442}
443
444pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
445 sqlx::query!(
446 r#"
447 DELETE FROM oauth_token WHERE id = $1
448 "#,
449 db_id
450 )
451 .execute(pool)
452 .await?;
453
454 Ok(())
455}
456
457pub async fn upsert_account_device(
458 pool: &PgPool,
459 did: &str,
460 device_id: &str,
461) -> Result<(), OAuthError> {
462 sqlx::query!(
463 r#"
464 INSERT INTO oauth_account_device (did, device_id, created_at, updated_at)
465 VALUES ($1, $2, NOW(), NOW())
466 ON CONFLICT (did, device_id) DO UPDATE SET updated_at = NOW()
467 "#,
468 did,
469 device_id
470 )
471 .execute(pool)
472 .await?;
473
474 Ok(())
475}
476
477pub async fn upsert_authorized_client(
478 pool: &PgPool,
479 did: &str,
480 client_id: &str,
481 data: &AuthorizedClientData,
482) -> Result<(), OAuthError> {
483 let data_json = to_json(data)?;
484
485 sqlx::query!(
486 r#"
487 INSERT INTO oauth_authorized_client (did, client_id, created_at, updated_at, data)
488 VALUES ($1, $2, NOW(), NOW(), $3)
489 ON CONFLICT (did, client_id) DO UPDATE SET updated_at = NOW(), data = $3
490 "#,
491 did,
492 client_id,
493 data_json
494 )
495 .execute(pool)
496 .await?;
497
498 Ok(())
499}
500
501pub async fn get_authorized_client(
502 pool: &PgPool,
503 did: &str,
504 client_id: &str,
505) -> Result<Option<AuthorizedClientData>, OAuthError> {
506 let row = sqlx::query_scalar!(
507 r#"
508 SELECT data FROM oauth_authorized_client
509 WHERE did = $1 AND client_id = $2
510 "#,
511 did,
512 client_id
513 )
514 .fetch_optional(pool)
515 .await?;
516
517 match row {
518 Some(v) => Ok(Some(from_json(v)?)),
519 None => Ok(None),
520 }
521}
522
523pub async fn list_tokens_for_user(
524 pool: &PgPool,
525 did: &str,
526) -> Result<Vec<TokenData>, OAuthError> {
527 let rows = sqlx::query!(
528 r#"
529 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
530 device_id, parameters, details, code, current_refresh_token, scope
531 FROM oauth_token
532 WHERE did = $1
533 "#,
534 did
535 )
536 .fetch_all(pool)
537 .await?;
538
539 let mut tokens = Vec::with_capacity(rows.len());
540 for r in rows {
541 tokens.push(TokenData {
542 did: r.did,
543 token_id: r.token_id,
544 created_at: r.created_at,
545 updated_at: r.updated_at,
546 expires_at: r.expires_at,
547 client_id: r.client_id,
548 client_auth: from_json(r.client_auth)?,
549 device_id: r.device_id,
550 parameters: from_json(r.parameters)?,
551 details: r.details,
552 code: r.code,
553 current_refresh_token: r.current_refresh_token,
554 scope: r.scope,
555 });
556 }
557 Ok(tokens)
558}
559
560pub async fn check_and_record_dpop_jti(
561 pool: &PgPool,
562 jti: &str,
563) -> Result<bool, OAuthError> {
564 let result = sqlx::query!(
565 r#"
566 INSERT INTO oauth_dpop_jti (jti)
567 VALUES ($1)
568 ON CONFLICT (jti) DO NOTHING
569 "#,
570 jti
571 )
572 .execute(pool)
573 .await?;
574
575 Ok(result.rows_affected() > 0)
576}
577
578pub async fn cleanup_expired_dpop_jtis(
579 pool: &PgPool,
580 max_age_secs: i64,
581) -> Result<u64, OAuthError> {
582 let result = sqlx::query!(
583 r#"
584 DELETE FROM oauth_dpop_jti
585 WHERE created_at < NOW() - INTERVAL '1 second' * $1
586 "#,
587 max_age_secs as f64
588 )
589 .execute(pool)
590 .await?;
591
592 Ok(result.rows_affected())
593}
594
595pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
596 let count = sqlx::query_scalar!(
597 r#"
598 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
599 "#,
600 did
601 )
602 .fetch_one(pool)
603 .await?;
604
605 Ok(count)
606}
607
608pub async fn delete_oldest_tokens_for_user(
609 pool: &PgPool,
610 did: &str,
611 keep_count: i64,
612) -> Result<u64, OAuthError> {
613 let result = sqlx::query!(
614 r#"
615 DELETE FROM oauth_token
616 WHERE id IN (
617 SELECT id FROM oauth_token
618 WHERE did = $1
619 ORDER BY updated_at ASC
620 OFFSET $2
621 )
622 "#,
623 did,
624 keep_count
625 )
626 .execute(pool)
627 .await?;
628
629 Ok(result.rows_affected())
630}
631
632const MAX_TOKENS_PER_USER: i64 = 100;
633
634pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
635 let count = count_tokens_for_user(pool, did).await?;
636 if count > MAX_TOKENS_PER_USER {
637 let to_keep = MAX_TOKENS_PER_USER - 1;
638 delete_oldest_tokens_for_user(pool, did, to_keep).await?;
639 }
640 Ok(())
641}