this repo has no description
1use super::helpers::extract_token_claims;
2use crate::oauth::{OAuthError, db};
3use crate::state::{AppState, RateLimitKind};
4use axum::extract::State;
5use axum::http::{HeaderMap, StatusCode};
6use axum::{Form, Json};
7use chrono::Utc;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Deserialize)]
11pub struct RevokeRequest {
12 pub token: Option<String>,
13 #[serde(default)]
14 pub token_type_hint: Option<String>,
15}
16
17pub async fn revoke_token(
18 State(state): State<AppState>,
19 headers: HeaderMap,
20 Form(request): Form<RevokeRequest>,
21) -> Result<StatusCode, OAuthError> {
22 let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
23 if !state
24 .check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip)
25 .await
26 {
27 tracing::warn!(ip = %client_ip, "OAuth revoke rate limit exceeded");
28 return Err(OAuthError::RateLimited);
29 }
30 if let Some(token) = &request.token {
31 if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? {
32 db::delete_token_family(&state.db, db_id).await?;
33 } else {
34 db::delete_token(&state.db, token).await?;
35 }
36 }
37 Ok(StatusCode::OK)
38}
39
40#[derive(Debug, Deserialize)]
41pub struct IntrospectRequest {
42 pub token: String,
43 #[serde(default)]
44 pub token_type_hint: Option<String>,
45}
46
47#[derive(Debug, Serialize)]
48pub struct IntrospectResponse {
49 pub active: bool,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 pub scope: Option<String>,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub client_id: Option<String>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub username: Option<String>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub token_type: Option<String>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub exp: Option<i64>,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 pub iat: Option<i64>,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub nbf: Option<i64>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 pub sub: Option<String>,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 pub aud: Option<String>,
68 #[serde(skip_serializing_if = "Option::is_none")]
69 pub iss: Option<String>,
70 #[serde(skip_serializing_if = "Option::is_none")]
71 pub jti: Option<String>,
72}
73
74pub async fn introspect_token(
75 State(state): State<AppState>,
76 headers: HeaderMap,
77 Form(request): Form<IntrospectRequest>,
78) -> Result<Json<IntrospectResponse>, OAuthError> {
79 let client_ip = crate::rate_limit::extract_client_ip(&headers, None);
80 if !state
81 .check_rate_limit(RateLimitKind::OAuthIntrospect, &client_ip)
82 .await
83 {
84 tracing::warn!(ip = %client_ip, "OAuth introspect rate limit exceeded");
85 return Err(OAuthError::RateLimited);
86 }
87 let inactive_response = IntrospectResponse {
88 active: false,
89 scope: None,
90 client_id: None,
91 username: None,
92 token_type: None,
93 exp: None,
94 iat: None,
95 nbf: None,
96 sub: None,
97 aud: None,
98 iss: None,
99 jti: None,
100 };
101 let token_info = match extract_token_claims(&request.token) {
102 Ok(info) => info,
103 Err(_) => return Ok(Json(inactive_response)),
104 };
105 let token_data = match db::get_token_by_id(&state.db, &token_info.sid).await {
106 Ok(Some(data)) => data,
107 _ => return Ok(Json(inactive_response)),
108 };
109 if token_data.expires_at < Utc::now() {
110 return Ok(Json(inactive_response));
111 }
112 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
113 let issuer = format!("https://{}", pds_hostname);
114 Ok(Json(IntrospectResponse {
115 active: true,
116 scope: token_data.scope,
117 client_id: Some(token_data.client_id),
118 username: None,
119 token_type: if token_data.parameters.dpop_jkt.is_some() {
120 Some("DPoP".to_string())
121 } else {
122 Some("Bearer".to_string())
123 },
124 exp: Some(token_info.exp),
125 iat: Some(token_info.iat),
126 nbf: Some(token_info.iat),
127 sub: Some(token_data.did),
128 aud: Some(issuer.clone()),
129 iss: Some(issuer),
130 jti: Some(token_info.jti),
131 }))
132}