this repo has no description
1use serde::{Deserialize, Serialize};
2use sqlx::PgPool;
3use std::fmt;
4use std::sync::Arc;
5use std::time::Duration;
6use crate::cache::Cache;
7
8pub mod extractor;
9pub mod token;
10pub mod verify;
11
12pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header};
13pub use token::{
14 create_access_token, create_refresh_token, create_service_token,
15 create_access_token_with_metadata, create_refresh_token_with_metadata,
16 TokenWithMetadata,
17 TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE,
18 SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED,
19};
20pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token};
21
22const KEY_CACHE_TTL_SECS: u64 = 300;
23const SESSION_CACHE_TTL_SECS: u64 = 60;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum TokenValidationError {
27 AccountDeactivated,
28 AccountTakedown,
29 KeyDecryptionFailed,
30 AuthenticationFailed,
31}
32
33impl fmt::Display for TokenValidationError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 Self::AccountDeactivated => write!(f, "AccountDeactivated"),
37 Self::AccountTakedown => write!(f, "AccountTakedown"),
38 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
39 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
40 }
41 }
42}
43
44pub struct AuthenticatedUser {
45 pub did: String,
46 pub key_bytes: Option<Vec<u8>>,
47 pub is_oauth: bool,
48}
49
50pub async fn validate_bearer_token(
51 db: &PgPool,
52 token: &str,
53) -> Result<AuthenticatedUser, TokenValidationError> {
54 validate_bearer_token_with_options_internal(db, None, token, false).await
55}
56
57pub async fn validate_bearer_token_allow_deactivated(
58 db: &PgPool,
59 token: &str,
60) -> Result<AuthenticatedUser, TokenValidationError> {
61 validate_bearer_token_with_options_internal(db, None, token, true).await
62}
63
64pub async fn validate_bearer_token_cached(
65 db: &PgPool,
66 cache: &Arc<dyn Cache>,
67 token: &str,
68) -> Result<AuthenticatedUser, TokenValidationError> {
69 validate_bearer_token_with_options_internal(db, Some(cache), token, false).await
70}
71
72pub async fn validate_bearer_token_cached_allow_deactivated(
73 db: &PgPool,
74 cache: &Arc<dyn Cache>,
75 token: &str,
76) -> Result<AuthenticatedUser, TokenValidationError> {
77 validate_bearer_token_with_options_internal(db, Some(cache), token, true).await
78}
79
80async fn validate_bearer_token_with_options_internal(
81 db: &PgPool,
82 cache: Option<&Arc<dyn Cache>>,
83 token: &str,
84 allow_deactivated: bool,
85) -> Result<AuthenticatedUser, TokenValidationError> {
86 let did_from_token = get_did_from_token(token).ok();
87
88 if let Some(ref did) = did_from_token {
89 let key_cache_key = format!("auth:key:{}", did);
90 let mut cached_key: Option<Vec<u8>> = None;
91
92 if let Some(c) = cache {
93 cached_key = c.get_bytes(&key_cache_key).await;
94 if cached_key.is_some() {
95 crate::metrics::record_auth_cache_hit("key");
96 } else {
97 crate::metrics::record_auth_cache_miss("key");
98 }
99 }
100
101 let (decrypted_key, deactivated_at, takedown_ref) = if let Some(key) = cached_key {
102 let user_status = sqlx::query!(
103 "SELECT deactivated_at, takedown_ref FROM users WHERE did = $1",
104 did
105 )
106 .fetch_optional(db)
107 .await
108 .ok()
109 .flatten();
110
111 match user_status {
112 Some(status) => (Some(key), status.deactivated_at, status.takedown_ref),
113 None => (None, None, None),
114 }
115 } else {
116 if let Some(user) = sqlx::query!(
117 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref
118 FROM users u
119 JOIN user_keys k ON u.id = k.user_id
120 WHERE u.did = $1",
121 did
122 )
123 .fetch_optional(db)
124 .await
125 .ok()
126 .flatten()
127 {
128 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
129 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
130
131 if let Some(c) = cache {
132 let _ = c.set_bytes(&key_cache_key, &key, Duration::from_secs(KEY_CACHE_TTL_SECS)).await;
133 }
134
135 (Some(key), user.deactivated_at, user.takedown_ref)
136 } else {
137 (None, None, None)
138 }
139 };
140
141 if let Some(decrypted_key) = decrypted_key {
142 if !allow_deactivated && deactivated_at.is_some() {
143 return Err(TokenValidationError::AccountDeactivated);
144 }
145 if takedown_ref.is_some() {
146 return Err(TokenValidationError::AccountTakedown);
147 }
148
149 if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
150 let jti = &token_data.claims.jti;
151 let session_cache_key = format!("auth:session:{}:{}", did, jti);
152 let mut session_valid = false;
153
154 if let Some(c) = cache {
155 if let Some(cached_value) = c.get(&session_cache_key).await {
156 session_valid = cached_value == "1";
157 crate::metrics::record_auth_cache_hit("session");
158 } else {
159 crate::metrics::record_auth_cache_miss("session");
160 }
161 }
162
163 if !session_valid {
164 let session_exists = sqlx::query_scalar!(
165 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()",
166 did,
167 jti
168 )
169 .fetch_optional(db)
170 .await
171 .ok()
172 .flatten();
173
174 session_valid = session_exists.is_some();
175
176 if session_valid {
177 if let Some(c) = cache {
178 let _ = c.set(&session_cache_key, "1", Duration::from_secs(SESSION_CACHE_TTL_SECS)).await;
179 }
180 }
181 }
182
183 if session_valid {
184 return Ok(AuthenticatedUser {
185 did: did.clone(),
186 key_bytes: Some(decrypted_key),
187 is_oauth: false,
188 });
189 }
190 }
191 }
192 }
193
194 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) {
195 if let Some(oauth_token) = sqlx::query!(
196 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref
197 FROM oauth_token t
198 JOIN users u ON t.did = u.did
199 WHERE t.token_id = $1"#,
200 oauth_info.token_id
201 )
202 .fetch_optional(db)
203 .await
204 .ok()
205 .flatten()
206 {
207 if !allow_deactivated && oauth_token.deactivated_at.is_some() {
208 return Err(TokenValidationError::AccountDeactivated);
209 }
210 if oauth_token.takedown_ref.is_some() {
211 return Err(TokenValidationError::AccountTakedown);
212 }
213
214 let now = chrono::Utc::now();
215 if oauth_token.expires_at > now {
216 return Ok(AuthenticatedUser {
217 did: oauth_token.did,
218 key_bytes: None,
219 is_oauth: true,
220 });
221 }
222 }
223 }
224
225 Err(TokenValidationError::AuthenticationFailed)
226}
227
228pub async fn invalidate_auth_cache(cache: &Arc<dyn Cache>, did: &str) {
229 let key_cache_key = format!("auth:key:{}", did);
230 let _ = cache.delete(&key_cache_key).await;
231}
232
233#[derive(Debug, Serialize, Deserialize)]
234pub struct Claims {
235 pub iss: String,
236 pub sub: String,
237 pub aud: String,
238 pub exp: usize,
239 pub iat: usize,
240 #[serde(skip_serializing_if = "Option::is_none")]
241 pub scope: Option<String>,
242 #[serde(skip_serializing_if = "Option::is_none")]
243 pub lxm: Option<String>,
244 pub jti: String,
245}
246
247#[derive(Debug, Serialize, Deserialize)]
248pub struct Header {
249 pub alg: String,
250 pub typ: String,
251}
252
253#[derive(Debug, Serialize, Deserialize)]
254pub struct UnsafeClaims {
255 pub iss: String,
256 pub sub: Option<String>,
257}
258
259// fancy boy TokenData equivalent for compatibility/structure
260pub struct TokenData<T> {
261 pub claims: T,
262}