this repo has no description
1use serde::{Deserialize, Serialize};
2use sqlx::PgPool;
3use std::fmt;
4
5pub mod extractor;
6pub mod token;
7pub mod verify;
8
9pub use extractor::{BearerAuth, BearerAuthAllowDeactivated, AuthError, extract_bearer_token_from_header};
10pub use token::{
11 create_access_token, create_refresh_token, create_service_token,
12 create_access_token_with_metadata, create_refresh_token_with_metadata,
13 TokenWithMetadata,
14 TOKEN_TYPE_ACCESS, TOKEN_TYPE_REFRESH, TOKEN_TYPE_SERVICE,
15 SCOPE_ACCESS, SCOPE_REFRESH, SCOPE_APP_PASS, SCOPE_APP_PASS_PRIVILEGED,
16};
17pub use verify::{get_did_from_token, get_jti_from_token, verify_token, verify_access_token, verify_refresh_token};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum TokenValidationError {
21 AccountDeactivated,
22 AccountTakedown,
23 KeyDecryptionFailed,
24 AuthenticationFailed,
25}
26
27impl fmt::Display for TokenValidationError {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 match self {
30 Self::AccountDeactivated => write!(f, "AccountDeactivated"),
31 Self::AccountTakedown => write!(f, "AccountTakedown"),
32 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"),
33 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"),
34 }
35 }
36}
37
38pub struct AuthenticatedUser {
39 pub did: String,
40 pub key_bytes: Option<Vec<u8>>,
41 pub is_oauth: bool,
42}
43
44pub async fn validate_bearer_token(
45 db: &PgPool,
46 token: &str,
47) -> Result<AuthenticatedUser, TokenValidationError> {
48 validate_bearer_token_with_options(db, token, false).await
49}
50
51pub async fn validate_bearer_token_allow_deactivated(
52 db: &PgPool,
53 token: &str,
54) -> Result<AuthenticatedUser, TokenValidationError> {
55 validate_bearer_token_with_options(db, token, true).await
56}
57
58async fn validate_bearer_token_with_options(
59 db: &PgPool,
60 token: &str,
61 allow_deactivated: bool,
62) -> Result<AuthenticatedUser, TokenValidationError> {
63 let did_from_token = get_did_from_token(token).ok();
64
65 if let Some(ref did) = did_from_token {
66 if let Some(user) = sqlx::query!(
67 "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref
68 FROM users u
69 JOIN user_keys k ON u.id = k.user_id
70 WHERE u.did = $1",
71 did
72 )
73 .fetch_optional(db)
74 .await
75 .ok()
76 .flatten()
77 {
78 if !allow_deactivated && user.deactivated_at.is_some() {
79 return Err(TokenValidationError::AccountDeactivated);
80 }
81 if user.takedown_ref.is_some() {
82 return Err(TokenValidationError::AccountTakedown);
83 }
84
85 let decrypted_key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version)
86 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?;
87
88 if let Ok(token_data) = verify_access_token(token, &decrypted_key) {
89 let session_exists = sqlx::query_scalar!(
90 "SELECT 1 as one FROM session_tokens WHERE did = $1 AND access_jti = $2 AND access_expires_at > NOW()",
91 did,
92 token_data.claims.jti
93 )
94 .fetch_optional(db)
95 .await
96 .ok()
97 .flatten();
98
99 if session_exists.is_some() {
100 return Ok(AuthenticatedUser {
101 did: did.clone(),
102 key_bytes: Some(decrypted_key),
103 is_oauth: false,
104 });
105 }
106 }
107 }
108 }
109
110 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) {
111 if let Some(oauth_token) = sqlx::query!(
112 r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref
113 FROM oauth_token t
114 JOIN users u ON t.did = u.did
115 WHERE t.token_id = $1"#,
116 oauth_info.token_id
117 )
118 .fetch_optional(db)
119 .await
120 .ok()
121 .flatten()
122 {
123 if !allow_deactivated && oauth_token.deactivated_at.is_some() {
124 return Err(TokenValidationError::AccountDeactivated);
125 }
126 if oauth_token.takedown_ref.is_some() {
127 return Err(TokenValidationError::AccountTakedown);
128 }
129
130 let now = chrono::Utc::now();
131 if oauth_token.expires_at > now {
132 return Ok(AuthenticatedUser {
133 did: oauth_token.did,
134 key_bytes: None,
135 is_oauth: true,
136 });
137 }
138 }
139 }
140
141 Err(TokenValidationError::AuthenticationFailed)
142}
143
144#[derive(Debug, Serialize, Deserialize)]
145pub struct Claims {
146 pub iss: String,
147 pub sub: String,
148 pub aud: String,
149 pub exp: usize,
150 pub iat: usize,
151 #[serde(skip_serializing_if = "Option::is_none")]
152 pub scope: Option<String>,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 pub lxm: Option<String>,
155 pub jti: String,
156}
157
158#[derive(Debug, Serialize, Deserialize)]
159pub struct Header {
160 pub alg: String,
161 pub typ: String,
162}
163
164#[derive(Debug, Serialize, Deserialize)]
165pub struct UnsafeClaims {
166 pub iss: String,
167 pub sub: Option<String>,
168}
169
170// fancy boy TokenData equivalent for compatibility/structure
171pub struct TokenData<T> {
172 pub claims: T,
173}