this repo has no description
1use axum::http::HeaderMap;
2use axum::Json;
3use chrono::{Duration, Utc};
4use crate::config::AuthConfig;
5use crate::state::AppState;
6use crate::oauth::{
7 ClientAuth, OAuthError, RefreshToken, TokenData, TokenId,
8 client::{ClientMetadataCache, verify_client_auth},
9 db,
10 dpop::DPoPVerifier,
11};
12use super::types::{TokenRequest, TokenResponse};
13use super::helpers::{create_access_token, verify_pkce};
14
15const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
16const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60;
17
18pub async fn handle_authorization_code_grant(
19 state: AppState,
20 _headers: HeaderMap,
21 request: TokenRequest,
22 dpop_proof: Option<String>,
23) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
24 let code = request
25 .code
26 .ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?;
27 let code_verifier = request
28 .code_verifier
29 .ok_or_else(|| OAuthError::InvalidRequest("code_verifier is required".to_string()))?;
30 let auth_request = db::consume_authorization_request_by_code(&state.db, &code)
31 .await?
32 .ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?;
33 if auth_request.expires_at < Utc::now() {
34 return Err(OAuthError::InvalidGrant("Authorization code has expired".to_string()));
35 }
36 if let Some(request_client_id) = &request.client_id {
37 if request_client_id != &auth_request.client_id {
38 return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
39 }
40 }
41 let did = auth_request
42 .did
43 .ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?;
44 let client_metadata_cache = ClientMetadataCache::new(3600);
45 let client_metadata = client_metadata_cache.get(&auth_request.client_id).await?;
46 let client_auth = if let (Some(assertion), Some(assertion_type)) = (&request.client_assertion, &request.client_assertion_type) {
47 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
48 return Err(OAuthError::InvalidClient(
49 "Unsupported client_assertion_type".to_string(),
50 ));
51 }
52 ClientAuth::PrivateKeyJwt {
53 client_assertion: assertion.clone(),
54 }
55 } else if let Some(secret) = &request.client_secret {
56 ClientAuth::SecretPost {
57 client_secret: secret.clone(),
58 }
59 } else {
60 ClientAuth::None
61 };
62 verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?;
63 verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;
64 if let Some(redirect_uri) = &request.redirect_uri {
65 if redirect_uri != &auth_request.parameters.redirect_uri {
66 return Err(OAuthError::InvalidGrant("redirect_uri mismatch".to_string()));
67 }
68 }
69 let dpop_jkt = if let Some(proof) = &dpop_proof {
70 let config = AuthConfig::get();
71 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
72 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
73 let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
74 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
75 if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
76 return Err(OAuthError::InvalidDpopProof(
77 "DPoP proof has already been used".to_string(),
78 ));
79 }
80 if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt {
81 if &result.jkt != expected_jkt {
82 return Err(OAuthError::InvalidDpopProof(
83 "DPoP key binding mismatch".to_string(),
84 ));
85 }
86 }
87 Some(result.jkt)
88 } else if auth_request.parameters.dpop_jkt.is_some() {
89 return Err(OAuthError::InvalidRequest(
90 "DPoP proof required for this authorization".to_string(),
91 ));
92 } else {
93 None
94 };
95 let token_id = TokenId::generate();
96 let refresh_token = RefreshToken::generate();
97 let now = Utc::now();
98 let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?;
99 let token_data = TokenData {
100 did: did.clone(),
101 token_id: token_id.0.clone(),
102 created_at: now,
103 updated_at: now,
104 expires_at: now + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS),
105 client_id: auth_request.client_id.clone(),
106 client_auth: auth_request.client_auth.unwrap_or(ClientAuth::None),
107 device_id: auth_request.device_id,
108 parameters: auth_request.parameters.clone(),
109 details: None,
110 code: None,
111 current_refresh_token: Some(refresh_token.0.clone()),
112 scope: auth_request.parameters.scope.clone(),
113 };
114 db::create_token(&state.db, &token_data).await?;
115 tokio::spawn({
116 let pool = state.db.clone();
117 let did_clone = did.clone();
118 async move {
119 if let Err(e) = db::enforce_token_limit_for_user(&pool, &did_clone).await {
120 tracing::warn!("Failed to enforce token limit for user: {:?}", e);
121 }
122 }
123 });
124 let mut response_headers = HeaderMap::new();
125 let config = AuthConfig::get();
126 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
127 response_headers.insert(
128 "DPoP-Nonce",
129 verifier.generate_nonce().parse().unwrap(),
130 );
131 Ok((
132 response_headers,
133 Json(TokenResponse {
134 access_token,
135 token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
136 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
137 refresh_token: Some(refresh_token.0),
138 scope: auth_request.parameters.scope,
139 sub: Some(did),
140 }),
141 ))
142}
143
144pub async fn handle_refresh_token_grant(
145 state: AppState,
146 _headers: HeaderMap,
147 request: TokenRequest,
148 dpop_proof: Option<String>,
149) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
150 let refresh_token_str = request
151 .refresh_token
152 .ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?;
153 if let Some(token_id) = db::check_refresh_token_used(&state.db, &refresh_token_str).await? {
154 db::delete_token_family(&state.db, token_id).await?;
155 return Err(OAuthError::InvalidGrant(
156 "Refresh token reuse detected, token family revoked".to_string(),
157 ));
158 }
159 let (db_id, token_data) = db::get_token_by_refresh_token(&state.db, &refresh_token_str)
160 .await?
161 .ok_or_else(|| OAuthError::InvalidGrant("Invalid refresh token".to_string()))?;
162 if token_data.expires_at < Utc::now() {
163 db::delete_token_family(&state.db, db_id).await?;
164 return Err(OAuthError::InvalidGrant("Refresh token has expired".to_string()));
165 }
166 let dpop_jkt = if let Some(proof) = &dpop_proof {
167 let config = AuthConfig::get();
168 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
169 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
170 let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
171 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
172 if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
173 return Err(OAuthError::InvalidDpopProof(
174 "DPoP proof has already been used".to_string(),
175 ));
176 }
177 if let Some(expected_jkt) = &token_data.parameters.dpop_jkt {
178 if &result.jkt != expected_jkt {
179 return Err(OAuthError::InvalidDpopProof(
180 "DPoP key binding mismatch".to_string(),
181 ));
182 }
183 }
184 Some(result.jkt)
185 } else if token_data.parameters.dpop_jkt.is_some() {
186 return Err(OAuthError::InvalidRequest(
187 "DPoP proof required".to_string(),
188 ));
189 } else {
190 None
191 };
192 let new_token_id = TokenId::generate();
193 let new_refresh_token = RefreshToken::generate();
194 let new_expires_at = Utc::now() + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS);
195 db::rotate_token(
196 &state.db,
197 db_id,
198 &new_token_id.0,
199 &new_refresh_token.0,
200 new_expires_at,
201 )
202 .await?;
203 let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?;
204 let mut response_headers = HeaderMap::new();
205 let config = AuthConfig::get();
206 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
207 response_headers.insert(
208 "DPoP-Nonce",
209 verifier.generate_nonce().parse().unwrap(),
210 );
211 Ok((
212 response_headers,
213 Json(TokenResponse {
214 access_token,
215 token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
216 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
217 refresh_token: Some(new_refresh_token.0),
218 scope: token_data.scope,
219 sub: Some(token_data.did),
220 }),
221 ))
222}