this repo has no description
1use serde::{Deserialize, Serialize};
2use sqlx::PgPool;
3use std::fmt;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::cache::Cache;
8use crate::oauth::scopes::ScopePermissions;
9
10pub mod extractor;
11pub mod scope_check;
12pub mod service;
13pub mod token;
14pub mod totp;
15pub mod verification_token;
16pub mod verify;
17pub mod webauthn;
18
19pub use extractor::{
20 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken,
21 extract_auth_token_from_header, extract_bearer_token_from_header,
22};
23pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token};
24pub use token::{
25 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS,
26 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token,
27 create_access_token_with_metadata, create_refresh_token, create_refresh_token_with_metadata,
28 create_service_token,
29};
30pub use verify::{
31 get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token,
32};
33
34const KEY_CACHE_TTL_SECS: u64 = 300;
35const SESSION_CACHE_TTL_SECS: u64 = 60;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum TokenValidationError {
39 AccountDeactivated,
40 AccountTakedown,
41 KeyDecryptionFailed,
42 AuthenticationFailed,
43}
44
45impl fmt::Display for TokenValidationError {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 Self::AccountDeactivated => write!(f, "AccountDeactivated"),
49 Self::AccountTakedown => write!(f, "AccountTakedown"),
50 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
51 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
52 }
53 }
54}
55
56pub struct AuthenticatedUser {
57 pub did: String,
58 pub key_bytes: Option<Vec<u8>>,
59 pub is_oauth: bool,
60 pub is_admin: bool,
61 pub scope: Option<String>,
62}
63
64impl AuthenticatedUser {
65 pub fn permissions(&self) -> ScopePermissions {
66 if !self.is_oauth {
67 return ScopePermissions::from_scope_string(Some("atproto"));
68 }
69 ScopePermissions::from_scope_string(self.scope.as_deref())
70 }
71}
72
73pub async fn validate_bearer_token(
74 db: &PgPool,
75 token: &str,
76) -> Result<AuthenticatedUser, TokenValidationError> {
77 validate_bearer_token_with_options_internal(db, None, token, false, false).await
78}
79
80pub async fn validate_bearer_token_allow_deactivated(
81 db: &PgPool,
82 token: &str,
83) -> Result<AuthenticatedUser, TokenValidationError> {
84 validate_bearer_token_with_options_internal(db, None, token, true, false).await
85}
86
87pub async fn validate_bearer_token_cached(
88 db: &PgPool,
89 cache: &Arc<dyn Cache>,
90 token: &str,
91) -> Result<AuthenticatedUser, TokenValidationError> {
92 validate_bearer_token_with_options_internal(db, Some(cache), token, false, false).await
93}
94
95pub async fn validate_bearer_token_cached_allow_deactivated(
96 db: &PgPool,
97 cache: &Arc<dyn Cache>,
98 token: &str,
99) -> Result<AuthenticatedUser, TokenValidationError> {
100 validate_bearer_token_with_options_internal(db, Some(cache), token, true, false).await
101}
102
103pub async fn validate_bearer_token_for_service_auth(
104 db: &PgPool,
105 token: &str,
106) -> Result<AuthenticatedUser, TokenValidationError> {
107 validate_bearer_token_with_options_internal(db, None, token, true, true).await
108}
109
110async fn validate_bearer_token_with_options_internal(
111 db: &PgPool,
112 cache: Option<&Arc<dyn Cache>>,
113 token: &str,
114 allow_deactivated: bool,
115 allow_takendown: bool,
116) -> Result<AuthenticatedUser, TokenValidationError> {
117 let did_from_token = get_did_from_token(token).ok();
118
119 if let Some(ref did) = did_from_token {
120 let key_cache_key = format!("auth:key:{}", did);
121 let mut cached_key: Option<Vec<u8>> = None;
122
123 if let Some(c) = cache {
124 cached_key = c.get_bytes(&key_cache_key).await;
125 if cached_key.is_some() {
126 crate::metrics::record_auth_cache_hit("key");
127 } else {
128 crate::metrics::record_auth_cache_miss("key");
129 }
130 }
131
132 let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key
133 {
134 let user_status = sqlx::query!(
135 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1",
136 did
137 )
138 .fetch_optional(db)
139 .await
140 .ok()
141 .flatten();
142
143 match user_status {
144 Some(status) => (
145 Some(key),
146 status.deactivated_at,
147 status.takedown_ref,
148 status.is_admin,
149 ),
150 None => (None, None, None, false),
151 }
152 } else if let Some(user) = sqlx::query!(
153 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin
154 FROM users u
155 JOIN user_keys k ON u.id = k.user_id
156 WHERE u.did = $1",
157 did
158 )
159 .fetch_optional(db)
160 .await
161 .ok()
162 .flatten()
163 {
164 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
165 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
166
167 if let Some(c) = cache {
168 let _ = c
169 .set_bytes(
170 &key_cache_key,
171 &key,
172 Duration::from_secs(KEY_CACHE_TTL_SECS),
173 )
174 .await;
175 }
176
177 (
178 Some(key),
179 user.deactivated_at,
180 user.takedown_ref,
181 user.is_admin,
182 )
183 } else {
184 (None, None, None, false)
185 };
186
187 if let Some(decrypted_key) = decrypted_key {
188 if !allow_deactivated && deactivated_at.is_some() {
189 return Err(TokenValidationError::AccountDeactivated);
190 }
191
192 if !allow_takendown && takedown_ref.is_some() {
193 return Err(TokenValidationError::AccountTakedown);
194 }
195
196 if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
197 let jti = &token_data.claims.jti;
198 let session_cache_key = format!("auth:session:{}:{}", did, jti);
199 let mut session_valid = false;
200
201 if let Some(c) = cache {
202 if let Some(cached_value) = c.get(&session_cache_key).await {
203 session_valid = cached_value == "1";
204 crate::metrics::record_auth_cache_hit("session");
205 } else {
206 crate::metrics::record_auth_cache_miss("session");
207 }
208 }
209
210 if !session_valid {
211 let session_exists = sqlx::query_scalar!(
212 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()",
213 did,
214 jti
215 )
216 .fetch_optional(db)
217 .await
218 .ok()
219 .flatten();
220
221 session_valid = session_exists.is_some();
222
223 if session_valid && let Some(c) = cache {
224 let _ = c
225 .set(
226 &session_cache_key,
227 "1",
228 Duration::from_secs(SESSION_CACHE_TTL_SECS),
229 )
230 .await;
231 }
232 }
233
234 if session_valid {
235 return Ok(AuthenticatedUser {
236 did: did.clone(),
237 key_bytes: Some(decrypted_key),
238 is_oauth: false,
239 is_admin,
240 scope: None,
241 });
242 }
243 }
244 }
245 }
246
247 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token)
248 && let Some(oauth_token) = sqlx::query!(
249 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin,
250 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
251 FROM oauth_token t
252 JOIN users u ON t.did = u.did
253 LEFT JOIN user_keys k ON u.id = k.user_id
254 WHERE t.token_id = $1"#,
255 oauth_info.token_id
256 )
257 .fetch_optional(db)
258 .await
259 .ok()
260 .flatten()
261 {
262 if !allow_deactivated && oauth_token.deactivated_at.is_some() {
263 return Err(TokenValidationError::AccountDeactivated);
264 }
265
266 if oauth_token.takedown_ref.is_some() {
267 return Err(TokenValidationError::AccountTakedown);
268 }
269
270 let now = chrono::Utc::now();
271 if oauth_token.expires_at > now {
272 let key_bytes = if let (Some(kb), Some(ev)) =
273 (&oauth_token.key_bytes, oauth_token.encryption_version)
274 {
275 crate::config::decrypt_key(kb, Some(ev)).ok()
276 } else {
277 None
278 };
279 return Ok(AuthenticatedUser {
280 did: oauth_token.did,
281 key_bytes,
282 is_oauth: true,
283 is_admin: oauth_token.is_admin,
284 scope: oauth_info.scope,
285 });
286 }
287 }
288
289 Err(TokenValidationError::AuthenticationFailed)
290}
291
292pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) {
293 let key_cache_key = format!("auth:key:{}", did);
294 let _ = cache.delete(&key_cache_key).await;
295}
296
297pub async fn validate_token_with_dpop(
298 db: &PgPool,
299 token: &str,
300 is_dpop_token: bool,
301 dpop_proof: Option<&str>,
302 http_method: &str,
303 http_uri: &str,
304 allow_deactivated: bool,
305) -> Result<AuthenticatedUser, TokenValidationError> {
306 if !is_dpop_token {
307 if allow_deactivated {
308 return validate_bearer_token_allow_deactivated(db, token).await;
309 } else {
310 return validate_bearer_token(db, token).await;
311 }
312 }
313 match crate::oauth::verify::verify_oauth_access_token(
314 db,
315 token,
316 dpop_proof,
317 http_method,
318 http_uri,
319 )
320 .await
321 {
322 Ok(result) => {
323 let user_info = sqlx::query!(
324 r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin,
325 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
326 FROM users u
327 LEFT JOIN user_keys k ON u.id = k.user_id
328 WHERE u.did = $1"#,
329 result.did
330 )
331 .fetch_optional(db)
332 .await
333 .ok()
334 .flatten();
335 let Some(user_info) = user_info else {
336 return Err(TokenValidationError::AuthenticationFailed);
337 };
338 if !allow_deactivated && user_info.deactivated_at.is_some() {
339 return Err(TokenValidationError::AccountDeactivated);
340 }
341 if user_info.takedown_ref.is_some() {
342 return Err(TokenValidationError::AccountTakedown);
343 }
344 let key_bytes = if let (Some(kb), Some(ev)) =
345 (&user_info.key_bytes, user_info.encryption_version)
346 {
347 crate::config::decrypt_key(kb, Some(ev)).ok()
348 } else {
349 None
350 };
351 Ok(AuthenticatedUser {
352 did: result.did,
353 key_bytes,
354 is_oauth: true,
355 is_admin: user_info.is_admin,
356 scope: result.scope,
357 })
358 }
359 Err(_) => Err(TokenValidationError::AuthenticationFailed),
360 }
361}
362
363#[derive(Debug, Serialize, Deserialize)]
364pub struct Claims {
365 pub iss: String,
366 pub sub: String,
367 pub aud: String,
368 pub exp: usize,
369 pub iat: usize,
370 #[serde(skip_serializing_if = "Option::is_none")]
371 pub scope: Option<String>,
372 #[serde(skip_serializing_if = "Option::is_none")]
373 pub lxm: Option<String>,
374 pub jti: String,
375}
376
377#[derive(Debug, Serialize, Deserialize)]
378pub struct Header {
379 pub alg: String,
380 pub typ: String,
381}
382
383#[derive(Debug, Serialize, Deserialize)]
384pub struct UnsafeClaims {
385 pub iss: String,
386 pub sub: Option<String>,
387}
388
389pub struct TokenData<T> {
390 pub claims: T,
391}