this repo has no description
1use serde::{Deserialize, Serialize};
2use sqlx::PgPool;
3use std::fmt;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::cache::Cache;
8use crate::oauth::scopes::ScopePermissions;
9
10pub mod extractor;
11pub mod scope_check;
12pub mod service;
13pub mod token;
14pub mod totp;
15pub mod verification_token;
16pub mod verify;
17pub mod webauthn;
18
19pub use extractor::{
20 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken,
21 extract_auth_token_from_header, extract_bearer_token_from_header,
22};
23pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token};
24pub use token::{
25 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS,
26 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token,
27 create_access_token_with_metadata, create_refresh_token, create_refresh_token_with_metadata,
28 create_service_token,
29};
30pub use verify::{
31 TokenVerifyError, get_did_from_token, get_jti_from_token, verify_access_token,
32 verify_access_token_typed, verify_refresh_token, verify_token,
33};
34
35const KEY_CACHE_TTL_SECS: u64 = 300;
36const SESSION_CACHE_TTL_SECS: u64 = 60;
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum TokenValidationError {
40 AccountDeactivated,
41 AccountTakedown,
42 KeyDecryptionFailed,
43 AuthenticationFailed,
44 TokenExpired,
45}
46
47impl fmt::Display for TokenValidationError {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 match self {
50 Self::AccountDeactivated => write!(f, "AccountDeactivated"),
51 Self::AccountTakedown => write!(f, "AccountTakedown"),
52 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
53 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
54 Self::TokenExpired => write!(f, "ExpiredToken"),
55 }
56 }
57}
58
59pub struct AuthenticatedUser {
60 pub did: String,
61 pub key_bytes: Option<Vec<u8>>,
62 pub is_oauth: bool,
63 pub is_admin: bool,
64 pub scope: Option<String>,
65}
66
67impl AuthenticatedUser {
68 pub fn permissions(&self) -> ScopePermissions {
69 if !self.is_oauth {
70 return ScopePermissions::from_scope_string(Some("atproto"));
71 }
72 ScopePermissions::from_scope_string(self.scope.as_deref())
73 }
74}
75
76pub async fn validate_bearer_token(
77 db: &PgPool,
78 token: &str,
79) -> Result<AuthenticatedUser, TokenValidationError> {
80 validate_bearer_token_with_options_internal(db, None, token, false, false).await
81}
82
83pub async fn validate_bearer_token_allow_deactivated(
84 db: &PgPool,
85 token: &str,
86) -> Result<AuthenticatedUser, TokenValidationError> {
87 validate_bearer_token_with_options_internal(db, None, token, true, false).await
88}
89
90pub async fn validate_bearer_token_cached(
91 db: &PgPool,
92 cache: &Arc<dyn Cache>,
93 token: &str,
94) -> Result<AuthenticatedUser, TokenValidationError> {
95 validate_bearer_token_with_options_internal(db, Some(cache), token, false, false).await
96}
97
98pub async fn validate_bearer_token_cached_allow_deactivated(
99 db: &PgPool,
100 cache: &Arc<dyn Cache>,
101 token: &str,
102) -> Result<AuthenticatedUser, TokenValidationError> {
103 validate_bearer_token_with_options_internal(db, Some(cache), token, true, false).await
104}
105
106pub async fn validate_bearer_token_for_service_auth(
107 db: &PgPool,
108 token: &str,
109) -> Result<AuthenticatedUser, TokenValidationError> {
110 validate_bearer_token_with_options_internal(db, None, token, true, true).await
111}
112
113async fn validate_bearer_token_with_options_internal(
114 db: &PgPool,
115 cache: Option<&Arc<dyn Cache>>,
116 token: &str,
117 allow_deactivated: bool,
118 allow_takendown: bool,
119) -> Result<AuthenticatedUser, TokenValidationError> {
120 let did_from_token = get_did_from_token(token).ok();
121
122 if let Some(ref did) = did_from_token {
123 let key_cache_key = format!("auth:key:{}", did);
124 let mut cached_key: Option<Vec<u8>> = None;
125
126 if let Some(c) = cache {
127 cached_key = c.get_bytes(&key_cache_key).await;
128 if cached_key.is_some() {
129 crate::metrics::record_auth_cache_hit("key");
130 } else {
131 crate::metrics::record_auth_cache_miss("key");
132 }
133 }
134
135 let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key
136 {
137 let user_status = sqlx::query!(
138 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1",
139 did
140 )
141 .fetch_optional(db)
142 .await
143 .ok()
144 .flatten();
145
146 match user_status {
147 Some(status) => (
148 Some(key),
149 status.deactivated_at,
150 status.takedown_ref,
151 status.is_admin,
152 ),
153 None => (None, None, None, false),
154 }
155 } else if let Some(user) = sqlx::query!(
156 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin
157 FROM users u
158 JOIN user_keys k ON u.id = k.user_id
159 WHERE u.did = $1",
160 did
161 )
162 .fetch_optional(db)
163 .await
164 .ok()
165 .flatten()
166 {
167 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
168 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
169
170 if let Some(c) = cache {
171 let _ = c
172 .set_bytes(
173 &key_cache_key,
174 &key,
175 Duration::from_secs(KEY_CACHE_TTL_SECS),
176 )
177 .await;
178 }
179
180 (
181 Some(key),
182 user.deactivated_at,
183 user.takedown_ref,
184 user.is_admin,
185 )
186 } else {
187 (None, None, None, false)
188 };
189
190 if let Some(decrypted_key) = decrypted_key {
191 if !allow_deactivated && deactivated_at.is_some() {
192 return Err(TokenValidationError::AccountDeactivated);
193 }
194
195 if !allow_takendown && takedown_ref.is_some() {
196 return Err(TokenValidationError::AccountTakedown);
197 }
198
199 match verify_access_token_typed(token, &decrypted_key) {
200 Ok(token_data) => {
201 let jti = &token_data.claims.jti;
202 let session_cache_key = format!("auth:session:{}:{}", did, jti);
203 let mut session_valid = false;
204
205 if let Some(c) = cache {
206 if let Some(cached_value) = c.get(&session_cache_key).await {
207 session_valid = cached_value == "1";
208 crate::metrics::record_auth_cache_hit("session");
209 } else {
210 crate::metrics::record_auth_cache_miss("session");
211 }
212 }
213
214 if !session_valid {
215 let session_exists = sqlx::query_scalar!(
216 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()",
217 did,
218 jti
219 )
220 .fetch_optional(db)
221 .await
222 .ok()
223 .flatten();
224
225 session_valid = session_exists.is_some();
226
227 if session_valid && let Some(c) = cache {
228 let _ = c
229 .set(
230 &session_cache_key,
231 "1",
232 Duration::from_secs(SESSION_CACHE_TTL_SECS),
233 )
234 .await;
235 }
236 }
237
238 if session_valid {
239 return Ok(AuthenticatedUser {
240 did: did.clone(),
241 key_bytes: Some(decrypted_key),
242 is_oauth: false,
243 is_admin,
244 scope: None,
245 });
246 }
247 }
248 Err(verify::TokenVerifyError::Expired) => {
249 return Err(TokenValidationError::TokenExpired);
250 }
251 Err(verify::TokenVerifyError::Invalid) => {}
252 }
253 }
254 }
255
256 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token)
257 && let Some(oauth_token) = sqlx::query!(
258 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin,
259 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
260 FROM oauth_token t
261 JOIN users u ON t.did = u.did
262 LEFT JOIN user_keys k ON u.id = k.user_id
263 WHERE t.token_id = $1"#,
264 oauth_info.token_id
265 )
266 .fetch_optional(db)
267 .await
268 .ok()
269 .flatten()
270 {
271 if !allow_deactivated && oauth_token.deactivated_at.is_some() {
272 return Err(TokenValidationError::AccountDeactivated);
273 }
274
275 if oauth_token.takedown_ref.is_some() {
276 return Err(TokenValidationError::AccountTakedown);
277 }
278
279 let now = chrono::Utc::now();
280 if oauth_token.expires_at > now {
281 let key_bytes = if let (Some(kb), Some(ev)) =
282 (&oauth_token.key_bytes, oauth_token.encryption_version)
283 {
284 crate::config::decrypt_key(kb, Some(ev)).ok()
285 } else {
286 None
287 };
288 return Ok(AuthenticatedUser {
289 did: oauth_token.did,
290 key_bytes,
291 is_oauth: true,
292 is_admin: oauth_token.is_admin,
293 scope: oauth_info.scope,
294 });
295 } else {
296 return Err(TokenValidationError::TokenExpired);
297 }
298 }
299
300 Err(TokenValidationError::AuthenticationFailed)
301}
302
303pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) {
304 let key_cache_key = format!("auth:key:{}", did);
305 let _ = cache.delete(&key_cache_key).await;
306}
307
308pub async fn validate_token_with_dpop(
309 db: &PgPool,
310 token: &str,
311 is_dpop_token: bool,
312 dpop_proof: Option<&str>,
313 http_method: &str,
314 http_uri: &str,
315 allow_deactivated: bool,
316) -> Result<AuthenticatedUser, TokenValidationError> {
317 if !is_dpop_token {
318 if allow_deactivated {
319 return validate_bearer_token_allow_deactivated(db, token).await;
320 } else {
321 return validate_bearer_token(db, token).await;
322 }
323 }
324 match crate::oauth::verify::verify_oauth_access_token(
325 db,
326 token,
327 dpop_proof,
328 http_method,
329 http_uri,
330 )
331 .await
332 {
333 Ok(result) => {
334 let user_info = sqlx::query!(
335 r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin,
336 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
337 FROM users u
338 LEFT JOIN user_keys k ON u.id = k.user_id
339 WHERE u.did = $1"#,
340 result.did
341 )
342 .fetch_optional(db)
343 .await
344 .ok()
345 .flatten();
346 let Some(user_info) = user_info else {
347 return Err(TokenValidationError::AuthenticationFailed);
348 };
349 if !allow_deactivated && user_info.deactivated_at.is_some() {
350 return Err(TokenValidationError::AccountDeactivated);
351 }
352 if user_info.takedown_ref.is_some() {
353 return Err(TokenValidationError::AccountTakedown);
354 }
355 let key_bytes = if let (Some(kb), Some(ev)) =
356 (&user_info.key_bytes, user_info.encryption_version)
357 {
358 crate::config::decrypt_key(kb, Some(ev)).ok()
359 } else {
360 None
361 };
362 Ok(AuthenticatedUser {
363 did: result.did,
364 key_bytes,
365 is_oauth: true,
366 is_admin: user_info.is_admin,
367 scope: result.scope,
368 })
369 }
370 Err(_) => Err(TokenValidationError::AuthenticationFailed),
371 }
372}
373
374#[derive(Debug, Serialize, Deserialize)]
375pub struct Claims {
376 pub iss: String,
377 pub sub: String,
378 pub aud: String,
379 pub exp: usize,
380 pub iat: usize,
381 #[serde(skip_serializing_if = "Option::is_none")]
382 pub scope: Option<String>,
383 #[serde(skip_serializing_if = "Option::is_none")]
384 pub lxm: Option<String>,
385 pub jti: String,
386}
387
388#[derive(Debug, Serialize, Deserialize)]
389pub struct Header {
390 pub alg: String,
391 pub typ: String,
392}
393
394#[derive(Debug, Serialize, Deserialize)]
395pub struct UnsafeClaims {
396 pub iss: String,
397 pub sub: Option<String>,
398}
399
400pub struct TokenData<T> {
401 pub claims: T,
402}