Ensuring at compile-time that we're definitely handling possible early failures in functions
+46
-45
crates/tranquil-db/src/postgres/oauth.rs
+46
-45
crates/tranquil-db/src/postgres/oauth.rs
···
7
7
ScopePreference, TrustedDeviceRow, TwoFactorChallenge,
8
8
};
9
9
use tranquil_oauth::{
10
-
AuthorizationRequestParameters, AuthorizedClientData, ClientAuth, DeviceData, RequestData,
11
-
TokenData,
10
+
AuthorizationRequestParameters, AuthorizedClientData, ClientAuth, Code as OAuthCode,
11
+
DeviceData, DeviceId as OAuthDeviceId, RequestData, SessionId as OAuthSessionId, TokenData,
12
+
TokenId as OAuthTokenId, RefreshToken as OAuthRefreshToken,
12
13
};
13
14
use tranquil_types::{
14
15
AuthorizationCode, ClientId, DPoPProofId, DeviceId, Did, Handle, RefreshToken, RequestId,
···
59
60
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
60
61
RETURNING id
61
62
"#,
62
-
data.did,
63
-
data.token_id,
63
+
data.did.as_str(),
64
+
&data.token_id.0,
64
65
data.created_at,
65
66
data.updated_at,
66
67
data.expires_at,
67
68
data.client_id,
68
69
client_auth_json,
69
-
data.device_id,
70
+
data.device_id.as_ref().map(|d| d.0.as_str()),
70
71
parameters_json,
71
72
data.details,
72
-
data.code,
73
-
data.current_refresh_token,
73
+
data.code.as_ref().map(|c| c.0.as_str()),
74
+
data.current_refresh_token.as_ref().map(|r| r.0.as_str()),
74
75
data.scope,
75
-
data.controller_did,
76
+
data.controller_did.as_ref().map(|d| d.as_str()),
76
77
)
77
78
.fetch_one(&self.pool)
78
79
.await
···
95
96
.map_err(map_sqlx_error)?;
96
97
match row {
97
98
Some(r) => Ok(Some(TokenData {
98
-
did: r.did,
99
-
token_id: r.token_id,
99
+
did: r.did.parse().map_err(|_| DbError::Other("Invalid DID in token".into()))?,
100
+
token_id: OAuthTokenId(r.token_id),
100
101
created_at: r.created_at,
101
102
updated_at: r.updated_at,
102
103
expires_at: r.expires_at,
103
104
client_id: r.client_id,
104
105
client_auth: from_json(r.client_auth)?,
105
-
device_id: r.device_id,
106
+
device_id: r.device_id.map(OAuthDeviceId),
106
107
parameters: from_json(r.parameters)?,
107
108
details: r.details,
108
-
code: r.code,
109
-
current_refresh_token: r.current_refresh_token,
109
+
code: r.code.map(OAuthCode),
110
+
current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken),
110
111
scope: r.scope,
111
-
controller_did: r.controller_did,
112
+
controller_did: r.controller_did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid controller DID".into()))?,
112
113
})),
113
114
None => Ok(None),
114
115
}
···
134
135
Some(r) => Ok(Some((
135
136
r.id,
136
137
TokenData {
137
-
did: r.did,
138
-
token_id: r.token_id,
138
+
did: r.did.parse().map_err(|_| DbError::Other("Invalid DID in token".into()))?,
139
+
token_id: OAuthTokenId(r.token_id),
139
140
created_at: r.created_at,
140
141
updated_at: r.updated_at,
141
142
expires_at: r.expires_at,
142
143
client_id: r.client_id,
143
144
client_auth: from_json(r.client_auth)?,
144
-
device_id: r.device_id,
145
+
device_id: r.device_id.map(OAuthDeviceId),
145
146
parameters: from_json(r.parameters)?,
146
147
details: r.details,
147
-
code: r.code,
148
-
current_refresh_token: r.current_refresh_token,
148
+
code: r.code.map(OAuthCode),
149
+
current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken),
149
150
scope: r.scope,
150
-
controller_did: r.controller_did,
151
+
controller_did: r.controller_did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid controller DID".into()))?,
151
152
},
152
153
))),
153
154
None => Ok(None),
···
176
177
Some(r) => Ok(Some((
177
178
r.id,
178
179
TokenData {
179
-
did: r.did,
180
-
token_id: r.token_id,
180
+
did: r.did.parse().map_err(|_| DbError::Other("Invalid DID in token".into()))?,
181
+
token_id: OAuthTokenId(r.token_id),
181
182
created_at: r.created_at,
182
183
updated_at: r.updated_at,
183
184
expires_at: r.expires_at,
184
185
client_id: r.client_id,
185
186
client_auth: from_json(r.client_auth)?,
186
-
device_id: r.device_id,
187
+
device_id: r.device_id.map(OAuthDeviceId),
187
188
parameters: from_json(r.parameters)?,
188
189
details: r.details,
189
-
code: r.code,
190
-
current_refresh_token: r.current_refresh_token,
190
+
code: r.code.map(OAuthCode),
191
+
current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken),
191
192
scope: r.scope,
192
-
controller_did: r.controller_did,
193
+
controller_did: r.controller_did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid controller DID".into()))?,
193
194
},
194
195
))),
195
196
None => Ok(None),
···
302
303
rows.into_iter()
303
304
.map(|r| {
304
305
Ok(TokenData {
305
-
did: r.did,
306
-
token_id: r.token_id,
306
+
did: r.did.parse().map_err(|_| DbError::Other("Invalid DID in token".into()))?,
307
+
token_id: OAuthTokenId(r.token_id),
307
308
created_at: r.created_at,
308
309
updated_at: r.updated_at,
309
310
expires_at: r.expires_at,
310
311
client_id: r.client_id,
311
312
client_auth: from_json(r.client_auth)?,
312
-
device_id: r.device_id,
313
+
device_id: r.device_id.map(OAuthDeviceId),
313
314
parameters: from_json(r.parameters)?,
314
315
details: r.details,
315
-
code: r.code,
316
-
current_refresh_token: r.current_refresh_token,
316
+
code: r.code.map(OAuthCode),
317
+
current_refresh_token: r.current_refresh_token.map(OAuthRefreshToken),
317
318
scope: r.scope,
318
-
controller_did: r.controller_did,
319
+
controller_did: r.controller_did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid controller DID".into()))?,
319
320
})
320
321
})
321
322
.collect()
···
407
408
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
408
409
"#,
409
410
request_id.as_str(),
410
-
data.did,
411
-
data.device_id,
411
+
data.did.as_ref().map(|d| d.as_str()),
412
+
data.device_id.as_ref().map(|d| d.0.as_str()),
412
413
data.client_id,
413
414
client_auth_json,
414
415
parameters_json,
415
416
data.expires_at,
416
-
data.code,
417
+
data.code.as_ref().map(|c| c.0.as_str()),
417
418
)
418
419
.execute(&self.pool)
419
420
.await
···
448
449
client_auth,
449
450
parameters,
450
451
expires_at: r.expires_at,
451
-
did: r.did,
452
-
device_id: r.device_id,
453
-
code: r.code,
454
-
controller_did: r.controller_did,
452
+
did: r.did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid DID in DB".into()))?,
453
+
device_id: r.device_id.map(OAuthDeviceId),
454
+
code: r.code.map(OAuthCode),
455
+
controller_did: r.controller_did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid controller DID in DB".into()))?,
455
456
}))
456
457
}
457
458
None => Ok(None),
···
534
535
client_auth,
535
536
parameters,
536
537
expires_at: r.expires_at,
537
-
did: r.did,
538
-
device_id: r.device_id,
539
-
code: r.code,
540
-
controller_did: r.controller_did,
538
+
did: r.did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid DID in DB".into()))?,
539
+
device_id: r.device_id.map(OAuthDeviceId),
540
+
code: r.code.map(OAuthCode),
541
+
controller_did: r.controller_did.map(|s| s.parse()).transpose().map_err(|_| DbError::Other("Invalid controller DID in DB".into()))?,
541
542
}))
542
543
}
543
544
None => Ok(None),
···
655
656
VALUES ($1, $2, $3, $4, $5)
656
657
"#,
657
658
device_id.as_str(),
658
-
data.session_id,
659
+
&data.session_id.0,
659
660
data.user_agent,
660
661
data.ip_address,
661
662
data.last_seen_at,
···
679
680
.await
680
681
.map_err(map_sqlx_error)?;
681
682
Ok(row.map(|r| DeviceData {
682
-
session_id: r.session_id,
683
+
session_id: OAuthSessionId(r.session_id),
683
684
user_agent: r.user_agent,
684
685
ip_address: r.ip_address,
685
686
last_seen_at: r.last_seen_at,
+6
-4
crates/tranquil-oauth/src/lib.rs
+6
-4
crates/tranquil-oauth/src/lib.rs
···
10
10
};
11
11
pub use error::OAuthError;
12
12
pub use types::{
13
-
AuthFlowState, AuthorizationRequestParameters, AuthorizationServerMetadata,
14
-
AuthorizedClientData, ClientAuth, Code, DPoPClaims, DeviceData, DeviceId, JwkPublicKey, Jwks,
15
-
OAuthClientMetadata, ParResponse, ProtectedResourceMetadata, RefreshToken, RefreshTokenState,
16
-
RequestData, RequestId, SessionId, TokenData, TokenId, TokenRequest, TokenResponse,
13
+
AuthFlow, AuthFlowWithUser, AuthorizationRequestParameters, AuthorizationServerMetadata,
14
+
AuthorizedClientData, ClientAuth, Code, CodeChallengeMethod, DPoPClaims, DeviceData, DeviceId,
15
+
FlowAuthenticated, FlowAuthorized, FlowExpired, FlowNotAuthenticated, FlowNotAuthorized,
16
+
FlowPending, JwkPublicKey, Jwks, OAuthClientMetadata, ParResponse, Prompt,
17
+
ProtectedResourceMetadata, RefreshToken, RefreshTokenState, RequestData, RequestId,
18
+
ResponseMode, ResponseType, SessionId, TokenData, TokenId, TokenRequest, TokenResponse,
17
19
};
+250
-141
crates/tranquil-oauth/src/types.rs
+250
-141
crates/tranquil-oauth/src/types.rs
···
1
1
use chrono::{DateTime, Utc};
2
2
use serde::{Deserialize, Serialize};
3
3
use serde_json::Value as JsonValue;
4
+
use tranquil_types::Did;
4
5
5
-
#[derive(Debug, Clone, Serialize, Deserialize)]
6
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
7
+
#[serde(transparent)]
8
+
#[sqlx(transparent)]
6
9
pub struct RequestId(pub String);
7
10
8
-
#[derive(Debug, Clone, Serialize, Deserialize)]
11
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
12
+
#[serde(transparent)]
13
+
#[sqlx(transparent)]
9
14
pub struct TokenId(pub String);
10
15
11
-
#[derive(Debug, Clone, Serialize, Deserialize)]
16
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
17
+
#[serde(transparent)]
18
+
#[sqlx(transparent)]
12
19
pub struct DeviceId(pub String);
13
20
14
-
#[derive(Debug, Clone, Serialize, Deserialize)]
21
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
22
+
#[serde(transparent)]
23
+
#[sqlx(transparent)]
15
24
pub struct SessionId(pub String);
16
25
17
-
#[derive(Debug, Clone, Serialize, Deserialize)]
26
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
27
+
#[serde(transparent)]
28
+
#[sqlx(transparent)]
18
29
pub struct Code(pub String);
19
30
20
-
#[derive(Debug, Clone, Serialize, Deserialize)]
31
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
32
+
#[serde(transparent)]
33
+
#[sqlx(transparent)]
21
34
pub struct RefreshToken(pub String);
22
35
23
36
impl RequestId {
···
82
95
PrivateKeyJwt { client_assertion: String },
83
96
}
84
97
98
+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
99
+
#[serde(rename_all = "snake_case")]
100
+
pub enum ResponseType {
101
+
#[default]
102
+
Code,
103
+
}
104
+
105
+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
106
+
pub enum CodeChallengeMethod {
107
+
#[default]
108
+
#[serde(rename = "S256")]
109
+
S256,
110
+
#[serde(rename = "plain")]
111
+
Plain,
112
+
}
113
+
114
+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
115
+
#[serde(rename_all = "snake_case")]
116
+
pub enum ResponseMode {
117
+
#[default]
118
+
Query,
119
+
Fragment,
120
+
FormPost,
121
+
}
122
+
123
+
impl ResponseMode {
124
+
pub fn as_str(&self) -> &'static str {
125
+
match self {
126
+
Self::Query => "query",
127
+
Self::Fragment => "fragment",
128
+
Self::FormPost => "form_post",
129
+
}
130
+
}
131
+
}
132
+
133
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
134
+
#[serde(rename_all = "snake_case")]
135
+
pub enum Prompt {
136
+
None,
137
+
Login,
138
+
Consent,
139
+
SelectAccount,
140
+
Create,
141
+
}
142
+
143
+
impl Prompt {
144
+
pub fn as_str(&self) -> &'static str {
145
+
match self {
146
+
Self::None => "none",
147
+
Self::Login => "login",
148
+
Self::Consent => "consent",
149
+
Self::SelectAccount => "select_account",
150
+
Self::Create => "create",
151
+
}
152
+
}
153
+
}
154
+
85
155
#[derive(Debug, Clone, Serialize, Deserialize)]
86
156
pub struct AuthorizationRequestParameters {
87
-
pub response_type: String,
157
+
pub response_type: ResponseType,
88
158
pub client_id: String,
89
159
pub redirect_uri: String,
90
160
pub scope: Option<String>,
91
161
pub state: Option<String>,
92
162
pub code_challenge: String,
93
-
pub code_challenge_method: String,
94
-
pub response_mode: Option<String>,
163
+
pub code_challenge_method: CodeChallengeMethod,
164
+
pub response_mode: Option<ResponseMode>,
95
165
pub login_hint: Option<String>,
96
166
pub dpop_jkt: Option<String>,
97
-
pub prompt: Option<String>,
167
+
pub prompt: Option<Prompt>,
98
168
#[serde(flatten)]
99
169
pub extra: Option<JsonValue>,
100
170
}
···
105
175
pub client_auth: Option<ClientAuth>,
106
176
pub parameters: AuthorizationRequestParameters,
107
177
pub expires_at: DateTime<Utc>,
108
-
pub did: Option<String>,
109
-
pub device_id: Option<String>,
110
-
pub code: Option<String>,
111
-
pub controller_did: Option<String>,
178
+
pub did: Option<Did>,
179
+
pub device_id: Option<DeviceId>,
180
+
pub code: Option<Code>,
181
+
pub controller_did: Option<Did>,
112
182
}
113
183
114
184
#[derive(Debug, Clone)]
115
185
pub struct DeviceData {
116
-
pub session_id: String,
186
+
pub session_id: SessionId,
117
187
pub user_agent: Option<String>,
118
188
pub ip_address: String,
119
189
pub last_seen_at: DateTime<Utc>,
···
121
191
122
192
#[derive(Debug, Clone)]
123
193
pub struct TokenData {
124
-
pub did: String,
125
-
pub token_id: String,
194
+
pub did: Did,
195
+
pub token_id: TokenId,
126
196
pub created_at: DateTime<Utc>,
127
197
pub updated_at: DateTime<Utc>,
128
198
pub expires_at: DateTime<Utc>,
129
199
pub client_id: String,
130
200
pub client_auth: ClientAuth,
131
-
pub device_id: Option<String>,
201
+
pub device_id: Option<DeviceId>,
132
202
pub parameters: AuthorizationRequestParameters,
133
203
pub details: Option<JsonValue>,
134
-
pub code: Option<String>,
135
-
pub current_refresh_token: Option<String>,
204
+
pub code: Option<Code>,
205
+
pub current_refresh_token: Option<RefreshToken>,
136
206
pub scope: Option<String>,
137
-
pub controller_did: Option<String>,
207
+
pub controller_did: Option<Did>,
138
208
}
139
209
140
210
#[derive(Debug, Clone, Serialize, Deserialize)]
···
247
317
pub keys: Vec<JwkPublicKey>,
248
318
}
249
319
250
-
#[derive(Debug, Clone, PartialEq, Eq)]
251
-
pub enum AuthFlowState {
252
-
Pending,
253
-
Authenticated {
254
-
did: String,
255
-
device_id: Option<String>,
256
-
},
257
-
Authorized {
258
-
did: String,
259
-
device_id: Option<String>,
260
-
code: String,
261
-
},
262
-
Expired,
320
+
321
+
#[derive(Debug, Clone)]
322
+
pub struct FlowPending {
323
+
pub parameters: AuthorizationRequestParameters,
324
+
pub client_id: String,
325
+
pub client_auth: Option<ClientAuth>,
326
+
pub expires_at: DateTime<Utc>,
327
+
pub controller_did: Option<Did>,
263
328
}
264
329
265
-
impl AuthFlowState {
266
-
pub fn from_request_data(data: &RequestData) -> Self {
267
-
if data.expires_at < chrono::Utc::now() {
268
-
return AuthFlowState::Expired;
269
-
}
270
-
match (&data.did, &data.code) {
271
-
(Some(did), Some(code)) => AuthFlowState::Authorized {
272
-
did: did.clone(),
273
-
device_id: data.device_id.clone(),
274
-
code: code.clone(),
275
-
},
276
-
(Some(did), None) => AuthFlowState::Authenticated {
277
-
did: did.clone(),
278
-
device_id: data.device_id.clone(),
279
-
},
280
-
(None, _) => AuthFlowState::Pending,
281
-
}
282
-
}
330
+
#[derive(Debug, Clone)]
331
+
pub struct FlowAuthenticated {
332
+
pub parameters: AuthorizationRequestParameters,
333
+
pub client_id: String,
334
+
pub client_auth: Option<ClientAuth>,
335
+
pub expires_at: DateTime<Utc>,
336
+
pub did: Did,
337
+
pub device_id: Option<DeviceId>,
338
+
pub controller_did: Option<Did>,
339
+
}
283
340
284
-
pub fn is_pending(&self) -> bool {
285
-
matches!(self, AuthFlowState::Pending)
286
-
}
341
+
#[derive(Debug, Clone)]
342
+
pub struct FlowAuthorized {
343
+
pub parameters: AuthorizationRequestParameters,
344
+
pub client_id: String,
345
+
pub client_auth: Option<ClientAuth>,
346
+
pub expires_at: DateTime<Utc>,
347
+
pub did: Did,
348
+
pub device_id: Option<DeviceId>,
349
+
pub code: Code,
350
+
pub controller_did: Option<Did>,
351
+
}
287
352
288
-
pub fn is_authenticated(&self) -> bool {
289
-
matches!(self, AuthFlowState::Authenticated { .. })
290
-
}
353
+
#[derive(Debug)]
354
+
pub struct FlowExpired;
355
+
356
+
#[derive(Debug)]
357
+
pub struct FlowNotAuthenticated;
358
+
359
+
#[derive(Debug)]
360
+
pub struct FlowNotAuthorized;
291
361
292
-
pub fn is_authorized(&self) -> bool {
293
-
matches!(self, AuthFlowState::Authorized { .. })
362
+
#[derive(Debug, Clone)]
363
+
pub enum AuthFlow {
364
+
Pending(FlowPending),
365
+
Authenticated(FlowAuthenticated),
366
+
Authorized(FlowAuthorized),
367
+
}
368
+
369
+
#[derive(Debug, Clone)]
370
+
pub enum AuthFlowWithUser {
371
+
Authenticated(FlowAuthenticated),
372
+
Authorized(FlowAuthorized),
373
+
}
374
+
375
+
impl AuthFlow {
376
+
pub fn from_request_data(data: RequestData) -> Result<Self, FlowExpired> {
377
+
if data.expires_at < chrono::Utc::now() {
378
+
return Err(FlowExpired);
379
+
}
380
+
match (data.did, data.code) {
381
+
(None, _) => Ok(AuthFlow::Pending(FlowPending {
382
+
parameters: data.parameters,
383
+
client_id: data.client_id,
384
+
client_auth: data.client_auth,
385
+
expires_at: data.expires_at,
386
+
controller_did: data.controller_did,
387
+
})),
388
+
(Some(did), None) => Ok(AuthFlow::Authenticated(FlowAuthenticated {
389
+
parameters: data.parameters,
390
+
client_id: data.client_id,
391
+
client_auth: data.client_auth,
392
+
expires_at: data.expires_at,
393
+
did,
394
+
device_id: data.device_id,
395
+
controller_did: data.controller_did,
396
+
})),
397
+
(Some(did), Some(code)) => Ok(AuthFlow::Authorized(FlowAuthorized {
398
+
parameters: data.parameters,
399
+
client_id: data.client_id,
400
+
client_auth: data.client_auth,
401
+
expires_at: data.expires_at,
402
+
did,
403
+
device_id: data.device_id,
404
+
code,
405
+
controller_did: data.controller_did,
406
+
})),
407
+
}
294
408
}
295
409
296
-
pub fn is_expired(&self) -> bool {
297
-
matches!(self, AuthFlowState::Expired)
410
+
pub fn require_user(self) -> Result<AuthFlowWithUser, FlowNotAuthenticated> {
411
+
match self {
412
+
AuthFlow::Pending(_) => Err(FlowNotAuthenticated),
413
+
AuthFlow::Authenticated(a) => Ok(AuthFlowWithUser::Authenticated(a)),
414
+
AuthFlow::Authorized(a) => Ok(AuthFlowWithUser::Authorized(a)),
415
+
}
298
416
}
299
417
300
-
pub fn can_authenticate(&self) -> bool {
301
-
matches!(self, AuthFlowState::Pending)
418
+
pub fn require_authorized(self) -> Result<FlowAuthorized, FlowNotAuthorized> {
419
+
match self {
420
+
AuthFlow::Authorized(a) => Ok(a),
421
+
_ => Err(FlowNotAuthorized),
422
+
}
302
423
}
424
+
}
303
425
304
-
pub fn can_authorize(&self) -> bool {
305
-
matches!(self, AuthFlowState::Authenticated { .. })
426
+
impl AuthFlowWithUser {
427
+
pub fn did(&self) -> &Did {
428
+
match self {
429
+
AuthFlowWithUser::Authenticated(a) => &a.did,
430
+
AuthFlowWithUser::Authorized(a) => &a.did,
431
+
}
306
432
}
307
433
308
-
pub fn can_exchange(&self) -> bool {
309
-
matches!(self, AuthFlowState::Authorized { .. })
434
+
pub fn device_id(&self) -> Option<&DeviceId> {
435
+
match self {
436
+
AuthFlowWithUser::Authenticated(a) => a.device_id.as_ref(),
437
+
AuthFlowWithUser::Authorized(a) => a.device_id.as_ref(),
438
+
}
310
439
}
311
440
312
-
pub fn did(&self) -> Option<&str> {
441
+
pub fn parameters(&self) -> &AuthorizationRequestParameters {
313
442
match self {
314
-
AuthFlowState::Authenticated { did, .. } | AuthFlowState::Authorized { did, .. } => {
315
-
Some(did)
316
-
}
317
-
_ => None,
443
+
AuthFlowWithUser::Authenticated(a) => &a.parameters,
444
+
AuthFlowWithUser::Authorized(a) => &a.parameters,
318
445
}
319
446
}
320
447
321
-
pub fn code(&self) -> Option<&str> {
448
+
pub fn client_id(&self) -> &str {
322
449
match self {
323
-
AuthFlowState::Authorized { code, .. } => Some(code),
324
-
_ => None,
450
+
AuthFlowWithUser::Authenticated(a) => &a.client_id,
451
+
AuthFlowWithUser::Authorized(a) => &a.client_id,
325
452
}
326
453
}
327
-
}
328
454
329
-
impl std::fmt::Display for AuthFlowState {
330
-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455
+
pub fn controller_did(&self) -> Option<&Did> {
331
456
match self {
332
-
AuthFlowState::Pending => write!(f, "pending"),
333
-
AuthFlowState::Authenticated { did, .. } => write!(f, "authenticated ({})", did),
334
-
AuthFlowState::Authorized { did, code, .. } => {
335
-
write!(
336
-
f,
337
-
"authorized ({}, code={}...)",
338
-
did,
339
-
&code[..8.min(code.len())]
340
-
)
341
-
}
342
-
AuthFlowState::Expired => write!(f, "expired"),
457
+
AuthFlowWithUser::Authenticated(a) => a.controller_did.as_ref(),
458
+
AuthFlowWithUser::Authorized(a) => a.controller_did.as_ref(),
343
459
}
344
460
}
345
461
}
346
462
463
+
347
464
#[derive(Debug, Clone, PartialEq, Eq)]
348
465
pub enum RefreshTokenState {
349
466
Valid,
···
406
523
use chrono::{Duration, Utc};
407
524
408
525
fn make_request_data(
409
-
did: Option<String>,
410
-
code: Option<String>,
526
+
did: Option<Did>,
527
+
code: Option<Code>,
411
528
expires_in: Duration,
412
529
) -> RequestData {
413
530
RequestData {
414
531
client_id: "test-client".into(),
415
532
client_auth: None,
416
533
parameters: AuthorizationRequestParameters {
417
-
response_type: "code".into(),
534
+
response_type: ResponseType::Code,
418
535
client_id: "test-client".into(),
419
536
redirect_uri: "https://example.com/callback".into(),
420
537
scope: Some("atproto".into()),
421
538
state: None,
422
539
code_challenge: "test".into(),
423
-
code_challenge_method: "S256".into(),
540
+
code_challenge_method: CodeChallengeMethod::S256,
424
541
response_mode: None,
425
542
login_hint: None,
426
543
dpop_jkt: None,
···
435
552
}
436
553
}
437
554
555
+
fn test_did(s: &str) -> Did {
556
+
s.parse().expect("valid test DID")
557
+
}
558
+
559
+
fn test_code(s: &str) -> Code {
560
+
Code(s.to_string())
561
+
}
562
+
438
563
#[test]
439
-
fn test_auth_flow_state_pending() {
564
+
fn test_auth_flow_pending() {
440
565
let data = make_request_data(None, None, Duration::minutes(5));
441
-
let state = AuthFlowState::from_request_data(&data);
442
-
assert!(state.is_pending());
443
-
assert!(!state.is_authenticated());
444
-
assert!(!state.is_authorized());
445
-
assert!(!state.is_expired());
446
-
assert!(state.can_authenticate());
447
-
assert!(!state.can_authorize());
448
-
assert!(!state.can_exchange());
449
-
assert!(state.did().is_none());
450
-
assert!(state.code().is_none());
566
+
let flow = AuthFlow::from_request_data(data).expect("should not be expired");
567
+
assert!(matches!(flow, AuthFlow::Pending(_)));
568
+
assert!(flow.clone().require_user().is_err());
569
+
assert!(flow.require_authorized().is_err());
451
570
}
452
571
453
572
#[test]
454
-
fn test_auth_flow_state_authenticated() {
455
-
let data = make_request_data(Some("did:plc:test".into()), None, Duration::minutes(5));
456
-
let state = AuthFlowState::from_request_data(&data);
457
-
assert!(!state.is_pending());
458
-
assert!(state.is_authenticated());
459
-
assert!(!state.is_authorized());
460
-
assert!(!state.is_expired());
461
-
assert!(!state.can_authenticate());
462
-
assert!(state.can_authorize());
463
-
assert!(!state.can_exchange());
464
-
assert_eq!(state.did(), Some("did:plc:test"));
465
-
assert!(state.code().is_none());
573
+
fn test_auth_flow_authenticated() {
574
+
let did = test_did("did:plc:test");
575
+
let data = make_request_data(Some(did.clone()), None, Duration::minutes(5));
576
+
let flow = AuthFlow::from_request_data(data).expect("should not be expired");
577
+
assert!(matches!(flow, AuthFlow::Authenticated(_)));
578
+
let with_user = flow.clone().require_user().expect("should have user");
579
+
assert_eq!(with_user.did(), &did);
580
+
assert!(flow.require_authorized().is_err());
466
581
}
467
582
468
583
#[test]
469
-
fn test_auth_flow_state_authorized() {
584
+
fn test_auth_flow_authorized() {
585
+
let did = test_did("did:plc:test");
586
+
let code = test_code("auth-code-123");
470
587
let data = make_request_data(
471
-
Some("did:plc:test".into()),
472
-
Some("auth-code-123".into()),
588
+
Some(did.clone()),
589
+
Some(code.clone()),
473
590
Duration::minutes(5),
474
591
);
475
-
let state = AuthFlowState::from_request_data(&data);
476
-
assert!(!state.is_pending());
477
-
assert!(!state.is_authenticated());
478
-
assert!(state.is_authorized());
479
-
assert!(!state.is_expired());
480
-
assert!(!state.can_authenticate());
481
-
assert!(!state.can_authorize());
482
-
assert!(state.can_exchange());
483
-
assert_eq!(state.did(), Some("did:plc:test"));
484
-
assert_eq!(state.code(), Some("auth-code-123"));
592
+
let flow = AuthFlow::from_request_data(data).expect("should not be expired");
593
+
assert!(matches!(flow, AuthFlow::Authorized(_)));
594
+
let with_user = flow.clone().require_user().expect("should have user");
595
+
assert_eq!(with_user.did(), &did);
596
+
let authorized = flow.require_authorized().expect("should be authorized");
597
+
assert_eq!(authorized.did, did);
598
+
assert_eq!(authorized.code, code);
485
599
}
486
600
487
601
#[test]
488
-
fn test_auth_flow_state_expired() {
489
-
let data = make_request_data(
490
-
Some("did:plc:test".into()),
491
-
Some("code".into()),
492
-
Duration::minutes(-1),
493
-
);
494
-
let state = AuthFlowState::from_request_data(&data);
495
-
assert!(state.is_expired());
496
-
assert!(!state.can_authenticate());
497
-
assert!(!state.can_authorize());
498
-
assert!(!state.can_exchange());
602
+
fn test_auth_flow_expired() {
603
+
let did = test_did("did:plc:test");
604
+
let code = test_code("code");
605
+
let data = make_request_data(Some(did), Some(code), Duration::minutes(-1));
606
+
let result = AuthFlow::from_request_data(data);
607
+
assert!(result.is_err());
499
608
}
500
609
501
610
#[test]
+4
-10
crates/tranquil-pds/src/api/admin/account/delete.rs
+4
-10
crates/tranquil-pds/src/api/admin/account/delete.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
2
+
use crate::api::error::{ApiError, DbResultExt};
3
3
use crate::auth::{Admin, Auth};
4
4
use crate::state::AppState;
5
5
use crate::types::Did;
···
9
9
response::{IntoResponse, Response},
10
10
};
11
11
use serde::Deserialize;
12
-
use tracing::{error, warn};
12
+
use tracing::warn;
13
13
14
14
#[derive(Deserialize)]
15
15
pub struct DeleteAccountInput {
···
26
26
.user_repo
27
27
.get_id_and_handle_by_did(did)
28
28
.await
29
-
.map_err(|e| {
30
-
error!("DB error in delete_account: {:?}", e);
31
-
ApiError::InternalError(None)
32
-
})?
29
+
.log_db_err("in delete_account")?
33
30
.ok_or(ApiError::AccountNotFound)
34
31
.map(|row| (row.id, row.handle))?;
35
32
···
37
34
.user_repo
38
35
.admin_delete_account_complete(user_id, did)
39
36
.await
40
-
.map_err(|e| {
41
-
error!("Failed to delete account {}: {:?}", did, e);
42
-
ApiError::InternalError(Some("Failed to delete account".into()))
43
-
})?;
37
+
.log_db_err("deleting account")?;
44
38
45
39
if let Err(e) =
46
40
crate::api::repo::record::sequence_account_event(&state, did, false, Some("deleted")).await
+5
-7
crates/tranquil-pds/src/api/admin/account/email.rs
+5
-7
crates/tranquil-pds/src/api/admin/account/email.rs
···
1
-
use crate::api::error::{ApiError, AtpJson};
1
+
use crate::api::error::{ApiError, AtpJson, DbResultExt};
2
2
use crate::auth::{Admin, Auth};
3
3
use crate::state::AppState;
4
+
use crate::util::pds_hostname;
4
5
use crate::types::Did;
5
6
use axum::{
6
7
Json,
···
9
10
response::{IntoResponse, Response},
10
11
};
11
12
use serde::{Deserialize, Serialize};
12
-
use tracing::{error, warn};
13
+
use tracing::warn;
13
14
14
15
#[derive(Deserialize)]
15
16
#[serde(rename_all = "camelCase")]
···
39
40
.user_repo
40
41
.get_by_did(&input.recipient_did)
41
42
.await
42
-
.map_err(|e| {
43
-
error!("DB error in send_email: {:?}", e);
44
-
ApiError::InternalError(None)
45
-
})?
43
+
.log_db_err("in send_email")?
46
44
.ok_or(ApiError::AccountNotFound)?;
47
45
48
46
let email = user.email.ok_or(ApiError::NoEmail)?;
49
47
let (user_id, handle) = (user.id, user.handle);
50
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
48
+
let hostname = pds_hostname();
51
49
let subject = input
52
50
.subject
53
51
.clone()
+3
-10
crates/tranquil-pds/src/api/admin/account/info.rs
+3
-10
crates/tranquil-pds/src/api/admin/account/info.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::auth::{Admin, Auth};
3
3
use crate::state::AppState;
4
4
use crate::types::{Did, Handle};
···
10
10
};
11
11
use serde::{Deserialize, Serialize};
12
12
use std::collections::HashMap;
13
-
use tracing::error;
14
13
15
14
#[derive(Deserialize)]
16
15
pub struct GetAccountInfoParams {
···
74
73
.infra_repo
75
74
.get_admin_account_info_by_did(¶ms.did)
76
75
.await
77
-
.map_err(|e| {
78
-
error!("DB error in get_account_info: {:?}", e);
79
-
ApiError::InternalError(None)
80
-
})?
76
+
.log_db_err("in get_account_info")?
81
77
.ok_or(ApiError::AccountNotFound)?;
82
78
83
79
let invited_by = get_invited_by(&state, account.id).await;
···
214
210
.infra_repo
215
211
.get_admin_account_infos_by_dids(&dids_typed)
216
212
.await
217
-
.map_err(|e| {
218
-
error!("Failed to fetch account infos: {:?}", e);
219
-
ApiError::InternalError(None)
220
-
})?;
213
+
.log_db_err("fetching account infos")?;
221
214
222
215
let user_ids: Vec<uuid::Uuid> = accounts.iter().map(|u| u.id).collect();
223
216
+2
-6
crates/tranquil-pds/src/api/admin/account/search.rs
+2
-6
crates/tranquil-pds/src/api/admin/account/search.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::auth::{Admin, Auth};
3
3
use crate::state::AppState;
4
4
use crate::types::{Did, Handle};
···
9
9
response::{IntoResponse, Response},
10
10
};
11
11
use serde::{Deserialize, Serialize};
12
-
use tracing::error;
13
12
14
13
#[derive(Deserialize)]
15
14
pub struct SearchAccountsParams {
···
66
65
limit + 1,
67
66
)
68
67
.await
69
-
.map_err(|e| {
70
-
error!("DB error in search_accounts: {:?}", e);
71
-
ApiError::InternalError(None)
72
-
})?;
68
+
.log_db_err("in search_accounts")?;
73
69
74
70
let has_more = rows.len() > limit as usize;
75
71
let accounts: Vec<AccountView> = rows
+2
-2
crates/tranquil-pds/src/api/admin/account/update.rs
+2
-2
crates/tranquil-pds/src/api/admin/account/update.rs
···
2
2
use crate::api::error::ApiError;
3
3
use crate::auth::{Admin, Auth};
4
4
use crate::state::AppState;
5
+
use crate::util::pds_hostname_without_port;
5
6
use crate::types::{Did, Handle, PlainPassword};
6
7
use axum::{
7
8
Json,
···
69
70
{
70
71
return Err(ApiError::InvalidHandle(None));
71
72
}
72
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
73
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
73
+
let hostname_for_handles = pds_hostname_without_port();
74
74
let handle = if !input_handle.contains('.') {
75
75
format!("{}.{}", input_handle, hostname_for_handles)
76
76
} else {
+13
-49
crates/tranquil-pds/src/api/admin/config.rs
+13
-49
crates/tranquil-pds/src/api/admin/config.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::auth::{Admin, Auth};
3
3
use crate::state::AppState;
4
4
use axum::{Json, extract::State};
···
56
56
.infra_repo
57
57
.get_server_configs(keys)
58
58
.await
59
-
.map_err(|e| {
60
-
error!("DB error fetching server config: {:?}", e);
61
-
ApiError::InternalError(None)
62
-
})?;
59
+
.log_db_err("fetching server config")?;
63
60
64
61
let config_map: std::collections::HashMap<String, String> = rows.into_iter().collect();
65
62
···
92
89
.infra_repo
93
90
.upsert_server_config("server_name", trimmed)
94
91
.await
95
-
.map_err(|e| {
96
-
error!("DB error upserting server_name: {:?}", e);
97
-
ApiError::InternalError(None)
98
-
})?;
92
+
.log_db_err("upserting server_name")?;
99
93
}
100
94
101
95
if let Some(ref color) = req.primary_color {
···
104
98
.infra_repo
105
99
.delete_server_config("primary_color")
106
100
.await
107
-
.map_err(|e| {
108
-
error!("DB error deleting primary_color: {:?}", e);
109
-
ApiError::InternalError(None)
110
-
})?;
101
+
.log_db_err("deleting primary_color")?;
111
102
} else if is_valid_hex_color(color) {
112
103
state
113
104
.infra_repo
114
105
.upsert_server_config("primary_color", color)
115
106
.await
116
-
.map_err(|e| {
117
-
error!("DB error upserting primary_color: {:?}", e);
118
-
ApiError::InternalError(None)
119
-
})?;
107
+
.log_db_err("upserting primary_color")?;
120
108
} else {
121
109
return Err(ApiError::InvalidRequest(
122
110
"Invalid primary color format (expected #RRGGBB)".into(),
···
130
118
.infra_repo
131
119
.delete_server_config("primary_color_dark")
132
120
.await
133
-
.map_err(|e| {
134
-
error!("DB error deleting primary_color_dark: {:?}", e);
135
-
ApiError::InternalError(None)
136
-
})?;
121
+
.log_db_err("deleting primary_color_dark")?;
137
122
} else if is_valid_hex_color(color) {
138
123
state
139
124
.infra_repo
140
125
.upsert_server_config("primary_color_dark", color)
141
126
.await
142
-
.map_err(|e| {
143
-
error!("DB error upserting primary_color_dark: {:?}", e);
144
-
ApiError::InternalError(None)
145
-
})?;
127
+
.log_db_err("upserting primary_color_dark")?;
146
128
} else {
147
129
return Err(ApiError::InvalidRequest(
148
130
"Invalid primary dark color format (expected #RRGGBB)".into(),
···
156
138
.infra_repo
157
139
.delete_server_config("secondary_color")
158
140
.await
159
-
.map_err(|e| {
160
-
error!("DB error deleting secondary_color: {:?}", e);
161
-
ApiError::InternalError(None)
162
-
})?;
141
+
.log_db_err("deleting secondary_color")?;
163
142
} else if is_valid_hex_color(color) {
164
143
state
165
144
.infra_repo
166
145
.upsert_server_config("secondary_color", color)
167
146
.await
168
-
.map_err(|e| {
169
-
error!("DB error upserting secondary_color: {:?}", e);
170
-
ApiError::InternalError(None)
171
-
})?;
147
+
.log_db_err("upserting secondary_color")?;
172
148
} else {
173
149
return Err(ApiError::InvalidRequest(
174
150
"Invalid secondary color format (expected #RRGGBB)".into(),
···
182
158
.infra_repo
183
159
.delete_server_config("secondary_color_dark")
184
160
.await
185
-
.map_err(|e| {
186
-
error!("DB error deleting secondary_color_dark: {:?}", e);
187
-
ApiError::InternalError(None)
188
-
})?;
161
+
.log_db_err("deleting secondary_color_dark")?;
189
162
} else if is_valid_hex_color(color) {
190
163
state
191
164
.infra_repo
192
165
.upsert_server_config("secondary_color_dark", color)
193
166
.await
194
-
.map_err(|e| {
195
-
error!("DB error upserting secondary_color_dark: {:?}", e);
196
-
ApiError::InternalError(None)
197
-
})?;
167
+
.log_db_err("upserting secondary_color_dark")?;
198
168
} else {
199
169
return Err(ApiError::InvalidRequest(
200
170
"Invalid secondary dark color format (expected #RRGGBB)".into(),
···
235
205
.infra_repo
236
206
.delete_server_config("logo_cid")
237
207
.await
238
-
.map_err(|e| {
239
-
error!("DB error deleting logo_cid: {:?}", e);
240
-
ApiError::InternalError(None)
241
-
})?;
208
+
.log_db_err("deleting logo_cid")?;
242
209
} else {
243
210
state
244
211
.infra_repo
245
212
.upsert_server_config("logo_cid", logo_cid)
246
213
.await
247
-
.map_err(|e| {
248
-
error!("DB error upserting logo_cid: {:?}", e);
249
-
ApiError::InternalError(None)
250
-
})?;
214
+
.log_db_err("upserting logo_cid")?;
251
215
}
252
216
}
253
217
+2
-5
crates/tranquil-pds/src/api/admin/invite.rs
+2
-5
crates/tranquil-pds/src/api/admin/invite.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
2
+
use crate::api::error::{ApiError, DbResultExt};
3
3
use crate::auth::{Admin, Auth};
4
4
use crate::state::AppState;
5
5
use axum::{
···
91
91
.infra_repo
92
92
.list_invite_codes(params.cursor.as_deref(), limit, sort_order)
93
93
.await
94
-
.map_err(|e| {
95
-
error!("DB error fetching invite codes: {:?}", e);
96
-
ApiError::InternalError(None)
97
-
})?;
94
+
.log_db_err("fetching invite codes")?;
98
95
99
96
let user_ids: Vec<uuid::Uuid> = codes_rows.iter().map(|r| r.created_by_user).collect();
100
97
let code_strings: Vec<String> = codes_rows.iter().map(|r| r.code.clone()).collect();
+2
-2
crates/tranquil-pds/src/api/age_assurance.rs
+2
-2
crates/tranquil-pds/src/api/age_assurance.rs
···
33
33
}
34
34
35
35
async fn get_account_created_at(state: &AppState, headers: &HeaderMap) -> Option<String> {
36
-
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
36
+
let auth_header = crate::util::get_header_str(headers, "Authorization");
37
37
tracing::debug!(?auth_header, "age assurance: extracting token");
38
38
39
39
let extracted = extract_auth_token_from_header(auth_header)?;
40
40
tracing::debug!("age assurance: got token, validating");
41
41
42
-
let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
42
+
let dpop_proof = crate::util::get_header_str(headers, "DPoP");
43
43
let http_uri = "/";
44
44
45
45
let auth_user = match validate_token_with_dpop(
+21
-70
crates/tranquil-pds/src/api/delegation.rs
+21
-70
crates/tranquil-pds/src/api/delegation.rs
···
1
1
use crate::api::error::ApiError;
2
2
use crate::api::repo::record::utils::create_signed_commit;
3
3
use crate::auth::{Active, Auth};
4
-
use crate::delegation::{DelegationActionType, SCOPE_PRESETS, scopes};
5
-
use crate::state::{AppState, RateLimitKind};
4
+
use crate::delegation::{
5
+
DelegationActionType, SCOPE_PRESETS, scopes, verify_can_add_controllers,
6
+
verify_can_be_controller, verify_can_control_accounts,
7
+
};
8
+
use crate::rate_limit::{AccountCreationLimit, RateLimited};
9
+
use crate::state::AppState;
6
10
use crate::types::{Did, Handle, Nsid, Rkey};
7
-
use crate::util::extract_client_ip;
11
+
use crate::util::{pds_hostname, pds_hostname_without_port};
8
12
use axum::{
9
13
Json,
10
14
extract::{Query, State},
11
-
http::{HeaderMap, StatusCode},
15
+
http::StatusCode,
12
16
response::{IntoResponse, Response},
13
17
};
14
18
use jacquard_common::types::{integer::LimitedU32, string::Tid};
···
93
97
return Ok(ApiError::ControllerNotFound.into_response());
94
98
}
95
99
96
-
match state.delegation_repo.controls_any_accounts(&auth.did).await {
97
-
Ok(true) => {
98
-
return Ok(ApiError::InvalidDelegation(
99
-
"Cannot add controllers to an account that controls other accounts".into(),
100
-
)
101
-
.into_response());
102
-
}
103
-
Err(e) => {
104
-
tracing::error!("Failed to check delegation status: {:?}", e);
105
-
return Ok(
106
-
ApiError::InternalError(Some("Failed to verify delegation status".into()))
107
-
.into_response(),
108
-
);
109
-
}
110
-
Ok(false) => {}
111
-
}
100
+
let _can_add = match verify_can_add_controllers(&state, &auth).await {
101
+
Ok(proof) => proof,
102
+
Err(response) => return Ok(response),
103
+
};
112
104
113
-
match state
114
-
.delegation_repo
115
-
.has_any_controllers(&input.controller_did)
116
-
.await
117
-
{
118
-
Ok(true) => {
119
-
return Ok(ApiError::InvalidDelegation(
120
-
"Cannot add a controlled account as a controller".into(),
121
-
)
122
-
.into_response());
123
-
}
124
-
Err(e) => {
125
-
tracing::error!("Failed to check controller status: {:?}", e);
126
-
return Ok(
127
-
ApiError::InternalError(Some("Failed to verify controller status".into()))
128
-
.into_response(),
129
-
);
130
-
}
131
-
Ok(false) => {}
105
+
if let Err(response) = verify_can_be_controller(&state, &input.controller_did).await {
106
+
return Ok(response);
132
107
}
133
108
134
109
match state
···
456
431
457
432
pub async fn create_delegated_account(
458
433
State(state): State<AppState>,
459
-
headers: HeaderMap,
434
+
_rate_limit: RateLimited<AccountCreationLimit>,
460
435
auth: Auth<Active>,
461
436
Json(input): Json<CreateDelegatedAccountInput>,
462
437
) -> Result<Response, ApiError> {
463
-
let client_ip = extract_client_ip(&headers);
464
-
if !state
465
-
.check_rate_limit(RateLimitKind::AccountCreation, &client_ip)
466
-
.await
467
-
{
468
-
warn!(ip = %client_ip, "Delegated account creation rate limit exceeded");
469
-
return Ok(ApiError::RateLimitExceeded(Some(
470
-
"Too many account creation attempts. Please try again later.".into(),
471
-
))
472
-
.into_response());
473
-
}
474
-
475
438
if let Err(e) = scopes::validate_delegation_scopes(&input.controller_scopes) {
476
439
return Ok(ApiError::InvalidScopes(e).into_response());
477
440
}
478
441
479
-
match state.delegation_repo.has_any_controllers(&auth.did).await {
480
-
Ok(true) => {
481
-
return Ok(ApiError::InvalidDelegation(
482
-
"Cannot create delegated accounts from a controlled account".into(),
483
-
)
484
-
.into_response());
485
-
}
486
-
Err(e) => {
487
-
tracing::error!("Failed to check controller status: {:?}", e);
488
-
return Ok(
489
-
ApiError::InternalError(Some("Failed to verify controller status".into()))
490
-
.into_response(),
491
-
);
492
-
}
493
-
Ok(false) => {}
494
-
}
442
+
let _can_control = match verify_can_control_accounts(&state, &auth).await {
443
+
Ok(proof) => proof,
444
+
Err(response) => return Ok(response),
445
+
};
495
446
496
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
497
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
447
+
let hostname = pds_hostname();
448
+
let hostname_for_handles = pds_hostname_without_port();
498
449
let pds_suffix = format!(".{}", hostname_for_handles);
499
450
500
451
let handle = if !input.handle.contains('.') || input.handle.ends_with(&pds_suffix) {
+19
crates/tranquil-pds/src/api/error.rs
+19
crates/tranquil-pds/src/api/error.rs
···
694
694
}
695
695
}
696
696
697
+
impl From<crate::rate_limit::UserRateLimitError> for ApiError {
698
+
fn from(e: crate::rate_limit::UserRateLimitError) -> Self {
699
+
Self::RateLimitExceeded(e.message)
700
+
}
701
+
}
702
+
697
703
#[allow(clippy::result_large_err)]
698
704
pub fn parse_did(s: &str) -> Result<tranquil_types::Did, Response> {
699
705
s.parse()
···
756
762
_ => "Invalid request body".to_string(),
757
763
}
758
764
}
765
+
766
+
pub trait DbResultExt<T> {
767
+
fn log_db_err(self, ctx: &str) -> Result<T, ApiError>;
768
+
}
769
+
770
+
impl<T, E: std::fmt::Debug> DbResultExt<T> for Result<T, E> {
771
+
fn log_db_err(self, ctx: &str) -> Result<T, ApiError> {
772
+
self.map_err(|e| {
773
+
tracing::error!("DB error {}: {:?}", ctx, e);
774
+
ApiError::DatabaseError
775
+
})
776
+
}
777
+
}
+16
-41
crates/tranquil-pds/src/api/identity/account.rs
+16
-41
crates/tranquil-pds/src/api/identity/account.rs
···
3
3
use crate::api::repo::record::utils::create_signed_commit;
4
4
use crate::auth::{ServiceTokenVerifier, extract_auth_token_from_header, is_service_token};
5
5
use crate::plc::{PlcClient, create_genesis_operation, signing_key_to_did_key};
6
-
use crate::state::{AppState, RateLimitKind};
6
+
use crate::rate_limit::{AccountCreationLimit, RateLimited};
7
+
use crate::state::AppState;
7
8
use crate::types::{Did, Handle, Nsid, PlainPassword, Rkey};
9
+
use crate::util::{pds_hostname, pds_hostname_without_port};
8
10
use crate::validation::validate_password;
9
11
use axum::{
10
12
Json,
···
22
24
use std::sync::Arc;
23
25
use tracing::{debug, error, info, warn};
24
26
25
-
fn extract_client_ip(headers: &HeaderMap) -> String {
26
-
if let Some(forwarded) = headers.get("x-forwarded-for")
27
-
&& let Ok(value) = forwarded.to_str()
28
-
&& let Some(first_ip) = value.split(',').next()
29
-
{
30
-
return first_ip.trim().to_string();
31
-
}
32
-
if let Some(real_ip) = headers.get("x-real-ip")
33
-
&& let Ok(value) = real_ip.to_str()
34
-
{
35
-
return value.trim().to_string();
36
-
}
37
-
"unknown".to_string()
38
-
}
39
-
40
27
#[derive(Deserialize)]
41
28
#[serde(rename_all = "camelCase")]
42
29
pub struct CreateAccountInput {
···
68
55
69
56
pub async fn create_account(
70
57
State(state): State<AppState>,
58
+
_rate_limit: RateLimited<AccountCreationLimit>,
71
59
headers: HeaderMap,
72
60
Json(input): Json<CreateAccountInput>,
73
61
) -> Response {
···
84
72
} else {
85
73
info!("create_account called");
86
74
}
87
-
let client_ip = extract_client_ip(&headers);
88
-
if !state
89
-
.check_rate_limit(RateLimitKind::AccountCreation, &client_ip)
90
-
.await
91
-
{
92
-
warn!(ip = %client_ip, "Account creation rate limit exceeded");
93
-
return ApiError::RateLimitExceeded(Some(
94
-
"Too many account creation attempts. Please try again later.".into(),
95
-
))
96
-
.into_response();
97
-
}
98
75
99
76
let migration_auth = if let Some(extracted) =
100
-
extract_auth_token_from_header(headers.get("Authorization").and_then(|h| h.to_str().ok()))
77
+
extract_auth_token_from_header(crate::util::get_header_str(&headers, "Authorization"))
101
78
{
102
79
let token = extracted.token;
103
80
if is_service_token(&token) {
···
143
120
if (is_migration || is_did_web_byod)
144
121
&& let (Some(provided_did), Some(auth_did)) = (input.did.as_ref(), migration_auth.as_ref())
145
122
{
146
-
if provided_did != auth_did {
123
+
if provided_did != auth_did.as_str() {
147
124
info!(
148
125
"[MIGRATION] createAccount: Service token mismatch - token_did={} provided_did={}",
149
126
auth_did, provided_did
···
164
141
}
165
142
}
166
143
167
-
let hostname_for_validation =
168
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
144
+
let hostname_for_validation = pds_hostname_without_port();
169
145
let pds_suffix = format!(".{}", hostname_for_validation);
170
146
171
147
let validated_short_handle = if !input.handle.contains('.')
···
242
218
_ => return ApiError::InvalidVerificationChannel.into_response(),
243
219
})
244
220
};
245
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
246
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
221
+
let hostname = pds_hostname();
222
+
let hostname_for_handles = pds_hostname_without_port();
247
223
let pds_endpoint = format!("https://{}", hostname);
248
224
let suffix = format!(".{}", hostname_for_handles);
249
225
let handle = if input.handle.ends_with(&suffix) {
···
308
284
}
309
285
if !is_did_web_byod
310
286
&& let Err(e) =
311
-
verify_did_web(d, &hostname, &input.handle, input.signing_key.as_deref()).await
287
+
verify_did_web(d, hostname, &input.handle, input.signing_key.as_deref()).await
312
288
{
313
289
return ApiError::InvalidDid(e).into_response();
314
290
}
···
324
300
if !is_did_web_byod
325
301
&& let Err(e) = verify_did_web(
326
302
d,
327
-
&hostname,
303
+
hostname,
328
304
&input.handle,
329
305
input.signing_key.as_deref(),
330
306
)
···
478
454
error!("Error creating session: {:?}", e);
479
455
return ApiError::InternalError(None).into_response();
480
456
}
481
-
let hostname =
482
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
457
+
let hostname = pds_hostname();
483
458
let verification_required = if let Some(ref user_email) = email {
484
459
let token =
485
460
crate::auth::verification_token::generate_migration_token(&did, user_email);
···
491
466
reactivated.user_id,
492
467
user_email,
493
468
&formatted_token,
494
-
&hostname,
469
+
hostname,
495
470
)
496
471
.await
497
472
{
···
756
731
warn!("Failed to create default profile for {}: {}", did, e);
757
732
}
758
733
}
759
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
734
+
let hostname = pds_hostname();
760
735
if !is_migration {
761
736
if let Some(ref recipient) = verification_recipient {
762
737
let verification_token = crate::auth::verification_token::generate_signup_token(
···
772
747
verification_channel,
773
748
recipient,
774
749
&formatted_token,
775
-
&hostname,
750
+
hostname,
776
751
)
777
752
.await
778
753
{
···
791
766
user_id,
792
767
user_email,
793
768
&formatted_token,
794
-
&hostname,
769
+
hostname,
795
770
)
796
771
.await
797
772
{
+23
-29
crates/tranquil-pds/src/api/identity/did.rs
+23
-29
crates/tranquil-pds/src/api/identity/did.rs
···
1
1
use crate::api::{ApiError, DidResponse, EmptyResponse};
2
2
use crate::auth::{Auth, NotTakendown};
3
3
use crate::plc::signing_key_to_did_key;
4
+
use crate::rate_limit::{HandleUpdateDailyLimit, HandleUpdateLimit, check_user_rate_limit_with_message};
4
5
use crate::state::AppState;
5
6
use crate::types::Handle;
7
+
use crate::util::{get_header_str, pds_hostname, pds_hostname_without_port};
6
8
use axum::{
7
9
Json,
8
10
extract::{Path, Query, State},
···
101
103
}
102
104
103
105
pub async fn well_known_did(State(state): State<AppState>, headers: HeaderMap) -> Response {
104
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
105
-
let host_header = headers
106
-
.get("host")
107
-
.and_then(|h| h.to_str().ok())
108
-
.unwrap_or(&hostname);
106
+
let hostname = pds_hostname();
107
+
let hostname_without_port = pds_hostname_without_port();
108
+
let host_header = get_header_str(&headers, "host").unwrap_or(hostname);
109
109
let host_without_port = host_header.split(':').next().unwrap_or(host_header);
110
-
let hostname_without_port = hostname.split(':').next().unwrap_or(&hostname);
111
110
if host_without_port != hostname_without_port
112
111
&& host_without_port.ends_with(&format!(".{}", hostname_without_port))
113
112
{
114
113
let handle = host_without_port
115
114
.strip_suffix(&format!(".{}", hostname_without_port))
116
115
.unwrap_or(host_without_port);
117
-
return serve_subdomain_did_doc(&state, handle, &hostname).await;
116
+
return serve_subdomain_did_doc(&state, handle, hostname).await;
118
117
}
119
118
let did = if hostname.contains(':') {
120
119
format!("did:web:{}", hostname.replace(':', "%3A"))
···
257
256
}
258
257
259
258
pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response {
260
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
261
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
259
+
let hostname = pds_hostname();
260
+
let hostname_for_handles = pds_hostname_without_port();
262
261
let current_handle = format!("{}.{}", handle, hostname_for_handles);
263
262
let current_handle_typed: Handle = match current_handle.parse() {
264
263
Ok(h) => h,
···
531
530
ApiError::AuthenticationFailed(Some("OAuth tokens cannot get DID credentials".into()))
532
531
})?;
533
532
534
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
533
+
let hostname = pds_hostname();
535
534
let pds_endpoint = format!("https://{}", hostname);
536
535
let signing_key = k256::ecdsa::SigningKey::from_slice(&key_bytes)
537
536
.map_err(|_| ApiError::InternalError(None))?;
···
585
584
return Ok(e);
586
585
}
587
586
let did = auth.did.clone();
588
-
if !state
589
-
.check_rate_limit(crate::state::RateLimitKind::HandleUpdate, &did)
590
-
.await
591
-
{
592
-
return Err(ApiError::RateLimitExceeded(Some(
593
-
"Too many handle updates. Try again later.".into(),
594
-
)));
595
-
}
596
-
if !state
597
-
.check_rate_limit(crate::state::RateLimitKind::HandleUpdateDaily, &did)
598
-
.await
599
-
{
600
-
return Err(ApiError::RateLimitExceeded(Some(
601
-
"Daily handle update limit exceeded.".into(),
602
-
)));
603
-
}
587
+
let _rate_limit = check_user_rate_limit_with_message::<HandleUpdateLimit>(
588
+
&state,
589
+
&did,
590
+
"Too many handle updates. Try again later.",
591
+
)
592
+
.await?;
593
+
let _daily_rate_limit = check_user_rate_limit_with_message::<HandleUpdateDailyLimit>(
594
+
&state,
595
+
&did,
596
+
"Daily handle update limit exceeded.",
597
+
)
598
+
.await?;
604
599
let user_row = state
605
600
.user_repo
606
601
.get_id_and_handle_by_did(&did)
···
639
634
"Inappropriate language in handle".into(),
640
635
)));
641
636
}
642
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
643
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
637
+
let hostname_for_handles = pds_hostname_without_port();
644
638
let suffix = format!(".{}", hostname_for_handles);
645
639
let is_service_domain =
646
640
crate::handle::is_service_domain_handle(&new_handle, hostname_for_handles);
···
772
766
}
773
767
774
768
pub async fn well_known_atproto_did(State(state): State<AppState>, headers: HeaderMap) -> Response {
775
-
let host = match headers.get("host").and_then(|h| h.to_str().ok()) {
769
+
let host = match crate::util::get_header_str(&headers, "host") {
776
770
Some(h) => h,
777
771
None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(),
778
772
};
+7
-12
crates/tranquil-pds/src/api/identity/plc/request.rs
+7
-12
crates/tranquil-pds/src/api/identity/plc/request.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
2
+
use crate::api::error::{ApiError, DbResultExt};
3
3
use crate::auth::{Auth, Permissive};
4
4
use crate::state::AppState;
5
+
use crate::util::pds_hostname;
5
6
use axum::{
6
7
extract::State,
7
8
response::{IntoResponse, Response},
8
9
};
9
10
use chrono::{Duration, Utc};
10
-
use tracing::{error, info, warn};
11
+
use tracing::{info, warn};
11
12
12
13
fn generate_plc_token() -> String {
13
14
crate::util::generate_token_code()
···
28
29
.user_repo
29
30
.get_id_by_did(&auth.did)
30
31
.await
31
-
.map_err(|e| {
32
-
error!("DB error: {:?}", e);
33
-
ApiError::InternalError(None)
34
-
})?
32
+
.log_db_err("fetching user id")?
35
33
.ok_or(ApiError::AccountNotFound)?;
36
34
37
35
let _ = state.infra_repo.delete_plc_tokens_for_user(user_id).await;
···
41
39
.infra_repo
42
40
.insert_plc_token(user_id, &plc_token, expires_at)
43
41
.await
44
-
.map_err(|e| {
45
-
error!("Failed to create PLC token: {:?}", e);
46
-
ApiError::InternalError(None)
47
-
})?;
42
+
.log_db_err("creating PLC token")?;
48
43
49
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
44
+
let hostname = pds_hostname();
50
45
if let Err(e) = crate::comms::comms_repo::enqueue_plc_operation(
51
46
state.user_repo.as_ref(),
52
47
state.infra_repo.as_ref(),
53
48
user_id,
54
49
&plc_token,
55
-
&hostname,
50
+
hostname,
56
51
)
57
52
.await
58
53
{
+4
-12
crates/tranquil-pds/src/api/identity/plc/sign.rs
+4
-12
crates/tranquil-pds/src/api/identity/plc/sign.rs
···
1
+
use crate::api::error::DbResultExt;
1
2
use crate::api::ApiError;
2
3
use crate::auth::{Auth, Permissive};
3
4
use crate::circuit_breaker::with_circuit_breaker;
···
64
65
.user_repo
65
66
.get_id_by_did(did)
66
67
.await
67
-
.map_err(|e| {
68
-
error!("DB error: {:?}", e);
69
-
ApiError::InternalError(None)
70
-
})?
68
+
.log_db_err("fetching user id")?
71
69
.ok_or(ApiError::AccountNotFound)?;
72
70
73
71
let token_expiry = state
74
72
.infra_repo
75
73
.get_plc_token_expiry(user_id, token)
76
74
.await
77
-
.map_err(|e| {
78
-
error!("DB error: {:?}", e);
79
-
ApiError::InternalError(None)
80
-
})?
75
+
.log_db_err("fetching PLC token expiry")?
81
76
.ok_or_else(|| ApiError::InvalidToken(Some("Invalid or expired token".into())))?;
82
77
83
78
if Utc::now() > token_expiry {
···
88
83
.user_repo
89
84
.get_user_key_by_id(user_id)
90
85
.await
91
-
.map_err(|e| {
92
-
error!("DB error: {:?}", e);
93
-
ApiError::InternalError(None)
94
-
})?
86
+
.log_db_err("fetching user key")?
95
87
.ok_or_else(|| ApiError::InternalError(Some("User signing key not found".into())))?;
96
88
97
89
let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
+5
-9
crates/tranquil-pds/src/api/identity/plc/submit.rs
+5
-9
crates/tranquil-pds/src/api/identity/plc/submit.rs
···
1
+
use crate::api::error::DbResultExt;
1
2
use crate::api::{ApiError, EmptyResponse};
2
3
use crate::auth::{Auth, Permissive};
3
4
use crate::circuit_breaker::with_circuit_breaker;
4
5
use crate::plc::{PlcClient, signing_key_to_did_key, validate_plc_operation};
5
6
use crate::state::AppState;
7
+
use crate::util::pds_hostname;
6
8
use axum::{
7
9
Json,
8
10
extract::State,
···
40
42
.map_err(|e| ApiError::InvalidRequest(format!("Invalid operation: {}", e)))?;
41
43
42
44
let op = &input.operation;
43
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
45
+
let hostname = pds_hostname();
44
46
let public_url = format!("https://{}", hostname);
45
47
let user = state
46
48
.user_repo
47
49
.get_id_and_handle_by_did(did)
48
50
.await
49
-
.map_err(|e| {
50
-
error!("DB error: {:?}", e);
51
-
ApiError::InternalError(None)
52
-
})?
51
+
.log_db_err("fetching user")?
53
52
.ok_or(ApiError::AccountNotFound)?;
54
53
55
54
let key_row = state
56
55
.user_repo
57
56
.get_user_key_by_id(user.id)
58
57
.await
59
-
.map_err(|e| {
60
-
error!("DB error: {:?}", e);
61
-
ApiError::InternalError(None)
62
-
})?
58
+
.log_db_err("fetching user key")?
63
59
.ok_or_else(|| ApiError::InternalError(Some("User signing key not found".into())))?;
64
60
65
61
let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version)
+3
-2
crates/tranquil-pds/src/api/notification_prefs.rs
+3
-2
crates/tranquil-pds/src/api/notification_prefs.rs
···
1
1
use crate::api::error::ApiError;
2
2
use crate::auth::{Active, Auth};
3
3
use crate::state::AppState;
4
+
use crate::util::pds_hostname;
4
5
use axum::{
5
6
Json,
6
7
extract::State,
···
145
146
let formatted_token = crate::auth::verification_token::format_token_for_display(&token);
146
147
147
148
if channel == "email" {
148
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
149
+
let hostname = pds_hostname();
149
150
let handle_str = handle.unwrap_or("user");
150
151
crate::comms::comms_repo::enqueue_email_update(
151
152
state.infra_repo.as_ref(),
···
153
154
identifier,
154
155
handle_str,
155
156
&formatted_token,
156
-
&hostname,
157
+
hostname,
157
158
)
158
159
.await
159
160
.map_err(|e| format!("Failed to enqueue email notification: {}", e))?;
+4
-7
crates/tranquil-pds/src/api/proxy.rs
+4
-7
crates/tranquil-pds/src/api/proxy.rs
···
3
3
use crate::api::error::ApiError;
4
4
use crate::api::proxy_client::proxy_client;
5
5
use crate::state::AppState;
6
+
use crate::util::get_header_str;
6
7
use axum::{
7
8
body::Bytes,
8
9
extract::{RawQuery, Request, State},
···
191
192
.into_response();
192
193
}
193
194
194
-
let Some(proxy_header) = headers
195
-
.get("atproto-proxy")
196
-
.and_then(|h| h.to_str().ok())
197
-
.map(String::from)
198
-
else {
195
+
let Some(proxy_header) = get_header_str(&headers, "atproto-proxy").map(String::from) else {
199
196
return ApiError::InvalidRequest("Missing required atproto-proxy header".into())
200
197
.into_response();
201
198
};
···
217
214
218
215
let mut auth_header_val = headers.get("Authorization").cloned();
219
216
if let Some(extracted) = crate::auth::extract_auth_token_from_header(
220
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
217
+
crate::util::get_header_str(&headers, "Authorization"),
221
218
) {
222
219
let token = extracted.token;
223
-
let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
220
+
let dpop_proof = crate::util::get_header_str(&headers, "DPoP");
224
221
let http_uri = crate::util::build_full_url(&uri.to_string());
225
222
226
223
match crate::auth::validate_token_with_dpop(
+13
-26
crates/tranquil-pds/src/api/repo/blob.rs
+13
-26
crates/tranquil-pds/src/api/repo/blob.rs
···
1
-
use crate::api::error::ApiError;
2
-
use crate::auth::{Auth, AuthAny, NotTakendown, Permissive};
1
+
use crate::api::error::{ApiError, DbResultExt};
2
+
use crate::auth::{Auth, AuthAny, NotTakendown, Permissive, VerifyScope};
3
3
use crate::delegation::DelegationActionType;
4
4
use crate::state::AppState;
5
5
use crate::types::{CidLink, Did};
6
-
use crate::util::get_max_blob_size;
6
+
use crate::util::{get_header_str, get_max_blob_size};
7
7
use axum::body::Body;
8
8
use axum::{
9
9
Json,
···
56
56
if user.status.is_takendown() {
57
57
return Err(ApiError::AccountTakedown);
58
58
}
59
-
let mime_type_for_check = headers
60
-
.get("content-type")
61
-
.and_then(|h| h.to_str().ok())
62
-
.unwrap_or("application/octet-stream");
63
-
if let Err(e) = crate::auth::scope_check::check_blob_scope(
64
-
user.is_oauth(),
65
-
user.scope.as_deref(),
66
-
mime_type_for_check,
67
-
) {
68
-
return Ok(e);
69
-
}
59
+
let mime_type_for_check =
60
+
get_header_str(&headers, "content-type").unwrap_or("application/octet-stream");
61
+
let _scope_proof = match user.verify_blob_upload(mime_type_for_check) {
62
+
Ok(proof) => proof,
63
+
Err(e) => return Ok(e.into_response()),
64
+
};
70
65
(user.did.clone(), user.controller_did.clone())
71
66
}
72
67
};
···
80
75
return Err(ApiError::Forbidden);
81
76
}
82
77
83
-
let client_mime_hint = headers
84
-
.get("content-type")
85
-
.and_then(|h| h.to_str().ok())
86
-
.unwrap_or("application/octet-stream");
78
+
let client_mime_hint =
79
+
get_header_str(&headers, "content-type").unwrap_or("application/octet-stream");
87
80
88
81
let user_id = state
89
82
.user_repo
···
232
225
.user_repo
233
226
.get_by_did(did)
234
227
.await
235
-
.map_err(|e| {
236
-
error!("DB error fetching user: {:?}", e);
237
-
ApiError::InternalError(None)
238
-
})?
228
+
.log_db_err("fetching user")?
239
229
.ok_or(ApiError::InternalError(None))?;
240
230
241
231
let limit = params.limit.unwrap_or(500).clamp(1, 1000);
···
244
234
.blob_repo
245
235
.list_missing_blobs(user.id, cursor, limit + 1)
246
236
.await
247
-
.map_err(|e| {
248
-
error!("DB error fetching missing blobs: {:?}", e);
249
-
ApiError::InternalError(None)
250
-
})?;
237
+
.log_db_err("fetching missing blobs")?;
251
238
252
239
let has_more = missing.len() > limit as usize;
253
240
let blobs: Vec<RecordBlob> = missing
+2
-5
crates/tranquil-pds/src/api/repo/import.rs
+2
-5
crates/tranquil-pds/src/api/repo/import.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
2
+
use crate::api::error::{ApiError, DbResultExt};
3
3
use crate::api::repo::record::create_signed_commit;
4
4
use crate::auth::{Auth, NotTakendown};
5
5
use crate::state::AppState;
···
49
49
.user_repo
50
50
.get_by_did(did)
51
51
.await
52
-
.map_err(|e| {
53
-
error!("DB error fetching user: {:?}", e);
54
-
ApiError::InternalError(None)
55
-
})?
52
+
.log_db_err("fetching user")?
56
53
.ok_or(ApiError::AccountNotFound)?;
57
54
if user.takedown_ref.is_some() {
58
55
return Err(ApiError::AccountTakedown);
+2
-2
crates/tranquil-pds/src/api/repo/meta.rs
+2
-2
crates/tranquil-pds/src/api/repo/meta.rs
···
1
1
use crate::api::error::ApiError;
2
2
use crate::state::AppState;
3
3
use crate::types::AtIdentifier;
4
+
use crate::util::pds_hostname_without_port;
4
5
use axum::{
5
6
Json,
6
7
extract::{Query, State},
···
18
19
State(state): State<AppState>,
19
20
Query(input): Query<DescribeRepoInput>,
20
21
) -> Response {
21
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
22
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
22
+
let hostname_for_handles = pds_hostname_without_port();
23
23
let user_row = if input.repo.is_did() {
24
24
let did: crate::types::Did = match input.repo.as_str().parse() {
25
25
Ok(d) => d,
+16
-35
crates/tranquil-pds/src/api/repo/record/batch.rs
+16
-35
crates/tranquil-pds/src/api/repo/record/batch.rs
···
1
1
use super::validation::validate_record_with_status;
2
2
use crate::api::error::ApiError;
3
3
use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids};
4
-
use crate::auth::{Active, Auth};
4
+
use crate::auth::{Active, Auth, VerifyScope};
5
5
use crate::delegation::DelegationActionType;
6
6
use crate::repo::tracking::TrackingBlockStore;
7
7
use crate::state::AppState;
···
271
271
input.writes.len()
272
272
);
273
273
let did = auth.did.clone();
274
-
let is_oauth = auth.is_oauth();
275
-
let scope = auth.scope.clone();
276
274
let controller_did = auth.controller_did.clone();
277
275
if input.repo.as_str() != did {
278
276
return Err(ApiError::InvalidRepo(
···
310
308
)));
311
309
}
312
310
313
-
let has_custom_scope = scope
314
-
.as_ref()
315
-
.map(|s| s != "com.atproto.access")
316
-
.unwrap_or(false);
317
-
if is_oauth || has_custom_scope {
311
+
{
318
312
use std::collections::HashSet;
319
313
let create_collections: HashSet<&Nsid> = input
320
314
.writes
···
350
344
})
351
345
.collect();
352
346
353
-
let scope_checks = create_collections
354
-
.iter()
355
-
.map(|c| (crate::oauth::RepoAction::Create, c))
356
-
.chain(
357
-
update_collections
358
-
.iter()
359
-
.map(|c| (crate::oauth::RepoAction::Update, c)),
360
-
)
361
-
.chain(
362
-
delete_collections
363
-
.iter()
364
-
.map(|c| (crate::oauth::RepoAction::Delete, c)),
365
-
);
366
-
367
-
if let Some(err) = scope_checks
368
-
.filter_map(|(action, collection)| {
369
-
crate::auth::scope_check::check_repo_scope(
370
-
is_oauth,
371
-
scope.as_deref(),
372
-
action,
373
-
collection,
374
-
)
375
-
.err()
376
-
})
377
-
.next()
378
-
{
379
-
return Ok(err);
347
+
for collection in &create_collections {
348
+
if let Err(e) = auth.verify_repo_create(collection) {
349
+
return Ok(e.into_response());
350
+
}
351
+
}
352
+
for collection in &update_collections {
353
+
if let Err(e) = auth.verify_repo_update(collection) {
354
+
return Ok(e.into_response());
355
+
}
356
+
}
357
+
for collection in &delete_collections {
358
+
if let Err(e) = auth.verify_repo_delete(collection) {
359
+
return Ok(e.into_response());
360
+
}
380
361
}
381
362
}
382
363
+6
-10
crates/tranquil-pds/src/api/repo/record/delete.rs
+6
-10
crates/tranquil-pds/src/api/repo/record/delete.rs
···
1
1
use crate::api::error::ApiError;
2
2
use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log};
3
3
use crate::api::repo::record::write::{CommitInfo, prepare_repo_write};
4
-
use crate::auth::{Active, Auth};
4
+
use crate::auth::{Active, Auth, VerifyScope};
5
5
use crate::delegation::DelegationActionType;
6
6
use crate::repo::tracking::TrackingBlockStore;
7
7
use crate::state::AppState;
···
43
43
auth: Auth<Active>,
44
44
Json(input): Json<DeleteRecordInput>,
45
45
) -> Result<Response, crate::api::error::ApiError> {
46
+
let _scope_proof = match auth.verify_repo_delete(&input.collection) {
47
+
Ok(proof) => proof,
48
+
Err(e) => return Ok(e.into_response()),
49
+
};
50
+
46
51
let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await {
47
52
Ok(res) => res,
48
53
Err(err_res) => return Ok(err_res),
49
54
};
50
55
51
-
if let Err(e) = crate::auth::scope_check::check_repo_scope(
52
-
repo_auth.is_oauth,
53
-
repo_auth.scope.as_deref(),
54
-
crate::oauth::RepoAction::Delete,
55
-
&input.collection,
56
-
) {
57
-
return Ok(e);
58
-
}
59
-
60
56
let did = repo_auth.did;
61
57
let user_id = repo_auth.user_id;
62
58
let current_root_cid = repo_auth.current_root_cid;
+3
-4
crates/tranquil-pds/src/api/repo/record/read.rs
+3
-4
crates/tranquil-pds/src/api/repo/record/read.rs
···
1
1
use crate::api::error::ApiError;
2
2
use crate::state::AppState;
3
3
use crate::types::{AtIdentifier, Nsid, Rkey};
4
+
use crate::util::pds_hostname_without_port;
4
5
use axum::{
5
6
Json,
6
7
extract::{Query, State},
···
58
59
_headers: HeaderMap,
59
60
Query(input): Query<GetRecordInput>,
60
61
) -> Response {
61
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
62
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
62
+
let hostname_for_handles = pds_hostname_without_port();
63
63
let user_id_opt = if input.repo.is_did() {
64
64
let did: crate::types::Did = match input.repo.as_str().parse() {
65
65
Ok(d) => d,
···
157
157
State(state): State<AppState>,
158
158
Query(input): Query<ListRecordsInput>,
159
159
) -> Response {
160
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
161
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
160
+
let hostname_for_handles = pds_hostname_without_port();
162
161
let user_id_opt = if input.repo.is_did() {
163
162
let did: crate::types::Did = match input.repo.as_str().parse() {
164
163
Ok(d) => d,
+15
-27
crates/tranquil-pds/src/api/repo/record/write.rs
+15
-27
crates/tranquil-pds/src/api/repo/record/write.rs
···
3
3
use crate::api::repo::record::utils::{
4
4
CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids,
5
5
};
6
-
use crate::auth::{Active, Auth};
6
+
use crate::auth::{Active, Auth, VerifyScope};
7
7
use crate::delegation::DelegationActionType;
8
8
use crate::repo::tracking::TrackingBlockStore;
9
9
use crate::state::AppState;
···
127
127
auth: Auth<Active>,
128
128
Json(input): Json<CreateRecordInput>,
129
129
) -> Result<Response, crate::api::error::ApiError> {
130
+
let _scope_proof = match auth.verify_repo_create(&input.collection) {
131
+
Ok(proof) => proof,
132
+
Err(e) => return Ok(e.into_response()),
133
+
};
134
+
130
135
let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await {
131
136
Ok(res) => res,
132
137
Err(err_res) => return Ok(err_res),
133
138
};
134
139
135
-
if let Err(e) = crate::auth::scope_check::check_repo_scope(
136
-
repo_auth.is_oauth,
137
-
repo_auth.scope.as_deref(),
138
-
crate::oauth::RepoAction::Create,
139
-
&input.collection,
140
-
) {
141
-
return Ok(e);
142
-
}
143
-
144
140
let did = repo_auth.did;
145
141
let user_id = repo_auth.user_id;
146
142
let current_root_cid = repo_auth.current_root_cid;
···
434
430
auth: Auth<Active>,
435
431
Json(input): Json<PutRecordInput>,
436
432
) -> Result<Response, crate::api::error::ApiError> {
433
+
let _create_proof = match auth.verify_repo_create(&input.collection) {
434
+
Ok(proof) => proof,
435
+
Err(e) => return Ok(e.into_response()),
436
+
};
437
+
let _update_proof = match auth.verify_repo_update(&input.collection) {
438
+
Ok(proof) => proof,
439
+
Err(e) => return Ok(e.into_response()),
440
+
};
441
+
437
442
let repo_auth = match prepare_repo_write(&state, &auth, &input.repo).await {
438
443
Ok(res) => res,
439
444
Err(err_res) => return Ok(err_res),
440
445
};
441
446
442
-
if let Err(e) = crate::auth::scope_check::check_repo_scope(
443
-
repo_auth.is_oauth,
444
-
repo_auth.scope.as_deref(),
445
-
crate::oauth::RepoAction::Create,
446
-
&input.collection,
447
-
) {
448
-
return Ok(e);
449
-
}
450
-
if let Err(e) = crate::auth::scope_check::check_repo_scope(
451
-
repo_auth.is_oauth,
452
-
repo_auth.scope.as_deref(),
453
-
crate::oauth::RepoAction::Update,
454
-
&input.collection,
455
-
) {
456
-
return Ok(e);
457
-
}
458
-
459
447
let did = repo_auth.did;
460
448
let user_id = repo_auth.user_id;
461
449
let current_root_cid = repo_auth.current_root_cid;
+18
-26
crates/tranquil-pds/src/api/server/account_status.rs
+18
-26
crates/tranquil-pds/src/api/server/account_status.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
3
-
use crate::auth::{Auth, NotTakendown, Permissive};
2
+
use crate::api::error::{ApiError, DbResultExt};
3
+
use crate::auth::{Auth, NotTakendown, Permissive, require_legacy_session_mfa};
4
4
use crate::cache::Cache;
5
5
use crate::plc::PlcClient;
6
6
use crate::state::AppState;
7
7
use crate::types::PlainPassword;
8
+
use crate::util::pds_hostname;
8
9
use axum::{
9
10
Json,
10
11
extract::State,
···
130
131
did: &crate::types::Did,
131
132
with_retry: bool,
132
133
) -> Result<(), ApiError> {
133
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
134
+
let hostname = pds_hostname();
134
135
let expected_endpoint = format!("https://{}", hostname);
135
136
136
137
if did.as_str().starts_with("did:plc:") {
···
219
220
.and_then(|v| v.get("atproto"))
220
221
.and_then(|k| k.as_str());
221
222
222
-
let user_key = user_repo.get_user_key_by_did(did).await.map_err(|e| {
223
-
error!("Failed to fetch user key: {:?}", e);
224
-
ApiError::InternalError(None)
225
-
})?;
223
+
let user_key = user_repo
224
+
.get_user_key_by_did(did)
225
+
.await
226
+
.log_db_err("fetching user key")?;
226
227
227
228
if let Some(key_info) = user_key {
228
229
let key_bytes =
···
523
524
State(state): State<AppState>,
524
525
auth: Auth<NotTakendown>,
525
526
) -> Result<Response, ApiError> {
526
-
let did = &auth.did;
527
-
528
-
if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, did).await {
529
-
return Ok(crate::api::server::reauth::legacy_mfa_required_response(
530
-
&*state.user_repo,
531
-
&*state.session_repo,
532
-
did,
533
-
)
534
-
.await);
535
-
}
527
+
let session_mfa = match require_legacy_session_mfa(&state, &auth).await {
528
+
Ok(proof) => proof,
529
+
Err(response) => return Ok(response),
530
+
};
536
531
537
532
let user_id = state
538
533
.user_repo
539
-
.get_id_by_did(did)
534
+
.get_id_by_did(session_mfa.did())
540
535
.await
541
536
.ok()
542
537
.flatten()
···
545
540
let expires_at = Utc::now() + Duration::minutes(15);
546
541
state
547
542
.infra_repo
548
-
.create_deletion_request(&confirmation_token, did, expires_at)
543
+
.create_deletion_request(&confirmation_token, session_mfa.did(), expires_at)
549
544
.await
550
-
.map_err(|e| {
551
-
error!("DB error creating deletion token: {:?}", e);
552
-
ApiError::InternalError(None)
553
-
})?;
554
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
545
+
.log_db_err("creating deletion token")?;
546
+
let hostname = pds_hostname();
555
547
if let Err(e) = crate::comms::comms_repo::enqueue_account_deletion(
556
548
state.user_repo.as_ref(),
557
549
state.infra_repo.as_ref(),
558
550
user_id,
559
551
&confirmation_token,
560
-
&hostname,
552
+
hostname,
561
553
)
562
554
.await
563
555
{
564
556
warn!("Failed to enqueue account deletion notification: {:?}", e);
565
557
}
566
-
info!("Account deletion requested for user {}", did);
558
+
info!("Account deletion requested for user {}", session_mfa.did());
567
559
Ok(EmptyResponse::ok().into_response())
568
560
}
569
561
+13
-46
crates/tranquil-pds/src/api/server/app_password.rs
+13
-46
crates/tranquil-pds/src/api/server/app_password.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
2
+
use crate::api::error::{ApiError, DbResultExt};
3
3
use crate::auth::{Auth, NotTakendown, Permissive, generate_app_password};
4
4
use crate::delegation::{DelegationActionType, intersect_scopes};
5
-
use crate::state::{AppState, RateLimitKind};
5
+
use crate::rate_limit::{AppPasswordLimit, RateLimited};
6
+
use crate::state::AppState;
6
7
use axum::{
7
8
Json,
8
9
extract::State,
9
-
http::HeaderMap,
10
10
response::{IntoResponse, Response},
11
11
};
12
12
use serde::{Deserialize, Serialize};
13
13
use serde_json::json;
14
-
use tracing::{error, warn};
14
+
use tracing::error;
15
15
use tranquil_db_traits::AppPasswordCreate;
16
16
17
17
#[derive(Serialize)]
···
39
39
.user_repo
40
40
.get_by_did(&auth.did)
41
41
.await
42
-
.map_err(|e| {
43
-
error!("DB error getting user: {:?}", e);
44
-
ApiError::InternalError(None)
45
-
})?
42
+
.log_db_err("getting user")?
46
43
.ok_or(ApiError::AccountNotFound)?;
47
44
48
45
let rows = state
49
46
.session_repo
50
47
.list_app_passwords(user.id)
51
48
.await
52
-
.map_err(|e| {
53
-
error!("DB error listing app passwords: {:?}", e);
54
-
ApiError::InternalError(None)
55
-
})?;
49
+
.log_db_err("listing app passwords")?;
56
50
let passwords: Vec<AppPassword> = rows
57
51
.iter()
58
52
.map(|row| AppPassword {
···
89
83
90
84
pub async fn create_app_password(
91
85
State(state): State<AppState>,
92
-
headers: HeaderMap,
86
+
_rate_limit: RateLimited<AppPasswordLimit>,
93
87
auth: Auth<NotTakendown>,
94
88
Json(input): Json<CreateAppPasswordInput>,
95
89
) -> Result<Response, ApiError> {
96
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
97
-
if !state
98
-
.check_rate_limit(RateLimitKind::AppPassword, &client_ip)
99
-
.await
100
-
{
101
-
warn!(ip = %client_ip, "App password creation rate limit exceeded");
102
-
return Err(ApiError::RateLimitExceeded(None));
103
-
}
104
-
105
90
let user = state
106
91
.user_repo
107
92
.get_by_did(&auth.did)
108
93
.await
109
-
.map_err(|e| {
110
-
error!("DB error getting user: {:?}", e);
111
-
ApiError::InternalError(None)
112
-
})?
94
+
.log_db_err("getting user")?
113
95
.ok_or(ApiError::AccountNotFound)?;
114
96
115
97
let name = input.name.trim();
···
121
103
.session_repo
122
104
.get_app_password_by_name(user.id, name)
123
105
.await
124
-
.map_err(|e| {
125
-
error!("DB error checking app password: {:?}", e);
126
-
ApiError::InternalError(None)
127
-
})?
106
+
.log_db_err("checking app password")?
128
107
.is_some()
129
108
{
130
109
return Err(ApiError::DuplicateAppPassword);
···
187
166
.session_repo
188
167
.create_app_password(&create_data)
189
168
.await
190
-
.map_err(|e| {
191
-
error!("DB error creating app password: {:?}", e);
192
-
ApiError::InternalError(None)
193
-
})?;
169
+
.log_db_err("creating app password")?;
194
170
195
171
if let Some(ref controller) = controller_did {
196
172
let _ = state
···
234
210
.user_repo
235
211
.get_by_did(&auth.did)
236
212
.await
237
-
.map_err(|e| {
238
-
error!("DB error getting user: {:?}", e);
239
-
ApiError::InternalError(None)
240
-
})?
213
+
.log_db_err("getting user")?
241
214
.ok_or(ApiError::AccountNotFound)?;
242
215
243
216
let name = input.name.trim();
···
255
228
.session_repo
256
229
.delete_sessions_by_app_password(&auth.did, name)
257
230
.await
258
-
.map_err(|e| {
259
-
error!("DB error revoking sessions for app password: {:?}", e);
260
-
ApiError::InternalError(None)
261
-
})?;
231
+
.log_db_err("revoking sessions for app password")?;
262
232
263
233
futures::future::join_all(sessions_to_invalidate.iter().map(|jti| {
264
234
let cache_key = format!("auth:session:{}:{}", &auth.did, jti);
···
273
243
.session_repo
274
244
.delete_app_password(user.id, name)
275
245
.await
276
-
.map_err(|e| {
277
-
error!("DB error revoking app password: {:?}", e);
278
-
ApiError::InternalError(None)
279
-
})?;
246
+
.log_db_err("revoking app password")?;
280
247
281
248
Ok(EmptyResponse::ok().into_response())
282
249
}
+21
-92
crates/tranquil-pds/src/api/server/email.rs
+21
-92
crates/tranquil-pds/src/api/server/email.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::api::{EmptyResponse, TokenRequiredResponse, VerifiedResponse};
3
3
use crate::auth::{Auth, NotTakendown};
4
-
use crate::state::{AppState, RateLimitKind};
4
+
use crate::rate_limit::{EmailUpdateLimit, RateLimited, VerificationCheckLimit};
5
+
use crate::state::AppState;
6
+
use crate::util::pds_hostname;
5
7
use axum::{
6
8
Json,
7
9
extract::State,
···
44
46
45
47
pub async fn request_email_update(
46
48
State(state): State<AppState>,
47
-
headers: axum::http::HeaderMap,
49
+
_rate_limit: RateLimited<EmailUpdateLimit>,
48
50
auth: Auth<NotTakendown>,
49
51
input: Option<Json<RequestEmailUpdateInput>>,
50
52
) -> Result<Response, ApiError> {
51
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
52
-
if !state
53
-
.check_rate_limit(RateLimitKind::EmailUpdate, &client_ip)
54
-
.await
55
-
{
56
-
warn!(ip = %client_ip, "Email update rate limit exceeded");
57
-
return Err(ApiError::RateLimitExceeded(None));
58
-
}
59
-
60
53
if let Err(e) = crate::auth::scope_check::check_account_scope(
61
54
auth.is_oauth(),
62
55
auth.scope.as_deref(),
···
70
63
.user_repo
71
64
.get_email_info_by_did(&auth.did)
72
65
.await
73
-
.map_err(|e| {
74
-
error!("DB error: {:?}", e);
75
-
ApiError::InternalError(None)
76
-
})?
66
+
.log_db_err("getting email info")?
77
67
.ok_or(ApiError::AccountNotFound)?;
78
68
79
69
let Some(current_email) = user.email else {
···
111
101
}
112
102
}
113
103
114
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
104
+
let hostname = pds_hostname();
115
105
if let Err(e) = crate::comms::comms_repo::enqueue_email_update_token(
116
106
state.user_repo.as_ref(),
117
107
state.infra_repo.as_ref(),
118
108
user.id,
119
109
&code,
120
110
&formatted_code,
121
-
&hostname,
111
+
hostname,
122
112
)
123
113
.await
124
114
{
···
139
129
140
130
pub async fn confirm_email(
141
131
State(state): State<AppState>,
142
-
headers: axum::http::HeaderMap,
132
+
_rate_limit: RateLimited<EmailUpdateLimit>,
143
133
auth: Auth<NotTakendown>,
144
134
Json(input): Json<ConfirmEmailInput>,
145
135
) -> Result<Response, ApiError> {
146
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
147
-
if !state
148
-
.check_rate_limit(RateLimitKind::EmailUpdate, &client_ip)
149
-
.await
150
-
{
151
-
warn!(ip = %client_ip, "Confirm email rate limit exceeded");
152
-
return Err(ApiError::RateLimitExceeded(None));
153
-
}
154
-
155
136
if let Err(e) = crate::auth::scope_check::check_account_scope(
156
137
auth.is_oauth(),
157
138
auth.scope.as_deref(),
···
166
147
.user_repo
167
148
.get_email_info_by_did(did)
168
149
.await
169
-
.map_err(|e| {
170
-
error!("DB error: {:?}", e);
171
-
ApiError::InternalError(None)
172
-
})?
150
+
.log_db_err("getting email info")?
173
151
.ok_or(ApiError::AccountNotFound)?;
174
152
175
153
let Some(ref email) = user.email else {
···
213
191
.user_repo
214
192
.set_email_verified(user.id, true)
215
193
.await
216
-
.map_err(|e| {
217
-
error!("DB error confirming email: {:?}", e);
218
-
ApiError::InternalError(None)
219
-
})?;
194
+
.log_db_err("confirming email")?;
220
195
221
196
info!("Email confirmed for user {}", user.id);
222
197
Ok(EmptyResponse::ok().into_response())
···
250
225
.user_repo
251
226
.get_email_info_by_did(did)
252
227
.await
253
-
.map_err(|e| {
254
-
error!("DB error: {:?}", e);
255
-
ApiError::InternalError(None)
256
-
})?
228
+
.log_db_err("getting email info")?
257
229
.ok_or(ApiError::AccountNotFound)?;
258
230
259
231
let user_id = user.id;
···
325
297
.user_repo
326
298
.update_email(user_id, &new_email)
327
299
.await
328
-
.map_err(|e| {
329
-
error!("DB error updating email: {:?}", e);
330
-
ApiError::InternalError(None)
331
-
})?;
300
+
.log_db_err("updating email")?;
332
301
333
302
let verification_token =
334
303
crate::auth::verification_token::generate_signup_token(did, "email", &new_email);
335
304
let formatted_token =
336
305
crate::auth::verification_token::format_token_for_display(&verification_token);
337
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
306
+
let hostname = pds_hostname();
338
307
if let Err(e) = crate::comms::comms_repo::enqueue_signup_verification(
339
308
state.infra_repo.as_ref(),
340
309
user_id,
341
310
"email",
342
311
&new_email,
343
312
&formatted_token,
344
-
&hostname,
313
+
hostname,
345
314
)
346
315
.await
347
316
{
···
371
340
372
341
pub async fn check_email_verified(
373
342
State(state): State<AppState>,
374
-
headers: axum::http::HeaderMap,
343
+
_rate_limit: RateLimited<VerificationCheckLimit>,
375
344
Json(input): Json<CheckEmailVerifiedInput>,
376
345
) -> Response {
377
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
378
-
if !state
379
-
.check_rate_limit(RateLimitKind::VerificationCheck, &client_ip)
380
-
.await
381
-
{
382
-
return ApiError::RateLimitExceeded(None).into_response();
383
-
}
384
-
385
346
match state
386
347
.user_repo
387
348
.check_email_verified_by_identifier(&input.identifier)
···
403
364
404
365
pub async fn authorize_email_update(
405
366
State(state): State<AppState>,
406
-
headers: axum::http::HeaderMap,
367
+
_rate_limit: RateLimited<VerificationCheckLimit>,
407
368
axum::extract::Query(query): axum::extract::Query<AuthorizeEmailUpdateQuery>,
408
369
) -> Response {
409
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
410
-
if !state
411
-
.check_rate_limit(RateLimitKind::VerificationCheck, &client_ip)
412
-
.await
413
-
{
414
-
return ApiError::RateLimitExceeded(None).into_response();
415
-
}
416
-
417
370
let verified = crate::auth::verification_token::verify_token_signature(&query.token);
418
371
419
372
let token_data = match verified {
···
488
441
489
442
info!(did = %did, "Email update authorized via link click");
490
443
491
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
444
+
let hostname = pds_hostname();
492
445
let redirect_url = format!(
493
446
"https://{}/app/verify?type=email-authorize-success",
494
447
hostname
···
499
452
500
453
pub async fn check_email_update_status(
501
454
State(state): State<AppState>,
502
-
headers: axum::http::HeaderMap,
455
+
_rate_limit: RateLimited<VerificationCheckLimit>,
503
456
auth: Auth<NotTakendown>,
504
457
) -> Result<Response, ApiError> {
505
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
506
-
if !state
507
-
.check_rate_limit(RateLimitKind::VerificationCheck, &client_ip)
508
-
.await
509
-
{
510
-
return Err(ApiError::RateLimitExceeded(None));
511
-
}
512
-
513
458
if let Err(e) = crate::auth::scope_check::check_account_scope(
514
459
auth.is_oauth(),
515
460
auth.scope.as_deref(),
···
549
494
550
495
pub async fn check_email_in_use(
551
496
State(state): State<AppState>,
552
-
headers: axum::http::HeaderMap,
497
+
_rate_limit: RateLimited<VerificationCheckLimit>,
553
498
Json(input): Json<CheckEmailInUseInput>,
554
499
) -> Response {
555
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
556
-
if !state
557
-
.check_rate_limit(RateLimitKind::VerificationCheck, &client_ip)
558
-
.await
559
-
{
560
-
return ApiError::RateLimitExceeded(None).into_response();
561
-
}
562
-
563
500
let email = input.email.trim().to_lowercase();
564
501
if email.is_empty() {
565
502
return ApiError::InvalidRequest("email is required".into()).into_response();
···
587
524
588
525
pub async fn check_comms_channel_in_use(
589
526
State(state): State<AppState>,
590
-
headers: axum::http::HeaderMap,
527
+
_rate_limit: RateLimited<VerificationCheckLimit>,
591
528
Json(input): Json<CheckCommsChannelInUseInput>,
592
529
) -> Response {
593
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
594
-
if !state
595
-
.check_rate_limit(RateLimitKind::VerificationCheck, &client_ip)
596
-
.await
597
-
{
598
-
return ApiError::RateLimitExceeded(None).into_response();
599
-
}
600
-
601
530
let channel = match input.channel.to_lowercase().as_str() {
602
531
"email" => CommsChannel::Email,
603
532
"discord" => CommsChannel::Discord,
+5
-9
crates/tranquil-pds/src/api/server/invite.rs
+5
-9
crates/tranquil-pds/src/api/server/invite.rs
···
1
+
use crate::api::error::DbResultExt;
1
2
use crate::api::ApiError;
2
3
use crate::auth::{Admin, Auth, NotTakendown};
3
4
use crate::state::AppState;
4
5
use crate::types::Did;
6
+
use crate::util::pds_hostname;
5
7
use axum::{
6
8
Json,
7
9
extract::State,
···
24
26
}
25
27
26
28
fn gen_invite_code() -> String {
27
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
29
+
let hostname = pds_hostname();
28
30
let hostname_prefix = hostname.replace('.', "-");
29
31
format!("{}-{}", hostname_prefix, gen_random_token())
30
32
}
···
121
123
.user_repo
122
124
.get_any_admin_user_id()
123
125
.await
124
-
.map_err(|e| {
125
-
error!("DB error looking up admin user: {:?}", e);
126
-
ApiError::InternalError(None)
127
-
})?
126
+
.log_db_err("looking up admin user")?
128
127
.ok_or_else(|| {
129
128
error!("No admin user found to create invite codes");
130
129
ApiError::InternalError(None)
···
202
201
.infra_repo
203
202
.get_invite_codes_for_account(&auth.did)
204
203
.await
205
-
.map_err(|e| {
206
-
error!("DB error fetching invite codes: {:?}", e);
207
-
ApiError::InternalError(None)
208
-
})?;
204
+
.log_db_err("fetching invite codes")?;
209
205
210
206
let filtered_codes: Vec<_> = codes_info
211
207
.into_iter()
+3
-2
crates/tranquil-pds/src/api/server/meta.rs
+3
-2
crates/tranquil-pds/src/api/server/meta.rs
···
1
1
use crate::state::AppState;
2
+
use crate::util::pds_hostname;
2
3
use axum::{Json, extract::State, http::StatusCode, response::IntoResponse};
3
4
use serde_json::json;
4
5
···
30
31
}
31
32
32
33
pub async fn describe_server() -> impl IntoResponse {
33
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
34
+
let pds_hostname = pds_hostname();
34
35
let domains_str =
35
-
std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| pds_hostname.clone());
36
+
std::env::var("AVAILABLE_USER_DOMAINS").unwrap_or_else(|_| pds_hostname.to_string());
36
37
let domains: Vec<&str> = domains_str.split(',').map(|s| s.trim()).collect();
37
38
let invite_code_required = std::env::var("INVITE_CODE_REQUIRED")
38
39
.map(|v| v == "true" || v == "1")
+6
-13
crates/tranquil-pds/src/api/server/migration.rs
+6
-13
crates/tranquil-pds/src/api/server/migration.rs
···
1
+
use crate::api::error::DbResultExt;
1
2
use crate::api::ApiError;
2
3
use crate::auth::{Active, Auth};
3
4
use crate::state::AppState;
5
+
use crate::util::pds_hostname;
4
6
use axum::{
5
7
Json,
6
8
extract::State,
···
49
51
.user_repo
50
52
.get_user_for_did_doc(&auth.did)
51
53
.await
52
-
.map_err(|e| {
53
-
tracing::error!("DB error getting user: {:?}", e);
54
-
ApiError::InternalError(None)
55
-
})?
54
+
.log_db_err("getting user")?
56
55
.ok_or(ApiError::AccountNotFound)?;
57
56
58
57
if let Some(ref methods) = input.verification_methods {
···
107
106
.user_repo
108
107
.upsert_did_web_overrides(user.id, verification_methods_json, also_known_as)
109
108
.await
110
-
.map_err(|e| {
111
-
tracing::error!("DB error upserting did_web_overrides: {:?}", e);
112
-
ApiError::InternalError(None)
113
-
})?;
109
+
.log_db_err("upserting did_web_overrides")?;
114
110
115
111
if let Some(ref endpoint) = input.service_endpoint {
116
112
let endpoint_clean = endpoint.trim().trim_end_matches('/');
···
118
114
.user_repo
119
115
.update_migrated_to_pds(&auth.did, endpoint_clean)
120
116
.await
121
-
.map_err(|e| {
122
-
tracing::error!("DB error updating service endpoint: {:?}", e);
123
-
ApiError::InternalError(None)
124
-
})?;
117
+
.log_db_err("updating service endpoint")?;
125
118
}
126
119
127
120
let did_doc = build_did_document(&state, &auth.did).await;
···
154
147
}
155
148
156
149
async fn build_did_document(state: &AppState, did: &crate::types::Did) -> serde_json::Value {
157
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
150
+
let hostname = pds_hostname();
158
151
159
152
let user = match state.user_repo.get_user_for_did_doc_build(did).await {
160
153
Ok(Some(row)) => row,
+17
-64
crates/tranquil-pds/src/api/server/passkey_account.rs
+17
-64
crates/tranquil-pds/src/api/server/passkey_account.rs
···
19
19
20
20
use crate::api::repo::record::utils::create_signed_commit;
21
21
use crate::auth::{ServiceTokenVerifier, generate_app_password, is_service_token};
22
-
use crate::state::{AppState, RateLimitKind};
22
+
use crate::rate_limit::{AccountCreationLimit, PasswordResetLimit, RateLimited};
23
+
use crate::state::AppState;
23
24
use crate::types::{Did, Handle, Nsid, PlainPassword, Rkey};
25
+
use crate::util::{pds_hostname, pds_hostname_without_port};
24
26
use crate::validation::validate_password;
25
27
26
-
fn extract_client_ip(headers: &HeaderMap) -> String {
27
-
if let Some(forwarded) = headers.get("x-forwarded-for")
28
-
&& let Ok(value) = forwarded.to_str()
29
-
&& let Some(first_ip) = value.split(',').next()
30
-
{
31
-
return first_ip.trim().to_string();
32
-
}
33
-
if let Some(real_ip) = headers.get("x-real-ip")
34
-
&& let Ok(value) = real_ip.to_str()
35
-
{
36
-
return value.trim().to_string();
37
-
}
38
-
"unknown".to_string()
39
-
}
40
-
41
28
fn generate_setup_token() -> String {
42
29
let mut rng = rand::thread_rng();
43
30
(0..32)
···
80
67
81
68
pub async fn create_passkey_account(
82
69
State(state): State<AppState>,
70
+
_rate_limit: RateLimited<AccountCreationLimit>,
83
71
headers: HeaderMap,
84
72
Json(input): Json<CreatePasskeyAccountInput>,
85
73
) -> Response {
86
-
let client_ip = extract_client_ip(&headers);
87
-
if !state
88
-
.check_rate_limit(RateLimitKind::AccountCreation, &client_ip)
89
-
.await
90
-
{
91
-
warn!(ip = %client_ip, "Account creation rate limit exceeded");
92
-
return ApiError::RateLimitExceeded(Some(
93
-
"Too many account creation attempts. Please try again later.".into(),
94
-
))
95
-
.into_response();
96
-
}
97
-
98
74
let byod_auth = if let Some(extracted) = crate::auth::extract_auth_token_from_header(
99
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
75
+
crate::util::get_header_str(&headers, "Authorization"),
100
76
) {
101
77
let token = extracted.token;
102
78
if is_service_token(&token) {
···
135
111
.map(|d| d.starts_with("did:web:"))
136
112
.unwrap_or(false);
137
113
138
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
139
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
114
+
let hostname = pds_hostname();
115
+
let hostname_for_handles = pds_hostname_without_port();
140
116
let pds_suffix = format!(".{}", hostname_for_handles);
141
117
142
118
let handle = if !input.handle.contains('.') || input.handle.ends_with(&pds_suffix) {
···
268
244
}
269
245
if is_byod_did_web {
270
246
if let Some(ref auth_did) = byod_auth
271
-
&& d != auth_did
247
+
&& d != auth_did.as_str()
272
248
{
273
249
return ApiError::AuthorizationError(format!(
274
250
"Service token issuer {} does not match DID {}",
···
280
256
} else {
281
257
if let Err(e) = crate::api::identity::did::verify_did_web(
282
258
d,
283
-
&hostname,
259
+
hostname,
284
260
&input.handle,
285
261
input.signing_key.as_deref(),
286
262
)
···
296
272
if let Some(ref auth_did) = byod_auth {
297
273
if let Some(ref provided_did) = input.did {
298
274
if provided_did.starts_with("did:plc:") {
299
-
if provided_did != auth_did {
275
+
if provided_did != auth_did.as_str() {
300
276
return ApiError::AuthorizationError(format!(
301
277
"Service token issuer {} does not match DID {}",
302
278
auth_did, provided_did
···
521
497
verification_channel,
522
498
&verification_recipient,
523
499
&formatted_token,
524
-
&hostname,
500
+
hostname,
525
501
)
526
502
.await
527
503
{
···
626
602
return ApiError::InvalidToken(None).into_response();
627
603
}
628
604
629
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
630
-
let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) {
631
-
Ok(w) => w,
632
-
Err(e) => {
633
-
error!("Failed to create WebAuthn config: {:?}", e);
634
-
return ApiError::InternalError(None).into_response();
635
-
}
636
-
};
605
+
let webauthn = &state.webauthn_config;
637
606
638
607
let reg_state = match state
639
608
.user_repo
···
768
737
return ApiError::InvalidToken(None).into_response();
769
738
}
770
739
771
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
772
-
let webauthn = match crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname) {
773
-
Ok(w) => w,
774
-
Err(e) => {
775
-
error!("Failed to create WebAuthn config: {:?}", e);
776
-
return ApiError::InternalError(None).into_response();
777
-
}
778
-
};
740
+
let webauthn = &state.webauthn_config;
779
741
780
742
let existing_passkeys = state
781
743
.user_repo
···
840
802
841
803
pub async fn request_passkey_recovery(
842
804
State(state): State<AppState>,
843
-
headers: HeaderMap,
805
+
_rate_limit: RateLimited<PasswordResetLimit>,
844
806
Json(input): Json<RequestPasskeyRecoveryInput>,
845
807
) -> Response {
846
-
let client_ip = extract_client_ip(&headers);
847
-
if !state
848
-
.check_rate_limit(RateLimitKind::PasswordReset, &client_ip)
849
-
.await
850
-
{
851
-
return ApiError::RateLimitExceeded(None).into_response();
852
-
}
853
-
854
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
855
-
let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname);
808
+
let hostname_for_handles = pds_hostname_without_port();
856
809
let identifier = input.email.trim().to_lowercase();
857
810
let identifier = identifier.strip_prefix('@').unwrap_or(&identifier);
858
811
let normalized_handle = if identifier.contains('@') || identifier.contains('.') {
···
890
843
return ApiError::InternalError(None).into_response();
891
844
}
892
845
893
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
846
+
let hostname = pds_hostname();
894
847
let recovery_url = format!(
895
848
"https://{}/app/recover-passkey?did={}&token={}",
896
849
hostname,
···
903
856
state.infra_repo.as_ref(),
904
857
user.id,
905
858
&recovery_url,
906
-
&hostname,
859
+
hostname,
907
860
)
908
861
.await;
909
862
+20
-56
crates/tranquil-pds/src/api/server/passkeys.rs
+20
-56
crates/tranquil-pds/src/api/server/passkeys.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
3
-
use crate::auth::webauthn::WebAuthnConfig;
4
-
use crate::auth::{Active, Auth};
2
+
use crate::api::error::{ApiError, DbResultExt};
3
+
use crate::auth::{Active, Auth, require_legacy_session_mfa, require_reauth_window};
5
4
use crate::state::AppState;
6
5
use axum::{
7
6
Json,
···
12
11
use tracing::{error, info, warn};
13
12
use webauthn_rs::prelude::*;
14
13
15
-
fn get_webauthn() -> Result<WebAuthnConfig, ApiError> {
16
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
17
-
WebAuthnConfig::new(&hostname).map_err(|e| {
18
-
error!("Failed to create WebAuthn config: {}", e);
19
-
ApiError::InternalError(Some("WebAuthn configuration failed".into()))
20
-
})
21
-
}
22
-
23
14
#[derive(Deserialize)]
24
15
#[serde(rename_all = "camelCase")]
25
16
pub struct StartRegistrationInput {
···
37
28
auth: Auth<Active>,
38
29
Json(input): Json<StartRegistrationInput>,
39
30
) -> Result<Response, ApiError> {
40
-
let webauthn = get_webauthn()?;
31
+
let webauthn = &state.webauthn_config;
41
32
42
33
let handle = state
43
34
.user_repo
44
35
.get_handle_by_did(&auth.did)
45
36
.await
46
-
.map_err(|e| {
47
-
error!("DB error fetching user: {:?}", e);
48
-
ApiError::InternalError(None)
49
-
})?
37
+
.log_db_err("fetching user")?
50
38
.ok_or(ApiError::AccountNotFound)?;
51
39
52
40
let existing_passkeys = state
53
41
.user_repo
54
42
.get_passkeys_for_user(&auth.did)
55
43
.await
56
-
.map_err(|e| {
57
-
error!("DB error fetching existing passkeys: {:?}", e);
58
-
ApiError::InternalError(None)
59
-
})?;
44
+
.log_db_err("fetching existing passkeys")?;
60
45
61
46
let exclude_credentials: Vec<CredentialID> = existing_passkeys
62
47
.iter()
···
81
66
.user_repo
82
67
.save_webauthn_challenge(&auth.did, "registration", &state_json)
83
68
.await
84
-
.map_err(|e| {
85
-
error!("Failed to save registration state: {:?}", e);
86
-
ApiError::InternalError(None)
87
-
})?;
69
+
.log_db_err("saving registration state")?;
88
70
89
71
let options = serde_json::to_value(&ccr).unwrap_or(serde_json::json!({}));
90
72
···
112
94
auth: Auth<Active>,
113
95
Json(input): Json<FinishRegistrationInput>,
114
96
) -> Result<Response, ApiError> {
115
-
let webauthn = get_webauthn()?;
97
+
let webauthn = &state.webauthn_config;
116
98
117
99
let reg_state_json = state
118
100
.user_repo
119
101
.load_webauthn_challenge(&auth.did, "registration")
120
102
.await
121
-
.map_err(|e| {
122
-
error!("DB error loading registration state: {:?}", e);
123
-
ApiError::InternalError(None)
124
-
})?
103
+
.log_db_err("loading registration state")?
125
104
.ok_or(ApiError::NoRegistrationInProgress)?;
126
105
127
106
let reg_state: SecurityKeyRegistration =
···
157
136
input.friendly_name.as_deref(),
158
137
)
159
138
.await
160
-
.map_err(|e| {
161
-
error!("Failed to save passkey: {:?}", e);
162
-
ApiError::InternalError(None)
163
-
})?;
139
+
.log_db_err("saving passkey")?;
164
140
165
141
if let Err(e) = state
166
142
.user_repo
···
208
184
.user_repo
209
185
.get_passkeys_for_user(&auth.did)
210
186
.await
211
-
.map_err(|e| {
212
-
error!("DB error fetching passkeys: {:?}", e);
213
-
ApiError::InternalError(None)
214
-
})?;
187
+
.log_db_err("fetching passkeys")?;
215
188
216
189
let passkey_infos: Vec<PasskeyInfo> = passkeys
217
190
.into_iter()
···
241
214
auth: Auth<Active>,
242
215
Json(input): Json<DeletePasskeyInput>,
243
216
) -> Result<Response, ApiError> {
244
-
if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await
245
-
{
246
-
return Ok(crate::api::server::reauth::legacy_mfa_required_response(
247
-
&*state.user_repo,
248
-
&*state.session_repo,
249
-
&auth.did,
250
-
)
251
-
.await);
252
-
}
217
+
let session_mfa = match require_legacy_session_mfa(&state, &auth).await {
218
+
Ok(proof) => proof,
219
+
Err(response) => return Ok(response),
220
+
};
253
221
254
-
if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.did).await {
255
-
return Ok(crate::api::server::reauth::reauth_required_response(
256
-
&*state.user_repo,
257
-
&*state.session_repo,
258
-
&auth.did,
259
-
)
260
-
.await);
261
-
}
222
+
let reauth_mfa = match require_reauth_window(&state, &auth).await {
223
+
Ok(proof) => proof,
224
+
Err(response) => return Ok(response),
225
+
};
262
226
263
227
let id: uuid::Uuid = input.id.parse().map_err(|_| ApiError::InvalidId)?;
264
228
265
-
match state.user_repo.delete_passkey(id, &auth.did).await {
229
+
match state.user_repo.delete_passkey(id, reauth_mfa.did()).await {
266
230
Ok(true) => {
267
-
info!(did = %auth.did, passkey_id = %id, "Passkey deleted");
231
+
info!(did = %session_mfa.did(), passkey_id = %id, "Passkey deleted");
268
232
Ok(EmptyResponse::ok().into_response())
269
233
}
270
234
Ok(false) => Err(ApiError::PasskeyNotFound),
+62
-164
crates/tranquil-pds/src/api/server/password.rs
+62
-164
crates/tranquil-pds/src/api/server/password.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::api::{EmptyResponse, HasPasswordResponse, SuccessResponse};
3
-
use crate::auth::{Active, Auth};
4
-
use crate::state::{AppState, RateLimitKind};
3
+
use crate::auth::{
4
+
Active, Auth, require_legacy_session_mfa, require_reauth_window,
5
+
require_reauth_window_if_available,
6
+
};
7
+
use crate::rate_limit::{PasswordResetLimit, RateLimited, ResetPasswordLimit};
8
+
use crate::state::AppState;
5
9
use crate::types::PlainPassword;
10
+
use crate::util::{pds_hostname, pds_hostname_without_port};
6
11
use crate::validation::validate_password;
7
12
use axum::{
8
13
Json,
9
14
extract::State,
10
-
http::HeaderMap,
11
15
response::{IntoResponse, Response},
12
16
};
13
-
use bcrypt::{DEFAULT_COST, hash, verify};
17
+
use bcrypt::{DEFAULT_COST, hash};
14
18
use chrono::{Duration, Utc};
15
19
use serde::Deserialize;
16
20
use tracing::{error, info, warn};
···
18
22
fn generate_reset_code() -> String {
19
23
crate::util::generate_token_code()
20
24
}
21
-
fn extract_client_ip(headers: &HeaderMap) -> String {
22
-
if let Some(forwarded) = headers.get("x-forwarded-for")
23
-
&& let Ok(value) = forwarded.to_str()
24
-
&& let Some(first_ip) = value.split(',').next()
25
-
{
26
-
return first_ip.trim().to_string();
27
-
}
28
-
if let Some(real_ip) = headers.get("x-real-ip")
29
-
&& let Ok(value) = real_ip.to_str()
30
-
{
31
-
return value.trim().to_string();
32
-
}
33
-
"unknown".to_string()
34
-
}
35
25
36
26
#[derive(Deserialize)]
37
27
pub struct RequestPasswordResetInput {
···
41
31
42
32
pub async fn request_password_reset(
43
33
State(state): State<AppState>,
44
-
headers: HeaderMap,
34
+
_rate_limit: RateLimited<PasswordResetLimit>,
45
35
Json(input): Json<RequestPasswordResetInput>,
46
36
) -> Response {
47
-
let client_ip = extract_client_ip(&headers);
48
-
if !state
49
-
.check_rate_limit(RateLimitKind::PasswordReset, &client_ip)
50
-
.await
51
-
{
52
-
warn!(ip = %client_ip, "Password reset rate limit exceeded");
53
-
return ApiError::RateLimitExceeded(None).into_response();
54
-
}
55
37
let identifier = input.email.trim();
56
38
if identifier.is_empty() {
57
39
return ApiError::InvalidRequest("email or handle is required".into()).into_response();
58
40
}
59
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
60
-
let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname);
41
+
let hostname_for_handles = pds_hostname_without_port();
61
42
let normalized = identifier.to_lowercase();
62
43
let normalized = normalized.strip_prefix('@').unwrap_or(&normalized);
63
44
let is_email_lookup = normalized.contains('@');
···
101
82
error!("DB error setting reset code: {:?}", e);
102
83
return ApiError::InternalError(None).into_response();
103
84
}
104
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
85
+
let hostname = pds_hostname();
105
86
if let Err(e) = crate::comms::comms_repo::enqueue_password_reset(
106
87
state.user_repo.as_ref(),
107
88
state.infra_repo.as_ref(),
108
89
user_id,
109
90
&code,
110
-
&hostname,
91
+
hostname,
111
92
)
112
93
.await
113
94
{
···
135
116
136
117
pub async fn reset_password(
137
118
State(state): State<AppState>,
138
-
headers: HeaderMap,
119
+
_rate_limit: RateLimited<ResetPasswordLimit>,
139
120
Json(input): Json<ResetPasswordInput>,
140
121
) -> Response {
141
-
let client_ip = extract_client_ip(&headers);
142
-
if !state
143
-
.check_rate_limit(RateLimitKind::ResetPassword, &client_ip)
144
-
.await
145
-
{
146
-
warn!(ip = %client_ip, "Reset password rate limit exceeded");
147
-
return ApiError::RateLimitExceeded(None).into_response();
148
-
}
149
122
let token = input.token.trim();
150
123
let password = &input.password;
151
124
if token.is_empty() {
···
230
203
auth: Auth<Active>,
231
204
Json(input): Json<ChangePasswordInput>,
232
205
) -> Result<Response, ApiError> {
233
-
if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await
234
-
{
235
-
return Ok(crate::api::server::reauth::legacy_mfa_required_response(
236
-
&*state.user_repo,
237
-
&*state.session_repo,
238
-
&auth.did,
239
-
)
240
-
.await);
241
-
}
206
+
use crate::auth::verify_password_mfa;
242
207
243
-
let current_password = &input.current_password;
244
-
let new_password = &input.new_password;
245
-
if current_password.is_empty() {
208
+
let session_mfa = match require_legacy_session_mfa(&state, &auth).await {
209
+
Ok(proof) => proof,
210
+
Err(response) => return Ok(response),
211
+
};
212
+
213
+
if input.current_password.is_empty() {
246
214
return Err(ApiError::InvalidRequest(
247
215
"currentPassword is required".into(),
248
216
));
249
217
}
250
-
if new_password.is_empty() {
218
+
if input.new_password.is_empty() {
251
219
return Err(ApiError::InvalidRequest("newPassword is required".into()));
252
220
}
253
-
if let Err(e) = validate_password(new_password) {
221
+
if let Err(e) = validate_password(&input.new_password) {
254
222
return Err(ApiError::InvalidRequest(e.to_string()));
255
223
}
224
+
225
+
let password_mfa = verify_password_mfa(&state, &auth, &input.current_password).await?;
226
+
256
227
let user = state
257
228
.user_repo
258
-
.get_id_and_password_hash_by_did(&auth.did)
229
+
.get_id_and_password_hash_by_did(password_mfa.did())
259
230
.await
260
-
.map_err(|e| {
261
-
error!("DB error in change_password: {:?}", e);
262
-
ApiError::InternalError(None)
263
-
})?
231
+
.log_db_err("in change_password")?
264
232
.ok_or(ApiError::AccountNotFound)?;
265
233
266
-
let (user_id, password_hash) = (user.id, user.password_hash);
267
-
let valid = verify(current_password, &password_hash).map_err(|e| {
268
-
error!("Password verification error: {:?}", e);
269
-
ApiError::InternalError(None)
270
-
})?;
271
-
if !valid {
272
-
return Err(ApiError::InvalidPassword(
273
-
"Current password is incorrect".into(),
274
-
));
275
-
}
276
-
let new_password_clone = new_password.to_string();
234
+
let new_password_clone = input.new_password.to_string();
277
235
let new_hash = tokio::task::spawn_blocking(move || hash(new_password_clone, DEFAULT_COST))
278
236
.await
279
237
.map_err(|e| {
···
287
245
288
246
state
289
247
.user_repo
290
-
.update_password_hash(user_id, &new_hash)
248
+
.update_password_hash(user.id, &new_hash)
291
249
.await
292
-
.map_err(|e| {
293
-
error!("DB error updating password: {:?}", e);
294
-
ApiError::InternalError(None)
295
-
})?;
250
+
.log_db_err("updating password")?;
296
251
297
-
info!(did = %&auth.did, "Password changed successfully");
252
+
info!(did = %session_mfa.did(), "Password changed successfully");
298
253
Ok(EmptyResponse::ok().into_response())
299
254
}
300
255
···
302
257
State(state): State<AppState>,
303
258
auth: Auth<Active>,
304
259
) -> Result<Response, ApiError> {
305
-
match state.user_repo.has_password_by_did(&auth.did).await {
306
-
Ok(Some(has)) => Ok(HasPasswordResponse::response(has).into_response()),
307
-
Ok(None) => Err(ApiError::AccountNotFound),
308
-
Err(e) => {
309
-
error!("DB error: {:?}", e);
310
-
Err(ApiError::InternalError(None))
311
-
}
312
-
}
260
+
let has = state
261
+
.user_repo
262
+
.has_password_by_did(&auth.did)
263
+
.await
264
+
.log_db_err("checking password status")?
265
+
.ok_or(ApiError::AccountNotFound)?;
266
+
Ok(HasPasswordResponse::response(has).into_response())
313
267
}
314
268
315
269
pub async fn remove_password(
316
270
State(state): State<AppState>,
317
271
auth: Auth<Active>,
318
272
) -> Result<Response, ApiError> {
319
-
if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await
320
-
{
321
-
return Ok(crate::api::server::reauth::legacy_mfa_required_response(
322
-
&*state.user_repo,
323
-
&*state.session_repo,
324
-
&auth.did,
325
-
)
326
-
.await);
327
-
}
273
+
let session_mfa = match require_legacy_session_mfa(&state, &auth).await {
274
+
Ok(proof) => proof,
275
+
Err(response) => return Ok(response),
276
+
};
328
277
329
-
if crate::api::server::reauth::check_reauth_required_cached(
330
-
&*state.session_repo,
331
-
&state.cache,
332
-
&auth.did,
333
-
)
334
-
.await
335
-
{
336
-
return Ok(crate::api::server::reauth::reauth_required_response(
337
-
&*state.user_repo,
338
-
&*state.session_repo,
339
-
&auth.did,
340
-
)
341
-
.await);
342
-
}
278
+
let reauth_mfa = match require_reauth_window(&state, &auth).await {
279
+
Ok(proof) => proof,
280
+
Err(response) => return Ok(response),
281
+
};
343
282
344
283
let has_passkeys = state
345
284
.user_repo
346
-
.has_passkeys(&auth.did)
285
+
.has_passkeys(reauth_mfa.did())
347
286
.await
348
287
.unwrap_or(false);
349
288
if !has_passkeys {
···
354
293
355
294
let user = state
356
295
.user_repo
357
-
.get_password_info_by_did(&auth.did)
296
+
.get_password_info_by_did(reauth_mfa.did())
358
297
.await
359
-
.map_err(|e| {
360
-
error!("DB error: {:?}", e);
361
-
ApiError::InternalError(None)
362
-
})?
298
+
.log_db_err("getting password info")?
363
299
.ok_or(ApiError::AccountNotFound)?;
364
300
365
301
if user.password_hash.is_none() {
···
372
308
.user_repo
373
309
.remove_user_password(user.id)
374
310
.await
375
-
.map_err(|e| {
376
-
error!("DB error removing password: {:?}", e);
377
-
ApiError::InternalError(None)
378
-
})?;
311
+
.log_db_err("removing password")?;
379
312
380
-
info!(did = %&auth.did, "Password removed - account is now passkey-only");
313
+
info!(did = %session_mfa.did(), "Password removed - account is now passkey-only");
381
314
Ok(SuccessResponse::ok().into_response())
382
315
}
383
316
···
392
325
auth: Auth<Active>,
393
326
Json(input): Json<SetPasswordInput>,
394
327
) -> Result<Response, ApiError> {
395
-
let has_password = state
396
-
.user_repo
397
-
.has_password_by_did(&auth.did)
398
-
.await
399
-
.ok()
400
-
.flatten()
401
-
.unwrap_or(false);
402
-
let has_passkeys = state
403
-
.user_repo
404
-
.has_passkeys(&auth.did)
405
-
.await
406
-
.unwrap_or(false);
407
-
let has_totp = state
408
-
.user_repo
409
-
.has_totp_enabled(&auth.did)
410
-
.await
411
-
.unwrap_or(false);
412
-
413
-
let has_any_reauth_method = has_password || has_passkeys || has_totp;
414
-
415
-
if has_any_reauth_method
416
-
&& crate::api::server::reauth::check_reauth_required_cached(
417
-
&*state.session_repo,
418
-
&state.cache,
419
-
&auth.did,
420
-
)
421
-
.await
422
-
{
423
-
return Ok(crate::api::server::reauth::reauth_required_response(
424
-
&*state.user_repo,
425
-
&*state.session_repo,
426
-
&auth.did,
427
-
)
428
-
.await);
429
-
}
328
+
let reauth_mfa = match require_reauth_window_if_available(&state, &auth).await {
329
+
Ok(proof) => proof,
330
+
Err(response) => return Ok(response),
331
+
};
430
332
431
333
let new_password = &input.new_password;
432
334
if new_password.is_empty() {
···
436
338
return Err(ApiError::InvalidRequest(e.to_string()));
437
339
}
438
340
341
+
let did = reauth_mfa.as_ref().map(|m| m.did()).unwrap_or(&auth.did);
342
+
439
343
let user = state
440
344
.user_repo
441
-
.get_password_info_by_did(&auth.did)
345
+
.get_password_info_by_did(did)
442
346
.await
443
-
.map_err(|e| {
444
-
error!("DB error: {:?}", e);
445
-
ApiError::InternalError(None)
446
-
})?
347
+
.log_db_err("getting password info")?
447
348
.ok_or(ApiError::AccountNotFound)?;
448
349
449
350
if user.password_hash.is_some() {
···
468
369
.user_repo
469
370
.set_new_user_password(user.id, &new_hash)
470
371
.await
471
-
.map_err(|e| {
472
-
error!("DB error setting password: {:?}", e);
473
-
ApiError::InternalError(None)
474
-
})?;
372
+
.log_db_err("setting password")?;
475
373
476
-
info!(did = %&auth.did, "Password set for passkey-only account");
374
+
info!(did = %did, "Password set for passkey-only account");
477
375
Ok(SuccessResponse::ok().into_response())
478
376
}
+21
-58
crates/tranquil-pds/src/api/server/reauth.rs
+21
-58
crates/tranquil-pds/src/api/server/reauth.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use axum::{
3
3
Json,
4
4
extract::State,
···
11
11
use tranquil_db_traits::{SessionRepository, UserRepository};
12
12
13
13
use crate::auth::{Active, Auth};
14
-
use crate::state::{AppState, RateLimitKind};
14
+
use crate::rate_limit::{TotpVerifyLimit, check_user_rate_limit_with_message};
15
+
use crate::state::AppState;
15
16
use crate::types::PlainPassword;
16
17
17
-
const REAUTH_WINDOW_SECONDS: i64 = 300;
18
+
pub const REAUTH_WINDOW_SECONDS: i64 = 300;
18
19
19
20
#[derive(Serialize)]
20
21
#[serde(rename_all = "camelCase")]
···
32
33
.session_repo
33
34
.get_last_reauth_at(&auth.did)
34
35
.await
35
-
.map_err(|e| {
36
-
error!("DB error: {:?}", e);
37
-
ApiError::InternalError(None)
38
-
})?;
36
+
.log_db_err("getting last reauth")?;
39
37
40
38
let reauth_required = is_reauth_required(last_reauth_at);
41
39
let available_methods =
···
70
68
.user_repo
71
69
.get_password_hash_by_did(&auth.did)
72
70
.await
73
-
.map_err(|e| {
74
-
error!("DB error: {:?}", e);
75
-
ApiError::InternalError(None)
76
-
})?
71
+
.log_db_err("fetching password hash")?
77
72
.ok_or(ApiError::AccountNotFound)?;
78
73
79
74
let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false);
···
97
92
98
93
let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did)
99
94
.await
100
-
.map_err(|e| {
101
-
error!("DB error updating reauth: {:?}", e);
102
-
ApiError::InternalError(None)
103
-
})?;
95
+
.log_db_err("updating reauth")?;
104
96
105
97
info!(did = %&auth.did, "Re-auth successful via password");
106
98
Ok(Json(ReauthResponse { reauthed_at }).into_response())
···
117
109
auth: Auth<Active>,
118
110
Json(input): Json<TotpReauthInput>,
119
111
) -> Result<Response, ApiError> {
120
-
if !state
121
-
.check_rate_limit(RateLimitKind::TotpVerify, &auth.did)
122
-
.await
123
-
{
124
-
warn!(did = %&auth.did, "TOTP verification rate limit exceeded");
125
-
return Err(ApiError::RateLimitExceeded(Some(
126
-
"Too many verification attempts. Please try again in a few minutes.".into(),
127
-
)));
128
-
}
112
+
let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>(
113
+
&state,
114
+
&auth.did,
115
+
"Too many verification attempts. Please try again in a few minutes.",
116
+
)
117
+
.await?;
129
118
130
119
let valid =
131
120
crate::api::server::totp::verify_totp_or_backup_for_user(&state, &auth.did, &input.code)
···
140
129
141
130
let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did)
142
131
.await
143
-
.map_err(|e| {
144
-
error!("DB error updating reauth: {:?}", e);
145
-
ApiError::InternalError(None)
146
-
})?;
132
+
.log_db_err("updating reauth")?;
147
133
148
134
info!(did = %&auth.did, "Re-auth successful via TOTP");
149
135
Ok(Json(ReauthResponse { reauthed_at }).into_response())
···
159
145
State(state): State<AppState>,
160
146
auth: Auth<Active>,
161
147
) -> Result<Response, ApiError> {
162
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
163
-
164
148
let stored_passkeys = state
165
149
.user_repo
166
150
.get_passkeys_for_user(&auth.did)
167
151
.await
168
-
.map_err(|e| {
169
-
error!("Failed to get passkeys: {:?}", e);
170
-
ApiError::InternalError(None)
171
-
})?;
152
+
.log_db_err("getting passkeys")?;
172
153
173
154
if stored_passkeys.is_empty() {
174
155
return Err(ApiError::NoPasskeys);
···
185
166
)));
186
167
}
187
168
188
-
let webauthn = crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname).map_err(|e| {
189
-
error!("Failed to create WebAuthn config: {:?}", e);
190
-
ApiError::InternalError(None)
191
-
})?;
169
+
let webauthn = &state.webauthn_config;
192
170
193
171
let (rcr, auth_state) = webauthn.start_authentication(passkeys).map_err(|e| {
194
172
error!("Failed to start passkey authentication: {:?}", e);
···
204
182
.user_repo
205
183
.save_webauthn_challenge(&auth.did, "authentication", &state_json)
206
184
.await
207
-
.map_err(|e| {
208
-
error!("Failed to save authentication state: {:?}", e);
209
-
ApiError::InternalError(None)
210
-
})?;
185
+
.log_db_err("saving authentication state")?;
211
186
212
187
let options = serde_json::to_value(&rcr).unwrap_or(serde_json::json!({}));
213
188
Ok(Json(PasskeyReauthStartResponse { options }).into_response())
···
224
199
auth: Auth<Active>,
225
200
Json(input): Json<PasskeyReauthFinishInput>,
226
201
) -> Result<Response, ApiError> {
227
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
228
-
229
202
let auth_state_json = state
230
203
.user_repo
231
204
.load_webauthn_challenge(&auth.did, "authentication")
232
205
.await
233
-
.map_err(|e| {
234
-
error!("Failed to load authentication state: {:?}", e);
235
-
ApiError::InternalError(None)
236
-
})?
206
+
.log_db_err("loading authentication state")?
237
207
.ok_or(ApiError::NoChallengeInProgress)?;
238
208
239
209
let auth_state: webauthn_rs::prelude::SecurityKeyAuthentication =
···
248
218
ApiError::InvalidCredential
249
219
})?;
250
220
251
-
let webauthn = crate::auth::webauthn::WebAuthnConfig::new(&pds_hostname).map_err(|e| {
252
-
error!("Failed to create WebAuthn config: {:?}", e);
253
-
ApiError::InternalError(None)
254
-
})?;
255
-
256
-
let auth_result = webauthn
221
+
let auth_result = state
222
+
.webauthn_config
257
223
.finish_authentication(&credential, &auth_state)
258
224
.map_err(|e| {
259
225
warn!(did = %&auth.did, "Passkey re-auth failed: {:?}", e);
···
287
253
288
254
let reauthed_at = update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.did)
289
255
.await
290
-
.map_err(|e| {
291
-
error!("DB error updating reauth: {:?}", e);
292
-
ApiError::InternalError(None)
293
-
})?;
256
+
.log_db_err("updating reauth")?;
294
257
295
258
info!(did = %&auth.did, "Re-auth successful via passkey");
296
259
Ok(Json(ReauthResponse { reauthed_at }).into_response())
+2
-2
crates/tranquil-pds/src/api/server/service_auth.rs
+2
-2
crates/tranquil-pds/src/api/server/service_auth.rs
···
51
51
headers: axum::http::HeaderMap,
52
52
Query(params): Query<GetServiceAuthParams>,
53
53
) -> Response {
54
-
let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
55
-
let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
54
+
let auth_header = crate::util::get_header_str(&headers, "Authorization");
55
+
let dpop_proof = crate::util::get_header_str(&headers, "DPoP");
56
56
info!(
57
57
has_auth_header = auth_header.is_some(),
58
58
has_dpop_proof = dpop_proof.is_some(),
+46
-120
crates/tranquil-pds/src/api/server/session.rs
+46
-120
crates/tranquil-pds/src/api/server/session.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::api::{EmptyResponse, SuccessResponse};
3
-
use crate::auth::{Active, Auth, Permissive};
4
-
use crate::state::{AppState, RateLimitKind};
3
+
use crate::auth::{Active, Auth, Permissive, require_legacy_session_mfa, require_reauth_window};
4
+
use crate::rate_limit::{LoginLimit, RateLimited, RefreshSessionLimit};
5
+
use crate::state::AppState;
5
6
use crate::types::{AccountState, Did, Handle, PlainPassword};
7
+
use crate::util::{pds_hostname, pds_hostname_without_port};
6
8
use axum::{
7
9
Json,
8
10
extract::State,
···
15
17
use tracing::{error, info, warn};
16
18
use tranquil_types::TokenId;
17
19
18
-
fn extract_client_ip(headers: &HeaderMap) -> String {
19
-
if let Some(forwarded) = headers.get("x-forwarded-for")
20
-
&& let Ok(value) = forwarded.to_str()
21
-
&& let Some(first_ip) = value.split(',').next()
22
-
{
23
-
return first_ip.trim().to_string();
24
-
}
25
-
if let Some(real_ip) = headers.get("x-real-ip")
26
-
&& let Ok(value) = real_ip.to_str()
27
-
{
28
-
return value.trim().to_string();
29
-
}
30
-
"unknown".to_string()
31
-
}
32
-
33
20
fn normalize_handle(identifier: &str, pds_hostname: &str) -> String {
34
21
let identifier = identifier.trim();
35
22
if identifier.contains('@') || identifier.starts_with("did:") {
···
75
62
76
63
pub async fn create_session(
77
64
State(state): State<AppState>,
78
-
headers: HeaderMap,
65
+
rate_limit: RateLimited<LoginLimit>,
79
66
Json(input): Json<CreateSessionInput>,
80
67
) -> Response {
68
+
let client_ip = rate_limit.client_ip();
81
69
info!(
82
70
"create_session called with identifier: {}",
83
71
input.identifier
84
72
);
85
-
let client_ip = extract_client_ip(&headers);
86
-
if !state
87
-
.check_rate_limit(RateLimitKind::Login, &client_ip)
88
-
.await
89
-
{
90
-
warn!(ip = %client_ip, "Login rate limit exceeded");
91
-
return ApiError::RateLimitExceeded(None).into_response();
92
-
}
93
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
94
-
let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname);
73
+
let pds_host = pds_hostname();
74
+
let hostname_for_handles = pds_hostname_without_port();
95
75
let normalized_identifier = normalize_handle(&input.identifier, hostname_for_handles);
96
76
info!(
97
77
"Normalized identifier: {} -> {}",
···
246
226
ip = %client_ip,
247
227
"Legacy login on TOTP-enabled account - sending notification"
248
228
);
249
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
229
+
let hostname = pds_hostname();
250
230
if let Err(e) = crate::comms::comms_repo::enqueue_legacy_login(
251
231
state.user_repo.as_ref(),
252
232
state.infra_repo.as_ref(),
253
233
row.id,
254
-
&hostname,
255
-
&client_ip,
234
+
hostname,
235
+
client_ip,
256
236
row.preferred_comms_channel,
257
237
)
258
238
.await
···
260
240
error!("Failed to queue legacy login notification: {:?}", e);
261
241
}
262
242
}
263
-
let handle = full_handle(&row.handle, &pds_hostname);
243
+
let handle = full_handle(&row.handle, pds_host);
264
244
let is_active = account_state.is_active();
265
245
let status = account_state.status_for_session().map(String::from);
266
246
Json(CreateSessionOutput {
···
299
279
tranquil_db_traits::CommsChannel::Telegram => ("telegram", row.telegram_verified),
300
280
tranquil_db_traits::CommsChannel::Signal => ("signal", row.signal_verified),
301
281
};
302
-
let pds_hostname =
303
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
304
-
let handle = full_handle(&row.handle, &pds_hostname);
282
+
let pds_hostname = pds_hostname();
283
+
let handle = full_handle(&row.handle, pds_hostname);
305
284
let account_state = AccountState::from_db_fields(
306
285
row.deactivated_at,
307
286
row.takedown_ref.clone(),
···
353
332
_auth: Auth<Active>,
354
333
) -> Result<Response, ApiError> {
355
334
let extracted = crate::auth::extract_auth_token_from_header(
356
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
335
+
crate::util::get_header_str(&headers, "Authorization"),
357
336
)
358
337
.ok_or(ApiError::AuthenticationRequired)?;
359
338
let jti = crate::auth::get_jti_from_token(&extracted.token)
···
374
353
375
354
pub async fn refresh_session(
376
355
State(state): State<AppState>,
356
+
_rate_limit: RateLimited<RefreshSessionLimit>,
377
357
headers: axum::http::HeaderMap,
378
358
) -> Response {
379
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
380
-
if !state
381
-
.check_rate_limit(RateLimitKind::RefreshSession, &client_ip)
382
-
.await
383
-
{
384
-
tracing::warn!(ip = %client_ip, "Refresh session rate limit exceeded");
385
-
return ApiError::RateLimitExceeded(None).into_response();
386
-
}
387
359
let extracted = match crate::auth::extract_auth_token_from_header(
388
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
360
+
crate::util::get_header_str(&headers, "Authorization"),
389
361
) {
390
362
Some(t) => t,
391
363
None => return ApiError::AuthenticationRequired.into_response(),
···
509
481
tranquil_db_traits::CommsChannel::Telegram => ("telegram", u.telegram_verified),
510
482
tranquil_db_traits::CommsChannel::Signal => ("signal", u.signal_verified),
511
483
};
512
-
let pds_hostname =
513
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
514
-
let handle = full_handle(&u.handle, &pds_hostname);
484
+
let pds_hostname = pds_hostname();
485
+
let handle = full_handle(&u.handle, pds_hostname);
515
486
let account_state =
516
487
AccountState::from_db_fields(u.deactivated_at, u.takedown_ref.clone(), None, None);
517
488
let mut response = json!({
···
675
646
return ApiError::InternalError(None).into_response();
676
647
}
677
648
678
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
649
+
let hostname = pds_hostname();
679
650
if let Err(e) = crate::comms::comms_repo::enqueue_welcome(
680
651
state.user_repo.as_ref(),
681
652
state.infra_repo.as_ref(),
682
653
row.id,
683
-
&hostname,
654
+
hostname,
684
655
)
685
656
.await
686
657
{
···
756
727
let formatted_token =
757
728
crate::auth::verification_token::format_token_for_display(&verification_token);
758
729
759
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
730
+
let hostname = pds_hostname();
760
731
if let Err(e) = crate::comms::comms_repo::enqueue_signup_verification(
761
732
state.infra_repo.as_ref(),
762
733
row.id,
763
734
channel_str,
764
735
&recipient,
765
736
&formatted_token,
766
-
&hostname,
737
+
hostname,
767
738
)
768
739
.await
769
740
{
···
804
775
.session_repo
805
776
.list_sessions_by_did(&auth.did)
806
777
.await
807
-
.map_err(|e| {
808
-
error!("DB error fetching JWT sessions: {:?}", e);
809
-
ApiError::InternalError(None)
810
-
})?;
778
+
.log_db_err("fetching JWT sessions")?;
811
779
812
780
let oauth_rows = state
813
781
.oauth_repo
814
782
.list_sessions_by_did(&auth.did)
815
783
.await
816
-
.map_err(|e| {
817
-
error!("DB error fetching OAuth sessions: {:?}", e);
818
-
ApiError::InternalError(None)
819
-
})?;
784
+
.log_db_err("fetching OAuth sessions")?;
820
785
821
786
let jwt_sessions = jwt_rows.into_iter().map(|row| SessionInfo {
822
787
id: format!("jwt:{}", row.id),
···
876
841
.session_repo
877
842
.get_session_access_jti_by_id(session_id, &auth.did)
878
843
.await
879
-
.map_err(|e| {
880
-
error!("DB error in revoke_session: {:?}", e);
881
-
ApiError::InternalError(None)
882
-
})?
844
+
.log_db_err("in revoke_session")?
883
845
.ok_or(ApiError::SessionNotFound)?;
884
846
state
885
847
.session_repo
886
848
.delete_session_by_id(session_id)
887
849
.await
888
-
.map_err(|e| {
889
-
error!("DB error deleting session: {:?}", e);
890
-
ApiError::InternalError(None)
891
-
})?;
850
+
.log_db_err("deleting session")?;
892
851
let cache_key = format!("auth:session:{}:{}", &auth.did, access_jti);
893
852
if let Err(e) = state.cache.delete(&cache_key).await {
894
853
warn!("Failed to invalidate session cache: {:?}", e);
···
902
861
.oauth_repo
903
862
.delete_session_by_id(session_id, &auth.did)
904
863
.await
905
-
.map_err(|e| {
906
-
error!("DB error deleting OAuth session: {:?}", e);
907
-
ApiError::InternalError(None)
908
-
})?;
864
+
.log_db_err("deleting OAuth session")?;
909
865
if deleted == 0 {
910
866
return Err(ApiError::SessionNotFound);
911
867
}
···
932
888
.session_repo
933
889
.delete_sessions_by_did(&auth.did)
934
890
.await
935
-
.map_err(|e| {
936
-
error!("DB error revoking JWT sessions: {:?}", e);
937
-
ApiError::InternalError(None)
938
-
})?;
891
+
.log_db_err("revoking JWT sessions")?;
939
892
let jti_typed = TokenId::from(jti.clone());
940
893
state
941
894
.oauth_repo
942
895
.delete_sessions_by_did_except(&auth.did, &jti_typed)
943
896
.await
944
-
.map_err(|e| {
945
-
error!("DB error revoking OAuth sessions: {:?}", e);
946
-
ApiError::InternalError(None)
947
-
})?;
897
+
.log_db_err("revoking OAuth sessions")?;
948
898
} else {
949
899
state
950
900
.session_repo
951
901
.delete_sessions_by_did_except_jti(&auth.did, &jti)
952
902
.await
953
-
.map_err(|e| {
954
-
error!("DB error revoking JWT sessions: {:?}", e);
955
-
ApiError::InternalError(None)
956
-
})?;
903
+
.log_db_err("revoking JWT sessions")?;
957
904
state
958
905
.oauth_repo
959
906
.delete_sessions_by_did(&auth.did)
960
907
.await
961
-
.map_err(|e| {
962
-
error!("DB error revoking OAuth sessions: {:?}", e);
963
-
ApiError::InternalError(None)
964
-
})?;
908
+
.log_db_err("revoking OAuth sessions")?;
965
909
}
966
910
967
911
info!(did = %&auth.did, "All other sessions revoked");
···
983
927
.user_repo
984
928
.get_legacy_login_pref(&auth.did)
985
929
.await
986
-
.map_err(|e| {
987
-
error!("DB error: {:?}", e);
988
-
ApiError::InternalError(None)
989
-
})?
930
+
.log_db_err("getting legacy login pref")?
990
931
.ok_or(ApiError::AccountNotFound)?;
991
932
Ok(Json(LegacyLoginPreferenceOutput {
992
933
allow_legacy_login: pref.allow_legacy_login,
···
1006
947
auth: Auth<Active>,
1007
948
Json(input): Json<UpdateLegacyLoginInput>,
1008
949
) -> Result<Response, ApiError> {
1009
-
if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await
1010
-
{
1011
-
return Ok(crate::api::server::reauth::legacy_mfa_required_response(
1012
-
&*state.user_repo,
1013
-
&*state.session_repo,
1014
-
&auth.did,
1015
-
)
1016
-
.await);
1017
-
}
950
+
let session_mfa = match require_legacy_session_mfa(&state, &auth).await {
951
+
Ok(proof) => proof,
952
+
Err(response) => return Ok(response),
953
+
};
1018
954
1019
-
if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.did).await {
1020
-
return Ok(crate::api::server::reauth::reauth_required_response(
1021
-
&*state.user_repo,
1022
-
&*state.session_repo,
1023
-
&auth.did,
1024
-
)
1025
-
.await);
1026
-
}
955
+
let reauth_mfa = match require_reauth_window(&state, &auth).await {
956
+
Ok(proof) => proof,
957
+
Err(response) => return Ok(response),
958
+
};
1027
959
1028
960
let updated = state
1029
961
.user_repo
1030
-
.update_legacy_login(&auth.did, input.allow_legacy_login)
962
+
.update_legacy_login(reauth_mfa.did(), input.allow_legacy_login)
1031
963
.await
1032
-
.map_err(|e| {
1033
-
error!("DB error: {:?}", e);
1034
-
ApiError::InternalError(None)
1035
-
})?;
964
+
.log_db_err("updating legacy login")?;
1036
965
if !updated {
1037
966
return Err(ApiError::AccountNotFound);
1038
967
}
1039
968
info!(
1040
-
did = %&auth.did,
969
+
did = %session_mfa.did(),
1041
970
allow_legacy_login = input.allow_legacy_login,
1042
971
"Legacy login preference updated"
1043
972
);
···
1071
1000
.user_repo
1072
1001
.update_locale(&auth.did, &input.preferred_locale)
1073
1002
.await
1074
-
.map_err(|e| {
1075
-
error!("DB error updating locale: {:?}", e);
1076
-
ApiError::InternalError(None)
1077
-
})?;
1003
+
.log_db_err("updating locale")?;
1078
1004
if !updated {
1079
1005
return Err(ApiError::AccountNotFound);
1080
1006
}
+45
-148
crates/tranquil-pds/src/api/server/totp.rs
+45
-148
crates/tranquil-pds/src/api/server/totp.rs
···
1
1
use crate::api::EmptyResponse;
2
-
use crate::api::error::ApiError;
3
-
use crate::auth::{Active, Auth};
2
+
use crate::api::error::{ApiError, DbResultExt};
4
3
use crate::auth::{
5
-
decrypt_totp_secret, encrypt_totp_secret, generate_backup_codes, generate_qr_png_base64,
6
-
generate_totp_secret, generate_totp_uri, hash_backup_code, is_backup_code_format,
7
-
verify_backup_code, verify_totp_code,
4
+
Active, Auth, decrypt_totp_secret, encrypt_totp_secret, generate_backup_codes,
5
+
generate_qr_png_base64, generate_totp_secret, generate_totp_uri, hash_backup_code,
6
+
is_backup_code_format, require_legacy_session_mfa, verify_backup_code, verify_password_mfa,
7
+
verify_totp_code, verify_totp_mfa,
8
8
};
9
-
use crate::state::{AppState, RateLimitKind};
9
+
use crate::rate_limit::{TotpVerifyLimit, check_user_rate_limit_with_message};
10
+
use crate::state::AppState;
10
11
use crate::types::PlainPassword;
12
+
use crate::util::pds_hostname;
11
13
use axum::{
12
14
Json,
13
15
extract::State,
···
45
47
.user_repo
46
48
.get_handle_by_did(&auth.did)
47
49
.await
48
-
.map_err(|e| {
49
-
error!("DB error fetching handle: {:?}", e);
50
-
ApiError::InternalError(None)
51
-
})?
50
+
.log_db_err("fetching handle")?
52
51
.ok_or(ApiError::AccountNotFound)?;
53
52
54
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
55
-
let uri = generate_totp_uri(&secret, &handle, &hostname);
53
+
let hostname = pds_hostname();
54
+
let uri = generate_totp_uri(&secret, &handle, hostname);
56
55
57
-
let qr_code = generate_qr_png_base64(&secret, &handle, &hostname).map_err(|e| {
56
+
let qr_code = generate_qr_png_base64(&secret, &handle, hostname).map_err(|e| {
58
57
error!("Failed to generate QR code: {:?}", e);
59
58
ApiError::InternalError(Some("Failed to generate QR code".into()))
60
59
})?;
···
68
67
.user_repo
69
68
.upsert_totp_secret(&auth.did, &encrypted_secret, ENCRYPTION_VERSION)
70
69
.await
71
-
.map_err(|e| {
72
-
error!("Failed to store TOTP secret: {:?}", e);
73
-
ApiError::InternalError(None)
74
-
})?;
70
+
.log_db_err("storing TOTP secret")?;
75
71
76
72
let secret_base32 = base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &secret);
77
73
···
101
97
auth: Auth<Active>,
102
98
Json(input): Json<EnableTotpInput>,
103
99
) -> Result<Response, ApiError> {
104
-
if !state
105
-
.check_rate_limit(RateLimitKind::TotpVerify, &auth.did)
106
-
.await
107
-
{
108
-
warn!(did = %&auth.did, "TOTP verification rate limit exceeded");
109
-
return Err(ApiError::RateLimitExceeded(None));
110
-
}
100
+
let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>(
101
+
&state,
102
+
&auth.did,
103
+
"Too many verification attempts. Please try again in a few minutes.",
104
+
)
105
+
.await?;
111
106
112
107
let totp_record = match state.user_repo.get_totp_record(&auth.did).await {
113
108
Ok(Some(row)) => row,
···
152
147
.user_repo
153
148
.enable_totp_with_backup_codes(&auth.did, &backup_hashes)
154
149
.await
155
-
.map_err(|e| {
156
-
error!("Failed to enable TOTP: {:?}", e);
157
-
ApiError::InternalError(None)
158
-
})?;
150
+
.log_db_err("enabling TOTP")?;
159
151
160
152
info!(did = %&auth.did, "TOTP enabled with {} backup codes", backup_codes.len());
161
153
···
173
165
auth: Auth<Active>,
174
166
Json(input): Json<DisableTotpInput>,
175
167
) -> Result<Response, ApiError> {
176
-
if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.did).await
177
-
{
178
-
return Ok(crate::api::server::reauth::legacy_mfa_required_response(
179
-
&*state.user_repo,
180
-
&*state.session_repo,
181
-
&auth.did,
182
-
)
183
-
.await);
184
-
}
185
-
186
-
if !state
187
-
.check_rate_limit(RateLimitKind::TotpVerify, &auth.did)
188
-
.await
189
-
{
190
-
warn!(did = %&auth.did, "TOTP verification rate limit exceeded");
191
-
return Err(ApiError::RateLimitExceeded(None));
192
-
}
193
-
194
-
let password_hash = state
195
-
.user_repo
196
-
.get_password_hash_by_did(&auth.did)
197
-
.await
198
-
.map_err(|e| {
199
-
error!("DB error fetching user: {:?}", e);
200
-
ApiError::InternalError(None)
201
-
})?
202
-
.ok_or(ApiError::AccountNotFound)?;
203
-
204
-
let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false);
205
-
if !password_valid {
206
-
return Err(ApiError::InvalidPassword("Password is incorrect".into()));
207
-
}
208
-
209
-
let totp_record = match state.user_repo.get_totp_record(&auth.did).await {
210
-
Ok(Some(row)) if row.verified => row,
211
-
Ok(Some(_)) | Ok(None) => return Err(ApiError::TotpNotEnabled),
212
-
Err(e) => {
213
-
error!("DB error fetching TOTP: {:?}", e);
214
-
return Err(ApiError::InternalError(None));
215
-
}
168
+
let _session_mfa = match require_legacy_session_mfa(&state, &auth).await {
169
+
Ok(proof) => proof,
170
+
Err(response) => return Ok(response),
216
171
};
217
172
218
-
let code = input.code.trim();
219
-
let code_valid = if is_backup_code_format(code) {
220
-
verify_backup_code_for_user(&state, &auth.did, code).await
221
-
} else {
222
-
let secret = decrypt_totp_secret(
223
-
&totp_record.secret_encrypted,
224
-
totp_record.encryption_version,
225
-
)
226
-
.map_err(|e| {
227
-
error!("Failed to decrypt TOTP secret: {:?}", e);
228
-
ApiError::InternalError(None)
229
-
})?;
230
-
verify_totp_code(&secret, code)
231
-
};
173
+
let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>(
174
+
&state,
175
+
&auth.did,
176
+
"Too many verification attempts. Please try again in a few minutes.",
177
+
)
178
+
.await?;
232
179
233
-
if !code_valid {
234
-
return Err(ApiError::InvalidCode(Some(
235
-
"Invalid verification code".into(),
236
-
)));
237
-
}
180
+
let password_mfa = verify_password_mfa(&state, &auth, &input.password).await?;
181
+
let totp_mfa = verify_totp_mfa(&state, &auth, &input.code).await?;
238
182
239
183
state
240
184
.user_repo
241
-
.delete_totp_and_backup_codes(&auth.did)
185
+
.delete_totp_and_backup_codes(totp_mfa.did())
242
186
.await
243
-
.map_err(|e| {
244
-
error!("Failed to delete TOTP: {:?}", e);
245
-
ApiError::InternalError(None)
246
-
})?;
187
+
.log_db_err("deleting TOTP")?;
247
188
248
-
info!(did = %&auth.did, "TOTP disabled");
189
+
info!(did = %password_mfa.did(), "TOTP disabled (verified via {} and {})", password_mfa.method(), totp_mfa.method());
249
190
250
191
Ok(EmptyResponse::ok().into_response())
251
192
}
···
275
216
.user_repo
276
217
.count_unused_backup_codes(&auth.did)
277
218
.await
278
-
.map_err(|e| {
279
-
error!("DB error counting backup codes: {:?}", e);
280
-
ApiError::InternalError(None)
281
-
})?;
219
+
.log_db_err("counting backup codes")?;
282
220
283
221
Ok(Json(GetTotpStatusResponse {
284
222
enabled,
···
305
243
auth: Auth<Active>,
306
244
Json(input): Json<RegenerateBackupCodesInput>,
307
245
) -> Result<Response, ApiError> {
308
-
if !state
309
-
.check_rate_limit(RateLimitKind::TotpVerify, &auth.did)
310
-
.await
311
-
{
312
-
warn!(did = %&auth.did, "TOTP verification rate limit exceeded");
313
-
return Err(ApiError::RateLimitExceeded(None));
314
-
}
315
-
316
-
let password_hash = state
317
-
.user_repo
318
-
.get_password_hash_by_did(&auth.did)
319
-
.await
320
-
.map_err(|e| {
321
-
error!("DB error fetching user: {:?}", e);
322
-
ApiError::InternalError(None)
323
-
})?
324
-
.ok_or(ApiError::AccountNotFound)?;
325
-
326
-
let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false);
327
-
if !password_valid {
328
-
return Err(ApiError::InvalidPassword("Password is incorrect".into()));
329
-
}
330
-
331
-
let totp_record = match state.user_repo.get_totp_record(&auth.did).await {
332
-
Ok(Some(row)) if row.verified => row,
333
-
Ok(Some(_)) | Ok(None) => return Err(ApiError::TotpNotEnabled),
334
-
Err(e) => {
335
-
error!("DB error fetching TOTP: {:?}", e);
336
-
return Err(ApiError::InternalError(None));
337
-
}
338
-
};
339
-
340
-
let secret = decrypt_totp_secret(
341
-
&totp_record.secret_encrypted,
342
-
totp_record.encryption_version,
246
+
let _rate_limit = check_user_rate_limit_with_message::<TotpVerifyLimit>(
247
+
&state,
248
+
&auth.did,
249
+
"Too many verification attempts. Please try again in a few minutes.",
343
250
)
344
-
.map_err(|e| {
345
-
error!("Failed to decrypt TOTP secret: {:?}", e);
346
-
ApiError::InternalError(None)
347
-
})?;
251
+
.await?;
348
252
349
-
let code = input.code.trim();
350
-
if !verify_totp_code(&secret, code) {
351
-
return Err(ApiError::InvalidCode(Some(
352
-
"Invalid verification code".into(),
353
-
)));
354
-
}
253
+
let password_mfa = verify_password_mfa(&state, &auth, &input.password).await?;
254
+
let totp_mfa = verify_totp_mfa(&state, &auth, &input.code).await?;
355
255
356
256
let backup_codes = generate_backup_codes();
357
257
let backup_hashes: Vec<_> = backup_codes
···
365
265
366
266
state
367
267
.user_repo
368
-
.replace_backup_codes(&auth.did, &backup_hashes)
268
+
.replace_backup_codes(totp_mfa.did(), &backup_hashes)
369
269
.await
370
-
.map_err(|e| {
371
-
error!("Failed to regenerate backup codes: {:?}", e);
372
-
ApiError::InternalError(None)
373
-
})?;
270
+
.log_db_err("replacing backup codes")?;
374
271
375
-
info!(did = %&auth.did, "Backup codes regenerated");
272
+
info!(did = %password_mfa.did(), "Backup codes regenerated (verified via {} and {})", password_mfa.method(), totp_mfa.method());
376
273
377
274
Ok(Json(RegenerateBackupCodesResponse { backup_codes }).into_response())
378
275
}
+4
-13
crates/tranquil-pds/src/api/server/trusted_devices.rs
+4
-13
crates/tranquil-pds/src/api/server/trusted_devices.rs
···
1
1
use crate::api::SuccessResponse;
2
-
use crate::api::error::ApiError;
2
+
use crate::api::error::{ApiError, DbResultExt};
3
3
use axum::{
4
4
Json,
5
5
extract::State,
···
79
79
.oauth_repo
80
80
.list_trusted_devices(&auth.did)
81
81
.await
82
-
.map_err(|e| {
83
-
error!("DB error: {:?}", e);
84
-
ApiError::InternalError(None)
85
-
})?;
82
+
.log_db_err("listing trusted devices")?;
86
83
87
84
let devices = rows
88
85
.into_iter()
···
134
131
.oauth_repo
135
132
.revoke_device_trust(&device_id)
136
133
.await
137
-
.map_err(|e| {
138
-
error!("DB error: {:?}", e);
139
-
ApiError::InternalError(None)
140
-
})?;
134
+
.log_db_err("revoking device trust")?;
141
135
142
136
info!(did = %&auth.did, device_id = %input.device_id, "Trusted device revoked");
143
137
Ok(SuccessResponse::ok().into_response())
···
175
169
.oauth_repo
176
170
.update_device_friendly_name(&device_id, input.friendly_name.as_deref())
177
171
.await
178
-
.map_err(|e| {
179
-
error!("DB error: {:?}", e);
180
-
ApiError::InternalError(None)
181
-
})?;
172
+
.log_db_err("updating device friendly name")?;
182
173
183
174
info!(did = %auth.did, device_id = %input.device_id, "Trusted device updated");
184
175
Ok(SuccessResponse::ok().into_response())
+3
-2
crates/tranquil-pds/src/api/server/verify_email.rs
+3
-2
crates/tranquil-pds/src/api/server/verify_email.rs
···
5
5
use tracing::{info, warn};
6
6
7
7
use crate::state::AppState;
8
+
use crate::util::pds_hostname;
8
9
9
10
#[derive(Deserialize)]
10
11
#[serde(rename_all = "camelCase")]
···
70
71
return Ok(Json(ResendMigrationVerificationOutput { sent: true }));
71
72
}
72
73
73
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
74
+
let hostname = pds_hostname();
74
75
let token = crate::auth::verification_token::generate_migration_token(&user.did, &email);
75
76
let formatted_token = crate::auth::verification_token::format_token_for_display(&token);
76
77
···
80
81
user.id,
81
82
&email,
82
83
&formatted_token,
83
-
&hostname,
84
+
hostname,
84
85
)
85
86
.await
86
87
{
+14
-47
crates/tranquil-pds/src/api/server/verify_token.rs
+14
-47
crates/tranquil-pds/src/api/server/verify_token.rs
···
1
-
use crate::api::error::ApiError;
1
+
use crate::api::error::{ApiError, DbResultExt};
2
2
use crate::types::Did;
3
3
use axum::{Json, extract::State};
4
4
use serde::{Deserialize, Serialize};
5
-
use tracing::{error, info, warn};
5
+
use tracing::{info, warn};
6
6
7
7
use crate::auth::verification_token::{
8
8
VerificationPurpose, normalize_token_input, verify_token_signature,
···
81
81
.user_repo
82
82
.get_verification_info(&did_typed)
83
83
.await
84
-
.map_err(|e| {
85
-
warn!(error = ?e, "Database error during migration verification");
86
-
ApiError::InternalError(None)
87
-
})?
84
+
.log_db_err("during migration verification")?
88
85
.ok_or(ApiError::AccountNotFound)?;
89
86
90
87
if user.email.as_ref().map(|e| e.to_lowercase()) != Some(identifier.to_string()) {
···
96
93
.user_repo
97
94
.set_email_verified_flag(user.id)
98
95
.await
99
-
.map_err(|e| {
100
-
warn!(error = ?e, "Failed to update email_verified status");
101
-
ApiError::InternalError(None)
102
-
})?;
96
+
.log_db_err("updating email_verified status")?;
103
97
}
104
98
105
99
info!(did = %did, "Migration email verified successfully");
···
125
119
.user_repo
126
120
.get_id_by_did(&did_typed)
127
121
.await
128
-
.map_err(|_| ApiError::InternalError(None))?
122
+
.log_db_err("fetching user id")?
129
123
.ok_or(ApiError::AccountNotFound)?;
130
124
131
125
match channel {
···
134
128
.user_repo
135
129
.verify_email_channel(user_id, identifier)
136
130
.await
137
-
.map_err(|e| {
138
-
error!("Failed to update email channel: {:?}", e);
139
-
ApiError::InternalError(None)
140
-
})?;
131
+
.log_db_err("updating email channel")?;
141
132
if !success {
142
133
return Err(ApiError::EmailTaken);
143
134
}
···
147
138
.user_repo
148
139
.verify_discord_channel(user_id, identifier)
149
140
.await
150
-
.map_err(|e| {
151
-
error!("Failed to update discord channel: {:?}", e);
152
-
ApiError::InternalError(None)
153
-
})?;
141
+
.log_db_err("updating discord channel")?;
154
142
}
155
143
"telegram" => {
156
144
state
157
145
.user_repo
158
146
.verify_telegram_channel(user_id, identifier)
159
147
.await
160
-
.map_err(|e| {
161
-
error!("Failed to update telegram channel: {:?}", e);
162
-
ApiError::InternalError(None)
163
-
})?;
148
+
.log_db_err("updating telegram channel")?;
164
149
}
165
150
"signal" => {
166
151
state
167
152
.user_repo
168
153
.verify_signal_channel(user_id, identifier)
169
154
.await
170
-
.map_err(|e| {
171
-
error!("Failed to update signal channel: {:?}", e);
172
-
ApiError::InternalError(None)
173
-
})?;
155
+
.log_db_err("updating signal channel")?;
174
156
}
175
157
_ => {
176
158
return Err(ApiError::InvalidChannel);
···
200
182
.user_repo
201
183
.get_verification_info(&did_typed)
202
184
.await
203
-
.map_err(|e| {
204
-
warn!(error = ?e, "Database error during signup verification");
205
-
ApiError::InternalError(None)
206
-
})?
185
+
.log_db_err("during signup verification")?
207
186
.ok_or(ApiError::AccountNotFound)?;
208
187
209
188
let is_verified = user.email_verified
···
226
205
.user_repo
227
206
.set_email_verified_flag(user.id)
228
207
.await
229
-
.map_err(|e| {
230
-
warn!(error = ?e, "Failed to update email verified status");
231
-
ApiError::InternalError(None)
232
-
})?;
208
+
.log_db_err("updating email verified status")?;
233
209
}
234
210
"discord" => {
235
211
state
236
212
.user_repo
237
213
.set_discord_verified_flag(user.id)
238
214
.await
239
-
.map_err(|e| {
240
-
warn!(error = ?e, "Failed to update discord verified status");
241
-
ApiError::InternalError(None)
242
-
})?;
215
+
.log_db_err("updating discord verified status")?;
243
216
}
244
217
"telegram" => {
245
218
state
246
219
.user_repo
247
220
.set_telegram_verified_flag(user.id)
248
221
.await
249
-
.map_err(|e| {
250
-
warn!(error = ?e, "Failed to update telegram verified status");
251
-
ApiError::InternalError(None)
252
-
})?;
222
+
.log_db_err("updating telegram verified status")?;
253
223
}
254
224
"signal" => {
255
225
state
256
226
.user_repo
257
227
.set_signal_verified_flag(user.id)
258
228
.await
259
-
.map_err(|e| {
260
-
warn!(error = ?e, "Failed to update signal verified status");
261
-
ApiError::InternalError(None)
262
-
})?;
229
+
.log_db_err("updating signal verified status")?;
263
230
}
264
231
_ => {
265
232
return Err(ApiError::InvalidChannel);
+22
-18
crates/tranquil-pds/src/auth/extractor.rs
+22
-18
crates/tranquil-pds/src/auth/extractor.rs
···
9
9
10
10
use super::{
11
11
AccountStatus, AuthSource, AuthenticatedUser, ServiceTokenClaims, ServiceTokenVerifier,
12
-
is_service_token, validate_bearer_token_for_service_auth,
12
+
is_service_token, scope_verified::VerifyScope, validate_bearer_token_for_service_auth,
13
13
};
14
14
use crate::api::error::ApiError;
15
15
use crate::oauth::scopes::{RepoAction, ScopePermissions};
···
293
293
return Ok(ExtractedAuth::Service(claims));
294
294
}
295
295
296
-
let dpop_proof = parts.headers.get("DPoP").and_then(|h| h.to_str().ok());
296
+
let dpop_proof = crate::util::get_header_str(&parts.headers, "DPoP");
297
297
let method = parts.method.as_str();
298
298
let uri = build_full_url(&parts.uri.to_string());
299
299
···
358
358
}
359
359
}
360
360
361
+
impl<P: AuthPolicy> AsRef<AuthenticatedUser> for Auth<P> {
362
+
fn as_ref(&self) -> &AuthenticatedUser {
363
+
&self.0
364
+
}
365
+
}
366
+
367
+
impl<P: AuthPolicy> VerifyScope for Auth<P> {
368
+
fn needs_scope_check(&self) -> bool {
369
+
self.0.is_oauth()
370
+
}
371
+
372
+
fn permissions(&self) -> ScopePermissions {
373
+
self.0.permissions()
374
+
}
375
+
}
376
+
361
377
impl<P: AuthPolicy> FromRequestParts<AppState> for Auth<P> {
362
378
type Rejection = AuthError;
363
379
···
418
434
) -> Result<Self, Self::Rejection> {
419
435
match extract_auth_internal(parts, state).await? {
420
436
ExtractedAuth::Service(claims) => {
421
-
let did: Did = claims
422
-
.iss
423
-
.parse()
424
-
.map_err(|_| AuthError::AuthenticationFailed)?;
437
+
let did = claims.iss.clone();
425
438
Ok(ServiceAuth { did, claims })
426
439
}
427
440
ExtractedAuth::User(_) => Err(AuthError::AuthenticationFailed),
···
438
451
) -> Result<Option<Self>, Self::Rejection> {
439
452
match extract_auth_internal(parts, state).await {
440
453
Ok(ExtractedAuth::Service(claims)) => {
441
-
let did: Did = claims
442
-
.iss
443
-
.parse()
444
-
.map_err(|_| AuthError::AuthenticationFailed)?;
454
+
let did = claims.iss.clone();
445
455
Ok(Some(ServiceAuth { did, claims }))
446
456
}
447
457
Ok(ExtractedAuth::User(_)) => Err(AuthError::AuthenticationFailed),
···
503
513
Ok(AuthAny::User(Auth(user, PhantomData)))
504
514
}
505
515
ExtractedAuth::Service(claims) => {
506
-
let did: Did = claims
507
-
.iss
508
-
.parse()
509
-
.map_err(|_| AuthError::AuthenticationFailed)?;
516
+
let did = claims.iss.clone();
510
517
Ok(AuthAny::Service(ServiceAuth { did, claims }))
511
518
}
512
519
}
···
526
533
Ok(Some(AuthAny::User(Auth(user, PhantomData))))
527
534
}
528
535
Ok(ExtractedAuth::Service(claims)) => {
529
-
let did: Did = claims
530
-
.iss
531
-
.parse()
532
-
.map_err(|_| AuthError::AuthenticationFailed)?;
536
+
let did = claims.iss.clone();
533
537
Ok(Some(AuthAny::Service(ServiceAuth { did, claims })))
534
538
}
535
539
Err(AuthError::MissingToken) => Ok(None),
+223
crates/tranquil-pds/src/auth/mfa_verified.rs
+223
crates/tranquil-pds/src/auth/mfa_verified.rs
···
1
+
use axum::response::Response;
2
+
3
+
use super::AuthenticatedUser;
4
+
use crate::state::AppState;
5
+
use crate::types::Did;
6
+
7
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8
+
pub enum MfaMethod {
9
+
Totp,
10
+
Passkey,
11
+
Password,
12
+
RecoveryCode,
13
+
SessionReauth,
14
+
}
15
+
16
+
impl MfaMethod {
17
+
pub fn as_str(&self) -> &'static str {
18
+
match self {
19
+
Self::Totp => "totp",
20
+
Self::Passkey => "passkey",
21
+
Self::Password => "password",
22
+
Self::RecoveryCode => "recovery_code",
23
+
Self::SessionReauth => "session_reauth",
24
+
}
25
+
}
26
+
}
27
+
28
+
impl std::fmt::Display for MfaMethod {
29
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30
+
write!(f, "{}", self.as_str())
31
+
}
32
+
}
33
+
34
+
pub struct MfaVerified<'a> {
35
+
user: &'a AuthenticatedUser,
36
+
method: MfaMethod,
37
+
}
38
+
39
+
impl<'a> MfaVerified<'a> {
40
+
fn new(user: &'a AuthenticatedUser, method: MfaMethod) -> Self {
41
+
Self { user, method }
42
+
}
43
+
44
+
pub(crate) fn from_totp(user: &'a AuthenticatedUser) -> Self {
45
+
Self::new(user, MfaMethod::Totp)
46
+
}
47
+
48
+
pub(crate) fn from_password(user: &'a AuthenticatedUser) -> Self {
49
+
Self::new(user, MfaMethod::Password)
50
+
}
51
+
52
+
pub(crate) fn from_recovery_code(user: &'a AuthenticatedUser) -> Self {
53
+
Self::new(user, MfaMethod::RecoveryCode)
54
+
}
55
+
56
+
pub(crate) fn from_session_reauth(user: &'a AuthenticatedUser) -> Self {
57
+
Self::new(user, MfaMethod::SessionReauth)
58
+
}
59
+
60
+
pub fn user(&self) -> &AuthenticatedUser {
61
+
self.user
62
+
}
63
+
64
+
pub fn did(&self) -> &Did {
65
+
&self.user.did
66
+
}
67
+
68
+
pub fn method(&self) -> MfaMethod {
69
+
self.method
70
+
}
71
+
}
72
+
73
+
pub async fn require_legacy_session_mfa<'a>(
74
+
state: &AppState,
75
+
user: &'a AuthenticatedUser,
76
+
) -> Result<MfaVerified<'a>, Response> {
77
+
use crate::api::server::reauth::{check_legacy_session_mfa, legacy_mfa_required_response};
78
+
79
+
if check_legacy_session_mfa(&*state.session_repo, &user.did).await {
80
+
Ok(MfaVerified::from_session_reauth(user))
81
+
} else {
82
+
Err(legacy_mfa_required_response(&*state.user_repo, &*state.session_repo, &user.did).await)
83
+
}
84
+
}
85
+
86
+
pub async fn require_reauth_window<'a>(
87
+
state: &AppState,
88
+
user: &'a AuthenticatedUser,
89
+
) -> Result<MfaVerified<'a>, Response> {
90
+
use chrono::Utc;
91
+
use crate::api::server::reauth::{REAUTH_WINDOW_SECONDS, reauth_required_response};
92
+
93
+
let status = state.session_repo.get_session_mfa_status(&user.did).await.ok().flatten();
94
+
95
+
match status {
96
+
Some(s) => {
97
+
if let Some(last_reauth) = s.last_reauth_at {
98
+
let elapsed = Utc::now().signed_duration_since(last_reauth);
99
+
if elapsed.num_seconds() <= REAUTH_WINDOW_SECONDS {
100
+
return Ok(MfaVerified::from_session_reauth(user));
101
+
}
102
+
}
103
+
Err(reauth_required_response(&*state.user_repo, &*state.session_repo, &user.did).await)
104
+
}
105
+
None => {
106
+
Err(reauth_required_response(&*state.user_repo, &*state.session_repo, &user.did).await)
107
+
}
108
+
}
109
+
}
110
+
111
+
pub async fn require_reauth_window_if_available<'a>(
112
+
state: &AppState,
113
+
user: &'a AuthenticatedUser,
114
+
) -> Result<Option<MfaVerified<'a>>, Response> {
115
+
use crate::api::server::reauth::{check_reauth_required_cached, reauth_required_response};
116
+
117
+
let has_password = state
118
+
.user_repo
119
+
.has_password_by_did(&user.did)
120
+
.await
121
+
.ok()
122
+
.flatten()
123
+
.unwrap_or(false);
124
+
let has_passkeys = state
125
+
.user_repo
126
+
.has_passkeys(&user.did)
127
+
.await
128
+
.unwrap_or(false);
129
+
let has_totp = state
130
+
.user_repo
131
+
.has_totp_enabled(&user.did)
132
+
.await
133
+
.unwrap_or(false);
134
+
135
+
let has_any_reauth_method = has_password || has_passkeys || has_totp;
136
+
137
+
if !has_any_reauth_method {
138
+
return Ok(None);
139
+
}
140
+
141
+
if check_reauth_required_cached(&*state.session_repo, &state.cache, &user.did).await {
142
+
Err(reauth_required_response(&*state.user_repo, &*state.session_repo, &user.did).await)
143
+
} else {
144
+
Ok(Some(MfaVerified::from_session_reauth(user)))
145
+
}
146
+
}
147
+
148
+
pub async fn verify_password_mfa<'a>(
149
+
state: &AppState,
150
+
user: &'a AuthenticatedUser,
151
+
password: &str,
152
+
) -> Result<MfaVerified<'a>, crate::api::error::ApiError> {
153
+
let hash = state
154
+
.user_repo
155
+
.get_password_hash_by_did(&user.did)
156
+
.await
157
+
.ok()
158
+
.flatten();
159
+
160
+
match hash {
161
+
Some(h) => {
162
+
if bcrypt::verify(password, &h).unwrap_or(false) {
163
+
Ok(MfaVerified::from_password(user))
164
+
} else {
165
+
Err(crate::api::error::ApiError::InvalidPassword(
166
+
"Password is incorrect".into(),
167
+
))
168
+
}
169
+
}
170
+
None => Err(crate::api::error::ApiError::AccountNotFound),
171
+
}
172
+
}
173
+
174
+
pub async fn verify_totp_mfa<'a>(
175
+
state: &AppState,
176
+
user: &'a AuthenticatedUser,
177
+
code: &str,
178
+
) -> Result<MfaVerified<'a>, crate::api::error::ApiError> {
179
+
use crate::auth::{decrypt_totp_secret, is_backup_code_format, verify_totp_code};
180
+
181
+
let code = code.trim();
182
+
183
+
if is_backup_code_format(code) {
184
+
let backup_codes = state.user_repo.get_unused_backup_codes(&user.did).await.ok().unwrap_or_default();
185
+
let code_upper = code.to_uppercase();
186
+
187
+
let matched = backup_codes
188
+
.iter()
189
+
.find(|row| crate::auth::verify_backup_code(&code_upper, &row.code_hash));
190
+
191
+
return match matched {
192
+
Some(row) => {
193
+
let _ = state.user_repo.mark_backup_code_used(row.id).await;
194
+
Ok(MfaVerified::from_recovery_code(user))
195
+
}
196
+
None => Err(crate::api::error::ApiError::InvalidCode(Some(
197
+
"Invalid backup code".into(),
198
+
))),
199
+
};
200
+
}
201
+
202
+
let totp_record = match state.user_repo.get_totp_record(&user.did).await {
203
+
Ok(Some(row)) if row.verified => row,
204
+
_ => {
205
+
return Err(crate::api::error::ApiError::TotpNotEnabled);
206
+
}
207
+
};
208
+
209
+
let secret = decrypt_totp_secret(
210
+
&totp_record.secret_encrypted,
211
+
totp_record.encryption_version,
212
+
)
213
+
.map_err(|_| crate::api::error::ApiError::InternalError(None))?;
214
+
215
+
if verify_totp_code(&secret, code) {
216
+
let _ = state.user_repo.update_totp_last_used(&user.did).await;
217
+
Ok(MfaVerified::from_totp(user))
218
+
} else {
219
+
Err(crate::api::error::ApiError::InvalidCode(Some(
220
+
"Invalid verification code".into(),
221
+
)))
222
+
}
223
+
}
+10
crates/tranquil-pds/src/auth/mod.rs
+10
crates/tranquil-pds/src/auth/mod.rs
···
11
11
use tranquil_db_traits::OAuthRepository;
12
12
13
13
pub mod extractor;
14
+
pub mod mfa_verified;
14
15
pub mod scope_check;
16
+
pub mod scope_verified;
15
17
pub mod service;
16
18
pub mod verification_token;
17
19
pub mod webauthn;
···
20
22
Active, Admin, AnyUser, Auth, AuthAny, AuthError, AuthPolicy, ExtractedToken, NotTakendown,
21
23
Permissive, ServiceAuth, extract_auth_token_from_header, extract_bearer_token_from_header,
22
24
};
25
+
pub use mfa_verified::{
26
+
MfaMethod, MfaVerified, require_legacy_session_mfa, require_reauth_window,
27
+
require_reauth_window_if_available, verify_password_mfa, verify_totp_mfa,
28
+
};
29
+
pub use scope_verified::{
30
+
AccountManage, AccountRead, BlobUpload, IdentityAccess, RepoCreate, RepoDelete, RepoUpdate,
31
+
RpcCall, ScopeAction, ScopeVerificationError, ScopeVerified, VerifyScope,
32
+
};
23
33
pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token};
24
34
25
35
pub use tranquil_auth::{
+277
crates/tranquil-pds/src/auth/scope_verified.rs
+277
crates/tranquil-pds/src/auth/scope_verified.rs
···
1
+
use std::marker::PhantomData;
2
+
3
+
use axum::response::{IntoResponse, Response};
4
+
5
+
use crate::api::error::ApiError;
6
+
use crate::oauth::scopes::{AccountAction, AccountAttr, IdentityAttr, RepoAction, ScopePermissions};
7
+
8
+
use super::AuthenticatedUser;
9
+
10
+
#[derive(Debug)]
11
+
pub struct ScopeVerificationError {
12
+
message: String,
13
+
}
14
+
15
+
impl ScopeVerificationError {
16
+
pub fn new(message: impl Into<String>) -> Self {
17
+
Self {
18
+
message: message.into(),
19
+
}
20
+
}
21
+
22
+
pub fn message(&self) -> &str {
23
+
&self.message
24
+
}
25
+
}
26
+
27
+
impl std::fmt::Display for ScopeVerificationError {
28
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29
+
write!(f, "{}", self.message)
30
+
}
31
+
}
32
+
33
+
impl std::error::Error for ScopeVerificationError {}
34
+
35
+
impl IntoResponse for ScopeVerificationError {
36
+
fn into_response(self) -> Response {
37
+
ApiError::InsufficientScope(Some(self.message)).into_response()
38
+
}
39
+
}
40
+
41
+
mod private {
42
+
pub trait Sealed {}
43
+
}
44
+
45
+
pub trait ScopeAction: private::Sealed {}
46
+
47
+
pub struct RepoCreate;
48
+
pub struct RepoUpdate;
49
+
pub struct RepoDelete;
50
+
pub struct BlobUpload;
51
+
pub struct RpcCall;
52
+
pub struct AccountRead;
53
+
pub struct AccountManage;
54
+
pub struct IdentityAccess;
55
+
56
+
impl private::Sealed for RepoCreate {}
57
+
impl private::Sealed for RepoUpdate {}
58
+
impl private::Sealed for RepoDelete {}
59
+
impl private::Sealed for BlobUpload {}
60
+
impl private::Sealed for RpcCall {}
61
+
impl private::Sealed for AccountRead {}
62
+
impl private::Sealed for AccountManage {}
63
+
impl private::Sealed for IdentityAccess {}
64
+
65
+
impl ScopeAction for RepoCreate {}
66
+
impl ScopeAction for RepoUpdate {}
67
+
impl ScopeAction for RepoDelete {}
68
+
impl ScopeAction for BlobUpload {}
69
+
impl ScopeAction for RpcCall {}
70
+
impl ScopeAction for AccountRead {}
71
+
impl ScopeAction for AccountManage {}
72
+
impl ScopeAction for IdentityAccess {}
73
+
74
+
pub struct ScopeVerified<'a, A: ScopeAction> {
75
+
user: &'a AuthenticatedUser,
76
+
_action: PhantomData<A>,
77
+
}
78
+
79
+
impl<'a, A: ScopeAction> ScopeVerified<'a, A> {
80
+
pub fn user(&self) -> &AuthenticatedUser {
81
+
self.user
82
+
}
83
+
84
+
pub fn did(&self) -> &crate::types::Did {
85
+
&self.user.did
86
+
}
87
+
88
+
pub fn is_admin(&self) -> bool {
89
+
self.user.is_admin
90
+
}
91
+
92
+
pub fn controller_did(&self) -> Option<&crate::types::Did> {
93
+
self.user.controller_did.as_ref()
94
+
}
95
+
}
96
+
97
+
pub trait VerifyScope {
98
+
fn needs_scope_check(&self) -> bool;
99
+
fn permissions(&self) -> ScopePermissions;
100
+
101
+
fn verify_repo_create<'a>(
102
+
&'a self,
103
+
collection: &str,
104
+
) -> Result<ScopeVerified<'a, RepoCreate>, ScopeVerificationError>
105
+
where
106
+
Self: AsRef<AuthenticatedUser>,
107
+
{
108
+
if !self.needs_scope_check() {
109
+
return Ok(ScopeVerified {
110
+
user: self.as_ref(),
111
+
_action: PhantomData,
112
+
});
113
+
}
114
+
self.permissions()
115
+
.assert_repo(RepoAction::Create, collection)
116
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
117
+
Ok(ScopeVerified {
118
+
user: self.as_ref(),
119
+
_action: PhantomData,
120
+
})
121
+
}
122
+
123
+
fn verify_repo_update<'a>(
124
+
&'a self,
125
+
collection: &str,
126
+
) -> Result<ScopeVerified<'a, RepoUpdate>, ScopeVerificationError>
127
+
where
128
+
Self: AsRef<AuthenticatedUser>,
129
+
{
130
+
if !self.needs_scope_check() {
131
+
return Ok(ScopeVerified {
132
+
user: self.as_ref(),
133
+
_action: PhantomData,
134
+
});
135
+
}
136
+
self.permissions()
137
+
.assert_repo(RepoAction::Update, collection)
138
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
139
+
Ok(ScopeVerified {
140
+
user: self.as_ref(),
141
+
_action: PhantomData,
142
+
})
143
+
}
144
+
145
+
fn verify_repo_delete<'a>(
146
+
&'a self,
147
+
collection: &str,
148
+
) -> Result<ScopeVerified<'a, RepoDelete>, ScopeVerificationError>
149
+
where
150
+
Self: AsRef<AuthenticatedUser>,
151
+
{
152
+
if !self.needs_scope_check() {
153
+
return Ok(ScopeVerified {
154
+
user: self.as_ref(),
155
+
_action: PhantomData,
156
+
});
157
+
}
158
+
self.permissions()
159
+
.assert_repo(RepoAction::Delete, collection)
160
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
161
+
Ok(ScopeVerified {
162
+
user: self.as_ref(),
163
+
_action: PhantomData,
164
+
})
165
+
}
166
+
167
+
fn verify_blob_upload<'a>(
168
+
&'a self,
169
+
mime_type: &str,
170
+
) -> Result<ScopeVerified<'a, BlobUpload>, ScopeVerificationError>
171
+
where
172
+
Self: AsRef<AuthenticatedUser>,
173
+
{
174
+
if !self.needs_scope_check() {
175
+
return Ok(ScopeVerified {
176
+
user: self.as_ref(),
177
+
_action: PhantomData,
178
+
});
179
+
}
180
+
self.permissions()
181
+
.assert_blob(mime_type)
182
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
183
+
Ok(ScopeVerified {
184
+
user: self.as_ref(),
185
+
_action: PhantomData,
186
+
})
187
+
}
188
+
189
+
fn verify_rpc<'a>(
190
+
&'a self,
191
+
aud: &str,
192
+
lxm: &str,
193
+
) -> Result<ScopeVerified<'a, RpcCall>, ScopeVerificationError>
194
+
where
195
+
Self: AsRef<AuthenticatedUser>,
196
+
{
197
+
if !self.needs_scope_check() {
198
+
return Ok(ScopeVerified {
199
+
user: self.as_ref(),
200
+
_action: PhantomData,
201
+
});
202
+
}
203
+
self.permissions()
204
+
.assert_rpc(aud, lxm)
205
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
206
+
Ok(ScopeVerified {
207
+
user: self.as_ref(),
208
+
_action: PhantomData,
209
+
})
210
+
}
211
+
212
+
fn verify_account_read<'a>(
213
+
&'a self,
214
+
attr: AccountAttr,
215
+
) -> Result<ScopeVerified<'a, AccountRead>, ScopeVerificationError>
216
+
where
217
+
Self: AsRef<AuthenticatedUser>,
218
+
{
219
+
if !self.needs_scope_check() {
220
+
return Ok(ScopeVerified {
221
+
user: self.as_ref(),
222
+
_action: PhantomData,
223
+
});
224
+
}
225
+
self.permissions()
226
+
.assert_account(attr, AccountAction::Read)
227
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
228
+
Ok(ScopeVerified {
229
+
user: self.as_ref(),
230
+
_action: PhantomData,
231
+
})
232
+
}
233
+
234
+
fn verify_account_manage<'a>(
235
+
&'a self,
236
+
attr: AccountAttr,
237
+
) -> Result<ScopeVerified<'a, AccountManage>, ScopeVerificationError>
238
+
where
239
+
Self: AsRef<AuthenticatedUser>,
240
+
{
241
+
if !self.needs_scope_check() {
242
+
return Ok(ScopeVerified {
243
+
user: self.as_ref(),
244
+
_action: PhantomData,
245
+
});
246
+
}
247
+
self.permissions()
248
+
.assert_account(attr, AccountAction::Manage)
249
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
250
+
Ok(ScopeVerified {
251
+
user: self.as_ref(),
252
+
_action: PhantomData,
253
+
})
254
+
}
255
+
256
+
fn verify_identity<'a>(
257
+
&'a self,
258
+
attr: IdentityAttr,
259
+
) -> Result<ScopeVerified<'a, IdentityAccess>, ScopeVerificationError>
260
+
where
261
+
Self: AsRef<AuthenticatedUser>,
262
+
{
263
+
if !self.needs_scope_check() {
264
+
return Ok(ScopeVerified {
265
+
user: self.as_ref(),
266
+
_action: PhantomData,
267
+
});
268
+
}
269
+
self.permissions()
270
+
.assert_identity(attr)
271
+
.map_err(|e| ScopeVerificationError::new(e.to_string()))?;
272
+
Ok(ScopeVerified {
273
+
user: self.as_ref(),
274
+
_action: PhantomData,
275
+
})
276
+
}
277
+
}
+10
-8
crates/tranquil-pds/src/auth/service.rs
+10
-8
crates/tranquil-pds/src/auth/service.rs
···
1
+
use crate::types::Did;
2
+
use crate::util::pds_hostname;
1
3
use anyhow::{Result, anyhow};
2
4
use base64::Engine as _;
3
5
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
···
42
44
43
45
#[derive(Debug, Clone, Serialize, Deserialize)]
44
46
pub struct ServiceTokenClaims {
45
-
pub iss: String,
47
+
pub iss: Did,
46
48
#[serde(default)]
47
-
pub sub: Option<String>,
48
-
pub aud: String,
49
+
pub sub: Option<Did>,
50
+
pub aud: Did,
49
51
pub exp: usize,
50
52
#[serde(default)]
51
53
pub iat: Option<usize>,
···
56
58
}
57
59
58
60
impl ServiceTokenClaims {
59
-
pub fn subject(&self) -> &str {
60
-
self.sub.as_deref().unwrap_or(&self.iss)
61
+
pub fn subject(&self) -> &Did {
62
+
self.sub.as_ref().unwrap_or(&self.iss)
61
63
}
62
64
}
63
65
···
79
81
.unwrap_or_else(|_| "https://plc.directory".to_string());
80
82
81
83
let pds_hostname =
82
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
84
+
pds_hostname();
83
85
let pds_did = format!("did:web:{}", pds_hostname);
84
86
85
87
let client = Client::builder()
···
130
132
return Err(anyhow!("Token expired"));
131
133
}
132
134
133
-
if claims.aud != self.pds_did {
135
+
if claims.aud.as_str() != self.pds_did {
134
136
return Err(anyhow!(
135
137
"Invalid audience: expected {}, got {}",
136
138
self.pds_did,
···
154
156
}
155
157
}
156
158
157
-
let did = &claims.iss;
159
+
let did = claims.iss.as_str();
158
160
let public_key = self.resolve_signing_key(did).await?;
159
161
160
162
let signature_bytes = URL_SAFE_NO_PAD
+1
-1
crates/tranquil-pds/src/auth/webauthn.rs
+1
-1
crates/tranquil-pds/src/auth/webauthn.rs
···
7
7
8
8
impl WebAuthnConfig {
9
9
pub fn new(hostname: &str) -> Result<Self, String> {
10
-
let rp_id = hostname.to_string();
10
+
let rp_id = hostname.split(':').next().unwrap_or(hostname).to_string();
11
11
let rp_origin = Url::parse(&format!("https://{}", hostname))
12
12
.map_err(|e| format!("Invalid origin URL: {}", e))?;
13
13
+6
-2
crates/tranquil-pds/src/crawlers.rs
+6
-2
crates/tranquil-pds/src/crawlers.rs
···
1
1
use crate::circuit_breaker::CircuitBreaker;
2
2
use crate::sync::firehose::SequencedEvent;
3
+
use crate::util::pds_hostname;
3
4
use reqwest::Client;
4
5
use std::sync::Arc;
5
6
use std::sync::atomic::{AtomicU64, Ordering};
···
40
41
}
41
42
42
43
pub fn from_env() -> Option<Self> {
43
-
let hostname = std::env::var("PDS_HOSTNAME").ok()?;
44
+
let hostname = pds_hostname();
45
+
if hostname == "localhost" {
46
+
return None;
47
+
}
44
48
45
49
let crawler_urls: Vec<String> = std::env::var("CRAWLERS")
46
50
.unwrap_or_default()
···
53
57
return None;
54
58
}
55
59
56
-
Some(Self::new(hostname, crawler_urls))
60
+
Some(Self::new(hostname.to_string(), crawler_urls))
57
61
}
58
62
59
63
fn should_notify(&self) -> bool {
+5
crates/tranquil-pds/src/delegation/mod.rs
+5
crates/tranquil-pds/src/delegation/mod.rs
···
1
+
pub mod roles;
1
2
pub mod scopes;
2
3
4
+
pub use roles::{
5
+
CanAddControllers, CanControlAccounts, verify_can_add_controllers, verify_can_be_controller,
6
+
verify_can_control_accounts,
7
+
};
3
8
pub use scopes::{SCOPE_PRESETS, ScopePreset, intersect_scopes};
4
9
pub use tranquil_db_traits::DelegationActionType;
+88
crates/tranquil-pds/src/delegation/roles.rs
+88
crates/tranquil-pds/src/delegation/roles.rs
···
1
+
use axum::response::{IntoResponse, Response};
2
+
3
+
use crate::api::error::ApiError;
4
+
use crate::auth::AuthenticatedUser;
5
+
use crate::state::AppState;
6
+
use crate::types::Did;
7
+
8
+
pub struct CanAddControllers<'a> {
9
+
user: &'a AuthenticatedUser,
10
+
}
11
+
12
+
pub struct CanControlAccounts<'a> {
13
+
user: &'a AuthenticatedUser,
14
+
}
15
+
16
+
impl<'a> CanAddControllers<'a> {
17
+
pub fn did(&self) -> &Did {
18
+
&self.user.did
19
+
}
20
+
21
+
pub fn user(&self) -> &AuthenticatedUser {
22
+
self.user
23
+
}
24
+
}
25
+
26
+
impl<'a> CanControlAccounts<'a> {
27
+
pub fn did(&self) -> &Did {
28
+
&self.user.did
29
+
}
30
+
31
+
pub fn user(&self) -> &AuthenticatedUser {
32
+
self.user
33
+
}
34
+
}
35
+
36
+
pub async fn verify_can_add_controllers<'a>(
37
+
state: &AppState,
38
+
user: &'a AuthenticatedUser,
39
+
) -> Result<CanAddControllers<'a>, Response> {
40
+
match state.delegation_repo.controls_any_accounts(&user.did).await {
41
+
Ok(true) => Err(ApiError::InvalidDelegation(
42
+
"Cannot add controllers to an account that controls other accounts".into(),
43
+
)
44
+
.into_response()),
45
+
Ok(false) => Ok(CanAddControllers { user }),
46
+
Err(e) => {
47
+
tracing::error!("Failed to check delegation status: {:?}", e);
48
+
Err(ApiError::InternalError(Some("Failed to verify delegation status".into()))
49
+
.into_response())
50
+
}
51
+
}
52
+
}
53
+
54
+
pub async fn verify_can_control_accounts<'a>(
55
+
state: &AppState,
56
+
user: &'a AuthenticatedUser,
57
+
) -> Result<CanControlAccounts<'a>, Response> {
58
+
match state.delegation_repo.has_any_controllers(&user.did).await {
59
+
Ok(true) => Err(ApiError::InvalidDelegation(
60
+
"Cannot create delegated accounts from a controlled account".into(),
61
+
)
62
+
.into_response()),
63
+
Ok(false) => Ok(CanControlAccounts { user }),
64
+
Err(e) => {
65
+
tracing::error!("Failed to check controller status: {:?}", e);
66
+
Err(ApiError::InternalError(Some("Failed to verify controller status".into()))
67
+
.into_response())
68
+
}
69
+
}
70
+
}
71
+
72
+
pub async fn verify_can_be_controller(
73
+
state: &AppState,
74
+
controller_did: &Did,
75
+
) -> Result<(), Response> {
76
+
match state.delegation_repo.has_any_controllers(controller_did).await {
77
+
Ok(true) => Err(ApiError::InvalidDelegation(
78
+
"Cannot add a controlled account as a controller".into(),
79
+
)
80
+
.into_response()),
81
+
Ok(false) => Ok(()),
82
+
Err(e) => {
83
+
tracing::error!("Failed to check controller status: {:?}", e);
84
+
Err(ApiError::InternalError(Some("Failed to verify controller status".into()))
85
+
.into_response())
86
+
}
87
+
}
88
+
}
+32
-64
crates/tranquil-pds/src/oauth/endpoints/delegation.rs
+32
-64
crates/tranquil-pds/src/oauth/endpoints/delegation.rs
···
1
1
use crate::auth::{Active, Auth};
2
2
use crate::delegation::DelegationActionType;
3
-
use crate::state::{AppState, RateLimitKind};
3
+
use crate::rate_limit::{LoginLimit, OAuthRateLimited, TotpVerifyLimit};
4
+
use crate::state::AppState;
4
5
use crate::types::PlainPassword;
5
6
use crate::util::extract_client_ip;
6
7
use axum::{
7
8
Json,
8
9
extract::State,
9
-
http::{HeaderMap, StatusCode},
10
+
http::HeaderMap,
10
11
response::{IntoResponse, Response},
11
12
};
12
13
use serde::{Deserialize, Serialize};
···
35
36
36
37
pub async fn delegation_auth(
37
38
State(state): State<AppState>,
39
+
rate_limit: OAuthRateLimited<LoginLimit>,
38
40
headers: HeaderMap,
39
41
Json(form): Json<DelegationAuthSubmit>,
40
42
) -> Response {
41
-
let client_ip = extract_client_ip(&headers);
42
-
if !state
43
-
.check_rate_limit(RateLimitKind::Login, &client_ip)
44
-
.await
45
-
{
46
-
return (
47
-
StatusCode::TOO_MANY_REQUESTS,
48
-
Json(DelegationAuthResponse {
49
-
success: false,
50
-
needs_totp: None,
51
-
redirect_uri: None,
52
-
error: Some("Too many login attempts. Please try again later.".to_string()),
53
-
}),
54
-
)
55
-
.into_response();
56
-
}
57
-
43
+
let client_ip = rate_limit.client_ip();
58
44
let request_id = RequestId::from(form.request_uri.clone());
59
45
let request = match state
60
46
.oauth_repo
···
82
68
}
83
69
};
84
70
85
-
let delegated_did_str = match form.delegated_did.as_ref().or(request.did.as_ref()) {
86
-
Some(did) => did.clone(),
87
-
None => {
88
-
return Json(DelegationAuthResponse {
89
-
success: false,
90
-
needs_totp: None,
91
-
redirect_uri: None,
92
-
error: Some("No delegated account selected".to_string()),
93
-
})
94
-
.into_response();
95
-
}
96
-
};
97
-
98
-
let delegated_did: Did = match delegated_did_str.parse() {
99
-
Ok(d) => d,
100
-
Err(_) => {
101
-
return Json(DelegationAuthResponse {
102
-
success: false,
103
-
needs_totp: None,
104
-
redirect_uri: None,
105
-
error: Some("Invalid delegated DID".to_string()),
106
-
})
107
-
.into_response();
71
+
let delegated_did: Did = if let Some(did_str) = form.delegated_did.as_ref() {
72
+
match did_str.parse() {
73
+
Ok(d) => d,
74
+
Err(_) => {
75
+
return Json(DelegationAuthResponse {
76
+
success: false,
77
+
needs_totp: None,
78
+
redirect_uri: None,
79
+
error: Some("Invalid delegated DID".to_string()),
80
+
})
81
+
.into_response();
82
+
}
108
83
}
84
+
} else if let Some(did) = request.did.as_ref() {
85
+
did.clone()
86
+
} else {
87
+
return Json(DelegationAuthResponse {
88
+
success: false,
89
+
needs_totp: None,
90
+
redirect_uri: None,
91
+
error: Some("No delegated account selected".to_string()),
92
+
})
93
+
.into_response();
109
94
};
110
95
111
96
let controller_did: Did = match form.controller_did.parse() {
···
249
234
.into_response();
250
235
}
251
236
252
-
let ip = extract_client_ip(&headers);
253
237
let user_agent = headers
254
238
.get("user-agent")
255
239
.and_then(|v| v.to_str().ok())
···
266
250
"client_id": request.client_id,
267
251
"granted_scopes": grant.granted_scopes
268
252
})),
269
-
Some(&ip),
253
+
Some(client_ip),
270
254
user_agent.as_deref(),
271
255
)
272
256
.await;
···
291
275
292
276
pub async fn delegation_totp_verify(
293
277
State(state): State<AppState>,
278
+
rate_limit: OAuthRateLimited<TotpVerifyLimit>,
294
279
headers: HeaderMap,
295
280
Json(form): Json<DelegationTotpSubmit>,
296
281
) -> Response {
297
-
let client_ip = extract_client_ip(&headers);
298
-
if !state
299
-
.check_rate_limit(RateLimitKind::TotpVerify, &client_ip)
300
-
.await
301
-
{
302
-
return (
303
-
StatusCode::TOO_MANY_REQUESTS,
304
-
Json(DelegationAuthResponse {
305
-
success: false,
306
-
needs_totp: None,
307
-
redirect_uri: None,
308
-
error: Some("Too many verification attempts. Please try again later.".to_string()),
309
-
}),
310
-
)
311
-
.into_response();
312
-
}
313
-
282
+
let client_ip = rate_limit.client_ip();
314
283
let totp_request_id = RequestId::from(form.request_uri.clone());
315
284
let request = match state
316
285
.oauth_repo
···
420
389
.into_response();
421
390
}
422
391
423
-
let ip = extract_client_ip(&headers);
424
392
let user_agent = headers
425
393
.get("user-agent")
426
394
.and_then(|v| v.to_str().ok())
···
437
405
"client_id": request.client_id,
438
406
"granted_scopes": grant.granted_scopes
439
407
})),
440
-
Some(&ip),
408
+
Some(client_ip),
441
409
user_agent.as_deref(),
442
410
)
443
411
.await;
···
564
532
.into_response();
565
533
}
566
534
567
-
let ip = extract_client_ip(&headers);
535
+
let ip = extract_client_ip(&headers, None);
568
536
let user_agent = headers
569
537
.get("user-agent")
570
538
.and_then(|v| v.to_str().ok())
+3
-2
crates/tranquil-pds/src/oauth/endpoints/metadata.rs
+3
-2
crates/tranquil-pds/src/oauth/endpoints/metadata.rs
···
1
1
use crate::oauth::jwks::{JwkSet, create_jwk_set};
2
2
use crate::state::AppState;
3
+
use crate::util::pds_hostname;
3
4
use axum::{Json, extract::State};
4
5
use serde::{Deserialize, Serialize};
5
6
···
57
58
pub async fn oauth_protected_resource(
58
59
State(_state): State<AppState>,
59
60
) -> Json<ProtectedResourceMetadata> {
60
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
61
+
let pds_hostname = pds_hostname();
61
62
let public_url = format!("https://{}", pds_hostname);
62
63
Json(ProtectedResourceMetadata {
63
64
resource: public_url.clone(),
···
71
72
pub async fn oauth_authorization_server(
72
73
State(_state): State<AppState>,
73
74
) -> Json<AuthorizationServerMetadata> {
74
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
75
+
let pds_hostname = pds_hostname();
75
76
let issuer = format!("https://{}", pds_hostname);
76
77
Json(AuthorizationServerMetadata {
77
78
issuer: issuer.clone(),
+57
-50
crates/tranquil-pds/src/oauth/endpoints/par.rs
+57
-50
crates/tranquil-pds/src/oauth/endpoints/par.rs
···
1
1
use crate::oauth::{
2
-
AuthorizationRequestParameters, ClientAuth, ClientMetadataCache, OAuthError, RequestData,
3
-
RequestId,
2
+
AuthorizationRequestParameters, ClientAuth, ClientMetadataCache, CodeChallengeMethod,
3
+
OAuthError, Prompt, RequestData, RequestId, ResponseMode, ResponseType,
4
4
scopes::{ParsedScope, parse_scope},
5
5
};
6
-
use crate::state::{AppState, RateLimitKind};
6
+
use crate::rate_limit::{OAuthParLimit, OAuthRateLimited};
7
+
use crate::state::AppState;
7
8
use axum::body::Bytes;
8
9
use axum::{Json, extract::State, http::HeaderMap};
9
10
use chrono::{Duration, Utc};
···
49
50
50
51
pub async fn pushed_authorization_request(
51
52
State(state): State<AppState>,
53
+
_rate_limit: OAuthRateLimited<OAuthParLimit>,
52
54
headers: HeaderMap,
53
55
body: Bytes,
54
56
) -> Result<(axum::http::StatusCode, Json<ParResponse>), OAuthError> {
···
70
72
.to_string(),
71
73
));
72
74
};
73
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
74
-
if !state
75
-
.check_rate_limit(RateLimitKind::OAuthPar, &client_ip)
76
-
.await
77
-
{
78
-
tracing::warn!(ip = %client_ip, "OAuth PAR rate limit exceeded");
79
-
return Err(OAuthError::RateLimited);
80
-
}
81
-
if request.response_type != "code" {
82
-
return Err(OAuthError::InvalidRequest(
83
-
"response_type must be 'code'".to_string(),
84
-
));
85
-
}
75
+
let response_type = parse_response_type(&request.response_type)?;
86
76
let code_challenge = request
87
77
.code_challenge
88
78
.as_ref()
89
79
.filter(|s| !s.is_empty())
90
80
.ok_or_else(|| OAuthError::InvalidRequest("code_challenge is required".to_string()))?;
91
-
let code_challenge_method = request.code_challenge_method.as_deref().unwrap_or("");
92
-
if code_challenge_method != "S256" {
93
-
return Err(OAuthError::InvalidRequest(
94
-
"code_challenge_method must be 'S256'".to_string(),
95
-
));
96
-
}
81
+
let code_challenge_method = parse_code_challenge_method(request.code_challenge_method.as_deref())?;
97
82
let client_cache = ClientMetadataCache::new(3600);
98
83
let client_metadata = client_cache.get(&request.client_id).await?;
99
84
client_cache.validate_redirect_uri(&client_metadata, &request.redirect_uri)?;
···
101
86
let validated_scope = validate_scope(&request.scope, &client_metadata)?;
102
87
let request_id = RequestId::generate();
103
88
let expires_at = Utc::now() + Duration::seconds(PAR_EXPIRY_SECONDS);
104
-
let response_mode = match request.response_mode.as_deref() {
105
-
Some("fragment") => Some("fragment".to_string()),
106
-
Some("query") | None => None,
107
-
Some(mode) => {
108
-
return Err(OAuthError::InvalidRequest(format!(
109
-
"Unsupported response_mode: {}",
110
-
mode
111
-
)));
112
-
}
113
-
};
114
-
let prompt = validate_prompt(&request.prompt)?;
89
+
let response_mode = parse_response_mode(request.response_mode.as_deref())?;
90
+
let prompt = parse_prompt(request.prompt.as_deref())?;
115
91
let parameters = AuthorizationRequestParameters {
116
-
response_type: request.response_type,
92
+
response_type,
117
93
client_id: request.client_id.clone(),
118
94
redirect_uri: request.redirect_uri,
119
95
scope: validated_scope,
120
96
state: request.state,
121
97
code_challenge: code_challenge.clone(),
122
-
code_challenge_method: code_challenge_method.to_string(),
98
+
code_challenge_method,
123
99
response_mode,
124
100
login_hint: request.login_hint,
125
101
dpop_jkt: request.dpop_jkt,
···
266
242
false
267
243
}
268
244
269
-
fn validate_prompt(prompt: &Option<String>) -> Result<Option<String>, OAuthError> {
270
-
const VALID_PROMPTS: &[&str] = &["none", "login", "consent", "select_account", "create"];
245
+
fn parse_response_type(value: &str) -> Result<ResponseType, OAuthError> {
246
+
match value {
247
+
"code" => Ok(ResponseType::Code),
248
+
other => Err(OAuthError::InvalidRequest(format!(
249
+
"response_type must be 'code', got '{}'",
250
+
other
251
+
))),
252
+
}
253
+
}
271
254
272
-
match prompt {
273
-
None => Ok(None),
274
-
Some(p) if p.is_empty() => Ok(None),
275
-
Some(p) => {
276
-
if VALID_PROMPTS.contains(&p.as_str()) {
277
-
Ok(Some(p.clone()))
278
-
} else {
279
-
Err(OAuthError::InvalidRequest(format!(
280
-
"Unsupported prompt value: {}",
281
-
p
282
-
)))
283
-
}
284
-
}
255
+
fn parse_code_challenge_method(value: Option<&str>) -> Result<CodeChallengeMethod, OAuthError> {
256
+
match value {
257
+
Some("S256") | None => Ok(CodeChallengeMethod::S256),
258
+
Some("plain") => Err(OAuthError::InvalidRequest(
259
+
"code_challenge_method 'plain' is not allowed, use 'S256'".to_string(),
260
+
)),
261
+
Some(other) => Err(OAuthError::InvalidRequest(format!(
262
+
"Unsupported code_challenge_method: {}",
263
+
other
264
+
))),
265
+
}
266
+
}
267
+
268
+
fn parse_response_mode(value: Option<&str>) -> Result<Option<ResponseMode>, OAuthError> {
269
+
match value {
270
+
None | Some("query") => Ok(None),
271
+
Some("fragment") => Ok(Some(ResponseMode::Fragment)),
272
+
Some("form_post") => Ok(Some(ResponseMode::FormPost)),
273
+
Some(other) => Err(OAuthError::InvalidRequest(format!(
274
+
"Unsupported response_mode: {}",
275
+
other
276
+
))),
277
+
}
278
+
}
279
+
280
+
fn parse_prompt(value: Option<&str>) -> Result<Option<Prompt>, OAuthError> {
281
+
match value {
282
+
None | Some("") => Ok(None),
283
+
Some("none") => Ok(Some(Prompt::None)),
284
+
Some("login") => Ok(Some(Prompt::Login)),
285
+
Some("consent") => Ok(Some(Prompt::Consent)),
286
+
Some("select_account") => Ok(Some(Prompt::SelectAccount)),
287
+
Some("create") => Ok(Some(Prompt::Create)),
288
+
Some(other) => Err(OAuthError::InvalidRequest(format!(
289
+
"Unsupported prompt value: {}",
290
+
other
291
+
))),
285
292
}
286
293
}
+42
-42
crates/tranquil-pds/src/oauth/endpoints/token/grants.rs
+42
-42
crates/tranquil-pds/src/oauth/endpoints/token/grants.rs
···
3
3
use crate::config::AuthConfig;
4
4
use crate::delegation::intersect_scopes;
5
5
use crate::oauth::{
6
-
AuthFlowState, ClientAuth, ClientMetadataCache, DPoPVerifier, OAuthError, RefreshToken,
7
-
TokenData, TokenId,
6
+
AuthFlow, ClientAuth, ClientMetadataCache, DPoPVerifier, OAuthError, RefreshToken, TokenData,
7
+
TokenId,
8
8
db::{enforce_token_limit_for_user, lookup_refresh_token},
9
9
scopes::expand_include_scopes,
10
10
verify_client_auth,
11
11
};
12
12
use crate::state::AppState;
13
+
use crate::util::pds_hostname;
13
14
use axum::Json;
14
15
use axum::http::HeaderMap;
15
16
use chrono::{Duration, Utc};
···
51
52
.map_err(crate::oauth::db_err_to_oauth)?
52
53
.ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?;
53
54
54
-
let flow_state = AuthFlowState::from_request_data(&auth_request);
55
-
if flow_state.is_expired() {
56
-
return Err(OAuthError::InvalidGrant(
57
-
"Authorization code has expired".to_string(),
58
-
));
59
-
}
60
-
if !flow_state.can_exchange() {
61
-
return Err(OAuthError::InvalidGrant(
62
-
"Authorization not completed".to_string(),
63
-
));
64
-
}
55
+
let flow = AuthFlow::from_request_data(auth_request).map_err(|_| {
56
+
OAuthError::InvalidGrant("Authorization code has expired".to_string())
57
+
})?;
58
+
59
+
let authorized = flow.require_authorized().map_err(|_| {
60
+
OAuthError::InvalidGrant("Authorization not completed".to_string())
61
+
})?;
65
62
66
63
if let Some(request_client_id) = &request.client_auth.client_id
67
-
&& request_client_id != &auth_request.client_id
64
+
&& request_client_id != &authorized.client_id
68
65
{
69
66
return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
70
67
}
71
-
let did = flow_state.did().unwrap().to_string();
68
+
let did = authorized.did.to_string();
72
69
let client_metadata_cache = ClientMetadataCache::new(3600);
73
-
let client_metadata = client_metadata_cache.get(&auth_request.client_id).await?;
70
+
let client_metadata = client_metadata_cache.get(&authorized.client_id).await?;
74
71
let client_auth = if let (Some(assertion), Some(assertion_type)) = (
75
72
&request.client_auth.client_assertion,
76
73
&request.client_auth.client_assertion_type,
···
91
88
ClientAuth::None
92
89
};
93
90
verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?;
94
-
verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;
91
+
verify_pkce(&authorized.parameters.code_challenge, &code_verifier)?;
95
92
if let Some(req_redirect_uri) = &redirect_uri
96
-
&& req_redirect_uri != &auth_request.parameters.redirect_uri
93
+
&& req_redirect_uri != &authorized.parameters.redirect_uri
97
94
{
98
95
return Err(OAuthError::InvalidGrant(
99
96
"redirect_uri mismatch".to_string(),
···
103
100
let config = AuthConfig::get();
104
101
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
105
102
let pds_hostname =
106
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
103
+
pds_hostname();
107
104
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
108
105
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
109
106
if !state
···
116
113
"DPoP proof has already been used".to_string(),
117
114
));
118
115
}
119
-
if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt
116
+
if let Some(expected_jkt) = &authorized.parameters.dpop_jkt
120
117
&& result.jkt.as_str() != expected_jkt
121
118
{
122
119
return Err(OAuthError::InvalidDpopProof(
···
124
121
));
125
122
}
126
123
Some(result.jkt.as_str().to_string())
127
-
} else if auth_request.parameters.dpop_jkt.is_some() || client_metadata.requires_dpop() {
124
+
} else if authorized.parameters.dpop_jkt.is_some() || client_metadata.requires_dpop() {
128
125
return Err(OAuthError::UseDpopNonce(
129
126
DPoPVerifier::new(AuthConfig::get().dpop_secret().as_bytes()).generate_nonce(),
130
127
));
···
135
132
let refresh_token = RefreshToken::generate();
136
133
let now = Utc::now();
137
134
138
-
let (raw_scope, controller_did) = if let Some(ref controller) = auth_request.controller_did {
135
+
let (raw_scope, controller_did) = if let Some(ref controller) = authorized.controller_did {
139
136
let did_parsed: Did = did
140
137
.parse()
141
138
.map_err(|_| OAuthError::InvalidRequest("Invalid DID format".to_string()))?;
···
149
146
.ok()
150
147
.flatten();
151
148
let granted_scopes = grant.map(|g| g.granted_scopes).unwrap_or_default();
152
-
let requested = auth_request
149
+
let requested = authorized
153
150
.parameters
154
151
.scope
155
152
.as_deref()
···
157
154
let intersected = intersect_scopes(requested, &granted_scopes);
158
155
(Some(intersected), Some(controller.clone()))
159
156
} else {
160
-
(auth_request.parameters.scope.clone(), None)
157
+
(authorized.parameters.scope.clone(), None)
161
158
};
162
159
163
160
let final_scope = if let Some(ref scope) = raw_scope {
···
177
174
final_scope.as_deref(),
178
175
controller_did.as_deref(),
179
176
)?;
180
-
let stored_client_auth = auth_request.client_auth.unwrap_or(ClientAuth::None);
177
+
let stored_client_auth = authorized.client_auth.unwrap_or(ClientAuth::None);
181
178
let refresh_expiry_days = if matches!(stored_client_auth, ClientAuth::None) {
182
179
REFRESH_TOKEN_EXPIRY_DAYS_PUBLIC
183
180
} else {
184
181
REFRESH_TOKEN_EXPIRY_DAYS_CONFIDENTIAL
185
182
};
186
-
let mut stored_parameters = auth_request.parameters.clone();
183
+
let mut stored_parameters = authorized.parameters.clone();
187
184
stored_parameters.dpop_jkt = dpop_jkt.clone();
185
+
let did_typed: Did = did
186
+
.parse()
187
+
.map_err(|_| OAuthError::InvalidRequest("Invalid DID format".to_string()))?;
188
188
let token_data = TokenData {
189
-
did: did.clone(),
190
-
token_id: token_id.0.clone(),
189
+
did: did_typed,
190
+
token_id: token_id.clone(),
191
191
created_at: now,
192
192
updated_at: now,
193
193
expires_at: now + Duration::days(refresh_expiry_days),
194
-
client_id: auth_request.client_id.clone(),
194
+
client_id: authorized.client_id.clone(),
195
195
client_auth: stored_client_auth,
196
-
device_id: auth_request.device_id,
196
+
device_id: authorized.device_id.clone(),
197
197
parameters: stored_parameters,
198
198
details: None,
199
199
code: None,
200
-
current_refresh_token: Some(refresh_token.0.clone()),
200
+
current_refresh_token: Some(refresh_token.clone()),
201
201
scope: final_scope.clone(),
202
202
controller_did: controller_did.clone(),
203
203
};
···
209
209
tracing::info!(
210
210
did = %did,
211
211
token_id = %token_id.0,
212
-
client_id = %auth_request.client_id,
212
+
client_id = %authorized.client_id,
213
213
"Authorization code grant completed, token created"
214
214
);
215
215
tokio::spawn({
···
280
280
);
281
281
let dpop_jkt = token_data.parameters.dpop_jkt.as_deref();
282
282
let access_token = create_access_token_with_delegation(
283
-
&token_data.token_id,
284
-
&token_data.did,
283
+
&token_data.token_id.0,
284
+
token_data.did.as_str(),
285
285
dpop_jkt,
286
286
token_data.scope.as_deref(),
287
-
token_data.controller_did.as_deref(),
287
+
token_data.controller_did.as_ref().map(|d| d.as_str()),
288
288
)?;
289
289
let mut response_headers = HeaderMap::new();
290
290
let config = AuthConfig::get();
···
296
296
access_token,
297
297
token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
298
298
expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
299
-
refresh_token: token_data.current_refresh_token,
299
+
refresh_token: token_data.current_refresh_token.map(|r| r.0),
300
300
scope: token_data.scope,
301
-
sub: Some(token_data.did),
301
+
sub: Some(token_data.did.to_string()),
302
302
}),
303
303
));
304
304
}
···
338
338
let config = AuthConfig::get();
339
339
let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
340
340
let pds_hostname =
341
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
341
+
pds_hostname();
342
342
let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
343
343
let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
344
344
if !state
···
385
385
"Refresh token rotated successfully"
386
386
);
387
387
let access_token = create_access_token_with_delegation(
388
-
&token_data.token_id,
389
-
&token_data.did,
388
+
&token_data.token_id.0,
389
+
token_data.did.as_str(),
390
390
dpop_jkt.as_deref(),
391
391
token_data.scope.as_deref(),
392
-
token_data.controller_did.as_deref(),
392
+
token_data.controller_did.as_ref().map(|d| d.as_str()),
393
393
)?;
394
394
let mut response_headers = HeaderMap::new();
395
395
let config = AuthConfig::get();
···
403
403
expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
404
404
refresh_token: Some(new_refresh_token.0),
405
405
scope: token_data.scope,
406
-
sub: Some(token_data.did),
406
+
sub: Some(token_data.did.to_string()),
407
407
}),
408
408
))
409
409
}
+2
-1
crates/tranquil-pds/src/oauth/endpoints/token/helpers.rs
+2
-1
crates/tranquil-pds/src/oauth/endpoints/token/helpers.rs
···
1
1
use crate::config::AuthConfig;
2
2
use crate::oauth::OAuthError;
3
+
use crate::util::pds_hostname;
3
4
use base64::Engine;
4
5
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5
6
use chrono::Utc;
···
51
52
) -> Result<String, OAuthError> {
52
53
use serde_json::json;
53
54
let jti = uuid::Uuid::new_v4().to_string();
54
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
55
+
let pds_hostname = pds_hostname();
55
56
let issuer = format!("https://{}", pds_hostname);
56
57
let now = Utc::now().timestamp();
57
58
let exp = now + ACCESS_TOKEN_EXPIRY_SECONDS;
+8
-22
crates/tranquil-pds/src/oauth/endpoints/token/introspect.rs
+8
-22
crates/tranquil-pds/src/oauth/endpoints/token/introspect.rs
···
1
1
use super::helpers::extract_token_claims;
2
2
use crate::oauth::OAuthError;
3
-
use crate::state::{AppState, RateLimitKind};
3
+
use crate::rate_limit::{OAuthIntrospectLimit, OAuthRateLimited};
4
+
use crate::state::AppState;
5
+
use crate::util::pds_hostname;
4
6
use axum::extract::State;
5
-
use axum::http::{HeaderMap, StatusCode};
7
+
use axum::http::StatusCode;
6
8
use axum::{Form, Json};
7
9
use chrono::Utc;
8
10
use serde::{Deserialize, Serialize};
···
17
19
18
20
pub async fn revoke_token(
19
21
State(state): State<AppState>,
20
-
headers: HeaderMap,
22
+
_rate_limit: OAuthRateLimited<OAuthIntrospectLimit>,
21
23
Form(request): Form<RevokeRequest>,
22
24
) -> Result<StatusCode, OAuthError> {
23
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
24
-
if !state
25
-
.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip)
26
-
.await
27
-
{
28
-
tracing::warn!(ip = %client_ip, "OAuth revoke rate limit exceeded");
29
-
return Err(OAuthError::RateLimited);
30
-
}
31
25
if let Some(token) = &request.token {
32
26
let refresh_token = RefreshToken::from(token.clone());
33
27
if let Some((db_id, _)) = state
···
89
83
90
84
pub async fn introspect_token(
91
85
State(state): State<AppState>,
92
-
headers: HeaderMap,
86
+
_rate_limit: OAuthRateLimited<OAuthIntrospectLimit>,
93
87
Form(request): Form<IntrospectRequest>,
94
88
) -> Result<Json<IntrospectResponse>, OAuthError> {
95
-
let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
96
-
if !state
97
-
.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip)
98
-
.await
99
-
{
100
-
tracing::warn!(ip = %client_ip, "OAuth introspect rate limit exceeded");
101
-
return Err(OAuthError::RateLimited);
102
-
}
103
89
let inactive_response = IntrospectResponse {
104
90
active: false,
105
91
scope: None,
···
126
112
if token_data.expires_at < Utc::now() {
127
113
return Ok(Json(inactive_response));
128
114
}
129
-
let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
115
+
let pds_hostname = pds_hostname();
130
116
let issuer = format!("https://{}", pds_hostname);
131
117
Ok(Json(IntrospectResponse {
132
118
active: true,
···
141
127
exp: Some(token_info.exp),
142
128
iat: Some(token_info.iat),
143
129
nbf: Some(token_info.iat),
144
-
sub: Some(token_data.did),
130
+
sub: Some(token_data.did.to_string()),
145
131
aud: Some(issuer.clone()),
146
132
iss: Some(issuer),
147
133
jti: Some(token_info.jti),
+3
-26
crates/tranquil-pds/src/oauth/endpoints/token/mod.rs
+3
-26
crates/tranquil-pds/src/oauth/endpoints/token/mod.rs
···
4
4
mod types;
5
5
6
6
use crate::oauth::OAuthError;
7
-
use crate::state::{AppState, RateLimitKind};
7
+
use crate::rate_limit::{OAuthRateLimited, OAuthTokenLimit};
8
+
use crate::state::AppState;
8
9
use axum::body::Bytes;
9
10
use axum::{Json, extract::State, http::HeaderMap};
10
11
···
17
18
ClientAuthParams, GrantType, TokenGrant, TokenRequest, TokenResponse, ValidatedTokenRequest,
18
19
};
19
20
20
-
fn extract_client_ip(headers: &HeaderMap) -> String {
21
-
if let Some(forwarded) = headers.get("x-forwarded-for")
22
-
&& let Ok(value) = forwarded.to_str()
23
-
&& let Some(first_ip) = value.split(',').next()
24
-
{
25
-
return first_ip.trim().to_string();
26
-
}
27
-
if let Some(real_ip) = headers.get("x-real-ip")
28
-
&& let Ok(value) = real_ip.to_str()
29
-
{
30
-
return value.trim().to_string();
31
-
}
32
-
"unknown".to_string()
33
-
}
34
-
35
21
pub async fn token_endpoint(
36
22
State(state): State<AppState>,
23
+
_rate_limit: OAuthRateLimited<OAuthTokenLimit>,
37
24
headers: HeaderMap,
38
25
body: Bytes,
39
26
) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
···
53
40
.to_string(),
54
41
));
55
42
};
56
-
let client_ip = extract_client_ip(&headers);
57
-
if !state
58
-
.check_rate_limit(RateLimitKind::OAuthToken, &client_ip)
59
-
.await
60
-
{
61
-
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
62
-
return Err(OAuthError::InvalidRequest(
63
-
"Too many requests. Please try again later.".to_string(),
64
-
));
65
-
}
66
43
let dpop_proof = headers
67
44
.get("DPoP")
68
45
.and_then(|v| v.to_str().ok())
+9
-7
crates/tranquil-pds/src/oauth/mod.rs
+9
-7
crates/tranquil-pds/src/oauth/mod.rs
···
10
10
}
11
11
12
12
pub use tranquil_oauth::{
13
-
AuthFlowState, AuthorizationRequestParameters, AuthorizationServerMetadata,
14
-
AuthorizedClientData, ClientAuth, ClientMetadata, ClientMetadataCache, Code, DPoPClaims,
15
-
DPoPJwk, DPoPProofHeader, DPoPProofPayload, DPoPVerifier, DPoPVerifyResult, DeviceData,
16
-
DeviceId, JwkPublicKey, Jwks, OAuthClientMetadata, OAuthError, ParResponse,
17
-
ProtectedResourceMetadata, RefreshToken, RefreshTokenState, RequestData, RequestId, SessionId,
18
-
TokenData, TokenId, TokenRequest, TokenResponse, compute_access_token_hash,
19
-
compute_jwk_thumbprint, verify_client_auth,
13
+
AuthFlow, AuthFlowWithUser, AuthorizationRequestParameters, AuthorizationServerMetadata,
14
+
AuthorizedClientData, ClientAuth, ClientMetadata, ClientMetadataCache, Code,
15
+
CodeChallengeMethod, DPoPClaims, DPoPJwk, DPoPProofHeader, DPoPProofPayload, DPoPVerifier,
16
+
DPoPVerifyResult, DeviceData, DeviceId, FlowAuthenticated, FlowAuthorized, FlowExpired,
17
+
FlowNotAuthenticated, FlowNotAuthorized, FlowPending, JwkPublicKey, Jwks, OAuthClientMetadata,
18
+
OAuthError, ParResponse, Prompt, ProtectedResourceMetadata, RefreshToken, RefreshTokenState,
19
+
RequestData, RequestId, ResponseMode, ResponseType, SessionId, TokenData, TokenId,
20
+
TokenRequest, TokenResponse, compute_access_token_hash, compute_jwk_thumbprint,
21
+
verify_client_auth,
20
22
};
21
23
22
24
pub use scopes::{AccountAction, AccountAttr, RepoAction, ScopeError, ScopePermissions};
+18
-14
crates/tranquil-pds/src/oauth/verify.rs
+18
-14
crates/tranquil-pds/src/oauth/verify.rs
···
20
20
use crate::state::AppState;
21
21
22
22
pub struct OAuthTokenInfo {
23
-
pub did: String,
24
-
pub token_id: String,
25
-
pub client_id: String,
23
+
pub did: Did,
24
+
pub token_id: TokenId,
25
+
pub client_id: ClientId,
26
26
pub scope: Option<String>,
27
27
pub dpop_jkt: Option<String>,
28
-
pub controller_did: Option<String>,
28
+
pub controller_did: Option<Did>,
29
29
}
30
30
31
31
pub struct VerifyResult {
···
48
48
has_dpop_proof = dpop_proof.is_some(),
49
49
"Verifying OAuth access token"
50
50
);
51
-
let token_id = TokenId::from(token_info.token_id.clone());
51
+
let token_id = token_info.token_id.clone();
52
52
let token_data = oauth_repo
53
53
.get_token_by_id(&token_id)
54
54
.await
···
154
154
if exp < now {
155
155
return Err(OAuthError::ExpiredToken("Token has expired".to_string()));
156
156
}
157
-
let token_id = payload
157
+
let token_id_str = payload
158
158
.get("sid")
159
159
.and_then(|j| j.as_str())
160
-
.ok_or_else(|| OAuthError::InvalidToken("Missing sid claim".to_string()))?
161
-
.to_string();
162
-
let did = payload
160
+
.ok_or_else(|| OAuthError::InvalidToken("Missing sid claim".to_string()))?;
161
+
let token_id = TokenId::new(token_id_str);
162
+
let did_str = payload
163
163
.get("sub")
164
164
.and_then(|s| s.as_str())
165
-
.ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))?
166
-
.to_string();
165
+
.ok_or_else(|| OAuthError::InvalidToken("Missing sub claim".to_string()))?;
166
+
let did: Did = did_str
167
+
.parse()
168
+
.map_err(|_| OAuthError::InvalidToken("Invalid sub claim (not a valid DID)".to_string()))?;
167
169
let scope = payload
168
170
.get("scope")
169
171
.and_then(|s| s.as_str())
···
173
175
.and_then(|c| c.get("jkt"))
174
176
.and_then(|j| j.as_str())
175
177
.map(|s| s.to_string());
176
-
let client_id = payload
178
+
let client_id_str = payload
177
179
.get("client_id")
178
180
.and_then(|c| c.as_str())
179
-
.map(|s| s.to_string())
180
181
.unwrap_or_default();
182
+
let client_id = ClientId::new(client_id_str);
181
183
let controller_did = payload
182
184
.get("act")
183
185
.and_then(|a| a.get("sub"))
184
186
.and_then(|s| s.as_str())
185
-
.map(|s| s.to_string());
187
+
.map(|s| s.parse::<Did>())
188
+
.transpose()
189
+
.map_err(|_| OAuthError::InvalidToken("Invalid act.sub claim (not a valid DID)".to_string()))?;
186
190
Ok(OAuthTokenInfo {
187
191
did,
188
192
token_id,
+272
crates/tranquil-pds/src/rate_limit/extractor.rs
+272
crates/tranquil-pds/src/rate_limit/extractor.rs
···
1
+
use std::marker::PhantomData;
2
+
3
+
use axum::{
4
+
extract::FromRequestParts,
5
+
http::request::Parts,
6
+
response::{IntoResponse, Response},
7
+
};
8
+
9
+
use crate::api::error::ApiError;
10
+
use crate::oauth::OAuthError;
11
+
use crate::state::{AppState, RateLimitKind};
12
+
use crate::util::extract_client_ip;
13
+
14
+
pub trait RateLimitPolicy: Send + Sync + 'static {
15
+
const KIND: RateLimitKind;
16
+
}
17
+
18
+
pub struct LoginLimit;
19
+
impl RateLimitPolicy for LoginLimit {
20
+
const KIND: RateLimitKind = RateLimitKind::Login;
21
+
}
22
+
23
+
pub struct AccountCreationLimit;
24
+
impl RateLimitPolicy for AccountCreationLimit {
25
+
const KIND: RateLimitKind = RateLimitKind::AccountCreation;
26
+
}
27
+
28
+
pub struct PasswordResetLimit;
29
+
impl RateLimitPolicy for PasswordResetLimit {
30
+
const KIND: RateLimitKind = RateLimitKind::PasswordReset;
31
+
}
32
+
33
+
pub struct ResetPasswordLimit;
34
+
impl RateLimitPolicy for ResetPasswordLimit {
35
+
const KIND: RateLimitKind = RateLimitKind::ResetPassword;
36
+
}
37
+
38
+
pub struct RefreshSessionLimit;
39
+
impl RateLimitPolicy for RefreshSessionLimit {
40
+
const KIND: RateLimitKind = RateLimitKind::RefreshSession;
41
+
}
42
+
43
+
pub struct OAuthTokenLimit;
44
+
impl RateLimitPolicy for OAuthTokenLimit {
45
+
const KIND: RateLimitKind = RateLimitKind::OAuthToken;
46
+
}
47
+
48
+
pub struct OAuthAuthorizeLimit;
49
+
impl RateLimitPolicy for OAuthAuthorizeLimit {
50
+
const KIND: RateLimitKind = RateLimitKind::OAuthAuthorize;
51
+
}
52
+
53
+
pub struct OAuthParLimit;
54
+
impl RateLimitPolicy for OAuthParLimit {
55
+
const KIND: RateLimitKind = RateLimitKind::OAuthPar;
56
+
}
57
+
58
+
pub struct OAuthIntrospectLimit;
59
+
impl RateLimitPolicy for OAuthIntrospectLimit {
60
+
const KIND: RateLimitKind = RateLimitKind::OAuthIntrospect;
61
+
}
62
+
63
+
pub struct AppPasswordLimit;
64
+
impl RateLimitPolicy for AppPasswordLimit {
65
+
const KIND: RateLimitKind = RateLimitKind::AppPassword;
66
+
}
67
+
68
+
pub struct EmailUpdateLimit;
69
+
impl RateLimitPolicy for EmailUpdateLimit {
70
+
const KIND: RateLimitKind = RateLimitKind::EmailUpdate;
71
+
}
72
+
73
+
pub struct TotpVerifyLimit;
74
+
impl RateLimitPolicy for TotpVerifyLimit {
75
+
const KIND: RateLimitKind = RateLimitKind::TotpVerify;
76
+
}
77
+
78
+
pub struct HandleUpdateLimit;
79
+
impl RateLimitPolicy for HandleUpdateLimit {
80
+
const KIND: RateLimitKind = RateLimitKind::HandleUpdate;
81
+
}
82
+
83
+
pub struct HandleUpdateDailyLimit;
84
+
impl RateLimitPolicy for HandleUpdateDailyLimit {
85
+
const KIND: RateLimitKind = RateLimitKind::HandleUpdateDaily;
86
+
}
87
+
88
+
pub struct VerificationCheckLimit;
89
+
impl RateLimitPolicy for VerificationCheckLimit {
90
+
const KIND: RateLimitKind = RateLimitKind::VerificationCheck;
91
+
}
92
+
93
+
pub struct SsoInitiateLimit;
94
+
impl RateLimitPolicy for SsoInitiateLimit {
95
+
const KIND: RateLimitKind = RateLimitKind::SsoInitiate;
96
+
}
97
+
98
+
pub struct SsoCallbackLimit;
99
+
impl RateLimitPolicy for SsoCallbackLimit {
100
+
const KIND: RateLimitKind = RateLimitKind::SsoCallback;
101
+
}
102
+
103
+
pub struct SsoUnlinkLimit;
104
+
impl RateLimitPolicy for SsoUnlinkLimit {
105
+
const KIND: RateLimitKind = RateLimitKind::SsoUnlink;
106
+
}
107
+
108
+
pub struct OAuthRegisterCompleteLimit;
109
+
impl RateLimitPolicy for OAuthRegisterCompleteLimit {
110
+
const KIND: RateLimitKind = RateLimitKind::OAuthRegisterComplete;
111
+
}
112
+
113
+
pub trait RateLimitRejection: IntoResponse + Send + 'static {
114
+
fn new() -> Self;
115
+
}
116
+
117
+
pub struct ApiRateLimitRejection;
118
+
119
+
impl RateLimitRejection for ApiRateLimitRejection {
120
+
fn new() -> Self {
121
+
Self
122
+
}
123
+
}
124
+
125
+
impl IntoResponse for ApiRateLimitRejection {
126
+
fn into_response(self) -> Response {
127
+
ApiError::RateLimitExceeded(None).into_response()
128
+
}
129
+
}
130
+
131
+
pub struct OAuthRateLimitRejection;
132
+
133
+
impl RateLimitRejection for OAuthRateLimitRejection {
134
+
fn new() -> Self {
135
+
Self
136
+
}
137
+
}
138
+
139
+
impl IntoResponse for OAuthRateLimitRejection {
140
+
fn into_response(self) -> Response {
141
+
OAuthError::RateLimited.into_response()
142
+
}
143
+
}
144
+
145
+
impl From<OAuthRateLimitRejection> for OAuthError {
146
+
fn from(_: OAuthRateLimitRejection) -> Self {
147
+
OAuthError::RateLimited
148
+
}
149
+
}
150
+
151
+
pub struct RateLimitedInner<P: RateLimitPolicy, R: RateLimitRejection> {
152
+
client_ip: String,
153
+
_marker: PhantomData<(P, R)>,
154
+
}
155
+
156
+
impl<P: RateLimitPolicy, R: RateLimitRejection> RateLimitedInner<P, R> {
157
+
pub fn client_ip(&self) -> &str {
158
+
&self.client_ip
159
+
}
160
+
}
161
+
162
+
impl<P: RateLimitPolicy, R: RateLimitRejection> FromRequestParts<AppState>
163
+
for RateLimitedInner<P, R>
164
+
{
165
+
type Rejection = R;
166
+
167
+
async fn from_request_parts(
168
+
parts: &mut Parts,
169
+
state: &AppState,
170
+
) -> Result<Self, Self::Rejection> {
171
+
let client_ip = extract_client_ip(&parts.headers, None);
172
+
173
+
if !state.check_rate_limit(P::KIND, &client_ip).await {
174
+
tracing::warn!(
175
+
ip = %client_ip,
176
+
kind = ?P::KIND,
177
+
"Rate limit exceeded"
178
+
);
179
+
return Err(R::new());
180
+
}
181
+
182
+
Ok(RateLimitedInner {
183
+
client_ip,
184
+
_marker: PhantomData,
185
+
})
186
+
}
187
+
}
188
+
189
+
pub type RateLimited<P> = RateLimitedInner<P, ApiRateLimitRejection>;
190
+
pub type OAuthRateLimited<P> = RateLimitedInner<P, OAuthRateLimitRejection>;
191
+
192
+
#[derive(Debug)]
193
+
pub struct UserRateLimitError {
194
+
pub kind: RateLimitKind,
195
+
pub message: Option<String>,
196
+
}
197
+
198
+
impl UserRateLimitError {
199
+
pub fn new(kind: RateLimitKind) -> Self {
200
+
Self {
201
+
kind,
202
+
message: None,
203
+
}
204
+
}
205
+
206
+
pub fn with_message(kind: RateLimitKind, message: impl Into<String>) -> Self {
207
+
Self {
208
+
kind,
209
+
message: Some(message.into()),
210
+
}
211
+
}
212
+
}
213
+
214
+
impl std::fmt::Display for UserRateLimitError {
215
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216
+
match &self.message {
217
+
Some(msg) => write!(f, "{}", msg),
218
+
None => write!(f, "Rate limit exceeded for {:?}", self.kind),
219
+
}
220
+
}
221
+
}
222
+
223
+
impl std::error::Error for UserRateLimitError {}
224
+
225
+
impl IntoResponse for UserRateLimitError {
226
+
fn into_response(self) -> Response {
227
+
ApiError::RateLimitExceeded(self.message).into_response()
228
+
}
229
+
}
230
+
231
+
pub struct UserRateLimitProof<P: RateLimitPolicy> {
232
+
_marker: PhantomData<P>,
233
+
}
234
+
235
+
impl<P: RateLimitPolicy> UserRateLimitProof<P> {
236
+
fn new() -> Self {
237
+
Self {
238
+
_marker: PhantomData,
239
+
}
240
+
}
241
+
}
242
+
243
+
pub async fn check_user_rate_limit<P: RateLimitPolicy>(
244
+
state: &AppState,
245
+
user_key: &str,
246
+
) -> Result<UserRateLimitProof<P>, UserRateLimitError> {
247
+
if !state.check_rate_limit(P::KIND, user_key).await {
248
+
tracing::warn!(
249
+
key = %user_key,
250
+
kind = ?P::KIND,
251
+
"User rate limit exceeded"
252
+
);
253
+
return Err(UserRateLimitError::new(P::KIND));
254
+
}
255
+
Ok(UserRateLimitProof::new())
256
+
}
257
+
258
+
pub async fn check_user_rate_limit_with_message<P: RateLimitPolicy>(
259
+
state: &AppState,
260
+
user_key: &str,
261
+
error_message: impl Into<String>,
262
+
) -> Result<UserRateLimitProof<P>, UserRateLimitError> {
263
+
if !state.check_rate_limit(P::KIND, user_key).await {
264
+
tracing::warn!(
265
+
key = %user_key,
266
+
kind = ?P::KIND,
267
+
"User rate limit exceeded"
268
+
);
269
+
return Err(UserRateLimitError::with_message(P::KIND, error_message));
270
+
}
271
+
Ok(UserRateLimitProof::new())
272
+
}
+5
-102
crates/tranquil-pds/src/rate_limit.rs
crates/tranquil-pds/src/rate_limit/mod.rs
+5
-102
crates/tranquil-pds/src/rate_limit.rs
crates/tranquil-pds/src/rate_limit/mod.rs
···
1
-
use axum::{
2
-
Json,
3
-
body::Body,
4
-
extract::ConnectInfo,
5
-
http::{HeaderMap, Request, StatusCode},
6
-
middleware::Next,
7
-
response::{IntoResponse, Response},
8
-
};
1
+
mod extractor;
2
+
3
+
pub use extractor::*;
4
+
9
5
use governor::{
10
6
Quota, RateLimiter,
11
7
clock::DefaultClock,
12
8
state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore},
13
9
};
14
-
use std::{net::SocketAddr, num::NonZeroU32, sync::Arc};
10
+
use std::{num::NonZeroU32, sync::Arc};
15
11
16
12
pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
17
13
pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
···
166
162
}
167
163
}
168
164
169
-
pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
170
-
if let Some(forwarded) = headers.get("x-forwarded-for")
171
-
&& let Ok(value) = forwarded.to_str()
172
-
&& let Some(first_ip) = value.split(',').next()
173
-
{
174
-
return first_ip.trim().to_string();
175
-
}
176
-
177
-
if let Some(real_ip) = headers.get("x-real-ip")
178
-
&& let Ok(value) = real_ip.to_str()
179
-
{
180
-
return value.trim().to_string();
181
-
}
182
-
183
-
addr.map(|a| a.ip().to_string())
184
-
.unwrap_or_else(|| "unknown".to_string())
185
-
}
186
-
187
-
fn rate_limit_response() -> Response {
188
-
(
189
-
StatusCode::TOO_MANY_REQUESTS,
190
-
Json(serde_json::json!({
191
-
"error": "RateLimitExceeded",
192
-
"message": "Too many requests. Please try again later."
193
-
})),
194
-
)
195
-
.into_response()
196
-
}
197
-
198
-
pub async fn login_rate_limit(
199
-
ConnectInfo(addr): ConnectInfo<SocketAddr>,
200
-
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
201
-
request: Request<Body>,
202
-
next: Next,
203
-
) -> Response {
204
-
let client_ip = extract_client_ip(request.headers(), Some(addr));
205
-
206
-
if limiters.login.check_key(&client_ip).is_err() {
207
-
tracing::warn!(ip = %client_ip, "Login rate limit exceeded");
208
-
return rate_limit_response();
209
-
}
210
-
211
-
next.run(request).await
212
-
}
213
-
214
-
pub async fn oauth_token_rate_limit(
215
-
ConnectInfo(addr): ConnectInfo<SocketAddr>,
216
-
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
217
-
request: Request<Body>,
218
-
next: Next,
219
-
) -> Response {
220
-
let client_ip = extract_client_ip(request.headers(), Some(addr));
221
-
222
-
if limiters.oauth_token.check_key(&client_ip).is_err() {
223
-
tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded");
224
-
return rate_limit_response();
225
-
}
226
-
227
-
next.run(request).await
228
-
}
229
-
230
-
pub async fn password_reset_rate_limit(
231
-
ConnectInfo(addr): ConnectInfo<SocketAddr>,
232
-
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
233
-
request: Request<Body>,
234
-
next: Next,
235
-
) -> Response {
236
-
let client_ip = extract_client_ip(request.headers(), Some(addr));
237
-
238
-
if limiters.password_reset.check_key(&client_ip).is_err() {
239
-
tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded");
240
-
return rate_limit_response();
241
-
}
242
-
243
-
next.run(request).await
244
-
}
245
-
246
-
pub async fn account_creation_rate_limit(
247
-
ConnectInfo(addr): ConnectInfo<SocketAddr>,
248
-
axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>,
249
-
request: Request<Body>,
250
-
next: Next,
251
-
) -> Response {
252
-
let client_ip = extract_client_ip(request.headers(), Some(addr));
253
-
254
-
if limiters.account_creation.check_key(&client_ip).is_err() {
255
-
tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded");
256
-
return rate_limit_response();
257
-
}
258
-
259
-
next.run(request).await
260
-
}
261
-
262
165
#[cfg(test)]
263
166
mod tests {
264
167
use super::*;
+2
-1
crates/tranquil-pds/src/sso/config.rs
+2
-1
crates/tranquil-pds/src/sso/config.rs
···
1
+
use crate::util::pds_hostname;
1
2
use std::sync::OnceLock;
2
3
use tranquil_db_traits::SsoProviderType;
3
4
···
50
51
};
51
52
52
53
if config.is_any_enabled() {
53
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_default();
54
+
let hostname = pds_hostname();
54
55
if hostname.is_empty() || hostname == "localhost" {
55
56
panic!(
56
57
"PDS_HOSTNAME must be set to a valid hostname when SSO is enabled. \
+33
-62
crates/tranquil-pds/src/sso/endpoints.rs
+33
-62
crates/tranquil-pds/src/sso/endpoints.rs
···
13
13
use crate::api::error::ApiError;
14
14
use crate::auth::extractor::extract_bearer_token_from_header;
15
15
use crate::auth::{generate_app_password, validate_bearer_token_cached};
16
-
use crate::rate_limit::extract_client_ip;
17
-
use crate::state::{AppState, RateLimitKind};
16
+
use crate::rate_limit::{
17
+
AccountCreationLimit, RateLimited, SsoCallbackLimit, SsoInitiateLimit, SsoUnlinkLimit,
18
+
check_user_rate_limit_with_message,
19
+
};
20
+
use crate::state::AppState;
21
+
use crate::util::{pds_hostname, pds_hostname_without_port};
18
22
19
23
fn generate_state() -> String {
20
24
use rand::RngCore;
···
71
75
72
76
pub async fn sso_initiate(
73
77
State(state): State<AppState>,
78
+
_rate_limit: RateLimited<SsoInitiateLimit>,
74
79
headers: HeaderMap,
75
80
Json(input): Json<SsoInitiateRequest>,
76
81
) -> Result<Json<SsoInitiateResponse>, ApiError> {
77
-
let client_ip = extract_client_ip(&headers, None);
78
-
if !state
79
-
.check_rate_limit(RateLimitKind::SsoInitiate, &client_ip)
80
-
.await
81
-
{
82
-
tracing::warn!(ip = %client_ip, "SSO initiate rate limit exceeded");
83
-
return Err(ApiError::RateLimitExceeded(None));
84
-
}
85
-
86
82
if input.provider.len() > 20 {
87
83
return Err(ApiError::SsoProviderNotFound);
88
84
}
···
217
213
218
214
pub async fn sso_callback(
219
215
State(state): State<AppState>,
220
-
headers: HeaderMap,
216
+
_rate_limit: RateLimited<SsoCallbackLimit>,
221
217
Query(query): Query<SsoCallbackQuery>,
222
218
) -> Response {
219
+
sso_callback_internal(&state, query).await
220
+
}
221
+
222
+
async fn sso_callback_internal(state: &AppState, query: SsoCallbackQuery) -> Response {
223
223
tracing::debug!(
224
224
has_code = query.code.is_some(),
225
225
has_state = query.state.is_some(),
···
227
227
"SSO callback received"
228
228
);
229
229
230
-
let client_ip = extract_client_ip(&headers, None);
231
-
if !state
232
-
.check_rate_limit(RateLimitKind::SsoCallback, &client_ip)
233
-
.await
234
-
{
235
-
tracing::warn!(ip = %client_ip, "SSO callback rate limit exceeded");
236
-
return redirect_to_error("Too many requests. Please try again later.");
237
-
}
238
-
239
230
if let Some(ref error) = query.error {
240
231
tracing::warn!(
241
232
error = %error,
···
329
320
match auth_state.action.as_str() {
330
321
"login" => {
331
322
handle_sso_login(
332
-
&state,
323
+
state,
333
324
&auth_state.request_uri,
334
325
auth_state.provider,
335
326
&user_info,
···
341
332
Some(d) => d,
342
333
None => return redirect_to_error("Not authenticated"),
343
334
};
344
-
handle_sso_link(&state, did, auth_state.provider, &user_info).await
335
+
handle_sso_link(state, did, auth_state.provider, &user_info).await
345
336
}
346
337
"register" => {
347
338
handle_sso_register(
348
-
&state,
339
+
state,
349
340
&auth_state.request_uri,
350
341
auth_state.provider,
351
342
&user_info,
···
358
349
359
350
pub async fn sso_callback_post(
360
351
State(state): State<AppState>,
361
-
headers: HeaderMap,
352
+
_rate_limit: RateLimited<SsoCallbackLimit>,
362
353
Form(form): Form<SsoCallbackForm>,
363
354
) -> Response {
364
355
tracing::debug!(
···
376
367
error_description: form.error_description,
377
368
};
378
369
379
-
sso_callback(State(state), headers, Query(query)).await
370
+
sso_callback_internal(&state, query).await
380
371
}
381
372
382
373
fn generate_registration_token() -> String {
···
682
673
auth: crate::auth::Auth<crate::auth::Active>,
683
674
Json(input): Json<UnlinkAccountRequest>,
684
675
) -> Result<Json<UnlinkAccountResponse>, ApiError> {
685
-
if !state
686
-
.check_rate_limit(RateLimitKind::SsoUnlink, auth.did.as_str())
687
-
.await
688
-
{
689
-
tracing::warn!(did = %auth.did, "SSO unlink rate limit exceeded");
690
-
return Err(ApiError::RateLimitExceeded(None));
691
-
}
676
+
let _rate_limit = check_user_rate_limit_with_message::<SsoUnlinkLimit>(
677
+
&state,
678
+
auth.did.as_str(),
679
+
"Too many unlink attempts. Please try again later.",
680
+
)
681
+
.await?;
692
682
693
683
let id = uuid::Uuid::parse_str(&input.id).map_err(|_| ApiError::InvalidId)?;
694
684
···
746
736
747
737
pub async fn get_pending_registration(
748
738
State(state): State<AppState>,
749
-
headers: HeaderMap,
739
+
_rate_limit: RateLimited<SsoCallbackLimit>,
750
740
Query(query): Query<PendingRegistrationQuery>,
751
741
) -> Result<Json<PendingRegistrationResponse>, ApiError> {
752
-
let client_ip = extract_client_ip(&headers, None);
753
-
if !state
754
-
.check_rate_limit(RateLimitKind::SsoCallback, &client_ip)
755
-
.await
756
-
{
757
-
tracing::warn!(ip = %client_ip, "SSO pending registration rate limit exceeded");
758
-
return Err(ApiError::RateLimitExceeded(None));
759
-
}
760
-
761
742
if query.token.len() > 100 {
762
743
return Err(ApiError::InvalidRequest("Invalid token".into()));
763
744
}
···
810
791
}
811
792
};
812
793
813
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
814
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
794
+
let hostname_for_handles = pds_hostname_without_port();
815
795
let full_handle = format!("{}.{}", validated, hostname_for_handles);
816
796
let handle_typed = crate::types::Handle::new_unchecked(&full_handle);
817
797
···
866
846
867
847
pub async fn complete_registration(
868
848
State(state): State<AppState>,
869
-
headers: HeaderMap,
849
+
rate_limit: RateLimited<AccountCreationLimit>,
870
850
Json(input): Json<CompleteRegistrationInput>,
871
851
) -> Result<Json<CompleteRegistrationResponse>, ApiError> {
852
+
let client_ip = rate_limit.client_ip();
872
853
use jacquard_common::types::{integer::LimitedU32, string::Tid};
873
854
use jacquard_repo::{mst::Mst, storage::BlockStore};
874
855
use k256::ecdsa::SigningKey;
···
876
857
use serde_json::json;
877
858
use std::sync::Arc;
878
859
879
-
let client_ip = extract_client_ip(&headers, None);
880
-
if !state
881
-
.check_rate_limit(RateLimitKind::AccountCreation, &client_ip)
882
-
.await
883
-
{
884
-
tracing::warn!(ip = %client_ip, "SSO registration rate limit exceeded");
885
-
return Err(ApiError::RateLimitExceeded(None));
886
-
}
887
-
888
860
if input.token.len() > 100 {
889
861
return Err(ApiError::InvalidRequest("Invalid token".into()));
890
862
}
···
899
871
.await?
900
872
.ok_or(ApiError::SsoSessionExpired)?;
901
873
902
-
let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
903
-
let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname);
874
+
let hostname = pds_hostname();
875
+
let hostname_for_handles = pds_hostname_without_port();
904
876
905
877
let handle = match crate::api::validation::validate_short_handle(&input.handle) {
906
878
Ok(h) => format!("{}.{}", h, hostname_for_handles),
···
977
949
let handle_typed = crate::types::Handle::new_unchecked(&handle);
978
950
let reserved = state
979
951
.user_repo
980
-
.reserve_handle(&handle_typed, &client_ip)
952
+
.reserve_handle(&handle_typed, client_ip)
981
953
.await
982
954
.unwrap_or(false);
983
955
···
1315
1287
return Err(ApiError::InternalError(None));
1316
1288
}
1317
1289
1318
-
let hostname =
1319
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
1290
+
let hostname = pds_hostname();
1320
1291
if let Err(e) = crate::comms::comms_repo::enqueue_welcome(
1321
1292
state.user_repo.as_ref(),
1322
1293
state.infra_repo.as_ref(),
1323
1294
user_id.unwrap_or(uuid::Uuid::nil()),
1324
-
&hostname,
1295
+
hostname,
1325
1296
)
1326
1297
.await
1327
1298
{
···
1367
1338
verification_channel,
1368
1339
&verification_recipient,
1369
1340
&formatted_token,
1370
-
&hostname,
1341
+
hostname,
1371
1342
)
1372
1343
.await
1373
1344
{
+9
crates/tranquil-pds/src/state.rs
+9
crates/tranquil-pds/src/state.rs
···
1
1
use crate::appview::DidResolver;
2
+
use crate::auth::webauthn::WebAuthnConfig;
2
3
use crate::cache::{Cache, DistributedRateLimiter, create_cache};
3
4
use crate::circuit_breaker::CircuitBreakers;
4
5
use crate::config::AuthConfig;
···
7
8
use crate::sso::{SsoConfig, SsoManager};
8
9
use crate::storage::{BackupStorage, BlobStorage, create_backup_storage, create_blob_storage};
9
10
use crate::sync::firehose::SequencedEvent;
11
+
use crate::util::pds_hostname;
10
12
use sqlx::PgPool;
11
13
use std::error::Error;
12
14
use std::sync::Arc;
···
41
43
pub did_resolver: Arc<DidResolver>,
42
44
pub sso_repo: Arc<dyn SsoRepository>,
43
45
pub sso_manager: SsoManager,
46
+
pub webauthn_config: Arc<WebAuthnConfig>,
44
47
}
45
48
49
+
#[derive(Debug, Clone, Copy)]
46
50
pub enum RateLimitKind {
47
51
Login,
48
52
AccountCreation,
···
180
184
let did_resolver = Arc::new(DidResolver::new());
181
185
let sso_config = SsoConfig::init();
182
186
let sso_manager = SsoManager::from_config(sso_config);
187
+
let webauthn_config = Arc::new(
188
+
WebAuthnConfig::new(pds_hostname())
189
+
.expect("Failed to create WebAuthn config at startup"),
190
+
);
183
191
184
192
Self {
185
193
user_repo: repos.user.clone(),
···
204
212
distributed_rate_limiter,
205
213
did_resolver,
206
214
sso_manager,
215
+
webauthn_config,
207
216
}
208
217
}
209
218
+2
-2
crates/tranquil-pds/src/sync/deprecated.rs
+2
-2
crates/tranquil-pds/src/sync/deprecated.rs
···
20
20
21
21
async fn check_admin_or_self(state: &AppState, headers: &HeaderMap, did: &str) -> bool {
22
22
let extracted = match crate::auth::extract_auth_token_from_header(
23
-
headers.get("Authorization").and_then(|h| h.to_str().ok()),
23
+
crate::util::get_header_str(headers, "Authorization"),
24
24
) {
25
25
Some(t) => t,
26
26
None => return false,
27
27
};
28
-
let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
28
+
let dpop_proof = crate::util::get_header_str(headers, "DPoP");
29
29
let http_uri = "/";
30
30
match crate::auth::validate_token_with_dpop(
31
31
state.user_repo.as_ref(),
+21
-4
crates/tranquil-pds/src/util.rs
+21
-4
crates/tranquil-pds/src/util.rs
···
4
4
use rand::Rng;
5
5
use serde_json::Value as JsonValue;
6
6
use std::collections::BTreeMap;
7
+
use std::net::SocketAddr;
7
8
use std::str::FromStr;
8
9
use std::sync::OnceLock;
9
10
···
11
12
const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024;
12
13
13
14
static MAX_BLOB_SIZE: OnceLock<usize> = OnceLock::new();
15
+
static PDS_HOSTNAME: OnceLock<String> = OnceLock::new();
16
+
static PDS_HOSTNAME_WITHOUT_PORT: OnceLock<String> = OnceLock::new();
14
17
15
18
pub fn get_max_blob_size() -> usize {
16
19
*MAX_BLOB_SIZE.get_or_init(|| {
···
69
72
.unwrap_or_default()
70
73
}
71
74
72
-
pub fn extract_client_ip(headers: &HeaderMap) -> String {
75
+
pub fn get_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
76
+
headers.get(name).and_then(|h| h.to_str().ok())
77
+
}
78
+
79
+
pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String {
73
80
if let Some(forwarded) = headers.get("x-forwarded-for")
74
81
&& let Ok(value) = forwarded.to_str()
75
82
&& let Some(first_ip) = value.split(',').next()
···
81
88
{
82
89
return value.trim().to_string();
83
90
}
84
-
"unknown".to_string()
91
+
addr.map(|a| a.ip().to_string())
92
+
.unwrap_or_else(|| "unknown".to_string())
93
+
}
94
+
95
+
pub fn pds_hostname() -> &'static str {
96
+
PDS_HOSTNAME.get_or_init(|| {
97
+
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
98
+
})
85
99
}
86
100
87
-
pub fn pds_hostname() -> String {
88
-
std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
101
+
pub fn pds_hostname_without_port() -> &'static str {
102
+
PDS_HOSTNAME_WITHOUT_PORT.get_or_init(|| {
103
+
let hostname = pds_hostname();
104
+
hostname.split(':').next().unwrap_or(hostname).to_string()
105
+
})
89
106
}
90
107
91
108
pub fn pds_public_url() -> String {
History
3 rounds
0 comments
expand 0 comments
pull request successfully merged