i18n+filtering fork - fluent-templates v2
at main 526 lines 17 kB view raw
1use std::borrow::Cow; 2 3use chrono::{DateTime, Utc}; 4use serde_json::json; 5 6use crate::{ 7 jose::jwk::WrappedJsonWebKey, 8 storage::{errors::StorageError, handle::model::Handle, StoragePool}, 9}; 10use model::{OAuthRequest, OAuthSession}; 11 12pub struct OAuthRequestParams { 13 pub oauth_state: Cow<'static, str>, 14 pub issuer: Cow<'static, str>, 15 pub did: Cow<'static, str>, 16 pub nonce: Cow<'static, str>, 17 pub pkce_verifier: Cow<'static, str>, 18 pub secret_jwk_id: Cow<'static, str>, 19 pub dpop_jwk: Option<WrappedJsonWebKey>, 20 pub destination: Option<Cow<'static, str>>, 21 pub created_at: DateTime<Utc>, 22 pub expires_at: DateTime<Utc>, 23} 24 25pub async fn oauth_request_insert( 26 pool: &StoragePool, 27 params: OAuthRequestParams, 28) -> Result<(), StorageError> { 29 // Validate required input parameters 30 if params.oauth_state.trim().is_empty() { 31 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 32 "OAuth state cannot be empty".into(), 33 ))); 34 } 35 36 if params.issuer.trim().is_empty() { 37 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 38 "Issuer cannot be empty".into(), 39 ))); 40 } 41 42 if params.did.trim().is_empty() { 43 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 44 "DID cannot be empty".into(), 45 ))); 46 } 47 48 if params.nonce.trim().is_empty() { 49 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 50 "Nonce cannot be empty".into(), 51 ))); 52 } 53 54 if params.pkce_verifier.trim().is_empty() { 55 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 56 "PKCE verifier cannot be empty".into(), 57 ))); 58 } 59 60 if params.secret_jwk_id.trim().is_empty() { 61 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 62 "Secret JWK ID cannot be empty".into(), 63 ))); 64 } 65 66 let mut tx = pool 67 .begin() 68 .await 69 .map_err(StorageError::CannotBeginDatabaseTransaction)?; 70 71 let dpop_jwk_value = params 72 .dpop_jwk 73 .map(|jwk| json!(jwk)) 74 .unwrap_or_else(|| json!({})); 75 76 sqlx::query("INSERT INTO oauth_requests (oauth_state, issuer, did, nonce, pkce_verifier, secret_jwk_id, dpop_jwk, destination, created_at, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)") 77 .bind(&params.oauth_state) 78 .bind(&params.issuer) 79 .bind(&params.did) 80 .bind(&params.nonce) 81 .bind(&params.pkce_verifier) 82 .bind(&params.secret_jwk_id) 83 .bind(dpop_jwk_value) 84 .bind(params.destination) 85 .bind(params.created_at) 86 .bind(params.expires_at) 87 .execute(tx.as_mut()) 88 .await 89 .map_err(StorageError::UnableToExecuteQuery)?; 90 91 tx.commit() 92 .await 93 .map_err(StorageError::CannotCommitDatabaseTransaction) 94} 95 96pub async fn oauth_request_get( 97 pool: &StoragePool, 98 oauth_state: &str, 99) -> Result<OAuthRequest, StorageError> { 100 // Validate oauth_state is not empty 101 if oauth_state.trim().is_empty() { 102 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 103 "OAuth state cannot be empty".into(), 104 ))); 105 } 106 107 let mut tx = pool 108 .begin() 109 .await 110 .map_err(StorageError::CannotBeginDatabaseTransaction)?; 111 112 let record = 113 sqlx::query_as::<_, OAuthRequest>("SELECT * FROM oauth_requests WHERE oauth_state = $1") 114 .bind(oauth_state) 115 .fetch_one(tx.as_mut()) 116 .await 117 .map_err(|err| match err { 118 sqlx::Error::RowNotFound => StorageError::OAuthRequestNotFound, 119 other => StorageError::UnableToExecuteQuery(other), 120 })?; 121 122 tx.commit() 123 .await 124 .map_err(StorageError::CannotCommitDatabaseTransaction)?; 125 126 Ok(record) 127} 128 129pub async fn oauth_request_remove( 130 pool: &StoragePool, 131 oauth_state: &str, 132) -> Result<(), StorageError> { 133 // Validate oauth_state is not empty 134 if oauth_state.trim().is_empty() { 135 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 136 "OAuth state cannot be empty".into(), 137 ))); 138 } 139 140 let mut tx = pool 141 .begin() 142 .await 143 .map_err(StorageError::CannotBeginDatabaseTransaction)?; 144 145 sqlx::query("DELETE FROM oauth_requests WHERE oauth_state = $1") 146 .bind(oauth_state) 147 .execute(tx.as_mut()) 148 .await 149 .map_err(StorageError::UnableToExecuteQuery)?; 150 151 tx.commit() 152 .await 153 .map_err(StorageError::CannotCommitDatabaseTransaction) 154} 155 156pub struct OAuthSessionParams { 157 pub session_group: Cow<'static, str>, 158 pub access_token: Cow<'static, str>, 159 pub did: Cow<'static, str>, 160 pub issuer: Cow<'static, str>, 161 pub refresh_token: Cow<'static, str>, 162 pub secret_jwk_id: Cow<'static, str>, 163 pub dpop_jwk: WrappedJsonWebKey, 164 pub created_at: DateTime<Utc>, 165 pub access_token_expires_at: DateTime<Utc>, 166} 167 168pub async fn oauth_session_insert( 169 pool: &StoragePool, 170 params: OAuthSessionParams, 171) -> Result<(), StorageError> { 172 // Validate required input parameters 173 if params.session_group.trim().is_empty() { 174 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 175 "Session group cannot be empty".into(), 176 ))); 177 } 178 179 if params.access_token.trim().is_empty() { 180 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 181 "Access token cannot be empty".into(), 182 ))); 183 } 184 185 if params.did.trim().is_empty() { 186 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 187 "DID cannot be empty".into(), 188 ))); 189 } 190 191 if params.issuer.trim().is_empty() { 192 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 193 "Issuer cannot be empty".into(), 194 ))); 195 } 196 197 if params.refresh_token.trim().is_empty() { 198 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 199 "Refresh token cannot be empty".into(), 200 ))); 201 } 202 203 if params.secret_jwk_id.trim().is_empty() { 204 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 205 "Secret JWK ID cannot be empty".into(), 206 ))); 207 } 208 209 let mut tx = pool 210 .begin() 211 .await 212 .map_err(StorageError::CannotBeginDatabaseTransaction)?; 213 214 sqlx::query("INSERT INTO oauth_sessions (session_group, access_token, did, issuer, refresh_token, secret_jwk_id, dpop_jwk, created_at, access_token_expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)") 215 .bind(&params.session_group) 216 .bind(&params.access_token) 217 .bind(&params.did) 218 .bind(&params.issuer) 219 .bind(&params.refresh_token) 220 .bind(&params.secret_jwk_id) 221 .bind(json!(params.dpop_jwk)) 222 .bind(params.created_at) 223 .bind(params.access_token_expires_at) 224 .execute(tx.as_mut()) 225 .await 226 .map_err(StorageError::UnableToExecuteQuery)?; 227 228 tx.commit() 229 .await 230 .map_err(StorageError::CannotCommitDatabaseTransaction) 231} 232 233pub async fn oauth_session_update( 234 pool: &StoragePool, 235 session_group: Cow<'_, str>, 236 access_token: Cow<'_, str>, 237 refresh_token: Cow<'_, str>, 238 access_token_expires_at: DateTime<Utc>, 239) -> Result<(), StorageError> { 240 // Validate input parameters 241 if session_group.trim().is_empty() { 242 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 243 "Session group cannot be empty".into(), 244 ))); 245 } 246 247 if access_token.trim().is_empty() { 248 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 249 "Access token cannot be empty".into(), 250 ))); 251 } 252 253 if refresh_token.trim().is_empty() { 254 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 255 "Refresh token cannot be empty".into(), 256 ))); 257 } 258 259 let mut tx = pool 260 .begin() 261 .await 262 .map_err(StorageError::CannotBeginDatabaseTransaction)?; 263 264 sqlx::query("UPDATE oauth_sessions SET access_token = $1, refresh_token = $2, access_token_expires_at = $3 WHERE session_group = $4") 265 .bind(access_token) 266 .bind(refresh_token) 267 .bind(access_token_expires_at) 268 .bind(session_group) 269 .execute(tx.as_mut()) 270 .await 271 .map_err(StorageError::UnableToExecuteQuery)?; 272 273 tx.commit() 274 .await 275 .map_err(StorageError::CannotCommitDatabaseTransaction) 276} 277 278/// Delete an OAuth session by its session group. 279pub async fn oauth_session_delete( 280 pool: &StoragePool, 281 session_group: &str, 282) -> Result<(), StorageError> { 283 // Validate session_group is not empty 284 if session_group.trim().is_empty() { 285 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 286 "Session group cannot be empty".into(), 287 ))); 288 } 289 290 let mut tx = pool 291 .begin() 292 .await 293 .map_err(StorageError::CannotBeginDatabaseTransaction)?; 294 295 sqlx::query("DELETE FROM oauth_sessions WHERE session_group = $1") 296 .bind(session_group) 297 .execute(tx.as_mut()) 298 .await 299 .map_err(StorageError::UnableToExecuteQuery)?; 300 301 tx.commit() 302 .await 303 .map_err(StorageError::CannotCommitDatabaseTransaction) 304} 305 306/// Look up a web session by session group and optionally filter by DID. 307pub async fn web_session_lookup( 308 pool: &StoragePool, 309 session_group: &str, 310 did: Option<&str>, 311) -> Result<(Handle, OAuthSession), StorageError> { 312 // Validate session_group is not empty 313 if session_group.trim().is_empty() { 314 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 315 "Session group cannot be empty".into(), 316 ))); 317 } 318 319 // If did is provided, validate it's not empty 320 if let Some(did_value) = did { 321 if did_value.trim().is_empty() { 322 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 323 "DID cannot be empty".into(), 324 ))); 325 } 326 } 327 328 let mut tx = pool 329 .begin() 330 .await 331 .map_err(StorageError::CannotBeginDatabaseTransaction)?; 332 333 let oauth_session = match did { 334 Some(did_value) => { 335 sqlx::query_as::<_, OAuthSession>( 336 "SELECT * FROM oauth_sessions WHERE session_group = $1 AND did = $2 ORDER BY created_at DESC LIMIT 1", 337 ) 338 .bind(session_group) 339 .bind(did_value) 340 .fetch_one(tx.as_mut()) 341 .await 342 }, 343 None => { 344 sqlx::query_as::<_, OAuthSession>( 345 "SELECT * FROM oauth_sessions WHERE session_group = $1 ORDER BY created_at DESC LIMIT 1", 346 ) 347 .bind(session_group) 348 .fetch_one(tx.as_mut()) 349 .await 350 } 351 } 352 .map_err(|err| match err { 353 sqlx::Error::RowNotFound => StorageError::WebSessionNotFound, 354 other => StorageError::UnableToExecuteQuery(other), 355 })?; 356 357 let did_for_handle = did.unwrap_or(&oauth_session.did); 358 359 let handle = sqlx::query_as::<_, Handle>("SELECT * FROM handles WHERE did = $1") 360 .bind(did_for_handle) 361 .fetch_one(tx.as_mut()) 362 .await 363 .map_err(|err| match err { 364 sqlx::Error::RowNotFound => StorageError::HandleNotFound, 365 other => StorageError::UnableToExecuteQuery(other), 366 })?; 367 368 tx.commit() 369 .await 370 .map_err(StorageError::CannotCommitDatabaseTransaction)?; 371 372 Ok((handle, oauth_session)) 373} 374 375pub mod model { 376 use anyhow::Error; 377 use chrono::{DateTime, Utc}; 378 use p256::SecretKey; 379 use serde::Deserialize; 380 use sqlx::FromRow; 381 382 use crate::{ 383 atproto::auth::SimpleOAuthSessionProvider, jose::jwk::WrappedJsonWebKey, 384 storage::errors::OAuthModelError, 385 }; 386 387 #[derive(Clone, FromRow, Deserialize)] 388 pub struct OAuthRequest { 389 pub oauth_state: String, 390 pub issuer: String, 391 pub did: String, 392 pub nonce: String, 393 pub pkce_verifier: String, 394 pub secret_jwk_id: String, 395 pub destination: Option<String>, 396 pub dpop_jwk: sqlx::types::Json<WrappedJsonWebKey>, 397 pub created_at: DateTime<Utc>, 398 pub expires_at: DateTime<Utc>, 399 } 400 401 pub struct OAuthRequestState { 402 pub state: String, 403 pub nonce: String, 404 pub code_challenge: String, 405 } 406 407 #[derive(Clone, FromRow, Deserialize)] 408 pub struct OAuthSession { 409 pub session_group: String, 410 pub access_token: String, 411 pub did: String, 412 pub issuer: String, 413 pub refresh_token: String, 414 pub secret_jwk_id: String, 415 pub dpop_jwk: sqlx::types::Json<WrappedJsonWebKey>, 416 pub created_at: DateTime<Utc>, 417 pub access_token_expires_at: DateTime<Utc>, 418 } 419 420 impl TryFrom<OAuthSession> for SimpleOAuthSessionProvider { 421 type Error = Error; 422 423 fn try_from(value: OAuthSession) -> Result<Self, Self::Error> { 424 let dpop_secret = SecretKey::from_jwk(&value.dpop_jwk.jwk) 425 .map_err(OAuthModelError::DpopSecretFromJwkFailed)?; 426 427 Ok(SimpleOAuthSessionProvider { 428 access_token: value.access_token, 429 issuer: value.issuer, 430 dpop_secret, 431 }) 432 } 433 } 434} 435 436#[cfg(test)] 437pub mod test { 438 use sqlx::PgPool; 439 440 use crate::{ 441 jose, 442 storage::oauth::{ 443 oauth_request_get, oauth_request_insert, oauth_request_remove, oauth_session_insert, 444 web_session_lookup, OAuthRequestParams, OAuthSessionParams, 445 }, 446 }; 447 448 #[sqlx::test(fixtures(path = "../../fixtures/storage", scripts("handles")))] 449 async fn test_oauth_request(pool: PgPool) -> anyhow::Result<()> { 450 let dpop_jwk = jose::jwk::generate(); 451 let created_at = chrono::Utc::now(); 452 let expires_at = created_at + chrono::Duration::seconds(60 as i64); 453 454 let res = oauth_request_insert( 455 &pool, 456 OAuthRequestParams { 457 oauth_state: "oauth_state".to_string().into(), 458 issuer: "pds.examplepds.com".to_string().into(), 459 did: "did:plc:d5c1ed6d01421a67b96f68fa".to_string().into(), 460 nonce: "nonce".to_string().into(), 461 pkce_verifier: "pkce_verifier".to_string().into(), 462 secret_jwk_id: "secret_jwk_id".to_string().into(), 463 dpop_jwk: Some(dpop_jwk.clone()), 464 destination: None, 465 created_at, 466 expires_at, 467 }, 468 ) 469 .await; 470 471 assert!(!res.is_err()); 472 473 let oauth_request = oauth_request_get(&pool, "oauth_state").await; 474 assert!(!oauth_request.is_err()); 475 let oauth_request = oauth_request.unwrap(); 476 477 assert_eq!(oauth_request.did, "did:plc:d5c1ed6d01421a67b96f68fa"); 478 assert_eq!(oauth_request.dpop_jwk.as_ref(), &dpop_jwk); 479 480 let res = oauth_request_remove(&pool, "oauth_state").await; 481 assert!(!res.is_err()); 482 483 { 484 let oauth_request = oauth_request_get(&pool, "oauth_state").await; 485 assert!(oauth_request.is_err()); 486 } 487 488 Ok(()) 489 } 490 491 #[sqlx::test(fixtures(path = "../../fixtures/storage", scripts("handles")))] 492 async fn test_oauth_session(pool: PgPool) -> anyhow::Result<()> { 493 let dpop_jwk = jose::jwk::generate(); 494 495 let session_group = ulid::Ulid::new().to_string(); 496 let now = chrono::Utc::now(); 497 498 let insert_session_res = oauth_session_insert( 499 &pool, 500 OAuthSessionParams { 501 session_group: session_group.clone().into(), 502 access_token: "access_token".to_string().into(), 503 did: "did:plc:d5c1ed6d01421a67b96f68fa".to_string().into(), 504 issuer: "pds.examplepds.com".to_string().into(), 505 refresh_token: "refresh_token".to_string().into(), 506 secret_jwk_id: "secret_jwk_id".to_string().into(), 507 dpop_jwk: dpop_jwk.clone(), 508 created_at: now, 509 access_token_expires_at: now + chrono::Duration::seconds(60 as i64), 510 }, 511 ) 512 .await; 513 514 assert!(!insert_session_res.is_err()); 515 516 let web_session = web_session_lookup( 517 &pool, 518 &session_group, 519 Some("did:plc:d5c1ed6d01421a67b96f68fa"), 520 ) 521 .await; 522 assert!(!web_session.is_err()); 523 524 Ok(()) 525 } 526}