Our Personal Data Server from scratch! tranquil.farm
oauth atproto pds rust postgresql objectstorage fun

fix: oauth consolidation, include-scope improvements #4

merged opened by lewis.moe targeting main from fix/oauth-on-niche-apps
  • auth extraction should be happening in the auth crate, yes, who coulda thought
  • include: scope should actually be doing the right thing and going out and requesting stuff to expand out the perms
  • more tests!!!1!
  • more correct parsing of the #bsky-appview or whatever suffixes on did webs that come through auth
Labels

None yet.

assignee
Participants 2
AT URI
at://did:plc:3fwecdnvtcscjnrx2p4n7alz/sh.tangled.repo.pull/3md3bluniqt22
+5198 -3437
Diff #5
-77
.sqlx/query-06eb7c6e1983b6121526ba63612236391290c2e63d37d2bb1cd89ea822950a82.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "\n SELECT token, request_uri, provider as \"provider: SsoProviderType\",\n provider_user_id, provider_username, provider_email, created_at, expires_at\n FROM sso_pending_registration\n WHERE token = $1 AND expires_at > NOW()\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "token", 9 - "type_info": "Text" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "request_uri", 14 - "type_info": "Text" 15 - }, 16 - { 17 - "ordinal": 2, 18 - "name": "provider: SsoProviderType", 19 - "type_info": { 20 - "Custom": { 21 - "name": "sso_provider_type", 22 - "kind": { 23 - "Enum": [ 24 - "github", 25 - "discord", 26 - "google", 27 - "gitlab", 28 - "oidc" 29 - ] 30 - } 31 - } 32 - } 33 - }, 34 - { 35 - "ordinal": 3, 36 - "name": "provider_user_id", 37 - "type_info": "Text" 38 - }, 39 - { 40 - "ordinal": 4, 41 - "name": "provider_username", 42 - "type_info": "Text" 43 - }, 44 - { 45 - "ordinal": 5, 46 - "name": "provider_email", 47 - "type_info": "Text" 48 - }, 49 - { 50 - "ordinal": 6, 51 - "name": "created_at", 52 - "type_info": "Timestamptz" 53 - }, 54 - { 55 - "ordinal": 7, 56 - "name": "expires_at", 57 - "type_info": "Timestamptz" 58 - } 59 - ], 60 - "parameters": { 61 - "Left": [ 62 - "Text" 63 - ] 64 - }, 65 - "nullable": [ 66 - false, 67 - false, 68 - false, 69 - false, 70 - true, 71 - true, 72 - false, 73 - false 74 - ] 75 - }, 76 - "hash": "06eb7c6e1983b6121526ba63612236391290c2e63d37d2bb1cd89ea822950a82" 77 - }
-77
.sqlx/query-5031b96c65078d6c54954ce6e57ff9cbba4c48dd8a7546882ab5647114ffab4a.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "\n DELETE FROM sso_pending_registration\n WHERE token = $1 AND expires_at > NOW()\n RETURNING token, request_uri, provider as \"provider: SsoProviderType\",\n provider_user_id, provider_username, provider_email, created_at, expires_at\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "token", 9 - "type_info": "Text" 10 - }, 11 - { 12 - "ordinal": 1, 13 - "name": "request_uri", 14 - "type_info": "Text" 15 - }, 16 - { 17 - "ordinal": 2, 18 - "name": "provider: SsoProviderType", 19 - "type_info": { 20 - "Custom": { 21 - "name": "sso_provider_type", 22 - "kind": { 23 - "Enum": [ 24 - "github", 25 - "discord", 26 - "google", 27 - "gitlab", 28 - "oidc" 29 - ] 30 - } 31 - } 32 - } 33 - }, 34 - { 35 - "ordinal": 3, 36 - "name": "provider_user_id", 37 - "type_info": "Text" 38 - }, 39 - { 40 - "ordinal": 4, 41 - "name": "provider_username", 42 - "type_info": "Text" 43 - }, 44 - { 45 - "ordinal": 5, 46 - "name": "provider_email", 47 - "type_info": "Text" 48 - }, 49 - { 50 - "ordinal": 6, 51 - "name": "created_at", 52 - "type_info": "Timestamptz" 53 - }, 54 - { 55 - "ordinal": 7, 56 - "name": "expires_at", 57 - "type_info": "Timestamptz" 58 - } 59 - ], 60 - "parameters": { 61 - "Left": [ 62 - "Text" 63 - ] 64 - }, 65 - "nullable": [ 66 - false, 67 - false, 68 - false, 69 - false, 70 - true, 71 - true, 72 - false, 73 - false 74 - ] 75 - }, 76 - "hash": "5031b96c65078d6c54954ce6e57ff9cbba4c48dd8a7546882ab5647114ffab4a" 77 - }
-22
.sqlx/query-6258398accee69e0c5f455a3c0ecc273b3da6ef5bb4d8660adafe63d8e3cd2d4.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "SELECT email_verified FROM users WHERE email = $1 OR handle = $1", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "ordinal": 0, 8 - "name": "email_verified", 9 - "type_info": "Bool" 10 - } 11 - ], 12 - "parameters": { 13 - "Left": [ 14 - "Text" 15 - ] 16 - }, 17 - "nullable": [ 18 - false 19 - ] 20 - }, 21 - "hash": "6258398accee69e0c5f455a3c0ecc273b3da6ef5bb4d8660adafe63d8e3cd2d4" 22 - }
-31
.sqlx/query-a4dc8fb22bd094d414c55b9da20b610f7b122b485ab0fd0d0646d68ae8e64fe6.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "\n INSERT INTO external_identities (did, provider, provider_user_id, provider_username, provider_email)\n VALUES ($1, $2, $3, $4, $5)\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Left": [ 8 - "Text", 9 - { 10 - "Custom": { 11 - "name": "sso_provider_type", 12 - "kind": { 13 - "Enum": [ 14 - "github", 15 - "discord", 16 - "google", 17 - "gitlab", 18 - "oidc" 19 - ] 20 - } 21 - } 22 - }, 23 - "Text", 24 - "Text", 25 - "Text" 26 - ] 27 - }, 28 - "nullable": [] 29 - }, 30 - "hash": "a4dc8fb22bd094d414c55b9da20b610f7b122b485ab0fd0d0646d68ae8e64fe6" 31 - }
-32
.sqlx/query-dec3a21a8e60cc8d2c5dad727750bc88f5535dedae244f7b6e4afa95769b8f1a.json
··· 1 - { 2 - "db_name": "PostgreSQL", 3 - "query": "\n INSERT INTO sso_pending_registration (token, request_uri, provider, provider_user_id, provider_username, provider_email)\n VALUES ($1, $2, $3, $4, $5, $6)\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Left": [ 8 - "Text", 9 - "Text", 10 - { 11 - "Custom": { 12 - "name": "sso_provider_type", 13 - "kind": { 14 - "Enum": [ 15 - "github", 16 - "discord", 17 - "google", 18 - "gitlab", 19 - "oidc" 20 - ] 21 - } 22 - } 23 - }, 24 - "Text", 25 - "Text", 26 - "Text" 27 - ] 28 - }, 29 - "nullable": [] 30 - }, 31 - "hash": "dec3a21a8e60cc8d2c5dad727750bc88f5535dedae244f7b6e4afa95769b8f1a" 32 - }
+2
Cargo.lock
··· 6156 6156 dependencies = [ 6157 6157 "axum", 6158 6158 "futures", 6159 + "hickory-resolver", 6159 6160 "reqwest", 6160 6161 "serde", 6161 6162 "serde_json", 6162 6163 "tokio", 6163 6164 "tracing", 6165 + "urlencoding", 6164 6166 ] 6165 6167 6166 6168 [[package]]
+7 -12
crates/tranquil-pds/src/api/actor/preferences.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::BearerAuthAllowDeactivated; 2 + use crate::auth::{Auth, NotTakendown, Permissive}; 3 3 use crate::state::AppState; 4 4 use axum::{ 5 5 Json, ··· 32 32 pub struct GetPreferencesOutput { 33 33 pub preferences: Vec<Value>, 34 34 } 35 - pub async fn get_preferences( 36 - State(state): State<AppState>, 37 - auth: BearerAuthAllowDeactivated, 38 - ) -> Response { 39 - let auth_user = auth.0; 40 - let has_full_access = auth_user.permissions().has_full_access(); 41 - let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&auth_user.did).await { 35 + pub async fn get_preferences(State(state): State<AppState>, auth: Auth<Permissive>) -> Response { 36 + let has_full_access = auth.permissions().has_full_access(); 37 + let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&auth.did).await { 42 38 Ok(Some(id)) => id, 43 39 _ => { 44 40 return ApiError::InternalError(Some("User not found".into())).into_response(); ··· 93 89 } 94 90 pub async fn put_preferences( 95 91 State(state): State<AppState>, 96 - auth: BearerAuthAllowDeactivated, 92 + auth: Auth<NotTakendown>, 97 93 Json(input): Json<PutPreferencesInput>, 98 94 ) -> Response { 99 - let auth_user = auth.0; 100 - let has_full_access = auth_user.permissions().has_full_access(); 101 - let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&auth_user.did).await { 95 + let has_full_access = auth.permissions().has_full_access(); 96 + let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&auth.did).await { 102 97 Ok(Some(id)) => id, 103 98 _ => { 104 99 return ApiError::InternalError(Some("User not found".into())).into_response();
+20 -18
crates/tranquil-pds/src/api/admin/account/delete.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::BearerAuthAdmin; 3 + use crate::auth::{Admin, Auth}; 4 4 use crate::state::AppState; 5 5 use crate::types::Did; 6 6 use axum::{ ··· 18 18 19 19 pub async fn delete_account( 20 20 State(state): State<AppState>, 21 - _auth: BearerAuthAdmin, 21 + _auth: Auth<Admin>, 22 22 Json(input): Json<DeleteAccountInput>, 23 - ) -> Response { 23 + ) -> Result<Response, ApiError> { 24 24 let did = &input.did; 25 - let (user_id, handle) = match state.user_repo.get_id_and_handle_by_did(did).await { 26 - Ok(Some(row)) => (row.id, row.handle), 27 - Ok(None) => { 28 - return ApiError::AccountNotFound.into_response(); 29 - } 30 - Err(e) => { 25 + let (user_id, handle) = state 26 + .user_repo 27 + .get_id_and_handle_by_did(did) 28 + .await 29 + .map_err(|e| { 31 30 error!("DB error in delete_account: {:?}", e); 32 - return ApiError::InternalError(None).into_response(); 33 - } 34 - }; 35 - if let Err(e) = state 31 + ApiError::InternalError(None) 32 + })? 33 + .ok_or(ApiError::AccountNotFound) 34 + .map(|row| (row.id, row.handle))?; 35 + 36 + state 36 37 .user_repo 37 38 .admin_delete_account_complete(user_id, did) 38 39 .await 39 - { 40 - error!("Failed to delete account {}: {:?}", did, e); 41 - return ApiError::InternalError(Some("Failed to delete account".into())).into_response(); 42 - } 40 + .map_err(|e| { 41 + error!("Failed to delete account {}: {:?}", did, e); 42 + ApiError::InternalError(Some("Failed to delete account".into())) 43 + })?; 44 + 43 45 if let Err(e) = 44 46 crate::api::repo::record::sequence_account_event(&state, did, false, Some("deleted")).await 45 47 { ··· 49 51 ); 50 52 } 51 53 let _ = state.cache.delete(&format!("handle:{}", handle)).await; 52 - EmptyResponse::ok().into_response() 54 + Ok(EmptyResponse::ok().into_response()) 53 55 }
+16 -21
crates/tranquil-pds/src/api/admin/account/email.rs
··· 1 1 use crate::api::error::{ApiError, AtpJson}; 2 - use crate::auth::BearerAuthAdmin; 2 + use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 4 use crate::types::Did; 5 5 use axum::{ ··· 28 28 29 29 pub async fn send_email( 30 30 State(state): State<AppState>, 31 - _auth: BearerAuthAdmin, 31 + _auth: Auth<Admin>, 32 32 AtpJson(input): AtpJson<SendEmailInput>, 33 - ) -> Response { 33 + ) -> Result<Response, ApiError> { 34 34 let content = input.content.trim(); 35 35 if content.is_empty() { 36 - return ApiError::InvalidRequest("content is required".into()).into_response(); 36 + return Err(ApiError::InvalidRequest("content is required".into())); 37 37 } 38 - let user = match state.user_repo.get_by_did(&input.recipient_did).await { 39 - Ok(Some(row)) => row, 40 - Ok(None) => { 41 - return ApiError::AccountNotFound.into_response(); 42 - } 43 - Err(e) => { 38 + let user = state 39 + .user_repo 40 + .get_by_did(&input.recipient_did) 41 + .await 42 + .map_err(|e| { 44 43 error!("DB error in send_email: {:?}", e); 45 - return ApiError::InternalError(None).into_response(); 46 - } 47 - }; 48 - let email = match user.email { 49 - Some(e) => e, 50 - None => { 51 - return ApiError::NoEmail.into_response(); 52 - } 53 - }; 44 + ApiError::InternalError(None) 45 + })? 46 + .ok_or(ApiError::AccountNotFound)?; 47 + 48 + let email = user.email.ok_or(ApiError::NoEmail)?; 54 49 let (user_id, handle) = (user.id, user.handle); 55 50 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 56 51 let subject = input ··· 76 71 handle, 77 72 input.recipient_did 78 73 ); 79 - (StatusCode::OK, Json(SendEmailOutput { sent: true })).into_response() 74 + Ok((StatusCode::OK, Json(SendEmailOutput { sent: true })).into_response()) 80 75 } 81 76 Err(e) => { 82 77 warn!("Failed to enqueue admin email: {:?}", e); 83 - (StatusCode::OK, Json(SendEmailOutput { sent: false })).into_response() 78 + Ok((StatusCode::OK, Json(SendEmailOutput { sent: false })).into_response()) 84 79 } 85 80 } 86 81 }
+18 -24
crates/tranquil-pds/src/api/admin/account/info.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::BearerAuthAdmin; 2 + use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 4 use crate::types::{Did, Handle}; 5 5 use axum::{ ··· 67 67 68 68 pub async fn get_account_info( 69 69 State(state): State<AppState>, 70 - _auth: BearerAuthAdmin, 70 + _auth: Auth<Admin>, 71 71 Query(params): Query<GetAccountInfoParams>, 72 - ) -> Response { 73 - let account = match state 72 + ) -> Result<Response, ApiError> { 73 + let account = state 74 74 .infra_repo 75 75 .get_admin_account_info_by_did(&params.did) 76 76 .await 77 - { 78 - Ok(Some(a)) => a, 79 - Ok(None) => return ApiError::AccountNotFound.into_response(), 80 - Err(e) => { 77 + .map_err(|e| { 81 78 error!("DB error in get_account_info: {:?}", e); 82 - return ApiError::InternalError(None).into_response(); 83 - } 84 - }; 79 + ApiError::InternalError(None) 80 + })? 81 + .ok_or(ApiError::AccountNotFound)?; 85 82 86 83 let invited_by = get_invited_by(&state, account.id).await; 87 84 let invites = get_invites_for_user(&state, account.id).await; 88 85 89 - ( 86 + Ok(( 90 87 StatusCode::OK, 91 88 Json(AccountInfo { 92 89 did: account.did, ··· 105 102 invites, 106 103 }), 107 104 ) 108 - .into_response() 105 + .into_response()) 109 106 } 110 107 111 108 async fn get_invited_by(state: &AppState, user_id: uuid::Uuid) -> Option<InviteCodeInfo> { ··· 200 197 201 198 pub async fn get_account_infos( 202 199 State(state): State<AppState>, 203 - _auth: BearerAuthAdmin, 200 + _auth: Auth<Admin>, 204 201 RawQuery(raw_query): RawQuery, 205 - ) -> Response { 202 + ) -> Result<Response, ApiError> { 206 203 let dids: Vec<String> = crate::util::parse_repeated_query_param(raw_query.as_deref(), "dids") 207 204 .into_iter() 208 205 .filter(|d| !d.is_empty()) 209 206 .collect(); 210 207 211 208 if dids.is_empty() { 212 - return ApiError::InvalidRequest("dids is required".into()).into_response(); 209 + return Err(ApiError::InvalidRequest("dids is required".into())); 213 210 } 214 211 215 212 let dids_typed: Vec<Did> = dids.iter().filter_map(|d| d.parse().ok()).collect(); 216 - let accounts = match state 213 + let accounts = state 217 214 .infra_repo 218 215 .get_admin_account_infos_by_dids(&dids_typed) 219 216 .await 220 - { 221 - Ok(accounts) => accounts, 222 - Err(e) => { 217 + .map_err(|e| { 223 218 error!("Failed to fetch account infos: {:?}", e); 224 - return ApiError::InternalError(None).into_response(); 225 - } 226 - }; 219 + ApiError::InternalError(None) 220 + })?; 227 221 228 222 let user_ids: Vec<uuid::Uuid> = accounts.iter().map(|u| u.id).collect(); 229 223 ··· 316 310 }) 317 311 .collect(); 318 312 319 - (StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response() 313 + Ok((StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response()) 320 314 }
+39 -42
crates/tranquil-pds/src/api/admin/account/search.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::BearerAuthAdmin; 2 + use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 4 use crate::types::{Did, Handle}; 5 5 use axum::{ ··· 50 50 51 51 pub async fn search_accounts( 52 52 State(state): State<AppState>, 53 - _auth: BearerAuthAdmin, 53 + _auth: Auth<Admin>, 54 54 Query(params): Query<SearchAccountsParams>, 55 - ) -> Response { 55 + ) -> Result<Response, ApiError> { 56 56 let limit = params.limit.clamp(1, 100); 57 57 let email_filter = params.email.as_deref().map(|e| format!("%{}%", e)); 58 58 let handle_filter = params.handle.as_deref().map(|h| format!("%{}%", h)); 59 59 let cursor_did: Option<Did> = params.cursor.as_ref().and_then(|c| c.parse().ok()); 60 - let result = state 60 + let rows = state 61 61 .user_repo 62 62 .search_accounts( 63 63 cursor_did.as_ref(), ··· 65 65 handle_filter.as_deref(), 66 66 limit + 1, 67 67 ) 68 - .await; 69 - match result { 70 - Ok(rows) => { 71 - let has_more = rows.len() > limit as usize; 72 - let accounts: Vec<AccountView> = rows 73 - .into_iter() 74 - .take(limit as usize) 75 - .map(|row| AccountView { 76 - did: row.did.clone(), 77 - handle: row.handle, 78 - email: row.email, 79 - indexed_at: row.created_at.to_rfc3339(), 80 - email_confirmed_at: if row.email_verified { 81 - Some(row.created_at.to_rfc3339()) 82 - } else { 83 - None 84 - }, 85 - deactivated_at: row.deactivated_at.map(|dt| dt.to_rfc3339()), 86 - invites_disabled: row.invites_disabled, 87 - }) 88 - .collect(); 89 - let next_cursor = if has_more { 90 - accounts.last().map(|a| a.did.to_string()) 68 + .await 69 + .map_err(|e| { 70 + error!("DB error in search_accounts: {:?}", e); 71 + ApiError::InternalError(None) 72 + })?; 73 + 74 + let has_more = rows.len() > limit as usize; 75 + let accounts: Vec<AccountView> = rows 76 + .into_iter() 77 + .take(limit as usize) 78 + .map(|row| AccountView { 79 + did: row.did.clone(), 80 + handle: row.handle, 81 + email: row.email, 82 + indexed_at: row.created_at.to_rfc3339(), 83 + email_confirmed_at: if row.email_verified { 84 + Some(row.created_at.to_rfc3339()) 91 85 } else { 92 86 None 93 - }; 94 - ( 95 - StatusCode::OK, 96 - Json(SearchAccountsOutput { 97 - cursor: next_cursor, 98 - accounts, 99 - }), 100 - ) 101 - .into_response() 102 - } 103 - Err(e) => { 104 - error!("DB error in search_accounts: {:?}", e); 105 - ApiError::InternalError(None).into_response() 106 - } 107 - } 87 + }, 88 + deactivated_at: row.deactivated_at.map(|dt| dt.to_rfc3339()), 89 + invites_disabled: row.invites_disabled, 90 + }) 91 + .collect(); 92 + let next_cursor = if has_more { 93 + accounts.last().map(|a| a.did.to_string()) 94 + } else { 95 + None 96 + }; 97 + Ok(( 98 + StatusCode::OK, 99 + Json(SearchAccountsOutput { 100 + cursor: next_cursor, 101 + accounts, 102 + }), 103 + ) 104 + .into_response()) 108 105 }
+39 -36
crates/tranquil-pds/src/api/admin/account/update.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::BearerAuthAdmin; 3 + use crate::auth::{Admin, Auth}; 4 4 use crate::state::AppState; 5 5 use crate::types::{Did, Handle, PlainPassword}; 6 6 use axum::{ ··· 19 19 20 20 pub async fn update_account_email( 21 21 State(state): State<AppState>, 22 - _auth: BearerAuthAdmin, 22 + _auth: Auth<Admin>, 23 23 Json(input): Json<UpdateAccountEmailInput>, 24 - ) -> Response { 24 + ) -> Result<Response, ApiError> { 25 25 let account = input.account.trim(); 26 26 let email = input.email.trim(); 27 27 if account.is_empty() || email.is_empty() { 28 - return ApiError::InvalidRequest("account and email are required".into()).into_response(); 28 + return Err(ApiError::InvalidRequest( 29 + "account and email are required".into(), 30 + )); 29 31 } 30 - let account_did: Did = match account.parse() { 31 - Ok(d) => d, 32 - Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 33 - }; 32 + let account_did: Did = account 33 + .parse() 34 + .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 35 + 34 36 match state 35 37 .user_repo 36 38 .admin_update_email(&account_did, email) 37 39 .await 38 40 { 39 - Ok(0) => ApiError::AccountNotFound.into_response(), 40 - Ok(_) => EmptyResponse::ok().into_response(), 41 + Ok(0) => Err(ApiError::AccountNotFound), 42 + Ok(_) => Ok(EmptyResponse::ok().into_response()), 41 43 Err(e) => { 42 44 error!("DB error updating email: {:?}", e); 43 - ApiError::InternalError(None).into_response() 45 + Err(ApiError::InternalError(None)) 44 46 } 45 47 } 46 48 } ··· 53 55 54 56 pub async fn update_account_handle( 55 57 State(state): State<AppState>, 56 - _auth: BearerAuthAdmin, 58 + _auth: Auth<Admin>, 57 59 Json(input): Json<UpdateAccountHandleInput>, 58 - ) -> Response { 60 + ) -> Result<Response, ApiError> { 59 61 let did = &input.did; 60 62 let input_handle = input.handle.trim(); 61 63 if input_handle.is_empty() { 62 - return ApiError::InvalidRequest("handle is required".into()).into_response(); 64 + return Err(ApiError::InvalidRequest("handle is required".into())); 63 65 } 64 66 if !input_handle 65 67 .chars() 66 68 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_') 67 69 { 68 - return ApiError::InvalidHandle(None).into_response(); 70 + return Err(ApiError::InvalidHandle(None)); 69 71 } 70 72 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 71 73 let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); ··· 75 77 input_handle.to_string() 76 78 }; 77 79 let old_handle = state.user_repo.get_handle_by_did(did).await.ok().flatten(); 78 - let user_id = match state.user_repo.get_id_by_did(did).await { 79 - Ok(Some(id)) => id, 80 - _ => return ApiError::AccountNotFound.into_response(), 81 - }; 80 + let user_id = state 81 + .user_repo 82 + .get_id_by_did(did) 83 + .await 84 + .ok() 85 + .flatten() 86 + .ok_or(ApiError::AccountNotFound)?; 82 87 let handle_for_check = Handle::new_unchecked(&handle); 83 88 if let Ok(true) = state 84 89 .user_repo 85 90 .check_handle_exists(&handle_for_check, user_id) 86 91 .await 87 92 { 88 - return ApiError::HandleTaken.into_response(); 93 + return Err(ApiError::HandleTaken); 89 94 } 90 95 match state 91 96 .user_repo 92 97 .admin_update_handle(did, &handle_for_check) 93 98 .await 94 99 { 95 - Ok(0) => ApiError::AccountNotFound.into_response(), 100 + Ok(0) => Err(ApiError::AccountNotFound), 96 101 Ok(_) => { 97 102 if let Some(old) = old_handle { 98 103 let _ = state.cache.delete(&format!("handle:{}", old)).await; ··· 115 120 { 116 121 warn!("Failed to update PLC handle for admin handle update: {}", e); 117 122 } 118 - EmptyResponse::ok().into_response() 123 + Ok(EmptyResponse::ok().into_response()) 119 124 } 120 125 Err(e) => { 121 126 error!("DB error updating handle: {:?}", e); 122 - ApiError::InternalError(None).into_response() 127 + Err(ApiError::InternalError(None)) 123 128 } 124 129 } 125 130 } ··· 132 137 133 138 pub async fn update_account_password( 134 139 State(state): State<AppState>, 135 - _auth: BearerAuthAdmin, 140 + _auth: Auth<Admin>, 136 141 Json(input): Json<UpdateAccountPasswordInput>, 137 - ) -> Response { 142 + ) -> Result<Response, ApiError> { 138 143 let did = &input.did; 139 144 let password = input.password.trim(); 140 145 if password.is_empty() { 141 - return ApiError::InvalidRequest("password is required".into()).into_response(); 146 + return Err(ApiError::InvalidRequest("password is required".into())); 142 147 } 143 - let password_hash = match bcrypt::hash(password, bcrypt::DEFAULT_COST) { 144 - Ok(h) => h, 145 - Err(e) => { 146 - error!("Failed to hash password: {:?}", e); 147 - return ApiError::InternalError(None).into_response(); 148 - } 149 - }; 148 + let password_hash = bcrypt::hash(password, bcrypt::DEFAULT_COST).map_err(|e| { 149 + error!("Failed to hash password: {:?}", e); 150 + ApiError::InternalError(None) 151 + })?; 152 + 150 153 match state 151 154 .user_repo 152 155 .admin_update_password(did, &password_hash) 153 156 .await 154 157 { 155 - Ok(0) => ApiError::AccountNotFound.into_response(), 156 - Ok(_) => EmptyResponse::ok().into_response(), 158 + Ok(0) => Err(ApiError::AccountNotFound), 159 + Ok(_) => Ok(EmptyResponse::ok().into_response()), 157 160 Err(e) => { 158 161 error!("DB error updating password: {:?}", e); 159 - ApiError::InternalError(None).into_response() 162 + Err(ApiError::InternalError(None)) 160 163 } 161 164 } 162 165 }
+2 -2
crates/tranquil-pds/src/api/admin/config.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::BearerAuthAdmin; 2 + use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 4 use axum::{Json, extract::State}; 5 5 use serde::{Deserialize, Serialize}; ··· 78 78 79 79 pub async fn update_server_config( 80 80 State(state): State<AppState>, 81 - _admin: BearerAuthAdmin, 81 + _auth: Auth<Admin>, 82 82 Json(req): Json<UpdateServerConfigRequest>, 83 83 ) -> Result<Json<UpdateServerConfigResponse>, ApiError> { 84 84 if let Some(server_name) = req.server_name {
+32 -35
crates/tranquil-pds/src/api/admin/invite.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::BearerAuthAdmin; 3 + use crate::auth::{Admin, Auth}; 4 4 use crate::state::AppState; 5 5 use axum::{ 6 6 Json, ··· 21 21 22 22 pub async fn disable_invite_codes( 23 23 State(state): State<AppState>, 24 - _auth: BearerAuthAdmin, 24 + _auth: Auth<Admin>, 25 25 Json(input): Json<DisableInviteCodesInput>, 26 - ) -> Response { 26 + ) -> Result<Response, ApiError> { 27 27 if let Some(codes) = &input.codes 28 28 && let Err(e) = state.infra_repo.disable_invite_codes_by_code(codes).await 29 29 { ··· 40 40 error!("DB error disabling invite codes by account: {:?}", e); 41 41 } 42 42 } 43 - EmptyResponse::ok().into_response() 43 + Ok(EmptyResponse::ok().into_response()) 44 44 } 45 45 46 46 #[derive(Deserialize)] ··· 78 78 79 79 pub async fn get_invite_codes( 80 80 State(state): State<AppState>, 81 - _auth: BearerAuthAdmin, 81 + _auth: Auth<Admin>, 82 82 Query(params): Query<GetInviteCodesParams>, 83 - ) -> Response { 83 + ) -> Result<Response, ApiError> { 84 84 let limit = params.limit.unwrap_or(100).clamp(1, 500); 85 85 let sort_order = match params.sort.as_deref() { 86 86 Some("usage") => InviteCodeSortOrder::Usage, 87 87 _ => InviteCodeSortOrder::Recent, 88 88 }; 89 89 90 - let codes_rows = match state 90 + let codes_rows = state 91 91 .infra_repo 92 92 .list_invite_codes(params.cursor.as_deref(), limit, sort_order) 93 93 .await 94 - { 95 - Ok(rows) => rows, 96 - Err(e) => { 94 + .map_err(|e| { 97 95 error!("DB error fetching invite codes: {:?}", e); 98 - return ApiError::InternalError(None).into_response(); 99 - } 100 - }; 96 + ApiError::InternalError(None) 97 + })?; 101 98 102 99 let user_ids: Vec<uuid::Uuid> = codes_rows.iter().map(|r| r.created_by_user).collect(); 103 100 let code_strings: Vec<String> = codes_rows.iter().map(|r| r.code.clone()).collect(); ··· 155 152 } else { 156 153 None 157 154 }; 158 - ( 155 + Ok(( 159 156 StatusCode::OK, 160 157 Json(GetInviteCodesOutput { 161 158 cursor: next_cursor, 162 159 codes, 163 160 }), 164 161 ) 165 - .into_response() 162 + .into_response()) 166 163 } 167 164 168 165 #[derive(Deserialize)] ··· 172 169 173 170 pub async fn disable_account_invites( 174 171 State(state): State<AppState>, 175 - _auth: BearerAuthAdmin, 172 + _auth: Auth<Admin>, 176 173 Json(input): Json<DisableAccountInvitesInput>, 177 - ) -> Response { 174 + ) -> Result<Response, ApiError> { 178 175 let account = input.account.trim(); 179 176 if account.is_empty() { 180 - return ApiError::InvalidRequest("account is required".into()).into_response(); 177 + return Err(ApiError::InvalidRequest("account is required".into())); 181 178 } 182 - let account_did: tranquil_types::Did = match account.parse() { 183 - Ok(d) => d, 184 - Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 185 - }; 179 + let account_did: tranquil_types::Did = account 180 + .parse() 181 + .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 182 + 186 183 match state 187 184 .user_repo 188 185 .set_invites_disabled(&account_did, true) 189 186 .await 190 187 { 191 - Ok(true) => EmptyResponse::ok().into_response(), 192 - Ok(false) => ApiError::AccountNotFound.into_response(), 188 + Ok(true) => Ok(EmptyResponse::ok().into_response()), 189 + Ok(false) => Err(ApiError::AccountNotFound), 193 190 Err(e) => { 194 191 error!("DB error disabling account invites: {:?}", e); 195 - ApiError::InternalError(None).into_response() 192 + Err(ApiError::InternalError(None)) 196 193 } 197 194 } 198 195 } ··· 204 201 205 202 pub async fn enable_account_invites( 206 203 State(state): State<AppState>, 207 - _auth: BearerAuthAdmin, 204 + _auth: Auth<Admin>, 208 205 Json(input): Json<EnableAccountInvitesInput>, 209 - ) -> Response { 206 + ) -> Result<Response, ApiError> { 210 207 let account = input.account.trim(); 211 208 if account.is_empty() { 212 - return ApiError::InvalidRequest("account is required".into()).into_response(); 209 + return Err(ApiError::InvalidRequest("account is required".into())); 213 210 } 214 - let account_did: tranquil_types::Did = match account.parse() { 215 - Ok(d) => d, 216 - Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 217 - }; 211 + let account_did: tranquil_types::Did = account 212 + .parse() 213 + .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 214 + 218 215 match state 219 216 .user_repo 220 217 .set_invites_disabled(&account_did, false) 221 218 .await 222 219 { 223 - Ok(true) => EmptyResponse::ok().into_response(), 224 - Ok(false) => ApiError::AccountNotFound.into_response(), 220 + Ok(true) => Ok(EmptyResponse::ok().into_response()), 221 + Ok(false) => Err(ApiError::AccountNotFound), 225 222 Err(e) => { 226 223 error!("DB error enabling account invites: {:?}", e); 227 - ApiError::InternalError(None).into_response() 224 + Err(ApiError::InternalError(None)) 228 225 } 229 226 } 230 227 }
+8 -4
crates/tranquil-pds/src/api/admin/server_stats.rs
··· 1 - use crate::auth::BearerAuthAdmin; 1 + use crate::api::error::ApiError; 2 + use crate::auth::{Admin, Auth}; 2 3 use crate::state::AppState; 3 4 use axum::{ 4 5 Json, ··· 16 17 pub blob_storage_bytes: i64, 17 18 } 18 19 19 - pub async fn get_server_stats(State(state): State<AppState>, _auth: BearerAuthAdmin) -> Response { 20 + pub async fn get_server_stats( 21 + State(state): State<AppState>, 22 + _auth: Auth<Admin>, 23 + ) -> Result<Response, ApiError> { 20 24 let user_count = state.user_repo.count_users().await.unwrap_or(0); 21 25 let repo_count = state.repo_repo.count_repos().await.unwrap_or(0); 22 26 let record_count = state.repo_repo.count_all_records().await.unwrap_or(0); 23 27 let blob_storage_bytes = state.blob_repo.sum_blob_storage().await.unwrap_or(0); 24 28 25 - Json(ServerStatsResponse { 29 + Ok(Json(ServerStatsResponse { 26 30 user_count, 27 31 repo_count, 28 32 record_count, 29 33 blob_storage_bytes, 30 34 }) 31 - .into_response() 35 + .into_response()) 32 36 }
+73 -94
crates/tranquil-pds/src/api/admin/status.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::BearerAuthAdmin; 2 + use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 4 use crate::types::{CidLink, Did}; 5 5 use axum::{ ··· 35 35 36 36 pub async fn get_subject_status( 37 37 State(state): State<AppState>, 38 - _auth: BearerAuthAdmin, 38 + _auth: Auth<Admin>, 39 39 Query(params): Query<GetSubjectStatusParams>, 40 - ) -> Response { 40 + ) -> Result<Response, ApiError> { 41 41 if params.did.is_none() && params.uri.is_none() && params.blob.is_none() { 42 - return ApiError::InvalidRequest("Must provide did, uri, or blob".into()).into_response(); 42 + return Err(ApiError::InvalidRequest( 43 + "Must provide did, uri, or blob".into(), 44 + )); 43 45 } 44 46 if let Some(did_str) = &params.did { 45 - let did: Did = match did_str.parse() { 46 - Ok(d) => d, 47 - Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 48 - }; 47 + let did: Did = did_str 48 + .parse() 49 + .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 49 50 match state.user_repo.get_status_by_did(&did).await { 50 51 Ok(Some(status)) => { 51 52 let deactivated = status.deactivated_at.map(|_| StatusAttr { ··· 56 57 applied: true, 57 58 r#ref: Some(r.clone()), 58 59 }); 59 - return ( 60 + return Ok(( 60 61 StatusCode::OK, 61 62 Json(SubjectStatus { 62 63 subject: json!({ ··· 67 68 deactivated, 68 69 }), 69 70 ) 70 - .into_response(); 71 + .into_response()); 71 72 } 72 73 Ok(None) => { 73 - return ApiError::SubjectNotFound.into_response(); 74 + return Err(ApiError::SubjectNotFound); 74 75 } 75 76 Err(e) => { 76 77 error!("DB error in get_subject_status: {:?}", e); 77 - return ApiError::InternalError(None).into_response(); 78 + return Err(ApiError::InternalError(None)); 78 79 } 79 80 } 80 81 } 81 82 if let Some(uri_str) = &params.uri { 82 - let cid: CidLink = match uri_str.parse() { 83 - Ok(c) => c, 84 - Err(_) => return ApiError::InvalidRequest("Invalid CID format".into()).into_response(), 85 - }; 83 + let cid: CidLink = uri_str 84 + .parse() 85 + .map_err(|_| ApiError::InvalidRequest("Invalid CID format".into()))?; 86 86 match state.repo_repo.get_record_by_cid(&cid).await { 87 87 Ok(Some(record)) => { 88 88 let takedown = record.takedown_ref.as_ref().map(|r| StatusAttr { 89 89 applied: true, 90 90 r#ref: Some(r.clone()), 91 91 }); 92 - return ( 92 + return Ok(( 93 93 StatusCode::OK, 94 94 Json(SubjectStatus { 95 95 subject: json!({ ··· 101 101 deactivated: None, 102 102 }), 103 103 ) 104 - .into_response(); 104 + .into_response()); 105 105 } 106 106 Ok(None) => { 107 - return ApiError::RecordNotFound.into_response(); 107 + return Err(ApiError::RecordNotFound); 108 108 } 109 109 Err(e) => { 110 110 error!("DB error in get_subject_status: {:?}", e); 111 - return ApiError::InternalError(None).into_response(); 111 + return Err(ApiError::InternalError(None)); 112 112 } 113 113 } 114 114 } 115 115 if let Some(blob_cid_str) = &params.blob { 116 - let blob_cid: CidLink = match blob_cid_str.parse() { 117 - Ok(c) => c, 118 - Err(_) => return ApiError::InvalidRequest("Invalid CID format".into()).into_response(), 119 - }; 120 - let did = match &params.did { 121 - Some(d) => d, 122 - None => { 123 - return ApiError::InvalidRequest("Must provide a did to request blob state".into()) 124 - .into_response(); 125 - } 126 - }; 116 + let blob_cid: CidLink = blob_cid_str 117 + .parse() 118 + .map_err(|_| ApiError::InvalidRequest("Invalid CID format".into()))?; 119 + let did = params.did.as_ref().ok_or_else(|| { 120 + ApiError::InvalidRequest("Must provide a did to request blob state".into()) 121 + })?; 127 122 match state.blob_repo.get_blob_with_takedown(&blob_cid).await { 128 123 Ok(Some(blob)) => { 129 124 let takedown = blob.takedown_ref.as_ref().map(|r| StatusAttr { 130 125 applied: true, 131 126 r#ref: Some(r.clone()), 132 127 }); 133 - return ( 128 + return Ok(( 134 129 StatusCode::OK, 135 130 Json(SubjectStatus { 136 131 subject: json!({ ··· 142 137 deactivated: None, 143 138 }), 144 139 ) 145 - .into_response(); 140 + .into_response()); 146 141 } 147 142 Ok(None) => { 148 - return ApiError::BlobNotFound(None).into_response(); 143 + return Err(ApiError::BlobNotFound(None)); 149 144 } 150 145 Err(e) => { 151 146 error!("DB error in get_subject_status: {:?}", e); 152 - return ApiError::InternalError(None).into_response(); 147 + return Err(ApiError::InternalError(None)); 153 148 } 154 149 } 155 150 } 156 - ApiError::InvalidRequest("Invalid subject type".into()).into_response() 151 + Err(ApiError::InvalidRequest("Invalid subject type".into())) 157 152 } 158 153 159 154 #[derive(Deserialize)] ··· 172 167 173 168 pub async fn update_subject_status( 174 169 State(state): State<AppState>, 175 - _auth: BearerAuthAdmin, 170 + _auth: Auth<Admin>, 176 171 Json(input): Json<UpdateSubjectStatusInput>, 177 - ) -> Response { 172 + ) -> Result<Response, ApiError> { 178 173 let subject_type = input.subject.get("$type").and_then(|t| t.as_str()); 179 174 match subject_type { 180 175 Some("com.atproto.admin.defs#repoRef") => { ··· 187 182 } else { 188 183 None 189 184 }; 190 - if let Err(e) = state.user_repo.set_user_takedown(&did, takedown_ref).await { 191 - error!("Failed to update user takedown status for {}: {:?}", did, e); 192 - return ApiError::InternalError(Some( 193 - "Failed to update takedown status".into(), 194 - )) 195 - .into_response(); 196 - } 185 + state 186 + .user_repo 187 + .set_user_takedown(&did, takedown_ref) 188 + .await 189 + .map_err(|e| { 190 + error!("Failed to update user takedown status for {}: {:?}", did, e); 191 + ApiError::InternalError(Some("Failed to update takedown status".into())) 192 + })?; 197 193 } 198 194 if let Some(deactivated) = &input.deactivated { 199 195 let result = if deactivated.applied { ··· 201 197 } else { 202 198 state.user_repo.activate_account(&did).await 203 199 }; 204 - if let Err(e) = result { 200 + result.map_err(|e| { 205 201 error!( 206 202 "Failed to update user deactivation status for {}: {:?}", 207 203 did, e 208 204 ); 209 - return ApiError::InternalError(Some( 210 - "Failed to update deactivation status".into(), 211 - )) 212 - .into_response(); 213 - } 205 + ApiError::InternalError(Some("Failed to update deactivation status".into())) 206 + })?; 214 207 } 215 208 if let Some(takedown) = &input.takedown { 216 209 let status = if takedown.applied { ··· 249 242 if let Ok(Some(handle)) = state.user_repo.get_handle_by_did(&did).await { 250 243 let _ = state.cache.delete(&format!("handle:{}", handle)).await; 251 244 } 252 - return ( 245 + return Ok(( 253 246 StatusCode::OK, 254 247 Json(json!({ 255 248 "subject": input.subject, ··· 262 255 })) 263 256 })), 264 257 ) 265 - .into_response(); 258 + .into_response()); 266 259 } 267 260 } 268 261 Some("com.atproto.repo.strongRef") => { 269 262 let uri_str = input.subject.get("uri").and_then(|u| u.as_str()); 270 263 if let Some(uri_str) = uri_str { 271 - let cid: CidLink = match uri_str.parse() { 272 - Ok(c) => c, 273 - Err(_) => { 274 - return ApiError::InvalidRequest("Invalid CID format".into()) 275 - .into_response(); 276 - } 277 - }; 264 + let cid: CidLink = uri_str 265 + .parse() 266 + .map_err(|_| ApiError::InvalidRequest("Invalid CID format".into()))?; 278 267 if let Some(takedown) = &input.takedown { 279 268 let takedown_ref = if takedown.applied { 280 269 takedown.r#ref.as_deref() 281 270 } else { 282 271 None 283 272 }; 284 - if let Err(e) = state 273 + state 285 274 .repo_repo 286 275 .set_record_takedown(&cid, takedown_ref) 287 276 .await 288 - { 289 - error!( 290 - "Failed to update record takedown status for {}: {:?}", 291 - uri_str, e 292 - ); 293 - return ApiError::InternalError(Some( 294 - "Failed to update takedown status".into(), 295 - )) 296 - .into_response(); 297 - } 277 + .map_err(|e| { 278 + error!( 279 + "Failed to update record takedown status for {}: {:?}", 280 + uri_str, e 281 + ); 282 + ApiError::InternalError(Some("Failed to update takedown status".into())) 283 + })?; 298 284 } 299 - return ( 285 + return Ok(( 300 286 StatusCode::OK, 301 287 Json(json!({ 302 288 "subject": input.subject, ··· 306 292 })) 307 293 })), 308 294 ) 309 - .into_response(); 295 + .into_response()); 310 296 } 311 297 } 312 298 Some("com.atproto.admin.defs#repoBlobRef") => { 313 299 let cid_str = input.subject.get("cid").and_then(|c| c.as_str()); 314 300 if let Some(cid_str) = cid_str { 315 - let cid: CidLink = match cid_str.parse() { 316 - Ok(c) => c, 317 - Err(_) => { 318 - return ApiError::InvalidRequest("Invalid CID format".into()) 319 - .into_response(); 320 - } 321 - }; 301 + let cid: CidLink = cid_str 302 + .parse() 303 + .map_err(|_| ApiError::InvalidRequest("Invalid CID format".into()))?; 322 304 if let Some(takedown) = &input.takedown { 323 305 let takedown_ref = if takedown.applied { 324 306 takedown.r#ref.as_deref() 325 307 } else { 326 308 None 327 309 }; 328 - if let Err(e) = state 310 + state 329 311 .blob_repo 330 312 .update_blob_takedown(&cid, takedown_ref) 331 313 .await 332 - { 333 - error!( 334 - "Failed to update blob takedown status for {}: {:?}", 335 - cid_str, e 336 - ); 337 - return ApiError::InternalError(Some( 338 - "Failed to update takedown status".into(), 339 - )) 340 - .into_response(); 341 - } 314 + .map_err(|e| { 315 + error!( 316 + "Failed to update blob takedown status for {}: {:?}", 317 + cid_str, e 318 + ); 319 + ApiError::InternalError(Some("Failed to update takedown status".into())) 320 + })?; 342 321 } 343 - return ( 322 + return Ok(( 344 323 StatusCode::OK, 345 324 Json(json!({ 346 325 "subject": input.subject, ··· 350 329 })) 351 330 })), 352 331 ) 353 - .into_response(); 332 + .into_response()); 354 333 } 355 334 } 356 335 _ => {} 357 336 } 358 - ApiError::InvalidRequest("Invalid subject type".into()).into_response() 337 + Err(ApiError::InvalidRequest("Invalid subject type".into())) 359 338 }
+93 -73
crates/tranquil-pds/src/api/backup.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::{EmptyResponse, EnabledResponse}; 3 - use crate::auth::BearerAuth; 3 + use crate::auth::{Active, Auth}; 4 4 use crate::scheduled::generate_full_backup; 5 5 use crate::state::AppState; 6 6 use crate::storage::{BackupStorage, backup_retention_count}; ··· 35 35 pub backup_enabled: bool, 36 36 } 37 37 38 - pub async fn list_backups(State(state): State<AppState>, auth: BearerAuth) -> Response { 39 - let (user_id, backup_enabled) = 40 - match state.backup_repo.get_user_backup_status(&auth.0.did).await { 41 - Ok(Some(status)) => status, 42 - Ok(None) => { 43 - return ApiError::AccountNotFound.into_response(); 44 - } 45 - Err(e) => { 46 - error!("DB error fetching user: {:?}", e); 47 - return ApiError::InternalError(None).into_response(); 48 - } 49 - }; 38 + pub async fn list_backups( 39 + State(state): State<AppState>, 40 + auth: Auth<Active>, 41 + ) -> Result<Response, crate::api::error::ApiError> { 42 + let (user_id, backup_enabled) = match state.backup_repo.get_user_backup_status(&auth.did).await 43 + { 44 + Ok(Some(status)) => status, 45 + Ok(None) => { 46 + return Ok(ApiError::AccountNotFound.into_response()); 47 + } 48 + Err(e) => { 49 + error!("DB error fetching user: {:?}", e); 50 + return Ok(ApiError::InternalError(None).into_response()); 51 + } 52 + }; 50 53 51 54 let backups = match state.backup_repo.list_backups_for_user(user_id).await { 52 55 Ok(rows) => rows, 53 56 Err(e) => { 54 57 error!("DB error fetching backups: {:?}", e); 55 - return ApiError::InternalError(None).into_response(); 58 + return Ok(ApiError::InternalError(None).into_response()); 56 59 } 57 60 }; 58 61 ··· 68 71 }) 69 72 .collect(); 70 73 71 - ( 74 + Ok(( 72 75 StatusCode::OK, 73 76 Json(ListBackupsOutput { 74 77 backups: backup_list, 75 78 backup_enabled, 76 79 }), 77 80 ) 78 - .into_response() 81 + .into_response()) 79 82 } 80 83 81 84 #[derive(Deserialize)] ··· 85 88 86 89 pub async fn get_backup( 87 90 State(state): State<AppState>, 88 - auth: BearerAuth, 91 + auth: Auth<Active>, 89 92 Query(query): Query<GetBackupQuery>, 90 - ) -> Response { 93 + ) -> Result<Response, crate::api::error::ApiError> { 91 94 let backup_id = match uuid::Uuid::parse_str(&query.id) { 92 95 Ok(id) => id, 93 96 Err(_) => { 94 - return ApiError::InvalidRequest("Invalid backup ID".into()).into_response(); 97 + return Ok(ApiError::InvalidRequest("Invalid backup ID".into()).into_response()); 95 98 } 96 99 }; 97 100 98 101 let backup_info = match state 99 102 .backup_repo 100 - .get_backup_storage_info(backup_id, &auth.0.did) 103 + .get_backup_storage_info(backup_id, &auth.did) 101 104 .await 102 105 { 103 106 Ok(Some(b)) => b, 104 107 Ok(None) => { 105 - return ApiError::BackupNotFound.into_response(); 108 + return Ok(ApiError::BackupNotFound.into_response()); 106 109 } 107 110 Err(e) => { 108 111 error!("DB error fetching backup: {:?}", e); 109 - return ApiError::InternalError(None).into_response(); 112 + return Ok(ApiError::InternalError(None).into_response()); 110 113 } 111 114 }; 112 115 113 116 let backup_storage = match state.backup_storage.as_ref() { 114 117 Some(storage) => storage, 115 118 None => { 116 - return ApiError::BackupsDisabled.into_response(); 119 + return Ok(ApiError::BackupsDisabled.into_response()); 117 120 } 118 121 }; 119 122 ··· 121 124 Ok(bytes) => bytes, 122 125 Err(e) => { 123 126 error!("Failed to fetch backup from storage: {:?}", e); 124 - return ApiError::InternalError(Some("Failed to retrieve backup".into())) 125 - .into_response(); 127 + return Ok( 128 + ApiError::InternalError(Some("Failed to retrieve backup".into())).into_response(), 129 + ); 126 130 } 127 131 }; 128 132 129 - ( 133 + Ok(( 130 134 StatusCode::OK, 131 135 [ 132 136 (axum::http::header::CONTENT_TYPE, "application/vnd.ipld.car"), ··· 137 141 ], 138 142 car_bytes, 139 143 ) 140 - .into_response() 144 + .into_response()) 141 145 } 142 146 143 147 #[derive(Serialize)] ··· 149 153 pub block_count: i32, 150 154 } 151 155 152 - pub async fn create_backup(State(state): State<AppState>, auth: BearerAuth) -> Response { 156 + pub async fn create_backup( 157 + State(state): State<AppState>, 158 + auth: Auth<Active>, 159 + ) -> Result<Response, crate::api::error::ApiError> { 153 160 let backup_storage = match state.backup_storage.as_ref() { 154 161 Some(storage) => storage, 155 162 None => { 156 - return ApiError::BackupsDisabled.into_response(); 163 + return Ok(ApiError::BackupsDisabled.into_response()); 157 164 } 158 165 }; 159 166 160 - let user = match state.backup_repo.get_user_for_backup(&auth.0.did).await { 167 + let user = match state.backup_repo.get_user_for_backup(&auth.did).await { 161 168 Ok(Some(u)) => u, 162 169 Ok(None) => { 163 - return ApiError::AccountNotFound.into_response(); 170 + return Ok(ApiError::AccountNotFound.into_response()); 164 171 } 165 172 Err(e) => { 166 173 error!("DB error fetching user: {:?}", e); 167 - return ApiError::InternalError(None).into_response(); 174 + return Ok(ApiError::InternalError(None).into_response()); 168 175 } 169 176 }; 170 177 171 178 if user.deactivated_at.is_some() { 172 - return ApiError::AccountDeactivated.into_response(); 179 + return Ok(ApiError::AccountDeactivated.into_response()); 173 180 } 174 181 175 182 let repo_rev = match &user.repo_rev { 176 183 Some(rev) => rev.clone(), 177 184 None => { 178 - return ApiError::RepoNotReady.into_response(); 185 + return Ok(ApiError::RepoNotReady.into_response()); 179 186 } 180 187 }; 181 188 182 189 let head_cid = match Cid::from_str(&user.repo_root_cid) { 183 190 Ok(c) => c, 184 191 Err(_) => { 185 - return ApiError::InternalError(Some("Invalid repo root CID".into())).into_response(); 192 + return Ok( 193 + ApiError::InternalError(Some("Invalid repo root CID".into())).into_response(), 194 + ); 186 195 } 187 196 }; 188 197 ··· 197 206 Ok(bytes) => bytes, 198 207 Err(e) => { 199 208 error!("Failed to generate CAR: {:?}", e); 200 - return ApiError::InternalError(Some("Failed to generate backup".into())) 201 - .into_response(); 209 + return Ok( 210 + ApiError::InternalError(Some("Failed to generate backup".into())).into_response(), 211 + ); 202 212 } 203 213 }; 204 214 ··· 212 222 Ok(key) => key, 213 223 Err(e) => { 214 224 error!("Failed to upload backup: {:?}", e); 215 - return ApiError::InternalError(Some("Failed to store backup".into())).into_response(); 225 + return Ok( 226 + ApiError::InternalError(Some("Failed to store backup".into())).into_response(), 227 + ); 216 228 } 217 229 }; 218 230 ··· 238 250 "Failed to rollback orphaned backup from S3" 239 251 ); 240 252 } 241 - return ApiError::InternalError(Some("Failed to record backup".into())).into_response(); 253 + return Ok( 254 + ApiError::InternalError(Some("Failed to record backup".into())).into_response(), 255 + ); 242 256 } 243 257 }; 244 258 ··· 261 275 warn!(did = %user.did, error = %e, "Failed to cleanup old backups after manual backup"); 262 276 } 263 277 264 - ( 278 + Ok(( 265 279 StatusCode::OK, 266 280 Json(CreateBackupOutput { 267 281 id: backup_id.to_string(), ··· 270 284 block_count, 271 285 }), 272 286 ) 273 - .into_response() 287 + .into_response()) 274 288 } 275 289 276 290 async fn cleanup_old_backups( ··· 310 324 311 325 pub async fn delete_backup( 312 326 State(state): State<AppState>, 313 - auth: BearerAuth, 327 + auth: Auth<Active>, 314 328 Query(query): Query<DeleteBackupQuery>, 315 - ) -> Response { 329 + ) -> Result<Response, crate::api::error::ApiError> { 316 330 let backup_id = match uuid::Uuid::parse_str(&query.id) { 317 331 Ok(id) => id, 318 332 Err(_) => { 319 - return ApiError::InvalidRequest("Invalid backup ID".into()).into_response(); 333 + return Ok(ApiError::InvalidRequest("Invalid backup ID".into()).into_response()); 320 334 } 321 335 }; 322 336 323 337 let backup = match state 324 338 .backup_repo 325 - .get_backup_for_deletion(backup_id, &auth.0.did) 339 + .get_backup_for_deletion(backup_id, &auth.did) 326 340 .await 327 341 { 328 342 Ok(Some(b)) => b, 329 343 Ok(None) => { 330 - return ApiError::BackupNotFound.into_response(); 344 + return Ok(ApiError::BackupNotFound.into_response()); 331 345 } 332 346 Err(e) => { 333 347 error!("DB error fetching backup: {:?}", e); 334 - return ApiError::InternalError(None).into_response(); 348 + return Ok(ApiError::InternalError(None).into_response()); 335 349 } 336 350 }; 337 351 338 352 if backup.deactivated_at.is_some() { 339 - return ApiError::AccountDeactivated.into_response(); 353 + return Ok(ApiError::AccountDeactivated.into_response()); 340 354 } 341 355 342 356 if let Some(backup_storage) = state.backup_storage.as_ref() ··· 351 365 352 366 if let Err(e) = state.backup_repo.delete_backup(backup.id).await { 353 367 error!("DB error deleting backup: {:?}", e); 354 - return ApiError::InternalError(Some("Failed to delete backup".into())).into_response(); 368 + return Ok(ApiError::InternalError(Some("Failed to delete backup".into())).into_response()); 355 369 } 356 370 357 - info!(did = %auth.0.did, backup_id = %backup_id, "Deleted backup"); 371 + info!(did = %auth.did, backup_id = %backup_id, "Deleted backup"); 358 372 359 - EmptyResponse::ok().into_response() 373 + Ok(EmptyResponse::ok().into_response()) 360 374 } 361 375 362 376 #[derive(Deserialize)] ··· 367 381 368 382 pub async fn set_backup_enabled( 369 383 State(state): State<AppState>, 370 - auth: BearerAuth, 384 + auth: Auth<Active>, 371 385 Json(input): Json<SetBackupEnabledInput>, 372 - ) -> Response { 386 + ) -> Result<Response, crate::api::error::ApiError> { 373 387 let deactivated_at = match state 374 388 .backup_repo 375 - .get_user_deactivated_status(&auth.0.did) 389 + .get_user_deactivated_status(&auth.did) 376 390 .await 377 391 { 378 392 Ok(Some(status)) => status, 379 393 Ok(None) => { 380 - return ApiError::AccountNotFound.into_response(); 394 + return Ok(ApiError::AccountNotFound.into_response()); 381 395 } 382 396 Err(e) => { 383 397 error!("DB error fetching user: {:?}", e); 384 - return ApiError::InternalError(None).into_response(); 398 + return Ok(ApiError::InternalError(None).into_response()); 385 399 } 386 400 }; 387 401 388 402 if deactivated_at.is_some() { 389 - return ApiError::AccountDeactivated.into_response(); 403 + return Ok(ApiError::AccountDeactivated.into_response()); 390 404 } 391 405 392 406 if let Err(e) = state 393 407 .backup_repo 394 - .update_backup_enabled(&auth.0.did, input.enabled) 408 + .update_backup_enabled(&auth.did, input.enabled) 395 409 .await 396 410 { 397 411 error!("DB error updating backup_enabled: {:?}", e); 398 - return ApiError::InternalError(Some("Failed to update setting".into())).into_response(); 412 + return Ok( 413 + ApiError::InternalError(Some("Failed to update setting".into())).into_response(), 414 + ); 399 415 } 400 416 401 - info!(did = %auth.0.did, enabled = input.enabled, "Updated backup_enabled setting"); 417 + info!(did = %auth.did, enabled = input.enabled, "Updated backup_enabled setting"); 402 418 403 - EnabledResponse::response(input.enabled).into_response() 419 + Ok(EnabledResponse::response(input.enabled).into_response()) 404 420 } 405 421 406 - pub async fn export_blobs(State(state): State<AppState>, auth: BearerAuth) -> Response { 407 - let user_id = match state.backup_repo.get_user_id_by_did(&auth.0.did).await { 422 + pub async fn export_blobs( 423 + State(state): State<AppState>, 424 + auth: Auth<Active>, 425 + ) -> Result<Response, crate::api::error::ApiError> { 426 + let user_id = match state.backup_repo.get_user_id_by_did(&auth.did).await { 408 427 Ok(Some(id)) => id, 409 428 Ok(None) => { 410 - return ApiError::AccountNotFound.into_response(); 429 + return Ok(ApiError::AccountNotFound.into_response()); 411 430 } 412 431 Err(e) => { 413 432 error!("DB error fetching user: {:?}", e); 414 - return ApiError::InternalError(None).into_response(); 433 + return Ok(ApiError::InternalError(None).into_response()); 415 434 } 416 435 }; 417 436 ··· 419 438 Ok(rows) => rows, 420 439 Err(e) => { 421 440 error!("DB error fetching blobs: {:?}", e); 422 - return ApiError::InternalError(None).into_response(); 441 + return Ok(ApiError::InternalError(None).into_response()); 423 442 } 424 443 }; 425 444 426 445 if blobs.is_empty() { 427 - return ( 446 + return Ok(( 428 447 StatusCode::OK, 429 448 [ 430 449 (axum::http::header::CONTENT_TYPE, "application/zip"), ··· 435 454 ], 436 455 Vec::<u8>::new(), 437 456 ) 438 - .into_response(); 457 + .into_response()); 439 458 } 440 459 441 460 let mut zip_buffer = std::io::Cursor::new(Vec::new()); ··· 513 532 514 533 if let Err(e) = zip.finish() { 515 534 error!("Failed to finish zip: {:?}", e); 516 - return ApiError::InternalError(Some("Failed to create zip file".into())) 517 - .into_response(); 535 + return Ok( 536 + ApiError::InternalError(Some("Failed to create zip file".into())).into_response(), 537 + ); 518 538 } 519 539 } 520 540 521 541 let zip_bytes = zip_buffer.into_inner(); 522 542 523 - info!(did = %auth.0.did, blob_count = blobs.len(), size_bytes = zip_bytes.len(), "Exported blobs"); 543 + info!(did = %auth.did, blob_count = blobs.len(), size_bytes = zip_bytes.len(), "Exported blobs"); 524 544 525 - ( 545 + Ok(( 526 546 StatusCode::OK, 527 547 [ 528 548 (axum::http::header::CONTENT_TYPE, "application/zip"), ··· 533 553 ], 534 554 zip_bytes, 535 555 ) 536 - .into_response() 556 + .into_response()) 537 557 } 538 558 539 559 fn mime_to_extension(mime_type: &str) -> &'static str {
+115 -98
crates/tranquil-pds/src/api/delegation.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::repo::record::utils::create_signed_commit; 3 - use crate::auth::BearerAuth; 3 + use crate::auth::{Active, Auth}; 4 4 use crate::delegation::{DelegationActionType, SCOPE_PRESETS, scopes}; 5 5 use crate::state::{AppState, RateLimitKind}; 6 6 use crate::types::{Did, Handle, Nsid, Rkey}; ··· 33 33 pub controllers: Vec<ControllerInfo>, 34 34 } 35 35 36 - pub async fn list_controllers(State(state): State<AppState>, auth: BearerAuth) -> Response { 36 + pub async fn list_controllers( 37 + State(state): State<AppState>, 38 + auth: Auth<Active>, 39 + ) -> Result<Response, ApiError> { 37 40 let controllers = match state 38 41 .delegation_repo 39 - .get_delegations_for_account(&auth.0.did) 42 + .get_delegations_for_account(&auth.did) 40 43 .await 41 44 { 42 45 Ok(c) => c, 43 46 Err(e) => { 44 47 tracing::error!("Failed to list controllers: {:?}", e); 45 - return ApiError::InternalError(Some("Failed to list controllers".into())) 46 - .into_response(); 48 + return Ok( 49 + ApiError::InternalError(Some("Failed to list controllers".into())).into_response(), 50 + ); 47 51 } 48 52 }; 49 53 50 - Json(ListControllersResponse { 54 + Ok(Json(ListControllersResponse { 51 55 controllers: controllers 52 56 .into_iter() 53 57 .map(|c| ControllerInfo { ··· 59 63 }) 60 64 .collect(), 61 65 }) 62 - .into_response() 66 + .into_response()) 63 67 } 64 68 65 69 #[derive(Debug, Deserialize)] ··· 70 74 71 75 pub async fn add_controller( 72 76 State(state): State<AppState>, 73 - auth: BearerAuth, 77 + auth: Auth<Active>, 74 78 Json(input): Json<AddControllerInput>, 75 - ) -> Response { 79 + ) -> Result<Response, ApiError> { 76 80 if let Err(e) = scopes::validate_delegation_scopes(&input.granted_scopes) { 77 - return ApiError::InvalidScopes(e).into_response(); 81 + return Ok(ApiError::InvalidScopes(e).into_response()); 78 82 } 79 83 80 84 let controller_exists = state ··· 86 90 .is_some(); 87 91 88 92 if !controller_exists { 89 - return ApiError::ControllerNotFound.into_response(); 93 + return Ok(ApiError::ControllerNotFound.into_response()); 90 94 } 91 95 92 - match state 93 - .delegation_repo 94 - .controls_any_accounts(&auth.0.did) 95 - .await 96 - { 96 + match state.delegation_repo.controls_any_accounts(&auth.did).await { 97 97 Ok(true) => { 98 - return ApiError::InvalidDelegation( 98 + return Ok(ApiError::InvalidDelegation( 99 99 "Cannot add controllers to an account that controls other accounts".into(), 100 100 ) 101 - .into_response(); 101 + .into_response()); 102 102 } 103 103 Err(e) => { 104 104 tracing::error!("Failed to check delegation status: {:?}", e); 105 - return ApiError::InternalError(Some("Failed to verify delegation status".into())) 106 - .into_response(); 105 + return Ok( 106 + ApiError::InternalError(Some("Failed to verify delegation status".into())) 107 + .into_response(), 108 + ); 107 109 } 108 110 Ok(false) => {} 109 111 } ··· 114 116 .await 115 117 { 116 118 Ok(true) => { 117 - return ApiError::InvalidDelegation( 119 + return Ok(ApiError::InvalidDelegation( 118 120 "Cannot add a controlled account as a controller".into(), 119 121 ) 120 - .into_response(); 122 + .into_response()); 121 123 } 122 124 Err(e) => { 123 125 tracing::error!("Failed to check controller status: {:?}", e); 124 - return ApiError::InternalError(Some("Failed to verify controller status".into())) 125 - .into_response(); 126 + return Ok( 127 + ApiError::InternalError(Some("Failed to verify controller status".into())) 128 + .into_response(), 129 + ); 126 130 } 127 131 Ok(false) => {} 128 132 } ··· 130 134 match state 131 135 .delegation_repo 132 136 .create_delegation( 133 - &auth.0.did, 137 + &auth.did, 134 138 &input.controller_did, 135 139 &input.granted_scopes, 136 - &auth.0.did, 140 + &auth.did, 137 141 ) 138 142 .await 139 143 { ··· 141 145 let _ = state 142 146 .delegation_repo 143 147 .log_delegation_action( 144 - &auth.0.did, 145 - &auth.0.did, 148 + &auth.did, 149 + &auth.did, 146 150 Some(&input.controller_did), 147 151 DelegationActionType::GrantCreated, 148 152 Some(serde_json::json!({ ··· 153 157 ) 154 158 .await; 155 159 156 - ( 160 + Ok(( 157 161 StatusCode::OK, 158 162 Json(serde_json::json!({ 159 163 "success": true 160 164 })), 161 165 ) 162 - .into_response() 166 + .into_response()) 163 167 } 164 168 Err(e) => { 165 169 tracing::error!("Failed to add controller: {:?}", e); 166 - ApiError::InternalError(Some("Failed to add controller".into())).into_response() 170 + Ok(ApiError::InternalError(Some("Failed to add controller".into())).into_response()) 167 171 } 168 172 } 169 173 } ··· 175 179 176 180 pub async fn remove_controller( 177 181 State(state): State<AppState>, 178 - auth: BearerAuth, 182 + auth: Auth<Active>, 179 183 Json(input): Json<RemoveControllerInput>, 180 - ) -> Response { 184 + ) -> Result<Response, ApiError> { 181 185 match state 182 186 .delegation_repo 183 - .revoke_delegation(&auth.0.did, &input.controller_did, &auth.0.did) 187 + .revoke_delegation(&auth.did, &input.controller_did, &auth.did) 184 188 .await 185 189 { 186 190 Ok(true) => { 187 191 let revoked_app_passwords = state 188 192 .session_repo 189 - .delete_app_passwords_by_controller(&auth.0.did, &input.controller_did) 193 + .delete_app_passwords_by_controller(&auth.did, &input.controller_did) 190 194 .await 191 195 .unwrap_or(0) as usize; 192 196 193 197 let revoked_oauth_tokens = state 194 198 .oauth_repo 195 - .revoke_tokens_for_controller(&auth.0.did, &input.controller_did) 199 + .revoke_tokens_for_controller(&auth.did, &input.controller_did) 196 200 .await 197 201 .unwrap_or(0); 198 202 199 203 let _ = state 200 204 .delegation_repo 201 205 .log_delegation_action( 202 - &auth.0.did, 203 - &auth.0.did, 206 + &auth.did, 207 + &auth.did, 204 208 Some(&input.controller_did), 205 209 DelegationActionType::GrantRevoked, 206 210 Some(serde_json::json!({ ··· 212 216 ) 213 217 .await; 214 218 215 - ( 219 + Ok(( 216 220 StatusCode::OK, 217 221 Json(serde_json::json!({ 218 222 "success": true 219 223 })), 220 224 ) 221 - .into_response() 225 + .into_response()) 222 226 } 223 - Ok(false) => ApiError::DelegationNotFound.into_response(), 227 + Ok(false) => Ok(ApiError::DelegationNotFound.into_response()), 224 228 Err(e) => { 225 229 tracing::error!("Failed to remove controller: {:?}", e); 226 - ApiError::InternalError(Some("Failed to remove controller".into())).into_response() 230 + Ok(ApiError::InternalError(Some("Failed to remove controller".into())).into_response()) 227 231 } 228 232 } 229 233 } ··· 236 240 237 241 pub async fn update_controller_scopes( 238 242 State(state): State<AppState>, 239 - auth: BearerAuth, 243 + auth: Auth<Active>, 240 244 Json(input): Json<UpdateControllerScopesInput>, 241 - ) -> Response { 245 + ) -> Result<Response, ApiError> { 242 246 if let Err(e) = scopes::validate_delegation_scopes(&input.granted_scopes) { 243 - return ApiError::InvalidScopes(e).into_response(); 247 + return Ok(ApiError::InvalidScopes(e).into_response()); 244 248 } 245 249 246 250 match state 247 251 .delegation_repo 248 - .update_delegation_scopes(&auth.0.did, &input.controller_did, &input.granted_scopes) 252 + .update_delegation_scopes(&auth.did, &input.controller_did, &input.granted_scopes) 249 253 .await 250 254 { 251 255 Ok(true) => { 252 256 let _ = state 253 257 .delegation_repo 254 258 .log_delegation_action( 255 - &auth.0.did, 256 - &auth.0.did, 259 + &auth.did, 260 + &auth.did, 257 261 Some(&input.controller_did), 258 262 DelegationActionType::ScopesModified, 259 263 Some(serde_json::json!({ ··· 264 268 ) 265 269 .await; 266 270 267 - ( 271 + Ok(( 268 272 StatusCode::OK, 269 273 Json(serde_json::json!({ 270 274 "success": true 271 275 })), 272 276 ) 273 - .into_response() 277 + .into_response()) 274 278 } 275 - Ok(false) => ApiError::DelegationNotFound.into_response(), 279 + Ok(false) => Ok(ApiError::DelegationNotFound.into_response()), 276 280 Err(e) => { 277 281 tracing::error!("Failed to update controller scopes: {:?}", e); 278 - ApiError::InternalError(Some("Failed to update controller scopes".into())) 279 - .into_response() 282 + Ok( 283 + ApiError::InternalError(Some("Failed to update controller scopes".into())) 284 + .into_response(), 285 + ) 280 286 } 281 287 } 282 288 } ··· 295 301 pub accounts: Vec<DelegatedAccountInfo>, 296 302 } 297 303 298 - pub async fn list_controlled_accounts(State(state): State<AppState>, auth: BearerAuth) -> Response { 304 + pub async fn list_controlled_accounts( 305 + State(state): State<AppState>, 306 + auth: Auth<Active>, 307 + ) -> Result<Response, ApiError> { 299 308 let accounts = match state 300 309 .delegation_repo 301 - .get_accounts_controlled_by(&auth.0.did) 310 + .get_accounts_controlled_by(&auth.did) 302 311 .await 303 312 { 304 313 Ok(a) => a, 305 314 Err(e) => { 306 315 tracing::error!("Failed to list controlled accounts: {:?}", e); 307 - return ApiError::InternalError(Some("Failed to list controlled accounts".into())) 308 - .into_response(); 316 + return Ok( 317 + ApiError::InternalError(Some("Failed to list controlled accounts".into())) 318 + .into_response(), 319 + ); 309 320 } 310 321 }; 311 322 312 - Json(ListControlledAccountsResponse { 323 + Ok(Json(ListControlledAccountsResponse { 313 324 accounts: accounts 314 325 .into_iter() 315 326 .map(|a| DelegatedAccountInfo { ··· 320 331 }) 321 332 .collect(), 322 333 }) 323 - .into_response() 334 + .into_response()) 324 335 } 325 336 326 337 #[derive(Debug, Deserialize)] ··· 355 366 356 367 pub async fn get_audit_log( 357 368 State(state): State<AppState>, 358 - auth: BearerAuth, 369 + auth: Auth<Active>, 359 370 Query(params): Query<AuditLogParams>, 360 - ) -> Response { 371 + ) -> Result<Response, ApiError> { 361 372 let limit = params.limit.clamp(1, 100); 362 373 let offset = params.offset.max(0); 363 374 364 375 let entries = match state 365 376 .delegation_repo 366 - .get_audit_log_for_account(&auth.0.did, limit, offset) 377 + .get_audit_log_for_account(&auth.did, limit, offset) 367 378 .await 368 379 { 369 380 Ok(e) => e, 370 381 Err(e) => { 371 382 tracing::error!("Failed to get audit log: {:?}", e); 372 - return ApiError::InternalError(Some("Failed to get audit log".into())).into_response(); 383 + return Ok( 384 + ApiError::InternalError(Some("Failed to get audit log".into())).into_response(), 385 + ); 373 386 } 374 387 }; 375 388 376 389 let total = state 377 390 .delegation_repo 378 - .count_audit_log_entries(&auth.0.did) 391 + .count_audit_log_entries(&auth.did) 379 392 .await 380 393 .unwrap_or_default(); 381 394 382 - Json(GetAuditLogResponse { 395 + Ok(Json(GetAuditLogResponse { 383 396 entries: entries 384 397 .into_iter() 385 398 .map(|e| AuditLogEntry { ··· 394 407 .collect(), 395 408 total, 396 409 }) 397 - .into_response() 410 + .into_response()) 398 411 } 399 412 400 413 #[derive(Debug, Serialize)] ··· 444 457 pub async fn create_delegated_account( 445 458 State(state): State<AppState>, 446 459 headers: HeaderMap, 447 - auth: BearerAuth, 460 + auth: Auth<Active>, 448 461 Json(input): Json<CreateDelegatedAccountInput>, 449 - ) -> Response { 462 + ) -> Result<Response, ApiError> { 450 463 let client_ip = extract_client_ip(&headers); 451 464 if !state 452 465 .check_rate_limit(RateLimitKind::AccountCreation, &client_ip) 453 466 .await 454 467 { 455 468 warn!(ip = %client_ip, "Delegated account creation rate limit exceeded"); 456 - return ApiError::RateLimitExceeded(Some( 469 + return Ok(ApiError::RateLimitExceeded(Some( 457 470 "Too many account creation attempts. Please try again later.".into(), 458 471 )) 459 - .into_response(); 472 + .into_response()); 460 473 } 461 474 462 475 if let Err(e) = scopes::validate_delegation_scopes(&input.controller_scopes) { 463 - return ApiError::InvalidScopes(e).into_response(); 476 + return Ok(ApiError::InvalidScopes(e).into_response()); 464 477 } 465 478 466 - match state.delegation_repo.has_any_controllers(&auth.0.did).await { 479 + match state.delegation_repo.has_any_controllers(&auth.did).await { 467 480 Ok(true) => { 468 - return ApiError::InvalidDelegation( 481 + return Ok(ApiError::InvalidDelegation( 469 482 "Cannot create delegated accounts from a controlled account".into(), 470 483 ) 471 - .into_response(); 484 + .into_response()); 472 485 } 473 486 Err(e) => { 474 487 tracing::error!("Failed to check controller status: {:?}", e); 475 - return ApiError::InternalError(Some("Failed to verify controller status".into())) 476 - .into_response(); 488 + return Ok( 489 + ApiError::InternalError(Some("Failed to verify controller status".into())) 490 + .into_response(), 491 + ); 477 492 } 478 493 Ok(false) => {} 479 494 } ··· 494 509 match crate::api::validation::validate_short_handle(handle_to_validate) { 495 510 Ok(h) => format!("{}.{}", h, hostname_for_handles), 496 511 Err(e) => { 497 - return ApiError::InvalidRequest(e.to_string()).into_response(); 512 + return Ok(ApiError::InvalidRequest(e.to_string()).into_response()); 498 513 } 499 514 } 500 515 } else { ··· 509 524 if let Some(ref email) = email 510 525 && !crate::api::validation::is_valid_email(email) 511 526 { 512 - return ApiError::InvalidEmail.into_response(); 527 + return Ok(ApiError::InvalidEmail.into_response()); 513 528 } 514 529 515 530 if let Some(ref code) = input.invite_code { ··· 520 535 .unwrap_or(false); 521 536 522 537 if !valid { 523 - return ApiError::InvalidInviteCode.into_response(); 538 + return Ok(ApiError::InvalidInviteCode.into_response()); 524 539 } 525 540 } else { 526 541 let invite_required = std::env::var("INVITE_CODE_REQUIRED") 527 542 .map(|v| v == "true" || v == "1") 528 543 .unwrap_or(false); 529 544 if invite_required { 530 - return ApiError::InviteCodeRequired.into_response(); 545 + return Ok(ApiError::InviteCodeRequired.into_response()); 531 546 } 532 547 } 533 548 ··· 542 557 Ok(k) => k, 543 558 Err(e) => { 544 559 error!("Error creating signing key: {:?}", e); 545 - return ApiError::InternalError(None).into_response(); 560 + return Ok(ApiError::InternalError(None).into_response()); 546 561 } 547 562 }; 548 563 ··· 558 573 Ok(r) => r, 559 574 Err(e) => { 560 575 error!("Error creating PLC genesis operation: {:?}", e); 561 - return ApiError::InternalError(Some("Failed to create PLC operation".into())) 562 - .into_response(); 576 + return Ok( 577 + ApiError::InternalError(Some("Failed to create PLC operation".into())) 578 + .into_response(), 579 + ); 563 580 } 564 581 }; 565 582 ··· 569 586 .await 570 587 { 571 588 error!("Failed to submit PLC genesis operation: {:?}", e); 572 - return ApiError::UpstreamErrorMsg(format!( 589 + return Ok(ApiError::UpstreamErrorMsg(format!( 573 590 "Failed to register DID with PLC directory: {}", 574 591 e 575 592 )) 576 - .into_response(); 593 + .into_response()); 577 594 } 578 595 579 596 let did = Did::new_unchecked(&genesis_result.did); 580 597 let handle = Handle::new_unchecked(&handle); 581 - info!(did = %did, handle = %handle, controller = %&auth.0.did, "Created DID for delegated account"); 598 + info!(did = %did, handle = %handle, controller = %&auth.did, "Created DID for delegated account"); 582 599 583 600 let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) { 584 601 Ok(bytes) => bytes, 585 602 Err(e) => { 586 603 error!("Error encrypting signing key: {:?}", e); 587 - return ApiError::InternalError(None).into_response(); 604 + return Ok(ApiError::InternalError(None).into_response()); 588 605 } 589 606 }; 590 607 ··· 593 610 Ok(c) => c, 594 611 Err(e) => { 595 612 error!("Error persisting MST: {:?}", e); 596 - return ApiError::InternalError(None).into_response(); 613 + return Ok(ApiError::InternalError(None).into_response()); 597 614 } 598 615 }; 599 616 let rev = Tid::now(LimitedU32::MIN); ··· 602 619 Ok(result) => result, 603 620 Err(e) => { 604 621 error!("Error creating genesis commit: {:?}", e); 605 - return ApiError::InternalError(None).into_response(); 622 + return Ok(ApiError::InternalError(None).into_response()); 606 623 } 607 624 }; 608 625 let commit_cid: cid::Cid = match state.block_store.put(&commit_bytes).await { 609 626 Ok(c) => c, 610 627 Err(e) => { 611 628 error!("Error saving genesis commit: {:?}", e); 612 - return ApiError::InternalError(None).into_response(); 629 + return Ok(ApiError::InternalError(None).into_response()); 613 630 } 614 631 }; 615 632 let genesis_block_cids = vec![mst_root.to_bytes(), commit_cid.to_bytes()]; ··· 618 635 handle: handle.clone(), 619 636 email: email.clone(), 620 637 did: did.clone(), 621 - controller_did: auth.0.did.clone(), 638 + controller_did: auth.did.clone(), 622 639 controller_scopes: input.controller_scopes.clone(), 623 640 encrypted_key_bytes, 624 641 encryption_version: crate::config::ENCRYPTION_VERSION, ··· 635 652 { 636 653 Ok(id) => id, 637 654 Err(tranquil_db_traits::CreateAccountError::HandleTaken) => { 638 - return ApiError::HandleNotAvailable(None).into_response(); 655 + return Ok(ApiError::HandleNotAvailable(None).into_response()); 639 656 } 640 657 Err(tranquil_db_traits::CreateAccountError::EmailTaken) => { 641 - return ApiError::EmailTaken.into_response(); 658 + return Ok(ApiError::EmailTaken.into_response()); 642 659 } 643 660 Err(e) => { 644 661 error!("Error creating delegated account: {:?}", e); 645 - return ApiError::InternalError(None).into_response(); 662 + return Ok(ApiError::InternalError(None).into_response()); 646 663 } 647 664 }; 648 665 ··· 678 695 .delegation_repo 679 696 .log_delegation_action( 680 697 &did, 681 - &auth.0.did, 682 - Some(&auth.0.did), 698 + &auth.did, 699 + Some(&auth.did), 683 700 DelegationActionType::GrantCreated, 684 701 Some(json!({ 685 702 "account_created": true, ··· 690 707 ) 691 708 .await; 692 709 693 - info!(did = %did, handle = %handle, controller = %&auth.0.did, "Delegated account created"); 710 + info!(did = %did, handle = %handle, controller = %&auth.did, "Delegated account created"); 694 711 695 - Json(CreateDelegatedAccountResponse { did, handle }).into_response() 712 + Ok(Json(CreateDelegatedAccountResponse { did, handle }).into_response()) 696 713 }
+13
crates/tranquil-pds/src/api/error.rs
··· 543 543 crate::auth::extractor::AuthError::AccountDeactivated => Self::AccountDeactivated, 544 544 crate::auth::extractor::AuthError::AccountTakedown => Self::AccountTakedown, 545 545 crate::auth::extractor::AuthError::AdminRequired => Self::AdminRequired, 546 + crate::auth::extractor::AuthError::ServiceAuthNotAllowed => Self::AuthenticationFailed( 547 + Some("Service authentication not allowed for this endpoint".to_string()), 548 + ), 549 + crate::auth::extractor::AuthError::InsufficientScope(msg) => { 550 + Self::InsufficientScope(Some(msg)) 551 + } 552 + crate::auth::extractor::AuthError::OAuthExpiredToken(msg) => { 553 + Self::OAuthExpiredToken(Some(msg)) 554 + } 555 + crate::auth::extractor::AuthError::UseDpopNonce(_) 556 + | crate::auth::extractor::AuthError::InvalidDpopProof(_) => { 557 + Self::AuthenticationFailed(None) 558 + } 546 559 } 547 560 } 548 561 }
+4 -4
crates/tranquil-pds/src/api/identity/account.rs
··· 1 1 use super::did::verify_did_web; 2 2 use crate::api::error::ApiError; 3 3 use crate::api::repo::record::utils::create_signed_commit; 4 - use crate::auth::{ServiceTokenVerifier, is_service_token}; 4 + use crate::auth::{ServiceTokenVerifier, extract_auth_token_from_header, is_service_token}; 5 5 use crate::plc::{PlcClient, create_genesis_operation, signing_key_to_did_key}; 6 6 use crate::state::{AppState, RateLimitKind}; 7 7 use crate::types::{Did, Handle, Nsid, PlainPassword, Rkey}; ··· 96 96 .into_response(); 97 97 } 98 98 99 - let migration_auth = if let Some(extracted) = crate::auth::extract_auth_token_from_header( 100 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 101 - ) { 99 + let migration_auth = if let Some(extracted) = 100 + extract_auth_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) 101 + { 102 102 let token = extracted.token; 103 103 if is_service_token(&token) { 104 104 let verifier = ServiceTokenVerifier::new();
+88 -102
crates/tranquil-pds/src/api/identity/did.rs
··· 1 1 use crate::api::{ApiError, DidResponse, EmptyResponse}; 2 - use crate::auth::BearerAuthAllowDeactivated; 2 + use crate::auth::{Auth, NotTakendown}; 3 3 use crate::plc::signing_key_to_did_key; 4 4 use crate::state::AppState; 5 5 use crate::types::Handle; ··· 518 518 519 519 pub async fn get_recommended_did_credentials( 520 520 State(state): State<AppState>, 521 - auth: BearerAuthAllowDeactivated, 522 - ) -> Response { 523 - let auth_user = auth.0; 524 - let handle = match state.user_repo.get_handle_by_did(&auth_user.did).await { 525 - Ok(Some(h)) => h, 526 - Ok(None) => return ApiError::InternalError(None).into_response(), 527 - Err(_) => return ApiError::InternalError(None).into_response(), 528 - }; 529 - let key_bytes = match auth_user.key_bytes { 530 - Some(kb) => kb, 531 - None => { 532 - return ApiError::AuthenticationFailed(Some( 533 - "OAuth tokens cannot get DID credentials".into(), 534 - )) 535 - .into_response(); 536 - } 537 - }; 521 + auth: Auth<NotTakendown>, 522 + ) -> Result<Response, ApiError> { 523 + let handle = state 524 + .user_repo 525 + .get_handle_by_did(&auth.did) 526 + .await 527 + .map_err(|_| ApiError::InternalError(None))? 528 + .ok_or(ApiError::InternalError(None))?; 529 + 530 + let key_bytes = auth.key_bytes.clone().ok_or_else(|| { 531 + ApiError::AuthenticationFailed(Some("OAuth tokens cannot get DID credentials".into())) 532 + })?; 533 + 538 534 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 539 535 let pds_endpoint = format!("https://{}", hostname); 540 - let signing_key = match k256::ecdsa::SigningKey::from_slice(&key_bytes) { 541 - Ok(k) => k, 542 - Err(_) => return ApiError::InternalError(None).into_response(), 543 - }; 536 + let signing_key = k256::ecdsa::SigningKey::from_slice(&key_bytes) 537 + .map_err(|_| ApiError::InternalError(None))?; 544 538 let did_key = signing_key_to_did_key(&signing_key); 545 - let rotation_keys = if auth_user.did.starts_with("did:web:") { 539 + let rotation_keys = if auth.did.starts_with("did:web:") { 546 540 vec![] 547 541 } else { 548 542 let server_rotation_key = match std::env::var("PLC_ROTATION_KEY") { ··· 556 550 }; 557 551 vec![server_rotation_key] 558 552 }; 559 - ( 553 + Ok(( 560 554 StatusCode::OK, 561 555 Json(GetRecommendedDidCredentialsOutput { 562 556 rotation_keys, ··· 570 564 }, 571 565 }), 572 566 ) 573 - .into_response() 567 + .into_response()) 574 568 } 575 569 576 570 #[derive(Deserialize)] ··· 580 574 581 575 pub async fn update_handle( 582 576 State(state): State<AppState>, 583 - auth: BearerAuthAllowDeactivated, 577 + auth: Auth<NotTakendown>, 584 578 Json(input): Json<UpdateHandleInput>, 585 - ) -> Response { 586 - let auth_user = auth.0; 579 + ) -> Result<Response, ApiError> { 587 580 if let Err(e) = crate::auth::scope_check::check_identity_scope( 588 - auth_user.is_oauth, 589 - auth_user.scope.as_deref(), 581 + auth.is_oauth(), 582 + auth.scope.as_deref(), 590 583 crate::oauth::scopes::IdentityAttr::Handle, 591 584 ) { 592 - return e; 585 + return Ok(e); 593 586 } 594 - let did = auth_user.did; 587 + let did = auth.did.clone(); 595 588 if !state 596 589 .check_rate_limit(crate::state::RateLimitKind::HandleUpdate, &did) 597 590 .await 598 591 { 599 - return ApiError::RateLimitExceeded(Some( 592 + return Err(ApiError::RateLimitExceeded(Some( 600 593 "Too many handle updates. Try again later.".into(), 601 - )) 602 - .into_response(); 594 + ))); 603 595 } 604 596 if !state 605 597 .check_rate_limit(crate::state::RateLimitKind::HandleUpdateDaily, &did) 606 598 .await 607 599 { 608 - return ApiError::RateLimitExceeded(Some("Daily handle update limit exceeded.".into())) 609 - .into_response(); 600 + return Err(ApiError::RateLimitExceeded(Some( 601 + "Daily handle update limit exceeded.".into(), 602 + ))); 610 603 } 611 - let user_row = match state.user_repo.get_id_and_handle_by_did(&did).await { 612 - Ok(Some(row)) => row, 613 - Ok(None) => return ApiError::InternalError(None).into_response(), 614 - Err(_) => return ApiError::InternalError(None).into_response(), 615 - }; 604 + let user_row = state 605 + .user_repo 606 + .get_id_and_handle_by_did(&did) 607 + .await 608 + .map_err(|_| ApiError::InternalError(None))? 609 + .ok_or(ApiError::InternalError(None))?; 616 610 let user_id = user_row.id; 617 611 let current_handle = user_row.handle; 618 612 let new_handle = input.handle.trim().to_ascii_lowercase(); 619 613 if new_handle.is_empty() { 620 - return ApiError::InvalidRequest("handle is required".into()).into_response(); 614 + return Err(ApiError::InvalidRequest("handle is required".into())); 621 615 } 622 616 if !new_handle 623 617 .chars() 624 618 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-') 625 619 { 626 - return ApiError::InvalidHandle(Some("Handle contains invalid characters".into())) 627 - .into_response(); 620 + return Err(ApiError::InvalidHandle(Some( 621 + "Handle contains invalid characters".into(), 622 + ))); 628 623 } 629 624 if new_handle.split('.').any(|segment| segment.is_empty()) { 630 - return ApiError::InvalidHandle(Some("Handle contains empty segment".into())) 631 - .into_response(); 625 + return Err(ApiError::InvalidHandle(Some( 626 + "Handle contains empty segment".into(), 627 + ))); 632 628 } 633 629 if new_handle 634 630 .split('.') 635 631 .any(|segment| segment.starts_with('-') || segment.ends_with('-')) 636 632 { 637 - return ApiError::InvalidHandle(Some( 633 + return Err(ApiError::InvalidHandle(Some( 638 634 "Handle segment cannot start or end with hyphen".into(), 639 - )) 640 - .into_response(); 635 + ))); 641 636 } 642 637 if crate::moderation::has_explicit_slur(&new_handle) { 643 - return ApiError::InvalidHandle(Some("Inappropriate language in handle".into())) 644 - .into_response(); 638 + return Err(ApiError::InvalidHandle(Some( 639 + "Inappropriate language in handle".into(), 640 + ))); 645 641 } 646 642 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 647 643 let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); ··· 667 663 { 668 664 warn!("Failed to sequence identity event for handle update: {}", e); 669 665 } 670 - return EmptyResponse::ok().into_response(); 666 + return Ok(EmptyResponse::ok().into_response()); 671 667 } 672 668 if short_part.contains('.') { 673 - return ApiError::InvalidHandle(Some( 669 + return Err(ApiError::InvalidHandle(Some( 674 670 "Nested subdomains are not allowed. Use a simple handle without dots.".into(), 675 - )) 676 - .into_response(); 671 + ))); 677 672 } 678 673 if short_part.len() < 3 { 679 - return ApiError::InvalidHandle(Some("Handle too short".into())).into_response(); 674 + return Err(ApiError::InvalidHandle(Some("Handle too short".into()))); 680 675 } 681 676 if short_part.len() > 18 { 682 - return ApiError::InvalidHandle(Some("Handle too long".into())).into_response(); 677 + return Err(ApiError::InvalidHandle(Some("Handle too long".into()))); 683 678 } 684 679 full_handle 685 680 } else { ··· 691 686 { 692 687 warn!("Failed to sequence identity event for handle update: {}", e); 693 688 } 694 - return EmptyResponse::ok().into_response(); 689 + return Ok(EmptyResponse::ok().into_response()); 695 690 } 696 691 match crate::handle::verify_handle_ownership(&new_handle, &did).await { 697 692 Ok(()) => {} 698 693 Err(crate::handle::HandleResolutionError::NotFound) => { 699 - return ApiError::HandleNotAvailable(None).into_response(); 694 + return Err(ApiError::HandleNotAvailable(None)); 700 695 } 701 696 Err(crate::handle::HandleResolutionError::DidMismatch { expected, actual }) => { 702 - return ApiError::HandleNotAvailable(Some(format!( 697 + return Err(ApiError::HandleNotAvailable(Some(format!( 703 698 "Handle points to different DID. Expected {}, got {}", 704 699 expected, actual 705 - ))) 706 - .into_response(); 700 + )))); 707 701 } 708 702 Err(e) => { 709 703 warn!("Handle verification failed: {}", e); 710 - return ApiError::HandleNotAvailable(Some(format!( 704 + return Err(ApiError::HandleNotAvailable(Some(format!( 711 705 "Handle verification failed: {}", 712 706 e 713 - ))) 714 - .into_response(); 707 + )))); 715 708 } 716 709 } 717 710 new_handle.clone() 718 711 }; 719 - let handle_typed: Handle = match handle.parse() { 720 - Ok(h) => h, 721 - Err(_) => { 722 - return ApiError::InvalidHandle(Some("Invalid handle format".into())).into_response(); 723 - } 724 - }; 725 - let handle_exists = match state 712 + let handle_typed: Handle = handle 713 + .parse() 714 + .map_err(|_| ApiError::InvalidHandle(Some("Invalid handle format".into())))?; 715 + let handle_exists = state 726 716 .user_repo 727 717 .check_handle_exists(&handle_typed, user_id) 728 718 .await 729 - { 730 - Ok(exists) => exists, 731 - Err(_) => return ApiError::InternalError(None).into_response(), 732 - }; 719 + .map_err(|_| ApiError::InternalError(None))?; 733 720 if handle_exists { 734 - return ApiError::HandleTaken.into_response(); 721 + return Err(ApiError::HandleTaken); 735 722 } 736 - let result = state.user_repo.update_handle(user_id, &handle_typed).await; 737 - match result { 738 - Ok(_) => { 739 - if !current_handle.is_empty() { 740 - let _ = state 741 - .cache 742 - .delete(&format!("handle:{}", current_handle)) 743 - .await; 744 - } 745 - let _ = state.cache.delete(&format!("handle:{}", handle)).await; 746 - if let Err(e) = 747 - crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle_typed)) 748 - .await 749 - { 750 - warn!("Failed to sequence identity event for handle update: {}", e); 751 - } 752 - if let Err(e) = update_plc_handle(&state, &did, &handle_typed).await { 753 - warn!("Failed to update PLC handle: {}", e); 754 - } 755 - EmptyResponse::ok().into_response() 756 - } 757 - Err(e) => { 723 + state 724 + .user_repo 725 + .update_handle(user_id, &handle_typed) 726 + .await 727 + .map_err(|e| { 758 728 error!("DB error updating handle: {:?}", e); 759 - ApiError::InternalError(None).into_response() 760 - } 729 + ApiError::InternalError(None) 730 + })?; 731 + 732 + if !current_handle.is_empty() { 733 + let _ = state 734 + .cache 735 + .delete(&format!("handle:{}", current_handle)) 736 + .await; 737 + } 738 + let _ = state.cache.delete(&format!("handle:{}", handle)).await; 739 + if let Err(e) = 740 + crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle_typed)).await 741 + { 742 + warn!("Failed to sequence identity event for handle update: {}", e); 743 + } 744 + if let Err(e) = update_plc_handle(&state, &did, &handle_typed).await { 745 + warn!("Failed to update PLC handle: {}", e); 761 746 } 747 + Ok(EmptyResponse::ok().into_response()) 762 748 } 763 749 764 750 pub async fn update_plc_handle(
+23 -24
crates/tranquil-pds/src/api/identity/plc/request.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::BearerAuthAllowDeactivated; 3 + use crate::auth::{Auth, Permissive}; 4 4 use crate::state::AppState; 5 5 use axum::{ 6 6 extract::State, ··· 15 15 16 16 pub async fn request_plc_operation_signature( 17 17 State(state): State<AppState>, 18 - auth: BearerAuthAllowDeactivated, 19 - ) -> Response { 20 - let auth_user = auth.0; 18 + auth: Auth<Permissive>, 19 + ) -> Result<Response, ApiError> { 21 20 if let Err(e) = crate::auth::scope_check::check_identity_scope( 22 - auth_user.is_oauth, 23 - auth_user.scope.as_deref(), 21 + auth.is_oauth(), 22 + auth.scope.as_deref(), 24 23 crate::oauth::scopes::IdentityAttr::Wildcard, 25 24 ) { 26 - return e; 25 + return Ok(e); 27 26 } 28 - let user_id = match state.user_repo.get_id_by_did(&auth_user.did).await { 29 - Ok(Some(id)) => id, 30 - Ok(None) => return ApiError::AccountNotFound.into_response(), 31 - Err(e) => { 27 + let user_id = state 28 + .user_repo 29 + .get_id_by_did(&auth.did) 30 + .await 31 + .map_err(|e| { 32 32 error!("DB error: {:?}", e); 33 - return ApiError::InternalError(None).into_response(); 34 - } 35 - }; 33 + ApiError::InternalError(None) 34 + })? 35 + .ok_or(ApiError::AccountNotFound)?; 36 + 36 37 let _ = state.infra_repo.delete_plc_tokens_for_user(user_id).await; 37 38 let plc_token = generate_plc_token(); 38 39 let expires_at = Utc::now() + Duration::minutes(10); 39 - if let Err(e) = state 40 + state 40 41 .infra_repo 41 42 .insert_plc_token(user_id, &plc_token, expires_at) 42 43 .await 43 - { 44 - error!("Failed to create PLC token: {:?}", e); 45 - return ApiError::InternalError(None).into_response(); 46 - } 44 + .map_err(|e| { 45 + error!("Failed to create PLC token: {:?}", e); 46 + ApiError::InternalError(None) 47 + })?; 48 + 47 49 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 48 50 if let Err(e) = crate::comms::comms_repo::enqueue_plc_operation( 49 51 state.user_repo.as_ref(), ··· 56 58 { 57 59 warn!("Failed to enqueue PLC operation notification: {:?}", e); 58 60 } 59 - info!( 60 - "PLC operation signature requested for user {}", 61 - auth_user.did 62 - ); 63 - EmptyResponse::ok().into_response() 61 + info!("PLC operation signature requested for user {}", auth.did); 62 + Ok(EmptyResponse::ok().into_response()) 64 63 }
+70 -85
crates/tranquil-pds/src/api/identity/plc/sign.rs
··· 1 1 use crate::api::ApiError; 2 - use crate::auth::BearerAuthAllowDeactivated; 2 + use crate::auth::{Auth, Permissive}; 3 3 use crate::circuit_breaker::with_circuit_breaker; 4 4 use crate::plc::{PlcClient, PlcError, PlcService, create_update_op, sign_operation}; 5 5 use crate::state::AppState; ··· 40 40 41 41 pub async fn sign_plc_operation( 42 42 State(state): State<AppState>, 43 - auth: BearerAuthAllowDeactivated, 43 + auth: Auth<Permissive>, 44 44 Json(input): Json<SignPlcOperationInput>, 45 - ) -> Response { 46 - let auth_user = auth.0; 45 + ) -> Result<Response, ApiError> { 47 46 if let Err(e) = crate::auth::scope_check::check_identity_scope( 48 - auth_user.is_oauth, 49 - auth_user.scope.as_deref(), 47 + auth.is_oauth(), 48 + auth.scope.as_deref(), 50 49 crate::oauth::scopes::IdentityAttr::Wildcard, 51 50 ) { 52 - return e; 51 + return Ok(e); 53 52 } 54 - let did = &auth_user.did; 53 + let did = &auth.did; 55 54 if did.starts_with("did:web:") { 56 - return ApiError::InvalidRequest( 55 + return Err(ApiError::InvalidRequest( 57 56 "PLC operations are only valid for did:plc identities".into(), 58 - ) 59 - .into_response(); 57 + )); 60 58 } 61 - let token = match &input.token { 62 - Some(t) => t, 63 - None => { 64 - return ApiError::InvalidRequest( 65 - "Email confirmation token required to sign PLC operations".into(), 66 - ) 67 - .into_response(); 68 - } 69 - }; 70 - let user_id = match state.user_repo.get_id_by_did(did).await { 71 - Ok(Some(id)) => id, 72 - Ok(None) => return ApiError::AccountNotFound.into_response(), 73 - Err(e) => { 59 + let token = input.token.as_ref().ok_or_else(|| { 60 + ApiError::InvalidRequest("Email confirmation token required to sign PLC operations".into()) 61 + })?; 62 + 63 + let user_id = state 64 + .user_repo 65 + .get_id_by_did(did) 66 + .await 67 + .map_err(|e| { 74 68 error!("DB error: {:?}", e); 75 - return ApiError::InternalError(None).into_response(); 76 - } 77 - }; 78 - let token_expiry = match state.infra_repo.get_plc_token_expiry(user_id, token).await { 79 - Ok(Some(expiry)) => expiry, 80 - Ok(None) => { 81 - return ApiError::InvalidToken(Some("Invalid or expired token".into())).into_response(); 82 - } 83 - Err(e) => { 69 + ApiError::InternalError(None) 70 + })? 71 + .ok_or(ApiError::AccountNotFound)?; 72 + 73 + let token_expiry = state 74 + .infra_repo 75 + .get_plc_token_expiry(user_id, token) 76 + .await 77 + .map_err(|e| { 84 78 error!("DB error: {:?}", e); 85 - return ApiError::InternalError(None).into_response(); 86 - } 87 - }; 79 + ApiError::InternalError(None) 80 + })? 81 + .ok_or_else(|| ApiError::InvalidToken(Some("Invalid or expired token".into())))?; 82 + 88 83 if Utc::now() > token_expiry { 89 84 let _ = state.infra_repo.delete_plc_token(user_id, token).await; 90 - return ApiError::ExpiredToken(Some("Token has expired".into())).into_response(); 85 + return Err(ApiError::ExpiredToken(Some("Token has expired".into()))); 91 86 } 92 - let key_row = match state.user_repo.get_user_key_by_id(user_id).await { 93 - Ok(Some(row)) => row, 94 - Ok(None) => { 95 - return ApiError::InternalError(Some("User signing key not found".into())) 96 - .into_response(); 97 - } 98 - Err(e) => { 87 + let key_row = state 88 + .user_repo 89 + .get_user_key_by_id(user_id) 90 + .await 91 + .map_err(|e| { 99 92 error!("DB error: {:?}", e); 100 - return ApiError::InternalError(None).into_response(); 101 - } 102 - }; 103 - let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 104 - { 105 - Ok(k) => k, 106 - Err(e) => { 93 + ApiError::InternalError(None) 94 + })? 95 + .ok_or_else(|| ApiError::InternalError(Some("User signing key not found".into())))?; 96 + 97 + let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 98 + .map_err(|e| { 107 99 error!("Failed to decrypt user key: {}", e); 108 - return ApiError::InternalError(None).into_response(); 109 - } 110 - }; 111 - let signing_key = match SigningKey::from_slice(&key_bytes) { 112 - Ok(k) => k, 113 - Err(e) => { 114 - error!("Failed to create signing key: {:?}", e); 115 - return ApiError::InternalError(None).into_response(); 116 - } 117 - }; 100 + ApiError::InternalError(None) 101 + })?; 102 + 103 + let signing_key = SigningKey::from_slice(&key_bytes).map_err(|e| { 104 + error!("Failed to create signing key: {:?}", e); 105 + ApiError::InternalError(None) 106 + })?; 107 + 118 108 let plc_client = PlcClient::with_cache(None, Some(state.cache.clone())); 119 109 let did_clone = did.clone(); 120 - let last_op = match with_circuit_breaker(&state.circuit_breakers.plc_directory, || async { 110 + let last_op = with_circuit_breaker(&state.circuit_breakers.plc_directory, || async { 121 111 plc_client.get_last_op(&did_clone).await 122 112 }) 123 113 .await 124 - { 125 - Ok(op) => op, 126 - Err(e) => return ApiError::from(e).into_response(), 127 - }; 114 + .map_err(ApiError::from)?; 115 + 128 116 if last_op.is_tombstone() { 129 - return ApiError::from(PlcError::Tombstoned).into_response(); 117 + return Err(ApiError::from(PlcError::Tombstoned)); 130 118 } 131 119 let services = input.services.map(|s| { 132 120 s.into_iter() ··· 141 129 }) 142 130 .collect() 143 131 }); 144 - let unsigned_op = match create_update_op( 132 + let unsigned_op = create_update_op( 145 133 &last_op, 146 134 input.rotation_keys, 147 135 input.verification_methods, 148 136 input.also_known_as, 149 137 services, 150 - ) { 151 - Ok(op) => op, 152 - Err(PlcError::Tombstoned) => { 153 - return ApiError::InvalidRequest("Cannot update tombstoned DID".into()).into_response(); 154 - } 155 - Err(e) => { 138 + ) 139 + .map_err(|e| match e { 140 + PlcError::Tombstoned => ApiError::InvalidRequest("Cannot update tombstoned DID".into()), 141 + _ => { 156 142 error!("Failed to create PLC operation: {:?}", e); 157 - return ApiError::InternalError(None).into_response(); 158 - } 159 - }; 160 - let signed_op = match sign_operation(&unsigned_op, &signing_key) { 161 - Ok(op) => op, 162 - Err(e) => { 163 - error!("Failed to sign PLC operation: {:?}", e); 164 - return ApiError::InternalError(None).into_response(); 143 + ApiError::InternalError(None) 165 144 } 166 - }; 145 + })?; 146 + 147 + let signed_op = sign_operation(&unsigned_op, &signing_key).map_err(|e| { 148 + error!("Failed to sign PLC operation: {:?}", e); 149 + ApiError::InternalError(None) 150 + })?; 151 + 167 152 let _ = state.infra_repo.delete_plc_token(user_id, token).await; 168 153 info!("Signed PLC operation for user {}", did); 169 - ( 154 + Ok(( 170 155 StatusCode::OK, 171 156 Json(SignPlcOperationOutput { 172 157 operation: signed_op, 173 158 }), 174 159 ) 175 - .into_response() 160 + .into_response()) 176 161 }
+58 -61
crates/tranquil-pds/src/api/identity/plc/submit.rs
··· 1 1 use crate::api::{ApiError, EmptyResponse}; 2 - use crate::auth::BearerAuthAllowDeactivated; 2 + use crate::auth::{Auth, Permissive}; 3 3 use crate::circuit_breaker::with_circuit_breaker; 4 4 use crate::plc::{PlcClient, signing_key_to_did_key, validate_plc_operation}; 5 5 use crate::state::AppState; ··· 20 20 21 21 pub async fn submit_plc_operation( 22 22 State(state): State<AppState>, 23 - auth: BearerAuthAllowDeactivated, 23 + auth: Auth<Permissive>, 24 24 Json(input): Json<SubmitPlcOperationInput>, 25 - ) -> Response { 26 - let auth_user = auth.0; 25 + ) -> Result<Response, ApiError> { 27 26 if let Err(e) = crate::auth::scope_check::check_identity_scope( 28 - auth_user.is_oauth, 29 - auth_user.scope.as_deref(), 27 + auth.is_oauth(), 28 + auth.scope.as_deref(), 30 29 crate::oauth::scopes::IdentityAttr::Wildcard, 31 30 ) { 32 - return e; 31 + return Ok(e); 33 32 } 34 - let did = &auth_user.did; 33 + let did = &auth.did; 35 34 if did.starts_with("did:web:") { 36 - return ApiError::InvalidRequest( 35 + return Err(ApiError::InvalidRequest( 37 36 "PLC operations are only valid for did:plc identities".into(), 38 - ) 39 - .into_response(); 37 + )); 40 38 } 41 - if let Err(e) = validate_plc_operation(&input.operation) { 42 - return ApiError::InvalidRequest(format!("Invalid operation: {}", e)).into_response(); 43 - } 39 + validate_plc_operation(&input.operation) 40 + .map_err(|e| ApiError::InvalidRequest(format!("Invalid operation: {}", e)))?; 41 + 44 42 let op = &input.operation; 45 43 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 46 44 let public_url = format!("https://{}", hostname); 47 - let user = match state.user_repo.get_id_and_handle_by_did(did).await { 48 - Ok(Some(u)) => u, 49 - Ok(None) => return ApiError::AccountNotFound.into_response(), 50 - Err(e) => { 45 + let user = state 46 + .user_repo 47 + .get_id_and_handle_by_did(did) 48 + .await 49 + .map_err(|e| { 51 50 error!("DB error: {:?}", e); 52 - return ApiError::InternalError(None).into_response(); 53 - } 54 - }; 55 - let key_row = match state.user_repo.get_user_key_by_id(user.id).await { 56 - Ok(Some(row)) => row, 57 - Ok(None) => { 58 - return ApiError::InternalError(Some("User signing key not found".into())) 59 - .into_response(); 60 - } 61 - Err(e) => { 51 + ApiError::InternalError(None) 52 + })? 53 + .ok_or(ApiError::AccountNotFound)?; 54 + 55 + let key_row = state 56 + .user_repo 57 + .get_user_key_by_id(user.id) 58 + .await 59 + .map_err(|e| { 62 60 error!("DB error: {:?}", e); 63 - return ApiError::InternalError(None).into_response(); 64 - } 65 - }; 66 - let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 67 - { 68 - Ok(k) => k, 69 - Err(e) => { 61 + ApiError::InternalError(None) 62 + })? 63 + .ok_or_else(|| ApiError::InternalError(Some("User signing key not found".into())))?; 64 + 65 + let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 66 + .map_err(|e| { 70 67 error!("Failed to decrypt user key: {}", e); 71 - return ApiError::InternalError(None).into_response(); 72 - } 73 - }; 74 - let signing_key = match SigningKey::from_slice(&key_bytes) { 75 - Ok(k) => k, 76 - Err(e) => { 77 - error!("Failed to create signing key: {:?}", e); 78 - return ApiError::InternalError(None).into_response(); 79 - } 80 - }; 68 + ApiError::InternalError(None) 69 + })?; 70 + 71 + let signing_key = SigningKey::from_slice(&key_bytes).map_err(|e| { 72 + error!("Failed to create signing key: {:?}", e); 73 + ApiError::InternalError(None) 74 + })?; 75 + 81 76 let user_did_key = signing_key_to_did_key(&signing_key); 82 77 let server_rotation_key = 83 78 std::env::var("PLC_ROTATION_KEY").unwrap_or_else(|_| user_did_key.clone()); ··· 86 81 .iter() 87 82 .any(|k| k.as_str() == Some(&server_rotation_key)); 88 83 if !has_server_key { 89 - return ApiError::InvalidRequest( 84 + return Err(ApiError::InvalidRequest( 90 85 "Rotation keys do not include server's rotation key".into(), 91 - ) 92 - .into_response(); 86 + )); 93 87 } 94 88 } 95 89 if let Some(services) = op.get("services").and_then(|v| v.as_object()) ··· 98 92 let service_type = pds.get("type").and_then(|v| v.as_str()); 99 93 let endpoint = pds.get("endpoint").and_then(|v| v.as_str()); 100 94 if service_type != Some("AtprotoPersonalDataServer") { 101 - return ApiError::InvalidRequest("Incorrect type on atproto_pds service".into()) 102 - .into_response(); 95 + return Err(ApiError::InvalidRequest( 96 + "Incorrect type on atproto_pds service".into(), 97 + )); 103 98 } 104 99 if endpoint != Some(&public_url) { 105 - return ApiError::InvalidRequest("Incorrect endpoint on atproto_pds service".into()) 106 - .into_response(); 100 + return Err(ApiError::InvalidRequest( 101 + "Incorrect endpoint on atproto_pds service".into(), 102 + )); 107 103 } 108 104 } 109 105 if let Some(verification_methods) = op.get("verificationMethods").and_then(|v| v.as_object()) 110 106 && let Some(atproto_key) = verification_methods.get("atproto").and_then(|v| v.as_str()) 111 107 && atproto_key != user_did_key 112 108 { 113 - return ApiError::InvalidRequest("Incorrect signing key in verificationMethods".into()) 114 - .into_response(); 109 + return Err(ApiError::InvalidRequest( 110 + "Incorrect signing key in verificationMethods".into(), 111 + )); 115 112 } 116 113 if let Some(also_known_as) = (!user.handle.is_empty()) 117 114 .then(|| op.get("alsoKnownAs").and_then(|v| v.as_array())) ··· 120 117 let expected_handle = format!("at://{}", user.handle); 121 118 let first_aka = also_known_as.first().and_then(|v| v.as_str()); 122 119 if first_aka != Some(&expected_handle) { 123 - return ApiError::InvalidRequest("Incorrect handle in alsoKnownAs".into()) 124 - .into_response(); 120 + return Err(ApiError::InvalidRequest( 121 + "Incorrect handle in alsoKnownAs".into(), 122 + )); 125 123 } 126 124 } 127 125 let plc_client = PlcClient::with_cache(None, Some(state.cache.clone())); 128 126 let operation_clone = input.operation.clone(); 129 127 let did_clone = did.clone(); 130 - if let Err(e) = with_circuit_breaker(&state.circuit_breakers.plc_directory, || async { 128 + with_circuit_breaker(&state.circuit_breakers.plc_directory, || async { 131 129 plc_client 132 130 .send_operation(&did_clone, &operation_clone) 133 131 .await 134 132 }) 135 133 .await 136 - { 137 - return ApiError::from(e).into_response(); 138 - } 134 + .map_err(ApiError::from)?; 135 + 139 136 match state 140 137 .repo_repo 141 138 .insert_identity_event(did, Some(&user.handle)) ··· 157 154 warn!(did = %did, "Failed to refresh DID cache after PLC update"); 158 155 } 159 156 info!(did = %did, "PLC operation submitted successfully"); 160 - EmptyResponse::ok().into_response() 157 + Ok(EmptyResponse::ok().into_response()) 161 158 }
+5 -7
crates/tranquil-pds/src/api/moderation/mod.rs
··· 1 1 use crate::api::ApiError; 2 2 use crate::api::proxy_client::{is_ssrf_safe, proxy_client}; 3 - use crate::auth::extractor::BearerAuthAllowTakendown; 3 + use crate::auth::{AnyUser, Auth}; 4 4 use crate::state::AppState; 5 5 use axum::{ 6 6 Json, ··· 42 42 43 43 pub async fn create_report( 44 44 State(state): State<AppState>, 45 - auth: BearerAuthAllowTakendown, 45 + auth: Auth<AnyUser>, 46 46 Json(input): Json<CreateReportInput>, 47 47 ) -> Response { 48 - let auth_user = auth.0; 49 - let did = &auth_user.did; 48 + let did = &auth.did; 50 49 51 50 if let Some((service_url, service_did)) = get_report_service_config() { 52 - return proxy_to_report_service(&state, &auth_user, &service_url, &service_did, &input) 53 - .await; 51 + return proxy_to_report_service(&state, &auth, &service_url, &service_did, &input).await; 54 52 } 55 53 56 - create_report_locally(&state, did, auth_user.is_takendown(), input).await 54 + create_report_locally(&state, did, auth.status.is_takendown(), input).await 57 55 } 58 56 59 57 async fn proxy_to_report_service(
+82 -98
crates/tranquil-pds/src/api/notification_prefs.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::BearerAuth; 2 + use crate::auth::{Active, Auth}; 3 3 use crate::state::AppState; 4 4 use axum::{ 5 5 Json, ··· 23 23 pub signal_verified: bool, 24 24 } 25 25 26 - pub async fn get_notification_prefs(State(state): State<AppState>, auth: BearerAuth) -> Response { 27 - let user = auth.0; 28 - let prefs = match state.user_repo.get_notification_prefs(&user.did).await { 29 - Ok(Some(p)) => p, 30 - Ok(None) => return ApiError::AccountNotFound.into_response(), 31 - Err(e) => { 32 - return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 33 - } 34 - }; 35 - Json(NotificationPrefsResponse { 26 + pub async fn get_notification_prefs( 27 + State(state): State<AppState>, 28 + auth: Auth<Active>, 29 + ) -> Result<Response, ApiError> { 30 + let prefs = state 31 + .user_repo 32 + .get_notification_prefs(&auth.did) 33 + .await 34 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))? 35 + .ok_or(ApiError::AccountNotFound)?; 36 + Ok(Json(NotificationPrefsResponse { 36 37 preferred_channel: prefs.preferred_channel, 37 38 email: prefs.email, 38 39 discord_id: prefs.discord_id, ··· 42 43 signal_number: prefs.signal_number, 43 44 signal_verified: prefs.signal_verified, 44 45 }) 45 - .into_response() 46 + .into_response()) 46 47 } 47 48 48 49 #[derive(Serialize)] ··· 62 63 pub notifications: Vec<NotificationHistoryEntry>, 63 64 } 64 65 65 - pub async fn get_notification_history(State(state): State<AppState>, auth: BearerAuth) -> Response { 66 - let user = auth.0; 67 - 68 - let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&user.did).await { 69 - Ok(Some(id)) => id, 70 - Ok(None) => return ApiError::AccountNotFound.into_response(), 71 - Err(e) => { 72 - return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 73 - } 74 - }; 66 + pub async fn get_notification_history( 67 + State(state): State<AppState>, 68 + auth: Auth<Active>, 69 + ) -> Result<Response, ApiError> { 70 + let user_id = state 71 + .user_repo 72 + .get_id_by_did(&auth.did) 73 + .await 74 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))? 75 + .ok_or(ApiError::AccountNotFound)?; 75 76 76 - let rows = match state.infra_repo.get_notification_history(user_id, 50).await { 77 - Ok(r) => r, 78 - Err(e) => { 79 - return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 80 - } 81 - }; 77 + let rows = state 78 + .infra_repo 79 + .get_notification_history(user_id, 50) 80 + .await 81 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 82 82 83 83 let sensitive_types = [ 84 84 "email_verification", ··· 111 111 }) 112 112 .collect(); 113 113 114 - Json(GetNotificationHistoryResponse { notifications }).into_response() 114 + Ok(Json(GetNotificationHistoryResponse { notifications }).into_response()) 115 115 } 116 116 117 117 #[derive(Deserialize)] ··· 184 184 185 185 pub async fn update_notification_prefs( 186 186 State(state): State<AppState>, 187 - auth: BearerAuth, 187 + auth: Auth<Active>, 188 188 Json(input): Json<UpdateNotificationPrefsInput>, 189 - ) -> Response { 190 - let user = auth.0; 191 - 192 - let user_row = match state.user_repo.get_id_handle_email_by_did(&user.did).await { 193 - Ok(Some(row)) => row, 194 - Ok(None) => return ApiError::AccountNotFound.into_response(), 195 - Err(e) => { 196 - return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 197 - } 198 - }; 189 + ) -> Result<Response, ApiError> { 190 + let user_row = state 191 + .user_repo 192 + .get_id_handle_email_by_did(&auth.did) 193 + .await 194 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))? 195 + .ok_or(ApiError::AccountNotFound)?; 199 196 200 197 let user_id = user_row.id; 201 198 let handle = user_row.handle; ··· 206 203 if let Some(ref channel) = input.preferred_channel { 207 204 let valid_channels = ["email", "discord", "telegram", "signal"]; 208 205 if !valid_channels.contains(&channel.as_str()) { 209 - return ApiError::InvalidRequest( 206 + return Err(ApiError::InvalidRequest( 210 207 "Invalid channel. Must be one of: email, discord, telegram, signal".into(), 211 - ) 212 - .into_response(); 208 + )); 213 209 } 214 - if let Err(e) = state 210 + state 215 211 .user_repo 216 - .update_preferred_comms_channel(&user.did, channel) 212 + .update_preferred_comms_channel(&auth.did, channel) 217 213 .await 218 - { 219 - return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 220 - } 221 - info!(did = %user.did, channel = %channel, "Updated preferred notification channel"); 214 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 215 + info!(did = %auth.did, channel = %channel, "Updated preferred notification channel"); 222 216 } 223 217 224 218 if let Some(ref new_email) = input.email { 225 219 let email_clean = new_email.trim().to_lowercase(); 226 220 if email_clean.is_empty() { 227 - return ApiError::InvalidRequest("Email cannot be empty".into()).into_response(); 221 + return Err(ApiError::InvalidRequest("Email cannot be empty".into())); 228 222 } 229 223 230 224 if !crate::api::validation::is_valid_email(&email_clean) { 231 - return ApiError::InvalidEmail.into_response(); 225 + return Err(ApiError::InvalidEmail); 232 226 } 233 227 234 - if current_email.as_ref().map(|e| e.to_lowercase()) == Some(email_clean.clone()) { 235 - info!(did = %user.did, "Email unchanged, skipping"); 236 - } else { 237 - if let Err(e) = request_channel_verification( 228 + if current_email.as_ref().map(|e| e.to_lowercase()) != Some(email_clean.clone()) { 229 + request_channel_verification( 238 230 &state, 239 231 user_id, 240 - &user.did, 232 + &auth.did, 241 233 "email", 242 234 &email_clean, 243 235 Some(&handle), 244 236 ) 245 237 .await 246 - { 247 - return ApiError::InternalError(Some(e)).into_response(); 248 - } 238 + .map_err(|e| ApiError::InternalError(Some(e)))?; 249 239 verification_required.push("email".to_string()); 250 - info!(did = %user.did, "Requested email verification"); 240 + info!(did = %auth.did, "Requested email verification"); 251 241 } 252 242 } 253 243 254 244 if let Some(ref discord_id) = input.discord_id { 255 245 if discord_id.is_empty() { 256 - if let Err(e) = state.user_repo.clear_discord(user_id).await { 257 - return ApiError::InternalError(Some(format!("Database error: {}", e))) 258 - .into_response(); 259 - } 260 - info!(did = %user.did, "Cleared Discord ID"); 246 + state 247 + .user_repo 248 + .clear_discord(user_id) 249 + .await 250 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 251 + info!(did = %auth.did, "Cleared Discord ID"); 261 252 } else { 262 - if let Err(e) = request_channel_verification( 263 - &state, user_id, &user.did, "discord", discord_id, None, 264 - ) 265 - .await 266 - { 267 - return ApiError::InternalError(Some(e)).into_response(); 268 - } 253 + request_channel_verification(&state, user_id, &auth.did, "discord", discord_id, None) 254 + .await 255 + .map_err(|e| ApiError::InternalError(Some(e)))?; 269 256 verification_required.push("discord".to_string()); 270 - info!(did = %user.did, "Requested Discord verification"); 257 + info!(did = %auth.did, "Requested Discord verification"); 271 258 } 272 259 } 273 260 274 261 if let Some(ref telegram) = input.telegram_username { 275 262 let telegram_clean = telegram.trim_start_matches('@'); 276 263 if telegram_clean.is_empty() { 277 - if let Err(e) = state.user_repo.clear_telegram(user_id).await { 278 - return ApiError::InternalError(Some(format!("Database error: {}", e))) 279 - .into_response(); 280 - } 281 - info!(did = %user.did, "Cleared Telegram username"); 264 + state 265 + .user_repo 266 + .clear_telegram(user_id) 267 + .await 268 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 269 + info!(did = %auth.did, "Cleared Telegram username"); 282 270 } else { 283 - if let Err(e) = request_channel_verification( 271 + request_channel_verification( 284 272 &state, 285 273 user_id, 286 - &user.did, 274 + &auth.did, 287 275 "telegram", 288 276 telegram_clean, 289 277 None, 290 278 ) 291 279 .await 292 - { 293 - return ApiError::InternalError(Some(e)).into_response(); 294 - } 280 + .map_err(|e| ApiError::InternalError(Some(e)))?; 295 281 verification_required.push("telegram".to_string()); 296 - info!(did = %user.did, "Requested Telegram verification"); 282 + info!(did = %auth.did, "Requested Telegram verification"); 297 283 } 298 284 } 299 285 300 286 if let Some(ref signal) = input.signal_number { 301 287 if signal.is_empty() { 302 - if let Err(e) = state.user_repo.clear_signal(user_id).await { 303 - return ApiError::InternalError(Some(format!("Database error: {}", e))) 304 - .into_response(); 305 - } 306 - info!(did = %user.did, "Cleared Signal number"); 288 + state 289 + .user_repo 290 + .clear_signal(user_id) 291 + .await 292 + .map_err(|e| ApiError::InternalError(Some(format!("Database error: {}", e))))?; 293 + info!(did = %auth.did, "Cleared Signal number"); 307 294 } else { 308 - if let Err(e) = 309 - request_channel_verification(&state, user_id, &user.did, "signal", signal, None) 310 - .await 311 - { 312 - return ApiError::InternalError(Some(e)).into_response(); 313 - } 295 + request_channel_verification(&state, user_id, &auth.did, "signal", signal, None) 296 + .await 297 + .map_err(|e| ApiError::InternalError(Some(e)))?; 314 298 verification_required.push("signal".to_string()); 315 - info!(did = %user.did, "Requested Signal verification"); 299 + info!(did = %auth.did, "Requested Signal verification"); 316 300 } 317 301 } 318 302 319 - Json(UpdateNotificationPrefsResponse { 303 + Ok(Json(UpdateNotificationPrefsResponse { 320 304 success: true, 321 305 verification_required, 322 306 }) 323 - .into_response() 307 + .into_response()) 324 308 }
+13 -4
crates/tranquil-pds/src/api/proxy.rs
··· 238 238 { 239 239 Ok(auth_user) => { 240 240 if let Err(e) = crate::auth::scope_check::check_rpc_scope( 241 - auth_user.is_oauth, 241 + auth_user.is_oauth(), 242 242 auth_user.scope.as_deref(), 243 243 &resolved.did, 244 244 method, ··· 267 267 } 268 268 } 269 269 Err(e) => { 270 - warn!("Token validation failed: {:?}", e); 271 - if matches!(e, crate::auth::TokenValidationError::OAuthTokenExpired) { 272 - return ApiError::from(e).into_response(); 270 + info!(error = ?e, "Proxy token validation failed, returning error to client"); 271 + if matches!( 272 + e, 273 + crate::auth::TokenValidationError::OAuthTokenExpired 274 + | crate::auth::TokenValidationError::TokenExpired 275 + ) { 276 + let mut response = ApiError::from(e).into_response(); 277 + let nonce = crate::oauth::verify::generate_dpop_nonce(); 278 + if let Ok(nonce_val) = nonce.parse() { 279 + response.headers_mut().insert("DPoP-Nonce", nonce_val); 280 + } 281 + return response; 273 282 } 274 283 } 275 284 }
+62 -120
crates/tranquil-pds/src/api/repo/blob.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::{BearerAuthAllowDeactivated, ServiceTokenVerifier, is_service_token}; 2 + use crate::auth::{Auth, AuthAny, NotTakendown, Permissive}; 3 3 use crate::delegation::DelegationActionType; 4 4 use crate::state::AppState; 5 5 use crate::types::{CidLink, Did}; ··· 44 44 pub async fn upload_blob( 45 45 State(state): State<AppState>, 46 46 headers: axum::http::HeaderMap, 47 + auth: AuthAny<Permissive>, 47 48 body: Body, 48 - ) -> Response { 49 - let extracted = match crate::auth::extract_auth_token_from_header( 50 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 51 - ) { 52 - Some(t) => t, 53 - None => return ApiError::AuthenticationRequired.into_response(), 54 - }; 55 - let token = extracted.token; 56 - 57 - let is_service_auth = is_service_token(&token); 58 - 59 - let (did, _is_migration, controller_did): (Did, bool, Option<Did>) = if is_service_auth { 60 - debug!("Verifying service token for blob upload"); 61 - let verifier = ServiceTokenVerifier::new(); 62 - match verifier 63 - .verify_service_token(&token, Some("com.atproto.repo.uploadBlob")) 64 - .await 65 - { 66 - Ok(claims) => { 67 - debug!("Service token verified for DID: {}", claims.iss); 68 - let did: Did = match claims.iss.parse() { 69 - Ok(d) => d, 70 - Err(_) => { 71 - return ApiError::InvalidDid("Invalid DID format".into()).into_response(); 72 - } 73 - }; 74 - (did, false, None) 75 - } 76 - Err(e) => { 77 - error!("Service token verification failed: {:?}", e); 78 - return ApiError::AuthenticationFailed(Some(format!( 79 - "Service token verification failed: {}", 80 - e 81 - ))) 82 - .into_response(); 83 - } 49 + ) -> Result<Response, ApiError> { 50 + let (did, controller_did): (Did, Option<Did>) = match &auth { 51 + AuthAny::Service(service) => { 52 + service.require_lxm("com.atproto.repo.uploadBlob")?; 53 + (service.did.clone(), None) 84 54 } 85 - } else { 86 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 87 - let http_uri = format!( 88 - "https://{}/xrpc/com.atproto.repo.uploadBlob", 89 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 90 - ); 91 - match crate::auth::validate_token_with_dpop( 92 - state.user_repo.as_ref(), 93 - state.oauth_repo.as_ref(), 94 - &token, 95 - extracted.is_dpop, 96 - dpop_proof, 97 - "POST", 98 - &http_uri, 99 - true, 100 - false, 101 - ) 102 - .await 103 - { 104 - Ok(user) => { 105 - let mime_type_for_check = headers 106 - .get("content-type") 107 - .and_then(|h| h.to_str().ok()) 108 - .unwrap_or("application/octet-stream"); 109 - if let Err(e) = crate::auth::scope_check::check_blob_scope( 110 - user.is_oauth, 111 - user.scope.as_deref(), 112 - mime_type_for_check, 113 - ) { 114 - return e; 115 - } 116 - let deactivated = state 117 - .user_repo 118 - .get_status_by_did(&user.did) 119 - .await 120 - .ok() 121 - .flatten() 122 - .and_then(|s| s.deactivated_at); 123 - let ctrl_did = user.controller_did.clone(); 124 - (user.did, deactivated.is_some(), ctrl_did) 55 + AuthAny::User(user) => { 56 + if user.status.is_takendown() { 57 + return Err(ApiError::AccountTakedown); 125 58 } 126 - Err(_) => { 127 - return ApiError::AuthenticationFailed(None).into_response(); 59 + let mime_type_for_check = headers 60 + .get("content-type") 61 + .and_then(|h| h.to_str().ok()) 62 + .unwrap_or("application/octet-stream"); 63 + if let Err(e) = crate::auth::scope_check::check_blob_scope( 64 + user.is_oauth(), 65 + user.scope.as_deref(), 66 + mime_type_for_check, 67 + ) { 68 + return Ok(e); 128 69 } 70 + (user.did.clone(), user.controller_did.clone()) 129 71 } 130 72 }; 131 73 ··· 135 77 .await 136 78 .unwrap_or(false) 137 79 { 138 - return ApiError::Forbidden.into_response(); 80 + return Err(ApiError::Forbidden); 139 81 } 140 82 141 83 let client_mime_hint = headers ··· 143 85 .and_then(|h| h.to_str().ok()) 144 86 .unwrap_or("application/octet-stream"); 145 87 146 - let user_id = match state.user_repo.get_id_by_did(&did).await { 147 - Ok(Some(id)) => id, 148 - _ => { 149 - return ApiError::InternalError(None).into_response(); 150 - } 151 - }; 88 + let user_id = state 89 + .user_repo 90 + .get_id_by_did(&did) 91 + .await 92 + .ok() 93 + .flatten() 94 + .ok_or(ApiError::InternalError(None))?; 152 95 153 96 let temp_key = format!("temp/{}", uuid::Uuid::new_v4()); 154 97 let max_size = get_max_blob_size() as u64; ··· 161 104 162 105 info!("Starting streaming blob upload to temp key: {}", temp_key); 163 106 164 - let upload_result = match state.blob_store.put_stream(&temp_key, pinned_stream).await { 165 - Ok(result) => result, 166 - Err(e) => { 107 + let upload_result = state 108 + .blob_store 109 + .put_stream(&temp_key, pinned_stream) 110 + .await 111 + .map_err(|e| { 167 112 error!("Failed to stream blob to storage: {:?}", e); 168 - return ApiError::InternalError(Some("Failed to store blob".into())).into_response(); 169 - } 170 - }; 113 + ApiError::InternalError(Some("Failed to store blob".into())) 114 + })?; 171 115 172 116 let size = upload_result.size; 173 117 if size > max_size { 174 118 let _ = state.blob_store.delete(&temp_key).await; 175 - return ApiError::InvalidRequest(format!( 119 + return Err(ApiError::InvalidRequest(format!( 176 120 "Blob size {} exceeds maximum of {} bytes", 177 121 size, max_size 178 - )) 179 - .into_response(); 122 + ))); 180 123 } 181 124 182 125 let mime_type = match state.blob_store.get_head(&temp_key, 8192).await { ··· 192 135 Err(e) => { 193 136 let _ = state.blob_store.delete(&temp_key).await; 194 137 error!("Failed to create multihash for blob: {:?}", e); 195 - return ApiError::InternalError(Some("Failed to hash blob".into())).into_response(); 138 + return Err(ApiError::InternalError(Some("Failed to hash blob".into()))); 196 139 } 197 140 }; 198 141 let cid = Cid::new_v1(0x55, multihash); ··· 215 158 Err(e) => { 216 159 let _ = state.blob_store.delete(&temp_key).await; 217 160 error!("Failed to insert blob record: {:?}", e); 218 - return ApiError::InternalError(None).into_response(); 161 + return Err(ApiError::InternalError(None)); 219 162 } 220 163 }; 221 164 222 165 if was_inserted && let Err(e) = state.blob_store.copy(&temp_key, &storage_key).await { 223 166 let _ = state.blob_store.delete(&temp_key).await; 224 167 error!("Failed to copy blob to final location: {:?}", e); 225 - return ApiError::InternalError(Some("Failed to store blob".into())).into_response(); 168 + return Err(ApiError::InternalError(Some("Failed to store blob".into()))); 226 169 } 227 170 228 171 let _ = state.blob_store.delete(&temp_key).await; ··· 246 189 .await; 247 190 } 248 191 249 - Json(json!({ 192 + Ok(Json(json!({ 250 193 "blob": { 251 194 "$type": "blob", 252 195 "ref": { ··· 256 199 "size": size 257 200 } 258 201 })) 259 - .into_response() 202 + .into_response()) 260 203 } 261 204 262 205 #[derive(Deserialize)] ··· 281 224 282 225 pub async fn list_missing_blobs( 283 226 State(state): State<AppState>, 284 - auth: BearerAuthAllowDeactivated, 227 + auth: Auth<NotTakendown>, 285 228 Query(params): Query<ListMissingBlobsParams>, 286 - ) -> Response { 287 - let auth_user = auth.0; 288 - let did = &auth_user.did; 289 - let user = match state.user_repo.get_by_did(did).await { 290 - Ok(Some(u)) => u, 291 - Ok(None) => return ApiError::InternalError(None).into_response(), 292 - Err(e) => { 229 + ) -> Result<Response, ApiError> { 230 + let did = &auth.did; 231 + let user = state 232 + .user_repo 233 + .get_by_did(did) 234 + .await 235 + .map_err(|e| { 293 236 error!("DB error fetching user: {:?}", e); 294 - return ApiError::InternalError(None).into_response(); 295 - } 296 - }; 237 + ApiError::InternalError(None) 238 + })? 239 + .ok_or(ApiError::InternalError(None))?; 240 + 297 241 let limit = params.limit.unwrap_or(500).clamp(1, 1000); 298 242 let cursor = params.cursor.as_deref(); 299 - let missing = match state 243 + let missing = state 300 244 .blob_repo 301 245 .list_missing_blobs(user.id, cursor, limit + 1) 302 246 .await 303 - { 304 - Ok(m) => m, 305 - Err(e) => { 247 + .map_err(|e| { 306 248 error!("DB error fetching missing blobs: {:?}", e); 307 - return ApiError::InternalError(None).into_response(); 308 - } 309 - }; 249 + ApiError::InternalError(None) 250 + })?; 251 + 310 252 let has_more = missing.len() > limit as usize; 311 253 let blobs: Vec<RecordBlob> = missing 312 254 .into_iter() ··· 321 263 } else { 322 264 None 323 265 }; 324 - ( 266 + Ok(( 325 267 StatusCode::OK, 326 268 Json(ListMissingBlobsOutput { 327 269 cursor: next_cursor, 328 270 blobs, 329 271 }), 330 272 ) 331 - .into_response() 273 + .into_response()) 332 274 }
+129 -131
crates/tranquil-pds/src/api/repo/import.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 3 use crate::api::repo::record::create_signed_commit; 4 - use crate::auth::BearerAuthAllowDeactivated; 4 + use crate::auth::{Auth, NotTakendown}; 5 5 use crate::state::AppState; 6 6 use crate::sync::import::{ImportError, apply_import, parse_car}; 7 7 use crate::sync::verify::CarVerifier; ··· 23 23 24 24 pub async fn import_repo( 25 25 State(state): State<AppState>, 26 - auth: BearerAuthAllowDeactivated, 26 + auth: Auth<NotTakendown>, 27 27 body: Bytes, 28 - ) -> Response { 28 + ) -> Result<Response, ApiError> { 29 29 let accepting_imports = std::env::var("ACCEPTING_REPO_IMPORTS") 30 30 .map(|v| v != "false" && v != "0") 31 31 .unwrap_or(true); 32 32 if !accepting_imports { 33 - return ApiError::InvalidRequest("Service is not accepting repo imports".into()) 34 - .into_response(); 33 + return Err(ApiError::InvalidRequest( 34 + "Service is not accepting repo imports".into(), 35 + )); 35 36 } 36 37 let max_size: usize = std::env::var("MAX_IMPORT_SIZE") 37 38 .ok() 38 39 .and_then(|s| s.parse().ok()) 39 40 .unwrap_or(DEFAULT_MAX_IMPORT_SIZE); 40 41 if body.len() > max_size { 41 - return ApiError::PayloadTooLarge(format!( 42 + return Err(ApiError::PayloadTooLarge(format!( 42 43 "Import size exceeds limit of {} bytes", 43 44 max_size 44 - )) 45 - .into_response(); 45 + ))); 46 46 } 47 - let auth_user = auth.0; 48 - let did = &auth_user.did; 49 - let user = match state.user_repo.get_by_did(did).await { 50 - Ok(Some(row)) => row, 51 - Ok(None) => { 52 - return ApiError::AccountNotFound.into_response(); 53 - } 54 - Err(e) => { 47 + let did = &auth.did; 48 + let user = state 49 + .user_repo 50 + .get_by_did(did) 51 + .await 52 + .map_err(|e| { 55 53 error!("DB error fetching user: {:?}", e); 56 - return ApiError::InternalError(None).into_response(); 57 - } 58 - }; 54 + ApiError::InternalError(None) 55 + })? 56 + .ok_or(ApiError::AccountNotFound)?; 59 57 if user.takedown_ref.is_some() { 60 - return ApiError::AccountTakedown.into_response(); 58 + return Err(ApiError::AccountTakedown); 61 59 } 62 60 let user_id = user.id; 63 61 let (root, blocks) = match parse_car(&body).await { 64 62 Ok((r, b)) => (r, b), 65 63 Err(ImportError::InvalidRootCount) => { 66 - return ApiError::InvalidRequest("Expected exactly one root in CAR file".into()) 67 - .into_response(); 64 + return Err(ApiError::InvalidRequest( 65 + "Expected exactly one root in CAR file".into(), 66 + )); 68 67 } 69 68 Err(ImportError::CarParse(msg)) => { 70 - return ApiError::InvalidRequest(format!("Failed to parse CAR file: {}", msg)) 71 - .into_response(); 69 + return Err(ApiError::InvalidRequest(format!( 70 + "Failed to parse CAR file: {}", 71 + msg 72 + ))); 72 73 } 73 74 Err(e) => { 74 75 error!("CAR parsing error: {:?}", e); 75 - return ApiError::InvalidRequest(format!("Invalid CAR file: {}", e)).into_response(); 76 + return Err(ApiError::InvalidRequest(format!("Invalid CAR file: {}", e))); 76 77 } 77 78 }; 78 79 info!( ··· 82 83 root 83 84 ); 84 85 let Some(root_block) = blocks.get(&root) else { 85 - return ApiError::InvalidRequest("Root block not found in CAR file".into()).into_response(); 86 + return Err(ApiError::InvalidRequest( 87 + "Root block not found in CAR file".into(), 88 + )); 86 89 }; 87 90 let commit_did = match jacquard_repo::commit::Commit::from_cbor(root_block) { 88 91 Ok(commit) => commit.did().to_string(), 89 92 Err(e) => { 90 - return ApiError::InvalidRequest(format!("Invalid commit: {}", e)).into_response(); 93 + return Err(ApiError::InvalidRequest(format!("Invalid commit: {}", e))); 91 94 } 92 95 }; 93 96 if commit_did != *did { 94 - return ApiError::InvalidRepo(format!( 97 + return Err(ApiError::InvalidRepo(format!( 95 98 "CAR file is for DID {} but you are authenticated as {}", 96 99 commit_did, did 97 - )) 98 - .into_response(); 100 + ))); 99 101 } 100 102 let skip_verification = std::env::var("SKIP_IMPORT_VERIFICATION") 101 103 .map(|v| v == "true" || v == "1") ··· 117 119 commit_did, 118 120 expected_did, 119 121 }) => { 120 - return ApiError::InvalidRepo(format!( 122 + return Err(ApiError::InvalidRepo(format!( 121 123 "CAR file is for DID {} but you are authenticated as {}", 122 124 commit_did, expected_did 123 - )) 124 - .into_response(); 125 + ))); 125 126 } 126 127 Err(crate::sync::verify::VerifyError::MstValidationFailed(msg)) => { 127 - return ApiError::InvalidRequest(format!("MST validation failed: {}", msg)) 128 - .into_response(); 128 + return Err(ApiError::InvalidRequest(format!( 129 + "MST validation failed: {}", 130 + msg 131 + ))); 129 132 } 130 133 Err(e) => { 131 134 error!("CAR structure verification error: {:?}", e); 132 - return ApiError::InvalidRequest(format!("CAR verification failed: {}", e)) 133 - .into_response(); 135 + return Err(ApiError::InvalidRequest(format!( 136 + "CAR verification failed: {}", 137 + e 138 + ))); 134 139 } 135 140 } 136 141 } else { ··· 147 152 commit_did, 148 153 expected_did, 149 154 }) => { 150 - return ApiError::InvalidRepo(format!( 155 + return Err(ApiError::InvalidRepo(format!( 151 156 "CAR file is for DID {} but you are authenticated as {}", 152 157 commit_did, expected_did 153 - )) 154 - .into_response(); 158 + ))); 155 159 } 156 160 Err(crate::sync::verify::VerifyError::InvalidSignature) => { 157 - return ApiError::InvalidRequest( 161 + return Err(ApiError::InvalidRequest( 158 162 "CAR file commit signature verification failed".into(), 159 - ) 160 - .into_response(); 163 + )); 161 164 } 162 165 Err(crate::sync::verify::VerifyError::DidResolutionFailed(msg)) => { 163 166 warn!("DID resolution failed during import verification: {}", msg); 164 - return ApiError::InvalidRequest(format!("Failed to verify DID: {}", msg)) 165 - .into_response(); 167 + return Err(ApiError::InvalidRequest(format!( 168 + "Failed to verify DID: {}", 169 + msg 170 + ))); 166 171 } 167 172 Err(crate::sync::verify::VerifyError::NoSigningKey) => { 168 - return ApiError::InvalidRequest( 173 + return Err(ApiError::InvalidRequest( 169 174 "DID document does not contain a signing key".into(), 170 - ) 171 - .into_response(); 175 + )); 172 176 } 173 177 Err(crate::sync::verify::VerifyError::MstValidationFailed(msg)) => { 174 - return ApiError::InvalidRequest(format!("MST validation failed: {}", msg)) 175 - .into_response(); 178 + return Err(ApiError::InvalidRequest(format!( 179 + "MST validation failed: {}", 180 + msg 181 + ))); 176 182 } 177 183 Err(e) => { 178 184 error!("CAR verification error: {:?}", e); 179 - return ApiError::InvalidRequest(format!("CAR verification failed: {}", e)) 180 - .into_response(); 185 + return Err(ApiError::InvalidRequest(format!( 186 + "CAR verification failed: {}", 187 + e 188 + ))); 181 189 } 182 190 } 183 191 } ··· 227 235 } 228 236 } 229 237 } 230 - let key_row = match state.user_repo.get_user_with_key_by_did(did).await { 231 - Ok(Some(row)) => row, 232 - Ok(None) => { 233 - error!("No signing key found for user {}", did); 234 - return ApiError::InternalError(Some("Signing key not found".into())) 235 - .into_response(); 236 - } 237 - Err(e) => { 238 + let key_row = state 239 + .user_repo 240 + .get_user_with_key_by_did(did) 241 + .await 242 + .map_err(|e| { 238 243 error!("DB error fetching signing key: {:?}", e); 239 - return ApiError::InternalError(None).into_response(); 240 - } 241 - }; 244 + ApiError::InternalError(None) 245 + })? 246 + .ok_or_else(|| { 247 + error!("No signing key found for user {}", did); 248 + ApiError::InternalError(Some("Signing key not found".into())) 249 + })?; 242 250 let key_bytes = 243 - match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) { 244 - Ok(k) => k, 245 - Err(e) => { 251 + crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 252 + .map_err(|e| { 246 253 error!("Failed to decrypt signing key: {}", e); 247 - return ApiError::InternalError(None).into_response(); 248 - } 249 - }; 250 - let signing_key = match SigningKey::from_slice(&key_bytes) { 251 - Ok(k) => k, 252 - Err(e) => { 253 - error!("Invalid signing key: {:?}", e); 254 - return ApiError::InternalError(None).into_response(); 255 - } 256 - }; 254 + ApiError::InternalError(None) 255 + })?; 256 + let signing_key = SigningKey::from_slice(&key_bytes).map_err(|e| { 257 + error!("Invalid signing key: {:?}", e); 258 + ApiError::InternalError(None) 259 + })?; 257 260 let new_rev = Tid::now(LimitedU32::MIN); 258 261 let new_rev_str = new_rev.to_string(); 259 - let (commit_bytes, _sig) = match create_signed_commit( 262 + let (commit_bytes, _sig) = create_signed_commit( 260 263 did, 261 264 import_result.data_cid, 262 265 &new_rev_str, 263 266 None, 264 267 &signing_key, 265 - ) { 266 - Ok(result) => result, 267 - Err(e) => { 268 - error!("Failed to create new commit: {}", e); 269 - return ApiError::InternalError(None).into_response(); 270 - } 271 - }; 272 - let new_root_cid: cid::Cid = match state.block_store.put(&commit_bytes).await { 273 - Ok(cid) => cid, 274 - Err(e) => { 268 + ) 269 + .map_err(|e| { 270 + error!("Failed to create new commit: {}", e); 271 + ApiError::InternalError(None) 272 + })?; 273 + let new_root_cid: cid::Cid = 274 + state.block_store.put(&commit_bytes).await.map_err(|e| { 275 275 error!("Failed to store new commit block: {:?}", e); 276 - return ApiError::InternalError(None).into_response(); 277 - } 278 - }; 276 + ApiError::InternalError(None) 277 + })?; 279 278 let new_root_cid_link = CidLink::new_unchecked(new_root_cid.to_string()); 280 - if let Err(e) = state 279 + state 281 280 .repo_repo 282 281 .update_repo_root(user_id, &new_root_cid_link, &new_rev_str) 283 282 .await 284 - { 285 - error!("Failed to update repo root: {:?}", e); 286 - return ApiError::InternalError(None).into_response(); 287 - } 283 + .map_err(|e| { 284 + error!("Failed to update repo root: {:?}", e); 285 + ApiError::InternalError(None) 286 + })?; 288 287 let mut all_block_cids: Vec<Vec<u8>> = blocks.keys().map(|c| c.to_bytes()).collect(); 289 288 all_block_cids.push(new_root_cid.to_bytes()); 290 - if let Err(e) = state 289 + state 291 290 .repo_repo 292 291 .insert_user_blocks(user_id, &all_block_cids, &new_rev_str) 293 292 .await 294 - { 295 - error!("Failed to insert user_blocks: {:?}", e); 296 - return ApiError::InternalError(None).into_response(); 297 - } 293 + .map_err(|e| { 294 + error!("Failed to insert user_blocks: {:?}", e); 295 + ApiError::InternalError(None) 296 + })?; 298 297 let new_root_str = new_root_cid.to_string(); 299 298 info!( 300 299 "Created new commit for imported repo: cid={}, rev={}", ··· 324 323 ); 325 324 } 326 325 } 327 - EmptyResponse::ok().into_response() 326 + Ok(EmptyResponse::ok().into_response()) 328 327 } 329 - Err(ImportError::SizeLimitExceeded) => { 330 - ApiError::PayloadTooLarge(format!("Import exceeds block limit of {}", max_blocks)) 331 - .into_response() 332 - } 333 - Err(ImportError::RepoNotFound) => { 334 - ApiError::RepoNotFound(Some("Repository not initialized for this account".into())) 335 - .into_response() 336 - } 337 - Err(ImportError::InvalidCbor(msg)) => { 338 - ApiError::InvalidRequest(format!("Invalid CBOR data: {}", msg)).into_response() 339 - } 340 - Err(ImportError::InvalidCommit(msg)) => { 341 - ApiError::InvalidRequest(format!("Invalid commit structure: {}", msg)).into_response() 342 - } 343 - Err(ImportError::BlockNotFound(cid)) => { 344 - ApiError::InvalidRequest(format!("Referenced block not found in CAR: {}", cid)) 345 - .into_response() 346 - } 347 - Err(ImportError::ConcurrentModification) => ApiError::InvalidSwap(Some( 328 + Err(ImportError::SizeLimitExceeded) => Err(ApiError::PayloadTooLarge(format!( 329 + "Import exceeds block limit of {}", 330 + max_blocks 331 + ))), 332 + Err(ImportError::RepoNotFound) => Err(ApiError::RepoNotFound(Some( 333 + "Repository not initialized for this account".into(), 334 + ))), 335 + Err(ImportError::InvalidCbor(msg)) => Err(ApiError::InvalidRequest(format!( 336 + "Invalid CBOR data: {}", 337 + msg 338 + ))), 339 + Err(ImportError::InvalidCommit(msg)) => Err(ApiError::InvalidRequest(format!( 340 + "Invalid commit structure: {}", 341 + msg 342 + ))), 343 + Err(ImportError::BlockNotFound(cid)) => Err(ApiError::InvalidRequest(format!( 344 + "Referenced block not found in CAR: {}", 345 + cid 346 + ))), 347 + Err(ImportError::ConcurrentModification) => Err(ApiError::InvalidSwap(Some( 348 348 "Repository is being modified by another operation, please retry".into(), 349 - )) 350 - .into_response(), 351 - Err(ImportError::VerificationFailed(ve)) => { 352 - ApiError::InvalidRequest(format!("CAR verification failed: {}", ve)).into_response() 353 - } 354 - Err(ImportError::DidMismatch { car_did, auth_did }) => ApiError::InvalidRequest(format!( 355 - "CAR is for {} but authenticated as {}", 356 - car_did, auth_did 357 - )) 358 - .into_response(), 349 + ))), 350 + Err(ImportError::VerificationFailed(ve)) => Err(ApiError::InvalidRequest(format!( 351 + "CAR verification failed: {}", 352 + ve 353 + ))), 354 + Err(ImportError::DidMismatch { car_did, auth_did }) => Err(ApiError::InvalidRequest( 355 + format!("CAR is for {} but authenticated as {}", car_did, auth_did), 356 + )), 359 357 Err(e) => { 360 358 error!("Import error: {:?}", e); 361 - ApiError::InternalError(None).into_response() 359 + Err(ApiError::InternalError(None)) 362 360 } 363 361 } 364 362 }
+63 -65
crates/tranquil-pds/src/api/repo/record/batch.rs
··· 1 1 use super::validation::validate_record_with_status; 2 2 use crate::api::error::ApiError; 3 3 use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids}; 4 - use crate::auth::BearerAuth; 4 + use crate::auth::{Active, Auth}; 5 5 use crate::delegation::DelegationActionType; 6 6 use crate::repo::tracking::TrackingBlockStore; 7 7 use crate::state::AppState; ··· 262 262 263 263 pub async fn apply_writes( 264 264 State(state): State<AppState>, 265 - auth: BearerAuth, 265 + auth: Auth<Active>, 266 266 Json(input): Json<ApplyWritesInput>, 267 - ) -> Response { 267 + ) -> Result<Response, ApiError> { 268 268 info!( 269 269 "apply_writes called: repo={}, writes={}", 270 270 input.repo, 271 271 input.writes.len() 272 272 ); 273 - let auth_user = auth.0; 274 - let did = auth_user.did.clone(); 275 - let is_oauth = auth_user.is_oauth; 276 - let scope = auth_user.scope; 277 - let controller_did = auth_user.controller_did.clone(); 273 + let did = auth.did.clone(); 274 + let is_oauth = auth.is_oauth(); 275 + let scope = auth.scope.clone(); 276 + let controller_did = auth.controller_did.clone(); 278 277 if input.repo.as_str() != did { 279 - return ApiError::InvalidRepo("Repo does not match authenticated user".into()) 280 - .into_response(); 278 + return Err(ApiError::InvalidRepo( 279 + "Repo does not match authenticated user".into(), 280 + )); 281 281 } 282 282 if state 283 283 .user_repo ··· 285 285 .await 286 286 .unwrap_or(false) 287 287 { 288 - return ApiError::AccountMigrated.into_response(); 288 + return Err(ApiError::AccountMigrated); 289 289 } 290 290 let is_verified = state 291 291 .user_repo ··· 298 298 .await 299 299 .unwrap_or(false); 300 300 if !is_verified && !is_delegated { 301 - return ApiError::AccountNotVerified.into_response(); 301 + return Err(ApiError::AccountNotVerified); 302 302 } 303 303 if input.writes.is_empty() { 304 - return ApiError::InvalidRequest("writes array is empty".into()).into_response(); 304 + return Err(ApiError::InvalidRequest("writes array is empty".into())); 305 305 } 306 306 if input.writes.len() > MAX_BATCH_WRITES { 307 - return ApiError::InvalidRequest(format!("Too many writes (max {})", MAX_BATCH_WRITES)) 308 - .into_response(); 307 + return Err(ApiError::InvalidRequest(format!( 308 + "Too many writes (max {})", 309 + MAX_BATCH_WRITES 310 + ))); 309 311 } 310 312 311 313 let has_custom_scope = scope ··· 374 376 }) 375 377 .next() 376 378 { 377 - return err; 379 + return Ok(err); 378 380 } 379 381 } 380 382 381 - let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&did).await { 382 - Ok(Some(id)) => id, 383 - _ => return ApiError::InternalError(Some("User not found".into())).into_response(), 384 - }; 385 - let root_cid_str = match state.repo_repo.get_repo_root_cid_by_user_id(user_id).await { 386 - Ok(Some(cid_str)) => cid_str, 387 - _ => return ApiError::InternalError(Some("Repo root not found".into())).into_response(), 388 - }; 389 - let current_root_cid = match Cid::from_str(&root_cid_str) { 390 - Ok(c) => c, 391 - Err(_) => { 392 - return ApiError::InternalError(Some("Invalid repo root CID".into())).into_response(); 393 - } 394 - }; 383 + let user_id: uuid::Uuid = state 384 + .user_repo 385 + .get_id_by_did(&did) 386 + .await 387 + .ok() 388 + .flatten() 389 + .ok_or_else(|| ApiError::InternalError(Some("User not found".into())))?; 390 + let root_cid_str = state 391 + .repo_repo 392 + .get_repo_root_cid_by_user_id(user_id) 393 + .await 394 + .ok() 395 + .flatten() 396 + .ok_or_else(|| ApiError::InternalError(Some("Repo root not found".into())))?; 397 + let current_root_cid = Cid::from_str(&root_cid_str) 398 + .map_err(|_| ApiError::InternalError(Some("Invalid repo root CID".into())))?; 395 399 if let Some(swap_commit) = &input.swap_commit 396 400 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 397 401 { 398 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 402 + return Err(ApiError::InvalidSwap(Some("Repo has been modified".into()))); 399 403 } 400 404 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 401 - let commit_bytes = match tracking_store.get(&current_root_cid).await { 402 - Ok(Some(b)) => b, 403 - _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(), 404 - }; 405 - let commit = match Commit::from_cbor(&commit_bytes) { 406 - Ok(c) => c, 407 - _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 408 - }; 405 + let commit_bytes = tracking_store 406 + .get(&current_root_cid) 407 + .await 408 + .ok() 409 + .flatten() 410 + .ok_or_else(|| ApiError::InternalError(Some("Commit block not found".into())))?; 411 + let commit = Commit::from_cbor(&commit_bytes) 412 + .map_err(|_| ApiError::InternalError(Some("Failed to parse commit".into())))?; 409 413 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 410 414 let initial_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 411 415 let WriteAccumulator { ··· 424 428 .await 425 429 { 426 430 Ok(acc) => acc, 427 - Err(response) => return response, 428 - }; 429 - let new_mst_root = match mst.persist().await { 430 - Ok(c) => c, 431 - Err(_) => { 432 - return ApiError::InternalError(Some("Failed to persist MST".into())).into_response(); 433 - } 431 + Err(response) => return Ok(response), 434 432 }; 433 + let new_mst_root = mst 434 + .persist() 435 + .await 436 + .map_err(|_| ApiError::InternalError(Some("Failed to persist MST".into())))?; 435 437 let (new_mst_blocks, old_mst_blocks) = { 436 438 let mut new_blocks = std::collections::BTreeMap::new(); 437 439 let mut old_blocks = std::collections::BTreeMap::new(); 438 440 for key in &modified_keys { 439 - if mst.blocks_for_path(key, &mut new_blocks).await.is_err() { 440 - return ApiError::InternalError(Some( 441 - "Failed to get new MST blocks for path".into(), 442 - )) 443 - .into_response(); 444 - } 445 - if original_mst 441 + mst.blocks_for_path(key, &mut new_blocks) 442 + .await 443 + .map_err(|_| { 444 + ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 445 + })?; 446 + original_mst 446 447 .blocks_for_path(key, &mut old_blocks) 447 448 .await 448 - .is_err() 449 - { 450 - return ApiError::InternalError(Some( 451 - "Failed to get old MST blocks for path".into(), 452 - )) 453 - .into_response(); 454 - } 449 + .map_err(|_| { 450 + ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 451 + })?; 455 452 } 456 453 (new_blocks, old_blocks) 457 454 }; ··· 503 500 { 504 501 Ok(res) => res, 505 502 Err(e) if e.contains("ConcurrentModification") => { 506 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 503 + return Err(ApiError::InvalidSwap(Some("Repo has been modified".into()))); 507 504 } 508 505 Err(e) => { 509 506 error!("Commit failed: {}", e); 510 - return ApiError::InternalError(Some("Failed to commit changes".into())) 511 - .into_response(); 507 + return Err(ApiError::InternalError(Some( 508 + "Failed to commit changes".into(), 509 + ))); 512 510 } 513 511 }; 514 512 ··· 557 555 .await; 558 556 } 559 557 560 - ( 558 + Ok(( 561 559 StatusCode::OK, 562 560 Json(ApplyWritesOutput { 563 561 commit: CommitInfo { ··· 567 565 results, 568 566 }), 569 567 ) 570 - .into_response() 568 + .into_response()) 571 569 }
+47 -38
crates/tranquil-pds/src/api/repo/record/delete.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log}; 3 3 use crate::api::repo::record::write::{CommitInfo, prepare_repo_write}; 4 + use crate::auth::{Active, Auth}; 4 5 use crate::delegation::DelegationActionType; 5 6 use crate::repo::tracking::TrackingBlockStore; 6 7 use crate::state::AppState; ··· 8 9 use axum::{ 9 10 Json, 10 11 extract::State, 11 - http::{HeaderMap, StatusCode}, 12 + http::StatusCode, 12 13 response::{IntoResponse, Response}, 13 14 }; 14 15 use cid::Cid; ··· 39 40 40 41 pub async fn delete_record( 41 42 State(state): State<AppState>, 42 - headers: HeaderMap, 43 - axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 43 + auth: Auth<Active>, 44 44 Json(input): Json<DeleteRecordInput>, 45 - ) -> Response { 46 - let auth = match prepare_repo_write( 47 - &state, 48 - &headers, 49 - &input.repo, 50 - "POST", 51 - &crate::util::build_full_url(&uri.to_string()), 52 - ) 53 - .await 54 - { 45 + ) -> Result<Response, crate::api::error::ApiError> { 46 + let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await { 55 47 Ok(res) => res, 56 - Err(err_res) => return err_res, 48 + Err(err_res) => return Ok(err_res), 57 49 }; 58 50 59 51 if let Err(e) = crate::auth::scope_check::check_repo_scope( 60 - auth.is_oauth, 61 - auth.scope.as_deref(), 52 + repo_auth.is_oauth, 53 + repo_auth.scope.as_deref(), 62 54 crate::oauth::RepoAction::Delete, 63 55 &input.collection, 64 56 ) { 65 - return e; 57 + return Ok(e); 66 58 } 67 59 68 - let did = auth.did; 69 - let user_id = auth.user_id; 70 - let current_root_cid = auth.current_root_cid; 71 - let controller_did = auth.controller_did; 60 + let did = repo_auth.did; 61 + let user_id = repo_auth.user_id; 62 + let current_root_cid = repo_auth.current_root_cid; 63 + let controller_did = repo_auth.controller_did; 72 64 73 65 if let Some(swap_commit) = &input.swap_commit 74 66 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 75 67 { 76 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 68 + return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 77 69 } 78 70 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 79 71 let commit_bytes = match tracking_store.get(&current_root_cid).await { 80 72 Ok(Some(b)) => b, 81 - _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(), 73 + _ => { 74 + return Ok( 75 + ApiError::InternalError(Some("Commit block not found".into())).into_response(), 76 + ); 77 + } 82 78 }; 83 79 let commit = match Commit::from_cbor(&commit_bytes) { 84 80 Ok(c) => c, 85 - _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 81 + _ => { 82 + return Ok( 83 + ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 84 + ); 85 + } 86 86 }; 87 87 let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 88 88 let key = format!("{}/{}", input.collection, input.rkey); ··· 90 90 let expected_cid = Cid::from_str(swap_record_str).ok(); 91 91 let actual_cid = mst.get(&key).await.ok().flatten(); 92 92 if expected_cid != actual_cid { 93 - return ApiError::InvalidSwap(Some( 93 + return Ok(ApiError::InvalidSwap(Some( 94 94 "Record has been modified or does not exist".into(), 95 95 )) 96 - .into_response(); 96 + .into_response()); 97 97 } 98 98 } 99 99 let prev_record_cid = mst.get(&key).await.ok().flatten(); 100 100 if prev_record_cid.is_none() { 101 - return (StatusCode::OK, Json(DeleteRecordOutput { commit: None })).into_response(); 101 + return Ok((StatusCode::OK, Json(DeleteRecordOutput { commit: None })).into_response()); 102 102 } 103 103 let new_mst = match mst.delete(&key).await { 104 104 Ok(m) => m, 105 105 Err(e) => { 106 106 error!("Failed to delete from MST: {:?}", e); 107 - return ApiError::InternalError(Some(format!("Failed to delete from MST: {:?}", e))) 108 - .into_response(); 107 + return Ok(ApiError::InternalError(Some(format!( 108 + "Failed to delete from MST: {:?}", 109 + e 110 + ))) 111 + .into_response()); 109 112 } 110 113 }; 111 114 let new_mst_root = match new_mst.persist().await { 112 115 Ok(c) => c, 113 116 Err(e) => { 114 117 error!("Failed to persist MST: {:?}", e); 115 - return ApiError::InternalError(Some("Failed to persist MST".into())).into_response(); 118 + return Ok( 119 + ApiError::InternalError(Some("Failed to persist MST".into())).into_response(), 120 + ); 116 121 } 117 122 }; 118 123 let collection_for_audit = input.collection.to_string(); ··· 129 134 .await 130 135 .is_err() 131 136 { 132 - return ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 133 - .into_response(); 137 + return Ok( 138 + ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 139 + .into_response(), 140 + ); 134 141 } 135 142 if mst 136 143 .blocks_for_path(&key, &mut old_mst_blocks) 137 144 .await 138 145 .is_err() 139 146 { 140 - return ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 141 - .into_response(); 147 + return Ok( 148 + ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 149 + .into_response(), 150 + ); 142 151 } 143 152 let mut relevant_blocks = new_mst_blocks.clone(); 144 153 relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); ··· 177 186 { 178 187 Ok(res) => res, 179 188 Err(e) if e.contains("ConcurrentModification") => { 180 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 189 + return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 181 190 } 182 - Err(e) => return ApiError::InternalError(Some(e)).into_response(), 191 + Err(e) => return Ok(ApiError::InternalError(Some(e)).into_response()), 183 192 }; 184 193 185 194 if let Some(ref controller) = controller_did { ··· 210 219 error!("Failed to remove backlinks for {}: {}", deleted_uri, e); 211 220 } 212 221 213 - ( 222 + Ok(( 214 223 StatusCode::OK, 215 224 Json(DeleteRecordOutput { 216 225 commit: Some(CommitInfo { ··· 219 228 }), 220 229 }), 221 230 ) 222 - .into_response() 231 + .into_response()) 223 232 } 224 233 225 234 use crate::types::Did;
+114 -121
crates/tranquil-pds/src/api/repo/record/write.rs
··· 3 3 use crate::api::repo::record::utils::{ 4 4 CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids, 5 5 }; 6 + use crate::auth::{Active, Auth}; 6 7 use crate::delegation::DelegationActionType; 7 8 use crate::repo::tracking::TrackingBlockStore; 8 9 use crate::state::AppState; ··· 10 11 use axum::{ 11 12 Json, 12 13 extract::State, 13 - http::{HeaderMap, StatusCode}, 14 + http::StatusCode, 14 15 response::{IntoResponse, Response}, 15 16 }; 16 17 use cid::Cid; ··· 33 34 34 35 pub async fn prepare_repo_write( 35 36 state: &AppState, 36 - headers: &HeaderMap, 37 + auth_user: &crate::auth::AuthenticatedUser, 37 38 repo: &AtIdentifier, 38 - http_method: &str, 39 - http_uri: &str, 40 39 ) -> Result<RepoWriteAuth, Response> { 41 - let extracted = crate::auth::extract_auth_token_from_header( 42 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 43 - ) 44 - .ok_or_else(|| ApiError::AuthenticationRequired.into_response())?; 45 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 46 - let auth_user = crate::auth::validate_token_with_dpop( 47 - state.user_repo.as_ref(), 48 - state.oauth_repo.as_ref(), 49 - &extracted.token, 50 - extracted.is_dpop, 51 - dpop_proof, 52 - http_method, 53 - http_uri, 54 - false, 55 - false, 56 - ) 57 - .await 58 - .map_err(|e| { 59 - tracing::warn!(error = ?e, is_dpop = extracted.is_dpop, "Token validation failed in prepare_repo_write"); 60 - ApiError::from(e).into_response() 61 - })?; 62 40 if repo.as_str() != auth_user.did.as_str() { 63 41 return Err( 64 42 ApiError::InvalidRepo("Repo does not match authenticated user".into()).into_response(), ··· 112 90 did: auth_user.did.clone(), 113 91 user_id, 114 92 current_root_cid, 115 - is_oauth: auth_user.is_oauth, 116 - scope: auth_user.scope, 93 + is_oauth: auth_user.is_oauth(), 94 + scope: auth_user.scope.clone(), 117 95 controller_did: auth_user.controller_did.clone(), 118 96 }) 119 97 } ··· 146 124 } 147 125 pub async fn create_record( 148 126 State(state): State<AppState>, 149 - headers: HeaderMap, 150 - axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 127 + auth: Auth<Active>, 151 128 Json(input): Json<CreateRecordInput>, 152 - ) -> Response { 153 - let auth = match prepare_repo_write( 154 - &state, 155 - &headers, 156 - &input.repo, 157 - "POST", 158 - &crate::util::build_full_url(&uri.to_string()), 159 - ) 160 - .await 161 - { 129 + ) -> Result<Response, crate::api::error::ApiError> { 130 + let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await { 162 131 Ok(res) => res, 163 - Err(err_res) => return err_res, 132 + Err(err_res) => return Ok(err_res), 164 133 }; 165 134 166 135 if let Err(e) = crate::auth::scope_check::check_repo_scope( 167 - auth.is_oauth, 168 - auth.scope.as_deref(), 136 + repo_auth.is_oauth, 137 + repo_auth.scope.as_deref(), 169 138 crate::oauth::RepoAction::Create, 170 139 &input.collection, 171 140 ) { 172 - return e; 141 + return Ok(e); 173 142 } 174 143 175 - let did = auth.did; 176 - let user_id = auth.user_id; 177 - let current_root_cid = auth.current_root_cid; 178 - let controller_did = auth.controller_did; 144 + let did = repo_auth.did; 145 + let user_id = repo_auth.user_id; 146 + let current_root_cid = repo_auth.current_root_cid; 147 + let controller_did = repo_auth.controller_did; 179 148 180 149 if let Some(swap_commit) = &input.swap_commit 181 150 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 182 151 { 183 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 152 + return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 184 153 } 185 154 186 155 let validation_status = if input.validate == Some(false) { ··· 194 163 require_lexicon, 195 164 ) { 196 165 Ok(status) => Some(status), 197 - Err(err_response) => return *err_response, 166 + Err(err_response) => return Ok(*err_response), 198 167 } 199 168 }; 200 169 let rkey = input.rkey.unwrap_or_else(Rkey::generate); ··· 202 171 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 203 172 let commit_bytes = match tracking_store.get(&current_root_cid).await { 204 173 Ok(Some(b)) => b, 205 - _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(), 174 + _ => { 175 + return Ok( 176 + ApiError::InternalError(Some("Commit block not found".into())).into_response(), 177 + ); 178 + } 206 179 }; 207 180 let commit = match Commit::from_cbor(&commit_bytes) { 208 181 Ok(c) => c, 209 - _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 182 + _ => { 183 + return Ok( 184 + ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 185 + ); 186 + } 210 187 }; 211 188 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 212 189 let initial_mst_root = commit.data; ··· 228 205 Ok(c) => c, 229 206 Err(e) => { 230 207 error!("Failed to check backlink conflicts: {}", e); 231 - return ApiError::InternalError(None).into_response(); 208 + return Ok(ApiError::InternalError(None).into_response()); 232 209 } 233 210 }; 234 211 ··· 281 258 let record_ipld = crate::util::json_to_ipld(&input.record); 282 259 let mut record_bytes = Vec::new(); 283 260 if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() { 284 - return ApiError::InvalidRecord("Failed to serialize record".into()).into_response(); 261 + return Ok(ApiError::InvalidRecord("Failed to serialize record".into()).into_response()); 285 262 } 286 263 let record_cid = match tracking_store.put(&record_bytes).await { 287 264 Ok(c) => c, 288 265 _ => { 289 - return ApiError::InternalError(Some("Failed to save record block".into())) 290 - .into_response(); 266 + return Ok( 267 + ApiError::InternalError(Some("Failed to save record block".into())).into_response(), 268 + ); 291 269 } 292 270 }; 293 271 let key = format!("{}/{}", input.collection, rkey); ··· 302 280 303 281 let new_mst = match mst.add(&key, record_cid).await { 304 282 Ok(m) => m, 305 - _ => return ApiError::InternalError(Some("Failed to add to MST".into())).into_response(), 283 + _ => { 284 + return Ok(ApiError::InternalError(Some("Failed to add to MST".into())).into_response()); 285 + } 306 286 }; 307 287 let new_mst_root = match new_mst.persist().await { 308 288 Ok(c) => c, 309 - _ => return ApiError::InternalError(Some("Failed to persist MST".into())).into_response(), 289 + _ => { 290 + return Ok( 291 + ApiError::InternalError(Some("Failed to persist MST".into())).into_response(), 292 + ); 293 + } 310 294 }; 311 295 312 296 ops.push(RecordOp::Create { ··· 321 305 .await 322 306 .is_err() 323 307 { 324 - return ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 325 - .into_response(); 308 + return Ok( 309 + ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 310 + .into_response(), 311 + ); 326 312 } 327 313 328 314 let mut relevant_blocks = new_mst_blocks.clone(); ··· 364 350 { 365 351 Ok(res) => res, 366 352 Err(e) if e.contains("ConcurrentModification") => { 367 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 353 + return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 368 354 } 369 - Err(e) => return ApiError::InternalError(Some(e)).into_response(), 355 + Err(e) => return Ok(ApiError::InternalError(Some(e)).into_response()), 370 356 }; 371 357 372 358 for conflict_uri in conflict_uris_to_cleanup { ··· 406 392 error!("Failed to add backlinks for {}: {}", created_uri, e); 407 393 } 408 394 409 - ( 395 + Ok(( 410 396 StatusCode::OK, 411 397 Json(CreateRecordOutput { 412 398 uri: created_uri, ··· 418 404 validation_status: validation_status.map(|s| s.to_string()), 419 405 }), 420 406 ) 421 - .into_response() 407 + .into_response()) 422 408 } 423 409 #[derive(Deserialize)] 424 410 #[allow(dead_code)] ··· 445 431 } 446 432 pub async fn put_record( 447 433 State(state): State<AppState>, 448 - headers: HeaderMap, 449 - axum::extract::OriginalUri(uri): axum::extract::OriginalUri, 434 + auth: Auth<Active>, 450 435 Json(input): Json<PutRecordInput>, 451 - ) -> Response { 452 - let auth = match prepare_repo_write( 453 - &state, 454 - &headers, 455 - &input.repo, 456 - "POST", 457 - &crate::util::build_full_url(&uri.to_string()), 458 - ) 459 - .await 460 - { 436 + ) -> Result<Response, crate::api::error::ApiError> { 437 + let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await { 461 438 Ok(res) => res, 462 - Err(err_res) => return err_res, 439 + Err(err_res) => return Ok(err_res), 463 440 }; 464 441 465 442 if let Err(e) = crate::auth::scope_check::check_repo_scope( 466 - auth.is_oauth, 467 - auth.scope.as_deref(), 443 + repo_auth.is_oauth, 444 + repo_auth.scope.as_deref(), 468 445 crate::oauth::RepoAction::Create, 469 446 &input.collection, 470 447 ) { 471 - return e; 448 + return Ok(e); 472 449 } 473 450 if let Err(e) = crate::auth::scope_check::check_repo_scope( 474 - auth.is_oauth, 475 - auth.scope.as_deref(), 451 + repo_auth.is_oauth, 452 + repo_auth.scope.as_deref(), 476 453 crate::oauth::RepoAction::Update, 477 454 &input.collection, 478 455 ) { 479 - return e; 456 + return Ok(e); 480 457 } 481 458 482 - let did = auth.did; 483 - let user_id = auth.user_id; 484 - let current_root_cid = auth.current_root_cid; 485 - let controller_did = auth.controller_did; 459 + let did = repo_auth.did; 460 + let user_id = repo_auth.user_id; 461 + let current_root_cid = repo_auth.current_root_cid; 462 + let controller_did = repo_auth.controller_did; 486 463 487 464 if let Some(swap_commit) = &input.swap_commit 488 465 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 489 466 { 490 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 467 + return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 491 468 } 492 469 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 493 470 let commit_bytes = match tracking_store.get(&current_root_cid).await { 494 471 Ok(Some(b)) => b, 495 - _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(), 472 + _ => { 473 + return Ok( 474 + ApiError::InternalError(Some("Commit block not found".into())).into_response(), 475 + ); 476 + } 496 477 }; 497 478 let commit = match Commit::from_cbor(&commit_bytes) { 498 479 Ok(c) => c, 499 - _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 480 + _ => { 481 + return Ok( 482 + ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 483 + ); 484 + } 500 485 }; 501 486 let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 502 487 let key = format!("{}/{}", input.collection, input.rkey); ··· 511 496 require_lexicon, 512 497 ) { 513 498 Ok(status) => Some(status), 514 - Err(err_response) => return *err_response, 499 + Err(err_response) => return Ok(*err_response), 515 500 } 516 501 }; 517 502 if let Some(swap_record_str) = &input.swap_record { 518 503 let expected_cid = Cid::from_str(swap_record_str).ok(); 519 504 let actual_cid = mst.get(&key).await.ok().flatten(); 520 505 if expected_cid != actual_cid { 521 - return ApiError::InvalidSwap(Some( 506 + return Ok(ApiError::InvalidSwap(Some( 522 507 "Record has been modified or does not exist".into(), 523 508 )) 524 - .into_response(); 509 + .into_response()); 525 510 } 526 511 } 527 512 let existing_cid = mst.get(&key).await.ok().flatten(); 528 513 let record_ipld = crate::util::json_to_ipld(&input.record); 529 514 let mut record_bytes = Vec::new(); 530 515 if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() { 531 - return ApiError::InvalidRecord("Failed to serialize record".into()).into_response(); 516 + return Ok(ApiError::InvalidRecord("Failed to serialize record".into()).into_response()); 532 517 } 533 518 let record_cid = match tracking_store.put(&record_bytes).await { 534 519 Ok(c) => c, 535 520 _ => { 536 - return ApiError::InternalError(Some("Failed to save record block".into())) 537 - .into_response(); 521 + return Ok( 522 + ApiError::InternalError(Some("Failed to save record block".into())).into_response(), 523 + ); 538 524 } 539 525 }; 540 526 if existing_cid == Some(record_cid) { 541 - return ( 527 + return Ok(( 542 528 StatusCode::OK, 543 529 Json(PutRecordOutput { 544 530 uri: AtUri::from_parts(&did, &input.collection, &input.rkey), ··· 547 533 validation_status: validation_status.map(|s| s.to_string()), 548 534 }), 549 535 ) 550 - .into_response(); 536 + .into_response()); 551 537 } 552 - let new_mst = if existing_cid.is_some() { 553 - match mst.update(&key, record_cid).await { 554 - Ok(m) => m, 555 - Err(_) => { 556 - return ApiError::InternalError(Some("Failed to update MST".into())) 557 - .into_response(); 538 + let new_mst = 539 + if existing_cid.is_some() { 540 + match mst.update(&key, record_cid).await { 541 + Ok(m) => m, 542 + Err(_) => { 543 + return Ok(ApiError::InternalError(Some("Failed to update MST".into())) 544 + .into_response()); 545 + } 558 546 } 559 - } 560 - } else { 561 - match mst.add(&key, record_cid).await { 562 - Ok(m) => m, 563 - Err(_) => { 564 - return ApiError::InternalError(Some("Failed to add to MST".into())) 565 - .into_response(); 547 + } else { 548 + match mst.add(&key, record_cid).await { 549 + Ok(m) => m, 550 + Err(_) => { 551 + return Ok(ApiError::InternalError(Some("Failed to add to MST".into())) 552 + .into_response()); 553 + } 566 554 } 567 - } 568 - }; 555 + }; 569 556 let new_mst_root = match new_mst.persist().await { 570 557 Ok(c) => c, 571 558 Err(_) => { 572 - return ApiError::InternalError(Some("Failed to persist MST".into())).into_response(); 559 + return Ok( 560 + ApiError::InternalError(Some("Failed to persist MST".into())).into_response(), 561 + ); 573 562 } 574 563 }; 575 564 let op = if existing_cid.is_some() { ··· 593 582 .await 594 583 .is_err() 595 584 { 596 - return ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 597 - .into_response(); 585 + return Ok( 586 + ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 587 + .into_response(), 588 + ); 598 589 } 599 590 if mst 600 591 .blocks_for_path(&key, &mut old_mst_blocks) 601 592 .await 602 593 .is_err() 603 594 { 604 - return ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 605 - .into_response(); 595 + return Ok( 596 + ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 597 + .into_response(), 598 + ); 606 599 } 607 600 let mut relevant_blocks = new_mst_blocks.clone(); 608 601 relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); ··· 644 637 { 645 638 Ok(res) => res, 646 639 Err(e) if e.contains("ConcurrentModification") => { 647 - return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 640 + return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 648 641 } 649 - Err(e) => return ApiError::InternalError(Some(e)).into_response(), 642 + Err(e) => return Ok(ApiError::InternalError(Some(e)).into_response()), 650 643 }; 651 644 652 645 if let Some(ref controller) = controller_did { ··· 668 661 .await; 669 662 } 670 663 671 - ( 664 + Ok(( 672 665 StatusCode::OK, 673 666 Json(PutRecordOutput { 674 667 uri: AtUri::from_parts(&did, &input.collection, &input.rkey), ··· 680 673 validation_status: validation_status.map(|s| s.to_string()), 681 674 }), 682 675 ) 683 - .into_response() 676 + .into_response()) 684 677 }
+54 -166
crates/tranquil-pds/src/api/server/account_status.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 + use crate::auth::{Auth, NotTakendown, Permissive}; 3 4 use crate::cache::Cache; 4 5 use crate::plc::PlcClient; 5 6 use crate::state::AppState; ··· 40 41 41 42 pub async fn check_account_status( 42 43 State(state): State<AppState>, 43 - headers: axum::http::HeaderMap, 44 - ) -> Response { 45 - let extracted = match crate::auth::extract_auth_token_from_header( 46 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 47 - ) { 48 - Some(t) => t, 49 - None => return ApiError::AuthenticationRequired.into_response(), 50 - }; 51 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 52 - let http_uri = format!( 53 - "https://{}/xrpc/com.atproto.server.checkAccountStatus", 54 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 55 - ); 56 - let did = match crate::auth::validate_token_with_dpop( 57 - state.user_repo.as_ref(), 58 - state.oauth_repo.as_ref(), 59 - &extracted.token, 60 - extracted.is_dpop, 61 - dpop_proof, 62 - "GET", 63 - &http_uri, 64 - true, 65 - false, 66 - ) 67 - .await 68 - { 69 - Ok(user) => user.did, 70 - Err(e) => return ApiError::from(e).into_response(), 71 - }; 72 - let user_id = match state.user_repo.get_id_by_did(&did).await { 73 - Ok(Some(id)) => id, 74 - _ => { 75 - return ApiError::InternalError(None).into_response(); 76 - } 77 - }; 44 + auth: Auth<Permissive>, 45 + ) -> Result<Response, ApiError> { 46 + let did = &auth.did; 47 + let user_id = state 48 + .user_repo 49 + .get_id_by_did(did) 50 + .await 51 + .map_err(|_| ApiError::InternalError(None))? 52 + .ok_or(ApiError::InternalError(None))?; 78 53 let is_active = state 79 54 .user_repo 80 - .is_account_active_by_did(&did) 55 + .is_account_active_by_did(did) 81 56 .await 82 57 .ok() 83 58 .flatten() ··· 121 96 .await 122 97 .unwrap_or(0); 123 98 let valid_did = 124 - is_valid_did_for_service(state.user_repo.as_ref(), state.cache.clone(), &did).await; 125 - ( 99 + is_valid_did_for_service(state.user_repo.as_ref(), state.cache.clone(), did).await; 100 + Ok(( 126 101 StatusCode::OK, 127 102 Json(CheckAccountStatusOutput { 128 103 activated: is_active, ··· 136 111 imported_blobs, 137 112 }), 138 113 ) 139 - .into_response() 114 + .into_response()) 140 115 } 141 116 142 117 async fn is_valid_did_for_service( ··· 331 306 332 307 pub async fn activate_account( 333 308 State(state): State<AppState>, 334 - headers: axum::http::HeaderMap, 335 - ) -> Response { 309 + auth: Auth<Permissive>, 310 + ) -> Result<Response, ApiError> { 336 311 info!("[MIGRATION] activateAccount called"); 337 - let extracted = match crate::auth::extract_auth_token_from_header( 338 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 339 - ) { 340 - Some(t) => t, 341 - None => { 342 - info!("[MIGRATION] activateAccount: No auth token"); 343 - return ApiError::AuthenticationRequired.into_response(); 344 - } 345 - }; 346 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 347 - let http_uri = format!( 348 - "https://{}/xrpc/com.atproto.server.activateAccount", 349 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 350 - ); 351 - let auth_user = match crate::auth::validate_token_with_dpop( 352 - state.user_repo.as_ref(), 353 - state.oauth_repo.as_ref(), 354 - &extracted.token, 355 - extracted.is_dpop, 356 - dpop_proof, 357 - "POST", 358 - &http_uri, 359 - true, 360 - false, 361 - ) 362 - .await 363 - { 364 - Ok(user) => user, 365 - Err(e) => { 366 - info!("[MIGRATION] activateAccount: Auth failed: {:?}", e); 367 - return ApiError::from(e).into_response(); 368 - } 369 - }; 370 312 info!( 371 313 "[MIGRATION] activateAccount: Authenticated user did={}", 372 - auth_user.did 314 + auth.did 373 315 ); 374 316 375 317 if let Err(e) = crate::auth::scope_check::check_account_scope( 376 - auth_user.is_oauth, 377 - auth_user.scope.as_deref(), 318 + auth.is_oauth(), 319 + auth.scope.as_deref(), 378 320 crate::oauth::scopes::AccountAttr::Repo, 379 321 crate::oauth::scopes::AccountAction::Manage, 380 322 ) { 381 323 info!("[MIGRATION] activateAccount: Scope check failed"); 382 - return e; 324 + return Ok(e); 383 325 } 384 326 385 - let did = auth_user.did; 327 + let did = auth.did.clone(); 386 328 387 329 info!( 388 330 "[MIGRATION] activateAccount: Validating DID document for did={}", ··· 402 344 did, 403 345 did_validation_start.elapsed() 404 346 ); 405 - return e.into_response(); 347 + return Err(e); 406 348 } 407 349 info!( 408 350 "[MIGRATION] activateAccount: DID document validation SUCCESS for {} (took {:?})", ··· 508 450 ); 509 451 } 510 452 info!("[MIGRATION] activateAccount: SUCCESS for did={}", did); 511 - EmptyResponse::ok().into_response() 453 + Ok(EmptyResponse::ok().into_response()) 512 454 } 513 455 Err(e) => { 514 456 error!( 515 457 "[MIGRATION] activateAccount: DB error activating account: {:?}", 516 458 e 517 459 ); 518 - ApiError::InternalError(None).into_response() 460 + Err(ApiError::InternalError(None)) 519 461 } 520 462 } 521 463 } ··· 528 470 529 471 pub async fn deactivate_account( 530 472 State(state): State<AppState>, 531 - headers: axum::http::HeaderMap, 473 + auth: Auth<Permissive>, 532 474 Json(input): Json<DeactivateAccountInput>, 533 - ) -> Response { 534 - let extracted = match crate::auth::extract_auth_token_from_header( 535 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 536 - ) { 537 - Some(t) => t, 538 - None => return ApiError::AuthenticationRequired.into_response(), 539 - }; 540 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 541 - let http_uri = format!( 542 - "https://{}/xrpc/com.atproto.server.deactivateAccount", 543 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 544 - ); 545 - let auth_user = match crate::auth::validate_token_with_dpop( 546 - state.user_repo.as_ref(), 547 - state.oauth_repo.as_ref(), 548 - &extracted.token, 549 - extracted.is_dpop, 550 - dpop_proof, 551 - "POST", 552 - &http_uri, 553 - false, 554 - false, 555 - ) 556 - .await 557 - { 558 - Ok(user) => user, 559 - Err(e) => return ApiError::from(e).into_response(), 560 - }; 561 - 475 + ) -> Result<Response, ApiError> { 562 476 if let Err(e) = crate::auth::scope_check::check_account_scope( 563 - auth_user.is_oauth, 564 - auth_user.scope.as_deref(), 477 + auth.is_oauth(), 478 + auth.scope.as_deref(), 565 479 crate::oauth::scopes::AccountAttr::Repo, 566 480 crate::oauth::scopes::AccountAction::Manage, 567 481 ) { 568 - return e; 482 + return Ok(e); 569 483 } 570 484 571 485 let delete_after: Option<chrono::DateTime<chrono::Utc>> = input ··· 574 488 .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok()) 575 489 .map(|dt| dt.with_timezone(&chrono::Utc)); 576 490 577 - let did = auth_user.did; 491 + let did = auth.did.clone(); 578 492 579 493 let handle = state.user_repo.get_handle_by_did(&did).await.ok().flatten(); 580 494 ··· 595 509 { 596 510 warn!("Failed to sequence account deactivated event: {}", e); 597 511 } 598 - EmptyResponse::ok().into_response() 512 + Ok(EmptyResponse::ok().into_response()) 599 513 } 600 - Ok(false) => EmptyResponse::ok().into_response(), 514 + Ok(false) => Ok(EmptyResponse::ok().into_response()), 601 515 Err(e) => { 602 516 error!("DB error deactivating account: {:?}", e); 603 - ApiError::InternalError(None).into_response() 517 + Err(ApiError::InternalError(None)) 604 518 } 605 519 } 606 520 } 607 521 608 522 pub async fn request_account_delete( 609 523 State(state): State<AppState>, 610 - headers: axum::http::HeaderMap, 611 - ) -> Response { 612 - let extracted = match crate::auth::extract_auth_token_from_header( 613 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 614 - ) { 615 - Some(t) => t, 616 - None => return ApiError::AuthenticationRequired.into_response(), 617 - }; 618 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 619 - let http_uri = format!( 620 - "https://{}/xrpc/com.atproto.server.requestAccountDelete", 621 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 622 - ); 623 - let validated = match crate::auth::validate_token_with_dpop( 624 - state.user_repo.as_ref(), 625 - state.oauth_repo.as_ref(), 626 - &extracted.token, 627 - extracted.is_dpop, 628 - dpop_proof, 629 - "POST", 630 - &http_uri, 631 - true, 632 - false, 633 - ) 634 - .await 635 - { 636 - Ok(user) => user, 637 - Err(e) => return ApiError::from(e).into_response(), 638 - }; 639 - let did = validated.did.clone(); 524 + auth: Auth<NotTakendown>, 525 + ) -> Result<Response, ApiError> { 526 + let did = &auth.did; 640 527 641 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &did).await { 642 - return crate::api::server::reauth::legacy_mfa_required_response( 528 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, did).await { 529 + return Ok(crate::api::server::reauth::legacy_mfa_required_response( 643 530 &*state.user_repo, 644 531 &*state.session_repo, 645 - &did, 532 + did, 646 533 ) 647 - .await; 534 + .await); 648 535 } 649 536 650 - let user_id = match state.user_repo.get_id_by_did(&did).await { 651 - Ok(Some(id)) => id, 652 - _ => { 653 - return ApiError::InternalError(None).into_response(); 654 - } 655 - }; 537 + let user_id = state 538 + .user_repo 539 + .get_id_by_did(did) 540 + .await 541 + .ok() 542 + .flatten() 543 + .ok_or(ApiError::InternalError(None))?; 656 544 let confirmation_token = Uuid::new_v4().to_string(); 657 545 let expires_at = Utc::now() + Duration::minutes(15); 658 - if let Err(e) = state 546 + state 659 547 .infra_repo 660 - .create_deletion_request(&confirmation_token, &did, expires_at) 548 + .create_deletion_request(&confirmation_token, did, expires_at) 661 549 .await 662 - { 663 - error!("DB error creating deletion token: {:?}", e); 664 - return ApiError::InternalError(None).into_response(); 665 - } 550 + .map_err(|e| { 551 + error!("DB error creating deletion token: {:?}", e); 552 + ApiError::InternalError(None) 553 + })?; 666 554 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 667 555 if let Err(e) = crate::comms::comms_repo::enqueue_account_deletion( 668 556 state.user_repo.as_ref(), ··· 676 564 warn!("Failed to enqueue account deletion notification: {:?}", e); 677 565 } 678 566 info!("Account deletion requested for user {}", did); 679 - EmptyResponse::ok().into_response() 567 + Ok(EmptyResponse::ok().into_response()) 680 568 } 681 569 682 570 #[derive(Deserialize)]
+126 -122
crates/tranquil-pds/src/api/server/app_password.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::{BearerAuth, generate_app_password}; 3 + use crate::auth::{Auth, NotTakendown, Permissive, generate_app_password}; 4 4 use crate::delegation::{DelegationActionType, intersect_scopes}; 5 5 use crate::state::{AppState, RateLimitKind}; 6 6 use axum::{ ··· 33 33 34 34 pub async fn list_app_passwords( 35 35 State(state): State<AppState>, 36 - BearerAuth(auth_user): BearerAuth, 37 - ) -> Response { 38 - let user = match state.user_repo.get_by_did(&auth_user.did).await { 39 - Ok(Some(u)) => u, 40 - Ok(None) => return ApiError::AccountNotFound.into_response(), 41 - Err(e) => { 36 + auth: Auth<Permissive>, 37 + ) -> Result<Response, ApiError> { 38 + let user = state 39 + .user_repo 40 + .get_by_did(&auth.did) 41 + .await 42 + .map_err(|e| { 42 43 error!("DB error getting user: {:?}", e); 43 - return ApiError::InternalError(None).into_response(); 44 - } 45 - }; 44 + ApiError::InternalError(None) 45 + })? 46 + .ok_or(ApiError::AccountNotFound)?; 46 47 47 - match state.session_repo.list_app_passwords(user.id).await { 48 - Ok(rows) => { 49 - let passwords: Vec<AppPassword> = rows 50 - .iter() 51 - .map(|row| AppPassword { 52 - name: row.name.clone(), 53 - created_at: row.created_at.to_rfc3339(), 54 - privileged: row.privileged, 55 - scopes: row.scopes.clone(), 56 - created_by_controller: row 57 - .created_by_controller_did 58 - .as_ref() 59 - .map(|d| d.to_string()), 60 - }) 61 - .collect(); 62 - Json(ListAppPasswordsOutput { passwords }).into_response() 63 - } 64 - Err(e) => { 48 + let rows = state 49 + .session_repo 50 + .list_app_passwords(user.id) 51 + .await 52 + .map_err(|e| { 65 53 error!("DB error listing app passwords: {:?}", e); 66 - ApiError::InternalError(None).into_response() 67 - } 68 - } 54 + ApiError::InternalError(None) 55 + })?; 56 + let passwords: Vec<AppPassword> = rows 57 + .iter() 58 + .map(|row| AppPassword { 59 + name: row.name.clone(), 60 + created_at: row.created_at.to_rfc3339(), 61 + privileged: row.privileged, 62 + scopes: row.scopes.clone(), 63 + created_by_controller: row 64 + .created_by_controller_did 65 + .as_ref() 66 + .map(|d| d.to_string()), 67 + }) 68 + .collect(); 69 + Ok(Json(ListAppPasswordsOutput { passwords }).into_response()) 69 70 } 70 71 71 72 #[derive(Deserialize)] ··· 89 90 pub async fn create_app_password( 90 91 State(state): State<AppState>, 91 92 headers: HeaderMap, 92 - BearerAuth(auth_user): BearerAuth, 93 + auth: Auth<NotTakendown>, 93 94 Json(input): Json<CreateAppPasswordInput>, 94 - ) -> Response { 95 + ) -> Result<Response, ApiError> { 95 96 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 96 97 if !state 97 98 .check_rate_limit(RateLimitKind::AppPassword, &client_ip) 98 99 .await 99 100 { 100 101 warn!(ip = %client_ip, "App password creation rate limit exceeded"); 101 - return ApiError::RateLimitExceeded(None).into_response(); 102 + return Err(ApiError::RateLimitExceeded(None)); 102 103 } 103 104 104 - let user = match state.user_repo.get_by_did(&auth_user.did).await { 105 - Ok(Some(u)) => u, 106 - Ok(None) => return ApiError::AccountNotFound.into_response(), 107 - Err(e) => { 105 + let user = state 106 + .user_repo 107 + .get_by_did(&auth.did) 108 + .await 109 + .map_err(|e| { 108 110 error!("DB error getting user: {:?}", e); 109 - return ApiError::InternalError(None).into_response(); 110 - } 111 - }; 111 + ApiError::InternalError(None) 112 + })? 113 + .ok_or(ApiError::AccountNotFound)?; 112 114 113 115 let name = input.name.trim(); 114 116 if name.is_empty() { 115 - return ApiError::InvalidRequest("name is required".into()).into_response(); 117 + return Err(ApiError::InvalidRequest("name is required".into())); 116 118 } 117 119 118 - match state 120 + if state 119 121 .session_repo 120 122 .get_app_password_by_name(user.id, name) 121 123 .await 122 - { 123 - Ok(Some(_)) => return ApiError::DuplicateAppPassword.into_response(), 124 - Err(e) => { 124 + .map_err(|e| { 125 125 error!("DB error checking app password: {:?}", e); 126 - return ApiError::InternalError(None).into_response(); 127 - } 128 - Ok(None) => {} 126 + ApiError::InternalError(None) 127 + })? 128 + .is_some() 129 + { 130 + return Err(ApiError::DuplicateAppPassword); 129 131 } 130 132 131 - let (final_scopes, controller_did) = if let Some(ref controller) = auth_user.controller_did { 133 + let (final_scopes, controller_did) = if let Some(ref controller) = auth.controller_did { 132 134 let grant = state 133 135 .delegation_repo 134 - .get_delegation(&auth_user.did, controller) 136 + .get_delegation(&auth.did, controller) 135 137 .await 136 138 .ok() 137 139 .flatten(); ··· 141 143 let intersected = intersect_scopes(requested, &granted_scopes); 142 144 143 145 if intersected.is_empty() && !granted_scopes.is_empty() { 144 - return ApiError::InsufficientScope(None).into_response(); 146 + return Err(ApiError::InsufficientScope(None)); 145 147 } 146 148 147 149 let scope_result = if intersected.is_empty() { ··· 157 159 let password = generate_app_password(); 158 160 159 161 let password_clone = password.clone(); 160 - let password_hash = match tokio::task::spawn_blocking(move || { 161 - bcrypt::hash(&password_clone, bcrypt::DEFAULT_COST) 162 - }) 163 - .await 164 - { 165 - Ok(Ok(h)) => h, 166 - Ok(Err(e)) => { 167 - error!("Failed to hash password: {:?}", e); 168 - return ApiError::InternalError(None).into_response(); 169 - } 170 - Err(e) => { 171 - error!("Failed to spawn blocking task: {:?}", e); 172 - return ApiError::InternalError(None).into_response(); 173 - } 174 - }; 162 + let password_hash = 163 + tokio::task::spawn_blocking(move || bcrypt::hash(&password_clone, bcrypt::DEFAULT_COST)) 164 + .await 165 + .map_err(|e| { 166 + error!("Failed to spawn blocking task: {:?}", e); 167 + ApiError::InternalError(None) 168 + })? 169 + .map_err(|e| { 170 + error!("Failed to hash password: {:?}", e); 171 + ApiError::InternalError(None) 172 + })?; 175 173 176 174 let privileged = input.privileged.unwrap_or(false); 177 175 let created_at = chrono::Utc::now(); ··· 185 183 created_by_controller_did: controller_did.clone(), 186 184 }; 187 185 188 - match state.session_repo.create_app_password(&create_data).await { 189 - Ok(_) => { 190 - if let Some(ref controller) = controller_did { 191 - let _ = state 192 - .delegation_repo 193 - .log_delegation_action( 194 - &auth_user.did, 195 - controller, 196 - Some(controller), 197 - DelegationActionType::AccountAction, 198 - Some(json!({ 199 - "action": "create_app_password", 200 - "name": name, 201 - "scopes": final_scopes 202 - })), 203 - None, 204 - None, 205 - ) 206 - .await; 207 - } 208 - Json(CreateAppPasswordOutput { 209 - name: name.to_string(), 210 - password, 211 - created_at: created_at.to_rfc3339(), 212 - privileged, 213 - scopes: final_scopes, 214 - }) 215 - .into_response() 216 - } 217 - Err(e) => { 186 + state 187 + .session_repo 188 + .create_app_password(&create_data) 189 + .await 190 + .map_err(|e| { 218 191 error!("DB error creating app password: {:?}", e); 219 - ApiError::InternalError(None).into_response() 220 - } 192 + ApiError::InternalError(None) 193 + })?; 194 + 195 + if let Some(ref controller) = controller_did { 196 + let _ = state 197 + .delegation_repo 198 + .log_delegation_action( 199 + &auth.did, 200 + controller, 201 + Some(controller), 202 + DelegationActionType::AccountAction, 203 + Some(json!({ 204 + "action": "create_app_password", 205 + "name": name, 206 + "scopes": final_scopes 207 + })), 208 + None, 209 + None, 210 + ) 211 + .await; 221 212 } 213 + Ok(Json(CreateAppPasswordOutput { 214 + name: name.to_string(), 215 + password, 216 + created_at: created_at.to_rfc3339(), 217 + privileged, 218 + scopes: final_scopes, 219 + }) 220 + .into_response()) 222 221 } 223 222 224 223 #[derive(Deserialize)] ··· 228 227 229 228 pub async fn revoke_app_password( 230 229 State(state): State<AppState>, 231 - BearerAuth(auth_user): BearerAuth, 230 + auth: Auth<Permissive>, 232 231 Json(input): Json<RevokeAppPasswordInput>, 233 - ) -> Response { 234 - let user = match state.user_repo.get_by_did(&auth_user.did).await { 235 - Ok(Some(u)) => u, 236 - Ok(None) => return ApiError::AccountNotFound.into_response(), 237 - Err(e) => { 232 + ) -> Result<Response, ApiError> { 233 + let user = state 234 + .user_repo 235 + .get_by_did(&auth.did) 236 + .await 237 + .map_err(|e| { 238 238 error!("DB error getting user: {:?}", e); 239 - return ApiError::InternalError(None).into_response(); 240 - } 241 - }; 239 + ApiError::InternalError(None) 240 + })? 241 + .ok_or(ApiError::AccountNotFound)?; 242 242 243 243 let name = input.name.trim(); 244 244 if name.is_empty() { 245 - return ApiError::InvalidRequest("name is required".into()).into_response(); 245 + return Err(ApiError::InvalidRequest("name is required".into())); 246 246 } 247 247 248 248 let sessions_to_invalidate = state 249 249 .session_repo 250 - .get_session_jtis_by_app_password(&auth_user.did, name) 250 + .get_session_jtis_by_app_password(&auth.did, name) 251 251 .await 252 252 .unwrap_or_default(); 253 253 254 - if let Err(e) = state 254 + state 255 255 .session_repo 256 - .delete_sessions_by_app_password(&auth_user.did, name) 256 + .delete_sessions_by_app_password(&auth.did, name) 257 257 .await 258 - { 259 - error!("DB error revoking sessions for app password: {:?}", e); 260 - return ApiError::InternalError(None).into_response(); 261 - } 258 + .map_err(|e| { 259 + error!("DB error revoking sessions for app password: {:?}", e); 260 + ApiError::InternalError(None) 261 + })?; 262 262 263 263 futures::future::join_all(sessions_to_invalidate.iter().map(|jti| { 264 - let cache_key = format!("auth:session:{}:{}", &auth_user.did, jti); 264 + let cache_key = format!("auth:session:{}:{}", &auth.did, jti); 265 265 let cache = state.cache.clone(); 266 266 async move { 267 267 let _ = cache.delete(&cache_key).await; ··· 269 269 })) 270 270 .await; 271 271 272 - if let Err(e) = state.session_repo.delete_app_password(user.id, name).await { 273 - error!("DB error revoking app password: {:?}", e); 274 - return ApiError::InternalError(None).into_response(); 275 - } 272 + state 273 + .session_repo 274 + .delete_app_password(user.id, name) 275 + .await 276 + .map_err(|e| { 277 + error!("DB error revoking app password: {:?}", e); 278 + ApiError::InternalError(None) 279 + })?; 276 280 277 - EmptyResponse::ok().into_response() 281 + Ok(EmptyResponse::ok().into_response()) 278 282 }
+92 -89
crates/tranquil-pds/src/api/server/email.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::{EmptyResponse, TokenRequiredResponse, VerifiedResponse}; 3 - use crate::auth::BearerAuth; 3 + use crate::auth::{Auth, NotTakendown}; 4 4 use crate::state::{AppState, RateLimitKind}; 5 5 use axum::{ 6 6 Json, ··· 45 45 pub async fn request_email_update( 46 46 State(state): State<AppState>, 47 47 headers: axum::http::HeaderMap, 48 - auth: BearerAuth, 48 + auth: Auth<NotTakendown>, 49 49 input: Option<Json<RequestEmailUpdateInput>>, 50 - ) -> Response { 50 + ) -> Result<Response, ApiError> { 51 51 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 52 52 if !state 53 53 .check_rate_limit(RateLimitKind::EmailUpdate, &client_ip) 54 54 .await 55 55 { 56 56 warn!(ip = %client_ip, "Email update rate limit exceeded"); 57 - return ApiError::RateLimitExceeded(None).into_response(); 57 + return Err(ApiError::RateLimitExceeded(None)); 58 58 } 59 59 60 60 if let Err(e) = crate::auth::scope_check::check_account_scope( 61 - auth.0.is_oauth, 62 - auth.0.scope.as_deref(), 61 + auth.is_oauth(), 62 + auth.scope.as_deref(), 63 63 crate::oauth::scopes::AccountAttr::Email, 64 64 crate::oauth::scopes::AccountAction::Manage, 65 65 ) { 66 - return e; 66 + return Ok(e); 67 67 } 68 68 69 - let user = match state.user_repo.get_email_info_by_did(&auth.0.did).await { 70 - Ok(Some(row)) => row, 71 - Ok(None) => { 72 - return ApiError::AccountNotFound.into_response(); 73 - } 74 - Err(e) => { 69 + let user = state 70 + .user_repo 71 + .get_email_info_by_did(&auth.did) 72 + .await 73 + .map_err(|e| { 75 74 error!("DB error: {:?}", e); 76 - return ApiError::InternalError(None).into_response(); 77 - } 78 - }; 75 + ApiError::InternalError(None) 76 + })? 77 + .ok_or(ApiError::AccountNotFound)?; 79 78 80 79 let Some(current_email) = user.email else { 81 - return ApiError::InvalidRequest("account does not have an email address".into()) 82 - .into_response(); 80 + return Err(ApiError::InvalidRequest( 81 + "account does not have an email address".into(), 82 + )); 83 83 }; 84 84 85 85 let token_required = user.email_verified; 86 86 87 87 if token_required { 88 88 let code = crate::auth::verification_token::generate_channel_update_token( 89 - &auth.0.did, 89 + &auth.did, 90 90 "email_update", 91 91 &current_email.to_lowercase(), 92 92 ); ··· 103 103 authorized: false, 104 104 }; 105 105 if let Ok(json) = serde_json::to_string(&pending) { 106 - let cache_key = email_update_cache_key(&auth.0.did); 106 + let cache_key = email_update_cache_key(&auth.did); 107 107 if let Err(e) = state.cache.set(&cache_key, &json, EMAIL_UPDATE_TTL).await { 108 108 warn!("Failed to cache pending email update: {:?}", e); 109 109 } ··· 127 127 } 128 128 129 129 info!("Email update requested for user {}", user.id); 130 - TokenRequiredResponse::response(token_required).into_response() 130 + Ok(TokenRequiredResponse::response(token_required).into_response()) 131 131 } 132 132 133 133 #[derive(Deserialize)] ··· 140 140 pub async fn confirm_email( 141 141 State(state): State<AppState>, 142 142 headers: axum::http::HeaderMap, 143 - auth: BearerAuth, 143 + auth: Auth<NotTakendown>, 144 144 Json(input): Json<ConfirmEmailInput>, 145 - ) -> Response { 145 + ) -> Result<Response, ApiError> { 146 146 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 147 147 if !state 148 148 .check_rate_limit(RateLimitKind::EmailUpdate, &client_ip) 149 149 .await 150 150 { 151 151 warn!(ip = %client_ip, "Confirm email rate limit exceeded"); 152 - return ApiError::RateLimitExceeded(None).into_response(); 152 + return Err(ApiError::RateLimitExceeded(None)); 153 153 } 154 154 155 155 if let Err(e) = crate::auth::scope_check::check_account_scope( 156 - auth.0.is_oauth, 157 - auth.0.scope.as_deref(), 156 + auth.is_oauth(), 157 + auth.scope.as_deref(), 158 158 crate::oauth::scopes::AccountAttr::Email, 159 159 crate::oauth::scopes::AccountAction::Manage, 160 160 ) { 161 - return e; 161 + return Ok(e); 162 162 } 163 163 164 - let did = &auth.0.did; 165 - let user = match state.user_repo.get_email_info_by_did(did).await { 166 - Ok(Some(row)) => row, 167 - Ok(None) => { 168 - return ApiError::AccountNotFound.into_response(); 169 - } 170 - Err(e) => { 164 + let did = &auth.did; 165 + let user = state 166 + .user_repo 167 + .get_email_info_by_did(did) 168 + .await 169 + .map_err(|e| { 171 170 error!("DB error: {:?}", e); 172 - return ApiError::InternalError(None).into_response(); 173 - } 174 - }; 171 + ApiError::InternalError(None) 172 + })? 173 + .ok_or(ApiError::AccountNotFound)?; 175 174 176 175 let Some(ref email) = user.email else { 177 - return ApiError::InvalidEmail.into_response(); 176 + return Err(ApiError::InvalidEmail); 178 177 }; 179 178 let current_email = email.to_lowercase(); 180 179 181 180 let provided_email = input.email.trim().to_lowercase(); 182 181 if provided_email != current_email { 183 - return ApiError::InvalidEmail.into_response(); 182 + return Err(ApiError::InvalidEmail); 184 183 } 185 184 186 185 if user.email_verified { 187 - return EmptyResponse::ok().into_response(); 186 + return Ok(EmptyResponse::ok().into_response()); 188 187 } 189 188 190 189 let confirmation_code = ··· 199 198 match verified { 200 199 Ok(token_data) => { 201 200 if token_data.did != did.as_str() { 202 - return ApiError::InvalidToken(None).into_response(); 201 + return Err(ApiError::InvalidToken(None)); 203 202 } 204 203 } 205 204 Err(crate::auth::verification_token::VerifyError::Expired) => { 206 - return ApiError::ExpiredToken(None).into_response(); 205 + return Err(ApiError::ExpiredToken(None)); 207 206 } 208 207 Err(_) => { 209 - return ApiError::InvalidToken(None).into_response(); 208 + return Err(ApiError::InvalidToken(None)); 210 209 } 211 210 } 212 211 213 - if let Err(e) = state.user_repo.set_email_verified(user.id, true).await { 214 - error!("DB error confirming email: {:?}", e); 215 - return ApiError::InternalError(None).into_response(); 216 - } 212 + state 213 + .user_repo 214 + .set_email_verified(user.id, true) 215 + .await 216 + .map_err(|e| { 217 + error!("DB error confirming email: {:?}", e); 218 + ApiError::InternalError(None) 219 + })?; 217 220 218 221 info!("Email confirmed for user {}", user.id); 219 - EmptyResponse::ok().into_response() 222 + Ok(EmptyResponse::ok().into_response()) 220 223 } 221 224 222 225 #[derive(Deserialize)] ··· 230 233 231 234 pub async fn update_email( 232 235 State(state): State<AppState>, 233 - auth: BearerAuth, 236 + auth: Auth<NotTakendown>, 234 237 Json(input): Json<UpdateEmailInput>, 235 - ) -> Response { 236 - let auth_user = auth.0; 237 - 238 + ) -> Result<Response, ApiError> { 238 239 if let Err(e) = crate::auth::scope_check::check_account_scope( 239 - auth_user.is_oauth, 240 - auth_user.scope.as_deref(), 240 + auth.is_oauth(), 241 + auth.scope.as_deref(), 241 242 crate::oauth::scopes::AccountAttr::Email, 242 243 crate::oauth::scopes::AccountAction::Manage, 243 244 ) { 244 - return e; 245 + return Ok(e); 245 246 } 246 247 247 - let did = &auth_user.did; 248 - let user = match state.user_repo.get_email_info_by_did(did).await { 249 - Ok(Some(row)) => row, 250 - Ok(None) => { 251 - return ApiError::AccountNotFound.into_response(); 252 - } 253 - Err(e) => { 248 + let did = &auth.did; 249 + let user = state 250 + .user_repo 251 + .get_email_info_by_did(did) 252 + .await 253 + .map_err(|e| { 254 254 error!("DB error: {:?}", e); 255 - return ApiError::InternalError(None).into_response(); 256 - } 257 - }; 255 + ApiError::InternalError(None) 256 + })? 257 + .ok_or(ApiError::AccountNotFound)?; 258 258 259 259 let user_id = user.id; 260 260 let current_email = user.email.clone(); ··· 262 262 let new_email = input.email.trim().to_lowercase(); 263 263 264 264 if !crate::api::validation::is_valid_email(&new_email) { 265 - return ApiError::InvalidRequest( 265 + return Err(ApiError::InvalidRequest( 266 266 "This email address is not supported, please use a different email.".into(), 267 - ) 268 - .into_response(); 267 + )); 269 268 } 270 269 271 270 if let Some(ref current) = current_email 272 271 && new_email == current.to_lowercase() 273 272 { 274 - return EmptyResponse::ok().into_response(); 273 + return Ok(EmptyResponse::ok().into_response()); 275 274 } 276 275 277 276 if email_verified { ··· 290 289 291 290 if !authorized_via_link { 292 291 let Some(ref t) = input.token else { 293 - return ApiError::TokenRequired.into_response(); 292 + return Err(ApiError::TokenRequired); 294 293 }; 295 294 let confirmation_token = 296 295 crate::auth::verification_token::normalize_token_input(t.trim()); ··· 309 308 match verified { 310 309 Ok(token_data) => { 311 310 if token_data.did != did.as_str() { 312 - return ApiError::InvalidToken(None).into_response(); 311 + return Err(ApiError::InvalidToken(None)); 313 312 } 314 313 } 315 314 Err(crate::auth::verification_token::VerifyError::Expired) => { 316 - return ApiError::ExpiredToken(None).into_response(); 315 + return Err(ApiError::ExpiredToken(None)); 317 316 } 318 317 Err(_) => { 319 - return ApiError::InvalidToken(None).into_response(); 318 + return Err(ApiError::InvalidToken(None)); 320 319 } 321 320 } 322 321 } 323 322 } 324 323 325 - if let Err(e) = state.user_repo.update_email(user_id, &new_email).await { 326 - error!("DB error updating email: {:?}", e); 327 - return ApiError::InternalError(None).into_response(); 328 - } 324 + state 325 + .user_repo 326 + .update_email(user_id, &new_email) 327 + .await 328 + .map_err(|e| { 329 + error!("DB error updating email: {:?}", e); 330 + ApiError::InternalError(None) 331 + })?; 329 332 330 333 let verification_token = 331 334 crate::auth::verification_token::generate_signup_token(did, "email", &new_email); ··· 358 361 } 359 362 360 363 info!("Email updated for user {}", user_id); 361 - EmptyResponse::ok().into_response() 364 + Ok(EmptyResponse::ok().into_response()) 362 365 } 363 366 364 367 #[derive(Deserialize)] ··· 497 500 pub async fn check_email_update_status( 498 501 State(state): State<AppState>, 499 502 headers: axum::http::HeaderMap, 500 - auth: BearerAuth, 501 - ) -> Response { 503 + auth: Auth<NotTakendown>, 504 + ) -> Result<Response, ApiError> { 502 505 let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 503 506 if !state 504 507 .check_rate_limit(RateLimitKind::VerificationCheck, &client_ip) 505 508 .await 506 509 { 507 - return ApiError::RateLimitExceeded(None).into_response(); 510 + return Err(ApiError::RateLimitExceeded(None)); 508 511 } 509 512 510 513 if let Err(e) = crate::auth::scope_check::check_account_scope( 511 - auth.0.is_oauth, 512 - auth.0.scope.as_deref(), 514 + auth.is_oauth(), 515 + auth.scope.as_deref(), 513 516 crate::oauth::scopes::AccountAttr::Email, 514 517 crate::oauth::scopes::AccountAction::Read, 515 518 ) { 516 - return e; 519 + return Ok(e); 517 520 } 518 521 519 - let cache_key = email_update_cache_key(&auth.0.did); 522 + let cache_key = email_update_cache_key(&auth.did); 520 523 let pending_json = match state.cache.get(&cache_key).await { 521 524 Some(json) => json, 522 525 None => { 523 - return Json(json!({ "pending": false, "authorized": false })).into_response(); 526 + return Ok(Json(json!({ "pending": false, "authorized": false })).into_response()); 524 527 } 525 528 }; 526 529 527 530 let pending: PendingEmailUpdate = match serde_json::from_str(&pending_json) { 528 531 Ok(p) => p, 529 532 Err(_) => { 530 - return Json(json!({ "pending": false, "authorized": false })).into_response(); 533 + return Ok(Json(json!({ "pending": false, "authorized": false })).into_response()); 531 534 } 532 535 }; 533 536 534 - Json(json!({ 537 + Ok(Json(json!({ 535 538 "pending": true, 536 539 "authorized": pending.authorized, 537 540 "newEmail": pending.new_email, 538 541 })) 539 - .into_response() 542 + .into_response()) 540 543 } 541 544 542 545 #[derive(Deserialize)]
+46 -48
crates/tranquil-pds/src/api/server/invite.rs
··· 1 1 use crate::api::ApiError; 2 - use crate::auth::BearerAuth; 3 - use crate::auth::extractor::BearerAuthAdmin; 2 + use crate::auth::{Admin, Auth, NotTakendown}; 4 3 use crate::state::AppState; 5 4 use crate::types::Did; 6 5 use axum::{ ··· 44 43 45 44 pub async fn create_invite_code( 46 45 State(state): State<AppState>, 47 - BearerAuthAdmin(auth_user): BearerAuthAdmin, 46 + auth: Auth<Admin>, 48 47 Json(input): Json<CreateInviteCodeInput>, 49 - ) -> Response { 48 + ) -> Result<Response, ApiError> { 50 49 if input.use_count < 1 { 51 - return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 50 + return Err(ApiError::InvalidRequest( 51 + "useCount must be at least 1".into(), 52 + )); 52 53 } 53 54 54 55 let for_account: Did = match &input.for_account { 55 - Some(acct) => match acct.parse() { 56 - Ok(d) => d, 57 - Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 58 - }, 59 - None => auth_user.did.clone(), 56 + Some(acct) => acct 57 + .parse() 58 + .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?, 59 + None => auth.did.clone(), 60 60 }; 61 61 let code = gen_invite_code(); 62 62 ··· 65 65 .create_invite_code(&code, input.use_count, Some(&for_account)) 66 66 .await 67 67 { 68 - Ok(true) => Json(CreateInviteCodeOutput { code }).into_response(), 68 + Ok(true) => Ok(Json(CreateInviteCodeOutput { code }).into_response()), 69 69 Ok(false) => { 70 70 error!("No admin user found to create invite code"); 71 - ApiError::InternalError(None).into_response() 71 + Err(ApiError::InternalError(None)) 72 72 } 73 73 Err(e) => { 74 74 error!("DB error creating invite code: {:?}", e); 75 - ApiError::InternalError(None).into_response() 75 + Err(ApiError::InternalError(None)) 76 76 } 77 77 } 78 78 } ··· 98 98 99 99 pub async fn create_invite_codes( 100 100 State(state): State<AppState>, 101 - BearerAuthAdmin(auth_user): BearerAuthAdmin, 101 + auth: Auth<Admin>, 102 102 Json(input): Json<CreateInviteCodesInput>, 103 - ) -> Response { 103 + ) -> Result<Response, ApiError> { 104 104 if input.use_count < 1 { 105 - return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 105 + return Err(ApiError::InvalidRequest( 106 + "useCount must be at least 1".into(), 107 + )); 106 108 } 107 109 108 110 let code_count = input.code_count.unwrap_or(1).max(1); 109 111 let for_accounts: Vec<Did> = match &input.for_accounts { 110 - Some(accounts) if !accounts.is_empty() => { 111 - let parsed: Result<Vec<Did>, _> = accounts.iter().map(|a| a.parse()).collect(); 112 - match parsed { 113 - Ok(dids) => dids, 114 - Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 115 - } 116 - } 117 - _ => vec![auth_user.did.clone()], 112 + Some(accounts) if !accounts.is_empty() => accounts 113 + .iter() 114 + .map(|a| a.parse()) 115 + .collect::<Result<Vec<Did>, _>>() 116 + .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?, 117 + _ => vec![auth.did.clone()], 118 118 }; 119 119 120 - let admin_user_id = match state.user_repo.get_any_admin_user_id().await { 121 - Ok(Some(id)) => id, 122 - Ok(None) => { 123 - error!("No admin user found to create invite codes"); 124 - return ApiError::InternalError(None).into_response(); 125 - } 126 - Err(e) => { 120 + let admin_user_id = state 121 + .user_repo 122 + .get_any_admin_user_id() 123 + .await 124 + .map_err(|e| { 127 125 error!("DB error looking up admin user: {:?}", e); 128 - return ApiError::InternalError(None).into_response(); 129 - } 130 - }; 126 + ApiError::InternalError(None) 127 + })? 128 + .ok_or_else(|| { 129 + error!("No admin user found to create invite codes"); 130 + ApiError::InternalError(None) 131 + })?; 131 132 132 133 let result = futures::future::try_join_all(for_accounts.into_iter().map(|account| { 133 134 let infra_repo = state.infra_repo.clone(); ··· 146 147 .await; 147 148 148 149 match result { 149 - Ok(result_codes) => Json(CreateInviteCodesOutput { 150 + Ok(result_codes) => Ok(Json(CreateInviteCodesOutput { 150 151 codes: result_codes, 151 152 }) 152 - .into_response(), 153 + .into_response()), 153 154 Err(e) => { 154 155 error!("DB error creating invite codes: {:?}", e); 155 - ApiError::InternalError(None).into_response() 156 + Err(ApiError::InternalError(None)) 156 157 } 157 158 } 158 159 } ··· 192 193 193 194 pub async fn get_account_invite_codes( 194 195 State(state): State<AppState>, 195 - BearerAuth(auth_user): BearerAuth, 196 + auth: Auth<NotTakendown>, 196 197 axum::extract::Query(params): axum::extract::Query<GetAccountInviteCodesParams>, 197 - ) -> Response { 198 + ) -> Result<Response, ApiError> { 198 199 let include_used = params.include_used.unwrap_or(true); 199 200 200 - let codes_info = match state 201 + let codes_info = state 201 202 .infra_repo 202 - .get_invite_codes_for_account(&auth_user.did) 203 + .get_invite_codes_for_account(&auth.did) 203 204 .await 204 - { 205 - Ok(info) => info, 206 - Err(e) => { 205 + .map_err(|e| { 207 206 error!("DB error fetching invite codes: {:?}", e); 208 - return ApiError::InternalError(None).into_response(); 209 - } 210 - }; 207 + ApiError::InternalError(None) 208 + })?; 211 209 212 210 let filtered_codes: Vec<_> = codes_info 213 211 .into_iter() ··· 254 252 .await; 255 253 256 254 let codes: Vec<InviteCode> = codes.into_iter().flatten().collect(); 257 - Json(GetAccountInviteCodesOutput { codes }).into_response() 255 + Ok(Json(GetAccountInviteCodesOutput { codes }).into_response()) 258 256 }
+46 -103
crates/tranquil-pds/src/api/server/migration.rs
··· 1 1 use crate::api::ApiError; 2 + use crate::auth::{Active, Auth}; 2 3 use crate::state::AppState; 3 4 use axum::{ 4 5 Json, ··· 35 36 36 37 pub async fn update_did_document( 37 38 State(state): State<AppState>, 38 - headers: axum::http::HeaderMap, 39 + auth: Auth<Active>, 39 40 Json(input): Json<UpdateDidDocumentInput>, 40 - ) -> Response { 41 - let extracted = match crate::auth::extract_auth_token_from_header( 42 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 43 - ) { 44 - Some(t) => t, 45 - None => return ApiError::AuthenticationRequired.into_response(), 46 - }; 47 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 48 - let http_uri = format!( 49 - "https://{}/xrpc/_account.updateDidDocument", 50 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 51 - ); 52 - let auth_user = match crate::auth::validate_token_with_dpop( 53 - state.user_repo.as_ref(), 54 - state.oauth_repo.as_ref(), 55 - &extracted.token, 56 - extracted.is_dpop, 57 - dpop_proof, 58 - "POST", 59 - &http_uri, 60 - true, 61 - false, 62 - ) 63 - .await 64 - { 65 - Ok(user) => user, 66 - Err(e) => return ApiError::from(e).into_response(), 67 - }; 68 - 69 - if !auth_user.did.starts_with("did:web:") { 70 - return ApiError::InvalidRequest( 41 + ) -> Result<Response, ApiError> { 42 + if !auth.did.starts_with("did:web:") { 43 + return Err(ApiError::InvalidRequest( 71 44 "DID document updates are only available for did:web accounts".into(), 72 - ) 73 - .into_response(); 45 + )); 74 46 } 75 47 76 - let user = match state.user_repo.get_user_for_did_doc(&auth_user.did).await { 77 - Ok(Some(u)) => u, 78 - Ok(None) => return ApiError::AccountNotFound.into_response(), 79 - Err(e) => { 48 + let user = state 49 + .user_repo 50 + .get_user_for_did_doc(&auth.did) 51 + .await 52 + .map_err(|e| { 80 53 tracing::error!("DB error getting user: {:?}", e); 81 - return ApiError::InternalError(None).into_response(); 82 - } 83 - }; 84 - 85 - if user.deactivated_at.is_some() { 86 - return ApiError::AccountDeactivated.into_response(); 87 - } 54 + ApiError::InternalError(None) 55 + })? 56 + .ok_or(ApiError::AccountNotFound)?; 88 57 89 58 if let Some(ref methods) = input.verification_methods { 90 59 if methods.is_empty() { 91 - return ApiError::InvalidRequest("verification_methods cannot be empty".into()) 92 - .into_response(); 60 + return Err(ApiError::InvalidRequest( 61 + "verification_methods cannot be empty".into(), 62 + )); 93 63 } 94 64 let validation_error = methods.iter().find_map(|method| { 95 65 if method.id.is_empty() { ··· 105 75 } 106 76 }); 107 77 if let Some(err) = validation_error { 108 - return ApiError::InvalidRequest(err.into()).into_response(); 78 + return Err(ApiError::InvalidRequest(err.into())); 109 79 } 110 80 } 111 81 112 82 if let Some(ref handles) = input.also_known_as 113 83 && handles.iter().any(|h| !h.starts_with("at://")) 114 84 { 115 - return ApiError::InvalidRequest("alsoKnownAs entries must be at:// URIs".into()) 116 - .into_response(); 85 + return Err(ApiError::InvalidRequest( 86 + "alsoKnownAs entries must be at:// URIs".into(), 87 + )); 117 88 } 118 89 119 90 if let Some(ref endpoint) = input.service_endpoint { 120 91 let endpoint = endpoint.trim(); 121 92 if !endpoint.starts_with("https://") { 122 - return ApiError::InvalidRequest("serviceEndpoint must start with https://".into()) 123 - .into_response(); 93 + return Err(ApiError::InvalidRequest( 94 + "serviceEndpoint must start with https://".into(), 95 + )); 124 96 } 125 97 } 126 98 ··· 131 103 132 104 let also_known_as: Option<Vec<String>> = input.also_known_as.clone(); 133 105 134 - if let Err(e) = state 106 + state 135 107 .user_repo 136 108 .upsert_did_web_overrides(user.id, verification_methods_json, also_known_as) 137 109 .await 138 - { 139 - tracing::error!("DB error upserting did_web_overrides: {:?}", e); 140 - return ApiError::InternalError(None).into_response(); 141 - } 110 + .map_err(|e| { 111 + tracing::error!("DB error upserting did_web_overrides: {:?}", e); 112 + ApiError::InternalError(None) 113 + })?; 142 114 143 115 if let Some(ref endpoint) = input.service_endpoint { 144 116 let endpoint_clean = endpoint.trim().trim_end_matches('/'); 145 - if let Err(e) = state 117 + state 146 118 .user_repo 147 - .update_migrated_to_pds(&auth_user.did, endpoint_clean) 119 + .update_migrated_to_pds(&auth.did, endpoint_clean) 148 120 .await 149 - { 150 - tracing::error!("DB error updating service endpoint: {:?}", e); 151 - return ApiError::InternalError(None).into_response(); 152 - } 121 + .map_err(|e| { 122 + tracing::error!("DB error updating service endpoint: {:?}", e); 123 + ApiError::InternalError(None) 124 + })?; 153 125 } 154 126 155 - let did_doc = build_did_document(&state, &auth_user.did).await; 127 + let did_doc = build_did_document(&state, &auth.did).await; 156 128 157 - tracing::info!("Updated DID document for {}", &auth_user.did); 129 + tracing::info!("Updated DID document for {}", &auth.did); 158 130 159 - ( 131 + Ok(( 160 132 StatusCode::OK, 161 133 Json(UpdateDidDocumentOutput { 162 134 success: true, 163 135 did_document: did_doc, 164 136 }), 165 137 ) 166 - .into_response() 138 + .into_response()) 167 139 } 168 140 169 141 pub async fn get_did_document( 170 142 State(state): State<AppState>, 171 - headers: axum::http::HeaderMap, 172 - ) -> Response { 173 - let extracted = match crate::auth::extract_auth_token_from_header( 174 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 175 - ) { 176 - Some(t) => t, 177 - None => return ApiError::AuthenticationRequired.into_response(), 178 - }; 179 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 180 - let http_uri = format!( 181 - "https://{}/xrpc/_account.getDidDocument", 182 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 183 - ); 184 - let auth_user = match crate::auth::validate_token_with_dpop( 185 - state.user_repo.as_ref(), 186 - state.oauth_repo.as_ref(), 187 - &extracted.token, 188 - extracted.is_dpop, 189 - dpop_proof, 190 - "GET", 191 - &http_uri, 192 - true, 193 - false, 194 - ) 195 - .await 196 - { 197 - Ok(user) => user, 198 - Err(e) => return ApiError::from(e).into_response(), 199 - }; 200 - 201 - if !auth_user.did.starts_with("did:web:") { 202 - return ApiError::InvalidRequest( 143 + auth: Auth<Active>, 144 + ) -> Result<Response, ApiError> { 145 + if !auth.did.starts_with("did:web:") { 146 + return Err(ApiError::InvalidRequest( 203 147 "This endpoint is only available for did:web accounts".into(), 204 - ) 205 - .into_response(); 148 + )); 206 149 } 207 150 208 - let did_doc = build_did_document(&state, &auth_user.did).await; 151 + let did_doc = build_did_document(&state, &auth.did).await; 209 152 210 - (StatusCode::OK, Json(json!({ "didDocument": did_doc }))).into_response() 153 + Ok((StatusCode::OK, Json(json!({ "didDocument": did_doc }))).into_response()) 211 154 } 212 155 213 156 async fn build_did_document(state: &AppState, did: &crate::types::Did) -> serde_json::Value {
+108 -147
crates/tranquil-pds/src/api/server/passkeys.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::BearerAuth; 4 3 use crate::auth::webauthn::WebAuthnConfig; 4 + use crate::auth::{Active, Auth}; 5 5 use crate::state::AppState; 6 6 use axum::{ 7 7 Json, ··· 34 34 35 35 pub async fn start_passkey_registration( 36 36 State(state): State<AppState>, 37 - auth: BearerAuth, 37 + auth: Auth<Active>, 38 38 Json(input): Json<StartRegistrationInput>, 39 - ) -> Response { 40 - let webauthn = match get_webauthn() { 41 - Ok(w) => w, 42 - Err(e) => return e.into_response(), 43 - }; 39 + ) -> Result<Response, ApiError> { 40 + let webauthn = get_webauthn()?; 44 41 45 - let handle = match state.user_repo.get_handle_by_did(&auth.0.did).await { 46 - Ok(Some(h)) => h, 47 - Ok(None) => { 48 - return ApiError::AccountNotFound.into_response(); 49 - } 50 - Err(e) => { 42 + let handle = state 43 + .user_repo 44 + .get_handle_by_did(&auth.did) 45 + .await 46 + .map_err(|e| { 51 47 error!("DB error fetching user: {:?}", e); 52 - return ApiError::InternalError(None).into_response(); 53 - } 54 - }; 48 + ApiError::InternalError(None) 49 + })? 50 + .ok_or(ApiError::AccountNotFound)?; 55 51 56 - let existing_passkeys = match state.user_repo.get_passkeys_for_user(&auth.0.did).await { 57 - Ok(passkeys) => passkeys, 58 - Err(e) => { 52 + let existing_passkeys = state 53 + .user_repo 54 + .get_passkeys_for_user(&auth.did) 55 + .await 56 + .map_err(|e| { 59 57 error!("DB error fetching existing passkeys: {:?}", e); 60 - return ApiError::InternalError(None).into_response(); 61 - } 62 - }; 58 + ApiError::InternalError(None) 59 + })?; 63 60 64 61 let exclude_credentials: Vec<CredentialID> = existing_passkeys 65 62 .iter() ··· 68 65 69 66 let display_name = input.friendly_name.as_deref().unwrap_or(&handle); 70 67 71 - let (ccr, reg_state) = match webauthn.start_registration( 72 - &auth.0.did, 73 - &handle, 74 - display_name, 75 - exclude_credentials, 76 - ) { 77 - Ok(result) => result, 78 - Err(e) => { 68 + let (ccr, reg_state) = webauthn 69 + .start_registration(&auth.did, &handle, display_name, exclude_credentials) 70 + .map_err(|e| { 79 71 error!("Failed to start passkey registration: {}", e); 80 - return ApiError::InternalError(Some("Failed to start registration".into())) 81 - .into_response(); 82 - } 83 - }; 72 + ApiError::InternalError(Some("Failed to start registration".into())) 73 + })?; 84 74 85 - let state_json = match serde_json::to_string(&reg_state) { 86 - Ok(s) => s, 87 - Err(e) => { 88 - error!("Failed to serialize registration state: {:?}", e); 89 - return ApiError::InternalError(None).into_response(); 90 - } 91 - }; 75 + let state_json = serde_json::to_string(&reg_state).map_err(|e| { 76 + error!("Failed to serialize registration state: {:?}", e); 77 + ApiError::InternalError(None) 78 + })?; 92 79 93 - if let Err(e) = state 80 + state 94 81 .user_repo 95 - .save_webauthn_challenge(&auth.0.did, "registration", &state_json) 82 + .save_webauthn_challenge(&auth.did, "registration", &state_json) 96 83 .await 97 - { 98 - error!("Failed to save registration state: {:?}", e); 99 - return ApiError::InternalError(None).into_response(); 100 - } 84 + .map_err(|e| { 85 + error!("Failed to save registration state: {:?}", e); 86 + ApiError::InternalError(None) 87 + })?; 101 88 102 89 let options = serde_json::to_value(&ccr).unwrap_or(serde_json::json!({})); 103 90 104 - info!(did = %auth.0.did, "Passkey registration started"); 91 + info!(did = %auth.did, "Passkey registration started"); 105 92 106 - Json(StartRegistrationResponse { options }).into_response() 93 + Ok(Json(StartRegistrationResponse { options }).into_response()) 107 94 } 108 95 109 96 #[derive(Deserialize)] ··· 122 109 123 110 pub async fn finish_passkey_registration( 124 111 State(state): State<AppState>, 125 - auth: BearerAuth, 112 + auth: Auth<Active>, 126 113 Json(input): Json<FinishRegistrationInput>, 127 - ) -> Response { 128 - let webauthn = match get_webauthn() { 129 - Ok(w) => w, 130 - Err(e) => return e.into_response(), 131 - }; 114 + ) -> Result<Response, ApiError> { 115 + let webauthn = get_webauthn()?; 132 116 133 - let reg_state_json = match state 117 + let reg_state_json = state 134 118 .user_repo 135 - .load_webauthn_challenge(&auth.0.did, "registration") 119 + .load_webauthn_challenge(&auth.did, "registration") 136 120 .await 137 - { 138 - Ok(Some(json)) => json, 139 - Ok(None) => { 140 - return ApiError::NoRegistrationInProgress.into_response(); 141 - } 142 - Err(e) => { 121 + .map_err(|e| { 143 122 error!("DB error loading registration state: {:?}", e); 144 - return ApiError::InternalError(None).into_response(); 145 - } 146 - }; 123 + ApiError::InternalError(None) 124 + })? 125 + .ok_or(ApiError::NoRegistrationInProgress)?; 147 126 148 - let reg_state: SecurityKeyRegistration = match serde_json::from_str(&reg_state_json) { 149 - Ok(s) => s, 150 - Err(e) => { 127 + let reg_state: SecurityKeyRegistration = 128 + serde_json::from_str(&reg_state_json).map_err(|e| { 151 129 error!("Failed to deserialize registration state: {:?}", e); 152 - return ApiError::InternalError(None).into_response(); 153 - } 154 - }; 130 + ApiError::InternalError(None) 131 + })?; 155 132 156 - let credential: RegisterPublicKeyCredential = match serde_json::from_value(input.credential) { 157 - Ok(c) => c, 158 - Err(e) => { 133 + let credential: RegisterPublicKeyCredential = serde_json::from_value(input.credential) 134 + .map_err(|e| { 159 135 warn!("Failed to parse credential: {:?}", e); 160 - return ApiError::InvalidCredential.into_response(); 161 - } 162 - }; 136 + ApiError::InvalidCredential 137 + })?; 163 138 164 - let passkey = match webauthn.finish_registration(&credential, &reg_state) { 165 - Ok(pk) => pk, 166 - Err(e) => { 139 + let passkey = webauthn 140 + .finish_registration(&credential, &reg_state) 141 + .map_err(|e| { 167 142 warn!("Failed to finish passkey registration: {}", e); 168 - return ApiError::RegistrationFailed.into_response(); 169 - } 170 - }; 143 + ApiError::RegistrationFailed 144 + })?; 171 145 172 - let public_key = match serde_json::to_vec(&passkey) { 173 - Ok(pk) => pk, 174 - Err(e) => { 175 - error!("Failed to serialize passkey: {:?}", e); 176 - return ApiError::InternalError(None).into_response(); 177 - } 178 - }; 146 + let public_key = serde_json::to_vec(&passkey).map_err(|e| { 147 + error!("Failed to serialize passkey: {:?}", e); 148 + ApiError::InternalError(None) 149 + })?; 179 150 180 - let passkey_id = match state 151 + let passkey_id = state 181 152 .user_repo 182 153 .save_passkey( 183 - &auth.0.did, 154 + &auth.did, 184 155 passkey.cred_id(), 185 156 &public_key, 186 157 input.friendly_name.as_deref(), 187 158 ) 188 159 .await 189 - { 190 - Ok(id) => id, 191 - Err(e) => { 160 + .map_err(|e| { 192 161 error!("Failed to save passkey: {:?}", e); 193 - return ApiError::InternalError(None).into_response(); 194 - } 195 - }; 162 + ApiError::InternalError(None) 163 + })?; 196 164 197 165 if let Err(e) = state 198 166 .user_repo 199 - .delete_webauthn_challenge(&auth.0.did, "registration") 167 + .delete_webauthn_challenge(&auth.did, "registration") 200 168 .await 201 169 { 202 170 warn!("Failed to delete registration state: {:?}", e); ··· 207 175 passkey.cred_id(), 208 176 ); 209 177 210 - info!(did = %auth.0.did, passkey_id = %passkey_id, "Passkey registered"); 178 + info!(did = %auth.did, passkey_id = %passkey_id, "Passkey registered"); 211 179 212 - Json(FinishRegistrationResponse { 180 + Ok(Json(FinishRegistrationResponse { 213 181 id: passkey_id.to_string(), 214 182 credential_id: credential_id_base64, 215 183 }) 216 - .into_response() 184 + .into_response()) 217 185 } 218 186 219 187 #[derive(Serialize)] ··· 232 200 pub passkeys: Vec<PasskeyInfo>, 233 201 } 234 202 235 - pub async fn list_passkeys(State(state): State<AppState>, auth: BearerAuth) -> Response { 236 - let passkeys = match state.user_repo.get_passkeys_for_user(&auth.0.did).await { 237 - Ok(pks) => pks, 238 - Err(e) => { 203 + pub async fn list_passkeys( 204 + State(state): State<AppState>, 205 + auth: Auth<Active>, 206 + ) -> Result<Response, ApiError> { 207 + let passkeys = state 208 + .user_repo 209 + .get_passkeys_for_user(&auth.did) 210 + .await 211 + .map_err(|e| { 239 212 error!("DB error fetching passkeys: {:?}", e); 240 - return ApiError::InternalError(None).into_response(); 241 - } 242 - }; 213 + ApiError::InternalError(None) 214 + })?; 243 215 244 216 let passkey_infos: Vec<PasskeyInfo> = passkeys 245 217 .into_iter() ··· 252 224 }) 253 225 .collect(); 254 226 255 - Json(ListPasskeysResponse { 227 + Ok(Json(ListPasskeysResponse { 256 228 passkeys: passkey_infos, 257 229 }) 258 - .into_response() 230 + .into_response()) 259 231 } 260 232 261 233 #[derive(Deserialize)] ··· 266 238 267 239 pub async fn delete_passkey( 268 240 State(state): State<AppState>, 269 - auth: BearerAuth, 241 + auth: Auth<Active>, 270 242 Json(input): Json<DeletePasskeyInput>, 271 - ) -> Response { 272 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did) 273 - .await 243 + ) -> Result<Response, ApiError> { 244 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 274 245 { 275 - return crate::api::server::reauth::legacy_mfa_required_response( 246 + return Ok(crate::api::server::reauth::legacy_mfa_required_response( 276 247 &*state.user_repo, 277 248 &*state.session_repo, 278 - &auth.0.did, 249 + &auth.did, 279 250 ) 280 - .await; 251 + .await); 281 252 } 282 253 283 - if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.0.did).await { 284 - return crate::api::server::reauth::reauth_required_response( 254 + if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.did).await { 255 + return Ok(crate::api::server::reauth::reauth_required_response( 285 256 &*state.user_repo, 286 257 &*state.session_repo, 287 - &auth.0.did, 258 + &auth.did, 288 259 ) 289 - .await; 260 + .await); 290 261 } 291 262 292 - let id: uuid::Uuid = match input.id.parse() { 293 - Ok(id) => id, 294 - Err(_) => { 295 - return ApiError::InvalidId.into_response(); 296 - } 297 - }; 263 + let id: uuid::Uuid = input.id.parse().map_err(|_| ApiError::InvalidId)?; 298 264 299 - match state.user_repo.delete_passkey(id, &auth.0.did).await { 265 + match state.user_repo.delete_passkey(id, &auth.did).await { 300 266 Ok(true) => { 301 - info!(did = %auth.0.did, passkey_id = %id, "Passkey deleted"); 302 - EmptyResponse::ok().into_response() 267 + info!(did = %auth.did, passkey_id = %id, "Passkey deleted"); 268 + Ok(EmptyResponse::ok().into_response()) 303 269 } 304 - Ok(false) => ApiError::PasskeyNotFound.into_response(), 270 + Ok(false) => Err(ApiError::PasskeyNotFound), 305 271 Err(e) => { 306 272 error!("DB error deleting passkey: {:?}", e); 307 - ApiError::InternalError(None).into_response() 273 + Err(ApiError::InternalError(None)) 308 274 } 309 275 } 310 276 } ··· 318 284 319 285 pub async fn update_passkey( 320 286 State(state): State<AppState>, 321 - auth: BearerAuth, 287 + auth: Auth<Active>, 322 288 Json(input): Json<UpdatePasskeyInput>, 323 - ) -> Response { 324 - let id: uuid::Uuid = match input.id.parse() { 325 - Ok(id) => id, 326 - Err(_) => { 327 - return ApiError::InvalidId.into_response(); 328 - } 329 - }; 289 + ) -> Result<Response, ApiError> { 290 + let id: uuid::Uuid = input.id.parse().map_err(|_| ApiError::InvalidId)?; 330 291 331 292 match state 332 293 .user_repo 333 - .update_passkey_name(id, &auth.0.did, &input.friendly_name) 294 + .update_passkey_name(id, &auth.did, &input.friendly_name) 334 295 .await 335 296 { 336 297 Ok(true) => { 337 - info!(did = %auth.0.did, passkey_id = %id, "Passkey renamed"); 338 - EmptyResponse::ok().into_response() 298 + info!(did = %auth.did, passkey_id = %id, "Passkey renamed"); 299 + Ok(EmptyResponse::ok().into_response()) 339 300 } 340 - Ok(false) => ApiError::PasskeyNotFound.into_response(), 301 + Ok(false) => Err(ApiError::PasskeyNotFound), 341 302 Err(e) => { 342 303 error!("DB error updating passkey: {:?}", e); 343 - ApiError::InternalError(None).into_response() 304 + Err(ApiError::InternalError(None)) 344 305 } 345 306 } 346 307 }
+127 -126
crates/tranquil-pds/src/api/server/password.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::{EmptyResponse, HasPasswordResponse, SuccessResponse}; 3 - use crate::auth::BearerAuth; 3 + use crate::auth::{Active, Auth}; 4 4 use crate::state::{AppState, RateLimitKind}; 5 5 use crate::types::PlainPassword; 6 6 use crate::validation::validate_password; ··· 227 227 228 228 pub async fn change_password( 229 229 State(state): State<AppState>, 230 - auth: BearerAuth, 230 + auth: Auth<Active>, 231 231 Json(input): Json<ChangePasswordInput>, 232 - ) -> Response { 233 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did) 234 - .await 232 + ) -> Result<Response, ApiError> { 233 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 235 234 { 236 - return crate::api::server::reauth::legacy_mfa_required_response( 235 + return Ok(crate::api::server::reauth::legacy_mfa_required_response( 237 236 &*state.user_repo, 238 237 &*state.session_repo, 239 - &auth.0.did, 238 + &auth.did, 240 239 ) 241 - .await; 240 + .await); 242 241 } 243 242 244 243 let current_password = &input.current_password; 245 244 let new_password = &input.new_password; 246 245 if current_password.is_empty() { 247 - return ApiError::InvalidRequest("currentPassword is required".into()).into_response(); 246 + return Err(ApiError::InvalidRequest( 247 + "currentPassword is required".into(), 248 + )); 248 249 } 249 250 if new_password.is_empty() { 250 - return ApiError::InvalidRequest("newPassword is required".into()).into_response(); 251 + return Err(ApiError::InvalidRequest("newPassword is required".into())); 251 252 } 252 253 if let Err(e) = validate_password(new_password) { 253 - return ApiError::InvalidRequest(e.to_string()).into_response(); 254 + return Err(ApiError::InvalidRequest(e.to_string())); 254 255 } 255 - let user = match state 256 + let user = state 256 257 .user_repo 257 - .get_id_and_password_hash_by_did(&auth.0.did) 258 + .get_id_and_password_hash_by_did(&auth.did) 258 259 .await 259 - { 260 - Ok(Some(u)) => u, 261 - Ok(None) => { 262 - return ApiError::AccountNotFound.into_response(); 263 - } 264 - Err(e) => { 260 + .map_err(|e| { 265 261 error!("DB error in change_password: {:?}", e); 266 - return ApiError::InternalError(None).into_response(); 267 - } 268 - }; 262 + ApiError::InternalError(None) 263 + })? 264 + .ok_or(ApiError::AccountNotFound)?; 265 + 269 266 let (user_id, password_hash) = (user.id, user.password_hash); 270 - let valid = match verify(current_password, &password_hash) { 271 - Ok(v) => v, 272 - Err(e) => { 273 - error!("Password verification error: {:?}", e); 274 - return ApiError::InternalError(None).into_response(); 275 - } 276 - }; 267 + let valid = verify(current_password, &password_hash).map_err(|e| { 268 + error!("Password verification error: {:?}", e); 269 + ApiError::InternalError(None) 270 + })?; 277 271 if !valid { 278 - return ApiError::InvalidPassword("Current password is incorrect".into()).into_response(); 272 + return Err(ApiError::InvalidPassword( 273 + "Current password is incorrect".into(), 274 + )); 279 275 } 280 276 let new_password_clone = new_password.to_string(); 281 - let new_hash = 282 - match tokio::task::spawn_blocking(move || hash(new_password_clone, DEFAULT_COST)).await { 283 - Ok(Ok(h)) => h, 284 - Ok(Err(e)) => { 285 - error!("Failed to hash password: {:?}", e); 286 - return ApiError::InternalError(None).into_response(); 287 - } 288 - Err(e) => { 289 - error!("Failed to spawn blocking task: {:?}", e); 290 - return ApiError::InternalError(None).into_response(); 291 - } 292 - }; 293 - if let Err(e) = state 277 + let new_hash = tokio::task::spawn_blocking(move || hash(new_password_clone, DEFAULT_COST)) 278 + .await 279 + .map_err(|e| { 280 + error!("Failed to spawn blocking task: {:?}", e); 281 + ApiError::InternalError(None) 282 + })? 283 + .map_err(|e| { 284 + error!("Failed to hash password: {:?}", e); 285 + ApiError::InternalError(None) 286 + })?; 287 + 288 + state 294 289 .user_repo 295 290 .update_password_hash(user_id, &new_hash) 296 291 .await 297 - { 298 - error!("DB error updating password: {:?}", e); 299 - return ApiError::InternalError(None).into_response(); 300 - } 301 - info!(did = %&auth.0.did, "Password changed successfully"); 302 - EmptyResponse::ok().into_response() 292 + .map_err(|e| { 293 + error!("DB error updating password: {:?}", e); 294 + ApiError::InternalError(None) 295 + })?; 296 + 297 + info!(did = %&auth.did, "Password changed successfully"); 298 + Ok(EmptyResponse::ok().into_response()) 303 299 } 304 300 305 - pub async fn get_password_status(State(state): State<AppState>, auth: BearerAuth) -> Response { 306 - match state.user_repo.has_password_by_did(&auth.0.did).await { 307 - Ok(Some(has)) => HasPasswordResponse::response(has).into_response(), 308 - Ok(None) => ApiError::AccountNotFound.into_response(), 301 + pub async fn get_password_status( 302 + State(state): State<AppState>, 303 + auth: Auth<Active>, 304 + ) -> Result<Response, ApiError> { 305 + match state.user_repo.has_password_by_did(&auth.did).await { 306 + Ok(Some(has)) => Ok(HasPasswordResponse::response(has).into_response()), 307 + Ok(None) => Err(ApiError::AccountNotFound), 309 308 Err(e) => { 310 309 error!("DB error: {:?}", e); 311 - ApiError::InternalError(None).into_response() 310 + Err(ApiError::InternalError(None)) 312 311 } 313 312 } 314 313 } 315 314 316 - pub async fn remove_password(State(state): State<AppState>, auth: BearerAuth) -> Response { 317 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did) 318 - .await 315 + pub async fn remove_password( 316 + State(state): State<AppState>, 317 + auth: Auth<Active>, 318 + ) -> Result<Response, ApiError> { 319 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 319 320 { 320 - return crate::api::server::reauth::legacy_mfa_required_response( 321 + return Ok(crate::api::server::reauth::legacy_mfa_required_response( 321 322 &*state.user_repo, 322 323 &*state.session_repo, 323 - &auth.0.did, 324 + &auth.did, 324 325 ) 325 - .await; 326 + .await); 326 327 } 327 328 328 329 if crate::api::server::reauth::check_reauth_required_cached( 329 330 &*state.session_repo, 330 331 &state.cache, 331 - &auth.0.did, 332 + &auth.did, 332 333 ) 333 334 .await 334 335 { 335 - return crate::api::server::reauth::reauth_required_response( 336 + return Ok(crate::api::server::reauth::reauth_required_response( 336 337 &*state.user_repo, 337 338 &*state.session_repo, 338 - &auth.0.did, 339 + &auth.did, 339 340 ) 340 - .await; 341 + .await); 341 342 } 342 343 343 344 let has_passkeys = state 344 345 .user_repo 345 - .has_passkeys(&auth.0.did) 346 + .has_passkeys(&auth.did) 346 347 .await 347 348 .unwrap_or(false); 348 349 if !has_passkeys { 349 - return ApiError::InvalidRequest( 350 + return Err(ApiError::InvalidRequest( 350 351 "You must have at least one passkey registered before removing your password".into(), 351 - ) 352 - .into_response(); 352 + )); 353 353 } 354 354 355 - let user = match state.user_repo.get_password_info_by_did(&auth.0.did).await { 356 - Ok(Some(u)) => u, 357 - Ok(None) => { 358 - return ApiError::AccountNotFound.into_response(); 359 - } 360 - Err(e) => { 355 + let user = state 356 + .user_repo 357 + .get_password_info_by_did(&auth.did) 358 + .await 359 + .map_err(|e| { 361 360 error!("DB error: {:?}", e); 362 - return ApiError::InternalError(None).into_response(); 363 - } 364 - }; 361 + ApiError::InternalError(None) 362 + })? 363 + .ok_or(ApiError::AccountNotFound)?; 365 364 366 365 if user.password_hash.is_none() { 367 - return ApiError::InvalidRequest("Account already has no password".into()).into_response(); 366 + return Err(ApiError::InvalidRequest( 367 + "Account already has no password".into(), 368 + )); 368 369 } 369 370 370 - if let Err(e) = state.user_repo.remove_user_password(user.id).await { 371 - error!("DB error removing password: {:?}", e); 372 - return ApiError::InternalError(None).into_response(); 373 - } 371 + state 372 + .user_repo 373 + .remove_user_password(user.id) 374 + .await 375 + .map_err(|e| { 376 + error!("DB error removing password: {:?}", e); 377 + ApiError::InternalError(None) 378 + })?; 374 379 375 - info!(did = %&auth.0.did, "Password removed - account is now passkey-only"); 376 - SuccessResponse::ok().into_response() 380 + info!(did = %&auth.did, "Password removed - account is now passkey-only"); 381 + Ok(SuccessResponse::ok().into_response()) 377 382 } 378 383 379 384 #[derive(Deserialize)] ··· 384 389 385 390 pub async fn set_password( 386 391 State(state): State<AppState>, 387 - auth: BearerAuth, 392 + auth: Auth<Active>, 388 393 Json(input): Json<SetPasswordInput>, 389 - ) -> Response { 394 + ) -> Result<Response, ApiError> { 390 395 let has_password = state 391 396 .user_repo 392 - .has_password_by_did(&auth.0.did) 397 + .has_password_by_did(&auth.did) 393 398 .await 394 399 .ok() 395 400 .flatten() 396 401 .unwrap_or(false); 397 402 let has_passkeys = state 398 403 .user_repo 399 - .has_passkeys(&auth.0.did) 404 + .has_passkeys(&auth.did) 400 405 .await 401 406 .unwrap_or(false); 402 407 let has_totp = state 403 408 .user_repo 404 - .has_totp_enabled(&auth.0.did) 409 + .has_totp_enabled(&auth.did) 405 410 .await 406 411 .unwrap_or(false); 407 412 ··· 411 416 && crate::api::server::reauth::check_reauth_required_cached( 412 417 &*state.session_repo, 413 418 &state.cache, 414 - &auth.0.did, 419 + &auth.did, 415 420 ) 416 421 .await 417 422 { 418 - return crate::api::server::reauth::reauth_required_response( 423 + return Ok(crate::api::server::reauth::reauth_required_response( 419 424 &*state.user_repo, 420 425 &*state.session_repo, 421 - &auth.0.did, 426 + &auth.did, 422 427 ) 423 - .await; 428 + .await); 424 429 } 425 430 426 431 let new_password = &input.new_password; 427 432 if new_password.is_empty() { 428 - return ApiError::InvalidRequest("newPassword is required".into()).into_response(); 433 + return Err(ApiError::InvalidRequest("newPassword is required".into())); 429 434 } 430 435 if let Err(e) = validate_password(new_password) { 431 - return ApiError::InvalidRequest(e.to_string()).into_response(); 436 + return Err(ApiError::InvalidRequest(e.to_string())); 432 437 } 433 438 434 - let user = match state.user_repo.get_password_info_by_did(&auth.0.did).await { 435 - Ok(Some(u)) => u, 436 - Ok(None) => { 437 - return ApiError::AccountNotFound.into_response(); 438 - } 439 - Err(e) => { 439 + let user = state 440 + .user_repo 441 + .get_password_info_by_did(&auth.did) 442 + .await 443 + .map_err(|e| { 440 444 error!("DB error: {:?}", e); 441 - return ApiError::InternalError(None).into_response(); 442 - } 443 - }; 445 + ApiError::InternalError(None) 446 + })? 447 + .ok_or(ApiError::AccountNotFound)?; 444 448 445 449 if user.password_hash.is_some() { 446 - return ApiError::InvalidRequest( 450 + return Err(ApiError::InvalidRequest( 447 451 "Account already has a password. Use changePassword instead.".into(), 448 - ) 449 - .into_response(); 452 + )); 450 453 } 451 454 452 455 let new_password_clone = new_password.to_string(); 453 - let new_hash = 454 - match tokio::task::spawn_blocking(move || hash(new_password_clone, DEFAULT_COST)).await { 455 - Ok(Ok(h)) => h, 456 - Ok(Err(e)) => { 457 - error!("Failed to hash password: {:?}", e); 458 - return ApiError::InternalError(None).into_response(); 459 - } 460 - Err(e) => { 461 - error!("Failed to spawn blocking task: {:?}", e); 462 - return ApiError::InternalError(None).into_response(); 463 - } 464 - }; 456 + let new_hash = tokio::task::spawn_blocking(move || hash(new_password_clone, DEFAULT_COST)) 457 + .await 458 + .map_err(|e| { 459 + error!("Failed to spawn blocking task: {:?}", e); 460 + ApiError::InternalError(None) 461 + })? 462 + .map_err(|e| { 463 + error!("Failed to hash password: {:?}", e); 464 + ApiError::InternalError(None) 465 + })?; 465 466 466 - if let Err(e) = state 467 + state 467 468 .user_repo 468 469 .set_new_user_password(user.id, &new_hash) 469 470 .await 470 - { 471 - error!("DB error setting password: {:?}", e); 472 - return ApiError::InternalError(None).into_response(); 473 - } 471 + .map_err(|e| { 472 + error!("DB error setting password: {:?}", e); 473 + ApiError::InternalError(None) 474 + })?; 474 475 475 - info!(did = %&auth.0.did, "Password set for passkey-only account"); 476 - SuccessResponse::ok().into_response() 476 + info!(did = %&auth.did, "Password set for passkey-only account"); 477 + Ok(SuccessResponse::ok().into_response()) 477 478 }
+127 -145
crates/tranquil-pds/src/api/server/reauth.rs
··· 10 10 use tracing::{error, info, warn}; 11 11 use tranquil_db_traits::{SessionRepository, UserRepository}; 12 12 13 - use crate::auth::BearerAuth; 13 + use crate::auth::{Active, Auth}; 14 14 use crate::state::{AppState, RateLimitKind}; 15 15 use crate::types::PlainPassword; 16 16 ··· 24 24 pub available_methods: Vec<String>, 25 25 } 26 26 27 - pub async fn get_reauth_status(State(state): State<AppState>, auth: BearerAuth) -> Response { 28 - let last_reauth_at = match state.session_repo.get_last_reauth_at(&auth.0.did).await { 29 - Ok(t) => t, 30 - Err(e) => { 27 + pub async fn get_reauth_status( 28 + State(state): State<AppState>, 29 + auth: Auth<Active>, 30 + ) -> Result<Response, ApiError> { 31 + let last_reauth_at = state 32 + .session_repo 33 + .get_last_reauth_at(&auth.did) 34 + .await 35 + .map_err(|e| { 31 36 error!("DB error: {:?}", e); 32 - return ApiError::InternalError(None).into_response(); 33 - } 34 - }; 37 + ApiError::InternalError(None) 38 + })?; 35 39 36 40 let reauth_required = is_reauth_required(last_reauth_at); 37 41 let available_methods = 38 - get_available_reauth_methods(&*state.user_repo, &*state.session_repo, &auth.0.did).await; 42 + get_available_reauth_methods(&*state.user_repo, &*state.session_repo, &auth.did).await; 39 43 40 - Json(ReauthStatusResponse { 44 + Ok(Json(ReauthStatusResponse { 41 45 last_reauth_at, 42 46 reauth_required, 43 47 available_methods, 44 48 }) 45 - .into_response() 49 + .into_response()) 46 50 } 47 51 48 52 #[derive(Deserialize)] ··· 59 63 60 64 pub async fn reauth_password( 61 65 State(state): State<AppState>, 62 - auth: BearerAuth, 66 + auth: Auth<Active>, 63 67 Json(input): Json<PasswordReauthInput>, 64 - ) -> Response { 65 - let password_hash = match state.user_repo.get_password_hash_by_did(&auth.0.did).await { 66 - Ok(Some(hash)) => hash, 67 - Ok(None) => { 68 - return ApiError::AccountNotFound.into_response(); 69 - } 70 - Err(e) => { 68 + ) -> Result<Response, ApiError> { 69 + let password_hash = state 70 + .user_repo 71 + .get_password_hash_by_did(&auth.did) 72 + .await 73 + .map_err(|e| { 71 74 error!("DB error: {:?}", e); 72 - return ApiError::InternalError(None).into_response(); 73 - } 74 - }; 75 + ApiError::InternalError(None) 76 + })? 77 + .ok_or(ApiError::AccountNotFound)?; 75 78 76 79 let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); 77 80 78 81 if !password_valid { 79 82 let app_password_hashes = state 80 83 .session_repo 81 - .get_app_password_hashes_by_did(&auth.0.did) 84 + .get_app_password_hashes_by_did(&auth.did) 82 85 .await 83 86 .unwrap_or_default(); 84 87 ··· 87 90 }); 88 91 89 92 if !app_password_valid { 90 - warn!(did = %&auth.0.did, "Re-auth failed: invalid password"); 91 - return ApiError::InvalidPassword("Password is incorrect".into()).into_response(); 93 + warn!(did = %&auth.did, "Re-auth failed: invalid password"); 94 + return Err(ApiError::InvalidPassword("Password is incorrect".into())); 92 95 } 93 96 } 94 97 95 - match update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.0.did).await { 96 - Ok(reauthed_at) => { 97 - info!(did = %&auth.0.did, "Re-auth successful via password"); 98 - Json(ReauthResponse { reauthed_at }).into_response() 99 - } 100 - Err(e) => { 98 + let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did) 99 + .await 100 + .map_err(|e| { 101 101 error!("DB error updating reauth: {:?}", e); 102 - ApiError::InternalError(None).into_response() 103 - } 104 - } 102 + ApiError::InternalError(None) 103 + })?; 104 + 105 + info!(did = %&auth.did, "Re-auth successful via password"); 106 + Ok(Json(ReauthResponse { reauthed_at }).into_response()) 105 107 } 106 108 107 109 #[derive(Deserialize)] ··· 112 114 113 115 pub async fn reauth_totp( 114 116 State(state): State<AppState>, 115 - auth: BearerAuth, 117 + auth: Auth<Active>, 116 118 Json(input): Json<TotpReauthInput>, 117 - ) -> Response { 119 + ) -> Result<Response, ApiError> { 118 120 if !state 119 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.0.did) 121 + .check_rate_limit(RateLimitKind::TotpVerify, &auth.did) 120 122 .await 121 123 { 122 - warn!(did = %&auth.0.did, "TOTP verification rate limit exceeded"); 123 - return ApiError::RateLimitExceeded(Some( 124 + warn!(did = %&auth.did, "TOTP verification rate limit exceeded"); 125 + return Err(ApiError::RateLimitExceeded(Some( 124 126 "Too many verification attempts. Please try again in a few minutes.".into(), 125 - )) 126 - .into_response(); 127 + ))); 127 128 } 128 129 129 130 let valid = 130 - crate::api::server::totp::verify_totp_or_backup_for_user(&state, &auth.0.did, &input.code) 131 + crate::api::server::totp::verify_totp_or_backup_for_user(&state, &auth.did, &input.code) 131 132 .await; 132 133 133 134 if !valid { 134 - warn!(did = %&auth.0.did, "Re-auth failed: invalid TOTP code"); 135 - return ApiError::InvalidCode(Some("Invalid TOTP or backup code".into())).into_response(); 135 + warn!(did = %&auth.did, "Re-auth failed: invalid TOTP code"); 136 + return Err(ApiError::InvalidCode(Some( 137 + "Invalid TOTP or backup code".into(), 138 + ))); 136 139 } 137 140 138 - match update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.0.did).await { 139 - Ok(reauthed_at) => { 140 - info!(did = %&auth.0.did, "Re-auth successful via TOTP"); 141 - Json(ReauthResponse { reauthed_at }).into_response() 142 - } 143 - Err(e) => { 141 + let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did) 142 + .await 143 + .map_err(|e| { 144 144 error!("DB error updating reauth: {:?}", e); 145 - ApiError::InternalError(None).into_response() 146 - } 147 - } 145 + ApiError::InternalError(None) 146 + })?; 147 + 148 + info!(did = %&auth.did, "Re-auth successful via TOTP"); 149 + Ok(Json(ReauthResponse { reauthed_at }).into_response()) 148 150 } 149 151 150 152 #[derive(Serialize)] ··· 153 155 pub options: serde_json::Value, 154 156 } 155 157 156 - pub async fn reauth_passkey_start(State(state): State<AppState>, auth: BearerAuth) -> Response { 158 + pub async fn reauth_passkey_start( 159 + State(state): State<AppState>, 160 + auth: Auth<Active>, 161 + ) -> Result<Response, ApiError> { 157 162 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 158 163 159 - let stored_passkeys = match state.user_repo.get_passkeys_for_user(&auth.0.did).await { 160 - Ok(pks) => pks, 161 - Err(e) => { 164 + let stored_passkeys = state 165 + .user_repo 166 + .get_passkeys_for_user(&auth.did) 167 + .await 168 + .map_err(|e| { 162 169 error!("Failed to get passkeys: {:?}", e); 163 - return ApiError::InternalError(None).into_response(); 164 - } 165 - }; 170 + ApiError::InternalError(None) 171 + })?; 166 172 167 173 if stored_passkeys.is_empty() { 168 - return ApiError::NoPasskeys.into_response(); 174 + return Err(ApiError::NoPasskeys); 169 175 } 170 176 171 177 let passkeys: Vec<webauthn_rs::prelude::SecurityKey> = stored_passkeys ··· 174 180 .collect(); 175 181 176 182 if passkeys.is_empty() { 177 - return ApiError::InternalError(Some("Failed to load passkeys".into())).into_response(); 183 + return Err(ApiError::InternalError(Some( 184 + "Failed to load passkeys".into(), 185 + ))); 178 186 } 179 187 180 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 181 - Ok(w) => w, 182 - Err(e) => { 183 - error!("Failed to create WebAuthn config: {:?}", e); 184 - return ApiError::InternalError(None).into_response(); 185 - } 186 - }; 188 + let webauthn = crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname).map_err(|e| { 189 + error!("Failed to create WebAuthn config: {:?}", e); 190 + ApiError::InternalError(None) 191 + })?; 187 192 188 - let (rcr, auth_state) = match webauthn.start_authentication(passkeys) { 189 - Ok(result) => result, 190 - Err(e) => { 191 - error!("Failed to start passkey authentication: {:?}", e); 192 - return ApiError::InternalError(None).into_response(); 193 - } 194 - }; 193 + let (rcr, auth_state) = webauthn.start_authentication(passkeys).map_err(|e| { 194 + error!("Failed to start passkey authentication: {:?}", e); 195 + ApiError::InternalError(None) 196 + })?; 195 197 196 - let state_json = match serde_json::to_string(&auth_state) { 197 - Ok(s) => s, 198 - Err(e) => { 199 - error!("Failed to serialize authentication state: {:?}", e); 200 - return ApiError::InternalError(None).into_response(); 201 - } 202 - }; 198 + let state_json = serde_json::to_string(&auth_state).map_err(|e| { 199 + error!("Failed to serialize authentication state: {:?}", e); 200 + ApiError::InternalError(None) 201 + })?; 203 202 204 - if let Err(e) = state 203 + state 205 204 .user_repo 206 - .save_webauthn_challenge(&auth.0.did, "authentication", &state_json) 205 + .save_webauthn_challenge(&auth.did, "authentication", &state_json) 207 206 .await 208 - { 209 - error!("Failed to save authentication state: {:?}", e); 210 - return ApiError::InternalError(None).into_response(); 211 - } 207 + .map_err(|e| { 208 + error!("Failed to save authentication state: {:?}", e); 209 + ApiError::InternalError(None) 210 + })?; 212 211 213 212 let options = serde_json::to_value(&rcr).unwrap_or(serde_json::json!({})); 214 - Json(PasskeyReauthStartResponse { options }).into_response() 213 + Ok(Json(PasskeyReauthStartResponse { options }).into_response()) 215 214 } 216 215 217 216 #[derive(Deserialize)] ··· 222 221 223 222 pub async fn reauth_passkey_finish( 224 223 State(state): State<AppState>, 225 - auth: BearerAuth, 224 + auth: Auth<Active>, 226 225 Json(input): Json<PasskeyReauthFinishInput>, 227 - ) -> Response { 226 + ) -> Result<Response, ApiError> { 228 227 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 229 228 230 - let auth_state_json = match state 229 + let auth_state_json = state 231 230 .user_repo 232 - .load_webauthn_challenge(&auth.0.did, "authentication") 231 + .load_webauthn_challenge(&auth.did, "authentication") 233 232 .await 234 - { 235 - Ok(Some(json)) => json, 236 - Ok(None) => { 237 - return ApiError::NoChallengeInProgress.into_response(); 238 - } 239 - Err(e) => { 233 + .map_err(|e| { 240 234 error!("Failed to load authentication state: {:?}", e); 241 - return ApiError::InternalError(None).into_response(); 242 - } 243 - }; 235 + ApiError::InternalError(None) 236 + })? 237 + .ok_or(ApiError::NoChallengeInProgress)?; 244 238 245 239 let auth_state: webauthn_rs::prelude::SecurityKeyAuthentication = 246 - match serde_json::from_str(&auth_state_json) { 247 - Ok(s) => s, 248 - Err(e) => { 249 - error!("Failed to deserialize authentication state: {:?}", e); 250 - return ApiError::InternalError(None).into_response(); 251 - } 252 - }; 240 + serde_json::from_str(&auth_state_json).map_err(|e| { 241 + error!("Failed to deserialize authentication state: {:?}", e); 242 + ApiError::InternalError(None) 243 + })?; 253 244 254 245 let credential: webauthn_rs::prelude::PublicKeyCredential = 255 - match serde_json::from_value(input.credential) { 256 - Ok(c) => c, 257 - Err(e) => { 258 - warn!("Failed to parse credential: {:?}", e); 259 - return ApiError::InvalidCredential.into_response(); 260 - } 261 - }; 246 + serde_json::from_value(input.credential).map_err(|e| { 247 + warn!("Failed to parse credential: {:?}", e); 248 + ApiError::InvalidCredential 249 + })?; 262 250 263 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 264 - Ok(w) => w, 265 - Err(e) => { 266 - error!("Failed to create WebAuthn config: {:?}", e); 267 - return ApiError::InternalError(None).into_response(); 268 - } 269 - }; 251 + let webauthn = crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname).map_err(|e| { 252 + error!("Failed to create WebAuthn config: {:?}", e); 253 + ApiError::InternalError(None) 254 + })?; 270 255 271 - let auth_result = match webauthn.finish_authentication(&credential, &auth_state) { 272 - Ok(r) => r, 273 - Err(e) => { 274 - warn!(did = %&auth.0.did, "Passkey re-auth failed: {:?}", e); 275 - return ApiError::AuthenticationFailed(Some("Passkey authentication failed".into())) 276 - .into_response(); 277 - } 278 - }; 256 + let auth_result = webauthn 257 + .finish_authentication(&credential, &auth_state) 258 + .map_err(|e| { 259 + warn!(did = %&auth.did, "Passkey re-auth failed: {:?}", e); 260 + ApiError::AuthenticationFailed(Some("Passkey authentication failed".into())) 261 + })?; 279 262 280 263 let cred_id_bytes = auth_result.cred_id().as_ref(); 281 264 match state ··· 284 267 .await 285 268 { 286 269 Ok(false) => { 287 - warn!(did = %&auth.0.did, "Passkey counter anomaly detected - possible cloned key"); 270 + warn!(did = %&auth.did, "Passkey counter anomaly detected - possible cloned key"); 288 271 let _ = state 289 272 .user_repo 290 - .delete_webauthn_challenge(&auth.0.did, "authentication") 273 + .delete_webauthn_challenge(&auth.did, "authentication") 291 274 .await; 292 - return ApiError::PasskeyCounterAnomaly.into_response(); 275 + return Err(ApiError::PasskeyCounterAnomaly); 293 276 } 294 277 Err(e) => { 295 278 error!("Failed to update passkey counter: {:?}", e); ··· 299 282 300 283 let _ = state 301 284 .user_repo 302 - .delete_webauthn_challenge(&auth.0.did, "authentication") 285 + .delete_webauthn_challenge(&auth.did, "authentication") 303 286 .await; 304 287 305 - match update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.0.did).await { 306 - Ok(reauthed_at) => { 307 - info!(did = %&auth.0.did, "Re-auth successful via passkey"); 308 - Json(ReauthResponse { reauthed_at }).into_response() 309 - } 310 - Err(e) => { 288 + let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did) 289 + .await 290 + .map_err(|e| { 311 291 error!("DB error updating reauth: {:?}", e); 312 - ApiError::InternalError(None).into_response() 313 - } 314 - } 292 + ApiError::InternalError(None) 293 + })?; 294 + 295 + info!(did = %&auth.did, "Re-auth successful via passkey"); 296 + Ok(Json(ReauthResponse { reauthed_at }).into_response()) 315 297 } 316 298 317 299 pub async fn update_last_reauth_cached(
+4 -4
crates/tranquil-pds/src/api/server/service_auth.rs
··· 95 95 { 96 96 Ok(result) => crate::auth::AuthenticatedUser { 97 97 did: Did::new_unchecked(result.did), 98 - is_oauth: true, 99 98 is_admin: false, 100 99 status: AccountStatus::Active, 101 100 scope: result.scope, 102 101 key_bytes: None, 103 102 controller_did: None, 103 + auth_source: crate::auth::AuthSource::OAuth, 104 104 }, 105 105 Err(crate::oauth::OAuthError::UseDpopNonce(nonce)) => { 106 106 return ( ··· 131 131 }; 132 132 info!( 133 133 did = %&auth_user.did, 134 - is_oauth = auth_user.is_oauth, 134 + is_oauth = auth_user.is_oauth(), 135 135 has_key = auth_user.key_bytes.is_some(), 136 136 "getServiceAuth auth validated" 137 137 ); ··· 180 180 181 181 if let Some(method) = lxm { 182 182 if let Err(e) = crate::auth::scope_check::check_rpc_scope( 183 - auth_user.is_oauth, 183 + auth_user.is_oauth(), 184 184 auth_user.scope.as_deref(), 185 185 &params.aud, 186 186 method, 187 187 ) { 188 188 return e; 189 189 } 190 - } else if auth_user.is_oauth { 190 + } else if auth_user.is_oauth() { 191 191 let permissions = auth_user.permissions(); 192 192 if !permissions.has_full_access() { 193 193 return ApiError::InvalidRequest(
+169 -174
crates/tranquil-pds/src/api/server/session.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::{EmptyResponse, SuccessResponse}; 3 - use crate::auth::{BearerAuth, BearerAuthAllowDeactivated}; 3 + use crate::auth::{Active, Auth, Permissive}; 4 4 use crate::state::{AppState, RateLimitKind}; 5 5 use crate::types::{AccountState, Did, Handle, PlainPassword}; 6 6 use axum::{ ··· 279 279 280 280 pub async fn get_session( 281 281 State(state): State<AppState>, 282 - BearerAuthAllowDeactivated(auth_user): BearerAuthAllowDeactivated, 283 - ) -> Response { 284 - let permissions = auth_user.permissions(); 282 + auth: Auth<Permissive>, 283 + ) -> Result<Response, ApiError> { 284 + let permissions = auth.permissions(); 285 285 let can_read_email = permissions.allows_email_read(); 286 286 287 - let did_for_doc = auth_user.did.clone(); 287 + let did_for_doc = auth.did.clone(); 288 288 let did_resolver = state.did_resolver.clone(); 289 289 let (db_result, did_doc) = tokio::join!( 290 - state.user_repo.get_session_info_by_did(&auth_user.did), 290 + state.user_repo.get_session_info_by_did(&auth.did), 291 291 did_resolver.resolve_did_document(&did_for_doc) 292 292 ); 293 293 match db_result { ··· 316 316 let email_confirmed_value = can_read_email && row.email_verified; 317 317 let mut response = json!({ 318 318 "handle": handle, 319 - "did": &auth_user.did, 319 + "did": &auth.did, 320 320 "active": account_state.is_active(), 321 321 "preferredChannel": preferred_channel, 322 322 "preferredChannelVerified": preferred_channel_verified, ··· 337 337 if let Some(doc) = did_doc { 338 338 response["didDoc"] = doc; 339 339 } 340 - Json(response).into_response() 340 + Ok(Json(response).into_response()) 341 341 } 342 - Ok(None) => ApiError::AuthenticationFailed(None).into_response(), 342 + Ok(None) => Err(ApiError::AuthenticationFailed(None)), 343 343 Err(e) => { 344 344 error!("Database error in get_session: {:?}", e); 345 - ApiError::InternalError(None).into_response() 345 + Err(ApiError::InternalError(None)) 346 346 } 347 347 } 348 348 } ··· 350 350 pub async fn delete_session( 351 351 State(state): State<AppState>, 352 352 headers: axum::http::HeaderMap, 353 - _auth: BearerAuth, 354 - ) -> Response { 355 - let extracted = match crate::auth::extract_auth_token_from_header( 353 + _auth: Auth<Active>, 354 + ) -> Result<Response, ApiError> { 355 + let extracted = crate::auth::extract_auth_token_from_header( 356 356 headers.get("Authorization").and_then(|h| h.to_str().ok()), 357 - ) { 358 - Some(t) => t, 359 - None => return ApiError::AuthenticationRequired.into_response(), 360 - }; 361 - let jti = match crate::auth::get_jti_from_token(&extracted.token) { 362 - Ok(jti) => jti, 363 - Err(_) => return ApiError::AuthenticationFailed(None).into_response(), 364 - }; 357 + ) 358 + .ok_or(ApiError::AuthenticationRequired)?; 359 + let jti = crate::auth::get_jti_from_token(&extracted.token) 360 + .map_err(|_| ApiError::AuthenticationFailed(None))?; 365 361 let did = crate::auth::get_did_from_token(&extracted.token).ok(); 366 362 match state.session_repo.delete_session_by_access_jti(&jti).await { 367 363 Ok(rows) if rows > 0 => { ··· 369 365 let session_cache_key = format!("auth:session:{}:{}", did, jti); 370 366 let _ = state.cache.delete(&session_cache_key).await; 371 367 } 372 - EmptyResponse::ok().into_response() 368 + Ok(EmptyResponse::ok().into_response()) 373 369 } 374 - Ok(_) => ApiError::AuthenticationFailed(None).into_response(), 375 - Err(_) => ApiError::AuthenticationFailed(None).into_response(), 370 + Ok(_) => Err(ApiError::AuthenticationFailed(None)), 371 + Err(_) => Err(ApiError::AuthenticationFailed(None)), 376 372 } 377 373 } 378 374 ··· 796 792 pub async fn list_sessions( 797 793 State(state): State<AppState>, 798 794 headers: HeaderMap, 799 - auth: BearerAuth, 800 - ) -> Response { 795 + auth: Auth<Active>, 796 + ) -> Result<Response, ApiError> { 801 797 let current_jti = headers 802 798 .get("authorization") 803 799 .and_then(|v| v.to_str().ok()) 804 800 .and_then(|v| v.strip_prefix("Bearer ")) 805 801 .and_then(|token| crate::auth::get_jti_from_token(token).ok()); 806 802 807 - let jwt_rows = match state.session_repo.list_sessions_by_did(&auth.0.did).await { 808 - Ok(rows) => rows, 809 - Err(e) => { 803 + let jwt_rows = state 804 + .session_repo 805 + .list_sessions_by_did(&auth.did) 806 + .await 807 + .map_err(|e| { 810 808 error!("DB error fetching JWT sessions: {:?}", e); 811 - return ApiError::InternalError(None).into_response(); 812 - } 813 - }; 809 + ApiError::InternalError(None) 810 + })?; 814 811 815 - let oauth_rows = match state.oauth_repo.list_sessions_by_did(&auth.0.did).await { 816 - Ok(rows) => rows, 817 - Err(e) => { 812 + let oauth_rows = state 813 + .oauth_repo 814 + .list_sessions_by_did(&auth.did) 815 + .await 816 + .map_err(|e| { 818 817 error!("DB error fetching OAuth sessions: {:?}", e); 819 - return ApiError::InternalError(None).into_response(); 820 - } 821 - }; 818 + ApiError::InternalError(None) 819 + })?; 822 820 823 821 let jwt_sessions = jwt_rows.into_iter().map(|row| SessionInfo { 824 822 id: format!("jwt:{}", row.id), ··· 829 827 is_current: current_jti.as_ref() == Some(&row.access_jti), 830 828 }); 831 829 832 - let is_oauth = auth.0.is_oauth; 830 + let is_oauth = auth.is_oauth(); 833 831 let oauth_sessions = oauth_rows.into_iter().map(|row| { 834 832 let client_name = extract_client_name(&row.client_id); 835 833 let is_current_oauth = is_oauth && current_jti.as_deref() == Some(row.token_id.as_str()); ··· 846 844 let mut sessions: Vec<SessionInfo> = jwt_sessions.chain(oauth_sessions).collect(); 847 845 sessions.sort_by(|a, b| b.created_at.cmp(&a.created_at)); 848 846 849 - (StatusCode::OK, Json(ListSessionsOutput { sessions })).into_response() 847 + Ok((StatusCode::OK, Json(ListSessionsOutput { sessions })).into_response()) 850 848 } 851 849 852 850 fn extract_client_name(client_id: &str) -> String { ··· 867 865 868 866 pub async fn revoke_session( 869 867 State(state): State<AppState>, 870 - auth: BearerAuth, 868 + auth: Auth<Active>, 871 869 Json(input): Json<RevokeSessionInput>, 872 - ) -> Response { 870 + ) -> Result<Response, ApiError> { 873 871 if let Some(jwt_id) = input.session_id.strip_prefix("jwt:") { 874 - let Ok(session_id) = jwt_id.parse::<i32>() else { 875 - return ApiError::InvalidRequest("Invalid session ID".into()).into_response(); 876 - }; 877 - let access_jti = match state 872 + let session_id: i32 = jwt_id 873 + .parse() 874 + .map_err(|_| ApiError::InvalidRequest("Invalid session ID".into()))?; 875 + let access_jti = state 878 876 .session_repo 879 - .get_session_access_jti_by_id(session_id, &auth.0.did) 877 + .get_session_access_jti_by_id(session_id, &auth.did) 880 878 .await 881 - { 882 - Ok(Some(jti)) => jti, 883 - Ok(None) => { 884 - return ApiError::SessionNotFound.into_response(); 885 - } 886 - Err(e) => { 879 + .map_err(|e| { 887 880 error!("DB error in revoke_session: {:?}", e); 888 - return ApiError::InternalError(None).into_response(); 889 - } 890 - }; 891 - if let Err(e) = state.session_repo.delete_session_by_id(session_id).await { 892 - error!("DB error deleting session: {:?}", e); 893 - return ApiError::InternalError(None).into_response(); 894 - } 895 - let cache_key = format!("auth:session:{}:{}", &auth.0.did, access_jti); 881 + ApiError::InternalError(None) 882 + })? 883 + .ok_or(ApiError::SessionNotFound)?; 884 + state 885 + .session_repo 886 + .delete_session_by_id(session_id) 887 + .await 888 + .map_err(|e| { 889 + error!("DB error deleting session: {:?}", e); 890 + ApiError::InternalError(None) 891 + })?; 892 + let cache_key = format!("auth:session:{}:{}", &auth.did, access_jti); 896 893 if let Err(e) = state.cache.delete(&cache_key).await { 897 894 warn!("Failed to invalidate session cache: {:?}", e); 898 895 } 899 - info!(did = %&auth.0.did, session_id = %session_id, "JWT session revoked"); 896 + info!(did = %&auth.did, session_id = %session_id, "JWT session revoked"); 900 897 } else if let Some(oauth_id) = input.session_id.strip_prefix("oauth:") { 901 - let Ok(session_id) = oauth_id.parse::<i32>() else { 902 - return ApiError::InvalidRequest("Invalid session ID".into()).into_response(); 903 - }; 904 - match state 898 + let session_id: i32 = oauth_id 899 + .parse() 900 + .map_err(|_| ApiError::InvalidRequest("Invalid session ID".into()))?; 901 + let deleted = state 905 902 .oauth_repo 906 - .delete_session_by_id(session_id, &auth.0.did) 903 + .delete_session_by_id(session_id, &auth.did) 907 904 .await 908 - { 909 - Ok(0) => { 910 - return ApiError::SessionNotFound.into_response(); 911 - } 912 - Err(e) => { 905 + .map_err(|e| { 913 906 error!("DB error deleting OAuth session: {:?}", e); 914 - return ApiError::InternalError(None).into_response(); 915 - } 916 - _ => {} 907 + ApiError::InternalError(None) 908 + })?; 909 + if deleted == 0 { 910 + return Err(ApiError::SessionNotFound); 917 911 } 918 - info!(did = %&auth.0.did, session_id = %session_id, "OAuth session revoked"); 912 + info!(did = %&auth.did, session_id = %session_id, "OAuth session revoked"); 919 913 } else { 920 - return ApiError::InvalidRequest("Invalid session ID format".into()).into_response(); 914 + return Err(ApiError::InvalidRequest("Invalid session ID format".into())); 921 915 } 922 - EmptyResponse::ok().into_response() 916 + Ok(EmptyResponse::ok().into_response()) 923 917 } 924 918 925 919 pub async fn revoke_all_sessions( 926 920 State(state): State<AppState>, 927 921 headers: HeaderMap, 928 - auth: BearerAuth, 929 - ) -> Response { 930 - let current_jti = crate::auth::extract_auth_token_from_header( 922 + auth: Auth<Active>, 923 + ) -> Result<Response, ApiError> { 924 + let jti = crate::auth::extract_auth_token_from_header( 931 925 headers.get("authorization").and_then(|v| v.to_str().ok()), 932 926 ) 933 - .and_then(|extracted| crate::auth::get_jti_from_token(&extracted.token).ok()); 927 + .and_then(|extracted| crate::auth::get_jti_from_token(&extracted.token).ok()) 928 + .ok_or(ApiError::InvalidToken(None))?; 934 929 935 - let Some(ref jti) = current_jti else { 936 - return ApiError::InvalidToken(None).into_response(); 937 - }; 938 - 939 - if auth.0.is_oauth { 940 - if let Err(e) = state.session_repo.delete_sessions_by_did(&auth.0.did).await { 941 - error!("DB error revoking JWT sessions: {:?}", e); 942 - return ApiError::InternalError(None).into_response(); 943 - } 930 + if auth.is_oauth() { 931 + state 932 + .session_repo 933 + .delete_sessions_by_did(&auth.did) 934 + .await 935 + .map_err(|e| { 936 + error!("DB error revoking JWT sessions: {:?}", e); 937 + ApiError::InternalError(None) 938 + })?; 944 939 let jti_typed = TokenId::from(jti.clone()); 945 - if let Err(e) = state 940 + state 946 941 .oauth_repo 947 - .delete_sessions_by_did_except(&auth.0.did, &jti_typed) 942 + .delete_sessions_by_did_except(&auth.did, &jti_typed) 948 943 .await 949 - { 950 - error!("DB error revoking OAuth sessions: {:?}", e); 951 - return ApiError::InternalError(None).into_response(); 952 - } 944 + .map_err(|e| { 945 + error!("DB error revoking OAuth sessions: {:?}", e); 946 + ApiError::InternalError(None) 947 + })?; 953 948 } else { 954 - if let Err(e) = state 949 + state 955 950 .session_repo 956 - .delete_sessions_by_did_except_jti(&auth.0.did, jti) 951 + .delete_sessions_by_did_except_jti(&auth.did, &jti) 952 + .await 953 + .map_err(|e| { 954 + error!("DB error revoking JWT sessions: {:?}", e); 955 + ApiError::InternalError(None) 956 + })?; 957 + state 958 + .oauth_repo 959 + .delete_sessions_by_did(&auth.did) 957 960 .await 958 - { 959 - error!("DB error revoking JWT sessions: {:?}", e); 960 - return ApiError::InternalError(None).into_response(); 961 - } 962 - if let Err(e) = state.oauth_repo.delete_sessions_by_did(&auth.0.did).await { 963 - error!("DB error revoking OAuth sessions: {:?}", e); 964 - return ApiError::InternalError(None).into_response(); 965 - } 961 + .map_err(|e| { 962 + error!("DB error revoking OAuth sessions: {:?}", e); 963 + ApiError::InternalError(None) 964 + })?; 966 965 } 967 966 968 - info!(did = %&auth.0.did, "All other sessions revoked"); 969 - SuccessResponse::ok().into_response() 967 + info!(did = %&auth.did, "All other sessions revoked"); 968 + Ok(SuccessResponse::ok().into_response()) 970 969 } 971 970 972 971 #[derive(Serialize)] ··· 978 977 979 978 pub async fn get_legacy_login_preference( 980 979 State(state): State<AppState>, 981 - auth: BearerAuth, 982 - ) -> Response { 983 - match state.user_repo.get_legacy_login_pref(&auth.0.did).await { 984 - Ok(Some(pref)) => Json(LegacyLoginPreferenceOutput { 985 - allow_legacy_login: pref.allow_legacy_login, 986 - has_mfa: pref.has_mfa, 987 - }) 988 - .into_response(), 989 - Ok(None) => ApiError::AccountNotFound.into_response(), 990 - Err(e) => { 980 + auth: Auth<Active>, 981 + ) -> Result<Response, ApiError> { 982 + let pref = state 983 + .user_repo 984 + .get_legacy_login_pref(&auth.did) 985 + .await 986 + .map_err(|e| { 991 987 error!("DB error: {:?}", e); 992 - ApiError::InternalError(None).into_response() 993 - } 994 - } 988 + ApiError::InternalError(None) 989 + })? 990 + .ok_or(ApiError::AccountNotFound)?; 991 + Ok(Json(LegacyLoginPreferenceOutput { 992 + allow_legacy_login: pref.allow_legacy_login, 993 + has_mfa: pref.has_mfa, 994 + }) 995 + .into_response()) 995 996 } 996 997 997 998 #[derive(Deserialize)] ··· 1002 1003 1003 1004 pub async fn update_legacy_login_preference( 1004 1005 State(state): State<AppState>, 1005 - auth: BearerAuth, 1006 + auth: Auth<Active>, 1006 1007 Json(input): Json<UpdateLegacyLoginInput>, 1007 - ) -> Response { 1008 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did) 1009 - .await 1008 + ) -> Result<Response, ApiError> { 1009 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 1010 1010 { 1011 - return crate::api::server::reauth::legacy_mfa_required_response( 1011 + return Ok(crate::api::server::reauth::legacy_mfa_required_response( 1012 1012 &*state.user_repo, 1013 1013 &*state.session_repo, 1014 - &auth.0.did, 1014 + &auth.did, 1015 1015 ) 1016 - .await; 1016 + .await); 1017 1017 } 1018 1018 1019 - if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.0.did).await { 1020 - return crate::api::server::reauth::reauth_required_response( 1019 + if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.did).await { 1020 + return Ok(crate::api::server::reauth::reauth_required_response( 1021 1021 &*state.user_repo, 1022 1022 &*state.session_repo, 1023 - &auth.0.did, 1023 + &auth.did, 1024 1024 ) 1025 - .await; 1025 + .await); 1026 1026 } 1027 1027 1028 - match state 1028 + let updated = state 1029 1029 .user_repo 1030 - .update_legacy_login(&auth.0.did, input.allow_legacy_login) 1030 + .update_legacy_login(&auth.did, input.allow_legacy_login) 1031 1031 .await 1032 - { 1033 - Ok(true) => { 1034 - info!( 1035 - did = %&auth.0.did, 1036 - allow_legacy_login = input.allow_legacy_login, 1037 - "Legacy login preference updated" 1038 - ); 1039 - Json(json!({ 1040 - "allowLegacyLogin": input.allow_legacy_login 1041 - })) 1042 - .into_response() 1043 - } 1044 - Ok(false) => ApiError::AccountNotFound.into_response(), 1045 - Err(e) => { 1032 + .map_err(|e| { 1046 1033 error!("DB error: {:?}", e); 1047 - ApiError::InternalError(None).into_response() 1048 - } 1034 + ApiError::InternalError(None) 1035 + })?; 1036 + if !updated { 1037 + return Err(ApiError::AccountNotFound); 1049 1038 } 1039 + info!( 1040 + did = %&auth.did, 1041 + allow_legacy_login = input.allow_legacy_login, 1042 + "Legacy login preference updated" 1043 + ); 1044 + Ok(Json(json!({ 1045 + "allowLegacyLogin": input.allow_legacy_login 1046 + })) 1047 + .into_response()) 1050 1048 } 1051 1049 1052 1050 use crate::comms::VALID_LOCALES; ··· 1059 1057 1060 1058 pub async fn update_locale( 1061 1059 State(state): State<AppState>, 1062 - auth: BearerAuth, 1060 + auth: Auth<Active>, 1063 1061 Json(input): Json<UpdateLocaleInput>, 1064 - ) -> Response { 1062 + ) -> Result<Response, ApiError> { 1065 1063 if !VALID_LOCALES.contains(&input.preferred_locale.as_str()) { 1066 - return ApiError::InvalidRequest(format!( 1064 + return Err(ApiError::InvalidRequest(format!( 1067 1065 "Invalid locale. Valid options: {}", 1068 1066 VALID_LOCALES.join(", ") 1069 - )) 1070 - .into_response(); 1067 + ))); 1071 1068 } 1072 1069 1073 - match state 1070 + let updated = state 1074 1071 .user_repo 1075 - .update_locale(&auth.0.did, &input.preferred_locale) 1072 + .update_locale(&auth.did, &input.preferred_locale) 1076 1073 .await 1077 - { 1078 - Ok(true) => { 1079 - info!( 1080 - did = %&auth.0.did, 1081 - locale = %input.preferred_locale, 1082 - "User locale preference updated" 1083 - ); 1084 - Json(json!({ 1085 - "preferredLocale": input.preferred_locale 1086 - })) 1087 - .into_response() 1088 - } 1089 - Ok(false) => ApiError::AccountNotFound.into_response(), 1090 - Err(e) => { 1074 + .map_err(|e| { 1091 1075 error!("DB error updating locale: {:?}", e); 1092 - ApiError::InternalError(None).into_response() 1093 - } 1076 + ApiError::InternalError(None) 1077 + })?; 1078 + if !updated { 1079 + return Err(ApiError::AccountNotFound); 1094 1080 } 1081 + info!( 1082 + did = %&auth.did, 1083 + locale = %input.preferred_locale, 1084 + "User locale preference updated" 1085 + ); 1086 + Ok(Json(json!({ 1087 + "preferredLocale": input.preferred_locale 1088 + })) 1089 + .into_response()) 1095 1090 }
+161 -161
crates/tranquil-pds/src/api/server/totp.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 - use crate::auth::BearerAuth; 3 + use crate::auth::{Active, Auth}; 4 4 use crate::auth::{ 5 5 decrypt_totp_secret, encrypt_totp_secret, generate_backup_codes, generate_qr_png_base64, 6 6 generate_totp_secret, generate_totp_uri, hash_backup_code, is_backup_code_format, ··· 26 26 pub qr_base64: String, 27 27 } 28 28 29 - pub async fn create_totp_secret(State(state): State<AppState>, auth: BearerAuth) -> Response { 30 - match state.user_repo.get_totp_record(&auth.0.did).await { 31 - Ok(Some(record)) if record.verified => return ApiError::TotpAlreadyEnabled.into_response(), 29 + pub async fn create_totp_secret( 30 + State(state): State<AppState>, 31 + auth: Auth<Active>, 32 + ) -> Result<Response, ApiError> { 33 + match state.user_repo.get_totp_record(&auth.did).await { 34 + Ok(Some(record)) if record.verified => return Err(ApiError::TotpAlreadyEnabled), 32 35 Ok(_) => {} 33 36 Err(e) => { 34 37 error!("DB error checking TOTP: {:?}", e); 35 - return ApiError::InternalError(None).into_response(); 38 + return Err(ApiError::InternalError(None)); 36 39 } 37 40 } 38 41 39 42 let secret = generate_totp_secret(); 40 43 41 - let handle = match state.user_repo.get_handle_by_did(&auth.0.did).await { 42 - Ok(Some(h)) => h, 43 - Ok(None) => return ApiError::AccountNotFound.into_response(), 44 - Err(e) => { 44 + let handle = state 45 + .user_repo 46 + .get_handle_by_did(&auth.did) 47 + .await 48 + .map_err(|e| { 45 49 error!("DB error fetching handle: {:?}", e); 46 - return ApiError::InternalError(None).into_response(); 47 - } 48 - }; 50 + ApiError::InternalError(None) 51 + })? 52 + .ok_or(ApiError::AccountNotFound)?; 49 53 50 54 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 51 55 let uri = generate_totp_uri(&secret, &handle, &hostname); 52 56 53 - let qr_code = match generate_qr_png_base64(&secret, &handle, &hostname) { 54 - Ok(qr) => qr, 55 - Err(e) => { 56 - error!("Failed to generate QR code: {:?}", e); 57 - return ApiError::InternalError(Some("Failed to generate QR code".into())) 58 - .into_response(); 59 - } 60 - }; 57 + let qr_code = generate_qr_png_base64(&secret, &handle, &hostname).map_err(|e| { 58 + error!("Failed to generate QR code: {:?}", e); 59 + ApiError::InternalError(Some("Failed to generate QR code".into())) 60 + })?; 61 61 62 - let encrypted_secret = match encrypt_totp_secret(&secret) { 63 - Ok(enc) => enc, 64 - Err(e) => { 65 - error!("Failed to encrypt TOTP secret: {:?}", e); 66 - return ApiError::InternalError(None).into_response(); 67 - } 68 - }; 62 + let encrypted_secret = encrypt_totp_secret(&secret).map_err(|e| { 63 + error!("Failed to encrypt TOTP secret: {:?}", e); 64 + ApiError::InternalError(None) 65 + })?; 69 66 70 - if let Err(e) = state 67 + state 71 68 .user_repo 72 - .upsert_totp_secret(&auth.0.did, &encrypted_secret, ENCRYPTION_VERSION) 69 + .upsert_totp_secret(&auth.did, &encrypted_secret, ENCRYPTION_VERSION) 73 70 .await 74 - { 75 - error!("Failed to store TOTP secret: {:?}", e); 76 - return ApiError::InternalError(None).into_response(); 77 - } 71 + .map_err(|e| { 72 + error!("Failed to store TOTP secret: {:?}", e); 73 + ApiError::InternalError(None) 74 + })?; 78 75 79 76 let secret_base32 = base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &secret); 80 77 81 - info!(did = %&auth.0.did, "TOTP secret created (pending verification)"); 78 + info!(did = %&auth.did, "TOTP secret created (pending verification)"); 82 79 83 - Json(CreateTotpSecretResponse { 80 + Ok(Json(CreateTotpSecretResponse { 84 81 secret: secret_base32, 85 82 uri, 86 83 qr_base64: qr_code, 87 84 }) 88 - .into_response() 85 + .into_response()) 89 86 } 90 87 91 88 #[derive(Deserialize)] ··· 101 98 102 99 pub async fn enable_totp( 103 100 State(state): State<AppState>, 104 - auth: BearerAuth, 101 + auth: Auth<Active>, 105 102 Json(input): Json<EnableTotpInput>, 106 - ) -> Response { 103 + ) -> Result<Response, ApiError> { 107 104 if !state 108 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.0.did) 105 + .check_rate_limit(RateLimitKind::TotpVerify, &auth.did) 109 106 .await 110 107 { 111 - warn!(did = %&auth.0.did, "TOTP verification rate limit exceeded"); 112 - return ApiError::RateLimitExceeded(None).into_response(); 108 + warn!(did = %&auth.did, "TOTP verification rate limit exceeded"); 109 + return Err(ApiError::RateLimitExceeded(None)); 113 110 } 114 111 115 - let totp_record = match state.user_repo.get_totp_record(&auth.0.did).await { 112 + let totp_record = match state.user_repo.get_totp_record(&auth.did).await { 116 113 Ok(Some(row)) => row, 117 - Ok(None) => return ApiError::TotpNotEnabled.into_response(), 114 + Ok(None) => return Err(ApiError::TotpNotEnabled), 118 115 Err(e) => { 119 116 error!("DB error fetching TOTP: {:?}", e); 120 - return ApiError::InternalError(None).into_response(); 117 + return Err(ApiError::InternalError(None)); 121 118 } 122 119 }; 123 120 124 121 if totp_record.verified { 125 - return ApiError::TotpAlreadyEnabled.into_response(); 122 + return Err(ApiError::TotpAlreadyEnabled); 126 123 } 127 124 128 - let secret = match decrypt_totp_secret( 125 + let secret = decrypt_totp_secret( 129 126 &totp_record.secret_encrypted, 130 127 totp_record.encryption_version, 131 - ) { 132 - Ok(s) => s, 133 - Err(e) => { 134 - error!("Failed to decrypt TOTP secret: {:?}", e); 135 - return ApiError::InternalError(None).into_response(); 136 - } 137 - }; 128 + ) 129 + .map_err(|e| { 130 + error!("Failed to decrypt TOTP secret: {:?}", e); 131 + ApiError::InternalError(None) 132 + })?; 138 133 139 134 let code = input.code.trim(); 140 135 if !verify_totp_code(&secret, code) { 141 - return ApiError::InvalidCode(Some("Invalid verification code".into())).into_response(); 136 + return Err(ApiError::InvalidCode(Some( 137 + "Invalid verification code".into(), 138 + ))); 142 139 } 143 140 144 141 let backup_codes = generate_backup_codes(); 145 - let backup_hashes: Result<Vec<_>, _> = 146 - backup_codes.iter().map(|c| hash_backup_code(c)).collect(); 147 - let backup_hashes = match backup_hashes { 148 - Ok(hashes) => hashes, 149 - Err(e) => { 142 + let backup_hashes: Vec<_> = backup_codes 143 + .iter() 144 + .map(|c| hash_backup_code(c)) 145 + .collect::<Result<Vec<_>, _>>() 146 + .map_err(|e| { 150 147 error!("Failed to hash backup code: {:?}", e); 151 - return ApiError::InternalError(None).into_response(); 152 - } 153 - }; 148 + ApiError::InternalError(None) 149 + })?; 154 150 155 - if let Err(e) = state 151 + state 156 152 .user_repo 157 - .enable_totp_with_backup_codes(&auth.0.did, &backup_hashes) 153 + .enable_totp_with_backup_codes(&auth.did, &backup_hashes) 158 154 .await 159 - { 160 - error!("Failed to enable TOTP: {:?}", e); 161 - return ApiError::InternalError(None).into_response(); 162 - } 155 + .map_err(|e| { 156 + error!("Failed to enable TOTP: {:?}", e); 157 + ApiError::InternalError(None) 158 + })?; 163 159 164 - info!(did = %&auth.0.did, "TOTP enabled with {} backup codes", backup_codes.len()); 160 + info!(did = %&auth.did, "TOTP enabled with {} backup codes", backup_codes.len()); 165 161 166 - Json(EnableTotpResponse { backup_codes }).into_response() 162 + Ok(Json(EnableTotpResponse { backup_codes }).into_response()) 167 163 } 168 164 169 165 #[derive(Deserialize)] ··· 174 170 175 171 pub async fn disable_totp( 176 172 State(state): State<AppState>, 177 - auth: BearerAuth, 173 + auth: Auth<Active>, 178 174 Json(input): Json<DisableTotpInput>, 179 - ) -> Response { 180 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did) 181 - .await 175 + ) -> Result<Response, ApiError> { 176 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 182 177 { 183 - return crate::api::server::reauth::legacy_mfa_required_response( 178 + return Ok(crate::api::server::reauth::legacy_mfa_required_response( 184 179 &*state.user_repo, 185 180 &*state.session_repo, 186 - &auth.0.did, 181 + &auth.did, 187 182 ) 188 - .await; 183 + .await); 189 184 } 190 185 191 186 if !state 192 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.0.did) 187 + .check_rate_limit(RateLimitKind::TotpVerify, &auth.did) 193 188 .await 194 189 { 195 - warn!(did = %&auth.0.did, "TOTP verification rate limit exceeded"); 196 - return ApiError::RateLimitExceeded(None).into_response(); 190 + warn!(did = %&auth.did, "TOTP verification rate limit exceeded"); 191 + return Err(ApiError::RateLimitExceeded(None)); 197 192 } 198 193 199 - let password_hash = match state.user_repo.get_password_hash_by_did(&auth.0.did).await { 200 - Ok(Some(hash)) => hash, 201 - Ok(None) => return ApiError::AccountNotFound.into_response(), 202 - Err(e) => { 194 + let password_hash = state 195 + .user_repo 196 + .get_password_hash_by_did(&auth.did) 197 + .await 198 + .map_err(|e| { 203 199 error!("DB error fetching user: {:?}", e); 204 - return ApiError::InternalError(None).into_response(); 205 - } 206 - }; 200 + ApiError::InternalError(None) 201 + })? 202 + .ok_or(ApiError::AccountNotFound)?; 207 203 208 204 let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); 209 205 if !password_valid { 210 - return ApiError::InvalidPassword("Password is incorrect".into()).into_response(); 206 + return Err(ApiError::InvalidPassword("Password is incorrect".into())); 211 207 } 212 208 213 - let totp_record = match state.user_repo.get_totp_record(&auth.0.did).await { 209 + let totp_record = match state.user_repo.get_totp_record(&auth.did).await { 214 210 Ok(Some(row)) if row.verified => row, 215 - Ok(Some(_)) | Ok(None) => return ApiError::TotpNotEnabled.into_response(), 211 + Ok(Some(_)) | Ok(None) => return Err(ApiError::TotpNotEnabled), 216 212 Err(e) => { 217 213 error!("DB error fetching TOTP: {:?}", e); 218 - return ApiError::InternalError(None).into_response(); 214 + return Err(ApiError::InternalError(None)); 219 215 } 220 216 }; 221 217 222 218 let code = input.code.trim(); 223 219 let code_valid = if is_backup_code_format(code) { 224 - verify_backup_code_for_user(&state, &auth.0.did, code).await 220 + verify_backup_code_for_user(&state, &auth.did, code).await 225 221 } else { 226 - let secret = match decrypt_totp_secret( 222 + let secret = decrypt_totp_secret( 227 223 &totp_record.secret_encrypted, 228 224 totp_record.encryption_version, 229 - ) { 230 - Ok(s) => s, 231 - Err(e) => { 232 - error!("Failed to decrypt TOTP secret: {:?}", e); 233 - return ApiError::InternalError(None).into_response(); 234 - } 235 - }; 225 + ) 226 + .map_err(|e| { 227 + error!("Failed to decrypt TOTP secret: {:?}", e); 228 + ApiError::InternalError(None) 229 + })?; 236 230 verify_totp_code(&secret, code) 237 231 }; 238 232 239 233 if !code_valid { 240 - return ApiError::InvalidCode(Some("Invalid verification code".into())).into_response(); 234 + return Err(ApiError::InvalidCode(Some( 235 + "Invalid verification code".into(), 236 + ))); 241 237 } 242 238 243 - if let Err(e) = state 239 + state 244 240 .user_repo 245 - .delete_totp_and_backup_codes(&auth.0.did) 241 + .delete_totp_and_backup_codes(&auth.did) 246 242 .await 247 - { 248 - error!("Failed to delete TOTP: {:?}", e); 249 - return ApiError::InternalError(None).into_response(); 250 - } 243 + .map_err(|e| { 244 + error!("Failed to delete TOTP: {:?}", e); 245 + ApiError::InternalError(None) 246 + })?; 251 247 252 - info!(did = %&auth.0.did, "TOTP disabled"); 248 + info!(did = %&auth.did, "TOTP disabled"); 253 249 254 - EmptyResponse::ok().into_response() 250 + Ok(EmptyResponse::ok().into_response()) 255 251 } 256 252 257 253 #[derive(Serialize)] ··· 262 258 pub backup_codes_remaining: i64, 263 259 } 264 260 265 - pub async fn get_totp_status(State(state): State<AppState>, auth: BearerAuth) -> Response { 266 - let enabled = match state.user_repo.get_totp_record(&auth.0.did).await { 261 + pub async fn get_totp_status( 262 + State(state): State<AppState>, 263 + auth: Auth<Active>, 264 + ) -> Result<Response, ApiError> { 265 + let enabled = match state.user_repo.get_totp_record(&auth.did).await { 267 266 Ok(Some(row)) => row.verified, 268 267 Ok(None) => false, 269 268 Err(e) => { 270 269 error!("DB error fetching TOTP status: {:?}", e); 271 - return ApiError::InternalError(None).into_response(); 270 + return Err(ApiError::InternalError(None)); 272 271 } 273 272 }; 274 273 275 - let backup_count = match state.user_repo.count_unused_backup_codes(&auth.0.did).await { 276 - Ok(count) => count, 277 - Err(e) => { 274 + let backup_count = state 275 + .user_repo 276 + .count_unused_backup_codes(&auth.did) 277 + .await 278 + .map_err(|e| { 278 279 error!("DB error counting backup codes: {:?}", e); 279 - return ApiError::InternalError(None).into_response(); 280 - } 281 - }; 280 + ApiError::InternalError(None) 281 + })?; 282 282 283 - Json(GetTotpStatusResponse { 283 + Ok(Json(GetTotpStatusResponse { 284 284 enabled, 285 285 has_backup_codes: backup_count > 0, 286 286 backup_codes_remaining: backup_count, 287 287 }) 288 - .into_response() 288 + .into_response()) 289 289 } 290 290 291 291 #[derive(Deserialize)] ··· 302 302 303 303 pub async fn regenerate_backup_codes( 304 304 State(state): State<AppState>, 305 - auth: BearerAuth, 305 + auth: Auth<Active>, 306 306 Json(input): Json<RegenerateBackupCodesInput>, 307 - ) -> Response { 307 + ) -> Result<Response, ApiError> { 308 308 if !state 309 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.0.did) 309 + .check_rate_limit(RateLimitKind::TotpVerify, &auth.did) 310 310 .await 311 311 { 312 - warn!(did = %&auth.0.did, "TOTP verification rate limit exceeded"); 313 - return ApiError::RateLimitExceeded(None).into_response(); 312 + warn!(did = %&auth.did, "TOTP verification rate limit exceeded"); 313 + return Err(ApiError::RateLimitExceeded(None)); 314 314 } 315 315 316 - let password_hash = match state.user_repo.get_password_hash_by_did(&auth.0.did).await { 317 - Ok(Some(hash)) => hash, 318 - Ok(None) => return ApiError::AccountNotFound.into_response(), 319 - Err(e) => { 316 + let password_hash = state 317 + .user_repo 318 + .get_password_hash_by_did(&auth.did) 319 + .await 320 + .map_err(|e| { 320 321 error!("DB error fetching user: {:?}", e); 321 - return ApiError::InternalError(None).into_response(); 322 - } 323 - }; 322 + ApiError::InternalError(None) 323 + })? 324 + .ok_or(ApiError::AccountNotFound)?; 324 325 325 326 let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); 326 327 if !password_valid { 327 - return ApiError::InvalidPassword("Password is incorrect".into()).into_response(); 328 + return Err(ApiError::InvalidPassword("Password is incorrect".into())); 328 329 } 329 330 330 - let totp_record = match state.user_repo.get_totp_record(&auth.0.did).await { 331 + let totp_record = match state.user_repo.get_totp_record(&auth.did).await { 331 332 Ok(Some(row)) if row.verified => row, 332 - Ok(Some(_)) | Ok(None) => return ApiError::TotpNotEnabled.into_response(), 333 + Ok(Some(_)) | Ok(None) => return Err(ApiError::TotpNotEnabled), 333 334 Err(e) => { 334 335 error!("DB error fetching TOTP: {:?}", e); 335 - return ApiError::InternalError(None).into_response(); 336 + return Err(ApiError::InternalError(None)); 336 337 } 337 338 }; 338 339 339 - let secret = match decrypt_totp_secret( 340 + let secret = decrypt_totp_secret( 340 341 &totp_record.secret_encrypted, 341 342 totp_record.encryption_version, 342 - ) { 343 - Ok(s) => s, 344 - Err(e) => { 345 - error!("Failed to decrypt TOTP secret: {:?}", e); 346 - return ApiError::InternalError(None).into_response(); 347 - } 348 - }; 343 + ) 344 + .map_err(|e| { 345 + error!("Failed to decrypt TOTP secret: {:?}", e); 346 + ApiError::InternalError(None) 347 + })?; 349 348 350 349 let code = input.code.trim(); 351 350 if !verify_totp_code(&secret, code) { 352 - return ApiError::InvalidCode(Some("Invalid verification code".into())).into_response(); 351 + return Err(ApiError::InvalidCode(Some( 352 + "Invalid verification code".into(), 353 + ))); 353 354 } 354 355 355 356 let backup_codes = generate_backup_codes(); 356 - let backup_hashes: Result<Vec<_>, _> = 357 - backup_codes.iter().map(|c| hash_backup_code(c)).collect(); 358 - let backup_hashes = match backup_hashes { 359 - Ok(hashes) => hashes, 360 - Err(e) => { 357 + let backup_hashes: Vec<_> = backup_codes 358 + .iter() 359 + .map(|c| hash_backup_code(c)) 360 + .collect::<Result<Vec<_>, _>>() 361 + .map_err(|e| { 361 362 error!("Failed to hash backup code: {:?}", e); 362 - return ApiError::InternalError(None).into_response(); 363 - } 364 - }; 363 + ApiError::InternalError(None) 364 + })?; 365 365 366 - if let Err(e) = state 366 + state 367 367 .user_repo 368 - .replace_backup_codes(&auth.0.did, &backup_hashes) 368 + .replace_backup_codes(&auth.did, &backup_hashes) 369 369 .await 370 - { 371 - error!("Failed to regenerate backup codes: {:?}", e); 372 - return ApiError::InternalError(None).into_response(); 373 - } 370 + .map_err(|e| { 371 + error!("Failed to regenerate backup codes: {:?}", e); 372 + ApiError::InternalError(None) 373 + })?; 374 374 375 - info!(did = %&auth.0.did, "Backup codes regenerated"); 375 + info!(did = %&auth.did, "Backup codes regenerated"); 376 376 377 - Json(RegenerateBackupCodesResponse { backup_codes }).into_response() 377 + Ok(Json(RegenerateBackupCodesResponse { backup_codes }).into_response()) 378 378 } 379 379 380 380 async fn verify_backup_code_for_user(
+57 -55
crates/tranquil-pds/src/api/server/trusted_devices.rs
··· 11 11 use tranquil_db_traits::OAuthRepository; 12 12 use tranquil_types::DeviceId; 13 13 14 - use crate::auth::BearerAuth; 14 + use crate::auth::{Active, Auth}; 15 15 use crate::state::AppState; 16 16 17 17 const TRUST_DURATION_DAYS: i64 = 30; ··· 71 71 pub devices: Vec<TrustedDevice>, 72 72 } 73 73 74 - pub async fn list_trusted_devices(State(state): State<AppState>, auth: BearerAuth) -> Response { 75 - match state.oauth_repo.list_trusted_devices(&auth.0.did).await { 76 - Ok(rows) => { 77 - let devices = rows 78 - .into_iter() 79 - .map(|row| { 80 - let trust_state = 81 - DeviceTrustState::from_timestamps(row.trusted_at, row.trusted_until); 82 - TrustedDevice { 83 - id: row.id, 84 - user_agent: row.user_agent, 85 - friendly_name: row.friendly_name, 86 - trusted_at: row.trusted_at, 87 - trusted_until: row.trusted_until, 88 - last_seen_at: row.last_seen_at, 89 - trust_state, 90 - } 91 - }) 92 - .collect(); 93 - Json(ListTrustedDevicesResponse { devices }).into_response() 94 - } 95 - Err(e) => { 74 + pub async fn list_trusted_devices( 75 + State(state): State<AppState>, 76 + auth: Auth<Active>, 77 + ) -> Result<Response, ApiError> { 78 + let rows = state 79 + .oauth_repo 80 + .list_trusted_devices(&auth.did) 81 + .await 82 + .map_err(|e| { 96 83 error!("DB error: {:?}", e); 97 - ApiError::InternalError(None).into_response() 98 - } 99 - } 84 + ApiError::InternalError(None) 85 + })?; 86 + 87 + let devices = rows 88 + .into_iter() 89 + .map(|row| { 90 + let trust_state = DeviceTrustState::from_timestamps(row.trusted_at, row.trusted_until); 91 + TrustedDevice { 92 + id: row.id, 93 + user_agent: row.user_agent, 94 + friendly_name: row.friendly_name, 95 + trusted_at: row.trusted_at, 96 + trusted_until: row.trusted_until, 97 + last_seen_at: row.last_seen_at, 98 + trust_state, 99 + } 100 + }) 101 + .collect(); 102 + 103 + Ok(Json(ListTrustedDevicesResponse { devices }).into_response()) 100 104 } 101 105 102 106 #[derive(Deserialize)] ··· 107 111 108 112 pub async fn revoke_trusted_device( 109 113 State(state): State<AppState>, 110 - auth: BearerAuth, 114 + auth: Auth<Active>, 111 115 Json(input): Json<RevokeTrustedDeviceInput>, 112 - ) -> Response { 116 + ) -> Result<Response, ApiError> { 113 117 let device_id = DeviceId::from(input.device_id.clone()); 114 118 match state 115 119 .oauth_repo 116 - .device_belongs_to_user(&device_id, &auth.0.did) 120 + .device_belongs_to_user(&device_id, &auth.did) 117 121 .await 118 122 { 119 123 Ok(true) => {} 120 124 Ok(false) => { 121 - return ApiError::DeviceNotFound.into_response(); 125 + return Err(ApiError::DeviceNotFound); 122 126 } 123 127 Err(e) => { 124 128 error!("DB error: {:?}", e); 125 - return ApiError::InternalError(None).into_response(); 129 + return Err(ApiError::InternalError(None)); 126 130 } 127 131 } 128 132 129 - match state.oauth_repo.revoke_device_trust(&device_id).await { 130 - Ok(()) => { 131 - info!(did = %&auth.0.did, device_id = %input.device_id, "Trusted device revoked"); 132 - SuccessResponse::ok().into_response() 133 - } 134 - Err(e) => { 133 + state 134 + .oauth_repo 135 + .revoke_device_trust(&device_id) 136 + .await 137 + .map_err(|e| { 135 138 error!("DB error: {:?}", e); 136 - ApiError::InternalError(None).into_response() 137 - } 138 - } 139 + ApiError::InternalError(None) 140 + })?; 141 + 142 + info!(did = %&auth.did, device_id = %input.device_id, "Trusted device revoked"); 143 + Ok(SuccessResponse::ok().into_response()) 139 144 } 140 145 141 146 #[derive(Deserialize)] ··· 147 152 148 153 pub async fn update_trusted_device( 149 154 State(state): State<AppState>, 150 - auth: BearerAuth, 155 + auth: Auth<Active>, 151 156 Json(input): Json<UpdateTrustedDeviceInput>, 152 - ) -> Response { 157 + ) -> Result<Response, ApiError> { 153 158 let device_id = DeviceId::from(input.device_id.clone()); 154 159 match state 155 160 .oauth_repo 156 - .device_belongs_to_user(&device_id, &auth.0.did) 161 + .device_belongs_to_user(&device_id, &auth.did) 157 162 .await 158 163 { 159 164 Ok(true) => {} 160 165 Ok(false) => { 161 - return ApiError::DeviceNotFound.into_response(); 166 + return Err(ApiError::DeviceNotFound); 162 167 } 163 168 Err(e) => { 164 169 error!("DB error: {:?}", e); 165 - return ApiError::InternalError(None).into_response(); 170 + return Err(ApiError::InternalError(None)); 166 171 } 167 172 } 168 173 169 - match state 174 + state 170 175 .oauth_repo 171 176 .update_device_friendly_name(&device_id, input.friendly_name.as_deref()) 172 177 .await 173 - { 174 - Ok(()) => { 175 - info!(did = %auth.0.did, device_id = %input.device_id, "Trusted device updated"); 176 - SuccessResponse::ok().into_response() 177 - } 178 - Err(e) => { 178 + .map_err(|e| { 179 179 error!("DB error: {:?}", e); 180 - ApiError::InternalError(None).into_response() 181 - } 182 - } 180 + ApiError::InternalError(None) 181 + })?; 182 + 183 + info!(did = %auth.did, device_id = %input.device_id, "Trusted device updated"); 184 + Ok(SuccessResponse::ok().into_response()) 183 185 } 184 186 185 187 pub async fn get_device_trust_state(
+9 -28
crates/tranquil-pds/src/api/temp.rs
··· 1 1 use crate::api::error::ApiError; 2 - use crate::auth::{BearerAuth, extract_auth_token_from_header, validate_token_with_dpop}; 2 + use crate::auth::{Active, Auth, Permissive}; 3 3 use crate::state::AppState; 4 4 use axum::{ 5 5 Json, 6 6 extract::State, 7 - http::HeaderMap, 8 7 response::{IntoResponse, Response}, 9 8 }; 10 9 use cid::Cid; ··· 22 21 pub estimated_time_ms: Option<i64>, 23 22 } 24 23 25 - pub async fn check_signup_queue(State(state): State<AppState>, headers: HeaderMap) -> Response { 26 - if let Some(extracted) = 27 - extract_auth_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) 24 + pub async fn check_signup_queue(auth: Option<Auth<Permissive>>) -> Response { 25 + if let Some(ref user) = auth 26 + && user.is_oauth() 28 27 { 29 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 30 - if let Ok(user) = validate_token_with_dpop( 31 - state.user_repo.as_ref(), 32 - state.oauth_repo.as_ref(), 33 - &extracted.token, 34 - extracted.is_dpop, 35 - dpop_proof, 36 - "GET", 37 - "/", 38 - false, 39 - false, 40 - ) 41 - .await 42 - && user.is_oauth 43 - { 44 - return ApiError::Forbidden.into_response(); 45 - } 28 + return ApiError::Forbidden.into_response(); 46 29 } 47 30 Json(CheckSignupQueueOutput { 48 31 activated: true, ··· 66 49 67 50 pub async fn dereference_scope( 68 51 State(state): State<AppState>, 69 - auth: BearerAuth, 52 + _auth: Auth<Active>, 70 53 Json(input): Json<DereferenceScopeInput>, 71 - ) -> Response { 72 - let _ = auth; 73 - 54 + ) -> Result<Response, ApiError> { 74 55 let scope_parts: Vec<&str> = input.scope.split_whitespace().collect(); 75 56 let mut resolved_scopes: Vec<String> = Vec::new(); 76 57 ··· 135 116 } 136 117 } 137 118 138 - Json(DereferenceScopeOutput { 119 + Ok(Json(DereferenceScopeOutput { 139 120 scope: resolved_scopes.join(" "), 140 121 }) 141 - .into_response() 122 + .into_response()) 142 123 }
+547
crates/tranquil-pds/src/auth/auth_extractor.rs
··· 1 + mod common; 2 + mod helpers; 3 + 4 + use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 5 + use chrono::Utc; 6 + use common::{base_url, client, create_account_and_login, pds_endpoint}; 7 + use helpers::verify_new_account; 8 + use reqwest::StatusCode; 9 + use serde_json::{Value, json}; 10 + use sha2::{Digest, Sha256}; 11 + use wiremock::matchers::{method, path}; 12 + use wiremock::{Mock, MockServer, ResponseTemplate}; 13 + 14 + fn generate_pkce() -> (String, String) { 15 + let verifier_bytes: [u8; 32] = rand::random(); 16 + let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); 17 + let mut hasher = Sha256::new(); 18 + hasher.update(code_verifier.as_bytes()); 19 + let code_challenge = URL_SAFE_NO_PAD.encode(hasher.finalize()); 20 + (code_verifier, code_challenge) 21 + } 22 + 23 + async fn setup_mock_client_metadata(redirect_uri: &str, dpop_bound: bool) -> MockServer { 24 + let mock_server = MockServer::start().await; 25 + let metadata = json!({ 26 + "client_id": mock_server.uri(), 27 + "client_name": "Auth Extractor Test Client", 28 + "redirect_uris": [redirect_uri], 29 + "grant_types": ["authorization_code", "refresh_token"], 30 + "response_types": ["code"], 31 + "token_endpoint_auth_method": "none", 32 + "dpop_bound_access_tokens": dpop_bound 33 + }); 34 + Mock::given(method("GET")) 35 + .and(path("/")) 36 + .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 37 + .mount(&mock_server) 38 + .await; 39 + mock_server 40 + } 41 + 42 + async fn get_oauth_session( 43 + http_client: &reqwest::Client, 44 + url: &str, 45 + dpop_bound: bool, 46 + ) -> (String, String, String, String) { 47 + let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 48 + let handle = format!("ae{}", suffix); 49 + let password = "AuthExtract123!"; 50 + let create_res = http_client 51 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 52 + .json(&json!({ 53 + "handle": handle, 54 + "email": format!("{}@example.com", handle), 55 + "password": password 56 + })) 57 + .send() 58 + .await 59 + .unwrap(); 60 + assert_eq!(create_res.status(), StatusCode::OK); 61 + let account: Value = create_res.json().await.unwrap(); 62 + let did = account["did"].as_str().unwrap().to_string(); 63 + verify_new_account(http_client, &did).await; 64 + 65 + let redirect_uri = "https://example.com/auth-callback"; 66 + let mock_client = setup_mock_client_metadata(redirect_uri, dpop_bound).await; 67 + let client_id = mock_client.uri(); 68 + let (code_verifier, code_challenge) = generate_pkce(); 69 + 70 + let par_body: Value = http_client 71 + .post(format!("{}/oauth/par", url)) 72 + .form(&[ 73 + ("response_type", "code"), 74 + ("client_id", &client_id), 75 + ("redirect_uri", redirect_uri), 76 + ("code_challenge", &code_challenge), 77 + ("code_challenge_method", "S256"), 78 + ]) 79 + .send() 80 + .await 81 + .unwrap() 82 + .json() 83 + .await 84 + .unwrap(); 85 + let request_uri = par_body["request_uri"].as_str().unwrap(); 86 + 87 + let auth_res = http_client 88 + .post(format!("{}/oauth/authorize", url)) 89 + .header("Content-Type", "application/json") 90 + .header("Accept", "application/json") 91 + .json(&json!({ 92 + "request_uri": request_uri, 93 + "username": &handle, 94 + "password": password, 95 + "remember_device": false 96 + })) 97 + .send() 98 + .await 99 + .unwrap(); 100 + let auth_body: Value = auth_res.json().await.unwrap(); 101 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 102 + 103 + if location.contains("/oauth/consent") { 104 + let consent_res = http_client 105 + .post(format!("{}/oauth/authorize/consent", url)) 106 + .header("Content-Type", "application/json") 107 + .json(&json!({ 108 + "request_uri": request_uri, 109 + "approved_scopes": ["atproto"], 110 + "remember": false 111 + })) 112 + .send() 113 + .await 114 + .unwrap(); 115 + let consent_body: Value = consent_res.json().await.unwrap(); 116 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 117 + } 118 + 119 + let code = location 120 + .split("code=") 121 + .nth(1) 122 + .unwrap() 123 + .split('&') 124 + .next() 125 + .unwrap(); 126 + 127 + let token_body: Value = http_client 128 + .post(format!("{}/oauth/token", url)) 129 + .form(&[ 130 + ("grant_type", "authorization_code"), 131 + ("code", code), 132 + ("redirect_uri", redirect_uri), 133 + ("code_verifier", &code_verifier), 134 + ("client_id", &client_id), 135 + ]) 136 + .send() 137 + .await 138 + .unwrap() 139 + .json() 140 + .await 141 + .unwrap(); 142 + 143 + ( 144 + token_body["access_token"].as_str().unwrap().to_string(), 145 + token_body["refresh_token"].as_str().unwrap().to_string(), 146 + client_id, 147 + did, 148 + ) 149 + } 150 + 151 + #[tokio::test] 152 + async fn test_oauth_token_works_with_bearer_auth() { 153 + let url = base_url().await; 154 + let http_client = client(); 155 + let (access_token, _, _, did) = get_oauth_session(&http_client, url, false).await; 156 + 157 + let res = http_client 158 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 159 + .bearer_auth(&access_token) 160 + .send() 161 + .await 162 + .unwrap(); 163 + 164 + assert_eq!(res.status(), StatusCode::OK, "OAuth token should work with RequiredAuth extractor"); 165 + let body: Value = res.json().await.unwrap(); 166 + assert_eq!(body["did"].as_str().unwrap(), did); 167 + } 168 + 169 + #[tokio::test] 170 + async fn test_session_token_still_works() { 171 + let url = base_url().await; 172 + let http_client = client(); 173 + let (jwt, did) = create_account_and_login(&http_client).await; 174 + 175 + let res = http_client 176 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 177 + .bearer_auth(&jwt) 178 + .send() 179 + .await 180 + .unwrap(); 181 + 182 + assert_eq!(res.status(), StatusCode::OK, "Session token should still work"); 183 + let body: Value = res.json().await.unwrap(); 184 + assert_eq!(body["did"].as_str().unwrap(), did); 185 + } 186 + 187 + 188 + #[tokio::test] 189 + async fn test_oauth_admin_extractor_allows_oauth_tokens() { 190 + let url = base_url().await; 191 + let http_client = client(); 192 + 193 + let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 194 + let handle = format!("adm{}", suffix); 195 + let password = "AdminOAuth123!"; 196 + let create_res = http_client 197 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 198 + .json(&json!({ 199 + "handle": handle, 200 + "email": format!("{}@example.com", handle), 201 + "password": password 202 + })) 203 + .send() 204 + .await 205 + .unwrap(); 206 + assert_eq!(create_res.status(), StatusCode::OK); 207 + let account: Value = create_res.json().await.unwrap(); 208 + let did = account["did"].as_str().unwrap().to_string(); 209 + verify_new_account(&http_client, &did).await; 210 + 211 + let pool = common::get_test_db_pool().await; 212 + sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did) 213 + .execute(pool) 214 + .await 215 + .expect("Failed to mark user as admin"); 216 + 217 + let redirect_uri = "https://example.com/admin-callback"; 218 + let mock_client = setup_mock_client_metadata(redirect_uri, false).await; 219 + let client_id = mock_client.uri(); 220 + let (code_verifier, code_challenge) = generate_pkce(); 221 + 222 + let par_body: Value = http_client 223 + .post(format!("{}/oauth/par", url)) 224 + .form(&[ 225 + ("response_type", "code"), 226 + ("client_id", &client_id), 227 + ("redirect_uri", redirect_uri), 228 + ("code_challenge", &code_challenge), 229 + ("code_challenge_method", "S256"), 230 + ]) 231 + .send() 232 + .await 233 + .unwrap() 234 + .json() 235 + .await 236 + .unwrap(); 237 + let request_uri = par_body["request_uri"].as_str().unwrap(); 238 + 239 + let auth_res = http_client 240 + .post(format!("{}/oauth/authorize", url)) 241 + .header("Content-Type", "application/json") 242 + .header("Accept", "application/json") 243 + .json(&json!({ 244 + "request_uri": request_uri, 245 + "username": &handle, 246 + "password": password, 247 + "remember_device": false 248 + })) 249 + .send() 250 + .await 251 + .unwrap(); 252 + let auth_body: Value = auth_res.json().await.unwrap(); 253 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 254 + if location.contains("/oauth/consent") { 255 + let consent_res = http_client 256 + .post(format!("{}/oauth/authorize/consent", url)) 257 + .header("Content-Type", "application/json") 258 + .json(&json!({ 259 + "request_uri": request_uri, 260 + "approved_scopes": ["atproto"], 261 + "remember": false 262 + })) 263 + .send() 264 + .await 265 + .unwrap(); 266 + let consent_body: Value = consent_res.json().await.unwrap(); 267 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 268 + } 269 + 270 + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 271 + let token_body: Value = http_client 272 + .post(format!("{}/oauth/token", url)) 273 + .form(&[ 274 + ("grant_type", "authorization_code"), 275 + ("code", code), 276 + ("redirect_uri", redirect_uri), 277 + ("code_verifier", &code_verifier), 278 + ("client_id", &client_id), 279 + ]) 280 + .send() 281 + .await 282 + .unwrap() 283 + .json() 284 + .await 285 + .unwrap(); 286 + let access_token = token_body["access_token"].as_str().unwrap(); 287 + 288 + let res = http_client 289 + .get(format!("{}/xrpc/com.atproto.admin.getAccountInfos?dids={}", url, did)) 290 + .bearer_auth(access_token) 291 + .send() 292 + .await 293 + .unwrap(); 294 + 295 + assert_eq!( 296 + res.status(), 297 + StatusCode::OK, 298 + "OAuth token for admin user should work with admin endpoint" 299 + ); 300 + } 301 + 302 + #[tokio::test] 303 + async fn test_expired_oauth_token_returns_proper_error() { 304 + let url = base_url().await; 305 + let http_client = client(); 306 + 307 + let now = Utc::now().timestamp(); 308 + let header = json!({"alg": "HS256", "typ": "at+jwt"}); 309 + let payload = json!({ 310 + "iss": url, 311 + "sub": "did:plc:test123", 312 + "aud": url, 313 + "iat": now - 7200, 314 + "exp": now - 3600, 315 + "jti": "expired-token", 316 + "sid": "expired-session", 317 + "scope": "atproto", 318 + "client_id": "https://example.com" 319 + }); 320 + let fake_token = format!( 321 + "{}.{}.{}", 322 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()), 323 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), 324 + URL_SAFE_NO_PAD.encode([1u8; 32]) 325 + ); 326 + 327 + let res = http_client 328 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 329 + .bearer_auth(&fake_token) 330 + .send() 331 + .await 332 + .unwrap(); 333 + 334 + assert_eq!( 335 + res.status(), 336 + StatusCode::UNAUTHORIZED, 337 + "Expired token should be rejected" 338 + ); 339 + } 340 + 341 + #[tokio::test] 342 + async fn test_dpop_nonce_error_has_proper_headers() { 343 + let url = base_url().await; 344 + let pds_url = pds_endpoint(); 345 + let http_client = client(); 346 + 347 + let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 348 + let handle = format!("dpop{}", suffix); 349 + let create_res = http_client 350 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 351 + .json(&json!({ 352 + "handle": handle, 353 + "email": format!("{}@test.com", handle), 354 + "password": "DpopTest123!" 355 + })) 356 + .send() 357 + .await 358 + .unwrap(); 359 + assert_eq!(create_res.status(), StatusCode::OK); 360 + let account: Value = create_res.json().await.unwrap(); 361 + let did = account["did"].as_str().unwrap(); 362 + verify_new_account(&http_client, did).await; 363 + 364 + let redirect_uri = "https://example.com/dpop-callback"; 365 + let mock_server = MockServer::start().await; 366 + let client_id = mock_server.uri(); 367 + let metadata = json!({ 368 + "client_id": &client_id, 369 + "client_name": "DPoP Test Client", 370 + "redirect_uris": [redirect_uri], 371 + "grant_types": ["authorization_code", "refresh_token"], 372 + "response_types": ["code"], 373 + "token_endpoint_auth_method": "none", 374 + "dpop_bound_access_tokens": true 375 + }); 376 + Mock::given(method("GET")) 377 + .and(path("/")) 378 + .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 379 + .mount(&mock_server) 380 + .await; 381 + 382 + let (code_verifier, code_challenge) = generate_pkce(); 383 + let par_body: Value = http_client 384 + .post(format!("{}/oauth/par", url)) 385 + .form(&[ 386 + ("response_type", "code"), 387 + ("client_id", &client_id), 388 + ("redirect_uri", redirect_uri), 389 + ("code_challenge", &code_challenge), 390 + ("code_challenge_method", "S256"), 391 + ]) 392 + .send() 393 + .await 394 + .unwrap() 395 + .json() 396 + .await 397 + .unwrap(); 398 + 399 + let request_uri = par_body["request_uri"].as_str().unwrap(); 400 + let auth_res = http_client 401 + .post(format!("{}/oauth/authorize", url)) 402 + .header("Content-Type", "application/json") 403 + .header("Accept", "application/json") 404 + .json(&json!({ 405 + "request_uri": request_uri, 406 + "username": &handle, 407 + "password": "DpopTest123!", 408 + "remember_device": false 409 + })) 410 + .send() 411 + .await 412 + .unwrap(); 413 + let auth_body: Value = auth_res.json().await.unwrap(); 414 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 415 + if location.contains("/oauth/consent") { 416 + let consent_res = http_client 417 + .post(format!("{}/oauth/authorize/consent", url)) 418 + .header("Content-Type", "application/json") 419 + .json(&json!({ 420 + "request_uri": request_uri, 421 + "approved_scopes": ["atproto"], 422 + "remember": false 423 + })) 424 + .send() 425 + .await 426 + .unwrap(); 427 + let consent_body: Value = consent_res.json().await.unwrap(); 428 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 429 + } 430 + 431 + let code = location.split("code=").nth(1).unwrap().split('&').next().unwrap(); 432 + 433 + let token_endpoint = format!("{}/oauth/token", pds_url); 434 + let (_, dpop_proof) = generate_dpop_proof("POST", &token_endpoint, None); 435 + 436 + let token_res = http_client 437 + .post(format!("{}/oauth/token", url)) 438 + .header("DPoP", &dpop_proof) 439 + .form(&[ 440 + ("grant_type", "authorization_code"), 441 + ("code", code), 442 + ("redirect_uri", redirect_uri), 443 + ("code_verifier", &code_verifier), 444 + ("client_id", &client_id), 445 + ]) 446 + .send() 447 + .await 448 + .unwrap(); 449 + 450 + let token_status = token_res.status(); 451 + let token_nonce = token_res.headers().get("dpop-nonce").map(|h| h.to_str().unwrap().to_string()); 452 + let token_body: Value = token_res.json().await.unwrap(); 453 + 454 + let access_token = if token_status == StatusCode::OK { 455 + token_body["access_token"].as_str().unwrap().to_string() 456 + } else if token_body.get("error").and_then(|e| e.as_str()) == Some("use_dpop_nonce") { 457 + let nonce = token_nonce.expect("Token endpoint should return DPoP-Nonce on use_dpop_nonce error"); 458 + let (_, dpop_proof_with_nonce) = generate_dpop_proof("POST", &token_endpoint, Some(&nonce)); 459 + 460 + let retry_res = http_client 461 + .post(format!("{}/oauth/token", url)) 462 + .header("DPoP", &dpop_proof_with_nonce) 463 + .form(&[ 464 + ("grant_type", "authorization_code"), 465 + ("code", code), 466 + ("redirect_uri", redirect_uri), 467 + ("code_verifier", &code_verifier), 468 + ("client_id", &client_id), 469 + ]) 470 + .send() 471 + .await 472 + .unwrap(); 473 + let retry_body: Value = retry_res.json().await.unwrap(); 474 + retry_body["access_token"].as_str().expect("Should get access_token after nonce retry").to_string() 475 + } else { 476 + panic!("Token exchange failed unexpectedly: {:?}", token_body); 477 + }; 478 + 479 + let res = http_client 480 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 481 + .header("Authorization", format!("DPoP {}", access_token)) 482 + .send() 483 + .await 484 + .unwrap(); 485 + 486 + assert_eq!(res.status(), StatusCode::UNAUTHORIZED, "DPoP token without proof should fail"); 487 + 488 + let www_auth = res.headers().get("www-authenticate").map(|h| h.to_str().unwrap()); 489 + assert!(www_auth.is_some(), "Should have WWW-Authenticate header"); 490 + assert!( 491 + www_auth.unwrap().contains("use_dpop_nonce"), 492 + "WWW-Authenticate should indicate dpop nonce required" 493 + ); 494 + 495 + let nonce = res.headers().get("dpop-nonce").map(|h| h.to_str().unwrap()); 496 + assert!(nonce.is_some(), "Should return DPoP-Nonce header"); 497 + 498 + let body: Value = res.json().await.unwrap(); 499 + assert_eq!(body["error"].as_str().unwrap(), "use_dpop_nonce"); 500 + } 501 + 502 + fn generate_dpop_proof(method: &str, uri: &str, nonce: Option<&str>) -> (Value, String) { 503 + use p256::ecdsa::{SigningKey, signature::Signer}; 504 + use p256::elliptic_curve::rand_core::OsRng; 505 + 506 + let signing_key = SigningKey::random(&mut OsRng); 507 + let verifying_key = signing_key.verifying_key(); 508 + let point = verifying_key.to_encoded_point(false); 509 + let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 510 + let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 511 + 512 + let jwk = json!({ 513 + "kty": "EC", 514 + "crv": "P-256", 515 + "x": x, 516 + "y": y 517 + }); 518 + 519 + let header = { 520 + let h = json!({ 521 + "typ": "dpop+jwt", 522 + "alg": "ES256", 523 + "jwk": jwk.clone() 524 + }); 525 + h 526 + }; 527 + 528 + let mut payload = json!({ 529 + "jti": uuid::Uuid::new_v4().to_string(), 530 + "htm": method, 531 + "htu": uri, 532 + "iat": Utc::now().timestamp() 533 + }); 534 + if let Some(n) = nonce { 535 + payload["nonce"] = json!(n); 536 + } 537 + 538 + let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 539 + let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 540 + let signing_input = format!("{}.{}", header_b64, payload_b64); 541 + 542 + let signature: p256::ecdsa::Signature = signing_key.sign(signing_input.as_bytes()); 543 + let sig_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 544 + 545 + let proof = format!("{}.{}", signing_input, sig_b64); 546 + (jwk, proof) 547 + }
+432 -231
crates/tranquil-pds/src/auth/extractor.rs
··· 1 + use std::marker::PhantomData; 2 + 1 3 use axum::{ 2 - extract::FromRequestParts, 3 - http::{header::AUTHORIZATION, request::Parts}, 4 + extract::{FromRequestParts, OptionalFromRequestParts}, 5 + http::{StatusCode, header::AUTHORIZATION, request::Parts}, 4 6 response::{IntoResponse, Response}, 5 7 }; 8 + use tracing::{debug, error, info}; 6 9 7 10 use super::{ 8 - AuthenticatedUser, TokenValidationError, validate_bearer_token_allow_takendown, 9 - validate_bearer_token_cached, validate_bearer_token_cached_allow_deactivated, 10 - validate_token_with_dpop, 11 + AccountStatus, AuthSource, AuthenticatedUser, ServiceTokenClaims, ServiceTokenVerifier, 12 + is_service_token, validate_bearer_token_for_service_auth, 11 13 }; 12 14 use crate::api::error::ApiError; 15 + use crate::oauth::scopes::{RepoAction, ScopePermissions}; 13 16 use crate::state::AppState; 17 + use crate::types::Did; 14 18 use crate::util::build_full_url; 15 - 16 - pub struct BearerAuth(pub AuthenticatedUser); 17 19 18 20 #[derive(Debug)] 19 21 pub enum AuthError { ··· 24 26 AccountDeactivated, 25 27 AccountTakedown, 26 28 AdminRequired, 29 + ServiceAuthNotAllowed, 30 + InsufficientScope(String), 31 + OAuthExpiredToken(String), 32 + UseDpopNonce(String), 33 + InvalidDpopProof(String), 27 34 } 28 35 29 36 impl IntoResponse for AuthError { 30 37 fn into_response(self) -> Response { 31 - ApiError::from(self).into_response() 38 + match self { 39 + Self::UseDpopNonce(nonce) => ( 40 + StatusCode::UNAUTHORIZED, 41 + [ 42 + ("DPoP-Nonce", nonce.as_str()), 43 + ("WWW-Authenticate", "DPoP error=\"use_dpop_nonce\""), 44 + ], 45 + axum::Json(serde_json::json!({ 46 + "error": "use_dpop_nonce", 47 + "message": "DPoP nonce required" 48 + })), 49 + ) 50 + .into_response(), 51 + Self::OAuthExpiredToken(msg) => ApiError::OAuthExpiredToken(Some(msg)).into_response(), 52 + Self::InvalidDpopProof(msg) => ( 53 + StatusCode::UNAUTHORIZED, 54 + [("WWW-Authenticate", "DPoP error=\"invalid_dpop_proof\"")], 55 + axum::Json(serde_json::json!({ 56 + "error": "invalid_dpop_proof", 57 + "message": msg 58 + })), 59 + ) 60 + .into_response(), 61 + Self::InsufficientScope(msg) => ApiError::InsufficientScope(Some(msg)).into_response(), 62 + other => ApiError::from(other).into_response(), 63 + } 32 64 } 33 65 } 34 66 35 - #[cfg(test)] 36 - fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 37 - let auth_header = auth_header.trim(); 38 - 39 - if auth_header.len() < 8 { 40 - return Err(AuthError::InvalidFormat); 41 - } 42 - 43 - let prefix = &auth_header[..7]; 44 - if !prefix.eq_ignore_ascii_case("bearer ") { 45 - return Err(AuthError::InvalidFormat); 46 - } 47 - 48 - let token = auth_header[7..].trim(); 49 - if token.is_empty() { 50 - return Err(AuthError::InvalidFormat); 51 - } 52 - 53 - Ok(token) 67 + pub struct ExtractedToken { 68 + pub token: String, 69 + pub is_dpop: bool, 54 70 } 55 71 56 72 pub fn extract_bearer_token_from_header(auth_header: Option<&str>) -> Option<String> { ··· 73 89 Some(token.to_string()) 74 90 } 75 91 76 - pub struct ExtractedToken { 77 - pub token: String, 78 - pub is_dpop: bool, 79 - } 80 - 81 92 pub fn extract_auth_token_from_header(auth_header: Option<&str>) -> Option<ExtractedToken> { 82 93 let header = auth_header?; 83 94 let header = header.trim(); ··· 107 118 None 108 119 } 109 120 110 - impl FromRequestParts<AppState> for BearerAuth { 121 + pub trait AuthPolicy: Send + Sync + 'static { 122 + fn validate(user: &AuthenticatedUser) -> Result<(), AuthError>; 123 + } 124 + 125 + pub struct Permissive; 126 + 127 + impl AuthPolicy for Permissive { 128 + fn validate(_user: &AuthenticatedUser) -> Result<(), AuthError> { 129 + Ok(()) 130 + } 131 + } 132 + 133 + pub struct Active; 134 + 135 + impl AuthPolicy for Active { 136 + fn validate(user: &AuthenticatedUser) -> Result<(), AuthError> { 137 + if user.status.is_deactivated() { 138 + return Err(AuthError::AccountDeactivated); 139 + } 140 + if user.status.is_takendown() { 141 + return Err(AuthError::AccountTakedown); 142 + } 143 + Ok(()) 144 + } 145 + } 146 + 147 + pub struct NotTakendown; 148 + 149 + impl AuthPolicy for NotTakendown { 150 + fn validate(user: &AuthenticatedUser) -> Result<(), AuthError> { 151 + if user.status.is_takendown() { 152 + return Err(AuthError::AccountTakedown); 153 + } 154 + Ok(()) 155 + } 156 + } 157 + 158 + pub struct AnyUser; 159 + 160 + impl AuthPolicy for AnyUser { 161 + fn validate(_user: &AuthenticatedUser) -> Result<(), AuthError> { 162 + Ok(()) 163 + } 164 + } 165 + 166 + pub struct Admin; 167 + 168 + impl AuthPolicy for Admin { 169 + fn validate(user: &AuthenticatedUser) -> Result<(), AuthError> { 170 + if user.status.is_deactivated() { 171 + return Err(AuthError::AccountDeactivated); 172 + } 173 + if user.status.is_takendown() { 174 + return Err(AuthError::AccountTakedown); 175 + } 176 + if !user.is_admin { 177 + return Err(AuthError::AdminRequired); 178 + } 179 + Ok(()) 180 + } 181 + } 182 + 183 + impl AuthenticatedUser { 184 + pub fn require_active(&self) -> Result<&Self, ApiError> { 185 + if self.status.is_deactivated() { 186 + return Err(ApiError::AccountDeactivated); 187 + } 188 + if self.status.is_takendown() { 189 + return Err(ApiError::AccountTakedown); 190 + } 191 + Ok(self) 192 + } 193 + 194 + pub fn require_not_takendown(&self) -> Result<&Self, ApiError> { 195 + if self.status.is_takendown() { 196 + return Err(ApiError::AccountTakedown); 197 + } 198 + Ok(self) 199 + } 200 + 201 + pub fn require_admin(&self) -> Result<&Self, ApiError> { 202 + if !self.is_admin { 203 + return Err(ApiError::AdminRequired); 204 + } 205 + Ok(self) 206 + } 207 + } 208 + 209 + async fn verify_oauth_token_and_build_user( 210 + state: &AppState, 211 + token: &str, 212 + dpop_proof: Option<&str>, 213 + method: &str, 214 + uri: &str, 215 + ) -> Result<AuthenticatedUser, AuthError> { 216 + match crate::oauth::verify::verify_oauth_access_token( 217 + state.oauth_repo.as_ref(), 218 + token, 219 + dpop_proof, 220 + method, 221 + uri, 222 + ) 223 + .await 224 + { 225 + Ok(result) => { 226 + let user_info = state 227 + .user_repo 228 + .get_user_info_by_did(&result.did) 229 + .await 230 + .ok() 231 + .flatten() 232 + .ok_or(AuthError::AuthenticationFailed)?; 233 + let status = AccountStatus::from_db_fields( 234 + user_info.takedown_ref.as_deref(), 235 + user_info.deactivated_at, 236 + ); 237 + Ok(AuthenticatedUser { 238 + did: result.did, 239 + key_bytes: user_info.key_bytes.and_then(|kb| { 240 + crate::config::decrypt_key(&kb, user_info.encryption_version).ok() 241 + }), 242 + is_admin: user_info.is_admin, 243 + status, 244 + scope: result.scope, 245 + controller_did: None, 246 + auth_source: AuthSource::OAuth, 247 + }) 248 + } 249 + Err(crate::oauth::OAuthError::ExpiredToken(msg)) => Err(AuthError::OAuthExpiredToken(msg)), 250 + Err(crate::oauth::OAuthError::UseDpopNonce(nonce)) => Err(AuthError::UseDpopNonce(nonce)), 251 + Err(crate::oauth::OAuthError::InvalidDpopProof(msg)) => { 252 + Err(AuthError::InvalidDpopProof(msg)) 253 + } 254 + Err(_) => Err(AuthError::AuthenticationFailed), 255 + } 256 + } 257 + 258 + async fn verify_service_token(token: &str) -> Result<ServiceTokenClaims, AuthError> { 259 + let verifier = ServiceTokenVerifier::new(); 260 + let claims = verifier 261 + .verify_service_token(token, None) 262 + .await 263 + .map_err(|e| { 264 + error!("Service token verification failed: {:?}", e); 265 + AuthError::AuthenticationFailed 266 + })?; 267 + 268 + debug!("Service token verified for DID: {}", claims.iss); 269 + Ok(claims) 270 + } 271 + 272 + enum ExtractedAuth { 273 + User(AuthenticatedUser), 274 + Service(ServiceTokenClaims), 275 + } 276 + 277 + async fn extract_auth_internal( 278 + parts: &mut Parts, 279 + state: &AppState, 280 + ) -> Result<ExtractedAuth, AuthError> { 281 + let auth_header = parts 282 + .headers 283 + .get(AUTHORIZATION) 284 + .ok_or(AuthError::MissingToken)? 285 + .to_str() 286 + .map_err(|_| AuthError::InvalidFormat)?; 287 + 288 + let extracted = 289 + extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 290 + 291 + if is_service_token(&extracted.token) { 292 + let claims = verify_service_token(&extracted.token).await?; 293 + return Ok(ExtractedAuth::Service(claims)); 294 + } 295 + 296 + let dpop_proof = parts.headers.get("DPoP").and_then(|h| h.to_str().ok()); 297 + let method = parts.method.as_str(); 298 + let uri = build_full_url(&parts.uri.to_string()); 299 + 300 + match validate_bearer_token_for_service_auth(state.user_repo.as_ref(), &extracted.token).await { 301 + Ok(user) if !user.auth_source.is_oauth() => { 302 + return Ok(ExtractedAuth::User(user)); 303 + } 304 + Ok(_) => {} 305 + Err(super::TokenValidationError::TokenExpired) => { 306 + info!("JWT access token expired, returning ExpiredToken"); 307 + return Err(AuthError::TokenExpired); 308 + } 309 + Err(_) => {} 310 + } 311 + 312 + let user = verify_oauth_token_and_build_user(state, &extracted.token, dpop_proof, method, &uri) 313 + .await?; 314 + Ok(ExtractedAuth::User(user)) 315 + } 316 + 317 + async fn extract_user_auth_internal( 318 + parts: &mut Parts, 319 + state: &AppState, 320 + ) -> Result<AuthenticatedUser, AuthError> { 321 + match extract_auth_internal(parts, state).await? { 322 + ExtractedAuth::User(user) => Ok(user), 323 + ExtractedAuth::Service(_) => Err(AuthError::ServiceAuthNotAllowed), 324 + } 325 + } 326 + 327 + pub struct Auth<P: AuthPolicy = Active>(pub AuthenticatedUser, PhantomData<P>); 328 + 329 + impl<P: AuthPolicy> Auth<P> { 330 + pub fn into_inner(self) -> AuthenticatedUser { 331 + self.0 332 + } 333 + 334 + pub fn needs_scope_check(&self) -> bool { 335 + self.0.is_oauth() 336 + } 337 + 338 + pub fn permissions(&self) -> ScopePermissions { 339 + self.0.permissions() 340 + } 341 + 342 + #[allow(clippy::result_large_err)] 343 + pub fn check_repo_scope(&self, action: RepoAction, collection: &str) -> Result<(), Response> { 344 + if !self.needs_scope_check() { 345 + return Ok(()); 346 + } 347 + self.permissions() 348 + .assert_repo(action, collection) 349 + .map_err(|e| ApiError::InsufficientScope(Some(e.to_string())).into_response()) 350 + } 351 + } 352 + 353 + impl<P: AuthPolicy> std::ops::Deref for Auth<P> { 354 + type Target = AuthenticatedUser; 355 + 356 + fn deref(&self) -> &Self::Target { 357 + &self.0 358 + } 359 + } 360 + 361 + impl<P: AuthPolicy> FromRequestParts<AppState> for Auth<P> { 111 362 type Rejection = AuthError; 112 363 113 364 async fn from_request_parts( 114 365 parts: &mut Parts, 115 366 state: &AppState, 116 367 ) -> Result<Self, Self::Rejection> { 117 - let auth_header = parts 118 - .headers 119 - .get(AUTHORIZATION) 120 - .ok_or(AuthError::MissingToken)? 121 - .to_str() 122 - .map_err(|_| AuthError::InvalidFormat)?; 368 + let user = extract_user_auth_internal(parts, state).await?; 369 + P::validate(&user)?; 370 + Ok(Auth(user, PhantomData)) 371 + } 372 + } 123 373 124 - let extracted = 125 - extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 374 + impl<P: AuthPolicy> OptionalFromRequestParts<AppState> for Auth<P> { 375 + type Rejection = AuthError; 126 376 127 - if extracted.is_dpop { 128 - let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 129 - let method = parts.method.as_str(); 130 - let uri = build_full_url(&parts.uri.to_string()); 131 - 132 - match validate_token_with_dpop( 133 - state.user_repo.as_ref(), 134 - state.oauth_repo.as_ref(), 135 - &extracted.token, 136 - true, 137 - dpop_proof, 138 - method, 139 - &uri, 140 - false, 141 - false, 142 - ) 143 - .await 144 - { 145 - Ok(user) => Ok(BearerAuth(user)), 146 - Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 147 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 148 - Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 149 - Err(_) => Err(AuthError::AuthenticationFailed), 377 + async fn from_request_parts( 378 + parts: &mut Parts, 379 + state: &AppState, 380 + ) -> Result<Option<Self>, Self::Rejection> { 381 + match extract_user_auth_internal(parts, state).await { 382 + Ok(user) => { 383 + P::validate(&user)?; 384 + Ok(Some(Auth(user, PhantomData))) 150 385 } 151 - } else { 152 - match validate_bearer_token_cached( 153 - state.user_repo.as_ref(), 154 - state.cache.as_ref(), 155 - &extracted.token, 156 - ) 157 - .await 158 - { 159 - Ok(user) => Ok(BearerAuth(user)), 160 - Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 161 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 162 - Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 163 - Err(_) => Err(AuthError::AuthenticationFailed), 164 - } 386 + Err(AuthError::MissingToken) => Ok(None), 387 + Err(e) => Err(e), 165 388 } 166 389 } 167 390 } 168 391 169 - pub struct BearerAuthAllowDeactivated(pub AuthenticatedUser); 392 + pub struct ServiceAuth { 393 + pub did: Did, 394 + pub claims: ServiceTokenClaims, 395 + } 396 + 397 + impl ServiceAuth { 398 + pub fn require_lxm(&self, expected_lxm: &str) -> Result<(), ApiError> { 399 + match &self.claims.lxm { 400 + Some(lxm) if lxm == "*" || lxm == expected_lxm => Ok(()), 401 + Some(lxm) => Err(ApiError::AuthorizationError(format!( 402 + "Token lxm '{}' does not permit '{}'", 403 + lxm, expected_lxm 404 + ))), 405 + None => Err(ApiError::AuthorizationError( 406 + "Token missing lxm claim".to_string(), 407 + )), 408 + } 409 + } 410 + } 170 411 171 - impl FromRequestParts<AppState> for BearerAuthAllowDeactivated { 412 + impl FromRequestParts<AppState> for ServiceAuth { 172 413 type Rejection = AuthError; 173 414 174 415 async fn from_request_parts( 175 416 parts: &mut Parts, 176 417 state: &AppState, 177 418 ) -> Result<Self, Self::Rejection> { 178 - let auth_header = parts 179 - .headers 180 - .get(AUTHORIZATION) 181 - .ok_or(AuthError::MissingToken)? 182 - .to_str() 183 - .map_err(|_| AuthError::InvalidFormat)?; 184 - 185 - let extracted = 186 - extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 419 + match extract_auth_internal(parts, state).await? { 420 + ExtractedAuth::Service(claims) => { 421 + let did: Did = claims 422 + .iss 423 + .parse() 424 + .map_err(|_| AuthError::AuthenticationFailed)?; 425 + Ok(ServiceAuth { did, claims }) 426 + } 427 + ExtractedAuth::User(_) => Err(AuthError::AuthenticationFailed), 428 + } 429 + } 430 + } 187 431 188 - if extracted.is_dpop { 189 - let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 190 - let method = parts.method.as_str(); 191 - let uri = build_full_url(&parts.uri.to_string()); 432 + impl OptionalFromRequestParts<AppState> for ServiceAuth { 433 + type Rejection = AuthError; 192 434 193 - match validate_token_with_dpop( 194 - state.user_repo.as_ref(), 195 - state.oauth_repo.as_ref(), 196 - &extracted.token, 197 - true, 198 - dpop_proof, 199 - method, 200 - &uri, 201 - true, 202 - false, 203 - ) 204 - .await 205 - { 206 - Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 207 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 208 - Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 209 - Err(_) => Err(AuthError::AuthenticationFailed), 435 + async fn from_request_parts( 436 + parts: &mut Parts, 437 + state: &AppState, 438 + ) -> Result<Option<Self>, Self::Rejection> { 439 + match extract_auth_internal(parts, state).await { 440 + Ok(ExtractedAuth::Service(claims)) => { 441 + let did: Did = claims 442 + .iss 443 + .parse() 444 + .map_err(|_| AuthError::AuthenticationFailed)?; 445 + Ok(Some(ServiceAuth { did, claims })) 210 446 } 211 - } else { 212 - match validate_bearer_token_cached_allow_deactivated( 213 - state.user_repo.as_ref(), 214 - state.cache.as_ref(), 215 - &extracted.token, 216 - ) 217 - .await 218 - { 219 - Ok(user) => Ok(BearerAuthAllowDeactivated(user)), 220 - Err(TokenValidationError::AccountTakedown) => Err(AuthError::AccountTakedown), 221 - Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 222 - Err(_) => Err(AuthError::AuthenticationFailed), 223 - } 447 + Ok(ExtractedAuth::User(_)) => Err(AuthError::AuthenticationFailed), 448 + Err(AuthError::MissingToken) => Ok(None), 449 + Err(e) => Err(e), 224 450 } 225 451 } 226 452 } 227 453 228 - pub struct BearerAuthAllowTakendown(pub AuthenticatedUser); 454 + pub enum AuthAny<P: AuthPolicy = Active> { 455 + User(Auth<P>), 456 + Service(ServiceAuth), 457 + } 458 + 459 + impl<P: AuthPolicy> AuthAny<P> { 460 + pub fn did(&self) -> &Did { 461 + match self { 462 + Self::User(auth) => &auth.did, 463 + Self::Service(auth) => &auth.did, 464 + } 465 + } 466 + 467 + pub fn as_user(&self) -> Option<&Auth<P>> { 468 + match self { 469 + Self::User(auth) => Some(auth), 470 + Self::Service(_) => None, 471 + } 472 + } 473 + 474 + pub fn as_service(&self) -> Option<&ServiceAuth> { 475 + match self { 476 + Self::User(_) => None, 477 + Self::Service(auth) => Some(auth), 478 + } 479 + } 480 + 481 + pub fn is_service(&self) -> bool { 482 + matches!(self, Self::Service(_)) 483 + } 484 + 485 + pub fn require_lxm(&self, expected_lxm: &str) -> Result<(), ApiError> { 486 + match self { 487 + Self::User(_) => Ok(()), 488 + Self::Service(auth) => auth.require_lxm(expected_lxm), 489 + } 490 + } 491 + } 229 492 230 - impl FromRequestParts<AppState> for BearerAuthAllowTakendown { 493 + impl<P: AuthPolicy> FromRequestParts<AppState> for AuthAny<P> { 231 494 type Rejection = AuthError; 232 495 233 496 async fn from_request_parts( 234 497 parts: &mut Parts, 235 498 state: &AppState, 236 499 ) -> Result<Self, Self::Rejection> { 237 - let auth_header = parts 238 - .headers 239 - .get(AUTHORIZATION) 240 - .ok_or(AuthError::MissingToken)? 241 - .to_str() 242 - .map_err(|_| AuthError::InvalidFormat)?; 243 - 244 - let extracted = 245 - extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 246 - 247 - if extracted.is_dpop { 248 - let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 249 - let method = parts.method.as_str(); 250 - let uri = build_full_url(&parts.uri.to_string()); 251 - 252 - match validate_token_with_dpop( 253 - state.user_repo.as_ref(), 254 - state.oauth_repo.as_ref(), 255 - &extracted.token, 256 - true, 257 - dpop_proof, 258 - method, 259 - &uri, 260 - false, 261 - true, 262 - ) 263 - .await 264 - { 265 - Ok(user) => Ok(BearerAuthAllowTakendown(user)), 266 - Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 267 - Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 268 - Err(_) => Err(AuthError::AuthenticationFailed), 500 + match extract_auth_internal(parts, state).await? { 501 + ExtractedAuth::User(user) => { 502 + P::validate(&user)?; 503 + Ok(AuthAny::User(Auth(user, PhantomData))) 269 504 } 270 - } else { 271 - match validate_bearer_token_allow_takendown(state.user_repo.as_ref(), &extracted.token) 272 - .await 273 - { 274 - Ok(user) => Ok(BearerAuthAllowTakendown(user)), 275 - Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 276 - Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), 277 - Err(_) => Err(AuthError::AuthenticationFailed), 505 + ExtractedAuth::Service(claims) => { 506 + let did: Did = claims 507 + .iss 508 + .parse() 509 + .map_err(|_| AuthError::AuthenticationFailed)?; 510 + Ok(AuthAny::Service(ServiceAuth { did, claims })) 278 511 } 279 512 } 280 513 } 281 514 } 282 515 283 - pub struct BearerAuthAdmin(pub AuthenticatedUser); 284 - 285 - impl FromRequestParts<AppState> for BearerAuthAdmin { 516 + impl<P: AuthPolicy> OptionalFromRequestParts<AppState> for AuthAny<P> { 286 517 type Rejection = AuthError; 287 518 288 519 async fn from_request_parts( 289 520 parts: &mut Parts, 290 521 state: &AppState, 291 - ) -> Result<Self, Self::Rejection> { 292 - let auth_header = parts 293 - .headers 294 - .get(AUTHORIZATION) 295 - .ok_or(AuthError::MissingToken)? 296 - .to_str() 297 - .map_err(|_| AuthError::InvalidFormat)?; 522 + ) -> Result<Option<Self>, Self::Rejection> { 523 + match extract_auth_internal(parts, state).await { 524 + Ok(ExtractedAuth::User(user)) => { 525 + P::validate(&user)?; 526 + Ok(Some(AuthAny::User(Auth(user, PhantomData)))) 527 + } 528 + Ok(ExtractedAuth::Service(claims)) => { 529 + let did: Did = claims 530 + .iss 531 + .parse() 532 + .map_err(|_| AuthError::AuthenticationFailed)?; 533 + Ok(Some(AuthAny::Service(ServiceAuth { did, claims }))) 534 + } 535 + Err(AuthError::MissingToken) => Ok(None), 536 + Err(e) => Err(e), 537 + } 538 + } 539 + } 298 540 299 - let extracted = 300 - extract_auth_token_from_header(Some(auth_header)).ok_or(AuthError::InvalidFormat)?; 541 + #[cfg(test)] 542 + fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> { 543 + let auth_header = auth_header.trim(); 301 544 302 - let user = if extracted.is_dpop { 303 - let dpop_proof = parts.headers.get("dpop").and_then(|h| h.to_str().ok()); 304 - let method = parts.method.as_str(); 305 - let uri = build_full_url(&parts.uri.to_string()); 545 + if auth_header.len() < 8 { 546 + return Err(AuthError::InvalidFormat); 547 + } 306 548 307 - match validate_token_with_dpop( 308 - state.user_repo.as_ref(), 309 - state.oauth_repo.as_ref(), 310 - &extracted.token, 311 - true, 312 - dpop_proof, 313 - method, 314 - &uri, 315 - false, 316 - false, 317 - ) 318 - .await 319 - { 320 - Ok(user) => user, 321 - Err(TokenValidationError::AccountDeactivated) => { 322 - return Err(AuthError::AccountDeactivated); 323 - } 324 - Err(TokenValidationError::AccountTakedown) => { 325 - return Err(AuthError::AccountTakedown); 326 - } 327 - Err(TokenValidationError::TokenExpired) => { 328 - return Err(AuthError::TokenExpired); 329 - } 330 - Err(_) => return Err(AuthError::AuthenticationFailed), 331 - } 332 - } else { 333 - match validate_bearer_token_cached( 334 - state.user_repo.as_ref(), 335 - state.cache.as_ref(), 336 - &extracted.token, 337 - ) 338 - .await 339 - { 340 - Ok(user) => user, 341 - Err(TokenValidationError::AccountDeactivated) => { 342 - return Err(AuthError::AccountDeactivated); 343 - } 344 - Err(TokenValidationError::AccountTakedown) => { 345 - return Err(AuthError::AccountTakedown); 346 - } 347 - Err(TokenValidationError::TokenExpired) => { 348 - return Err(AuthError::TokenExpired); 349 - } 350 - Err(_) => return Err(AuthError::AuthenticationFailed), 351 - } 352 - }; 549 + let prefix = &auth_header[..7]; 550 + if !prefix.eq_ignore_ascii_case("bearer ") { 551 + return Err(AuthError::InvalidFormat); 552 + } 353 553 354 - if !user.is_admin { 355 - return Err(AuthError::AdminRequired); 356 - } 357 - Ok(BearerAuthAdmin(user)) 554 + let token = auth_header[7..].trim(); 555 + if token.is_empty() { 556 + return Err(AuthError::InvalidFormat); 358 557 } 558 + 559 + Ok(token) 359 560 } 360 561 361 562 #[cfg(test)]
+74 -7
crates/tranquil-pds/src/auth/mod.rs
··· 3 3 use std::time::Duration; 4 4 5 5 use crate::AccountStatus; 6 + use crate::api::ApiError; 6 7 use crate::cache::Cache; 7 8 use crate::oauth::scopes::ScopePermissions; 8 9 use crate::types::Did; ··· 16 17 pub mod webauthn; 17 18 18 19 pub use extractor::{ 19 - AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken, 20 - extract_auth_token_from_header, extract_bearer_token_from_header, 20 + Active, Admin, AnyUser, Auth, AuthAny, AuthError, AuthPolicy, ExtractedToken, NotTakendown, 21 + Permissive, ServiceAuth, extract_auth_token_from_header, extract_bearer_token_from_header, 21 22 }; 22 23 pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token}; 23 24 ··· 93 94 } 94 95 } 95 96 97 + pub enum AuthSource { 98 + Session, 99 + OAuth, 100 + Service { claims: ServiceTokenClaims }, 101 + } 102 + 103 + impl AuthSource { 104 + pub fn is_oauth(&self) -> bool { 105 + matches!(self, Self::OAuth) 106 + } 107 + 108 + pub fn is_service(&self) -> bool { 109 + matches!(self, Self::Service { .. }) 110 + } 111 + 112 + pub fn service_claims(&self) -> Option<&ServiceTokenClaims> { 113 + match self { 114 + Self::Service { claims } => Some(claims), 115 + _ => None, 116 + } 117 + } 118 + } 119 + 96 120 pub struct AuthenticatedUser { 97 121 pub did: Did, 98 122 pub key_bytes: Option<Vec<u8>>, 99 - pub is_oauth: bool, 100 123 pub is_admin: bool, 101 124 pub status: AccountStatus, 102 125 pub scope: Option<String>, 103 126 pub controller_did: Option<Did>, 127 + pub auth_source: AuthSource, 128 + } 129 + 130 + impl AuthenticatedUser { 131 + pub fn is_oauth(&self) -> bool { 132 + self.auth_source.is_oauth() 133 + } 134 + 135 + pub fn is_service(&self) -> bool { 136 + self.auth_source.is_service() 137 + } 138 + 139 + pub fn service_claims(&self) -> Option<&ServiceTokenClaims> { 140 + self.auth_source.service_claims() 141 + } 142 + 143 + pub fn require_lxm(&self, expected_lxm: &str) -> Result<(), ApiError> { 144 + match self.auth_source.service_claims() { 145 + Some(claims) => match &claims.lxm { 146 + Some(lxm) if lxm == "*" || lxm == expected_lxm => Ok(()), 147 + Some(lxm) => Err(ApiError::AuthorizationError(format!( 148 + "Token lxm '{}' does not permit '{}'", 149 + lxm, expected_lxm 150 + ))), 151 + None => Err(ApiError::AuthorizationError( 152 + "Token missing lxm claim".to_string(), 153 + )), 154 + }, 155 + None => Ok(()), 156 + } 157 + } 158 + 159 + pub fn require_user(&self) -> Result<&Self, ApiError> { 160 + if self.is_service() { 161 + return Err(ApiError::AuthenticationFailed(Some( 162 + "User authentication required".to_string(), 163 + ))); 164 + } 165 + Ok(self) 166 + } 167 + 168 + pub fn as_user(&self) -> Option<&Self> { 169 + if self.is_service() { None } else { Some(self) } 170 + } 104 171 } 105 172 106 173 impl AuthenticatedUser { ··· 110 177 { 111 178 return ScopePermissions::from_scope_string(Some(scope)); 112 179 } 113 - if !self.is_oauth { 180 + if !self.is_oauth() { 114 181 return ScopePermissions::from_scope_string(Some("atproto")); 115 182 } 116 183 ScopePermissions::from_scope_string(self.scope.as_deref()) ··· 348 415 return Ok(AuthenticatedUser { 349 416 did: did.clone(), 350 417 key_bytes: Some(decrypted_key), 351 - is_oauth: false, 352 418 is_admin, 353 419 status, 354 420 scope: token_data.claims.scope.clone(), 355 421 controller_did, 422 + auth_source: AuthSource::Session, 356 423 }); 357 424 } 358 425 } ··· 396 463 return Ok(AuthenticatedUser { 397 464 did: Did::new_unchecked(oauth_token.did), 398 465 key_bytes, 399 - is_oauth: true, 400 466 is_admin: oauth_token.is_admin, 401 467 status, 402 468 scope: oauth_info.scope, 403 469 controller_did: oauth_info.controller_did.map(Did::new_unchecked), 470 + auth_source: AuthSource::OAuth, 404 471 }); 405 472 } else { 406 473 return Err(TokenValidationError::TokenExpired); ··· 480 547 Ok(AuthenticatedUser { 481 548 did: Did::new_unchecked(result.did), 482 549 key_bytes, 483 - is_oauth: true, 484 550 is_admin: user_info.is_admin, 485 551 status, 486 552 scope: result.scope, 487 553 controller_did: None, 554 + auth_source: AuthSource::OAuth, 488 555 }) 489 556 } 490 557 Err(crate::oauth::OAuthError::ExpiredToken(_)) => {
+11 -2
crates/tranquil-pds/src/lib.rs
··· 528 528 )); 529 529 let xrpc_service = ServiceBuilder::new() 530 530 .layer(XrpcProxyLayer::new(state.clone())) 531 - .service(xrpc_router.with_state(state.clone())); 531 + .service( 532 + xrpc_router 533 + .layer(middleware::from_fn(oauth::verify::dpop_nonce_middleware)) 534 + .with_state(state.clone()), 535 + ); 532 536 533 537 let oauth_router = Router::new() 534 538 .route("/jwks", get(oauth::endpoints::oauth_jwks)) ··· 568 572 "/register/complete", 569 573 post(oauth::endpoints::register_complete), 570 574 ) 575 + .route( 576 + "/establish-session", 577 + post(oauth::endpoints::establish_session), 578 + ) 571 579 .route("/authorize/consent", get(oauth::endpoints::consent_get)) 572 580 .route("/authorize/consent", post(oauth::endpoints::consent_post)) 573 581 .route( ··· 605 613 .route( 606 614 "/sso/check-handle-available", 607 615 get(sso::endpoints::check_handle_available), 608 - ); 616 + ) 617 + .layer(middleware::from_fn(oauth::verify::dpop_nonce_middleware)); 609 618 610 619 let well_known_router = Router::new() 611 620 .route("/did.json", get(api::identity::well_known_did))
+148 -25
crates/tranquil-pds/src/oauth/endpoints/authorize.rs
··· 1 1 use crate::comms::{channel_display_name, comms_repo::enqueue_2fa_code}; 2 2 use crate::oauth::{ 3 3 AuthFlowState, ClientMetadataCache, Code, DeviceData, DeviceId, OAuthError, SessionId, 4 - db::should_show_consent, 4 + db::should_show_consent, scopes::expand_include_scopes, 5 5 }; 6 6 use crate::state::{AppState, RateLimitKind}; 7 7 use crate::types::{Did, Handle, PlainPassword}; ··· 1106 1106 .oauth_repo 1107 1107 .upsert_account_device(&did, &select_device_typed) 1108 1108 .await; 1109 + 1110 + let requested_scope_str = request_data 1111 + .parameters 1112 + .scope 1113 + .as_deref() 1114 + .unwrap_or("atproto"); 1115 + let requested_scopes: Vec<String> = requested_scope_str 1116 + .split_whitespace() 1117 + .map(|s| s.to_string()) 1118 + .collect(); 1119 + let client_id_typed = ClientId::from(request_data.parameters.client_id.clone()); 1120 + let needs_consent = should_show_consent( 1121 + state.oauth_repo.as_ref(), 1122 + &did, 1123 + &client_id_typed, 1124 + &requested_scopes, 1125 + ) 1126 + .await 1127 + .unwrap_or(true); 1128 + 1129 + if needs_consent { 1130 + if state 1131 + .oauth_repo 1132 + .set_authorization_did(&select_request_id, &did, Some(&select_device_typed)) 1133 + .await 1134 + .is_err() 1135 + { 1136 + return json_error( 1137 + StatusCode::INTERNAL_SERVER_ERROR, 1138 + "server_error", 1139 + "An error occurred. Please try again.", 1140 + ); 1141 + } 1142 + let consent_url = format!( 1143 + "/app/oauth/consent?request_uri={}", 1144 + url_encode(&form.request_uri) 1145 + ); 1146 + return Json(serde_json::json!({"redirect_uri": consent_url})).into_response(); 1147 + } 1148 + 1109 1149 let code = Code::generate(); 1110 1150 let select_code = AuthorizationCode::from(code.0.clone()); 1111 1151 if state ··· 1475 1515 requested_scope_str.to_string() 1476 1516 }; 1477 1517 1478 - let requested_scopes: Vec<&str> = effective_scope_str.split_whitespace().collect(); 1518 + let expanded_scope_str = expand_include_scopes(&effective_scope_str).await; 1519 + let requested_scopes: Vec<&str> = expanded_scope_str.split_whitespace().collect(); 1479 1520 let consent_client_id = ClientId::from(request_data.parameters.client_id.clone()); 1480 1521 let preferences = state 1481 1522 .oauth_repo ··· 2407 2448 } 2408 2449 2409 2450 let delegation_from_param = match &form.delegated_did { 2410 - Some(delegated_did_str) => { 2411 - match delegated_did_str.parse::<tranquil_types::Did>() { 2412 - Ok(delegated_did) if delegated_did != user.did => { 2413 - match state 2414 - .delegation_repo 2415 - .get_delegation(&delegated_did, &user.did) 2416 - .await 2417 - { 2418 - Ok(Some(_)) => Some(delegated_did), 2419 - Ok(None) => None, 2420 - Err(e) => { 2421 - tracing::warn!( 2422 - error = %e, 2423 - delegated_did = %delegated_did, 2424 - controller_did = %user.did, 2425 - "Failed to verify delegation relationship" 2426 - ); 2427 - None 2428 - } 2451 + Some(delegated_did_str) => match delegated_did_str.parse::<tranquil_types::Did>() { 2452 + Ok(delegated_did) if delegated_did != user.did => { 2453 + match state 2454 + .delegation_repo 2455 + .get_delegation(&delegated_did, &user.did) 2456 + .await 2457 + { 2458 + Ok(Some(_)) => Some(delegated_did), 2459 + Ok(None) => None, 2460 + Err(e) => { 2461 + tracing::warn!( 2462 + error = %e, 2463 + delegated_did = %delegated_did, 2464 + controller_did = %user.did, 2465 + "Failed to verify delegation relationship" 2466 + ); 2467 + None 2429 2468 } 2430 2469 } 2431 - _ => None, 2432 2470 } 2433 - } 2471 + _ => None, 2472 + }, 2434 2473 None => None, 2435 2474 }; 2436 2475 2437 2476 let is_delegation_flow = delegation_from_param.is_some() 2438 - || request_data.did.as_ref().map_or(false, |existing_did| { 2477 + || request_data.did.as_ref().is_some_and(|existing_did| { 2439 2478 existing_did 2440 2479 .parse::<tranquil_types::Did>() 2441 2480 .ok() 2442 - .map_or(false, |parsed| parsed != user.did) 2481 + .is_some_and(|parsed| parsed != user.did) 2443 2482 }); 2444 2483 2445 2484 if let Some(delegated_did) = delegation_from_param { ··· 3601 3640 ); 3602 3641 Json(serde_json::json!({"redirect_uri": redirect_url})).into_response() 3603 3642 } 3643 + 3644 + pub async fn establish_session( 3645 + State(state): State<AppState>, 3646 + headers: HeaderMap, 3647 + auth: crate::auth::Auth<crate::auth::Active>, 3648 + ) -> Response { 3649 + let did = &auth.did; 3650 + 3651 + let existing_device = extract_device_cookie(&headers); 3652 + 3653 + let (device_id, new_cookie) = match existing_device { 3654 + Some(id) => { 3655 + let device_typed = DeviceIdType::from(id.clone()); 3656 + let _ = state 3657 + .oauth_repo 3658 + .upsert_account_device(did, &device_typed) 3659 + .await; 3660 + (id, None) 3661 + } 3662 + None => { 3663 + let new_id = DeviceId::generate(); 3664 + let device_data = DeviceData { 3665 + session_id: SessionId::generate().0, 3666 + user_agent: extract_user_agent(&headers), 3667 + ip_address: extract_client_ip(&headers), 3668 + last_seen_at: Utc::now(), 3669 + }; 3670 + let device_typed = DeviceIdType::from(new_id.0.clone()); 3671 + 3672 + if let Err(e) = state 3673 + .oauth_repo 3674 + .create_device(&device_typed, &device_data) 3675 + .await 3676 + { 3677 + tracing::error!(error = ?e, "Failed to create device"); 3678 + return ( 3679 + StatusCode::INTERNAL_SERVER_ERROR, 3680 + Json(serde_json::json!({ 3681 + "error": "server_error", 3682 + "error_description": "Failed to establish session" 3683 + })), 3684 + ) 3685 + .into_response(); 3686 + } 3687 + 3688 + if let Err(e) = state 3689 + .oauth_repo 3690 + .upsert_account_device(did, &device_typed) 3691 + .await 3692 + { 3693 + tracing::error!(error = ?e, "Failed to link device to account"); 3694 + return ( 3695 + StatusCode::INTERNAL_SERVER_ERROR, 3696 + Json(serde_json::json!({ 3697 + "error": "server_error", 3698 + "error_description": "Failed to establish session" 3699 + })), 3700 + ) 3701 + .into_response(); 3702 + } 3703 + 3704 + (new_id.0.clone(), Some(make_device_cookie(&new_id.0))) 3705 + } 3706 + }; 3707 + 3708 + tracing::info!(did = %did, device_id = %device_id, "Device session established"); 3709 + 3710 + match new_cookie { 3711 + Some(cookie) => ( 3712 + StatusCode::OK, 3713 + [(SET_COOKIE, cookie)], 3714 + Json(serde_json::json!({ 3715 + "success": true, 3716 + "device_id": device_id 3717 + })), 3718 + ) 3719 + .into_response(), 3720 + None => Json(serde_json::json!({ 3721 + "success": true, 3722 + "device_id": device_id 3723 + })) 3724 + .into_response(), 3725 + } 3726 + }
+8 -56
crates/tranquil-pds/src/oauth/endpoints/delegation.rs
··· 1 - use crate::auth::{extract_auth_token_from_header, validate_token_with_dpop}; 1 + use crate::auth::{Active, Auth}; 2 2 use crate::delegation::DelegationActionType; 3 3 use crate::state::{AppState, RateLimitKind}; 4 4 use crate::types::PlainPassword; 5 - use crate::util::{build_full_url, extract_client_ip}; 5 + use crate::util::extract_client_ip; 6 6 use axum::{ 7 7 Json, 8 8 extract::State, ··· 463 463 pub async fn delegation_auth_token( 464 464 State(state): State<AppState>, 465 465 headers: HeaderMap, 466 + auth: Auth<Active>, 466 467 Json(form): Json<DelegationTokenAuthSubmit>, 467 468 ) -> Response { 468 - let auth_header = headers.get("authorization").and_then(|v| v.to_str().ok()); 469 - 470 - let extracted = match extract_auth_token_from_header(auth_header) { 471 - Some(e) => e, 472 - None => { 473 - return ( 474 - StatusCode::UNAUTHORIZED, 475 - Json(DelegationAuthResponse { 476 - success: false, 477 - needs_totp: None, 478 - redirect_uri: None, 479 - error: Some("Missing or invalid authorization header".to_string()), 480 - }), 481 - ) 482 - .into_response(); 483 - } 484 - }; 485 - 486 - let dpop_proof = headers.get("dpop").and_then(|h| h.to_str().ok()); 487 - let uri = build_full_url("/oauth/delegation/auth-token"); 488 - 489 - let auth_user = match validate_token_with_dpop( 490 - state.user_repo.as_ref(), 491 - state.oauth_repo.as_ref(), 492 - &extracted.token, 493 - extracted.is_dpop, 494 - dpop_proof, 495 - "POST", 496 - &uri, 497 - false, 498 - false, 499 - ) 500 - .await 501 - { 502 - Ok(user) => user, 503 - Err(_) => { 504 - return ( 505 - StatusCode::UNAUTHORIZED, 506 - Json(DelegationAuthResponse { 507 - success: false, 508 - needs_totp: None, 509 - redirect_uri: None, 510 - error: Some("Invalid or expired access token".to_string()), 511 - }), 512 - ) 513 - .into_response(); 514 - } 515 - }; 516 - 517 - let controller_did = auth_user.did; 469 + let controller_did = &auth.did; 518 470 519 471 let delegated_did: Did = match form.delegated_did.parse() { 520 472 Ok(d) => d, ··· 558 510 559 511 let grant = match state 560 512 .delegation_repo 561 - .get_delegation(&delegated_did, &controller_did) 513 + .get_delegation(&delegated_did, controller_did) 562 514 .await 563 515 { 564 516 Ok(Some(g)) => g, ··· 599 551 600 552 if state 601 553 .oauth_repo 602 - .set_controller_did(&request_id, &controller_did) 554 + .set_controller_did(&request_id, controller_did) 603 555 .await 604 556 .is_err() 605 557 { ··· 622 574 .delegation_repo 623 575 .log_delegation_action( 624 576 &delegated_did, 625 - &controller_did, 626 - Some(&controller_did), 577 + controller_did, 578 + Some(controller_did), 627 579 DelegationActionType::TokenIssued, 628 580 Some(serde_json::json!({ 629 581 "client_id": request.client_id,
+31 -13
crates/tranquil-pds/src/oauth/verify.rs
··· 10 10 use sha2::Sha256; 11 11 use subtle::ConstantTimeEq; 12 12 use tranquil_db_traits::{OAuthRepository, UserRepository}; 13 - use tranquil_types::TokenId; 13 + use tranquil_types::{ClientId, TokenId}; 14 + 15 + use crate::types::Did; 14 16 15 17 use super::scopes::ScopePermissions; 16 18 use super::{DPoPVerifier, OAuthError}; ··· 27 29 } 28 30 29 31 pub struct VerifyResult { 30 - pub did: String, 31 - pub token_id: String, 32 - pub client_id: String, 32 + pub did: Did, 33 + pub token_id: TokenId, 34 + pub client_id: ClientId, 33 35 pub scope: Option<String>, 34 36 } 35 37 ··· 91 93 )); 92 94 } 93 95 } 96 + let did: Did = token_data 97 + .did 98 + .parse() 99 + .map_err(|_| OAuthError::InvalidToken("Invalid DID in token".to_string()))?; 94 100 Ok(VerifyResult { 95 - did: token_data.did, 96 - token_id: token_id.to_string(), 97 - client_id: token_data.client_id, 101 + did, 102 + token_id, 103 + client_id: ClientId::from(token_data.client_id), 98 104 scope: token_data.scope, 99 105 }) 100 106 } ··· 202 208 } 203 209 204 210 pub struct OAuthUser { 205 - pub did: String, 206 - pub client_id: Option<String>, 211 + pub did: Did, 212 + pub client_id: Option<ClientId>, 207 213 pub scope: Option<String>, 208 214 pub is_oauth: bool, 209 215 pub permissions: ScopePermissions, ··· 382 388 } 383 389 384 390 struct LegacyAuthResult { 385 - did: String, 391 + did: Did, 386 392 } 387 393 388 394 async fn try_legacy_auth( ··· 390 396 token: &str, 391 397 ) -> Result<LegacyAuthResult, ()> { 392 398 match crate::auth::validate_bearer_token(user_repo, token).await { 393 - Ok(user) if !user.is_oauth => Ok(LegacyAuthResult { 394 - did: user.did.to_string(), 395 - }), 399 + Ok(user) if !user.is_oauth() => Ok(LegacyAuthResult { did: user.did }), 396 400 _ => Err(()), 397 401 } 398 402 } 403 + 404 + pub async fn dpop_nonce_middleware( 405 + req: axum::http::Request<axum::body::Body>, 406 + next: axum::middleware::Next, 407 + ) -> Response { 408 + let mut response = next.run(req).await; 409 + let config = AuthConfig::get(); 410 + let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 411 + let nonce = verifier.generate_nonce(); 412 + if let Ok(nonce_val) = nonce.parse() { 413 + response.headers_mut().insert("DPoP-Nonce", nonce_val); 414 + } 415 + response 416 + }
+2 -2
crates/tranquil-pds/src/sso/endpoints.rs
··· 644 644 645 645 pub async fn get_linked_accounts( 646 646 State(state): State<AppState>, 647 - crate::auth::extractor::BearerAuth(auth): crate::auth::extractor::BearerAuth, 647 + auth: crate::auth::Auth<crate::auth::Active>, 648 648 ) -> Result<Json<LinkedAccountsResponse>, ApiError> { 649 649 let identities = state 650 650 .sso_repo ··· 679 679 680 680 pub async fn unlink_account( 681 681 State(state): State<AppState>, 682 - crate::auth::extractor::BearerAuth(auth): crate::auth::extractor::BearerAuth, 682 + auth: crate::auth::Auth<crate::auth::Active>, 683 683 Json(input): Json<UnlinkAccountRequest>, 684 684 ) -> Result<Json<UnlinkAccountResponse>, ApiError> { 685 685 if !state
+102
crates/tranquil-pds/tests/actor.rs
··· 436 436 assert_eq!(declared_age["isOverAge16"], false); 437 437 assert_eq!(declared_age["isOverAge18"], false); 438 438 } 439 + 440 + #[tokio::test] 441 + async fn test_deactivated_account_can_get_preferences() { 442 + let client = client(); 443 + let base = base_url().await; 444 + let (token, _did) = create_account_and_login(&client).await; 445 + 446 + let prefs = json!({ 447 + "preferences": [ 448 + { 449 + "$type": "app.bsky.actor.defs#adultContentPref", 450 + "enabled": true 451 + } 452 + ] 453 + }); 454 + let put_resp = client 455 + .post(format!("{}/xrpc/app.bsky.actor.putPreferences", base)) 456 + .header("Authorization", format!("Bearer {}", token)) 457 + .json(&prefs) 458 + .send() 459 + .await 460 + .unwrap(); 461 + assert_eq!(put_resp.status(), 200); 462 + 463 + let deactivate = client 464 + .post(format!( 465 + "{}/xrpc/com.atproto.server.deactivateAccount", 466 + base 467 + )) 468 + .header("Authorization", format!("Bearer {}", token)) 469 + .json(&json!({})) 470 + .send() 471 + .await 472 + .unwrap(); 473 + assert_eq!(deactivate.status(), 200); 474 + 475 + let get_resp = client 476 + .get(format!("{}/xrpc/app.bsky.actor.getPreferences", base)) 477 + .header("Authorization", format!("Bearer {}", token)) 478 + .send() 479 + .await 480 + .unwrap(); 481 + assert_eq!( 482 + get_resp.status(), 483 + 200, 484 + "Deactivated account should still be able to get preferences" 485 + ); 486 + let body: Value = get_resp.json().await.unwrap(); 487 + let prefs_arr = body["preferences"].as_array().unwrap(); 488 + assert_eq!(prefs_arr.len(), 1); 489 + } 490 + 491 + #[tokio::test] 492 + async fn test_deactivated_account_can_put_preferences() { 493 + let client = client(); 494 + let base = base_url().await; 495 + let (token, _did) = create_account_and_login(&client).await; 496 + 497 + let deactivate = client 498 + .post(format!( 499 + "{}/xrpc/com.atproto.server.deactivateAccount", 500 + base 501 + )) 502 + .header("Authorization", format!("Bearer {}", token)) 503 + .json(&json!({})) 504 + .send() 505 + .await 506 + .unwrap(); 507 + assert_eq!(deactivate.status(), 200); 508 + 509 + let prefs = json!({ 510 + "preferences": [ 511 + { 512 + "$type": "app.bsky.actor.defs#adultContentPref", 513 + "enabled": true 514 + } 515 + ] 516 + }); 517 + let put_resp = client 518 + .post(format!("{}/xrpc/app.bsky.actor.putPreferences", base)) 519 + .header("Authorization", format!("Bearer {}", token)) 520 + .json(&prefs) 521 + .send() 522 + .await 523 + .unwrap(); 524 + assert_eq!( 525 + put_resp.status(), 526 + 200, 527 + "Deactivated account should still be able to put preferences" 528 + ); 529 + 530 + let get_resp = client 531 + .get(format!("{}/xrpc/app.bsky.actor.getPreferences", base)) 532 + .header("Authorization", format!("Bearer {}", token)) 533 + .send() 534 + .await 535 + .unwrap(); 536 + assert_eq!(get_resp.status(), 200); 537 + let body: Value = get_resp.json().await.unwrap(); 538 + let prefs_arr = body["preferences"].as_array().unwrap(); 539 + assert_eq!(prefs_arr.len(), 1); 540 + }
+648
crates/tranquil-pds/tests/auth_extractor.rs
··· 1 + mod common; 2 + mod helpers; 3 + 4 + use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; 5 + use chrono::Utc; 6 + use common::{base_url, client, create_account_and_login, pds_endpoint}; 7 + use helpers::verify_new_account; 8 + use reqwest::StatusCode; 9 + use serde_json::{Value, json}; 10 + use sha2::{Digest, Sha256}; 11 + use wiremock::matchers::{method, path}; 12 + use wiremock::{Mock, MockServer, ResponseTemplate}; 13 + 14 + fn generate_pkce() -> (String, String) { 15 + let verifier_bytes: [u8; 32] = rand::random(); 16 + let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); 17 + let mut hasher = Sha256::new(); 18 + hasher.update(code_verifier.as_bytes()); 19 + let code_challenge = URL_SAFE_NO_PAD.encode(hasher.finalize()); 20 + (code_verifier, code_challenge) 21 + } 22 + 23 + async fn setup_mock_client_metadata(redirect_uri: &str, dpop_bound: bool) -> MockServer { 24 + let mock_server = MockServer::start().await; 25 + let metadata = json!({ 26 + "client_id": mock_server.uri(), 27 + "client_name": "Auth Extractor Test Client", 28 + "redirect_uris": [redirect_uri], 29 + "grant_types": ["authorization_code", "refresh_token"], 30 + "response_types": ["code"], 31 + "token_endpoint_auth_method": "none", 32 + "dpop_bound_access_tokens": dpop_bound 33 + }); 34 + Mock::given(method("GET")) 35 + .and(path("/")) 36 + .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 37 + .mount(&mock_server) 38 + .await; 39 + mock_server 40 + } 41 + 42 + async fn get_oauth_session( 43 + http_client: &reqwest::Client, 44 + url: &str, 45 + dpop_bound: bool, 46 + ) -> (String, String, String, String) { 47 + let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 48 + let handle = format!("ae{}", suffix); 49 + let password = "AuthExtract123!"; 50 + let create_res = http_client 51 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 52 + .json(&json!({ 53 + "handle": handle, 54 + "email": format!("{}@example.com", handle), 55 + "password": password 56 + })) 57 + .send() 58 + .await 59 + .unwrap(); 60 + assert_eq!(create_res.status(), StatusCode::OK); 61 + let account: Value = create_res.json().await.unwrap(); 62 + let did = account["did"].as_str().unwrap().to_string(); 63 + verify_new_account(http_client, &did).await; 64 + 65 + let redirect_uri = "https://example.com/auth-callback"; 66 + let mock_client = setup_mock_client_metadata(redirect_uri, dpop_bound).await; 67 + let client_id = mock_client.uri(); 68 + let (code_verifier, code_challenge) = generate_pkce(); 69 + 70 + let par_body: Value = http_client 71 + .post(format!("{}/oauth/par", url)) 72 + .form(&[ 73 + ("response_type", "code"), 74 + ("client_id", &client_id), 75 + ("redirect_uri", redirect_uri), 76 + ("code_challenge", &code_challenge), 77 + ("code_challenge_method", "S256"), 78 + ]) 79 + .send() 80 + .await 81 + .unwrap() 82 + .json() 83 + .await 84 + .unwrap(); 85 + let request_uri = par_body["request_uri"].as_str().unwrap(); 86 + 87 + let auth_res = http_client 88 + .post(format!("{}/oauth/authorize", url)) 89 + .header("Content-Type", "application/json") 90 + .header("Accept", "application/json") 91 + .json(&json!({ 92 + "request_uri": request_uri, 93 + "username": &handle, 94 + "password": password, 95 + "remember_device": false 96 + })) 97 + .send() 98 + .await 99 + .unwrap(); 100 + let auth_body: Value = auth_res.json().await.unwrap(); 101 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 102 + 103 + if location.contains("/oauth/consent") { 104 + let consent_res = http_client 105 + .post(format!("{}/oauth/authorize/consent", url)) 106 + .header("Content-Type", "application/json") 107 + .json(&json!({ 108 + "request_uri": request_uri, 109 + "approved_scopes": ["atproto"], 110 + "remember": false 111 + })) 112 + .send() 113 + .await 114 + .unwrap(); 115 + let consent_body: Value = consent_res.json().await.unwrap(); 116 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 117 + } 118 + 119 + let code = location 120 + .split("code=") 121 + .nth(1) 122 + .unwrap() 123 + .split('&') 124 + .next() 125 + .unwrap(); 126 + 127 + let token_body: Value = http_client 128 + .post(format!("{}/oauth/token", url)) 129 + .form(&[ 130 + ("grant_type", "authorization_code"), 131 + ("code", code), 132 + ("redirect_uri", redirect_uri), 133 + ("code_verifier", &code_verifier), 134 + ("client_id", &client_id), 135 + ]) 136 + .send() 137 + .await 138 + .unwrap() 139 + .json() 140 + .await 141 + .unwrap(); 142 + 143 + ( 144 + token_body["access_token"].as_str().unwrap().to_string(), 145 + token_body["refresh_token"].as_str().unwrap().to_string(), 146 + client_id, 147 + did, 148 + ) 149 + } 150 + 151 + #[tokio::test] 152 + async fn test_oauth_token_works_with_bearer_auth() { 153 + let url = base_url().await; 154 + let http_client = client(); 155 + let (access_token, _, _, did) = get_oauth_session(&http_client, url, false).await; 156 + 157 + let res = http_client 158 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 159 + .bearer_auth(&access_token) 160 + .send() 161 + .await 162 + .unwrap(); 163 + 164 + assert_eq!( 165 + res.status(), 166 + StatusCode::OK, 167 + "OAuth token should work with BearerAuth extractor" 168 + ); 169 + let body: Value = res.json().await.unwrap(); 170 + assert_eq!(body["did"].as_str().unwrap(), did); 171 + } 172 + 173 + #[tokio::test] 174 + async fn test_session_token_still_works() { 175 + let url = base_url().await; 176 + let http_client = client(); 177 + let (jwt, did) = create_account_and_login(&http_client).await; 178 + 179 + let res = http_client 180 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 181 + .bearer_auth(&jwt) 182 + .send() 183 + .await 184 + .unwrap(); 185 + 186 + assert_eq!( 187 + res.status(), 188 + StatusCode::OK, 189 + "Session token should still work" 190 + ); 191 + let body: Value = res.json().await.unwrap(); 192 + assert_eq!(body["did"].as_str().unwrap(), did); 193 + } 194 + 195 + #[tokio::test] 196 + async fn test_oauth_admin_extractor_allows_oauth_tokens() { 197 + let url = base_url().await; 198 + let http_client = client(); 199 + 200 + let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 201 + let handle = format!("adm{}", suffix); 202 + let password = "AdminOAuth123!"; 203 + let create_res = http_client 204 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 205 + .json(&json!({ 206 + "handle": handle, 207 + "email": format!("{}@example.com", handle), 208 + "password": password 209 + })) 210 + .send() 211 + .await 212 + .unwrap(); 213 + assert_eq!(create_res.status(), StatusCode::OK); 214 + let account: Value = create_res.json().await.unwrap(); 215 + let did = account["did"].as_str().unwrap().to_string(); 216 + verify_new_account(&http_client, &did).await; 217 + 218 + let pool = common::get_test_db_pool().await; 219 + sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did) 220 + .execute(pool) 221 + .await 222 + .expect("Failed to mark user as admin"); 223 + 224 + let redirect_uri = "https://example.com/admin-callback"; 225 + let mock_client = setup_mock_client_metadata(redirect_uri, false).await; 226 + let client_id = mock_client.uri(); 227 + let (code_verifier, code_challenge) = generate_pkce(); 228 + 229 + let par_body: Value = http_client 230 + .post(format!("{}/oauth/par", url)) 231 + .form(&[ 232 + ("response_type", "code"), 233 + ("client_id", &client_id), 234 + ("redirect_uri", redirect_uri), 235 + ("code_challenge", &code_challenge), 236 + ("code_challenge_method", "S256"), 237 + ]) 238 + .send() 239 + .await 240 + .unwrap() 241 + .json() 242 + .await 243 + .unwrap(); 244 + let request_uri = par_body["request_uri"].as_str().unwrap(); 245 + 246 + let auth_res = http_client 247 + .post(format!("{}/oauth/authorize", url)) 248 + .header("Content-Type", "application/json") 249 + .header("Accept", "application/json") 250 + .json(&json!({ 251 + "request_uri": request_uri, 252 + "username": &handle, 253 + "password": password, 254 + "remember_device": false 255 + })) 256 + .send() 257 + .await 258 + .unwrap(); 259 + let auth_body: Value = auth_res.json().await.unwrap(); 260 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 261 + if location.contains("/oauth/consent") { 262 + let consent_res = http_client 263 + .post(format!("{}/oauth/authorize/consent", url)) 264 + .header("Content-Type", "application/json") 265 + .json(&json!({ 266 + "request_uri": request_uri, 267 + "approved_scopes": ["atproto"], 268 + "remember": false 269 + })) 270 + .send() 271 + .await 272 + .unwrap(); 273 + let consent_body: Value = consent_res.json().await.unwrap(); 274 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 275 + } 276 + 277 + let code = location 278 + .split("code=") 279 + .nth(1) 280 + .unwrap() 281 + .split('&') 282 + .next() 283 + .unwrap(); 284 + let token_body: Value = http_client 285 + .post(format!("{}/oauth/token", url)) 286 + .form(&[ 287 + ("grant_type", "authorization_code"), 288 + ("code", code), 289 + ("redirect_uri", redirect_uri), 290 + ("code_verifier", &code_verifier), 291 + ("client_id", &client_id), 292 + ]) 293 + .send() 294 + .await 295 + .unwrap() 296 + .json() 297 + .await 298 + .unwrap(); 299 + let access_token = token_body["access_token"].as_str().unwrap(); 300 + 301 + let res = http_client 302 + .get(format!( 303 + "{}/xrpc/com.atproto.admin.getAccountInfos?dids={}", 304 + url, did 305 + )) 306 + .bearer_auth(access_token) 307 + .send() 308 + .await 309 + .unwrap(); 310 + 311 + assert_eq!( 312 + res.status(), 313 + StatusCode::OK, 314 + "OAuth token for admin user should work with admin endpoint" 315 + ); 316 + } 317 + 318 + #[tokio::test] 319 + async fn test_expired_oauth_token_returns_proper_error() { 320 + let url = base_url().await; 321 + let http_client = client(); 322 + 323 + let now = Utc::now().timestamp(); 324 + let header = json!({"alg": "HS256", "typ": "at+jwt"}); 325 + let payload = json!({ 326 + "iss": url, 327 + "sub": "did:plc:test123", 328 + "aud": url, 329 + "iat": now - 7200, 330 + "exp": now - 3600, 331 + "jti": "expired-token", 332 + "sid": "expired-session", 333 + "scope": "atproto", 334 + "client_id": "https://example.com" 335 + }); 336 + let fake_token = format!( 337 + "{}.{}.{}", 338 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()), 339 + URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()), 340 + URL_SAFE_NO_PAD.encode([1u8; 32]) 341 + ); 342 + 343 + let res = http_client 344 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 345 + .bearer_auth(&fake_token) 346 + .send() 347 + .await 348 + .unwrap(); 349 + 350 + assert_eq!( 351 + res.status(), 352 + StatusCode::UNAUTHORIZED, 353 + "Expired token should be rejected" 354 + ); 355 + } 356 + 357 + #[tokio::test] 358 + async fn test_dpop_nonce_error_has_proper_headers() { 359 + let url = base_url().await; 360 + let pds_url = pds_endpoint(); 361 + let http_client = client(); 362 + 363 + let suffix = &uuid::Uuid::new_v4().simple().to_string()[..8]; 364 + let handle = format!("dpop{}", suffix); 365 + let create_res = http_client 366 + .post(format!("{}/xrpc/com.atproto.server.createAccount", url)) 367 + .json(&json!({ 368 + "handle": handle, 369 + "email": format!("{}@test.com", handle), 370 + "password": "DpopTest123!" 371 + })) 372 + .send() 373 + .await 374 + .unwrap(); 375 + assert_eq!(create_res.status(), StatusCode::OK); 376 + let account: Value = create_res.json().await.unwrap(); 377 + let did = account["did"].as_str().unwrap(); 378 + verify_new_account(&http_client, did).await; 379 + 380 + let redirect_uri = "https://example.com/dpop-callback"; 381 + let mock_server = MockServer::start().await; 382 + let client_id = mock_server.uri(); 383 + let metadata = json!({ 384 + "client_id": &client_id, 385 + "client_name": "DPoP Test Client", 386 + "redirect_uris": [redirect_uri], 387 + "grant_types": ["authorization_code", "refresh_token"], 388 + "response_types": ["code"], 389 + "token_endpoint_auth_method": "none", 390 + "dpop_bound_access_tokens": true 391 + }); 392 + Mock::given(method("GET")) 393 + .and(path("/")) 394 + .respond_with(ResponseTemplate::new(200).set_body_json(metadata)) 395 + .mount(&mock_server) 396 + .await; 397 + 398 + let (code_verifier, code_challenge) = generate_pkce(); 399 + let par_body: Value = http_client 400 + .post(format!("{}/oauth/par", url)) 401 + .form(&[ 402 + ("response_type", "code"), 403 + ("client_id", &client_id), 404 + ("redirect_uri", redirect_uri), 405 + ("code_challenge", &code_challenge), 406 + ("code_challenge_method", "S256"), 407 + ]) 408 + .send() 409 + .await 410 + .unwrap() 411 + .json() 412 + .await 413 + .unwrap(); 414 + 415 + let request_uri = par_body["request_uri"].as_str().unwrap(); 416 + let auth_res = http_client 417 + .post(format!("{}/oauth/authorize", url)) 418 + .header("Content-Type", "application/json") 419 + .header("Accept", "application/json") 420 + .json(&json!({ 421 + "request_uri": request_uri, 422 + "username": &handle, 423 + "password": "DpopTest123!", 424 + "remember_device": false 425 + })) 426 + .send() 427 + .await 428 + .unwrap(); 429 + let auth_body: Value = auth_res.json().await.unwrap(); 430 + let mut location = auth_body["redirect_uri"].as_str().unwrap().to_string(); 431 + if location.contains("/oauth/consent") { 432 + let consent_res = http_client 433 + .post(format!("{}/oauth/authorize/consent", url)) 434 + .header("Content-Type", "application/json") 435 + .json(&json!({ 436 + "request_uri": request_uri, 437 + "approved_scopes": ["atproto"], 438 + "remember": false 439 + })) 440 + .send() 441 + .await 442 + .unwrap(); 443 + let consent_body: Value = consent_res.json().await.unwrap(); 444 + location = consent_body["redirect_uri"].as_str().unwrap().to_string(); 445 + } 446 + 447 + let code = location 448 + .split("code=") 449 + .nth(1) 450 + .unwrap() 451 + .split('&') 452 + .next() 453 + .unwrap(); 454 + 455 + let token_endpoint = format!("{}/oauth/token", pds_url); 456 + let (_, dpop_proof) = generate_dpop_proof("POST", &token_endpoint, None); 457 + 458 + let token_res = http_client 459 + .post(format!("{}/oauth/token", url)) 460 + .header("DPoP", &dpop_proof) 461 + .form(&[ 462 + ("grant_type", "authorization_code"), 463 + ("code", code), 464 + ("redirect_uri", redirect_uri), 465 + ("code_verifier", &code_verifier), 466 + ("client_id", &client_id), 467 + ]) 468 + .send() 469 + .await 470 + .unwrap(); 471 + 472 + let token_status = token_res.status(); 473 + let token_nonce = token_res 474 + .headers() 475 + .get("dpop-nonce") 476 + .map(|h| h.to_str().unwrap().to_string()); 477 + let token_body: Value = token_res.json().await.unwrap(); 478 + 479 + let access_token = if token_status == StatusCode::OK { 480 + token_body["access_token"].as_str().unwrap().to_string() 481 + } else if token_body.get("error").and_then(|e| e.as_str()) == Some("use_dpop_nonce") { 482 + let nonce = 483 + token_nonce.expect("Token endpoint should return DPoP-Nonce on use_dpop_nonce error"); 484 + let (_, dpop_proof_with_nonce) = generate_dpop_proof("POST", &token_endpoint, Some(&nonce)); 485 + 486 + let retry_res = http_client 487 + .post(format!("{}/oauth/token", url)) 488 + .header("DPoP", &dpop_proof_with_nonce) 489 + .form(&[ 490 + ("grant_type", "authorization_code"), 491 + ("code", code), 492 + ("redirect_uri", redirect_uri), 493 + ("code_verifier", &code_verifier), 494 + ("client_id", &client_id), 495 + ]) 496 + .send() 497 + .await 498 + .unwrap(); 499 + let retry_body: Value = retry_res.json().await.unwrap(); 500 + retry_body["access_token"] 501 + .as_str() 502 + .expect("Should get access_token after nonce retry") 503 + .to_string() 504 + } else { 505 + panic!("Token exchange failed unexpectedly: {:?}", token_body); 506 + }; 507 + 508 + let res = http_client 509 + .get(format!("{}/xrpc/com.atproto.server.getSession", url)) 510 + .header("Authorization", format!("DPoP {}", access_token)) 511 + .send() 512 + .await 513 + .unwrap(); 514 + 515 + assert_eq!( 516 + res.status(), 517 + StatusCode::UNAUTHORIZED, 518 + "DPoP token without proof should fail" 519 + ); 520 + 521 + let www_auth = res 522 + .headers() 523 + .get("www-authenticate") 524 + .map(|h| h.to_str().unwrap()); 525 + assert!(www_auth.is_some(), "Should have WWW-Authenticate header"); 526 + assert!( 527 + www_auth.unwrap().contains("use_dpop_nonce"), 528 + "WWW-Authenticate should indicate dpop nonce required" 529 + ); 530 + 531 + let nonce = res.headers().get("dpop-nonce").map(|h| h.to_str().unwrap()); 532 + assert!(nonce.is_some(), "Should return DPoP-Nonce header"); 533 + 534 + let body: Value = res.json().await.unwrap(); 535 + assert_eq!(body["error"].as_str().unwrap(), "use_dpop_nonce"); 536 + } 537 + 538 + fn generate_dpop_proof(method: &str, uri: &str, nonce: Option<&str>) -> (Value, String) { 539 + use p256::ecdsa::{SigningKey, signature::Signer}; 540 + use p256::elliptic_curve::rand_core::OsRng; 541 + 542 + let signing_key = SigningKey::random(&mut OsRng); 543 + let verifying_key = signing_key.verifying_key(); 544 + let point = verifying_key.to_encoded_point(false); 545 + let x = URL_SAFE_NO_PAD.encode(point.x().unwrap()); 546 + let y = URL_SAFE_NO_PAD.encode(point.y().unwrap()); 547 + 548 + let jwk = json!({ 549 + "kty": "EC", 550 + "crv": "P-256", 551 + "x": x, 552 + "y": y 553 + }); 554 + 555 + let header = { 556 + let h = json!({ 557 + "typ": "dpop+jwt", 558 + "alg": "ES256", 559 + "jwk": jwk.clone() 560 + }); 561 + h 562 + }; 563 + 564 + let mut payload = json!({ 565 + "jti": uuid::Uuid::new_v4().to_string(), 566 + "htm": method, 567 + "htu": uri, 568 + "iat": Utc::now().timestamp() 569 + }); 570 + if let Some(n) = nonce { 571 + payload["nonce"] = json!(n); 572 + } 573 + 574 + let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); 575 + let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload).unwrap()); 576 + let signing_input = format!("{}.{}", header_b64, payload_b64); 577 + 578 + let signature: p256::ecdsa::Signature = signing_key.sign(signing_input.as_bytes()); 579 + let sig_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes()); 580 + 581 + let proof = format!("{}.{}", signing_input, sig_b64); 582 + (jwk, proof) 583 + } 584 + 585 + #[tokio::test] 586 + async fn test_optional_service_auth_extractor_behavior() { 587 + let url = base_url().await; 588 + let http_client = client(); 589 + let (access_jwt, did) = create_account_and_login(&http_client).await; 590 + 591 + let service_auth_res = http_client 592 + .get(format!("{}/xrpc/com.atproto.server.getServiceAuth", url)) 593 + .bearer_auth(&access_jwt) 594 + .query(&[("aud", "did:web:test.example")]) 595 + .send() 596 + .await 597 + .unwrap(); 598 + assert_eq!(service_auth_res.status(), StatusCode::OK); 599 + let service_body: Value = service_auth_res.json().await.unwrap(); 600 + let service_token = service_body["token"].as_str().unwrap(); 601 + 602 + let no_auth_res = http_client 603 + .get(format!( 604 + "{}/xrpc/com.atproto.sync.getBlob?did={}&cid=bafyreifakecidfornowfakecidfornow1234567", 605 + url, did 606 + )) 607 + .send() 608 + .await 609 + .unwrap(); 610 + assert!( 611 + no_auth_res.status() == StatusCode::NOT_FOUND 612 + || no_auth_res.status() == StatusCode::BAD_REQUEST, 613 + "getBlob with no auth should reach handler (AuthAny optional path) - got {}", 614 + no_auth_res.status() 615 + ); 616 + 617 + let service_auth_blob_res = http_client 618 + .get(format!( 619 + "{}/xrpc/com.atproto.sync.getBlob?did={}&cid=bafyreifakecidfornowfakecidfornow1234567", 620 + url, did 621 + )) 622 + .bearer_auth(service_token) 623 + .send() 624 + .await 625 + .unwrap(); 626 + assert!( 627 + service_auth_blob_res.status() == StatusCode::NOT_FOUND 628 + || service_auth_blob_res.status() == StatusCode::BAD_REQUEST, 629 + "getBlob with service auth should reach handler (AuthAny service path) - got {}", 630 + service_auth_blob_res.status() 631 + ); 632 + 633 + let user_auth_blob_res = http_client 634 + .get(format!( 635 + "{}/xrpc/com.atproto.sync.getBlob?did={}&cid=bafyreifakecidfornowfakecidfornow1234567", 636 + url, did 637 + )) 638 + .bearer_auth(&access_jwt) 639 + .send() 640 + .await 641 + .unwrap(); 642 + assert!( 643 + user_auth_blob_res.status() == StatusCode::NOT_FOUND 644 + || user_auth_blob_res.status() == StatusCode::BAD_REQUEST, 645 + "getBlob with user auth should reach handler (AuthAny user path) - got {}", 646 + user_auth_blob_res.status() 647 + ); 648 + }
+3 -3
crates/tranquil-pds/tests/common/mod.rs
··· 1 - #[cfg(feature = "s3-storage")] 1 + #[cfg(all(not(feature = "external-infra"), feature = "s3-storage"))] 2 2 use aws_config::BehaviorVersion; 3 - #[cfg(feature = "s3-storage")] 3 + #[cfg(all(not(feature = "external-infra"), feature = "s3-storage"))] 4 4 use aws_sdk_s3::Client as S3Client; 5 - #[cfg(feature = "s3-storage")] 5 + #[cfg(all(not(feature = "external-infra"), feature = "s3-storage"))] 6 6 use aws_sdk_s3::config::Credentials; 7 7 use chrono::Utc; 8 8 use reqwest::{Client, StatusCode, header};
+8 -2
crates/tranquil-pds/tests/oauth_security.rs
··· 1373 1373 .send() 1374 1374 .await 1375 1375 .unwrap(); 1376 - assert_eq!(token_res.status(), StatusCode::OK, "Token exchange should succeed"); 1376 + assert_eq!( 1377 + token_res.status(), 1378 + StatusCode::OK, 1379 + "Token exchange should succeed" 1380 + ); 1377 1381 let tokens: Value = token_res.json().await.unwrap(); 1378 1382 1379 - let sub = tokens["sub"].as_str().expect("Token response should have sub claim"); 1383 + let sub = tokens["sub"] 1384 + .as_str() 1385 + .expect("Token response should have sub claim"); 1380 1386 1381 1387 assert_eq!( 1382 1388 sub, delegated_did,
+2
crates/tranquil-scopes/Cargo.toml
··· 7 7 [dependencies] 8 8 axum = { workspace = true } 9 9 futures = { workspace = true } 10 + hickory-resolver = { version = "0.24", features = ["tokio-runtime"] } 10 11 reqwest = { workspace = true } 11 12 serde = { workspace = true } 12 13 serde_json = { workspace = true } 13 14 tokio = { workspace = true } 14 15 tracing = { workspace = true } 16 + urlencoding = "2"
+521 -59
crates/tranquil-scopes/src/permission_set.rs
··· 1 + use hickory_resolver::TokioAsyncResolver; 1 2 use reqwest::Client; 2 3 use serde::Deserialize; 3 4 use std::collections::HashMap; ··· 17 18 const CACHE_TTL_SECS: u64 = 3600; 18 19 19 20 #[derive(Debug, Deserialize)] 21 + struct PlcDocument { 22 + service: Vec<PlcService>, 23 + } 24 + 25 + #[derive(Debug, Deserialize)] 26 + struct PlcService { 27 + id: String, 28 + #[serde(rename = "serviceEndpoint")] 29 + service_endpoint: String, 30 + } 31 + 32 + #[derive(Debug, Deserialize)] 33 + struct GetRecordResponse { 34 + value: LexiconDoc, 35 + } 36 + 37 + #[derive(Debug, Deserialize)] 20 38 struct LexiconDoc { 21 39 defs: HashMap<String, LexiconDef>, 22 40 } ··· 31 49 #[derive(Debug, Deserialize)] 32 50 struct PermissionEntry { 33 51 resource: String, 52 + action: Option<Vec<String>>, 34 53 collection: Option<Vec<String>>, 54 + lxm: Option<Vec<String>>, 55 + aud: Option<String>, 35 56 } 36 57 37 58 pub async fn expand_include_scopes(scope_string: &str) -> String { ··· 39 60 .split_whitespace() 40 61 .map(|scope| async move { 41 62 match scope.strip_prefix("include:") { 42 - Some(nsid) => { 43 - let nsid_base = nsid.split('?').next().unwrap_or(nsid); 44 - expand_permission_set(nsid_base).await.unwrap_or_else(|e| { 45 - warn!(nsid = nsid_base, error = %e, "Failed to expand permission set, keeping original"); 46 - scope.to_string() 47 - }) 63 + Some(rest) => { 64 + let (nsid_base, aud) = parse_include_scope(rest); 65 + expand_permission_set(nsid_base, aud) 66 + .await 67 + .unwrap_or_else(|e| { 68 + warn!(nsid = nsid_base, error = %e, "Failed to expand permission set, keeping original"); 69 + scope.to_string() 70 + }) 48 71 } 49 72 None => scope.to_string(), 50 73 } ··· 54 77 futures::future::join_all(futures).await.join(" ") 55 78 } 56 79 57 - async fn expand_permission_set(nsid: &str) -> Result<String, String> { 80 + fn parse_include_scope(rest: &str) -> (&str, Option<&str>) { 81 + rest.split_once('?') 82 + .map(|(nsid, params)| { 83 + let aud = params.split('&').find_map(|p| p.strip_prefix("aud=")); 84 + (nsid, aud) 85 + }) 86 + .unwrap_or((rest, None)) 87 + } 88 + 89 + async fn expand_permission_set(nsid: &str, aud: Option<&str>) -> Result<String, String> { 90 + let cache_key = match aud { 91 + Some(a) => format!("{}?aud={}", nsid, a), 92 + None => nsid.to_string(), 93 + }; 94 + 58 95 { 59 96 let cache = LEXICON_CACHE.read().await; 60 - if let Some(cached) = cache.get(nsid) 97 + if let Some(cached) = cache.get(&cache_key) 61 98 && cached.cached_at.elapsed().as_secs() < CACHE_TTL_SECS 62 99 { 63 100 debug!(nsid, "Using cached permission set expansion"); ··· 65 102 } 66 103 } 67 104 105 + let lexicon = fetch_lexicon_via_atproto(nsid).await?; 106 + 107 + let main_def = lexicon 108 + .defs 109 + .get("main") 110 + .ok_or("Missing 'main' definition in lexicon")?; 111 + 112 + if main_def.def_type != "permission-set" { 113 + return Err(format!( 114 + "Expected permission-set type, got: {}", 115 + main_def.def_type 116 + )); 117 + } 118 + 119 + let permissions = main_def 120 + .permissions 121 + .as_ref() 122 + .ok_or("Missing permissions in permission-set")?; 123 + 124 + let namespace_authority = extract_namespace_authority(nsid); 125 + let expanded = build_expanded_scopes(permissions, aud, &namespace_authority); 126 + 127 + if expanded.is_empty() { 128 + return Err("No valid permissions found in permission-set".to_string()); 129 + } 130 + 131 + { 132 + let mut cache = LEXICON_CACHE.write().await; 133 + cache.insert( 134 + cache_key, 135 + CachedLexicon { 136 + expanded_scope: expanded.clone(), 137 + cached_at: std::time::Instant::now(), 138 + }, 139 + ); 140 + } 141 + 142 + debug!(nsid, expanded = %expanded, "Successfully expanded permission set"); 143 + Ok(expanded) 144 + } 145 + 146 + async fn fetch_lexicon_via_atproto(nsid: &str) -> Result<LexiconDoc, String> { 68 147 let parts: Vec<&str> = nsid.split('.').collect(); 69 148 if parts.len() < 3 { 70 149 return Err(format!("Invalid NSID format: {}", nsid)); 71 150 } 72 151 73 - let domain_parts: Vec<&str> = parts[..2].iter().rev().cloned().collect(); 74 - let domain = domain_parts.join("."); 75 - let path = parts[2..].join("/"); 152 + let authority = parts[..2] 153 + .iter() 154 + .rev() 155 + .cloned() 156 + .collect::<Vec<_>>() 157 + .join("."); 158 + debug!(nsid, authority = %authority, "Resolving lexicon DID authority via DNS"); 76 159 77 - let url = format!("https://{}/lexicons/{}.json", domain, path); 78 - debug!(nsid, url = %url, "Fetching permission set lexicon"); 160 + let did = resolve_lexicon_did_authority(&authority).await?; 161 + debug!(nsid, did = %did, "Resolved lexicon DID authority"); 162 + 163 + let pds_endpoint = resolve_did_to_pds(&did).await?; 164 + debug!(nsid, pds = %pds_endpoint, "Resolved DID to PDS endpoint"); 79 165 80 166 let client = Client::builder() 81 167 .timeout(std::time::Duration::from_secs(10)) 82 168 .build() 83 169 .map_err(|e| format!("Failed to create HTTP client: {}", e))?; 84 170 171 + let url = format!( 172 + "{}/xrpc/com.atproto.repo.getRecord?repo={}&collection=com.atproto.lexicon.schema&rkey={}", 173 + pds_endpoint, 174 + urlencoding::encode(&did), 175 + urlencoding::encode(nsid) 176 + ); 177 + debug!(nsid, url = %url, "Fetching lexicon from PDS"); 178 + 85 179 let response = client 86 180 .get(&url) 87 181 .header("Accept", "application/json") ··· 96 190 )); 97 191 } 98 192 99 - let lexicon: LexiconDoc = response 193 + let record: GetRecordResponse = response 100 194 .json() 101 195 .await 102 - .map_err(|e| format!("Failed to parse lexicon: {}", e))?; 196 + .map_err(|e| format!("Failed to parse lexicon response: {}", e))?; 103 197 104 - let main_def = lexicon 105 - .defs 106 - .get("main") 107 - .ok_or("Missing 'main' definition in lexicon")?; 198 + Ok(record.value) 199 + } 108 200 109 - if main_def.def_type != "permission-set" { 110 - return Err(format!( 111 - "Expected permission-set type, got: {}", 112 - main_def.def_type 113 - )); 201 + async fn resolve_lexicon_did_authority(authority: &str) -> Result<String, String> { 202 + let resolver = TokioAsyncResolver::tokio_from_system_conf() 203 + .map_err(|e| format!("Failed to create DNS resolver: {}", e))?; 204 + 205 + let dns_name = format!("_lexicon.{}", authority); 206 + debug!(dns_name = %dns_name, "Looking up DNS TXT record"); 207 + 208 + let txt_records = resolver 209 + .txt_lookup(&dns_name) 210 + .await 211 + .map_err(|e| format!("DNS lookup failed for {}: {}", dns_name, e))?; 212 + 213 + txt_records 214 + .iter() 215 + .flat_map(|record| record.iter()) 216 + .find_map(|data| { 217 + let txt = String::from_utf8_lossy(data); 218 + txt.strip_prefix("did=").map(|did| did.to_string()) 219 + }) 220 + .ok_or_else(|| format!("No valid did= TXT record found at {}", dns_name)) 221 + } 222 + 223 + async fn resolve_did_to_pds(did: &str) -> Result<String, String> { 224 + let client = Client::builder() 225 + .timeout(std::time::Duration::from_secs(10)) 226 + .build() 227 + .map_err(|e| format!("Failed to create HTTP client: {}", e))?; 228 + 229 + let url = if did.starts_with("did:plc:") { 230 + format!("https://plc.directory/{}", did) 231 + } else if did.starts_with("did:web:") { 232 + let domain = did.strip_prefix("did:web:").unwrap(); 233 + format!("https://{}/.well-known/did.json", domain) 234 + } else { 235 + return Err(format!("Unsupported DID method: {}", did)); 236 + }; 237 + 238 + let response = client 239 + .get(&url) 240 + .header("Accept", "application/json") 241 + .send() 242 + .await 243 + .map_err(|e| format!("Failed to resolve DID: {}", e))?; 244 + 245 + if !response.status().is_success() { 246 + return Err(format!("Failed to resolve DID: HTTP {}", response.status())); 114 247 } 115 248 116 - let permissions = main_def 117 - .permissions 118 - .as_ref() 119 - .ok_or("Missing permissions in permission-set")?; 249 + let doc: PlcDocument = response 250 + .json() 251 + .await 252 + .map_err(|e| format!("Failed to parse DID document: {}", e))?; 120 253 121 - let mut collections: Vec<String> = permissions 254 + doc.service 122 255 .iter() 123 - .filter(|perm| perm.resource == "repo") 124 - .filter_map(|perm| perm.collection.as_ref()) 125 - .flatten() 126 - .cloned() 127 - .collect(); 256 + .find(|s| s.id == "#atproto_pds") 257 + .map(|s| s.service_endpoint.clone()) 258 + .ok_or_else(|| "No #atproto_pds service found in DID document".to_string()) 259 + } 128 260 129 - if collections.is_empty() { 130 - return Err("No repo collections found in permission-set".to_string()); 261 + fn extract_namespace_authority(nsid: &str) -> String { 262 + let parts: Vec<&str> = nsid.split('.').collect(); 263 + if parts.len() >= 2 { 264 + parts[..parts.len() - 1].join(".") 265 + } else { 266 + nsid.to_string() 131 267 } 268 + } 132 269 133 - collections.sort(); 270 + fn is_under_authority(target_nsid: &str, authority: &str) -> bool { 271 + target_nsid.starts_with(authority) 272 + && target_nsid 273 + .chars() 274 + .nth(authority.len()) 275 + .is_some_and(|c| c == '.') 276 + } 277 + 278 + const DEFAULT_ACTIONS: &[&str] = &["create", "update", "delete"]; 279 + 280 + fn build_expanded_scopes( 281 + permissions: &[PermissionEntry], 282 + default_aud: Option<&str>, 283 + namespace_authority: &str, 284 + ) -> String { 285 + let mut scopes: Vec<String> = Vec::new(); 134 286 135 - let collection_params: Vec<String> = collections 287 + permissions 136 288 .iter() 137 - .map(|c| format!("collection={}", c)) 138 - .collect(); 289 + .for_each(|perm| match perm.resource.as_str() { 290 + "repo" => { 291 + if let Some(collections) = &perm.collection { 292 + let actions: Vec<&str> = perm 293 + .action 294 + .as_ref() 295 + .map(|a| a.iter().map(String::as_str).collect()) 296 + .unwrap_or_else(|| DEFAULT_ACTIONS.to_vec()); 139 297 140 - let expanded = format!("repo?{}", collection_params.join("&")); 298 + collections 299 + .iter() 300 + .filter(|coll| is_under_authority(coll, namespace_authority)) 301 + .for_each(|coll| { 302 + actions.iter().for_each(|action| { 303 + scopes.push(format!("repo:{}?action={}", coll, action)); 304 + }); 305 + }); 306 + } 307 + } 308 + "rpc" => { 309 + if let Some(lxms) = &perm.lxm { 310 + let perm_aud = perm.aud.as_deref().or(default_aud); 141 311 142 - { 143 - let mut cache = LEXICON_CACHE.write().await; 144 - cache.insert( 145 - nsid.to_string(), 146 - CachedLexicon { 147 - expanded_scope: expanded.clone(), 148 - cached_at: std::time::Instant::now(), 149 - }, 150 - ); 151 - } 312 + lxms.iter().for_each(|lxm| { 313 + let scope = match perm_aud { 314 + Some(aud) => format!("rpc:{}?aud={}", lxm, aud), 315 + None => format!("rpc:{}", lxm), 316 + }; 317 + scopes.push(scope); 318 + }); 319 + } 320 + } 321 + _ => {} 322 + }); 152 323 153 - debug!(nsid, expanded = %expanded, "Successfully expanded permission set"); 154 - Ok(expanded) 324 + scopes.join(" ") 155 325 } 156 326 157 327 #[cfg(test)] 158 328 mod tests { 329 + use super::*; 330 + 159 331 #[test] 160 - fn test_nsid_to_url() { 332 + fn test_parse_include_scope() { 333 + let (nsid, aud) = parse_include_scope("io.atcr.authFullApp"); 334 + assert_eq!(nsid, "io.atcr.authFullApp"); 335 + assert_eq!(aud, None); 336 + 337 + let (nsid, aud) = parse_include_scope("io.atcr.authFullApp?aud=did:web:api.bsky.app"); 338 + assert_eq!(nsid, "io.atcr.authFullApp"); 339 + assert_eq!(aud, Some("did:web:api.bsky.app")); 340 + } 341 + 342 + #[test] 343 + fn test_parse_include_scope_with_multiple_params() { 344 + let (nsid, aud) = 345 + parse_include_scope("io.atcr.authFullApp?foo=bar&aud=did:web:example.com&baz=qux"); 346 + assert_eq!(nsid, "io.atcr.authFullApp"); 347 + assert_eq!(aud, Some("did:web:example.com")); 348 + } 349 + 350 + #[test] 351 + fn test_extract_namespace_authority() { 352 + assert_eq!( 353 + extract_namespace_authority("io.atcr.authFullApp"), 354 + "io.atcr" 355 + ); 356 + assert_eq!( 357 + extract_namespace_authority("app.bsky.authFullApp"), 358 + "app.bsky" 359 + ); 360 + } 361 + 362 + #[test] 363 + fn test_extract_namespace_authority_deep_nesting() { 364 + assert_eq!( 365 + extract_namespace_authority("io.atcr.sailor.star.collection"), 366 + "io.atcr.sailor.star" 367 + ); 368 + } 369 + 370 + #[test] 371 + fn test_extract_namespace_authority_single_segment() { 372 + assert_eq!(extract_namespace_authority("single"), "single"); 373 + } 374 + 375 + #[test] 376 + fn test_is_under_authority() { 377 + assert!(is_under_authority("io.atcr.manifest", "io.atcr")); 378 + assert!(is_under_authority("io.atcr.sailor.star", "io.atcr")); 379 + assert!(!is_under_authority("app.bsky.feed.post", "io.atcr")); 380 + assert!(!is_under_authority("io.atcr", "io.atcr")); 381 + } 382 + 383 + #[test] 384 + fn test_is_under_authority_prefix_collision() { 385 + assert!(!is_under_authority("io.atcritical.something", "io.atcr")); 386 + assert!(is_under_authority("io.atcr.something", "io.atcr")); 387 + } 388 + 389 + #[test] 390 + fn test_build_expanded_scopes_repo() { 391 + let permissions = vec![PermissionEntry { 392 + resource: "repo".to_string(), 393 + action: Some(vec!["create".to_string(), "delete".to_string()]), 394 + collection: Some(vec![ 395 + "io.atcr.manifest".to_string(), 396 + "io.atcr.sailor.star".to_string(), 397 + "app.bsky.feed.post".to_string(), 398 + ]), 399 + lxm: None, 400 + aud: None, 401 + }]; 402 + 403 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 404 + assert!(expanded.contains("repo:io.atcr.manifest?action=create")); 405 + assert!(expanded.contains("repo:io.atcr.manifest?action=delete")); 406 + assert!(expanded.contains("repo:io.atcr.sailor.star?action=create")); 407 + assert!(!expanded.contains("app.bsky.feed.post")); 408 + } 409 + 410 + #[test] 411 + fn test_build_expanded_scopes_repo_default_actions() { 412 + let permissions = vec![PermissionEntry { 413 + resource: "repo".to_string(), 414 + action: None, 415 + collection: Some(vec!["io.atcr.manifest".to_string()]), 416 + lxm: None, 417 + aud: None, 418 + }]; 419 + 420 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 421 + assert!(expanded.contains("repo:io.atcr.manifest?action=create")); 422 + assert!(expanded.contains("repo:io.atcr.manifest?action=update")); 423 + assert!(expanded.contains("repo:io.atcr.manifest?action=delete")); 424 + } 425 + 426 + #[test] 427 + fn test_build_expanded_scopes_rpc() { 428 + let permissions = vec![PermissionEntry { 429 + resource: "rpc".to_string(), 430 + action: None, 431 + collection: None, 432 + lxm: Some(vec![ 433 + "io.atcr.getManifest".to_string(), 434 + "com.atproto.repo.getRecord".to_string(), 435 + ]), 436 + aud: Some("*".to_string()), 437 + }]; 438 + 439 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 440 + assert!(expanded.contains("rpc:io.atcr.getManifest?aud=*")); 441 + assert!(expanded.contains("rpc:com.atproto.repo.getRecord?aud=*")); 442 + } 443 + 444 + #[test] 445 + fn test_build_expanded_scopes_rpc_with_default_aud() { 446 + let permissions = vec![PermissionEntry { 447 + resource: "rpc".to_string(), 448 + action: None, 449 + collection: None, 450 + lxm: Some(vec!["io.atcr.getManifest".to_string()]), 451 + aud: None, 452 + }]; 453 + 454 + let expanded = 455 + build_expanded_scopes(&permissions, Some("did:web:api.example.com"), "io.atcr"); 456 + assert!(expanded.contains("rpc:io.atcr.getManifest?aud=did:web:api.example.com")); 457 + } 458 + 459 + #[test] 460 + fn test_build_expanded_scopes_rpc_no_aud() { 461 + let permissions = vec![PermissionEntry { 462 + resource: "rpc".to_string(), 463 + action: None, 464 + collection: None, 465 + lxm: Some(vec!["io.atcr.getManifest".to_string()]), 466 + aud: None, 467 + }]; 468 + 469 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 470 + assert_eq!(expanded, "rpc:io.atcr.getManifest"); 471 + } 472 + 473 + #[test] 474 + fn test_build_expanded_scopes_mixed_permissions() { 475 + let permissions = vec![ 476 + PermissionEntry { 477 + resource: "repo".to_string(), 478 + action: Some(vec!["create".to_string()]), 479 + collection: Some(vec!["io.atcr.manifest".to_string()]), 480 + lxm: None, 481 + aud: None, 482 + }, 483 + PermissionEntry { 484 + resource: "rpc".to_string(), 485 + action: None, 486 + collection: None, 487 + lxm: Some(vec!["com.atproto.repo.getRecord".to_string()]), 488 + aud: Some("*".to_string()), 489 + }, 490 + ]; 491 + 492 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 493 + assert!(expanded.contains("repo:io.atcr.manifest?action=create")); 494 + assert!(expanded.contains("rpc:com.atproto.repo.getRecord?aud=*")); 495 + } 496 + 497 + #[test] 498 + fn test_build_expanded_scopes_unknown_resource_ignored() { 499 + let permissions = vec![PermissionEntry { 500 + resource: "unknown".to_string(), 501 + action: None, 502 + collection: Some(vec!["io.atcr.manifest".to_string()]), 503 + lxm: None, 504 + aud: None, 505 + }]; 506 + 507 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 508 + assert!(expanded.is_empty()); 509 + } 510 + 511 + #[test] 512 + fn test_build_expanded_scopes_empty_permissions() { 513 + let permissions: Vec<PermissionEntry> = vec![]; 514 + let expanded = build_expanded_scopes(&permissions, None, "io.atcr"); 515 + assert!(expanded.is_empty()); 516 + } 517 + 518 + #[tokio::test] 519 + async fn test_expand_include_scopes_passthrough_non_include() { 520 + let result = expand_include_scopes("atproto transition:generic").await; 521 + assert_eq!(result, "atproto transition:generic"); 522 + } 523 + 524 + #[tokio::test] 525 + async fn test_expand_include_scopes_mixed_with_regular() { 526 + let result = expand_include_scopes("atproto repo:app.bsky.feed.post?action=create").await; 527 + assert!(result.contains("atproto")); 528 + assert!(result.contains("repo:app.bsky.feed.post?action=create")); 529 + } 530 + 531 + #[tokio::test] 532 + async fn test_cache_population_and_retrieval() { 533 + let cache_key = "test.cached.scope"; 534 + let cached_value = "repo:test.cached.collection?action=create"; 535 + 536 + { 537 + let mut cache = LEXICON_CACHE.write().await; 538 + cache.insert( 539 + cache_key.to_string(), 540 + CachedLexicon { 541 + expanded_scope: cached_value.to_string(), 542 + cached_at: std::time::Instant::now(), 543 + }, 544 + ); 545 + } 546 + 547 + let result = expand_permission_set(cache_key, None).await; 548 + assert!(result.is_ok()); 549 + assert_eq!(result.unwrap(), cached_value); 550 + 551 + { 552 + let mut cache = LEXICON_CACHE.write().await; 553 + cache.remove(cache_key); 554 + } 555 + } 556 + 557 + #[tokio::test] 558 + async fn test_cache_with_aud_parameter() { 559 + let nsid = "test.aud.scope"; 560 + let aud = "did:web:example.com"; 561 + let cache_key = format!("{}?aud={}", nsid, aud); 562 + let cached_value = "rpc:test.aud.method?aud=did:web:example.com"; 563 + 564 + { 565 + let mut cache = LEXICON_CACHE.write().await; 566 + cache.insert( 567 + cache_key.clone(), 568 + CachedLexicon { 569 + expanded_scope: cached_value.to_string(), 570 + cached_at: std::time::Instant::now(), 571 + }, 572 + ); 573 + } 574 + 575 + let result = expand_permission_set(nsid, Some(aud)).await; 576 + assert!(result.is_ok()); 577 + assert_eq!(result.unwrap(), cached_value); 578 + 579 + { 580 + let mut cache = LEXICON_CACHE.write().await; 581 + cache.remove(&cache_key); 582 + } 583 + } 584 + 585 + #[tokio::test] 586 + async fn test_expired_cache_triggers_refresh() { 587 + let cache_key = "test.expired.scope"; 588 + 589 + { 590 + let mut cache = LEXICON_CACHE.write().await; 591 + cache.insert( 592 + cache_key.to_string(), 593 + CachedLexicon { 594 + expanded_scope: "old_value".to_string(), 595 + cached_at: std::time::Instant::now() 596 + - std::time::Duration::from_secs(CACHE_TTL_SECS + 1), 597 + }, 598 + ); 599 + } 600 + 601 + let result = expand_permission_set(cache_key, None).await; 602 + assert!(result.is_err()); 603 + 604 + { 605 + let mut cache = LEXICON_CACHE.write().await; 606 + cache.remove(cache_key); 607 + } 608 + } 609 + 610 + #[test] 611 + fn test_nsid_authority_extraction_for_dns() { 161 612 let nsid = "io.atcr.authFullApp"; 162 613 let parts: Vec<&str> = nsid.split('.').collect(); 163 - let domain_parts: Vec<&str> = parts[..2].iter().rev().cloned().collect(); 164 - let domain = domain_parts.join("."); 165 - let path = parts[2..].join("/"); 614 + let authority = parts[..2] 615 + .iter() 616 + .rev() 617 + .cloned() 618 + .collect::<Vec<_>>() 619 + .join("."); 620 + assert_eq!(authority, "atcr.io"); 166 621 167 - assert_eq!(domain, "atcr.io"); 168 - assert_eq!(path, "authFullApp"); 622 + let nsid2 = "app.bsky.feed.post"; 623 + let parts2: Vec<&str> = nsid2.split('.').collect(); 624 + let authority2 = parts2[..2] 625 + .iter() 626 + .rev() 627 + .cloned() 628 + .collect::<Vec<_>>() 629 + .join("."); 630 + assert_eq!(authority2, "bsky.app"); 169 631 } 170 632 }
+38 -3
crates/tranquil-scopes/src/permissions.rs
··· 126 126 return Ok(()); 127 127 } 128 128 129 - let has_permission = self.find_repo_scopes().any(|repo_scope| { 129 + let has_repo_permission = self.find_repo_scopes().any(|repo_scope| { 130 130 repo_scope.actions.contains(&action) 131 131 && match &repo_scope.collection { 132 132 None => true, ··· 140 140 } 141 141 }); 142 142 143 - if has_permission { 143 + if has_repo_permission { 144 144 Ok(()) 145 145 } else { 146 146 Err(ScopeError::InsufficientScope { ··· 181 181 return Ok(()); 182 182 } 183 183 184 + let aud_base = aud.split('#').next().unwrap_or(aud); 185 + 184 186 let has_permission = self.find_rpc_scopes().any(|rpc_scope| { 185 187 let lxm_matches = match &rpc_scope.lxm { 186 188 None => true, ··· 195 197 let aud_matches = match &rpc_scope.aud { 196 198 None => true, 197 199 Some(scope_aud) if scope_aud == "*" => true, 198 - Some(scope_aud) => scope_aud == aud, 200 + Some(scope_aud) => { 201 + let scope_aud_base = scope_aud.split('#').next().unwrap_or(scope_aud); 202 + scope_aud_base == aud_base 203 + } 199 204 }; 200 205 201 206 lxm_matches && aud_matches ··· 520 525 assert!(perms.allows_repo(RepoAction::Update, "any.collection")); 521 526 assert!(perms.allows_blob("image/png")); 522 527 assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline")); 528 + } 529 + 530 + #[test] 531 + fn test_rpc_scope_with_did_fragment() { 532 + let perms = ScopePermissions::from_scope_string(Some( 533 + "rpc:app.bsky.feed.getAuthorFeed?aud=did:web:api.bsky.app#bsky_appview", 534 + )); 535 + assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getAuthorFeed")); 536 + assert!(perms.allows_rpc( 537 + "did:web:api.bsky.app#bsky_appview", 538 + "app.bsky.feed.getAuthorFeed" 539 + )); 540 + assert!(perms.allows_rpc( 541 + "did:web:api.bsky.app#other_service", 542 + "app.bsky.feed.getAuthorFeed" 543 + )); 544 + assert!(!perms.allows_rpc("did:web:other.app", "app.bsky.feed.getAuthorFeed")); 545 + assert!(!perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getTimeline")); 546 + } 547 + 548 + #[test] 549 + fn test_rpc_scope_without_fragment_matches_with_fragment() { 550 + let perms = ScopePermissions::from_scope_string(Some( 551 + "rpc:app.bsky.feed.getAuthorFeed?aud=did:web:api.bsky.app", 552 + )); 553 + assert!(perms.allows_rpc("did:web:api.bsky.app", "app.bsky.feed.getAuthorFeed")); 554 + assert!(perms.allows_rpc( 555 + "did:web:api.bsky.app#bsky_appview", 556 + "app.bsky.feed.getAuthorFeed" 557 + )); 523 558 } 524 559 }
+24 -16
crates/tranquil-storage/src/lib.rs
··· 22 22 const CID_SHARD_PREFIX_LEN: usize = 9; 23 23 24 24 fn split_cid_path(key: &str) -> Option<(&str, &str)> { 25 - let is_cid = key.get(..3).map_or(false, |p| p.eq_ignore_ascii_case("baf")); 26 - (key.len() > CID_SHARD_PREFIX_LEN && is_cid) 27 - .then(|| key.split_at(CID_SHARD_PREFIX_LEN)) 25 + let is_cid = key.get(..3).is_some_and(|p| p.eq_ignore_ascii_case("baf")); 26 + (key.len() > CID_SHARD_PREFIX_LEN && is_cid).then(|| key.split_at(CID_SHARD_PREFIX_LEN)) 28 27 } 29 28 30 29 fn validate_key(key: &str) -> Result<(), StorageError> { ··· 771 770 let cid = "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"; 772 771 assert_eq!( 773 772 split_cid_path(cid), 774 - Some(("bafkreihd", "wdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku")) 773 + Some(( 774 + "bafkreihd", 775 + "wdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku" 776 + )) 775 777 ); 776 778 } 777 779 ··· 780 782 let cid = "bafyreigdmqpykrgxyaxtlafqpqhzrb7qy2rh75nldvfd4tucqmqqme5yje"; 781 783 assert_eq!( 782 784 split_cid_path(cid), 783 - Some(("bafyreigd", "mqpykrgxyaxtlafqpqhzrb7qy2rh75nldvfd4tucqmqqme5yje")) 785 + Some(( 786 + "bafyreigd", 787 + "mqpykrgxyaxtlafqpqhzrb7qy2rh75nldvfd4tucqmqqme5yje" 788 + )) 784 789 ); 785 790 } 786 791 ··· 810 815 let mixed = "BaFkReIhDwDcEfGh4DqKjV67UzCmW7OjEe6XeDzDeTojUzJevTeNxQuVyKu"; 811 816 assert_eq!( 812 817 split_cid_path(upper), 813 - Some(("BAFKREIHD", "WDCEFGH4DQKJV67UZCMW7OJEE6XEDZDETOJUZJEVTENXQUVYKU")) 818 + Some(( 819 + "BAFKREIHD", 820 + "WDCEFGH4DQKJV67UZCMW7OJEE6XEDZDETOJUZJEVTENXQUVYKU" 821 + )) 814 822 ); 815 823 assert_eq!( 816 824 split_cid_path(mixed), 817 - Some(("BaFkReIhD", "wDcEfGh4DqKjV67UzCmW7OjEe6XeDzDeTojUzJevTeNxQuVyKu")) 825 + Some(( 826 + "BaFkReIhD", 827 + "wDcEfGh4DqKjV67UzCmW7OjEe6XeDzDeTojUzJevTeNxQuVyKu" 828 + )) 818 829 ); 819 830 } 820 831 ··· 829 840 let base = PathBuf::from("/blobs"); 830 841 let cid = "bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"; 831 842 832 - let expected = PathBuf::from("/blobs/bafkreihd/wdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"); 833 - let result = split_cid_path(cid).map_or_else( 834 - || base.join(cid), 835 - |(dir, file)| base.join(dir).join(file), 836 - ); 843 + let expected = 844 + PathBuf::from("/blobs/bafkreihd/wdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku"); 845 + let result = split_cid_path(cid) 846 + .map_or_else(|| base.join(cid), |(dir, file)| base.join(dir).join(file)); 837 847 assert_eq!(result, expected); 838 848 } 839 849 ··· 843 853 let key = "temp/abc123"; 844 854 845 855 let expected = PathBuf::from("/blobs/temp/abc123"); 846 - let result = split_cid_path(key).map_or_else( 847 - || base.join(key), 848 - |(dir, file)| base.join(dir).join(file), 849 - ); 856 + let result = split_cid_path(key) 857 + .map_or_else(|| base.join(key), |(dir, file)| base.join(dir).join(file)); 850 858 assert_eq!(result, expected); 851 859 } 852 860 }
+100 -35
frontend/src/lib/api.ts
··· 16 16 unsafeAsISODate, 17 17 unsafeAsRefreshToken, 18 18 } from "./types/branded.ts"; 19 + import { 20 + createDPoPProofForRequest, 21 + getDPoPNonce, 22 + setDPoPNonce, 23 + } from "./oauth.ts"; 19 24 import type { 20 25 AccountInfo, 21 26 ApiErrorCode, ··· 91 96 } 92 97 } 93 98 94 - let tokenRefreshCallback: (() => Promise<string | null>) | null = null; 99 + let tokenRefreshCallback: (() => Promise<AccessToken | null>) | null = null; 95 100 96 101 export function setTokenRefreshCallback( 97 - callback: () => Promise<string | null>, 102 + callback: () => Promise<AccessToken | null>, 98 103 ) { 99 104 tokenRefreshCallback = callback; 100 105 } 101 106 107 + interface AuthenticatedFetchOptions { 108 + method?: "GET" | "POST"; 109 + token: AccessToken | RefreshToken; 110 + headers?: Record<string, string>; 111 + body?: BodyInit; 112 + } 113 + 114 + async function authenticatedFetch( 115 + url: string, 116 + options: AuthenticatedFetchOptions, 117 + ): Promise<Response> { 118 + const { method = "GET", token, headers = {}, body } = options; 119 + const fullUrl = url.startsWith("http") 120 + ? url 121 + : `${globalThis.location.origin}${url}`; 122 + const dpopProof = await createDPoPProofForRequest(method, fullUrl, token); 123 + const res = await fetch(url, { 124 + method, 125 + headers: { 126 + ...headers, 127 + Authorization: `DPoP ${token}`, 128 + DPoP: dpopProof, 129 + }, 130 + body, 131 + }); 132 + const dpopNonce = res.headers.get("DPoP-Nonce"); 133 + if (dpopNonce) { 134 + setDPoPNonce(dpopNonce); 135 + } 136 + return res; 137 + } 138 + 102 139 interface XrpcOptions { 103 140 method?: "GET" | "POST"; 104 141 params?: Record<string, string>; 105 142 body?: unknown; 106 - token?: string; 143 + token?: AccessToken | RefreshToken; 107 144 skipRetry?: boolean; 145 + skipDpopRetry?: boolean; 108 146 } 109 147 110 148 async function xrpc<T>(method: string, options?: XrpcOptions): Promise<T> { 111 - const { method: httpMethod = "GET", params, body, token, skipRetry } = 112 - options ?? {}; 149 + const { 150 + method: httpMethod = "GET", 151 + params, 152 + body, 153 + token, 154 + skipRetry, 155 + skipDpopRetry, 156 + } = options ?? {}; 113 157 let url = `${API_BASE}/${method}`; 114 158 if (params) { 115 159 const searchParams = new URLSearchParams(params); 116 160 url += `?${searchParams}`; 117 161 } 118 162 const headers: Record<string, string> = {}; 119 - if (token) { 120 - headers["Authorization"] = `Bearer ${token}`; 121 - } 122 163 if (body) { 123 164 headers["Content-Type"] = "application/json"; 124 165 } 125 - const res = await fetch(url, { 126 - method: httpMethod, 127 - headers, 128 - body: body ? JSON.stringify(body) : undefined, 129 - }); 166 + const res = token 167 + ? await authenticatedFetch(url, { 168 + method: httpMethod, 169 + token, 170 + headers, 171 + body: body ? JSON.stringify(body) : undefined, 172 + }) 173 + : await fetch(url, { 174 + method: httpMethod, 175 + headers, 176 + body: body ? JSON.stringify(body) : undefined, 177 + }); 130 178 if (!res.ok) { 131 179 const errData = await res.json().catch(() => ({ 132 180 error: "Unknown", ··· 134 182 })); 135 183 if ( 136 184 res.status === 401 && 185 + errData.error === "use_dpop_nonce" && 186 + token && 187 + !skipDpopRetry && 188 + getDPoPNonce() 189 + ) { 190 + return xrpc(method, { ...options, skipDpopRetry: true }); 191 + } 192 + if ( 193 + res.status === 401 && 137 194 (errData.error === "AuthenticationFailed" || 138 - errData.error === "ExpiredToken") && 139 - token && tokenRefreshCallback && !skipRetry 195 + errData.error === "ExpiredToken" || 196 + errData.error === "OAuthExpiredToken") && 197 + token && 198 + tokenRefreshCallback && 199 + !skipRetry 140 200 ) { 141 201 const newToken = await tokenRefreshCallback(); 142 202 if (newToken && newToken !== token) { ··· 536 596 token: AccessToken, 537 597 file: File, 538 598 ): Promise<UploadBlobResponse> { 539 - const res = await fetch("/xrpc/com.atproto.repo.uploadBlob", { 599 + const res = await authenticatedFetch("/xrpc/com.atproto.repo.uploadBlob", { 540 600 method: "POST", 541 - headers: { 542 - "Authorization": `Bearer ${token}`, 543 - "Content-Type": file.type, 544 - }, 601 + token, 602 + headers: { "Content-Type": file.type }, 545 603 body: file, 546 604 }); 547 605 if (!res.ok) { ··· 1084 1142 }, 1085 1143 1086 1144 async getRepo(token: AccessToken, did: Did): Promise<ArrayBuffer> { 1087 - const url = `${API_BASE}/com.atproto.sync.getRepo?did=${ 1088 - encodeURIComponent(did) 1089 - }`; 1090 - const res = await fetch(url, { 1091 - headers: { Authorization: `Bearer ${token}` }, 1092 - }); 1145 + const url = `${API_BASE}/com.atproto.sync.getRepo?did=${encodeURIComponent(did)}`; 1146 + const res = await authenticatedFetch(url, { token }); 1093 1147 if (!res.ok) { 1094 1148 const errData = await res.json().catch(() => ({ 1095 1149 error: "Unknown", ··· 1106 1160 1107 1161 async getBackup(token: AccessToken, id: string): Promise<Blob> { 1108 1162 const url = `${API_BASE}/_backup.getBackup?id=${encodeURIComponent(id)}`; 1109 - const res = await fetch(url, { 1110 - headers: { Authorization: `Bearer ${token}` }, 1111 - }); 1163 + const res = await authenticatedFetch(url, { token }); 1112 1164 if (!res.ok) { 1113 1165 const errData = await res.json().catch(() => ({ 1114 1166 error: "Unknown", ··· 1146 1198 }, 1147 1199 1148 1200 async importRepo(token: AccessToken, car: Uint8Array): Promise<void> { 1149 - const url = `${API_BASE}/com.atproto.repo.importRepo`; 1150 - const res = await fetch(url, { 1201 + const res = await authenticatedFetch(`${API_BASE}/com.atproto.repo.importRepo`, { 1151 1202 method: "POST", 1152 - headers: { 1153 - Authorization: `Bearer ${token}`, 1154 - "Content-Type": "application/vnd.ipld.car", 1155 - }, 1203 + token, 1204 + headers: { "Content-Type": "application/vnd.ipld.car" }, 1156 1205 body: car as unknown as BodyInit, 1157 1206 }); 1158 1207 if (!res.ok) { ··· 1162 1211 })); 1163 1212 throw new ApiError(res.status, errData.error, errData.message); 1164 1213 } 1214 + }, 1215 + 1216 + async establishOAuthSession(token: AccessToken): Promise<{ success: boolean; device_id: string }> { 1217 + const res = await authenticatedFetch("/oauth/establish-session", { 1218 + method: "POST", 1219 + token, 1220 + headers: { "Content-Type": "application/json" }, 1221 + }); 1222 + if (!res.ok) { 1223 + const errData = await res.json().catch(() => ({ 1224 + error: "Unknown", 1225 + message: res.statusText, 1226 + })); 1227 + throw new ApiError(res.status, errData.error, errData.message); 1228 + } 1229 + return res.json(); 1165 1230 }, 1166 1231 }; 1167 1232
+1 -1
frontend/src/lib/auth.svelte.ts
··· 281 281 } 282 282 } 283 283 284 - async function tryRefreshToken(): Promise<string | null> { 284 + async function tryRefreshToken(): Promise<AccessToken | null> { 285 285 if (state.current.kind !== "authenticated") return null; 286 286 const currentSession = state.current.session; 287 287 try {
+18 -1
frontend/src/lib/migration/atproto-client.ts
··· 240 240 }&cid=${encodeURIComponent(cid)}`; 241 241 const headers: Record<string, string> = {}; 242 242 if (this.accessToken) { 243 - headers["Authorization"] = `Bearer ${this.accessToken}`; 243 + if (this.dpopKeyPair) { 244 + headers["Authorization"] = `DPoP ${this.accessToken}`; 245 + const tokenHash = await computeAccessTokenHash(this.accessToken); 246 + const dpopProof = await createDPoPProof( 247 + this.dpopKeyPair, 248 + "GET", 249 + url.split("?")[0], 250 + this.dpopNonce ?? undefined, 251 + tokenHash, 252 + ); 253 + headers["DPoP"] = dpopProof; 254 + } else { 255 + headers["Authorization"] = `Bearer ${this.accessToken}`; 256 + } 244 257 } 245 258 const res = await fetch(url, { headers }); 259 + const newNonce = res.headers.get("DPoP-Nonce"); 260 + if (newNonce) { 261 + this.dpopNonce = newNonce; 262 + } 246 263 if (!res.ok) { 247 264 const err = await res.json().catch(() => ({ 248 265 error: "Unknown",
+3 -1
frontend/src/lib/migration/flow.svelte.ts
··· 88 88 89 89 function setStep(step: InboundStep) { 90 90 state.step = step; 91 - state.error = null; 91 + if (step !== "error") { 92 + state.error = null; 93 + } 92 94 if (step !== "success") { 93 95 saveMigrationState(state); 94 96 updateStep(step);
+3 -1
frontend/src/lib/migration/offline-flow.svelte.ts
··· 177 177 178 178 function setStep(step: OfflineInboundStep) { 179 179 state.step = step; 180 - state.error = null; 180 + if (step !== "error") { 181 + state.error = null; 182 + } 181 183 if (step !== "success") { 182 184 saveOfflineState(state); 183 185 }
+3 -3
frontend/src/lib/oauth.ts
··· 246 246 return base64UrlEncode(hash); 247 247 } 248 248 249 - function getDPoPNonce(): string | null { 249 + export function getDPoPNonce(): string | null { 250 250 return sessionStorage.getItem(DPOP_NONCE_KEY); 251 251 } 252 252 253 - function setDPoPNonce(nonce: string): void { 253 + export function setDPoPNonce(nonce: string): void { 254 254 sessionStorage.setItem(DPOP_NONCE_KEY, nonce); 255 255 } 256 256 257 - function extractDPoPNonceFromResponse(response: Response): void { 257 + export function extractDPoPNonceFromResponse(response: Response): void { 258 258 const nonce = response.headers.get("DPoP-Nonce"); 259 259 if (nonce) { 260 260 setDPoPNonce(nonce);
+5
frontend/src/locales/en.json
··· 779 779 "name": "Manage Account", 780 780 "description": "Manage account settings and preferences" 781 781 } 782 + }, 783 + "unexpectedState": { 784 + "title": "Unexpected State", 785 + "description": "The consent page is in an unexpected state. Please check the browser console for errors.", 786 + "reload": "Reload Page" 782 787 } 783 788 }, 784 789 "accounts": {
+5
frontend/src/locales/fi.json
··· 785 785 "name": "Hallitse tiliรค", 786 786 "description": "Hallitse tilin asetuksia ja asetuksia" 787 787 } 788 + }, 789 + "unexpectedState": { 790 + "title": "Odottamaton tila", 791 + "description": "Suostumussivulla on odottamaton tila. Tarkista selaimen konsoli virheiden varalta.", 792 + "reload": "Lataa sivu uudelleen" 788 793 } 789 794 }, 790 795 "accounts": {
+5
frontend/src/locales/ja.json
··· 778 778 "name": "ใ‚ขใ‚ซใ‚ฆใƒณใƒˆ็ฎก็†", 779 779 "description": "ใ‚ขใ‚ซใ‚ฆใƒณใƒˆ่จญๅฎšใจ่จญๅฎšใ‚’็ฎก็†" 780 780 } 781 + }, 782 + "unexpectedState": { 783 + "title": "ไบˆๆœŸใ—ใชใ„็Šถๆ…‹", 784 + "description": "ๅŒๆ„ใƒšใƒผใ‚ธใŒไบˆๆœŸใ—ใชใ„็Šถๆ…‹ใงใ™ใ€‚ใƒ–ใƒฉใ‚ฆใ‚ถใฎใ‚ณใƒณใ‚ฝใƒผใƒซใงใ‚จใƒฉใƒผใ‚’็ขบ่ชใ—ใฆใใ ใ•ใ„ใ€‚", 785 + "reload": "ใƒšใƒผใ‚ธใ‚’ๅ†่ชญใฟ่พผใฟ" 781 786 } 782 787 }, 783 788 "accounts": {
+5
frontend/src/locales/ko.json
··· 778 778 "name": "๊ณ„์ • ๊ด€๋ฆฌ", 779 779 "description": "๊ณ„์ • ์„ค์ • ๋ฐ ํ™˜๊ฒฝ์„ค์ • ๊ด€๋ฆฌ" 780 780 } 781 + }, 782 + "unexpectedState": { 783 + "title": "์˜ˆ๊ธฐ์น˜ ์•Š์€ ์ƒํƒœ", 784 + "description": "๋™์˜ ํŽ˜์ด์ง€๊ฐ€ ์˜ˆ๊ธฐ์น˜ ์•Š์€ ์ƒํƒœ์ž…๋‹ˆ๋‹ค. ๋ธŒ๋ผ์šฐ์ € ์ฝ˜์†”์—์„œ ์˜ค๋ฅ˜๋ฅผ ํ™•์ธํ•˜์„ธ์š”.", 785 + "reload": "ํŽ˜์ด์ง€ ์ƒˆ๋กœ๊ณ ์นจ" 781 786 } 782 787 }, 783 788 "accounts": {
+5
frontend/src/locales/sv.json
··· 778 778 "name": "Hantera konto", 779 779 "description": "Hantera kontoinstรคllningar och preferenser" 780 780 } 781 + }, 782 + "unexpectedState": { 783 + "title": "Ovรคntat tillstรฅnd", 784 + "description": "Samtyckes-sidan รคr i ett ovรคntat tillstรฅnd. Kontrollera webblรคsarens konsol fรถr fel.", 785 + "reload": "Ladda om sidan" 781 786 } 782 787 }, 783 788 "accounts": {
+5
frontend/src/locales/zh.json
··· 778 778 "name": "็ฎก็†่ดฆๆˆท", 779 779 "description": "็ฎก็†่ดฆๆˆท่ฎพ็ฝฎๅ’Œๅๅฅฝ" 780 780 } 781 + }, 782 + "unexpectedState": { 783 + "title": "ๆ„ๅค–็Šถๆ€", 784 + "description": "ๅŒๆ„้กต้ขๅค„ไบŽๆ„ๅค–็Šถๆ€ใ€‚่ฏทๆฃ€ๆŸฅๆต่งˆๅ™จๆŽงๅˆถๅฐไปฅๆŸฅ็œ‹้”™่ฏฏใ€‚", 785 + "reload": "้‡ๆ–ฐๅŠ ่ฝฝ้กต้ข" 781 786 } 782 787 }, 783 788 "accounts": {
+37 -16
frontend/src/routes/Migration.svelte
··· 2 2 import { setSession } from '../lib/auth.svelte' 3 3 import { navigate, routes } from '../lib/router.svelte' 4 4 import { _ } from '../lib/i18n' 5 + import { api } from '../lib/api' 6 + import { startOAuthLogin } from '../lib/oauth' 7 + import { unsafeAsAccessToken } from '../lib/types/branded' 5 8 import { 6 9 createInboundMigrationFlow, 7 10 createOfflineInboundMigrationFlow, ··· 143 146 direction = 'select' 144 147 } 145 148 146 - function handleInboundComplete() { 149 + async function handleInboundComplete() { 147 150 const session = inboundFlow?.getLocalSession() 148 151 if (session) { 149 - setSession({ 150 - did: session.did, 151 - handle: session.handle, 152 - accessJwt: session.accessJwt, 153 - refreshJwt: '', 154 - }) 152 + try { 153 + await api.establishOAuthSession(unsafeAsAccessToken(session.accessJwt)) 154 + clearMigrationState() 155 + await startOAuthLogin(session.handle) 156 + } catch (e) { 157 + console.error('Failed to establish OAuth session, falling back to direct login:', e) 158 + setSession({ 159 + did: session.did, 160 + handle: session.handle, 161 + accessJwt: session.accessJwt, 162 + refreshJwt: '', 163 + }) 164 + navigate(routes.dashboard) 165 + } 166 + } else { 167 + navigate(routes.dashboard) 155 168 } 156 - navigate(routes.dashboard) 157 169 } 158 170 159 - function handleOfflineComplete() { 171 + async function handleOfflineComplete() { 160 172 const session = offlineFlow?.getLocalSession() 161 173 if (session) { 162 - setSession({ 163 - did: session.did, 164 - handle: session.handle, 165 - accessJwt: session.accessJwt, 166 - refreshJwt: '', 167 - }) 174 + try { 175 + await api.establishOAuthSession(unsafeAsAccessToken(session.accessJwt)) 176 + clearOfflineState() 177 + await startOAuthLogin(session.handle) 178 + } catch (e) { 179 + console.error('Failed to establish OAuth session, falling back to direct login:', e) 180 + setSession({ 181 + did: session.did, 182 + handle: session.handle, 183 + accessJwt: session.accessJwt, 184 + refreshJwt: '', 185 + }) 186 + navigate(routes.dashboard) 187 + } 188 + } else { 189 + navigate(routes.dashboard) 168 190 } 169 - navigate(routes.dashboard) 170 191 } 171 192 </script> 172 193
+5 -31
frontend/src/routes/OAuthAccounts.svelte
··· 196 196 display: flex; 197 197 align-items: center; 198 198 padding: var(--space-4); 199 - background: var(--bg-card); 199 + background: var(--bg-secondary); 200 200 border: 1px solid var(--border-color); 201 201 border-radius: var(--radius-xl); 202 202 cursor: pointer; 203 203 text-align: left; 204 204 width: 100%; 205 - transition: border-color var(--transition-fast), box-shadow var(--transition-fast); 205 + transition: border-color var(--transition-fast), background var(--transition-fast); 206 206 } 207 207 208 208 .account-item:hover:not(.disabled) { 209 209 border-color: var(--accent); 210 - box-shadow: var(--shadow-sm); 210 + background: var(--bg-tertiary); 211 211 } 212 212 213 213 .account-item.disabled { ··· 231 231 color: var(--text-secondary); 232 232 } 233 233 234 - button { 235 - padding: var(--space-3); 236 - background: var(--accent); 237 - color: var(--text-inverse); 238 - border: none; 239 - border-radius: var(--radius-md); 240 - font-size: var(--text-base); 241 - cursor: pointer; 242 - } 243 - 244 - button:hover:not(:disabled) { 245 - background: var(--accent-hover); 246 - } 247 - 248 - button:disabled { 249 - opacity: 0.6; 250 - cursor: not-allowed; 251 - } 252 - 253 - button.secondary { 254 - background: transparent; 255 - color: var(--accent); 256 - border: 1px solid var(--accent); 234 + .different-account { 235 + margin-top: var(--space-4); 257 236 width: 100%; 258 - } 259 - 260 - button.secondary:hover:not(:disabled) { 261 - background: var(--accent); 262 - color: var(--text-inverse); 263 237 } 264 238 265 239 .different-account {
+38 -3
frontend/src/routes/OAuthConsent.svelte
··· 65 65 async function fetchConsentData() { 66 66 const requestUri = getRequestUri() 67 67 if (!requestUri) { 68 + console.error('[OAuthConsent] No request_uri in URL') 68 69 error = $_('oauth.error.genericError') 69 70 loading = false 70 71 return ··· 74 75 const response = await fetch(`/oauth/authorize/consent?request_uri=${encodeURIComponent(requestUri)}`) 75 76 if (!response.ok) { 76 77 const data = await response.json() 78 + console.error('[OAuthConsent] Consent fetch failed:', data) 77 79 error = data.error_description || data.error || $_('oauth.error.genericError') 78 80 loading = false 79 81 return 80 82 } 81 83 const data: ConsentData = await response.json() 84 + 85 + if (!data.scopes || !Array.isArray(data.scopes)) { 86 + console.error('[OAuthConsent] Invalid scopes data:', data.scopes) 87 + error = 'Invalid consent data received' 88 + loading = false 89 + return 90 + } 91 + 82 92 consentData = data 83 93 84 94 scopeSelections = Object.fromEntries( ··· 91 101 if (!data.show_consent) { 92 102 await submitConsent() 93 103 } 94 - } catch { 104 + } catch (e) { 105 + console.error('[OAuthConsent] Error during consent fetch:', e) 95 106 error = $_('oauth.error.genericError') 96 107 } finally { 97 108 loading = false ··· 104 115 } 105 116 106 117 async function submitConsent() { 107 - if (!consentData) return 118 + if (!consentData) { 119 + console.error('[OAuthConsent] submitConsent called but no consentData') 120 + return 121 + } 108 122 109 123 submitting = true 110 124 let approvedScopes = Object.entries(scopeSelections) ··· 128 142 129 143 if (!response.ok) { 130 144 const data = await response.json() 145 + console.error('[OAuthConsent] Submit failed:', data) 131 146 error = data.error_description || data.error || $_('oauth.error.genericError') 132 147 submitting = false 133 148 return ··· 136 151 const data = await response.json() 137 152 if (data.redirect_uri) { 138 153 window.location.href = data.redirect_uri 154 + } else { 155 + console.error('[OAuthConsent] No redirect_uri in response') 156 + error = 'Authorization failed - no redirect received' 157 + submitting = false 139 158 } 140 - } catch { 159 + } catch (e) { 160 + console.error('[OAuthConsent] Submit error:', e) 141 161 error = $_('oauth.error.genericError') 142 162 submitting = false 143 163 } ··· 249 269 <div class="spinner"></div> 250 270 <p>{$_('common.loading')}</p> 251 271 </div> 272 + {:else} 273 + <p style="color: var(--text-muted); font-size: 0.875rem;">Loading consent data...</p> 252 274 {/if} 253 275 </div> 254 276 {:else if error} ··· 370 392 </button> 371 393 <button type="button" class="approve-btn" onclick={submitConsent} disabled={submitting}> 372 394 {submitting ? $_('oauth.consent.authorizing') : $_('oauth.consent.authorize')} 395 + </button> 396 + </div> 397 + {:else} 398 + <div class="error-container"> 399 + <h1>{$_('oauth.consent.unexpectedState.title')}</h1> 400 + <p style="color: var(--text-secondary);"> 401 + {$_('oauth.consent.unexpectedState.description')} 402 + </p> 403 + <p style="color: var(--text-muted); font-size: 0.75rem; font-family: monospace;"> 404 + loading={loading}, error={error ? 'set' : 'null'}, consentData={consentData ? 'set' : 'null'}, submitting={submitting} 405 + </p> 406 + <button type="button" onclick={() => window.location.reload()}> 407 + {$_('oauth.consent.unexpectedState.reload')} 373 408 </button> 374 409 </div> 375 410 {/if}

History

6 rounds 1 comment
sign up or login to add to the discussion
3 commits
expand
fix: oauth consolidation, include-scope improvements
fix: consolidate auth extractors & standardize usage
fix: match ref pds permission-levels for some endpoints
expand 0 comments
pull request successfully merged
3 commits
expand
fix: oauth consolidation, include-scope improvements
fix: consolidate auth extractors & standardize usage
fix: match ref pds permission-levels for some endpoints
expand 0 comments
2 commits
expand
fix: oauth consolidation, include-scope improvements
fix: consolidate auth extractors & standardize usage
expand 0 comments
2 commits
expand
fix: oauth consolidation, include-scope improvements
fix: consolidate auth extractors & standardize usage
expand 1 comment

so three things:

crates/tranquil-pds/src/auth/auth_extractor.rs does not seem to be used at all anywhere?

crates/tranquil-pds/src/api/temp.rs feels like it just ... shouldnt excist based on the name

and i still dont really like these extractors. the separation of inter service auth is weird to me. inter-service auth is a form of user auth. it shouldnt be separated out from the other types. the AuthExtractor should just be AuthExtractor(pub AuthenticatedUser).

i also discovered https://docs.rs/axum/0.8.8/axum/extract/trait.OptionalFromRequestParts.html which we should be able to reduce optional vs not optional with just an AuthExtractor vs Option.

principly id want whether or not its required that the account is active or not and whether its an admin account or not to both also be type safe configurations on the extractor. probably with generics of some sort. but i cant think of a specific design i like right now so. if you come up with one feel free to do it. otherwise we can do it later

1 commit
expand
fix: oauth consolidation, include-scope improvements
expand 0 comments
lewis.moe submitted #0
1 commit
expand
fix: oauth consolidation, include-scope improvements
expand 0 comments