this repo has no description
1use serde::{Deserialize, Serialize};
2use sqlx::PgPool;
3use std::fmt;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::cache::Cache;
8use crate::oauth::scopes::ScopePermissions;
9
10pub mod extractor;
11pub mod scope_check;
12pub mod service;
13pub mod token;
14pub mod verify;
15
16pub use extractor::{
17 AuthError, BearerAuth, BearerAuthAdmin, BearerAuthAllowDeactivated, ExtractedToken,
18 extract_auth_token_from_header, extract_bearer_token_from_header,
19};
20pub use service::{ServiceTokenClaims, ServiceTokenVerifier, is_service_token};
21pub use token::{
22 SCOPE_ACCESS, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED, SCOPE_REFRESH, TOKEN_TYPE_ACCESS,
23 TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE, TokenWithMetadata, create_access_token,
24 create_access_token_with_metadata, create_refresh_token, create_refresh_token_with_metadata,
25 create_service_token,
26};
27pub use verify::{
28 get_did_from_token, get_jti_from_token, verify_access_token, verify_refresh_token, verify_token,
29};
30
31const KEY_CACHE_TTL_SECS: u64 = 300;
32const SESSION_CACHE_TTL_SECS: u64 = 60;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum TokenValidationError {
36 AccountDeactivated,
37 AccountTakedown,
38 KeyDecryptionFailed,
39 AuthenticationFailed,
40}
41
42impl fmt::Display for TokenValidationError {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 match self {
45 Self::AccountDeactivated => write!(f, "AccountDeactivated"),
46 Self::AccountTakedown => write!(f, "AccountTakedown"),
47 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
48 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
49 }
50 }
51}
52
53pub struct AuthenticatedUser {
54 pub did: String,
55 pub key_bytes: Option<Vec<u8>>,
56 pub is_oauth: bool,
57 pub is_admin: bool,
58 pub scope: Option<String>,
59}
60
61impl AuthenticatedUser {
62 pub fn permissions(&self) -> ScopePermissions {
63 if !self.is_oauth {
64 return ScopePermissions::from_scope_string(Some("atproto"));
65 }
66 ScopePermissions::from_scope_string(self.scope.as_deref())
67 }
68}
69
70pub async fn validate_bearer_token(
71 db: &PgPool,
72 token: &str,
73) -> Result<AuthenticatedUser, TokenValidationError> {
74 validate_bearer_token_with_options_internal(db, None, token, false, false).await
75}
76
77pub async fn validate_bearer_token_allow_deactivated(
78 db: &PgPool,
79 token: &str,
80) -> Result<AuthenticatedUser, TokenValidationError> {
81 validate_bearer_token_with_options_internal(db, None, token, true, false).await
82}
83
84pub async fn validate_bearer_token_cached(
85 db: &PgPool,
86 cache: &Arc<dyn Cache>,
87 token: &str,
88) -> Result<AuthenticatedUser, TokenValidationError> {
89 validate_bearer_token_with_options_internal(db, Some(cache), token, false, false).await
90}
91
92pub async fn validate_bearer_token_cached_allow_deactivated(
93 db: &PgPool,
94 cache: &Arc<dyn Cache>,
95 token: &str,
96) -> Result<AuthenticatedUser, TokenValidationError> {
97 validate_bearer_token_with_options_internal(db, Some(cache), token, true, false).await
98}
99
100pub async fn validate_bearer_token_for_service_auth(
101 db: &PgPool,
102 token: &str,
103) -> Result<AuthenticatedUser, TokenValidationError> {
104 validate_bearer_token_with_options_internal(db, None, token, true, true).await
105}
106
107async fn validate_bearer_token_with_options_internal(
108 db: &PgPool,
109 cache: Option<&Arc<dyn Cache>>,
110 token: &str,
111 allow_deactivated: bool,
112 allow_takendown: bool,
113) -> Result<AuthenticatedUser, TokenValidationError> {
114 let did_from_token = get_did_from_token(token).ok();
115
116 if let Some(ref did) = did_from_token {
117 let key_cache_key = format!("auth:key:{}", did);
118 let mut cached_key: Option<Vec<u8>> = None;
119
120 if let Some(c) = cache {
121 cached_key = c.get_bytes(&key_cache_key).await;
122 if cached_key.is_some() {
123 crate::metrics::record_auth_cache_hit("key");
124 } else {
125 crate::metrics::record_auth_cache_miss("key");
126 }
127 }
128
129 let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key
130 {
131 let user_status = sqlx::query!(
132 "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1",
133 did
134 )
135 .fetch_optional(db)
136 .await
137 .ok()
138 .flatten();
139
140 match user_status {
141 Some(status) => (
142 Some(key),
143 status.deactivated_at,
144 status.takedown_ref,
145 status.is_admin,
146 ),
147 None => (None, None, None, false),
148 }
149 } else if let Some(user) = sqlx::query!(
150 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin
151 FROM users u
152 JOIN user_keys k ON u.id = k.user_id
153 WHERE u.did = $1",
154 did
155 )
156 .fetch_optional(db)
157 .await
158 .ok()
159 .flatten()
160 {
161 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
162 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
163
164 if let Some(c) = cache {
165 let _ = c
166 .set_bytes(
167 &key_cache_key,
168 &key,
169 Duration::from_secs(KEY_CACHE_TTL_SECS),
170 )
171 .await;
172 }
173
174 (
175 Some(key),
176 user.deactivated_at,
177 user.takedown_ref,
178 user.is_admin,
179 )
180 } else {
181 (None, None, None, false)
182 };
183
184 if let Some(decrypted_key) = decrypted_key {
185 if !allow_deactivated && deactivated_at.is_some() {
186 return Err(TokenValidationError::AccountDeactivated);
187 }
188
189 if !allow_takendown && takedown_ref.is_some() {
190 return Err(TokenValidationError::AccountTakedown);
191 }
192
193 if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
194 let jti = &token_data.claims.jti;
195 let session_cache_key = format!("auth:session:{}:{}", did, jti);
196 let mut session_valid = false;
197
198 if let Some(c) = cache {
199 if let Some(cached_value) = c.get(&session_cache_key).await {
200 session_valid = cached_value == "1";
201 crate::metrics::record_auth_cache_hit("session");
202 } else {
203 crate::metrics::record_auth_cache_miss("session");
204 }
205 }
206
207 if !session_valid {
208 let session_exists = sqlx::query_scalar!(
209 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()",
210 did,
211 jti
212 )
213 .fetch_optional(db)
214 .await
215 .ok()
216 .flatten();
217
218 session_valid = session_exists.is_some();
219
220 if session_valid && let Some(c) = cache {
221 let _ = c
222 .set(
223 &session_cache_key,
224 "1",
225 Duration::from_secs(SESSION_CACHE_TTL_SECS),
226 )
227 .await;
228 }
229 }
230
231 if session_valid {
232 return Ok(AuthenticatedUser {
233 did: did.clone(),
234 key_bytes: Some(decrypted_key),
235 is_oauth: false,
236 is_admin,
237 scope: None,
238 });
239 }
240 }
241 }
242 }
243
244 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token)
245 && let Some(oauth_token) = sqlx::query!(
246 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin,
247 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
248 FROM oauth_token t
249 JOIN users u ON t.did = u.did
250 LEFT JOIN user_keys k ON u.id = k.user_id
251 WHERE t.token_id = $1"#,
252 oauth_info.token_id
253 )
254 .fetch_optional(db)
255 .await
256 .ok()
257 .flatten()
258 {
259 if !allow_deactivated && oauth_token.deactivated_at.is_some() {
260 return Err(TokenValidationError::AccountDeactivated);
261 }
262
263 if oauth_token.takedown_ref.is_some() {
264 return Err(TokenValidationError::AccountTakedown);
265 }
266
267 let now = chrono::Utc::now();
268 if oauth_token.expires_at > now {
269 let key_bytes = if let (Some(kb), Some(ev)) =
270 (&oauth_token.key_bytes, oauth_token.encryption_version)
271 {
272 crate::config::decrypt_key(kb, Some(ev)).ok()
273 } else {
274 None
275 };
276 return Ok(AuthenticatedUser {
277 did: oauth_token.did,
278 key_bytes,
279 is_oauth: true,
280 is_admin: oauth_token.is_admin,
281 scope: oauth_info.scope,
282 });
283 }
284 }
285
286 Err(TokenValidationError::AuthenticationFailed)
287}
288
289pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) {
290 let key_cache_key = format!("auth:key:{}", did);
291 let _ = cache.delete(&key_cache_key).await;
292}
293
294pub async fn validate_token_with_dpop(
295 db: &PgPool,
296 token: &str,
297 is_dpop_token: bool,
298 dpop_proof: Option<&str>,
299 http_method: &str,
300 http_uri: &str,
301 allow_deactivated: bool,
302) -> Result<AuthenticatedUser, TokenValidationError> {
303 if !is_dpop_token {
304 if allow_deactivated {
305 return validate_bearer_token_allow_deactivated(db, token).await;
306 } else {
307 return validate_bearer_token(db, token).await;
308 }
309 }
310 match crate::oauth::verify::verify_oauth_access_token(
311 db,
312 token,
313 dpop_proof,
314 http_method,
315 http_uri,
316 )
317 .await
318 {
319 Ok(result) => {
320 let user_info = sqlx::query!(
321 r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin,
322 k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?"
323 FROM users u
324 LEFT JOIN user_keys k ON u.id = k.user_id
325 WHERE u.did = $1"#,
326 result.did
327 )
328 .fetch_optional(db)
329 .await
330 .ok()
331 .flatten();
332 let Some(user_info) = user_info else {
333 return Err(TokenValidationError::AuthenticationFailed);
334 };
335 if !allow_deactivated && user_info.deactivated_at.is_some() {
336 return Err(TokenValidationError::AccountDeactivated);
337 }
338 if user_info.takedown_ref.is_some() {
339 return Err(TokenValidationError::AccountTakedown);
340 }
341 let key_bytes = if let (Some(kb), Some(ev)) =
342 (&user_info.key_bytes, user_info.encryption_version)
343 {
344 crate::config::decrypt_key(kb, Some(ev)).ok()
345 } else {
346 None
347 };
348 Ok(AuthenticatedUser {
349 did: result.did,
350 key_bytes,
351 is_oauth: true,
352 is_admin: user_info.is_admin,
353 scope: result.scope,
354 })
355 }
356 Err(_) => Err(TokenValidationError::AuthenticationFailed),
357 }
358}
359
360#[derive(Debug, Serialize, Deserialize)]
361pub struct Claims {
362 pub iss: String,
363 pub sub: String,
364 pub aud: String,
365 pub exp: usize,
366 pub iat: usize,
367 #[serde(skip_serializing_if = "Option::is_none")]
368 pub scope: Option<String>,
369 #[serde(skip_serializing_if = "Option::is_none")]
370 pub lxm: Option<String>,
371 pub jti: String,
372}
373
374#[derive(Debug, Serialize, Deserialize)]
375pub struct Header {
376 pub alg: String,
377 pub typ: String,
378}
379
380#[derive(Debug, Serialize, Deserialize)]
381pub struct UnsafeClaims {
382 pub iss: String,
383 pub sub: Option<String>,
384}
385
386pub struct TokenData<T> {
387 pub claims: T,
388}