this repo has no description
1use chrono::{DateTime, Utc}; 2use serde::{de::DeserializeOwned, Serialize}; 3use sqlx::PgPool; 4 5use super::{ 6 AuthorizationRequestParameters, ClientAuth, DeviceData, OAuthError, RequestData, TokenData, 7 AuthorizedClientData, 8}; 9 10fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> { 11 serde_json::to_value(value).map_err(|e| { 12 tracing::error!("JSON serialization error: {}", e); 13 OAuthError::ServerError("Internal serialization error".to_string()) 14 }) 15} 16 17fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> { 18 serde_json::from_value(value).map_err(|e| { 19 tracing::error!("JSON deserialization error: {}", e); 20 OAuthError::ServerError("Internal data corruption".to_string()) 21 }) 22} 23 24pub async fn create_device( 25 pool: &PgPool, 26 device_id: &str, 27 data: &DeviceData, 28) -> Result<(), OAuthError> { 29 sqlx::query!( 30 r#" 31 INSERT INTO oauth_device (id, session_id, user_agent, ip_address, last_seen_at) 32 VALUES ($1, $2, $3, $4, $5) 33 "#, 34 device_id, 35 data.session_id, 36 data.user_agent, 37 data.ip_address, 38 data.last_seen_at, 39 ) 40 .execute(pool) 41 .await?; 42 43 Ok(()) 44} 45 46pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> { 47 let row = sqlx::query!( 48 r#" 49 SELECT session_id, user_agent, ip_address, last_seen_at 50 FROM oauth_device 51 WHERE id = $1 52 "#, 53 device_id 54 ) 55 .fetch_optional(pool) 56 .await?; 57 58 Ok(row.map(|r| DeviceData { 59 session_id: r.session_id, 60 user_agent: r.user_agent, 61 ip_address: r.ip_address, 62 last_seen_at: r.last_seen_at, 63 })) 64} 65 66pub async fn update_device_last_seen( 67 pool: &PgPool, 68 device_id: &str, 69) -> Result<(), OAuthError> { 70 sqlx::query!( 71 r#" 72 UPDATE oauth_device 73 SET last_seen_at = NOW() 74 WHERE id = $1 75 "#, 76 device_id 77 ) 78 .execute(pool) 79 .await?; 80 81 Ok(()) 82} 83 84pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> { 85 sqlx::query!( 86 r#" 87 DELETE FROM oauth_device WHERE id = $1 88 "#, 89 device_id 90 ) 91 .execute(pool) 92 .await?; 93 94 Ok(()) 95} 96 97pub async fn create_authorization_request( 98 pool: &PgPool, 99 request_id: &str, 100 data: &RequestData, 101) -> Result<(), OAuthError> { 102 let client_auth_json = match &data.client_auth { 103 Some(ca) => Some(to_json(ca)?), 104 None => None, 105 }; 106 let parameters_json = to_json(&data.parameters)?; 107 108 sqlx::query!( 109 r#" 110 INSERT INTO oauth_authorization_request 111 (id, did, device_id, client_id, client_auth, parameters, expires_at, code) 112 VALUES ($1, $2, $3, $4, $5, $6, $7, $8) 113 "#, 114 request_id, 115 data.did, 116 data.device_id, 117 data.client_id, 118 client_auth_json, 119 parameters_json, 120 data.expires_at, 121 data.code, 122 ) 123 .execute(pool) 124 .await?; 125 126 Ok(()) 127} 128 129pub async fn get_authorization_request( 130 pool: &PgPool, 131 request_id: &str, 132) -> Result<Option<RequestData>, OAuthError> { 133 let row = sqlx::query!( 134 r#" 135 SELECT did, device_id, client_id, client_auth, parameters, expires_at, code 136 FROM oauth_authorization_request 137 WHERE id = $1 138 "#, 139 request_id 140 ) 141 .fetch_optional(pool) 142 .await?; 143 144 match row { 145 Some(r) => { 146 let client_auth: Option<ClientAuth> = match r.client_auth { 147 Some(v) => Some(from_json(v)?), 148 None => None, 149 }; 150 let parameters: AuthorizationRequestParameters = from_json(r.parameters)?; 151 152 Ok(Some(RequestData { 153 client_id: r.client_id, 154 client_auth, 155 parameters, 156 expires_at: r.expires_at, 157 did: r.did, 158 device_id: r.device_id, 159 code: r.code, 160 })) 161 } 162 None => Ok(None), 163 } 164} 165 166pub async fn update_authorization_request( 167 pool: &PgPool, 168 request_id: &str, 169 did: &str, 170 device_id: Option<&str>, 171 code: &str, 172) -> Result<(), OAuthError> { 173 sqlx::query!( 174 r#" 175 UPDATE oauth_authorization_request 176 SET did = $2, device_id = $3, code = $4 177 WHERE id = $1 178 "#, 179 request_id, 180 did, 181 device_id, 182 code 183 ) 184 .execute(pool) 185 .await?; 186 187 Ok(()) 188} 189 190pub async fn consume_authorization_request_by_code( 191 pool: &PgPool, 192 code: &str, 193) -> Result<Option<RequestData>, OAuthError> { 194 let row = sqlx::query!( 195 r#" 196 DELETE FROM oauth_authorization_request 197 WHERE code = $1 198 RETURNING did, device_id, client_id, client_auth, parameters, expires_at, code 199 "#, 200 code 201 ) 202 .fetch_optional(pool) 203 .await?; 204 205 match row { 206 Some(r) => { 207 let client_auth: Option<ClientAuth> = match r.client_auth { 208 Some(v) => Some(from_json(v)?), 209 None => None, 210 }; 211 let parameters: AuthorizationRequestParameters = from_json(r.parameters)?; 212 213 Ok(Some(RequestData { 214 client_id: r.client_id, 215 client_auth, 216 parameters, 217 expires_at: r.expires_at, 218 did: r.did, 219 device_id: r.device_id, 220 code: r.code, 221 })) 222 } 223 None => Ok(None), 224 } 225} 226 227pub async fn delete_authorization_request( 228 pool: &PgPool, 229 request_id: &str, 230) -> Result<(), OAuthError> { 231 sqlx::query!( 232 r#" 233 DELETE FROM oauth_authorization_request WHERE id = $1 234 "#, 235 request_id 236 ) 237 .execute(pool) 238 .await?; 239 240 Ok(()) 241} 242 243pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> { 244 let result = sqlx::query!( 245 r#" 246 DELETE FROM oauth_authorization_request 247 WHERE expires_at < NOW() 248 "# 249 ) 250 .execute(pool) 251 .await?; 252 253 Ok(result.rows_affected()) 254} 255 256pub async fn create_token( 257 pool: &PgPool, 258 data: &TokenData, 259) -> Result<i32, OAuthError> { 260 let client_auth_json = to_json(&data.client_auth)?; 261 let parameters_json = to_json(&data.parameters)?; 262 263 let row = sqlx::query!( 264 r#" 265 INSERT INTO oauth_token 266 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 267 device_id, parameters, details, code, current_refresh_token, scope) 268 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) 269 RETURNING id 270 "#, 271 data.did, 272 data.token_id, 273 data.created_at, 274 data.updated_at, 275 data.expires_at, 276 data.client_id, 277 client_auth_json, 278 data.device_id, 279 parameters_json, 280 data.details, 281 data.code, 282 data.current_refresh_token, 283 data.scope, 284 ) 285 .fetch_one(pool) 286 .await?; 287 288 Ok(row.id) 289} 290 291pub async fn get_token_by_id( 292 pool: &PgPool, 293 token_id: &str, 294) -> Result<Option<TokenData>, OAuthError> { 295 let row = sqlx::query!( 296 r#" 297 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 298 device_id, parameters, details, code, current_refresh_token, scope 299 FROM oauth_token 300 WHERE token_id = $1 301 "#, 302 token_id 303 ) 304 .fetch_optional(pool) 305 .await?; 306 307 match row { 308 Some(r) => Ok(Some(TokenData { 309 did: r.did, 310 token_id: r.token_id, 311 created_at: r.created_at, 312 updated_at: r.updated_at, 313 expires_at: r.expires_at, 314 client_id: r.client_id, 315 client_auth: from_json(r.client_auth)?, 316 device_id: r.device_id, 317 parameters: from_json(r.parameters)?, 318 details: r.details, 319 code: r.code, 320 current_refresh_token: r.current_refresh_token, 321 scope: r.scope, 322 })), 323 None => Ok(None), 324 } 325} 326 327pub async fn get_token_by_refresh_token( 328 pool: &PgPool, 329 refresh_token: &str, 330) -> Result<Option<(i32, TokenData)>, OAuthError> { 331 let row = sqlx::query!( 332 r#" 333 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 334 device_id, parameters, details, code, current_refresh_token, scope 335 FROM oauth_token 336 WHERE current_refresh_token = $1 337 "#, 338 refresh_token 339 ) 340 .fetch_optional(pool) 341 .await?; 342 343 match row { 344 Some(r) => Ok(Some(( 345 r.id, 346 TokenData { 347 did: r.did, 348 token_id: r.token_id, 349 created_at: r.created_at, 350 updated_at: r.updated_at, 351 expires_at: r.expires_at, 352 client_id: r.client_id, 353 client_auth: from_json(r.client_auth)?, 354 device_id: r.device_id, 355 parameters: from_json(r.parameters)?, 356 details: r.details, 357 code: r.code, 358 current_refresh_token: r.current_refresh_token, 359 scope: r.scope, 360 }, 361 ))), 362 None => Ok(None), 363 } 364} 365 366pub async fn rotate_token( 367 pool: &PgPool, 368 old_db_id: i32, 369 new_token_id: &str, 370 new_refresh_token: &str, 371 new_expires_at: DateTime<Utc>, 372) -> Result<(), OAuthError> { 373 let mut tx = pool.begin().await?; 374 375 let old_refresh = sqlx::query_scalar!( 376 r#" 377 SELECT current_refresh_token FROM oauth_token WHERE id = $1 378 "#, 379 old_db_id 380 ) 381 .fetch_one(&mut *tx) 382 .await?; 383 384 if let Some(old_rt) = old_refresh { 385 sqlx::query!( 386 r#" 387 INSERT INTO oauth_used_refresh_token (refresh_token, token_id) 388 VALUES ($1, $2) 389 "#, 390 old_rt, 391 old_db_id 392 ) 393 .execute(&mut *tx) 394 .await?; 395 } 396 397 sqlx::query!( 398 r#" 399 UPDATE oauth_token 400 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW() 401 WHERE id = $1 402 "#, 403 old_db_id, 404 new_token_id, 405 new_refresh_token, 406 new_expires_at 407 ) 408 .execute(&mut *tx) 409 .await?; 410 411 tx.commit().await?; 412 Ok(()) 413} 414 415pub async fn check_refresh_token_used( 416 pool: &PgPool, 417 refresh_token: &str, 418) -> Result<Option<i32>, OAuthError> { 419 let row = sqlx::query_scalar!( 420 r#" 421 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1 422 "#, 423 refresh_token 424 ) 425 .fetch_optional(pool) 426 .await?; 427 428 Ok(row) 429} 430 431pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> { 432 sqlx::query!( 433 r#" 434 DELETE FROM oauth_token WHERE token_id = $1 435 "#, 436 token_id 437 ) 438 .execute(pool) 439 .await?; 440 441 Ok(()) 442} 443 444pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> { 445 sqlx::query!( 446 r#" 447 DELETE FROM oauth_token WHERE id = $1 448 "#, 449 db_id 450 ) 451 .execute(pool) 452 .await?; 453 454 Ok(()) 455} 456 457pub async fn upsert_account_device( 458 pool: &PgPool, 459 did: &str, 460 device_id: &str, 461) -> Result<(), OAuthError> { 462 sqlx::query!( 463 r#" 464 INSERT INTO oauth_account_device (did, device_id, created_at, updated_at) 465 VALUES ($1, $2, NOW(), NOW()) 466 ON CONFLICT (did, device_id) DO UPDATE SET updated_at = NOW() 467 "#, 468 did, 469 device_id 470 ) 471 .execute(pool) 472 .await?; 473 474 Ok(()) 475} 476 477pub async fn upsert_authorized_client( 478 pool: &PgPool, 479 did: &str, 480 client_id: &str, 481 data: &AuthorizedClientData, 482) -> Result<(), OAuthError> { 483 let data_json = to_json(data)?; 484 485 sqlx::query!( 486 r#" 487 INSERT INTO oauth_authorized_client (did, client_id, created_at, updated_at, data) 488 VALUES ($1, $2, NOW(), NOW(), $3) 489 ON CONFLICT (did, client_id) DO UPDATE SET updated_at = NOW(), data = $3 490 "#, 491 did, 492 client_id, 493 data_json 494 ) 495 .execute(pool) 496 .await?; 497 498 Ok(()) 499} 500 501pub async fn get_authorized_client( 502 pool: &PgPool, 503 did: &str, 504 client_id: &str, 505) -> Result<Option<AuthorizedClientData>, OAuthError> { 506 let row = sqlx::query_scalar!( 507 r#" 508 SELECT data FROM oauth_authorized_client 509 WHERE did = $1 AND client_id = $2 510 "#, 511 did, 512 client_id 513 ) 514 .fetch_optional(pool) 515 .await?; 516 517 match row { 518 Some(v) => Ok(Some(from_json(v)?)), 519 None => Ok(None), 520 } 521} 522 523pub async fn list_tokens_for_user( 524 pool: &PgPool, 525 did: &str, 526) -> Result<Vec<TokenData>, OAuthError> { 527 let rows = sqlx::query!( 528 r#" 529 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 530 device_id, parameters, details, code, current_refresh_token, scope 531 FROM oauth_token 532 WHERE did = $1 533 "#, 534 did 535 ) 536 .fetch_all(pool) 537 .await?; 538 539 let mut tokens = Vec::with_capacity(rows.len()); 540 for r in rows { 541 tokens.push(TokenData { 542 did: r.did, 543 token_id: r.token_id, 544 created_at: r.created_at, 545 updated_at: r.updated_at, 546 expires_at: r.expires_at, 547 client_id: r.client_id, 548 client_auth: from_json(r.client_auth)?, 549 device_id: r.device_id, 550 parameters: from_json(r.parameters)?, 551 details: r.details, 552 code: r.code, 553 current_refresh_token: r.current_refresh_token, 554 scope: r.scope, 555 }); 556 } 557 Ok(tokens) 558} 559 560pub async fn check_and_record_dpop_jti( 561 pool: &PgPool, 562 jti: &str, 563) -> Result<bool, OAuthError> { 564 let result = sqlx::query!( 565 r#" 566 INSERT INTO oauth_dpop_jti (jti) 567 VALUES ($1) 568 ON CONFLICT (jti) DO NOTHING 569 "#, 570 jti 571 ) 572 .execute(pool) 573 .await?; 574 575 Ok(result.rows_affected() > 0) 576} 577 578pub async fn cleanup_expired_dpop_jtis( 579 pool: &PgPool, 580 max_age_secs: i64, 581) -> Result<u64, OAuthError> { 582 let result = sqlx::query!( 583 r#" 584 DELETE FROM oauth_dpop_jti 585 WHERE created_at < NOW() - INTERVAL '1 second' * $1 586 "#, 587 max_age_secs as f64 588 ) 589 .execute(pool) 590 .await?; 591 592 Ok(result.rows_affected()) 593} 594 595pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> { 596 let count = sqlx::query_scalar!( 597 r#" 598 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1 599 "#, 600 did 601 ) 602 .fetch_one(pool) 603 .await?; 604 605 Ok(count) 606} 607 608pub async fn delete_oldest_tokens_for_user( 609 pool: &PgPool, 610 did: &str, 611 keep_count: i64, 612) -> Result<u64, OAuthError> { 613 let result = sqlx::query!( 614 r#" 615 DELETE FROM oauth_token 616 WHERE id IN ( 617 SELECT id FROM oauth_token 618 WHERE did = $1 619 ORDER BY updated_at ASC 620 OFFSET $2 621 ) 622 "#, 623 did, 624 keep_count 625 ) 626 .execute(pool) 627 .await?; 628 629 Ok(result.rows_affected()) 630} 631 632const MAX_TOKENS_PER_USER: i64 = 100; 633 634pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> { 635 let count = count_tokens_for_user(pool, did).await?; 636 if count > MAX_TOKENS_PER_USER { 637 let to_keep = MAX_TOKENS_PER_USER - 1; 638 delete_oldest_tokens_for_user(pool, did, to_keep).await?; 639 } 640 Ok(()) 641}