Our Personal Data Server from scratch! tranquil.farm
oauth atproto pds rust postgresql objectstorage fun

DRAFT: Better code quality via type safety #5

merged opened by lewis.moe targeting main from fix/code-quality-in-general

Ensuring at compile-time that we're definitely handling possible early failures in functions

Labels

None yet.

assignee

None yet.

Participants 1
AT URI
at://did:plc:3fwecdnvtcscjnrx2p4n7alz/sh.tangled.repo.pull/3mdbo7zq5ae22
+2048 -2117
Diff #0
+46 -45
crates/tranquil-db/src/postgres/oauth.rs
··· 7 7 ScopePreference, TrustedDeviceRow, TwoFactorChallenge, 8 8 }; 9 9 use tranquil_oauth::{ 10 - AuthorizationRequestParameters, AuthorizedClientData, ClientAuth, DeviceData, RequestData, 11 - TokenData, 10 + AuthorizationRequestParameters, AuthorizedClientData, ClientAuth, Code as OAuthCode, 11 + DeviceData, DeviceId as OAuthDeviceId, RequestData, SessionId as OAuthSessionId, TokenData, 12 + TokenId as OAuthTokenId, RefreshToken as OAuthRefreshToken, 12 13 }; 13 14 use tranquil_types::{ 14 15 AuthorizationCode, ClientId, DPoPProofId, DeviceId, Did, Handle, RefreshToken, RequestId, ··· 59 60 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) 60 61 RETURNING id 61 62 "#, 62 - data.did, 63 - data.token_id, 63 + data.did.as_str(), 64 + &data.token_id.0, 64 65 data.created_at, 65 66 data.updated_at, 66 67 data.expires_at, 67 68 data.client_id, 68 69 client_auth_json, 69 - data.device_id, 70 + data.device_id.as_ref().map(|d| d.0.as_str()), 70 71 parameters_json, 71 72 data.details, 72 - data.code, 73 - data.current_refresh_token, 73 + data.code.as_ref().map(|c| c.0.as_str()), 74 + data.current_refresh_token.as_ref().map(|r| r.0.as_str()), 74 75 data.scope, 75 - data.controller_did, 76 + data.controller_did.as_ref().map(|d| d.as_str()), 76 77 ) 77 78 .fetch_one(&self.pool) 78 79 .await ··· 95 96 .map_err(map_sqlx_error)?; 96 97 match row { 97 98 Some(r) => Ok(Some(TokenData { 98 - did: r.did, 99 - token_id: r.token_id, 99 + did: r.did.parse().map_err(|_| DbError::Other("Invalid DID in token".into()))?, 100 + token_id: OAuthTokenId(r.token_id), 100 101 created_at: r.created_at, 101 102 updated_at: r.updated_at, 102 103 expires_at: r.expires_at, 103 104 client_id: r.client_id, 104 105 client_auth: from_json(r.client_auth)?, 105 - device_id: r.device_id, 106 + device_id: r.device_id.map(OAuthDeviceId), 106 107 parameters: from_json(r.parameters)?, 107 108 details: r.details, 108 - code: r.code, 109 - current_refresh_token: r.current_refresh_token, 109 + code: r.code.map(OAuthCode), 110 + current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken), 110 111 scope: r.scope, 111 - controller_did: r.controller_did, 112 + controller_did: r.controller_did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid controller DID".into()))?, 112 113 })), 113 114 None => Ok(None), 114 115 } ··· 134 135 Some(r) => Ok(Some(( 135 136 r.id, 136 137 TokenData { 137 - did: r.did, 138 - token_id: r.token_id, 138 + did: r.did.parse().map_err(|_| DbError::Other("Invalid DID in token".into()))?, 139 + token_id: OAuthTokenId(r.token_id), 139 140 created_at: r.created_at, 140 141 updated_at: r.updated_at, 141 142 expires_at: r.expires_at, 142 143 client_id: r.client_id, 143 144 client_auth: from_json(r.client_auth)?, 144 - device_id: r.device_id, 145 + device_id: r.device_id.map(OAuthDeviceId), 145 146 parameters: from_json(r.parameters)?, 146 147 details: r.details, 147 - code: r.code, 148 - current_refresh_token: r.current_refresh_token, 148 + code: r.code.map(OAuthCode), 149 + current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken), 149 150 scope: r.scope, 150 - controller_did: r.controller_did, 151 + controller_did: r.controller_did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid controller DID".into()))?, 151 152 }, 152 153 ))), 153 154 None => Ok(None), ··· 176 177 Some(r) => Ok(Some(( 177 178 r.id, 178 179 TokenData { 179 - did: r.did, 180 - token_id: r.token_id, 180 + did: r.did.parse().map_err(|_| DbError::Other("Invalid DID in token".into()))?, 181 + token_id: OAuthTokenId(r.token_id), 181 182 created_at: r.created_at, 182 183 updated_at: r.updated_at, 183 184 expires_at: r.expires_at, 184 185 client_id: r.client_id, 185 186 client_auth: from_json(r.client_auth)?, 186 - device_id: r.device_id, 187 + device_id: r.device_id.map(OAuthDeviceId), 187 188 parameters: from_json(r.parameters)?, 188 189 details: r.details, 189 - code: r.code, 190 - current_refresh_token: r.current_refresh_token, 190 + code: r.code.map(OAuthCode), 191 + current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken), 191 192 scope: r.scope, 192 - controller_did: r.controller_did, 193 + controller_did: r.controller_did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid controller DID".into()))?, 193 194 }, 194 195 ))), 195 196 None => Ok(None), ··· 302 303 rows.into_iter() 303 304 .map(|r| { 304 305 Ok(TokenData { 305 - did: r.did, 306 - token_id: r.token_id, 306 + did: r.did.parse().map_err(|_| DbError::Other("Invalid DID in token".into()))?, 307 + token_id: OAuthTokenId(r.token_id), 307 308 created_at: r.created_at, 308 309 updated_at: r.updated_at, 309 310 expires_at: r.expires_at, 310 311 client_id: r.client_id, 311 312 client_auth: from_json(r.client_auth)?, 312 - device_id: r.device_id, 313 + device_id: r.device_id.map(OAuthDeviceId), 313 314 parameters: from_json(r.parameters)?, 314 315 details: r.details, 315 - code: r.code, 316 - current_refresh_token: r.current_refresh_token, 316 + code: r.code.map(OAuthCode), 317 + current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken), 317 318 scope: r.scope, 318 - controller_did: r.controller_did, 319 + controller_did: r.controller_did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid controller DID".into()))?, 319 320 }) 320 321 }) 321 322 .collect() ··· 407 408 VALUES ($1, $2, $3, $4, $5, $6, $7, $8) 408 409 "#, 409 410 request_id.as_str(), 410 - data.did, 411 - data.device_id, 411 + data.did.as_ref().map(|d| d.as_str()), 412 + data.device_id.as_ref().map(|d| d.0.as_str()), 412 413 data.client_id, 413 414 client_auth_json, 414 415 parameters_json, 415 416 data.expires_at, 416 - data.code, 417 + data.code.as_ref().map(|c| c.0.as_str()), 417 418 ) 418 419 .execute(&self.pool) 419 420 .await ··· 448 449 client_auth, 449 450 parameters, 450 451 expires_at: r.expires_at, 451 - did: r.did, 452 - device_id: r.device_id, 453 - code: r.code, 454 - controller_did: r.controller_did, 452 + did: r.did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid DID in DB".into()))?, 453 + device_id: r.device_id.map(OAuthDeviceId), 454 + code: r.code.map(OAuthCode), 455 + controller_did: r.controller_did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid controller DID in DB".into()))?, 455 456 })) 456 457 } 457 458 None => Ok(None), ··· 534 535 client_auth, 535 536 parameters, 536 537 expires_at: r.expires_at, 537 - did: r.did, 538 - device_id: r.device_id, 539 - code: r.code, 540 - controller_did: r.controller_did, 538 + did: r.did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid DID in DB".into()))?, 539 + device_id: r.device_id.map(OAuthDeviceId), 540 + code: r.code.map(OAuthCode), 541 + controller_did: r.controller_did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid controller DID in DB".into()))?, 541 542 })) 542 543 } 543 544 None => Ok(None), ··· 655 656 VALUES ($1, $2, $3, $4, $5) 656 657 "#, 657 658 device_id.as_str(), 658 - data.session_id, 659 + &data.session_id.0, 659 660 data.user_agent, 660 661 data.ip_address, 661 662 data.last_seen_at, ··· 679 680 .await 680 681 .map_err(map_sqlx_error)?; 681 682 Ok(row.map(|r| DeviceData { 682 - session_id: r.session_id, 683 + session_id: OAuthSessionId(r.session_id), 683 684 user_agent: r.user_agent, 684 685 ip_address: r.ip_address, 685 686 last_seen_at: r.last_seen_at,
+6 -4
crates/tranquil-oauth/src/lib.rs
··· 10 10 }; 11 11 pub use error::OAuthError; 12 12 pub use types::{ 13 - AuthFlowState, AuthorizationRequestParameters, AuthorizationServerMetadata, 14 - AuthorizedClientData, ClientAuth, Code, DPoPClaims, DeviceData, DeviceId, JwkPublicKey, Jwks, 15 - OAuthClientMetadata, ParResponse, ProtectedResourceMetadata, RefreshToken, RefreshTokenState, 16 - RequestData, RequestId, SessionId, TokenData, TokenId, TokenRequest, TokenResponse, 13 + AuthFlow, AuthFlowWithUser, AuthorizationRequestParameters, AuthorizationServerMetadata, 14 + AuthorizedClientData, ClientAuth, Code, CodeChallengeMethod, DPoPClaims, DeviceData, DeviceId, 15 + FlowAuthenticated, FlowAuthorized, FlowExpired, FlowNotAuthenticated, FlowNotAuthorized, 16 + FlowPending, JwkPublicKey, Jwks, OAuthClientMetadata, ParResponse, Prompt, 17 + ProtectedResourceMetadata, RefreshToken, RefreshTokenState, RequestData, RequestId, 18 + ResponseMode, ResponseType, SessionId, TokenData, TokenId, TokenRequest, TokenResponse, 17 19 };
+250 -141
crates/tranquil-oauth/src/types.rs
··· 1 1 use chrono::{DateTime, Utc}; 2 2 use serde::{Deserialize, Serialize}; 3 3 use serde_json::Value as JsonValue; 4 + use tranquil_types::Did; 4 5 5 - #[derive(Debug, Clone, Serialize, Deserialize)] 6 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 7 + #[serde(transparent)] 8 + #[sqlx(transparent)] 6 9 pub struct RequestId(pub String); 7 10 8 - #[derive(Debug, Clone, Serialize, Deserialize)] 11 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 12 + #[serde(transparent)] 13 + #[sqlx(transparent)] 9 14 pub struct TokenId(pub String); 10 15 11 - #[derive(Debug, Clone, Serialize, Deserialize)] 16 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 17 + #[serde(transparent)] 18 + #[sqlx(transparent)] 12 19 pub struct DeviceId(pub String); 13 20 14 - #[derive(Debug, Clone, Serialize, Deserialize)] 21 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 22 + #[serde(transparent)] 23 + #[sqlx(transparent)] 15 24 pub struct SessionId(pub String); 16 25 17 - #[derive(Debug, Clone, Serialize, Deserialize)] 26 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 27 + #[serde(transparent)] 28 + #[sqlx(transparent)] 18 29 pub struct Code(pub String); 19 30 20 - #[derive(Debug, Clone, Serialize, Deserialize)] 31 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 32 + #[serde(transparent)] 33 + #[sqlx(transparent)] 21 34 pub struct RefreshToken(pub String); 22 35 23 36 impl RequestId { ··· 82 95 PrivateKeyJwt { client_assertion: String }, 83 96 } 84 97 98 + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] 99 + #[serde(rename_all = "snake_case")] 100 + pub enum ResponseType { 101 + #[default] 102 + Code, 103 + } 104 + 105 + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] 106 + pub enum CodeChallengeMethod { 107 + #[default] 108 + #[serde(rename = "S256")] 109 + S256, 110 + #[serde(rename = "plain")] 111 + Plain, 112 + } 113 + 114 + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] 115 + #[serde(rename_all = "snake_case")] 116 + pub enum ResponseMode { 117 + #[default] 118 + Query, 119 + Fragment, 120 + FormPost, 121 + } 122 + 123 + impl ResponseMode { 124 + pub fn as_str(&self) -> &'static str { 125 + match self { 126 + Self::Query => "query", 127 + Self::Fragment => "fragment", 128 + Self::FormPost => "form_post", 129 + } 130 + } 131 + } 132 + 133 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] 134 + #[serde(rename_all = "snake_case")] 135 + pub enum Prompt { 136 + None, 137 + Login, 138 + Consent, 139 + SelectAccount, 140 + Create, 141 + } 142 + 143 + impl Prompt { 144 + pub fn as_str(&self) -> &'static str { 145 + match self { 146 + Self::None => "none", 147 + Self::Login => "login", 148 + Self::Consent => "consent", 149 + Self::SelectAccount => "select_account", 150 + Self::Create => "create", 151 + } 152 + } 153 + } 154 + 85 155 #[derive(Debug, Clone, Serialize, Deserialize)] 86 156 pub struct AuthorizationRequestParameters { 87 - pub response_type: String, 157 + pub response_type: ResponseType, 88 158 pub client_id: String, 89 159 pub redirect_uri: String, 90 160 pub scope: Option<String>, 91 161 pub state: Option<String>, 92 162 pub code_challenge: String, 93 - pub code_challenge_method: String, 94 - pub response_mode: Option<String>, 163 + pub code_challenge_method: CodeChallengeMethod, 164 + pub response_mode: Option<ResponseMode>, 95 165 pub login_hint: Option<String>, 96 166 pub dpop_jkt: Option<String>, 97 - pub prompt: Option<String>, 167 + pub prompt: Option<Prompt>, 98 168 #[serde(flatten)] 99 169 pub extra: Option<JsonValue>, 100 170 } ··· 105 175 pub client_auth: Option<ClientAuth>, 106 176 pub parameters: AuthorizationRequestParameters, 107 177 pub expires_at: DateTime<Utc>, 108 - pub did: Option<String>, 109 - pub device_id: Option<String>, 110 - pub code: Option<String>, 111 - pub controller_did: Option<String>, 178 + pub did: Option<Did>, 179 + pub device_id: Option<DeviceId>, 180 + pub code: Option<Code>, 181 + pub controller_did: Option<Did>, 112 182 } 113 183 114 184 #[derive(Debug, Clone)] 115 185 pub struct DeviceData { 116 - pub session_id: String, 186 + pub session_id: SessionId, 117 187 pub user_agent: Option<String>, 118 188 pub ip_address: String, 119 189 pub last_seen_at: DateTime<Utc>, ··· 121 191 122 192 #[derive(Debug, Clone)] 123 193 pub struct TokenData { 124 - pub did: String, 125 - pub token_id: String, 194 + pub did: Did, 195 + pub token_id: TokenId, 126 196 pub created_at: DateTime<Utc>, 127 197 pub updated_at: DateTime<Utc>, 128 198 pub expires_at: DateTime<Utc>, 129 199 pub client_id: String, 130 200 pub client_auth: ClientAuth, 131 - pub device_id: Option<String>, 201 + pub device_id: Option<DeviceId>, 132 202 pub parameters: AuthorizationRequestParameters, 133 203 pub details: Option<JsonValue>, 134 - pub code: Option<String>, 135 - pub current_refresh_token: Option<String>, 204 + pub code: Option<Code>, 205 + pub current_refresh_token: Option<RefreshToken>, 136 206 pub scope: Option<String>, 137 - pub controller_did: Option<String>, 207 + pub controller_did: Option<Did>, 138 208 } 139 209 140 210 #[derive(Debug, Clone, Serialize, Deserialize)] ··· 247 317 pub keys: Vec<JwkPublicKey>, 248 318 } 249 319 250 - #[derive(Debug, Clone, PartialEq, Eq)] 251 - pub enum AuthFlowState { 252 - Pending, 253 - Authenticated { 254 - did: String, 255 - device_id: Option<String>, 256 - }, 257 - Authorized { 258 - did: String, 259 - device_id: Option<String>, 260 - code: String, 261 - }, 262 - Expired, 320 + 321 + #[derive(Debug, Clone)] 322 + pub struct FlowPending { 323 + pub parameters: AuthorizationRequestParameters, 324 + pub client_id: String, 325 + pub client_auth: Option<ClientAuth>, 326 + pub expires_at: DateTime<Utc>, 327 + pub controller_did: Option<Did>, 263 328 } 264 329 265 - impl AuthFlowState { 266 - pub fn from_request_data(data: &RequestData) -> Self { 267 - if data.expires_at < chrono::Utc::now() { 268 - return AuthFlowState::Expired; 269 - } 270 - match (&data.did, &data.code) { 271 - (Some(did), Some(code)) => AuthFlowState::Authorized { 272 - did: did.clone(), 273 - device_id: data.device_id.clone(), 274 - code: code.clone(), 275 - }, 276 - (Some(did), None) => AuthFlowState::Authenticated { 277 - did: did.clone(), 278 - device_id: data.device_id.clone(), 279 - }, 280 - (None, _) => AuthFlowState::Pending, 281 - } 282 - } 330 + #[derive(Debug, Clone)] 331 + pub struct FlowAuthenticated { 332 + pub parameters: AuthorizationRequestParameters, 333 + pub client_id: String, 334 + pub client_auth: Option<ClientAuth>, 335 + pub expires_at: DateTime<Utc>, 336 + pub did: Did, 337 + pub device_id: Option<DeviceId>, 338 + pub controller_did: Option<Did>, 339 + } 283 340 284 - pub fn is_pending(&self) -> bool { 285 - matches!(self, AuthFlowState::Pending) 286 - } 341 + #[derive(Debug, Clone)] 342 + pub struct FlowAuthorized { 343 + pub parameters: AuthorizationRequestParameters, 344 + pub client_id: String, 345 + pub client_auth: Option<ClientAuth>, 346 + pub expires_at: DateTime<Utc>, 347 + pub did: Did, 348 + pub device_id: Option<DeviceId>, 349 + pub code: Code, 350 + pub controller_did: Option<Did>, 351 + } 287 352 288 - pub fn is_authenticated(&self) -> bool { 289 - matches!(self, AuthFlowState::Authenticated { .. }) 290 - } 353 + #[derive(Debug)] 354 + pub struct FlowExpired; 355 + 356 + #[derive(Debug)] 357 + pub struct FlowNotAuthenticated; 358 + 359 + #[derive(Debug)] 360 + pub struct FlowNotAuthorized; 291 361 292 - pub fn is_authorized(&self) -> bool { 293 - matches!(self, AuthFlowState::Authorized { .. }) 362 + #[derive(Debug, Clone)] 363 + pub enum AuthFlow { 364 + Pending(FlowPending), 365 + Authenticated(FlowAuthenticated), 366 + Authorized(FlowAuthorized), 367 + } 368 + 369 + #[derive(Debug, Clone)] 370 + pub enum AuthFlowWithUser { 371 + Authenticated(FlowAuthenticated), 372 + Authorized(FlowAuthorized), 373 + } 374 + 375 + impl AuthFlow { 376 + pub fn from_request_data(data: RequestData) -> Result<Self, FlowExpired> { 377 + if data.expires_at < chrono::Utc::now() { 378 + return Err(FlowExpired); 379 + } 380 + match (data.did, data.code) { 381 + (None, _) => Ok(AuthFlow::Pending(FlowPending { 382 + parameters: data.parameters, 383 + client_id: data.client_id, 384 + client_auth: data.client_auth, 385 + expires_at: data.expires_at, 386 + controller_did: data.controller_did, 387 + })), 388 + (Some(did), None) => Ok(AuthFlow::Authenticated(FlowAuthenticated { 389 + parameters: data.parameters, 390 + client_id: data.client_id, 391 + client_auth: data.client_auth, 392 + expires_at: data.expires_at, 393 + did, 394 + device_id: data.device_id, 395 + controller_did: data.controller_did, 396 + })), 397 + (Some(did), Some(code)) => Ok(AuthFlow::Authorized(FlowAuthorized { 398 + parameters: data.parameters, 399 + client_id: data.client_id, 400 + client_auth: data.client_auth, 401 + expires_at: data.expires_at, 402 + did, 403 + device_id: data.device_id, 404 + code, 405 + controller_did: data.controller_did, 406 + })), 407 + } 294 408 } 295 409 296 - pub fn is_expired(&self) -> bool { 297 - matches!(self, AuthFlowState::Expired) 410 + pub fn require_user(self) -> Result<AuthFlowWithUser, FlowNotAuthenticated> { 411 + match self { 412 + AuthFlow::Pending(_) => Err(FlowNotAuthenticated), 413 + AuthFlow::Authenticated(a) => Ok(AuthFlowWithUser::Authenticated(a)), 414 + AuthFlow::Authorized(a) => Ok(AuthFlowWithUser::Authorized(a)), 415 + } 298 416 } 299 417 300 - pub fn can_authenticate(&self) -> bool { 301 - matches!(self, AuthFlowState::Pending) 418 + pub fn require_authorized(self) -> Result<FlowAuthorized, FlowNotAuthorized> { 419 + match self { 420 + AuthFlow::Authorized(a) => Ok(a), 421 + _ => Err(FlowNotAuthorized), 422 + } 302 423 } 424 + } 303 425 304 - pub fn can_authorize(&self) -> bool { 305 - matches!(self, AuthFlowState::Authenticated { .. }) 426 + impl AuthFlowWithUser { 427 + pub fn did(&self) -> &Did { 428 + match self { 429 + AuthFlowWithUser::Authenticated(a) => &a.did, 430 + AuthFlowWithUser::Authorized(a) => &a.did, 431 + } 306 432 } 307 433 308 - pub fn can_exchange(&self) -> bool { 309 - matches!(self, AuthFlowState::Authorized { .. }) 434 + pub fn device_id(&self) -> Option<&DeviceId> { 435 + match self { 436 + AuthFlowWithUser::Authenticated(a) => a.device_id.as_ref(), 437 + AuthFlowWithUser::Authorized(a) => a.device_id.as_ref(), 438 + } 310 439 } 311 440 312 - pub fn did(&self) -> Option<&str> { 441 + pub fn parameters(&self) -> &AuthorizationRequestParameters { 313 442 match self { 314 - AuthFlowState::Authenticated { did, .. } | AuthFlowState::Authorized { did, .. } => { 315 - Some(did) 316 - } 317 - _ => None, 443 + AuthFlowWithUser::Authenticated(a) => &a.parameters, 444 + AuthFlowWithUser::Authorized(a) => &a.parameters, 318 445 } 319 446 } 320 447 321 - pub fn code(&self) -> Option<&str> { 448 + pub fn client_id(&self) -> &str { 322 449 match self { 323 - AuthFlowState::Authorized { code, .. } => Some(code), 324 - _ => None, 450 + AuthFlowWithUser::Authenticated(a) => &a.client_id, 451 + AuthFlowWithUser::Authorized(a) => &a.client_id, 325 452 } 326 453 } 327 - } 328 454 329 - impl std::fmt::Display for AuthFlowState { 330 - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 455 + pub fn controller_did(&self) -> Option<&Did> { 331 456 match self { 332 - AuthFlowState::Pending => write!(f, "pending"), 333 - AuthFlowState::Authenticated { did, .. } => write!(f, "authenticated ({})", did), 334 - AuthFlowState::Authorized { did, code, .. } => { 335 - write!( 336 - f, 337 - "authorized ({}, code={}...)", 338 - did, 339 - &code[..8.min(code.len())] 340 - ) 341 - } 342 - AuthFlowState::Expired => write!(f, "expired"), 457 + AuthFlowWithUser::Authenticated(a) => a.controller_did.as_ref(), 458 + AuthFlowWithUser::Authorized(a) => a.controller_did.as_ref(), 343 459 } 344 460 } 345 461 } 346 462 463 + 347 464 #[derive(Debug, Clone, PartialEq, Eq)] 348 465 pub enum RefreshTokenState { 349 466 Valid, ··· 406 523 use chrono::{Duration, Utc}; 407 524 408 525 fn make_request_data( 409 - did: Option<String>, 410 - code: Option<String>, 526 + did: Option<Did>, 527 + code: Option<Code>, 411 528 expires_in: Duration, 412 529 ) -> RequestData { 413 530 RequestData { 414 531 client_id: "test-client".into(), 415 532 client_auth: None, 416 533 parameters: AuthorizationRequestParameters { 417 - response_type: "code".into(), 534 + response_type: ResponseType::Code, 418 535 client_id: "test-client".into(), 419 536 redirect_uri: "https://example.com/callback".into(), 420 537 scope: Some("atproto".into()), 421 538 state: None, 422 539 code_challenge: "test".into(), 423 - code_challenge_method: "S256".into(), 540 + code_challenge_method: CodeChallengeMethod::S256, 424 541 response_mode: None, 425 542 login_hint: None, 426 543 dpop_jkt: None, ··· 435 552 } 436 553 } 437 554 555 + fn test_did(s: &str) -> Did { 556 + s.parse().expect("valid test DID") 557 + } 558 + 559 + fn test_code(s: &str) -> Code { 560 + Code(s.to_string()) 561 + } 562 + 438 563 #[test] 439 - fn test_auth_flow_state_pending() { 564 + fn test_auth_flow_pending() { 440 565 let data = make_request_data(None, None, Duration::minutes(5)); 441 - let state = AuthFlowState::from_request_data(&data); 442 - assert!(state.is_pending()); 443 - assert!(!state.is_authenticated()); 444 - assert!(!state.is_authorized()); 445 - assert!(!state.is_expired()); 446 - assert!(state.can_authenticate()); 447 - assert!(!state.can_authorize()); 448 - assert!(!state.can_exchange()); 449 - assert!(state.did().is_none()); 450 - assert!(state.code().is_none()); 566 + let flow = AuthFlow::from_request_data(data).expect("should not be expired"); 567 + assert!(matches!(flow, AuthFlow::Pending(_))); 568 + assert!(flow.clone().require_user().is_err()); 569 + assert!(flow.require_authorized().is_err()); 451 570 } 452 571 453 572 #[test] 454 - fn test_auth_flow_state_authenticated() { 455 - let data = make_request_data(Some("did:plc:test".into()), None, Duration::minutes(5)); 456 - let state = AuthFlowState::from_request_data(&data); 457 - assert!(!state.is_pending()); 458 - assert!(state.is_authenticated()); 459 - assert!(!state.is_authorized()); 460 - assert!(!state.is_expired()); 461 - assert!(!state.can_authenticate()); 462 - assert!(state.can_authorize()); 463 - assert!(!state.can_exchange()); 464 - assert_eq!(state.did(), Some("did:plc:test")); 465 - assert!(state.code().is_none()); 573 + fn test_auth_flow_authenticated() { 574 + let did = test_did("did:plc:test"); 575 + let data = make_request_data(Some(did.clone()), None, Duration::minutes(5)); 576 + let flow = AuthFlow::from_request_data(data).expect("should not be expired"); 577 + assert!(matches!(flow, AuthFlow::Authenticated(_))); 578 + let with_user = flow.clone().require_user().expect("should have user"); 579 + assert_eq!(with_user.did(), &did); 580 + assert!(flow.require_authorized().is_err()); 466 581 } 467 582 468 583 #[test] 469 - fn test_auth_flow_state_authorized() { 584 + fn test_auth_flow_authorized() { 585 + let did = test_did("did:plc:test"); 586 + let code = test_code("auth-code-123"); 470 587 let data = make_request_data( 471 - Some("did:plc:test".into()), 472 - Some("auth-code-123".into()), 588 + Some(did.clone()), 589 + Some(code.clone()), 473 590 Duration::minutes(5), 474 591 ); 475 - let state = AuthFlowState::from_request_data(&data); 476 - assert!(!state.is_pending()); 477 - assert!(!state.is_authenticated()); 478 - assert!(state.is_authorized()); 479 - assert!(!state.is_expired()); 480 - assert!(!state.can_authenticate()); 481 - assert!(!state.can_authorize()); 482 - assert!(state.can_exchange()); 483 - assert_eq!(state.did(), Some("did:plc:test")); 484 - assert_eq!(state.code(), Some("auth-code-123")); 592 + let flow = AuthFlow::from_request_data(data).expect("should not be expired"); 593 + assert!(matches!(flow, AuthFlow::Authorized(_))); 594 + let with_user = flow.clone().require_user().expect("should have user"); 595 + assert_eq!(with_user.did(), &did); 596 + let authorized = flow.require_authorized().expect("should be authorized"); 597 + assert_eq!(authorized.did, did); 598 + assert_eq!(authorized.code, code); 485 599 } 486 600 487 601 #[test] 488 - fn test_auth_flow_state_expired() { 489 - let data = make_request_data( 490 - Some("did:plc:test".into()), 491 - Some("code".into()), 492 - Duration::minutes(-1), 493 - ); 494 - let state = AuthFlowState::from_request_data(&data); 495 - assert!(state.is_expired()); 496 - assert!(!state.can_authenticate()); 497 - assert!(!state.can_authorize()); 498 - assert!(!state.can_exchange()); 602 + fn test_auth_flow_expired() { 603 + let did = test_did("did:plc:test"); 604 + let code = test_code("code"); 605 + let data = make_request_data(Some(did), Some(code), Duration::minutes(-1)); 606 + let result = AuthFlow::from_request_data(data); 607 + assert!(result.is_err()); 499 608 } 500 609 501 610 #[test]
+4 -10
crates/tranquil-pds/src/api/admin/account/delete.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 3 use crate::auth::{Admin, Auth}; 4 4 use crate::state::AppState; 5 5 use crate::types::Did; ··· 9 9 response::{IntoResponse, Response}, 10 10 }; 11 11 use serde::Deserialize; 12 - use tracing::{error, warn}; 12 + use tracing::warn; 13 13 14 14 #[derive(Deserialize)] 15 15 pub struct DeleteAccountInput { ··· 26 26 .user_repo 27 27 .get_id_and_handle_by_did(did) 28 28 .await 29 - .map_err(|e| { 30 - error!("DB error in delete_account: {:?}", e); 31 - ApiError::InternalError(None) 32 - })? 29 + .log_db_err("in delete_account")? 33 30 .ok_or(ApiError::AccountNotFound) 34 31 .map(|row| (row.id, row.handle))?; 35 32 ··· 37 34 .user_repo 38 35 .admin_delete_account_complete(user_id, did) 39 36 .await 40 - .map_err(|e| { 41 - error!("Failed to delete account {}: {:?}", did, e); 42 - ApiError::InternalError(Some("Failed to delete account".into())) 43 - })?; 37 + .log_db_err("deleting account")?; 44 38 45 39 if let Err(e) = 46 40 crate::api::repo::record::sequence_account_event(&state, did, false, Some("deleted")).await
+5 -7
crates/tranquil-pds/src/api/admin/account/email.rs
··· 1 - use crate::api::error::{ApiError, AtpJson}; 1 + use crate::api::error::{ApiError, AtpJson, DbResultExt}; 2 2 use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 + use crate::util::pds_hostname; 4 5 use crate::types::Did; 5 6 use axum::{ 6 7 Json, ··· 9 10 response::{IntoResponse, Response}, 10 11 }; 11 12 use serde::{Deserialize, Serialize}; 12 - use tracing::{error, warn}; 13 + use tracing::warn; 13 14 14 15 #[derive(Deserialize)] 15 16 #[serde(rename_all = "camelCase")] ··· 39 40 .user_repo 40 41 .get_by_did(&input.recipient_did) 41 42 .await 42 - .map_err(|e| { 43 - error!("DB error in send_email: {:?}", e); 44 - ApiError::InternalError(None) 45 - })? 43 + .log_db_err("in send_email")? 46 44 .ok_or(ApiError::AccountNotFound)?; 47 45 48 46 let email = user.email.ok_or(ApiError::NoEmail)?; 49 47 let (user_id, handle) = (user.id, user.handle); 50 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 48 + let hostname = pds_hostname(); 51 49 let subject = input 52 50 .subject 53 51 .clone()
+3 -10
crates/tranquil-pds/src/api/admin/account/info.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 4 use crate::types::{Did, Handle}; ··· 10 10 }; 11 11 use serde::{Deserialize, Serialize}; 12 12 use std::collections::HashMap; 13 - use tracing::error; 14 13 15 14 #[derive(Deserialize)] 16 15 pub struct GetAccountInfoParams { ··· 74 73 .infra_repo 75 74 .get_admin_account_info_by_did(&params.did) 76 75 .await 77 - .map_err(|e| { 78 - error!("DB error in get_account_info: {:?}", e); 79 - ApiError::InternalError(None) 80 - })? 76 + .log_db_err("in get_account_info")? 81 77 .ok_or(ApiError::AccountNotFound)?; 82 78 83 79 let invited_by = get_invited_by(&state, account.id).await; ··· 214 210 .infra_repo 215 211 .get_admin_account_infos_by_dids(&dids_typed) 216 212 .await 217 - .map_err(|e| { 218 - error!("Failed to fetch account infos: {:?}", e); 219 - ApiError::InternalError(None) 220 - })?; 213 + .log_db_err("fetching account infos")?; 221 214 222 215 let user_ids: Vec<uuid::Uuid> = accounts.iter().map(|u| u.id).collect(); 223 216
+2 -6
crates/tranquil-pds/src/api/admin/account/search.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 4 use crate::types::{Did, Handle}; ··· 9 9 response::{IntoResponse, Response}, 10 10 }; 11 11 use serde::{Deserialize, Serialize}; 12 - use tracing::error; 13 12 14 13 #[derive(Deserialize)] 15 14 pub struct SearchAccountsParams { ··· 66 65 limit + 1, 67 66 ) 68 67 .await 69 - .map_err(|e| { 70 - error!("DB error in search_accounts: {:?}", e); 71 - ApiError::InternalError(None) 72 - })?; 68 + .log_db_err("in search_accounts")?; 73 69 74 70 let has_more = rows.len() > limit as usize; 75 71 let accounts: Vec<AccountView> = rows
+2 -2
crates/tranquil-pds/src/api/admin/account/update.rs
··· 2 2 use crate::api::error::ApiError; 3 3 use crate::auth::{Admin, Auth}; 4 4 use crate::state::AppState; 5 + use crate::util::pds_hostname_without_port; 5 6 use crate::types::{Did, Handle, PlainPassword}; 6 7 use axum::{ 7 8 Json, ··· 69 70 { 70 71 return Err(ApiError::InvalidHandle(None)); 71 72 } 72 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 73 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 73 + let hostname_for_handles = pds_hostname_without_port(); 74 74 let handle = if !input_handle.contains('.') { 75 75 format!("{}.{}", input_handle, hostname_for_handles) 76 76 } else {
+13 -49
crates/tranquil-pds/src/api/admin/config.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::auth::{Admin, Auth}; 3 3 use crate::state::AppState; 4 4 use axum::{Json, extract::State}; ··· 56 56 .infra_repo 57 57 .get_server_configs(keys) 58 58 .await 59 - .map_err(|e| { 60 - error!("DB error fetching server config: {:?}", e); 61 - ApiError::InternalError(None) 62 - })?; 59 + .log_db_err("fetching server config")?; 63 60 64 61 let config_map: std::collections::HashMap<String, String> = rows.into_iter().collect(); 65 62 ··· 92 89 .infra_repo 93 90 .upsert_server_config("server_name", trimmed) 94 91 .await 95 - .map_err(|e| { 96 - error!("DB error upserting server_name: {:?}", e); 97 - ApiError::InternalError(None) 98 - })?; 92 + .log_db_err("upserting server_name")?; 99 93 } 100 94 101 95 if let Some(ref color) = req.primary_color { ··· 104 98 .infra_repo 105 99 .delete_server_config("primary_color") 106 100 .await 107 - .map_err(|e| { 108 - error!("DB error deleting primary_color: {:?}", e); 109 - ApiError::InternalError(None) 110 - })?; 101 + .log_db_err("deleting primary_color")?; 111 102 } else if is_valid_hex_color(color) { 112 103 state 113 104 .infra_repo 114 105 .upsert_server_config("primary_color", color) 115 106 .await 116 - .map_err(|e| { 117 - error!("DB error upserting primary_color: {:?}", e); 118 - ApiError::InternalError(None) 119 - })?; 107 + .log_db_err("upserting primary_color")?; 120 108 } else { 121 109 return Err(ApiError::InvalidRequest( 122 110 "Invalid primary color format (expected #RRGGBB)".into(), ··· 130 118 .infra_repo 131 119 .delete_server_config("primary_color_dark") 132 120 .await 133 - .map_err(|e| { 134 - error!("DB error deleting primary_color_dark: {:?}", e); 135 - ApiError::InternalError(None) 136 - })?; 121 + .log_db_err("deleting primary_color_dark")?; 137 122 } else if is_valid_hex_color(color) { 138 123 state 139 124 .infra_repo 140 125 .upsert_server_config("primary_color_dark", color) 141 126 .await 142 - .map_err(|e| { 143 - error!("DB error upserting primary_color_dark: {:?}", e); 144 - ApiError::InternalError(None) 145 - })?; 127 + .log_db_err("upserting primary_color_dark")?; 146 128 } else { 147 129 return Err(ApiError::InvalidRequest( 148 130 "Invalid primary dark color format (expected #RRGGBB)".into(), ··· 156 138 .infra_repo 157 139 .delete_server_config("secondary_color") 158 140 .await 159 - .map_err(|e| { 160 - error!("DB error deleting secondary_color: {:?}", e); 161 - ApiError::InternalError(None) 162 - })?; 141 + .log_db_err("deleting secondary_color")?; 163 142 } else if is_valid_hex_color(color) { 164 143 state 165 144 .infra_repo 166 145 .upsert_server_config("secondary_color", color) 167 146 .await 168 - .map_err(|e| { 169 - error!("DB error upserting secondary_color: {:?}", e); 170 - ApiError::InternalError(None) 171 - })?; 147 + .log_db_err("upserting secondary_color")?; 172 148 } else { 173 149 return Err(ApiError::InvalidRequest( 174 150 "Invalid secondary color format (expected #RRGGBB)".into(), ··· 182 158 .infra_repo 183 159 .delete_server_config("secondary_color_dark") 184 160 .await 185 - .map_err(|e| { 186 - error!("DB error deleting secondary_color_dark: {:?}", e); 187 - ApiError::InternalError(None) 188 - })?; 161 + .log_db_err("deleting secondary_color_dark")?; 189 162 } else if is_valid_hex_color(color) { 190 163 state 191 164 .infra_repo 192 165 .upsert_server_config("secondary_color_dark", color) 193 166 .await 194 - .map_err(|e| { 195 - error!("DB error upserting secondary_color_dark: {:?}", e); 196 - ApiError::InternalError(None) 197 - })?; 167 + .log_db_err("upserting secondary_color_dark")?; 198 168 } else { 199 169 return Err(ApiError::InvalidRequest( 200 170 "Invalid secondary dark color format (expected #RRGGBB)".into(), ··· 235 205 .infra_repo 236 206 .delete_server_config("logo_cid") 237 207 .await 238 - .map_err(|e| { 239 - error!("DB error deleting logo_cid: {:?}", e); 240 - ApiError::InternalError(None) 241 - })?; 208 + .log_db_err("deleting logo_cid")?; 242 209 } else { 243 210 state 244 211 .infra_repo 245 212 .upsert_server_config("logo_cid", logo_cid) 246 213 .await 247 - .map_err(|e| { 248 - error!("DB error upserting logo_cid: {:?}", e); 249 - ApiError::InternalError(None) 250 - })?; 214 + .log_db_err("upserting logo_cid")?; 251 215 } 252 216 } 253 217
+2 -5
crates/tranquil-pds/src/api/admin/invite.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 3 use crate::auth::{Admin, Auth}; 4 4 use crate::state::AppState; 5 5 use axum::{ ··· 91 91 .infra_repo 92 92 .list_invite_codes(params.cursor.as_deref(), limit, sort_order) 93 93 .await 94 - .map_err(|e| { 95 - error!("DB error fetching invite codes: {:?}", e); 96 - ApiError::InternalError(None) 97 - })?; 94 + .log_db_err("fetching invite codes")?; 98 95 99 96 let user_ids: Vec<uuid::Uuid> = codes_rows.iter().map(|r| r.created_by_user).collect(); 100 97 let code_strings: Vec<String> = codes_rows.iter().map(|r| r.code.clone()).collect();
+2 -2
crates/tranquil-pds/src/api/age_assurance.rs
··· 33 33 } 34 34 35 35 async fn get_account_created_at(state: &AppState, headers: &HeaderMap) -> Option<String> { 36 - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 36 + let auth_header = crate::util::get_header_str(headers, "Authorization"); 37 37 tracing::debug!(?auth_header, "age assurance: extracting token"); 38 38 39 39 let extracted = extract_auth_token_from_header(auth_header)?; 40 40 tracing::debug!("age assurance: got token, validating"); 41 41 42 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 42 + let dpop_proof = crate::util::get_header_str(headers, "DPoP"); 43 43 let http_uri = "/"; 44 44 45 45 let auth_user = match validate_token_with_dpop(
+21 -70
crates/tranquil-pds/src/api/delegation.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::repo::record::utils::create_signed_commit; 3 3 use crate::auth::{Active, Auth}; 4 - use crate::delegation::{DelegationActionType, SCOPE_PRESETS, scopes}; 5 - use crate::state::{AppState, RateLimitKind}; 4 + use crate::delegation::{ 5 + DelegationActionType, SCOPE_PRESETS, scopes, verify_can_add_controllers, 6 + verify_can_be_controller, verify_can_control_accounts, 7 + }; 8 + use crate::rate_limit::{AccountCreationLimit, RateLimited}; 9 + use crate::state::AppState; 6 10 use crate::types::{Did, Handle, Nsid, Rkey}; 7 - use crate::util::extract_client_ip; 11 + use crate::util::{pds_hostname, pds_hostname_without_port}; 8 12 use axum::{ 9 13 Json, 10 14 extract::{Query, State}, 11 - http::{HeaderMap, StatusCode}, 15 + http::StatusCode, 12 16 response::{IntoResponse, Response}, 13 17 }; 14 18 use jacquard_common::types::{integer::LimitedU32, string::Tid}; ··· 93 97 return Ok(ApiError::ControllerNotFound.into_response()); 94 98 } 95 99 96 - match state.delegation_repo.controls_any_accounts(&auth.did).await { 97 - Ok(true) => { 98 - return Ok(ApiError::InvalidDelegation( 99 - "Cannot add controllers to an account that controls other accounts".into(), 100 - ) 101 - .into_response()); 102 - } 103 - Err(e) => { 104 - tracing::error!("Failed to check delegation status: {:?}", e); 105 - return Ok( 106 - ApiError::InternalError(Some("Failed to verify delegation status".into())) 107 - .into_response(), 108 - ); 109 - } 110 - Ok(false) => {} 111 - } 100 + let _can_add = match verify_can_add_controllers(&state, &auth).await { 101 + Ok(proof) => proof, 102 + Err(response) => return Ok(response), 103 + }; 112 104 113 - match state 114 - .delegation_repo 115 - .has_any_controllers(&input.controller_did) 116 - .await 117 - { 118 - Ok(true) => { 119 - return Ok(ApiError::InvalidDelegation( 120 - "Cannot add a controlled account as a controller".into(), 121 - ) 122 - .into_response()); 123 - } 124 - Err(e) => { 125 - tracing::error!("Failed to check controller status: {:?}", e); 126 - return Ok( 127 - ApiError::InternalError(Some("Failed to verify controller status".into())) 128 - .into_response(), 129 - ); 130 - } 131 - Ok(false) => {} 105 + if let Err(response) = verify_can_be_controller(&state, &input.controller_did).await { 106 + return Ok(response); 132 107 } 133 108 134 109 match state ··· 456 431 457 432 pub async fn create_delegated_account( 458 433 State(state): State<AppState>, 459 - headers: HeaderMap, 434 + _rate_limit: RateLimited<AccountCreationLimit>, 460 435 auth: Auth<Active>, 461 436 Json(input): Json<CreateDelegatedAccountInput>, 462 437 ) -> Result<Response, ApiError> { 463 - let client_ip = extract_client_ip(&headers); 464 - if !state 465 - .check_rate_limit(RateLimitKind::AccountCreation, &client_ip) 466 - .await 467 - { 468 - warn!(ip = %client_ip, "Delegated account creation rate limit exceeded"); 469 - return Ok(ApiError::RateLimitExceeded(Some( 470 - "Too many account creation attempts. Please try again later.".into(), 471 - )) 472 - .into_response()); 473 - } 474 - 475 438 if let Err(e) = scopes::validate_delegation_scopes(&input.controller_scopes) { 476 439 return Ok(ApiError::InvalidScopes(e).into_response()); 477 440 } 478 441 479 - match state.delegation_repo.has_any_controllers(&auth.did).await { 480 - Ok(true) => { 481 - return Ok(ApiError::InvalidDelegation( 482 - "Cannot create delegated accounts from a controlled account".into(), 483 - ) 484 - .into_response()); 485 - } 486 - Err(e) => { 487 - tracing::error!("Failed to check controller status: {:?}", e); 488 - return Ok( 489 - ApiError::InternalError(Some("Failed to verify controller status".into())) 490 - .into_response(), 491 - ); 492 - } 493 - Ok(false) => {} 494 - } 442 + let _can_control = match verify_can_control_accounts(&state, &auth).await { 443 + Ok(proof) => proof, 444 + Err(response) => return Ok(response), 445 + }; 495 446 496 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 497 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 447 + let hostname = pds_hostname(); 448 + let hostname_for_handles = pds_hostname_without_port(); 498 449 let pds_suffix = format!(".{}", hostname_for_handles); 499 450 500 451 let handle = if !input.handle.contains('.') || input.handle.ends_with(&pds_suffix) {
+19
crates/tranquil-pds/src/api/error.rs
··· 694 694 } 695 695 } 696 696 697 + impl From<crate::rate_limit::UserRateLimitError> for ApiError { 698 + fn from(e: crate::rate_limit::UserRateLimitError) -> Self { 699 + Self::RateLimitExceeded(e.message) 700 + } 701 + } 702 + 697 703 #[allow(clippy::result_large_err)] 698 704 pub fn parse_did(s: &str) -> Result<tranquil_types::Did, Response> { 699 705 s.parse() ··· 756 762 _ => "Invalid request body".to_string(), 757 763 } 758 764 } 765 + 766 + pub trait DbResultExt<T> { 767 + fn log_db_err(self, ctx: &str) -> Result<T, ApiError>; 768 + } 769 + 770 + impl<T, E: std::fmt::Debug> DbResultExt<T> for Result<T, E> { 771 + fn log_db_err(self, ctx: &str) -> Result<T, ApiError> { 772 + self.map_err(|e| { 773 + tracing::error!("DB error {}: {:?}", ctx, e); 774 + ApiError::DatabaseError 775 + }) 776 + } 777 + }
+16 -41
crates/tranquil-pds/src/api/identity/account.rs
··· 3 3 use crate::api::repo::record::utils::create_signed_commit; 4 4 use crate::auth::{ServiceTokenVerifier, extract_auth_token_from_header, is_service_token}; 5 5 use crate::plc::{PlcClient, create_genesis_operation, signing_key_to_did_key}; 6 - use crate::state::{AppState, RateLimitKind}; 6 + use crate::rate_limit::{AccountCreationLimit, RateLimited}; 7 + use crate::state::AppState; 7 8 use crate::types::{Did, Handle, Nsid, PlainPassword, Rkey}; 9 + use crate::util::{pds_hostname, pds_hostname_without_port}; 8 10 use crate::validation::validate_password; 9 11 use axum::{ 10 12 Json, ··· 22 24 use std::sync::Arc; 23 25 use tracing::{debug, error, info, warn}; 24 26 25 - fn extract_client_ip(headers: &HeaderMap) -> String { 26 - if let Some(forwarded) = headers.get("x-forwarded-for") 27 - && let Ok(value) = forwarded.to_str() 28 - && let Some(first_ip) = value.split(',').next() 29 - { 30 - return first_ip.trim().to_string(); 31 - } 32 - if let Some(real_ip) = headers.get("x-real-ip") 33 - && let Ok(value) = real_ip.to_str() 34 - { 35 - return value.trim().to_string(); 36 - } 37 - "unknown".to_string() 38 - } 39 - 40 27 #[derive(Deserialize)] 41 28 #[serde(rename_all = "camelCase")] 42 29 pub struct CreateAccountInput { ··· 68 55 69 56 pub async fn create_account( 70 57 State(state): State<AppState>, 58 + _rate_limit: RateLimited<AccountCreationLimit>, 71 59 headers: HeaderMap, 72 60 Json(input): Json<CreateAccountInput>, 73 61 ) -> Response { ··· 84 72 } else { 85 73 info!("create_account called"); 86 74 } 87 - let client_ip = extract_client_ip(&headers); 88 - if !state 89 - .check_rate_limit(RateLimitKind::AccountCreation, &client_ip) 90 - .await 91 - { 92 - warn!(ip = %client_ip, "Account creation rate limit exceeded"); 93 - return ApiError::RateLimitExceeded(Some( 94 - "Too many account creation attempts. Please try again later.".into(), 95 - )) 96 - .into_response(); 97 - } 98 75 99 76 let migration_auth = if let Some(extracted) = 100 - extract_auth_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok())) 77 + extract_auth_token_from_header(crate::util::get_header_str(&headers, "Authorization")) 101 78 { 102 79 let token = extracted.token; 103 80 if is_service_token(&token) { ··· 143 120 if (is_migration || is_did_web_byod) 144 121 && let (Some(provided_did), Some(auth_did)) = (input.did.as_ref(), migration_auth.as_ref()) 145 122 { 146 - if provided_did != auth_did { 123 + if provided_did != auth_did.as_str() { 147 124 info!( 148 125 "[MIGRATION] createAccount: Service token mismatch - token_did={} provided_did={}", 149 126 auth_did, provided_did ··· 164 141 } 165 142 } 166 143 167 - let hostname_for_validation = 168 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 144 + let hostname_for_validation = pds_hostname_without_port(); 169 145 let pds_suffix = format!(".{}", hostname_for_validation); 170 146 171 147 let validated_short_handle = if !input.handle.contains('.') ··· 242 218 _ => return ApiError::InvalidVerificationChannel.into_response(), 243 219 }) 244 220 }; 245 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 246 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 221 + let hostname = pds_hostname(); 222 + let hostname_for_handles = pds_hostname_without_port(); 247 223 let pds_endpoint = format!("https://{}", hostname); 248 224 let suffix = format!(".{}", hostname_for_handles); 249 225 let handle = if input.handle.ends_with(&suffix) { ··· 308 284 } 309 285 if !is_did_web_byod 310 286 && let Err(e) = 311 - verify_did_web(d, &hostname, &input.handle, input.signing_key.as_deref()).await 287 + verify_did_web(d, hostname, &input.handle, input.signing_key.as_deref()).await 312 288 { 313 289 return ApiError::InvalidDid(e).into_response(); 314 290 } ··· 324 300 if !is_did_web_byod 325 301 && let Err(e) = verify_did_web( 326 302 d, 327 - &hostname, 303 + hostname, 328 304 &input.handle, 329 305 input.signing_key.as_deref(), 330 306 ) ··· 478 454 error!("Error creating session: {:?}", e); 479 455 return ApiError::InternalError(None).into_response(); 480 456 } 481 - let hostname = 482 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 457 + let hostname = pds_hostname(); 483 458 let verification_required = if let Some(ref user_email) = email { 484 459 let token = 485 460 crate::auth::verification_token::generate_migration_token(&did, user_email); ··· 491 466 reactivated.user_id, 492 467 user_email, 493 468 &formatted_token, 494 - &hostname, 469 + hostname, 495 470 ) 496 471 .await 497 472 { ··· 756 731 warn!("Failed to create default profile for {}: {}", did, e); 757 732 } 758 733 } 759 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 734 + let hostname = pds_hostname(); 760 735 if !is_migration { 761 736 if let Some(ref recipient) = verification_recipient { 762 737 let verification_token = crate::auth::verification_token::generate_signup_token( ··· 772 747 verification_channel, 773 748 recipient, 774 749 &formatted_token, 775 - &hostname, 750 + hostname, 776 751 ) 777 752 .await 778 753 { ··· 791 766 user_id, 792 767 user_email, 793 768 &formatted_token, 794 - &hostname, 769 + hostname, 795 770 ) 796 771 .await 797 772 {
+23 -29
crates/tranquil-pds/src/api/identity/did.rs
··· 1 1 use crate::api::{ApiError, DidResponse, EmptyResponse}; 2 2 use crate::auth::{Auth, NotTakendown}; 3 3 use crate::plc::signing_key_to_did_key; 4 + use crate::rate_limit::{HandleUpdateDailyLimit, HandleUpdateLimit, check_user_rate_limit_with_message}; 4 5 use crate::state::AppState; 5 6 use crate::types::Handle; 7 + use crate::util::{get_header_str, pds_hostname, pds_hostname_without_port}; 6 8 use axum::{ 7 9 Json, 8 10 extract::{Path, Query, State}, ··· 101 103 } 102 104 103 105 pub async fn well_known_did(State(state): State<AppState>, headers: HeaderMap) -> Response { 104 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 105 - let host_header = headers 106 - .get("host") 107 - .and_then(|h| h.to_str().ok()) 108 - .unwrap_or(&hostname); 106 + let hostname = pds_hostname(); 107 + let hostname_without_port = pds_hostname_without_port(); 108 + let host_header = get_header_str(&headers, "host").unwrap_or(hostname); 109 109 let host_without_port = host_header.split(':').next().unwrap_or(host_header); 110 - let hostname_without_port = hostname.split(':').next().unwrap_or(&hostname); 111 110 if host_without_port != hostname_without_port 112 111 && host_without_port.ends_with(&format!(".{}", hostname_without_port)) 113 112 { 114 113 let handle = host_without_port 115 114 .strip_suffix(&format!(".{}", hostname_without_port)) 116 115 .unwrap_or(host_without_port); 117 - return serve_subdomain_did_doc(&state, handle, &hostname).await; 116 + return serve_subdomain_did_doc(&state, handle, hostname).await; 118 117 } 119 118 let did = if hostname.contains(':') { 120 119 format!("did:web:{}", hostname.replace(':', "%3A")) ··· 257 256 } 258 257 259 258 pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 260 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 261 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 259 + let hostname = pds_hostname(); 260 + let hostname_for_handles = pds_hostname_without_port(); 262 261 let current_handle = format!("{}.{}", handle, hostname_for_handles); 263 262 let current_handle_typed: Handle = match current_handle.parse() { 264 263 Ok(h) => h, ··· 531 530 ApiError::AuthenticationFailed(Some("OAuth tokens cannot get DID credentials".into())) 532 531 })?; 533 532 534 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 533 + let hostname = pds_hostname(); 535 534 let pds_endpoint = format!("https://{}", hostname); 536 535 let signing_key = k256::ecdsa::SigningKey::from_slice(&key_bytes) 537 536 .map_err(|_| ApiError::InternalError(None))?; ··· 585 584 return Ok(e); 586 585 } 587 586 let did = auth.did.clone(); 588 - if !state 589 - .check_rate_limit(crate::state::RateLimitKind::HandleUpdate, &did) 590 - .await 591 - { 592 - return Err(ApiError::RateLimitExceeded(Some( 593 - "Too many handle updates. Try again later.".into(), 594 - ))); 595 - } 596 - if !state 597 - .check_rate_limit(crate::state::RateLimitKind::HandleUpdateDaily, &did) 598 - .await 599 - { 600 - return Err(ApiError::RateLimitExceeded(Some( 601 - "Daily handle update limit exceeded.".into(), 602 - ))); 603 - } 587 + let _rate_limit = check_user_rate_limit_with_message::<HandleUpdateLimit>( 588 + &state, 589 + &did, 590 + "Too many handle updates. Try again later.", 591 + ) 592 + .await?; 593 + let _daily_rate_limit = check_user_rate_limit_with_message::<HandleUpdateDailyLimit>( 594 + &state, 595 + &did, 596 + "Daily handle update limit exceeded.", 597 + ) 598 + .await?; 604 599 let user_row = state 605 600 .user_repo 606 601 .get_id_and_handle_by_did(&did) ··· 639 634 "Inappropriate language in handle".into(), 640 635 ))); 641 636 } 642 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 643 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 637 + let hostname_for_handles = pds_hostname_without_port(); 644 638 let suffix = format!(".{}", hostname_for_handles); 645 639 let is_service_domain = 646 640 crate::handle::is_service_domain_handle(&new_handle, hostname_for_handles); ··· 772 766 } 773 767 774 768 pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response { 775 - let host = match headers.get("host").and_then(|h| h.to_str().ok()) { 769 + let host = match crate::util::get_header_str(&headers, "host") { 776 770 Some(h) => h, 777 771 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(), 778 772 };
+7 -12
crates/tranquil-pds/src/api/identity/plc/request.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 3 use crate::auth::{Auth, Permissive}; 4 4 use crate::state::AppState; 5 + use crate::util::pds_hostname; 5 6 use axum::{ 6 7 extract::State, 7 8 response::{IntoResponse, Response}, 8 9 }; 9 10 use chrono::{Duration, Utc}; 10 - use tracing::{error, info, warn}; 11 + use tracing::{info, warn}; 11 12 12 13 fn generate_plc_token() -> String { 13 14 crate::util::generate_token_code() ··· 28 29 .user_repo 29 30 .get_id_by_did(&auth.did) 30 31 .await 31 - .map_err(|e| { 32 - error!("DB error: {:?}", e); 33 - ApiError::InternalError(None) 34 - })? 32 + .log_db_err("fetching user id")? 35 33 .ok_or(ApiError::AccountNotFound)?; 36 34 37 35 let _ = state.infra_repo.delete_plc_tokens_for_user(user_id).await; ··· 41 39 .infra_repo 42 40 .insert_plc_token(user_id, &plc_token, expires_at) 43 41 .await 44 - .map_err(|e| { 45 - error!("Failed to create PLC token: {:?}", e); 46 - ApiError::InternalError(None) 47 - })?; 42 + .log_db_err("creating PLC token")?; 48 43 49 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 44 + let hostname = pds_hostname(); 50 45 if let Err(e) = crate::comms::comms_repo::enqueue_plc_operation( 51 46 state.user_repo.as_ref(), 52 47 state.infra_repo.as_ref(), 53 48 user_id, 54 49 &plc_token, 55 - &hostname, 50 + hostname, 56 51 ) 57 52 .await 58 53 {
+4 -12
crates/tranquil-pds/src/api/identity/plc/sign.rs
··· 1 + use crate::api::error::DbResultExt; 1 2 use crate::api::ApiError; 2 3 use crate::auth::{Auth, Permissive}; 3 4 use crate::circuit_breaker::with_circuit_breaker; ··· 64 65 .user_repo 65 66 .get_id_by_did(did) 66 67 .await 67 - .map_err(|e| { 68 - error!("DB error: {:?}", e); 69 - ApiError::InternalError(None) 70 - })? 68 + .log_db_err("fetching user id")? 71 69 .ok_or(ApiError::AccountNotFound)?; 72 70 73 71 let token_expiry = state 74 72 .infra_repo 75 73 .get_plc_token_expiry(user_id, token) 76 74 .await 77 - .map_err(|e| { 78 - error!("DB error: {:?}", e); 79 - ApiError::InternalError(None) 80 - })? 75 + .log_db_err("fetching PLC token expiry")? 81 76 .ok_or_else(|| ApiError::InvalidToken(Some("Invalid or expired token".into())))?; 82 77 83 78 if Utc::now() > token_expiry { ··· 88 83 .user_repo 89 84 .get_user_key_by_id(user_id) 90 85 .await 91 - .map_err(|e| { 92 - error!("DB error: {:?}", e); 93 - ApiError::InternalError(None) 94 - })? 86 + .log_db_err("fetching user key")? 95 87 .ok_or_else(|| ApiError::InternalError(Some("User signing key not found".into())))?; 96 88 97 89 let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
+5 -9
crates/tranquil-pds/src/api/identity/plc/submit.rs
··· 1 + use crate::api::error::DbResultExt; 1 2 use crate::api::{ApiError, EmptyResponse}; 2 3 use crate::auth::{Auth, Permissive}; 3 4 use crate::circuit_breaker::with_circuit_breaker; 4 5 use crate::plc::{PlcClient, signing_key_to_did_key, validate_plc_operation}; 5 6 use crate::state::AppState; 7 + use crate::util::pds_hostname; 6 8 use axum::{ 7 9 Json, 8 10 extract::State, ··· 40 42 .map_err(|e| ApiError::InvalidRequest(format!("Invalid operation: {}", e)))?; 41 43 42 44 let op = &input.operation; 43 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 45 + let hostname = pds_hostname(); 44 46 let public_url = format!("https://{}", hostname); 45 47 let user = state 46 48 .user_repo 47 49 .get_id_and_handle_by_did(did) 48 50 .await 49 - .map_err(|e| { 50 - error!("DB error: {:?}", e); 51 - ApiError::InternalError(None) 52 - })? 51 + .log_db_err("fetching user")? 53 52 .ok_or(ApiError::AccountNotFound)?; 54 53 55 54 let key_row = state 56 55 .user_repo 57 56 .get_user_key_by_id(user.id) 58 57 .await 59 - .map_err(|e| { 60 - error!("DB error: {:?}", e); 61 - ApiError::InternalError(None) 62 - })? 58 + .log_db_err("fetching user key")? 63 59 .ok_or_else(|| ApiError::InternalError(Some("User signing key not found".into())))?; 64 60 65 61 let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
+3 -2
crates/tranquil-pds/src/api/notification_prefs.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::auth::{Active, Auth}; 3 3 use crate::state::AppState; 4 + use crate::util::pds_hostname; 4 5 use axum::{ 5 6 Json, 6 7 extract::State, ··· 145 146 let formatted_token = crate::auth::verification_token::format_token_for_display(&token); 146 147 147 148 if channel == "email" { 148 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 149 + let hostname = pds_hostname(); 149 150 let handle_str = handle.unwrap_or("user"); 150 151 crate::comms::comms_repo::enqueue_email_update( 151 152 state.infra_repo.as_ref(), ··· 153 154 identifier, 154 155 handle_str, 155 156 &formatted_token, 156 - &hostname, 157 + hostname, 157 158 ) 158 159 .await 159 160 .map_err(|e| format!("Failed to enqueue email notification: {}", e))?;
+4 -7
crates/tranquil-pds/src/api/proxy.rs
··· 3 3 use crate::api::error::ApiError; 4 4 use crate::api::proxy_client::proxy_client; 5 5 use crate::state::AppState; 6 + use crate::util::get_header_str; 6 7 use axum::{ 7 8 body::Bytes, 8 9 extract::{RawQuery, Request, State}, ··· 191 192 .into_response(); 192 193 } 193 194 194 - let Some(proxy_header) = headers 195 - .get("atproto-proxy") 196 - .and_then(|h| h.to_str().ok()) 197 - .map(String::from) 198 - else { 195 + let Some(proxy_header) = get_header_str(&headers, "atproto-proxy").map(String::from) else { 199 196 return ApiError::InvalidRequest("Missing required atproto-proxy header".into()) 200 197 .into_response(); 201 198 }; ··· 217 214 218 215 let mut auth_header_val = headers.get("Authorization").cloned(); 219 216 if let Some(extracted) = crate::auth::extract_auth_token_from_header( 220 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 217 + crate::util::get_header_str(&headers, "Authorization"), 221 218 ) { 222 219 let token = extracted.token; 223 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 220 + let dpop_proof = crate::util::get_header_str(&headers, "DPoP"); 224 221 let http_uri = crate::util::build_full_url(&uri.to_string()); 225 222 226 223 match crate::auth::validate_token_with_dpop(
+13 -26
crates/tranquil-pds/src/api/repo/blob.rs
··· 1 - use crate::api::error::ApiError; 2 - use crate::auth::{Auth, AuthAny, NotTakendown, Permissive}; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 + use crate::auth::{Auth, AuthAny, NotTakendown, Permissive, VerifyScope}; 3 3 use crate::delegation::DelegationActionType; 4 4 use crate::state::AppState; 5 5 use crate::types::{CidLink, Did}; 6 - use crate::util::get_max_blob_size; 6 + use crate::util::{get_header_str, get_max_blob_size}; 7 7 use axum::body::Body; 8 8 use axum::{ 9 9 Json, ··· 56 56 if user.status.is_takendown() { 57 57 return Err(ApiError::AccountTakedown); 58 58 } 59 - let mime_type_for_check = headers 60 - .get("content-type") 61 - .and_then(|h| h.to_str().ok()) 62 - .unwrap_or("application/octet-stream"); 63 - if let Err(e) = crate::auth::scope_check::check_blob_scope( 64 - user.is_oauth(), 65 - user.scope.as_deref(), 66 - mime_type_for_check, 67 - ) { 68 - return Ok(e); 69 - } 59 + let mime_type_for_check = 60 + get_header_str(&headers, "content-type").unwrap_or("application/octet-stream"); 61 + let _scope_proof = match user.verify_blob_upload(mime_type_for_check) { 62 + Ok(proof) => proof, 63 + Err(e) => return Ok(e.into_response()), 64 + }; 70 65 (user.did.clone(), user.controller_did.clone()) 71 66 } 72 67 }; ··· 80 75 return Err(ApiError::Forbidden); 81 76 } 82 77 83 - let client_mime_hint = headers 84 - .get("content-type") 85 - .and_then(|h| h.to_str().ok()) 86 - .unwrap_or("application/octet-stream"); 78 + let client_mime_hint = 79 + get_header_str(&headers, "content-type").unwrap_or("application/octet-stream"); 87 80 88 81 let user_id = state 89 82 .user_repo ··· 232 225 .user_repo 233 226 .get_by_did(did) 234 227 .await 235 - .map_err(|e| { 236 - error!("DB error fetching user: {:?}", e); 237 - ApiError::InternalError(None) 238 - })? 228 + .log_db_err("fetching user")? 239 229 .ok_or(ApiError::InternalError(None))?; 240 230 241 231 let limit = params.limit.unwrap_or(500).clamp(1, 1000); ··· 244 234 .blob_repo 245 235 .list_missing_blobs(user.id, cursor, limit + 1) 246 236 .await 247 - .map_err(|e| { 248 - error!("DB error fetching missing blobs: {:?}", e); 249 - ApiError::InternalError(None) 250 - })?; 237 + .log_db_err("fetching missing blobs")?; 251 238 252 239 let has_more = missing.len() > limit as usize; 253 240 let blobs: Vec<RecordBlob> = missing
+2 -5
crates/tranquil-pds/src/api/repo/import.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 3 use crate::api::repo::record::create_signed_commit; 4 4 use crate::auth::{Auth, NotTakendown}; 5 5 use crate::state::AppState; ··· 49 49 .user_repo 50 50 .get_by_did(did) 51 51 .await 52 - .map_err(|e| { 53 - error!("DB error fetching user: {:?}", e); 54 - ApiError::InternalError(None) 55 - })? 52 + .log_db_err("fetching user")? 56 53 .ok_or(ApiError::AccountNotFound)?; 57 54 if user.takedown_ref.is_some() { 58 55 return Err(ApiError::AccountTakedown);
+2 -2
crates/tranquil-pds/src/api/repo/meta.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::state::AppState; 3 3 use crate::types::AtIdentifier; 4 + use crate::util::pds_hostname_without_port; 4 5 use axum::{ 5 6 Json, 6 7 extract::{Query, State}, ··· 18 19 State(state): State<AppState>, 19 20 Query(input): Query<DescribeRepoInput>, 20 21 ) -> Response { 21 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 22 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 22 + let hostname_for_handles = pds_hostname_without_port(); 23 23 let user_row = if input.repo.is_did() { 24 24 let did: crate::types::Did = match input.repo.as_str().parse() { 25 25 Ok(d) => d,
+16 -35
crates/tranquil-pds/src/api/repo/record/batch.rs
··· 1 1 use super::validation::validate_record_with_status; 2 2 use crate::api::error::ApiError; 3 3 use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids}; 4 - use crate::auth::{Active, Auth}; 4 + use crate::auth::{Active, Auth, VerifyScope}; 5 5 use crate::delegation::DelegationActionType; 6 6 use crate::repo::tracking::TrackingBlockStore; 7 7 use crate::state::AppState; ··· 271 271 input.writes.len() 272 272 ); 273 273 let did = auth.did.clone(); 274 - let is_oauth = auth.is_oauth(); 275 - let scope = auth.scope.clone(); 276 274 let controller_did = auth.controller_did.clone(); 277 275 if input.repo.as_str() != did { 278 276 return Err(ApiError::InvalidRepo( ··· 310 308 ))); 311 309 } 312 310 313 - let has_custom_scope = scope 314 - .as_ref() 315 - .map(|s| s != "com.atproto.access") 316 - .unwrap_or(false); 317 - if is_oauth || has_custom_scope { 311 + { 318 312 use std::collections::HashSet; 319 313 let create_collections: HashSet<&Nsid> = input 320 314 .writes ··· 350 344 }) 351 345 .collect(); 352 346 353 - let scope_checks = create_collections 354 - .iter() 355 - .map(|c| (crate::oauth::RepoAction::Create, c)) 356 - .chain( 357 - update_collections 358 - .iter() 359 - .map(|c| (crate::oauth::RepoAction::Update, c)), 360 - ) 361 - .chain( 362 - delete_collections 363 - .iter() 364 - .map(|c| (crate::oauth::RepoAction::Delete, c)), 365 - ); 366 - 367 - if let Some(err) = scope_checks 368 - .filter_map(|(action, collection)| { 369 - crate::auth::scope_check::check_repo_scope( 370 - is_oauth, 371 - scope.as_deref(), 372 - action, 373 - collection, 374 - ) 375 - .err() 376 - }) 377 - .next() 378 - { 379 - return Ok(err); 347 + for collection in &create_collections { 348 + if let Err(e) = auth.verify_repo_create(collection) { 349 + return Ok(e.into_response()); 350 + } 351 + } 352 + for collection in &update_collections { 353 + if let Err(e) = auth.verify_repo_update(collection) { 354 + return Ok(e.into_response()); 355 + } 356 + } 357 + for collection in &delete_collections { 358 + if let Err(e) = auth.verify_repo_delete(collection) { 359 + return Ok(e.into_response()); 360 + } 380 361 } 381 362 } 382 363
+6 -10
crates/tranquil-pds/src/api/repo/record/delete.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log}; 3 3 use crate::api::repo::record::write::{CommitInfo, prepare_repo_write}; 4 - use crate::auth::{Active, Auth}; 4 + use crate::auth::{Active, Auth, VerifyScope}; 5 5 use crate::delegation::DelegationActionType; 6 6 use crate::repo::tracking::TrackingBlockStore; 7 7 use crate::state::AppState; ··· 43 43 auth: Auth<Active>, 44 44 Json(input): Json<DeleteRecordInput>, 45 45 ) -> Result<Response, crate::api::error::ApiError> { 46 + let _scope_proof = match auth.verify_repo_delete(&input.collection) { 47 + Ok(proof) => proof, 48 + Err(e) => return Ok(e.into_response()), 49 + }; 50 + 46 51 let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await { 47 52 Ok(res) => res, 48 53 Err(err_res) => return Ok(err_res), 49 54 }; 50 55 51 - if let Err(e) = crate::auth::scope_check::check_repo_scope( 52 - repo_auth.is_oauth, 53 - repo_auth.scope.as_deref(), 54 - crate::oauth::RepoAction::Delete, 55 - &input.collection, 56 - ) { 57 - return Ok(e); 58 - } 59 - 60 56 let did = repo_auth.did; 61 57 let user_id = repo_auth.user_id; 62 58 let current_root_cid = repo_auth.current_root_cid;
+3 -4
crates/tranquil-pds/src/api/repo/record/read.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::state::AppState; 3 3 use crate::types::{AtIdentifier, Nsid, Rkey}; 4 + use crate::util::pds_hostname_without_port; 4 5 use axum::{ 5 6 Json, 6 7 extract::{Query, State}, ··· 58 59 _headers: HeaderMap, 59 60 Query(input): Query<GetRecordInput>, 60 61 ) -> Response { 61 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 62 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 62 + let hostname_for_handles = pds_hostname_without_port(); 63 63 let user_id_opt = if input.repo.is_did() { 64 64 let did: crate::types::Did = match input.repo.as_str().parse() { 65 65 Ok(d) => d, ··· 157 157 State(state): State<AppState>, 158 158 Query(input): Query<ListRecordsInput>, 159 159 ) -> Response { 160 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 161 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 160 + let hostname_for_handles = pds_hostname_without_port(); 162 161 let user_id_opt = if input.repo.is_did() { 163 162 let did: crate::types::Did = match input.repo.as_str().parse() { 164 163 Ok(d) => d,
+15 -27
crates/tranquil-pds/src/api/repo/record/write.rs
··· 3 3 use crate::api::repo::record::utils::{ 4 4 CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids, 5 5 }; 6 - use crate::auth::{Active, Auth}; 6 + use crate::auth::{Active, Auth, VerifyScope}; 7 7 use crate::delegation::DelegationActionType; 8 8 use crate::repo::tracking::TrackingBlockStore; 9 9 use crate::state::AppState; ··· 127 127 auth: Auth<Active>, 128 128 Json(input): Json<CreateRecordInput>, 129 129 ) -> Result<Response, crate::api::error::ApiError> { 130 + let _scope_proof = match auth.verify_repo_create(&input.collection) { 131 + Ok(proof) => proof, 132 + Err(e) => return Ok(e.into_response()), 133 + }; 134 + 130 135 let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await { 131 136 Ok(res) => res, 132 137 Err(err_res) => return Ok(err_res), 133 138 }; 134 139 135 - if let Err(e) = crate::auth::scope_check::check_repo_scope( 136 - repo_auth.is_oauth, 137 - repo_auth.scope.as_deref(), 138 - crate::oauth::RepoAction::Create, 139 - &input.collection, 140 - ) { 141 - return Ok(e); 142 - } 143 - 144 140 let did = repo_auth.did; 145 141 let user_id = repo_auth.user_id; 146 142 let current_root_cid = repo_auth.current_root_cid; ··· 434 430 auth: Auth<Active>, 435 431 Json(input): Json<PutRecordInput>, 436 432 ) -> Result<Response, crate::api::error::ApiError> { 433 + let _create_proof = match auth.verify_repo_create(&input.collection) { 434 + Ok(proof) => proof, 435 + Err(e) => return Ok(e.into_response()), 436 + }; 437 + let _update_proof = match auth.verify_repo_update(&input.collection) { 438 + Ok(proof) => proof, 439 + Err(e) => return Ok(e.into_response()), 440 + }; 441 + 437 442 let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await { 438 443 Ok(res) => res, 439 444 Err(err_res) => return Ok(err_res), 440 445 }; 441 446 442 - if let Err(e) = crate::auth::scope_check::check_repo_scope( 443 - repo_auth.is_oauth, 444 - repo_auth.scope.as_deref(), 445 - crate::oauth::RepoAction::Create, 446 - &input.collection, 447 - ) { 448 - return Ok(e); 449 - } 450 - if let Err(e) = crate::auth::scope_check::check_repo_scope( 451 - repo_auth.is_oauth, 452 - repo_auth.scope.as_deref(), 453 - crate::oauth::RepoAction::Update, 454 - &input.collection, 455 - ) { 456 - return Ok(e); 457 - } 458 - 459 447 let did = repo_auth.did; 460 448 let user_id = repo_auth.user_id; 461 449 let current_root_cid = repo_auth.current_root_cid;
+18 -26
crates/tranquil-pds/src/api/server/account_status.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 3 - use crate::auth::{Auth, NotTakendown, Permissive}; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 + use crate::auth::{Auth, NotTakendown, Permissive, require_legacy_session_mfa}; 4 4 use crate::cache::Cache; 5 5 use crate::plc::PlcClient; 6 6 use crate::state::AppState; 7 7 use crate::types::PlainPassword; 8 + use crate::util::pds_hostname; 8 9 use axum::{ 9 10 Json, 10 11 extract::State, ··· 130 131 did: &crate::types::Did, 131 132 with_retry: bool, 132 133 ) -> Result<(), ApiError> { 133 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 134 + let hostname = pds_hostname(); 134 135 let expected_endpoint = format!("https://{}", hostname); 135 136 136 137 if did.as_str().starts_with("did:plc:") { ··· 219 220 .and_then(|v| v.get("atproto")) 220 221 .and_then(|k| k.as_str()); 221 222 222 - let user_key = user_repo.get_user_key_by_did(did).await.map_err(|e| { 223 - error!("Failed to fetch user key: {:?}", e); 224 - ApiError::InternalError(None) 225 - })?; 223 + let user_key = user_repo 224 + .get_user_key_by_did(did) 225 + .await 226 + .log_db_err("fetching user key")?; 226 227 227 228 if let Some(key_info) = user_key { 228 229 let key_bytes = ··· 523 524 State(state): State<AppState>, 524 525 auth: Auth<NotTakendown>, 525 526 ) -> Result<Response, ApiError> { 526 - let did = &auth.did; 527 - 528 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, did).await { 529 - return Ok(crate::api::server::reauth::legacy_mfa_required_response( 530 - &*state.user_repo, 531 - &*state.session_repo, 532 - did, 533 - ) 534 - .await); 535 - } 527 + let session_mfa = match require_legacy_session_mfa(&state, &auth).await { 528 + Ok(proof) => proof, 529 + Err(response) => return Ok(response), 530 + }; 536 531 537 532 let user_id = state 538 533 .user_repo 539 - .get_id_by_did(did) 534 + .get_id_by_did(session_mfa.did()) 540 535 .await 541 536 .ok() 542 537 .flatten() ··· 545 540 let expires_at = Utc::now() + Duration::minutes(15); 546 541 state 547 542 .infra_repo 548 - .create_deletion_request(&confirmation_token, did, expires_at) 543 + .create_deletion_request(&confirmation_token, session_mfa.did(), expires_at) 549 544 .await 550 - .map_err(|e| { 551 - error!("DB error creating deletion token: {:?}", e); 552 - ApiError::InternalError(None) 553 - })?; 554 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 545 + .log_db_err("creating deletion token")?; 546 + let hostname = pds_hostname(); 555 547 if let Err(e) = crate::comms::comms_repo::enqueue_account_deletion( 556 548 state.user_repo.as_ref(), 557 549 state.infra_repo.as_ref(), 558 550 user_id, 559 551 &confirmation_token, 560 - &hostname, 552 + hostname, 561 553 ) 562 554 .await 563 555 { 564 556 warn!("Failed to enqueue account deletion notification: {:?}", e); 565 557 } 566 - info!("Account deletion requested for user {}", did); 558 + info!("Account deletion requested for user {}", session_mfa.did()); 567 559 Ok(EmptyResponse::ok().into_response()) 568 560 } 569 561
+13 -46
crates/tranquil-pds/src/api/server/app_password.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 3 use crate::auth::{Auth, NotTakendown, Permissive, generate_app_password}; 4 4 use crate::delegation::{DelegationActionType, intersect_scopes}; 5 - use crate::state::{AppState, RateLimitKind}; 5 + use crate::rate_limit::{AppPasswordLimit, RateLimited}; 6 + use crate::state::AppState; 6 7 use axum::{ 7 8 Json, 8 9 extract::State, 9 - http::HeaderMap, 10 10 response::{IntoResponse, Response}, 11 11 }; 12 12 use serde::{Deserialize, Serialize}; 13 13 use serde_json::json; 14 - use tracing::{error, warn}; 14 + use tracing::error; 15 15 use tranquil_db_traits::AppPasswordCreate; 16 16 17 17 #[derive(Serialize)] ··· 39 39 .user_repo 40 40 .get_by_did(&auth.did) 41 41 .await 42 - .map_err(|e| { 43 - error!("DB error getting user: {:?}", e); 44 - ApiError::InternalError(None) 45 - })? 42 + .log_db_err("getting user")? 46 43 .ok_or(ApiError::AccountNotFound)?; 47 44 48 45 let rows = state 49 46 .session_repo 50 47 .list_app_passwords(user.id) 51 48 .await 52 - .map_err(|e| { 53 - error!("DB error listing app passwords: {:?}", e); 54 - ApiError::InternalError(None) 55 - })?; 49 + .log_db_err("listing app passwords")?; 56 50 let passwords: Vec<AppPassword> = rows 57 51 .iter() 58 52 .map(|row| AppPassword { ··· 89 83 90 84 pub async fn create_app_password( 91 85 State(state): State<AppState>, 92 - headers: HeaderMap, 86 + _rate_limit: RateLimited<AppPasswordLimit>, 93 87 auth: Auth<NotTakendown>, 94 88 Json(input): Json<CreateAppPasswordInput>, 95 89 ) -> Result<Response, ApiError> { 96 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 97 - if !state 98 - .check_rate_limit(RateLimitKind::AppPassword, &client_ip) 99 - .await 100 - { 101 - warn!(ip = %client_ip, "App password creation rate limit exceeded"); 102 - return Err(ApiError::RateLimitExceeded(None)); 103 - } 104 - 105 90 let user = state 106 91 .user_repo 107 92 .get_by_did(&auth.did) 108 93 .await 109 - .map_err(|e| { 110 - error!("DB error getting user: {:?}", e); 111 - ApiError::InternalError(None) 112 - })? 94 + .log_db_err("getting user")? 113 95 .ok_or(ApiError::AccountNotFound)?; 114 96 115 97 let name = input.name.trim(); ··· 121 103 .session_repo 122 104 .get_app_password_by_name(user.id, name) 123 105 .await 124 - .map_err(|e| { 125 - error!("DB error checking app password: {:?}", e); 126 - ApiError::InternalError(None) 127 - })? 106 + .log_db_err("checking app password")? 128 107 .is_some() 129 108 { 130 109 return Err(ApiError::DuplicateAppPassword); ··· 187 166 .session_repo 188 167 .create_app_password(&create_data) 189 168 .await 190 - .map_err(|e| { 191 - error!("DB error creating app password: {:?}", e); 192 - ApiError::InternalError(None) 193 - })?; 169 + .log_db_err("creating app password")?; 194 170 195 171 if let Some(ref controller) = controller_did { 196 172 let _ = state ··· 234 210 .user_repo 235 211 .get_by_did(&auth.did) 236 212 .await 237 - .map_err(|e| { 238 - error!("DB error getting user: {:?}", e); 239 - ApiError::InternalError(None) 240 - })? 213 + .log_db_err("getting user")? 241 214 .ok_or(ApiError::AccountNotFound)?; 242 215 243 216 let name = input.name.trim(); ··· 255 228 .session_repo 256 229 .delete_sessions_by_app_password(&auth.did, name) 257 230 .await 258 - .map_err(|e| { 259 - error!("DB error revoking sessions for app password: {:?}", e); 260 - ApiError::InternalError(None) 261 - })?; 231 + .log_db_err("revoking sessions for app password")?; 262 232 263 233 futures::future::join_all(sessions_to_invalidate.iter().map(|jti| { 264 234 let cache_key = format!("auth:session:{}:{}", &auth.did, jti); ··· 273 243 .session_repo 274 244 .delete_app_password(user.id, name) 275 245 .await 276 - .map_err(|e| { 277 - error!("DB error revoking app password: {:?}", e); 278 - ApiError::InternalError(None) 279 - })?; 246 + .log_db_err("revoking app password")?; 280 247 281 248 Ok(EmptyResponse::ok().into_response()) 282 249 }
+21 -92
crates/tranquil-pds/src/api/server/email.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::api::{EmptyResponse, TokenRequiredResponse, VerifiedResponse}; 3 3 use crate::auth::{Auth, NotTakendown}; 4 - use crate::state::{AppState, RateLimitKind}; 4 + use crate::rate_limit::{EmailUpdateLimit, RateLimited, VerificationCheckLimit}; 5 + use crate::state::AppState; 6 + use crate::util::pds_hostname; 5 7 use axum::{ 6 8 Json, 7 9 extract::State, ··· 44 46 45 47 pub async fn request_email_update( 46 48 State(state): State<AppState>, 47 - headers: axum::http::HeaderMap, 49 + _rate_limit: RateLimited<EmailUpdateLimit>, 48 50 auth: Auth<NotTakendown>, 49 51 input: Option<Json<RequestEmailUpdateInput>>, 50 52 ) -> Result<Response, ApiError> { 51 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 52 - if !state 53 - .check_rate_limit(RateLimitKind::EmailUpdate, &client_ip) 54 - .await 55 - { 56 - warn!(ip = %client_ip, "Email update rate limit exceeded"); 57 - return Err(ApiError::RateLimitExceeded(None)); 58 - } 59 - 60 53 if let Err(e) = crate::auth::scope_check::check_account_scope( 61 54 auth.is_oauth(), 62 55 auth.scope.as_deref(), ··· 70 63 .user_repo 71 64 .get_email_info_by_did(&auth.did) 72 65 .await 73 - .map_err(|e| { 74 - error!("DB error: {:?}", e); 75 - ApiError::InternalError(None) 76 - })? 66 + .log_db_err("getting email info")? 77 67 .ok_or(ApiError::AccountNotFound)?; 78 68 79 69 let Some(current_email) = user.email else { ··· 111 101 } 112 102 } 113 103 114 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 104 + let hostname = pds_hostname(); 115 105 if let Err(e) = crate::comms::comms_repo::enqueue_email_update_token( 116 106 state.user_repo.as_ref(), 117 107 state.infra_repo.as_ref(), 118 108 user.id, 119 109 &code, 120 110 &formatted_code, 121 - &hostname, 111 + hostname, 122 112 ) 123 113 .await 124 114 { ··· 139 129 140 130 pub async fn confirm_email( 141 131 State(state): State<AppState>, 142 - headers: axum::http::HeaderMap, 132 + _rate_limit: RateLimited<EmailUpdateLimit>, 143 133 auth: Auth<NotTakendown>, 144 134 Json(input): Json<ConfirmEmailInput>, 145 135 ) -> Result<Response, ApiError> { 146 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 147 - if !state 148 - .check_rate_limit(RateLimitKind::EmailUpdate, &client_ip) 149 - .await 150 - { 151 - warn!(ip = %client_ip, "Confirm email rate limit exceeded"); 152 - return Err(ApiError::RateLimitExceeded(None)); 153 - } 154 - 155 136 if let Err(e) = crate::auth::scope_check::check_account_scope( 156 137 auth.is_oauth(), 157 138 auth.scope.as_deref(), ··· 166 147 .user_repo 167 148 .get_email_info_by_did(did) 168 149 .await 169 - .map_err(|e| { 170 - error!("DB error: {:?}", e); 171 - ApiError::InternalError(None) 172 - })? 150 + .log_db_err("getting email info")? 173 151 .ok_or(ApiError::AccountNotFound)?; 174 152 175 153 let Some(ref email) = user.email else { ··· 213 191 .user_repo 214 192 .set_email_verified(user.id, true) 215 193 .await 216 - .map_err(|e| { 217 - error!("DB error confirming email: {:?}", e); 218 - ApiError::InternalError(None) 219 - })?; 194 + .log_db_err("confirming email")?; 220 195 221 196 info!("Email confirmed for user {}", user.id); 222 197 Ok(EmptyResponse::ok().into_response()) ··· 250 225 .user_repo 251 226 .get_email_info_by_did(did) 252 227 .await 253 - .map_err(|e| { 254 - error!("DB error: {:?}", e); 255 - ApiError::InternalError(None) 256 - })? 228 + .log_db_err("getting email info")? 257 229 .ok_or(ApiError::AccountNotFound)?; 258 230 259 231 let user_id = user.id; ··· 325 297 .user_repo 326 298 .update_email(user_id, &new_email) 327 299 .await 328 - .map_err(|e| { 329 - error!("DB error updating email: {:?}", e); 330 - ApiError::InternalError(None) 331 - })?; 300 + .log_db_err("updating email")?; 332 301 333 302 let verification_token = 334 303 crate::auth::verification_token::generate_signup_token(did, "email", &new_email); 335 304 let formatted_token = 336 305 crate::auth::verification_token::format_token_for_display(&verification_token); 337 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 306 + let hostname = pds_hostname(); 338 307 if let Err(e) = crate::comms::comms_repo::enqueue_signup_verification( 339 308 state.infra_repo.as_ref(), 340 309 user_id, 341 310 "email", 342 311 &new_email, 343 312 &formatted_token, 344 - &hostname, 313 + hostname, 345 314 ) 346 315 .await 347 316 { ··· 371 340 372 341 pub async fn check_email_verified( 373 342 State(state): State<AppState>, 374 - headers: axum::http::HeaderMap, 343 + _rate_limit: RateLimited<VerificationCheckLimit>, 375 344 Json(input): Json<CheckEmailVerifiedInput>, 376 345 ) -> Response { 377 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 378 - if !state 379 - .check_rate_limit(RateLimitKind::VerificationCheck, &client_ip) 380 - .await 381 - { 382 - return ApiError::RateLimitExceeded(None).into_response(); 383 - } 384 - 385 346 match state 386 347 .user_repo 387 348 .check_email_verified_by_identifier(&input.identifier) ··· 403 364 404 365 pub async fn authorize_email_update( 405 366 State(state): State<AppState>, 406 - headers: axum::http::HeaderMap, 367 + _rate_limit: RateLimited<VerificationCheckLimit>, 407 368 axum::extract::Query(query): axum::extract::Query<AuthorizeEmailUpdateQuery>, 408 369 ) -> Response { 409 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 410 - if !state 411 - .check_rate_limit(RateLimitKind::VerificationCheck, &client_ip) 412 - .await 413 - { 414 - return ApiError::RateLimitExceeded(None).into_response(); 415 - } 416 - 417 370 let verified = crate::auth::verification_token::verify_token_signature(&query.token); 418 371 419 372 let token_data = match verified { ··· 488 441 489 442 info!(did = %did, "Email update authorized via link click"); 490 443 491 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 444 + let hostname = pds_hostname(); 492 445 let redirect_url = format!( 493 446 "https://{}/app/verify?type=email-authorize-success", 494 447 hostname ··· 499 452 500 453 pub async fn check_email_update_status( 501 454 State(state): State<AppState>, 502 - headers: axum::http::HeaderMap, 455 + _rate_limit: RateLimited<VerificationCheckLimit>, 503 456 auth: Auth<NotTakendown>, 504 457 ) -> Result<Response, ApiError> { 505 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 506 - if !state 507 - .check_rate_limit(RateLimitKind::VerificationCheck, &client_ip) 508 - .await 509 - { 510 - return Err(ApiError::RateLimitExceeded(None)); 511 - } 512 - 513 458 if let Err(e) = crate::auth::scope_check::check_account_scope( 514 459 auth.is_oauth(), 515 460 auth.scope.as_deref(), ··· 549 494 550 495 pub async fn check_email_in_use( 551 496 State(state): State<AppState>, 552 - headers: axum::http::HeaderMap, 497 + _rate_limit: RateLimited<VerificationCheckLimit>, 553 498 Json(input): Json<CheckEmailInUseInput>, 554 499 ) -> Response { 555 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 556 - if !state 557 - .check_rate_limit(RateLimitKind::VerificationCheck, &client_ip) 558 - .await 559 - { 560 - return ApiError::RateLimitExceeded(None).into_response(); 561 - } 562 - 563 500 let email = input.email.trim().to_lowercase(); 564 501 if email.is_empty() { 565 502 return ApiError::InvalidRequest("email is required".into()).into_response(); ··· 587 524 588 525 pub async fn check_comms_channel_in_use( 589 526 State(state): State<AppState>, 590 - headers: axum::http::HeaderMap, 527 + _rate_limit: RateLimited<VerificationCheckLimit>, 591 528 Json(input): Json<CheckCommsChannelInUseInput>, 592 529 ) -> Response { 593 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 594 - if !state 595 - .check_rate_limit(RateLimitKind::VerificationCheck, &client_ip) 596 - .await 597 - { 598 - return ApiError::RateLimitExceeded(None).into_response(); 599 - } 600 - 601 530 let channel = match input.channel.to_lowercase().as_str() { 602 531 "email" => CommsChannel::Email, 603 532 "discord" => CommsChannel::Discord,
+5 -9
crates/tranquil-pds/src/api/server/invite.rs
··· 1 + use crate::api::error::DbResultExt; 1 2 use crate::api::ApiError; 2 3 use crate::auth::{Admin, Auth, NotTakendown}; 3 4 use crate::state::AppState; 4 5 use crate::types::Did; 6 + use crate::util::pds_hostname; 5 7 use axum::{ 6 8 Json, 7 9 extract::State, ··· 24 26 } 25 27 26 28 fn gen_invite_code() -> String { 27 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 29 + let hostname = pds_hostname(); 28 30 let hostname_prefix = hostname.replace('.', "-"); 29 31 format!("{}-{}", hostname_prefix, gen_random_token()) 30 32 } ··· 121 123 .user_repo 122 124 .get_any_admin_user_id() 123 125 .await 124 - .map_err(|e| { 125 - error!("DB error looking up admin user: {:?}", e); 126 - ApiError::InternalError(None) 127 - })? 126 + .log_db_err("looking up admin user")? 128 127 .ok_or_else(|| { 129 128 error!("No admin user found to create invite codes"); 130 129 ApiError::InternalError(None) ··· 202 201 .infra_repo 203 202 .get_invite_codes_for_account(&auth.did) 204 203 .await 205 - .map_err(|e| { 206 - error!("DB error fetching invite codes: {:?}", e); 207 - ApiError::InternalError(None) 208 - })?; 204 + .log_db_err("fetching invite codes")?; 209 205 210 206 let filtered_codes: Vec<_> = codes_info 211 207 .into_iter()
+3 -2
crates/tranquil-pds/src/api/server/meta.rs
··· 1 1 use crate::state::AppState; 2 + use crate::util::pds_hostname; 2 3 use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; 3 4 use serde_json::json; 4 5 ··· 30 31 } 31 32 32 33 pub async fn describe_server() -> impl IntoResponse { 33 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 34 + let pds_hostname = pds_hostname(); 34 35 let domains_str = 35 - std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| pds_hostname.clone()); 36 + std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| pds_hostname.to_string()); 36 37 let domains: Vec<&str> = domains_str.split(',').map(|s| s.trim()).collect(); 37 38 let invite_code_required = std::env::var("INVITE_CODE_REQUIRED") 38 39 .map(|v| v == "true" || v == "1")
+6 -13
crates/tranquil-pds/src/api/server/migration.rs
··· 1 + use crate::api::error::DbResultExt; 1 2 use crate::api::ApiError; 2 3 use crate::auth::{Active, Auth}; 3 4 use crate::state::AppState; 5 + use crate::util::pds_hostname; 4 6 use axum::{ 5 7 Json, 6 8 extract::State, ··· 49 51 .user_repo 50 52 .get_user_for_did_doc(&auth.did) 51 53 .await 52 - .map_err(|e| { 53 - tracing::error!("DB error getting user: {:?}", e); 54 - ApiError::InternalError(None) 55 - })? 54 + .log_db_err("getting user")? 56 55 .ok_or(ApiError::AccountNotFound)?; 57 56 58 57 if let Some(ref methods) = input.verification_methods { ··· 107 106 .user_repo 108 107 .upsert_did_web_overrides(user.id, verification_methods_json, also_known_as) 109 108 .await 110 - .map_err(|e| { 111 - tracing::error!("DB error upserting did_web_overrides: {:?}", e); 112 - ApiError::InternalError(None) 113 - })?; 109 + .log_db_err("upserting did_web_overrides")?; 114 110 115 111 if let Some(ref endpoint) = input.service_endpoint { 116 112 let endpoint_clean = endpoint.trim().trim_end_matches('/'); ··· 118 114 .user_repo 119 115 .update_migrated_to_pds(&auth.did, endpoint_clean) 120 116 .await 121 - .map_err(|e| { 122 - tracing::error!("DB error updating service endpoint: {:?}", e); 123 - ApiError::InternalError(None) 124 - })?; 117 + .log_db_err("updating service endpoint")?; 125 118 } 126 119 127 120 let did_doc = build_did_document(&state, &auth.did).await; ··· 154 147 } 155 148 156 149 async fn build_did_document(state: &AppState, did: &crate::types::Did) -> serde_json::Value { 157 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 150 + let hostname = pds_hostname(); 158 151 159 152 let user = match state.user_repo.get_user_for_did_doc_build(did).await { 160 153 Ok(Some(row)) => row,
+17 -64
crates/tranquil-pds/src/api/server/passkey_account.rs
··· 19 19 20 20 use crate::api::repo::record::utils::create_signed_commit; 21 21 use crate::auth::{ServiceTokenVerifier, generate_app_password, is_service_token}; 22 - use crate::state::{AppState, RateLimitKind}; 22 + use crate::rate_limit::{AccountCreationLimit, PasswordResetLimit, RateLimited}; 23 + use crate::state::AppState; 23 24 use crate::types::{Did, Handle, Nsid, PlainPassword, Rkey}; 25 + use crate::util::{pds_hostname, pds_hostname_without_port}; 24 26 use crate::validation::validate_password; 25 27 26 - fn extract_client_ip(headers: &HeaderMap) -> String { 27 - if let Some(forwarded) = headers.get("x-forwarded-for") 28 - && let Ok(value) = forwarded.to_str() 29 - && let Some(first_ip) = value.split(',').next() 30 - { 31 - return first_ip.trim().to_string(); 32 - } 33 - if let Some(real_ip) = headers.get("x-real-ip") 34 - && let Ok(value) = real_ip.to_str() 35 - { 36 - return value.trim().to_string(); 37 - } 38 - "unknown".to_string() 39 - } 40 - 41 28 fn generate_setup_token() -> String { 42 29 let mut rng = rand::thread_rng(); 43 30 (0..32) ··· 80 67 81 68 pub async fn create_passkey_account( 82 69 State(state): State<AppState>, 70 + _rate_limit: RateLimited<AccountCreationLimit>, 83 71 headers: HeaderMap, 84 72 Json(input): Json<CreatePasskeyAccountInput>, 85 73 ) -> Response { 86 - let client_ip = extract_client_ip(&headers); 87 - if !state 88 - .check_rate_limit(RateLimitKind::AccountCreation, &client_ip) 89 - .await 90 - { 91 - warn!(ip = %client_ip, "Account creation rate limit exceeded"); 92 - return ApiError::RateLimitExceeded(Some( 93 - "Too many account creation attempts. Please try again later.".into(), 94 - )) 95 - .into_response(); 96 - } 97 - 98 74 let byod_auth = if let Some(extracted) = crate::auth::extract_auth_token_from_header( 99 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 75 + crate::util::get_header_str(&headers, "Authorization"), 100 76 ) { 101 77 let token = extracted.token; 102 78 if is_service_token(&token) { ··· 135 111 .map(|d| d.starts_with("did:web:")) 136 112 .unwrap_or(false); 137 113 138 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 139 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 114 + let hostname = pds_hostname(); 115 + let hostname_for_handles = pds_hostname_without_port(); 140 116 let pds_suffix = format!(".{}", hostname_for_handles); 141 117 142 118 let handle = if !input.handle.contains('.') || input.handle.ends_with(&pds_suffix) { ··· 268 244 } 269 245 if is_byod_did_web { 270 246 if let Some(ref auth_did) = byod_auth 271 - && d != auth_did 247 + && d != auth_did.as_str() 272 248 { 273 249 return ApiError::AuthorizationError(format!( 274 250 "Service token issuer {} does not match DID {}", ··· 280 256 } else { 281 257 if let Err(e) = crate::api::identity::did::verify_did_web( 282 258 d, 283 - &hostname, 259 + hostname, 284 260 &input.handle, 285 261 input.signing_key.as_deref(), 286 262 ) ··· 296 272 if let Some(ref auth_did) = byod_auth { 297 273 if let Some(ref provided_did) = input.did { 298 274 if provided_did.starts_with("did:plc:") { 299 - if provided_did != auth_did { 275 + if provided_did != auth_did.as_str() { 300 276 return ApiError::AuthorizationError(format!( 301 277 "Service token issuer {} does not match DID {}", 302 278 auth_did, provided_did ··· 521 497 verification_channel, 522 498 &verification_recipient, 523 499 &formatted_token, 524 - &hostname, 500 + hostname, 525 501 ) 526 502 .await 527 503 { ··· 626 602 return ApiError::InvalidToken(None).into_response(); 627 603 } 628 604 629 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 630 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 631 - Ok(w) => w, 632 - Err(e) => { 633 - error!("Failed to create WebAuthn config: {:?}", e); 634 - return ApiError::InternalError(None).into_response(); 635 - } 636 - }; 605 + let webauthn = &state.webauthn_config; 637 606 638 607 let reg_state = match state 639 608 .user_repo ··· 768 737 return ApiError::InvalidToken(None).into_response(); 769 738 } 770 739 771 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 772 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 773 - Ok(w) => w, 774 - Err(e) => { 775 - error!("Failed to create WebAuthn config: {:?}", e); 776 - return ApiError::InternalError(None).into_response(); 777 - } 778 - }; 740 + let webauthn = &state.webauthn_config; 779 741 780 742 let existing_passkeys = state 781 743 .user_repo ··· 840 802 841 803 pub async fn request_passkey_recovery( 842 804 State(state): State<AppState>, 843 - headers: HeaderMap, 805 + _rate_limit: RateLimited<PasswordResetLimit>, 844 806 Json(input): Json<RequestPasskeyRecoveryInput>, 845 807 ) -> Response { 846 - let client_ip = extract_client_ip(&headers); 847 - if !state 848 - .check_rate_limit(RateLimitKind::PasswordReset, &client_ip) 849 - .await 850 - { 851 - return ApiError::RateLimitExceeded(None).into_response(); 852 - } 853 - 854 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 855 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 808 + let hostname_for_handles = pds_hostname_without_port(); 856 809 let identifier = input.email.trim().to_lowercase(); 857 810 let identifier = identifier.strip_prefix('@').unwrap_or(&identifier); 858 811 let normalized_handle = if identifier.contains('@') || identifier.contains('.') { ··· 890 843 return ApiError::InternalError(None).into_response(); 891 844 } 892 845 893 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 846 + let hostname = pds_hostname(); 894 847 let recovery_url = format!( 895 848 "https://{}/app/recover-passkey?did={}&token={}", 896 849 hostname, ··· 903 856 state.infra_repo.as_ref(), 904 857 user.id, 905 858 &recovery_url, 906 - &hostname, 859 + hostname, 907 860 ) 908 861 .await; 909 862
+20 -56
crates/tranquil-pds/src/api/server/passkeys.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 3 - use crate::auth::webauthn::WebAuthnConfig; 4 - use crate::auth::{Active, Auth}; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 + use crate::auth::{Active, Auth, require_legacy_session_mfa, require_reauth_window}; 5 4 use crate::state::AppState; 6 5 use axum::{ 7 6 Json, ··· 12 11 use tracing::{error, info, warn}; 13 12 use webauthn_rs::prelude::*; 14 13 15 - fn get_webauthn() -> Result<WebAuthnConfig, ApiError> { 16 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 17 - WebAuthnConfig::new(&hostname).map_err(|e| { 18 - error!("Failed to create WebAuthn config: {}", e); 19 - ApiError::InternalError(Some("WebAuthn configuration failed".into())) 20 - }) 21 - } 22 - 23 14 #[derive(Deserialize)] 24 15 #[serde(rename_all = "camelCase")] 25 16 pub struct StartRegistrationInput { ··· 37 28 auth: Auth<Active>, 38 29 Json(input): Json<StartRegistrationInput>, 39 30 ) -> Result<Response, ApiError> { 40 - let webauthn = get_webauthn()?; 31 + let webauthn = &state.webauthn_config; 41 32 42 33 let handle = state 43 34 .user_repo 44 35 .get_handle_by_did(&auth.did) 45 36 .await 46 - .map_err(|e| { 47 - error!("DB error fetching user: {:?}", e); 48 - ApiError::InternalError(None) 49 - })? 37 + .log_db_err("fetching user")? 50 38 .ok_or(ApiError::AccountNotFound)?; 51 39 52 40 let existing_passkeys = state 53 41 .user_repo 54 42 .get_passkeys_for_user(&auth.did) 55 43 .await 56 - .map_err(|e| { 57 - error!("DB error fetching existing passkeys: {:?}", e); 58 - ApiError::InternalError(None) 59 - })?; 44 + .log_db_err("fetching existing passkeys")?; 60 45 61 46 let exclude_credentials: Vec<CredentialID> = existing_passkeys 62 47 .iter() ··· 81 66 .user_repo 82 67 .save_webauthn_challenge(&auth.did, "registration", &state_json) 83 68 .await 84 - .map_err(|e| { 85 - error!("Failed to save registration state: {:?}", e); 86 - ApiError::InternalError(None) 87 - })?; 69 + .log_db_err("saving registration state")?; 88 70 89 71 let options = serde_json::to_value(&ccr).unwrap_or(serde_json::json!({})); 90 72 ··· 112 94 auth: Auth<Active>, 113 95 Json(input): Json<FinishRegistrationInput>, 114 96 ) -> Result<Response, ApiError> { 115 - let webauthn = get_webauthn()?; 97 + let webauthn = &state.webauthn_config; 116 98 117 99 let reg_state_json = state 118 100 .user_repo 119 101 .load_webauthn_challenge(&auth.did, "registration") 120 102 .await 121 - .map_err(|e| { 122 - error!("DB error loading registration state: {:?}", e); 123 - ApiError::InternalError(None) 124 - })? 103 + .log_db_err("loading registration state")? 125 104 .ok_or(ApiError::NoRegistrationInProgress)?; 126 105 127 106 let reg_state: SecurityKeyRegistration = ··· 157 136 input.friendly_name.as_deref(), 158 137 ) 159 138 .await 160 - .map_err(|e| { 161 - error!("Failed to save passkey: {:?}", e); 162 - ApiError::InternalError(None) 163 - })?; 139 + .log_db_err("saving passkey")?; 164 140 165 141 if let Err(e) = state 166 142 .user_repo ··· 208 184 .user_repo 209 185 .get_passkeys_for_user(&auth.did) 210 186 .await 211 - .map_err(|e| { 212 - error!("DB error fetching passkeys: {:?}", e); 213 - ApiError::InternalError(None) 214 - })?; 187 + .log_db_err("fetching passkeys")?; 215 188 216 189 let passkey_infos: Vec<PasskeyInfo> = passkeys 217 190 .into_iter() ··· 241 214 auth: Auth<Active>, 242 215 Json(input): Json<DeletePasskeyInput>, 243 216 ) -> Result<Response, ApiError> { 244 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 245 - { 246 - return Ok(crate::api::server::reauth::legacy_mfa_required_response( 247 - &*state.user_repo, 248 - &*state.session_repo, 249 - &auth.did, 250 - ) 251 - .await); 252 - } 217 + let session_mfa = match require_legacy_session_mfa(&state, &auth).await { 218 + Ok(proof) => proof, 219 + Err(response) => return Ok(response), 220 + }; 253 221 254 - if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.did).await { 255 - return Ok(crate::api::server::reauth::reauth_required_response( 256 - &*state.user_repo, 257 - &*state.session_repo, 258 - &auth.did, 259 - ) 260 - .await); 261 - } 222 + let reauth_mfa = match require_reauth_window(&state, &auth).await { 223 + Ok(proof) => proof, 224 + Err(response) => return Ok(response), 225 + }; 262 226 263 227 let id: uuid::Uuid = input.id.parse().map_err(|_| ApiError::InvalidId)?; 264 228 265 - match state.user_repo.delete_passkey(id, &auth.did).await { 229 + match state.user_repo.delete_passkey(id, reauth_mfa.did()).await { 266 230 Ok(true) => { 267 - info!(did = %auth.did, passkey_id = %id, "Passkey deleted"); 231 + info!(did = %session_mfa.did(), passkey_id = %id, "Passkey deleted"); 268 232 Ok(EmptyResponse::ok().into_response()) 269 233 } 270 234 Ok(false) => Err(ApiError::PasskeyNotFound),
+62 -164
crates/tranquil-pds/src/api/server/password.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::api::{EmptyResponse, HasPasswordResponse, SuccessResponse}; 3 - use crate::auth::{Active, Auth}; 4 - use crate::state::{AppState, RateLimitKind}; 3 + use crate::auth::{ 4 + Active, Auth, require_legacy_session_mfa, require_reauth_window, 5 + require_reauth_window_if_available, 6 + }; 7 + use crate::rate_limit::{PasswordResetLimit, RateLimited, ResetPasswordLimit}; 8 + use crate::state::AppState; 5 9 use crate::types::PlainPassword; 10 + use crate::util::{pds_hostname, pds_hostname_without_port}; 6 11 use crate::validation::validate_password; 7 12 use axum::{ 8 13 Json, 9 14 extract::State, 10 - http::HeaderMap, 11 15 response::{IntoResponse, Response}, 12 16 }; 13 - use bcrypt::{DEFAULT_COST, hash, verify}; 17 + use bcrypt::{DEFAULT_COST, hash}; 14 18 use chrono::{Duration, Utc}; 15 19 use serde::Deserialize; 16 20 use tracing::{error, info, warn}; ··· 18 22 fn generate_reset_code() -> String { 19 23 crate::util::generate_token_code() 20 24 } 21 - fn extract_client_ip(headers: &HeaderMap) -> String { 22 - if let Some(forwarded) = headers.get("x-forwarded-for") 23 - && let Ok(value) = forwarded.to_str() 24 - && let Some(first_ip) = value.split(',').next() 25 - { 26 - return first_ip.trim().to_string(); 27 - } 28 - if let Some(real_ip) = headers.get("x-real-ip") 29 - && let Ok(value) = real_ip.to_str() 30 - { 31 - return value.trim().to_string(); 32 - } 33 - "unknown".to_string() 34 - } 35 25 36 26 #[derive(Deserialize)] 37 27 pub struct RequestPasswordResetInput { ··· 41 31 42 32 pub async fn request_password_reset( 43 33 State(state): State<AppState>, 44 - headers: HeaderMap, 34 + _rate_limit: RateLimited<PasswordResetLimit>, 45 35 Json(input): Json<RequestPasswordResetInput>, 46 36 ) -> Response { 47 - let client_ip = extract_client_ip(&headers); 48 - if !state 49 - .check_rate_limit(RateLimitKind::PasswordReset, &client_ip) 50 - .await 51 - { 52 - warn!(ip = %client_ip, "Password reset rate limit exceeded"); 53 - return ApiError::RateLimitExceeded(None).into_response(); 54 - } 55 37 let identifier = input.email.trim(); 56 38 if identifier.is_empty() { 57 39 return ApiError::InvalidRequest("email or handle is required".into()).into_response(); 58 40 } 59 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 60 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 41 + let hostname_for_handles = pds_hostname_without_port(); 61 42 let normalized = identifier.to_lowercase(); 62 43 let normalized = normalized.strip_prefix('@').unwrap_or(&normalized); 63 44 let is_email_lookup = normalized.contains('@'); ··· 101 82 error!("DB error setting reset code: {:?}", e); 102 83 return ApiError::InternalError(None).into_response(); 103 84 } 104 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 85 + let hostname = pds_hostname(); 105 86 if let Err(e) = crate::comms::comms_repo::enqueue_password_reset( 106 87 state.user_repo.as_ref(), 107 88 state.infra_repo.as_ref(), 108 89 user_id, 109 90 &code, 110 - &hostname, 91 + hostname, 111 92 ) 112 93 .await 113 94 { ··· 135 116 136 117 pub async fn reset_password( 137 118 State(state): State<AppState>, 138 - headers: HeaderMap, 119 + _rate_limit: RateLimited<ResetPasswordLimit>, 139 120 Json(input): Json<ResetPasswordInput>, 140 121 ) -> Response { 141 - let client_ip = extract_client_ip(&headers); 142 - if !state 143 - .check_rate_limit(RateLimitKind::ResetPassword, &client_ip) 144 - .await 145 - { 146 - warn!(ip = %client_ip, "Reset password rate limit exceeded"); 147 - return ApiError::RateLimitExceeded(None).into_response(); 148 - } 149 122 let token = input.token.trim(); 150 123 let password = &input.password; 151 124 if token.is_empty() { ··· 230 203 auth: Auth<Active>, 231 204 Json(input): Json<ChangePasswordInput>, 232 205 ) -> Result<Response, ApiError> { 233 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 234 - { 235 - return Ok(crate::api::server::reauth::legacy_mfa_required_response( 236 - &*state.user_repo, 237 - &*state.session_repo, 238 - &auth.did, 239 - ) 240 - .await); 241 - } 206 + use crate::auth::verify_password_mfa; 242 207 243 - let current_password = &input.current_password; 244 - let new_password = &input.new_password; 245 - if current_password.is_empty() { 208 + let session_mfa = match require_legacy_session_mfa(&state, &auth).await { 209 + Ok(proof) => proof, 210 + Err(response) => return Ok(response), 211 + }; 212 + 213 + if input.current_password.is_empty() { 246 214 return Err(ApiError::InvalidRequest( 247 215 "currentPassword is required".into(), 248 216 )); 249 217 } 250 - if new_password.is_empty() { 218 + if input.new_password.is_empty() { 251 219 return Err(ApiError::InvalidRequest("newPassword is required".into())); 252 220 } 253 - if let Err(e) = validate_password(new_password) { 221 + if let Err(e) = validate_password(&input.new_password) { 254 222 return Err(ApiError::InvalidRequest(e.to_string())); 255 223 } 224 + 225 + let password_mfa = verify_password_mfa(&state, &auth, &input.current_password).await?; 226 + 256 227 let user = state 257 228 .user_repo 258 - .get_id_and_password_hash_by_did(&auth.did) 229 + .get_id_and_password_hash_by_did(password_mfa.did()) 259 230 .await 260 - .map_err(|e| { 261 - error!("DB error in change_password: {:?}", e); 262 - ApiError::InternalError(None) 263 - })? 231 + .log_db_err("in change_password")? 264 232 .ok_or(ApiError::AccountNotFound)?; 265 233 266 - let (user_id, password_hash) = (user.id, user.password_hash); 267 - let valid = verify(current_password, &password_hash).map_err(|e| { 268 - error!("Password verification error: {:?}", e); 269 - ApiError::InternalError(None) 270 - })?; 271 - if !valid { 272 - return Err(ApiError::InvalidPassword( 273 - "Current password is incorrect".into(), 274 - )); 275 - } 276 - let new_password_clone = new_password.to_string(); 234 + let new_password_clone = input.new_password.to_string(); 277 235 let new_hash = tokio::task::spawn_blocking(move || hash(new_password_clone, DEFAULT_COST)) 278 236 .await 279 237 .map_err(|e| { ··· 287 245 288 246 state 289 247 .user_repo 290 - .update_password_hash(user_id, &new_hash) 248 + .update_password_hash(user.id, &new_hash) 291 249 .await 292 - .map_err(|e| { 293 - error!("DB error updating password: {:?}", e); 294 - ApiError::InternalError(None) 295 - })?; 250 + .log_db_err("updating password")?; 296 251 297 - info!(did = %&auth.did, "Password changed successfully"); 252 + info!(did = %session_mfa.did(), "Password changed successfully"); 298 253 Ok(EmptyResponse::ok().into_response()) 299 254 } 300 255 ··· 302 257 State(state): State<AppState>, 303 258 auth: Auth<Active>, 304 259 ) -> Result<Response, ApiError> { 305 - match state.user_repo.has_password_by_did(&auth.did).await { 306 - Ok(Some(has)) => Ok(HasPasswordResponse::response(has).into_response()), 307 - Ok(None) => Err(ApiError::AccountNotFound), 308 - Err(e) => { 309 - error!("DB error: {:?}", e); 310 - Err(ApiError::InternalError(None)) 311 - } 312 - } 260 + let has = state 261 + .user_repo 262 + .has_password_by_did(&auth.did) 263 + .await 264 + .log_db_err("checking password status")? 265 + .ok_or(ApiError::AccountNotFound)?; 266 + Ok(HasPasswordResponse::response(has).into_response()) 313 267 } 314 268 315 269 pub async fn remove_password( 316 270 State(state): State<AppState>, 317 271 auth: Auth<Active>, 318 272 ) -> Result<Response, ApiError> { 319 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 320 - { 321 - return Ok(crate::api::server::reauth::legacy_mfa_required_response( 322 - &*state.user_repo, 323 - &*state.session_repo, 324 - &auth.did, 325 - ) 326 - .await); 327 - } 273 + let session_mfa = match require_legacy_session_mfa(&state, &auth).await { 274 + Ok(proof) => proof, 275 + Err(response) => return Ok(response), 276 + }; 328 277 329 - if crate::api::server::reauth::check_reauth_required_cached( 330 - &*state.session_repo, 331 - &state.cache, 332 - &auth.did, 333 - ) 334 - .await 335 - { 336 - return Ok(crate::api::server::reauth::reauth_required_response( 337 - &*state.user_repo, 338 - &*state.session_repo, 339 - &auth.did, 340 - ) 341 - .await); 342 - } 278 + let reauth_mfa = match require_reauth_window(&state, &auth).await { 279 + Ok(proof) => proof, 280 + Err(response) => return Ok(response), 281 + }; 343 282 344 283 let has_passkeys = state 345 284 .user_repo 346 - .has_passkeys(&auth.did) 285 + .has_passkeys(reauth_mfa.did()) 347 286 .await 348 287 .unwrap_or(false); 349 288 if !has_passkeys { ··· 354 293 355 294 let user = state 356 295 .user_repo 357 - .get_password_info_by_did(&auth.did) 296 + .get_password_info_by_did(reauth_mfa.did()) 358 297 .await 359 - .map_err(|e| { 360 - error!("DB error: {:?}", e); 361 - ApiError::InternalError(None) 362 - })? 298 + .log_db_err("getting password info")? 363 299 .ok_or(ApiError::AccountNotFound)?; 364 300 365 301 if user.password_hash.is_none() { ··· 372 308 .user_repo 373 309 .remove_user_password(user.id) 374 310 .await 375 - .map_err(|e| { 376 - error!("DB error removing password: {:?}", e); 377 - ApiError::InternalError(None) 378 - })?; 311 + .log_db_err("removing password")?; 379 312 380 - info!(did = %&auth.did, "Password removed - account is now passkey-only"); 313 + info!(did = %session_mfa.did(), "Password removed - account is now passkey-only"); 381 314 Ok(SuccessResponse::ok().into_response()) 382 315 } 383 316 ··· 392 325 auth: Auth<Active>, 393 326 Json(input): Json<SetPasswordInput>, 394 327 ) -> Result<Response, ApiError> { 395 - let has_password = state 396 - .user_repo 397 - .has_password_by_did(&auth.did) 398 - .await 399 - .ok() 400 - .flatten() 401 - .unwrap_or(false); 402 - let has_passkeys = state 403 - .user_repo 404 - .has_passkeys(&auth.did) 405 - .await 406 - .unwrap_or(false); 407 - let has_totp = state 408 - .user_repo 409 - .has_totp_enabled(&auth.did) 410 - .await 411 - .unwrap_or(false); 412 - 413 - let has_any_reauth_method = has_password || has_passkeys || has_totp; 414 - 415 - if has_any_reauth_method 416 - && crate::api::server::reauth::check_reauth_required_cached( 417 - &*state.session_repo, 418 - &state.cache, 419 - &auth.did, 420 - ) 421 - .await 422 - { 423 - return Ok(crate::api::server::reauth::reauth_required_response( 424 - &*state.user_repo, 425 - &*state.session_repo, 426 - &auth.did, 427 - ) 428 - .await); 429 - } 328 + let reauth_mfa = match require_reauth_window_if_available(&state, &auth).await { 329 + Ok(proof) => proof, 330 + Err(response) => return Ok(response), 331 + }; 430 332 431 333 let new_password = &input.new_password; 432 334 if new_password.is_empty() { ··· 436 338 return Err(ApiError::InvalidRequest(e.to_string())); 437 339 } 438 340 341 + let did = reauth_mfa.as_ref().map(|m| m.did()).unwrap_or(&auth.did); 342 + 439 343 let user = state 440 344 .user_repo 441 - .get_password_info_by_did(&auth.did) 345 + .get_password_info_by_did(did) 442 346 .await 443 - .map_err(|e| { 444 - error!("DB error: {:?}", e); 445 - ApiError::InternalError(None) 446 - })? 347 + .log_db_err("getting password info")? 447 348 .ok_or(ApiError::AccountNotFound)?; 448 349 449 350 if user.password_hash.is_some() { ··· 468 369 .user_repo 469 370 .set_new_user_password(user.id, &new_hash) 470 371 .await 471 - .map_err(|e| { 472 - error!("DB error setting password: {:?}", e); 473 - ApiError::InternalError(None) 474 - })?; 372 + .log_db_err("setting password")?; 475 373 476 - info!(did = %&auth.did, "Password set for passkey-only account"); 374 + info!(did = %did, "Password set for passkey-only account"); 477 375 Ok(SuccessResponse::ok().into_response()) 478 376 }
+21 -58
crates/tranquil-pds/src/api/server/reauth.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use axum::{ 3 3 Json, 4 4 extract::State, ··· 11 11 use tranquil_db_traits::{SessionRepository, UserRepository}; 12 12 13 13 use crate::auth::{Active, Auth}; 14 - use crate::state::{AppState, RateLimitKind}; 14 + use crate::rate_limit::{TotpVerifyLimit, check_user_rate_limit_with_message}; 15 + use crate::state::AppState; 15 16 use crate::types::PlainPassword; 16 17 17 - const REAUTH_WINDOW_SECONDS: i64 = 300; 18 + pub const REAUTH_WINDOW_SECONDS: i64 = 300; 18 19 19 20 #[derive(Serialize)] 20 21 #[serde(rename_all = "camelCase")] ··· 32 33 .session_repo 33 34 .get_last_reauth_at(&auth.did) 34 35 .await 35 - .map_err(|e| { 36 - error!("DB error: {:?}", e); 37 - ApiError::InternalError(None) 38 - })?; 36 + .log_db_err("getting last reauth")?; 39 37 40 38 let reauth_required = is_reauth_required(last_reauth_at); 41 39 let available_methods = ··· 70 68 .user_repo 71 69 .get_password_hash_by_did(&auth.did) 72 70 .await 73 - .map_err(|e| { 74 - error!("DB error: {:?}", e); 75 - ApiError::InternalError(None) 76 - })? 71 + .log_db_err("fetching password hash")? 77 72 .ok_or(ApiError::AccountNotFound)?; 78 73 79 74 let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); ··· 97 92 98 93 let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did) 99 94 .await 100 - .map_err(|e| { 101 - error!("DB error updating reauth: {:?}", e); 102 - ApiError::InternalError(None) 103 - })?; 95 + .log_db_err("updating reauth")?; 104 96 105 97 info!(did = %&auth.did, "Re-auth successful via password"); 106 98 Ok(Json(ReauthResponse { reauthed_at }).into_response()) ··· 117 109 auth: Auth<Active>, 118 110 Json(input): Json<TotpReauthInput>, 119 111 ) -> Result<Response, ApiError> { 120 - if !state 121 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.did) 122 - .await 123 - { 124 - warn!(did = %&auth.did, "TOTP verification rate limit exceeded"); 125 - return Err(ApiError::RateLimitExceeded(Some( 126 - "Too many verification attempts. Please try again in a few minutes.".into(), 127 - ))); 128 - } 112 + let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>( 113 + &state, 114 + &auth.did, 115 + "Too many verification attempts. Please try again in a few minutes.", 116 + ) 117 + .await?; 129 118 130 119 let valid = 131 120 crate::api::server::totp::verify_totp_or_backup_for_user(&state, &auth.did, &input.code) ··· 140 129 141 130 let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did) 142 131 .await 143 - .map_err(|e| { 144 - error!("DB error updating reauth: {:?}", e); 145 - ApiError::InternalError(None) 146 - })?; 132 + .log_db_err("updating reauth")?; 147 133 148 134 info!(did = %&auth.did, "Re-auth successful via TOTP"); 149 135 Ok(Json(ReauthResponse { reauthed_at }).into_response()) ··· 159 145 State(state): State<AppState>, 160 146 auth: Auth<Active>, 161 147 ) -> Result<Response, ApiError> { 162 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 163 - 164 148 let stored_passkeys = state 165 149 .user_repo 166 150 .get_passkeys_for_user(&auth.did) 167 151 .await 168 - .map_err(|e| { 169 - error!("Failed to get passkeys: {:?}", e); 170 - ApiError::InternalError(None) 171 - })?; 152 + .log_db_err("getting passkeys")?; 172 153 173 154 if stored_passkeys.is_empty() { 174 155 return Err(ApiError::NoPasskeys); ··· 185 166 ))); 186 167 } 187 168 188 - let webauthn = crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname).map_err(|e| { 189 - error!("Failed to create WebAuthn config: {:?}", e); 190 - ApiError::InternalError(None) 191 - })?; 169 + let webauthn = &state.webauthn_config; 192 170 193 171 let (rcr, auth_state) = webauthn.start_authentication(passkeys).map_err(|e| { 194 172 error!("Failed to start passkey authentication: {:?}", e); ··· 204 182 .user_repo 205 183 .save_webauthn_challenge(&auth.did, "authentication", &state_json) 206 184 .await 207 - .map_err(|e| { 208 - error!("Failed to save authentication state: {:?}", e); 209 - ApiError::InternalError(None) 210 - })?; 185 + .log_db_err("saving authentication state")?; 211 186 212 187 let options = serde_json::to_value(&rcr).unwrap_or(serde_json::json!({})); 213 188 Ok(Json(PasskeyReauthStartResponse { options }).into_response()) ··· 224 199 auth: Auth<Active>, 225 200 Json(input): Json<PasskeyReauthFinishInput>, 226 201 ) -> Result<Response, ApiError> { 227 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 228 - 229 202 let auth_state_json = state 230 203 .user_repo 231 204 .load_webauthn_challenge(&auth.did, "authentication") 232 205 .await 233 - .map_err(|e| { 234 - error!("Failed to load authentication state: {:?}", e); 235 - ApiError::InternalError(None) 236 - })? 206 + .log_db_err("loading authentication state")? 237 207 .ok_or(ApiError::NoChallengeInProgress)?; 238 208 239 209 let auth_state: webauthn_rs::prelude::SecurityKeyAuthentication = ··· 248 218 ApiError::InvalidCredential 249 219 })?; 250 220 251 - let webauthn = crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname).map_err(|e| { 252 - error!("Failed to create WebAuthn config: {:?}", e); 253 - ApiError::InternalError(None) 254 - })?; 255 - 256 - let auth_result = webauthn 221 + let auth_result = state 222 + .webauthn_config 257 223 .finish_authentication(&credential, &auth_state) 258 224 .map_err(|e| { 259 225 warn!(did = %&auth.did, "Passkey re-auth failed: {:?}", e); ··· 287 253 288 254 let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did) 289 255 .await 290 - .map_err(|e| { 291 - error!("DB error updating reauth: {:?}", e); 292 - ApiError::InternalError(None) 293 - })?; 256 + .log_db_err("updating reauth")?; 294 257 295 258 info!(did = %&auth.did, "Re-auth successful via passkey"); 296 259 Ok(Json(ReauthResponse { reauthed_at }).into_response())
+2 -2
crates/tranquil-pds/src/api/server/service_auth.rs
··· 51 51 headers: axum::http::HeaderMap, 52 52 Query(params): Query<GetServiceAuthParams>, 53 53 ) -> Response { 54 - let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 55 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 54 + let auth_header = crate::util::get_header_str(&headers, "Authorization"); 55 + let dpop_proof = crate::util::get_header_str(&headers, "DPoP"); 56 56 info!( 57 57 has_auth_header = auth_header.is_some(), 58 58 has_dpop_proof = dpop_proof.is_some(),
+46 -120
crates/tranquil-pds/src/api/server/session.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::api::{EmptyResponse, SuccessResponse}; 3 - use crate::auth::{Active, Auth, Permissive}; 4 - use crate::state::{AppState, RateLimitKind}; 3 + use crate::auth::{Active, Auth, Permissive, require_legacy_session_mfa, require_reauth_window}; 4 + use crate::rate_limit::{LoginLimit, RateLimited, RefreshSessionLimit}; 5 + use crate::state::AppState; 5 6 use crate::types::{AccountState, Did, Handle, PlainPassword}; 7 + use crate::util::{pds_hostname, pds_hostname_without_port}; 6 8 use axum::{ 7 9 Json, 8 10 extract::State, ··· 15 17 use tracing::{error, info, warn}; 16 18 use tranquil_types::TokenId; 17 19 18 - fn extract_client_ip(headers: &HeaderMap) -> String { 19 - if let Some(forwarded) = headers.get("x-forwarded-for") 20 - && let Ok(value) = forwarded.to_str() 21 - && let Some(first_ip) = value.split(',').next() 22 - { 23 - return first_ip.trim().to_string(); 24 - } 25 - if let Some(real_ip) = headers.get("x-real-ip") 26 - && let Ok(value) = real_ip.to_str() 27 - { 28 - return value.trim().to_string(); 29 - } 30 - "unknown".to_string() 31 - } 32 - 33 20 fn normalize_handle(identifier: &str, pds_hostname: &str) -> String { 34 21 let identifier = identifier.trim(); 35 22 if identifier.contains('@') || identifier.starts_with("did:") { ··· 75 62 76 63 pub async fn create_session( 77 64 State(state): State<AppState>, 78 - headers: HeaderMap, 65 + rate_limit: RateLimited<LoginLimit>, 79 66 Json(input): Json<CreateSessionInput>, 80 67 ) -> Response { 68 + let client_ip = rate_limit.client_ip(); 81 69 info!( 82 70 "create_session called with identifier: {}", 83 71 input.identifier 84 72 ); 85 - let client_ip = extract_client_ip(&headers); 86 - if !state 87 - .check_rate_limit(RateLimitKind::Login, &client_ip) 88 - .await 89 - { 90 - warn!(ip = %client_ip, "Login rate limit exceeded"); 91 - return ApiError::RateLimitExceeded(None).into_response(); 92 - } 93 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 94 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 73 + let pds_host = pds_hostname(); 74 + let hostname_for_handles = pds_hostname_without_port(); 95 75 let normalized_identifier = normalize_handle(&input.identifier, hostname_for_handles); 96 76 info!( 97 77 "Normalized identifier: {} -> {}", ··· 246 226 ip = %client_ip, 247 227 "Legacy login on TOTP-enabled account - sending notification" 248 228 ); 249 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 229 + let hostname = pds_hostname(); 250 230 if let Err(e) = crate::comms::comms_repo::enqueue_legacy_login( 251 231 state.user_repo.as_ref(), 252 232 state.infra_repo.as_ref(), 253 233 row.id, 254 - &hostname, 255 - &client_ip, 234 + hostname, 235 + client_ip, 256 236 row.preferred_comms_channel, 257 237 ) 258 238 .await ··· 260 240 error!("Failed to queue legacy login notification: {:?}", e); 261 241 } 262 242 } 263 - let handle = full_handle(&row.handle, &pds_hostname); 243 + let handle = full_handle(&row.handle, pds_host); 264 244 let is_active = account_state.is_active(); 265 245 let status = account_state.status_for_session().map(String::from); 266 246 Json(CreateSessionOutput { ··· 299 279 tranquil_db_traits::CommsChannel::Telegram => ("telegram", row.telegram_verified), 300 280 tranquil_db_traits::CommsChannel::Signal => ("signal", row.signal_verified), 301 281 }; 302 - let pds_hostname = 303 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 304 - let handle = full_handle(&row.handle, &pds_hostname); 282 + let pds_hostname = pds_hostname(); 283 + let handle = full_handle(&row.handle, pds_hostname); 305 284 let account_state = AccountState::from_db_fields( 306 285 row.deactivated_at, 307 286 row.takedown_ref.clone(), ··· 353 332 _auth: Auth<Active>, 354 333 ) -> Result<Response, ApiError> { 355 334 let extracted = crate::auth::extract_auth_token_from_header( 356 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 335 + crate::util::get_header_str(&headers, "Authorization"), 357 336 ) 358 337 .ok_or(ApiError::AuthenticationRequired)?; 359 338 let jti = crate::auth::get_jti_from_token(&extracted.token) ··· 374 353 375 354 pub async fn refresh_session( 376 355 State(state): State<AppState>, 356 + _rate_limit: RateLimited<RefreshSessionLimit>, 377 357 headers: axum::http::HeaderMap, 378 358 ) -> Response { 379 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 380 - if !state 381 - .check_rate_limit(RateLimitKind::RefreshSession, &client_ip) 382 - .await 383 - { 384 - tracing::warn!(ip = %client_ip, "Refresh session rate limit exceeded"); 385 - return ApiError::RateLimitExceeded(None).into_response(); 386 - } 387 359 let extracted = match crate::auth::extract_auth_token_from_header( 388 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 360 + crate::util::get_header_str(&headers, "Authorization"), 389 361 ) { 390 362 Some(t) => t, 391 363 None => return ApiError::AuthenticationRequired.into_response(), ··· 509 481 tranquil_db_traits::CommsChannel::Telegram => ("telegram", u.telegram_verified), 510 482 tranquil_db_traits::CommsChannel::Signal => ("signal", u.signal_verified), 511 483 }; 512 - let pds_hostname = 513 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 514 - let handle = full_handle(&u.handle, &pds_hostname); 484 + let pds_hostname = pds_hostname(); 485 + let handle = full_handle(&u.handle, pds_hostname); 515 486 let account_state = 516 487 AccountState::from_db_fields(u.deactivated_at, u.takedown_ref.clone(), None, None); 517 488 let mut response = json!({ ··· 675 646 return ApiError::InternalError(None).into_response(); 676 647 } 677 648 678 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 649 + let hostname = pds_hostname(); 679 650 if let Err(e) = crate::comms::comms_repo::enqueue_welcome( 680 651 state.user_repo.as_ref(), 681 652 state.infra_repo.as_ref(), 682 653 row.id, 683 - &hostname, 654 + hostname, 684 655 ) 685 656 .await 686 657 { ··· 756 727 let formatted_token = 757 728 crate::auth::verification_token::format_token_for_display(&verification_token); 758 729 759 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 730 + let hostname = pds_hostname(); 760 731 if let Err(e) = crate::comms::comms_repo::enqueue_signup_verification( 761 732 state.infra_repo.as_ref(), 762 733 row.id, 763 734 channel_str, 764 735 &recipient, 765 736 &formatted_token, 766 - &hostname, 737 + hostname, 767 738 ) 768 739 .await 769 740 { ··· 804 775 .session_repo 805 776 .list_sessions_by_did(&auth.did) 806 777 .await 807 - .map_err(|e| { 808 - error!("DB error fetching JWT sessions: {:?}", e); 809 - ApiError::InternalError(None) 810 - })?; 778 + .log_db_err("fetching JWT sessions")?; 811 779 812 780 let oauth_rows = state 813 781 .oauth_repo 814 782 .list_sessions_by_did(&auth.did) 815 783 .await 816 - .map_err(|e| { 817 - error!("DB error fetching OAuth sessions: {:?}", e); 818 - ApiError::InternalError(None) 819 - })?; 784 + .log_db_err("fetching OAuth sessions")?; 820 785 821 786 let jwt_sessions = jwt_rows.into_iter().map(|row| SessionInfo { 822 787 id: format!("jwt:{}", row.id), ··· 876 841 .session_repo 877 842 .get_session_access_jti_by_id(session_id, &auth.did) 878 843 .await 879 - .map_err(|e| { 880 - error!("DB error in revoke_session: {:?}", e); 881 - ApiError::InternalError(None) 882 - })? 844 + .log_db_err("in revoke_session")? 883 845 .ok_or(ApiError::SessionNotFound)?; 884 846 state 885 847 .session_repo 886 848 .delete_session_by_id(session_id) 887 849 .await 888 - .map_err(|e| { 889 - error!("DB error deleting session: {:?}", e); 890 - ApiError::InternalError(None) 891 - })?; 850 + .log_db_err("deleting session")?; 892 851 let cache_key = format!("auth:session:{}:{}", &auth.did, access_jti); 893 852 if let Err(e) = state.cache.delete(&cache_key).await { 894 853 warn!("Failed to invalidate session cache: {:?}", e); ··· 902 861 .oauth_repo 903 862 .delete_session_by_id(session_id, &auth.did) 904 863 .await 905 - .map_err(|e| { 906 - error!("DB error deleting OAuth session: {:?}", e); 907 - ApiError::InternalError(None) 908 - })?; 864 + .log_db_err("deleting OAuth session")?; 909 865 if deleted == 0 { 910 866 return Err(ApiError::SessionNotFound); 911 867 } ··· 932 888 .session_repo 933 889 .delete_sessions_by_did(&auth.did) 934 890 .await 935 - .map_err(|e| { 936 - error!("DB error revoking JWT sessions: {:?}", e); 937 - ApiError::InternalError(None) 938 - })?; 891 + .log_db_err("revoking JWT sessions")?; 939 892 let jti_typed = TokenId::from(jti.clone()); 940 893 state 941 894 .oauth_repo 942 895 .delete_sessions_by_did_except(&auth.did, &jti_typed) 943 896 .await 944 - .map_err(|e| { 945 - error!("DB error revoking OAuth sessions: {:?}", e); 946 - ApiError::InternalError(None) 947 - })?; 897 + .log_db_err("revoking OAuth sessions")?; 948 898 } else { 949 899 state 950 900 .session_repo 951 901 .delete_sessions_by_did_except_jti(&auth.did, &jti) 952 902 .await 953 - .map_err(|e| { 954 - error!("DB error revoking JWT sessions: {:?}", e); 955 - ApiError::InternalError(None) 956 - })?; 903 + .log_db_err("revoking JWT sessions")?; 957 904 state 958 905 .oauth_repo 959 906 .delete_sessions_by_did(&auth.did) 960 907 .await 961 - .map_err(|e| { 962 - error!("DB error revoking OAuth sessions: {:?}", e); 963 - ApiError::InternalError(None) 964 - })?; 908 + .log_db_err("revoking OAuth sessions")?; 965 909 } 966 910 967 911 info!(did = %&auth.did, "All other sessions revoked"); ··· 983 927 .user_repo 984 928 .get_legacy_login_pref(&auth.did) 985 929 .await 986 - .map_err(|e| { 987 - error!("DB error: {:?}", e); 988 - ApiError::InternalError(None) 989 - })? 930 + .log_db_err("getting legacy login pref")? 990 931 .ok_or(ApiError::AccountNotFound)?; 991 932 Ok(Json(LegacyLoginPreferenceOutput { 992 933 allow_legacy_login: pref.allow_legacy_login, ··· 1006 947 auth: Auth<Active>, 1007 948 Json(input): Json<UpdateLegacyLoginInput>, 1008 949 ) -> Result<Response, ApiError> { 1009 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 1010 - { 1011 - return Ok(crate::api::server::reauth::legacy_mfa_required_response( 1012 - &*state.user_repo, 1013 - &*state.session_repo, 1014 - &auth.did, 1015 - ) 1016 - .await); 1017 - } 950 + let session_mfa = match require_legacy_session_mfa(&state, &auth).await { 951 + Ok(proof) => proof, 952 + Err(response) => return Ok(response), 953 + }; 1018 954 1019 - if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.did).await { 1020 - return Ok(crate::api::server::reauth::reauth_required_response( 1021 - &*state.user_repo, 1022 - &*state.session_repo, 1023 - &auth.did, 1024 - ) 1025 - .await); 1026 - } 955 + let reauth_mfa = match require_reauth_window(&state, &auth).await { 956 + Ok(proof) => proof, 957 + Err(response) => return Ok(response), 958 + }; 1027 959 1028 960 let updated = state 1029 961 .user_repo 1030 - .update_legacy_login(&auth.did, input.allow_legacy_login) 962 + .update_legacy_login(reauth_mfa.did(), input.allow_legacy_login) 1031 963 .await 1032 - .map_err(|e| { 1033 - error!("DB error: {:?}", e); 1034 - ApiError::InternalError(None) 1035 - })?; 964 + .log_db_err("updating legacy login")?; 1036 965 if !updated { 1037 966 return Err(ApiError::AccountNotFound); 1038 967 } 1039 968 info!( 1040 - did = %&auth.did, 969 + did = %session_mfa.did(), 1041 970 allow_legacy_login = input.allow_legacy_login, 1042 971 "Legacy login preference updated" 1043 972 ); ··· 1071 1000 .user_repo 1072 1001 .update_locale(&auth.did, &input.preferred_locale) 1073 1002 .await 1074 - .map_err(|e| { 1075 - error!("DB error updating locale: {:?}", e); 1076 - ApiError::InternalError(None) 1077 - })?; 1003 + .log_db_err("updating locale")?; 1078 1004 if !updated { 1079 1005 return Err(ApiError::AccountNotFound); 1080 1006 }
+45 -148
crates/tranquil-pds/src/api/server/totp.rs
··· 1 1 use crate::api::EmptyResponse; 2 - use crate::api::error::ApiError; 3 - use crate::auth::{Active, Auth}; 2 + use crate::api::error::{ApiError, DbResultExt}; 4 3 use crate::auth::{ 5 - decrypt_totp_secret, encrypt_totp_secret, generate_backup_codes, generate_qr_png_base64, 6 - generate_totp_secret, generate_totp_uri, hash_backup_code, is_backup_code_format, 7 - verify_backup_code, verify_totp_code, 4 + Active, Auth, decrypt_totp_secret, encrypt_totp_secret, generate_backup_codes, 5 + generate_qr_png_base64, generate_totp_secret, generate_totp_uri, hash_backup_code, 6 + is_backup_code_format, require_legacy_session_mfa, verify_backup_code, verify_password_mfa, 7 + verify_totp_code, verify_totp_mfa, 8 8 }; 9 - use crate::state::{AppState, RateLimitKind}; 9 + use crate::rate_limit::{TotpVerifyLimit, check_user_rate_limit_with_message}; 10 + use crate::state::AppState; 10 11 use crate::types::PlainPassword; 12 + use crate::util::pds_hostname; 11 13 use axum::{ 12 14 Json, 13 15 extract::State, ··· 45 47 .user_repo 46 48 .get_handle_by_did(&auth.did) 47 49 .await 48 - .map_err(|e| { 49 - error!("DB error fetching handle: {:?}", e); 50 - ApiError::InternalError(None) 51 - })? 50 + .log_db_err("fetching handle")? 52 51 .ok_or(ApiError::AccountNotFound)?; 53 52 54 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 55 - let uri = generate_totp_uri(&secret, &handle, &hostname); 53 + let hostname = pds_hostname(); 54 + let uri = generate_totp_uri(&secret, &handle, hostname); 56 55 57 - let qr_code = generate_qr_png_base64(&secret, &handle, &hostname).map_err(|e| { 56 + let qr_code = generate_qr_png_base64(&secret, &handle, hostname).map_err(|e| { 58 57 error!("Failed to generate QR code: {:?}", e); 59 58 ApiError::InternalError(Some("Failed to generate QR code".into())) 60 59 })?; ··· 68 67 .user_repo 69 68 .upsert_totp_secret(&auth.did, &encrypted_secret, ENCRYPTION_VERSION) 70 69 .await 71 - .map_err(|e| { 72 - error!("Failed to store TOTP secret: {:?}", e); 73 - ApiError::InternalError(None) 74 - })?; 70 + .log_db_err("storing TOTP secret")?; 75 71 76 72 let secret_base32 = base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &secret); 77 73 ··· 101 97 auth: Auth<Active>, 102 98 Json(input): Json<EnableTotpInput>, 103 99 ) -> Result<Response, ApiError> { 104 - if !state 105 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.did) 106 - .await 107 - { 108 - warn!(did = %&auth.did, "TOTP verification rate limit exceeded"); 109 - return Err(ApiError::RateLimitExceeded(None)); 110 - } 100 + let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>( 101 + &state, 102 + &auth.did, 103 + "Too many verification attempts. Please try again in a few minutes.", 104 + ) 105 + .await?; 111 106 112 107 let totp_record = match state.user_repo.get_totp_record(&auth.did).await { 113 108 Ok(Some(row)) => row, ··· 152 147 .user_repo 153 148 .enable_totp_with_backup_codes(&auth.did, &backup_hashes) 154 149 .await 155 - .map_err(|e| { 156 - error!("Failed to enable TOTP: {:?}", e); 157 - ApiError::InternalError(None) 158 - })?; 150 + .log_db_err("enabling TOTP")?; 159 151 160 152 info!(did = %&auth.did, "TOTP enabled with {} backup codes", backup_codes.len()); 161 153 ··· 173 165 auth: Auth<Active>, 174 166 Json(input): Json<DisableTotpInput>, 175 167 ) -> Result<Response, ApiError> { 176 - if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await 177 - { 178 - return Ok(crate::api::server::reauth::legacy_mfa_required_response( 179 - &*state.user_repo, 180 - &*state.session_repo, 181 - &auth.did, 182 - ) 183 - .await); 184 - } 185 - 186 - if !state 187 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.did) 188 - .await 189 - { 190 - warn!(did = %&auth.did, "TOTP verification rate limit exceeded"); 191 - return Err(ApiError::RateLimitExceeded(None)); 192 - } 193 - 194 - let password_hash = state 195 - .user_repo 196 - .get_password_hash_by_did(&auth.did) 197 - .await 198 - .map_err(|e| { 199 - error!("DB error fetching user: {:?}", e); 200 - ApiError::InternalError(None) 201 - })? 202 - .ok_or(ApiError::AccountNotFound)?; 203 - 204 - let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); 205 - if !password_valid { 206 - return Err(ApiError::InvalidPassword("Password is incorrect".into())); 207 - } 208 - 209 - let totp_record = match state.user_repo.get_totp_record(&auth.did).await { 210 - Ok(Some(row)) if row.verified => row, 211 - Ok(Some(_)) | Ok(None) => return Err(ApiError::TotpNotEnabled), 212 - Err(e) => { 213 - error!("DB error fetching TOTP: {:?}", e); 214 - return Err(ApiError::InternalError(None)); 215 - } 168 + let _session_mfa = match require_legacy_session_mfa(&state, &auth).await { 169 + Ok(proof) => proof, 170 + Err(response) => return Ok(response), 216 171 }; 217 172 218 - let code = input.code.trim(); 219 - let code_valid = if is_backup_code_format(code) { 220 - verify_backup_code_for_user(&state, &auth.did, code).await 221 - } else { 222 - let secret = decrypt_totp_secret( 223 - &totp_record.secret_encrypted, 224 - totp_record.encryption_version, 225 - ) 226 - .map_err(|e| { 227 - error!("Failed to decrypt TOTP secret: {:?}", e); 228 - ApiError::InternalError(None) 229 - })?; 230 - verify_totp_code(&secret, code) 231 - }; 173 + let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>( 174 + &state, 175 + &auth.did, 176 + "Too many verification attempts. Please try again in a few minutes.", 177 + ) 178 + .await?; 232 179 233 - if !code_valid { 234 - return Err(ApiError::InvalidCode(Some( 235 - "Invalid verification code".into(), 236 - ))); 237 - } 180 + let password_mfa = verify_password_mfa(&state, &auth, &input.password).await?; 181 + let totp_mfa = verify_totp_mfa(&state, &auth, &input.code).await?; 238 182 239 183 state 240 184 .user_repo 241 - .delete_totp_and_backup_codes(&auth.did) 185 + .delete_totp_and_backup_codes(totp_mfa.did()) 242 186 .await 243 - .map_err(|e| { 244 - error!("Failed to delete TOTP: {:?}", e); 245 - ApiError::InternalError(None) 246 - })?; 187 + .log_db_err("deleting TOTP")?; 247 188 248 - info!(did = %&auth.did, "TOTP disabled"); 189 + info!(did = %password_mfa.did(), "TOTP disabled (verified via {} and {})", password_mfa.method(), totp_mfa.method()); 249 190 250 191 Ok(EmptyResponse::ok().into_response()) 251 192 } ··· 275 216 .user_repo 276 217 .count_unused_backup_codes(&auth.did) 277 218 .await 278 - .map_err(|e| { 279 - error!("DB error counting backup codes: {:?}", e); 280 - ApiError::InternalError(None) 281 - })?; 219 + .log_db_err("counting backup codes")?; 282 220 283 221 Ok(Json(GetTotpStatusResponse { 284 222 enabled, ··· 305 243 auth: Auth<Active>, 306 244 Json(input): Json<RegenerateBackupCodesInput>, 307 245 ) -> Result<Response, ApiError> { 308 - if !state 309 - .check_rate_limit(RateLimitKind::TotpVerify, &auth.did) 310 - .await 311 - { 312 - warn!(did = %&auth.did, "TOTP verification rate limit exceeded"); 313 - return Err(ApiError::RateLimitExceeded(None)); 314 - } 315 - 316 - let password_hash = state 317 - .user_repo 318 - .get_password_hash_by_did(&auth.did) 319 - .await 320 - .map_err(|e| { 321 - error!("DB error fetching user: {:?}", e); 322 - ApiError::InternalError(None) 323 - })? 324 - .ok_or(ApiError::AccountNotFound)?; 325 - 326 - let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); 327 - if !password_valid { 328 - return Err(ApiError::InvalidPassword("Password is incorrect".into())); 329 - } 330 - 331 - let totp_record = match state.user_repo.get_totp_record(&auth.did).await { 332 - Ok(Some(row)) if row.verified => row, 333 - Ok(Some(_)) | Ok(None) => return Err(ApiError::TotpNotEnabled), 334 - Err(e) => { 335 - error!("DB error fetching TOTP: {:?}", e); 336 - return Err(ApiError::InternalError(None)); 337 - } 338 - }; 339 - 340 - let secret = decrypt_totp_secret( 341 - &totp_record.secret_encrypted, 342 - totp_record.encryption_version, 246 + let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>( 247 + &state, 248 + &auth.did, 249 + "Too many verification attempts. Please try again in a few minutes.", 343 250 ) 344 - .map_err(|e| { 345 - error!("Failed to decrypt TOTP secret: {:?}", e); 346 - ApiError::InternalError(None) 347 - })?; 251 + .await?; 348 252 349 - let code = input.code.trim(); 350 - if !verify_totp_code(&secret, code) { 351 - return Err(ApiError::InvalidCode(Some( 352 - "Invalid verification code".into(), 353 - ))); 354 - } 253 + let password_mfa = verify_password_mfa(&state, &auth, &input.password).await?; 254 + let totp_mfa = verify_totp_mfa(&state, &auth, &input.code).await?; 355 255 356 256 let backup_codes = generate_backup_codes(); 357 257 let backup_hashes: Vec<_> = backup_codes ··· 365 265 366 266 state 367 267 .user_repo 368 - .replace_backup_codes(&auth.did, &backup_hashes) 268 + .replace_backup_codes(totp_mfa.did(), &backup_hashes) 369 269 .await 370 - .map_err(|e| { 371 - error!("Failed to regenerate backup codes: {:?}", e); 372 - ApiError::InternalError(None) 373 - })?; 270 + .log_db_err("replacing backup codes")?; 374 271 375 - info!(did = %&auth.did, "Backup codes regenerated"); 272 + info!(did = %password_mfa.did(), "Backup codes regenerated (verified via {} and {})", password_mfa.method(), totp_mfa.method()); 376 273 377 274 Ok(Json(RegenerateBackupCodesResponse { backup_codes }).into_response()) 378 275 }
+4 -13
crates/tranquil-pds/src/api/server/trusted_devices.rs
··· 1 1 use crate::api::SuccessResponse; 2 - use crate::api::error::ApiError; 2 + use crate::api::error::{ApiError, DbResultExt}; 3 3 use axum::{ 4 4 Json, 5 5 extract::State, ··· 79 79 .oauth_repo 80 80 .list_trusted_devices(&auth.did) 81 81 .await 82 - .map_err(|e| { 83 - error!("DB error: {:?}", e); 84 - ApiError::InternalError(None) 85 - })?; 82 + .log_db_err("listing trusted devices")?; 86 83 87 84 let devices = rows 88 85 .into_iter() ··· 134 131 .oauth_repo 135 132 .revoke_device_trust(&device_id) 136 133 .await 137 - .map_err(|e| { 138 - error!("DB error: {:?}", e); 139 - ApiError::InternalError(None) 140 - })?; 134 + .log_db_err("revoking device trust")?; 141 135 142 136 info!(did = %&auth.did, device_id = %input.device_id, "Trusted device revoked"); 143 137 Ok(SuccessResponse::ok().into_response()) ··· 175 169 .oauth_repo 176 170 .update_device_friendly_name(&device_id, input.friendly_name.as_deref()) 177 171 .await 178 - .map_err(|e| { 179 - error!("DB error: {:?}", e); 180 - ApiError::InternalError(None) 181 - })?; 172 + .log_db_err("updating device friendly name")?; 182 173 183 174 info!(did = %auth.did, device_id = %input.device_id, "Trusted device updated"); 184 175 Ok(SuccessResponse::ok().into_response())
+3 -2
crates/tranquil-pds/src/api/server/verify_email.rs
··· 5 5 use tracing::{info, warn}; 6 6 7 7 use crate::state::AppState; 8 + use crate::util::pds_hostname; 8 9 9 10 #[derive(Deserialize)] 10 11 #[serde(rename_all = "camelCase")] ··· 70 71 return Ok(Json(ResendMigrationVerificationOutput { sent: true })); 71 72 } 72 73 73 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 74 + let hostname = pds_hostname(); 74 75 let token = crate::auth::verification_token::generate_migration_token(&user.did, &email); 75 76 let formatted_token = crate::auth::verification_token::format_token_for_display(&token); 76 77 ··· 80 81 user.id, 81 82 &email, 82 83 &formatted_token, 83 - &hostname, 84 + hostname, 84 85 ) 85 86 .await 86 87 {
+14 -47
crates/tranquil-pds/src/api/server/verify_token.rs
··· 1 - use crate::api::error::ApiError; 1 + use crate::api::error::{ApiError, DbResultExt}; 2 2 use crate::types::Did; 3 3 use axum::{Json, extract::State}; 4 4 use serde::{Deserialize, Serialize}; 5 - use tracing::{error, info, warn}; 5 + use tracing::{info, warn}; 6 6 7 7 use crate::auth::verification_token::{ 8 8 VerificationPurpose, normalize_token_input, verify_token_signature, ··· 81 81 .user_repo 82 82 .get_verification_info(&did_typed) 83 83 .await 84 - .map_err(|e| { 85 - warn!(error = ?e, "Database error during migration verification"); 86 - ApiError::InternalError(None) 87 - })? 84 + .log_db_err("during migration verification")? 88 85 .ok_or(ApiError::AccountNotFound)?; 89 86 90 87 if user.email.as_ref().map(|e| e.to_lowercase()) != Some(identifier.to_string()) { ··· 96 93 .user_repo 97 94 .set_email_verified_flag(user.id) 98 95 .await 99 - .map_err(|e| { 100 - warn!(error = ?e, "Failed to update email_verified status"); 101 - ApiError::InternalError(None) 102 - })?; 96 + .log_db_err("updating email_verified status")?; 103 97 } 104 98 105 99 info!(did = %did, "Migration email verified successfully"); ··· 125 119 .user_repo 126 120 .get_id_by_did(&did_typed) 127 121 .await 128 - .map_err(|_| ApiError::InternalError(None))? 122 + .log_db_err("fetching user id")? 129 123 .ok_or(ApiError::AccountNotFound)?; 130 124 131 125 match channel { ··· 134 128 .user_repo 135 129 .verify_email_channel(user_id, identifier) 136 130 .await 137 - .map_err(|e| { 138 - error!("Failed to update email channel: {:?}", e); 139 - ApiError::InternalError(None) 140 - })?; 131 + .log_db_err("updating email channel")?; 141 132 if !success { 142 133 return Err(ApiError::EmailTaken); 143 134 } ··· 147 138 .user_repo 148 139 .verify_discord_channel(user_id, identifier) 149 140 .await 150 - .map_err(|e| { 151 - error!("Failed to update discord channel: {:?}", e); 152 - ApiError::InternalError(None) 153 - })?; 141 + .log_db_err("updating discord channel")?; 154 142 } 155 143 "telegram" => { 156 144 state 157 145 .user_repo 158 146 .verify_telegram_channel(user_id, identifier) 159 147 .await 160 - .map_err(|e| { 161 - error!("Failed to update telegram channel: {:?}", e); 162 - ApiError::InternalError(None) 163 - })?; 148 + .log_db_err("updating telegram channel")?; 164 149 } 165 150 "signal" => { 166 151 state 167 152 .user_repo 168 153 .verify_signal_channel(user_id, identifier) 169 154 .await 170 - .map_err(|e| { 171 - error!("Failed to update signal channel: {:?}", e); 172 - ApiError::InternalError(None) 173 - })?; 155 + .log_db_err("updating signal channel")?; 174 156 } 175 157 _ => { 176 158 return Err(ApiError::InvalidChannel); ··· 200 182 .user_repo 201 183 .get_verification_info(&did_typed) 202 184 .await 203 - .map_err(|e| { 204 - warn!(error = ?e, "Database error during signup verification"); 205 - ApiError::InternalError(None) 206 - })? 185 + .log_db_err("during signup verification")? 207 186 .ok_or(ApiError::AccountNotFound)?; 208 187 209 188 let is_verified = user.email_verified ··· 226 205 .user_repo 227 206 .set_email_verified_flag(user.id) 228 207 .await 229 - .map_err(|e| { 230 - warn!(error = ?e, "Failed to update email verified status"); 231 - ApiError::InternalError(None) 232 - })?; 208 + .log_db_err("updating email verified status")?; 233 209 } 234 210 "discord" => { 235 211 state 236 212 .user_repo 237 213 .set_discord_verified_flag(user.id) 238 214 .await 239 - .map_err(|e| { 240 - warn!(error = ?e, "Failed to update discord verified status"); 241 - ApiError::InternalError(None) 242 - })?; 215 + .log_db_err("updating discord verified status")?; 243 216 } 244 217 "telegram" => { 245 218 state 246 219 .user_repo 247 220 .set_telegram_verified_flag(user.id) 248 221 .await 249 - .map_err(|e| { 250 - warn!(error = ?e, "Failed to update telegram verified status"); 251 - ApiError::InternalError(None) 252 - })?; 222 + .log_db_err("updating telegram verified status")?; 253 223 } 254 224 "signal" => { 255 225 state 256 226 .user_repo 257 227 .set_signal_verified_flag(user.id) 258 228 .await 259 - .map_err(|e| { 260 - warn!(error = ?e, "Failed to update signal verified status"); 261 - ApiError::InternalError(None) 262 - })?; 229 + .log_db_err("updating signal verified status")?; 263 230 } 264 231 _ => { 265 232 return Err(ApiError::InvalidChannel);
+22 -18
crates/tranquil-pds/src/auth/extractor.rs
··· 9 9 10 10 use super::{ 11 11 AccountStatus, AuthSource, AuthenticatedUser, ServiceTokenClaims, ServiceTokenVerifier, 12 - is_service_token, validate_bearer_token_for_service_auth, 12 + is_service_token, scope_verified::VerifyScope, validate_bearer_token_for_service_auth, 13 13 }; 14 14 use crate::api::error::ApiError; 15 15 use crate::oauth::scopes::{RepoAction, ScopePermissions}; ··· 293 293 return Ok(ExtractedAuth::Service(claims)); 294 294 } 295 295 296 - let dpop_proof = parts.headers.get("DPoP").and_then(|h| h.to_str().ok()); 296 + let dpop_proof = crate::util::get_header_str(&parts.headers, "DPoP"); 297 297 let method = parts.method.as_str(); 298 298 let uri = build_full_url(&parts.uri.to_string()); 299 299 ··· 358 358 } 359 359 } 360 360 361 + impl<P: AuthPolicy> AsRef<AuthenticatedUser> for Auth<P> { 362 + fn as_ref(&self) -> &AuthenticatedUser { 363 + &self.0 364 + } 365 + } 366 + 367 + impl<P: AuthPolicy> VerifyScope for Auth<P> { 368 + fn needs_scope_check(&self) -> bool { 369 + self.0.is_oauth() 370 + } 371 + 372 + fn permissions(&self) -> ScopePermissions { 373 + self.0.permissions() 374 + } 375 + } 376 + 361 377 impl<P: AuthPolicy> FromRequestParts<AppState> for Auth<P> { 362 378 type Rejection = AuthError; 363 379 ··· 418 434 ) -> Result<Self, Self::Rejection> { 419 435 match extract_auth_internal(parts, state).await? { 420 436 ExtractedAuth::Service(claims) => { 421 - let did: Did = claims 422 - .iss 423 - .parse() 424 - .map_err(|_| AuthError::AuthenticationFailed)?; 437 + let did = claims.iss.clone(); 425 438 Ok(ServiceAuth { did, claims }) 426 439 } 427 440 ExtractedAuth::User(_) => Err(AuthError::AuthenticationFailed), ··· 438 451 ) -> Result<Option<Self>, Self::Rejection> { 439 452 match extract_auth_internal(parts, state).await { 440 453 Ok(ExtractedAuth::Service(claims)) => { 441 - let did: Did = claims 442 - .iss 443 - .parse() 444 - .map_err(|_| AuthError::AuthenticationFailed)?; 454 + let did = claims.iss.clone(); 445 455 Ok(Some(ServiceAuth { did, claims })) 446 456 } 447 457 Ok(ExtractedAuth::User(_)) => Err(AuthError::AuthenticationFailed), ··· 503 513 Ok(AuthAny::User(Auth(user, PhantomData))) 504 514 } 505 515 ExtractedAuth::Service(claims) => { 506 - let did: Did = claims 507 - .iss 508 - .parse() 509 - .map_err(|_| AuthError::AuthenticationFailed)?; 516 + let did = claims.iss.clone(); 510 517 Ok(AuthAny::Service(ServiceAuth { did, claims })) 511 518 } 512 519 } ··· 526 533 Ok(Some(AuthAny::User(Auth(user, PhantomData)))) 527 534 } 528 535 Ok(ExtractedAuth::Service(claims)) => { 529 - let did: Did = claims 530 - .iss 531 - .parse() 532 - .map_err(|_| AuthError::AuthenticationFailed)?; 536 + let did = claims.iss.clone(); 533 537 Ok(Some(AuthAny::Service(ServiceAuth { did, claims }))) 534 538 } 535 539 Err(AuthError::MissingToken) => Ok(None),
+223
crates/tranquil-pds/src/auth/mfa_verified.rs
··· 1 + use axum::response::Response; 2 + 3 + use super::AuthenticatedUser; 4 + use crate::state::AppState; 5 + use crate::types::Did; 6 + 7 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 8 + pub enum MfaMethod { 9 + Totp, 10 + Passkey, 11 + Password, 12 + RecoveryCode, 13 + SessionReauth, 14 + } 15 + 16 + impl MfaMethod { 17 + pub fn as_str(&self) -> &'static str { 18 + match self { 19 + Self::Totp => "totp", 20 + Self::Passkey => "passkey", 21 + Self::Password => "password", 22 + Self::RecoveryCode => "recovery_code", 23 + Self::SessionReauth => "session_reauth", 24 + } 25 + } 26 + } 27 + 28 + impl std::fmt::Display for MfaMethod { 29 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 30 + write!(f, "{}", self.as_str()) 31 + } 32 + } 33 + 34 + pub struct MfaVerified<'a> { 35 + user: &'a AuthenticatedUser, 36 + method: MfaMethod, 37 + } 38 + 39 + impl<'a> MfaVerified<'a> { 40 + fn new(user: &'a AuthenticatedUser, method: MfaMethod) -> Self { 41 + Self { user, method } 42 + } 43 + 44 + pub(crate) fn from_totp(user: &'a AuthenticatedUser) -> Self { 45 + Self::new(user, MfaMethod::Totp) 46 + } 47 + 48 + pub(crate) fn from_password(user: &'a AuthenticatedUser) -> Self { 49 + Self::new(user, MfaMethod::Password) 50 + } 51 + 52 + pub(crate) fn from_recovery_code(user: &'a AuthenticatedUser) -> Self { 53 + Self::new(user, MfaMethod::RecoveryCode) 54 + } 55 + 56 + pub(crate) fn from_session_reauth(user: &'a AuthenticatedUser) -> Self { 57 + Self::new(user, MfaMethod::SessionReauth) 58 + } 59 + 60 + pub fn user(&self) -> &AuthenticatedUser { 61 + self.user 62 + } 63 + 64 + pub fn did(&self) -> &Did { 65 + &self.user.did 66 + } 67 + 68 + pub fn method(&self) -> MfaMethod { 69 + self.method 70 + } 71 + } 72 + 73 + pub async fn require_legacy_session_mfa<'a>( 74 + state: &AppState, 75 + user: &'a AuthenticatedUser, 76 + ) -> Result<MfaVerified<'a>, Response> { 77 + use crate::api::server::reauth::{check_legacy_session_mfa, legacy_mfa_required_response}; 78 + 79 + if check_legacy_session_mfa(&*state.session_repo, &user.did).await { 80 + Ok(MfaVerified::from_session_reauth(user)) 81 + } else { 82 + Err(legacy_mfa_required_response(&*state.user_repo, &*state.session_repo, &user.did).await) 83 + } 84 + } 85 + 86 + pub async fn require_reauth_window<'a>( 87 + state: &AppState, 88 + user: &'a AuthenticatedUser, 89 + ) -> Result<MfaVerified<'a>, Response> { 90 + use chrono::Utc; 91 + use crate::api::server::reauth::{REAUTH_WINDOW_SECONDS, reauth_required_response}; 92 + 93 + let status = state.session_repo.get_session_mfa_status(&user.did).await.ok().flatten(); 94 + 95 + match status { 96 + Some(s) => { 97 + if let Some(last_reauth) = s.last_reauth_at { 98 + let elapsed = Utc::now().signed_duration_since(last_reauth); 99 + if elapsed.num_seconds() <= REAUTH_WINDOW_SECONDS { 100 + return Ok(MfaVerified::from_session_reauth(user)); 101 + } 102 + } 103 + Err(reauth_required_response(&*state.user_repo, &*state.session_repo, &user.did).await) 104 + } 105 + None => { 106 + Err(reauth_required_response(&*state.user_repo, &*state.session_repo, &user.did).await) 107 + } 108 + } 109 + } 110 + 111 + pub async fn require_reauth_window_if_available<'a>( 112 + state: &AppState, 113 + user: &'a AuthenticatedUser, 114 + ) -> Result<Option<MfaVerified<'a>>, Response> { 115 + use crate::api::server::reauth::{check_reauth_required_cached, reauth_required_response}; 116 + 117 + let has_password = state 118 + .user_repo 119 + .has_password_by_did(&user.did) 120 + .await 121 + .ok() 122 + .flatten() 123 + .unwrap_or(false); 124 + let has_passkeys = state 125 + .user_repo 126 + .has_passkeys(&user.did) 127 + .await 128 + .unwrap_or(false); 129 + let has_totp = state 130 + .user_repo 131 + .has_totp_enabled(&user.did) 132 + .await 133 + .unwrap_or(false); 134 + 135 + let has_any_reauth_method = has_password || has_passkeys || has_totp; 136 + 137 + if !has_any_reauth_method { 138 + return Ok(None); 139 + } 140 + 141 + if check_reauth_required_cached(&*state.session_repo, &state.cache, &user.did).await { 142 + Err(reauth_required_response(&*state.user_repo, &*state.session_repo, &user.did).await) 143 + } else { 144 + Ok(Some(MfaVerified::from_session_reauth(user))) 145 + } 146 + } 147 + 148 + pub async fn verify_password_mfa<'a>( 149 + state: &AppState, 150 + user: &'a AuthenticatedUser, 151 + password: &str, 152 + ) -> Result<MfaVerified<'a>, crate::api::error::ApiError> { 153 + let hash = state 154 + .user_repo 155 + .get_password_hash_by_did(&user.did) 156 + .await 157 + .ok() 158 + .flatten(); 159 + 160 + match hash { 161 + Some(h) => { 162 + if bcrypt::verify(password, &h).unwrap_or(false) { 163 + Ok(MfaVerified::from_password(user)) 164 + } else { 165 + Err(crate::api::error::ApiError::InvalidPassword( 166 + "Password is incorrect".into(), 167 + )) 168 + } 169 + } 170 + None => Err(crate::api::error::ApiError::AccountNotFound), 171 + } 172 + } 173 + 174 + pub async fn verify_totp_mfa<'a>( 175 + state: &AppState, 176 + user: &'a AuthenticatedUser, 177 + code: &str, 178 + ) -> Result<MfaVerified<'a>, crate::api::error::ApiError> { 179 + use crate::auth::{decrypt_totp_secret, is_backup_code_format, verify_totp_code}; 180 + 181 + let code = code.trim(); 182 + 183 + if is_backup_code_format(code) { 184 + let backup_codes = state.user_repo.get_unused_backup_codes(&user.did).await.ok().unwrap_or_default(); 185 + let code_upper = code.to_uppercase(); 186 + 187 + let matched = backup_codes 188 + .iter() 189 + .find(|row| crate::auth::verify_backup_code(&code_upper, &row.code_hash)); 190 + 191 + return match matched { 192 + Some(row) => { 193 + let _ = state.user_repo.mark_backup_code_used(row.id).await; 194 + Ok(MfaVerified::from_recovery_code(user)) 195 + } 196 + None => Err(crate::api::error::ApiError::InvalidCode(Some( 197 + "Invalid backup code".into(), 198 + ))), 199 + }; 200 + } 201 + 202 + let totp_record = match state.user_repo.get_totp_record(&user.did).await { 203 + Ok(Some(row)) if row.verified => row, 204 + _ => { 205 + return Err(crate::api::error::ApiError::TotpNotEnabled); 206 + } 207 + }; 208 + 209 + let secret = decrypt_totp_secret( 210 + &totp_record.secret_encrypted, 211 + totp_record.encryption_version, 212 + ) 213 + .map_err(|_| crate::api::error::ApiError::InternalError(None))?; 214 + 215 + if verify_totp_code(&secret, code) { 216 + let _ = state.user_repo.update_totp_last_used(&user.did).await; 217 + Ok(MfaVerified::from_totp(user)) 218 + } else { 219 + Err(crate::api::error::ApiError::InvalidCode(Some( 220 + "Invalid verification code".into(), 221 + ))) 222 + } 223 + }
+10
crates/tranquil-pds/src/auth/mod.rs
··· 11 11 use tranquil_db_traits::OAuthRepository; 12 12 13 13 pub mod extractor; 14 + pub mod mfa_verified; 14 15 pub mod scope_check; 16 + pub mod scope_verified; 15 17 pub mod service; 16 18 pub mod verification_token; 17 19 pub mod webauthn; ··· 20 22 Active, Admin, AnyUser, Auth, AuthAny, AuthError, AuthPolicy, ExtractedToken, NotTakendown, 21 23 Permissive, ServiceAuth, extract_auth_token_from_header, extract_bearer_token_from_header, 22 24 }; 25 + pub use mfa_verified::{ 26 + MfaMethod, MfaVerified, require_legacy_session_mfa, require_reauth_window, 27 + require_reauth_window_if_available, verify_password_mfa, verify_totp_mfa, 28 + }; 29 + pub use scope_verified::{ 30 + AccountManage, AccountRead, BlobUpload, IdentityAccess, RepoCreate, RepoDelete, RepoUpdate, 31 + RpcCall, ScopeAction, ScopeVerificationError, ScopeVerified, VerifyScope, 32 + }; 23 33 pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token}; 24 34 25 35 pub use tranquil_auth::{
+277
crates/tranquil-pds/src/auth/scope_verified.rs
··· 1 + use std::marker::PhantomData; 2 + 3 + use axum::response::{IntoResponse, Response}; 4 + 5 + use crate::api::error::ApiError; 6 + use crate::oauth::scopes::{AccountAction, AccountAttr, IdentityAttr, RepoAction, ScopePermissions}; 7 + 8 + use super::AuthenticatedUser; 9 + 10 + #[derive(Debug)] 11 + pub struct ScopeVerificationError { 12 + message: String, 13 + } 14 + 15 + impl ScopeVerificationError { 16 + pub fn new(message: impl Into<String>) -> Self { 17 + Self { 18 + message: message.into(), 19 + } 20 + } 21 + 22 + pub fn message(&self) -> &str { 23 + &self.message 24 + } 25 + } 26 + 27 + impl std::fmt::Display for ScopeVerificationError { 28 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 29 + write!(f, "{}", self.message) 30 + } 31 + } 32 + 33 + impl std::error::Error for ScopeVerificationError {} 34 + 35 + impl IntoResponse for ScopeVerificationError { 36 + fn into_response(self) -> Response { 37 + ApiError::InsufficientScope(Some(self.message)).into_response() 38 + } 39 + } 40 + 41 + mod private { 42 + pub trait Sealed {} 43 + } 44 + 45 + pub trait ScopeAction: private::Sealed {} 46 + 47 + pub struct RepoCreate; 48 + pub struct RepoUpdate; 49 + pub struct RepoDelete; 50 + pub struct BlobUpload; 51 + pub struct RpcCall; 52 + pub struct AccountRead; 53 + pub struct AccountManage; 54 + pub struct IdentityAccess; 55 + 56 + impl private::Sealed for RepoCreate {} 57 + impl private::Sealed for RepoUpdate {} 58 + impl private::Sealed for RepoDelete {} 59 + impl private::Sealed for BlobUpload {} 60 + impl private::Sealed for RpcCall {} 61 + impl private::Sealed for AccountRead {} 62 + impl private::Sealed for AccountManage {} 63 + impl private::Sealed for IdentityAccess {} 64 + 65 + impl ScopeAction for RepoCreate {} 66 + impl ScopeAction for RepoUpdate {} 67 + impl ScopeAction for RepoDelete {} 68 + impl ScopeAction for BlobUpload {} 69 + impl ScopeAction for RpcCall {} 70 + impl ScopeAction for AccountRead {} 71 + impl ScopeAction for AccountManage {} 72 + impl ScopeAction for IdentityAccess {} 73 + 74 + pub struct ScopeVerified<'a, A: ScopeAction> { 75 + user: &'a AuthenticatedUser, 76 + _action: PhantomData<A>, 77 + } 78 + 79 + impl<'a, A: ScopeAction> ScopeVerified<'a, A> { 80 + pub fn user(&self) -> &AuthenticatedUser { 81 + self.user 82 + } 83 + 84 + pub fn did(&self) -> &crate::types::Did { 85 + &self.user.did 86 + } 87 + 88 + pub fn is_admin(&self) -> bool { 89 + self.user.is_admin 90 + } 91 + 92 + pub fn controller_did(&self) -> Option<&crate::types::Did> { 93 + self.user.controller_did.as_ref() 94 + } 95 + } 96 + 97 + pub trait VerifyScope { 98 + fn needs_scope_check(&self) -> bool; 99 + fn permissions(&self) -> ScopePermissions; 100 + 101 + fn verify_repo_create<'a>( 102 + &'a self, 103 + collection: &str, 104 + ) -> Result<ScopeVerified<'a, RepoCreate>, ScopeVerificationError> 105 + where 106 + Self: AsRef<AuthenticatedUser>, 107 + { 108 + if !self.needs_scope_check() { 109 + return Ok(ScopeVerified { 110 + user: self.as_ref(), 111 + _action: PhantomData, 112 + }); 113 + } 114 + self.permissions() 115 + .assert_repo(RepoAction::Create, collection) 116 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 117 + Ok(ScopeVerified { 118 + user: self.as_ref(), 119 + _action: PhantomData, 120 + }) 121 + } 122 + 123 + fn verify_repo_update<'a>( 124 + &'a self, 125 + collection: &str, 126 + ) -> Result<ScopeVerified<'a, RepoUpdate>, ScopeVerificationError> 127 + where 128 + Self: AsRef<AuthenticatedUser>, 129 + { 130 + if !self.needs_scope_check() { 131 + return Ok(ScopeVerified { 132 + user: self.as_ref(), 133 + _action: PhantomData, 134 + }); 135 + } 136 + self.permissions() 137 + .assert_repo(RepoAction::Update, collection) 138 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 139 + Ok(ScopeVerified { 140 + user: self.as_ref(), 141 + _action: PhantomData, 142 + }) 143 + } 144 + 145 + fn verify_repo_delete<'a>( 146 + &'a self, 147 + collection: &str, 148 + ) -> Result<ScopeVerified<'a, RepoDelete>, ScopeVerificationError> 149 + where 150 + Self: AsRef<AuthenticatedUser>, 151 + { 152 + if !self.needs_scope_check() { 153 + return Ok(ScopeVerified { 154 + user: self.as_ref(), 155 + _action: PhantomData, 156 + }); 157 + } 158 + self.permissions() 159 + .assert_repo(RepoAction::Delete, collection) 160 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 161 + Ok(ScopeVerified { 162 + user: self.as_ref(), 163 + _action: PhantomData, 164 + }) 165 + } 166 + 167 + fn verify_blob_upload<'a>( 168 + &'a self, 169 + mime_type: &str, 170 + ) -> Result<ScopeVerified<'a, BlobUpload>, ScopeVerificationError> 171 + where 172 + Self: AsRef<AuthenticatedUser>, 173 + { 174 + if !self.needs_scope_check() { 175 + return Ok(ScopeVerified { 176 + user: self.as_ref(), 177 + _action: PhantomData, 178 + }); 179 + } 180 + self.permissions() 181 + .assert_blob(mime_type) 182 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 183 + Ok(ScopeVerified { 184 + user: self.as_ref(), 185 + _action: PhantomData, 186 + }) 187 + } 188 + 189 + fn verify_rpc<'a>( 190 + &'a self, 191 + aud: &str, 192 + lxm: &str, 193 + ) -> Result<ScopeVerified<'a, RpcCall>, ScopeVerificationError> 194 + where 195 + Self: AsRef<AuthenticatedUser>, 196 + { 197 + if !self.needs_scope_check() { 198 + return Ok(ScopeVerified { 199 + user: self.as_ref(), 200 + _action: PhantomData, 201 + }); 202 + } 203 + self.permissions() 204 + .assert_rpc(aud, lxm) 205 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 206 + Ok(ScopeVerified { 207 + user: self.as_ref(), 208 + _action: PhantomData, 209 + }) 210 + } 211 + 212 + fn verify_account_read<'a>( 213 + &'a self, 214 + attr: AccountAttr, 215 + ) -> Result<ScopeVerified<'a, AccountRead>, ScopeVerificationError> 216 + where 217 + Self: AsRef<AuthenticatedUser>, 218 + { 219 + if !self.needs_scope_check() { 220 + return Ok(ScopeVerified { 221 + user: self.as_ref(), 222 + _action: PhantomData, 223 + }); 224 + } 225 + self.permissions() 226 + .assert_account(attr, AccountAction::Read) 227 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 228 + Ok(ScopeVerified { 229 + user: self.as_ref(), 230 + _action: PhantomData, 231 + }) 232 + } 233 + 234 + fn verify_account_manage<'a>( 235 + &'a self, 236 + attr: AccountAttr, 237 + ) -> Result<ScopeVerified<'a, AccountManage>, ScopeVerificationError> 238 + where 239 + Self: AsRef<AuthenticatedUser>, 240 + { 241 + if !self.needs_scope_check() { 242 + return Ok(ScopeVerified { 243 + user: self.as_ref(), 244 + _action: PhantomData, 245 + }); 246 + } 247 + self.permissions() 248 + .assert_account(attr, AccountAction::Manage) 249 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 250 + Ok(ScopeVerified { 251 + user: self.as_ref(), 252 + _action: PhantomData, 253 + }) 254 + } 255 + 256 + fn verify_identity<'a>( 257 + &'a self, 258 + attr: IdentityAttr, 259 + ) -> Result<ScopeVerified<'a, IdentityAccess>, ScopeVerificationError> 260 + where 261 + Self: AsRef<AuthenticatedUser>, 262 + { 263 + if !self.needs_scope_check() { 264 + return Ok(ScopeVerified { 265 + user: self.as_ref(), 266 + _action: PhantomData, 267 + }); 268 + } 269 + self.permissions() 270 + .assert_identity(attr) 271 + .map_err(|e| ScopeVerificationError::new(e.to_string()))?; 272 + Ok(ScopeVerified { 273 + user: self.as_ref(), 274 + _action: PhantomData, 275 + }) 276 + } 277 + }
+10 -8
crates/tranquil-pds/src/auth/service.rs
··· 1 + use crate::types::Did; 2 + use crate::util::pds_hostname; 1 3 use anyhow::{Result, anyhow}; 2 4 use base64::Engine as _; 3 5 use base64::engine::general_purpose::URL_SAFE_NO_PAD; ··· 42 44 43 45 #[derive(Debug, Clone, Serialize, Deserialize)] 44 46 pub struct ServiceTokenClaims { 45 - pub iss: String, 47 + pub iss: Did, 46 48 #[serde(default)] 47 - pub sub: Option<String>, 48 - pub aud: String, 49 + pub sub: Option<Did>, 50 + pub aud: Did, 49 51 pub exp: usize, 50 52 #[serde(default)] 51 53 pub iat: Option<usize>, ··· 56 58 } 57 59 58 60 impl ServiceTokenClaims { 59 - pub fn subject(&self) -> &str { 60 - self.sub.as_deref().unwrap_or(&self.iss) 61 + pub fn subject(&self) -> &Did { 62 + self.sub.as_ref().unwrap_or(&self.iss) 61 63 } 62 64 } 63 65 ··· 79 81 .unwrap_or_else(|_| "https://plc.directory".to_string()); 80 82 81 83 let pds_hostname = 82 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 84 + pds_hostname(); 83 85 let pds_did = format!("did:web:{}", pds_hostname); 84 86 85 87 let client = Client::builder() ··· 130 132 return Err(anyhow!("Token expired")); 131 133 } 132 134 133 - if claims.aud != self.pds_did { 135 + if claims.aud.as_str() != self.pds_did { 134 136 return Err(anyhow!( 135 137 "Invalid audience: expected {}, got {}", 136 138 self.pds_did, ··· 154 156 } 155 157 } 156 158 157 - let did = &claims.iss; 159 + let did = claims.iss.as_str(); 158 160 let public_key = self.resolve_signing_key(did).await?; 159 161 160 162 let signature_bytes = URL_SAFE_NO_PAD
+1 -1
crates/tranquil-pds/src/auth/webauthn.rs
··· 7 7 8 8 impl WebAuthnConfig { 9 9 pub fn new(hostname: &str) -> Result<Self, String> { 10 - let rp_id = hostname.to_string(); 10 + let rp_id = hostname.split(':').next().unwrap_or(hostname).to_string(); 11 11 let rp_origin = Url::parse(&format!("https://{}", hostname)) 12 12 .map_err(|e| format!("Invalid origin URL: {}", e))?; 13 13
+6 -2
crates/tranquil-pds/src/crawlers.rs
··· 1 1 use crate::circuit_breaker::CircuitBreaker; 2 2 use crate::sync::firehose::SequencedEvent; 3 + use crate::util::pds_hostname; 3 4 use reqwest::Client; 4 5 use std::sync::Arc; 5 6 use std::sync::atomic::{AtomicU64, Ordering}; ··· 40 41 } 41 42 42 43 pub fn from_env() -> Option<Self> { 43 - let hostname = std::env::var("PDS_HOSTNAME").ok()?; 44 + let hostname = pds_hostname(); 45 + if hostname == "localhost" { 46 + return None; 47 + } 44 48 45 49 let crawler_urls: Vec<String> = std::env::var("CRAWLERS") 46 50 .unwrap_or_default() ··· 53 57 return None; 54 58 } 55 59 56 - Some(Self::new(hostname, crawler_urls)) 60 + Some(Self::new(hostname.to_string(), crawler_urls)) 57 61 } 58 62 59 63 fn should_notify(&self) -> bool {
+5
crates/tranquil-pds/src/delegation/mod.rs
··· 1 + pub mod roles; 1 2 pub mod scopes; 2 3 4 + pub use roles::{ 5 + CanAddControllers, CanControlAccounts, verify_can_add_controllers, verify_can_be_controller, 6 + verify_can_control_accounts, 7 + }; 3 8 pub use scopes::{SCOPE_PRESETS, ScopePreset, intersect_scopes}; 4 9 pub use tranquil_db_traits::DelegationActionType;
+88
crates/tranquil-pds/src/delegation/roles.rs
··· 1 + use axum::response::{IntoResponse, Response}; 2 + 3 + use crate::api::error::ApiError; 4 + use crate::auth::AuthenticatedUser; 5 + use crate::state::AppState; 6 + use crate::types::Did; 7 + 8 + pub struct CanAddControllers<'a> { 9 + user: &'a AuthenticatedUser, 10 + } 11 + 12 + pub struct CanControlAccounts<'a> { 13 + user: &'a AuthenticatedUser, 14 + } 15 + 16 + impl<'a> CanAddControllers<'a> { 17 + pub fn did(&self) -> &Did { 18 + &self.user.did 19 + } 20 + 21 + pub fn user(&self) -> &AuthenticatedUser { 22 + self.user 23 + } 24 + } 25 + 26 + impl<'a> CanControlAccounts<'a> { 27 + pub fn did(&self) -> &Did { 28 + &self.user.did 29 + } 30 + 31 + pub fn user(&self) -> &AuthenticatedUser { 32 + self.user 33 + } 34 + } 35 + 36 + pub async fn verify_can_add_controllers<'a>( 37 + state: &AppState, 38 + user: &'a AuthenticatedUser, 39 + ) -> Result<CanAddControllers<'a>, Response> { 40 + match state.delegation_repo.controls_any_accounts(&user.did).await { 41 + Ok(true) => Err(ApiError::InvalidDelegation( 42 + "Cannot add controllers to an account that controls other accounts".into(), 43 + ) 44 + .into_response()), 45 + Ok(false) => Ok(CanAddControllers { user }), 46 + Err(e) => { 47 + tracing::error!("Failed to check delegation status: {:?}", e); 48 + Err(ApiError::InternalError(Some("Failed to verify delegation status".into())) 49 + .into_response()) 50 + } 51 + } 52 + } 53 + 54 + pub async fn verify_can_control_accounts<'a>( 55 + state: &AppState, 56 + user: &'a AuthenticatedUser, 57 + ) -> Result<CanControlAccounts<'a>, Response> { 58 + match state.delegation_repo.has_any_controllers(&user.did).await { 59 + Ok(true) => Err(ApiError::InvalidDelegation( 60 + "Cannot create delegated accounts from a controlled account".into(), 61 + ) 62 + .into_response()), 63 + Ok(false) => Ok(CanControlAccounts { user }), 64 + Err(e) => { 65 + tracing::error!("Failed to check controller status: {:?}", e); 66 + Err(ApiError::InternalError(Some("Failed to verify controller status".into())) 67 + .into_response()) 68 + } 69 + } 70 + } 71 + 72 + pub async fn verify_can_be_controller( 73 + state: &AppState, 74 + controller_did: &Did, 75 + ) -> Result<(), Response> { 76 + match state.delegation_repo.has_any_controllers(controller_did).await { 77 + Ok(true) => Err(ApiError::InvalidDelegation( 78 + "Cannot add a controlled account as a controller".into(), 79 + ) 80 + .into_response()), 81 + Ok(false) => Ok(()), 82 + Err(e) => { 83 + tracing::error!("Failed to check controller status: {:?}", e); 84 + Err(ApiError::InternalError(Some("Failed to verify controller status".into())) 85 + .into_response()) 86 + } 87 + } 88 + }
+94 -255
crates/tranquil-pds/src/oauth/endpoints/authorize.rs
··· 1 1 use crate::comms::{channel_display_name, comms_repo::enqueue_2fa_code}; 2 2 use crate::oauth::{ 3 - AuthFlowState, ClientMetadataCache, Code, DeviceData, DeviceId, OAuthError, SessionId, 3 + AuthFlow, ClientMetadataCache, Code, DeviceData, DeviceId, OAuthError, Prompt, SessionId, 4 4 db::should_show_consent, scopes::expand_include_scopes, 5 5 }; 6 - use crate::state::{AppState, RateLimitKind}; 6 + use crate::rate_limit::{ 7 + OAuthAuthorizeLimit, OAuthRateLimited, OAuthRegisterCompleteLimit, TotpVerifyLimit, 8 + check_user_rate_limit, 9 + }; 10 + use crate::state::AppState; 7 11 use crate::types::{Did, Handle, PlainPassword}; 12 + use crate::util::{extract_client_ip, pds_hostname, pds_hostname_without_port}; 8 13 use axum::{ 9 14 Json, 10 15 extract::{Query, State}, ··· 79 84 || s.starts_with("include:") 80 85 } 81 86 82 - fn validate_auth_flow_state( 83 - flow_state: &AuthFlowState, 84 - require_authenticated: bool, 85 - ) -> Option<Response> { 86 - if flow_state.is_expired() { 87 - return Some(json_error( 88 - StatusCode::BAD_REQUEST, 89 - "invalid_request", 90 - "Authorization request has expired", 91 - )); 92 - } 93 - if require_authenticated && flow_state.is_pending() { 94 - return Some(json_error( 95 - StatusCode::FORBIDDEN, 96 - "access_denied", 97 - "Not authenticated", 98 - )); 99 - } 100 - None 101 - } 102 - 103 87 fn extract_device_cookie(headers: &HeaderMap) -> Option<String> { 104 88 headers 105 89 .get("cookie") ··· 113 97 }) 114 98 } 115 99 116 - fn extract_client_ip(headers: &HeaderMap) -> String { 117 - if let Some(forwarded) = headers.get("x-forwarded-for") 118 - && let Ok(value) = forwarded.to_str() 119 - && let Some(first_ip) = value.split(',').next() 120 - { 121 - return first_ip.trim().to_string(); 122 - } 123 - if let Some(real_ip) = headers.get("x-real-ip") 124 - && let Ok(value) = real_ip.to_str() 125 - { 126 - return value.trim().to_string(); 127 - } 128 - "0.0.0.0".to_string() 129 - } 130 - 131 100 fn extract_user_agent(headers: &HeaderMap) -> Option<String> { 132 101 headers 133 102 .get("user-agent") ··· 282 251 283 252 if let Some(ref login_hint) = request_data.parameters.login_hint { 284 253 tracing::info!(login_hint = %login_hint, "Checking login_hint for delegation"); 285 - let pds_hostname = 286 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 287 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 254 + let hostname_for_handles = pds_hostname_without_port(); 288 255 let normalized = if login_hint.contains('@') || login_hint.starts_with("did:") { 289 256 login_hint.clone() 290 257 } else if !login_hint.contains('.') { ··· 340 307 tracing::info!("No login_hint in request"); 341 308 } 342 309 343 - if request_data.parameters.prompt.as_deref() == Some("create") { 310 + if request_data.parameters.prompt == Some(Prompt::Create) { 344 311 return redirect_see_other(&format!( 345 312 "/app/oauth/register?request_uri={}", 346 313 url_encode(&request_uri) ··· 485 452 486 453 pub async fn authorize_post( 487 454 State(state): State<AppState>, 455 + _rate_limit: OAuthRateLimited<OAuthAuthorizeLimit>, 488 456 headers: HeaderMap, 489 457 Json(form): Json<AuthorizeSubmit>, 490 458 ) -> Response { 491 459 let json_response = wants_json(&headers); 492 - let client_ip = extract_client_ip(&headers); 493 - if !state 494 - .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip) 495 - .await 496 - { 497 - tracing::warn!(ip = %client_ip, "OAuth authorize rate limit exceeded"); 498 - if json_response { 499 - return ( 500 - axum::http::StatusCode::TOO_MANY_REQUESTS, 501 - Json(serde_json::json!({ 502 - "error": "RateLimitExceeded", 503 - "error_description": "Too many login attempts. Please try again later." 504 - })), 505 - ) 506 - .into_response(); 507 - } 508 - return redirect_to_frontend_error( 509 - "RateLimitExceeded", 510 - "Too many login attempts. Please try again later.", 511 - ); 512 - } 513 460 let form_request_id = RequestId::from(form.request_uri.clone()); 514 461 let request_data = match state 515 462 .oauth_repo ··· 584 531 url_encode(error_msg) 585 532 )) 586 533 }; 587 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 588 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 534 + let hostname_for_handles = pds_hostname_without_port(); 589 535 let normalized_username = form.username.trim(); 590 536 let normalized_username = normalized_username 591 537 .strip_prefix('@') ··· 600 546 tracing::debug!( 601 547 original_username = %form.username, 602 548 normalized_username = %normalized_username, 603 - pds_hostname = %pds_hostname, 549 + pds_hostname = %pds_hostname(), 604 550 "Normalized username for lookup" 605 551 ); 606 552 let user = match state ··· 748 694 .await 749 695 { 750 696 Ok(challenge) => { 751 - let hostname = 752 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 697 + let hostname = pds_hostname(); 753 698 if let Err(e) = enqueue_2fa_code( 754 699 state.user_repo.as_ref(), 755 700 state.infra_repo.as_ref(), 756 701 user.id, 757 702 &challenge.code, 758 - &hostname, 703 + hostname, 759 704 ) 760 705 .await 761 706 { ··· 792 737 } else { 793 738 let new_id = DeviceId::generate(); 794 739 let device_data = DeviceData { 795 - session_id: SessionId::generate().0, 740 + session_id: SessionId::generate(), 796 741 user_agent: extract_user_agent(&headers), 797 - ip_address: extract_client_ip(&headers), 742 + ip_address: extract_client_ip(&headers, None), 798 743 last_seen_at: Utc::now(), 799 744 }; 800 745 let new_device_id_typed = DeviceIdType::from(new_id.0.clone()); ··· 888 833 &request_data.parameters.redirect_uri, 889 834 &code.0, 890 835 request_data.parameters.state.as_deref(), 891 - request_data.parameters.response_mode.as_deref(), 836 + request_data.parameters.response_mode.map(|m| m.as_str()), 892 837 ); 893 838 if let Some(cookie) = new_cookie { 894 839 ( ··· 905 850 &request_data.parameters.redirect_uri, 906 851 &code.0, 907 852 request_data.parameters.state.as_deref(), 908 - request_data.parameters.response_mode.as_deref(), 853 + request_data.parameters.response_mode.map(|m| m.as_str()), 909 854 ); 910 855 if let Some(cookie) = new_cookie { 911 856 ( ··· 1068 1013 .await 1069 1014 { 1070 1015 Ok(challenge) => { 1071 - let hostname = 1072 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 1016 + let hostname = pds_hostname(); 1073 1017 if let Err(e) = enqueue_2fa_code( 1074 1018 state.user_repo.as_ref(), 1075 1019 state.infra_repo.as_ref(), 1076 1020 user.id, 1077 1021 &challenge.code, 1078 - &hostname, 1022 + hostname, 1079 1023 ) 1080 1024 .await 1081 1025 { ··· 1169 1113 &request_data.parameters.redirect_uri, 1170 1114 &code.0, 1171 1115 request_data.parameters.state.as_deref(), 1172 - request_data.parameters.response_mode.as_deref(), 1116 + request_data.parameters.response_mode.map(|m| m.as_str()), 1173 1117 ); 1174 1118 Json(serde_json::json!({ 1175 1119 "redirect_uri": redirect_url ··· 1193 1137 '?' 1194 1138 }; 1195 1139 redirect_url.push(separator); 1196 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 1140 + let pds_host = pds_hostname(); 1197 1141 redirect_url.push_str(&format!( 1198 1142 "iss={}", 1199 - url_encode(&format!("https://{}", pds_hostname)) 1143 + url_encode(&format!("https://{}", pds_host)) 1200 1144 )); 1201 1145 if let Some(req_state) = state { 1202 1146 redirect_url.push_str(&format!("&state={}", url_encode(req_state))); ··· 1211 1155 state: Option<&str>, 1212 1156 response_mode: Option<&str>, 1213 1157 ) -> String { 1214 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 1158 + let pds_host = pds_hostname(); 1215 1159 let mut url = format!( 1216 1160 "https://{}/oauth/authorize/redirect?redirect_uri={}&code={}", 1217 - pds_hostname, 1161 + pds_host, 1218 1162 url_encode(redirect_uri), 1219 1163 url_encode(code) 1220 1164 ); ··· 1459 1403 ); 1460 1404 } 1461 1405 }; 1462 - let flow_state = AuthFlowState::from_request_data(&request_data); 1463 - 1464 - if let Some(err_response) = validate_auth_flow_state(&flow_state, true) { 1465 - if flow_state.is_expired() { 1406 + let flow_with_user = match AuthFlow::from_request_data(request_data.clone()) { 1407 + Ok(flow) => match flow.require_user() { 1408 + Ok(u) => u, 1409 + Err(_) => { 1410 + return json_error( 1411 + StatusCode::FORBIDDEN, 1412 + "access_denied", 1413 + "Not authenticated", 1414 + ); 1415 + } 1416 + }, 1417 + Err(_) => { 1466 1418 let _ = state 1467 1419 .oauth_repo 1468 1420 .delete_authorization_request(&consent_request_id) 1469 1421 .await; 1470 - } 1471 - return err_response; 1472 - } 1473 - 1474 - let did_str = flow_state.did().unwrap().to_string(); 1475 - let did: Did = match did_str.parse() { 1476 - Ok(d) => d, 1477 - Err(_) => { 1478 1422 return json_error( 1479 1423 StatusCode::BAD_REQUEST, 1480 1424 "invalid_request", 1481 - "Invalid DID format in request.", 1425 + "Authorization request has expired", 1482 1426 ); 1483 1427 } 1484 1428 }; 1429 + 1430 + let did = flow_with_user.did().clone(); 1485 1431 let client_cache = ClientMetadataCache::new(3600); 1486 1432 let client_metadata = client_cache 1487 1433 .get(&request_data.parameters.client_id) ··· 1635 1581 logo_uri: client_metadata.as_ref().and_then(|m| m.logo_uri.clone()), 1636 1582 scopes, 1637 1583 show_consent, 1638 - did: did_str, 1584 + did: did.to_string(), 1639 1585 handle: account_handle, 1640 1586 is_delegation, 1641 1587 controller_did: controller_did_resp, ··· 1676 1622 ); 1677 1623 } 1678 1624 }; 1679 - let flow_state = AuthFlowState::from_request_data(&request_data); 1680 - 1681 - if flow_state.is_expired() { 1682 - let _ = state 1683 - .oauth_repo 1684 - .delete_authorization_request(&consent_post_request_id) 1685 - .await; 1686 - return json_error( 1687 - StatusCode::BAD_REQUEST, 1688 - "invalid_request", 1689 - "Authorization request has expired", 1690 - ); 1691 - } 1692 - if flow_state.is_pending() { 1693 - return json_error(StatusCode::FORBIDDEN, "access_denied", "Not authenticated"); 1694 - } 1695 - 1696 - let did_str = flow_state.did().unwrap().to_string(); 1697 - let did: Did = match did_str.parse() { 1698 - Ok(d) => d, 1625 + let flow_with_user = match AuthFlow::from_request_data(request_data.clone()) { 1626 + Ok(flow) => match flow.require_user() { 1627 + Ok(u) => u, 1628 + Err(_) => { 1629 + return json_error(StatusCode::FORBIDDEN, "access_denied", "Not authenticated"); 1630 + } 1631 + }, 1699 1632 Err(_) => { 1633 + let _ = state 1634 + .oauth_repo 1635 + .delete_authorization_request(&consent_post_request_id) 1636 + .await; 1700 1637 return json_error( 1701 1638 StatusCode::BAD_REQUEST, 1702 1639 "invalid_request", 1703 - "Invalid DID format", 1640 + "Authorization request has expired", 1704 1641 ); 1705 1642 } 1706 1643 }; 1644 + 1645 + let did = flow_with_user.did().clone(); 1707 1646 let original_scope_str = request_data 1708 1647 .parameters 1709 1648 .scope ··· 1799 1738 let consent_post_device_id = request_data 1800 1739 .device_id 1801 1740 .as_ref() 1802 - .map(|d| DeviceIdType::from(d.clone())); 1741 + .map(|d| DeviceIdType::from(d.0.clone())); 1803 1742 let consent_post_code = AuthorizationCode::from(code.0.clone()); 1804 1743 if state 1805 1744 .oauth_repo ··· 1823 1762 redirect_uri, 1824 1763 &code.0, 1825 1764 request_data.parameters.state.as_deref(), 1826 - request_data.parameters.response_mode.as_deref(), 1765 + request_data.parameters.response_mode.map(|m| m.as_str()), 1827 1766 ); 1828 1767 tracing::info!( 1829 1768 intermediate_url = %intermediate_url, ··· 1835 1774 1836 1775 pub async fn authorize_2fa_post( 1837 1776 State(state): State<AppState>, 1777 + _rate_limit: OAuthRateLimited<OAuthAuthorizeLimit>, 1838 1778 headers: HeaderMap, 1839 1779 Json(form): Json<Authorize2faSubmit>, 1840 1780 ) -> Response { ··· 1848 1788 ) 1849 1789 .into_response() 1850 1790 }; 1851 - let client_ip = extract_client_ip(&headers); 1852 - if !state 1853 - .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip) 1854 - .await 1855 - { 1856 - tracing::warn!(ip = %client_ip, "OAuth 2FA rate limit exceeded"); 1857 - return json_error( 1858 - StatusCode::TOO_MANY_REQUESTS, 1859 - "RateLimitExceeded", 1860 - "Too many attempts. Please try again later.", 1861 - ); 1862 - } 1863 1791 let twofa_post_request_id = RequestId::from(form.request_uri.clone()); 1864 1792 let request_data = match state 1865 1793 .oauth_repo ··· 1956 1884 &request_data.parameters.redirect_uri, 1957 1885 &code.0, 1958 1886 request_data.parameters.state.as_deref(), 1959 - request_data.parameters.response_mode.as_deref(), 1887 + request_data.parameters.response_mode.map(|m| m.as_str()), 1960 1888 ); 1961 1889 return Json(serde_json::json!({ 1962 1890 "redirect_uri": redirect_url ··· 1990 1918 "No 2FA challenge found. Please start over.", 1991 1919 ); 1992 1920 } 1993 - if !state 1994 - .check_rate_limit(RateLimitKind::TotpVerify, &did) 1995 - .await 1996 - { 1997 - tracing::warn!(did = %did, "TOTP verification rate limit exceeded"); 1998 - return json_error( 1999 - StatusCode::TOO_MANY_REQUESTS, 2000 - "RateLimitExceeded", 2001 - "Too many verification attempts. Please try again in a few minutes.", 2002 - ); 2003 - } 1921 + let _rate_proof = match check_user_rate_limit::<TotpVerifyLimit>(&state, &did).await { 1922 + Ok(proof) => proof, 1923 + Err(_) => { 1924 + return json_error( 1925 + StatusCode::TOO_MANY_REQUESTS, 1926 + "RateLimitExceeded", 1927 + "Too many verification attempts. Please try again in a few minutes.", 1928 + ) 1929 + } 1930 + }; 2004 1931 let totp_valid = 2005 1932 crate::api::server::verify_totp_or_backup_for_user(&state, &did, &form.code).await; 2006 1933 if !totp_valid { ··· 2065 1992 &request_data.parameters.redirect_uri, 2066 1993 &code.0, 2067 1994 request_data.parameters.state.as_deref(), 2068 - request_data.parameters.response_mode.as_deref(), 1995 + request_data.parameters.response_mode.map(|m| m.as_str()), 2069 1996 ); 2070 1997 Json(serde_json::json!({ 2071 1998 "redirect_uri": redirect_url ··· 2089 2016 State(state): State<AppState>, 2090 2017 Query(query): Query<CheckPasskeysQuery>, 2091 2018 ) -> Response { 2092 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2093 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 2019 + let hostname_for_handles = pds_hostname_without_port(); 2094 2020 let normalized_identifier = query.identifier.trim(); 2095 2021 let normalized_identifier = normalized_identifier 2096 2022 .strip_prefix('@') ··· 2131 2057 State(state): State<AppState>, 2132 2058 Query(query): Query<CheckPasskeysQuery>, 2133 2059 ) -> Response { 2134 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2135 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 2060 + let hostname_for_handles = pds_hostname_without_port(); 2136 2061 let identifier = query.identifier.trim(); 2137 2062 let identifier = identifier.strip_prefix('@').unwrap_or(identifier); 2138 2063 let normalized_identifier = if identifier.contains('@') || identifier.starts_with("did:") { ··· 2200 2125 2201 2126 pub async fn passkey_start( 2202 2127 State(state): State<AppState>, 2203 - headers: HeaderMap, 2128 + _rate_limit: OAuthRateLimited<OAuthAuthorizeLimit>, 2204 2129 Json(form): Json<PasskeyStartInput>, 2205 2130 ) -> Response { 2206 - let client_ip = extract_client_ip(&headers); 2207 - 2208 - if !state 2209 - .check_rate_limit(RateLimitKind::OAuthAuthorize, &client_ip) 2210 - .await 2211 - { 2212 - tracing::warn!(ip = %client_ip, "OAuth passkey rate limit exceeded"); 2213 - return ( 2214 - StatusCode::TOO_MANY_REQUESTS, 2215 - Json(serde_json::json!({ 2216 - "error": "RateLimitExceeded", 2217 - "error_description": "Too many login attempts. Please try again later." 2218 - })), 2219 - ) 2220 - .into_response(); 2221 - } 2222 - 2223 2131 let passkey_start_request_id = RequestId::from(form.request_uri.clone()); 2224 2132 let request_data = match state 2225 2133 .oauth_repo ··· 2264 2172 .into_response(); 2265 2173 } 2266 2174 2267 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2268 - let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 2175 + let hostname_for_handles = pds_hostname_without_port(); 2269 2176 let normalized_username = form.identifier.trim(); 2270 2177 let normalized_username = normalized_username 2271 2178 .strip_prefix('@') ··· 2386 2293 .into_response(); 2387 2294 } 2388 2295 2389 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 2390 - Ok(w) => w, 2391 - Err(e) => { 2392 - tracing::error!(error = %e, "Failed to create WebAuthn config"); 2393 - return ( 2394 - StatusCode::INTERNAL_SERVER_ERROR, 2395 - Json(serde_json::json!({ 2396 - "error": "server_error", 2397 - "error_description": "WebAuthn configuration failed." 2398 - })), 2399 - ) 2400 - .into_response(); 2401 - } 2402 - }; 2403 - 2404 - let (rcr, auth_state) = match webauthn.start_authentication(passkeys) { 2296 + let (rcr, auth_state) = match state.webauthn_config.start_authentication(passkeys) { 2405 2297 Ok(result) => result, 2406 2298 Err(e) => { 2407 2299 tracing::error!(error = %e, "Failed to start passkey authentication"); ··· 2680 2572 } 2681 2573 }; 2682 2574 2683 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2684 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 2685 - Ok(w) => w, 2686 - Err(e) => { 2687 - tracing::error!(error = %e, "Failed to create WebAuthn config"); 2688 - return ( 2689 - StatusCode::INTERNAL_SERVER_ERROR, 2690 - Json(serde_json::json!({ 2691 - "error": "server_error", 2692 - "error_description": "WebAuthn configuration failed." 2693 - })), 2694 - ) 2695 - .into_response(); 2696 - } 2697 - }; 2698 - 2699 - let auth_result = match webauthn.finish_authentication(&credential, &auth_state) { 2575 + let auth_result = match state 2576 + .webauthn_config 2577 + .finish_authentication(&credential, &auth_state) 2578 + { 2700 2579 Ok(r) => r, 2701 2580 Err(e) => { 2702 2581 tracing::warn!(error = %e, did = %did, "Failed to verify passkey authentication"); ··· 2769 2648 .await 2770 2649 { 2771 2650 Ok(challenge) => { 2772 - let hostname = 2773 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2651 + let hostname = pds_hostname(); 2774 2652 if let Err(e) = enqueue_2fa_code( 2775 2653 state.user_repo.as_ref(), 2776 2654 state.infra_repo.as_ref(), 2777 2655 user.id, 2778 2656 &challenge.code, 2779 - &hostname, 2657 + hostname, 2780 2658 ) 2781 2659 .await 2782 2660 { ··· 2859 2737 &request_data.parameters.redirect_uri, 2860 2738 &code.0, 2861 2739 request_data.parameters.state.as_deref(), 2862 - request_data.parameters.response_mode.as_deref(), 2740 + request_data.parameters.response_mode.map(|m| m.as_str()), 2863 2741 ); 2864 2742 2865 2743 Json(serde_json::json!({ ··· 2884 2762 State(state): State<AppState>, 2885 2763 Query(query): Query<AuthorizePasskeyQuery>, 2886 2764 ) -> Response { 2887 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2888 - 2889 2765 let auth_passkey_start_request_id = RequestId::from(query.request_uri.clone()); 2890 2766 let request_data = match state 2891 2767 .oauth_repo ··· 2994 2870 .into_response(); 2995 2871 } 2996 2872 2997 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 2998 - Ok(w) => w, 2999 - Err(e) => { 3000 - tracing::error!("Failed to create WebAuthn config: {:?}", e); 3001 - return ( 3002 - StatusCode::INTERNAL_SERVER_ERROR, 3003 - Json(serde_json::json!({"error": "server_error", "error_description": "An error occurred."})), 3004 - ) 3005 - .into_response(); 3006 - } 3007 - }; 3008 - 3009 - let (rcr, auth_state) = match webauthn.start_authentication(passkeys) { 2873 + let (rcr, auth_state) = match state.webauthn_config.start_authentication(passkeys) { 3010 2874 Ok(result) => result, 3011 2875 Err(e) => { 3012 2876 tracing::error!("Failed to start passkey authentication: {:?}", e); ··· 3063 2927 headers: HeaderMap, 3064 2928 Json(form): Json<AuthorizePasskeySubmit>, 3065 2929 ) -> Response { 3066 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2930 + let pds_hostname = pds_hostname(); 3067 2931 let passkey_finish_request_id = RequestId::from(form.request_uri.clone()); 3068 2932 3069 2933 let request_data = match state ··· 3193 3057 } 3194 3058 }; 3195 3059 3196 - let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) { 3197 - Ok(w) => w, 3198 - Err(e) => { 3199 - tracing::error!("Failed to create WebAuthn config: {:?}", e); 3200 - return ( 3201 - StatusCode::INTERNAL_SERVER_ERROR, 3202 - Json(serde_json::json!({"error": "server_error", "error_description": "An error occurred."})), 3203 - ) 3204 - .into_response(); 3205 - } 3206 - }; 3207 - 3208 - let auth_result = match webauthn.finish_authentication(&credential, &auth_state) { 3060 + let auth_result = match state 3061 + .webauthn_config 3062 + .finish_authentication(&credential, &auth_state) 3063 + { 3209 3064 Ok(r) => r, 3210 3065 Err(e) => { 3211 3066 tracing::warn!("Passkey authentication failed: {:?}", e); ··· 3292 3147 state.infra_repo.as_ref(), 3293 3148 user.id, 3294 3149 &challenge.code, 3295 - &pds_hostname, 3150 + pds_hostname, 3296 3151 ) 3297 3152 .await 3298 3153 { ··· 3347 3202 3348 3203 pub async fn register_complete( 3349 3204 State(state): State<AppState>, 3350 - headers: HeaderMap, 3205 + _rate_limit: OAuthRateLimited<OAuthRegisterCompleteLimit>, 3351 3206 Json(form): Json<RegisterCompleteInput>, 3352 3207 ) -> Response { 3353 - let client_ip = extract_client_ip(&headers); 3354 - 3355 - if !state 3356 - .check_rate_limit(RateLimitKind::OAuthRegisterComplete, &client_ip) 3357 - .await 3358 - { 3359 - return ( 3360 - StatusCode::TOO_MANY_REQUESTS, 3361 - Json(serde_json::json!({ 3362 - "error": "RateLimitExceeded", 3363 - "error_description": "Too many attempts. Please try again later." 3364 - })), 3365 - ) 3366 - .into_response(); 3367 - } 3368 - 3369 3208 let did = Did::from(form.did.clone()); 3370 3209 3371 3210 let request_id = RequestId::from(form.request_uri.clone()); ··· 3417 3256 .into_response(); 3418 3257 } 3419 3258 3420 - if request_data.parameters.prompt.as_deref() != Some("create") { 3259 + if request_data.parameters.prompt != Some(Prompt::Create) { 3421 3260 tracing::warn!( 3422 3261 request_uri = %form.request_uri, 3423 3262 prompt = ?request_data.parameters.prompt, ··· 3636 3475 &request_data.parameters.redirect_uri, 3637 3476 &code.0, 3638 3477 request_data.parameters.state.as_deref(), 3639 - request_data.parameters.response_mode.as_deref(), 3478 + request_data.parameters.response_mode.map(|m| m.as_str()), 3640 3479 ); 3641 3480 Json(serde_json::json!({"redirect_uri": redirect_url})).into_response() 3642 3481 } ··· 3662 3501 None => { 3663 3502 let new_id = DeviceId::generate(); 3664 3503 let device_data = DeviceData { 3665 - session_id: SessionId::generate().0, 3504 + session_id: SessionId::generate(), 3666 3505 user_agent: extract_user_agent(&headers), 3667 - ip_address: extract_client_ip(&headers), 3506 + ip_address: extract_client_ip(&headers, None), 3668 3507 last_seen_at: Utc::now(), 3669 3508 }; 3670 3509 let device_typed = DeviceIdType::from(new_id.0.clone());
+32 -64
crates/tranquil-pds/src/oauth/endpoints/delegation.rs
··· 1 1 use crate::auth::{Active, Auth}; 2 2 use crate::delegation::DelegationActionType; 3 - use crate::state::{AppState, RateLimitKind}; 3 + use crate::rate_limit::{LoginLimit, OAuthRateLimited, TotpVerifyLimit}; 4 + use crate::state::AppState; 4 5 use crate::types::PlainPassword; 5 6 use crate::util::extract_client_ip; 6 7 use axum::{ 7 8 Json, 8 9 extract::State, 9 - http::{HeaderMap, StatusCode}, 10 + http::HeaderMap, 10 11 response::{IntoResponse, Response}, 11 12 }; 12 13 use serde::{Deserialize, Serialize}; ··· 35 36 36 37 pub async fn delegation_auth( 37 38 State(state): State<AppState>, 39 + rate_limit: OAuthRateLimited<LoginLimit>, 38 40 headers: HeaderMap, 39 41 Json(form): Json<DelegationAuthSubmit>, 40 42 ) -> Response { 41 - let client_ip = extract_client_ip(&headers); 42 - if !state 43 - .check_rate_limit(RateLimitKind::Login, &client_ip) 44 - .await 45 - { 46 - return ( 47 - StatusCode::TOO_MANY_REQUESTS, 48 - Json(DelegationAuthResponse { 49 - success: false, 50 - needs_totp: None, 51 - redirect_uri: None, 52 - error: Some("Too many login attempts. Please try again later.".to_string()), 53 - }), 54 - ) 55 - .into_response(); 56 - } 57 - 43 + let client_ip = rate_limit.client_ip(); 58 44 let request_id = RequestId::from(form.request_uri.clone()); 59 45 let request = match state 60 46 .oauth_repo ··· 82 68 } 83 69 }; 84 70 85 - let delegated_did_str = match form.delegated_did.as_ref().or(request.did.as_ref()) { 86 - Some(did) => did.clone(), 87 - None => { 88 - return Json(DelegationAuthResponse { 89 - success: false, 90 - needs_totp: None, 91 - redirect_uri: None, 92 - error: Some("No delegated account selected".to_string()), 93 - }) 94 - .into_response(); 95 - } 96 - }; 97 - 98 - let delegated_did: Did = match delegated_did_str.parse() { 99 - Ok(d) => d, 100 - Err(_) => { 101 - return Json(DelegationAuthResponse { 102 - success: false, 103 - needs_totp: None, 104 - redirect_uri: None, 105 - error: Some("Invalid delegated DID".to_string()), 106 - }) 107 - .into_response(); 71 + let delegated_did: Did = if let Some(did_str) = form.delegated_did.as_ref() { 72 + match did_str.parse() { 73 + Ok(d) => d, 74 + Err(_) => { 75 + return Json(DelegationAuthResponse { 76 + success: false, 77 + needs_totp: None, 78 + redirect_uri: None, 79 + error: Some("Invalid delegated DID".to_string()), 80 + }) 81 + .into_response(); 82 + } 108 83 } 84 + } else if let Some(did) = request.did.as_ref() { 85 + did.clone() 86 + } else { 87 + return Json(DelegationAuthResponse { 88 + success: false, 89 + needs_totp: None, 90 + redirect_uri: None, 91 + error: Some("No delegated account selected".to_string()), 92 + }) 93 + .into_response(); 109 94 }; 110 95 111 96 let controller_did: Did = match form.controller_did.parse() { ··· 249 234 .into_response(); 250 235 } 251 236 252 - let ip = extract_client_ip(&headers); 253 237 let user_agent = headers 254 238 .get("user-agent") 255 239 .and_then(|v| v.to_str().ok()) ··· 266 250 "client_id": request.client_id, 267 251 "granted_scopes": grant.granted_scopes 268 252 })), 269 - Some(&ip), 253 + Some(client_ip), 270 254 user_agent.as_deref(), 271 255 ) 272 256 .await; ··· 291 275 292 276 pub async fn delegation_totp_verify( 293 277 State(state): State<AppState>, 278 + rate_limit: OAuthRateLimited<TotpVerifyLimit>, 294 279 headers: HeaderMap, 295 280 Json(form): Json<DelegationTotpSubmit>, 296 281 ) -> Response { 297 - let client_ip = extract_client_ip(&headers); 298 - if !state 299 - .check_rate_limit(RateLimitKind::TotpVerify, &client_ip) 300 - .await 301 - { 302 - return ( 303 - StatusCode::TOO_MANY_REQUESTS, 304 - Json(DelegationAuthResponse { 305 - success: false, 306 - needs_totp: None, 307 - redirect_uri: None, 308 - error: Some("Too many verification attempts. Please try again later.".to_string()), 309 - }), 310 - ) 311 - .into_response(); 312 - } 313 - 282 + let client_ip = rate_limit.client_ip(); 314 283 let totp_request_id = RequestId::from(form.request_uri.clone()); 315 284 let request = match state 316 285 .oauth_repo ··· 420 389 .into_response(); 421 390 } 422 391 423 - let ip = extract_client_ip(&headers); 424 392 let user_agent = headers 425 393 .get("user-agent") 426 394 .and_then(|v| v.to_str().ok()) ··· 437 405 "client_id": request.client_id, 438 406 "granted_scopes": grant.granted_scopes 439 407 })), 440 - Some(&ip), 408 + Some(client_ip), 441 409 user_agent.as_deref(), 442 410 ) 443 411 .await; ··· 564 532 .into_response(); 565 533 } 566 534 567 - let ip = extract_client_ip(&headers); 535 + let ip = extract_client_ip(&headers, None); 568 536 let user_agent = headers 569 537 .get("user-agent") 570 538 .and_then(|v| v.to_str().ok())
+3 -2
crates/tranquil-pds/src/oauth/endpoints/metadata.rs
··· 1 1 use crate::oauth::jwks::{JwkSet, create_jwk_set}; 2 2 use crate::state::AppState; 3 + use crate::util::pds_hostname; 3 4 use axum::{Json, extract::State}; 4 5 use serde::{Deserialize, Serialize}; 5 6 ··· 57 58 pub async fn oauth_protected_resource( 58 59 State(_state): State<AppState>, 59 60 ) -> Json<ProtectedResourceMetadata> { 60 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 61 + let pds_hostname = pds_hostname(); 61 62 let public_url = format!("https://{}", pds_hostname); 62 63 Json(ProtectedResourceMetadata { 63 64 resource: public_url.clone(), ··· 71 72 pub async fn oauth_authorization_server( 72 73 State(_state): State<AppState>, 73 74 ) -> Json<AuthorizationServerMetadata> { 74 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 75 + let pds_hostname = pds_hostname(); 75 76 let issuer = format!("https://{}", pds_hostname); 76 77 Json(AuthorizationServerMetadata { 77 78 issuer: issuer.clone(),
+57 -50
crates/tranquil-pds/src/oauth/endpoints/par.rs
··· 1 1 use crate::oauth::{ 2 - AuthorizationRequestParameters, ClientAuth, ClientMetadataCache, OAuthError, RequestData, 3 - RequestId, 2 + AuthorizationRequestParameters, ClientAuth, ClientMetadataCache, CodeChallengeMethod, 3 + OAuthError, Prompt, RequestData, RequestId, ResponseMode, ResponseType, 4 4 scopes::{ParsedScope, parse_scope}, 5 5 }; 6 - use crate::state::{AppState, RateLimitKind}; 6 + use crate::rate_limit::{OAuthParLimit, OAuthRateLimited}; 7 + use crate::state::AppState; 7 8 use axum::body::Bytes; 8 9 use axum::{Json, extract::State, http::HeaderMap}; 9 10 use chrono::{Duration, Utc}; ··· 49 50 50 51 pub async fn pushed_authorization_request( 51 52 State(state): State<AppState>, 53 + _rate_limit: OAuthRateLimited<OAuthParLimit>, 52 54 headers: HeaderMap, 53 55 body: Bytes, 54 56 ) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> { ··· 70 72 .to_string(), 71 73 )); 72 74 }; 73 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 74 - if !state 75 - .check_rate_limit(RateLimitKind::OAuthPar, &client_ip) 76 - .await 77 - { 78 - tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded"); 79 - return Err(OAuthError::RateLimited); 80 - } 81 - if request.response_type != "code" { 82 - return Err(OAuthError::InvalidRequest( 83 - "response_type must be 'code'".to_string(), 84 - )); 85 - } 75 + let response_type = parse_response_type(&request.response_type)?; 86 76 let code_challenge = request 87 77 .code_challenge 88 78 .as_ref() 89 79 .filter(|s| !s.is_empty()) 90 80 .ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?; 91 - let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or(""); 92 - if code_challenge_method != "S256" { 93 - return Err(OAuthError::InvalidRequest( 94 - "code_challenge_method must be 'S256'".to_string(), 95 - )); 96 - } 81 + let code_challenge_method = parse_code_challenge_method(request.code_challenge_method.as_deref())?; 97 82 let client_cache = ClientMetadataCache::new(3600); 98 83 let client_metadata = client_cache.get(&request.client_id).await?; 99 84 client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?; ··· 101 86 let validated_scope = validate_scope(&request.scope, &client_metadata)?; 102 87 let request_id = RequestId::generate(); 103 88 let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS); 104 - let response_mode = match request.response_mode.as_deref() { 105 - Some("fragment") => Some("fragment".to_string()), 106 - Some("query") | None => None, 107 - Some(mode) => { 108 - return Err(OAuthError::InvalidRequest(format!( 109 - "Unsupported response_mode: {}", 110 - mode 111 - ))); 112 - } 113 - }; 114 - let prompt = validate_prompt(&request.prompt)?; 89 + let response_mode = parse_response_mode(request.response_mode.as_deref())?; 90 + let prompt = parse_prompt(request.prompt.as_deref())?; 115 91 let parameters = AuthorizationRequestParameters { 116 - response_type: request.response_type, 92 + response_type, 117 93 client_id: request.client_id.clone(), 118 94 redirect_uri: request.redirect_uri, 119 95 scope: validated_scope, 120 96 state: request.state, 121 97 code_challenge: code_challenge.clone(), 122 - code_challenge_method: code_challenge_method.to_string(), 98 + code_challenge_method, 123 99 response_mode, 124 100 login_hint: request.login_hint, 125 101 dpop_jkt: request.dpop_jkt, ··· 266 242 false 267 243 } 268 244 269 - fn validate_prompt(prompt: &Option<String>) -> Result<Option<String>, OAuthError> { 270 - const VALID_PROMPTS: &[&str] = &["none", "login", "consent", "select_account", "create"]; 245 + fn parse_response_type(value: &str) -> Result<ResponseType, OAuthError> { 246 + match value { 247 + "code" => Ok(ResponseType::Code), 248 + other => Err(OAuthError::InvalidRequest(format!( 249 + "response_type must be 'code', got '{}'", 250 + other 251 + ))), 252 + } 253 + } 271 254 272 - match prompt { 273 - None => Ok(None), 274 - Some(p) if p.is_empty() => Ok(None), 275 - Some(p) => { 276 - if VALID_PROMPTS.contains(&p.as_str()) { 277 - Ok(Some(p.clone())) 278 - } else { 279 - Err(OAuthError::InvalidRequest(format!( 280 - "Unsupported prompt value: {}", 281 - p 282 - ))) 283 - } 284 - } 255 + fn parse_code_challenge_method(value: Option<&str>) -> Result<CodeChallengeMethod, OAuthError> { 256 + match value { 257 + Some("S256") | None => Ok(CodeChallengeMethod::S256), 258 + Some("plain") => Err(OAuthError::InvalidRequest( 259 + "code_challenge_method 'plain' is not allowed, use 'S256'".to_string(), 260 + )), 261 + Some(other) => Err(OAuthError::InvalidRequest(format!( 262 + "Unsupported code_challenge_method: {}", 263 + other 264 + ))), 265 + } 266 + } 267 + 268 + fn parse_response_mode(value: Option<&str>) -> Result<Option<ResponseMode>, OAuthError> { 269 + match value { 270 + None | Some("query") => Ok(None), 271 + Some("fragment") => Ok(Some(ResponseMode::Fragment)), 272 + Some("form_post") => Ok(Some(ResponseMode::FormPost)), 273 + Some(other) => Err(OAuthError::InvalidRequest(format!( 274 + "Unsupported response_mode: {}", 275 + other 276 + ))), 277 + } 278 + } 279 + 280 + fn parse_prompt(value: Option<&str>) -> Result<Option<Prompt>, OAuthError> { 281 + match value { 282 + None | Some("") => Ok(None), 283 + Some("none") => Ok(Some(Prompt::None)), 284 + Some("login") => Ok(Some(Prompt::Login)), 285 + Some("consent") => Ok(Some(Prompt::Consent)), 286 + Some("select_account") => Ok(Some(Prompt::SelectAccount)), 287 + Some("create") => Ok(Some(Prompt::Create)), 288 + Some(other) => Err(OAuthError::InvalidRequest(format!( 289 + "Unsupported prompt value: {}", 290 + other 291 + ))), 285 292 } 286 293 }
+42 -42
crates/tranquil-pds/src/oauth/endpoints/token/grants.rs
··· 3 3 use crate::config::AuthConfig; 4 4 use crate::delegation::intersect_scopes; 5 5 use crate::oauth::{ 6 - AuthFlowState, ClientAuth, ClientMetadataCache, DPoPVerifier, OAuthError, RefreshToken, 7 - TokenData, TokenId, 6 + AuthFlow, ClientAuth, ClientMetadataCache, DPoPVerifier, OAuthError, RefreshToken, TokenData, 7 + TokenId, 8 8 db::{enforce_token_limit_for_user, lookup_refresh_token}, 9 9 scopes::expand_include_scopes, 10 10 verify_client_auth, 11 11 }; 12 12 use crate::state::AppState; 13 + use crate::util::pds_hostname; 13 14 use axum::Json; 14 15 use axum::http::HeaderMap; 15 16 use chrono::{Duration, Utc}; ··· 51 52 .map_err(crate::oauth::db_err_to_oauth)? 52 53 .ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?; 53 54 54 - let flow_state = AuthFlowState::from_request_data(&auth_request); 55 - if flow_state.is_expired() { 56 - return Err(OAuthError::InvalidGrant( 57 - "Authorization code has expired".to_string(), 58 - )); 59 - } 60 - if !flow_state.can_exchange() { 61 - return Err(OAuthError::InvalidGrant( 62 - "Authorization not completed".to_string(), 63 - )); 64 - } 55 + let flow = AuthFlow::from_request_data(auth_request).map_err(|_| { 56 + OAuthError::InvalidGrant("Authorization code has expired".to_string()) 57 + })?; 58 + 59 + let authorized = flow.require_authorized().map_err(|_| { 60 + OAuthError::InvalidGrant("Authorization not completed".to_string()) 61 + })?; 65 62 66 63 if let Some(request_client_id) = &request.client_auth.client_id 67 - && request_client_id != &auth_request.client_id 64 + && request_client_id != &authorized.client_id 68 65 { 69 66 return Err(OAuthError::InvalidGrant("client_id mismatch".to_string())); 70 67 } 71 - let did = flow_state.did().unwrap().to_string(); 68 + let did = authorized.did.to_string(); 72 69 let client_metadata_cache = ClientMetadataCache::new(3600); 73 - let client_metadata = client_metadata_cache.get(&auth_request.client_id).await?; 70 + let client_metadata = client_metadata_cache.get(&authorized.client_id).await?; 74 71 let client_auth = if let (Some(assertion), Some(assertion_type)) = ( 75 72 &request.client_auth.client_assertion, 76 73 &request.client_auth.client_assertion_type, ··· 91 88 ClientAuth::None 92 89 }; 93 90 verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?; 94 - verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?; 91 + verify_pkce(&authorized.parameters.code_challenge, &code_verifier)?; 95 92 if let Some(req_redirect_uri) = &redirect_uri 96 - && req_redirect_uri != &auth_request.parameters.redirect_uri 93 + && req_redirect_uri != &authorized.parameters.redirect_uri 97 94 { 98 95 return Err(OAuthError::InvalidGrant( 99 96 "redirect_uri mismatch".to_string(), ··· 103 100 let config = AuthConfig::get(); 104 101 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 105 102 let pds_hostname = 106 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 103 + pds_hostname(); 107 104 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 108 105 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 109 106 if !state ··· 116 113 "DPoP proof has already been used".to_string(), 117 114 )); 118 115 } 119 - if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt 116 + if let Some(expected_jkt) = &authorized.parameters.dpop_jkt 120 117 && result.jkt.as_str() != expected_jkt 121 118 { 122 119 return Err(OAuthError::InvalidDpopProof( ··· 124 121 )); 125 122 } 126 123 Some(result.jkt.as_str().to_string()) 127 - } else if auth_request.parameters.dpop_jkt.is_some() || client_metadata.requires_dpop() { 124 + } else if authorized.parameters.dpop_jkt.is_some() || client_metadata.requires_dpop() { 128 125 return Err(OAuthError::UseDpopNonce( 129 126 DPoPVerifier::new(AuthConfig::get().dpop_secret().as_bytes()).generate_nonce(), 130 127 )); ··· 135 132 let refresh_token = RefreshToken::generate(); 136 133 let now = Utc::now(); 137 134 138 - let (raw_scope, controller_did) = if let Some(ref controller) = auth_request.controller_did { 135 + let (raw_scope, controller_did) = if let Some(ref controller) = authorized.controller_did { 139 136 let did_parsed: Did = did 140 137 .parse() 141 138 .map_err(|_| OAuthError::InvalidRequest("Invalid DID format".to_string()))?; ··· 149 146 .ok() 150 147 .flatten(); 151 148 let granted_scopes = grant.map(|g| g.granted_scopes).unwrap_or_default(); 152 - let requested = auth_request 149 + let requested = authorized 153 150 .parameters 154 151 .scope 155 152 .as_deref() ··· 157 154 let intersected = intersect_scopes(requested, &granted_scopes); 158 155 (Some(intersected), Some(controller.clone())) 159 156 } else { 160 - (auth_request.parameters.scope.clone(), None) 157 + (authorized.parameters.scope.clone(), None) 161 158 }; 162 159 163 160 let final_scope = if let Some(ref scope) = raw_scope { ··· 177 174 final_scope.as_deref(), 178 175 controller_did.as_deref(), 179 176 )?; 180 - let stored_client_auth = auth_request.client_auth.unwrap_or(ClientAuth::None); 177 + let stored_client_auth = authorized.client_auth.unwrap_or(ClientAuth::None); 181 178 let refresh_expiry_days = if matches!(stored_client_auth, ClientAuth::None) { 182 179 REFRESH_TOKEN_EXPIRY_DAYS_PUBLIC 183 180 } else { 184 181 REFRESH_TOKEN_EXPIRY_DAYS_CONFIDENTIAL 185 182 }; 186 - let mut stored_parameters = auth_request.parameters.clone(); 183 + let mut stored_parameters = authorized.parameters.clone(); 187 184 stored_parameters.dpop_jkt = dpop_jkt.clone(); 185 + let did_typed: Did = did 186 + .parse() 187 + .map_err(|_| OAuthError::InvalidRequest("Invalid DID format".to_string()))?; 188 188 let token_data = TokenData { 189 - did: did.clone(), 190 - token_id: token_id.0.clone(), 189 + did: did_typed, 190 + token_id: token_id.clone(), 191 191 created_at: now, 192 192 updated_at: now, 193 193 expires_at: now + Duration::days(refresh_expiry_days), 194 - client_id: auth_request.client_id.clone(), 194 + client_id: authorized.client_id.clone(), 195 195 client_auth: stored_client_auth, 196 - device_id: auth_request.device_id, 196 + device_id: authorized.device_id.clone(), 197 197 parameters: stored_parameters, 198 198 details: None, 199 199 code: None, 200 - current_refresh_token: Some(refresh_token.0.clone()), 200 + current_refresh_token: Some(refresh_token.clone()), 201 201 scope: final_scope.clone(), 202 202 controller_did: controller_did.clone(), 203 203 }; ··· 209 209 tracing::info!( 210 210 did = %did, 211 211 token_id = %token_id.0, 212 - client_id = %auth_request.client_id, 212 + client_id = %authorized.client_id, 213 213 "Authorization code grant completed, token created" 214 214 ); 215 215 tokio::spawn({ ··· 280 280 ); 281 281 let dpop_jkt = token_data.parameters.dpop_jkt.as_deref(); 282 282 let access_token = create_access_token_with_delegation( 283 - &token_data.token_id, 284 - &token_data.did, 283 + &token_data.token_id.0, 284 + token_data.did.as_str(), 285 285 dpop_jkt, 286 286 token_data.scope.as_deref(), 287 - token_data.controller_did.as_deref(), 287 + token_data.controller_did.as_ref().map(|d| d.as_str()), 288 288 )?; 289 289 let mut response_headers = HeaderMap::new(); 290 290 let config = AuthConfig::get(); ··· 296 296 access_token, 297 297 token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(), 298 298 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 299 - refresh_token: token_data.current_refresh_token, 299 + refresh_token: token_data.current_refresh_token.map(|r| r.0), 300 300 scope: token_data.scope, 301 - sub: Some(token_data.did), 301 + sub: Some(token_data.did.to_string()), 302 302 }), 303 303 )); 304 304 } ··· 338 338 let config = AuthConfig::get(); 339 339 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes()); 340 340 let pds_hostname = 341 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 341 + pds_hostname(); 342 342 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 343 343 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 344 344 if !state ··· 385 385 "Refresh token rotated successfully" 386 386 ); 387 387 let access_token = create_access_token_with_delegation( 388 - &token_data.token_id, 389 - &token_data.did, 388 + &token_data.token_id.0, 389 + token_data.did.as_str(), 390 390 dpop_jkt.as_deref(), 391 391 token_data.scope.as_deref(), 392 - token_data.controller_did.as_deref(), 392 + token_data.controller_did.as_ref().map(|d| d.as_str()), 393 393 )?; 394 394 let mut response_headers = HeaderMap::new(); 395 395 let config = AuthConfig::get(); ··· 403 403 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64, 404 404 refresh_token: Some(new_refresh_token.0), 405 405 scope: token_data.scope, 406 - sub: Some(token_data.did), 406 + sub: Some(token_data.did.to_string()), 407 407 }), 408 408 )) 409 409 }
+2 -1
crates/tranquil-pds/src/oauth/endpoints/token/helpers.rs
··· 1 1 use crate::config::AuthConfig; 2 2 use crate::oauth::OAuthError; 3 + use crate::util::pds_hostname; 3 4 use base64::Engine; 4 5 use base64::engine::general_purpose::URL_SAFE_NO_PAD; 5 6 use chrono::Utc; ··· 51 52 ) -> Result<String, OAuthError> { 52 53 use serde_json::json; 53 54 let jti = uuid::Uuid::new_v4().to_string(); 54 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 55 + let pds_hostname = pds_hostname(); 55 56 let issuer = format!("https://{}", pds_hostname); 56 57 let now = Utc::now().timestamp(); 57 58 let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS;
+8 -22
crates/tranquil-pds/src/oauth/endpoints/token/introspect.rs
··· 1 1 use super::helpers::extract_token_claims; 2 2 use crate::oauth::OAuthError; 3 - use crate::state::{AppState, RateLimitKind}; 3 + use crate::rate_limit::{OAuthIntrospectLimit, OAuthRateLimited}; 4 + use crate::state::AppState; 5 + use crate::util::pds_hostname; 4 6 use axum::extract::State; 5 - use axum::http::{HeaderMap, StatusCode}; 7 + use axum::http::StatusCode; 6 8 use axum::{Form, Json}; 7 9 use chrono::Utc; 8 10 use serde::{Deserialize, Serialize}; ··· 17 19 18 20 pub async fn revoke_token( 19 21 State(state): State<AppState>, 20 - headers: HeaderMap, 22 + _rate_limit: OAuthRateLimited<OAuthIntrospectLimit>, 21 23 Form(request): Form<RevokeRequest>, 22 24 ) -> Result<StatusCode, OAuthError> { 23 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 24 - if !state 25 - .check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip) 26 - .await 27 - { 28 - tracing::warn!(ip = %client_ip, "OAuth revoke rate limit exceeded"); 29 - return Err(OAuthError::RateLimited); 30 - } 31 25 if let Some(token) = &request.token { 32 26 let refresh_token = RefreshToken::from(token.clone()); 33 27 if let Some((db_id, _)) = state ··· 89 83 90 84 pub async fn introspect_token( 91 85 State(state): State<AppState>, 92 - headers: HeaderMap, 86 + _rate_limit: OAuthRateLimited<OAuthIntrospectLimit>, 93 87 Form(request): Form<IntrospectRequest>, 94 88 ) -> Result<Json<IntrospectResponse>, OAuthError> { 95 - let client_ip = crate::rate_limit::extract_client_ip(&headers, None); 96 - if !state 97 - .check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip) 98 - .await 99 - { 100 - tracing::warn!(ip = %client_ip, "OAuth introspect rate limit exceeded"); 101 - return Err(OAuthError::RateLimited); 102 - } 103 89 let inactive_response = IntrospectResponse { 104 90 active: false, 105 91 scope: None, ··· 126 112 if token_data.expires_at < Utc::now() { 127 113 return Ok(Json(inactive_response)); 128 114 } 129 - let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 115 + let pds_hostname = pds_hostname(); 130 116 let issuer = format!("https://{}", pds_hostname); 131 117 Ok(Json(IntrospectResponse { 132 118 active: true, ··· 141 127 exp: Some(token_info.exp), 142 128 iat: Some(token_info.iat), 143 129 nbf: Some(token_info.iat), 144 - sub: Some(token_data.did), 130 + sub: Some(token_data.did.to_string()), 145 131 aud: Some(issuer.clone()), 146 132 iss: Some(issuer), 147 133 jti: Some(token_info.jti),
+3 -26
crates/tranquil-pds/src/oauth/endpoints/token/mod.rs
··· 4 4 mod types; 5 5 6 6 use crate::oauth::OAuthError; 7 - use crate::state::{AppState, RateLimitKind}; 7 + use crate::rate_limit::{OAuthRateLimited, OAuthTokenLimit}; 8 + use crate::state::AppState; 8 9 use axum::body::Bytes; 9 10 use axum::{Json, extract::State, http::HeaderMap}; 10 11 ··· 17 18 ClientAuthParams, GrantType, TokenGrant, TokenRequest, TokenResponse, ValidatedTokenRequest, 18 19 }; 19 20 20 - fn extract_client_ip(headers: &HeaderMap) -> String { 21 - if let Some(forwarded) = headers.get("x-forwarded-for") 22 - && let Ok(value) = forwarded.to_str() 23 - && let Some(first_ip) = value.split(',').next() 24 - { 25 - return first_ip.trim().to_string(); 26 - } 27 - if let Some(real_ip) = headers.get("x-real-ip") 28 - && let Ok(value) = real_ip.to_str() 29 - { 30 - return value.trim().to_string(); 31 - } 32 - "unknown".to_string() 33 - } 34 - 35 21 pub async fn token_endpoint( 36 22 State(state): State<AppState>, 23 + _rate_limit: OAuthRateLimited<OAuthTokenLimit>, 37 24 headers: HeaderMap, 38 25 body: Bytes, 39 26 ) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> { ··· 53 40 .to_string(), 54 41 )); 55 42 }; 56 - let client_ip = extract_client_ip(&headers); 57 - if !state 58 - .check_rate_limit(RateLimitKind::OAuthToken, &client_ip) 59 - .await 60 - { 61 - tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 62 - return Err(OAuthError::InvalidRequest( 63 - "Too many requests. Please try again later.".to_string(), 64 - )); 65 - } 66 43 let dpop_proof = headers 67 44 .get("DPoP") 68 45 .and_then(|v| v.to_str().ok())
+9 -7
crates/tranquil-pds/src/oauth/mod.rs
··· 10 10 } 11 11 12 12 pub use tranquil_oauth::{ 13 - AuthFlowState, AuthorizationRequestParameters, AuthorizationServerMetadata, 14 - AuthorizedClientData, ClientAuth, ClientMetadata, ClientMetadataCache, Code, DPoPClaims, 15 - DPoPJwk, DPoPProofHeader, DPoPProofPayload, DPoPVerifier, DPoPVerifyResult, DeviceData, 16 - DeviceId, JwkPublicKey, Jwks, OAuthClientMetadata, OAuthError, ParResponse, 17 - ProtectedResourceMetadata, RefreshToken, RefreshTokenState, RequestData, RequestId, SessionId, 18 - TokenData, TokenId, TokenRequest, TokenResponse, compute_access_token_hash, 19 - compute_jwk_thumbprint, verify_client_auth, 13 + AuthFlow, AuthFlowWithUser, AuthorizationRequestParameters, AuthorizationServerMetadata, 14 + AuthorizedClientData, ClientAuth, ClientMetadata, ClientMetadataCache, Code, 15 + CodeChallengeMethod, DPoPClaims, DPoPJwk, DPoPProofHeader, DPoPProofPayload, DPoPVerifier, 16 + DPoPVerifyResult, DeviceData, DeviceId, FlowAuthenticated, FlowAuthorized, FlowExpired, 17 + FlowNotAuthenticated, FlowNotAuthorized, FlowPending, JwkPublicKey, Jwks, OAuthClientMetadata, 18 + OAuthError, ParResponse, Prompt, ProtectedResourceMetadata, RefreshToken, RefreshTokenState, 19 + RequestData, RequestId, ResponseMode, ResponseType, SessionId, TokenData, TokenId, 20 + TokenRequest, TokenResponse, compute_access_token_hash, compute_jwk_thumbprint, 21 + verify_client_auth, 20 22 }; 21 23 22 24 pub use scopes::{AccountAction, AccountAttr, RepoAction, ScopeError, ScopePermissions};
+18 -14
crates/tranquil-pds/src/oauth/verify.rs
··· 20 20 use crate::state::AppState; 21 21 22 22 pub struct OAuthTokenInfo { 23 - pub did: String, 24 - pub token_id: String, 25 - pub client_id: String, 23 + pub did: Did, 24 + pub token_id: TokenId, 25 + pub client_id: ClientId, 26 26 pub scope: Option<String>, 27 27 pub dpop_jkt: Option<String>, 28 - pub controller_did: Option<String>, 28 + pub controller_did: Option<Did>, 29 29 } 30 30 31 31 pub struct VerifyResult { ··· 48 48 has_dpop_proof = dpop_proof.is_some(), 49 49 "Verifying OAuth access token" 50 50 ); 51 - let token_id = TokenId::from(token_info.token_id.clone()); 51 + let token_id = token_info.token_id.clone(); 52 52 let token_data = oauth_repo 53 53 .get_token_by_id(&token_id) 54 54 .await ··· 154 154 if exp < now { 155 155 return Err(OAuthError::ExpiredToken("Token has expired".to_string())); 156 156 } 157 - let token_id = payload 157 + let token_id_str = payload 158 158 .get("sid") 159 159 .and_then(|j| j.as_str()) 160 - .ok_or_else(|| OAuthError::InvalidToken("Missing sid claim".to_string()))? 161 - .to_string(); 162 - let did = payload 160 + .ok_or_else(|| OAuthError::InvalidToken("Missing sid claim".to_string()))?; 161 + let token_id = TokenId::new(token_id_str); 162 + let did_str = payload 163 163 .get("sub") 164 164 .and_then(|s| s.as_str()) 165 - .ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))? 166 - .to_string(); 165 + .ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))?; 166 + let did: Did = did_str 167 + .parse() 168 + .map_err(|_| OAuthError::InvalidToken("Invalid sub claim (not a valid DID)".to_string()))?; 167 169 let scope = payload 168 170 .get("scope") 169 171 .and_then(|s| s.as_str()) ··· 173 175 .and_then(|c| c.get("jkt")) 174 176 .and_then(|j| j.as_str()) 175 177 .map(|s| s.to_string()); 176 - let client_id = payload 178 + let client_id_str = payload 177 179 .get("client_id") 178 180 .and_then(|c| c.as_str()) 179 - .map(|s| s.to_string()) 180 181 .unwrap_or_default(); 182 + let client_id = ClientId::new(client_id_str); 181 183 let controller_did = payload 182 184 .get("act") 183 185 .and_then(|a| a.get("sub")) 184 186 .and_then(|s| s.as_str()) 185 - .map(|s| s.to_string()); 187 + .map(|s| s.parse::<Did>()) 188 + .transpose() 189 + .map_err(|_| OAuthError::InvalidToken("Invalid act.sub claim (not a valid DID)".to_string()))?; 186 190 Ok(OAuthTokenInfo { 187 191 did, 188 192 token_id,
+272
crates/tranquil-pds/src/rate_limit/extractor.rs
··· 1 + use std::marker::PhantomData; 2 + 3 + use axum::{ 4 + extract::FromRequestParts, 5 + http::request::Parts, 6 + response::{IntoResponse, Response}, 7 + }; 8 + 9 + use crate::api::error::ApiError; 10 + use crate::oauth::OAuthError; 11 + use crate::state::{AppState, RateLimitKind}; 12 + use crate::util::extract_client_ip; 13 + 14 + pub trait RateLimitPolicy: Send + Sync + 'static { 15 + const KIND: RateLimitKind; 16 + } 17 + 18 + pub struct LoginLimit; 19 + impl RateLimitPolicy for LoginLimit { 20 + const KIND: RateLimitKind = RateLimitKind::Login; 21 + } 22 + 23 + pub struct AccountCreationLimit; 24 + impl RateLimitPolicy for AccountCreationLimit { 25 + const KIND: RateLimitKind = RateLimitKind::AccountCreation; 26 + } 27 + 28 + pub struct PasswordResetLimit; 29 + impl RateLimitPolicy for PasswordResetLimit { 30 + const KIND: RateLimitKind = RateLimitKind::PasswordReset; 31 + } 32 + 33 + pub struct ResetPasswordLimit; 34 + impl RateLimitPolicy for ResetPasswordLimit { 35 + const KIND: RateLimitKind = RateLimitKind::ResetPassword; 36 + } 37 + 38 + pub struct RefreshSessionLimit; 39 + impl RateLimitPolicy for RefreshSessionLimit { 40 + const KIND: RateLimitKind = RateLimitKind::RefreshSession; 41 + } 42 + 43 + pub struct OAuthTokenLimit; 44 + impl RateLimitPolicy for OAuthTokenLimit { 45 + const KIND: RateLimitKind = RateLimitKind::OAuthToken; 46 + } 47 + 48 + pub struct OAuthAuthorizeLimit; 49 + impl RateLimitPolicy for OAuthAuthorizeLimit { 50 + const KIND: RateLimitKind = RateLimitKind::OAuthAuthorize; 51 + } 52 + 53 + pub struct OAuthParLimit; 54 + impl RateLimitPolicy for OAuthParLimit { 55 + const KIND: RateLimitKind = RateLimitKind::OAuthPar; 56 + } 57 + 58 + pub struct OAuthIntrospectLimit; 59 + impl RateLimitPolicy for OAuthIntrospectLimit { 60 + const KIND: RateLimitKind = RateLimitKind::OAuthIntrospect; 61 + } 62 + 63 + pub struct AppPasswordLimit; 64 + impl RateLimitPolicy for AppPasswordLimit { 65 + const KIND: RateLimitKind = RateLimitKind::AppPassword; 66 + } 67 + 68 + pub struct EmailUpdateLimit; 69 + impl RateLimitPolicy for EmailUpdateLimit { 70 + const KIND: RateLimitKind = RateLimitKind::EmailUpdate; 71 + } 72 + 73 + pub struct TotpVerifyLimit; 74 + impl RateLimitPolicy for TotpVerifyLimit { 75 + const KIND: RateLimitKind = RateLimitKind::TotpVerify; 76 + } 77 + 78 + pub struct HandleUpdateLimit; 79 + impl RateLimitPolicy for HandleUpdateLimit { 80 + const KIND: RateLimitKind = RateLimitKind::HandleUpdate; 81 + } 82 + 83 + pub struct HandleUpdateDailyLimit; 84 + impl RateLimitPolicy for HandleUpdateDailyLimit { 85 + const KIND: RateLimitKind = RateLimitKind::HandleUpdateDaily; 86 + } 87 + 88 + pub struct VerificationCheckLimit; 89 + impl RateLimitPolicy for VerificationCheckLimit { 90 + const KIND: RateLimitKind = RateLimitKind::VerificationCheck; 91 + } 92 + 93 + pub struct SsoInitiateLimit; 94 + impl RateLimitPolicy for SsoInitiateLimit { 95 + const KIND: RateLimitKind = RateLimitKind::SsoInitiate; 96 + } 97 + 98 + pub struct SsoCallbackLimit; 99 + impl RateLimitPolicy for SsoCallbackLimit { 100 + const KIND: RateLimitKind = RateLimitKind::SsoCallback; 101 + } 102 + 103 + pub struct SsoUnlinkLimit; 104 + impl RateLimitPolicy for SsoUnlinkLimit { 105 + const KIND: RateLimitKind = RateLimitKind::SsoUnlink; 106 + } 107 + 108 + pub struct OAuthRegisterCompleteLimit; 109 + impl RateLimitPolicy for OAuthRegisterCompleteLimit { 110 + const KIND: RateLimitKind = RateLimitKind::OAuthRegisterComplete; 111 + } 112 + 113 + pub trait RateLimitRejection: IntoResponse + Send + 'static { 114 + fn new() -> Self; 115 + } 116 + 117 + pub struct ApiRateLimitRejection; 118 + 119 + impl RateLimitRejection for ApiRateLimitRejection { 120 + fn new() -> Self { 121 + Self 122 + } 123 + } 124 + 125 + impl IntoResponse for ApiRateLimitRejection { 126 + fn into_response(self) -> Response { 127 + ApiError::RateLimitExceeded(None).into_response() 128 + } 129 + } 130 + 131 + pub struct OAuthRateLimitRejection; 132 + 133 + impl RateLimitRejection for OAuthRateLimitRejection { 134 + fn new() -> Self { 135 + Self 136 + } 137 + } 138 + 139 + impl IntoResponse for OAuthRateLimitRejection { 140 + fn into_response(self) -> Response { 141 + OAuthError::RateLimited.into_response() 142 + } 143 + } 144 + 145 + impl From<OAuthRateLimitRejection> for OAuthError { 146 + fn from(_: OAuthRateLimitRejection) -> Self { 147 + OAuthError::RateLimited 148 + } 149 + } 150 + 151 + pub struct RateLimitedInner<P: RateLimitPolicy, R: RateLimitRejection> { 152 + client_ip: String, 153 + _marker: PhantomData<(P, R)>, 154 + } 155 + 156 + impl<P: RateLimitPolicy, R: RateLimitRejection> RateLimitedInner<P, R> { 157 + pub fn client_ip(&self) -> &str { 158 + &self.client_ip 159 + } 160 + } 161 + 162 + impl<P: RateLimitPolicy, R: RateLimitRejection> FromRequestParts<AppState> 163 + for RateLimitedInner<P, R> 164 + { 165 + type Rejection = R; 166 + 167 + async fn from_request_parts( 168 + parts: &mut Parts, 169 + state: &AppState, 170 + ) -> Result<Self, Self::Rejection> { 171 + let client_ip = extract_client_ip(&parts.headers, None); 172 + 173 + if !state.check_rate_limit(P::KIND, &client_ip).await { 174 + tracing::warn!( 175 + ip = %client_ip, 176 + kind = ?P::KIND, 177 + "Rate limit exceeded" 178 + ); 179 + return Err(R::new()); 180 + } 181 + 182 + Ok(RateLimitedInner { 183 + client_ip, 184 + _marker: PhantomData, 185 + }) 186 + } 187 + } 188 + 189 + pub type RateLimited<P> = RateLimitedInner<P, ApiRateLimitRejection>; 190 + pub type OAuthRateLimited<P> = RateLimitedInner<P, OAuthRateLimitRejection>; 191 + 192 + #[derive(Debug)] 193 + pub struct UserRateLimitError { 194 + pub kind: RateLimitKind, 195 + pub message: Option<String>, 196 + } 197 + 198 + impl UserRateLimitError { 199 + pub fn new(kind: RateLimitKind) -> Self { 200 + Self { 201 + kind, 202 + message: None, 203 + } 204 + } 205 + 206 + pub fn with_message(kind: RateLimitKind, message: impl Into<String>) -> Self { 207 + Self { 208 + kind, 209 + message: Some(message.into()), 210 + } 211 + } 212 + } 213 + 214 + impl std::fmt::Display for UserRateLimitError { 215 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 216 + match &self.message { 217 + Some(msg) => write!(f, "{}", msg), 218 + None => write!(f, "Rate limit exceeded for {:?}", self.kind), 219 + } 220 + } 221 + } 222 + 223 + impl std::error::Error for UserRateLimitError {} 224 + 225 + impl IntoResponse for UserRateLimitError { 226 + fn into_response(self) -> Response { 227 + ApiError::RateLimitExceeded(self.message).into_response() 228 + } 229 + } 230 + 231 + pub struct UserRateLimitProof<P: RateLimitPolicy> { 232 + _marker: PhantomData<P>, 233 + } 234 + 235 + impl<P: RateLimitPolicy> UserRateLimitProof<P> { 236 + fn new() -> Self { 237 + Self { 238 + _marker: PhantomData, 239 + } 240 + } 241 + } 242 + 243 + pub async fn check_user_rate_limit<P: RateLimitPolicy>( 244 + state: &AppState, 245 + user_key: &str, 246 + ) -> Result<UserRateLimitProof<P>, UserRateLimitError> { 247 + if !state.check_rate_limit(P::KIND, user_key).await { 248 + tracing::warn!( 249 + key = %user_key, 250 + kind = ?P::KIND, 251 + "User rate limit exceeded" 252 + ); 253 + return Err(UserRateLimitError::new(P::KIND)); 254 + } 255 + Ok(UserRateLimitProof::new()) 256 + } 257 + 258 + pub async fn check_user_rate_limit_with_message<P: RateLimitPolicy>( 259 + state: &AppState, 260 + user_key: &str, 261 + error_message: impl Into<String>, 262 + ) -> Result<UserRateLimitProof<P>, UserRateLimitError> { 263 + if !state.check_rate_limit(P::KIND, user_key).await { 264 + tracing::warn!( 265 + key = %user_key, 266 + kind = ?P::KIND, 267 + "User rate limit exceeded" 268 + ); 269 + return Err(UserRateLimitError::with_message(P::KIND, error_message)); 270 + } 271 + Ok(UserRateLimitProof::new()) 272 + }
+5 -102
crates/tranquil-pds/src/rate_limit.rs crates/tranquil-pds/src/rate_limit/mod.rs
··· 1 - use axum::{ 2 - Json, 3 - body::Body, 4 - extract::ConnectInfo, 5 - http::{HeaderMap, Request, StatusCode}, 6 - middleware::Next, 7 - response::{IntoResponse, Response}, 8 - }; 1 + mod extractor; 2 + 3 + pub use extractor::*; 4 + 9 5 use governor::{ 10 6 Quota, RateLimiter, 11 7 clock::DefaultClock, 12 8 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore}, 13 9 }; 14 - use std::{net::SocketAddr, num::NonZeroU32, sync::Arc}; 10 + use std::{num::NonZeroU32, sync::Arc}; 15 11 16 12 pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 17 13 pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; ··· 166 162 } 167 163 } 168 164 169 - pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 170 - if let Some(forwarded) = headers.get("x-forwarded-for") 171 - && let Ok(value) = forwarded.to_str() 172 - && let Some(first_ip) = value.split(',').next() 173 - { 174 - return first_ip.trim().to_string(); 175 - } 176 - 177 - if let Some(real_ip) = headers.get("x-real-ip") 178 - && let Ok(value) = real_ip.to_str() 179 - { 180 - return value.trim().to_string(); 181 - } 182 - 183 - addr.map(|a| a.ip().to_string()) 184 - .unwrap_or_else(|| "unknown".to_string()) 185 - } 186 - 187 - fn rate_limit_response() -> Response { 188 - ( 189 - StatusCode::TOO_MANY_REQUESTS, 190 - Json(serde_json::json!({ 191 - "error": "RateLimitExceeded", 192 - "message": "Too many requests. Please try again later." 193 - })), 194 - ) 195 - .into_response() 196 - } 197 - 198 - pub async fn login_rate_limit( 199 - ConnectInfo(addr): ConnectInfo<SocketAddr>, 200 - axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 201 - request: Request<Body>, 202 - next: Next, 203 - ) -> Response { 204 - let client_ip = extract_client_ip(request.headers(), Some(addr)); 205 - 206 - if limiters.login.check_key(&client_ip).is_err() { 207 - tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 208 - return rate_limit_response(); 209 - } 210 - 211 - next.run(request).await 212 - } 213 - 214 - pub async fn oauth_token_rate_limit( 215 - ConnectInfo(addr): ConnectInfo<SocketAddr>, 216 - axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 217 - request: Request<Body>, 218 - next: Next, 219 - ) -> Response { 220 - let client_ip = extract_client_ip(request.headers(), Some(addr)); 221 - 222 - if limiters.oauth_token.check_key(&client_ip).is_err() { 223 - tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 224 - return rate_limit_response(); 225 - } 226 - 227 - next.run(request).await 228 - } 229 - 230 - pub async fn password_reset_rate_limit( 231 - ConnectInfo(addr): ConnectInfo<SocketAddr>, 232 - axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 233 - request: Request<Body>, 234 - next: Next, 235 - ) -> Response { 236 - let client_ip = extract_client_ip(request.headers(), Some(addr)); 237 - 238 - if limiters.password_reset.check_key(&client_ip).is_err() { 239 - tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 240 - return rate_limit_response(); 241 - } 242 - 243 - next.run(request).await 244 - } 245 - 246 - pub async fn account_creation_rate_limit( 247 - ConnectInfo(addr): ConnectInfo<SocketAddr>, 248 - axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 249 - request: Request<Body>, 250 - next: Next, 251 - ) -> Response { 252 - let client_ip = extract_client_ip(request.headers(), Some(addr)); 253 - 254 - if limiters.account_creation.check_key(&client_ip).is_err() { 255 - tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 256 - return rate_limit_response(); 257 - } 258 - 259 - next.run(request).await 260 - } 261 - 262 165 #[cfg(test)] 263 166 mod tests { 264 167 use super::*;
+2 -1
crates/tranquil-pds/src/sso/config.rs
··· 1 + use crate::util::pds_hostname; 1 2 use std::sync::OnceLock; 2 3 use tranquil_db_traits::SsoProviderType; 3 4 ··· 50 51 }; 51 52 52 53 if config.is_any_enabled() { 53 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_default(); 54 + let hostname = pds_hostname(); 54 55 if hostname.is_empty() || hostname == "localhost" { 55 56 panic!( 56 57 "PDS_HOSTNAME must be set to a valid hostname when SSO is enabled. \
+33 -62
crates/tranquil-pds/src/sso/endpoints.rs
··· 13 13 use crate::api::error::ApiError; 14 14 use crate::auth::extractor::extract_bearer_token_from_header; 15 15 use crate::auth::{generate_app_password, validate_bearer_token_cached}; 16 - use crate::rate_limit::extract_client_ip; 17 - use crate::state::{AppState, RateLimitKind}; 16 + use crate::rate_limit::{ 17 + AccountCreationLimit, RateLimited, SsoCallbackLimit, SsoInitiateLimit, SsoUnlinkLimit, 18 + check_user_rate_limit_with_message, 19 + }; 20 + use crate::state::AppState; 21 + use crate::util::{pds_hostname, pds_hostname_without_port}; 18 22 19 23 fn generate_state() -> String { 20 24 use rand::RngCore; ··· 71 75 72 76 pub async fn sso_initiate( 73 77 State(state): State<AppState>, 78 + _rate_limit: RateLimited<SsoInitiateLimit>, 74 79 headers: HeaderMap, 75 80 Json(input): Json<SsoInitiateRequest>, 76 81 ) -> Result<Json<SsoInitiateResponse>, ApiError> { 77 - let client_ip = extract_client_ip(&headers, None); 78 - if !state 79 - .check_rate_limit(RateLimitKind::SsoInitiate, &client_ip) 80 - .await 81 - { 82 - tracing::warn!(ip = %client_ip, "SSO initiate rate limit exceeded"); 83 - return Err(ApiError::RateLimitExceeded(None)); 84 - } 85 - 86 82 if input.provider.len() > 20 { 87 83 return Err(ApiError::SsoProviderNotFound); 88 84 } ··· 217 213 218 214 pub async fn sso_callback( 219 215 State(state): State<AppState>, 220 - headers: HeaderMap, 216 + _rate_limit: RateLimited<SsoCallbackLimit>, 221 217 Query(query): Query<SsoCallbackQuery>, 222 218 ) -> Response { 219 + sso_callback_internal(&state, query).await 220 + } 221 + 222 + async fn sso_callback_internal(state: &AppState, query: SsoCallbackQuery) -> Response { 223 223 tracing::debug!( 224 224 has_code = query.code.is_some(), 225 225 has_state = query.state.is_some(), ··· 227 227 "SSO callback received" 228 228 ); 229 229 230 - let client_ip = extract_client_ip(&headers, None); 231 - if !state 232 - .check_rate_limit(RateLimitKind::SsoCallback, &client_ip) 233 - .await 234 - { 235 - tracing::warn!(ip = %client_ip, "SSO callback rate limit exceeded"); 236 - return redirect_to_error("Too many requests. Please try again later."); 237 - } 238 - 239 230 if let Some(ref error) = query.error { 240 231 tracing::warn!( 241 232 error = %error, ··· 329 320 match auth_state.action.as_str() { 330 321 "login" => { 331 322 handle_sso_login( 332 - &state, 323 + state, 333 324 &auth_state.request_uri, 334 325 auth_state.provider, 335 326 &user_info, ··· 341 332 Some(d) => d, 342 333 None => return redirect_to_error("Not authenticated"), 343 334 }; 344 - handle_sso_link(&state, did, auth_state.provider, &user_info).await 335 + handle_sso_link(state, did, auth_state.provider, &user_info).await 345 336 } 346 337 "register" => { 347 338 handle_sso_register( 348 - &state, 339 + state, 349 340 &auth_state.request_uri, 350 341 auth_state.provider, 351 342 &user_info, ··· 358 349 359 350 pub async fn sso_callback_post( 360 351 State(state): State<AppState>, 361 - headers: HeaderMap, 352 + _rate_limit: RateLimited<SsoCallbackLimit>, 362 353 Form(form): Form<SsoCallbackForm>, 363 354 ) -> Response { 364 355 tracing::debug!( ··· 376 367 error_description: form.error_description, 377 368 }; 378 369 379 - sso_callback(State(state), headers, Query(query)).await 370 + sso_callback_internal(&state, query).await 380 371 } 381 372 382 373 fn generate_registration_token() -> String { ··· 682 673 auth: crate::auth::Auth<crate::auth::Active>, 683 674 Json(input): Json<UnlinkAccountRequest>, 684 675 ) -> Result<Json<UnlinkAccountResponse>, ApiError> { 685 - if !state 686 - .check_rate_limit(RateLimitKind::SsoUnlink, auth.did.as_str()) 687 - .await 688 - { 689 - tracing::warn!(did = %auth.did, "SSO unlink rate limit exceeded"); 690 - return Err(ApiError::RateLimitExceeded(None)); 691 - } 676 + let _rate_limit = check_user_rate_limit_with_message::<SsoUnlinkLimit>( 677 + &state, 678 + auth.did.as_str(), 679 + "Too many unlink attempts. Please try again later.", 680 + ) 681 + .await?; 692 682 693 683 let id = uuid::Uuid::parse_str(&input.id).map_err(|_| ApiError::InvalidId)?; 694 684 ··· 746 736 747 737 pub async fn get_pending_registration( 748 738 State(state): State<AppState>, 749 - headers: HeaderMap, 739 + _rate_limit: RateLimited<SsoCallbackLimit>, 750 740 Query(query): Query<PendingRegistrationQuery>, 751 741 ) -> Result<Json<PendingRegistrationResponse>, ApiError> { 752 - let client_ip = extract_client_ip(&headers, None); 753 - if !state 754 - .check_rate_limit(RateLimitKind::SsoCallback, &client_ip) 755 - .await 756 - { 757 - tracing::warn!(ip = %client_ip, "SSO pending registration rate limit exceeded"); 758 - return Err(ApiError::RateLimitExceeded(None)); 759 - } 760 - 761 742 if query.token.len() > 100 { 762 743 return Err(ApiError::InvalidRequest("Invalid token".into())); 763 744 } ··· 810 791 } 811 792 }; 812 793 813 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 814 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 794 + let hostname_for_handles = pds_hostname_without_port(); 815 795 let full_handle = format!("{}.{}", validated, hostname_for_handles); 816 796 let handle_typed = crate::types::Handle::new_unchecked(&full_handle); 817 797 ··· 866 846 867 847 pub async fn complete_registration( 868 848 State(state): State<AppState>, 869 - headers: HeaderMap, 849 + rate_limit: RateLimited<AccountCreationLimit>, 870 850 Json(input): Json<CompleteRegistrationInput>, 871 851 ) -> Result<Json<CompleteRegistrationResponse>, ApiError> { 852 + let client_ip = rate_limit.client_ip(); 872 853 use jacquard_common::types::{integer::LimitedU32, string::Tid}; 873 854 use jacquard_repo::{mst::Mst, storage::BlockStore}; 874 855 use k256::ecdsa::SigningKey; ··· 876 857 use serde_json::json; 877 858 use std::sync::Arc; 878 859 879 - let client_ip = extract_client_ip(&headers, None); 880 - if !state 881 - .check_rate_limit(RateLimitKind::AccountCreation, &client_ip) 882 - .await 883 - { 884 - tracing::warn!(ip = %client_ip, "SSO registration rate limit exceeded"); 885 - return Err(ApiError::RateLimitExceeded(None)); 886 - } 887 - 888 860 if input.token.len() > 100 { 889 861 return Err(ApiError::InvalidRequest("Invalid token".into())); 890 862 } ··· 899 871 .await? 900 872 .ok_or(ApiError::SsoSessionExpired)?; 901 873 902 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 903 - let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 874 + let hostname = pds_hostname(); 875 + let hostname_for_handles = pds_hostname_without_port(); 904 876 905 877 let handle = match crate::api::validation::validate_short_handle(&input.handle) { 906 878 Ok(h) => format!("{}.{}", h, hostname_for_handles), ··· 977 949 let handle_typed = crate::types::Handle::new_unchecked(&handle); 978 950 let reserved = state 979 951 .user_repo 980 - .reserve_handle(&handle_typed, &client_ip) 952 + .reserve_handle(&handle_typed, client_ip) 981 953 .await 982 954 .unwrap_or(false); 983 955 ··· 1315 1287 return Err(ApiError::InternalError(None)); 1316 1288 } 1317 1289 1318 - let hostname = 1319 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 1290 + let hostname = pds_hostname(); 1320 1291 if let Err(e) = crate::comms::comms_repo::enqueue_welcome( 1321 1292 state.user_repo.as_ref(), 1322 1293 state.infra_repo.as_ref(), 1323 1294 user_id.unwrap_or(uuid::Uuid::nil()), 1324 - &hostname, 1295 + hostname, 1325 1296 ) 1326 1297 .await 1327 1298 { ··· 1367 1338 verification_channel, 1368 1339 &verification_recipient, 1369 1340 &formatted_token, 1370 - &hostname, 1341 + hostname, 1371 1342 ) 1372 1343 .await 1373 1344 {
+9
crates/tranquil-pds/src/state.rs
··· 1 1 use crate::appview::DidResolver; 2 + use crate::auth::webauthn::WebAuthnConfig; 2 3 use crate::cache::{Cache, DistributedRateLimiter, create_cache}; 3 4 use crate::circuit_breaker::CircuitBreakers; 4 5 use crate::config::AuthConfig; ··· 7 8 use crate::sso::{SsoConfig, SsoManager}; 8 9 use crate::storage::{BackupStorage, BlobStorage, create_backup_storage, create_blob_storage}; 9 10 use crate::sync::firehose::SequencedEvent; 11 + use crate::util::pds_hostname; 10 12 use sqlx::PgPool; 11 13 use std::error::Error; 12 14 use std::sync::Arc; ··· 41 43 pub did_resolver: Arc<DidResolver>, 42 44 pub sso_repo: Arc<dyn SsoRepository>, 43 45 pub sso_manager: SsoManager, 46 + pub webauthn_config: Arc<WebAuthnConfig>, 44 47 } 45 48 49 + #[derive(Debug, Clone, Copy)] 46 50 pub enum RateLimitKind { 47 51 Login, 48 52 AccountCreation, ··· 180 184 let did_resolver = Arc::new(DidResolver::new()); 181 185 let sso_config = SsoConfig::init(); 182 186 let sso_manager = SsoManager::from_config(sso_config); 187 + let webauthn_config = Arc::new( 188 + WebAuthnConfig::new(pds_hostname()) 189 + .expect("Failed to create WebAuthn config at startup"), 190 + ); 183 191 184 192 Self { 185 193 user_repo: repos.user.clone(), ··· 204 212 distributed_rate_limiter, 205 213 did_resolver, 206 214 sso_manager, 215 + webauthn_config, 207 216 } 208 217 } 209 218
+2 -2
crates/tranquil-pds/src/sync/deprecated.rs
··· 20 20 21 21 async fn check_admin_or_self(state: &AppState, headers: &HeaderMap, did: &str) -> bool { 22 22 let extracted = match crate::auth::extract_auth_token_from_header( 23 - headers.get("Authorization").and_then(|h| h.to_str().ok()), 23 + crate::util::get_header_str(headers, "Authorization"), 24 24 ) { 25 25 Some(t) => t, 26 26 None => return false, 27 27 }; 28 - let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 28 + let dpop_proof = crate::util::get_header_str(headers, "DPoP"); 29 29 let http_uri = "/"; 30 30 match crate::auth::validate_token_with_dpop( 31 31 state.user_repo.as_ref(),
+21 -4
crates/tranquil-pds/src/util.rs
··· 4 4 use rand::Rng; 5 5 use serde_json::Value as JsonValue; 6 6 use std::collections::BTreeMap; 7 + use std::net::SocketAddr; 7 8 use std::str::FromStr; 8 9 use std::sync::OnceLock; 9 10 ··· 11 12 const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024; 12 13 13 14 static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new(); 15 + static PDS_HOSTNAME: OnceLock<String> = OnceLock::new(); 16 + static PDS_HOSTNAME_WITHOUT_PORT: OnceLock<String> = OnceLock::new(); 14 17 15 18 pub fn get_max_blob_size() -> usize { 16 19 *MAX_BLOB_SIZE.get_or_init(|| { ··· 69 72 .unwrap_or_default() 70 73 } 71 74 72 - pub fn extract_client_ip(headers: &HeaderMap) -> String { 75 + pub fn get_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> { 76 + headers.get(name).and_then(|h| h.to_str().ok()) 77 + } 78 + 79 + pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 73 80 if let Some(forwarded) = headers.get("x-forwarded-for") 74 81 && let Ok(value) = forwarded.to_str() 75 82 && let Some(first_ip) = value.split(',').next() ··· 81 88 { 82 89 return value.trim().to_string(); 83 90 } 84 - "unknown".to_string() 91 + addr.map(|a| a.ip().to_string()) 92 + .unwrap_or_else(|| "unknown".to_string()) 93 + } 94 + 95 + pub fn pds_hostname() -> &'static str { 96 + PDS_HOSTNAME.get_or_init(|| { 97 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 98 + }) 85 99 } 86 100 87 - pub fn pds_hostname() -> String { 88 - std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 101 + pub fn pds_hostname_without_port() -> &'static str { 102 + PDS_HOSTNAME_WITHOUT_PORT.get_or_init(|| { 103 + let hostname = pds_hostname(); 104 + hostname.split(':').next().unwrap_or(hostname).to_string() 105 + }) 89 106 } 90 107 91 108 pub fn pds_public_url() -> String {

History

3 rounds 0 comments
sign up or login to add to the discussion
1 commit
expand
fix: better type-safety
expand 0 comments
pull request successfully merged
1 commit
expand
fix: better type-safety
expand 0 comments
lewis.moe submitted #0
1 commit
expand
fix: better type-safety
expand 0 comments