this repo has no description
1use serde::{Deserialize, Serialize};
2use sqlx::PgPool;
3use std::fmt;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::cache::Cache;
8use crate::oauth::scopes::ScopePermissions;
9
10pub mod extractor;
11pub mod scope_check;
12pub mod service;
13pub mod token;
14pub mod totp;
15pub mod verification_token;
16pub mod verify;
17pub mod webauthn;
18
19pub use extractor::{
20 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken,
21 extract_auth_token_from_header, extract_bearer_token_from_header,
22};
23pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token};
24pub use token::{
25 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS,
26 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token,
27 create_access_token_with_delegation, create_access_token_with_metadata,
28 create_access_token_with_scope_metadata, create_refresh_token,
29 create_refresh_token_with_metadata, create_service_token,
30};
31pub use verify::{
32 TokenVerifyError, get_did_from_token, get_jti_from_token, verify_access_token,
33 verify_access_token_typed, verify_refresh_token, verify_token,
34};
35
36const KEY_CACHE_TTL_SECS: u64 = 300;
37const SESSION_CACHE_TTL_SECS: u64 = 60;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum TokenValidationError {
41 AccountDeactivated,
42 AccountTakedown,
43 KeyDecryptionFailed,
44 AuthenticationFailed,
45 TokenExpired,
46}
47
48impl fmt::Display for TokenValidationError {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 Self::AccountDeactivated => write!(f, "AccountDeactivated"),
52 Self::AccountTakedown => write!(f, "AccountTakedown"),
53 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
54 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
55 Self::TokenExpired => write!(f, "ExpiredToken"),
56 }
57 }
58}
59
60pub struct AuthenticatedUser {
61 pub did: String,
62 pub key_bytes: Option<Vec<u8>>,
63 pub is_oauth: bool,
64 pub is_admin: bool,
65 pub is_takendown: bool,
66 pub scope: Option<String>,
67 pub controller_did: Option<String>,
68}
69
70impl AuthenticatedUser {
71 pub fn permissions(&self) -> ScopePermissions {
72 if let Some(ref scope) = self.scope
73 && scope != SCOPE_ACCESS
74 {
75 return ScopePermissions::from_scope_string(Some(scope));
76 }
77 if !self.is_oauth {
78 return ScopePermissions::from_scope_string(Some("atproto"));
79 }
80 ScopePermissions::from_scope_string(self.scope.as_deref())
81 }
82}
83
84pub async fn validate_bearer_token(
85 db: &PgPool,
86 token: &str,
87) -> Result<AuthenticatedUser, TokenValidationError> {
88 validate_bearer_token_with_options_internal(db, None, token, false, false).await
89}
90
91pub async fn validate_bearer_token_allow_deactivated(
92 db: &PgPool,
93 token: &str,
94) -> Result<AuthenticatedUser, TokenValidationError> {
95 validate_bearer_token_with_options_internal(db, None, token, true, false).await
96}
97
98pub async fn validate_bearer_token_cached(
99 db: &PgPool,
100 cache: &Arc<dyn Cache>,
101 token: &str,
102) -> Result<AuthenticatedUser, TokenValidationError> {
103 validate_bearer_token_with_options_internal(db, Some(cache), token, false, false).await
104}
105
106pub async fn validate_bearer_token_cached_allow_deactivated(
107 db: &PgPool,
108 cache: &Arc<dyn Cache>,
109 token: &str,
110) -> Result<AuthenticatedUser, TokenValidationError> {
111 validate_bearer_token_with_options_internal(db, Some(cache), token, true, false).await
112}
113
114pub async fn validate_bearer_token_for_service_auth(
115 db: &PgPool,
116 token: &str,
117) -> Result<AuthenticatedUser, TokenValidationError> {
118 validate_bearer_token_with_options_internal(db, None, token, true, true).await
119}
120
121pub async fn validate_bearer_token_allow_takendown(
122 db: &PgPool,
123 token: &str,
124) -> Result<AuthenticatedUser, TokenValidationError> {
125 validate_bearer_token_with_options_internal(db, None, token, false, true).await
126}
127
128async fn validate_bearer_token_with_options_internal(
129 db: &PgPool,
130 cache: Option<&Arc<dyn Cache>>,
131 token: &str,
132 allow_deactivated: bool,
133 allow_takendown: bool,
134) -> Result<AuthenticatedUser, TokenValidationError> {
135 let did_from_token = get_did_from_token(token).ok();
136
137 if let Some(ref did) = did_from_token {
138 let key_cache_key = format!("auth:key:{}", did);
139 let mut cached_key: Option<Vec<u8>> = None;
140
141 if let Some(c) = cache {
142 cached_key = c.get_bytes(&key_cache_key).await;
143 if cached_key.is_some() {
144 crate::metrics::record_auth_cache_hit("key");
145 } else {
146 crate::metrics::record_auth_cache_miss("key");
147 }
148 }
149
150 let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key
151 {
152 let user_status = sqlx::query!(
153 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1",
154 did
155 )
156 .fetch_optional(db)
157 .await
158 .ok()
159 .flatten();
160
161 match user_status {
162 Some(status) => (
163 Some(key),
164 status.deactivated_at,
165 status.takedown_ref,
166 status.is_admin,
167 ),
168 None => (None, None, None, false),
169 }
170 } else if let Some(user) = sqlx::query!(
171 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin
172 FROM users u
173 JOIN user_keys k ON u.id = k.user_id
174 WHERE u.did = $1",
175 did
176 )
177 .fetch_optional(db)
178 .await
179 .ok()
180 .flatten()
181 {
182 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
183 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
184
185 if let Some(c) = cache {
186 let _ = c
187 .set_bytes(
188 &key_cache_key,
189 &key,
190 Duration::from_secs(KEY_CACHE_TTL_SECS),
191 )
192 .await;
193 }
194
195 (
196 Some(key),
197 user.deactivated_at,
198 user.takedown_ref,
199 user.is_admin,
200 )
201 } else {
202 (None, None, None, false)
203 };
204
205 if let Some(decrypted_key) = decrypted_key {
206 if !allow_deactivated && deactivated_at.is_some() {
207 return Err(TokenValidationError::AccountDeactivated);
208 }
209
210 if !allow_takendown && takedown_ref.is_some() {
211 return Err(TokenValidationError::AccountTakedown);
212 }
213
214 match verify_access_token_typed(token, &decrypted_key) {
215 Ok(token_data) => {
216 let jti = &token_data.claims.jti;
217 let session_cache_key = format!("auth:session:{}:{}", did, jti);
218 let mut session_valid = false;
219
220 if let Some(c) = cache {
221 if let Some(cached_value) = c.get(&session_cache_key).await {
222 session_valid = cached_value == "1";
223 crate::metrics::record_auth_cache_hit("session");
224 } else {
225 crate::metrics::record_auth_cache_miss("session");
226 }
227 }
228
229 if !session_valid {
230 let session_row = sqlx::query!(
231 "SELECT access_expires_at FROM session_tokens WHERE did = $1 AND access_jti = $2",
232 did,
233 jti
234 )
235 .fetch_optional(db)
236 .await
237 .ok()
238 .flatten();
239
240 if let Some(row) = session_row {
241 if row.access_expires_at > chrono::Utc::now() {
242 session_valid = true;
243 if let Some(c) = cache {
244 let _ = c
245 .set(
246 &session_cache_key,
247 "1",
248 Duration::from_secs(SESSION_CACHE_TTL_SECS),
249 )
250 .await;
251 }
252 } else {
253 return Err(TokenValidationError::TokenExpired);
254 }
255 }
256 }
257
258 if session_valid {
259 let controller_did = token_data.claims.act.as_ref().map(|a| a.sub.clone());
260 return Ok(AuthenticatedUser {
261 did: did.clone(),
262 key_bytes: Some(decrypted_key),
263 is_oauth: false,
264 is_admin,
265 is_takendown: takedown_ref.is_some(),
266 scope: token_data.claims.scope.clone(),
267 controller_did,
268 });
269 }
270 }
271 Err(verify::TokenVerifyError::Expired) => {
272 return Err(TokenValidationError::TokenExpired);
273 }
274 Err(verify::TokenVerifyError::Invalid) => {}
275 }
276 }
277 }
278
279 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token)
280 && let Some(oauth_token) = sqlx::query!(
281 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin,
282 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
283 FROM oauth_token t
284 JOIN users u ON t.did = u.did
285 LEFT JOIN user_keys k ON u.id = k.user_id
286 WHERE t.token_id = $1"#,
287 oauth_info.token_id
288 )
289 .fetch_optional(db)
290 .await
291 .ok()
292 .flatten()
293 {
294 if !allow_deactivated && oauth_token.deactivated_at.is_some() {
295 return Err(TokenValidationError::AccountDeactivated);
296 }
297
298 let is_takendown = oauth_token.takedown_ref.is_some();
299 if !allow_takendown && is_takendown {
300 return Err(TokenValidationError::AccountTakedown);
301 }
302
303 let now = chrono::Utc::now();
304 if oauth_token.expires_at > now {
305 let key_bytes = if let (Some(kb), Some(ev)) =
306 (&oauth_token.key_bytes, oauth_token.encryption_version)
307 {
308 crate::config::decrypt_key(kb, Some(ev)).ok()
309 } else {
310 None
311 };
312 return Ok(AuthenticatedUser {
313 did: oauth_token.did,
314 key_bytes,
315 is_oauth: true,
316 is_admin: oauth_token.is_admin,
317 is_takendown,
318 scope: oauth_info.scope,
319 controller_did: oauth_info.controller_did,
320 });
321 } else {
322 return Err(TokenValidationError::TokenExpired);
323 }
324 }
325
326 Err(TokenValidationError::AuthenticationFailed)
327}
328
329pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) {
330 let key_cache_key = format!("auth:key:{}", did);
331 let _ = cache.delete(&key_cache_key).await;
332}
333
334pub async fn validate_token_with_dpop(
335 db: &PgPool,
336 token: &str,
337 is_dpop_token: bool,
338 dpop_proof: Option<&str>,
339 http_method: &str,
340 http_uri: &str,
341 allow_deactivated: bool,
342) -> Result<AuthenticatedUser, TokenValidationError> {
343 if !is_dpop_token {
344 if allow_deactivated {
345 return validate_bearer_token_allow_deactivated(db, token).await;
346 } else {
347 return validate_bearer_token(db, token).await;
348 }
349 }
350 match crate::oauth::verify::verify_oauth_access_token(
351 db,
352 token,
353 dpop_proof,
354 http_method,
355 http_uri,
356 )
357 .await
358 {
359 Ok(result) => {
360 let user_info = sqlx::query!(
361 r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin,
362 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
363 FROM users u
364 LEFT JOIN user_keys k ON u.id = k.user_id
365 WHERE u.did = $1"#,
366 result.did
367 )
368 .fetch_optional(db)
369 .await
370 .ok()
371 .flatten();
372 let Some(user_info) = user_info else {
373 return Err(TokenValidationError::AuthenticationFailed);
374 };
375 if !allow_deactivated && user_info.deactivated_at.is_some() {
376 return Err(TokenValidationError::AccountDeactivated);
377 }
378 let is_takendown = user_info.takedown_ref.is_some();
379 if is_takendown {
380 return Err(TokenValidationError::AccountTakedown);
381 }
382 let key_bytes = if let (Some(kb), Some(ev)) =
383 (&user_info.key_bytes, user_info.encryption_version)
384 {
385 crate::config::decrypt_key(kb, Some(ev)).ok()
386 } else {
387 None
388 };
389 Ok(AuthenticatedUser {
390 did: result.did,
391 key_bytes,
392 is_oauth: true,
393 is_admin: user_info.is_admin,
394 is_takendown,
395 scope: result.scope,
396 controller_did: None,
397 })
398 }
399 Err(crate::oauth::OAuthError::ExpiredToken(_)) => Err(TokenValidationError::TokenExpired),
400 Err(_) => Err(TokenValidationError::AuthenticationFailed),
401 }
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize)]
405pub struct ActClaim {
406 pub sub: String,
407}
408
409#[derive(Debug, Serialize, Deserialize)]
410pub struct Claims {
411 pub iss: String,
412 pub sub: String,
413 pub aud: String,
414 pub exp: usize,
415 pub iat: usize,
416 #[serde(skip_serializing_if = "Option::is_none")]
417 pub scope: Option<String>,
418 #[serde(skip_serializing_if = "Option::is_none")]
419 pub lxm: Option<String>,
420 pub jti: String,
421 #[serde(skip_serializing_if = "Option::is_none")]
422 pub act: Option<ActClaim>,
423}
424
425#[derive(Debug, Serialize, Deserialize)]
426pub struct Header {
427 pub alg: String,
428 pub typ: String,
429}
430
431#[derive(Debug, Serialize, Deserialize)]
432pub struct UnsafeClaims {
433 pub iss: String,
434 pub sub: Option<String>,
435}
436
437pub struct TokenData<T> {
438 pub claims: T,
439}