this repo has no description
1use serde::{Deserialize, Serialize};
2use sqlx::PgPool;
3use std::fmt;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::cache::Cache;
8use crate::oauth::scopes::ScopePermissions;
9
10pub mod extractor;
11pub mod scope_check;
12pub mod service;
13pub mod token;
14pub mod totp;
15pub mod verify;
16pub mod webauthn;
17
18pub use extractor::{
19 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken,
20 extract_auth_token_from_header, extract_bearer_token_from_header,
21};
22pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token};
23pub use token::{
24 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS,
25 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token,
26 create_access_token_with_metadata, create_refresh_token, create_refresh_token_with_metadata,
27 create_service_token,
28};
29pub use verify::{
30 get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token,
31};
32
33const KEY_CACHE_TTL_SECS: u64 = 300;
34const SESSION_CACHE_TTL_SECS: u64 = 60;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum TokenValidationError {
38 AccountDeactivated,
39 AccountTakedown,
40 KeyDecryptionFailed,
41 AuthenticationFailed,
42}
43
44impl fmt::Display for TokenValidationError {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 match self {
47 Self::AccountDeactivated => write!(f, "AccountDeactivated"),
48 Self::AccountTakedown => write!(f, "AccountTakedown"),
49 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
50 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
51 }
52 }
53}
54
55pub struct AuthenticatedUser {
56 pub did: String,
57 pub key_bytes: Option<Vec<u8>>,
58 pub is_oauth: bool,
59 pub is_admin: bool,
60 pub scope: Option<String>,
61}
62
63impl AuthenticatedUser {
64 pub fn permissions(&self) -> ScopePermissions {
65 if !self.is_oauth {
66 return ScopePermissions::from_scope_string(Some("atproto"));
67 }
68 ScopePermissions::from_scope_string(self.scope.as_deref())
69 }
70}
71
72pub async fn validate_bearer_token(
73 db: &PgPool,
74 token: &str,
75) -> Result<AuthenticatedUser, TokenValidationError> {
76 validate_bearer_token_with_options_internal(db, None, token, false, false).await
77}
78
79pub async fn validate_bearer_token_allow_deactivated(
80 db: &PgPool,
81 token: &str,
82) -> Result<AuthenticatedUser, TokenValidationError> {
83 validate_bearer_token_with_options_internal(db, None, token, true, false).await
84}
85
86pub async fn validate_bearer_token_cached(
87 db: &PgPool,
88 cache: &Arc<dyn Cache>,
89 token: &str,
90) -> Result<AuthenticatedUser, TokenValidationError> {
91 validate_bearer_token_with_options_internal(db, Some(cache), token, false, false).await
92}
93
94pub async fn validate_bearer_token_cached_allow_deactivated(
95 db: &PgPool,
96 cache: &Arc<dyn Cache>,
97 token: &str,
98) -> Result<AuthenticatedUser, TokenValidationError> {
99 validate_bearer_token_with_options_internal(db, Some(cache), token, true, false).await
100}
101
102pub async fn validate_bearer_token_for_service_auth(
103 db: &PgPool,
104 token: &str,
105) -> Result<AuthenticatedUser, TokenValidationError> {
106 validate_bearer_token_with_options_internal(db, None, token, true, true).await
107}
108
109async fn validate_bearer_token_with_options_internal(
110 db: &PgPool,
111 cache: Option<&Arc<dyn Cache>>,
112 token: &str,
113 allow_deactivated: bool,
114 allow_takendown: bool,
115) -> Result<AuthenticatedUser, TokenValidationError> {
116 let did_from_token = get_did_from_token(token).ok();
117
118 if let Some(ref did) = did_from_token {
119 let key_cache_key = format!("auth:key:{}", did);
120 let mut cached_key: Option<Vec<u8>> = None;
121
122 if let Some(c) = cache {
123 cached_key = c.get_bytes(&key_cache_key).await;
124 if cached_key.is_some() {
125 crate::metrics::record_auth_cache_hit("key");
126 } else {
127 crate::metrics::record_auth_cache_miss("key");
128 }
129 }
130
131 let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key
132 {
133 let user_status = sqlx::query!(
134 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1",
135 did
136 )
137 .fetch_optional(db)
138 .await
139 .ok()
140 .flatten();
141
142 match user_status {
143 Some(status) => (
144 Some(key),
145 status.deactivated_at,
146 status.takedown_ref,
147 status.is_admin,
148 ),
149 None => (None, None, None, false),
150 }
151 } else if let Some(user) = sqlx::query!(
152 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin
153 FROM users u
154 JOIN user_keys k ON u.id = k.user_id
155 WHERE u.did = $1",
156 did
157 )
158 .fetch_optional(db)
159 .await
160 .ok()
161 .flatten()
162 {
163 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
164 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
165
166 if let Some(c) = cache {
167 let _ = c
168 .set_bytes(
169 &key_cache_key,
170 &key,
171 Duration::from_secs(KEY_CACHE_TTL_SECS),
172 )
173 .await;
174 }
175
176 (
177 Some(key),
178 user.deactivated_at,
179 user.takedown_ref,
180 user.is_admin,
181 )
182 } else {
183 (None, None, None, false)
184 };
185
186 if let Some(decrypted_key) = decrypted_key {
187 if !allow_deactivated && deactivated_at.is_some() {
188 return Err(TokenValidationError::AccountDeactivated);
189 }
190
191 if !allow_takendown && takedown_ref.is_some() {
192 return Err(TokenValidationError::AccountTakedown);
193 }
194
195 if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
196 let jti = &token_data.claims.jti;
197 let session_cache_key = format!("auth:session:{}:{}", did, jti);
198 let mut session_valid = false;
199
200 if let Some(c) = cache {
201 if let Some(cached_value) = c.get(&session_cache_key).await {
202 session_valid = cached_value == "1";
203 crate::metrics::record_auth_cache_hit("session");
204 } else {
205 crate::metrics::record_auth_cache_miss("session");
206 }
207 }
208
209 if !session_valid {
210 let session_exists = sqlx::query_scalar!(
211 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()",
212 did,
213 jti
214 )
215 .fetch_optional(db)
216 .await
217 .ok()
218 .flatten();
219
220 session_valid = session_exists.is_some();
221
222 if session_valid && let Some(c) = cache {
223 let _ = c
224 .set(
225 &session_cache_key,
226 "1",
227 Duration::from_secs(SESSION_CACHE_TTL_SECS),
228 )
229 .await;
230 }
231 }
232
233 if session_valid {
234 return Ok(AuthenticatedUser {
235 did: did.clone(),
236 key_bytes: Some(decrypted_key),
237 is_oauth: false,
238 is_admin,
239 scope: None,
240 });
241 }
242 }
243 }
244 }
245
246 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token)
247 && let Some(oauth_token) = sqlx::query!(
248 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin,
249 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
250 FROM oauth_token t
251 JOIN users u ON t.did = u.did
252 LEFT JOIN user_keys k ON u.id = k.user_id
253 WHERE t.token_id = $1"#,
254 oauth_info.token_id
255 )
256 .fetch_optional(db)
257 .await
258 .ok()
259 .flatten()
260 {
261 if !allow_deactivated && oauth_token.deactivated_at.is_some() {
262 return Err(TokenValidationError::AccountDeactivated);
263 }
264
265 if oauth_token.takedown_ref.is_some() {
266 return Err(TokenValidationError::AccountTakedown);
267 }
268
269 let now = chrono::Utc::now();
270 if oauth_token.expires_at > now {
271 let key_bytes = if let (Some(kb), Some(ev)) =
272 (&oauth_token.key_bytes, oauth_token.encryption_version)
273 {
274 crate::config::decrypt_key(kb, Some(ev)).ok()
275 } else {
276 None
277 };
278 return Ok(AuthenticatedUser {
279 did: oauth_token.did,
280 key_bytes,
281 is_oauth: true,
282 is_admin: oauth_token.is_admin,
283 scope: oauth_info.scope,
284 });
285 }
286 }
287
288 Err(TokenValidationError::AuthenticationFailed)
289}
290
291pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) {
292 let key_cache_key = format!("auth:key:{}", did);
293 let _ = cache.delete(&key_cache_key).await;
294}
295
296pub async fn validate_token_with_dpop(
297 db: &PgPool,
298 token: &str,
299 is_dpop_token: bool,
300 dpop_proof: Option<&str>,
301 http_method: &str,
302 http_uri: &str,
303 allow_deactivated: bool,
304) -> Result<AuthenticatedUser, TokenValidationError> {
305 if !is_dpop_token {
306 if allow_deactivated {
307 return validate_bearer_token_allow_deactivated(db, token).await;
308 } else {
309 return validate_bearer_token(db, token).await;
310 }
311 }
312 match crate::oauth::verify::verify_oauth_access_token(
313 db,
314 token,
315 dpop_proof,
316 http_method,
317 http_uri,
318 )
319 .await
320 {
321 Ok(result) => {
322 let user_info = sqlx::query!(
323 r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin,
324 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
325 FROM users u
326 LEFT JOIN user_keys k ON u.id = k.user_id
327 WHERE u.did = $1"#,
328 result.did
329 )
330 .fetch_optional(db)
331 .await
332 .ok()
333 .flatten();
334 let Some(user_info) = user_info else {
335 return Err(TokenValidationError::AuthenticationFailed);
336 };
337 if !allow_deactivated && user_info.deactivated_at.is_some() {
338 return Err(TokenValidationError::AccountDeactivated);
339 }
340 if user_info.takedown_ref.is_some() {
341 return Err(TokenValidationError::AccountTakedown);
342 }
343 let key_bytes = if let (Some(kb), Some(ev)) =
344 (&user_info.key_bytes, user_info.encryption_version)
345 {
346 crate::config::decrypt_key(kb, Some(ev)).ok()
347 } else {
348 None
349 };
350 Ok(AuthenticatedUser {
351 did: result.did,
352 key_bytes,
353 is_oauth: true,
354 is_admin: user_info.is_admin,
355 scope: result.scope,
356 })
357 }
358 Err(_) => Err(TokenValidationError::AuthenticationFailed),
359 }
360}
361
362#[derive(Debug, Serialize, Deserialize)]
363pub struct Claims {
364 pub iss: String,
365 pub sub: String,
366 pub aud: String,
367 pub exp: usize,
368 pub iat: usize,
369 #[serde(skip_serializing_if = "Option::is_none")]
370 pub scope: Option<String>,
371 #[serde(skip_serializing_if = "Option::is_none")]
372 pub lxm: Option<String>,
373 pub jti: String,
374}
375
376#[derive(Debug, Serialize, Deserialize)]
377pub struct Header {
378 pub alg: String,
379 pub typ: String,
380}
381
382#[derive(Debug, Serialize, Deserialize)]
383pub struct UnsafeClaims {
384 pub iss: String,
385 pub sub: Option<String>,
386}
387
388pub struct TokenData<T> {
389 pub claims: T,
390}