this repo has no description
1use axum::{ 2 Form, Json, 3 extract::{Query, State}, 4 http::HeaderMap, 5 response::{IntoResponse, Redirect, Response}, 6}; 7use chrono::Utc; 8use serde::{Deserialize, Serialize}; 9use urlencoding::encode as url_encode; 10 11use crate::state::AppState; 12use crate::oauth::{Code, DeviceData, DeviceId, OAuthError, SessionId, db}; 13 14fn extract_client_ip(headers: &HeaderMap) -> String { 15 if let Some(forwarded) = headers.get("x-forwarded-for") { 16 if let Ok(value) = forwarded.to_str() { 17 if let Some(first_ip) = value.split(',').next() { 18 return first_ip.trim().to_string(); 19 } 20 } 21 } 22 23 if let Some(real_ip) = headers.get("x-real-ip") { 24 if let Ok(value) = real_ip.to_str() { 25 return value.trim().to_string(); 26 } 27 } 28 29 "0.0.0.0".to_string() 30} 31 32fn extract_user_agent(headers: &HeaderMap) -> Option<String> { 33 headers 34 .get("user-agent") 35 .and_then(|v| v.to_str().ok()) 36 .map(|s| s.to_string()) 37} 38 39#[derive(Debug, Deserialize)] 40pub struct AuthorizeQuery { 41 pub request_uri: Option<String>, 42 pub client_id: Option<String>, 43} 44 45#[derive(Debug, Serialize)] 46pub struct AuthorizeResponse { 47 pub client_id: String, 48 pub client_name: Option<String>, 49 pub scope: Option<String>, 50 pub redirect_uri: String, 51 pub state: Option<String>, 52 pub login_hint: Option<String>, 53} 54 55#[derive(Debug, Deserialize)] 56pub struct AuthorizeSubmit { 57 pub request_uri: String, 58 pub username: String, 59 pub password: String, 60 #[serde(default)] 61 pub remember_device: bool, 62} 63 64pub async fn authorize_get( 65 State(state): State<AppState>, 66 Query(query): Query<AuthorizeQuery>, 67) -> Result<Json<AuthorizeResponse>, OAuthError> { 68 let request_uri = query.request_uri.ok_or_else(|| { 69 OAuthError::InvalidRequest("request_uri is required".to_string()) 70 })?; 71 72 let request_data = db::get_authorization_request(&state.db, &request_uri) 73 .await? 74 .ok_or_else(|| OAuthError::InvalidRequest("Invalid or expired request_uri".to_string()))?; 75 76 if request_data.expires_at < Utc::now() { 77 db::delete_authorization_request(&state.db, &request_uri).await?; 78 return Err(OAuthError::InvalidRequest("request_uri has expired".to_string())); 79 } 80 81 Ok(Json(AuthorizeResponse { 82 client_id: request_data.parameters.client_id.clone(), 83 client_name: None, 84 scope: request_data.parameters.scope.clone(), 85 redirect_uri: request_data.parameters.redirect_uri.clone(), 86 state: request_data.parameters.state.clone(), 87 login_hint: request_data.parameters.login_hint.clone(), 88 })) 89} 90 91pub async fn authorize_post( 92 State(state): State<AppState>, 93 headers: HeaderMap, 94 Form(form): Form<AuthorizeSubmit>, 95) -> Result<Response, OAuthError> { 96 let request_data = db::get_authorization_request(&state.db, &form.request_uri) 97 .await? 98 .ok_or_else(|| OAuthError::InvalidRequest("Invalid or expired request_uri".to_string()))?; 99 100 if request_data.expires_at < Utc::now() { 101 db::delete_authorization_request(&state.db, &form.request_uri).await?; 102 return Err(OAuthError::InvalidRequest("request_uri has expired".to_string())); 103 } 104 105 let user = sqlx::query!( 106 r#" 107 SELECT did, password_hash, deactivated_at, takedown_ref 108 FROM users 109 WHERE handle = $1 OR email = $1 110 "#, 111 form.username 112 ) 113 .fetch_optional(&state.db) 114 .await 115 .map_err(|e| OAuthError::ServerError(e.to_string()))? 116 .ok_or_else(|| OAuthError::AccessDenied("Invalid credentials".to_string()))?; 117 118 if user.deactivated_at.is_some() { 119 return Err(OAuthError::AccessDenied("Account is deactivated".to_string())); 120 } 121 122 if user.takedown_ref.is_some() { 123 return Err(OAuthError::AccessDenied("Account is taken down".to_string())); 124 } 125 126 let password_valid = bcrypt::verify(&form.password, &user.password_hash) 127 .map_err(|_| OAuthError::ServerError("Password verification failed".to_string()))?; 128 129 if !password_valid { 130 return Err(OAuthError::AccessDenied("Invalid credentials".to_string())); 131 } 132 133 let code = Code::generate(); 134 let mut device_id: Option<String> = None; 135 136 if form.remember_device { 137 let new_device_id = DeviceId::generate(); 138 let device_data = DeviceData { 139 session_id: SessionId::generate().0, 140 user_agent: extract_user_agent(&headers), 141 ip_address: extract_client_ip(&headers), 142 last_seen_at: Utc::now(), 143 }; 144 145 db::create_device(&state.db, &new_device_id.0, &device_data).await?; 146 db::upsert_account_device(&state.db, &user.did, &new_device_id.0).await?; 147 device_id = Some(new_device_id.0); 148 } 149 150 db::update_authorization_request( 151 &state.db, 152 &form.request_uri, 153 &user.did, 154 device_id.as_deref(), 155 &code.0, 156 ) 157 .await?; 158 159 let redirect_uri = &request_data.parameters.redirect_uri; 160 let mut redirect_url = redirect_uri.to_string(); 161 162 let separator = if redirect_url.contains('?') { '&' } else { '?' }; 163 redirect_url.push(separator); 164 redirect_url.push_str(&format!("code={}", url_encode(&code.0))); 165 166 if let Some(state) = &request_data.parameters.state { 167 redirect_url.push_str(&format!("&state={}", url_encode(state))); 168 } 169 170 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 171 redirect_url.push_str(&format!("&iss={}", url_encode(&format!("https://{}", pds_hostname)))); 172 173 Ok(Redirect::temporary(&redirect_url).into_response()) 174} 175 176#[derive(Debug, Serialize)] 177pub struct AuthorizeDenyResponse { 178 pub error: String, 179 pub error_description: String, 180} 181 182pub async fn authorize_deny( 183 State(state): State<AppState>, 184 Form(form): Form<AuthorizeDenyForm>, 185) -> Result<Response, OAuthError> { 186 let request_data = db::get_authorization_request(&state.db, &form.request_uri) 187 .await? 188 .ok_or_else(|| OAuthError::InvalidRequest("Invalid request_uri".to_string()))?; 189 190 db::delete_authorization_request(&state.db, &form.request_uri).await?; 191 192 let redirect_uri = &request_data.parameters.redirect_uri; 193 let mut redirect_url = redirect_uri.to_string(); 194 195 let separator = if redirect_url.contains('?') { '&' } else { '?' }; 196 redirect_url.push(separator); 197 redirect_url.push_str("error=access_denied"); 198 redirect_url.push_str("&error_description=User%20denied%20the%20request"); 199 200 if let Some(state) = &request_data.parameters.state { 201 redirect_url.push_str(&format!("&state={}", url_encode(state))); 202 } 203 204 Ok(Redirect::temporary(&redirect_url).into_response()) 205} 206 207#[derive(Debug, Deserialize)] 208pub struct AuthorizeDenyForm { 209 pub request_uri: String, 210}