this repo has no description
1use axum::{Form, Json};
2use axum::extract::State;
3use axum::http::{HeaderMap, StatusCode};
4use chrono::Utc;
5use serde::{Deserialize, Serialize};
6
7use crate::state::{AppState, RateLimitKind};
8use crate::oauth::{OAuthError, db};
9
10use super::helpers::extract_token_claims;
11
12#[derive(Debug, Deserialize)]
13pub struct RevokeRequest {
14 pub token: Option<String>,
15 #[serde(default)]
16 pub token_type_hint: Option<String>,
17}
18
19pub async fn revoke_token(
20 State(state): State<AppState>,
21 headers: HeaderMap,
22 Form(request): Form<RevokeRequest>,
23) -> Result<StatusCode, OAuthError> {
24 let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
25 if !state.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip).await {
26 tracing::warn!(ip = %client_ip, "OAuth revoke rate limit exceeded");
27 return Err(OAuthError::RateLimited);
28 }
29
30 if let Some(token) = &request.token {
31 if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? {
32 db::delete_token_family(&state.db, db_id).await?;
33 } else {
34 db::delete_token(&state.db, token).await?;
35 }
36 }
37
38 Ok(StatusCode::OK)
39}
40
41#[derive(Debug, Deserialize)]
42pub struct IntrospectRequest {
43 pub token: String,
44 #[serde(default)]
45 pub token_type_hint: Option<String>,
46}
47
48#[derive(Debug, Serialize)]
49pub struct IntrospectResponse {
50 pub active: bool,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub scope: Option<String>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub client_id: Option<String>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub username: Option<String>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub token_type: Option<String>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub exp: Option<i64>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub iat: Option<i64>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub nbf: Option<i64>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub sub: Option<String>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub aud: Option<String>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub iss: Option<String>,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 pub jti: Option<String>,
73}
74
75pub async fn introspect_token(
76 State(state): State<AppState>,
77 headers: HeaderMap,
78 Form(request): Form<IntrospectRequest>,
79) -> Result<Json<IntrospectResponse>, OAuthError> {
80 let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
81 if !state.check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip).await {
82 tracing::warn!(ip = %client_ip, "OAuth introspect rate limit exceeded");
83 return Err(OAuthError::RateLimited);
84 }
85
86 let inactive_response = IntrospectResponse {
87 active: false,
88 scope: None,
89 client_id: None,
90 username: None,
91 token_type: None,
92 exp: None,
93 iat: None,
94 nbf: None,
95 sub: None,
96 aud: None,
97 iss: None,
98 jti: None,
99 };
100
101 let token_info = match extract_token_claims(&request.token) {
102 Ok(info) => info,
103 Err(_) => return Ok(Json(inactive_response)),
104 };
105
106 let token_data = match db::get_token_by_id(&state.db, &token_info.jti).await {
107 Ok(Some(data)) => data,
108 _ => return Ok(Json(inactive_response)),
109 };
110
111 if token_data.expires_at < Utc::now() {
112 return Ok(Json(inactive_response));
113 }
114
115 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
116 let issuer = format!("https://{}", pds_hostname);
117
118 Ok(Json(IntrospectResponse {
119 active: true,
120 scope: token_data.scope,
121 client_id: Some(token_data.client_id),
122 username: None,
123 token_type: if token_data.parameters.dpop_jkt.is_some() {
124 Some("DPoP".to_string())
125 } else {
126 Some("Bearer".to_string())
127 },
128 exp: Some(token_info.exp),
129 iat: Some(token_info.iat),
130 nbf: Some(token_info.iat),
131 sub: Some(token_data.did),
132 aud: Some(issuer.clone()),
133 iss: Some(issuer),
134 jti: Some(token_info.jti),
135 }))
136}