this repo has no description
1use super::helpers::{create_access_token, verify_pkce};
2use super::types::{TokenRequest, TokenResponse};
3use crate::config::AuthConfig;
4use crate::oauth::{
5 ClientAuth, OAuthError, RefreshToken, TokenData, TokenId,
6 client::{ClientMetadataCache, verify_client_auth},
7 db,
8 dpop::DPoPVerifier,
9};
10use crate::state::AppState;
11use axum::Json;
12use axum::http::HeaderMap;
13use chrono::{Duration, Utc};
14
15const ACCESS_TOKEN_EXPIRY_SECONDS: i64 = 3600;
16const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 60;
17
18pub async fn handle_authorization_code_grant(
19 state: AppState,
20 _headers: HeaderMap,
21 request: TokenRequest,
22 dpop_proof: Option<String>,
23) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
24 let code = request
25 .code
26 .ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?;
27 let code_verifier = request
28 .code_verifier
29 .ok_or_else(|| OAuthError::InvalidRequest("code_verifier is required".to_string()))?;
30 let auth_request = db::consume_authorization_request_by_code(&state.db, &code)
31 .await?
32 .ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?;
33 if auth_request.expires_at < Utc::now() {
34 return Err(OAuthError::InvalidGrant(
35 "Authorization code has expired".to_string(),
36 ));
37 }
38 if let Some(request_client_id) = &request.client_id
39 && request_client_id != &auth_request.client_id {
40 return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
41 }
42 let did = auth_request
43 .did
44 .ok_or_else(|| OAuthError::InvalidGrant("Authorization not completed".to_string()))?;
45 let client_metadata_cache = ClientMetadataCache::new(3600);
46 let client_metadata = client_metadata_cache.get(&auth_request.client_id).await?;
47 let client_auth = if let (Some(assertion), Some(assertion_type)) =
48 (&request.client_assertion, &request.client_assertion_type)
49 {
50 if assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
51 return Err(OAuthError::InvalidClient(
52 "Unsupported client_assertion_type".to_string(),
53 ));
54 }
55 ClientAuth::PrivateKeyJwt {
56 client_assertion: assertion.clone(),
57 }
58 } else if let Some(secret) = &request.client_secret {
59 ClientAuth::SecretPost {
60 client_secret: secret.clone(),
61 }
62 } else {
63 ClientAuth::None
64 };
65 verify_client_auth(&client_metadata_cache, &client_metadata, &client_auth).await?;
66 verify_pkce(&auth_request.parameters.code_challenge, &code_verifier)?;
67 if let Some(redirect_uri) = &request.redirect_uri
68 && redirect_uri != &auth_request.parameters.redirect_uri {
69 return Err(OAuthError::InvalidGrant(
70 "redirect_uri mismatch".to_string(),
71 ));
72 }
73 let dpop_jkt = if let Some(proof) = &dpop_proof {
74 let config = AuthConfig::get();
75 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
76 let pds_hostname =
77 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
78 let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
79 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
80 if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
81 return Err(OAuthError::InvalidDpopProof(
82 "DPoP proof has already been used".to_string(),
83 ));
84 }
85 if let Some(expected_jkt) = &auth_request.parameters.dpop_jkt
86 && &result.jkt != expected_jkt {
87 return Err(OAuthError::InvalidDpopProof(
88 "DPoP key binding mismatch".to_string(),
89 ));
90 }
91 Some(result.jkt)
92 } else if auth_request.parameters.dpop_jkt.is_some() {
93 return Err(OAuthError::InvalidRequest(
94 "DPoP proof required for this authorization".to_string(),
95 ));
96 } else {
97 None
98 };
99 let token_id = TokenId::generate();
100 let refresh_token = RefreshToken::generate();
101 let now = Utc::now();
102 let access_token = create_access_token(&token_id.0, &did, dpop_jkt.as_deref())?;
103 let token_data = TokenData {
104 did: did.clone(),
105 token_id: token_id.0.clone(),
106 created_at: now,
107 updated_at: now,
108 expires_at: now + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS),
109 client_id: auth_request.client_id.clone(),
110 client_auth: auth_request.client_auth.unwrap_or(ClientAuth::None),
111 device_id: auth_request.device_id,
112 parameters: auth_request.parameters.clone(),
113 details: None,
114 code: None,
115 current_refresh_token: Some(refresh_token.0.clone()),
116 scope: auth_request.parameters.scope.clone(),
117 };
118 db::create_token(&state.db, &token_data).await?;
119 tokio::spawn({
120 let pool = state.db.clone();
121 let did_clone = did.clone();
122 async move {
123 if let Err(e) = db::enforce_token_limit_for_user(&pool, &did_clone).await {
124 tracing::warn!("Failed to enforce token limit for user: {:?}", e);
125 }
126 }
127 });
128 let mut response_headers = HeaderMap::new();
129 let config = AuthConfig::get();
130 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
131 response_headers.insert("DPoP-Nonce", verifier.generate_nonce().parse().unwrap());
132 Ok((
133 response_headers,
134 Json(TokenResponse {
135 access_token,
136 token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
137 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
138 refresh_token: Some(refresh_token.0),
139 scope: auth_request.parameters.scope,
140 sub: Some(did),
141 }),
142 ))
143}
144
145pub async fn handle_refresh_token_grant(
146 state: AppState,
147 _headers: HeaderMap,
148 request: TokenRequest,
149 dpop_proof: Option<String>,
150) -> Result<(HeaderMap, Json<TokenResponse>), OAuthError> {
151 let refresh_token_str = request
152 .refresh_token
153 .ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?;
154 if let Some(token_id) = db::check_refresh_token_used(&state.db, &refresh_token_str).await? {
155 db::delete_token_family(&state.db, token_id).await?;
156 return Err(OAuthError::InvalidGrant(
157 "Refresh token reuse detected, token family revoked".to_string(),
158 ));
159 }
160 let (db_id, token_data) = db::get_token_by_refresh_token(&state.db, &refresh_token_str)
161 .await?
162 .ok_or_else(|| OAuthError::InvalidGrant("Invalid refresh token".to_string()))?;
163 if token_data.expires_at < Utc::now() {
164 db::delete_token_family(&state.db, db_id).await?;
165 return Err(OAuthError::InvalidGrant(
166 "Refresh token has expired".to_string(),
167 ));
168 }
169 let dpop_jkt = if let Some(proof) = &dpop_proof {
170 let config = AuthConfig::get();
171 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
172 let pds_hostname =
173 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
174 let token_endpoint = format!("https://{}/oauth/token", pds_hostname);
175 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?;
176 if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? {
177 return Err(OAuthError::InvalidDpopProof(
178 "DPoP proof has already been used".to_string(),
179 ));
180 }
181 if let Some(expected_jkt) = &token_data.parameters.dpop_jkt
182 && &result.jkt != expected_jkt {
183 return Err(OAuthError::InvalidDpopProof(
184 "DPoP key binding mismatch".to_string(),
185 ));
186 }
187 Some(result.jkt)
188 } else if token_data.parameters.dpop_jkt.is_some() {
189 return Err(OAuthError::InvalidRequest(
190 "DPoP proof required".to_string(),
191 ));
192 } else {
193 None
194 };
195 let new_token_id = TokenId::generate();
196 let new_refresh_token = RefreshToken::generate();
197 let new_expires_at = Utc::now() + Duration::days(REFRESH_TOKEN_EXPIRY_DAYS);
198 db::rotate_token(
199 &state.db,
200 db_id,
201 &new_token_id.0,
202 &new_refresh_token.0,
203 new_expires_at,
204 )
205 .await?;
206 let access_token = create_access_token(&new_token_id.0, &token_data.did, dpop_jkt.as_deref())?;
207 let mut response_headers = HeaderMap::new();
208 let config = AuthConfig::get();
209 let verifier = DPoPVerifier::new(config.dpop_secret().as_bytes());
210 response_headers.insert("DPoP-Nonce", verifier.generate_nonce().parse().unwrap());
211 Ok((
212 response_headers,
213 Json(TokenResponse {
214 access_token,
215 token_type: if dpop_jkt.is_some() { "DPoP" } else { "Bearer" }.to_string(),
216 expires_in: ACCESS_TOKEN_EXPIRY_SECONDS as u64,
217 refresh_token: Some(new_refresh_token.0),
218 scope: token_data.scope,
219 sub: Some(token_data.did),
220 }),
221 ))
222}