forked from
baileytownsend.dev/pds-gatekeeper
Microservice to bring 2FA to self hosted PDSes
1use crate::AppState;
2use crate::helpers::{AuthResult, oauth_json_error_response, preauth_check};
3use axum::body::Body;
4use axum::extract::State;
5use axum::http::header::CONTENT_TYPE;
6use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
7use axum::response::{IntoResponse, Response};
8use axum::{Json, extract};
9use serde::{Deserialize, Serialize};
10use tracing::log;
11
12#[derive(Serialize, Deserialize, Clone)]
13pub struct SignInRequest {
14 pub username: String,
15 pub password: String,
16 pub remember: bool,
17 pub locale: String,
18 #[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")]
19 pub email_otp: Option<String>,
20}
21
22pub async fn sign_in(
23 State(state): State<AppState>,
24 headers: HeaderMap,
25 Json(mut payload): extract::Json<SignInRequest>,
26) -> Result<Response<Body>, StatusCode> {
27 let identifier = payload.username.clone();
28 let password = payload.password.clone();
29 let auth_factor_token = payload.email_otp.clone();
30
31 match preauth_check(&state, &identifier, &password, auth_factor_token, true).await {
32 Ok(result) => match result {
33 AuthResult::WrongIdentityOrPassword => oauth_json_error_response(
34 StatusCode::BAD_REQUEST,
35 "invalid_request",
36 "Invalid identifier or password",
37 ),
38 AuthResult::TwoFactorRequired(masked_email) => {
39 // Email sending step can be handled here if needed in the future.
40
41 // {"error":"second_authentication_factor_required","error_description":"emailOtp authentication factor required (hint: 2***0@p***m)","type":"emailOtp","hint":"2***0@p***m"}
42 let body_str = match serde_json::to_string(&serde_json::json!({
43 "error": "second_authentication_factor_required",
44 "error_description": format!("emailOtp authentication factor required (hint: {})", masked_email),
45 "type": "emailOtp",
46 "hint": masked_email,
47 })) {
48 Ok(s) => s,
49 Err(_) => return Err(StatusCode::BAD_REQUEST),
50 };
51
52 Response::builder()
53 .status(StatusCode::BAD_REQUEST)
54 .header(CONTENT_TYPE, "application/json")
55 .body(Body::from(body_str))
56 .map_err(|_| StatusCode::BAD_REQUEST)
57 }
58 AuthResult::ProxyThrough => {
59 //No 2FA or already passed
60 let uri = format!(
61 "{}{}",
62 state.pds_base_url, "/@atproto/oauth-provider/~api/sign-in"
63 );
64
65 let mut req = axum::http::Request::post(uri);
66 if let Some(req_headers) = req.headers_mut() {
67 // Copy headers but remove problematic ones. There was an issue with the PDS not parsing the body fully if i forwarded all headers
68 copy_filtered_headers(&headers, req_headers);
69 //Setting the content type to application/json manually
70 req_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
71 }
72
73 //Clears the email_otp because the pds will reject a request with it.
74 payload.email_otp = None;
75 let payload_bytes =
76 serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
77
78 let req = req
79 .body(Body::from(payload_bytes))
80 .map_err(|_| StatusCode::BAD_REQUEST)?;
81
82 let proxied = state
83 .reverse_proxy_client
84 .request(req)
85 .await
86 .map_err(|_| StatusCode::BAD_REQUEST)?
87 .into_response();
88
89 Ok(proxied)
90 }
91 //Ignoring the type of token check failure. Looks like oauth on the entry treads them the same.
92 AuthResult::TokenCheckFailed(_) => oauth_json_error_response(
93 StatusCode::BAD_REQUEST,
94 "invalid_request",
95 "Unable to sign-in due to an unexpected server error",
96 ),
97 },
98 Err(err) => {
99 log::error!(
100 "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}"
101 );
102 oauth_json_error_response(
103 StatusCode::BAD_REQUEST,
104 "pds_gatekeeper_error",
105 "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.",
106 )
107 }
108 }
109}
110
111fn is_disallowed_header(name: &HeaderName) -> bool {
112 // possible problematic headers with proxying
113 matches!(
114 name.as_str(),
115 "connection"
116 | "keep-alive"
117 | "proxy-authenticate"
118 | "proxy-authorization"
119 | "te"
120 | "trailer"
121 | "transfer-encoding"
122 | "upgrade"
123 | "host"
124 | "content-length"
125 | "content-encoding"
126 | "expect"
127 | "accept-encoding"
128 )
129}
130
131fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) {
132 for (name, value) in src.iter() {
133 if is_disallowed_header(name) {
134 continue;
135 }
136 // Only copy valid headers
137 if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) {
138 dst.insert(name.clone(), hv);
139 }
140 }
141}