Our Personal Data Server from scratch! tranquil.farm
oauth atproto pds rust postgresql objectstorage fun

fix: oauth consolidation, include-scope improvements #4

merged opened by lewis.moe targeting main from fix/oauth-on-niche-apps
  • auth extraction should be happening in the auth crate, yes, who coulda thought
  • include: scope should actually be doing the right thing and going out and requesting stuff to expand out the perms
  • more tests!!!1!
  • more correct parsing of the #bsky-appview or whatever suffixes on did webs that come through auth
Labels

None yet.

assignee
Participants 2
AT URI
at://did:plc:3fwecdnvtcscjnrx2p4n7alz/sh.tangled.repo.pull/3md3bluniqt22
+4881 -3339
Diff #2
-77
.sqlx/query-06eb7c6e1983b6121526ba63612236391290c2e63d37d2bb1cd89ea822950a82.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "\n SELECT token, request_uri, provider as \"provider: SsoProviderType\",\n provider_user_id, provider_username, provider_email, created_at, expires_at\n FROM sso_pending_registration\n WHERE token = $1 AND expires_at > NOW()\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "token", 9 - "type_info": "Text" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "request_uri", 14 - "type_info": "Text" 15 - }, 16 - { 17 - "ordinal": 2, 18 - "name": "provider: SsoProviderType", 19 - "type_info": { 20 - "Custom": { 21 - "name": "sso_provider_type", 22 - "kind": { 23 - "Enum": [ 24 - "github", 25 - "discord", 26 - "google", 27 - "gitlab", 28 - "oidc" 29 - ] 30 - } 31 - } 32 - } 33 - }, 34 - { 35 - "ordinal": 3, 36 - "name": "provider_user_id", 37 - "type_info": "Text" 38 - }, 39 - { 40 - "ordinal": 4, 41 - "name": "provider_username", 42 - "type_info": "Text" 43 - }, 44 - { 45 - "ordinal": 5, 46 - "name": "provider_email", 47 - "type_info": "Text" 48 - }, 49 - { 50 - "ordinal": 6, 51 - "name": "created_at", 52 - "type_info": "Timestamptz" 53 - }, 54 - { 55 - "ordinal": 7, 56 - "name": "expires_at", 57 - "type_info": "Timestamptz" 58 - } 59 - ], 60 - "parameters": { 61 - "Left": [ 62 - "Text" 63 - ] 64 - }, 65 - "nullable": [ 66 - false, 67 - false, 68 - false, 69 - false, 70 - true, 71 - true, 72 - false, 73 - false 74 - ] 75 - }, 76 - "hash": "06eb7c6e1983b6121526ba63612236391290c2e63d37d2bb1cd89ea822950a82" 77 - }
-77
.sqlx/query-5031b96c65078d6c54954ce6e57ff9cbba4c48dd8a7546882ab5647114ffab4a.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "\n DELETE FROM sso_pending_registration\n WHERE token = $1 AND expires_at > NOW()\n RETURNING token, request_uri, provider as \"provider: SsoProviderType\",\n provider_user_id, provider_username, provider_email, created_at, expires_at\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "token", 9 - "type_info": "Text" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "request_uri", 14 - "type_info": "Text" 15 - }, 16 - { 17 - "ordinal": 2, 18 - "name": "provider: SsoProviderType", 19 - "type_info": { 20 - "Custom": { 21 - "name": "sso_provider_type", 22 - "kind": { 23 - "Enum": [ 24 - "github", 25 - "discord", 26 - "google", 27 - "gitlab", 28 - "oidc" 29 - ] 30 - } 31 - } 32 - } 33 - }, 34 - { 35 - "ordinal": 3, 36 - "name": "provider_user_id", 37 - "type_info": "Text" 38 - }, 39 - { 40 - "ordinal": 4, 41 - "name": "provider_username", 42 - "type_info": "Text" 43 - }, 44 - { 45 - "ordinal": 5, 46 - "name": "provider_email", 47 - "type_info": "Text" 48 - }, 49 - { 50 - "ordinal": 6, 51 - "name": "created_at", 52 - "type_info": "Timestamptz" 53 - }, 54 - { 55 - "ordinal": 7, 56 - "name": "expires_at", 57 - "type_info": "Timestamptz" 58 - } 59 - ], 60 - "parameters": { 61 - "Left": [ 62 - "Text" 63 - ] 64 - }, 65 - "nullable": [ 66 - false, 67 - false, 68 - false, 69 - false, 70 - true, 71 - true, 72 - false, 73 - false 74 - ] 75 - }, 76 - "hash": "5031b96c65078d6c54954ce6e57ff9cbba4c48dd8a7546882ab5647114ffab4a" 77 - }
-22
.sqlx/query-6258398accee69e0c5f455a3c0ecc273b3da6ef5bb4d8660adafe63d8e3cd2d4.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT email_verified FROM users WHERE email = $1 OR handle = $1", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "email_verified", 9 - "type_info": "Bool" 10 - } 11 - ], 12 - "parameters": { 13 - "Left": [ 14 - "Text" 15 - ] 16 - }, 17 - "nullable": [ 18 - false 19 - ] 20 - }, 21 - "hash": "6258398accee69e0c5f455a3c0ecc273b3da6ef5bb4d8660adafe63d8e3cd2d4" 22 - }
-31
.sqlx/query-a4dc8fb22bd094d414c55b9da20b610f7b122b485ab0fd0d0646d68ae8e64fe6.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "\n INSERT INTO external_identities (did, provider, provider_user_id, provider_username, provider_email)\n VALUES ($1, $2, $3, $4, $5)\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Left": [ 8 - "Text", 9 - { 10 - "Custom": { 11 - "name": "sso_provider_type", 12 - "kind": { 13 - "Enum": [ 14 - "github", 15 - "discord", 16 - "google", 17 - "gitlab", 18 - "oidc" 19 - ] 20 - } 21 - } 22 - }, 23 - "Text", 24 - "Text", 25 - "Text" 26 - ] 27 - }, 28 - "nullable": [] 29 - }, 30 - "hash": "a4dc8fb22bd094d414c55b9da20b610f7b122b485ab0fd0d0646d68ae8e64fe6" 31 - }
-32
.sqlx/query-dec3a21a8e60cc8d2c5dad727750bc88f5535dedae244f7b6e4afa95769b8f1a.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "\n INSERT INTO sso_pending_registration (token, request_uri, provider, provider_user_id, provider_username, provider_email)\n VALUES ($1, $2, $3, $4, $5, $6)\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Left": [ 8 - "Text", 9 - "Text", 10 - { 11 - "Custom": { 12 - "name": "sso_provider_type", 13 - "kind": { 14 - "Enum": [ 15 - "github", 16 - "discord", 17 - "google", 18 - "gitlab", 19 - "oidc" 20 - ] 21 - } 22 - } 23 - }, 24 - "Text", 25 - "Text", 26 - "Text" 27 - ] 28 - }, 29 - "nullable": [] 30 - }, 31 - "hash": "dec3a21a8e60cc8d2c5dad727750bc88f5535dedae244f7b6e4afa95769b8f1a" 32 - }
+2
Cargo.lock
··· 6156 6156 dependencies = [ 6157 6157 "axum", 6158 6158 "futures", 6159 + "hickory-resolver", 6159 6160 "reqwest", 6160 6161 "serde", 6161 6162 "serde_json", 6162 6163 "tokio", 6163 6164 "tracing", 6165 + "urlencoding", 6164 6166 ] 6165 6167 6166 6168 [[package]]
+21 -12
crates/tranquil-pds/src/api/actor/preferences.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::BearerAuthAllowDeactivated; 2 + use crate::auth::RequiredAuth; 3 3 use crate::state::AppState; 4 4 use axum::{ 5 5 Json, ··· 32 32 pub struct GetPreferencesOutput { 33 33 pub preferences: Vec<Value>, 34 34 } 35 - pub async fn get_preferences( 36 - State(state): State<AppState>, 37 - auth: BearerAuthAllowDeactivated, 38 - ) -> Response { 39 - let auth_user = auth.0; 40 - let has_full_access = auth_user.permissions().has_full_access(); 41 - let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&auth_user.did).await { 35 + pub async fn get_preferences(State(state): State<AppState>, auth: RequiredAuth) -> Response { 36 + let user = match auth.0.require_user() { 37 + Ok(u) => u, 38 + Err(e) => return e.into_response(), 39 + }; 40 + if let Err(e) = user.require_not_takendown() { 41 + return e.into_response(); 42 + } 43 + let has_full_access = user.permissions().has_full_access(); 44 + let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&user.did).await { 42 45 Ok(Some(id)) => id, 43 46 _ => { 44 47 return ApiError::InternalError(Some("User not found".into())).into_response(); ··· 93 96 } 94 97 pub async fn put_preferences( 95 98 State(state): State<AppState>, 96 - auth: BearerAuthAllowDeactivated, 99 + auth: RequiredAuth, 97 100 Json(input): Json<PutPreferencesInput>, 98 101 ) -> Response { 99 - let auth_user = auth.0; 100 - let has_full_access = auth_user.permissions().has_full_access(); 101 - let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&auth_user.did).await { 102 + let user = match auth.0.require_user() { 103 + Ok(u) => u, 104 + Err(e) => return e.into_response(), 105 + }; 106 + if let Err(e) = user.require_not_takendown() { 107 + return e.into_response(); 108 + } 109 + let has_full_access = user.permissions().has_full_access(); 110 + let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&user.did).await { 102 111 Ok(Some(id)) => id, 103 112 _ => { 104 113 return ApiError::InternalError(Some("User not found".into())).into_response();
+22 -18
crates/tranquil-pds/src/api/admin/account/delete.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::BearerAuthAdmin; 3 + use crate::auth::RequiredAuth; 4 4 use crate::state::AppState; 5 5 use crate::types::Did; 6 6 use axum::{ ··· 18 18 19 19 pub async fn delete_account( 20 20 State(state): State<AppState>, 21 - _auth: BearerAuthAdmin, 21 + auth: RequiredAuth, 22 22 Json(input): Json<DeleteAccountInput>, 23 - ) -> Response { 23 + ) -> Result<Response, ApiError> { 24 + auth.0.require_user()?.require_active()?.require_admin()?; 25 + 24 26 let did = &input.did; 25 - let (user_id, handle) = match state.user_repo.get_id_and_handle_by_did(did).await { 26 - Ok(Some(row)) => (row.id, row.handle), 27 - Ok(None) => { 28 - return ApiError::AccountNotFound.into_response(); 29 - } 30 - Err(e) => { 27 + let (user_id, handle) = state 28 + .user_repo 29 + .get_id_and_handle_by_did(did) 30 + .await 31 + .map_err(|e| { 31 32 error!("DB error in delete_account: {:?}", e); 32 - return ApiError::InternalError(None).into_response(); 33 - } 34 - }; 35 - if let Err(e) = state 33 + ApiError::InternalError(None) 34 + })? 35 + .ok_or(ApiError::AccountNotFound) 36 + .map(|row| (row.id, row.handle))?; 37 + 38 + state 36 39 .user_repo 37 40 .admin_delete_account_complete(user_id, did) 38 41 .await 39 - { 40 - error!("Failed to delete account {}: {:?}", did, e); 41 - return ApiError::InternalError(Some("Failed to delete account".into())).into_response(); 42 - } 42 + .map_err(|e| { 43 + error!("Failed to delete account {}: {:?}", did, e); 44 + ApiError::InternalError(Some("Failed to delete account".into())) 45 + })?; 46 + 43 47 if let Err(e) = 44 48 crate::api::repo::record::sequence_account_event(&state, did, false, Some("deleted")).await 45 49 { ··· 49 53 ); 50 54 } 51 55 let _ = state.cache.delete(&format!("handle:{}", handle)).await; 52 - EmptyResponse::ok().into_response() 56 + Ok(EmptyResponse::ok().into_response()) 53 57 }
+18 -21
crates/tranquil-pds/src/api/admin/account/email.rs
··· 1 1 use crate::api::error::{ApiError, AtpJson}; 2 - use crate::auth::BearerAuthAdmin; 2 + use crate::auth::RequiredAuth; 3 3 use crate::state::AppState; 4 4 use crate::types::Did; 5 5 use axum::{ ··· 28 28 29 29 pub async fn send_email( 30 30 State(state): State<AppState>, 31 - _auth: BearerAuthAdmin, 31 + auth: RequiredAuth, 32 32 AtpJson(input): AtpJson<SendEmailInput>, 33 - ) -> Response { 33 + ) -> Result<Response, ApiError> { 34 + auth.0.require_user()?.require_active()?.require_admin()?; 35 + 34 36 let content = input.content.trim(); 35 37 if content.is_empty() { 36 - return ApiError::InvalidRequest("content is required".into()).into_response(); 38 + return Err(ApiError::InvalidRequest("content is required".into())); 37 39 } 38 - let user = match state.user_repo.get_by_did(&input.recipient_did).await { 39 - Ok(Some(row)) => row, 40 - Ok(None) => { 41 - return ApiError::AccountNotFound.into_response(); 42 - } 43 - Err(e) => { 40 + let user = state 41 + .user_repo 42 + .get_by_did(&input.recipient_did) 43 + .await 44 + .map_err(|e| { 44 45 error!("DB error in send_email: {:?}", e); 45 - return ApiError::InternalError(None).into_response(); 46 - } 47 - }; 48 - let email = match user.email { 49 - Some(e) => e, 50 - None => { 51 - return ApiError::NoEmail.into_response(); 52 - } 53 - }; 46 + ApiError::InternalError(None) 47 + })? 48 + .ok_or(ApiError::AccountNotFound)?; 49 + 50 + let email = user.email.ok_or(ApiError::NoEmail)?; 54 51 let (user_id, handle) = (user.id, user.handle); 55 52 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 56 53 let subject = input ··· 76 73 handle, 77 74 input.recipient_did 78 75 ); 79 - (StatusCode::OK, Json(SendEmailOutput { sent: true })).into_response() 76 + Ok((StatusCode::OK, Json(SendEmailOutput { sent: true })).into_response()) 80 77 } 81 78 Err(e) => { 82 79 warn!("Failed to enqueue admin email: {:?}", e); 83 - (StatusCode::OK, Json(SendEmailOutput { sent: false })).into_response() 80 + Ok((StatusCode::OK, Json(SendEmailOutput { sent: false })).into_response()) 84 81 } 85 82 } 86 83 }
+22 -24
crates/tranquil-pds/src/api/admin/account/info.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::BearerAuthAdmin; 2 + use crate::auth::RequiredAuth; 3 3 use crate::state::AppState; 4 4 use crate::types::{Did, Handle}; 5 5 use axum::{ ··· 67 67 68 68 pub async fn get_account_info( 69 69 State(state): State<AppState>, 70 - _auth: BearerAuthAdmin, 70 + auth: RequiredAuth, 71 71 Query(params): Query<GetAccountInfoParams>, 72 - ) -> Response { 73 - let account = match state 72 + ) -> Result<Response, ApiError> { 73 + auth.0.require_user()?.require_active()?.require_admin()?; 74 + 75 + let account = state 74 76 .infra_repo 75 77 .get_admin_account_info_by_did(&params.did) 76 78 .await 77 - { 78 - Ok(Some(a)) => a, 79 - Ok(None) => return ApiError::AccountNotFound.into_response(), 80 - Err(e) => { 79 + .map_err(|e| { 81 80 error!("DB error in get_account_info: {:?}", e); 82 - return ApiError::InternalError(None).into_response(); 83 - } 84 - }; 81 + ApiError::InternalError(None) 82 + })? 83 + .ok_or(ApiError::AccountNotFound)?; 85 84 86 85 let invited_by = get_invited_by(&state, account.id).await; 87 86 let invites = get_invites_for_user(&state, account.id).await; 88 87 89 - ( 88 + Ok(( 90 89 StatusCode::OK, 91 90 Json(AccountInfo { 92 91 did: account.did, ··· 105 104 invites, 106 105 }), 107 106 ) 108 - .into_response() 107 + .into_response()) 109 108 } 110 109 111 110 async fn get_invited_by(state: &AppState, user_id: uuid::Uuid) -> Option<InviteCodeInfo> { ··· 200 199 201 200 pub async fn get_account_infos( 202 201 State(state): State<AppState>, 203 - _auth: BearerAuthAdmin, 202 + auth: RequiredAuth, 204 203 RawQuery(raw_query): RawQuery, 205 - ) -> Response { 204 + ) -> Result<Response, ApiError> { 205 + auth.0.require_user()?.require_active()?.require_admin()?; 206 + 206 207 let dids: Vec<String> = crate::util::parse_repeated_query_param(raw_query.as_deref(), "dids") 207 208 .into_iter() 208 209 .filter(|d| !d.is_empty()) 209 210 .collect(); 210 211 211 212 if dids.is_empty() { 212 - return ApiError::InvalidRequest("dids is required".into()).into_response(); 213 + return Err(ApiError::InvalidRequest("dids is required".into())); 213 214 } 214 215 215 216 let dids_typed: Vec<Did> = dids.iter().filter_map(|d| d.parse().ok()).collect(); 216 - let accounts = match state 217 + let accounts = state 217 218 .infra_repo 218 219 .get_admin_account_infos_by_dids(&dids_typed) 219 220 .await 220 - { 221 - Ok(accounts) => accounts, 222 - Err(e) => { 221 + .map_err(|e| { 223 222 error!("Failed to fetch account infos: {:?}", e); 224 - return ApiError::InternalError(None).into_response(); 225 - } 226 - }; 223 + ApiError::InternalError(None) 224 + })?; 227 225 228 226 let user_ids: Vec<uuid::Uuid> = accounts.iter().map(|u| u.id).collect(); 229 227 ··· 316 314 }) 317 315 .collect(); 318 316 319 - (StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response() 317 + Ok((StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response()) 320 318 }
+41 -42
crates/tranquil-pds/src/api/admin/account/search.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::BearerAuthAdmin; 2 + use crate::auth::RequiredAuth; 3 3 use crate::state::AppState; 4 4 use crate::types::{Did, Handle}; 5 5 use axum::{ ··· 50 50 51 51 pub async fn search_accounts( 52 52 State(state): State<AppState>, 53 - _auth: BearerAuthAdmin, 53 + auth: RequiredAuth, 54 54 Query(params): Query<SearchAccountsParams>, 55 - ) -> Response { 55 + ) -> Result<Response, ApiError> { 56 + auth.0.require_user()?.require_active()?.require_admin()?; 57 + 56 58 let limit = params.limit.clamp(1, 100); 57 59 let email_filter = params.email.as_deref().map(|e| format!("%{}%", e)); 58 60 let handle_filter = params.handle.as_deref().map(|h| format!("%{}%", h)); 59 61 let cursor_did: Option<Did> = params.cursor.as_ref().and_then(|c| c.parse().ok()); 60 - let result = state 62 + let rows = state 61 63 .user_repo 62 64 .search_accounts( 63 65 cursor_did.as_ref(), ··· 65 67 handle_filter.as_deref(), 66 68 limit + 1, 67 69 ) 68 - .await; 69 - match result { 70 - Ok(rows) => { 71 - let has_more = rows.len() > limit as usize; 72 - let accounts: Vec<AccountView> = rows 73 - .into_iter() 74 - .take(limit as usize) 75 - .map(|row| AccountView { 76 - did: row.did.clone(), 77 - handle: row.handle, 78 - email: row.email, 79 - indexed_at: row.created_at.to_rfc3339(), 80 - email_confirmed_at: if row.email_verified { 81 - Some(row.created_at.to_rfc3339()) 82 - } else { 83 - None 84 - }, 85 - deactivated_at: row.deactivated_at.map(|dt| dt.to_rfc3339()), 86 - invites_disabled: row.invites_disabled, 87 - }) 88 - .collect(); 89 - let next_cursor = if has_more { 90 - accounts.last().map(|a| a.did.to_string()) 70 + .await 71 + .map_err(|e| { 72 + error!("DB error in search_accounts: {:?}", e); 73 + ApiError::InternalError(None) 74 + })?; 75 + 76 + let has_more = rows.len() > limit as usize; 77 + let accounts: Vec<AccountView> = rows 78 + .into_iter() 79 + .take(limit as usize) 80 + .map(|row| AccountView { 81 + did: row.did.clone(), 82 + handle: row.handle, 83 + email: row.email, 84 + indexed_at: row.created_at.to_rfc3339(), 85 + email_confirmed_at: if row.email_verified { 86 + Some(row.created_at.to_rfc3339()) 91 87 } else { 92 88 None 93 - }; 94 - ( 95 - StatusCode::OK, 96 - Json(SearchAccountsOutput { 97 - cursor: next_cursor, 98 - accounts, 99 - }), 100 - ) 101 - .into_response() 102 - } 103 - Err(e) => { 104 - error!("DB error in search_accounts: {:?}", e); 105 - ApiError::InternalError(None).into_response() 106 - } 107 - } 89 + }, 90 + deactivated_at: row.deactivated_at.map(|dt| dt.to_rfc3339()), 91 + invites_disabled: row.invites_disabled, 92 + }) 93 + .collect(); 94 + let next_cursor = if has_more { 95 + accounts.last().map(|a| a.did.to_string()) 96 + } else { 97 + None 98 + }; 99 + Ok(( 100 + StatusCode::OK, 101 + Json(SearchAccountsOutput { 102 + cursor: next_cursor, 103 + accounts, 104 + }), 105 + ) 106 + .into_response()) 108 107 }
+45 -36
crates/tranquil-pds/src/api/admin/account/update.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::BearerAuthAdmin; 3 + use crate::auth::RequiredAuth; 4 4 use crate::state::AppState; 5 5 use crate::types::{Did, Handle, PlainPassword}; 6 6 use axum::{ ··· 19 19 20 20 pub async fn update_account_email( 21 21 State(state): State<AppState>, 22 - _auth: BearerAuthAdmin, 22 + auth: RequiredAuth, 23 23 Json(input): Json<UpdateAccountEmailInput>, 24 - ) -> Response { 24 + ) -> Result<Response, ApiError> { 25 + auth.0.require_user()?.require_active()?.require_admin()?; 26 + 25 27 let account = input.account.trim(); 26 28 let email = input.email.trim(); 27 29 if account.is_empty() || email.is_empty() { 28 - return ApiError::InvalidRequest("account and email are required".into()).into_response(); 30 + return Err(ApiError::InvalidRequest( 31 + "account and email are required".into(), 32 + )); 29 33 } 30 - let account_did: Did = match account.parse() { 31 - Ok(d) => d, 32 - Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 33 - }; 34 + let account_did: Did = account 35 + .parse() 36 + .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 37 + 34 38 match state 35 39 .user_repo 36 40 .admin_update_email(&account_did, email) 37 41 .await 38 42 { 39 - Ok(0) => ApiError::AccountNotFound.into_response(), 40 - Ok(_) => EmptyResponse::ok().into_response(), 43 + Ok(0) => Err(ApiError::AccountNotFound), 44 + Ok(_) => Ok(EmptyResponse::ok().into_response()), 41 45 Err(e) => { 42 46 error!("DB error updating email: {:?}", e); 43 - ApiError::InternalError(None).into_response() 47 + Err(ApiError::InternalError(None)) 44 48 } 45 49 } 46 50 } ··· 53 57 54 58 pub async fn update_account_handle( 55 59 State(state): State<AppState>, 56 - _auth: BearerAuthAdmin, 60 + auth: RequiredAuth, 57 61 Json(input): Json<UpdateAccountHandleInput>, 58 - ) -> Response { 62 + ) -> Result<Response, ApiError> { 63 + auth.0.require_user()?.require_active()?.require_admin()?; 64 + 59 65 let did = &input.did; 60 66 let input_handle = input.handle.trim(); 61 67 if input_handle.is_empty() { 62 - return ApiError::InvalidRequest("handle is required".into()).into_response(); 68 + return Err(ApiError::InvalidRequest("handle is required".into())); 63 69 } 64 70 if !input_handle 65 71 .chars() 66 72 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 67 73 { 68 - return ApiError::InvalidHandle(None).into_response(); 74 + return Err(ApiError::InvalidHandle(None)); 69 75 } 70 76 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 71 77 let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); ··· 75 81 input_handle.to_string() 76 82 }; 77 83 let old_handle = state.user_repo.get_handle_by_did(did).await.ok().flatten(); 78 - let user_id = match state.user_repo.get_id_by_did(did).await { 79 - Ok(Some(id)) => id, 80 - _ => return ApiError::AccountNotFound.into_response(), 81 - }; 84 + let user_id = state 85 + .user_repo 86 + .get_id_by_did(did) 87 + .await 88 + .ok() 89 + .flatten() 90 + .ok_or(ApiError::AccountNotFound)?; 82 91 let handle_for_check = Handle::new_unchecked(&handle); 83 92 if let Ok(true) = state 84 93 .user_repo 85 94 .check_handle_exists(&handle_for_check, user_id) 86 95 .await 87 96 { 88 - return ApiError::HandleTaken.into_response(); 97 + return Err(ApiError::HandleTaken); 89 98 } 90 99 match state 91 100 .user_repo 92 101 .admin_update_handle(did, &handle_for_check) 93 102 .await 94 103 { 95 - Ok(0) => ApiError::AccountNotFound.into_response(), 104 + Ok(0) => Err(ApiError::AccountNotFound), 96 105 Ok(_) => { 97 106 if let Some(old) = old_handle { 98 107 let _ = state.cache.delete(&format!("handle:{}", old)).await; ··· 115 124 { 116 125 warn!("Failed to update PLC handle for admin handle update: {}", e); 117 126 } 118 - EmptyResponse::ok().into_response() 127 + Ok(EmptyResponse::ok().into_response()) 119 128 } 120 129 Err(e) => { 121 130 error!("DB error updating handle: {:?}", e); 122 - ApiError::InternalError(None).into_response() 131 + Err(ApiError::InternalError(None)) 123 132 } 124 133 } 125 134 } ··· 132 141 133 142 pub async fn update_account_password( 134 143 State(state): State<AppState>, 135 - _auth: BearerAuthAdmin, 144 + auth: RequiredAuth, 136 145 Json(input): Json<UpdateAccountPasswordInput>, 137 - ) -> Response { 146 + ) -> Result<Response, ApiError> { 147 + auth.0.require_user()?.require_active()?.require_admin()?; 148 + 138 149 let did = &input.did; 139 150 let password = input.password.trim(); 140 151 if password.is_empty() { 141 - return ApiError::InvalidRequest("password is required".into()).into_response(); 152 + return Err(ApiError::InvalidRequest("password is required".into())); 142 153 } 143 - let password_hash = match bcrypt::hash(password, bcrypt::DEFAULT_COST) { 144 - Ok(h) => h, 145 - Err(e) => { 146 - error!("Failed to hash password: {:?}", e); 147 - return ApiError::InternalError(None).into_response(); 148 - } 149 - }; 154 + let password_hash = bcrypt::hash(password, bcrypt::DEFAULT_COST).map_err(|e| { 155 + error!("Failed to hash password: {:?}", e); 156 + ApiError::InternalError(None) 157 + })?; 158 + 150 159 match state 151 160 .user_repo 152 161 .admin_update_password(did, &password_hash) 153 162 .await 154 163 { 155 - Ok(0) => ApiError::AccountNotFound.into_response(), 156 - Ok(_) => EmptyResponse::ok().into_response(), 164 + Ok(0) => Err(ApiError::AccountNotFound), 165 + Ok(_) => Ok(EmptyResponse::ok().into_response()), 157 166 Err(e) => { 158 167 error!("DB error updating password: {:?}", e); 159 - ApiError::InternalError(None).into_response() 168 + Err(ApiError::InternalError(None)) 160 169 } 161 170 } 162 171 }
+4 -2
crates/tranquil-pds/src/api/admin/config.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::BearerAuthAdmin; 2 + use crate::auth::RequiredAuth; 3 3 use crate::state::AppState; 4 4 use axum::{Json, extract::State}; 5 5 use serde::{Deserialize, Serialize}; ··· 78 78 79 79 pub async fn update_server_config( 80 80 State(state): State<AppState>, 81 - _admin: BearerAuthAdmin, 81 + auth: RequiredAuth, 82 82 Json(req): Json<UpdateServerConfigRequest>, 83 83 ) -> Result<Json<UpdateServerConfigResponse>, ApiError> { 84 + auth.0.require_user()?.require_active()?.require_admin()?; 85 + 84 86 if let Some(server_name) = req.server_name { 85 87 let trimmed = server_name.trim(); 86 88 if trimmed.is_empty() || trimmed.len() > 100 {
+40 -35
crates/tranquil-pds/src/api/admin/invite.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::BearerAuthAdmin; 3 + use crate::auth::RequiredAuth; 4 4 use crate::state::AppState; 5 5 use axum::{ 6 6 Json, ··· 21 21 22 22 pub async fn disable_invite_codes( 23 23 State(state): State<AppState>, 24 - _auth: BearerAuthAdmin, 24 + auth: RequiredAuth, 25 25 Json(input): Json<DisableInviteCodesInput>, 26 - ) -> Response { 26 + ) -> Result<Response, ApiError> { 27 + auth.0.require_user()?.require_active()?.require_admin()?; 28 + 27 29 if let Some(codes) = &input.codes 28 30 && let Err(e) = state.infra_repo.disable_invite_codes_by_code(codes).await 29 31 { ··· 40 42 error!("DB error disabling invite codes by account: {:?}", e); 41 43 } 42 44 } 43 - EmptyResponse::ok().into_response() 45 + Ok(EmptyResponse::ok().into_response()) 44 46 } 45 47 46 48 #[derive(Deserialize)] ··· 78 80 79 81 pub async fn get_invite_codes( 80 82 State(state): State<AppState>, 81 - _auth: BearerAuthAdmin, 83 + auth: RequiredAuth, 82 84 Query(params): Query<GetInviteCodesParams>, 83 - ) -> Response { 85 + ) -> Result<Response, ApiError> { 86 + auth.0.require_user()?.require_active()?.require_admin()?; 87 + 84 88 let limit = params.limit.unwrap_or(100).clamp(1, 500); 85 89 let sort_order = match params.sort.as_deref() { 86 90 Some("usage") => InviteCodeSortOrder::Usage, 87 91 _ => InviteCodeSortOrder::Recent, 88 92 }; 89 93 90 - let codes_rows = match state 94 + let codes_rows = state 91 95 .infra_repo 92 96 .list_invite_codes(params.cursor.as_deref(), limit, sort_order) 93 97 .await 94 - { 95 - Ok(rows) => rows, 96 - Err(e) => { 98 + .map_err(|e| { 97 99 error!("DB error fetching invite codes: {:?}", e); 98 - return ApiError::InternalError(None).into_response(); 99 - } 100 - }; 100 + ApiError::InternalError(None) 101 + })?; 101 102 102 103 let user_ids: Vec<uuid::Uuid> = codes_rows.iter().map(|r| r.created_by_user).collect(); 103 104 let code_strings: Vec<String> = codes_rows.iter().map(|r| r.code.clone()).collect(); ··· 155 156 } else { 156 157 None 157 158 }; 158 - ( 159 + Ok(( 159 160 StatusCode::OK, 160 161 Json(GetInviteCodesOutput { 161 162 cursor: next_cursor, 162 163 codes, 163 164 }), 164 165 ) 165 - .into_response() 166 + .into_response()) 166 167 } 167 168 168 169 #[derive(Deserialize)] ··· 172 173 173 174 pub async fn disable_account_invites( 174 175 State(state): State<AppState>, 175 - _auth: BearerAuthAdmin, 176 + auth: RequiredAuth, 176 177 Json(input): Json<DisableAccountInvitesInput>, 177 - ) -> Response { 178 + ) -> Result<Response, ApiError> { 179 + auth.0.require_user()?.require_active()?.require_admin()?; 180 + 178 181 let account = input.account.trim(); 179 182 if account.is_empty() { 180 - return ApiError::InvalidRequest("account is required".into()).into_response(); 183 + return Err(ApiError::InvalidRequest("account is required".into())); 181 184 } 182 - let account_did: tranquil_types::Did = match account.parse() { 183 - Ok(d) => d, 184 - Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 185 - }; 185 + let account_did: tranquil_types::Did = account 186 + .parse() 187 + .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 188 + 186 189 match state 187 190 .user_repo 188 191 .set_invites_disabled(&account_did, true) 189 192 .await 190 193 { 191 - Ok(true) => EmptyResponse::ok().into_response(), 192 - Ok(false) => ApiError::AccountNotFound.into_response(), 194 + Ok(true) => Ok(EmptyResponse::ok().into_response()), 195 + Ok(false) => Err(ApiError::AccountNotFound), 193 196 Err(e) => { 194 197 error!("DB error disabling account invites: {:?}", e); 195 - ApiError::InternalError(None).into_response() 198 + Err(ApiError::InternalError(None)) 196 199 } 197 200 } 198 201 } ··· 204 207 205 208 pub async fn enable_account_invites( 206 209 State(state): State<AppState>, 207 - _auth: BearerAuthAdmin, 210 + auth: RequiredAuth, 208 211 Json(input): Json<EnableAccountInvitesInput>, 209 - ) -> Response { 212 + ) -> Result<Response, ApiError> { 213 + auth.0.require_user()?.require_active()?.require_admin()?; 214 + 210 215 let account = input.account.trim(); 211 216 if account.is_empty() { 212 - return ApiError::InvalidRequest("account is required".into()).into_response(); 217 + return Err(ApiError::InvalidRequest("account is required".into())); 213 218 } 214 - let account_did: tranquil_types::Did = match account.parse() { 215 - Ok(d) => d, 216 - Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 217 - }; 219 + let account_did: tranquil_types::Did = account 220 + .parse() 221 + .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 222 + 218 223 match state 219 224 .user_repo 220 225 .set_invites_disabled(&account_did, false) 221 226 .await 222 227 { 223 - Ok(true) => EmptyResponse::ok().into_response(), 224 - Ok(false) => ApiError::AccountNotFound.into_response(), 228 + Ok(true) => Ok(EmptyResponse::ok().into_response()), 229 + Ok(false) => Err(ApiError::AccountNotFound), 225 230 Err(e) => { 226 231 error!("DB error enabling account invites: {:?}", e); 227 - ApiError::InternalError(None).into_response() 232 + Err(ApiError::InternalError(None)) 228 233 } 229 234 } 230 235 }
+10 -4
crates/tranquil-pds/src/api/admin/server_stats.rs
··· 1 - use crate::auth::BearerAuthAdmin; 1 + use crate::api::error::ApiError; 2 + use crate::auth::RequiredAuth; 2 3 use crate::state::AppState; 3 4 use axum::{ 4 5 Json, ··· 16 17 pub blob_storage_bytes: i64, 17 18 } 18 19 19 - pub async fn get_server_stats(State(state): State<AppState>, _auth: BearerAuthAdmin) -> Response { 20 + pub async fn get_server_stats( 21 + State(state): State<AppState>, 22 + auth: RequiredAuth, 23 + ) -> Result<Response, ApiError> { 24 + auth.0.require_user()?.require_active()?.require_admin()?; 25 + 20 26 let user_count = state.user_repo.count_users().await.unwrap_or(0); 21 27 let repo_count = state.repo_repo.count_repos().await.unwrap_or(0); 22 28 let record_count = state.repo_repo.count_all_records().await.unwrap_or(0); 23 29 let blob_storage_bytes = state.blob_repo.sum_blob_storage().await.unwrap_or(0); 24 30 25 - Json(ServerStatsResponse { 31 + Ok(Json(ServerStatsResponse { 26 32 user_count, 27 33 repo_count, 28 34 record_count, 29 35 blob_storage_bytes, 30 36 }) 31 - .into_response() 37 + .into_response()) 32 38 }
+77 -94
crates/tranquil-pds/src/api/admin/status.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::BearerAuthAdmin; 2 + use crate::auth::RequiredAuth; 3 3 use crate::state::AppState; 4 4 use crate::types::{CidLink, Did}; 5 5 use axum::{ ··· 35 35 36 36 pub async fn get_subject_status( 37 37 State(state): State<AppState>, 38 - _auth: BearerAuthAdmin, 38 + auth: RequiredAuth, 39 39 Query(params): Query<GetSubjectStatusParams>, 40 - ) -> Response { 40 + ) -> Result<Response, ApiError> { 41 + auth.0.require_user()?.require_active()?.require_admin()?; 42 + 41 43 if params.did.is_none() && params.uri.is_none() && params.blob.is_none() { 42 - return ApiError::InvalidRequest("Must provide did, uri, or blob".into()).into_response(); 44 + return Err(ApiError::InvalidRequest( 45 + "Must provide did, uri, or blob".into(), 46 + )); 43 47 } 44 48 if let Some(did_str) = &params.did { 45 - let did: Did = match did_str.parse() { 46 - Ok(d) => d, 47 - Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 48 - }; 49 + let did: Did = did_str 50 + .parse() 51 + .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 49 52 match state.user_repo.get_status_by_did(&did).await { 50 53 Ok(Some(status)) => { 51 54 let deactivated = status.deactivated_at.map(|_| StatusAttr { ··· 56 59 applied: true, 57 60 r#ref: Some(r.clone()), 58 61 }); 59 - return ( 62 + return Ok(( 60 63 StatusCode::OK, 61 64 Json(SubjectStatus { 62 65 subject: json!({ ··· 67 70 deactivated, 68 71 }), 69 72 ) 70 - .into_response(); 73 + .into_response()); 71 74 } 72 75 Ok(None) => { 73 - return ApiError::SubjectNotFound.into_response(); 76 + return Err(ApiError::SubjectNotFound); 74 77 } 75 78 Err(e) => { 76 79 error!("DB error in get_subject_status: {:?}", e); 77 - return ApiError::InternalError(None).into_response(); 80 + return Err(ApiError::InternalError(None)); 78 81 } 79 82 } 80 83 } 81 84 if let Some(uri_str) = &params.uri { 82 - let cid: CidLink = match uri_str.parse() { 83 - Ok(c) => c, 84 - Err(_) => return ApiError::InvalidRequest("Invalid CID format".into()).into_response(), 85 - }; 85 + let cid: CidLink = uri_str 86 + .parse() 87 + .map_err(|_| ApiError::InvalidRequest("Invalid CID format".into()))?; 86 88 match state.repo_repo.get_record_by_cid(&cid).await { 87 89 Ok(Some(record)) => { 88 90 let takedown = record.takedown_ref.as_ref().map(|r| StatusAttr { 89 91 applied: true, 90 92 r#ref: Some(r.clone()), 91 93 }); 92 - return ( 94 + return Ok(( 93 95 StatusCode::OK, 94 96 Json(SubjectStatus { 95 97 subject: json!({ ··· 101 103 deactivated: None, 102 104 }), 103 105 ) 104 - .into_response(); 106 + .into_response()); 105 107 } 106 108 Ok(None) => { 107 - return ApiError::RecordNotFound.into_response(); 109 + return Err(ApiError::RecordNotFound); 108 110 } 109 111 Err(e) => { 110 112 error!("DB error in get_subject_status: {:?}", e); 111 - return ApiError::InternalError(None).into_response(); 113 + return Err(ApiError::InternalError(None)); 112 114 } 113 115 } 114 116 } 115 117 if let Some(blob_cid_str) = &params.blob { 116 - let blob_cid: CidLink = match blob_cid_str.parse() { 117 - Ok(c) => c, 118 - Err(_) => return ApiError::InvalidRequest("Invalid CID format".into()).into_response(), 119 - }; 120 - let did = match &params.did { 121 - Some(d) => d, 122 - None => { 123 - return ApiError::InvalidRequest("Must provide a did to request blob state".into()) 124 - .into_response(); 125 - } 126 - }; 118 + let blob_cid: CidLink = blob_cid_str 119 + .parse() 120 + .map_err(|_| ApiError::InvalidRequest("Invalid CID format".into()))?; 121 + let did = params.did.as_ref().ok_or_else(|| { 122 + ApiError::InvalidRequest("Must provide a did to request blob state".into()) 123 + })?; 127 124 match state.blob_repo.get_blob_with_takedown(&blob_cid).await { 128 125 Ok(Some(blob)) => { 129 126 let takedown = blob.takedown_ref.as_ref().map(|r| StatusAttr { 130 127 applied: true, 131 128 r#ref: Some(r.clone()), 132 129 }); 133 - return ( 130 + return Ok(( 134 131 StatusCode::OK, 135 132 Json(SubjectStatus { 136 133 subject: json!({ ··· 142 139 deactivated: None, 143 140 }), 144 141 ) 145 - .into_response(); 142 + .into_response()); 146 143 } 147 144 Ok(None) => { 148 - return ApiError::BlobNotFound(None).into_response(); 145 + return Err(ApiError::BlobNotFound(None)); 149 146 } 150 147 Err(e) => { 151 148 error!("DB error in get_subject_status: {:?}", e); 152 - return ApiError::InternalError(None).into_response(); 149 + return Err(ApiError::InternalError(None)); 153 150 } 154 151 } 155 152 } 156 - ApiError::InvalidRequest("Invalid subject type".into()).into_response() 153 + Err(ApiError::InvalidRequest("Invalid subject type".into())) 157 154 } 158 155 159 156 #[derive(Deserialize)] ··· 172 169 173 170 pub async fn update_subject_status( 174 171 State(state): State<AppState>, 175 - _auth: BearerAuthAdmin, 172 + auth: RequiredAuth, 176 173 Json(input): Json<UpdateSubjectStatusInput>, 177 - ) -> Response { 174 + ) -> Result<Response, ApiError> { 175 + auth.0.require_user()?.require_active()?.require_admin()?; 176 + 178 177 let subject_type = input.subject.get("$type").and_then(|t| t.as_str()); 179 178 match subject_type { 180 179 Some("com.atproto.admin.defs#repoRef") => { ··· 187 186 } else { 188 187 None 189 188 }; 190 - if let Err(e) = state.user_repo.set_user_takedown(&did, takedown_ref).await { 191 - error!("Failed to update user takedown status for {}: {:?}", did, e); 192 - return ApiError::InternalError(Some( 193 - "Failed to update takedown status".into(), 194 - )) 195 - .into_response(); 196 - } 189 + state 190 + .user_repo 191 + .set_user_takedown(&did, takedown_ref) 192 + .await 193 + .map_err(|e| { 194 + error!("Failed to update user takedown status for {}: {:?}", did, e); 195 + ApiError::InternalError(Some("Failed to update takedown status".into())) 196 + })?; 197 197 } 198 198 if let Some(deactivated) = &input.deactivated { 199 199 let result = if deactivated.applied { ··· 201 201 } else { 202 202 state.user_repo.activate_account(&did).await 203 203 }; 204 - if let Err(e) = result { 204 + result.map_err(|e| { 205 205 error!( 206 206 "Failed to update user deactivation status for {}: {:?}", 207 207 did, e 208 208 ); 209 - return ApiError::InternalError(Some( 210 - "Failed to update deactivation status".into(), 211 - )) 212 - .into_response(); 213 - } 209 + ApiError::InternalError(Some("Failed to update deactivation status".into())) 210 + })?; 214 211 } 215 212 if let Some(takedown) = &input.takedown { 216 213 let status = if takedown.applied { ··· 249 246 if let Ok(Some(handle)) = state.user_repo.get_handle_by_did(&did).await { 250 247 let _ = state.cache.delete(&format!("handle:{}", handle)).await; 251 248 } 252 - return ( 249 + return Ok(( 253 250 StatusCode::OK, 254 251 Json(json!({ 255 252 "subject": input.subject, ··· 262 259 })) 263 260 })), 264 261 ) 265 - .into_response(); 262 + .into_response()); 266 263 } 267 264 } 268 265 Some("com.atproto.repo.strongRef") => { 269 266 let uri_str = input.subject.get("uri").and_then(|u| u.as_str()); 270 267 if let Some(uri_str) = uri_str { 271 - let cid: CidLink = match uri_str.parse() { 272 - Ok(c) => c, 273 - Err(_) => { 274 - return ApiError::InvalidRequest("Invalid CID format".into()) 275 - .into_response(); 276 - } 277 - }; 268 + let cid: CidLink = uri_str 269 + .parse() 270 + .map_err(|_| ApiError::InvalidRequest("Invalid CID format".into()))?; 278 271 if let Some(takedown) = &input.takedown { 279 272 let takedown_ref = if takedown.applied { 280 273 takedown.r#ref.as_deref() 281 274 } else { 282 275 None 283 276 }; 284 - if let Err(e) = state 277 + state 285 278 .repo_repo 286 279 .set_record_takedown(&cid, takedown_ref) 287 280 .await 288 - { 289 - error!( 290 - "Failed to update record takedown status for {}: {:?}", 291 - uri_str, e 292 - ); 293 - return ApiError::InternalError(Some( 294 - "Failed to update takedown status".into(), 295 - )) 296 - .into_response(); 297 - } 281 + .map_err(|e| { 282 + error!( 283 + "Failed to update record takedown status for {}: {:?}", 284 + uri_str, e 285 + ); 286 + ApiError::InternalError(Some("Failed to update takedown status".into())) 287 + })?; 298 288 } 299 - return ( 289 + return Ok(( 300 290 StatusCode::OK, 301 291 Json(json!({ 302 292 "subject": input.subject, ··· 306 296 })) 307 297 })), 308 298 ) 309 - .into_response(); 299 + .into_response()); 310 300 } 311 301 } 312 302 Some("com.atproto.admin.defs#repoBlobRef") => { 313 303 let cid_str = input.subject.get("cid").and_then(|c| c.as_str()); 314 304 if let Some(cid_str) = cid_str { 315 - let cid: CidLink = match cid_str.parse() { 316 - Ok(c) => c, 317 - Err(_) => { 318 - return ApiError::InvalidRequest("Invalid CID format".into()) 319 - .into_response(); 320 - } 321 - }; 305 + let cid: CidLink = cid_str 306 + .parse() 307 + .map_err(|_| ApiError::InvalidRequest("Invalid CID format".into()))?; 322 308 if let Some(takedown) = &input.takedown { 323 309 let takedown_ref = if takedown.applied { 324 310 takedown.r#ref.as_deref() 325 311 } else { 326 312 None 327 313 }; 328 - if let Err(e) = state 314 + state 329 315 .blob_repo 330 316 .update_blob_takedown(&cid, takedown_ref) 331 317 .await 332 - { 333 - error!( 334 - "Failed to update blob takedown status for {}: {:?}", 335 - cid_str, e 336 - ); 337 - return ApiError::InternalError(Some( 338 - "Failed to update takedown status".into(), 339 - )) 340 - .into_response(); 341 - } 318 + .map_err(|e| { 319 + error!( 320 + "Failed to update blob takedown status for {}: {:?}", 321 + cid_str, e 322 + ); 323 + ApiError::InternalError(Some("Failed to update takedown status".into())) 324 + })?; 342 325 } 343 - return ( 326 + return Ok(( 344 327 StatusCode::OK, 345 328 Json(json!({ 346 329 "subject": input.subject, ··· 350 333 })) 351 334 })), 352 335 ) 353 - .into_response(); 336 + .into_response()); 354 337 } 355 338 } 356 339 _ => {} 357 340 } 358 - ApiError::InvalidRequest("Invalid subject type".into()).into_response() 341 + Err(ApiError::InvalidRequest("Invalid subject type".into())) 359 342 }
+99 -73
crates/tranquil-pds/src/api/backup.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::{EmptyResponse, EnabledResponse}; 3 - use crate::auth::BearerAuth; 3 + use crate::auth::RequiredAuth; 4 4 use crate::scheduled::generate_full_backup; 5 5 use crate::state::AppState; 6 6 use crate::storage::{BackupStorage, backup_retention_count}; ··· 35 35 pub backup_enabled: bool, 36 36 } 37 37 38 - pub async fn list_backups(State(state): State<AppState>, auth: BearerAuth) -> Response { 39 - let (user_id, backup_enabled) = 40 - match state.backup_repo.get_user_backup_status(&auth.0.did).await { 41 - Ok(Some(status)) => status, 42 - Ok(None) => { 43 - return ApiError::AccountNotFound.into_response(); 44 - } 45 - Err(e) => { 46 - error!("DB error fetching user: {:?}", e); 47 - return ApiError::InternalError(None).into_response(); 48 - } 49 - }; 38 + pub async fn list_backups( 39 + State(state): State<AppState>, 40 + auth: RequiredAuth, 41 + ) -> Result<Response, crate::api::error::ApiError> { 42 + let user = auth.0.require_user()?.require_active()?; 43 + let (user_id, backup_enabled) = match state.backup_repo.get_user_backup_status(&user.did).await 44 + { 45 + Ok(Some(status)) => status, 46 + Ok(None) => { 47 + return Ok(ApiError::AccountNotFound.into_response()); 48 + } 49 + Err(e) => { 50 + error!("DB error fetching user: {:?}", e); 51 + return Ok(ApiError::InternalError(None).into_response()); 52 + } 53 + }; 50 54 51 55 let backups = match state.backup_repo.list_backups_for_user(user_id).await { 52 56 Ok(rows) => rows, 53 57 Err(e) => { 54 58 error!("DB error fetching backups: {:?}", e); 55 - return ApiError::InternalError(None).into_response(); 59 + return Ok(ApiError::InternalError(None).into_response()); 56 60 } 57 61 }; 58 62 ··· 68 72 }) 69 73 .collect(); 70 74 71 - ( 75 + Ok(( 72 76 StatusCode::OK, 73 77 Json(ListBackupsOutput { 74 78 backups: backup_list, 75 79 backup_enabled, 76 80 }), 77 81 ) 78 - .into_response() 82 + .into_response()) 79 83 } 80 84 81 85 #[derive(Deserialize)] ··· 85 89 86 90 pub async fn get_backup( 87 91 State(state): State<AppState>, 88 - auth: BearerAuth, 92 + auth: RequiredAuth, 89 93 Query(query): Query<GetBackupQuery>, 90 - ) -> Response { 94 + ) -> Result<Response, crate::api::error::ApiError> { 95 + let user = auth.0.require_user()?.require_active()?; 91 96 let backup_id = match uuid::Uuid::parse_str(&query.id) { 92 97 Ok(id) => id, 93 98 Err(_) => { 94 - return ApiError::InvalidRequest("Invalid backup ID".into()).into_response(); 99 + return Ok(ApiError::InvalidRequest("Invalid backup ID".into()).into_response()); 95 100 } 96 101 }; 97 102 98 103 let backup_info = match state 99 104 .backup_repo 100 - .get_backup_storage_info(backup_id, &auth.0.did) 105 + .get_backup_storage_info(backup_id, &user.did) 101 106 .await 102 107 { 103 108 Ok(Some(b)) => b, 104 109 Ok(None) => { 105 - return ApiError::BackupNotFound.into_response(); 110 + return Ok(ApiError::BackupNotFound.into_response()); 106 111 } 107 112 Err(e) => { 108 113 error!("DB error fetching backup: {:?}", e); 109 - return ApiError::InternalError(None).into_response(); 114 + return Ok(ApiError::InternalError(None).into_response()); 110 115 } 111 116 }; 112 117 113 118 let backup_storage = match state.backup_storage.as_ref() { 114 119 Some(storage) => storage, 115 120 None => { 116 - return ApiError::BackupsDisabled.into_response(); 121 + return Ok(ApiError::BackupsDisabled.into_response()); 117 122 } 118 123 }; 119 124 ··· 121 126 Ok(bytes) => bytes, 122 127 Err(e) => { 123 128 error!("Failed to fetch backup from storage: {:?}", e); 124 - return ApiError::InternalError(Some("Failed to retrieve backup".into())) 125 - .into_response(); 129 + return Ok( 130 + ApiError::InternalError(Some("Failed to retrieve backup".into())).into_response(), 131 + ); 126 132 } 127 133 }; 128 134 129 - ( 135 + Ok(( 130 136 StatusCode::OK, 131 137 [ 132 138 (axum::http::header::CONTENT_TYPE, "application/vnd.ipld.car"), ··· 137 143 ], 138 144 car_bytes, 139 145 ) 140 - .into_response() 146 + .into_response()) 141 147 } 142 148 143 149 #[derive(Serialize)] ··· 149 155 pub block_count: i32, 150 156 } 151 157 152 - pub async fn create_backup(State(state): State<AppState>, auth: BearerAuth) -> Response { 158 + pub async fn create_backup( 159 + State(state): State<AppState>, 160 + auth: RequiredAuth, 161 + ) -> Result<Response, crate::api::error::ApiError> { 162 + let auth_user = auth.0.require_user()?.require_active()?; 153 163 let backup_storage = match state.backup_storage.as_ref() { 154 164 Some(storage) => storage, 155 165 None => { 156 - return ApiError::BackupsDisabled.into_response(); 166 + return Ok(ApiError::BackupsDisabled.into_response()); 157 167 } 158 168 }; 159 169 160 - let user = match state.backup_repo.get_user_for_backup(&auth.0.did).await { 170 + let user = match state.backup_repo.get_user_for_backup(&auth_user.did).await { 161 171 Ok(Some(u)) => u, 162 172 Ok(None) => { 163 - return ApiError::AccountNotFound.into_response(); 173 + return Ok(ApiError::AccountNotFound.into_response()); 164 174 } 165 175 Err(e) => { 166 176 error!("DB error fetching user: {:?}", e); 167 - return ApiError::InternalError(None).into_response(); 177 + return Ok(ApiError::InternalError(None).into_response()); 168 178 } 169 179 }; 170 180 171 181 if user.deactivated_at.is_some() { 172 - return ApiError::AccountDeactivated.into_response(); 182 + return Ok(ApiError::AccountDeactivated.into_response()); 173 183 } 174 184 175 185 let repo_rev = match &user.repo_rev { 176 186 Some(rev) => rev.clone(), 177 187 None => { 178 - return ApiError::RepoNotReady.into_response(); 188 + return Ok(ApiError::RepoNotReady.into_response()); 179 189 } 180 190 }; 181 191 182 192 let head_cid = match Cid::from_str(&user.repo_root_cid) { 183 193 Ok(c) => c, 184 194 Err(_) => { 185 - return ApiError::InternalError(Some("Invalid repo root CID".into())).into_response(); 195 + return Ok( 196 + ApiError::InternalError(Some("Invalid repo root CID".into())).into_response(), 197 + ); 186 198 } 187 199 }; 188 200 ··· 197 209 Ok(bytes) => bytes, 198 210 Err(e) => { 199 211 error!("Failed to generate CAR: {:?}", e); 200 - return ApiError::InternalError(Some("Failed to generate backup".into())) 201 - .into_response(); 212 + return Ok( 213 + ApiError::InternalError(Some("Failed to generate backup".into())).into_response(), 214 + ); 202 215 } 203 216 }; 204 217 ··· 212 225 Ok(key) => key, 213 226 Err(e) => { 214 227 error!("Failed to upload backup: {:?}", e); 215 - return ApiError::InternalError(Some("Failed to store backup".into())).into_response(); 228 + return Ok( 229 + ApiError::InternalError(Some("Failed to store backup".into())).into_response(), 230 + ); 216 231 } 217 232 }; 218 233 ··· 238 253 "Failed to rollback orphaned backup from S3" 239 254 ); 240 255 } 241 - return ApiError::InternalError(Some("Failed to record backup".into())).into_response(); 256 + return Ok( 257 + ApiError::InternalError(Some("Failed to record backup".into())).into_response(), 258 + ); 242 259 } 243 260 }; 244 261 ··· 261 278 warn!(did = %user.did, error = %e, "Failed to cleanup old backups after manual backup"); 262 279 } 263 280 264 - ( 281 + Ok(( 265 282 StatusCode::OK, 266 283 Json(CreateBackupOutput { 267 284 id: backup_id.to_string(), ··· 270 287 block_count, 271 288 }), 272 289 ) 273 - .into_response() 290 + .into_response()) 274 291 } 275 292 276 293 async fn cleanup_old_backups( ··· 310 327 311 328 pub async fn delete_backup( 312 329 State(state): State<AppState>, 313 - auth: BearerAuth, 330 + auth: RequiredAuth, 314 331 Query(query): Query<DeleteBackupQuery>, 315 - ) -> Response { 332 + ) -> Result<Response, crate::api::error::ApiError> { 333 + let user = auth.0.require_user()?.require_active()?; 316 334 let backup_id = match uuid::Uuid::parse_str(&query.id) { 317 335 Ok(id) => id, 318 336 Err(_) => { 319 - return ApiError::InvalidRequest("Invalid backup ID".into()).into_response(); 337 + return Ok(ApiError::InvalidRequest("Invalid backup ID".into()).into_response()); 320 338 } 321 339 }; 322 340 323 341 let backup = match state 324 342 .backup_repo 325 - .get_backup_for_deletion(backup_id, &auth.0.did) 343 + .get_backup_for_deletion(backup_id, &user.did) 326 344 .await 327 345 { 328 346 Ok(Some(b)) => b, 329 347 Ok(None) => { 330 - return ApiError::BackupNotFound.into_response(); 348 + return Ok(ApiError::BackupNotFound.into_response()); 331 349 } 332 350 Err(e) => { 333 351 error!("DB error fetching backup: {:?}", e); 334 - return ApiError::InternalError(None).into_response(); 352 + return Ok(ApiError::InternalError(None).into_response()); 335 353 } 336 354 }; 337 355 338 356 if backup.deactivated_at.is_some() { 339 - return ApiError::AccountDeactivated.into_response(); 357 + return Ok(ApiError::AccountDeactivated.into_response()); 340 358 } 341 359 342 360 if let Some(backup_storage) = state.backup_storage.as_ref() ··· 351 369 352 370 if let Err(e) = state.backup_repo.delete_backup(backup.id).await { 353 371 error!("DB error deleting backup: {:?}", e); 354 - return ApiError::InternalError(Some("Failed to delete backup".into())).into_response(); 372 + return Ok(ApiError::InternalError(Some("Failed to delete backup".into())).into_response()); 355 373 } 356 374 357 - info!(did = %auth.0.did, backup_id = %backup_id, "Deleted backup"); 375 + info!(did = %user.did, backup_id = %backup_id, "Deleted backup"); 358 376 359 - EmptyResponse::ok().into_response() 377 + Ok(EmptyResponse::ok().into_response()) 360 378 } 361 379 362 380 #[derive(Deserialize)] ··· 367 385 368 386 pub async fn set_backup_enabled( 369 387 State(state): State<AppState>, 370 - auth: BearerAuth, 388 + auth: RequiredAuth, 371 389 Json(input): Json<SetBackupEnabledInput>, 372 - ) -> Response { 390 + ) -> Result<Response, crate::api::error::ApiError> { 391 + let user = auth.0.require_user()?.require_active()?; 373 392 let deactivated_at = match state 374 393 .backup_repo 375 - .get_user_deactivated_status(&auth.0.did) 394 + .get_user_deactivated_status(&user.did) 376 395 .await 377 396 { 378 397 Ok(Some(status)) => status, 379 398 Ok(None) => { 380 - return ApiError::AccountNotFound.into_response(); 399 + return Ok(ApiError::AccountNotFound.into_response()); 381 400 } 382 401 Err(e) => { 383 402 error!("DB error fetching user: {:?}", e); 384 - return ApiError::InternalError(None).into_response(); 403 + return Ok(ApiError::InternalError(None).into_response()); 385 404 } 386 405 }; 387 406 388 407 if deactivated_at.is_some() { 389 - return ApiError::AccountDeactivated.into_response(); 408 + return Ok(ApiError::AccountDeactivated.into_response()); 390 409 } 391 410 392 411 if let Err(e) = state 393 412 .backup_repo 394 - .update_backup_enabled(&auth.0.did, input.enabled) 413 + .update_backup_enabled(&user.did, input.enabled) 395 414 .await 396 415 { 397 416 error!("DB error updating backup_enabled: {:?}", e); 398 - return ApiError::InternalError(Some("Failed to update setting".into())).into_response(); 417 + return Ok( 418 + ApiError::InternalError(Some("Failed to update setting".into())).into_response(), 419 + ); 399 420 } 400 421 401 - info!(did = %auth.0.did, enabled = input.enabled, "Updated backup_enabled setting"); 422 + info!(did = %user.did, enabled = input.enabled, "Updated backup_enabled setting"); 402 423 403 - EnabledResponse::response(input.enabled).into_response() 424 + Ok(EnabledResponse::response(input.enabled).into_response()) 404 425 } 405 426 406 - pub async fn export_blobs(State(state): State<AppState>, auth: BearerAuth) -> Response { 407 - let user_id = match state.backup_repo.get_user_id_by_did(&auth.0.did).await { 427 + pub async fn export_blobs( 428 + State(state): State<AppState>, 429 + auth: RequiredAuth, 430 + ) -> Result<Response, crate::api::error::ApiError> { 431 + let user = auth.0.require_user()?.require_active()?; 432 + let user_id = match state.backup_repo.get_user_id_by_did(&user.did).await { 408 433 Ok(Some(id)) => id, 409 434 Ok(None) => { 410 - return ApiError::AccountNotFound.into_response(); 435 + return Ok(ApiError::AccountNotFound.into_response()); 411 436 } 412 437 Err(e) => { 413 438 error!("DB error fetching user: {:?}", e); 414 - return ApiError::InternalError(None).into_response(); 439 + return Ok(ApiError::InternalError(None).into_response()); 415 440 } 416 441 }; 417 442 ··· 419 444 Ok(rows) => rows, 420 445 Err(e) => { 421 446 error!("DB error fetching blobs: {:?}", e); 422 - return ApiError::InternalError(None).into_response(); 447 + return Ok(ApiError::InternalError(None).into_response()); 423 448 } 424 449 }; 425 450 426 451 if blobs.is_empty() { 427 - return ( 452 + return Ok(( 428 453 StatusCode::OK, 429 454 [ 430 455 (axum::http::header::CONTENT_TYPE, "application/zip"), ··· 435 460 ], 436 461 Vec::<u8>::new(), 437 462 ) 438 - .into_response(); 463 + .into_response()); 439 464 } 440 465 441 466 let mut zip_buffer = std::io::Cursor::new(Vec::new()); ··· 513 538 514 539 if let Err(e) = zip.finish() { 515 540 error!("Failed to finish zip: {:?}", e); 516 - return ApiError::InternalError(Some("Failed to create zip file".into())) 517 - .into_response(); 541 + return Ok( 542 + ApiError::InternalError(Some("Failed to create zip file".into())).into_response(), 543 + ); 518 544 } 519 545 } 520 546 521 547 let zip_bytes = zip_buffer.into_inner(); 522 548 523 - info!(did = %auth.0.did, blob_count = blobs.len(), size_bytes = zip_bytes.len(), "Exported blobs"); 549 + info!(did = %user.did, blob_count = blobs.len(), size_bytes = zip_bytes.len(), "Exported blobs"); 524 550 525 - ( 551 + Ok(( 526 552 StatusCode::OK, 527 553 [ 528 554 (axum::http::header::CONTENT_TYPE, "application/zip"), ··· 533 559 ], 534 560 zip_bytes, 535 561 ) 536 - .into_response() 562 + .into_response()) 537 563 } 538 564 539 565 fn mime_to_extension(mime_type: &str) -> &'static str {
+122 -98
crates/tranquil-pds/src/api/delegation.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::repo::record::utils::create_signed_commit; 3 - use crate::auth::BearerAuth; 3 + use crate::auth::RequiredAuth; 4 4 use crate::delegation::{DelegationActionType, SCOPE_PRESETS, scopes}; 5 5 use crate::state::{AppState, RateLimitKind}; 6 6 use crate::types::{Did, Handle, Nsid, Rkey}; ··· 33 33 pub controllers: Vec<ControllerInfo>, 34 34 } 35 35 36 - pub async fn list_controllers(State(state): State<AppState>, auth: BearerAuth) -> Response { 36 + pub async fn list_controllers( 37 + State(state): State<AppState>, 38 + auth: RequiredAuth, 39 + ) -> Result<Response, ApiError> { 40 + let user = auth.0.require_user()?.require_active()?; 37 41 let controllers = match state 38 42 .delegation_repo 39 - .get_delegations_for_account(&auth.0.did) 43 + .get_delegations_for_account(&user.did) 40 44 .await 41 45 { 42 46 Ok(c) => c, 43 47 Err(e) => { 44 48 tracing::error!("Failed to list controllers: {:?}", e); 45 - return ApiError::InternalError(Some("Failed to list controllers".into())) 46 - .into_response(); 49 + return Ok( 50 + ApiError::InternalError(Some("Failed to list controllers".into())).into_response(), 51 + ); 47 52 } 48 53 }; 49 54 50 - Json(ListControllersResponse { 55 + Ok(Json(ListControllersResponse { 51 56 controllers: controllers 52 57 .into_iter() 53 58 .map(|c| ControllerInfo { ··· 59 64 }) 60 65 .collect(), 61 66 }) 62 - .into_response() 67 + .into_response()) 63 68 } 64 69 65 70 #[derive(Debug, Deserialize)] ··· 70 75 71 76 pub async fn add_controller( 72 77 State(state): State<AppState>, 73 - auth: BearerAuth, 78 + auth: RequiredAuth, 74 79 Json(input): Json<AddControllerInput>, 75 - ) -> Response { 80 + ) -> Result<Response, ApiError> { 81 + let user = auth.0.require_user()?.require_active()?; 76 82 if let Err(e) = scopes::validate_delegation_scopes(&input.granted_scopes) { 77 - return ApiError::InvalidScopes(e).into_response(); 83 + return Ok(ApiError::InvalidScopes(e).into_response()); 78 84 } 79 85 80 86 let controller_exists = state ··· 86 92 .is_some(); 87 93 88 94 if !controller_exists { 89 - return ApiError::ControllerNotFound.into_response(); 95 + return Ok(ApiError::ControllerNotFound.into_response()); 90 96 } 91 97 92 - match state 93 - .delegation_repo 94 - .controls_any_accounts(&auth.0.did) 95 - .await 96 - { 98 + match state.delegation_repo.controls_any_accounts(&user.did).await { 97 99 Ok(true) => { 98 - return ApiError::InvalidDelegation( 100 + return Ok(ApiError::InvalidDelegation( 99 101 "Cannot add controllers to an account that controls other accounts".into(), 100 102 ) 101 - .into_response(); 103 + .into_response()); 102 104 } 103 105 Err(e) => { 104 106 tracing::error!("Failed to check delegation status: {:?}", e); 105 - return ApiError::InternalError(Some("Failed to verify delegation status".into())) 106 - .into_response(); 107 + return Ok( 108 + ApiError::InternalError(Some("Failed to verify delegation status".into())) 109 + .into_response(), 110 + ); 107 111 } 108 112 Ok(false) => {} 109 113 } ··· 114 118 .await 115 119 { 116 120 Ok(true) => { 117 - return ApiError::InvalidDelegation( 121 + return Ok(ApiError::InvalidDelegation( 118 122 "Cannot add a controlled account as a controller".into(), 119 123 ) 120 - .into_response(); 124 + .into_response()); 121 125 } 122 126 Err(e) => { 123 127 tracing::error!("Failed to check controller status: {:?}", e); 124 - return ApiError::InternalError(Some("Failed to verify controller status".into())) 125 - .into_response(); 128 + return Ok( 129 + ApiError::InternalError(Some("Failed to verify controller status".into())) 130 + .into_response(), 131 + ); 126 132 } 127 133 Ok(false) => {} 128 134 } ··· 130 136 match state 131 137 .delegation_repo 132 138 .create_delegation( 133 - &auth.0.did, 139 + &user.did, 134 140 &input.controller_did, 135 141 &input.granted_scopes, 136 - &auth.0.did, 142 + &user.did, 137 143 ) 138 144 .await 139 145 { ··· 141 147 let _ = state 142 148 .delegation_repo 143 149 .log_delegation_action( 144 - &auth.0.did, 145 - &auth.0.did, 150 + &user.did, 151 + &user.did, 146 152 Some(&input.controller_did), 147 153 DelegationActionType::GrantCreated, 148 154 Some(serde_json::json!({ ··· 153 159 ) 154 160 .await; 155 161 156 - ( 162 + Ok(( 157 163 StatusCode::OK, 158 164 Json(serde_json::json!({ 159 165 "success": true 160 166 })), 161 167 ) 162 - .into_response() 168 + .into_response()) 163 169 } 164 170 Err(e) => { 165 171 tracing::error!("Failed to add controller: {:?}", e); 166 - ApiError::InternalError(Some("Failed to add controller".into())).into_response() 172 + Ok(ApiError::InternalError(Some("Failed to add controller".into())).into_response()) 167 173 } 168 174 } 169 175 } ··· 175 181 176 182 pub async fn remove_controller( 177 183 State(state): State<AppState>, 178 - auth: BearerAuth, 184 + auth: RequiredAuth, 179 185 Json(input): Json<RemoveControllerInput>, 180 - ) -> Response { 186 + ) -> Result<Response, ApiError> { 187 + let user = auth.0.require_user()?.require_active()?; 181 188 match state 182 189 .delegation_repo 183 - .revoke_delegation(&auth.0.did, &input.controller_did, &auth.0.did) 190 + .revoke_delegation(&user.did, &input.controller_did, &user.did) 184 191 .await 185 192 { 186 193 Ok(true) => { 187 194 let revoked_app_passwords = state 188 195 .session_repo 189 - .delete_app_passwords_by_controller(&auth.0.did, &input.controller_did) 196 + .delete_app_passwords_by_controller(&user.did, &input.controller_did) 190 197 .await 191 198 .unwrap_or(0) as usize; 192 199 193 200 let revoked_oauth_tokens = state 194 201 .oauth_repo 195 - .revoke_tokens_for_controller(&auth.0.did, &input.controller_did) 202 + .revoke_tokens_for_controller(&user.did, &input.controller_did) 196 203 .await 197 204 .unwrap_or(0); 198 205 199 206 let _ = state 200 207 .delegation_repo 201 208 .log_delegation_action( 202 - &auth.0.did, 203 - &auth.0.did, 209 + &user.did, 210 + &user.did, 204 211 Some(&input.controller_did), 205 212 DelegationActionType::GrantRevoked, 206 213 Some(serde_json::json!({ ··· 212 219 ) 213 220 .await; 214 221 215 - ( 222 + Ok(( 216 223 StatusCode::OK, 217 224 Json(serde_json::json!({ 218 225 "success": true 219 226 })), 220 227 ) 221 - .into_response() 228 + .into_response()) 222 229 } 223 - Ok(false) => ApiError::DelegationNotFound.into_response(), 230 + Ok(false) => Ok(ApiError::DelegationNotFound.into_response()), 224 231 Err(e) => { 225 232 tracing::error!("Failed to remove controller: {:?}", e); 226 - ApiError::InternalError(Some("Failed to remove controller".into())).into_response() 233 + Ok(ApiError::InternalError(Some("Failed to remove controller".into())).into_response()) 227 234 } 228 235 } 229 236 } ··· 236 243 237 244 pub async fn update_controller_scopes( 238 245 State(state): State<AppState>, 239 - auth: BearerAuth, 246 + auth: RequiredAuth, 240 247 Json(input): Json<UpdateControllerScopesInput>, 241 - ) -> Response { 248 + ) -> Result<Response, ApiError> { 249 + let user = auth.0.require_user()?.require_active()?; 242 250 if let Err(e) = scopes::validate_delegation_scopes(&input.granted_scopes) { 243 - return ApiError::InvalidScopes(e).into_response(); 251 + return Ok(ApiError::InvalidScopes(e).into_response()); 244 252 } 245 253 246 254 match state 247 255 .delegation_repo 248 - .update_delegation_scopes(&auth.0.did, &input.controller_did, &input.granted_scopes) 256 + .update_delegation_scopes(&user.did, &input.controller_did, &input.granted_scopes) 249 257 .await 250 258 { 251 259 Ok(true) => { 252 260 let _ = state 253 261 .delegation_repo 254 262 .log_delegation_action( 255 - &auth.0.did, 256 - &auth.0.did, 263 + &user.did, 264 + &user.did, 257 265 Some(&input.controller_did), 258 266 DelegationActionType::ScopesModified, 259 267 Some(serde_json::json!({ ··· 264 272 ) 265 273 .await; 266 274 267 - ( 275 + Ok(( 268 276 StatusCode::OK, 269 277 Json(serde_json::json!({ 270 278 "success": true 271 279 })), 272 280 ) 273 - .into_response() 281 + .into_response()) 274 282 } 275 - Ok(false) => ApiError::DelegationNotFound.into_response(), 283 + Ok(false) => Ok(ApiError::DelegationNotFound.into_response()), 276 284 Err(e) => { 277 285 tracing::error!("Failed to update controller scopes: {:?}", e); 278 - ApiError::InternalError(Some("Failed to update controller scopes".into())) 279 - .into_response() 286 + Ok( 287 + ApiError::InternalError(Some("Failed to update controller scopes".into())) 288 + .into_response(), 289 + ) 280 290 } 281 291 } 282 292 } ··· 295 305 pub accounts: Vec<DelegatedAccountInfo>, 296 306 } 297 307 298 - pub async fn list_controlled_accounts(State(state): State<AppState>, auth: BearerAuth) -> Response { 308 + pub async fn list_controlled_accounts( 309 + State(state): State<AppState>, 310 + auth: RequiredAuth, 311 + ) -> Result<Response, ApiError> { 312 + let user = auth.0.require_user()?.require_active()?; 299 313 let accounts = match state 300 314 .delegation_repo 301 - .get_accounts_controlled_by(&auth.0.did) 315 + .get_accounts_controlled_by(&user.did) 302 316 .await 303 317 { 304 318 Ok(a) => a, 305 319 Err(e) => { 306 320 tracing::error!("Failed to list controlled accounts: {:?}", e); 307 - return ApiError::InternalError(Some("Failed to list controlled accounts".into())) 308 - .into_response(); 321 + return Ok( 322 + ApiError::InternalError(Some("Failed to list controlled accounts".into())) 323 + .into_response(), 324 + ); 309 325 } 310 326 }; 311 327 312 - Json(ListControlledAccountsResponse { 328 + Ok(Json(ListControlledAccountsResponse { 313 329 accounts: accounts 314 330 .into_iter() 315 331 .map(|a| DelegatedAccountInfo { ··· 320 336 }) 321 337 .collect(), 322 338 }) 323 - .into_response() 339 + .into_response()) 324 340 } 325 341 326 342 #[derive(Debug, Deserialize)] ··· 355 371 356 372 pub async fn get_audit_log( 357 373 State(state): State<AppState>, 358 - auth: BearerAuth, 374 + auth: RequiredAuth, 359 375 Query(params): Query<AuditLogParams>, 360 - ) -> Response { 376 + ) -> Result<Response, ApiError> { 377 + let user = auth.0.require_user()?.require_active()?; 361 378 let limit = params.limit.clamp(1, 100); 362 379 let offset = params.offset.max(0); 363 380 364 381 let entries = match state 365 382 .delegation_repo 366 - .get_audit_log_for_account(&auth.0.did, limit, offset) 383 + .get_audit_log_for_account(&user.did, limit, offset) 367 384 .await 368 385 { 369 386 Ok(e) => e, 370 387 Err(e) => { 371 388 tracing::error!("Failed to get audit log: {:?}", e); 372 - return ApiError::InternalError(Some("Failed to get audit log".into())).into_response(); 389 + return Ok( 390 + ApiError::InternalError(Some("Failed to get audit log".into())).into_response(), 391 + ); 373 392 } 374 393 }; 375 394 376 395 let total = state 377 396 .delegation_repo 378 - .count_audit_log_entries(&auth.0.did) 397 + .count_audit_log_entries(&user.did) 379 398 .await 380 399 .unwrap_or_default(); 381 400 382 - Json(GetAuditLogResponse { 401 + Ok(Json(GetAuditLogResponse { 383 402 entries: entries 384 403 .into_iter() 385 404 .map(|e| AuditLogEntry { ··· 394 413 .collect(), 395 414 total, 396 415 }) 397 - .into_response() 416 + .into_response()) 398 417 } 399 418 400 419 #[derive(Debug, Serialize)] ··· 444 463 pub async fn create_delegated_account( 445 464 State(state): State<AppState>, 446 465 headers: HeaderMap, 447 - auth: BearerAuth, 466 + auth: RequiredAuth, 448 467 Json(input): Json<CreateDelegatedAccountInput>, 449 - ) -> Response { 468 + ) -> Result<Response, ApiError> { 469 + let user = auth.0.require_user()?.require_active()?; 450 470 let client_ip = extract_client_ip(&headers); 451 471 if !state 452 472 .check_rate_limit(RateLimitKind::AccountCreation, &client_ip) 453 473 .await 454 474 { 455 475 warn!(ip = %client_ip, "Delegated account creation rate limit exceeded"); 456 - return ApiError::RateLimitExceeded(Some( 476 + return Ok(ApiError::RateLimitExceeded(Some( 457 477 "Too many account creation attempts. Please try again later.".into(), 458 478 )) 459 - .into_response(); 479 + .into_response()); 460 480 } 461 481 462 482 if let Err(e) = scopes::validate_delegation_scopes(&input.controller_scopes) { 463 - return ApiError::InvalidScopes(e).into_response(); 483 + return Ok(ApiError::InvalidScopes(e).into_response()); 464 484 } 465 485 466 - match state.delegation_repo.has_any_controllers(&auth.0.did).await { 486 + match state.delegation_repo.has_any_controllers(&user.did).await { 467 487 Ok(true) => { 468 - return ApiError::InvalidDelegation( 488 + return Ok(ApiError::InvalidDelegation( 469 489 "Cannot create delegated accounts from a controlled account".into(), 470 490 ) 471 - .into_response(); 491 + .into_response()); 472 492 } 473 493 Err(e) => { 474 494 tracing::error!("Failed to check controller status: {:?}", e); 475 - return ApiError::InternalError(Some("Failed to verify controller status".into())) 476 - .into_response(); 495 + return Ok( 496 + ApiError::InternalError(Some("Failed to verify controller status".into())) 497 + .into_response(), 498 + ); 477 499 } 478 500 Ok(false) => {} 479 501 } ··· 494 516 match crate::api::validation::validate_short_handle(handle_to_validate) { 495 517 Ok(h) => format!("{}.{}", h, hostname_for_handles), 496 518 Err(e) => { 497 - return ApiError::InvalidRequest(e.to_string()).into_response(); 519 + return Ok(ApiError::InvalidRequest(e.to_string()).into_response()); 498 520 } 499 521 } 500 522 } else { ··· 509 531 if let Some(ref email) = email 510 532 && !crate::api::validation::is_valid_email(email) 511 533 { 512 - return ApiError::InvalidEmail.into_response(); 534 + return Ok(ApiError::InvalidEmail.into_response()); 513 535 } 514 536 515 537 if let Some(ref code) = input.invite_code { ··· 520 542 .unwrap_or(false); 521 543 522 544 if !valid { 523 - return ApiError::InvalidInviteCode.into_response(); 545 + return Ok(ApiError::InvalidInviteCode.into_response()); 524 546 } 525 547 } else { 526 548 let invite_required = std::env::var("INVITE_CODE_REQUIRED") 527 549 .map(|v| v == "true" || v == "1") 528 550 .unwrap_or(false); 529 551 if invite_required { 530 - return ApiError::InviteCodeRequired.into_response(); 552 + return Ok(ApiError::InviteCodeRequired.into_response()); 531 553 } 532 554 } 533 555 ··· 542 564 Ok(k) => k, 543 565 Err(e) => { 544 566 error!("Error creating signing key: {:?}", e); 545 - return ApiError::InternalError(None).into_response(); 567 + return Ok(ApiError::InternalError(None).into_response()); 546 568 } 547 569 }; 548 570 ··· 558 580 Ok(r) => r, 559 581 Err(e) => { 560 582 error!("Error creating PLC genesis operation: {:?}", e); 561 - return ApiError::InternalError(Some("Failed to create PLC operation".into())) 562 - .into_response(); 583 + return Ok( 584 + ApiError::InternalError(Some("Failed to create PLC operation".into())) 585 + .into_response(), 586 + ); 563 587 } 564 588 }; 565 589 ··· 569 593 .await 570 594 { 571 595 error!("Failed to submit PLC genesis operation: {:?}", e); 572 - return ApiError::UpstreamErrorMsg(format!( 596 + return Ok(ApiError::UpstreamErrorMsg(format!( 573 597 "Failed to register DID with PLC directory: {}", 574 598 e 575 599 )) 576 - .into_response(); 600 + .into_response()); 577 601 } 578 602 579 603 let did = Did::new_unchecked(&genesis_result.did); 580 604 let handle = Handle::new_unchecked(&handle); 581 - info!(did = %did, handle = %handle, controller = %&auth.0.did, "Created DID for delegated account"); 605 + info!(did = %did, handle = %handle, controller = %&user.did, "Created DID for delegated account"); 582 606 583 607 let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) { 584 608 Ok(bytes) => bytes, 585 609 Err(e) => { 586 610 error!("Error encrypting signing key: {:?}", e); 587 - return ApiError::InternalError(None).into_response(); 611 + return Ok(ApiError::InternalError(None).into_response()); 588 612 } 589 613 }; 590 614 ··· 593 617 Ok(c) => c, 594 618 Err(e) => { 595 619 error!("Error persisting MST: {:?}", e); 596 - return ApiError::InternalError(None).into_response(); 620 + return Ok(ApiError::InternalError(None).into_response()); 597 621 } 598 622 }; 599 623 let rev = Tid::now(LimitedU32::MIN); ··· 602 626 Ok(result) => result, 603 627 Err(e) => { 604 628 error!("Error creating genesis commit: {:?}", e); 605 - return ApiError::InternalError(None).into_response(); 629 + return Ok(ApiError::InternalError(None).into_response()); 606 630 } 607 631 }; 608 632 let commit_cid: cid::Cid = match state.block_store.put(&commit_bytes).await { 609 633 Ok(c) => c, 610 634 Err(e) => { 611 635 error!("Error saving genesis commit: {:?}", e); 612 - return ApiError::InternalError(None).into_response(); 636 + return Ok(ApiError::InternalError(None).into_response()); 613 637 } 614 638 }; 615 639 let genesis_block_cids = vec![mst_root.to_bytes(), commit_cid.to_bytes()]; ··· 618 642 handle: handle.clone(), 619 643 email: email.clone(), 620 644 did: did.clone(), 621 - controller_did: auth.0.did.clone(), 645 + controller_did: user.did.clone(), 622 646 controller_scopes: input.controller_scopes.clone(), 623 647 encrypted_key_bytes, 624 648 encryption_version: crate::config::ENCRYPTION_VERSION, ··· 635 659 { 636 660 Ok(id) => id, 637 661 Err(tranquil_db_traits::CreateAccountError::HandleTaken) => { 638 - return ApiError::HandleNotAvailable(None).into_response(); 662 + return Ok(ApiError::HandleNotAvailable(None).into_response()); 639 663 } 640 664 Err(tranquil_db_traits::CreateAccountError::EmailTaken) => { 641 - return ApiError::EmailTaken.into_response(); 665 + return Ok(ApiError::EmailTaken.into_response()); 642 666 } 643 667 Err(e) => { 644 668 error!("Error creating delegated account: {:?}", e); 645 - return ApiError::InternalError(None).into_response(); 669 + return Ok(ApiError::InternalError(None).into_response()); 646 670 } 647 671 }; 648 672 ··· 678 702 .delegation_repo 679 703 .log_delegation_action( 680 704 &did, 681 - &auth.0.did, 682 - Some(&auth.0.did), 705 + &user.did, 706 + Some(&user.did), 683 707 DelegationActionType::GrantCreated, 684 708 Some(json!({ 685 709 "account_created": true, ··· 690 714 ) 691 715 .await; 692 716 693 - info!(did = %did, handle = %handle, controller = %&auth.0.did, "Delegated account created"); 717 + info!(did = %did, handle = %handle, controller = %&user.did, "Delegated account created"); 694 718 695 - Json(CreateDelegatedAccountResponse { did, handle }).into_response() 719 + Ok(Json(CreateDelegatedAccountResponse { did, handle }).into_response()) 696 720 }
+7
crates/tranquil-pds/src/api/error.rs
··· 543 543 crate::auth::extractor::AuthError::AccountDeactivated => Self::AccountDeactivated, 544 544 crate::auth::extractor::AuthError::AccountTakedown => Self::AccountTakedown, 545 545 crate::auth::extractor::AuthError::AdminRequired => Self::AdminRequired, 546 + crate::auth::extractor::AuthError::OAuthExpiredToken(msg) => { 547 + Self::OAuthExpiredToken(Some(msg)) 548 + } 549 + crate::auth::extractor::AuthError::UseDpopNonce(_) 550 + | crate::auth::extractor::AuthError::InvalidDpopProof(_) => { 551 + Self::AuthenticationFailed(None) 552 + } 546 553 } 547 554 } 548 555 }
+4 -4
crates/tranquil-pds/src/api/identity/account.rs
··· 1 1 use super::did::verify_did_web; 2 2 use crate::api::error::ApiError; 3 3 use crate::api::repo::record::utils::create_signed_commit; 4 - use crate::auth::{ServiceTokenVerifier, is_service_token}; 4 + use crate::auth::{ServiceTokenVerifier, extract_auth_token_from_header, is_service_token}; 5 5 use crate::plc::{PlcClient, create_genesis_operation, signing_key_to_did_key}; 6 6 use crate::state::{AppState, RateLimitKind}; 7 7 use crate::types::{Did, Handle, Nsid, PlainPassword, Rkey}; ··· 96 96 .into_response(); 97 97 } 98 98 99 - let migration_auth = if let Some(extracted) = crate::auth::extract_auth_token_from_header( 100 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 101 - ) { 99 + let migration_auth = if let Some(extracted) = 100 + extract_auth_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) 101 + { 102 102 let token = extracted.token; 103 103 if is_service_token(&token) { 104 104 let verifier = ServiceTokenVerifier::new();
+87 -99
crates/tranquil-pds/src/api/identity/did.rs
··· 1 1 use crate::api::{ApiError, DidResponse, EmptyResponse}; 2 - use crate::auth::BearerAuthAllowDeactivated; 2 + use crate::auth::RequiredAuth; 3 3 use crate::plc::signing_key_to_did_key; 4 4 use crate::state::AppState; 5 5 use crate::types::Handle; ··· 518 518 519 519 pub async fn get_recommended_did_credentials( 520 520 State(state): State<AppState>, 521 - auth: BearerAuthAllowDeactivated, 522 - ) -> Response { 523 - let auth_user = auth.0; 524 - let handle = match state.user_repo.get_handle_by_did(&auth_user.did).await { 525 - Ok(Some(h)) => h, 526 - Ok(None) => return ApiError::InternalError(None).into_response(), 527 - Err(_) => return ApiError::InternalError(None).into_response(), 528 - }; 529 - let key_bytes = match auth_user.key_bytes { 530 - Some(kb) => kb, 531 - None => { 532 - return ApiError::AuthenticationFailed(Some( 533 - "OAuth tokens cannot get DID credentials".into(), 534 - )) 535 - .into_response(); 536 - } 537 - }; 521 + auth: RequiredAuth, 522 + ) -> Result<Response, ApiError> { 523 + let auth_user = auth.0.require_user()?.require_not_takendown()?; 524 + let handle = state 525 + .user_repo 526 + .get_handle_by_did(&auth_user.did) 527 + .await 528 + .map_err(|_| ApiError::InternalError(None))? 529 + .ok_or(ApiError::InternalError(None))?; 530 + 531 + let key_bytes = auth_user.key_bytes.clone().ok_or_else(|| { 532 + ApiError::AuthenticationFailed(Some("OAuth tokens cannot get DID credentials".into())) 533 + })?; 534 + 538 535 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 539 536 let pds_endpoint = format!("https://{}", hostname); 540 - let signing_key = match k256::ecdsa::SigningKey::from_slice(&key_bytes) { 541 - Ok(k) => k, 542 - Err(_) => return ApiError::InternalError(None).into_response(), 543 - }; 537 + let signing_key = k256::ecdsa::SigningKey::from_slice(&key_bytes) 538 + .map_err(|_| ApiError::InternalError(None))?; 544 539 let did_key = signing_key_to_did_key(&signing_key); 545 540 let rotation_keys = if auth_user.did.starts_with("did:web:") { 546 541 vec![] ··· 556 551 }; 557 552 vec![server_rotation_key] 558 553 }; 559 - ( 554 + Ok(( 560 555 StatusCode::OK, 561 556 Json(GetRecommendedDidCredentialsOutput { 562 557 rotation_keys, ··· 570 565 }, 571 566 }), 572 567 ) 573 - .into_response() 568 + .into_response()) 574 569 } 575 570 576 571 #[derive(Deserialize)] ··· 580 575 581 576 pub async fn update_handle( 582 577 State(state): State<AppState>, 583 - auth: BearerAuthAllowDeactivated, 578 + auth: RequiredAuth, 584 579 Json(input): Json<UpdateHandleInput>, 585 - ) -> Response { 586 - let auth_user = auth.0; 580 + ) -> Result<Response, ApiError> { 581 + let auth_user = auth.0.require_user()?.require_not_takendown()?; 587 582 if let Err(e) = crate::auth::scope_check::check_identity_scope( 588 583 auth_user.is_oauth, 589 584 auth_user.scope.as_deref(), 590 585 crate::oauth::scopes::IdentityAttr::Handle, 591 586 ) { 592 - return e; 587 + return Ok(e); 593 588 } 594 - let did = auth_user.did; 589 + let did = auth_user.did.clone(); 595 590 if !state 596 591 .check_rate_limit(crate::state::RateLimitKind::HandleUpdate, &did) 597 592 .await 598 593 { 599 - return ApiError::RateLimitExceeded(Some( 594 + return Err(ApiError::RateLimitExceeded(Some( 600 595 "Too many handle updates. Try again later.".into(), 601 - )) 602 - .into_response(); 596 + ))); 603 597 } 604 598 if !state 605 599 .check_rate_limit(crate::state::RateLimitKind::HandleUpdateDaily, &did) 606 600 .await 607 601 { 608 - return ApiError::RateLimitExceeded(Some("Daily handle update limit exceeded.".into())) 609 - .into_response(); 602 + return Err(ApiError::RateLimitExceeded(Some( 603 + "Daily handle update limit exceeded.".into(), 604 + ))); 610 605 } 611 - let user_row = match state.user_repo.get_id_and_handle_by_did(&did).await { 612 - Ok(Some(row)) => row, 613 - Ok(None) => return ApiError::InternalError(None).into_response(), 614 - Err(_) => return ApiError::InternalError(None).into_response(), 615 - }; 606 + let user_row = state 607 + .user_repo 608 + .get_id_and_handle_by_did(&did) 609 + .await 610 + .map_err(|_| ApiError::InternalError(None))? 611 + .ok_or(ApiError::InternalError(None))?; 616 612 let user_id = user_row.id; 617 613 let current_handle = user_row.handle; 618 614 let new_handle = input.handle.trim().to_ascii_lowercase(); 619 615 if new_handle.is_empty() { 620 - return ApiError::InvalidRequest("handle is required".into()).into_response(); 616 + return Err(ApiError::InvalidRequest("handle is required".into())); 621 617 } 622 618 if !new_handle 623 619 .chars() 624 620 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-') 625 621 { 626 - return ApiError::InvalidHandle(Some("Handle contains invalid characters".into())) 627 - .into_response(); 622 + return Err(ApiError::InvalidHandle(Some( 623 + "Handle contains invalid characters".into(), 624 + ))); 628 625 } 629 626 if new_handle.split('.').any(|segment| segment.is_empty()) { 630 - return ApiError::InvalidHandle(Some("Handle contains empty segment".into())) 631 - .into_response(); 627 + return Err(ApiError::InvalidHandle(Some( 628 + "Handle contains empty segment".into(), 629 + ))); 632 630 } 633 631 if new_handle 634 632 .split('.') 635 633 .any(|segment| segment.starts_with('-') || segment.ends_with('-')) 636 634 { 637 - return ApiError::InvalidHandle(Some( 635 + return Err(ApiError::InvalidHandle(Some( 638 636 "Handle segment cannot start or end with hyphen".into(), 639 - )) 640 - .into_response(); 637 + ))); 641 638 } 642 639 if crate::moderation::has_explicit_slur(&new_handle) { 643 - return ApiError::InvalidHandle(Some("Inappropriate language in handle".into())) 644 - .into_response(); 640 + return Err(ApiError::InvalidHandle(Some( 641 + "Inappropriate language in handle".into(), 642 + ))); 645 643 } 646 644 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 647 645 let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); ··· 667 665 { 668 666 warn!("Failed to sequence identity event for handle update: {}", e); 669 667 } 670 - return EmptyResponse::ok().into_response(); 668 + return Ok(EmptyResponse::ok().into_response()); 671 669 } 672 670 if short_part.contains('.') { 673 - return ApiError::InvalidHandle(Some( 671 + return Err(ApiError::InvalidHandle(Some( 674 672 "Nested subdomains are not allowed. Use a simple handle without dots.".into(), 675 - )) 676 - .into_response(); 673 + ))); 677 674 } 678 675 if short_part.len() < 3 { 679 - return ApiError::InvalidHandle(Some("Handle too short".into())).into_response(); 676 + return Err(ApiError::InvalidHandle(Some("Handle too short".into()))); 680 677 } 681 678 if short_part.len() > 18 { 682 - return ApiError::InvalidHandle(Some("Handle too long".into())).into_response(); 679 + return Err(ApiError::InvalidHandle(Some("Handle too long".into()))); 683 680 } 684 681 full_handle 685 682 } else { ··· 691 688 { 692 689 warn!("Failed to sequence identity event for handle update: {}", e); 693 690 } 694 - return EmptyResponse::ok().into_response(); 691 + return Ok(EmptyResponse::ok().into_response()); 695 692 } 696 693 match crate::handle::verify_handle_ownership(&new_handle, &did).await { 697 694 Ok(()) => {} 698 695 Err(crate::handle::HandleResolutionError::NotFound) => { 699 - return ApiError::HandleNotAvailable(None).into_response(); 696 + return Err(ApiError::HandleNotAvailable(None)); 700 697 } 701 698 Err(crate::handle::HandleResolutionError::DidMismatch { expected, actual }) => { 702 - return ApiError::HandleNotAvailable(Some(format!( 699 + return Err(ApiError::HandleNotAvailable(Some(format!( 703 700 "Handle points to different DID. Expected {}, got {}", 704 701 expected, actual 705 - ))) 706 - .into_response(); 702 + )))); 707 703 } 708 704 Err(e) => { 709 705 warn!("Handle verification failed: {}", e); 710 - return ApiError::HandleNotAvailable(Some(format!( 706 + return Err(ApiError::HandleNotAvailable(Some(format!( 711 707 "Handle verification failed: {}", 712 708 e 713 - ))) 714 - .into_response(); 709 + )))); 715 710 } 716 711 } 717 712 new_handle.clone() 718 713 }; 719 - let handle_typed: Handle = match handle.parse() { 720 - Ok(h) => h, 721 - Err(_) => { 722 - return ApiError::InvalidHandle(Some("Invalid handle format".into())).into_response(); 723 - } 724 - }; 725 - let handle_exists = match state 714 + let handle_typed: Handle = handle 715 + .parse() 716 + .map_err(|_| ApiError::InvalidHandle(Some("Invalid handle format".into())))?; 717 + let handle_exists = state 726 718 .user_repo 727 719 .check_handle_exists(&handle_typed, user_id) 728 720 .await 729 - { 730 - Ok(exists) => exists, 731 - Err(_) => return ApiError::InternalError(None).into_response(), 732 - }; 721 + .map_err(|_| ApiError::InternalError(None))?; 733 722 if handle_exists { 734 - return ApiError::HandleTaken.into_response(); 723 + return Err(ApiError::HandleTaken); 735 724 } 736 - let result = state.user_repo.update_handle(user_id, &handle_typed).await; 737 - match result { 738 - Ok(_) => { 739 - if !current_handle.is_empty() { 740 - let _ = state 741 - .cache 742 - .delete(&format!("handle:{}", current_handle)) 743 - .await; 744 - } 745 - let _ = state.cache.delete(&format!("handle:{}", handle)).await; 746 - if let Err(e) = 747 - crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle_typed)) 748 - .await 749 - { 750 - warn!("Failed to sequence identity event for handle update: {}", e); 751 - } 752 - if let Err(e) = update_plc_handle(&state, &did, &handle_typed).await { 753 - warn!("Failed to update PLC handle: {}", e); 754 - } 755 - EmptyResponse::ok().into_response() 756 - } 757 - Err(e) => { 725 + state 726 + .user_repo 727 + .update_handle(user_id, &handle_typed) 728 + .await 729 + .map_err(|e| { 758 730 error!("DB error updating handle: {:?}", e); 759 - ApiError::InternalError(None).into_response() 760 - } 731 + ApiError::InternalError(None) 732 + })?; 733 + 734 + if !current_handle.is_empty() { 735 + let _ = state 736 + .cache 737 + .delete(&format!("handle:{}", current_handle)) 738 + .await; 739 + } 740 + let _ = state.cache.delete(&format!("handle:{}", handle)).await; 741 + if let Err(e) = 742 + crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle_typed)).await 743 + { 744 + warn!("Failed to sequence identity event for handle update: {}", e); 745 + } 746 + if let Err(e) = update_plc_handle(&state, &did, &handle_typed).await { 747 + warn!("Failed to update PLC handle: {}", e); 761 748 } 749 + Ok(EmptyResponse::ok().into_response()) 762 750 } 763 751 764 752 pub async fn update_plc_handle(
+21 -18
crates/tranquil-pds/src/api/identity/plc/request.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::BearerAuthAllowDeactivated; 3 + use crate::auth::RequiredAuth; 4 4 use crate::state::AppState; 5 5 use axum::{ 6 6 extract::State, ··· 15 15 16 16 pub async fn request_plc_operation_signature( 17 17 State(state): State<AppState>, 18 - auth: BearerAuthAllowDeactivated, 19 - ) -> Response { 20 - let auth_user = auth.0; 18 + auth: RequiredAuth, 19 + ) -> Result<Response, ApiError> { 20 + let auth_user = auth.0.require_user()?.require_not_takendown()?; 21 21 if let Err(e) = crate::auth::scope_check::check_identity_scope( 22 22 auth_user.is_oauth, 23 23 auth_user.scope.as_deref(), 24 24 crate::oauth::scopes::IdentityAttr::Wildcard, 25 25 ) { 26 - return e; 26 + return Ok(e); 27 27 } 28 - let user_id = match state.user_repo.get_id_by_did(&auth_user.did).await { 29 - Ok(Some(id)) => id, 30 - Ok(None) => return ApiError::AccountNotFound.into_response(), 31 - Err(e) => { 28 + let user_id = state 29 + .user_repo 30 + .get_id_by_did(&auth_user.did) 31 + .await 32 + .map_err(|e| { 32 33 error!("DB error: {:?}", e); 33 - return ApiError::InternalError(None).into_response(); 34 - } 35 - }; 34 + ApiError::InternalError(None) 35 + })? 36 + .ok_or(ApiError::AccountNotFound)?; 37 + 36 38 let _ = state.infra_repo.delete_plc_tokens_for_user(user_id).await; 37 39 let plc_token = generate_plc_token(); 38 40 let expires_at = Utc::now() + Duration::minutes(10); 39 - if let Err(e) = state 41 + state 40 42 .infra_repo 41 43 .insert_plc_token(user_id, &plc_token, expires_at) 42 44 .await 43 - { 44 - error!("Failed to create PLC token: {:?}", e); 45 - return ApiError::InternalError(None).into_response(); 46 - } 45 + .map_err(|e| { 46 + error!("Failed to create PLC token: {:?}", e); 47 + ApiError::InternalError(None) 48 + })?; 49 + 47 50 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 48 51 if let Err(e) = crate::comms::comms_repo::enqueue_plc_operation( 49 52 state.user_repo.as_ref(), ··· 60 63 "PLC operation signature requested for user {}", 61 64 auth_user.did 62 65 ); 63 - EmptyResponse::ok().into_response() 66 + Ok(EmptyResponse::ok().into_response()) 64 67 }
+68 -82
crates/tranquil-pds/src/api/identity/plc/sign.rs
··· 1 1 use crate::api::ApiError; 2 - use crate::auth::BearerAuthAllowDeactivated; 2 + use crate::auth::RequiredAuth; 3 3 use crate::circuit_breaker::with_circuit_breaker; 4 4 use crate::plc::{PlcClient, PlcError, PlcService, create_update_op, sign_operation}; 5 5 use crate::state::AppState; ··· 40 40 41 41 pub async fn sign_plc_operation( 42 42 State(state): State<AppState>, 43 - auth: BearerAuthAllowDeactivated, 43 + auth: RequiredAuth, 44 44 Json(input): Json<SignPlcOperationInput>, 45 - ) -> Response { 46 - let auth_user = auth.0; 45 + ) -> Result<Response, ApiError> { 46 + let auth_user = auth.0.require_user()?.require_not_takendown()?; 47 47 if let Err(e) = crate::auth::scope_check::check_identity_scope( 48 48 auth_user.is_oauth, 49 49 auth_user.scope.as_deref(), 50 50 crate::oauth::scopes::IdentityAttr::Wildcard, 51 51 ) { 52 - return e; 52 + return Ok(e); 53 53 } 54 54 let did = &auth_user.did; 55 55 if did.starts_with("did:web:") { 56 - return ApiError::InvalidRequest( 56 + return Err(ApiError::InvalidRequest( 57 57 "PLC operations are only valid for did:plc identities".into(), 58 - ) 59 - .into_response(); 58 + )); 60 59 } 61 - let token = match &input.token { 62 - Some(t) => t, 63 - None => { 64 - return ApiError::InvalidRequest( 65 - "Email confirmation token required to sign PLC operations".into(), 66 - ) 67 - .into_response(); 68 - } 69 - }; 70 - let user_id = match state.user_repo.get_id_by_did(did).await { 71 - Ok(Some(id)) => id, 72 - Ok(None) => return ApiError::AccountNotFound.into_response(), 73 - Err(e) => { 60 + let token = input.token.as_ref().ok_or_else(|| { 61 + ApiError::InvalidRequest("Email confirmation token required to sign PLC operations".into()) 62 + })?; 63 + 64 + let user_id = state 65 + .user_repo 66 + .get_id_by_did(did) 67 + .await 68 + .map_err(|e| { 74 69 error!("DB error: {:?}", e); 75 - return ApiError::InternalError(None).into_response(); 76 - } 77 - }; 78 - let token_expiry = match state.infra_repo.get_plc_token_expiry(user_id, token).await { 79 - Ok(Some(expiry)) => expiry, 80 - Ok(None) => { 81 - return ApiError::InvalidToken(Some("Invalid or expired token".into())).into_response(); 82 - } 83 - Err(e) => { 70 + ApiError::InternalError(None) 71 + })? 72 + .ok_or(ApiError::AccountNotFound)?; 73 + 74 + let token_expiry = state 75 + .infra_repo 76 + .get_plc_token_expiry(user_id, token) 77 + .await 78 + .map_err(|e| { 84 79 error!("DB error: {:?}", e); 85 - return ApiError::InternalError(None).into_response(); 86 - } 87 - }; 80 + ApiError::InternalError(None) 81 + })? 82 + .ok_or_else(|| ApiError::InvalidToken(Some("Invalid or expired token".into())))?; 83 + 88 84 if Utc::now() > token_expiry { 89 85 let _ = state.infra_repo.delete_plc_token(user_id, token).await; 90 - return ApiError::ExpiredToken(Some("Token has expired".into())).into_response(); 86 + return Err(ApiError::ExpiredToken(Some("Token has expired".into()))); 91 87 } 92 - let key_row = match state.user_repo.get_user_key_by_id(user_id).await { 93 - Ok(Some(row)) => row, 94 - Ok(None) => { 95 - return ApiError::InternalError(Some("User signing key not found".into())) 96 - .into_response(); 97 - } 98 - Err(e) => { 88 + let key_row = state 89 + .user_repo 90 + .get_user_key_by_id(user_id) 91 + .await 92 + .map_err(|e| { 99 93 error!("DB error: {:?}", e); 100 - return ApiError::InternalError(None).into_response(); 101 - } 102 - }; 103 - let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 104 - { 105 - Ok(k) => k, 106 - Err(e) => { 94 + ApiError::InternalError(None) 95 + })? 96 + .ok_or_else(|| ApiError::InternalError(Some("User signing key not found".into())))?; 97 + 98 + let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 99 + .map_err(|e| { 107 100 error!("Failed to decrypt user key: {}", e); 108 - return ApiError::InternalError(None).into_response(); 109 - } 110 - }; 111 - let signing_key = match SigningKey::from_slice(&key_bytes) { 112 - Ok(k) => k, 113 - Err(e) => { 114 - error!("Failed to create signing key: {:?}", e); 115 - return ApiError::InternalError(None).into_response(); 116 - } 117 - }; 101 + ApiError::InternalError(None) 102 + })?; 103 + 104 + let signing_key = SigningKey::from_slice(&key_bytes).map_err(|e| { 105 + error!("Failed to create signing key: {:?}", e); 106 + ApiError::InternalError(None) 107 + })?; 108 + 118 109 let plc_client = PlcClient::with_cache(None, Some(state.cache.clone())); 119 110 let did_clone = did.clone(); 120 - let last_op = match with_circuit_breaker(&state.circuit_breakers.plc_directory, || async { 111 + let last_op = with_circuit_breaker(&state.circuit_breakers.plc_directory, || async { 121 112 plc_client.get_last_op(&did_clone).await 122 113 }) 123 114 .await 124 - { 125 - Ok(op) => op, 126 - Err(e) => return ApiError::from(e).into_response(), 127 - }; 115 + .map_err(ApiError::from)?; 116 + 128 117 if last_op.is_tombstone() { 129 - return ApiError::from(PlcError::Tombstoned).into_response(); 118 + return Err(ApiError::from(PlcError::Tombstoned)); 130 119 } 131 120 let services = input.services.map(|s| { 132 121 s.into_iter() ··· 141 130 }) 142 131 .collect() 143 132 }); 144 - let unsigned_op = match create_update_op( 133 + let unsigned_op = create_update_op( 145 134 &last_op, 146 135 input.rotation_keys, 147 136 input.verification_methods, 148 137 input.also_known_as, 149 138 services, 150 - ) { 151 - Ok(op) => op, 152 - Err(PlcError::Tombstoned) => { 153 - return ApiError::InvalidRequest("Cannot update tombstoned DID".into()).into_response(); 154 - } 155 - Err(e) => { 139 + ) 140 + .map_err(|e| match e { 141 + PlcError::Tombstoned => ApiError::InvalidRequest("Cannot update tombstoned DID".into()), 142 + _ => { 156 143 error!("Failed to create PLC operation: {:?}", e); 157 - return ApiError::InternalError(None).into_response(); 144 + ApiError::InternalError(None) 158 145 } 159 - }; 160 - let signed_op = match sign_operation(&unsigned_op, &signing_key) { 161 - Ok(op) => op, 162 - Err(e) => { 163 - error!("Failed to sign PLC operation: {:?}", e); 164 - return ApiError::InternalError(None).into_response(); 165 - } 166 - }; 146 + })?; 147 + 148 + let signed_op = sign_operation(&unsigned_op, &signing_key).map_err(|e| { 149 + error!("Failed to sign PLC operation: {:?}", e); 150 + ApiError::InternalError(None) 151 + })?; 152 + 167 153 let _ = state.infra_repo.delete_plc_token(user_id, token).await; 168 154 info!("Signed PLC operation for user {}", did); 169 - ( 155 + Ok(( 170 156 StatusCode::OK, 171 157 Json(SignPlcOperationOutput { 172 158 operation: signed_op, 173 159 }), 174 160 ) 175 - .into_response() 161 + .into_response()) 176 162 }
+56 -58
crates/tranquil-pds/src/api/identity/plc/submit.rs
··· 1 1 use crate::api::{ApiError, EmptyResponse}; 2 - use crate::auth::BearerAuthAllowDeactivated; 2 + use crate::auth::RequiredAuth; 3 3 use crate::circuit_breaker::with_circuit_breaker; 4 4 use crate::plc::{PlcClient, signing_key_to_did_key, validate_plc_operation}; 5 5 use crate::state::AppState; ··· 20 20 21 21 pub async fn submit_plc_operation( 22 22 State(state): State<AppState>, 23 - auth: BearerAuthAllowDeactivated, 23 + auth: RequiredAuth, 24 24 Json(input): Json<SubmitPlcOperationInput>, 25 - ) -> Response { 26 - let auth_user = auth.0; 25 + ) -> Result<Response, ApiError> { 26 + let auth_user = auth.0.require_user()?.require_not_takendown()?; 27 27 if let Err(e) = crate::auth::scope_check::check_identity_scope( 28 28 auth_user.is_oauth, 29 29 auth_user.scope.as_deref(), 30 30 crate::oauth::scopes::IdentityAttr::Wildcard, 31 31 ) { 32 - return e; 32 + return Ok(e); 33 33 } 34 34 let did = &auth_user.did; 35 35 if did.starts_with("did:web:") { 36 - return ApiError::InvalidRequest( 36 + return Err(ApiError::InvalidRequest( 37 37 "PLC operations are only valid for did:plc identities".into(), 38 - ) 39 - .into_response(); 38 + )); 40 39 } 41 - if let Err(e) = validate_plc_operation(&input.operation) { 42 - return ApiError::InvalidRequest(format!("Invalid operation: {}", e)).into_response(); 43 - } 40 + validate_plc_operation(&input.operation) 41 + .map_err(|e| ApiError::InvalidRequest(format!("Invalid operation: {}", e)))?; 42 + 44 43 let op = &input.operation; 45 44 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 46 45 let public_url = format!("https://{}", hostname); 47 - let user = match state.user_repo.get_id_and_handle_by_did(did).await { 48 - Ok(Some(u)) => u, 49 - Ok(None) => return ApiError::AccountNotFound.into_response(), 50 - Err(e) => { 46 + let user = state 47 + .user_repo 48 + .get_id_and_handle_by_did(did) 49 + .await 50 + .map_err(|e| { 51 51 error!("DB error: {:?}", e); 52 - return ApiError::InternalError(None).into_response(); 53 - } 54 - }; 55 - let key_row = match state.user_repo.get_user_key_by_id(user.id).await { 56 - Ok(Some(row)) => row, 57 - Ok(None) => { 58 - return ApiError::InternalError(Some("User signing key not found".into())) 59 - .into_response(); 60 - } 61 - Err(e) => { 52 + ApiError::InternalError(None) 53 + })? 54 + .ok_or(ApiError::AccountNotFound)?; 55 + 56 + let key_row = state 57 + .user_repo 58 + .get_user_key_by_id(user.id) 59 + .await 60 + .map_err(|e| { 62 61 error!("DB error: {:?}", e); 63 - return ApiError::InternalError(None).into_response(); 64 - } 65 - }; 66 - let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 67 - { 68 - Ok(k) => k, 69 - Err(e) => { 62 + ApiError::InternalError(None) 63 + })? 64 + .ok_or_else(|| ApiError::InternalError(Some("User signing key not found".into())))?; 65 + 66 + let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 67 + .map_err(|e| { 70 68 error!("Failed to decrypt user key: {}", e); 71 - return ApiError::InternalError(None).into_response(); 72 - } 73 - }; 74 - let signing_key = match SigningKey::from_slice(&key_bytes) { 75 - Ok(k) => k, 76 - Err(e) => { 77 - error!("Failed to create signing key: {:?}", e); 78 - return ApiError::InternalError(None).into_response(); 79 - } 80 - }; 69 + ApiError::InternalError(None) 70 + })?; 71 + 72 + let signing_key = SigningKey::from_slice(&key_bytes).map_err(|e| { 73 + error!("Failed to create signing key: {:?}", e); 74 + ApiError::InternalError(None) 75 + })?; 76 + 81 77 let user_did_key = signing_key_to_did_key(&signing_key); 82 78 let server_rotation_key = 83 79 std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone()); ··· 86 82 .iter() 87 83 .any(|k| k.as_str() == Some(&server_rotation_key)); 88 84 if !has_server_key { 89 - return ApiError::InvalidRequest( 85 + return Err(ApiError::InvalidRequest( 90 86 "Rotation keys do not include server's rotation key".into(), 91 - ) 92 - .into_response(); 87 + )); 93 88 } 94 89 } 95 90 if let Some(services) = op.get("services").and_then(|v| v.as_object()) ··· 98 93 let service_type = pds.get("type").and_then(|v| v.as_str()); 99 94 let endpoint = pds.get("endpoint").and_then(|v| v.as_str()); 100 95 if service_type != Some("AtprotoPersonalDataServer") { 101 - return ApiError::InvalidRequest("Incorrect type on atproto_pds service".into()) 102 - .into_response(); 96 + return Err(ApiError::InvalidRequest( 97 + "Incorrect type on atproto_pds service".into(), 98 + )); 103 99 } 104 100 if endpoint != Some(&public_url) { 105 - return ApiError::InvalidRequest("Incorrect endpoint on atproto_pds service".into()) 106 - .into_response(); 101 + return Err(ApiError::InvalidRequest( 102 + "Incorrect endpoint on atproto_pds service".into(), 103 + )); 107 104 } 108 105 } 109 106 if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) 110 107 && let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) 111 108 && atproto_key != user_did_key 112 109 { 113 - return ApiError::InvalidRequest("Incorrect signing key in verificationMethods".into()) 114 - .into_response(); 110 + return Err(ApiError::InvalidRequest( 111 + "Incorrect signing key in verificationMethods".into(), 112 + )); 115 113 } 116 114 if let Some(also_known_as) = (!user.handle.is_empty()) 117 115 .then(|| op.get("alsoKnownAs").and_then(|v| v.as_array())) ··· 120 118 let expected_handle = format!("at://{}", user.handle); 121 119 let first_aka = also_known_as.first().and_then(|v| v.as_str()); 122 120 if first_aka != Some(&expected_handle) { 123 - return ApiError::InvalidRequest("Incorrect handle in alsoKnownAs".into()) 124 - .into_response(); 121 + return Err(ApiError::InvalidRequest( 122 + "Incorrect handle in alsoKnownAs".into(), 123 + )); 125 124 } 126 125 } 127 126 let plc_client = PlcClient::with_cache(None, Some(state.cache.clone())); 128 127 let operation_clone = input.operation.clone(); 129 128 let did_clone = did.clone(); 130 - if let Err(e) = with_circuit_breaker(&state.circuit_breakers.plc_directory, || async { 129 + with_circuit_breaker(&state.circuit_breakers.plc_directory, || async { 131 130 plc_client 132 131 .send_operation(&did_clone, &operation_clone) 133 132 .await 134 133 }) 135 134 .await 136 - { 137 - return ApiError::from(e).into_response(); 138 - } 135 + .map_err(ApiError::from)?; 136 + 139 137 match state 140 138 .repo_repo 141 139 .insert_identity_event(did, Some(&user.handle)) ··· 157 155 warn!(did = %did, "Failed to refresh DID cache after PLC update"); 158 156 } 159 157 info!(did = %did, "PLC operation submitted successfully"); 160 - EmptyResponse::ok().into_response() 158 + Ok(EmptyResponse::ok().into_response()) 161 159 }
+8 -5
crates/tranquil-pds/src/api/moderation/mod.rs
··· 1 1 use crate::api::ApiError; 2 2 use crate::api::proxy_client::{is_ssrf_safe, proxy_client}; 3 - use crate::auth::extractor::BearerAuthAllowTakendown; 3 + use crate::auth::RequiredAuth; 4 4 use crate::state::AppState; 5 5 use axum::{ 6 6 Json, ··· 42 42 43 43 pub async fn create_report( 44 44 State(state): State<AppState>, 45 - auth: BearerAuthAllowTakendown, 45 + auth: RequiredAuth, 46 46 Json(input): Json<CreateReportInput>, 47 47 ) -> Response { 48 - let auth_user = auth.0; 48 + let auth_user = match auth.0.require_user() { 49 + Ok(u) => u, 50 + Err(e) => return e.into_response(), 51 + }; 49 52 let did = &auth_user.did; 50 53 51 54 if let Some((service_url, service_did)) = get_report_service_config() { 52 - return proxy_to_report_service(&state, &auth_user, &service_url, &service_did, &input) 55 + return proxy_to_report_service(&state, auth_user, &service_url, &service_did, &input) 53 56 .await; 54 57 } 55 58 56 - create_report_locally(&state, did, auth_user.is_takendown(), input).await 59 + create_report_locally(&state, did, auth_user.status.is_takendown(), input).await 57 60 } 58 61 59 62 async fn proxy_to_report_service(
+74 -85
crates/tranquil-pds/src/api/notification_prefs.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::BearerAuth; 2 + use crate::auth::RequiredAuth; 3 3 use crate::state::AppState; 4 4 use axum::{ 5 5 Json, ··· 23 23 pub signal_verified: bool, 24 24 } 25 25 26 - pub async fn get_notification_prefs(State(state): State<AppState>, auth: BearerAuth) -> Response { 27 - let user = auth.0; 28 - let prefs = match state.user_repo.get_notification_prefs(&user.did).await { 29 - Ok(Some(p)) => p, 30 - Ok(None) => return ApiError::AccountNotFound.into_response(), 31 - Err(e) => { 32 - return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 33 - } 34 - }; 35 - Json(NotificationPrefsResponse { 26 + pub async fn get_notification_prefs( 27 + State(state): State<AppState>, 28 + auth: RequiredAuth, 29 + ) -> Result<Response, ApiError> { 30 + let user = auth.0.require_user()?.require_active()?; 31 + let prefs = state 32 + .user_repo 33 + .get_notification_prefs(&user.did) 34 + .await 35 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))? 36 + .ok_or(ApiError::AccountNotFound)?; 37 + Ok(Json(NotificationPrefsResponse { 36 38 preferred_channel: prefs.preferred_channel, 37 39 email: prefs.email, 38 40 discord_id: prefs.discord_id, ··· 42 44 signal_number: prefs.signal_number, 43 45 signal_verified: prefs.signal_verified, 44 46 }) 45 - .into_response() 47 + .into_response()) 46 48 } 47 49 48 50 #[derive(Serialize)] ··· 62 64 pub notifications: Vec<NotificationHistoryEntry>, 63 65 } 64 66 65 - pub async fn get_notification_history(State(state): State<AppState>, auth: BearerAuth) -> Response { 66 - let user = auth.0; 67 + pub async fn get_notification_history( 68 + State(state): State<AppState>, 69 + auth: RequiredAuth, 70 + ) -> Result<Response, ApiError> { 71 + let user = auth.0.require_user()?.require_active()?; 67 72 68 - let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&user.did).await { 69 - Ok(Some(id)) => id, 70 - Ok(None) => return ApiError::AccountNotFound.into_response(), 71 - Err(e) => { 72 - return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 73 - } 74 - }; 73 + let user_id = state 74 + .user_repo 75 + .get_id_by_did(&user.did) 76 + .await 77 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))? 78 + .ok_or(ApiError::AccountNotFound)?; 75 79 76 - let rows = match state.infra_repo.get_notification_history(user_id, 50).await { 77 - Ok(r) => r, 78 - Err(e) => { 79 - return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 80 - } 81 - }; 80 + let rows = state 81 + .infra_repo 82 + .get_notification_history(user_id, 50) 83 + .await 84 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 82 85 83 86 let sensitive_types = [ 84 87 "email_verification", ··· 111 114 }) 112 115 .collect(); 113 116 114 - Json(GetNotificationHistoryResponse { notifications }).into_response() 117 + Ok(Json(GetNotificationHistoryResponse { notifications }).into_response()) 115 118 } 116 119 117 120 #[derive(Deserialize)] ··· 184 187 185 188 pub async fn update_notification_prefs( 186 189 State(state): State<AppState>, 187 - auth: BearerAuth, 190 + auth: RequiredAuth, 188 191 Json(input): Json<UpdateNotificationPrefsInput>, 189 - ) -> Response { 190 - let user = auth.0; 192 + ) -> Result<Response, ApiError> { 193 + let user = auth.0.require_user()?.require_active()?; 191 194 192 - let user_row = match state.user_repo.get_id_handle_email_by_did(&user.did).await { 193 - Ok(Some(row)) => row, 194 - Ok(None) => return ApiError::AccountNotFound.into_response(), 195 - Err(e) => { 196 - return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 197 - } 198 - }; 195 + let user_row = state 196 + .user_repo 197 + .get_id_handle_email_by_did(&user.did) 198 + .await 199 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))? 200 + .ok_or(ApiError::AccountNotFound)?; 199 201 200 202 let user_id = user_row.id; 201 203 let handle = user_row.handle; ··· 206 208 if let Some(ref channel) = input.preferred_channel { 207 209 let valid_channels = ["email", "discord", "telegram", "signal"]; 208 210 if !valid_channels.contains(&channel.as_str()) { 209 - return ApiError::InvalidRequest( 211 + return Err(ApiError::InvalidRequest( 210 212 "Invalid channel. Must be one of: email, discord, telegram, signal".into(), 211 - ) 212 - .into_response(); 213 + )); 213 214 } 214 - if let Err(e) = state 215 + state 215 216 .user_repo 216 217 .update_preferred_comms_channel(&user.did, channel) 217 218 .await 218 - { 219 - return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 220 - } 219 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 221 220 info!(did = %user.did, channel = %channel, "Updated preferred notification channel"); 222 221 } 223 222 224 223 if let Some(ref new_email) = input.email { 225 224 let email_clean = new_email.trim().to_lowercase(); 226 225 if email_clean.is_empty() { 227 - return ApiError::InvalidRequest("Email cannot be empty".into()).into_response(); 226 + return Err(ApiError::InvalidRequest("Email cannot be empty".into())); 228 227 } 229 228 230 229 if !crate::api::validation::is_valid_email(&email_clean) { 231 - return ApiError::InvalidEmail.into_response(); 230 + return Err(ApiError::InvalidEmail); 232 231 } 233 232 234 - if current_email.as_ref().map(|e| e.to_lowercase()) == Some(email_clean.clone()) { 235 - info!(did = %user.did, "Email unchanged, skipping"); 236 - } else { 237 - if let Err(e) = request_channel_verification( 233 + if current_email.as_ref().map(|e| e.to_lowercase()) != Some(email_clean.clone()) { 234 + request_channel_verification( 238 235 &state, 239 236 user_id, 240 237 &user.did, ··· 243 240 Some(&handle), 244 241 ) 245 242 .await 246 - { 247 - return ApiError::InternalError(Some(e)).into_response(); 248 - } 243 + .map_err(|e| ApiError::InternalError(Some(e)))?; 249 244 verification_required.push("email".to_string()); 250 245 info!(did = %user.did, "Requested email verification"); 251 246 } ··· 253 248 254 249 if let Some(ref discord_id) = input.discord_id { 255 250 if discord_id.is_empty() { 256 - if let Err(e) = state.user_repo.clear_discord(user_id).await { 257 - return ApiError::InternalError(Some(format!("Database error: {}", e))) 258 - .into_response(); 259 - } 251 + state 252 + .user_repo 253 + .clear_discord(user_id) 254 + .await 255 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 260 256 info!(did = %user.did, "Cleared Discord ID"); 261 257 } else { 262 - if let Err(e) = request_channel_verification( 263 - &state, user_id, &user.did, "discord", discord_id, None, 264 - ) 265 - .await 266 - { 267 - return ApiError::InternalError(Some(e)).into_response(); 268 - } 258 + request_channel_verification(&state, user_id, &user.did, "discord", discord_id, None) 259 + .await 260 + .map_err(|e| ApiError::InternalError(Some(e)))?; 269 261 verification_required.push("discord".to_string()); 270 262 info!(did = %user.did, "Requested Discord verification"); 271 263 } ··· 274 266 if let Some(ref telegram) = input.telegram_username { 275 267 let telegram_clean = telegram.trim_start_matches('@'); 276 268 if telegram_clean.is_empty() { 277 - if let Err(e) = state.user_repo.clear_telegram(user_id).await { 278 - return ApiError::InternalError(Some(format!("Database error: {}", e))) 279 - .into_response(); 280 - } 269 + state 270 + .user_repo 271 + .clear_telegram(user_id) 272 + .await 273 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 281 274 info!(did = %user.did, "Cleared Telegram username"); 282 275 } else { 283 - if let Err(e) = request_channel_verification( 276 + request_channel_verification( 284 277 &state, 285 278 user_id, 286 279 &user.did, ··· 289 282 None, 290 283 ) 291 284 .await 292 - { 293 - return ApiError::InternalError(Some(e)).into_response(); 294 - } 285 + .map_err(|e| ApiError::InternalError(Some(e)))?; 295 286 verification_required.push("telegram".to_string()); 296 287 info!(did = %user.did, "Requested Telegram verification"); 297 288 } ··· 299 290 300 291 if let Some(ref signal) = input.signal_number { 301 292 if signal.is_empty() { 302 - if let Err(e) = state.user_repo.clear_signal(user_id).await { 303 - return ApiError::InternalError(Some(format!("Database error: {}", e))) 304 - .into_response(); 305 - } 293 + state 294 + .user_repo 295 + .clear_signal(user_id) 296 + .await 297 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 306 298 info!(did = %user.did, "Cleared Signal number"); 307 299 } else { 308 - if let Err(e) = 309 - request_channel_verification(&state, user_id, &user.did, "signal", signal, None) 310 - .await 311 - { 312 - return ApiError::InternalError(Some(e)).into_response(); 313 - } 300 + request_channel_verification(&state, user_id, &user.did, "signal", signal, None) 301 + .await 302 + .map_err(|e| ApiError::InternalError(Some(e)))?; 314 303 verification_required.push("signal".to_string()); 315 304 info!(did = %user.did, "Requested Signal verification"); 316 305 } 317 306 } 318 307 319 - Json(UpdateNotificationPrefsResponse { 308 + Ok(Json(UpdateNotificationPrefsResponse { 320 309 success: true, 321 310 verification_required, 322 311 }) 323 - .into_response() 312 + .into_response()) 324 313 }
+12 -3
crates/tranquil-pds/src/api/proxy.rs
··· 267 267 } 268 268 } 269 269 Err(e) => { 270 - warn!("Token validation failed: {:?}", e); 271 - if matches!(e, crate::auth::TokenValidationError::OAuthTokenExpired) { 272 - return ApiError::from(e).into_response(); 270 + info!(error = ?e, "Proxy token validation failed, returning error to client"); 271 + if matches!( 272 + e, 273 + crate::auth::TokenValidationError::OAuthTokenExpired 274 + | crate::auth::TokenValidationError::TokenExpired 275 + ) { 276 + let mut response = ApiError::from(e).into_response(); 277 + let nonce = crate::oauth::verify::generate_dpop_nonce(); 278 + if let Ok(nonce_val) = nonce.parse() { 279 + response.headers_mut().insert("DPoP-Nonce", nonce_val); 280 + } 281 + return response; 273 282 } 274 283 } 275 284 }
+76 -118
crates/tranquil-pds/src/api/repo/blob.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::{BearerAuthAllowDeactivated, ServiceTokenVerifier, is_service_token}; 2 + use crate::auth::{AuthenticatedEntity, RequiredAuth}; 3 3 use crate::delegation::DelegationActionType; 4 4 use crate::state::AppState; 5 5 use crate::types::{CidLink, Did}; ··· 44 44 pub async fn upload_blob( 45 45 State(state): State<AppState>, 46 46 headers: axum::http::HeaderMap, 47 + auth: RequiredAuth, 47 48 body: Body, 48 - ) -> Response { 49 - let extracted = match crate::auth::extract_auth_token_from_header( 50 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 51 - ) { 52 - Some(t) => t, 53 - None => return ApiError::AuthenticationRequired.into_response(), 54 - }; 55 - let token = extracted.token; 56 - 57 - let is_service_auth = is_service_token(&token); 58 - 59 - let (did, _is_migration, controller_did): (Did, bool, Option<Did>) = if is_service_auth { 60 - debug!("Verifying service token for blob upload"); 61 - let verifier = ServiceTokenVerifier::new(); 62 - match verifier 63 - .verify_service_token(&token, Some("com.atproto.repo.uploadBlob")) 64 - .await 65 - { 66 - Ok(claims) => { 67 - debug!("Service token verified for DID: {}", claims.iss); 68 - let did: Did = match claims.iss.parse() { 69 - Ok(d) => d, 70 - Err(_) => { 71 - return ApiError::InvalidDid("Invalid DID format".into()).into_response(); 72 - } 73 - }; 74 - (did, false, None) 49 + ) -> Result<Response, ApiError> { 50 + let (did, controller_did): (Did, Option<Did>) = match &auth.0 { 51 + AuthenticatedEntity::Service { did, claims } => { 52 + match &claims.lxm { 53 + Some(lxm) if lxm == "*" || lxm == "com.atproto.repo.uploadBlob" => {} 54 + Some(lxm) => { 55 + return Err(ApiError::AuthorizationError(format!( 56 + "Token lxm '{}' does not permit 'com.atproto.repo.uploadBlob'", 57 + lxm 58 + ))); 59 + } 60 + None => { 61 + return Err(ApiError::AuthorizationError( 62 + "Token missing lxm claim".to_string(), 63 + )); 64 + } 75 65 } 76 - Err(e) => { 77 - error!("Service token verification failed: {:?}", e); 78 - return ApiError::AuthenticationFailed(Some(format!( 79 - "Service token verification failed: {}", 80 - e 81 - ))) 82 - .into_response(); 83 - } 66 + (did.clone(), None) 84 67 } 85 - } else { 86 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 87 - let http_uri = format!( 88 - "https://{}/xrpc/com.atproto.repo.uploadBlob", 89 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 90 - ); 91 - match crate::auth::validate_token_with_dpop( 92 - state.user_repo.as_ref(), 93 - state.oauth_repo.as_ref(), 94 - &token, 95 - extracted.is_dpop, 96 - dpop_proof, 97 - "POST", 98 - &http_uri, 99 - true, 100 - false, 101 - ) 102 - .await 103 - { 104 - Ok(user) => { 105 - let mime_type_for_check = headers 106 - .get("content-type") 107 - .and_then(|h| h.to_str().ok()) 108 - .unwrap_or("application/octet-stream"); 109 - if let Err(e) = crate::auth::scope_check::check_blob_scope( 110 - user.is_oauth, 111 - user.scope.as_deref(), 112 - mime_type_for_check, 113 - ) { 114 - return e; 115 - } 116 - let deactivated = state 117 - .user_repo 118 - .get_status_by_did(&user.did) 119 - .await 120 - .ok() 121 - .flatten() 122 - .and_then(|s| s.deactivated_at); 123 - let ctrl_did = user.controller_did.clone(); 124 - (user.did, deactivated.is_some(), ctrl_did) 68 + AuthenticatedEntity::User(auth_user) => { 69 + if auth_user.status.is_takendown() { 70 + return Err(ApiError::AccountTakedown); 125 71 } 126 - Err(_) => { 127 - return ApiError::AuthenticationFailed(None).into_response(); 72 + let mime_type_for_check = headers 73 + .get("content-type") 74 + .and_then(|h| h.to_str().ok()) 75 + .unwrap_or("application/octet-stream"); 76 + if let Err(e) = crate::auth::scope_check::check_blob_scope( 77 + auth_user.is_oauth, 78 + auth_user.scope.as_deref(), 79 + mime_type_for_check, 80 + ) { 81 + return Ok(e); 128 82 } 83 + let ctrl_did = auth_user.controller_did.clone(); 84 + (auth_user.did.clone(), ctrl_did) 129 85 } 130 86 }; 131 87 ··· 135 91 .await 136 92 .unwrap_or(false) 137 93 { 138 - return ApiError::Forbidden.into_response(); 94 + return Err(ApiError::Forbidden); 139 95 } 140 96 141 97 let client_mime_hint = headers ··· 143 99 .and_then(|h| h.to_str().ok()) 144 100 .unwrap_or("application/octet-stream"); 145 101 146 - let user_id = match state.user_repo.get_id_by_did(&did).await { 147 - Ok(Some(id)) => id, 148 - _ => { 149 - return ApiError::InternalError(None).into_response(); 150 - } 151 - }; 102 + let user_id = state 103 + .user_repo 104 + .get_id_by_did(&did) 105 + .await 106 + .ok() 107 + .flatten() 108 + .ok_or(ApiError::InternalError(None))?; 152 109 153 110 let temp_key = format!("temp/{}", uuid::Uuid::new_v4()); 154 111 let max_size = get_max_blob_size() as u64; ··· 161 118 162 119 info!("Starting streaming blob upload to temp key: {}", temp_key); 163 120 164 - let upload_result = match state.blob_store.put_stream(&temp_key, pinned_stream).await { 165 - Ok(result) => result, 166 - Err(e) => { 121 + let upload_result = state 122 + .blob_store 123 + .put_stream(&temp_key, pinned_stream) 124 + .await 125 + .map_err(|e| { 167 126 error!("Failed to stream blob to storage: {:?}", e); 168 - return ApiError::InternalError(Some("Failed to store blob".into())).into_response(); 169 - } 170 - }; 127 + ApiError::InternalError(Some("Failed to store blob".into())) 128 + })?; 171 129 172 130 let size = upload_result.size; 173 131 if size > max_size { 174 132 let _ = state.blob_store.delete(&temp_key).await; 175 - return ApiError::InvalidRequest(format!( 133 + return Err(ApiError::InvalidRequest(format!( 176 134 "Blob size {} exceeds maximum of {} bytes", 177 135 size, max_size 178 - )) 179 - .into_response(); 136 + ))); 180 137 } 181 138 182 139 let mime_type = match state.blob_store.get_head(&temp_key, 8192).await { ··· 192 149 Err(e) => { 193 150 let _ = state.blob_store.delete(&temp_key).await; 194 151 error!("Failed to create multihash for blob: {:?}", e); 195 - return ApiError::InternalError(Some("Failed to hash blob".into())).into_response(); 152 + return Err(ApiError::InternalError(Some("Failed to hash blob".into()))); 196 153 } 197 154 }; 198 155 let cid = Cid::new_v1(0x55, multihash); ··· 215 172 Err(e) => { 216 173 let _ = state.blob_store.delete(&temp_key).await; 217 174 error!("Failed to insert blob record: {:?}", e); 218 - return ApiError::InternalError(None).into_response(); 175 + return Err(ApiError::InternalError(None)); 219 176 } 220 177 }; 221 178 222 179 if was_inserted && let Err(e) = state.blob_store.copy(&temp_key, &storage_key).await { 223 180 let _ = state.blob_store.delete(&temp_key).await; 224 181 error!("Failed to copy blob to final location: {:?}", e); 225 - return ApiError::InternalError(Some("Failed to store blob".into())).into_response(); 182 + return Err(ApiError::InternalError(Some("Failed to store blob".into()))); 226 183 } 227 184 228 185 let _ = state.blob_store.delete(&temp_key).await; ··· 246 203 .await; 247 204 } 248 205 249 - Json(json!({ 206 + Ok(Json(json!({ 250 207 "blob": { 251 208 "$type": "blob", 252 209 "ref": { ··· 256 213 "size": size 257 214 } 258 215 })) 259 - .into_response() 216 + .into_response()) 260 217 } 261 218 262 219 #[derive(Deserialize)] ··· 281 238 282 239 pub async fn list_missing_blobs( 283 240 State(state): State<AppState>, 284 - auth: BearerAuthAllowDeactivated, 241 + auth: RequiredAuth, 285 242 Query(params): Query<ListMissingBlobsParams>, 286 - ) -> Response { 287 - let auth_user = auth.0; 243 + ) -> Result<Response, ApiError> { 244 + let auth_user = auth.0.require_user()?.require_not_takendown()?; 245 + 288 246 let did = &auth_user.did; 289 - let user = match state.user_repo.get_by_did(did).await { 290 - Ok(Some(u)) => u, 291 - Ok(None) => return ApiError::InternalError(None).into_response(), 292 - Err(e) => { 247 + let user = state 248 + .user_repo 249 + .get_by_did(did) 250 + .await 251 + .map_err(|e| { 293 252 error!("DB error fetching user: {:?}", e); 294 - return ApiError::InternalError(None).into_response(); 295 - } 296 - }; 253 + ApiError::InternalError(None) 254 + })? 255 + .ok_or(ApiError::InternalError(None))?; 256 + 297 257 let limit = params.limit.unwrap_or(500).clamp(1, 1000); 298 258 let cursor = params.cursor.as_deref(); 299 - let missing = match state 259 + let missing = state 300 260 .blob_repo 301 261 .list_missing_blobs(user.id, cursor, limit + 1) 302 262 .await 303 - { 304 - Ok(m) => m, 305 - Err(e) => { 263 + .map_err(|e| { 306 264 error!("DB error fetching missing blobs: {:?}", e); 307 - return ApiError::InternalError(None).into_response(); 308 - } 309 - }; 265 + ApiError::InternalError(None) 266 + })?; 267 + 310 268 let has_more = missing.len() > limit as usize; 311 269 let blobs: Vec<RecordBlob> = missing 312 270 .into_iter() ··· 321 279 } else { 322 280 None 323 281 }; 324 - ( 282 + Ok(( 325 283 StatusCode::OK, 326 284 Json(ListMissingBlobsOutput { 327 285 cursor: next_cursor, 328 286 blobs, 329 287 }), 330 288 ) 331 - .into_response() 289 + .into_response()) 332 290 }
+129 -130
crates/tranquil-pds/src/api/repo/import.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 3 use crate::api::repo::record::create_signed_commit; 4 - use crate::auth::BearerAuthAllowDeactivated; 4 + use crate::auth::RequiredAuth; 5 5 use crate::state::AppState; 6 6 use crate::sync::import::{ImportError, apply_import, parse_car}; 7 7 use crate::sync::verify::CarVerifier; ··· 23 23 24 24 pub async fn import_repo( 25 25 State(state): State<AppState>, 26 - auth: BearerAuthAllowDeactivated, 26 + auth: RequiredAuth, 27 27 body: Bytes, 28 - ) -> Response { 28 + ) -> Result<Response, ApiError> { 29 29 let accepting_imports = std::env::var("ACCEPTING_REPO_IMPORTS") 30 30 .map(|v| v != "false" && v != "0") 31 31 .unwrap_or(true); 32 32 if !accepting_imports { 33 - return ApiError::InvalidRequest("Service is not accepting repo imports".into()) 34 - .into_response(); 33 + return Err(ApiError::InvalidRequest( 34 + "Service is not accepting repo imports".into(), 35 + )); 35 36 } 36 37 let max_size: usize = std::env::var("MAX_IMPORT_SIZE") 37 38 .ok() 38 39 .and_then(|s| s.parse().ok()) 39 40 .unwrap_or(DEFAULT_MAX_IMPORT_SIZE); 40 41 if body.len() > max_size { 41 - return ApiError::PayloadTooLarge(format!( 42 + return Err(ApiError::PayloadTooLarge(format!( 42 43 "Import size exceeds limit of {} bytes", 43 44 max_size 44 - )) 45 - .into_response(); 45 + ))); 46 46 } 47 - let auth_user = auth.0; 47 + let auth_user = auth.0.require_user()?.require_not_takendown()?; 48 48 let did = &auth_user.did; 49 - let user = match state.user_repo.get_by_did(did).await { 50 - Ok(Some(row)) => row, 51 - Ok(None) => { 52 - return ApiError::AccountNotFound.into_response(); 53 - } 54 - Err(e) => { 49 + let user = state 50 + .user_repo 51 + .get_by_did(did) 52 + .await 53 + .map_err(|e| { 55 54 error!("DB error fetching user: {:?}", e); 56 - return ApiError::InternalError(None).into_response(); 57 - } 58 - }; 55 + ApiError::InternalError(None) 56 + })? 57 + .ok_or(ApiError::AccountNotFound)?; 59 58 if user.takedown_ref.is_some() { 60 - return ApiError::AccountTakedown.into_response(); 59 + return Err(ApiError::AccountTakedown); 61 60 } 62 61 let user_id = user.id; 63 62 let (root, blocks) = match parse_car(&body).await { 64 63 Ok((r, b)) => (r, b), 65 64 Err(ImportError::InvalidRootCount) => { 66 - return ApiError::InvalidRequest("Expected exactly one root in CAR file".into()) 67 - .into_response(); 65 + return Err(ApiError::InvalidRequest( 66 + "Expected exactly one root in CAR file".into(), 67 + )); 68 68 } 69 69 Err(ImportError::CarParse(msg)) => { 70 - return ApiError::InvalidRequest(format!("Failed to parse CAR file: {}", msg)) 71 - .into_response(); 70 + return Err(ApiError::InvalidRequest(format!( 71 + "Failed to parse CAR file: {}", 72 + msg 73 + ))); 72 74 } 73 75 Err(e) => { 74 76 error!("CAR parsing error: {:?}", e); 75 - return ApiError::InvalidRequest(format!("Invalid CAR file: {}", e)).into_response(); 77 + return Err(ApiError::InvalidRequest(format!("Invalid CAR file: {}", e))); 76 78 } 77 79 }; 78 80 info!( ··· 82 84 root 83 85 ); 84 86 let Some(root_block) = blocks.get(&root) else { 85 - return ApiError::InvalidRequest("Root block not found in CAR file".into()).into_response(); 87 + return Err(ApiError::InvalidRequest( 88 + "Root block not found in CAR file".into(), 89 + )); 86 90 }; 87 91 let commit_did = match jacquard_repo::commit::Commit::from_cbor(root_block) { 88 92 Ok(commit) => commit.did().to_string(), 89 93 Err(e) => { 90 - return ApiError::InvalidRequest(format!("Invalid commit: {}", e)).into_response(); 94 + return Err(ApiError::InvalidRequest(format!("Invalid commit: {}", e))); 91 95 } 92 96 }; 93 97 if commit_did != *did { 94 - return ApiError::InvalidRepo(format!( 98 + return Err(ApiError::InvalidRepo(format!( 95 99 "CAR file is for DID {} but you are authenticated as {}", 96 100 commit_did, did 97 - )) 98 - .into_response(); 101 + ))); 99 102 } 100 103 let skip_verification = std::env::var("SKIP_IMPORT_VERIFICATION") 101 104 .map(|v| v == "true" || v == "1") ··· 117 120 commit_did, 118 121 expected_did, 119 122 }) => { 120 - return ApiError::InvalidRepo(format!( 123 + return Err(ApiError::InvalidRepo(format!( 121 124 "CAR file is for DID {} but you are authenticated as {}", 122 125 commit_did, expected_did 123 - )) 124 - .into_response(); 126 + ))); 125 127 } 126 128 Err(crate::sync::verify::VerifyError::MstValidationFailed(msg)) => { 127 - return ApiError::InvalidRequest(format!("MST validation failed: {}", msg)) 128 - .into_response(); 129 + return Err(ApiError::InvalidRequest(format!( 130 + "MST validation failed: {}", 131 + msg 132 + ))); 129 133 } 130 134 Err(e) => { 131 135 error!("CAR structure verification error: {:?}", e); 132 - return ApiError::InvalidRequest(format!("CAR verification failed: {}", e)) 133 - .into_response(); 136 + return Err(ApiError::InvalidRequest(format!( 137 + "CAR verification failed: {}", 138 + e 139 + ))); 134 140 } 135 141 } 136 142 } else { ··· 147 153 commit_did, 148 154 expected_did, 149 155 }) => { 150 - return ApiError::InvalidRepo(format!( 156 + return Err(ApiError::InvalidRepo(format!( 151 157 "CAR file is for DID {} but you are authenticated as {}", 152 158 commit_did, expected_did 153 - )) 154 - .into_response(); 159 + ))); 155 160 } 156 161 Err(crate::sync::verify::VerifyError::InvalidSignature) => { 157 - return ApiError::InvalidRequest( 162 + return Err(ApiError::InvalidRequest( 158 163 "CAR file commit signature verification failed".into(), 159 - ) 160 - .into_response(); 164 + )); 161 165 } 162 166 Err(crate::sync::verify::VerifyError::DidResolutionFailed(msg)) => { 163 167 warn!("DID resolution failed during import verification: {}", msg); 164 - return ApiError::InvalidRequest(format!("Failed to verify DID: {}", msg)) 165 - .into_response(); 168 + return Err(ApiError::InvalidRequest(format!( 169 + "Failed to verify DID: {}", 170 + msg 171 + ))); 166 172 } 167 173 Err(crate::sync::verify::VerifyError::NoSigningKey) => { 168 - return ApiError::InvalidRequest( 174 + return Err(ApiError::InvalidRequest( 169 175 "DID document does not contain a signing key".into(), 170 - ) 171 - .into_response(); 176 + )); 172 177 } 173 178 Err(crate::sync::verify::VerifyError::MstValidationFailed(msg)) => { 174 - return ApiError::InvalidRequest(format!("MST validation failed: {}", msg)) 175 - .into_response(); 179 + return Err(ApiError::InvalidRequest(format!( 180 + "MST validation failed: {}", 181 + msg 182 + ))); 176 183 } 177 184 Err(e) => { 178 185 error!("CAR verification error: {:?}", e); 179 - return ApiError::InvalidRequest(format!("CAR verification failed: {}", e)) 180 - .into_response(); 186 + return Err(ApiError::InvalidRequest(format!( 187 + "CAR verification failed: {}", 188 + e 189 + ))); 181 190 } 182 191 } 183 192 } ··· 227 236 } 228 237 } 229 238 } 230 - let key_row = match state.user_repo.get_user_with_key_by_did(did).await { 231 - Ok(Some(row)) => row, 232 - Ok(None) => { 239 + let key_row = state 240 + .user_repo 241 + .get_user_with_key_by_did(did) 242 + .await 243 + .map_err(|e| { 244 + error!("DB error fetching signing key: {:?}", e); 245 + ApiError::InternalError(None) 246 + })? 247 + .ok_or_else(|| { 233 248 error!("No signing key found for user {}", did); 234 - return ApiError::InternalError(Some("Signing key not found".into())) 235 - .into_response(); 236 - } 237 - Err(e) => { 238 - error!("DB error fetching signing key: {:?}", e); 239 - return ApiError::InternalError(None).into_response(); 240 - } 241 - }; 249 + ApiError::InternalError(Some("Signing key not found".into())) 250 + })?; 242 251 let key_bytes = 243 - match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) { 244 - Ok(k) => k, 245 - Err(e) => { 252 + crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 253 + .map_err(|e| { 246 254 error!("Failed to decrypt signing key: {}", e); 247 - return ApiError::InternalError(None).into_response(); 248 - } 249 - }; 250 - let signing_key = match SigningKey::from_slice(&key_bytes) { 251 - Ok(k) => k, 252 - Err(e) => { 253 - error!("Invalid signing key: {:?}", e); 254 - return ApiError::InternalError(None).into_response(); 255 - } 256 - }; 255 + ApiError::InternalError(None) 256 + })?; 257 + let signing_key = SigningKey::from_slice(&key_bytes).map_err(|e| { 258 + error!("Invalid signing key: {:?}", e); 259 + ApiError::InternalError(None) 260 + })?; 257 261 let new_rev = Tid::now(LimitedU32::MIN); 258 262 let new_rev_str = new_rev.to_string(); 259 - let (commit_bytes, _sig) = match create_signed_commit( 263 + let (commit_bytes, _sig) = create_signed_commit( 260 264 did, 261 265 import_result.data_cid, 262 266 &new_rev_str, 263 267 None, 264 268 &signing_key, 265 - ) { 266 - Ok(result) => result, 267 - Err(e) => { 268 - error!("Failed to create new commit: {}", e); 269 - return ApiError::InternalError(None).into_response(); 270 - } 271 - }; 272 - let new_root_cid: cid::Cid = match state.block_store.put(&commit_bytes).await { 273 - Ok(cid) => cid, 274 - Err(e) => { 269 + ) 270 + .map_err(|e| { 271 + error!("Failed to create new commit: {}", e); 272 + ApiError::InternalError(None) 273 + })?; 274 + let new_root_cid: cid::Cid = 275 + state.block_store.put(&commit_bytes).await.map_err(|e| { 275 276 error!("Failed to store new commit block: {:?}", e); 276 - return ApiError::InternalError(None).into_response(); 277 - } 278 - }; 277 + ApiError::InternalError(None) 278 + })?; 279 279 let new_root_cid_link = CidLink::new_unchecked(new_root_cid.to_string()); 280 - if let Err(e) = state 280 + state 281 281 .repo_repo 282 282 .update_repo_root(user_id, &new_root_cid_link, &new_rev_str) 283 283 .await 284 - { 285 - error!("Failed to update repo root: {:?}", e); 286 - return ApiError::InternalError(None).into_response(); 287 - } 284 + .map_err(|e| { 285 + error!("Failed to update repo root: {:?}", e); 286 + ApiError::InternalError(None) 287 + })?; 288 288 let mut all_block_cids: Vec<Vec<u8>> = blocks.keys().map(|c| c.to_bytes()).collect(); 289 289 all_block_cids.push(new_root_cid.to_bytes()); 290 - if let Err(e) = state 290 + state 291 291 .repo_repo 292 292 .insert_user_blocks(user_id, &all_block_cids, &new_rev_str) 293 293 .await 294 - { 295 - error!("Failed to insert user_blocks: {:?}", e); 296 - return ApiError::InternalError(None).into_response(); 297 - } 294 + .map_err(|e| { 295 + error!("Failed to insert user_blocks: {:?}", e); 296 + ApiError::InternalError(None) 297 + })?; 298 298 let new_root_str = new_root_cid.to_string(); 299 299 info!( 300 300 "Created new commit for imported repo: cid={}, rev={}", ··· 324 324 ); 325 325 } 326 326 } 327 - EmptyResponse::ok().into_response() 327 + Ok(EmptyResponse::ok().into_response()) 328 328 } 329 - Err(ImportError::SizeLimitExceeded) => { 330 - ApiError::PayloadTooLarge(format!("Import exceeds block limit of {}", max_blocks)) 331 - .into_response() 332 - } 333 - Err(ImportError::RepoNotFound) => { 334 - ApiError::RepoNotFound(Some("Repository not initialized for this account".into())) 335 - .into_response() 336 - } 337 - Err(ImportError::InvalidCbor(msg)) => { 338 - ApiError::InvalidRequest(format!("Invalid CBOR data: {}", msg)).into_response() 339 - } 340 - Err(ImportError::InvalidCommit(msg)) => { 341 - ApiError::InvalidRequest(format!("Invalid commit structure: {}", msg)).into_response() 342 - } 343 - Err(ImportError::BlockNotFound(cid)) => { 344 - ApiError::InvalidRequest(format!("Referenced block not found in CAR: {}", cid)) 345 - .into_response() 346 - } 347 - Err(ImportError::ConcurrentModification) => ApiError::InvalidSwap(Some( 329 + Err(ImportError::SizeLimitExceeded) => Err(ApiError::PayloadTooLarge(format!( 330 + "Import exceeds block limit of {}", 331 + max_blocks 332 + ))), 333 + Err(ImportError::RepoNotFound) => Err(ApiError::RepoNotFound(Some( 334 + "Repository not initialized for this account".into(), 335 + ))), 336 + Err(ImportError::InvalidCbor(msg)) => Err(ApiError::InvalidRequest(format!( 337 + "Invalid CBOR data: {}", 338 + msg 339 + ))), 340 + Err(ImportError::InvalidCommit(msg)) => Err(ApiError::InvalidRequest(format!( 341 + "Invalid commit structure: {}", 342 + msg 343 + ))), 344 + Err(ImportError::BlockNotFound(cid)) => Err(ApiError::InvalidRequest(format!( 345 + "Referenced block not found in CAR: {}", 346 + cid 347 + ))), 348 + Err(ImportError::ConcurrentModification) => Err(ApiError::InvalidSwap(Some( 348 349 "Repository is being modified by another operation, please retry".into(), 349 - )) 350 - .into_response(), 351 - Err(ImportError::VerificationFailed(ve)) => { 352 - ApiError::InvalidRequest(format!("CAR verification failed: {}", ve)).into_response() 353 - } 354 - Err(ImportError::DidMismatch { car_did, auth_did }) => ApiError::InvalidRequest(format!( 355 - "CAR is for {} but authenticated as {}", 356 - car_did, auth_did 357 - )) 358 - .into_response(), 350 + ))), 351 + Err(ImportError::VerificationFailed(ve)) => Err(ApiError::InvalidRequest(format!( 352 + "CAR verification failed: {}", 353 + ve 354 + ))), 355 + Err(ImportError::DidMismatch { car_did, auth_did }) => Err(ApiError::InvalidRequest( 356 + format!("CAR is for {} but authenticated as {}", car_did, auth_did), 357 + )), 359 358 Err(e) => { 360 359 error!("Import error: {:?}", e); 361 - ApiError::InternalError(None).into_response() 360 + Err(ApiError::InternalError(None)) 362 361 } 363 362 } 364 363 }
+61 -62
crates/tranquil-pds/src/api/repo/record/batch.rs
··· 1 1 use super::validation::validate_record_with_status; 2 2 use crate::api::error::ApiError; 3 3 use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids}; 4 - use crate::auth::BearerAuth; 4 + use crate::auth::RequiredAuth; 5 5 use crate::delegation::DelegationActionType; 6 6 use crate::repo::tracking::TrackingBlockStore; 7 7 use crate::state::AppState; ··· 262 262 263 263 pub async fn apply_writes( 264 264 State(state): State<AppState>, 265 - auth: BearerAuth, 265 + auth: RequiredAuth, 266 266 Json(input): Json<ApplyWritesInput>, 267 - ) -> Response { 267 + ) -> Result<Response, ApiError> { 268 268 info!( 269 269 "apply_writes called: repo={}, writes={}", 270 270 input.repo, 271 271 input.writes.len() 272 272 ); 273 - let auth_user = auth.0; 273 + let auth_user = auth.0.require_user()?.require_active()?; 274 274 let did = auth_user.did.clone(); 275 275 let is_oauth = auth_user.is_oauth; 276 - let scope = auth_user.scope; 276 + let scope = auth_user.scope.clone(); 277 277 let controller_did = auth_user.controller_did.clone(); 278 278 if input.repo.as_str() != did { 279 - return ApiError::InvalidRepo("Repo does not match authenticated user".into()) 280 - .into_response(); 279 + return Err(ApiError::InvalidRepo( 280 + "Repo does not match authenticated user".into(), 281 + )); 281 282 } 282 283 if state 283 284 .user_repo ··· 285 286 .await 286 287 .unwrap_or(false) 287 288 { 288 - return ApiError::AccountMigrated.into_response(); 289 + return Err(ApiError::AccountMigrated); 289 290 } 290 291 let is_verified = state 291 292 .user_repo ··· 298 299 .await 299 300 .unwrap_or(false); 300 301 if !is_verified && !is_delegated { 301 - return ApiError::AccountNotVerified.into_response(); 302 + return Err(ApiError::AccountNotVerified); 302 303 } 303 304 if input.writes.is_empty() { 304 - return ApiError::InvalidRequest("writes array is empty".into()).into_response(); 305 + return Err(ApiError::InvalidRequest("writes array is empty".into())); 305 306 } 306 307 if input.writes.len() > MAX_BATCH_WRITES { 307 - return ApiError::InvalidRequest(format!("Too many writes (max {})", MAX_BATCH_WRITES)) 308 - .into_response(); 308 + return Err(ApiError::InvalidRequest(format!( 309 + "Too many writes (max {})", 310 + MAX_BATCH_WRITES 311 + ))); 309 312 } 310 313 311 314 let has_custom_scope = scope ··· 374 377 }) 375 378 .next() 376 379 { 377 - return err; 380 + return Ok(err); 378 381 } 379 382 } 380 383 381 - let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&did).await { 382 - Ok(Some(id)) => id, 383 - _ => return ApiError::InternalError(Some("User not found".into())).into_response(), 384 - }; 385 - let root_cid_str = match state.repo_repo.get_repo_root_cid_by_user_id(user_id).await { 386 - Ok(Some(cid_str)) => cid_str, 387 - _ => return ApiError::InternalError(Some("Repo root not found".into())).into_response(), 388 - }; 389 - let current_root_cid = match Cid::from_str(&root_cid_str) { 390 - Ok(c) => c, 391 - Err(_) => { 392 - return ApiError::InternalError(Some("Invalid repo root CID".into())).into_response(); 393 - } 394 - }; 384 + let user_id: uuid::Uuid = state 385 + .user_repo 386 + .get_id_by_did(&did) 387 + .await 388 + .ok() 389 + .flatten() 390 + .ok_or_else(|| ApiError::InternalError(Some("User not found".into())))?; 391 + let root_cid_str = state 392 + .repo_repo 393 + .get_repo_root_cid_by_user_id(user_id) 394 + .await 395 + .ok() 396 + .flatten() 397 + .ok_or_else(|| ApiError::InternalError(Some("Repo root not found".into())))?; 398 + let current_root_cid = Cid::from_str(&root_cid_str) 399 + .map_err(|_| ApiError::InternalError(Some("Invalid repo root CID".into())))?; 395 400 if let Some(swap_commit) = &input.swap_commit 396 401 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 397 402 { 398 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 403 + return Err(ApiError::InvalidSwap(Some("Repo has been modified".into()))); 399 404 } 400 405 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 401 - let commit_bytes = match tracking_store.get(&current_root_cid).await { 402 - Ok(Some(b)) => b, 403 - _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(), 404 - }; 405 - let commit = match Commit::from_cbor(&commit_bytes) { 406 - Ok(c) => c, 407 - _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 408 - }; 406 + let commit_bytes = tracking_store 407 + .get(&current_root_cid) 408 + .await 409 + .ok() 410 + .flatten() 411 + .ok_or_else(|| ApiError::InternalError(Some("Commit block not found".into())))?; 412 + let commit = Commit::from_cbor(&commit_bytes) 413 + .map_err(|_| ApiError::InternalError(Some("Failed to parse commit".into())))?; 409 414 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 410 415 let initial_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 411 416 let WriteAccumulator { ··· 424 429 .await 425 430 { 426 431 Ok(acc) => acc, 427 - Err(response) => return response, 432 + Err(response) => return Ok(response), 428 433 }; 429 - let new_mst_root = match mst.persist().await { 430 - Ok(c) => c, 431 - Err(_) => { 432 - return ApiError::InternalError(Some("Failed to persist MST".into())).into_response(); 433 - } 434 - }; 434 + let new_mst_root = mst 435 + .persist() 436 + .await 437 + .map_err(|_| ApiError::InternalError(Some("Failed to persist MST".into())))?; 435 438 let (new_mst_blocks, old_mst_blocks) = { 436 439 let mut new_blocks = std::collections::BTreeMap::new(); 437 440 let mut old_blocks = std::collections::BTreeMap::new(); 438 441 for key in &modified_keys { 439 - if mst.blocks_for_path(key, &mut new_blocks).await.is_err() { 440 - return ApiError::InternalError(Some( 441 - "Failed to get new MST blocks for path".into(), 442 - )) 443 - .into_response(); 444 - } 445 - if original_mst 442 + mst.blocks_for_path(key, &mut new_blocks) 443 + .await 444 + .map_err(|_| { 445 + ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 446 + })?; 447 + original_mst 446 448 .blocks_for_path(key, &mut old_blocks) 447 449 .await 448 - .is_err() 449 - { 450 - return ApiError::InternalError(Some( 451 - "Failed to get old MST blocks for path".into(), 452 - )) 453 - .into_response(); 454 - } 450 + .map_err(|_| { 451 + ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 452 + })?; 455 453 } 456 454 (new_blocks, old_blocks) 457 455 }; ··· 503 501 { 504 502 Ok(res) => res, 505 503 Err(e) if e.contains("ConcurrentModification") => { 506 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 504 + return Err(ApiError::InvalidSwap(Some("Repo has been modified".into()))); 507 505 } 508 506 Err(e) => { 509 507 error!("Commit failed: {}", e); 510 - return ApiError::InternalError(Some("Failed to commit changes".into())) 511 - .into_response(); 508 + return Err(ApiError::InternalError(Some( 509 + "Failed to commit changes".into(), 510 + ))); 512 511 } 513 512 }; 514 513 ··· 557 556 .await; 558 557 } 559 558 560 - ( 559 + Ok(( 561 560 StatusCode::OK, 562 561 Json(ApplyWritesOutput { 563 562 commit: CommitInfo { ··· 567 566 results, 568 567 }), 569 568 ) 570 - .into_response() 569 + .into_response()) 571 570 }
+42 -32
crates/tranquil-pds/src/api/repo/record/delete.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log}; 3 3 use crate::api::repo::record::write::{CommitInfo, prepare_repo_write}; 4 + use crate::auth::RequiredAuth; 4 5 use crate::delegation::DelegationActionType; 5 6 use crate::repo::tracking::TrackingBlockStore; 6 7 use crate::state::AppState; ··· 8 9 use axum::{ 9 10 Json, 10 11 extract::State, 11 - http::{HeaderMap, StatusCode}, 12 + http::StatusCode, 12 13 response::{IntoResponse, Response}, 13 14 }; 14 15 use cid::Cid; ··· 39 40 40 41 pub async fn delete_record( 41 42 State(state): State<AppState>, 42 - headers: HeaderMap, 43 - axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 43 + auth: RequiredAuth, 44 44 Json(input): Json<DeleteRecordInput>, 45 - ) -> Response { 46 - let auth = match prepare_repo_write( 47 - &state, 48 - &headers, 49 - &input.repo, 50 - "POST", 51 - &crate::util::build_full_url(&uri.to_string()), 52 - ) 53 - .await 54 - { 45 + ) -> Result<Response, crate::api::error::ApiError> { 46 + let user = auth.0.require_user()?.require_active()?; 47 + let auth = match prepare_repo_write(&state, user, &input.repo).await { 55 48 Ok(res) => res, 56 - Err(err_res) => return err_res, 49 + Err(err_res) => return Ok(err_res), 57 50 }; 58 51 59 52 if let Err(e) = crate::auth::scope_check::check_repo_scope( ··· 62 55 crate::oauth::RepoAction::Delete, 63 56 &input.collection, 64 57 ) { 65 - return e; 58 + return Ok(e); 66 59 } 67 60 68 61 let did = auth.did; ··· 73 66 if let Some(swap_commit) = &input.swap_commit 74 67 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 75 68 { 76 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 69 + return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 77 70 } 78 71 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 79 72 let commit_bytes = match tracking_store.get(&current_root_cid).await { 80 73 Ok(Some(b)) => b, 81 - _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(), 74 + _ => { 75 + return Ok( 76 + ApiError::InternalError(Some("Commit block not found".into())).into_response(), 77 + ); 78 + } 82 79 }; 83 80 let commit = match Commit::from_cbor(&commit_bytes) { 84 81 Ok(c) => c, 85 - _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 82 + _ => { 83 + return Ok( 84 + ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 85 + ); 86 + } 86 87 }; 87 88 let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 88 89 let key = format!("{}/{}", input.collection, input.rkey); ··· 90 91 let expected_cid = Cid::from_str(swap_record_str).ok(); 91 92 let actual_cid = mst.get(&key).await.ok().flatten(); 92 93 if expected_cid != actual_cid { 93 - return ApiError::InvalidSwap(Some( 94 + return Ok(ApiError::InvalidSwap(Some( 94 95 "Record has been modified or does not exist".into(), 95 96 )) 96 - .into_response(); 97 + .into_response()); 97 98 } 98 99 } 99 100 let prev_record_cid = mst.get(&key).await.ok().flatten(); 100 101 if prev_record_cid.is_none() { 101 - return (StatusCode::OK, Json(DeleteRecordOutput { commit: None })).into_response(); 102 + return Ok((StatusCode::OK, Json(DeleteRecordOutput { commit: None })).into_response()); 102 103 } 103 104 let new_mst = match mst.delete(&key).await { 104 105 Ok(m) => m, 105 106 Err(e) => { 106 107 error!("Failed to delete from MST: {:?}", e); 107 - return ApiError::InternalError(Some(format!("Failed to delete from MST: {:?}", e))) 108 - .into_response(); 108 + return Ok(ApiError::InternalError(Some(format!( 109 + "Failed to delete from MST: {:?}", 110 + e 111 + ))) 112 + .into_response()); 109 113 } 110 114 }; 111 115 let new_mst_root = match new_mst.persist().await { 112 116 Ok(c) => c, 113 117 Err(e) => { 114 118 error!("Failed to persist MST: {:?}", e); 115 - return ApiError::InternalError(Some("Failed to persist MST".into())).into_response(); 119 + return Ok( 120 + ApiError::InternalError(Some("Failed to persist MST".into())).into_response(), 121 + ); 116 122 } 117 123 }; 118 124 let collection_for_audit = input.collection.to_string(); ··· 129 135 .await 130 136 .is_err() 131 137 { 132 - return ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 133 - .into_response(); 138 + return Ok( 139 + ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 140 + .into_response(), 141 + ); 134 142 } 135 143 if mst 136 144 .blocks_for_path(&key, &mut old_mst_blocks) 137 145 .await 138 146 .is_err() 139 147 { 140 - return ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 141 - .into_response(); 148 + return Ok( 149 + ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 150 + .into_response(), 151 + ); 142 152 } 143 153 let mut relevant_blocks = new_mst_blocks.clone(); 144 154 relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); ··· 177 187 { 178 188 Ok(res) => res, 179 189 Err(e) if e.contains("ConcurrentModification") => { 180 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 190 + return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 181 191 } 182 - Err(e) => return ApiError::InternalError(Some(e)).into_response(), 192 + Err(e) => return Ok(ApiError::InternalError(Some(e)).into_response()), 183 193 }; 184 194 185 195 if let Some(ref controller) = controller_did { ··· 210 220 error!("Failed to remove backlinks for {}: {}", deleted_uri, e); 211 221 } 212 222 213 - ( 223 + Ok(( 214 224 StatusCode::OK, 215 225 Json(DeleteRecordOutput { 216 226 commit: Some(CommitInfo { ··· 219 229 }), 220 230 }), 221 231 ) 222 - .into_response() 232 + .into_response()) 223 233 } 224 234 225 235 use crate::types::Did;
+101 -106
crates/tranquil-pds/src/api/repo/record/write.rs
··· 3 3 use crate::api::repo::record::utils::{ 4 4 CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids, 5 5 }; 6 + use crate::auth::RequiredAuth; 6 7 use crate::delegation::DelegationActionType; 7 8 use crate::repo::tracking::TrackingBlockStore; 8 9 use crate::state::AppState; ··· 10 11 use axum::{ 11 12 Json, 12 13 extract::State, 13 - http::{HeaderMap, StatusCode}, 14 + http::StatusCode, 14 15 response::{IntoResponse, Response}, 15 16 }; 16 17 use cid::Cid; ··· 33 34 34 35 pub async fn prepare_repo_write( 35 36 state: &AppState, 36 - headers: &HeaderMap, 37 + auth_user: &crate::auth::AuthenticatedUser, 37 38 repo: &AtIdentifier, 38 - http_method: &str, 39 - http_uri: &str, 40 39 ) -> Result<RepoWriteAuth, Response> { 41 - let extracted = crate::auth::extract_auth_token_from_header( 42 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 43 - ) 44 - .ok_or_else(|| ApiError::AuthenticationRequired.into_response())?; 45 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 46 - let auth_user = crate::auth::validate_token_with_dpop( 47 - state.user_repo.as_ref(), 48 - state.oauth_repo.as_ref(), 49 - &extracted.token, 50 - extracted.is_dpop, 51 - dpop_proof, 52 - http_method, 53 - http_uri, 54 - false, 55 - false, 56 - ) 57 - .await 58 - .map_err(|e| { 59 - tracing::warn!(error = ?e, is_dpop = extracted.is_dpop, "Token validation failed in prepare_repo_write"); 60 - ApiError::from(e).into_response() 61 - })?; 62 40 if repo.as_str() != auth_user.did.as_str() { 63 41 return Err( 64 42 ApiError::InvalidRepo("Repo does not match authenticated user".into()).into_response(), ··· 113 91 user_id, 114 92 current_root_cid, 115 93 is_oauth: auth_user.is_oauth, 116 - scope: auth_user.scope, 94 + scope: auth_user.scope.clone(), 117 95 controller_did: auth_user.controller_did.clone(), 118 96 }) 119 97 } ··· 146 124 } 147 125 pub async fn create_record( 148 126 State(state): State<AppState>, 149 - headers: HeaderMap, 150 - axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 127 + auth: RequiredAuth, 151 128 Json(input): Json<CreateRecordInput>, 152 - ) -> Response { 153 - let auth = match prepare_repo_write( 154 - &state, 155 - &headers, 156 - &input.repo, 157 - "POST", 158 - &crate::util::build_full_url(&uri.to_string()), 159 - ) 160 - .await 161 - { 129 + ) -> Result<Response, crate::api::error::ApiError> { 130 + let user = auth.0.require_user()?.require_active()?; 131 + let auth = match prepare_repo_write(&state, user, &input.repo).await { 162 132 Ok(res) => res, 163 - Err(err_res) => return err_res, 133 + Err(err_res) => return Ok(err_res), 164 134 }; 165 135 166 136 if let Err(e) = crate::auth::scope_check::check_repo_scope( ··· 169 139 crate::oauth::RepoAction::Create, 170 140 &input.collection, 171 141 ) { 172 - return e; 142 + return Ok(e); 173 143 } 174 144 175 145 let did = auth.did; ··· 180 150 if let Some(swap_commit) = &input.swap_commit 181 151 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 182 152 { 183 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 153 + return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 184 154 } 185 155 186 156 let validation_status = if input.validate == Some(false) { ··· 194 164 require_lexicon, 195 165 ) { 196 166 Ok(status) => Some(status), 197 - Err(err_response) => return *err_response, 167 + Err(err_response) => return Ok(*err_response), 198 168 } 199 169 }; 200 170 let rkey = input.rkey.unwrap_or_else(Rkey::generate); ··· 202 172 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 203 173 let commit_bytes = match tracking_store.get(&current_root_cid).await { 204 174 Ok(Some(b)) => b, 205 - _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(), 175 + _ => { 176 + return Ok( 177 + ApiError::InternalError(Some("Commit block not found".into())).into_response(), 178 + ); 179 + } 206 180 }; 207 181 let commit = match Commit::from_cbor(&commit_bytes) { 208 182 Ok(c) => c, 209 - _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 183 + _ => { 184 + return Ok( 185 + ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 186 + ); 187 + } 210 188 }; 211 189 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 212 190 let initial_mst_root = commit.data; ··· 228 206 Ok(c) => c, 229 207 Err(e) => { 230 208 error!("Failed to check backlink conflicts: {}", e); 231 - return ApiError::InternalError(None).into_response(); 209 + return Ok(ApiError::InternalError(None).into_response()); 232 210 } 233 211 }; 234 212 ··· 281 259 let record_ipld = crate::util::json_to_ipld(&input.record); 282 260 let mut record_bytes = Vec::new(); 283 261 if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() { 284 - return ApiError::InvalidRecord("Failed to serialize record".into()).into_response(); 262 + return Ok(ApiError::InvalidRecord("Failed to serialize record".into()).into_response()); 285 263 } 286 264 let record_cid = match tracking_store.put(&record_bytes).await { 287 265 Ok(c) => c, 288 266 _ => { 289 - return ApiError::InternalError(Some("Failed to save record block".into())) 290 - .into_response(); 267 + return Ok( 268 + ApiError::InternalError(Some("Failed to save record block".into())).into_response(), 269 + ); 291 270 } 292 271 }; 293 272 let key = format!("{}/{}", input.collection, rkey); ··· 302 281 303 282 let new_mst = match mst.add(&key, record_cid).await { 304 283 Ok(m) => m, 305 - _ => return ApiError::InternalError(Some("Failed to add to MST".into())).into_response(), 284 + _ => { 285 + return Ok(ApiError::InternalError(Some("Failed to add to MST".into())).into_response()); 286 + } 306 287 }; 307 288 let new_mst_root = match new_mst.persist().await { 308 289 Ok(c) => c, 309 - _ => return ApiError::InternalError(Some("Failed to persist MST".into())).into_response(), 290 + _ => { 291 + return Ok( 292 + ApiError::InternalError(Some("Failed to persist MST".into())).into_response(), 293 + ); 294 + } 310 295 }; 311 296 312 297 ops.push(RecordOp::Create { ··· 321 306 .await 322 307 .is_err() 323 308 { 324 - return ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 325 - .into_response(); 309 + return Ok( 310 + ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 311 + .into_response(), 312 + ); 326 313 } 327 314 328 315 let mut relevant_blocks = new_mst_blocks.clone(); ··· 364 351 { 365 352 Ok(res) => res, 366 353 Err(e) if e.contains("ConcurrentModification") => { 367 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 354 + return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 368 355 } 369 - Err(e) => return ApiError::InternalError(Some(e)).into_response(), 356 + Err(e) => return Ok(ApiError::InternalError(Some(e)).into_response()), 370 357 }; 371 358 372 359 for conflict_uri in conflict_uris_to_cleanup { ··· 406 393 error!("Failed to add backlinks for {}: {}", created_uri, e); 407 394 } 408 395 409 - ( 396 + Ok(( 410 397 StatusCode::OK, 411 398 Json(CreateRecordOutput { 412 399 uri: created_uri, ··· 418 405 validation_status: validation_status.map(|s| s.to_string()), 419 406 }), 420 407 ) 421 - .into_response() 408 + .into_response()) 422 409 } 423 410 #[derive(Deserialize)] 424 411 #[allow(dead_code)] ··· 445 432 } 446 433 pub async fn put_record( 447 434 State(state): State<AppState>, 448 - headers: HeaderMap, 449 - axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 435 + auth: RequiredAuth, 450 436 Json(input): Json<PutRecordInput>, 451 - ) -> Response { 452 - let auth = match prepare_repo_write( 453 - &state, 454 - &headers, 455 - &input.repo, 456 - "POST", 457 - &crate::util::build_full_url(&uri.to_string()), 458 - ) 459 - .await 460 - { 437 + ) -> Result<Response, crate::api::error::ApiError> { 438 + let user = auth.0.require_user()?.require_active()?; 439 + let auth = match prepare_repo_write(&state, user, &input.repo).await { 461 440 Ok(res) => res, 462 - Err(err_res) => return err_res, 441 + Err(err_res) => return Ok(err_res), 463 442 }; 464 443 465 444 if let Err(e) = crate::auth::scope_check::check_repo_scope( ··· 468 447 crate::oauth::RepoAction::Create, 469 448 &input.collection, 470 449 ) { 471 - return e; 450 + return Ok(e); 472 451 } 473 452 if let Err(e) = crate::auth::scope_check::check_repo_scope( 474 453 auth.is_oauth, ··· 476 455 crate::oauth::RepoAction::Update, 477 456 &input.collection, 478 457 ) { 479 - return e; 458 + return Ok(e); 480 459 } 481 460 482 461 let did = auth.did; ··· 487 466 if let Some(swap_commit) = &input.swap_commit 488 467 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 489 468 { 490 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 469 + return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 491 470 } 492 471 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 493 472 let commit_bytes = match tracking_store.get(&current_root_cid).await { 494 473 Ok(Some(b)) => b, 495 - _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(), 474 + _ => { 475 + return Ok( 476 + ApiError::InternalError(Some("Commit block not found".into())).into_response(), 477 + ); 478 + } 496 479 }; 497 480 let commit = match Commit::from_cbor(&commit_bytes) { 498 481 Ok(c) => c, 499 - _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 482 + _ => { 483 + return Ok( 484 + ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 485 + ); 486 + } 500 487 }; 501 488 let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 502 489 let key = format!("{}/{}", input.collection, input.rkey); ··· 511 498 require_lexicon, 512 499 ) { 513 500 Ok(status) => Some(status), 514 - Err(err_response) => return *err_response, 501 + Err(err_response) => return Ok(*err_response), 515 502 } 516 503 }; 517 504 if let Some(swap_record_str) = &input.swap_record { 518 505 let expected_cid = Cid::from_str(swap_record_str).ok(); 519 506 let actual_cid = mst.get(&key).await.ok().flatten(); 520 507 if expected_cid != actual_cid { 521 - return ApiError::InvalidSwap(Some( 508 + return Ok(ApiError::InvalidSwap(Some( 522 509 "Record has been modified or does not exist".into(), 523 510 )) 524 - .into_response(); 511 + .into_response()); 525 512 } 526 513 } 527 514 let existing_cid = mst.get(&key).await.ok().flatten(); 528 515 let record_ipld = crate::util::json_to_ipld(&input.record); 529 516 let mut record_bytes = Vec::new(); 530 517 if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() { 531 - return ApiError::InvalidRecord("Failed to serialize record".into()).into_response(); 518 + return Ok(ApiError::InvalidRecord("Failed to serialize record".into()).into_response()); 532 519 } 533 520 let record_cid = match tracking_store.put(&record_bytes).await { 534 521 Ok(c) => c, 535 522 _ => { 536 - return ApiError::InternalError(Some("Failed to save record block".into())) 537 - .into_response(); 523 + return Ok( 524 + ApiError::InternalError(Some("Failed to save record block".into())).into_response(), 525 + ); 538 526 } 539 527 }; 540 528 if existing_cid == Some(record_cid) { 541 - return ( 529 + return Ok(( 542 530 StatusCode::OK, 543 531 Json(PutRecordOutput { 544 532 uri: AtUri::from_parts(&did, &input.collection, &input.rkey), ··· 547 535 validation_status: validation_status.map(|s| s.to_string()), 548 536 }), 549 537 ) 550 - .into_response(); 538 + .into_response()); 551 539 } 552 - let new_mst = if existing_cid.is_some() { 553 - match mst.update(&key, record_cid).await { 554 - Ok(m) => m, 555 - Err(_) => { 556 - return ApiError::InternalError(Some("Failed to update MST".into())) 557 - .into_response(); 540 + let new_mst = 541 + if existing_cid.is_some() { 542 + match mst.update(&key, record_cid).await { 543 + Ok(m) => m, 544 + Err(_) => { 545 + return Ok(ApiError::InternalError(Some("Failed to update MST".into())) 546 + .into_response()); 547 + } 558 548 } 559 - } 560 - } else { 561 - match mst.add(&key, record_cid).await { 562 - Ok(m) => m, 563 - Err(_) => { 564 - return ApiError::InternalError(Some("Failed to add to MST".into())) 565 - .into_response(); 549 + } else { 550 + match mst.add(&key, record_cid).await { 551 + Ok(m) => m, 552 + Err(_) => { 553 + return Ok(ApiError::InternalError(Some("Failed to add to MST".into())) 554 + .into_response()); 555 + } 566 556 } 567 - } 568 - }; 557 + }; 569 558 let new_mst_root = match new_mst.persist().await { 570 559 Ok(c) => c, 571 560 Err(_) => { 572 - return ApiError::InternalError(Some("Failed to persist MST".into())).into_response(); 561 + return Ok( 562 + ApiError::InternalError(Some("Failed to persist MST".into())).into_response(), 563 + ); 573 564 } 574 565 }; 575 566 let op = if existing_cid.is_some() { ··· 593 584 .await 594 585 .is_err() 595 586 { 596 - return ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 597 - .into_response(); 587 + return Ok( 588 + ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 589 + .into_response(), 590 + ); 598 591 } 599 592 if mst 600 593 .blocks_for_path(&key, &mut old_mst_blocks) 601 594 .await 602 595 .is_err() 603 596 { 604 - return ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 605 - .into_response(); 597 + return Ok( 598 + ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 599 + .into_response(), 600 + ); 606 601 } 607 602 let mut relevant_blocks = new_mst_blocks.clone(); 608 603 relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); ··· 644 639 { 645 640 Ok(res) => res, 646 641 Err(e) if e.contains("ConcurrentModification") => { 647 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 642 + return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 648 643 } 649 - Err(e) => return ApiError::InternalError(Some(e)).into_response(), 644 + Err(e) => return Ok(ApiError::InternalError(Some(e)).into_response()), 650 645 }; 651 646 652 647 if let Some(ref controller) = controller_did { ··· 668 663 .await; 669 664 } 670 665 671 - ( 666 + Ok(( 672 667 StatusCode::OK, 673 668 Json(PutRecordOutput { 674 669 uri: AtUri::from_parts(&did, &input.collection, &input.rkey), ··· 680 675 validation_status: validation_status.map(|s| s.to_string()), 681 676 }), 682 677 ) 683 - .into_response() 678 + .into_response()) 684 679 }
+52 -160
crates/tranquil-pds/src/api/server/account_status.rs
··· 40 40 41 41 pub async fn check_account_status( 42 42 State(state): State<AppState>, 43 - headers: axum::http::HeaderMap, 44 - ) -> Response { 45 - let extracted = match crate::auth::extract_auth_token_from_header( 46 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 47 - ) { 48 - Some(t) => t, 49 - None => return ApiError::AuthenticationRequired.into_response(), 50 - }; 51 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 52 - let http_uri = format!( 53 - "https://{}/xrpc/com.atproto.server.checkAccountStatus", 54 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 55 - ); 56 - let did = match crate::auth::validate_token_with_dpop( 57 - state.user_repo.as_ref(), 58 - state.oauth_repo.as_ref(), 59 - &extracted.token, 60 - extracted.is_dpop, 61 - dpop_proof, 62 - "GET", 63 - &http_uri, 64 - true, 65 - false, 66 - ) 67 - .await 68 - { 69 - Ok(user) => user.did, 70 - Err(e) => return ApiError::from(e).into_response(), 71 - }; 72 - let user_id = match state.user_repo.get_id_by_did(&did).await { 73 - Ok(Some(id)) => id, 74 - _ => { 75 - return ApiError::InternalError(None).into_response(); 76 - } 77 - }; 43 + auth: crate::auth::RequiredAuth, 44 + ) -> Result<Response, ApiError> { 45 + let user = auth.0.require_user()?.require_not_takendown()?; 46 + let did = &user.did; 47 + let user_id = state 48 + .user_repo 49 + .get_id_by_did(did) 50 + .await 51 + .map_err(|_| ApiError::InternalError(None))? 52 + .ok_or(ApiError::InternalError(None))?; 78 53 let is_active = state 79 54 .user_repo 80 - .is_account_active_by_did(&did) 55 + .is_account_active_by_did(did) 81 56 .await 82 57 .ok() 83 58 .flatten() ··· 121 96 .await 122 97 .unwrap_or(0); 123 98 let valid_did = 124 - is_valid_did_for_service(state.user_repo.as_ref(), state.cache.clone(), &did).await; 125 - ( 99 + is_valid_did_for_service(state.user_repo.as_ref(), state.cache.clone(), did).await; 100 + Ok(( 126 101 StatusCode::OK, 127 102 Json(CheckAccountStatusOutput { 128 103 activated: is_active, ··· 136 111 imported_blobs, 137 112 }), 138 113 ) 139 - .into_response() 114 + .into_response()) 140 115 } 141 116 142 117 async fn is_valid_did_for_service( ··· 331 306 332 307 pub async fn activate_account( 333 308 State(state): State<AppState>, 334 - headers: axum::http::HeaderMap, 335 - ) -> Response { 309 + auth: crate::auth::RequiredAuth, 310 + ) -> Result<Response, ApiError> { 336 311 info!("[MIGRATION] activateAccount called"); 337 - let extracted = match crate::auth::extract_auth_token_from_header( 338 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 339 - ) { 340 - Some(t) => t, 341 - None => { 342 - info!("[MIGRATION] activateAccount: No auth token"); 343 - return ApiError::AuthenticationRequired.into_response(); 344 - } 345 - }; 346 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 347 - let http_uri = format!( 348 - "https://{}/xrpc/com.atproto.server.activateAccount", 349 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 350 - ); 351 - let auth_user = match crate::auth::validate_token_with_dpop( 352 - state.user_repo.as_ref(), 353 - state.oauth_repo.as_ref(), 354 - &extracted.token, 355 - extracted.is_dpop, 356 - dpop_proof, 357 - "POST", 358 - &http_uri, 359 - true, 360 - false, 361 - ) 362 - .await 363 - { 364 - Ok(user) => user, 365 - Err(e) => { 366 - info!("[MIGRATION] activateAccount: Auth failed: {:?}", e); 367 - return ApiError::from(e).into_response(); 368 - } 369 - }; 312 + let auth_user = auth.0.require_user()?.require_not_takendown()?; 370 313 info!( 371 314 "[MIGRATION] activateAccount: Authenticated user did={}", 372 315 auth_user.did ··· 379 322 crate::oauth::scopes::AccountAction::Manage, 380 323 ) { 381 324 info!("[MIGRATION] activateAccount: Scope check failed"); 382 - return e; 325 + return Ok(e); 383 326 } 384 327 385 - let did = auth_user.did; 328 + let did = auth_user.did.clone(); 386 329 387 330 info!( 388 331 "[MIGRATION] activateAccount: Validating DID document for did={}", ··· 402 345 did, 403 346 did_validation_start.elapsed() 404 347 ); 405 - return e.into_response(); 348 + return Err(e); 406 349 } 407 350 info!( 408 351 "[MIGRATION] activateAccount: DID document validation SUCCESS for {} (took {:?})", ··· 508 451 ); 509 452 } 510 453 info!("[MIGRATION] activateAccount: SUCCESS for did={}", did); 511 - EmptyResponse::ok().into_response() 454 + Ok(EmptyResponse::ok().into_response()) 512 455 } 513 456 Err(e) => { 514 457 error!( 515 458 "[MIGRATION] activateAccount: DB error activating account: {:?}", 516 459 e 517 460 ); 518 - ApiError::InternalError(None).into_response() 461 + Err(ApiError::InternalError(None)) 519 462 } 520 463 } 521 464 } ··· 528 471 529 472 pub async fn deactivate_account( 530 473 State(state): State<AppState>, 531 - headers: axum::http::HeaderMap, 474 + auth: crate::auth::RequiredAuth, 532 475 Json(input): Json<DeactivateAccountInput>, 533 - ) -> Response { 534 - let extracted = match crate::auth::extract_auth_token_from_header( 535 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 536 - ) { 537 - Some(t) => t, 538 - None => return ApiError::AuthenticationRequired.into_response(), 539 - }; 540 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 541 - let http_uri = format!( 542 - "https://{}/xrpc/com.atproto.server.deactivateAccount", 543 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 544 - ); 545 - let auth_user = match crate::auth::validate_token_with_dpop( 546 - state.user_repo.as_ref(), 547 - state.oauth_repo.as_ref(), 548 - &extracted.token, 549 - extracted.is_dpop, 550 - dpop_proof, 551 - "POST", 552 - &http_uri, 553 - false, 554 - false, 555 - ) 556 - .await 557 - { 558 - Ok(user) => user, 559 - Err(e) => return ApiError::from(e).into_response(), 560 - }; 476 + ) -> Result<Response, ApiError> { 477 + let auth_user = auth.0.require_user()?.require_active()?; 561 478 562 479 if let Err(e) = crate::auth::scope_check::check_account_scope( 563 480 auth_user.is_oauth, ··· 565 482 crate::oauth::scopes::AccountAttr::Repo, 566 483 crate::oauth::scopes::AccountAction::Manage, 567 484 ) { 568 - return e; 485 + return Ok(e); 569 486 } 570 487 571 488 let delete_after: Option<chrono::DateTime<chrono::Utc>> = input ··· 574 491 .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok()) 575 492 .map(|dt| dt.with_timezone(&chrono::Utc)); 576 493 577 - let did = auth_user.did; 494 + let did = auth_user.did.clone(); 578 495 579 496 let handle = state.user_repo.get_handle_by_did(&did).await.ok().flatten(); 580 497 ··· 595 512 { 596 513 warn!("Failed to sequence account deactivated event: {}", e); 597 514 } 598 - EmptyResponse::ok().into_response() 515 + Ok(EmptyResponse::ok().into_response()) 599 516 } 600 - Ok(false) => EmptyResponse::ok().into_response(), 517 + Ok(false) => Ok(EmptyResponse::ok().into_response()), 601 518 Err(e) => { 602 519 error!("DB error deactivating account: {:?}", e); 603 - ApiError::InternalError(None).into_response() 520 + Err(ApiError::InternalError(None)) 604 521 } 605 522 } 606 523 } 607 524 608 525 pub async fn request_account_delete( 609 526 State(state): State<AppState>, 610 - headers: axum::http::HeaderMap, 611 - ) -> Response { 612 - let extracted = match crate::auth::extract_auth_token_from_header( 613 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 614 - ) { 615 - Some(t) => t, 616 - None => return ApiError::AuthenticationRequired.into_response(), 617 - }; 618 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 619 - let http_uri = format!( 620 - "https://{}/xrpc/com.atproto.server.requestAccountDelete", 621 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 622 - ); 623 - let validated = match crate::auth::validate_token_with_dpop( 624 - state.user_repo.as_ref(), 625 - state.oauth_repo.as_ref(), 626 - &extracted.token, 627 - extracted.is_dpop, 628 - dpop_proof, 629 - "POST", 630 - &http_uri, 631 - true, 632 - false, 633 - ) 634 - .await 635 - { 636 - Ok(user) => user, 637 - Err(e) => return ApiError::from(e).into_response(), 638 - }; 639 - let did = validated.did.clone(); 527 + auth: crate::auth::RequiredAuth, 528 + ) -> Result<Response, ApiError> { 529 + let user = auth.0.require_user()?.require_not_takendown()?; 530 + let did = &user.did; 640 531 641 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &did).await { 642 - return crate::api::server::reauth::legacy_mfa_required_response( 532 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, did).await { 533 + return Ok(crate::api::server::reauth::legacy_mfa_required_response( 643 534 &*state.user_repo, 644 535 &*state.session_repo, 645 - &did, 536 + did, 646 537 ) 647 - .await; 538 + .await); 648 539 } 649 540 650 - let user_id = match state.user_repo.get_id_by_did(&did).await { 651 - Ok(Some(id)) => id, 652 - _ => { 653 - return ApiError::InternalError(None).into_response(); 654 - } 655 - }; 541 + let user_id = state 542 + .user_repo 543 + .get_id_by_did(did) 544 + .await 545 + .ok() 546 + .flatten() 547 + .ok_or(ApiError::InternalError(None))?; 656 548 let confirmation_token = Uuid::new_v4().to_string(); 657 549 let expires_at = Utc::now() + Duration::minutes(15); 658 - if let Err(e) = state 550 + state 659 551 .infra_repo 660 - .create_deletion_request(&confirmation_token, &did, expires_at) 552 + .create_deletion_request(&confirmation_token, did, expires_at) 661 553 .await 662 - { 663 - error!("DB error creating deletion token: {:?}", e); 664 - return ApiError::InternalError(None).into_response(); 665 - } 554 + .map_err(|e| { 555 + error!("DB error creating deletion token: {:?}", e); 556 + ApiError::InternalError(None) 557 + })?; 666 558 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 667 559 if let Err(e) = crate::comms::comms_repo::enqueue_account_deletion( 668 560 state.user_repo.as_ref(), ··· 676 568 warn!("Failed to enqueue account deletion notification: {:?}", e); 677 569 } 678 570 info!("Account deletion requested for user {}", did); 679 - EmptyResponse::ok().into_response() 571 + Ok(EmptyResponse::ok().into_response()) 680 572 } 681 573 682 574 #[derive(Deserialize)]
+124 -117
crates/tranquil-pds/src/api/server/app_password.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::{BearerAuth, generate_app_password}; 3 + use crate::auth::{RequiredAuth, generate_app_password}; 4 4 use crate::delegation::{DelegationActionType, intersect_scopes}; 5 5 use crate::state::{AppState, RateLimitKind}; 6 6 use axum::{ ··· 33 33 34 34 pub async fn list_app_passwords( 35 35 State(state): State<AppState>, 36 - BearerAuth(auth_user): BearerAuth, 37 - ) -> Response { 38 - let user = match state.user_repo.get_by_did(&auth_user.did).await { 39 - Ok(Some(u)) => u, 40 - Ok(None) => return ApiError::AccountNotFound.into_response(), 41 - Err(e) => { 36 + auth: RequiredAuth, 37 + ) -> Result<Response, ApiError> { 38 + let auth_user = auth.0.require_user()?.require_active()?; 39 + let user = state 40 + .user_repo 41 + .get_by_did(&auth_user.did) 42 + .await 43 + .map_err(|e| { 42 44 error!("DB error getting user: {:?}", e); 43 - return ApiError::InternalError(None).into_response(); 44 - } 45 - }; 45 + ApiError::InternalError(None) 46 + })? 47 + .ok_or(ApiError::AccountNotFound)?; 46 48 47 - match state.session_repo.list_app_passwords(user.id).await { 48 - Ok(rows) => { 49 - let passwords: Vec<AppPassword> = rows 50 - .iter() 51 - .map(|row| AppPassword { 52 - name: row.name.clone(), 53 - created_at: row.created_at.to_rfc3339(), 54 - privileged: row.privileged, 55 - scopes: row.scopes.clone(), 56 - created_by_controller: row 57 - .created_by_controller_did 58 - .as_ref() 59 - .map(|d| d.to_string()), 60 - }) 61 - .collect(); 62 - Json(ListAppPasswordsOutput { passwords }).into_response() 63 - } 64 - Err(e) => { 49 + let rows = state 50 + .session_repo 51 + .list_app_passwords(user.id) 52 + .await 53 + .map_err(|e| { 65 54 error!("DB error listing app passwords: {:?}", e); 66 - ApiError::InternalError(None).into_response() 67 - } 68 - } 55 + ApiError::InternalError(None) 56 + })?; 57 + let passwords: Vec<AppPassword> = rows 58 + .iter() 59 + .map(|row| AppPassword { 60 + name: row.name.clone(), 61 + created_at: row.created_at.to_rfc3339(), 62 + privileged: row.privileged, 63 + scopes: row.scopes.clone(), 64 + created_by_controller: row 65 + .created_by_controller_did 66 + .as_ref() 67 + .map(|d| d.to_string()), 68 + }) 69 + .collect(); 70 + Ok(Json(ListAppPasswordsOutput { passwords }).into_response()) 69 71 } 70 72 71 73 #[derive(Deserialize)] ··· 89 91 pub async fn create_app_password( 90 92 State(state): State<AppState>, 91 93 headers: HeaderMap, 92 - BearerAuth(auth_user): BearerAuth, 94 + auth: RequiredAuth, 93 95 Json(input): Json<CreateAppPasswordInput>, 94 - ) -> Response { 96 + ) -> Result<Response, ApiError> { 97 + let auth_user = auth.0.require_user()?.require_active()?; 95 98 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 96 99 if !state 97 100 .check_rate_limit(RateLimitKind::AppPassword, &client_ip) 98 101 .await 99 102 { 100 103 warn!(ip = %client_ip, "App password creation rate limit exceeded"); 101 - return ApiError::RateLimitExceeded(None).into_response(); 104 + return Err(ApiError::RateLimitExceeded(None)); 102 105 } 103 106 104 - let user = match state.user_repo.get_by_did(&auth_user.did).await { 105 - Ok(Some(u)) => u, 106 - Ok(None) => return ApiError::AccountNotFound.into_response(), 107 - Err(e) => { 107 + let user = state 108 + .user_repo 109 + .get_by_did(&auth_user.did) 110 + .await 111 + .map_err(|e| { 108 112 error!("DB error getting user: {:?}", e); 109 - return ApiError::InternalError(None).into_response(); 110 - } 111 - }; 113 + ApiError::InternalError(None) 114 + })? 115 + .ok_or(ApiError::AccountNotFound)?; 112 116 113 117 let name = input.name.trim(); 114 118 if name.is_empty() { 115 - return ApiError::InvalidRequest("name is required".into()).into_response(); 119 + return Err(ApiError::InvalidRequest("name is required".into())); 116 120 } 117 121 118 - match state 122 + if state 119 123 .session_repo 120 124 .get_app_password_by_name(user.id, name) 121 125 .await 122 - { 123 - Ok(Some(_)) => return ApiError::DuplicateAppPassword.into_response(), 124 - Err(e) => { 126 + .map_err(|e| { 125 127 error!("DB error checking app password: {:?}", e); 126 - return ApiError::InternalError(None).into_response(); 127 - } 128 - Ok(None) => {} 128 + ApiError::InternalError(None) 129 + })? 130 + .is_some() 131 + { 132 + return Err(ApiError::DuplicateAppPassword); 129 133 } 130 134 131 135 let (final_scopes, controller_did) = if let Some(ref controller) = auth_user.controller_did { ··· 141 145 let intersected = intersect_scopes(requested, &granted_scopes); 142 146 143 147 if intersected.is_empty() && !granted_scopes.is_empty() { 144 - return ApiError::InsufficientScope(None).into_response(); 148 + return Err(ApiError::InsufficientScope(None)); 145 149 } 146 150 147 151 let scope_result = if intersected.is_empty() { ··· 157 161 let password = generate_app_password(); 158 162 159 163 let password_clone = password.clone(); 160 - let password_hash = match tokio::task::spawn_blocking(move || { 161 - bcrypt::hash(&password_clone, bcrypt::DEFAULT_COST) 162 - }) 163 - .await 164 - { 165 - Ok(Ok(h)) => h, 166 - Ok(Err(e)) => { 167 - error!("Failed to hash password: {:?}", e); 168 - return ApiError::InternalError(None).into_response(); 169 - } 170 - Err(e) => { 171 - error!("Failed to spawn blocking task: {:?}", e); 172 - return ApiError::InternalError(None).into_response(); 173 - } 174 - }; 164 + let password_hash = 165 + tokio::task::spawn_blocking(move || bcrypt::hash(&password_clone, bcrypt::DEFAULT_COST)) 166 + .await 167 + .map_err(|e| { 168 + error!("Failed to spawn blocking task: {:?}", e); 169 + ApiError::InternalError(None) 170 + })? 171 + .map_err(|e| { 172 + error!("Failed to hash password: {:?}", e); 173 + ApiError::InternalError(None) 174 + })?; 175 175 176 176 let privileged = input.privileged.unwrap_or(false); 177 177 let created_at = chrono::Utc::now(); ··· 185 185 created_by_controller_did: controller_did.clone(), 186 186 }; 187 187 188 - match state.session_repo.create_app_password(&create_data).await { 189 - Ok(_) => { 190 - if let Some(ref controller) = controller_did { 191 - let _ = state 192 - .delegation_repo 193 - .log_delegation_action( 194 - &auth_user.did, 195 - controller, 196 - Some(controller), 197 - DelegationActionType::AccountAction, 198 - Some(json!({ 199 - "action": "create_app_password", 200 - "name": name, 201 - "scopes": final_scopes 202 - })), 203 - None, 204 - None, 205 - ) 206 - .await; 207 - } 208 - Json(CreateAppPasswordOutput { 209 - name: name.to_string(), 210 - password, 211 - created_at: created_at.to_rfc3339(), 212 - privileged, 213 - scopes: final_scopes, 214 - }) 215 - .into_response() 216 - } 217 - Err(e) => { 188 + state 189 + .session_repo 190 + .create_app_password(&create_data) 191 + .await 192 + .map_err(|e| { 218 193 error!("DB error creating app password: {:?}", e); 219 - ApiError::InternalError(None).into_response() 220 - } 194 + ApiError::InternalError(None) 195 + })?; 196 + 197 + if let Some(ref controller) = controller_did { 198 + let _ = state 199 + .delegation_repo 200 + .log_delegation_action( 201 + &auth_user.did, 202 + controller, 203 + Some(controller), 204 + DelegationActionType::AccountAction, 205 + Some(json!({ 206 + "action": "create_app_password", 207 + "name": name, 208 + "scopes": final_scopes 209 + })), 210 + None, 211 + None, 212 + ) 213 + .await; 221 214 } 215 + Ok(Json(CreateAppPasswordOutput { 216 + name: name.to_string(), 217 + password, 218 + created_at: created_at.to_rfc3339(), 219 + privileged, 220 + scopes: final_scopes, 221 + }) 222 + .into_response()) 222 223 } 223 224 224 225 #[derive(Deserialize)] ··· 228 229 229 230 pub async fn revoke_app_password( 230 231 State(state): State<AppState>, 231 - BearerAuth(auth_user): BearerAuth, 232 + auth: RequiredAuth, 232 233 Json(input): Json<RevokeAppPasswordInput>, 233 - ) -> Response { 234 - let user = match state.user_repo.get_by_did(&auth_user.did).await { 235 - Ok(Some(u)) => u, 236 - Ok(None) => return ApiError::AccountNotFound.into_response(), 237 - Err(e) => { 234 + ) -> Result<Response, ApiError> { 235 + let auth_user = auth.0.require_user()?.require_active()?; 236 + let user = state 237 + .user_repo 238 + .get_by_did(&auth_user.did) 239 + .await 240 + .map_err(|e| { 238 241 error!("DB error getting user: {:?}", e); 239 - return ApiError::InternalError(None).into_response(); 240 - } 241 - }; 242 + ApiError::InternalError(None) 243 + })? 244 + .ok_or(ApiError::AccountNotFound)?; 242 245 243 246 let name = input.name.trim(); 244 247 if name.is_empty() { 245 - return ApiError::InvalidRequest("name is required".into()).into_response(); 248 + return Err(ApiError::InvalidRequest("name is required".into())); 246 249 } 247 250 248 251 let sessions_to_invalidate = state ··· 251 254 .await 252 255 .unwrap_or_default(); 253 256 254 - if let Err(e) = state 257 + state 255 258 .session_repo 256 259 .delete_sessions_by_app_password(&auth_user.did, name) 257 260 .await 258 - { 259 - error!("DB error revoking sessions for app password: {:?}", e); 260 - return ApiError::InternalError(None).into_response(); 261 - } 261 + .map_err(|e| { 262 + error!("DB error revoking sessions for app password: {:?}", e); 263 + ApiError::InternalError(None) 264 + })?; 262 265 263 266 futures::future::join_all(sessions_to_invalidate.iter().map(|jti| { 264 267 let cache_key = format!("auth:session:{}:{}", &auth_user.did, jti); ··· 269 272 })) 270 273 .await; 271 274 272 - if let Err(e) = state.session_repo.delete_app_password(user.id, name).await { 273 - error!("DB error revoking app password: {:?}", e); 274 - return ApiError::InternalError(None).into_response(); 275 - } 275 + state 276 + .session_repo 277 + .delete_app_password(user.id, name) 278 + .await 279 + .map_err(|e| { 280 + error!("DB error revoking app password: {:?}", e); 281 + ApiError::InternalError(None) 282 + })?; 276 283 277 - EmptyResponse::ok().into_response() 284 + Ok(EmptyResponse::ok().into_response()) 278 285 }
+93 -85
crates/tranquil-pds/src/api/server/email.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::{EmptyResponse, TokenRequiredResponse, VerifiedResponse}; 3 - use crate::auth::BearerAuth; 3 + use crate::auth::RequiredAuth; 4 4 use crate::state::{AppState, RateLimitKind}; 5 5 use axum::{ 6 6 Json, ··· 45 45 pub async fn request_email_update( 46 46 State(state): State<AppState>, 47 47 headers: axum::http::HeaderMap, 48 - auth: BearerAuth, 48 + auth: RequiredAuth, 49 49 input: Option<Json<RequestEmailUpdateInput>>, 50 - ) -> Response { 50 + ) -> Result<Response, ApiError> { 51 + let auth_user = auth.0.require_user()?.require_active()?; 51 52 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 52 53 if !state 53 54 .check_rate_limit(RateLimitKind::EmailUpdate, &client_ip) 54 55 .await 55 56 { 56 57 warn!(ip = %client_ip, "Email update rate limit exceeded"); 57 - return ApiError::RateLimitExceeded(None).into_response(); 58 + return Err(ApiError::RateLimitExceeded(None)); 58 59 } 59 60 60 61 if let Err(e) = crate::auth::scope_check::check_account_scope( 61 - auth.0.is_oauth, 62 - auth.0.scope.as_deref(), 62 + auth_user.is_oauth, 63 + auth_user.scope.as_deref(), 63 64 crate::oauth::scopes::AccountAttr::Email, 64 65 crate::oauth::scopes::AccountAction::Manage, 65 66 ) { 66 - return e; 67 + return Ok(e); 67 68 } 68 69 69 - let user = match state.user_repo.get_email_info_by_did(&auth.0.did).await { 70 - Ok(Some(row)) => row, 71 - Ok(None) => { 72 - return ApiError::AccountNotFound.into_response(); 73 - } 74 - Err(e) => { 70 + let user = state 71 + .user_repo 72 + .get_email_info_by_did(&auth_user.did) 73 + .await 74 + .map_err(|e| { 75 75 error!("DB error: {:?}", e); 76 - return ApiError::InternalError(None).into_response(); 77 - } 78 - }; 76 + ApiError::InternalError(None) 77 + })? 78 + .ok_or(ApiError::AccountNotFound)?; 79 79 80 80 let Some(current_email) = user.email else { 81 - return ApiError::InvalidRequest("account does not have an email address".into()) 82 - .into_response(); 81 + return Err(ApiError::InvalidRequest( 82 + "account does not have an email address".into(), 83 + )); 83 84 }; 84 85 85 86 let token_required = user.email_verified; 86 87 87 88 if token_required { 88 89 let code = crate::auth::verification_token::generate_channel_update_token( 89 - &auth.0.did, 90 + &auth_user.did, 90 91 "email_update", 91 92 &current_email.to_lowercase(), 92 93 ); ··· 103 104 authorized: false, 104 105 }; 105 106 if let Ok(json) = serde_json::to_string(&pending) { 106 - let cache_key = email_update_cache_key(&auth.0.did); 107 + let cache_key = email_update_cache_key(&auth_user.did); 107 108 if let Err(e) = state.cache.set(&cache_key, &json, EMAIL_UPDATE_TTL).await { 108 109 warn!("Failed to cache pending email update: {:?}", e); 109 110 } ··· 127 128 } 128 129 129 130 info!("Email update requested for user {}", user.id); 130 - TokenRequiredResponse::response(token_required).into_response() 131 + Ok(TokenRequiredResponse::response(token_required).into_response()) 131 132 } 132 133 133 134 #[derive(Deserialize)] ··· 140 141 pub async fn confirm_email( 141 142 State(state): State<AppState>, 142 143 headers: axum::http::HeaderMap, 143 - auth: BearerAuth, 144 + auth: RequiredAuth, 144 145 Json(input): Json<ConfirmEmailInput>, 145 - ) -> Response { 146 + ) -> Result<Response, ApiError> { 147 + let auth_user = auth.0.require_user()?.require_active()?; 146 148 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 147 149 if !state 148 150 .check_rate_limit(RateLimitKind::EmailUpdate, &client_ip) 149 151 .await 150 152 { 151 153 warn!(ip = %client_ip, "Confirm email rate limit exceeded"); 152 - return ApiError::RateLimitExceeded(None).into_response(); 154 + return Err(ApiError::RateLimitExceeded(None)); 153 155 } 154 156 155 157 if let Err(e) = crate::auth::scope_check::check_account_scope( 156 - auth.0.is_oauth, 157 - auth.0.scope.as_deref(), 158 + auth_user.is_oauth, 159 + auth_user.scope.as_deref(), 158 160 crate::oauth::scopes::AccountAttr::Email, 159 161 crate::oauth::scopes::AccountAction::Manage, 160 162 ) { 161 - return e; 163 + return Ok(e); 162 164 } 163 165 164 - let did = &auth.0.did; 165 - let user = match state.user_repo.get_email_info_by_did(did).await { 166 - Ok(Some(row)) => row, 167 - Ok(None) => { 168 - return ApiError::AccountNotFound.into_response(); 169 - } 170 - Err(e) => { 166 + let did = &auth_user.did; 167 + let user = state 168 + .user_repo 169 + .get_email_info_by_did(did) 170 + .await 171 + .map_err(|e| { 171 172 error!("DB error: {:?}", e); 172 - return ApiError::InternalError(None).into_response(); 173 - } 174 - }; 173 + ApiError::InternalError(None) 174 + })? 175 + .ok_or(ApiError::AccountNotFound)?; 175 176 176 177 let Some(ref email) = user.email else { 177 - return ApiError::InvalidEmail.into_response(); 178 + return Err(ApiError::InvalidEmail); 178 179 }; 179 180 let current_email = email.to_lowercase(); 180 181 181 182 let provided_email = input.email.trim().to_lowercase(); 182 183 if provided_email != current_email { 183 - return ApiError::InvalidEmail.into_response(); 184 + return Err(ApiError::InvalidEmail); 184 185 } 185 186 186 187 if user.email_verified { 187 - return EmptyResponse::ok().into_response(); 188 + return Ok(EmptyResponse::ok().into_response()); 188 189 } 189 190 190 191 let confirmation_code = ··· 199 200 match verified { 200 201 Ok(token_data) => { 201 202 if token_data.did != did.as_str() { 202 - return ApiError::InvalidToken(None).into_response(); 203 + return Err(ApiError::InvalidToken(None)); 203 204 } 204 205 } 205 206 Err(crate::auth::verification_token::VerifyError::Expired) => { 206 - return ApiError::ExpiredToken(None).into_response(); 207 + return Err(ApiError::ExpiredToken(None)); 207 208 } 208 209 Err(_) => { 209 - return ApiError::InvalidToken(None).into_response(); 210 + return Err(ApiError::InvalidToken(None)); 210 211 } 211 212 } 212 213 213 - if let Err(e) = state.user_repo.set_email_verified(user.id, true).await { 214 - error!("DB error confirming email: {:?}", e); 215 - return ApiError::InternalError(None).into_response(); 216 - } 214 + state 215 + .user_repo 216 + .set_email_verified(user.id, true) 217 + .await 218 + .map_err(|e| { 219 + error!("DB error confirming email: {:?}", e); 220 + ApiError::InternalError(None) 221 + })?; 217 222 218 223 info!("Email confirmed for user {}", user.id); 219 - EmptyResponse::ok().into_response() 224 + Ok(EmptyResponse::ok().into_response()) 220 225 } 221 226 222 227 #[derive(Deserialize)] ··· 230 235 231 236 pub async fn update_email( 232 237 State(state): State<AppState>, 233 - auth: BearerAuth, 238 + auth: RequiredAuth, 234 239 Json(input): Json<UpdateEmailInput>, 235 - ) -> Response { 236 - let auth_user = auth.0; 240 + ) -> Result<Response, ApiError> { 241 + let auth_user = auth.0.require_user()?.require_active()?; 237 242 238 243 if let Err(e) = crate::auth::scope_check::check_account_scope( 239 244 auth_user.is_oauth, ··· 241 246 crate::oauth::scopes::AccountAttr::Email, 242 247 crate::oauth::scopes::AccountAction::Manage, 243 248 ) { 244 - return e; 249 + return Ok(e); 245 250 } 246 251 247 252 let did = &auth_user.did; 248 - let user = match state.user_repo.get_email_info_by_did(did).await { 249 - Ok(Some(row)) => row, 250 - Ok(None) => { 251 - return ApiError::AccountNotFound.into_response(); 252 - } 253 - Err(e) => { 253 + let user = state 254 + .user_repo 255 + .get_email_info_by_did(did) 256 + .await 257 + .map_err(|e| { 254 258 error!("DB error: {:?}", e); 255 - return ApiError::InternalError(None).into_response(); 256 - } 257 - }; 259 + ApiError::InternalError(None) 260 + })? 261 + .ok_or(ApiError::AccountNotFound)?; 258 262 259 263 let user_id = user.id; 260 264 let current_email = user.email.clone(); ··· 262 266 let new_email = input.email.trim().to_lowercase(); 263 267 264 268 if !crate::api::validation::is_valid_email(&new_email) { 265 - return ApiError::InvalidRequest( 269 + return Err(ApiError::InvalidRequest( 266 270 "This email address is not supported, please use a different email.".into(), 267 - ) 268 - .into_response(); 271 + )); 269 272 } 270 273 271 274 if let Some(ref current) = current_email 272 275 && new_email == current.to_lowercase() 273 276 { 274 - return EmptyResponse::ok().into_response(); 277 + return Ok(EmptyResponse::ok().into_response()); 275 278 } 276 279 277 280 if email_verified { ··· 290 293 291 294 if !authorized_via_link { 292 295 let Some(ref t) = input.token else { 293 - return ApiError::TokenRequired.into_response(); 296 + return Err(ApiError::TokenRequired); 294 297 }; 295 298 let confirmation_token = 296 299 crate::auth::verification_token::normalize_token_input(t.trim()); ··· 309 312 match verified { 310 313 Ok(token_data) => { 311 314 if token_data.did != did.as_str() { 312 - return ApiError::InvalidToken(None).into_response(); 315 + return Err(ApiError::InvalidToken(None)); 313 316 } 314 317 } 315 318 Err(crate::auth::verification_token::VerifyError::Expired) => { 316 - return ApiError::ExpiredToken(None).into_response(); 319 + return Err(ApiError::ExpiredToken(None)); 317 320 } 318 321 Err(_) => { 319 - return ApiError::InvalidToken(None).into_response(); 322 + return Err(ApiError::InvalidToken(None)); 320 323 } 321 324 } 322 325 } 323 326 } 324 327 325 - if let Err(e) = state.user_repo.update_email(user_id, &new_email).await { 326 - error!("DB error updating email: {:?}", e); 327 - return ApiError::InternalError(None).into_response(); 328 - } 328 + state 329 + .user_repo 330 + .update_email(user_id, &new_email) 331 + .await 332 + .map_err(|e| { 333 + error!("DB error updating email: {:?}", e); 334 + ApiError::InternalError(None) 335 + })?; 329 336 330 337 let verification_token = 331 338 crate::auth::verification_token::generate_signup_token(did, "email", &new_email); ··· 358 365 } 359 366 360 367 info!("Email updated for user {}", user_id); 361 - EmptyResponse::ok().into_response() 368 + Ok(EmptyResponse::ok().into_response()) 362 369 } 363 370 364 371 #[derive(Deserialize)] ··· 497 504 pub async fn check_email_update_status( 498 505 State(state): State<AppState>, 499 506 headers: axum::http::HeaderMap, 500 - auth: BearerAuth, 501 - ) -> Response { 507 + auth: RequiredAuth, 508 + ) -> Result<Response, ApiError> { 509 + let auth_user = auth.0.require_user()?.require_active()?; 502 510 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 503 511 if !state 504 512 .check_rate_limit(RateLimitKind::VerificationCheck, &client_ip) 505 513 .await 506 514 { 507 - return ApiError::RateLimitExceeded(None).into_response(); 515 + return Err(ApiError::RateLimitExceeded(None)); 508 516 } 509 517 510 518 if let Err(e) = crate::auth::scope_check::check_account_scope( 511 - auth.0.is_oauth, 512 - auth.0.scope.as_deref(), 519 + auth_user.is_oauth, 520 + auth_user.scope.as_deref(), 513 521 crate::oauth::scopes::AccountAttr::Email, 514 522 crate::oauth::scopes::AccountAction::Read, 515 523 ) { 516 - return e; 524 + return Ok(e); 517 525 } 518 526 519 - let cache_key = email_update_cache_key(&auth.0.did); 527 + let cache_key = email_update_cache_key(&auth_user.did); 520 528 let pending_json = match state.cache.get(&cache_key).await { 521 529 Some(json) => json, 522 530 None => { 523 - return Json(json!({ "pending": false, "authorized": false })).into_response(); 531 + return Ok(Json(json!({ "pending": false, "authorized": false })).into_response()); 524 532 } 525 533 }; 526 534 527 535 let pending: PendingEmailUpdate = match serde_json::from_str(&pending_json) { 528 536 Ok(p) => p, 529 537 Err(_) => { 530 - return Json(json!({ "pending": false, "authorized": false })).into_response(); 538 + return Ok(Json(json!({ "pending": false, "authorized": false })).into_response()); 531 539 } 532 540 }; 533 541 534 - Json(json!({ 542 + Ok(Json(json!({ 535 543 "pending": true, 536 544 "authorized": pending.authorized, 537 545 "newEmail": pending.new_email, 538 546 })) 539 - .into_response() 547 + .into_response()) 540 548 } 541 549 542 550 #[derive(Deserialize)]
+46 -45
crates/tranquil-pds/src/api/server/invite.rs
··· 1 1 use crate::api::ApiError; 2 - use crate::auth::BearerAuth; 3 - use crate::auth::extractor::BearerAuthAdmin; 2 + use crate::auth::RequiredAuth; 4 3 use crate::state::AppState; 5 4 use crate::types::Did; 6 5 use axum::{ ··· 44 43 45 44 pub async fn create_invite_code( 46 45 State(state): State<AppState>, 47 - BearerAuthAdmin(auth_user): BearerAuthAdmin, 46 + auth: RequiredAuth, 48 47 Json(input): Json<CreateInviteCodeInput>, 49 - ) -> Response { 48 + ) -> Result<Response, ApiError> { 49 + let auth_user = auth.0.require_user()?.require_active()?.require_admin()?; 50 50 if input.use_count < 1 { 51 - return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 51 + return Err(ApiError::InvalidRequest( 52 + "useCount must be at least 1".into(), 53 + )); 52 54 } 53 55 54 56 let for_account: Did = match &input.for_account { 55 - Some(acct) => match acct.parse() { 56 - Ok(d) => d, 57 - Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 58 - }, 57 + Some(acct) => acct 58 + .parse() 59 + .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?, 59 60 None => auth_user.did.clone(), 60 61 }; 61 62 let code = gen_invite_code(); ··· 65 66 .create_invite_code(&code, input.use_count, Some(&for_account)) 66 67 .await 67 68 { 68 - Ok(true) => Json(CreateInviteCodeOutput { code }).into_response(), 69 + Ok(true) => Ok(Json(CreateInviteCodeOutput { code }).into_response()), 69 70 Ok(false) => { 70 71 error!("No admin user found to create invite code"); 71 - ApiError::InternalError(None).into_response() 72 + Err(ApiError::InternalError(None)) 72 73 } 73 74 Err(e) => { 74 75 error!("DB error creating invite code: {:?}", e); 75 - ApiError::InternalError(None).into_response() 76 + Err(ApiError::InternalError(None)) 76 77 } 77 78 } 78 79 } ··· 98 99 99 100 pub async fn create_invite_codes( 100 101 State(state): State<AppState>, 101 - BearerAuthAdmin(auth_user): BearerAuthAdmin, 102 + auth: RequiredAuth, 102 103 Json(input): Json<CreateInviteCodesInput>, 103 - ) -> Response { 104 + ) -> Result<Response, ApiError> { 105 + let auth_user = auth.0.require_user()?.require_active()?.require_admin()?; 104 106 if input.use_count < 1 { 105 - return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 107 + return Err(ApiError::InvalidRequest( 108 + "useCount must be at least 1".into(), 109 + )); 106 110 } 107 111 108 112 let code_count = input.code_count.unwrap_or(1).max(1); 109 113 let for_accounts: Vec<Did> = match &input.for_accounts { 110 - Some(accounts) if !accounts.is_empty() => { 111 - let parsed: Result<Vec<Did>, _> = accounts.iter().map(|a| a.parse()).collect(); 112 - match parsed { 113 - Ok(dids) => dids, 114 - Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 115 - } 116 - } 114 + Some(accounts) if !accounts.is_empty() => accounts 115 + .iter() 116 + .map(|a| a.parse()) 117 + .collect::<Result<Vec<Did>, _>>() 118 + .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?, 117 119 _ => vec![auth_user.did.clone()], 118 120 }; 119 121 120 - let admin_user_id = match state.user_repo.get_any_admin_user_id().await { 121 - Ok(Some(id)) => id, 122 - Ok(None) => { 122 + let admin_user_id = state 123 + .user_repo 124 + .get_any_admin_user_id() 125 + .await 126 + .map_err(|e| { 127 + error!("DB error looking up admin user: {:?}", e); 128 + ApiError::InternalError(None) 129 + })? 130 + .ok_or_else(|| { 123 131 error!("No admin user found to create invite codes"); 124 - return ApiError::InternalError(None).into_response(); 125 - } 126 - Err(e) => { 127 - error!("DB error looking up admin user: {:?}", e); 128 - return ApiError::InternalError(None).into_response(); 129 - } 130 - }; 132 + ApiError::InternalError(None) 133 + })?; 131 134 132 135 let result = futures::future::try_join_all(for_accounts.into_iter().map(|account| { 133 136 let infra_repo = state.infra_repo.clone(); ··· 146 149 .await; 147 150 148 151 match result { 149 - Ok(result_codes) => Json(CreateInviteCodesOutput { 152 + Ok(result_codes) => Ok(Json(CreateInviteCodesOutput { 150 153 codes: result_codes, 151 154 }) 152 - .into_response(), 155 + .into_response()), 153 156 Err(e) => { 154 157 error!("DB error creating invite codes: {:?}", e); 155 - ApiError::InternalError(None).into_response() 158 + Err(ApiError::InternalError(None)) 156 159 } 157 160 } 158 161 } ··· 192 195 193 196 pub async fn get_account_invite_codes( 194 197 State(state): State<AppState>, 195 - BearerAuth(auth_user): BearerAuth, 198 + auth: RequiredAuth, 196 199 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>, 197 - ) -> Response { 200 + ) -> Result<Response, ApiError> { 201 + let auth_user = auth.0.require_user()?.require_active()?; 198 202 let include_used = params.include_used.unwrap_or(true); 199 203 200 - let codes_info = match state 204 + let codes_info = state 201 205 .infra_repo 202 206 .get_invite_codes_for_account(&auth_user.did) 203 207 .await 204 - { 205 - Ok(info) => info, 206 - Err(e) => { 208 + .map_err(|e| { 207 209 error!("DB error fetching invite codes: {:?}", e); 208 - return ApiError::InternalError(None).into_response(); 209 - } 210 - }; 210 + ApiError::InternalError(None) 211 + })?; 211 212 212 213 let filtered_codes: Vec<_> = codes_info 213 214 .into_iter() ··· 254 255 .await; 255 256 256 257 let codes: Vec<InviteCode> = codes.into_iter().flatten().collect(); 257 - Json(GetAccountInviteCodesOutput { codes }).into_response() 258 + Ok(Json(GetAccountInviteCodesOutput { codes }).into_response()) 258 259 }
+42 -95
crates/tranquil-pds/src/api/server/migration.rs
··· 1 1 use crate::api::ApiError; 2 + use crate::auth::RequiredAuth; 2 3 use crate::state::AppState; 3 4 use axum::{ 4 5 Json, ··· 35 36 36 37 pub async fn update_did_document( 37 38 State(state): State<AppState>, 38 - headers: axum::http::HeaderMap, 39 + auth: RequiredAuth, 39 40 Json(input): Json<UpdateDidDocumentInput>, 40 - ) -> Response { 41 - let extracted = match crate::auth::extract_auth_token_from_header( 42 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 43 - ) { 44 - Some(t) => t, 45 - None => return ApiError::AuthenticationRequired.into_response(), 46 - }; 47 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 48 - let http_uri = format!( 49 - "https://{}/xrpc/_account.updateDidDocument", 50 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 51 - ); 52 - let auth_user = match crate::auth::validate_token_with_dpop( 53 - state.user_repo.as_ref(), 54 - state.oauth_repo.as_ref(), 55 - &extracted.token, 56 - extracted.is_dpop, 57 - dpop_proof, 58 - "POST", 59 - &http_uri, 60 - true, 61 - false, 62 - ) 63 - .await 64 - { 65 - Ok(user) => user, 66 - Err(e) => return ApiError::from(e).into_response(), 67 - }; 41 + ) -> Result<Response, ApiError> { 42 + let auth_user = auth.0.require_user()?.require_active()?; 68 43 69 44 if !auth_user.did.starts_with("did:web:") { 70 - return ApiError::InvalidRequest( 45 + return Err(ApiError::InvalidRequest( 71 46 "DID document updates are only available for did:web accounts".into(), 72 - ) 73 - .into_response(); 47 + )); 74 48 } 75 49 76 - let user = match state.user_repo.get_user_for_did_doc(&auth_user.did).await { 77 - Ok(Some(u)) => u, 78 - Ok(None) => return ApiError::AccountNotFound.into_response(), 79 - Err(e) => { 50 + let user = state 51 + .user_repo 52 + .get_user_for_did_doc(&auth_user.did) 53 + .await 54 + .map_err(|e| { 80 55 tracing::error!("DB error getting user: {:?}", e); 81 - return ApiError::InternalError(None).into_response(); 82 - } 83 - }; 84 - 85 - if user.deactivated_at.is_some() { 86 - return ApiError::AccountDeactivated.into_response(); 87 - } 56 + ApiError::InternalError(None) 57 + })? 58 + .ok_or(ApiError::AccountNotFound)?; 88 59 89 60 if let Some(ref methods) = input.verification_methods { 90 61 if methods.is_empty() { 91 - return ApiError::InvalidRequest("verification_methods cannot be empty".into()) 92 - .into_response(); 62 + return Err(ApiError::InvalidRequest( 63 + "verification_methods cannot be empty".into(), 64 + )); 93 65 } 94 66 let validation_error = methods.iter().find_map(|method| { 95 67 if method.id.is_empty() { ··· 105 77 } 106 78 }); 107 79 if let Some(err) = validation_error { 108 - return ApiError::InvalidRequest(err.into()).into_response(); 80 + return Err(ApiError::InvalidRequest(err.into())); 109 81 } 110 82 } 111 83 112 84 if let Some(ref handles) = input.also_known_as 113 85 && handles.iter().any(|h| !h.starts_with("at://")) 114 86 { 115 - return ApiError::InvalidRequest("alsoKnownAs entries must be at:// URIs".into()) 116 - .into_response(); 87 + return Err(ApiError::InvalidRequest( 88 + "alsoKnownAs entries must be at:// URIs".into(), 89 + )); 117 90 } 118 91 119 92 if let Some(ref endpoint) = input.service_endpoint { 120 93 let endpoint = endpoint.trim(); 121 94 if !endpoint.starts_with("https://") { 122 - return ApiError::InvalidRequest("serviceEndpoint must start with https://".into()) 123 - .into_response(); 95 + return Err(ApiError::InvalidRequest( 96 + "serviceEndpoint must start with https://".into(), 97 + )); 124 98 } 125 99 } 126 100 ··· 131 105 132 106 let also_known_as: Option<Vec<String>> = input.also_known_as.clone(); 133 107 134 - if let Err(e) = state 108 + state 135 109 .user_repo 136 110 .upsert_did_web_overrides(user.id, verification_methods_json, also_known_as) 137 111 .await 138 - { 139 - tracing::error!("DB error upserting did_web_overrides: {:?}", e); 140 - return ApiError::InternalError(None).into_response(); 141 - } 112 + .map_err(|e| { 113 + tracing::error!("DB error upserting did_web_overrides: {:?}", e); 114 + ApiError::InternalError(None) 115 + })?; 142 116 143 117 if let Some(ref endpoint) = input.service_endpoint { 144 118 let endpoint_clean = endpoint.trim().trim_end_matches('/'); 145 - if let Err(e) = state 119 + state 146 120 .user_repo 147 121 .update_migrated_to_pds(&auth_user.did, endpoint_clean) 148 122 .await 149 - { 150 - tracing::error!("DB error updating service endpoint: {:?}", e); 151 - return ApiError::InternalError(None).into_response(); 152 - } 123 + .map_err(|e| { 124 + tracing::error!("DB error updating service endpoint: {:?}", e); 125 + ApiError::InternalError(None) 126 + })?; 153 127 } 154 128 155 129 let did_doc = build_did_document(&state, &auth_user.did).await; 156 130 157 131 tracing::info!("Updated DID document for {}", &auth_user.did); 158 132 159 - ( 133 + Ok(( 160 134 StatusCode::OK, 161 135 Json(UpdateDidDocumentOutput { 162 136 success: true, 163 137 did_document: did_doc, 164 138 }), 165 139 ) 166 - .into_response() 140 + .into_response()) 167 141 } 168 142 169 143 pub async fn get_did_document( 170 144 State(state): State<AppState>, 171 - headers: axum::http::HeaderMap, 172 - ) -> Response { 173 - let extracted = match crate::auth::extract_auth_token_from_header( 174 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 175 - ) { 176 - Some(t) => t, 177 - None => return ApiError::AuthenticationRequired.into_response(), 178 - }; 179 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 180 - let http_uri = format!( 181 - "https://{}/xrpc/_account.getDidDocument", 182 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 183 - ); 184 - let auth_user = match crate::auth::validate_token_with_dpop( 185 - state.user_repo.as_ref(), 186 - state.oauth_repo.as_ref(), 187 - &extracted.token, 188 - extracted.is_dpop, 189 - dpop_proof, 190 - "GET", 191 - &http_uri, 192 - true, 193 - false, 194 - ) 195 - .await 196 - { 197 - Ok(user) => user, 198 - Err(e) => return ApiError::from(e).into_response(), 199 - }; 145 + auth: RequiredAuth, 146 + ) -> Result<Response, ApiError> { 147 + let auth_user = auth.0.require_user()?.require_active()?; 200 148 201 149 if !auth_user.did.starts_with("did:web:") { 202 - return ApiError::InvalidRequest( 150 + return Err(ApiError::InvalidRequest( 203 151 "This endpoint is only available for did:web accounts".into(), 204 - ) 205 - .into_response(); 152 + )); 206 153 } 207 154 208 155 let did_doc = build_did_document(&state, &auth_user.did).await; 209 156 210 - (StatusCode::OK, Json(json!({ "didDocument": did_doc }))).into_response() 157 + Ok((StatusCode::OK, Json(json!({ "didDocument": did_doc }))).into_response()) 211 158 } 212 159 213 160 async fn build_did_document(state: &AppState, did: &crate::types::Did) -> serde_json::Value {
+114 -146
crates/tranquil-pds/src/api/server/passkeys.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::BearerAuth; 3 + use crate::auth::RequiredAuth; 4 4 use crate::auth::webauthn::WebAuthnConfig; 5 5 use crate::state::AppState; 6 6 use axum::{ ··· 34 34 35 35 pub async fn start_passkey_registration( 36 36 State(state): State<AppState>, 37 - auth: BearerAuth, 37 + auth: RequiredAuth, 38 38 Json(input): Json<StartRegistrationInput>, 39 - ) -> Response { 40 - let webauthn = match get_webauthn() { 41 - Ok(w) => w, 42 - Err(e) => return e.into_response(), 43 - }; 39 + ) -> Result<Response, ApiError> { 40 + let auth_user = auth.0.require_user()?.require_active()?; 41 + let webauthn = get_webauthn()?; 44 42 45 - let handle = match state.user_repo.get_handle_by_did(&auth.0.did).await { 46 - Ok(Some(h)) => h, 47 - Ok(None) => { 48 - return ApiError::AccountNotFound.into_response(); 49 - } 50 - Err(e) => { 43 + let handle = state 44 + .user_repo 45 + .get_handle_by_did(&auth_user.did) 46 + .await 47 + .map_err(|e| { 51 48 error!("DB error fetching user: {:?}", e); 52 - return ApiError::InternalError(None).into_response(); 53 - } 54 - }; 49 + ApiError::InternalError(None) 50 + })? 51 + .ok_or(ApiError::AccountNotFound)?; 55 52 56 - let existing_passkeys = match state.user_repo.get_passkeys_for_user(&auth.0.did).await { 57 - Ok(passkeys) => passkeys, 58 - Err(e) => { 53 + let existing_passkeys = state 54 + .user_repo 55 + .get_passkeys_for_user(&auth_user.did) 56 + .await 57 + .map_err(|e| { 59 58 error!("DB error fetching existing passkeys: {:?}", e); 60 - return ApiError::InternalError(None).into_response(); 61 - } 62 - }; 59 + ApiError::InternalError(None) 60 + })?; 63 61 64 62 let exclude_credentials: Vec<CredentialID> = existing_passkeys 65 63 .iter() ··· 68 66 69 67 let display_name = input.friendly_name.as_deref().unwrap_or(&handle); 70 68 71 - let (ccr, reg_state) = match webauthn.start_registration( 72 - &auth.0.did, 73 - &handle, 74 - display_name, 75 - exclude_credentials, 76 - ) { 77 - Ok(result) => result, 78 - Err(e) => { 69 + let (ccr, reg_state) = webauthn 70 + .start_registration(&auth_user.did, &handle, display_name, exclude_credentials) 71 + .map_err(|e| { 79 72 error!("Failed to start passkey registration: {}", e); 80 - return ApiError::InternalError(Some("Failed to start registration".into())) 81 - .into_response(); 82 - } 83 - }; 73 + ApiError::InternalError(Some("Failed to start registration".into())) 74 + })?; 84 75 85 - let state_json = match serde_json::to_string(&reg_state) { 86 - Ok(s) => s, 87 - Err(e) => { 88 - error!("Failed to serialize registration state: {:?}", e); 89 - return ApiError::InternalError(None).into_response(); 90 - } 91 - }; 76 + let state_json = serde_json::to_string(&reg_state).map_err(|e| { 77 + error!("Failed to serialize registration state: {:?}", e); 78 + ApiError::InternalError(None) 79 + })?; 92 80 93 - if let Err(e) = state 81 + state 94 82 .user_repo 95 - .save_webauthn_challenge(&auth.0.did, "registration", &state_json) 83 + .save_webauthn_challenge(&auth_user.did, "registration", &state_json) 96 84 .await 97 - { 98 - error!("Failed to save registration state: {:?}", e); 99 - return ApiError::InternalError(None).into_response(); 100 - } 85 + .map_err(|e| { 86 + error!("Failed to save registration state: {:?}", e); 87 + ApiError::InternalError(None) 88 + })?; 101 89 102 90 let options = serde_json::to_value(&ccr).unwrap_or(serde_json::json!({})); 103 91 104 - info!(did = %auth.0.did, "Passkey registration started"); 92 + info!(did = %auth_user.did, "Passkey registration started"); 105 93 106 - Json(StartRegistrationResponse { options }).into_response() 94 + Ok(Json(StartRegistrationResponse { options }).into_response()) 107 95 } 108 96 109 97 #[derive(Deserialize)] ··· 122 110 123 111 pub async fn finish_passkey_registration( 124 112 State(state): State<AppState>, 125 - auth: BearerAuth, 113 + auth: RequiredAuth, 126 114 Json(input): Json<FinishRegistrationInput>, 127 - ) -> Response { 128 - let webauthn = match get_webauthn() { 129 - Ok(w) => w, 130 - Err(e) => return e.into_response(), 131 - }; 115 + ) -> Result<Response, ApiError> { 116 + let auth_user = auth.0.require_user()?.require_active()?; 117 + let webauthn = get_webauthn()?; 132 118 133 - let reg_state_json = match state 119 + let reg_state_json = state 134 120 .user_repo 135 - .load_webauthn_challenge(&auth.0.did, "registration") 121 + .load_webauthn_challenge(&auth_user.did, "registration") 136 122 .await 137 - { 138 - Ok(Some(json)) => json, 139 - Ok(None) => { 140 - return ApiError::NoRegistrationInProgress.into_response(); 141 - } 142 - Err(e) => { 123 + .map_err(|e| { 143 124 error!("DB error loading registration state: {:?}", e); 144 - return ApiError::InternalError(None).into_response(); 145 - } 146 - }; 125 + ApiError::InternalError(None) 126 + })? 127 + .ok_or(ApiError::NoRegistrationInProgress)?; 147 128 148 - let reg_state: SecurityKeyRegistration = match serde_json::from_str(&reg_state_json) { 149 - Ok(s) => s, 150 - Err(e) => { 129 + let reg_state: SecurityKeyRegistration = 130 + serde_json::from_str(&reg_state_json).map_err(|e| { 151 131 error!("Failed to deserialize registration state: {:?}", e); 152 - return ApiError::InternalError(None).into_response(); 153 - } 154 - }; 132 + ApiError::InternalError(None) 133 + })?; 155 134 156 - let credential: RegisterPublicKeyCredential = match serde_json::from_value(input.credential) { 157 - Ok(c) => c, 158 - Err(e) => { 135 + let credential: RegisterPublicKeyCredential = serde_json::from_value(input.credential) 136 + .map_err(|e| { 159 137 warn!("Failed to parse credential: {:?}", e); 160 - return ApiError::InvalidCredential.into_response(); 161 - } 162 - }; 138 + ApiError::InvalidCredential 139 + })?; 163 140 164 - let passkey = match webauthn.finish_registration(&credential, &reg_state) { 165 - Ok(pk) => pk, 166 - Err(e) => { 141 + let passkey = webauthn 142 + .finish_registration(&credential, &reg_state) 143 + .map_err(|e| { 167 144 warn!("Failed to finish passkey registration: {}", e); 168 - return ApiError::RegistrationFailed.into_response(); 169 - } 170 - }; 145 + ApiError::RegistrationFailed 146 + })?; 171 147 172 - let public_key = match serde_json::to_vec(&passkey) { 173 - Ok(pk) => pk, 174 - Err(e) => { 175 - error!("Failed to serialize passkey: {:?}", e); 176 - return ApiError::InternalError(None).into_response(); 177 - } 178 - }; 148 + let public_key = serde_json::to_vec(&passkey).map_err(|e| { 149 + error!("Failed to serialize passkey: {:?}", e); 150 + ApiError::InternalError(None) 151 + })?; 179 152 180 - let passkey_id = match state 153 + let passkey_id = state 181 154 .user_repo 182 155 .save_passkey( 183 - &auth.0.did, 156 + &auth_user.did, 184 157 passkey.cred_id(), 185 158 &public_key, 186 159 input.friendly_name.as_deref(), 187 160 ) 188 161 .await 189 - { 190 - Ok(id) => id, 191 - Err(e) => { 162 + .map_err(|e| { 192 163 error!("Failed to save passkey: {:?}", e); 193 - return ApiError::InternalError(None).into_response(); 194 - } 195 - }; 164 + ApiError::InternalError(None) 165 + })?; 196 166 197 167 if let Err(e) = state 198 168 .user_repo 199 - .delete_webauthn_challenge(&auth.0.did, "registration") 169 + .delete_webauthn_challenge(&auth_user.did, "registration") 200 170 .await 201 171 { 202 172 warn!("Failed to delete registration state: {:?}", e); ··· 207 177 passkey.cred_id(), 208 178 ); 209 179 210 - info!(did = %auth.0.did, passkey_id = %passkey_id, "Passkey registered"); 180 + info!(did = %auth_user.did, passkey_id = %passkey_id, "Passkey registered"); 211 181 212 - Json(FinishRegistrationResponse { 182 + Ok(Json(FinishRegistrationResponse { 213 183 id: passkey_id.to_string(), 214 184 credential_id: credential_id_base64, 215 185 }) 216 - .into_response() 186 + .into_response()) 217 187 } 218 188 219 189 #[derive(Serialize)] ··· 232 202 pub passkeys: Vec<PasskeyInfo>, 233 203 } 234 204 235 - pub async fn list_passkeys(State(state): State<AppState>, auth: BearerAuth) -> Response { 236 - let passkeys = match state.user_repo.get_passkeys_for_user(&auth.0.did).await { 237 - Ok(pks) => pks, 238 - Err(e) => { 205 + pub async fn list_passkeys( 206 + State(state): State<AppState>, 207 + auth: RequiredAuth, 208 + ) -> Result<Response, ApiError> { 209 + let auth_user = auth.0.require_user()?.require_active()?; 210 + let passkeys = state 211 + .user_repo 212 + .get_passkeys_for_user(&auth_user.did) 213 + .await 214 + .map_err(|e| { 239 215 error!("DB error fetching passkeys: {:?}", e); 240 - return ApiError::InternalError(None).into_response(); 241 - } 242 - }; 216 + ApiError::InternalError(None) 217 + })?; 243 218 244 219 let passkey_infos: Vec<PasskeyInfo> = passkeys 245 220 .into_iter() ··· 252 227 }) 253 228 .collect(); 254 229 255 - Json(ListPasskeysResponse { 230 + Ok(Json(ListPasskeysResponse { 256 231 passkeys: passkey_infos, 257 232 }) 258 - .into_response() 233 + .into_response()) 259 234 } 260 235 261 236 #[derive(Deserialize)] ··· 266 241 267 242 pub async fn delete_passkey( 268 243 State(state): State<AppState>, 269 - auth: BearerAuth, 244 + auth: RequiredAuth, 270 245 Json(input): Json<DeletePasskeyInput>, 271 - ) -> Response { 272 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did) 246 + ) -> Result<Response, ApiError> { 247 + let auth_user = auth.0.require_user()?.require_active()?; 248 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth_user.did) 273 249 .await 274 250 { 275 - return crate::api::server::reauth::legacy_mfa_required_response( 251 + return Ok(crate::api::server::reauth::legacy_mfa_required_response( 276 252 &*state.user_repo, 277 253 &*state.session_repo, 278 - &auth.0.did, 254 + &auth_user.did, 279 255 ) 280 - .await; 256 + .await); 281 257 } 282 258 283 - if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.0.did).await { 284 - return crate::api::server::reauth::reauth_required_response( 259 + if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth_user.did).await 260 + { 261 + return Ok(crate::api::server::reauth::reauth_required_response( 285 262 &*state.user_repo, 286 263 &*state.session_repo, 287 - &auth.0.did, 264 + &auth_user.did, 288 265 ) 289 - .await; 266 + .await); 290 267 } 291 268 292 - let id: uuid::Uuid = match input.id.parse() { 293 - Ok(id) => id, 294 - Err(_) => { 295 - return ApiError::InvalidId.into_response(); 296 - } 297 - }; 269 + let id: uuid::Uuid = input.id.parse().map_err(|_| ApiError::InvalidId)?; 298 270 299 - match state.user_repo.delete_passkey(id, &auth.0.did).await { 271 + match state.user_repo.delete_passkey(id, &auth_user.did).await { 300 272 Ok(true) => { 301 - info!(did = %auth.0.did, passkey_id = %id, "Passkey deleted"); 302 - EmptyResponse::ok().into_response() 273 + info!(did = %auth_user.did, passkey_id = %id, "Passkey deleted"); 274 + Ok(EmptyResponse::ok().into_response()) 303 275 } 304 - Ok(false) => ApiError::PasskeyNotFound.into_response(), 276 + Ok(false) => Err(ApiError::PasskeyNotFound), 305 277 Err(e) => { 306 278 error!("DB error deleting passkey: {:?}", e); 307 - ApiError::InternalError(None).into_response() 279 + Err(ApiError::InternalError(None)) 308 280 } 309 281 } 310 282 } ··· 318 290 319 291 pub async fn update_passkey( 320 292 State(state): State<AppState>, 321 - auth: BearerAuth, 293 + auth: RequiredAuth, 322 294 Json(input): Json<UpdatePasskeyInput>, 323 - ) -> Response { 324 - let id: uuid::Uuid = match input.id.parse() { 325 - Ok(id) => id, 326 - Err(_) => { 327 - return ApiError::InvalidId.into_response(); 328 - } 329 - }; 295 + ) -> Result<Response, ApiError> { 296 + let auth_user = auth.0.require_user()?.require_active()?; 297 + let id: uuid::Uuid = input.id.parse().map_err(|_| ApiError::InvalidId)?; 330 298 331 299 match state 332 300 .user_repo 333 - .update_passkey_name(id, &auth.0.did, &input.friendly_name) 301 + .update_passkey_name(id, &auth_user.did, &input.friendly_name) 334 302 .await 335 303 { 336 304 Ok(true) => { 337 - info!(did = %auth.0.did, passkey_id = %id, "Passkey renamed"); 338 - EmptyResponse::ok().into_response() 305 + info!(did = %auth_user.did, passkey_id = %id, "Passkey renamed"); 306 + Ok(EmptyResponse::ok().into_response()) 339 307 } 340 - Ok(false) => ApiError::PasskeyNotFound.into_response(), 308 + Ok(false) => Err(ApiError::PasskeyNotFound), 341 309 Err(e) => { 342 310 error!("DB error updating passkey: {:?}", e); 343 - ApiError::InternalError(None).into_response() 311 + Err(ApiError::InternalError(None)) 344 312 } 345 313 } 346 314 }
+131 -124
crates/tranquil-pds/src/api/server/password.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::{EmptyResponse, HasPasswordResponse, SuccessResponse}; 3 - use crate::auth::BearerAuth; 3 + use crate::auth::RequiredAuth; 4 4 use crate::state::{AppState, RateLimitKind}; 5 5 use crate::types::PlainPassword; 6 6 use crate::validation::validate_password; ··· 227 227 228 228 pub async fn change_password( 229 229 State(state): State<AppState>, 230 - auth: BearerAuth, 230 + auth: RequiredAuth, 231 231 Json(input): Json<ChangePasswordInput>, 232 - ) -> Response { 233 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did) 232 + ) -> Result<Response, ApiError> { 233 + let auth_user = auth.0.require_user()?.require_active()?; 234 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth_user.did) 234 235 .await 235 236 { 236 - return crate::api::server::reauth::legacy_mfa_required_response( 237 + return Ok(crate::api::server::reauth::legacy_mfa_required_response( 237 238 &*state.user_repo, 238 239 &*state.session_repo, 239 - &auth.0.did, 240 + &auth_user.did, 240 241 ) 241 - .await; 242 + .await); 242 243 } 243 244 244 245 let current_password = &input.current_password; 245 246 let new_password = &input.new_password; 246 247 if current_password.is_empty() { 247 - return ApiError::InvalidRequest("currentPassword is required".into()).into_response(); 248 + return Err(ApiError::InvalidRequest( 249 + "currentPassword is required".into(), 250 + )); 248 251 } 249 252 if new_password.is_empty() { 250 - return ApiError::InvalidRequest("newPassword is required".into()).into_response(); 253 + return Err(ApiError::InvalidRequest("newPassword is required".into())); 251 254 } 252 255 if let Err(e) = validate_password(new_password) { 253 - return ApiError::InvalidRequest(e.to_string()).into_response(); 256 + return Err(ApiError::InvalidRequest(e.to_string())); 254 257 } 255 - let user = match state 258 + let user = state 256 259 .user_repo 257 - .get_id_and_password_hash_by_did(&auth.0.did) 260 + .get_id_and_password_hash_by_did(&auth_user.did) 258 261 .await 259 - { 260 - Ok(Some(u)) => u, 261 - Ok(None) => { 262 - return ApiError::AccountNotFound.into_response(); 263 - } 264 - Err(e) => { 262 + .map_err(|e| { 265 263 error!("DB error in change_password: {:?}", e); 266 - return ApiError::InternalError(None).into_response(); 267 - } 268 - }; 264 + ApiError::InternalError(None) 265 + })? 266 + .ok_or(ApiError::AccountNotFound)?; 267 + 269 268 let (user_id, password_hash) = (user.id, user.password_hash); 270 - let valid = match verify(current_password, &password_hash) { 271 - Ok(v) => v, 272 - Err(e) => { 273 - error!("Password verification error: {:?}", e); 274 - return ApiError::InternalError(None).into_response(); 275 - } 276 - }; 269 + let valid = verify(current_password, &password_hash).map_err(|e| { 270 + error!("Password verification error: {:?}", e); 271 + ApiError::InternalError(None) 272 + })?; 277 273 if !valid { 278 - return ApiError::InvalidPassword("Current password is incorrect".into()).into_response(); 274 + return Err(ApiError::InvalidPassword( 275 + "Current password is incorrect".into(), 276 + )); 279 277 } 280 278 let new_password_clone = new_password.to_string(); 281 - let new_hash = 282 - match tokio::task::spawn_blocking(move || hash(new_password_clone, DEFAULT_COST)).await { 283 - Ok(Ok(h)) => h, 284 - Ok(Err(e)) => { 285 - error!("Failed to hash password: {:?}", e); 286 - return ApiError::InternalError(None).into_response(); 287 - } 288 - Err(e) => { 289 - error!("Failed to spawn blocking task: {:?}", e); 290 - return ApiError::InternalError(None).into_response(); 291 - } 292 - }; 293 - if let Err(e) = state 279 + let new_hash = tokio::task::spawn_blocking(move || hash(new_password_clone, DEFAULT_COST)) 280 + .await 281 + .map_err(|e| { 282 + error!("Failed to spawn blocking task: {:?}", e); 283 + ApiError::InternalError(None) 284 + })? 285 + .map_err(|e| { 286 + error!("Failed to hash password: {:?}", e); 287 + ApiError::InternalError(None) 288 + })?; 289 + 290 + state 294 291 .user_repo 295 292 .update_password_hash(user_id, &new_hash) 296 293 .await 297 - { 298 - error!("DB error updating password: {:?}", e); 299 - return ApiError::InternalError(None).into_response(); 300 - } 301 - info!(did = %&auth.0.did, "Password changed successfully"); 302 - EmptyResponse::ok().into_response() 294 + .map_err(|e| { 295 + error!("DB error updating password: {:?}", e); 296 + ApiError::InternalError(None) 297 + })?; 298 + 299 + info!(did = %&auth_user.did, "Password changed successfully"); 300 + Ok(EmptyResponse::ok().into_response()) 303 301 } 304 302 305 - pub async fn get_password_status(State(state): State<AppState>, auth: BearerAuth) -> Response { 306 - match state.user_repo.has_password_by_did(&auth.0.did).await { 307 - Ok(Some(has)) => HasPasswordResponse::response(has).into_response(), 308 - Ok(None) => ApiError::AccountNotFound.into_response(), 303 + pub async fn get_password_status( 304 + State(state): State<AppState>, 305 + auth: RequiredAuth, 306 + ) -> Result<Response, ApiError> { 307 + let auth_user = auth.0.require_user()?.require_active()?; 308 + match state.user_repo.has_password_by_did(&auth_user.did).await { 309 + Ok(Some(has)) => Ok(HasPasswordResponse::response(has).into_response()), 310 + Ok(None) => Err(ApiError::AccountNotFound), 309 311 Err(e) => { 310 312 error!("DB error: {:?}", e); 311 - ApiError::InternalError(None).into_response() 313 + Err(ApiError::InternalError(None)) 312 314 } 313 315 } 314 316 } 315 317 316 - pub async fn remove_password(State(state): State<AppState>, auth: BearerAuth) -> Response { 317 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did) 318 + pub async fn remove_password( 319 + State(state): State<AppState>, 320 + auth: RequiredAuth, 321 + ) -> Result<Response, ApiError> { 322 + let auth_user = auth.0.require_user()?.require_active()?; 323 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth_user.did) 318 324 .await 319 325 { 320 - return crate::api::server::reauth::legacy_mfa_required_response( 326 + return Ok(crate::api::server::reauth::legacy_mfa_required_response( 321 327 &*state.user_repo, 322 328 &*state.session_repo, 323 - &auth.0.did, 329 + &auth_user.did, 324 330 ) 325 - .await; 331 + .await); 326 332 } 327 333 328 334 if crate::api::server::reauth::check_reauth_required_cached( 329 335 &*state.session_repo, 330 336 &state.cache, 331 - &auth.0.did, 337 + &auth_user.did, 332 338 ) 333 339 .await 334 340 { 335 - return crate::api::server::reauth::reauth_required_response( 341 + return Ok(crate::api::server::reauth::reauth_required_response( 336 342 &*state.user_repo, 337 343 &*state.session_repo, 338 - &auth.0.did, 344 + &auth_user.did, 339 345 ) 340 - .await; 346 + .await); 341 347 } 342 348 343 349 let has_passkeys = state 344 350 .user_repo 345 - .has_passkeys(&auth.0.did) 351 + .has_passkeys(&auth_user.did) 346 352 .await 347 353 .unwrap_or(false); 348 354 if !has_passkeys { 349 - return ApiError::InvalidRequest( 355 + return Err(ApiError::InvalidRequest( 350 356 "You must have at least one passkey registered before removing your password".into(), 351 - ) 352 - .into_response(); 357 + )); 353 358 } 354 359 355 - let user = match state.user_repo.get_password_info_by_did(&auth.0.did).await { 356 - Ok(Some(u)) => u, 357 - Ok(None) => { 358 - return ApiError::AccountNotFound.into_response(); 359 - } 360 - Err(e) => { 360 + let user = state 361 + .user_repo 362 + .get_password_info_by_did(&auth_user.did) 363 + .await 364 + .map_err(|e| { 361 365 error!("DB error: {:?}", e); 362 - return ApiError::InternalError(None).into_response(); 363 - } 364 - }; 366 + ApiError::InternalError(None) 367 + })? 368 + .ok_or(ApiError::AccountNotFound)?; 365 369 366 370 if user.password_hash.is_none() { 367 - return ApiError::InvalidRequest("Account already has no password".into()).into_response(); 371 + return Err(ApiError::InvalidRequest( 372 + "Account already has no password".into(), 373 + )); 368 374 } 369 375 370 - if let Err(e) = state.user_repo.remove_user_password(user.id).await { 371 - error!("DB error removing password: {:?}", e); 372 - return ApiError::InternalError(None).into_response(); 373 - } 376 + state 377 + .user_repo 378 + .remove_user_password(user.id) 379 + .await 380 + .map_err(|e| { 381 + error!("DB error removing password: {:?}", e); 382 + ApiError::InternalError(None) 383 + })?; 374 384 375 - info!(did = %&auth.0.did, "Password removed - account is now passkey-only"); 376 - SuccessResponse::ok().into_response() 385 + info!(did = %&auth_user.did, "Password removed - account is now passkey-only"); 386 + Ok(SuccessResponse::ok().into_response()) 377 387 } 378 388 379 389 #[derive(Deserialize)] ··· 384 394 385 395 pub async fn set_password( 386 396 State(state): State<AppState>, 387 - auth: BearerAuth, 397 + auth: RequiredAuth, 388 398 Json(input): Json<SetPasswordInput>, 389 - ) -> Response { 399 + ) -> Result<Response, ApiError> { 400 + let auth_user = auth.0.require_user()?.require_active()?; 390 401 let has_password = state 391 402 .user_repo 392 - .has_password_by_did(&auth.0.did) 403 + .has_password_by_did(&auth_user.did) 393 404 .await 394 405 .ok() 395 406 .flatten() 396 407 .unwrap_or(false); 397 408 let has_passkeys = state 398 409 .user_repo 399 - .has_passkeys(&auth.0.did) 410 + .has_passkeys(&auth_user.did) 400 411 .await 401 412 .unwrap_or(false); 402 413 let has_totp = state 403 414 .user_repo 404 - .has_totp_enabled(&auth.0.did) 415 + .has_totp_enabled(&auth_user.did) 405 416 .await 406 417 .unwrap_or(false); 407 418 ··· 411 422 && crate::api::server::reauth::check_reauth_required_cached( 412 423 &*state.session_repo, 413 424 &state.cache, 414 - &auth.0.did, 425 + &auth_user.did, 415 426 ) 416 427 .await 417 428 { 418 - return crate::api::server::reauth::reauth_required_response( 429 + return Ok(crate::api::server::reauth::reauth_required_response( 419 430 &*state.user_repo, 420 431 &*state.session_repo, 421 - &auth.0.did, 432 + &auth_user.did, 422 433 ) 423 - .await; 434 + .await); 424 435 } 425 436 426 437 let new_password = &input.new_password; 427 438 if new_password.is_empty() { 428 - return ApiError::InvalidRequest("newPassword is required".into()).into_response(); 439 + return Err(ApiError::InvalidRequest("newPassword is required".into())); 429 440 } 430 441 if let Err(e) = validate_password(new_password) { 431 - return ApiError::InvalidRequest(e.to_string()).into_response(); 442 + return Err(ApiError::InvalidRequest(e.to_string())); 432 443 } 433 444 434 - let user = match state.user_repo.get_password_info_by_did(&auth.0.did).await { 435 - Ok(Some(u)) => u, 436 - Ok(None) => { 437 - return ApiError::AccountNotFound.into_response(); 438 - } 439 - Err(e) => { 445 + let user = state 446 + .user_repo 447 + .get_password_info_by_did(&auth_user.did) 448 + .await 449 + .map_err(|e| { 440 450 error!("DB error: {:?}", e); 441 - return ApiError::InternalError(None).into_response(); 442 - } 443 - }; 451 + ApiError::InternalError(None) 452 + })? 453 + .ok_or(ApiError::AccountNotFound)?; 444 454 445 455 if user.password_hash.is_some() { 446 - return ApiError::InvalidRequest( 456 + return Err(ApiError::InvalidRequest( 447 457 "Account already has a password. Use changePassword instead.".into(), 448 - ) 449 - .into_response(); 458 + )); 450 459 } 451 460 452 461 let new_password_clone = new_password.to_string(); 453 - let new_hash = 454 - match tokio::task::spawn_blocking(move || hash(new_password_clone, DEFAULT_COST)).await { 455 - Ok(Ok(h)) => h, 456 - Ok(Err(e)) => { 457 - error!("Failed to hash password: {:?}", e); 458 - return ApiError::InternalError(None).into_response(); 459 - } 460 - Err(e) => { 461 - error!("Failed to spawn blocking task: {:?}", e); 462 - return ApiError::InternalError(None).into_response(); 463 - } 464 - }; 462 + let new_hash = tokio::task::spawn_blocking(move || hash(new_password_clone, DEFAULT_COST)) 463 + .await 464 + .map_err(|e| { 465 + error!("Failed to spawn blocking task: {:?}", e); 466 + ApiError::InternalError(None) 467 + })? 468 + .map_err(|e| { 469 + error!("Failed to hash password: {:?}", e); 470 + ApiError::InternalError(None) 471 + })?; 465 472 466 - if let Err(e) = state 473 + state 467 474 .user_repo 468 475 .set_new_user_password(user.id, &new_hash) 469 476 .await 470 - { 471 - error!("DB error setting password: {:?}", e); 472 - return ApiError::InternalError(None).into_response(); 473 - } 477 + .map_err(|e| { 478 + error!("DB error setting password: {:?}", e); 479 + ApiError::InternalError(None) 480 + })?; 474 481 475 - info!(did = %&auth.0.did, "Password set for passkey-only account"); 476 - SuccessResponse::ok().into_response() 482 + info!(did = %&auth_user.did, "Password set for passkey-only account"); 483 + Ok(SuccessResponse::ok().into_response()) 477 484 }
+137 -147
crates/tranquil-pds/src/api/server/reauth.rs
··· 10 10 use tracing::{error, info, warn}; 11 11 use tranquil_db_traits::{SessionRepository, UserRepository}; 12 12 13 - use crate::auth::BearerAuth; 13 + use crate::auth::RequiredAuth; 14 14 use crate::state::{AppState, RateLimitKind}; 15 15 use crate::types::PlainPassword; 16 16 ··· 24 24 pub available_methods: Vec<String>, 25 25 } 26 26 27 - pub async fn get_reauth_status(State(state): State<AppState>, auth: BearerAuth) -> Response { 28 - let last_reauth_at = match state.session_repo.get_last_reauth_at(&auth.0.did).await { 29 - Ok(t) => t, 30 - Err(e) => { 27 + pub async fn get_reauth_status( 28 + State(state): State<AppState>, 29 + auth: RequiredAuth, 30 + ) -> Result<Response, ApiError> { 31 + let auth_user = auth.0.require_user()?.require_active()?; 32 + let last_reauth_at = state 33 + .session_repo 34 + .get_last_reauth_at(&auth_user.did) 35 + .await 36 + .map_err(|e| { 31 37 error!("DB error: {:?}", e); 32 - return ApiError::InternalError(None).into_response(); 33 - } 34 - }; 38 + ApiError::InternalError(None) 39 + })?; 35 40 36 41 let reauth_required = is_reauth_required(last_reauth_at); 37 42 let available_methods = 38 - get_available_reauth_methods(&*state.user_repo, &*state.session_repo, &auth.0.did).await; 43 + get_available_reauth_methods(&*state.user_repo, &*state.session_repo, &auth_user.did).await; 39 44 40 - Json(ReauthStatusResponse { 45 + Ok(Json(ReauthStatusResponse { 41 46 last_reauth_at, 42 47 reauth_required, 43 48 available_methods, 44 49 }) 45 - .into_response() 50 + .into_response()) 46 51 } 47 52 48 53 #[derive(Deserialize)] ··· 59 64 60 65 pub async fn reauth_password( 61 66 State(state): State<AppState>, 62 - auth: BearerAuth, 67 + auth: RequiredAuth, 63 68 Json(input): Json<PasswordReauthInput>, 64 - ) -> Response { 65 - let password_hash = match state.user_repo.get_password_hash_by_did(&auth.0.did).await { 66 - Ok(Some(hash)) => hash, 67 - Ok(None) => { 68 - return ApiError::AccountNotFound.into_response(); 69 - } 70 - Err(e) => { 69 + ) -> Result<Response, ApiError> { 70 + let auth_user = auth.0.require_user()?.require_active()?; 71 + let password_hash = state 72 + .user_repo 73 + .get_password_hash_by_did(&auth_user.did) 74 + .await 75 + .map_err(|e| { 71 76 error!("DB error: {:?}", e); 72 - return ApiError::InternalError(None).into_response(); 73 - } 74 - }; 77 + ApiError::InternalError(None) 78 + })? 79 + .ok_or(ApiError::AccountNotFound)?; 75 80 76 81 let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); 77 82 78 83 if !password_valid { 79 84 let app_password_hashes = state 80 85 .session_repo 81 - .get_app_password_hashes_by_did(&auth.0.did) 86 + .get_app_password_hashes_by_did(&auth_user.did) 82 87 .await 83 88 .unwrap_or_default(); 84 89 ··· 87 92 }); 88 93 89 94 if !app_password_valid { 90 - warn!(did = %&auth.0.did, "Re-auth failed: invalid password"); 91 - return ApiError::InvalidPassword("Password is incorrect".into()).into_response(); 95 + warn!(did = %&auth_user.did, "Re-auth failed: invalid password"); 96 + return Err(ApiError::InvalidPassword("Password is incorrect".into())); 92 97 } 93 98 } 94 99 95 - match update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.0.did).await { 96 - Ok(reauthed_at) => { 97 - info!(did = %&auth.0.did, "Re-auth successful via password"); 98 - Json(ReauthResponse { reauthed_at }).into_response() 99 - } 100 - Err(e) => { 100 + let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth_user.did) 101 + .await 102 + .map_err(|e| { 101 103 error!("DB error updating reauth: {:?}", e); 102 - ApiError::InternalError(None).into_response() 103 - } 104 - } 104 + ApiError::InternalError(None) 105 + })?; 106 + 107 + info!(did = %&auth_user.did, "Re-auth successful via password"); 108 + Ok(Json(ReauthResponse { reauthed_at }).into_response()) 105 109 } 106 110 107 111 #[derive(Deserialize)] ··· 112 116 113 117 pub async fn reauth_totp( 114 118 State(state): State<AppState>, 115 - auth: BearerAuth, 119 + auth: RequiredAuth, 116 120 Json(input): Json<TotpReauthInput>, 117 - ) -> Response { 121 + ) -> Result<Response, ApiError> { 122 + let auth_user = auth.0.require_user()?.require_active()?; 118 123 if !state 119 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.0.did) 124 + .check_rate_limit(RateLimitKind::TotpVerify, &auth_user.did) 120 125 .await 121 126 { 122 - warn!(did = %&auth.0.did, "TOTP verification rate limit exceeded"); 123 - return ApiError::RateLimitExceeded(Some( 127 + warn!(did = %&auth_user.did, "TOTP verification rate limit exceeded"); 128 + return Err(ApiError::RateLimitExceeded(Some( 124 129 "Too many verification attempts. Please try again in a few minutes.".into(), 125 - )) 126 - .into_response(); 130 + ))); 127 131 } 128 132 129 - let valid = 130 - crate::api::server::totp::verify_totp_or_backup_for_user(&state, &auth.0.did, &input.code) 131 - .await; 133 + let valid = crate::api::server::totp::verify_totp_or_backup_for_user( 134 + &state, 135 + &auth_user.did, 136 + &input.code, 137 + ) 138 + .await; 132 139 133 140 if !valid { 134 - warn!(did = %&auth.0.did, "Re-auth failed: invalid TOTP code"); 135 - return ApiError::InvalidCode(Some("Invalid TOTP or backup code".into())).into_response(); 141 + warn!(did = %&auth_user.did, "Re-auth failed: invalid TOTP code"); 142 + return Err(ApiError::InvalidCode(Some( 143 + "Invalid TOTP or backup code".into(), 144 + ))); 136 145 } 137 146 138 - match update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.0.did).await { 139 - Ok(reauthed_at) => { 140 - info!(did = %&auth.0.did, "Re-auth successful via TOTP"); 141 - Json(ReauthResponse { reauthed_at }).into_response() 142 - } 143 - Err(e) => { 147 + let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth_user.did) 148 + .await 149 + .map_err(|e| { 144 150 error!("DB error updating reauth: {:?}", e); 145 - ApiError::InternalError(None).into_response() 146 - } 147 - } 151 + ApiError::InternalError(None) 152 + })?; 153 + 154 + info!(did = %&auth_user.did, "Re-auth successful via TOTP"); 155 + Ok(Json(ReauthResponse { reauthed_at }).into_response()) 148 156 } 149 157 150 158 #[derive(Serialize)] ··· 153 161 pub options: serde_json::Value, 154 162 } 155 163 156 - pub async fn reauth_passkey_start(State(state): State<AppState>, auth: BearerAuth) -> Response { 164 + pub async fn reauth_passkey_start( 165 + State(state): State<AppState>, 166 + auth: RequiredAuth, 167 + ) -> Result<Response, ApiError> { 168 + let auth_user = auth.0.require_user()?.require_active()?; 157 169 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 158 170 159 - let stored_passkeys = match state.user_repo.get_passkeys_for_user(&auth.0.did).await { 160 - Ok(pks) => pks, 161 - Err(e) => { 171 + let stored_passkeys = state 172 + .user_repo 173 + .get_passkeys_for_user(&auth_user.did) 174 + .await 175 + .map_err(|e| { 162 176 error!("Failed to get passkeys: {:?}", e); 163 - return ApiError::InternalError(None).into_response(); 164 - } 165 - }; 177 + ApiError::InternalError(None) 178 + })?; 166 179 167 180 if stored_passkeys.is_empty() { 168 - return ApiError::NoPasskeys.into_response(); 181 + return Err(ApiError::NoPasskeys); 169 182 } 170 183 171 184 let passkeys: Vec<webauthn_rs::prelude::SecurityKey> = stored_passkeys ··· 174 187 .collect(); 175 188 176 189 if passkeys.is_empty() { 177 - return ApiError::InternalError(Some("Failed to load passkeys".into())).into_response(); 190 + return Err(ApiError::InternalError(Some( 191 + "Failed to load passkeys".into(), 192 + ))); 178 193 } 179 194 180 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 181 - Ok(w) => w, 182 - Err(e) => { 183 - error!("Failed to create WebAuthn config: {:?}", e); 184 - return ApiError::InternalError(None).into_response(); 185 - } 186 - }; 195 + let webauthn = crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname).map_err(|e| { 196 + error!("Failed to create WebAuthn config: {:?}", e); 197 + ApiError::InternalError(None) 198 + })?; 187 199 188 - let (rcr, auth_state) = match webauthn.start_authentication(passkeys) { 189 - Ok(result) => result, 190 - Err(e) => { 191 - error!("Failed to start passkey authentication: {:?}", e); 192 - return ApiError::InternalError(None).into_response(); 193 - } 194 - }; 200 + let (rcr, auth_state) = webauthn.start_authentication(passkeys).map_err(|e| { 201 + error!("Failed to start passkey authentication: {:?}", e); 202 + ApiError::InternalError(None) 203 + })?; 195 204 196 - let state_json = match serde_json::to_string(&auth_state) { 197 - Ok(s) => s, 198 - Err(e) => { 199 - error!("Failed to serialize authentication state: {:?}", e); 200 - return ApiError::InternalError(None).into_response(); 201 - } 202 - }; 205 + let state_json = serde_json::to_string(&auth_state).map_err(|e| { 206 + error!("Failed to serialize authentication state: {:?}", e); 207 + ApiError::InternalError(None) 208 + })?; 203 209 204 - if let Err(e) = state 210 + state 205 211 .user_repo 206 - .save_webauthn_challenge(&auth.0.did, "authentication", &state_json) 212 + .save_webauthn_challenge(&auth_user.did, "authentication", &state_json) 207 213 .await 208 - { 209 - error!("Failed to save authentication state: {:?}", e); 210 - return ApiError::InternalError(None).into_response(); 211 - } 214 + .map_err(|e| { 215 + error!("Failed to save authentication state: {:?}", e); 216 + ApiError::InternalError(None) 217 + })?; 212 218 213 219 let options = serde_json::to_value(&rcr).unwrap_or(serde_json::json!({})); 214 - Json(PasskeyReauthStartResponse { options }).into_response() 220 + Ok(Json(PasskeyReauthStartResponse { options }).into_response()) 215 221 } 216 222 217 223 #[derive(Deserialize)] ··· 222 228 223 229 pub async fn reauth_passkey_finish( 224 230 State(state): State<AppState>, 225 - auth: BearerAuth, 231 + auth: RequiredAuth, 226 232 Json(input): Json<PasskeyReauthFinishInput>, 227 - ) -> Response { 233 + ) -> Result<Response, ApiError> { 234 + let auth_user = auth.0.require_user()?.require_active()?; 228 235 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 229 236 230 - let auth_state_json = match state 237 + let auth_state_json = state 231 238 .user_repo 232 - .load_webauthn_challenge(&auth.0.did, "authentication") 239 + .load_webauthn_challenge(&auth_user.did, "authentication") 233 240 .await 234 - { 235 - Ok(Some(json)) => json, 236 - Ok(None) => { 237 - return ApiError::NoChallengeInProgress.into_response(); 238 - } 239 - Err(e) => { 241 + .map_err(|e| { 240 242 error!("Failed to load authentication state: {:?}", e); 241 - return ApiError::InternalError(None).into_response(); 242 - } 243 - }; 243 + ApiError::InternalError(None) 244 + })? 245 + .ok_or(ApiError::NoChallengeInProgress)?; 244 246 245 247 let auth_state: webauthn_rs::prelude::SecurityKeyAuthentication = 246 - match serde_json::from_str(&auth_state_json) { 247 - Ok(s) => s, 248 - Err(e) => { 249 - error!("Failed to deserialize authentication state: {:?}", e); 250 - return ApiError::InternalError(None).into_response(); 251 - } 252 - }; 248 + serde_json::from_str(&auth_state_json).map_err(|e| { 249 + error!("Failed to deserialize authentication state: {:?}", e); 250 + ApiError::InternalError(None) 251 + })?; 253 252 254 253 let credential: webauthn_rs::prelude::PublicKeyCredential = 255 - match serde_json::from_value(input.credential) { 256 - Ok(c) => c, 257 - Err(e) => { 258 - warn!("Failed to parse credential: {:?}", e); 259 - return ApiError::InvalidCredential.into_response(); 260 - } 261 - }; 254 + serde_json::from_value(input.credential).map_err(|e| { 255 + warn!("Failed to parse credential: {:?}", e); 256 + ApiError::InvalidCredential 257 + })?; 262 258 263 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 264 - Ok(w) => w, 265 - Err(e) => { 266 - error!("Failed to create WebAuthn config: {:?}", e); 267 - return ApiError::InternalError(None).into_response(); 268 - } 269 - }; 259 + let webauthn = crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname).map_err(|e| { 260 + error!("Failed to create WebAuthn config: {:?}", e); 261 + ApiError::InternalError(None) 262 + })?; 270 263 271 - let auth_result = match webauthn.finish_authentication(&credential, &auth_state) { 272 - Ok(r) => r, 273 - Err(e) => { 274 - warn!(did = %&auth.0.did, "Passkey re-auth failed: {:?}", e); 275 - return ApiError::AuthenticationFailed(Some("Passkey authentication failed".into())) 276 - .into_response(); 277 - } 278 - }; 264 + let auth_result = webauthn 265 + .finish_authentication(&credential, &auth_state) 266 + .map_err(|e| { 267 + warn!(did = %&auth_user.did, "Passkey re-auth failed: {:?}", e); 268 + ApiError::AuthenticationFailed(Some("Passkey authentication failed".into())) 269 + })?; 279 270 280 271 let cred_id_bytes = auth_result.cred_id().as_ref(); 281 272 match state ··· 284 275 .await 285 276 { 286 277 Ok(false) => { 287 - warn!(did = %&auth.0.did, "Passkey counter anomaly detected - possible cloned key"); 278 + warn!(did = %&auth_user.did, "Passkey counter anomaly detected - possible cloned key"); 288 279 let _ = state 289 280 .user_repo 290 - .delete_webauthn_challenge(&auth.0.did, "authentication") 281 + .delete_webauthn_challenge(&auth_user.did, "authentication") 291 282 .await; 292 - return ApiError::PasskeyCounterAnomaly.into_response(); 283 + return Err(ApiError::PasskeyCounterAnomaly); 293 284 } 294 285 Err(e) => { 295 286 error!("Failed to update passkey counter: {:?}", e); ··· 299 290 300 291 let _ = state 301 292 .user_repo 302 - .delete_webauthn_challenge(&auth.0.did, "authentication") 293 + .delete_webauthn_challenge(&auth_user.did, "authentication") 303 294 .await; 304 295 305 - match update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.0.did).await { 306 - Ok(reauthed_at) => { 307 - info!(did = %&auth.0.did, "Re-auth successful via passkey"); 308 - Json(ReauthResponse { reauthed_at }).into_response() 309 - } 310 - Err(e) => { 296 + let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth_user.did) 297 + .await 298 + .map_err(|e| { 311 299 error!("DB error updating reauth: {:?}", e); 312 - ApiError::InternalError(None).into_response() 313 - } 314 - } 300 + ApiError::InternalError(None) 301 + })?; 302 + 303 + info!(did = %&auth_user.did, "Re-auth successful via passkey"); 304 + Ok(Json(ReauthResponse { reauthed_at }).into_response()) 315 305 } 316 306 317 307 pub async fn update_last_reauth_cached(
+174 -169
crates/tranquil-pds/src/api/server/session.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::{EmptyResponse, SuccessResponse}; 3 - use crate::auth::{BearerAuth, BearerAuthAllowDeactivated}; 3 + use crate::auth::RequiredAuth; 4 4 use crate::state::{AppState, RateLimitKind}; 5 5 use crate::types::{AccountState, Did, Handle, PlainPassword}; 6 6 use axum::{ ··· 279 279 280 280 pub async fn get_session( 281 281 State(state): State<AppState>, 282 - BearerAuthAllowDeactivated(auth_user): BearerAuthAllowDeactivated, 283 - ) -> Response { 282 + auth: RequiredAuth, 283 + ) -> Result<Response, ApiError> { 284 + let auth_user = auth.0.require_user()?.require_not_takendown()?; 284 285 let permissions = auth_user.permissions(); 285 286 let can_read_email = permissions.allows_email_read(); 286 287 ··· 337 338 if let Some(doc) = did_doc { 338 339 response["didDoc"] = doc; 339 340 } 340 - Json(response).into_response() 341 + Ok(Json(response).into_response()) 341 342 } 342 - Ok(None) => ApiError::AuthenticationFailed(None).into_response(), 343 + Ok(None) => Err(ApiError::AuthenticationFailed(None)), 343 344 Err(e) => { 344 345 error!("Database error in get_session: {:?}", e); 345 - ApiError::InternalError(None).into_response() 346 + Err(ApiError::InternalError(None)) 346 347 } 347 348 } 348 349 } ··· 350 351 pub async fn delete_session( 351 352 State(state): State<AppState>, 352 353 headers: axum::http::HeaderMap, 353 - _auth: BearerAuth, 354 - ) -> Response { 355 - let extracted = match crate::auth::extract_auth_token_from_header( 354 + auth: RequiredAuth, 355 + ) -> Result<Response, ApiError> { 356 + auth.0.require_user()?.require_active()?; 357 + let extracted = crate::auth::extract_auth_token_from_header( 356 358 headers.get("Authorization").and_then(|h| h.to_str().ok()), 357 - ) { 358 - Some(t) => t, 359 - None => return ApiError::AuthenticationRequired.into_response(), 360 - }; 361 - let jti = match crate::auth::get_jti_from_token(&extracted.token) { 362 - Ok(jti) => jti, 363 - Err(_) => return ApiError::AuthenticationFailed(None).into_response(), 364 - }; 359 + ) 360 + .ok_or(ApiError::AuthenticationRequired)?; 361 + let jti = crate::auth::get_jti_from_token(&extracted.token) 362 + .map_err(|_| ApiError::AuthenticationFailed(None))?; 365 363 let did = crate::auth::get_did_from_token(&extracted.token).ok(); 366 364 match state.session_repo.delete_session_by_access_jti(&jti).await { 367 365 Ok(rows) if rows > 0 => { ··· 369 367 let session_cache_key = format!("auth:session:{}:{}", did, jti); 370 368 let _ = state.cache.delete(&session_cache_key).await; 371 369 } 372 - EmptyResponse::ok().into_response() 370 + Ok(EmptyResponse::ok().into_response()) 373 371 } 374 - Ok(_) => ApiError::AuthenticationFailed(None).into_response(), 375 - Err(_) => ApiError::AuthenticationFailed(None).into_response(), 372 + Ok(_) => Err(ApiError::AuthenticationFailed(None)), 373 + Err(_) => Err(ApiError::AuthenticationFailed(None)), 376 374 } 377 375 } 378 376 ··· 796 794 pub async fn list_sessions( 797 795 State(state): State<AppState>, 798 796 headers: HeaderMap, 799 - auth: BearerAuth, 800 - ) -> Response { 797 + auth: RequiredAuth, 798 + ) -> Result<Response, ApiError> { 799 + let auth_user = auth.0.require_user()?.require_active()?; 801 800 let current_jti = headers 802 801 .get("authorization") 803 802 .and_then(|v| v.to_str().ok()) 804 803 .and_then(|v| v.strip_prefix("Bearer ")) 805 804 .and_then(|token| crate::auth::get_jti_from_token(token).ok()); 806 805 807 - let jwt_rows = match state.session_repo.list_sessions_by_did(&auth.0.did).await { 808 - Ok(rows) => rows, 809 - Err(e) => { 806 + let jwt_rows = state 807 + .session_repo 808 + .list_sessions_by_did(&auth_user.did) 809 + .await 810 + .map_err(|e| { 810 811 error!("DB error fetching JWT sessions: {:?}", e); 811 - return ApiError::InternalError(None).into_response(); 812 - } 813 - }; 812 + ApiError::InternalError(None) 813 + })?; 814 814 815 - let oauth_rows = match state.oauth_repo.list_sessions_by_did(&auth.0.did).await { 816 - Ok(rows) => rows, 817 - Err(e) => { 815 + let oauth_rows = state 816 + .oauth_repo 817 + .list_sessions_by_did(&auth_user.did) 818 + .await 819 + .map_err(|e| { 818 820 error!("DB error fetching OAuth sessions: {:?}", e); 819 - return ApiError::InternalError(None).into_response(); 820 - } 821 - }; 821 + ApiError::InternalError(None) 822 + })?; 822 823 823 824 let jwt_sessions = jwt_rows.into_iter().map(|row| SessionInfo { 824 825 id: format!("jwt:{}", row.id), ··· 829 830 is_current: current_jti.as_ref() == Some(&row.access_jti), 830 831 }); 831 832 832 - let is_oauth = auth.0.is_oauth; 833 + let is_oauth = auth_user.is_oauth; 833 834 let oauth_sessions = oauth_rows.into_iter().map(|row| { 834 835 let client_name = extract_client_name(&row.client_id); 835 836 let is_current_oauth = is_oauth && current_jti.as_deref() == Some(row.token_id.as_str()); ··· 846 847 let mut sessions: Vec<SessionInfo> = jwt_sessions.chain(oauth_sessions).collect(); 847 848 sessions.sort_by(|a, b| b.created_at.cmp(&a.created_at)); 848 849 849 - (StatusCode::OK, Json(ListSessionsOutput { sessions })).into_response() 850 + Ok((StatusCode::OK, Json(ListSessionsOutput { sessions })).into_response()) 850 851 } 851 852 852 853 fn extract_client_name(client_id: &str) -> String { ··· 867 868 868 869 pub async fn revoke_session( 869 870 State(state): State<AppState>, 870 - auth: BearerAuth, 871 + auth: RequiredAuth, 871 872 Json(input): Json<RevokeSessionInput>, 872 - ) -> Response { 873 + ) -> Result<Response, ApiError> { 874 + let auth_user = auth.0.require_user()?.require_active()?; 873 875 if let Some(jwt_id) = input.session_id.strip_prefix("jwt:") { 874 - let Ok(session_id) = jwt_id.parse::<i32>() else { 875 - return ApiError::InvalidRequest("Invalid session ID".into()).into_response(); 876 - }; 877 - let access_jti = match state 876 + let session_id: i32 = jwt_id 877 + .parse() 878 + .map_err(|_| ApiError::InvalidRequest("Invalid session ID".into()))?; 879 + let access_jti = state 878 880 .session_repo 879 - .get_session_access_jti_by_id(session_id, &auth.0.did) 881 + .get_session_access_jti_by_id(session_id, &auth_user.did) 880 882 .await 881 - { 882 - Ok(Some(jti)) => jti, 883 - Ok(None) => { 884 - return ApiError::SessionNotFound.into_response(); 885 - } 886 - Err(e) => { 883 + .map_err(|e| { 887 884 error!("DB error in revoke_session: {:?}", e); 888 - return ApiError::InternalError(None).into_response(); 889 - } 890 - }; 891 - if let Err(e) = state.session_repo.delete_session_by_id(session_id).await { 892 - error!("DB error deleting session: {:?}", e); 893 - return ApiError::InternalError(None).into_response(); 894 - } 895 - let cache_key = format!("auth:session:{}:{}", &auth.0.did, access_jti); 885 + ApiError::InternalError(None) 886 + })? 887 + .ok_or(ApiError::SessionNotFound)?; 888 + state 889 + .session_repo 890 + .delete_session_by_id(session_id) 891 + .await 892 + .map_err(|e| { 893 + error!("DB error deleting session: {:?}", e); 894 + ApiError::InternalError(None) 895 + })?; 896 + let cache_key = format!("auth:session:{}:{}", &auth_user.did, access_jti); 896 897 if let Err(e) = state.cache.delete(&cache_key).await { 897 898 warn!("Failed to invalidate session cache: {:?}", e); 898 899 } 899 - info!(did = %&auth.0.did, session_id = %session_id, "JWT session revoked"); 900 + info!(did = %&auth_user.did, session_id = %session_id, "JWT session revoked"); 900 901 } else if let Some(oauth_id) = input.session_id.strip_prefix("oauth:") { 901 - let Ok(session_id) = oauth_id.parse::<i32>() else { 902 - return ApiError::InvalidRequest("Invalid session ID".into()).into_response(); 903 - }; 904 - match state 902 + let session_id: i32 = oauth_id 903 + .parse() 904 + .map_err(|_| ApiError::InvalidRequest("Invalid session ID".into()))?; 905 + let deleted = state 905 906 .oauth_repo 906 - .delete_session_by_id(session_id, &auth.0.did) 907 + .delete_session_by_id(session_id, &auth_user.did) 907 908 .await 908 - { 909 - Ok(0) => { 910 - return ApiError::SessionNotFound.into_response(); 911 - } 912 - Err(e) => { 909 + .map_err(|e| { 913 910 error!("DB error deleting OAuth session: {:?}", e); 914 - return ApiError::InternalError(None).into_response(); 915 - } 916 - _ => {} 911 + ApiError::InternalError(None) 912 + })?; 913 + if deleted == 0 { 914 + return Err(ApiError::SessionNotFound); 917 915 } 918 - info!(did = %&auth.0.did, session_id = %session_id, "OAuth session revoked"); 916 + info!(did = %&auth_user.did, session_id = %session_id, "OAuth session revoked"); 919 917 } else { 920 - return ApiError::InvalidRequest("Invalid session ID format".into()).into_response(); 918 + return Err(ApiError::InvalidRequest("Invalid session ID format".into())); 921 919 } 922 - EmptyResponse::ok().into_response() 920 + Ok(EmptyResponse::ok().into_response()) 923 921 } 924 922 925 923 pub async fn revoke_all_sessions( 926 924 State(state): State<AppState>, 927 925 headers: HeaderMap, 928 - auth: BearerAuth, 929 - ) -> Response { 930 - let current_jti = crate::auth::extract_auth_token_from_header( 926 + auth: RequiredAuth, 927 + ) -> Result<Response, ApiError> { 928 + let auth_user = auth.0.require_user()?.require_active()?; 929 + let jti = crate::auth::extract_auth_token_from_header( 931 930 headers.get("authorization").and_then(|v| v.to_str().ok()), 932 931 ) 933 - .and_then(|extracted| crate::auth::get_jti_from_token(&extracted.token).ok()); 932 + .and_then(|extracted| crate::auth::get_jti_from_token(&extracted.token).ok()) 933 + .ok_or(ApiError::InvalidToken(None))?; 934 934 935 - let Some(ref jti) = current_jti else { 936 - return ApiError::InvalidToken(None).into_response(); 937 - }; 938 - 939 - if auth.0.is_oauth { 940 - if let Err(e) = state.session_repo.delete_sessions_by_did(&auth.0.did).await { 941 - error!("DB error revoking JWT sessions: {:?}", e); 942 - return ApiError::InternalError(None).into_response(); 943 - } 935 + if auth_user.is_oauth { 936 + state 937 + .session_repo 938 + .delete_sessions_by_did(&auth_user.did) 939 + .await 940 + .map_err(|e| { 941 + error!("DB error revoking JWT sessions: {:?}", e); 942 + ApiError::InternalError(None) 943 + })?; 944 944 let jti_typed = TokenId::from(jti.clone()); 945 - if let Err(e) = state 945 + state 946 946 .oauth_repo 947 - .delete_sessions_by_did_except(&auth.0.did, &jti_typed) 947 + .delete_sessions_by_did_except(&auth_user.did, &jti_typed) 948 948 .await 949 - { 950 - error!("DB error revoking OAuth sessions: {:?}", e); 951 - return ApiError::InternalError(None).into_response(); 952 - } 949 + .map_err(|e| { 950 + error!("DB error revoking OAuth sessions: {:?}", e); 951 + ApiError::InternalError(None) 952 + })?; 953 953 } else { 954 - if let Err(e) = state 954 + state 955 955 .session_repo 956 - .delete_sessions_by_did_except_jti(&auth.0.did, jti) 956 + .delete_sessions_by_did_except_jti(&auth_user.did, &jti) 957 + .await 958 + .map_err(|e| { 959 + error!("DB error revoking JWT sessions: {:?}", e); 960 + ApiError::InternalError(None) 961 + })?; 962 + state 963 + .oauth_repo 964 + .delete_sessions_by_did(&auth_user.did) 957 965 .await 958 - { 959 - error!("DB error revoking JWT sessions: {:?}", e); 960 - return ApiError::InternalError(None).into_response(); 961 - } 962 - if let Err(e) = state.oauth_repo.delete_sessions_by_did(&auth.0.did).await { 963 - error!("DB error revoking OAuth sessions: {:?}", e); 964 - return ApiError::InternalError(None).into_response(); 965 - } 966 + .map_err(|e| { 967 + error!("DB error revoking OAuth sessions: {:?}", e); 968 + ApiError::InternalError(None) 969 + })?; 966 970 } 967 971 968 - info!(did = %&auth.0.did, "All other sessions revoked"); 969 - SuccessResponse::ok().into_response() 972 + info!(did = %&auth_user.did, "All other sessions revoked"); 973 + Ok(SuccessResponse::ok().into_response()) 970 974 } 971 975 972 976 #[derive(Serialize)] ··· 978 982 979 983 pub async fn get_legacy_login_preference( 980 984 State(state): State<AppState>, 981 - auth: BearerAuth, 982 - ) -> Response { 983 - match state.user_repo.get_legacy_login_pref(&auth.0.did).await { 984 - Ok(Some(pref)) => Json(LegacyLoginPreferenceOutput { 985 - allow_legacy_login: pref.allow_legacy_login, 986 - has_mfa: pref.has_mfa, 987 - }) 988 - .into_response(), 989 - Ok(None) => ApiError::AccountNotFound.into_response(), 990 - Err(e) => { 985 + auth: RequiredAuth, 986 + ) -> Result<Response, ApiError> { 987 + let auth_user = auth.0.require_user()?.require_active()?; 988 + let pref = state 989 + .user_repo 990 + .get_legacy_login_pref(&auth_user.did) 991 + .await 992 + .map_err(|e| { 991 993 error!("DB error: {:?}", e); 992 - ApiError::InternalError(None).into_response() 993 - } 994 - } 994 + ApiError::InternalError(None) 995 + })? 996 + .ok_or(ApiError::AccountNotFound)?; 997 + Ok(Json(LegacyLoginPreferenceOutput { 998 + allow_legacy_login: pref.allow_legacy_login, 999 + has_mfa: pref.has_mfa, 1000 + }) 1001 + .into_response()) 995 1002 } 996 1003 997 1004 #[derive(Deserialize)] ··· 1002 1009 1003 1010 pub async fn update_legacy_login_preference( 1004 1011 State(state): State<AppState>, 1005 - auth: BearerAuth, 1012 + auth: RequiredAuth, 1006 1013 Json(input): Json<UpdateLegacyLoginInput>, 1007 - ) -> Response { 1008 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did) 1014 + ) -> Result<Response, ApiError> { 1015 + let auth_user = auth.0.require_user()?.require_active()?; 1016 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth_user.did) 1009 1017 .await 1010 1018 { 1011 - return crate::api::server::reauth::legacy_mfa_required_response( 1019 + return Ok(crate::api::server::reauth::legacy_mfa_required_response( 1012 1020 &*state.user_repo, 1013 1021 &*state.session_repo, 1014 - &auth.0.did, 1022 + &auth_user.did, 1015 1023 ) 1016 - .await; 1024 + .await); 1017 1025 } 1018 1026 1019 - if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.0.did).await { 1020 - return crate::api::server::reauth::reauth_required_response( 1027 + if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth_user.did).await 1028 + { 1029 + return Ok(crate::api::server::reauth::reauth_required_response( 1021 1030 &*state.user_repo, 1022 1031 &*state.session_repo, 1023 - &auth.0.did, 1032 + &auth_user.did, 1024 1033 ) 1025 - .await; 1034 + .await); 1026 1035 } 1027 1036 1028 - match state 1037 + let updated = state 1029 1038 .user_repo 1030 - .update_legacy_login(&auth.0.did, input.allow_legacy_login) 1039 + .update_legacy_login(&auth_user.did, input.allow_legacy_login) 1031 1040 .await 1032 - { 1033 - Ok(true) => { 1034 - info!( 1035 - did = %&auth.0.did, 1036 - allow_legacy_login = input.allow_legacy_login, 1037 - "Legacy login preference updated" 1038 - ); 1039 - Json(json!({ 1040 - "allowLegacyLogin": input.allow_legacy_login 1041 - })) 1042 - .into_response() 1043 - } 1044 - Ok(false) => ApiError::AccountNotFound.into_response(), 1045 - Err(e) => { 1041 + .map_err(|e| { 1046 1042 error!("DB error: {:?}", e); 1047 - ApiError::InternalError(None).into_response() 1048 - } 1043 + ApiError::InternalError(None) 1044 + })?; 1045 + if !updated { 1046 + return Err(ApiError::AccountNotFound); 1049 1047 } 1048 + info!( 1049 + did = %&auth_user.did, 1050 + allow_legacy_login = input.allow_legacy_login, 1051 + "Legacy login preference updated" 1052 + ); 1053 + Ok(Json(json!({ 1054 + "allowLegacyLogin": input.allow_legacy_login 1055 + })) 1056 + .into_response()) 1050 1057 } 1051 1058 1052 1059 use crate::comms::VALID_LOCALES; ··· 1059 1066 1060 1067 pub async fn update_locale( 1061 1068 State(state): State<AppState>, 1062 - auth: BearerAuth, 1069 + auth: RequiredAuth, 1063 1070 Json(input): Json<UpdateLocaleInput>, 1064 - ) -> Response { 1071 + ) -> Result<Response, ApiError> { 1072 + let auth_user = auth.0.require_user()?.require_active()?; 1065 1073 if !VALID_LOCALES.contains(&input.preferred_locale.as_str()) { 1066 - return ApiError::InvalidRequest(format!( 1074 + return Err(ApiError::InvalidRequest(format!( 1067 1075 "Invalid locale. Valid options: {}", 1068 1076 VALID_LOCALES.join(", ") 1069 - )) 1070 - .into_response(); 1077 + ))); 1071 1078 } 1072 1079 1073 - match state 1080 + let updated = state 1074 1081 .user_repo 1075 - .update_locale(&auth.0.did, &input.preferred_locale) 1082 + .update_locale(&auth_user.did, &input.preferred_locale) 1076 1083 .await 1077 - { 1078 - Ok(true) => { 1079 - info!( 1080 - did = %&auth.0.did, 1081 - locale = %input.preferred_locale, 1082 - "User locale preference updated" 1083 - ); 1084 - Json(json!({ 1085 - "preferredLocale": input.preferred_locale 1086 - })) 1087 - .into_response() 1088 - } 1089 - Ok(false) => ApiError::AccountNotFound.into_response(), 1090 - Err(e) => { 1084 + .map_err(|e| { 1091 1085 error!("DB error updating locale: {:?}", e); 1092 - ApiError::InternalError(None).into_response() 1093 - } 1086 + ApiError::InternalError(None) 1087 + })?; 1088 + if !updated { 1089 + return Err(ApiError::AccountNotFound); 1094 1090 } 1091 + info!( 1092 + did = %&auth_user.did, 1093 + locale = %input.preferred_locale, 1094 + "User locale preference updated" 1095 + ); 1096 + Ok(Json(json!({ 1097 + "preferredLocale": input.preferred_locale 1098 + })) 1099 + .into_response()) 1095 1100 }
+166 -160
crates/tranquil-pds/src/api/server/totp.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::BearerAuth; 3 + use crate::auth::RequiredAuth; 4 4 use crate::auth::{ 5 5 decrypt_totp_secret, encrypt_totp_secret, generate_backup_codes, generate_qr_png_base64, 6 6 generate_totp_secret, generate_totp_uri, hash_backup_code, is_backup_code_format, ··· 26 26 pub qr_base64: String, 27 27 } 28 28 29 - pub async fn create_totp_secret(State(state): State<AppState>, auth: BearerAuth) -> Response { 30 - match state.user_repo.get_totp_record(&auth.0.did).await { 31 - Ok(Some(record)) if record.verified => return ApiError::TotpAlreadyEnabled.into_response(), 29 + pub async fn create_totp_secret( 30 + State(state): State<AppState>, 31 + auth: RequiredAuth, 32 + ) -> Result<Response, ApiError> { 33 + let auth_user = auth.0.require_user()?.require_active()?; 34 + match state.user_repo.get_totp_record(&auth_user.did).await { 35 + Ok(Some(record)) if record.verified => return Err(ApiError::TotpAlreadyEnabled), 32 36 Ok(_) => {} 33 37 Err(e) => { 34 38 error!("DB error checking TOTP: {:?}", e); 35 - return ApiError::InternalError(None).into_response(); 39 + return Err(ApiError::InternalError(None)); 36 40 } 37 41 } 38 42 39 43 let secret = generate_totp_secret(); 40 44 41 - let handle = match state.user_repo.get_handle_by_did(&auth.0.did).await { 42 - Ok(Some(h)) => h, 43 - Ok(None) => return ApiError::AccountNotFound.into_response(), 44 - Err(e) => { 45 + let handle = state 46 + .user_repo 47 + .get_handle_by_did(&auth_user.did) 48 + .await 49 + .map_err(|e| { 45 50 error!("DB error fetching handle: {:?}", e); 46 - return ApiError::InternalError(None).into_response(); 47 - } 48 - }; 51 + ApiError::InternalError(None) 52 + })? 53 + .ok_or(ApiError::AccountNotFound)?; 49 54 50 55 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 51 56 let uri = generate_totp_uri(&secret, &handle, &hostname); 52 57 53 - let qr_code = match generate_qr_png_base64(&secret, &handle, &hostname) { 54 - Ok(qr) => qr, 55 - Err(e) => { 56 - error!("Failed to generate QR code: {:?}", e); 57 - return ApiError::InternalError(Some("Failed to generate QR code".into())) 58 - .into_response(); 59 - } 60 - }; 58 + let qr_code = generate_qr_png_base64(&secret, &handle, &hostname).map_err(|e| { 59 + error!("Failed to generate QR code: {:?}", e); 60 + ApiError::InternalError(Some("Failed to generate QR code".into())) 61 + })?; 61 62 62 - let encrypted_secret = match encrypt_totp_secret(&secret) { 63 - Ok(enc) => enc, 64 - Err(e) => { 65 - error!("Failed to encrypt TOTP secret: {:?}", e); 66 - return ApiError::InternalError(None).into_response(); 67 - } 68 - }; 63 + let encrypted_secret = encrypt_totp_secret(&secret).map_err(|e| { 64 + error!("Failed to encrypt TOTP secret: {:?}", e); 65 + ApiError::InternalError(None) 66 + })?; 69 67 70 - if let Err(e) = state 68 + state 71 69 .user_repo 72 - .upsert_totp_secret(&auth.0.did, &encrypted_secret, ENCRYPTION_VERSION) 70 + .upsert_totp_secret(&auth_user.did, &encrypted_secret, ENCRYPTION_VERSION) 73 71 .await 74 - { 75 - error!("Failed to store TOTP secret: {:?}", e); 76 - return ApiError::InternalError(None).into_response(); 77 - } 72 + .map_err(|e| { 73 + error!("Failed to store TOTP secret: {:?}", e); 74 + ApiError::InternalError(None) 75 + })?; 78 76 79 77 let secret_base32 = base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &secret); 80 78 81 - info!(did = %&auth.0.did, "TOTP secret created (pending verification)"); 79 + info!(did = %&auth_user.did, "TOTP secret created (pending verification)"); 82 80 83 - Json(CreateTotpSecretResponse { 81 + Ok(Json(CreateTotpSecretResponse { 84 82 secret: secret_base32, 85 83 uri, 86 84 qr_base64: qr_code, 87 85 }) 88 - .into_response() 86 + .into_response()) 89 87 } 90 88 91 89 #[derive(Deserialize)] ··· 101 99 102 100 pub async fn enable_totp( 103 101 State(state): State<AppState>, 104 - auth: BearerAuth, 102 + auth: RequiredAuth, 105 103 Json(input): Json<EnableTotpInput>, 106 - ) -> Response { 104 + ) -> Result<Response, ApiError> { 105 + let auth_user = auth.0.require_user()?.require_active()?; 107 106 if !state 108 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.0.did) 107 + .check_rate_limit(RateLimitKind::TotpVerify, &auth_user.did) 109 108 .await 110 109 { 111 - warn!(did = %&auth.0.did, "TOTP verification rate limit exceeded"); 112 - return ApiError::RateLimitExceeded(None).into_response(); 110 + warn!(did = %&auth_user.did, "TOTP verification rate limit exceeded"); 111 + return Err(ApiError::RateLimitExceeded(None)); 113 112 } 114 113 115 - let totp_record = match state.user_repo.get_totp_record(&auth.0.did).await { 114 + let totp_record = match state.user_repo.get_totp_record(&auth_user.did).await { 116 115 Ok(Some(row)) => row, 117 - Ok(None) => return ApiError::TotpNotEnabled.into_response(), 116 + Ok(None) => return Err(ApiError::TotpNotEnabled), 118 117 Err(e) => { 119 118 error!("DB error fetching TOTP: {:?}", e); 120 - return ApiError::InternalError(None).into_response(); 119 + return Err(ApiError::InternalError(None)); 121 120 } 122 121 }; 123 122 124 123 if totp_record.verified { 125 - return ApiError::TotpAlreadyEnabled.into_response(); 124 + return Err(ApiError::TotpAlreadyEnabled); 126 125 } 127 126 128 - let secret = match decrypt_totp_secret( 127 + let secret = decrypt_totp_secret( 129 128 &totp_record.secret_encrypted, 130 129 totp_record.encryption_version, 131 - ) { 132 - Ok(s) => s, 133 - Err(e) => { 134 - error!("Failed to decrypt TOTP secret: {:?}", e); 135 - return ApiError::InternalError(None).into_response(); 136 - } 137 - }; 130 + ) 131 + .map_err(|e| { 132 + error!("Failed to decrypt TOTP secret: {:?}", e); 133 + ApiError::InternalError(None) 134 + })?; 138 135 139 136 let code = input.code.trim(); 140 137 if !verify_totp_code(&secret, code) { 141 - return ApiError::InvalidCode(Some("Invalid verification code".into())).into_response(); 138 + return Err(ApiError::InvalidCode(Some( 139 + "Invalid verification code".into(), 140 + ))); 142 141 } 143 142 144 143 let backup_codes = generate_backup_codes(); 145 - let backup_hashes: Result<Vec<_>, _> = 146 - backup_codes.iter().map(|c| hash_backup_code(c)).collect(); 147 - let backup_hashes = match backup_hashes { 148 - Ok(hashes) => hashes, 149 - Err(e) => { 144 + let backup_hashes: Vec<_> = backup_codes 145 + .iter() 146 + .map(|c| hash_backup_code(c)) 147 + .collect::<Result<Vec<_>, _>>() 148 + .map_err(|e| { 150 149 error!("Failed to hash backup code: {:?}", e); 151 - return ApiError::InternalError(None).into_response(); 152 - } 153 - }; 150 + ApiError::InternalError(None) 151 + })?; 154 152 155 - if let Err(e) = state 153 + state 156 154 .user_repo 157 - .enable_totp_with_backup_codes(&auth.0.did, &backup_hashes) 155 + .enable_totp_with_backup_codes(&auth_user.did, &backup_hashes) 158 156 .await 159 - { 160 - error!("Failed to enable TOTP: {:?}", e); 161 - return ApiError::InternalError(None).into_response(); 162 - } 157 + .map_err(|e| { 158 + error!("Failed to enable TOTP: {:?}", e); 159 + ApiError::InternalError(None) 160 + })?; 163 161 164 - info!(did = %&auth.0.did, "TOTP enabled with {} backup codes", backup_codes.len()); 162 + info!(did = %&auth_user.did, "TOTP enabled with {} backup codes", backup_codes.len()); 165 163 166 - Json(EnableTotpResponse { backup_codes }).into_response() 164 + Ok(Json(EnableTotpResponse { backup_codes }).into_response()) 167 165 } 168 166 169 167 #[derive(Deserialize)] ··· 174 172 175 173 pub async fn disable_totp( 176 174 State(state): State<AppState>, 177 - auth: BearerAuth, 175 + auth: RequiredAuth, 178 176 Json(input): Json<DisableTotpInput>, 179 - ) -> Response { 180 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did) 177 + ) -> Result<Response, ApiError> { 178 + let auth_user = auth.0.require_user()?.require_active()?; 179 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth_user.did) 181 180 .await 182 181 { 183 - return crate::api::server::reauth::legacy_mfa_required_response( 182 + return Ok(crate::api::server::reauth::legacy_mfa_required_response( 184 183 &*state.user_repo, 185 184 &*state.session_repo, 186 - &auth.0.did, 185 + &auth_user.did, 187 186 ) 188 - .await; 187 + .await); 189 188 } 190 189 191 190 if !state 192 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.0.did) 191 + .check_rate_limit(RateLimitKind::TotpVerify, &auth_user.did) 193 192 .await 194 193 { 195 - warn!(did = %&auth.0.did, "TOTP verification rate limit exceeded"); 196 - return ApiError::RateLimitExceeded(None).into_response(); 194 + warn!(did = %&auth_user.did, "TOTP verification rate limit exceeded"); 195 + return Err(ApiError::RateLimitExceeded(None)); 197 196 } 198 197 199 - let password_hash = match state.user_repo.get_password_hash_by_did(&auth.0.did).await { 200 - Ok(Some(hash)) => hash, 201 - Ok(None) => return ApiError::AccountNotFound.into_response(), 202 - Err(e) => { 198 + let password_hash = state 199 + .user_repo 200 + .get_password_hash_by_did(&auth_user.did) 201 + .await 202 + .map_err(|e| { 203 203 error!("DB error fetching user: {:?}", e); 204 - return ApiError::InternalError(None).into_response(); 205 - } 206 - }; 204 + ApiError::InternalError(None) 205 + })? 206 + .ok_or(ApiError::AccountNotFound)?; 207 207 208 208 let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); 209 209 if !password_valid { 210 - return ApiError::InvalidPassword("Password is incorrect".into()).into_response(); 210 + return Err(ApiError::InvalidPassword("Password is incorrect".into())); 211 211 } 212 212 213 - let totp_record = match state.user_repo.get_totp_record(&auth.0.did).await { 213 + let totp_record = match state.user_repo.get_totp_record(&auth_user.did).await { 214 214 Ok(Some(row)) if row.verified => row, 215 - Ok(Some(_)) | Ok(None) => return ApiError::TotpNotEnabled.into_response(), 215 + Ok(Some(_)) | Ok(None) => return Err(ApiError::TotpNotEnabled), 216 216 Err(e) => { 217 217 error!("DB error fetching TOTP: {:?}", e); 218 - return ApiError::InternalError(None).into_response(); 218 + return Err(ApiError::InternalError(None)); 219 219 } 220 220 }; 221 221 222 222 let code = input.code.trim(); 223 223 let code_valid = if is_backup_code_format(code) { 224 - verify_backup_code_for_user(&state, &auth.0.did, code).await 224 + verify_backup_code_for_user(&state, &auth_user.did, code).await 225 225 } else { 226 - let secret = match decrypt_totp_secret( 226 + let secret = decrypt_totp_secret( 227 227 &totp_record.secret_encrypted, 228 228 totp_record.encryption_version, 229 - ) { 230 - Ok(s) => s, 231 - Err(e) => { 232 - error!("Failed to decrypt TOTP secret: {:?}", e); 233 - return ApiError::InternalError(None).into_response(); 234 - } 235 - }; 229 + ) 230 + .map_err(|e| { 231 + error!("Failed to decrypt TOTP secret: {:?}", e); 232 + ApiError::InternalError(None) 233 + })?; 236 234 verify_totp_code(&secret, code) 237 235 }; 238 236 239 237 if !code_valid { 240 - return ApiError::InvalidCode(Some("Invalid verification code".into())).into_response(); 238 + return Err(ApiError::InvalidCode(Some( 239 + "Invalid verification code".into(), 240 + ))); 241 241 } 242 242 243 - if let Err(e) = state 243 + state 244 244 .user_repo 245 - .delete_totp_and_backup_codes(&auth.0.did) 245 + .delete_totp_and_backup_codes(&auth_user.did) 246 246 .await 247 - { 248 - error!("Failed to delete TOTP: {:?}", e); 249 - return ApiError::InternalError(None).into_response(); 250 - } 247 + .map_err(|e| { 248 + error!("Failed to delete TOTP: {:?}", e); 249 + ApiError::InternalError(None) 250 + })?; 251 251 252 - info!(did = %&auth.0.did, "TOTP disabled"); 252 + info!(did = %&auth_user.did, "TOTP disabled"); 253 253 254 - EmptyResponse::ok().into_response() 254 + Ok(EmptyResponse::ok().into_response()) 255 255 } 256 256 257 257 #[derive(Serialize)] ··· 262 262 pub backup_codes_remaining: i64, 263 263 } 264 264 265 - pub async fn get_totp_status(State(state): State<AppState>, auth: BearerAuth) -> Response { 266 - let enabled = match state.user_repo.get_totp_record(&auth.0.did).await { 265 + pub async fn get_totp_status( 266 + State(state): State<AppState>, 267 + auth: RequiredAuth, 268 + ) -> Result<Response, ApiError> { 269 + let auth_user = auth.0.require_user()?.require_active()?; 270 + let enabled = match state.user_repo.get_totp_record(&auth_user.did).await { 267 271 Ok(Some(row)) => row.verified, 268 272 Ok(None) => false, 269 273 Err(e) => { 270 274 error!("DB error fetching TOTP status: {:?}", e); 271 - return ApiError::InternalError(None).into_response(); 275 + return Err(ApiError::InternalError(None)); 272 276 } 273 277 }; 274 278 275 - let backup_count = match state.user_repo.count_unused_backup_codes(&auth.0.did).await { 276 - Ok(count) => count, 277 - Err(e) => { 279 + let backup_count = state 280 + .user_repo 281 + .count_unused_backup_codes(&auth_user.did) 282 + .await 283 + .map_err(|e| { 278 284 error!("DB error counting backup codes: {:?}", e); 279 - return ApiError::InternalError(None).into_response(); 280 - } 281 - }; 285 + ApiError::InternalError(None) 286 + })?; 282 287 283 - Json(GetTotpStatusResponse { 288 + Ok(Json(GetTotpStatusResponse { 284 289 enabled, 285 290 has_backup_codes: backup_count > 0, 286 291 backup_codes_remaining: backup_count, 287 292 }) 288 - .into_response() 293 + .into_response()) 289 294 } 290 295 291 296 #[derive(Deserialize)] ··· 302 307 303 308 pub async fn regenerate_backup_codes( 304 309 State(state): State<AppState>, 305 - auth: BearerAuth, 310 + auth: RequiredAuth, 306 311 Json(input): Json<RegenerateBackupCodesInput>, 307 - ) -> Response { 312 + ) -> Result<Response, ApiError> { 313 + let auth_user = auth.0.require_user()?.require_active()?; 308 314 if !state 309 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.0.did) 315 + .check_rate_limit(RateLimitKind::TotpVerify, &auth_user.did) 310 316 .await 311 317 { 312 - warn!(did = %&auth.0.did, "TOTP verification rate limit exceeded"); 313 - return ApiError::RateLimitExceeded(None).into_response(); 318 + warn!(did = %&auth_user.did, "TOTP verification rate limit exceeded"); 319 + return Err(ApiError::RateLimitExceeded(None)); 314 320 } 315 321 316 - let password_hash = match state.user_repo.get_password_hash_by_did(&auth.0.did).await { 317 - Ok(Some(hash)) => hash, 318 - Ok(None) => return ApiError::AccountNotFound.into_response(), 319 - Err(e) => { 322 + let password_hash = state 323 + .user_repo 324 + .get_password_hash_by_did(&auth_user.did) 325 + .await 326 + .map_err(|e| { 320 327 error!("DB error fetching user: {:?}", e); 321 - return ApiError::InternalError(None).into_response(); 322 - } 323 - }; 328 + ApiError::InternalError(None) 329 + })? 330 + .ok_or(ApiError::AccountNotFound)?; 324 331 325 332 let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); 326 333 if !password_valid { 327 - return ApiError::InvalidPassword("Password is incorrect".into()).into_response(); 334 + return Err(ApiError::InvalidPassword("Password is incorrect".into())); 328 335 } 329 336 330 - let totp_record = match state.user_repo.get_totp_record(&auth.0.did).await { 337 + let totp_record = match state.user_repo.get_totp_record(&auth_user.did).await { 331 338 Ok(Some(row)) if row.verified => row, 332 - Ok(Some(_)) | Ok(None) => return ApiError::TotpNotEnabled.into_response(), 339 + Ok(Some(_)) | Ok(None) => return Err(ApiError::TotpNotEnabled), 333 340 Err(e) => { 334 341 error!("DB error fetching TOTP: {:?}", e); 335 - return ApiError::InternalError(None).into_response(); 342 + return Err(ApiError::InternalError(None)); 336 343 } 337 344 }; 338 345 339 - let secret = match decrypt_totp_secret( 346 + let secret = decrypt_totp_secret( 340 347 &totp_record.secret_encrypted, 341 348 totp_record.encryption_version, 342 - ) { 343 - Ok(s) => s, 344 - Err(e) => { 345 - error!("Failed to decrypt TOTP secret: {:?}", e); 346 - return ApiError::InternalError(None).into_response(); 347 - } 348 - }; 349 + ) 350 + .map_err(|e| { 351 + error!("Failed to decrypt TOTP secret: {:?}", e); 352 + ApiError::InternalError(None) 353 + })?; 349 354 350 355 let code = input.code.trim(); 351 356 if !verify_totp_code(&secret, code) { 352 - return ApiError::InvalidCode(Some("Invalid verification code".into())).into_response(); 357 + return Err(ApiError::InvalidCode(Some( 358 + "Invalid verification code".into(), 359 + ))); 353 360 } 354 361 355 362 let backup_codes = generate_backup_codes(); 356 - let backup_hashes: Result<Vec<_>, _> = 357 - backup_codes.iter().map(|c| hash_backup_code(c)).collect(); 358 - let backup_hashes = match backup_hashes { 359 - Ok(hashes) => hashes, 360 - Err(e) => { 363 + let backup_hashes: Vec<_> = backup_codes 364 + .iter() 365 + .map(|c| hash_backup_code(c)) 366 + .collect::<Result<Vec<_>, _>>() 367 + .map_err(|e| { 361 368 error!("Failed to hash backup code: {:?}", e); 362 - return ApiError::InternalError(None).into_response(); 363 - } 364 - }; 369 + ApiError::InternalError(None) 370 + })?; 365 371 366 - if let Err(e) = state 372 + state 367 373 .user_repo 368 - .replace_backup_codes(&auth.0.did, &backup_hashes) 374 + .replace_backup_codes(&auth_user.did, &backup_hashes) 369 375 .await 370 - { 371 - error!("Failed to regenerate backup codes: {:?}", e); 372 - return ApiError::InternalError(None).into_response(); 373 - } 376 + .map_err(|e| { 377 + error!("Failed to regenerate backup codes: {:?}", e); 378 + ApiError::InternalError(None) 379 + })?; 374 380 375 - info!(did = %&auth.0.did, "Backup codes regenerated"); 381 + info!(did = %&auth_user.did, "Backup codes regenerated"); 376 382 377 - Json(RegenerateBackupCodesResponse { backup_codes }).into_response() 383 + Ok(Json(RegenerateBackupCodesResponse { backup_codes }).into_response()) 378 384 } 379 385 380 386 async fn verify_backup_code_for_user(
+60 -55
crates/tranquil-pds/src/api/server/trusted_devices.rs
··· 11 11 use tranquil_db_traits::OAuthRepository; 12 12 use tranquil_types::DeviceId; 13 13 14 - use crate::auth::BearerAuth; 14 + use crate::auth::RequiredAuth; 15 15 use crate::state::AppState; 16 16 17 17 const TRUST_DURATION_DAYS: i64 = 30; ··· 71 71 pub devices: Vec<TrustedDevice>, 72 72 } 73 73 74 - pub async fn list_trusted_devices(State(state): State<AppState>, auth: BearerAuth) -> Response { 75 - match state.oauth_repo.list_trusted_devices(&auth.0.did).await { 76 - Ok(rows) => { 77 - let devices = rows 78 - .into_iter() 79 - .map(|row| { 80 - let trust_state = 81 - DeviceTrustState::from_timestamps(row.trusted_at, row.trusted_until); 82 - TrustedDevice { 83 - id: row.id, 84 - user_agent: row.user_agent, 85 - friendly_name: row.friendly_name, 86 - trusted_at: row.trusted_at, 87 - trusted_until: row.trusted_until, 88 - last_seen_at: row.last_seen_at, 89 - trust_state, 90 - } 91 - }) 92 - .collect(); 93 - Json(ListTrustedDevicesResponse { devices }).into_response() 94 - } 95 - Err(e) => { 74 + pub async fn list_trusted_devices( 75 + State(state): State<AppState>, 76 + auth: RequiredAuth, 77 + ) -> Result<Response, ApiError> { 78 + let auth_user = auth.0.require_user()?.require_active()?; 79 + let rows = state 80 + .oauth_repo 81 + .list_trusted_devices(&auth_user.did) 82 + .await 83 + .map_err(|e| { 96 84 error!("DB error: {:?}", e); 97 - ApiError::InternalError(None).into_response() 98 - } 99 - } 85 + ApiError::InternalError(None) 86 + })?; 87 + 88 + let devices = rows 89 + .into_iter() 90 + .map(|row| { 91 + let trust_state = DeviceTrustState::from_timestamps(row.trusted_at, row.trusted_until); 92 + TrustedDevice { 93 + id: row.id, 94 + user_agent: row.user_agent, 95 + friendly_name: row.friendly_name, 96 + trusted_at: row.trusted_at, 97 + trusted_until: row.trusted_until, 98 + last_seen_at: row.last_seen_at, 99 + trust_state, 100 + } 101 + }) 102 + .collect(); 103 + 104 + Ok(Json(ListTrustedDevicesResponse { devices }).into_response()) 100 105 } 101 106 102 107 #[derive(Deserialize)] ··· 107 112 108 113 pub async fn revoke_trusted_device( 109 114 State(state): State<AppState>, 110 - auth: BearerAuth, 115 + auth: RequiredAuth, 111 116 Json(input): Json<RevokeTrustedDeviceInput>, 112 - ) -> Response { 117 + ) -> Result<Response, ApiError> { 118 + let auth_user = auth.0.require_user()?.require_active()?; 113 119 let device_id = DeviceId::from(input.device_id.clone()); 114 120 match state 115 121 .oauth_repo 116 - .device_belongs_to_user(&device_id, &auth.0.did) 122 + .device_belongs_to_user(&device_id, &auth_user.did) 117 123 .await 118 124 { 119 125 Ok(true) => {} 120 126 Ok(false) => { 121 - return ApiError::DeviceNotFound.into_response(); 127 + return Err(ApiError::DeviceNotFound); 122 128 } 123 129 Err(e) => { 124 130 error!("DB error: {:?}", e); 125 - return ApiError::InternalError(None).into_response(); 131 + return Err(ApiError::InternalError(None)); 126 132 } 127 133 } 128 134 129 - match state.oauth_repo.revoke_device_trust(&device_id).await { 130 - Ok(()) => { 131 - info!(did = %&auth.0.did, device_id = %input.device_id, "Trusted device revoked"); 132 - SuccessResponse::ok().into_response() 133 - } 134 - Err(e) => { 135 + state 136 + .oauth_repo 137 + .revoke_device_trust(&device_id) 138 + .await 139 + .map_err(|e| { 135 140 error!("DB error: {:?}", e); 136 - ApiError::InternalError(None).into_response() 137 - } 138 - } 141 + ApiError::InternalError(None) 142 + })?; 143 + 144 + info!(did = %&auth_user.did, device_id = %input.device_id, "Trusted device revoked"); 145 + Ok(SuccessResponse::ok().into_response()) 139 146 } 140 147 141 148 #[derive(Deserialize)] ··· 147 154 148 155 pub async fn update_trusted_device( 149 156 State(state): State<AppState>, 150 - auth: BearerAuth, 157 + auth: RequiredAuth, 151 158 Json(input): Json<UpdateTrustedDeviceInput>, 152 - ) -> Response { 159 + ) -> Result<Response, ApiError> { 160 + let auth_user = auth.0.require_user()?.require_active()?; 153 161 let device_id = DeviceId::from(input.device_id.clone()); 154 162 match state 155 163 .oauth_repo 156 - .device_belongs_to_user(&device_id, &auth.0.did) 164 + .device_belongs_to_user(&device_id, &auth_user.did) 157 165 .await 158 166 { 159 167 Ok(true) => {} 160 168 Ok(false) => { 161 - return ApiError::DeviceNotFound.into_response(); 169 + return Err(ApiError::DeviceNotFound); 162 170 } 163 171 Err(e) => { 164 172 error!("DB error: {:?}", e); 165 - return ApiError::InternalError(None).into_response(); 173 + return Err(ApiError::InternalError(None)); 166 174 } 167 175 } 168 176 169 - match state 177 + state 170 178 .oauth_repo 171 179 .update_device_friendly_name(&device_id, input.friendly_name.as_deref()) 172 180 .await 173 - { 174 - Ok(()) => { 175 - info!(did = %auth.0.did, device_id = %input.device_id, "Trusted device updated"); 176 - SuccessResponse::ok().into_response() 177 - } 178 - Err(e) => { 181 + .map_err(|e| { 179 182 error!("DB error: {:?}", e); 180 - ApiError::InternalError(None).into_response() 181 - } 182 - } 183 + ApiError::InternalError(None) 184 + })?; 185 + 186 + info!(did = %auth_user.did, device_id = %input.device_id, "Trusted device updated"); 187 + Ok(SuccessResponse::ok().into_response()) 183 188 } 184 189 185 190 pub async fn get_device_trust_state(
+11 -27
crates/tranquil-pds/src/api/temp.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::{BearerAuth, extract_auth_token_from_header, validate_token_with_dpop}; 2 + use crate::auth::{OptionalAuth, RequiredAuth}; 3 3 use crate::state::AppState; 4 4 use axum::{ 5 5 Json, 6 6 extract::State, 7 - http::HeaderMap, 8 7 response::{IntoResponse, Response}, 9 8 }; 10 9 use cid::Cid; ··· 22 21 pub estimated_time_ms: Option<i64>, 23 22 } 24 23 25 - pub async fn check_signup_queue(State(state): State<AppState>, headers: HeaderMap) -> Response { 26 - if let Some(extracted) = 27 - extract_auth_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) 24 + pub async fn check_signup_queue(auth: OptionalAuth) -> Response { 25 + if let Some(entity) = auth.0 26 + && let Some(user) = entity.as_user() 27 + && user.is_oauth 28 28 { 29 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 30 - if let Ok(user) = validate_token_with_dpop( 31 - state.user_repo.as_ref(), 32 - state.oauth_repo.as_ref(), 33 - &extracted.token, 34 - extracted.is_dpop, 35 - dpop_proof, 36 - "GET", 37 - "/", 38 - false, 39 - false, 40 - ) 41 - .await 42 - && user.is_oauth 43 - { 44 - return ApiError::Forbidden.into_response(); 45 - } 29 + return ApiError::Forbidden.into_response(); 46 30 } 47 31 Json(CheckSignupQueueOutput { 48 32 activated: true, ··· 66 50 67 51 pub async fn dereference_scope( 68 52 State(state): State<AppState>, 69 - auth: BearerAuth, 53 + auth: RequiredAuth, 70 54 Json(input): Json<DereferenceScopeInput>, 71 - ) -> Response { 72 - let _ = auth; 55 + ) -> Result<Response, ApiError> { 56 + let _user = auth.0.require_user()?.require_active()?; 73 57 74 58 let scope_parts: Vec<&str> = input.scope.split_whitespace().collect(); 75 59 let mut resolved_scopes: Vec<String> = Vec::new(); ··· 135 119 } 136 120 } 137 121 138 - Json(DereferenceScopeOutput { 122 + Ok(Json(DereferenceScopeOutput { 139 123 scope: resolved_scopes.join(" "), 140 124 }) 141 - .into_response() 125 + .into_response()) 142 126 }
+547
crates/tranquil-pds/src/auth/auth_extractor.rs
··· 1 + mod common; 2 + mod helpers; 3 + 4 + use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 5 + use chrono::Utc; 6 + use common::{base_url, client, create_account_and_login, pds_endpoint}; 7 + use helpers::verify_new_account; 8 + use reqwest::StatusCode; 9 + use serde_json::{Value, json}; 10 + use sha2::{Digest, Sha256}; 11 + use wiremock::matchers::{method, path}; 12 + use wiremock::{Mock, MockServer, ResponseTemplate}; 13 + 14 + fn generate_pkce() -> (String, String) { 15 + let verifier_bytes: [u8; 32] = rand::random(); 16 + let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); 17 + let mut hasher = Sha256::new(); 18 + hasher.update(code_verifier.as_bytes()); 19 + let code_challenge = URL_SAFE_NO_PAD.encode(hasher.finalize()); 20 + (code_verifier, code_challenge) 21 + } 22 + 23 + async fn setup_mock_client_metadata(redirect_uri: &str, dpop_bound: bool) -> MockServer { 24 + let mock_server = MockServer::start().await; 25 + let metadata = json!({ 26 + "client_id": mock_server.uri(), 27 + "client_name": "Auth Extractor Test Client", 28 + "redirect_uris": [redirect_uri], 29 + "grant_types": ["authorization_code", "refresh_token"], 30 + "response_types": ["code"], 31 + "token_endpoint_auth_method": "none", 32 + "dpop_bound_access_tokens": dpop_bound 33 + }); 34 + Mock::given(method("GET")) 35 + .and(path("/")) 36 + .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 37 + .mount(&mock_server) 38 + .await; 39 + mock_server 40 + } 41 + 42 + async fn get_oauth_session( 43 + http_client: &reqwest::Client, 44 + url: &str, 45 + dpop_bound: bool, 46 + ) -> (String, String, String, String) { 47 + let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 48 + let handle = format!("ae{}", suffix); 49 + let password = "AuthExtract123!"; 50 + let create_res = http_client 51 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 52 + .json(&json!({ 53 + "handle": handle, 54 + "email": format!("{}@example.com", handle), 55 + "password": password 56 + })) 57 + .send() 58 + .await 59 + .unwrap(); 60 + assert_eq!(create_res.status(), StatusCode::OK); 61 + let account: Value = create_res.json().await.unwrap(); 62 + let did = account["did"].as_str().unwrap().to_string(); 63 + verify_new_account(http_client, &did).await; 64 + 65 + let redirect_uri = "https://example.com/auth-callback"; 66 + let mock_client = setup_mock_client_metadata(redirect_uri, dpop_bound).await; 67 + let client_id = mock_client.uri(); 68 + let (code_verifier, code_challenge) = generate_pkce(); 69 + 70 + let par_body: Value = http_client 71 + .post(format!("{}/oauth/par", url)) 72 + .form(&[ 73 + ("response_type", "code"), 74 + ("client_id", &client_id), 75 + ("redirect_uri", redirect_uri), 76 + ("code_challenge", &code_challenge), 77 + ("code_challenge_method", "S256"), 78 + ]) 79 + .send() 80 + .await 81 + .unwrap() 82 + .json() 83 + .await 84 + .unwrap(); 85 + let request_uri = par_body["request_uri"].as_str().unwrap(); 86 + 87 + let auth_res = http_client 88 + .post(format!("{}/oauth/authorize", url)) 89 + .header("Content-Type", "application/json") 90 + .header("Accept", "application/json") 91 + .json(&json!({ 92 + "request_uri": request_uri, 93 + "username": &handle, 94 + "password": password, 95 + "remember_device": false 96 + })) 97 + .send() 98 + .await 99 + .unwrap(); 100 + let auth_body: Value = auth_res.json().await.unwrap(); 101 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 102 + 103 + if location.contains("/oauth/consent") { 104 + let consent_res = http_client 105 + .post(format!("{}/oauth/authorize/consent", url)) 106 + .header("Content-Type", "application/json") 107 + .json(&json!({ 108 + "request_uri": request_uri, 109 + "approved_scopes": ["atproto"], 110 + "remember": false 111 + })) 112 + .send() 113 + .await 114 + .unwrap(); 115 + let consent_body: Value = consent_res.json().await.unwrap(); 116 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 117 + } 118 + 119 + let code = location 120 + .split("code=") 121 + .nth(1) 122 + .unwrap() 123 + .split('&') 124 + .next() 125 + .unwrap(); 126 + 127 + let token_body: Value = http_client 128 + .post(format!("{}/oauth/token", url)) 129 + .form(&[ 130 + ("grant_type", "authorization_code"), 131 + ("code", code), 132 + ("redirect_uri", redirect_uri), 133 + ("code_verifier", &code_verifier), 134 + ("client_id", &client_id), 135 + ]) 136 + .send() 137 + .await 138 + .unwrap() 139 + .json() 140 + .await 141 + .unwrap(); 142 + 143 + ( 144 + token_body["access_token"].as_str().unwrap().to_string(), 145 + token_body["refresh_token"].as_str().unwrap().to_string(), 146 + client_id, 147 + did, 148 + ) 149 + } 150 + 151 + #[tokio::test] 152 + async fn test_oauth_token_works_with_bearer_auth() { 153 + let url = base_url().await; 154 + let http_client = client(); 155 + let (access_token, _, _, did) = get_oauth_session(&http_client, url, false).await; 156 + 157 + let res = http_client 158 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 159 + .bearer_auth(&access_token) 160 + .send() 161 + .await 162 + .unwrap(); 163 + 164 + assert_eq!(res.status(), StatusCode::OK, "OAuth token should work with RequiredAuth extractor"); 165 + let body: Value = res.json().await.unwrap(); 166 + assert_eq!(body["did"].as_str().unwrap(), did); 167 + } 168 + 169 + #[tokio::test] 170 + async fn test_session_token_still_works() { 171 + let url = base_url().await; 172 + let http_client = client(); 173 + let (jwt, did) = create_account_and_login(&http_client).await; 174 + 175 + let res = http_client 176 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 177 + .bearer_auth(&jwt) 178 + .send() 179 + .await 180 + .unwrap(); 181 + 182 + assert_eq!(res.status(), StatusCode::OK, "Session token should still work"); 183 + let body: Value = res.json().await.unwrap(); 184 + assert_eq!(body["did"].as_str().unwrap(), did); 185 + } 186 + 187 + 188 + #[tokio::test] 189 + async fn test_oauth_admin_extractor_allows_oauth_tokens() { 190 + let url = base_url().await; 191 + let http_client = client(); 192 + 193 + let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 194 + let handle = format!("adm{}", suffix); 195 + let password = "AdminOAuth123!"; 196 + let create_res = http_client 197 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 198 + .json(&json!({ 199 + "handle": handle, 200 + "email": format!("{}@example.com", handle), 201 + "password": password 202 + })) 203 + .send() 204 + .await 205 + .unwrap(); 206 + assert_eq!(create_res.status(), StatusCode::OK); 207 + let account: Value = create_res.json().await.unwrap(); 208 + let did = account["did"].as_str().unwrap().to_string(); 209 + verify_new_account(&http_client, &did).await; 210 + 211 + let pool = common::get_test_db_pool().await; 212 + sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did) 213 + .execute(pool) 214 + .await 215 + .expect("Failed to mark user as admin"); 216 + 217 + let redirect_uri = "https://example.com/admin-callback"; 218 + let mock_client = setup_mock_client_metadata(redirect_uri, false).await; 219 + let client_id = mock_client.uri(); 220 + let (code_verifier, code_challenge) = generate_pkce(); 221 + 222 + let par_body: Value = http_client 223 + .post(format!("{}/oauth/par", url)) 224 + .form(&[ 225 + ("response_type", "code"), 226 + ("client_id", &client_id), 227 + ("redirect_uri", redirect_uri), 228 + ("code_challenge", &code_challenge), 229 + ("code_challenge_method", "S256"), 230 + ]) 231 + .send() 232 + .await 233 + .unwrap() 234 + .json() 235 + .await 236 + .unwrap(); 237 + let request_uri = par_body["request_uri"].as_str().unwrap(); 238 + 239 + let auth_res = http_client 240 + .post(format!("{}/oauth/authorize", url)) 241 + .header("Content-Type", "application/json") 242 + .header("Accept", "application/json") 243 + .json(&json!({ 244 + "request_uri": request_uri, 245 + "username": &handle, 246 + "password": password, 247 + "remember_device": false 248 + })) 249 + .send() 250 + .await 251 + .unwrap(); 252 + let auth_body: Value = auth_res.json().await.unwrap(); 253 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 254 + if location.contains("/oauth/consent") { 255 + let consent_res = http_client 256 + .post(format!("{}/oauth/authorize/consent", url)) 257 + .header("Content-Type", "application/json") 258 + .json(&json!({ 259 + "request_uri": request_uri, 260 + "approved_scopes": ["atproto"], 261 + "remember": false 262 + })) 263 + .send() 264 + .await 265 + .unwrap(); 266 + let consent_body: Value = consent_res.json().await.unwrap(); 267 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 268 + } 269 + 270 + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 271 + let token_body: Value = http_client 272 + .post(format!("{}/oauth/token", url)) 273 + .form(&[ 274 + ("grant_type", "authorization_code"), 275 + ("code", code), 276 + ("redirect_uri", redirect_uri), 277 + ("code_verifier", &code_verifier), 278 + ("client_id", &client_id), 279 + ]) 280 + .send() 281 + .await 282 + .unwrap() 283 + .json() 284 + .await 285 + .unwrap(); 286 + let access_token = token_body["access_token"].as_str().unwrap(); 287 + 288 + let res = http_client 289 + .get(format!("{}/xrpc/com.atproto.admin.getAccountInfos?dids={}", url, did)) 290 + .bearer_auth(access_token) 291 + .send() 292 + .await 293 + .unwrap(); 294 + 295 + assert_eq!( 296 + res.status(), 297 + StatusCode::OK, 298 + "OAuth token for admin user should work with admin endpoint" 299 + ); 300 + } 301 + 302 + #[tokio::test] 303 + async fn test_expired_oauth_token_returns_proper_error() { 304 + let url = base_url().await; 305 + let http_client = client(); 306 + 307 + let now = Utc::now().timestamp(); 308 + let header = json!({"alg": "HS256", "typ": "at+jwt"}); 309 + let payload = json!({ 310 + "iss": url, 311 + "sub": "did:plc:test123", 312 + "aud": url, 313 + "iat": now - 7200, 314 + "exp": now - 3600, 315 + "jti": "expired-token", 316 + "sid": "expired-session", 317 + "scope": "atproto", 318 + "client_id": "https://example.com" 319 + }); 320 + let fake_token = format!( 321 + "{}.{}.{}", 322 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()), 323 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), 324 + URL_SAFE_NO_PAD.encode([1u8; 32]) 325 + ); 326 + 327 + let res = http_client 328 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 329 + .bearer_auth(&fake_token) 330 + .send() 331 + .await 332 + .unwrap(); 333 + 334 + assert_eq!( 335 + res.status(), 336 + StatusCode::UNAUTHORIZED, 337 + "Expired token should be rejected" 338 + ); 339 + } 340 + 341 + #[tokio::test] 342 + async fn test_dpop_nonce_error_has_proper_headers() { 343 + let url = base_url().await; 344 + let pds_url = pds_endpoint(); 345 + let http_client = client(); 346 + 347 + let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 348 + let handle = format!("dpop{}", suffix); 349 + let create_res = http_client 350 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 351 + .json(&json!({ 352 + "handle": handle, 353 + "email": format!("{}@test.com", handle), 354 + "password": "DpopTest123!" 355 + })) 356 + .send() 357 + .await 358 + .unwrap(); 359 + assert_eq!(create_res.status(), StatusCode::OK); 360 + let account: Value = create_res.json().await.unwrap(); 361 + let did = account["did"].as_str().unwrap(); 362 + verify_new_account(&http_client, did).await; 363 + 364 + let redirect_uri = "https://example.com/dpop-callback"; 365 + let mock_server = MockServer::start().await; 366 + let client_id = mock_server.uri(); 367 + let metadata = json!({ 368 + "client_id": &client_id, 369 + "client_name": "DPoP Test Client", 370 + "redirect_uris": [redirect_uri], 371 + "grant_types": ["authorization_code", "refresh_token"], 372 + "response_types": ["code"], 373 + "token_endpoint_auth_method": "none", 374 + "dpop_bound_access_tokens": true 375 + }); 376 + Mock::given(method("GET")) 377 + .and(path("/")) 378 + .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 379 + .mount(&mock_server) 380 + .await; 381 + 382 + let (code_verifier, code_challenge) = generate_pkce(); 383 + let par_body: Value = http_client 384 + .post(format!("{}/oauth/par", url)) 385 + .form(&[ 386 + ("response_type", "code"), 387 + ("client_id", &client_id), 388 + ("redirect_uri", redirect_uri), 389 + ("code_challenge", &code_challenge), 390 + ("code_challenge_method", "S256"), 391 + ]) 392 + .send() 393 + .await 394 + .unwrap() 395 + .json() 396 + .await 397 + .unwrap(); 398 + 399 + let request_uri = par_body["request_uri"].as_str().unwrap(); 400 + let auth_res = http_client 401 + .post(format!("{}/oauth/authorize", url)) 402 + .header("Content-Type", "application/json") 403 + .header("Accept", "application/json") 404 + .json(&json!({ 405 + "request_uri": request_uri, 406 + "username": &handle, 407 + "password": "DpopTest123!", 408 + "remember_device": false 409 + })) 410 + .send() 411 + .await 412 + .unwrap(); 413 + let auth_body: Value = auth_res.json().await.unwrap(); 414 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 415 + if location.contains("/oauth/consent") { 416 + let consent_res = http_client 417 + .post(format!("{}/oauth/authorize/consent", url)) 418 + .header("Content-Type", "application/json") 419 + .json(&json!({ 420 + "request_uri": request_uri, 421 + "approved_scopes": ["atproto"], 422 + "remember": false 423 + })) 424 + .send() 425 + .await 426 + .unwrap(); 427 + let consent_body: Value = consent_res.json().await.unwrap(); 428 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 429 + } 430 + 431 + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 432 + 433 + let token_endpoint = format!("{}/oauth/token", pds_url); 434 + let (_, dpop_proof) = generate_dpop_proof("POST", &token_endpoint, None); 435 + 436 + let token_res = http_client 437 + .post(format!("{}/oauth/token", url)) 438 + .header("DPoP", &dpop_proof) 439 + .form(&[ 440 + ("grant_type", "authorization_code"), 441 + ("code", code), 442 + ("redirect_uri", redirect_uri), 443 + ("code_verifier", &code_verifier), 444 + ("client_id", &client_id), 445 + ]) 446 + .send() 447 + .await 448 + .unwrap(); 449 + 450 + let token_status = token_res.status(); 451 + let token_nonce = token_res.headers().get("dpop-nonce").map(|h| h.to_str().unwrap().to_string()); 452 + let token_body: Value = token_res.json().await.unwrap(); 453 + 454 + let access_token = if token_status == StatusCode::OK { 455 + token_body["access_token"].as_str().unwrap().to_string() 456 + } else if token_body.get("error").and_then(|e| e.as_str()) == Some("use_dpop_nonce") { 457 + let nonce = token_nonce.expect("Token endpoint should return DPoP-Nonce on use_dpop_nonce error"); 458 + let (_, dpop_proof_with_nonce) = generate_dpop_proof("POST", &token_endpoint, Some(&nonce)); 459 + 460 + let retry_res = http_client 461 + .post(format!("{}/oauth/token", url)) 462 + .header("DPoP", &dpop_proof_with_nonce) 463 + .form(&[ 464 + ("grant_type", "authorization_code"), 465 + ("code", code), 466 + ("redirect_uri", redirect_uri), 467 + ("code_verifier", &code_verifier), 468 + ("client_id", &client_id), 469 + ]) 470 + .send() 471 + .await 472 + .unwrap(); 473 + let retry_body: Value = retry_res.json().await.unwrap(); 474 + retry_body["access_token"].as_str().expect("Should get access_token after nonce retry").to_string() 475 + } else { 476 + panic!("Token exchange failed unexpectedly: {:?}", token_body); 477 + }; 478 + 479 + let res = http_client 480 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 481 + .header("Authorization", format!("DPoP {}", access_token)) 482 + .send() 483 + .await 484 + .unwrap(); 485 + 486 + assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "DPoP token without proof should fail"); 487 + 488 + let www_auth = res.headers().get("www-authenticate").map(|h| h.to_str().unwrap()); 489 + assert!(www_auth.is_some(), "Should have WWW-Authenticate header"); 490 + assert!( 491 + www_auth.unwrap().contains("use_dpop_nonce"), 492 + "WWW-Authenticate should indicate dpop nonce required" 493 + ); 494 + 495 + let nonce = res.headers().get("dpop-nonce").map(|h| h.to_str().unwrap()); 496 + assert!(nonce.is_some(), "Should return DPoP-Nonce header"); 497 + 498 + let body: Value = res.json().await.unwrap(); 499 + assert_eq!(body["error"].as_str().unwrap(), "use_dpop_nonce"); 500 + } 501 + 502 + fn generate_dpop_proof(method: &str, uri: &str, nonce: Option<&str>) -> (Value, String) { 503 + use p256::ecdsa::{SigningKey, signature::Signer}; 504 + use p256::elliptic_curve::rand_core::OsRng; 505 + 506 + let signing_key = SigningKey::random(&mut OsRng); 507 + let verifying_key = signing_key.verifying_key(); 508 + let point = verifying_key.to_encoded_point(false); 509 + let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 510 + let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 511 + 512 + let jwk = json!({ 513 + "kty": "EC", 514 + "crv": "P-256", 515 + "x": x, 516 + "y": y 517 + }); 518 + 519 + let header = { 520 + let h = json!({ 521 + "typ": "dpop+jwt", 522 + "alg": "ES256", 523 + "jwk": jwk.clone() 524 + }); 525 + h 526 + }; 527 + 528 + let mut payload = json!({ 529 + "jti": uuid::Uuid::new_v4().to_string(), 530 + "htm": method, 531 + "htu": uri, 532 + "iat": Utc::now().timestamp() 533 + }); 534 + if let Some(n) = nonce { 535 + payload["nonce"] = json!(n); 536 + } 537 + 538 + let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 539 + let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 540 + let signing_input = format!("{}.{}", header_b64, payload_b64); 541 + 542 + let signature: p256::ecdsa::Signature = signing_key.sign(signing_input.as_bytes()); 543 + let sig_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 544 + 545 + let proof = format!("{}.{}", signing_input, sig_b64); 546 + (jwk, proof) 547 + }
+254 -242
crates/tranquil-pds/src/auth/extractor.rs
··· 1 1 use axum::{ 2 2 extract::FromRequestParts, 3 - http::{header::AUTHORIZATION, request::Parts}, 3 + http::{StatusCode, header::AUTHORIZATION, request::Parts}, 4 4 response::{IntoResponse, Response}, 5 5 }; 6 + use tracing::{debug, error, info}; 6 7 7 8 use super::{ 8 - AuthenticatedUser, TokenValidationError, validate_bearer_token_allow_takendown, 9 - validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated, 10 - validate_token_with_dpop, 9 + AccountStatus, AuthenticatedUser, ServiceTokenClaims, ServiceTokenVerifier, is_service_token, 10 + validate_bearer_token_for_service_auth, 11 11 }; 12 12 use crate::api::error::ApiError; 13 13 use crate::state::AppState; 14 + use crate::types::Did; 14 15 use crate::util::build_full_url; 15 - 16 - pub struct BearerAuth(pub AuthenticatedUser); 17 16 18 17 #[derive(Debug)] 19 18 pub enum AuthError { ··· 24 23 AccountDeactivated, 25 24 AccountTakedown, 26 25 AdminRequired, 26 + OAuthExpiredToken(String), 27 + UseDpopNonce(String), 28 + InvalidDpopProof(String), 27 29 } 28 30 29 31 impl IntoResponse for AuthError { 30 32 fn into_response(self) -> Response { 31 - ApiError::from(self).into_response() 33 + match self { 34 + Self::UseDpopNonce(nonce) => ( 35 + StatusCode::UNAUTHORIZED, 36 + [ 37 + ("DPoP-Nonce", nonce.as_str()), 38 + ("WWW-Authenticate", "DPoP error=\"use_dpop_nonce\""), 39 + ], 40 + axum::Json(serde_json::json!({ 41 + "error": "use_dpop_nonce", 42 + "message": "DPoP nonce required" 43 + })), 44 + ) 45 + .into_response(), 46 + Self::OAuthExpiredToken(msg) => ApiError::OAuthExpiredToken(Some(msg)).into_response(), 47 + Self::InvalidDpopProof(msg) => ( 48 + StatusCode::UNAUTHORIZED, 49 + [("WWW-Authenticate", "DPoP error=\"invalid_dpop_proof\"")], 50 + axum::Json(serde_json::json!({ 51 + "error": "invalid_dpop_proof", 52 + "message": msg 53 + })), 54 + ) 55 + .into_response(), 56 + other => ApiError::from(other).into_response(), 57 + } 32 58 } 33 59 } 34 60 35 - #[cfg(test)] 36 - fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 37 - let auth_header = auth_header.trim(); 38 - 39 - if auth_header.len() < 8 { 40 - return Err(AuthError::InvalidFormat); 41 - } 42 - 43 - let prefix = &auth_header[..7]; 44 - if !prefix.eq_ignore_ascii_case("bearer ") { 45 - return Err(AuthError::InvalidFormat); 46 - } 47 - 48 - let token = auth_header[7..].trim(); 49 - if token.is_empty() { 50 - return Err(AuthError::InvalidFormat); 51 - } 52 - 53 - Ok(token) 61 + pub struct ExtractedToken { 62 + pub token: String, 63 + pub is_dpop: bool, 54 64 } 55 65 56 66 pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { ··· 73 83 Some(token.to_string()) 74 84 } 75 85 76 - pub struct ExtractedToken { 77 - pub token: String, 78 - pub is_dpop: bool, 79 - } 80 - 81 86 pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> { 82 87 let header = auth_header?; 83 88 let header = header.trim(); ··· 107 112 None 108 113 } 109 114 110 - impl FromRequestParts<AppState> for BearerAuth { 111 - type Rejection = AuthError; 115 + pub enum AuthenticatedEntity { 116 + User(AuthenticatedUser), 117 + Service { 118 + did: Did, 119 + claims: ServiceTokenClaims, 120 + }, 121 + } 112 122 113 - async fn from_request_parts( 114 - parts: &mut Parts, 115 - state: &AppState, 116 - ) -> Result<Self, Self::Rejection> { 117 - let auth_header = parts 118 - .headers 119 - .get(AUTHORIZATION) 120 - .ok_or(AuthError::MissingToken)? 121 - .to_str() 122 - .map_err(|_| AuthError::InvalidFormat)?; 123 + impl AuthenticatedEntity { 124 + pub fn did(&self) -> &Did { 125 + match self { 126 + Self::User(user) => &user.did, 127 + Self::Service { did, .. } => did, 128 + } 129 + } 123 130 124 - let extracted = 125 - extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 131 + pub fn as_user(&self) -> Option<&AuthenticatedUser> { 132 + match self { 133 + Self::User(user) => Some(user), 134 + Self::Service { .. } => None, 135 + } 136 + } 126 137 127 - if extracted.is_dpop { 128 - let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 129 - let method = parts.method.as_str(); 130 - let uri = build_full_url(&parts.uri.to_string()); 138 + pub fn as_service(&self) -> Option<(&Did, &ServiceTokenClaims)> { 139 + match self { 140 + Self::User(_) => None, 141 + Self::Service { did, claims } => Some((did, claims)), 142 + } 143 + } 131 144 132 - match validate_token_with_dpop( 133 - state.user_repo.as_ref(), 134 - state.oauth_repo.as_ref(), 135 - &extracted.token, 136 - true, 137 - dpop_proof, 138 - method, 139 - &uri, 140 - false, 141 - false, 142 - ) 143 - .await 144 - { 145 - Ok(user) => Ok(BearerAuth(user)), 146 - Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 147 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 148 - Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 149 - Err(_) => Err(AuthError::AuthenticationFailed), 150 - } 151 - } else { 152 - match validate_bearer_token_cached( 153 - state.user_repo.as_ref(), 154 - state.cache.as_ref(), 155 - &extracted.token, 156 - ) 157 - .await 158 - { 159 - Ok(user) => Ok(BearerAuth(user)), 160 - Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 161 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 162 - Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 163 - Err(_) => Err(AuthError::AuthenticationFailed), 164 - } 145 + pub fn require_user(&self) -> Result<&AuthenticatedUser, ApiError> { 146 + match self { 147 + Self::User(user) => Ok(user), 148 + Self::Service { .. } => Err(ApiError::AuthenticationFailed(Some( 149 + "User authentication required".to_string(), 150 + ))), 151 + } 152 + } 153 + 154 + pub fn require_service(&self) -> Result<(&Did, &ServiceTokenClaims), ApiError> { 155 + match self { 156 + Self::User(_) => Err(ApiError::AuthenticationFailed(Some( 157 + "Service authentication required".to_string(), 158 + ))), 159 + Self::Service { did, claims } => Ok((did, claims)), 160 + } 161 + } 162 + 163 + pub fn require_service_lxm( 164 + &self, 165 + expected_lxm: &str, 166 + ) -> Result<(&Did, &ServiceTokenClaims), ApiError> { 167 + let (did, claims) = self.require_service()?; 168 + match &claims.lxm { 169 + Some(lxm) if lxm == "*" || lxm == expected_lxm => Ok((did, claims)), 170 + Some(lxm) => Err(ApiError::AuthorizationError(format!( 171 + "Token lxm '{}' does not permit '{}'", 172 + lxm, expected_lxm 173 + ))), 174 + None => Err(ApiError::AuthorizationError( 175 + "Token missing lxm claim".to_string(), 176 + )), 177 + } 178 + } 179 + 180 + pub fn into_user(self) -> Result<AuthenticatedUser, ApiError> { 181 + match self { 182 + Self::User(user) => Ok(user), 183 + Self::Service { .. } => Err(ApiError::AuthenticationFailed(Some( 184 + "User authentication required".to_string(), 185 + ))), 165 186 } 166 187 } 167 188 } 168 189 169 - pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 190 + impl AuthenticatedUser { 191 + pub fn require_active(&self) -> Result<&Self, ApiError> { 192 + if self.status.is_deactivated() { 193 + return Err(ApiError::AccountDeactivated); 194 + } 195 + if self.status.is_takendown() { 196 + return Err(ApiError::AccountTakedown); 197 + } 198 + Ok(self) 199 + } 170 200 171 - impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 172 - type Rejection = AuthError; 201 + pub fn require_not_takendown(&self) -> Result<&Self, ApiError> { 202 + if self.status.is_takendown() { 203 + return Err(ApiError::AccountTakedown); 204 + } 205 + Ok(self) 206 + } 173 207 174 - async fn from_request_parts( 175 - parts: &mut Parts, 176 - state: &AppState, 177 - ) -> Result<Self, Self::Rejection> { 178 - let auth_header = parts 179 - .headers 180 - .get(AUTHORIZATION) 181 - .ok_or(AuthError::MissingToken)? 182 - .to_str() 183 - .map_err(|_| AuthError::InvalidFormat)?; 208 + pub fn require_admin(&self) -> Result<&Self, ApiError> { 209 + if !self.is_admin { 210 + return Err(ApiError::AdminRequired); 211 + } 212 + Ok(self) 213 + } 214 + } 184 215 185 - let extracted = 186 - extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 216 + async fn verify_oauth_token_and_build_user( 217 + state: &AppState, 218 + token: &str, 219 + dpop_proof: Option<&str>, 220 + method: &str, 221 + uri: &str, 222 + ) -> Result<AuthenticatedUser, AuthError> { 223 + match crate::oauth::verify::verify_oauth_access_token( 224 + state.oauth_repo.as_ref(), 225 + token, 226 + dpop_proof, 227 + method, 228 + uri, 229 + ) 230 + .await 231 + { 232 + Ok(result) => { 233 + let user_info = state 234 + .user_repo 235 + .get_user_info_by_did(&result.did) 236 + .await 237 + .ok() 238 + .flatten() 239 + .ok_or(AuthError::AuthenticationFailed)?; 240 + let status = AccountStatus::from_db_fields( 241 + user_info.takedown_ref.as_deref(), 242 + user_info.deactivated_at, 243 + ); 244 + Ok(AuthenticatedUser { 245 + did: result.did, 246 + key_bytes: user_info.key_bytes.and_then(|kb| { 247 + crate::config::decrypt_key(&kb, user_info.encryption_version).ok() 248 + }), 249 + is_oauth: true, 250 + is_admin: user_info.is_admin, 251 + status, 252 + scope: result.scope, 253 + controller_did: None, 254 + }) 255 + } 256 + Err(crate::oauth::OAuthError::ExpiredToken(msg)) => Err(AuthError::OAuthExpiredToken(msg)), 257 + Err(crate::oauth::OAuthError::UseDpopNonce(nonce)) => Err(AuthError::UseDpopNonce(nonce)), 258 + Err(crate::oauth::OAuthError::InvalidDpopProof(msg)) => { 259 + Err(AuthError::InvalidDpopProof(msg)) 260 + } 261 + Err(_) => Err(AuthError::AuthenticationFailed), 262 + } 263 + } 187 264 188 - if extracted.is_dpop { 189 - let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 190 - let method = parts.method.as_str(); 191 - let uri = build_full_url(&parts.uri.to_string()); 265 + async fn verify_service_token(token: &str) -> Result<(Did, ServiceTokenClaims), AuthError> { 266 + let verifier = ServiceTokenVerifier::new(); 267 + let claims = verifier 268 + .verify_service_token(token, None) 269 + .await 270 + .map_err(|e| { 271 + error!("Service token verification failed: {:?}", e); 272 + AuthError::AuthenticationFailed 273 + })?; 192 274 193 - match validate_token_with_dpop( 194 - state.user_repo.as_ref(), 195 - state.oauth_repo.as_ref(), 196 - &extracted.token, 197 - true, 198 - dpop_proof, 199 - method, 200 - &uri, 201 - true, 202 - false, 203 - ) 204 - .await 205 - { 206 - Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 207 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 208 - Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 209 - Err(_) => Err(AuthError::AuthenticationFailed), 210 - } 211 - } else { 212 - match validate_bearer_token_cached_allow_deactivated( 213 - state.user_repo.as_ref(), 214 - state.cache.as_ref(), 215 - &extracted.token, 216 - ) 217 - .await 218 - { 219 - Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 220 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 221 - Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 222 - Err(_) => Err(AuthError::AuthenticationFailed), 223 - } 275 + let did: Did = claims 276 + .iss 277 + .parse() 278 + .map_err(|_| AuthError::AuthenticationFailed)?; 279 + 280 + debug!("Service token verified for DID: {}", did); 281 + 282 + Ok((did, claims)) 283 + } 284 + 285 + async fn extract_auth_internal( 286 + parts: &mut Parts, 287 + state: &AppState, 288 + ) -> Result<AuthenticatedEntity, AuthError> { 289 + let auth_header = parts 290 + .headers 291 + .get(AUTHORIZATION) 292 + .ok_or(AuthError::MissingToken)? 293 + .to_str() 294 + .map_err(|_| AuthError::InvalidFormat)?; 295 + 296 + let extracted = 297 + extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 298 + 299 + if is_service_token(&extracted.token) { 300 + let (did, claims) = verify_service_token(&extracted.token).await?; 301 + return Ok(AuthenticatedEntity::Service { did, claims }); 302 + } 303 + 304 + let dpop_proof = parts.headers.get("DPoP").and_then(|h| h.to_str().ok()); 305 + let method = parts.method.as_str(); 306 + let uri = build_full_url(&parts.uri.to_string()); 307 + 308 + match validate_bearer_token_for_service_auth(state.user_repo.as_ref(), &extracted.token).await { 309 + Ok(user) if !user.is_oauth => { 310 + return Ok(AuthenticatedEntity::User(user)); 224 311 } 312 + Ok(_) => {} 313 + Err(super::TokenValidationError::TokenExpired) => { 314 + info!("JWT access token expired, returning ExpiredToken"); 315 + return Err(AuthError::TokenExpired); 316 + } 317 + Err(_) => {} 225 318 } 319 + 320 + let user = verify_oauth_token_and_build_user(state, &extracted.token, dpop_proof, method, &uri) 321 + .await?; 322 + 323 + Ok(AuthenticatedEntity::User(user)) 226 324 } 227 325 228 - pub struct BearerAuthAllowTakendown(pub AuthenticatedUser); 326 + pub struct RequiredAuth(pub AuthenticatedEntity); 229 327 230 - impl FromRequestParts<AppState> for BearerAuthAllowTakendown { 328 + impl FromRequestParts<AppState> for RequiredAuth { 231 329 type Rejection = AuthError; 232 330 233 331 async fn from_request_parts( 234 332 parts: &mut Parts, 235 333 state: &AppState, 236 334 ) -> Result<Self, Self::Rejection> { 237 - let auth_header = parts 238 - .headers 239 - .get(AUTHORIZATION) 240 - .ok_or(AuthError::MissingToken)? 241 - .to_str() 242 - .map_err(|_| AuthError::InvalidFormat)?; 243 - 244 - let extracted = 245 - extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 246 - 247 - if extracted.is_dpop { 248 - let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 249 - let method = parts.method.as_str(); 250 - let uri = build_full_url(&parts.uri.to_string()); 251 - 252 - match validate_token_with_dpop( 253 - state.user_repo.as_ref(), 254 - state.oauth_repo.as_ref(), 255 - &extracted.token, 256 - true, 257 - dpop_proof, 258 - method, 259 - &uri, 260 - false, 261 - true, 262 - ) 263 - .await 264 - { 265 - Ok(user) => Ok(BearerAuthAllowTakendown(user)), 266 - Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 267 - Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 268 - Err(_) => Err(AuthError::AuthenticationFailed), 269 - } 270 - } else { 271 - match validate_bearer_token_allow_takendown(state.user_repo.as_ref(), &extracted.token) 272 - .await 273 - { 274 - Ok(user) => Ok(BearerAuthAllowTakendown(user)), 275 - Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 276 - Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 277 - Err(_) => Err(AuthError::AuthenticationFailed), 278 - } 279 - } 335 + extract_auth_internal(parts, state).await.map(RequiredAuth) 280 336 } 281 337 } 282 338 283 - pub struct BearerAuthAdmin(pub AuthenticatedUser); 339 + pub struct OptionalAuth(pub Option<AuthenticatedEntity>); 284 340 285 - impl FromRequestParts<AppState> for BearerAuthAdmin { 286 - type Rejection = AuthError; 341 + impl FromRequestParts<AppState> for OptionalAuth { 342 + type Rejection = std::convert::Infallible; 287 343 288 344 async fn from_request_parts( 289 345 parts: &mut Parts, 290 346 state: &AppState, 291 347 ) -> Result<Self, Self::Rejection> { 292 - let auth_header = parts 293 - .headers 294 - .get(AUTHORIZATION) 295 - .ok_or(AuthError::MissingToken)? 296 - .to_str() 297 - .map_err(|_| AuthError::InvalidFormat)?; 348 + Ok(OptionalAuth(extract_auth_internal(parts, state).await.ok())) 349 + } 350 + } 298 351 299 - let extracted = 300 - extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 352 + #[cfg(test)] 353 + fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 354 + let auth_header = auth_header.trim(); 301 355 302 - let user = if extracted.is_dpop { 303 - let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 304 - let method = parts.method.as_str(); 305 - let uri = build_full_url(&parts.uri.to_string()); 356 + if auth_header.len() < 8 { 357 + return Err(AuthError::InvalidFormat); 358 + } 306 359 307 - match validate_token_with_dpop( 308 - state.user_repo.as_ref(), 309 - state.oauth_repo.as_ref(), 310 - &extracted.token, 311 - true, 312 - dpop_proof, 313 - method, 314 - &uri, 315 - false, 316 - false, 317 - ) 318 - .await 319 - { 320 - Ok(user) => user, 321 - Err(TokenValidationError::AccountDeactivated) => { 322 - return Err(AuthError::AccountDeactivated); 323 - } 324 - Err(TokenValidationError::AccountTakedown) => { 325 - return Err(AuthError::AccountTakedown); 326 - } 327 - Err(TokenValidationError::TokenExpired) => { 328 - return Err(AuthError::TokenExpired); 329 - } 330 - Err(_) => return Err(AuthError::AuthenticationFailed), 331 - } 332 - } else { 333 - match validate_bearer_token_cached( 334 - state.user_repo.as_ref(), 335 - state.cache.as_ref(), 336 - &extracted.token, 337 - ) 338 - .await 339 - { 340 - Ok(user) => user, 341 - Err(TokenValidationError::AccountDeactivated) => { 342 - return Err(AuthError::AccountDeactivated); 343 - } 344 - Err(TokenValidationError::AccountTakedown) => { 345 - return Err(AuthError::AccountTakedown); 346 - } 347 - Err(TokenValidationError::TokenExpired) => { 348 - return Err(AuthError::TokenExpired); 349 - } 350 - Err(_) => return Err(AuthError::AuthenticationFailed), 351 - } 352 - }; 360 + let prefix = &auth_header[..7]; 361 + if !prefix.eq_ignore_ascii_case("bearer ") { 362 + return Err(AuthError::InvalidFormat); 363 + } 353 364 354 - if !user.is_admin { 355 - return Err(AuthError::AdminRequired); 356 - } 357 - Ok(BearerAuthAdmin(user)) 365 + let token = auth_header[7..].trim(); 366 + if token.is_empty() { 367 + return Err(AuthError::InvalidFormat); 358 368 } 369 + 370 + Ok(token) 359 371 } 360 372 361 373 #[cfg(test)]
+1 -1
crates/tranquil-pds/src/auth/mod.rs
··· 16 16 pub mod webauthn; 17 17 18 18 pub use extractor::{ 19 - AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken, 19 + AuthError, AuthenticatedEntity, ExtractedToken, OptionalAuth, RequiredAuth, 20 20 extract_auth_token_from_header, extract_bearer_token_from_header, 21 21 }; 22 22 pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token};
+11 -2
crates/tranquil-pds/src/lib.rs
··· 528 528 )); 529 529 let xrpc_service = ServiceBuilder::new() 530 530 .layer(XrpcProxyLayer::new(state.clone())) 531 - .service(xrpc_router.with_state(state.clone())); 531 + .service( 532 + xrpc_router 533 + .layer(middleware::from_fn(oauth::verify::dpop_nonce_middleware)) 534 + .with_state(state.clone()), 535 + ); 532 536 533 537 let oauth_router = Router::new() 534 538 .route("/jwks", get(oauth::endpoints::oauth_jwks)) ··· 568 572 "/register/complete", 569 573 post(oauth::endpoints::register_complete), 570 574 ) 575 + .route( 576 + "/establish-session", 577 + post(oauth::endpoints::establish_session), 578 + ) 571 579 .route("/authorize/consent", get(oauth::endpoints::consent_get)) 572 580 .route("/authorize/consent", post(oauth::endpoints::consent_post)) 573 581 .route( ··· 605 613 .route( 606 614 "/sso/check-handle-available", 607 615 get(sso::endpoints::check_handle_available), 608 - ); 616 + ) 617 + .layer(middleware::from_fn(oauth::verify::dpop_nonce_middleware)); 609 618 610 619 let well_known_router = Router::new() 611 620 .route("/did.json", get(api::identity::well_known_did))
+173 -25
crates/tranquil-pds/src/oauth/endpoints/authorize.rs
··· 1 1 use crate::comms::{channel_display_name, comms_repo::enqueue_2fa_code}; 2 2 use crate::oauth::{ 3 3 AuthFlowState, ClientMetadataCache, Code, DeviceData, DeviceId, OAuthError, SessionId, 4 - db::should_show_consent, 4 + db::should_show_consent, scopes::expand_include_scopes, 5 5 }; 6 6 use crate::state::{AppState, RateLimitKind}; 7 7 use crate::types::{Did, Handle, PlainPassword}; ··· 1106 1106 .oauth_repo 1107 1107 .upsert_account_device(&did, &select_device_typed) 1108 1108 .await; 1109 + 1110 + let requested_scope_str = request_data 1111 + .parameters 1112 + .scope 1113 + .as_deref() 1114 + .unwrap_or("atproto"); 1115 + let requested_scopes: Vec<String> = requested_scope_str 1116 + .split_whitespace() 1117 + .map(|s| s.to_string()) 1118 + .collect(); 1119 + let client_id_typed = ClientId::from(request_data.parameters.client_id.clone()); 1120 + let needs_consent = should_show_consent( 1121 + state.oauth_repo.as_ref(), 1122 + &did, 1123 + &client_id_typed, 1124 + &requested_scopes, 1125 + ) 1126 + .await 1127 + .unwrap_or(true); 1128 + 1129 + if needs_consent { 1130 + if state 1131 + .oauth_repo 1132 + .set_authorization_did(&select_request_id, &did, Some(&select_device_typed)) 1133 + .await 1134 + .is_err() 1135 + { 1136 + return json_error( 1137 + StatusCode::INTERNAL_SERVER_ERROR, 1138 + "server_error", 1139 + "An error occurred. Please try again.", 1140 + ); 1141 + } 1142 + let consent_url = format!( 1143 + "/app/oauth/consent?request_uri={}", 1144 + url_encode(&form.request_uri) 1145 + ); 1146 + return Json(serde_json::json!({"redirect_uri": consent_url})).into_response(); 1147 + } 1148 + 1109 1149 let code = Code::generate(); 1110 1150 let select_code = AuthorizationCode::from(code.0.clone()); 1111 1151 if state ··· 1475 1515 requested_scope_str.to_string() 1476 1516 }; 1477 1517 1478 - let requested_scopes: Vec<&str> = effective_scope_str.split_whitespace().collect(); 1518 + let expanded_scope_str = expand_include_scopes(&effective_scope_str).await; 1519 + let requested_scopes: Vec<&str> = expanded_scope_str.split_whitespace().collect(); 1479 1520 let consent_client_id = ClientId::from(request_data.parameters.client_id.clone()); 1480 1521 let preferences = state 1481 1522 .oauth_repo ··· 2407 2448 } 2408 2449 2409 2450 let delegation_from_param = match &form.delegated_did { 2410 - Some(delegated_did_str) => { 2411 - match delegated_did_str.parse::<tranquil_types::Did>() { 2412 - Ok(delegated_did) if delegated_did != user.did => { 2413 - match state 2414 - .delegation_repo 2415 - .get_delegation(&delegated_did, &user.did) 2416 - .await 2417 - { 2418 - Ok(Some(_)) => Some(delegated_did), 2419 - Ok(None) => None, 2420 - Err(e) => { 2421 - tracing::warn!( 2422 - error = %e, 2423 - delegated_did = %delegated_did, 2424 - controller_did = %user.did, 2425 - "Failed to verify delegation relationship" 2426 - ); 2427 - None 2428 - } 2451 + Some(delegated_did_str) => match delegated_did_str.parse::<tranquil_types::Did>() { 2452 + Ok(delegated_did) if delegated_did != user.did => { 2453 + match state 2454 + .delegation_repo 2455 + .get_delegation(&delegated_did, &user.did) 2456 + .await 2457 + { 2458 + Ok(Some(_)) => Some(delegated_did), 2459 + Ok(None) => None, 2460 + Err(e) => { 2461 + tracing::warn!( 2462 + error = %e, 2463 + delegated_did = %delegated_did, 2464 + controller_did = %user.did, 2465 + "Failed to verify delegation relationship" 2466 + ); 2467 + None 2429 2468 } 2430 2469 } 2431 - _ => None, 2432 2470 } 2433 - } 2471 + _ => None, 2472 + }, 2434 2473 None => None, 2435 2474 }; 2436 2475 2437 2476 let is_delegation_flow = delegation_from_param.is_some() 2438 - || request_data.did.as_ref().map_or(false, |existing_did| { 2477 + || request_data.did.as_ref().is_some_and(|existing_did| { 2439 2478 existing_did 2440 2479 .parse::<tranquil_types::Did>() 2441 2480 .ok() 2442 - .map_or(false, |parsed| parsed != user.did) 2481 + .is_some_and(|parsed| parsed != user.did) 2443 2482 }); 2444 2483 2445 2484 if let Some(delegated_did) = delegation_from_param { ··· 3601 3640 ); 3602 3641 Json(serde_json::json!({"redirect_uri": redirect_url})).into_response() 3603 3642 } 3643 + 3644 + pub async fn establish_session( 3645 + State(state): State<AppState>, 3646 + headers: HeaderMap, 3647 + auth: crate::auth::RequiredAuth, 3648 + ) -> Response { 3649 + let user = match auth.0.require_user() { 3650 + Ok(u) => match u.require_active() { 3651 + Ok(u) => u, 3652 + Err(_) => { 3653 + return ( 3654 + StatusCode::FORBIDDEN, 3655 + Json(serde_json::json!({ 3656 + "error": "access_denied", 3657 + "error_description": "Account is deactivated" 3658 + })), 3659 + ) 3660 + .into_response(); 3661 + } 3662 + }, 3663 + Err(_) => { 3664 + return ( 3665 + StatusCode::UNAUTHORIZED, 3666 + Json(serde_json::json!({ 3667 + "error": "invalid_token", 3668 + "error_description": "Authentication required" 3669 + })), 3670 + ) 3671 + .into_response(); 3672 + } 3673 + }; 3674 + let did = &user.did; 3675 + 3676 + let existing_device = extract_device_cookie(&headers); 3677 + 3678 + let (device_id, new_cookie) = match existing_device { 3679 + Some(id) => { 3680 + let device_typed = DeviceIdType::from(id.clone()); 3681 + let _ = state 3682 + .oauth_repo 3683 + .upsert_account_device(did, &device_typed) 3684 + .await; 3685 + (id, None) 3686 + } 3687 + None => { 3688 + let new_id = DeviceId::generate(); 3689 + let device_data = DeviceData { 3690 + session_id: SessionId::generate().0, 3691 + user_agent: extract_user_agent(&headers), 3692 + ip_address: extract_client_ip(&headers), 3693 + last_seen_at: Utc::now(), 3694 + }; 3695 + let device_typed = DeviceIdType::from(new_id.0.clone()); 3696 + 3697 + if let Err(e) = state 3698 + .oauth_repo 3699 + .create_device(&device_typed, &device_data) 3700 + .await 3701 + { 3702 + tracing::error!(error = ?e, "Failed to create device"); 3703 + return ( 3704 + StatusCode::INTERNAL_SERVER_ERROR, 3705 + Json(serde_json::json!({ 3706 + "error": "server_error", 3707 + "error_description": "Failed to establish session" 3708 + })), 3709 + ) 3710 + .into_response(); 3711 + } 3712 + 3713 + if let Err(e) = state 3714 + .oauth_repo 3715 + .upsert_account_device(did, &device_typed) 3716 + .await 3717 + { 3718 + tracing::error!(error = ?e, "Failed to link device to account"); 3719 + return ( 3720 + StatusCode::INTERNAL_SERVER_ERROR, 3721 + Json(serde_json::json!({ 3722 + "error": "server_error", 3723 + "error_description": "Failed to establish session" 3724 + })), 3725 + ) 3726 + .into_response(); 3727 + } 3728 + 3729 + (new_id.0.clone(), Some(make_device_cookie(&new_id.0))) 3730 + } 3731 + }; 3732 + 3733 + tracing::info!(did = %did, device_id = %device_id, "Device session established"); 3734 + 3735 + match new_cookie { 3736 + Some(cookie) => ( 3737 + StatusCode::OK, 3738 + [(SET_COOKIE, cookie)], 3739 + Json(serde_json::json!({ 3740 + "success": true, 3741 + "device_id": device_id 3742 + })), 3743 + ) 3744 + .into_response(), 3745 + None => Json(serde_json::json!({ 3746 + "success": true, 3747 + "device_id": device_id 3748 + })) 3749 + .into_response(), 3750 + } 3751 + }
+24 -49
crates/tranquil-pds/src/oauth/endpoints/delegation.rs
··· 1 - use crate::auth::{extract_auth_token_from_header, validate_token_with_dpop}; 1 + use crate::auth::RequiredAuth; 2 2 use crate::delegation::DelegationActionType; 3 3 use crate::state::{AppState, RateLimitKind}; 4 4 use crate::types::PlainPassword; 5 - use crate::util::{build_full_url, extract_client_ip}; 5 + use crate::util::extract_client_ip; 6 6 use axum::{ 7 7 Json, 8 8 extract::State, ··· 463 463 pub async fn delegation_auth_token( 464 464 State(state): State<AppState>, 465 465 headers: HeaderMap, 466 + auth: RequiredAuth, 466 467 Json(form): Json<DelegationTokenAuthSubmit>, 467 468 ) -> Response { 468 - let auth_header = headers.get("authorization").and_then(|v| v.to_str().ok()); 469 - 470 - let extracted = match extract_auth_token_from_header(auth_header) { 471 - Some(e) => e, 472 - None => { 473 - return ( 474 - StatusCode::UNAUTHORIZED, 475 - Json(DelegationAuthResponse { 469 + let user = match auth.0.require_user() { 470 + Ok(u) => match u.require_active() { 471 + Ok(u) => u, 472 + Err(_) => { 473 + return Json(DelegationAuthResponse { 476 474 success: false, 477 475 needs_totp: None, 478 476 redirect_uri: None, 479 - error: Some("Missing or invalid authorization header".to_string()), 480 - }), 481 - ) 477 + error: Some("Account is deactivated".to_string()), 478 + }) 482 479 .into_response(); 483 - } 484 - }; 485 - 486 - let dpop_proof = headers.get("dpop").and_then(|h| h.to_str().ok()); 487 - let uri = build_full_url("/oauth/delegation/auth-token"); 488 - 489 - let auth_user = match validate_token_with_dpop( 490 - state.user_repo.as_ref(), 491 - state.oauth_repo.as_ref(), 492 - &extracted.token, 493 - extracted.is_dpop, 494 - dpop_proof, 495 - "POST", 496 - &uri, 497 - false, 498 - false, 499 - ) 500 - .await 501 - { 502 - Ok(user) => user, 480 + } 481 + }, 503 482 Err(_) => { 504 - return ( 505 - StatusCode::UNAUTHORIZED, 506 - Json(DelegationAuthResponse { 507 - success: false, 508 - needs_totp: None, 509 - redirect_uri: None, 510 - error: Some("Invalid or expired access token".to_string()), 511 - }), 512 - ) 513 - .into_response(); 483 + return Json(DelegationAuthResponse { 484 + success: false, 485 + needs_totp: None, 486 + redirect_uri: None, 487 + error: Some("Authentication required".to_string()), 488 + }) 489 + .into_response(); 514 490 } 515 491 }; 516 - 517 - let controller_did = auth_user.did; 492 + let controller_did = &user.did; 518 493 519 494 let delegated_did: Did = match form.delegated_did.parse() { 520 495 Ok(d) => d, ··· 558 533 559 534 let grant = match state 560 535 .delegation_repo 561 - .get_delegation(&delegated_did, &controller_did) 536 + .get_delegation(&delegated_did, controller_did) 562 537 .await 563 538 { 564 539 Ok(Some(g)) => g, ··· 599 574 600 575 if state 601 576 .oauth_repo 602 - .set_controller_did(&request_id, &controller_did) 577 + .set_controller_did(&request_id, controller_did) 603 578 .await 604 579 .is_err() 605 580 { ··· 622 597 .delegation_repo 623 598 .log_delegation_action( 624 599 &delegated_did, 625 - &controller_did, 626 - Some(&controller_did), 600 + controller_did, 601 + Some(controller_did), 627 602 DelegationActionType::TokenIssued, 628 603 Some(serde_json::json!({ 629 604 "client_id": request.client_id,
+31 -13
crates/tranquil-pds/src/oauth/verify.rs
··· 10 10 use sha2::Sha256; 11 11 use subtle::ConstantTimeEq; 12 12 use tranquil_db_traits::{OAuthRepository, UserRepository}; 13 - use tranquil_types::TokenId; 13 + use tranquil_types::{ClientId, TokenId}; 14 + 15 + use crate::types::Did; 14 16 15 17 use super::scopes::ScopePermissions; 16 18 use super::{DPoPVerifier, OAuthError}; ··· 27 29 } 28 30 29 31 pub struct VerifyResult { 30 - pub did: String, 31 - pub token_id: String, 32 - pub client_id: String, 32 + pub did: Did, 33 + pub token_id: TokenId, 34 + pub client_id: ClientId, 33 35 pub scope: Option<String>, 34 36 } 35 37 ··· 91 93 )); 92 94 } 93 95 } 96 + let did: Did = token_data 97 + .did 98 + .parse() 99 + .map_err(|_| OAuthError::InvalidToken("Invalid DID in token".to_string()))?; 94 100 Ok(VerifyResult { 95 - did: token_data.did, 96 - token_id: token_id.to_string(), 97 - client_id: token_data.client_id, 101 + did, 102 + token_id, 103 + client_id: ClientId::from(token_data.client_id), 98 104 scope: token_data.scope, 99 105 }) 100 106 } ··· 202 208 } 203 209 204 210 pub struct OAuthUser { 205 - pub did: String, 206 - pub client_id: Option<String>, 211 + pub did: Did, 212 + pub client_id: Option<ClientId>, 207 213 pub scope: Option<String>, 208 214 pub is_oauth: bool, 209 215 pub permissions: ScopePermissions, ··· 382 388 } 383 389 384 390 struct LegacyAuthResult { 385 - did: String, 391 + did: Did, 386 392 } 387 393 388 394 async fn try_legacy_auth( ··· 390 396 token: &str, 391 397 ) -> Result<LegacyAuthResult, ()> { 392 398 match crate::auth::validate_bearer_token(user_repo, token).await { 393 - Ok(user) if !user.is_oauth => Ok(LegacyAuthResult { 394 - did: user.did.to_string(), 395 - }), 399 + Ok(user) if !user.is_oauth => Ok(LegacyAuthResult { did: user.did }), 396 400 _ => Err(()), 397 401 } 398 402 } 403 + 404 + pub async fn dpop_nonce_middleware( 405 + req: axum::http::Request<axum::body::Body>, 406 + next: axum::middleware::Next, 407 + ) -> Response { 408 + let mut response = next.run(req).await; 409 + let config = AuthConfig::get(); 410 + let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 411 + let nonce = verifier.generate_nonce(); 412 + if let Ok(nonce_val) = nonce.parse() { 413 + response.headers_mut().insert("DPoP-Nonce", nonce_val); 414 + } 415 + response 416 + }
+4 -2
crates/tranquil-pds/src/sso/endpoints.rs
··· 644 644 645 645 pub async fn get_linked_accounts( 646 646 State(state): State<AppState>, 647 - crate::auth::extractor::BearerAuth(auth): crate::auth::extractor::BearerAuth, 647 + auth: crate::auth::RequiredAuth, 648 648 ) -> Result<Json<LinkedAccountsResponse>, ApiError> { 649 + let auth = auth.0.require_user()?.require_active()?; 649 650 let identities = state 650 651 .sso_repo 651 652 .get_external_identities_by_did(&auth.did) ··· 679 680 680 681 pub async fn unlink_account( 681 682 State(state): State<AppState>, 682 - crate::auth::extractor::BearerAuth(auth): crate::auth::extractor::BearerAuth, 683 + auth: crate::auth::RequiredAuth, 683 684 Json(input): Json<UnlinkAccountRequest>, 684 685 ) -> Result<Json<UnlinkAccountResponse>, ApiError> { 686 + let auth = auth.0.require_user()?.require_active()?; 685 687 if !state 686 688 .check_rate_limit(RateLimitKind::SsoUnlink, auth.did.as_str()) 687 689 .await
+583
crates/tranquil-pds/tests/auth_extractor.rs
··· 1 + mod common; 2 + mod helpers; 3 + 4 + use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 5 + use chrono::Utc; 6 + use common::{base_url, client, create_account_and_login, pds_endpoint}; 7 + use helpers::verify_new_account; 8 + use reqwest::StatusCode; 9 + use serde_json::{Value, json}; 10 + use sha2::{Digest, Sha256}; 11 + use wiremock::matchers::{method, path}; 12 + use wiremock::{Mock, MockServer, ResponseTemplate}; 13 + 14 + fn generate_pkce() -> (String, String) { 15 + let verifier_bytes: [u8; 32] = rand::random(); 16 + let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); 17 + let mut hasher = Sha256::new(); 18 + hasher.update(code_verifier.as_bytes()); 19 + let code_challenge = URL_SAFE_NO_PAD.encode(hasher.finalize()); 20 + (code_verifier, code_challenge) 21 + } 22 + 23 + async fn setup_mock_client_metadata(redirect_uri: &str, dpop_bound: bool) -> MockServer { 24 + let mock_server = MockServer::start().await; 25 + let metadata = json!({ 26 + "client_id": mock_server.uri(), 27 + "client_name": "Auth Extractor Test Client", 28 + "redirect_uris": [redirect_uri], 29 + "grant_types": ["authorization_code", "refresh_token"], 30 + "response_types": ["code"], 31 + "token_endpoint_auth_method": "none", 32 + "dpop_bound_access_tokens": dpop_bound 33 + }); 34 + Mock::given(method("GET")) 35 + .and(path("/")) 36 + .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 37 + .mount(&mock_server) 38 + .await; 39 + mock_server 40 + } 41 + 42 + async fn get_oauth_session( 43 + http_client: &reqwest::Client, 44 + url: &str, 45 + dpop_bound: bool, 46 + ) -> (String, String, String, String) { 47 + let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 48 + let handle = format!("ae{}", suffix); 49 + let password = "AuthExtract123!"; 50 + let create_res = http_client 51 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 52 + .json(&json!({ 53 + "handle": handle, 54 + "email": format!("{}@example.com", handle), 55 + "password": password 56 + })) 57 + .send() 58 + .await 59 + .unwrap(); 60 + assert_eq!(create_res.status(), StatusCode::OK); 61 + let account: Value = create_res.json().await.unwrap(); 62 + let did = account["did"].as_str().unwrap().to_string(); 63 + verify_new_account(http_client, &did).await; 64 + 65 + let redirect_uri = "https://example.com/auth-callback"; 66 + let mock_client = setup_mock_client_metadata(redirect_uri, dpop_bound).await; 67 + let client_id = mock_client.uri(); 68 + let (code_verifier, code_challenge) = generate_pkce(); 69 + 70 + let par_body: Value = http_client 71 + .post(format!("{}/oauth/par", url)) 72 + .form(&[ 73 + ("response_type", "code"), 74 + ("client_id", &client_id), 75 + ("redirect_uri", redirect_uri), 76 + ("code_challenge", &code_challenge), 77 + ("code_challenge_method", "S256"), 78 + ]) 79 + .send() 80 + .await 81 + .unwrap() 82 + .json() 83 + .await 84 + .unwrap(); 85 + let request_uri = par_body["request_uri"].as_str().unwrap(); 86 + 87 + let auth_res = http_client 88 + .post(format!("{}/oauth/authorize", url)) 89 + .header("Content-Type", "application/json") 90 + .header("Accept", "application/json") 91 + .json(&json!({ 92 + "request_uri": request_uri, 93 + "username": &handle, 94 + "password": password, 95 + "remember_device": false 96 + })) 97 + .send() 98 + .await 99 + .unwrap(); 100 + let auth_body: Value = auth_res.json().await.unwrap(); 101 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 102 + 103 + if location.contains("/oauth/consent") { 104 + let consent_res = http_client 105 + .post(format!("{}/oauth/authorize/consent", url)) 106 + .header("Content-Type", "application/json") 107 + .json(&json!({ 108 + "request_uri": request_uri, 109 + "approved_scopes": ["atproto"], 110 + "remember": false 111 + })) 112 + .send() 113 + .await 114 + .unwrap(); 115 + let consent_body: Value = consent_res.json().await.unwrap(); 116 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 117 + } 118 + 119 + let code = location 120 + .split("code=") 121 + .nth(1) 122 + .unwrap() 123 + .split('&') 124 + .next() 125 + .unwrap(); 126 + 127 + let token_body: Value = http_client 128 + .post(format!("{}/oauth/token", url)) 129 + .form(&[ 130 + ("grant_type", "authorization_code"), 131 + ("code", code), 132 + ("redirect_uri", redirect_uri), 133 + ("code_verifier", &code_verifier), 134 + ("client_id", &client_id), 135 + ]) 136 + .send() 137 + .await 138 + .unwrap() 139 + .json() 140 + .await 141 + .unwrap(); 142 + 143 + ( 144 + token_body["access_token"].as_str().unwrap().to_string(), 145 + token_body["refresh_token"].as_str().unwrap().to_string(), 146 + client_id, 147 + did, 148 + ) 149 + } 150 + 151 + #[tokio::test] 152 + async fn test_oauth_token_works_with_bearer_auth() { 153 + let url = base_url().await; 154 + let http_client = client(); 155 + let (access_token, _, _, did) = get_oauth_session(&http_client, url, false).await; 156 + 157 + let res = http_client 158 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 159 + .bearer_auth(&access_token) 160 + .send() 161 + .await 162 + .unwrap(); 163 + 164 + assert_eq!( 165 + res.status(), 166 + StatusCode::OK, 167 + "OAuth token should work with BearerAuth extractor" 168 + ); 169 + let body: Value = res.json().await.unwrap(); 170 + assert_eq!(body["did"].as_str().unwrap(), did); 171 + } 172 + 173 + #[tokio::test] 174 + async fn test_session_token_still_works() { 175 + let url = base_url().await; 176 + let http_client = client(); 177 + let (jwt, did) = create_account_and_login(&http_client).await; 178 + 179 + let res = http_client 180 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 181 + .bearer_auth(&jwt) 182 + .send() 183 + .await 184 + .unwrap(); 185 + 186 + assert_eq!( 187 + res.status(), 188 + StatusCode::OK, 189 + "Session token should still work" 190 + ); 191 + let body: Value = res.json().await.unwrap(); 192 + assert_eq!(body["did"].as_str().unwrap(), did); 193 + } 194 + 195 + #[tokio::test] 196 + async fn test_oauth_admin_extractor_allows_oauth_tokens() { 197 + let url = base_url().await; 198 + let http_client = client(); 199 + 200 + let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 201 + let handle = format!("adm{}", suffix); 202 + let password = "AdminOAuth123!"; 203 + let create_res = http_client 204 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 205 + .json(&json!({ 206 + "handle": handle, 207 + "email": format!("{}@example.com", handle), 208 + "password": password 209 + })) 210 + .send() 211 + .await 212 + .unwrap(); 213 + assert_eq!(create_res.status(), StatusCode::OK); 214 + let account: Value = create_res.json().await.unwrap(); 215 + let did = account["did"].as_str().unwrap().to_string(); 216 + verify_new_account(&http_client, &did).await; 217 + 218 + let pool = common::get_test_db_pool().await; 219 + sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did) 220 + .execute(pool) 221 + .await 222 + .expect("Failed to mark user as admin"); 223 + 224 + let redirect_uri = "https://example.com/admin-callback"; 225 + let mock_client = setup_mock_client_metadata(redirect_uri, false).await; 226 + let client_id = mock_client.uri(); 227 + let (code_verifier, code_challenge) = generate_pkce(); 228 + 229 + let par_body: Value = http_client 230 + .post(format!("{}/oauth/par", url)) 231 + .form(&[ 232 + ("response_type", "code"), 233 + ("client_id", &client_id), 234 + ("redirect_uri", redirect_uri), 235 + ("code_challenge", &code_challenge), 236 + ("code_challenge_method", "S256"), 237 + ]) 238 + .send() 239 + .await 240 + .unwrap() 241 + .json() 242 + .await 243 + .unwrap(); 244 + let request_uri = par_body["request_uri"].as_str().unwrap(); 245 + 246 + let auth_res = http_client 247 + .post(format!("{}/oauth/authorize", url)) 248 + .header("Content-Type", "application/json") 249 + .header("Accept", "application/json") 250 + .json(&json!({ 251 + "request_uri": request_uri, 252 + "username": &handle, 253 + "password": password, 254 + "remember_device": false 255 + })) 256 + .send() 257 + .await 258 + .unwrap(); 259 + let auth_body: Value = auth_res.json().await.unwrap(); 260 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 261 + if location.contains("/oauth/consent") { 262 + let consent_res = http_client 263 + .post(format!("{}/oauth/authorize/consent", url)) 264 + .header("Content-Type", "application/json") 265 + .json(&json!({ 266 + "request_uri": request_uri, 267 + "approved_scopes": ["atproto"], 268 + "remember": false 269 + })) 270 + .send() 271 + .await 272 + .unwrap(); 273 + let consent_body: Value = consent_res.json().await.unwrap(); 274 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 275 + } 276 + 277 + let code = location 278 + .split("code=") 279 + .nth(1) 280 + .unwrap() 281 + .split('&') 282 + .next() 283 + .unwrap(); 284 + let token_body: Value = http_client 285 + .post(format!("{}/oauth/token", url)) 286 + .form(&[ 287 + ("grant_type", "authorization_code"), 288 + ("code", code), 289 + ("redirect_uri", redirect_uri), 290 + ("code_verifier", &code_verifier), 291 + ("client_id", &client_id), 292 + ]) 293 + .send() 294 + .await 295 + .unwrap() 296 + .json() 297 + .await 298 + .unwrap(); 299 + let access_token = token_body["access_token"].as_str().unwrap(); 300 + 301 + let res = http_client 302 + .get(format!( 303 + "{}/xrpc/com.atproto.admin.getAccountInfos?dids={}", 304 + url, did 305 + )) 306 + .bearer_auth(access_token) 307 + .send() 308 + .await 309 + .unwrap(); 310 + 311 + assert_eq!( 312 + res.status(), 313 + StatusCode::OK, 314 + "OAuth token for admin user should work with admin endpoint" 315 + ); 316 + } 317 + 318 + #[tokio::test] 319 + async fn test_expired_oauth_token_returns_proper_error() { 320 + let url = base_url().await; 321 + let http_client = client(); 322 + 323 + let now = Utc::now().timestamp(); 324 + let header = json!({"alg": "HS256", "typ": "at+jwt"}); 325 + let payload = json!({ 326 + "iss": url, 327 + "sub": "did:plc:test123", 328 + "aud": url, 329 + "iat": now - 7200, 330 + "exp": now - 3600, 331 + "jti": "expired-token", 332 + "sid": "expired-session", 333 + "scope": "atproto", 334 + "client_id": "https://example.com" 335 + }); 336 + let fake_token = format!( 337 + "{}.{}.{}", 338 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()), 339 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), 340 + URL_SAFE_NO_PAD.encode([1u8; 32]) 341 + ); 342 + 343 + let res = http_client 344 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 345 + .bearer_auth(&fake_token) 346 + .send() 347 + .await 348 + .unwrap(); 349 + 350 + assert_eq!( 351 + res.status(), 352 + StatusCode::UNAUTHORIZED, 353 + "Expired token should be rejected" 354 + ); 355 + } 356 + 357 + #[tokio::test] 358 + async fn test_dpop_nonce_error_has_proper_headers() { 359 + let url = base_url().await; 360 + let pds_url = pds_endpoint(); 361 + let http_client = client(); 362 + 363 + let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 364 + let handle = format!("dpop{}", suffix); 365 + let create_res = http_client 366 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 367 + .json(&json!({ 368 + "handle": handle, 369 + "email": format!("{}@test.com", handle), 370 + "password": "DpopTest123!" 371 + })) 372 + .send() 373 + .await 374 + .unwrap(); 375 + assert_eq!(create_res.status(), StatusCode::OK); 376 + let account: Value = create_res.json().await.unwrap(); 377 + let did = account["did"].as_str().unwrap(); 378 + verify_new_account(&http_client, did).await; 379 + 380 + let redirect_uri = "https://example.com/dpop-callback"; 381 + let mock_server = MockServer::start().await; 382 + let client_id = mock_server.uri(); 383 + let metadata = json!({ 384 + "client_id": &client_id, 385 + "client_name": "DPoP Test Client", 386 + "redirect_uris": [redirect_uri], 387 + "grant_types": ["authorization_code", "refresh_token"], 388 + "response_types": ["code"], 389 + "token_endpoint_auth_method": "none", 390 + "dpop_bound_access_tokens": true 391 + }); 392 + Mock::given(method("GET")) 393 + .and(path("/")) 394 + .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 395 + .mount(&mock_server) 396 + .await; 397 + 398 + let (code_verifier, code_challenge) = generate_pkce(); 399 + let par_body: Value = http_client 400 + .post(format!("{}/oauth/par", url)) 401 + .form(&[ 402 + ("response_type", "code"), 403 + ("client_id", &client_id), 404 + ("redirect_uri", redirect_uri), 405 + ("code_challenge", &code_challenge), 406 + ("code_challenge_method", "S256"), 407 + ]) 408 + .send() 409 + .await 410 + .unwrap() 411 + .json() 412 + .await 413 + .unwrap(); 414 + 415 + let request_uri = par_body["request_uri"].as_str().unwrap(); 416 + let auth_res = http_client 417 + .post(format!("{}/oauth/authorize", url)) 418 + .header("Content-Type", "application/json") 419 + .header("Accept", "application/json") 420 + .json(&json!({ 421 + "request_uri": request_uri, 422 + "username": &handle, 423 + "password": "DpopTest123!", 424 + "remember_device": false 425 + })) 426 + .send() 427 + .await 428 + .unwrap(); 429 + let auth_body: Value = auth_res.json().await.unwrap(); 430 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 431 + if location.contains("/oauth/consent") { 432 + let consent_res = http_client 433 + .post(format!("{}/oauth/authorize/consent", url)) 434 + .header("Content-Type", "application/json") 435 + .json(&json!({ 436 + "request_uri": request_uri, 437 + "approved_scopes": ["atproto"], 438 + "remember": false 439 + })) 440 + .send() 441 + .await 442 + .unwrap(); 443 + let consent_body: Value = consent_res.json().await.unwrap(); 444 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 445 + } 446 + 447 + let code = location 448 + .split("code=") 449 + .nth(1) 450 + .unwrap() 451 + .split('&') 452 + .next() 453 + .unwrap(); 454 + 455 + let token_endpoint = format!("{}/oauth/token", pds_url); 456 + let (_, dpop_proof) = generate_dpop_proof("POST", &token_endpoint, None); 457 + 458 + let token_res = http_client 459 + .post(format!("{}/oauth/token", url)) 460 + .header("DPoP", &dpop_proof) 461 + .form(&[ 462 + ("grant_type", "authorization_code"), 463 + ("code", code), 464 + ("redirect_uri", redirect_uri), 465 + ("code_verifier", &code_verifier), 466 + ("client_id", &client_id), 467 + ]) 468 + .send() 469 + .await 470 + .unwrap(); 471 + 472 + let token_status = token_res.status(); 473 + let token_nonce = token_res 474 + .headers() 475 + .get("dpop-nonce") 476 + .map(|h| h.to_str().unwrap().to_string()); 477 + let token_body: Value = token_res.json().await.unwrap(); 478 + 479 + let access_token = if token_status == StatusCode::OK { 480 + token_body["access_token"].as_str().unwrap().to_string() 481 + } else if token_body.get("error").and_then(|e| e.as_str()) == Some("use_dpop_nonce") { 482 + let nonce = 483 + token_nonce.expect("Token endpoint should return DPoP-Nonce on use_dpop_nonce error"); 484 + let (_, dpop_proof_with_nonce) = generate_dpop_proof("POST", &token_endpoint, Some(&nonce)); 485 + 486 + let retry_res = http_client 487 + .post(format!("{}/oauth/token", url)) 488 + .header("DPoP", &dpop_proof_with_nonce) 489 + .form(&[ 490 + ("grant_type", "authorization_code"), 491 + ("code", code), 492 + ("redirect_uri", redirect_uri), 493 + ("code_verifier", &code_verifier), 494 + ("client_id", &client_id), 495 + ]) 496 + .send() 497 + .await 498 + .unwrap(); 499 + let retry_body: Value = retry_res.json().await.unwrap(); 500 + retry_body["access_token"] 501 + .as_str() 502 + .expect("Should get access_token after nonce retry") 503 + .to_string() 504 + } else { 505 + panic!("Token exchange failed unexpectedly: {:?}", token_body); 506 + }; 507 + 508 + let res = http_client 509 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 510 + .header("Authorization", format!("DPoP {}", access_token)) 511 + .send() 512 + .await 513 + .unwrap(); 514 + 515 + assert_eq!( 516 + res.status(), 517 + StatusCode::UNAUTHORIZED, 518 + "DPoP token without proof should fail" 519 + ); 520 + 521 + let www_auth = res 522 + .headers() 523 + .get("www-authenticate") 524 + .map(|h| h.to_str().unwrap()); 525 + assert!(www_auth.is_some(), "Should have WWW-Authenticate header"); 526 + assert!( 527 + www_auth.unwrap().contains("use_dpop_nonce"), 528 + "WWW-Authenticate should indicate dpop nonce required" 529 + ); 530 + 531 + let nonce = res.headers().get("dpop-nonce").map(|h| h.to_str().unwrap()); 532 + assert!(nonce.is_some(), "Should return DPoP-Nonce header"); 533 + 534 + let body: Value = res.json().await.unwrap(); 535 + assert_eq!(body["error"].as_str().unwrap(), "use_dpop_nonce"); 536 + } 537 + 538 + fn generate_dpop_proof(method: &str, uri: &str, nonce: Option<&str>) -> (Value, String) { 539 + use p256::ecdsa::{SigningKey, signature::Signer}; 540 + use p256::elliptic_curve::rand_core::OsRng; 541 + 542 + let signing_key = SigningKey::random(&mut OsRng); 543 + let verifying_key = signing_key.verifying_key(); 544 + let point = verifying_key.to_encoded_point(false); 545 + let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 546 + let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 547 + 548 + let jwk = json!({ 549 + "kty": "EC", 550 + "crv": "P-256", 551 + "x": x, 552 + "y": y 553 + }); 554 + 555 + let header = { 556 + let h = json!({ 557 + "typ": "dpop+jwt", 558 + "alg": "ES256", 559 + "jwk": jwk.clone() 560 + }); 561 + h 562 + }; 563 + 564 + let mut payload = json!({ 565 + "jti": uuid::Uuid::new_v4().to_string(), 566 + "htm": method, 567 + "htu": uri, 568 + "iat": Utc::now().timestamp() 569 + }); 570 + if let Some(n) = nonce { 571 + payload["nonce"] = json!(n); 572 + } 573 + 574 + let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 575 + let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 576 + let signing_input = format!("{}.{}", header_b64, payload_b64); 577 + 578 + let signature: p256::ecdsa::Signature = signing_key.sign(signing_input.as_bytes()); 579 + let sig_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 580 + 581 + let proof = format!("{}.{}", signing_input, sig_b64); 582 + (jwk, proof) 583 + }
+3 -3
crates/tranquil-pds/tests/common/mod.rs
··· 1 - #[cfg(feature = "s3-storage")] 1 + #[cfg(all(not(feature = "external-infra"), feature = "s3-storage"))] 2 2 use aws_config::BehaviorVersion; 3 - #[cfg(feature = "s3-storage")] 3 + #[cfg(all(not(feature = "external-infra"), feature = "s3-storage"))] 4 4 use aws_sdk_s3::Client as S3Client; 5 - #[cfg(feature = "s3-storage")] 5 + #[cfg(all(not(feature = "external-infra"), feature = "s3-storage"))] 6 6 use aws_sdk_s3::config::Credentials; 7 7 use chrono::Utc; 8 8 use reqwest::{Client, StatusCode, header};
+8 -2
crates/tranquil-pds/tests/oauth_security.rs
··· 1373 1373 .send() 1374 1374 .await 1375 1375 .unwrap(); 1376 - assert_eq!(token_res.status(), StatusCode::OK, "Token exchange should succeed"); 1376 + assert_eq!( 1377 + token_res.status(), 1378 + StatusCode::OK, 1379 + "Token exchange should succeed" 1380 + ); 1377 1381 let tokens: Value = token_res.json().await.unwrap(); 1378 1382 1379 - let sub = tokens["sub"].as_str().expect("Token response should have sub claim"); 1383 + let sub = tokens["sub"] 1384 + .as_str() 1385 + .expect("Token response should have sub claim"); 1380 1386 1381 1387 assert_eq!( 1382 1388 sub, delegated_did,
+2
crates/tranquil-scopes/Cargo.toml
··· 7 7 [dependencies] 8 8 axum = { workspace = true } 9 9 futures = { workspace = true } 10 + hickory-resolver = { version = "0.24", features = ["tokio-runtime"] } 10 11 reqwest = { workspace = true } 11 12 serde = { workspace = true } 12 13 serde_json = { workspace = true } 13 14 tokio = { workspace = true } 14 15 tracing = { workspace = true } 16 + urlencoding = "2"
+521 -59
crates/tranquil-scopes/src/permission_set.rs
··· 1 + use hickory_resolver::TokioAsyncResolver; 1 2 use reqwest::Client; 2 3 use serde::Deserialize; 3 4 use std::collections::HashMap; ··· 17 18 const CACHE_TTL_SECS: u64 = 3600; 18 19 19 20 #[derive(Debug, Deserialize)] 21 + struct PlcDocument { 22 + service: Vec<PlcService>, 23 + } 24 + 25 + #[derive(Debug, Deserialize)] 26 + struct PlcService { 27 + id: String, 28 + #[serde(rename = "serviceEndpoint")] 29 + service_endpoint: String, 30 + } 31 + 32 + #[derive(Debug, Deserialize)] 33 + struct GetRecordResponse { 34 + value: LexiconDoc, 35 + } 36 + 37 + #[derive(Debug, Deserialize)] 20 38 struct LexiconDoc { 21 39 defs: HashMap<String, LexiconDef>, 22 40 } ··· 31 49 #[derive(Debug, Deserialize)] 32 50 struct PermissionEntry { 33 51 resource: String, 52 + action: Option<Vec<String>>, 34 53 collection: Option<Vec<String>>, 54 + lxm: Option<Vec<String>>, 55 + aud: Option<String>, 35 56 } 36 57 37 58 pub async fn expand_include_scopes(scope_string: &str) -> String { ··· 39 60 .split_whitespace() 40 61 .map(|scope| async move { 41 62 match scope.strip_prefix("include:") { 42 - Some(nsid) => { 43 - let nsid_base = nsid.split('?').next().unwrap_or(nsid); 44 - expand_permission_set(nsid_base).await.unwrap_or_else(|e| { 45 - warn!(nsid = nsid_base, error = %e, "Failed to expand permission set, keeping original"); 46 - scope.to_string() 47 - }) 63 + Some(rest) => { 64 + let (nsid_base, aud) = parse_include_scope(rest); 65 + expand_permission_set(nsid_base, aud) 66 + .await 67 + .unwrap_or_else(|e| { 68 + warn!(nsid = nsid_base, error = %e, "Failed to expand permission set, keeping original"); 69 + scope.to_string() 70 + }) 48 71 } 49 72 None => scope.to_string(), 50 73 } ··· 54 77 futures::future::join_all(futures).await.join(" ") 55 78 } 56 79 57 - async fn expand_permission_set(nsid: &str) -> Result<String, String> { 80 + fn parse_include_scope(rest: &str) -> (&str, Option<&str>) { 81 + rest.split_once('?') 82 + .map(|(nsid, params)| { 83 + let aud = params.split('&').find_map(|p| p.strip_prefix("aud=")); 84 + (nsid, aud) 85 + }) 86 + .unwrap_or((rest, None)) 87 + } 88 + 89 + async fn expand_permission_set(nsid: &str, aud: Option<&str>) -> Result<String, String> { 90 + let cache_key = match aud { 91 + Some(a) => format!("{}?aud={}", nsid, a), 92 + None => nsid.to_string(), 93 + }; 94 + 58 95 { 59 96 let cache = LEXICON_CACHE.read().await; 60 - if let Some(cached) = cache.get(nsid) 97 + if let Some(cached) = cache.get(&cache_key) 61 98 && cached.cached_at.elapsed().as_secs() < CACHE_TTL_SECS 62 99 { 63 100 debug!(nsid, "Using cached permission set expansion"); ··· 65 102 } 66 103 } 67 104 105 + let lexicon = fetch_lexicon_via_atproto(nsid).await?; 106 + 107 + let main_def = lexicon 108 + .defs 109 + .get("main") 110 + .ok_or("Missing 'main' definition in lexicon")?; 111 + 112 + if main_def.def_type != "permission-set" { 113 + return Err(format!( 114 + "Expected permission-set type, got: {}", 115 + main_def.def_type 116 + )); 117 + } 118 + 119 + let permissions = main_def 120 + .permissions 121 + .as_ref() 122 + .ok_or("Missing permissions in permission-set")?; 123 + 124 + let namespace_authority = extract_namespace_authority(nsid); 125 + let expanded = build_expanded_scopes(permissions, aud, &namespace_authority); 126 + 127 + if expanded.is_empty() { 128 + return Err("No valid permissions found in permission-set".to_string()); 129 + } 130 + 131 + { 132 + let mut cache = LEXICON_CACHE.write().await; 133 + cache.insert( 134 + cache_key, 135 + CachedLexicon { 136 + expanded_scope: expanded.clone(), 137 + cached_at: std::time::Instant::now(), 138 + }, 139 + ); 140 + } 141 + 142 + debug!(nsid, expanded = %expanded, "Successfully expanded permission set"); 143 + Ok(expanded) 144 + } 145 + 146 + async fn fetch_lexicon_via_atproto(nsid: &str) -> Result<LexiconDoc, String> { 68 147 let parts: Vec<&str> = nsid.split('.').collect(); 69 148 if parts.len() < 3 { 70 149 return Err(format!("Invalid NSID format: {}", nsid)); 71 150 } 72 151 73 - let domain_parts: Vec<&str> = parts[..2].iter().rev().cloned().collect(); 74 - let domain = domain_parts.join("."); 75 - let path = parts[2..].join("/"); 152 + let authority = parts[..2] 153 + .iter() 154 + .rev() 155 + .cloned() 156 + .collect::<Vec<_>>() 157 + .join("."); 158 + debug!(nsid, authority = %authority, "Resolving lexicon DID authority via DNS"); 76 159 77 - let url = format!("https://{}/lexicons/{}.json", domain, path); 78 - debug!(nsid, url = %url, "Fetching permission set lexicon"); 160 + let did = resolve_lexicon_did_authority(&authority).await?; 161 + debug!(nsid, did = %did, "Resolved lexicon DID authority"); 162 + 163 + let pds_endpoint = resolve_did_to_pds(&did).await?; 164 + debug!(nsid, pds = %pds_endpoint, "Resolved DID to PDS endpoint"); 79 165 80 166 let client = Client::builder() 81 167 .timeout(std::time::Duration::from_secs(10)) 82 168 .build() 83 169 .map_err(|e| format!("Failed to create HTTP client: {}", e))?; 84 170 171 + let url = format!( 172 + "{}/xrpc/com.atproto.repo.getRecord?repo={}&collection=com.atproto.lexicon.schema&rkey={}", 173 + pds_endpoint, 174 + urlencoding::encode(&did), 175 + urlencoding::encode(nsid) 176 + ); 177 + debug!(nsid, url = %url, "Fetching lexicon from PDS"); 178 + 85 179 let response = client 86 180 .get(&url) 87 181 .header("Accept", "application/json") ··· 96 190 )); 97 191 } 98 192 99 - let lexicon: LexiconDoc = response 193 + let record: GetRecordResponse = response 100 194 .json() 101 195 .await 102 - .map_err(|e| format!("Failed to parse lexicon: {}", e))?; 196 + .map_err(|e| format!("Failed to parse lexicon response: {}", e))?; 103 197 104 - let main_def = lexicon 105 - .defs 106 - .get("main") 107 - .ok_or("Missing 'main' definition in lexicon")?; 198 + Ok(record.value) 199 + } 108 200 109 - if main_def.def_type != "permission-set" { 110 - return Err(format!( 111 - "Expected permission-set type, got: {}", 112 - main_def.def_type 113 - )); 201 + async fn resolve_lexicon_did_authority(authority: &str) -> Result<String, String> { 202 + let resolver = TokioAsyncResolver::tokio_from_system_conf() 203 + .map_err(|e| format!("Failed to create DNS resolver: {}", e))?; 204 + 205 + let dns_name = format!("_lexicon.{}", authority); 206 + debug!(dns_name = %dns_name, "Looking up DNS TXT record"); 207 + 208 + let txt_records = resolver 209 + .txt_lookup(&dns_name) 210 + .await 211 + .map_err(|e| format!("DNS lookup failed for {}: {}", dns_name, e))?; 212 + 213 + txt_records 214 + .iter() 215 + .flat_map(|record| record.iter()) 216 + .find_map(|data| { 217 + let txt = String::from_utf8_lossy(data); 218 + txt.strip_prefix("did=").map(|did| did.to_string()) 219 + }) 220 + .ok_or_else(|| format!("No valid did= TXT record found at {}", dns_name)) 221 + } 222 + 223 + async fn resolve_did_to_pds(did: &str) -> Result<String, String> { 224 + let client = Client::builder() 225 + .timeout(std::time::Duration::from_secs(10)) 226 + .build() 227 + .map_err(|e| format!("Failed to create HTTP client: {}", e))?; 228 + 229 + let url = if did.starts_with("did:plc:") { 230 + format!("https://plc.directory/{}", did) 231 + } else if did.starts_with("did:web:") { 232 + let domain = did.strip_prefix("did:web:").unwrap(); 233 + format!("https://{}/.well-known/did.json", domain) 234 + } else { 235 + return Err(format!("Unsupported DID method: {}", did)); 236 + }; 237 + 238 + let response = client 239 + .get(&url) 240 + .header("Accept", "application/json") 241 + .send() 242 + .await 243 + .map_err(|e| format!("Failed to resolve DID: {}", e))?; 244 + 245 + if !response.status().is_success() { 246 + return Err(format!("Failed to resolve DID: HTTP {}", response.status())); 114 247 } 115 248 116 - let permissions = main_def 117 - .permissions 118 - .as_ref() 119 - .ok_or("Missing permissions in permission-set")?; 249 + let doc: PlcDocument = response 250 + .json() 251 + .await 252 + .map_err(|e| format!("Failed to parse DID document: {}", e))?; 120 253 121 - let mut collections: Vec<String> = permissions 254 + doc.service 122 255 .iter() 123 - .filter(|perm| perm.resource == "repo") 124 - .filter_map(|perm| perm.collection.as_ref()) 125 - .flatten() 126 - .cloned() 127 - .collect(); 256 + .find(|s| s.id == "#atproto_pds") 257 + .map(|s| s.service_endpoint.clone()) 258 + .ok_or_else(|| "No #atproto_pds service found in DID document".to_string()) 259 + } 128 260 129 - if collections.is_empty() { 130 - return Err("No repo collections found in permission-set".to_string()); 261 + fn extract_namespace_authority(nsid: &str) -> String { 262 + let parts: Vec<&str> = nsid.split('.').collect(); 263 + if parts.len() >= 2 { 264 + parts[..parts.len() - 1].join(".") 265 + } else { 266 + nsid.to_string() 131 267 } 268 + } 132 269 133 - collections.sort(); 270 + fn is_under_authority(target_nsid: &str, authority: &str) -> bool { 271 + target_nsid.starts_with(authority) 272 + && target_nsid 273 + .chars() 274 + .nth(authority.len()) 275 + .is_some_and(|c| c == '.') 276 + } 277 + 278 + const DEFAULT_ACTIONS: &[&str] = &["create", "update", "delete"]; 279 + 280 + fn build_expanded_scopes( 281 + permissions: &[PermissionEntry], 282 + default_aud: Option<&str>, 283 + namespace_authority: &str, 284 + ) -> String { 285 + let mut scopes: Vec<String> = Vec::new(); 134 286 135 - let collection_params: Vec<String> = collections 287 + permissions 136 288 .iter() 137 - .map(|c| format!("collection={}", c)) 138 - .collect(); 289 + .for_each(|perm| match perm.resource.as_str() { 290 + "repo" => { 291 + if let Some(collections) = &perm.collection { 292 + let actions: Vec<&str> = perm 293 + .action 294 + .as_ref() 295 + .map(|a| a.iter().map(String::as_str).collect()) 296 + .unwrap_or_else(|| DEFAULT_ACTIONS.to_vec()); 139 297 140 - let expanded = format!("repo?{}", collection_params.join("&")); 298 + collections 299 + .iter() 300 + .filter(|coll| is_under_authority(coll, namespace_authority)) 301 + .for_each(|coll| { 302 + actions.iter().for_each(|action| { 303 + scopes.push(format!("repo:{}?action={}", coll, action)); 304 + }); 305 + }); 306 + } 307 + } 308 + "rpc" => { 309 + if let Some(lxms) = &perm.lxm { 310 + let perm_aud = perm.aud.as_deref().or(default_aud); 141 311 142 - { 143 - let mut cache = LEXICON_CACHE.write().await; 144 - cache.insert( 145 - nsid.to_string(), 146 - CachedLexicon { 147 - expanded_scope: expanded.clone(), 148 - cached_at: std::time::Instant::now(), 149 - }, 150 - ); 151 - } 312 + lxms.iter().for_each(|lxm| { 313 + let scope = match perm_aud { 314 + Some(aud) => format!("rpc:{}?aud={}", lxm, aud), 315 + None => format!("rpc:{}", lxm), 316 + }; 317 + scopes.push(scope); 318 + }); 319 + } 320 + } 321 + _ => {} 322 + }); 152 323 153 - debug!(nsid, expanded = %expanded, "Successfully expanded permission set"); 154 - Ok(expanded) 324 + scopes.join(" ") 155 325 } 156 326 157 327 #[cfg(test)] 158 328 mod tests { 329 + use super::*; 330 + 159 331 #[test] 160 - fn test_nsid_to_url() { 332 + fn test_parse_include_scope() { 333 + let (nsid, aud) = parse_include_scope("io.atcr.authFullApp"); 334 + assert_eq!(nsid, "io.atcr.authFullApp"); 335 + assert_eq!(aud, None); 336 + 337 + let (nsid, aud) = parse_include_scope("io.atcr.authFullApp?aud=did:web:api.bsky.app"); 338 + assert_eq!(nsid, "io.atcr.authFullApp"); 339 + assert_eq!(aud, Some("did:web:api.bsky.app")); 340 + } 341 + 342 + #[test] 343 + fn test_parse_include_scope_with_multiple_params() { 344 + let (nsid, aud) = 345 + parse_include_scope("io.atcr.authFullApp?foo=bar&aud=did:web:example.com&baz=qux"); 346 + assert_eq!(nsid, "io.atcr.authFullApp"); 347 + assert_eq!(aud, Some("did:web:example.com")); 348 + } 349 + 350 + #[test] 351 + fn test_extract_namespace_authority() { 352 + assert_eq!( 353 + extract_namespace_authority("io.atcr.authFullApp"), 354 + "io.atcr" 355 + ); 356 + assert_eq!( 357 + extract_namespace_authority("app.bsky.authFullApp"), 358 + "app.bsky" 359 + ); 360 + } 361 + 362 + #[test] 363 + fn test_extract_namespace_authority_deep_nesting() { 364 + assert_eq!( 365 + extract_namespace_authority("io.atcr.sailor.star.collection"), 366 + "io.atcr.sailor.star" 367 + ); 368 + } 369 + 370 + #[test] 371 + fn test_extract_namespace_authority_single_segment() { 372 + assert_eq!(extract_namespace_authority("single"), "single"); 373 + } 374 + 375 + #[test] 376 + fn test_is_under_authority() { 377 + assert!(is_under_authority("io.atcr.manifest", "io.atcr")); 378 + assert!(is_under_authority("io.atcr.sailor.star", "io.atcr")); 379 + assert!(!is_under_authority("app.bsky.feed.post", "io.atcr")); 380 + assert!(!is_under_authority("io.atcr", "io.atcr")); 381 + } 382 + 383 + #[test] 384 + fn test_is_under_authority_prefix_collision() { 385 + assert!(!is_under_authority("io.atcritical.something", "io.atcr")); 386 + assert!(is_under_authority("io.atcr.something", "io.atcr")); 387 + } 388 + 389 + #[test] 390 + fn test_build_expanded_scopes_repo() { 391 + let permissions = vec![PermissionEntry { 392 + resource: "repo".to_string(), 393 + action: Some(vec!["create".to_string(), "delete".to_string()]), 394 + collection: Some(vec![ 395 + "io.atcr.manifest".to_string(), 396 + "io.atcr.sailor.star".to_string(), 397 + "app.bsky.feed.post".to_string(), 398 + ]), 399 + lxm: None, 400 + aud: None, 401 + }]; 402 + 403 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 404 + assert!(expanded.contains("repo:io.atcr.manifest?action=create")); 405 + assert!(expanded.contains("repo:io.atcr.manifest?action=delete")); 406 + assert!(expanded.contains("repo:io.atcr.sailor.star?action=create")); 407 + assert!(!expanded.contains("app.bsky.feed.post")); 408 + } 409 + 410 + #[test] 411 + fn test_build_expanded_scopes_repo_default_actions() { 412 + let permissions = vec![PermissionEntry { 413 + resource: "repo".to_string(), 414 + action: None, 415 + collection: Some(vec!["io.atcr.manifest".to_string()]), 416 + lxm: None, 417 + aud: None, 418 + }]; 419 + 420 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 421 + assert!(expanded.contains("repo:io.atcr.manifest?action=create")); 422 + assert!(expanded.contains("repo:io.atcr.manifest?action=update")); 423 + assert!(expanded.contains("repo:io.atcr.manifest?action=delete")); 424 + } 425 + 426 + #[test] 427 + fn test_build_expanded_scopes_rpc() { 428 + let permissions = vec![PermissionEntry { 429 + resource: "rpc".to_string(), 430 + action: None, 431 + collection: None, 432 + lxm: Some(vec![ 433 + "io.atcr.getManifest".to_string(), 434 + "com.atproto.repo.getRecord".to_string(), 435 + ]), 436 + aud: Some("*".to_string()), 437 + }]; 438 + 439 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 440 + assert!(expanded.contains("rpc:io.atcr.getManifest?aud=*")); 441 + assert!(expanded.contains("rpc:com.atproto.repo.getRecord?aud=*")); 442 + } 443 + 444 + #[test] 445 + fn test_build_expanded_scopes_rpc_with_default_aud() { 446 + let permissions = vec![PermissionEntry { 447 + resource: "rpc".to_string(), 448 + action: None, 449 + collection: None, 450 + lxm: Some(vec!["io.atcr.getManifest".to_string()]), 451 + aud: None, 452 + }]; 453 + 454 + let expanded = 455 + build_expanded_scopes(&permissions, Some("did:web:api.example.com"), "io.atcr"); 456 + assert!(expanded.contains("rpc:io.atcr.getManifest?aud=did:web:api.example.com")); 457 + } 458 + 459 + #[test] 460 + fn test_build_expanded_scopes_rpc_no_aud() { 461 + let permissions = vec![PermissionEntry { 462 + resource: "rpc".to_string(), 463 + action: None, 464 + collection: None, 465 + lxm: Some(vec!["io.atcr.getManifest".to_string()]), 466 + aud: None, 467 + }]; 468 + 469 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 470 + assert_eq!(expanded, "rpc:io.atcr.getManifest"); 471 + } 472 + 473 + #[test] 474 + fn test_build_expanded_scopes_mixed_permissions() { 475 + let permissions = vec![ 476 + PermissionEntry { 477 + resource: "repo".to_string(), 478 + action: Some(vec!["create".to_string()]), 479 + collection: Some(vec!["io.atcr.manifest".to_string()]), 480 + lxm: None, 481 + aud: None, 482 + }, 483 + PermissionEntry { 484 + resource: "rpc".to_string(), 485 + action: None, 486 + collection: None, 487 + lxm: Some(vec!["com.atproto.repo.getRecord".to_string()]), 488 + aud: Some("*".to_string()), 489 + }, 490 + ]; 491 + 492 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 493 + assert!(expanded.contains("repo:io.atcr.manifest?action=create")); 494 + assert!(expanded.contains("rpc:com.atproto.repo.getRecord?aud=*")); 495 + } 496 + 497 + #[test] 498 + fn test_build_expanded_scopes_unknown_resource_ignored() { 499 + let permissions = vec![PermissionEntry { 500 + resource: "unknown".to_string(), 501 + action: None, 502 + collection: Some(vec!["io.atcr.manifest".to_string()]), 503 + lxm: None, 504 + aud: None, 505 + }]; 506 + 507 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 508 + assert!(expanded.is_empty()); 509 + } 510 + 511 + #[test] 512 + fn test_build_expanded_scopes_empty_permissions() { 513 + let permissions: Vec<PermissionEntry> = vec![]; 514 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 515 + assert!(expanded.is_empty()); 516 + } 517 + 518 + #[tokio::test] 519 + async fn test_expand_include_scopes_passthrough_non_include() { 520 + let result = expand_include_scopes("atproto transition:generic").await; 521 + assert_eq!(result, "atproto transition:generic"); 522 + } 523 + 524 + #[tokio::test] 525 + async fn test_expand_include_scopes_mixed_with_regular() { 526 + let result = expand_include_scopes("atproto repo:app.bsky.feed.post?action=create").await; 527 + assert!(result.contains("atproto")); 528 + assert!(result.contains("repo:app.bsky.feed.post?action=create")); 529 + } 530 + 531 + #[tokio::test] 532 + async fn test_cache_population_and_retrieval() { 533 + let cache_key = "test.cached.scope"; 534 + let cached_value = "repo:test.cached.collection?action=create"; 535 + 536 + { 537 + let mut cache = LEXICON_CACHE.write().await; 538 + cache.insert( 539 + cache_key.to_string(), 540 + CachedLexicon { 541 + expanded_scope: cached_value.to_string(), 542 + cached_at: std::time::Instant::now(), 543 + }, 544 + ); 545 + } 546 + 547 + let result = expand_permission_set(cache_key, None).await; 548 + assert!(result.is_ok()); 549 + assert_eq!(result.unwrap(), cached_value); 550 + 551 + { 552 + let mut cache = LEXICON_CACHE.write().await; 553 + cache.remove(cache_key); 554 + } 555 + } 556 + 557 + #[tokio::test] 558 + async fn test_cache_with_aud_parameter() { 559 + let nsid = "test.aud.scope"; 560 + let aud = "did:web:example.com"; 561 + let cache_key = format!("{}?aud={}", nsid, aud); 562 + let cached_value = "rpc:test.aud.method?aud=did:web:example.com"; 563 + 564 + { 565 + let mut cache = LEXICON_CACHE.write().await; 566 + cache.insert( 567 + cache_key.clone(), 568 + CachedLexicon { 569 + expanded_scope: cached_value.to_string(), 570 + cached_at: std::time::Instant::now(), 571 + }, 572 + ); 573 + } 574 + 575 + let result = expand_permission_set(nsid, Some(aud)).await; 576 + assert!(result.is_ok()); 577 + assert_eq!(result.unwrap(), cached_value); 578 + 579 + { 580 + let mut cache = LEXICON_CACHE.write().await; 581 + cache.remove(&cache_key); 582 + } 583 + } 584 + 585 + #[tokio::test] 586 + async fn test_expired_cache_triggers_refresh() { 587 + let cache_key = "test.expired.scope"; 588 + 589 + { 590 + let mut cache = LEXICON_CACHE.write().await; 591 + cache.insert( 592 + cache_key.to_string(), 593 + CachedLexicon { 594 + expanded_scope: "old_value".to_string(), 595 + cached_at: std::time::Instant::now() 596 + - std::time::Duration::from_secs(CACHE_TTL_SECS + 1), 597 + }, 598 + ); 599 + } 600 + 601 + let result = expand_permission_set(cache_key, None).await; 602 + assert!(result.is_err()); 603 + 604 + { 605 + let mut cache = LEXICON_CACHE.write().await; 606 + cache.remove(cache_key); 607 + } 608 + } 609 + 610 + #[test] 611 + fn test_nsid_authority_extraction_for_dns() { 161 612 let nsid = "io.atcr.authFullApp"; 162 613 let parts: Vec<&str> = nsid.split('.').collect(); 163 - let domain_parts: Vec<&str> = parts[..2].iter().rev().cloned().collect(); 164 - let domain = domain_parts.join("."); 165 - let path = parts[2..].join("/"); 614 + let authority = parts[..2] 615 + .iter() 616 + .rev() 617 + .cloned() 618 + .collect::<Vec<_>>() 619 + .join("."); 620 + assert_eq!(authority, "atcr.io"); 166 621 167 - assert_eq!(domain, "atcr.io"); 168 - assert_eq!(path, "authFullApp"); 622 + let nsid2 = "app.bsky.feed.post"; 623 + let parts2: Vec<&str> = nsid2.split('.').collect(); 624 + let authority2 = parts2[..2] 625 + .iter() 626 + .rev() 627 + .cloned() 628 + .collect::<Vec<_>>() 629 + .join("."); 630 + assert_eq!(authority2, "bsky.app"); 169 631 } 170 632 }
+38 -3
crates/tranquil-scopes/src/permissions.rs
··· 126 126 return Ok(()); 127 127 } 128 128 129 - let has_permission = self.find_repo_scopes().any(|repo_scope| { 129 + let has_repo_permission = self.find_repo_scopes().any(|repo_scope| { 130 130 repo_scope.actions.contains(&action) 131 131 && match &repo_scope.collection { 132 132 None => true, ··· 140 140 } 141 141 }); 142 142 143 - if has_permission { 143 + if has_repo_permission { 144 144 Ok(()) 145 145 } else { 146 146 Err(ScopeError::InsufficientScope { ··· 181 181 return Ok(()); 182 182 } 183 183 184 + let aud_base = aud.split('#').next().unwrap_or(aud); 185 + 184 186 let has_permission = self.find_rpc_scopes().any(|rpc_scope| { 185 187 let lxm_matches = match &rpc_scope.lxm { 186 188 None => true, ··· 195 197 let aud_matches = match &rpc_scope.aud { 196 198 None => true, 197 199 Some(scope_aud) if scope_aud == "*" => true, 198 - Some(scope_aud) => scope_aud == aud, 200 + Some(scope_aud) => { 201 + let scope_aud_base = scope_aud.split('#').next().unwrap_or(scope_aud); 202 + scope_aud_base == aud_base 203 + } 199 204 }; 200 205 201 206 lxm_matches && aud_matches ··· 520 525 assert!(perms.allows_repo(RepoAction::Update, "any.collection")); 521 526 assert!(perms.allows_blob("image/png")); 522 527 assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline")); 528 + } 529 + 530 + #[test] 531 + fn test_rpc_scope_with_did_fragment() { 532 + let perms = ScopePermissions::from_scope_string(Some( 533 + "rpc:app.bsky.feed.getAuthorFeed?aud=did:web:api.bsky.app#bsky_appview", 534 + )); 535 + assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getAuthorFeed")); 536 + assert!(perms.allows_rpc( 537 + "did:web:api.bsky.app#bsky_appview", 538 + "app.bsky.feed.getAuthorFeed" 539 + )); 540 + assert!(perms.allows_rpc( 541 + "did:web:api.bsky.app#other_service", 542 + "app.bsky.feed.getAuthorFeed" 543 + )); 544 + assert!(!perms.allows_rpc("did:web:other.app", "app.bsky.feed.getAuthorFeed")); 545 + assert!(!perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline")); 546 + } 547 + 548 + #[test] 549 + fn test_rpc_scope_without_fragment_matches_with_fragment() { 550 + let perms = ScopePermissions::from_scope_string(Some( 551 + "rpc:app.bsky.feed.getAuthorFeed?aud=did:web:api.bsky.app", 552 + )); 553 + assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getAuthorFeed")); 554 + assert!(perms.allows_rpc( 555 + "did:web:api.bsky.app#bsky_appview", 556 + "app.bsky.feed.getAuthorFeed" 557 + )); 523 558 } 524 559 }
+24 -16
crates/tranquil-storage/src/lib.rs
··· 22 22 const CID_SHARD_PREFIX_LEN: usize = 9; 23 23 24 24 fn split_cid_path(key: &str) -> Option<(&str, &str)> { 25 - let is_cid = key.get(..3).map_or(false, |p| p.eq_ignore_ascii_case("baf")); 26 - (key.len() > CID_SHARD_PREFIX_LEN && is_cid) 27 - .then(|| key.split_at(CID_SHARD_PREFIX_LEN)) 25 + let is_cid = key.get(..3).is_some_and(|p| p.eq_ignore_ascii_case("baf")); 26 + (key.len() > CID_SHARD_PREFIX_LEN && is_cid).then(|| key.split_at(CID_SHARD_PREFIX_LEN)) 28 27 } 29 28 30 29 fn validate_key(key: &str) -> Result<(), StorageError> { ··· 771 770 let cid = "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"; 772 771 assert_eq!( 773 772 split_cid_path(cid), 774 - Some(("bafkreihd", "wdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku")) 773 + Some(( 774 + "bafkreihd", 775 + "wdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 776 + )) 775 777 ); 776 778 } 777 779 ··· 780 782 let cid = "bafyreigdmqpykrgxyaxtlafqpqhzrb7qy2rh75nldvfd4tucqmqqme5yje"; 781 783 assert_eq!( 782 784 split_cid_path(cid), 783 - Some(("bafyreigd", "mqpykrgxyaxtlafqpqhzrb7qy2rh75nldvfd4tucqmqqme5yje")) 785 + Some(( 786 + "bafyreigd", 787 + "mqpykrgxyaxtlafqpqhzrb7qy2rh75nldvfd4tucqmqqme5yje" 788 + )) 784 789 ); 785 790 } 786 791 ··· 810 815 let mixed = "BaFkReIhDwDcEfGh4DqKjV67UzCmW7OjEe6XeDzDeTojUzJevTeNxQuVyKu"; 811 816 assert_eq!( 812 817 split_cid_path(upper), 813 - Some(("BAFKREIHD", "WDCEFGH4DQKJV67UZCMW7OJEE6XEDZDETOJUZJEVTENXQUVYKU")) 818 + Some(( 819 + "BAFKREIHD", 820 + "WDCEFGH4DQKJV67UZCMW7OJEE6XEDZDETOJUZJEVTENXQUVYKU" 821 + )) 814 822 ); 815 823 assert_eq!( 816 824 split_cid_path(mixed), 817 - Some(("BaFkReIhD", "wDcEfGh4DqKjV67UzCmW7OjEe6XeDzDeTojUzJevTeNxQuVyKu")) 825 + Some(( 826 + "BaFkReIhD", 827 + "wDcEfGh4DqKjV67UzCmW7OjEe6XeDzDeTojUzJevTeNxQuVyKu" 828 + )) 818 829 ); 819 830 } 820 831 ··· 829 840 let base = PathBuf::from("/blobs"); 830 841 let cid = "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"; 831 842 832 - let expected = PathBuf::from("/blobs/bafkreihd/wdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"); 833 - let result = split_cid_path(cid).map_or_else( 834 - || base.join(cid), 835 - |(dir, file)| base.join(dir).join(file), 836 - ); 843 + let expected = 844 + PathBuf::from("/blobs/bafkreihd/wdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"); 845 + let result = split_cid_path(cid) 846 + .map_or_else(|| base.join(cid), |(dir, file)| base.join(dir).join(file)); 837 847 assert_eq!(result, expected); 838 848 } 839 849 ··· 843 853 let key = "temp/abc123"; 844 854 845 855 let expected = PathBuf::from("/blobs/temp/abc123"); 846 - let result = split_cid_path(key).map_or_else( 847 - || base.join(key), 848 - |(dir, file)| base.join(dir).join(file), 849 - ); 856 + let result = split_cid_path(key) 857 + .map_or_else(|| base.join(key), |(dir, file)| base.join(dir).join(file)); 850 858 assert_eq!(result, expected); 851 859 } 852 860 }
+100 -35
frontend/src/lib/api.ts
··· 16 16 unsafeAsISODate, 17 17 unsafeAsRefreshToken, 18 18 } from "./types/branded.ts"; 19 + import { 20 + createDPoPProofForRequest, 21 + getDPoPNonce, 22 + setDPoPNonce, 23 + } from "./oauth.ts"; 19 24 import type { 20 25 AccountInfo, 21 26 ApiErrorCode, ··· 91 96 } 92 97 } 93 98 94 - let tokenRefreshCallback: (() => Promise<string | null>) | null = null; 99 + let tokenRefreshCallback: (() => Promise<AccessToken | null>) | null = null; 95 100 96 101 export function setTokenRefreshCallback( 97 - callback: () => Promise<string | null>, 102 + callback: () => Promise<AccessToken | null>, 98 103 ) { 99 104 tokenRefreshCallback = callback; 100 105 } 101 106 107 + interface AuthenticatedFetchOptions { 108 + method?: "GET" | "POST"; 109 + token: AccessToken | RefreshToken; 110 + headers?: Record<string, string>; 111 + body?: BodyInit; 112 + } 113 + 114 + async function authenticatedFetch( 115 + url: string, 116 + options: AuthenticatedFetchOptions, 117 + ): Promise<Response> { 118 + const { method = "GET", token, headers = {}, body } = options; 119 + const fullUrl = url.startsWith("http") 120 + ? url 121 + : `${globalThis.location.origin}${url}`; 122 + const dpopProof = await createDPoPProofForRequest(method, fullUrl, token); 123 + const res = await fetch(url, { 124 + method, 125 + headers: { 126 + ...headers, 127 + Authorization: `DPoP ${token}`, 128 + DPoP: dpopProof, 129 + }, 130 + body, 131 + }); 132 + const dpopNonce = res.headers.get("DPoP-Nonce"); 133 + if (dpopNonce) { 134 + setDPoPNonce(dpopNonce); 135 + } 136 + return res; 137 + } 138 + 102 139 interface XrpcOptions { 103 140 method?: "GET" | "POST"; 104 141 params?: Record<string, string>; 105 142 body?: unknown; 106 - token?: string; 143 + token?: AccessToken | RefreshToken; 107 144 skipRetry?: boolean; 145 + skipDpopRetry?: boolean; 108 146 } 109 147 110 148 async function xrpc<T>(method: string, options?: XrpcOptions): Promise<T> { 111 - const { method: httpMethod = "GET", params, body, token, skipRetry } = 112 - options ?? {}; 149 + const { 150 + method: httpMethod = "GET", 151 + params, 152 + body, 153 + token, 154 + skipRetry, 155 + skipDpopRetry, 156 + } = options ?? {}; 113 157 let url = `${API_BASE}/${method}`; 114 158 if (params) { 115 159 const searchParams = new URLSearchParams(params); 116 160 url += `?${searchParams}`; 117 161 } 118 162 const headers: Record<string, string> = {}; 119 - if (token) { 120 - headers["Authorization"] = `Bearer ${token}`; 121 - } 122 163 if (body) { 123 164 headers["Content-Type"] = "application/json"; 124 165 } 125 - const res = await fetch(url, { 126 - method: httpMethod, 127 - headers, 128 - body: body ? JSON.stringify(body) : undefined, 129 - }); 166 + const res = token 167 + ? await authenticatedFetch(url, { 168 + method: httpMethod, 169 + token, 170 + headers, 171 + body: body ? JSON.stringify(body) : undefined, 172 + }) 173 + : await fetch(url, { 174 + method: httpMethod, 175 + headers, 176 + body: body ? JSON.stringify(body) : undefined, 177 + }); 130 178 if (!res.ok) { 131 179 const errData = await res.json().catch(() => ({ 132 180 error: "Unknown", ··· 134 182 })); 135 183 if ( 136 184 res.status === 401 && 185 + errData.error === "use_dpop_nonce" && 186 + token && 187 + !skipDpopRetry && 188 + getDPoPNonce() 189 + ) { 190 + return xrpc(method, { ...options, skipDpopRetry: true }); 191 + } 192 + if ( 193 + res.status === 401 && 137 194 (errData.error === "AuthenticationFailed" || 138 - errData.error === "ExpiredToken") && 139 - token && tokenRefreshCallback && !skipRetry 195 + errData.error === "ExpiredToken" || 196 + errData.error === "OAuthExpiredToken") && 197 + token && 198 + tokenRefreshCallback && 199 + !skipRetry 140 200 ) { 141 201 const newToken = await tokenRefreshCallback(); 142 202 if (newToken && newToken !== token) { ··· 536 596 token: AccessToken, 537 597 file: File, 538 598 ): Promise<UploadBlobResponse> { 539 - const res = await fetch("/xrpc/com.atproto.repo.uploadBlob", { 599 + const res = await authenticatedFetch("/xrpc/com.atproto.repo.uploadBlob", { 540 600 method: "POST", 541 - headers: { 542 - "Authorization": `Bearer ${token}`, 543 - "Content-Type": file.type, 544 - }, 601 + token, 602 + headers: { "Content-Type": file.type }, 545 603 body: file, 546 604 }); 547 605 if (!res.ok) { ··· 1084 1142 }, 1085 1143 1086 1144 async getRepo(token: AccessToken, did: Did): Promise<ArrayBuffer> { 1087 - const url = `${API_BASE}/com.atproto.sync.getRepo?did=${ 1088 - encodeURIComponent(did) 1089 - }`; 1090 - const res = await fetch(url, { 1091 - headers: { Authorization: `Bearer ${token}` }, 1092 - }); 1145 + const url = `${API_BASE}/com.atproto.sync.getRepo?did=${encodeURIComponent(did)}`; 1146 + const res = await authenticatedFetch(url, { token }); 1093 1147 if (!res.ok) { 1094 1148 const errData = await res.json().catch(() => ({ 1095 1149 error: "Unknown", ··· 1106 1160 1107 1161 async getBackup(token: AccessToken, id: string): Promise<Blob> { 1108 1162 const url = `${API_BASE}/_backup.getBackup?id=${encodeURIComponent(id)}`; 1109 - const res = await fetch(url, { 1110 - headers: { Authorization: `Bearer ${token}` }, 1111 - }); 1163 + const res = await authenticatedFetch(url, { token }); 1112 1164 if (!res.ok) { 1113 1165 const errData = await res.json().catch(() => ({ 1114 1166 error: "Unknown", ··· 1146 1198 }, 1147 1199 1148 1200 async importRepo(token: AccessToken, car: Uint8Array): Promise<void> { 1149 - const url = `${API_BASE}/com.atproto.repo.importRepo`; 1150 - const res = await fetch(url, { 1201 + const res = await authenticatedFetch(`${API_BASE}/com.atproto.repo.importRepo`, { 1151 1202 method: "POST", 1152 - headers: { 1153 - Authorization: `Bearer ${token}`, 1154 - "Content-Type": "application/vnd.ipld.car", 1155 - }, 1203 + token, 1204 + headers: { "Content-Type": "application/vnd.ipld.car" }, 1156 1205 body: car as unknown as BodyInit, 1157 1206 }); 1158 1207 if (!res.ok) { ··· 1162 1211 })); 1163 1212 throw new ApiError(res.status, errData.error, errData.message); 1164 1213 } 1214 + }, 1215 + 1216 + async establishOAuthSession(token: AccessToken): Promise<{ success: boolean; device_id: string }> { 1217 + const res = await authenticatedFetch("/oauth/establish-session", { 1218 + method: "POST", 1219 + token, 1220 + headers: { "Content-Type": "application/json" }, 1221 + }); 1222 + if (!res.ok) { 1223 + const errData = await res.json().catch(() => ({ 1224 + error: "Unknown", 1225 + message: res.statusText, 1226 + })); 1227 + throw new ApiError(res.status, errData.error, errData.message); 1228 + } 1229 + return res.json(); 1165 1230 }, 1166 1231 }; 1167 1232
+1 -1
frontend/src/lib/auth.svelte.ts
··· 281 281 } 282 282 } 283 283 284 - async function tryRefreshToken(): Promise<string | null> { 284 + async function tryRefreshToken(): Promise<AccessToken | null> { 285 285 if (state.current.kind !== "authenticated") return null; 286 286 const currentSession = state.current.session; 287 287 try {
+18 -1
frontend/src/lib/migration/atproto-client.ts
··· 240 240 }&cid=${encodeURIComponent(cid)}`; 241 241 const headers: Record<string, string> = {}; 242 242 if (this.accessToken) { 243 - headers["Authorization"] = `Bearer ${this.accessToken}`; 243 + if (this.dpopKeyPair) { 244 + headers["Authorization"] = `DPoP ${this.accessToken}`; 245 + const tokenHash = await computeAccessTokenHash(this.accessToken); 246 + const dpopProof = await createDPoPProof( 247 + this.dpopKeyPair, 248 + "GET", 249 + url.split("?")[0], 250 + this.dpopNonce ?? undefined, 251 + tokenHash, 252 + ); 253 + headers["DPoP"] = dpopProof; 254 + } else { 255 + headers["Authorization"] = `Bearer ${this.accessToken}`; 256 + } 244 257 } 245 258 const res = await fetch(url, { headers }); 259 + const newNonce = res.headers.get("DPoP-Nonce"); 260 + if (newNonce) { 261 + this.dpopNonce = newNonce; 262 + } 246 263 if (!res.ok) { 247 264 const err = await res.json().catch(() => ({ 248 265 error: "Unknown",
+3 -1
frontend/src/lib/migration/flow.svelte.ts
··· 88 88 89 89 function setStep(step: InboundStep) { 90 90 state.step = step; 91 - state.error = null; 91 + if (step !== "error") { 92 + state.error = null; 93 + } 92 94 if (step !== "success") { 93 95 saveMigrationState(state); 94 96 updateStep(step);
+3 -1
frontend/src/lib/migration/offline-flow.svelte.ts
··· 177 177 178 178 function setStep(step: OfflineInboundStep) { 179 179 state.step = step; 180 - state.error = null; 180 + if (step !== "error") { 181 + state.error = null; 182 + } 181 183 if (step !== "success") { 182 184 saveOfflineState(state); 183 185 }
+3 -3
frontend/src/lib/oauth.ts
··· 246 246 return base64UrlEncode(hash); 247 247 } 248 248 249 - function getDPoPNonce(): string | null { 249 + export function getDPoPNonce(): string | null { 250 250 return sessionStorage.getItem(DPOP_NONCE_KEY); 251 251 } 252 252 253 - function setDPoPNonce(nonce: string): void { 253 + export function setDPoPNonce(nonce: string): void { 254 254 sessionStorage.setItem(DPOP_NONCE_KEY, nonce); 255 255 } 256 256 257 - function extractDPoPNonceFromResponse(response: Response): void { 257 + export function extractDPoPNonceFromResponse(response: Response): void { 258 258 const nonce = response.headers.get("DPoP-Nonce"); 259 259 if (nonce) { 260 260 setDPoPNonce(nonce);
+5
frontend/src/locales/en.json
··· 779 779 "name": "Manage Account", 780 780 "description": "Manage account settings and preferences" 781 781 } 782 + }, 783 + "unexpectedState": { 784 + "title": "Unexpected State", 785 + "description": "The consent page is in an unexpected state. Please check the browser console for errors.", 786 + "reload": "Reload Page" 782 787 } 783 788 }, 784 789 "accounts": {
+5
frontend/src/locales/fi.json
··· 785 785 "name": "Hallitse tiliรค", 786 786 "description": "Hallitse tilin asetuksia ja asetuksia" 787 787 } 788 + }, 789 + "unexpectedState": { 790 + "title": "Odottamaton tila", 791 + "description": "Suostumussivulla on odottamaton tila. Tarkista selaimen konsoli virheiden varalta.", 792 + "reload": "Lataa sivu uudelleen" 788 793 } 789 794 }, 790 795 "accounts": {
+5
frontend/src/locales/ja.json
··· 778 778 "name": "ใ‚ขใ‚ซใ‚ฆใƒณใƒˆ็ฎก็†", 779 779 "description": "ใ‚ขใ‚ซใ‚ฆใƒณใƒˆ่จญๅฎšใจ่จญๅฎšใ‚’็ฎก็†" 780 780 } 781 + }, 782 + "unexpectedState": { 783 + "title": "ไบˆๆœŸใ—ใชใ„็Šถๆ…‹", 784 + "description": "ๅŒๆ„ใƒšใƒผใ‚ธใŒไบˆๆœŸใ—ใชใ„็Šถๆ…‹ใงใ™ใ€‚ใƒ–ใƒฉใ‚ฆใ‚ถใฎใ‚ณใƒณใ‚ฝใƒผใƒซใงใ‚จใƒฉใƒผใ‚’็ขบ่ชใ—ใฆใใ ใ•ใ„ใ€‚", 785 + "reload": "ใƒšใƒผใ‚ธใ‚’ๅ†่ชญใฟ่พผใฟ" 781 786 } 782 787 }, 783 788 "accounts": {
+5
frontend/src/locales/ko.json
··· 778 778 "name": "๊ณ„์ • ๊ด€๋ฆฌ", 779 779 "description": "๊ณ„์ • ์„ค์ • ๋ฐ ํ™˜๊ฒฝ์„ค์ • ๊ด€๋ฆฌ" 780 780 } 781 + }, 782 + "unexpectedState": { 783 + "title": "์˜ˆ๊ธฐ์น˜ ์•Š์€ ์ƒํƒœ", 784 + "description": "๋™์˜ ํŽ˜์ด์ง€๊ฐ€ ์˜ˆ๊ธฐ์น˜ ์•Š์€ ์ƒํƒœ์ž…๋‹ˆ๋‹ค. ๋ธŒ๋ผ์šฐ์ € ์ฝ˜์†”์—์„œ ์˜ค๋ฅ˜๋ฅผ ํ™•์ธํ•˜์„ธ์š”.", 785 + "reload": "ํŽ˜์ด์ง€ ์ƒˆ๋กœ๊ณ ์นจ" 781 786 } 782 787 }, 783 788 "accounts": {
+5
frontend/src/locales/sv.json
··· 778 778 "name": "Hantera konto", 779 779 "description": "Hantera kontoinstรคllningar och preferenser" 780 780 } 781 + }, 782 + "unexpectedState": { 783 + "title": "Ovรคntat tillstรฅnd", 784 + "description": "Samtyckes-sidan รคr i ett ovรคntat tillstรฅnd. Kontrollera webblรคsarens konsol fรถr fel.", 785 + "reload": "Ladda om sidan" 781 786 } 782 787 }, 783 788 "accounts": {
+5
frontend/src/locales/zh.json
··· 778 778 "name": "็ฎก็†่ดฆๆˆท", 779 779 "description": "็ฎก็†่ดฆๆˆท่ฎพ็ฝฎๅ’Œๅๅฅฝ" 780 780 } 781 + }, 782 + "unexpectedState": { 783 + "title": "ๆ„ๅค–็Šถๆ€", 784 + "description": "ๅŒๆ„้กต้ขๅค„ไบŽๆ„ๅค–็Šถๆ€ใ€‚่ฏทๆฃ€ๆŸฅๆต่งˆๅ™จๆŽงๅˆถๅฐไปฅๆŸฅ็œ‹้”™่ฏฏใ€‚", 785 + "reload": "้‡ๆ–ฐๅŠ ่ฝฝ้กต้ข" 781 786 } 782 787 }, 783 788 "accounts": {
+37 -16
frontend/src/routes/Migration.svelte
··· 2 2 import { setSession } from '../lib/auth.svelte' 3 3 import { navigate, routes } from '../lib/router.svelte' 4 4 import { _ } from '../lib/i18n' 5 + import { api } from '../lib/api' 6 + import { startOAuthLogin } from '../lib/oauth' 7 + import { unsafeAsAccessToken } from '../lib/types/branded' 5 8 import { 6 9 createInboundMigrationFlow, 7 10 createOfflineInboundMigrationFlow, ··· 143 146 direction = 'select' 144 147 } 145 148 146 - function handleInboundComplete() { 149 + async function handleInboundComplete() { 147 150 const session = inboundFlow?.getLocalSession() 148 151 if (session) { 149 - setSession({ 150 - did: session.did, 151 - handle: session.handle, 152 - accessJwt: session.accessJwt, 153 - refreshJwt: '', 154 - }) 152 + try { 153 + await api.establishOAuthSession(unsafeAsAccessToken(session.accessJwt)) 154 + clearMigrationState() 155 + await startOAuthLogin(session.handle) 156 + } catch (e) { 157 + console.error('Failed to establish OAuth session, falling back to direct login:', e) 158 + setSession({ 159 + did: session.did, 160 + handle: session.handle, 161 + accessJwt: session.accessJwt, 162 + refreshJwt: '', 163 + }) 164 + navigate(routes.dashboard) 165 + } 166 + } else { 167 + navigate(routes.dashboard) 155 168 } 156 - navigate(routes.dashboard) 157 169 } 158 170 159 - function handleOfflineComplete() { 171 + async function handleOfflineComplete() { 160 172 const session = offlineFlow?.getLocalSession() 161 173 if (session) { 162 - setSession({ 163 - did: session.did, 164 - handle: session.handle, 165 - accessJwt: session.accessJwt, 166 - refreshJwt: '', 167 - }) 174 + try { 175 + await api.establishOAuthSession(unsafeAsAccessToken(session.accessJwt)) 176 + clearOfflineState() 177 + await startOAuthLogin(session.handle) 178 + } catch (e) { 179 + console.error('Failed to establish OAuth session, falling back to direct login:', e) 180 + setSession({ 181 + did: session.did, 182 + handle: session.handle, 183 + accessJwt: session.accessJwt, 184 + refreshJwt: '', 185 + }) 186 + navigate(routes.dashboard) 187 + } 188 + } else { 189 + navigate(routes.dashboard) 168 190 } 169 - navigate(routes.dashboard) 170 191 } 171 192 </script> 172 193
+5 -31
frontend/src/routes/OAuthAccounts.svelte
··· 196 196 display: flex; 197 197 align-items: center; 198 198 padding: var(--space-4); 199 - background: var(--bg-card); 199 + background: var(--bg-secondary); 200 200 border: 1px solid var(--border-color); 201 201 border-radius: var(--radius-xl); 202 202 cursor: pointer; 203 203 text-align: left; 204 204 width: 100%; 205 - transition: border-color var(--transition-fast), box-shadow var(--transition-fast); 205 + transition: border-color var(--transition-fast), background var(--transition-fast); 206 206 } 207 207 208 208 .account-item:hover:not(.disabled) { 209 209 border-color: var(--accent); 210 - box-shadow: var(--shadow-sm); 210 + background: var(--bg-tertiary); 211 211 } 212 212 213 213 .account-item.disabled { ··· 231 231 color: var(--text-secondary); 232 232 } 233 233 234 - button { 235 - padding: var(--space-3); 236 - background: var(--accent); 237 - color: var(--text-inverse); 238 - border: none; 239 - border-radius: var(--radius-md); 240 - font-size: var(--text-base); 241 - cursor: pointer; 242 - } 243 - 244 - button:hover:not(:disabled) { 245 - background: var(--accent-hover); 246 - } 247 - 248 - button:disabled { 249 - opacity: 0.6; 250 - cursor: not-allowed; 251 - } 252 - 253 - button.secondary { 254 - background: transparent; 255 - color: var(--accent); 256 - border: 1px solid var(--accent); 234 + .different-account { 235 + margin-top: var(--space-4); 257 236 width: 100%; 258 - } 259 - 260 - button.secondary:hover:not(:disabled) { 261 - background: var(--accent); 262 - color: var(--text-inverse); 263 237 } 264 238 265 239 .different-account {
+38 -3
frontend/src/routes/OAuthConsent.svelte
··· 65 65 async function fetchConsentData() { 66 66 const requestUri = getRequestUri() 67 67 if (!requestUri) { 68 + console.error('[OAuthConsent] No request_uri in URL') 68 69 error = $_('oauth.error.genericError') 69 70 loading = false 70 71 return ··· 74 75 const response = await fetch(`/oauth/authorize/consent?request_uri=${encodeURIComponent(requestUri)}`) 75 76 if (!response.ok) { 76 77 const data = await response.json() 78 + console.error('[OAuthConsent] Consent fetch failed:', data) 77 79 error = data.error_description || data.error || $_('oauth.error.genericError') 78 80 loading = false 79 81 return 80 82 } 81 83 const data: ConsentData = await response.json() 84 + 85 + if (!data.scopes || !Array.isArray(data.scopes)) { 86 + console.error('[OAuthConsent] Invalid scopes data:', data.scopes) 87 + error = 'Invalid consent data received' 88 + loading = false 89 + return 90 + } 91 + 82 92 consentData = data 83 93 84 94 scopeSelections = Object.fromEntries( ··· 91 101 if (!data.show_consent) { 92 102 await submitConsent() 93 103 } 94 - } catch { 104 + } catch (e) { 105 + console.error('[OAuthConsent] Error during consent fetch:', e) 95 106 error = $_('oauth.error.genericError') 96 107 } finally { 97 108 loading = false ··· 104 115 } 105 116 106 117 async function submitConsent() { 107 - if (!consentData) return 118 + if (!consentData) { 119 + console.error('[OAuthConsent] submitConsent called but no consentData') 120 + return 121 + } 108 122 109 123 submitting = true 110 124 let approvedScopes = Object.entries(scopeSelections) ··· 128 142 129 143 if (!response.ok) { 130 144 const data = await response.json() 145 + console.error('[OAuthConsent] Submit failed:', data) 131 146 error = data.error_description || data.error || $_('oauth.error.genericError') 132 147 submitting = false 133 148 return ··· 136 151 const data = await response.json() 137 152 if (data.redirect_uri) { 138 153 window.location.href = data.redirect_uri 154 + } else { 155 + console.error('[OAuthConsent] No redirect_uri in response') 156 + error = 'Authorization failed - no redirect received' 157 + submitting = false 139 158 } 140 - } catch { 159 + } catch (e) { 160 + console.error('[OAuthConsent] Submit error:', e) 141 161 error = $_('oauth.error.genericError') 142 162 submitting = false 143 163 } ··· 249 269 <div class="spinner"></div> 250 270 <p>{$_('common.loading')}</p> 251 271 </div> 272 + {:else} 273 + <p style="color: var(--text-muted); font-size: 0.875rem;">Loading consent data...</p> 252 274 {/if} 253 275 </div> 254 276 {:else if error} ··· 370 392 </button> 371 393 <button type="button" class="approve-btn" onclick={submitConsent} disabled={submitting}> 372 394 {submitting ? $_('oauth.consent.authorizing') : $_('oauth.consent.authorize')} 395 + </button> 396 + </div> 397 + {:else} 398 + <div class="error-container"> 399 + <h1>{$_('oauth.consent.unexpectedState.title')}</h1> 400 + <p style="color: var(--text-secondary);"> 401 + {$_('oauth.consent.unexpectedState.description')} 402 + </p> 403 + <p style="color: var(--text-muted); font-size: 0.75rem; font-family: monospace;"> 404 + loading={loading}, error={error ? 'set' : 'null'}, consentData={consentData ? 'set' : 'null'}, submitting={submitting} 405 + </p> 406 + <button type="button" onclick={() => window.location.reload()}> 407 + {$_('oauth.consent.unexpectedState.reload')} 373 408 </button> 374 409 </div> 375 410 {/if}

History

6 rounds 1 comment
sign up or login to add to the discussion
3 commits
expand
fix: oauth consolidation, include-scope improvements
fix: consolidate auth extractors & standardize usage
fix: match ref pds permission-levels for some endpoints
expand 0 comments
pull request successfully merged
3 commits
expand
fix: oauth consolidation, include-scope improvements
fix: consolidate auth extractors & standardize usage
fix: match ref pds permission-levels for some endpoints
expand 0 comments
2 commits
expand
fix: oauth consolidation, include-scope improvements
fix: consolidate auth extractors & standardize usage
expand 0 comments
2 commits
expand
fix: oauth consolidation, include-scope improvements
fix: consolidate auth extractors & standardize usage
expand 1 comment

so three things:

crates/tranquil-pds/src/auth/auth_extractor.rs does not seem to be used at all anywhere?

crates/tranquil-pds/src/api/temp.rs feels like it just ... shouldnt excist based on the name

and i still dont really like these extractors. the separation of inter service auth is weird to me. inter-service auth is a form of user auth. it shouldnt be separated out from the other types. the AuthExtractor should just be AuthExtractor(pub AuthenticatedUser).

i also discovered https://docs.rs/axum/0.8.8/axum/extract/trait.OptionalFromRequestParts.html which we should be able to reduce optional vs not optional with just an AuthExtractor vs Option.

principly id want whether or not its required that the account is active or not and whether its an admin account or not to both also be type safe configurations on the extractor. probably with generics of some sort. but i cant think of a specific design i like right now so. if you come up with one feel free to do it. otherwise we can do it later

1 commit
expand
fix: oauth consolidation, include-scope improvements
expand 0 comments
lewis.moe submitted #0
1 commit
expand
fix: oauth consolidation, include-scope improvements
expand 0 comments