this repo has no description
1use axum::{
2 Form, Json,
3 extract::{Query, State},
4 http::HeaderMap,
5 response::{IntoResponse, Redirect, Response},
6};
7use chrono::Utc;
8use serde::{Deserialize, Serialize};
9use urlencoding::encode as url_encode;
10
11use crate::state::AppState;
12use crate::oauth::{Code, DeviceData, DeviceId, OAuthError, SessionId, db};
13
14fn extract_client_ip(headers: &HeaderMap) -> String {
15 if let Some(forwarded) = headers.get("x-forwarded-for") {
16 if let Ok(value) = forwarded.to_str() {
17 if let Some(first_ip) = value.split(',').next() {
18 return first_ip.trim().to_string();
19 }
20 }
21 }
22
23 if let Some(real_ip) = headers.get("x-real-ip") {
24 if let Ok(value) = real_ip.to_str() {
25 return value.trim().to_string();
26 }
27 }
28
29 "0.0.0.0".to_string()
30}
31
32fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
33 headers
34 .get("user-agent")
35 .and_then(|v| v.to_str().ok())
36 .map(|s| s.to_string())
37}
38
39#[derive(Debug, Deserialize)]
40pub struct AuthorizeQuery {
41 pub request_uri: Option<String>,
42 pub client_id: Option<String>,
43}
44
45#[derive(Debug, Serialize)]
46pub struct AuthorizeResponse {
47 pub client_id: String,
48 pub client_name: Option<String>,
49 pub scope: Option<String>,
50 pub redirect_uri: String,
51 pub state: Option<String>,
52 pub login_hint: Option<String>,
53}
54
55#[derive(Debug, Deserialize)]
56pub struct AuthorizeSubmit {
57 pub request_uri: String,
58 pub username: String,
59 pub password: String,
60 #[serde(default)]
61 pub remember_device: bool,
62}
63
64pub async fn authorize_get(
65 State(state): State<AppState>,
66 Query(query): Query<AuthorizeQuery>,
67) -> Result<Json<AuthorizeResponse>, OAuthError> {
68 let request_uri = query.request_uri.ok_or_else(|| {
69 OAuthError::InvalidRequest("request_uri is required".to_string())
70 })?;
71
72 let request_data = db::get_authorization_request(&state.db, &request_uri)
73 .await?
74 .ok_or_else(|| OAuthError::InvalidRequest("Invalid or expired request_uri".to_string()))?;
75
76 if request_data.expires_at < Utc::now() {
77 db::delete_authorization_request(&state.db, &request_uri).await?;
78 return Err(OAuthError::InvalidRequest("request_uri has expired".to_string()));
79 }
80
81 Ok(Json(AuthorizeResponse {
82 client_id: request_data.parameters.client_id.clone(),
83 client_name: None,
84 scope: request_data.parameters.scope.clone(),
85 redirect_uri: request_data.parameters.redirect_uri.clone(),
86 state: request_data.parameters.state.clone(),
87 login_hint: request_data.parameters.login_hint.clone(),
88 }))
89}
90
91pub async fn authorize_post(
92 State(state): State<AppState>,
93 headers: HeaderMap,
94 Form(form): Form<AuthorizeSubmit>,
95) -> Result<Response, OAuthError> {
96 let request_data = db::get_authorization_request(&state.db, &form.request_uri)
97 .await?
98 .ok_or_else(|| OAuthError::InvalidRequest("Invalid or expired request_uri".to_string()))?;
99
100 if request_data.expires_at < Utc::now() {
101 db::delete_authorization_request(&state.db, &form.request_uri).await?;
102 return Err(OAuthError::InvalidRequest("request_uri has expired".to_string()));
103 }
104
105 let user = sqlx::query!(
106 r#"
107 SELECT did, password_hash, deactivated_at, takedown_ref
108 FROM users
109 WHERE handle = $1 OR email = $1
110 "#,
111 form.username
112 )
113 .fetch_optional(&state.db)
114 .await
115 .map_err(|e| OAuthError::ServerError(e.to_string()))?
116 .ok_or_else(|| OAuthError::AccessDenied("Invalid credentials".to_string()))?;
117
118 if user.deactivated_at.is_some() {
119 return Err(OAuthError::AccessDenied("Account is deactivated".to_string()));
120 }
121
122 if user.takedown_ref.is_some() {
123 return Err(OAuthError::AccessDenied("Account is taken down".to_string()));
124 }
125
126 let password_valid = bcrypt::verify(&form.password, &user.password_hash)
127 .map_err(|_| OAuthError::ServerError("Password verification failed".to_string()))?;
128
129 if !password_valid {
130 return Err(OAuthError::AccessDenied("Invalid credentials".to_string()));
131 }
132
133 let code = Code::generate();
134 let mut device_id: Option<String> = None;
135
136 if form.remember_device {
137 let new_device_id = DeviceId::generate();
138 let device_data = DeviceData {
139 session_id: SessionId::generate().0,
140 user_agent: extract_user_agent(&headers),
141 ip_address: extract_client_ip(&headers),
142 last_seen_at: Utc::now(),
143 };
144
145 db::create_device(&state.db, &new_device_id.0, &device_data).await?;
146 db::upsert_account_device(&state.db, &user.did, &new_device_id.0).await?;
147 device_id = Some(new_device_id.0);
148 }
149
150 db::update_authorization_request(
151 &state.db,
152 &form.request_uri,
153 &user.did,
154 device_id.as_deref(),
155 &code.0,
156 )
157 .await?;
158
159 let redirect_uri = &request_data.parameters.redirect_uri;
160 let mut redirect_url = redirect_uri.to_string();
161
162 let separator = if redirect_url.contains('?') { '&' } else { '?' };
163 redirect_url.push(separator);
164 redirect_url.push_str(&format!("code={}", url_encode(&code.0)));
165
166 if let Some(state) = &request_data.parameters.state {
167 redirect_url.push_str(&format!("&state={}", url_encode(state)));
168 }
169
170 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
171 redirect_url.push_str(&format!("&iss={}", url_encode(&format!("https://{}", pds_hostname))));
172
173 Ok(Redirect::temporary(&redirect_url).into_response())
174}
175
176#[derive(Debug, Serialize)]
177pub struct AuthorizeDenyResponse {
178 pub error: String,
179 pub error_description: String,
180}
181
182pub async fn authorize_deny(
183 State(state): State<AppState>,
184 Form(form): Form<AuthorizeDenyForm>,
185) -> Result<Response, OAuthError> {
186 let request_data = db::get_authorization_request(&state.db, &form.request_uri)
187 .await?
188 .ok_or_else(|| OAuthError::InvalidRequest("Invalid request_uri".to_string()))?;
189
190 db::delete_authorization_request(&state.db, &form.request_uri).await?;
191
192 let redirect_uri = &request_data.parameters.redirect_uri;
193 let mut redirect_url = redirect_uri.to_string();
194
195 let separator = if redirect_url.contains('?') { '&' } else { '?' };
196 redirect_url.push(separator);
197 redirect_url.push_str("error=access_denied");
198 redirect_url.push_str("&error_description=User%20denied%20the%20request");
199
200 if let Some(state) = &request_data.parameters.state {
201 redirect_url.push_str(&format!("&state={}", url_encode(state)));
202 }
203
204 Ok(Redirect::temporary(&redirect_url).into_response())
205}
206
207#[derive(Debug, Deserialize)]
208pub struct AuthorizeDenyForm {
209 pub request_uri: String,
210}