···11[package]
22-name = "pds_bells_and_whistles"
22+name = "pds_gatekeeper"
33version = "0.1.0"
44edition = "2024"
5566[dependencies]
77-axum = { version = "0.7", features = ["macros", "json"] }
88-tokio = { version = "1.39", features = ["rt-multi-thread", "macros", "signal"] }
99-sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "sqlite"] }
77+axum = { version = "0.8.4", features = ["macros", "json"] }
88+tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros", "signal"] }
99+sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "migrate"] }
1010dotenvy = "0.15.7"
1111serde = { version = "1.0", features = ["derive"] }
1212serde_json = "1.0"
1313tracing = "0.1"
1414tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
1515+hyper-util = { version = "0.1.16", features = ["client", "client-legacy"] }
1616+tower-http = { version = "0.6", features = ["cors", "compression-zstd"] }
1717+tower_governor = "0.8.0"
1818+hex = "0.4"
1919+jwt-compact = { version = "0.8.0", features = ["es256k"] }
2020+scrypt = "0.11"
2121+lettre = { version = "0.11.18", features = ["tokio1", "pool", "tokio1-native-tls"] }
2222+handlebars = { version = "6.3.2", features = ["rust-embed"] }
2323+rust-embed = "8.7.2"
2424+axum-template = { version = "3.0.0", features = ["handlebars"] }
+49
README.md
···11+# PDS gatekeeper
22+33+A microservice that sits on the same server as the PDS to add some of the security that the entryway does.
44+55+
66+77+PDS gatekeeper works by overriding some of the PDS endpoints inside your Caddyfile to provide gatekeeping to certain
88+endpoints. Mainly, the ability to have 2FA on a self hosted PDS like it does on a Bluesky mushroom(PDS). Most of the
99+logic of these endpoints still happens on the PDS via a proxied request, just some are gatekept.
1010+1111+# Features
1212+1313+## 2FA
1414+1515+- [x] Ability to turn on/off 2FA
1616+- [x] getSession overwrite to set the `emailAuthFactor` flag if the user has 2FA turned on
1717+- [x] send an email using the `PDS_EMAIL_SMTP_URL` with a handlebar email template like Bluesky's 2FA sign in email.
1818+- [ ] generate a 2FA code
1919+- [ ] createSession gatekeeping (It does stop logins, just eh, doesn't actually send a real code or check it yet)
2020+- [ ] oauth endpoint gatekeeping
2121+2222+## Captcha on Create Account
2323+2424+Future feature?
2525+2626+# Setup
2727+2828+Nothing here yet! If you are brave enough to try before full release, let me know and I'll help you set it up.
2929+But I want to run it locally on my own PDS first to test run it a bit.
3030+3131+Example Caddyfile (mostly so I don't lose it for now. Will have a better one in the future)
3232+3333+```caddyfile
3434+http://localhost {
3535+3636+ @gatekeeper {
3737+ path /xrpc/com.atproto.server.getSession
3838+ path /xrpc/com.atproto.server.updateEmail
3939+ path /xrpc/com.atproto.server.createSession
4040+ }
4141+4242+ handle @gatekeeper {
4343+ reverse_proxy http://localhost:8080
4444+ }
4545+4646+ reverse_proxy /* http://localhost:3000
4747+}
4848+4949+```
+5
build.rs
···11+// generated by `sqlx migrate build-script`
22+fn main() {
33+ // trigger recompilation when a new migration is added
44+ println!("cargo:rerun-if-changed=migrations");
55+}
···11+-- Add migration script here
22+CREATE TABLE IF NOT EXISTS two_factor_accounts
33+(
44+ did VARCHAR PRIMARY KEY,
55+ required INT2 NOT NULL
66+);
+3
migrations_bells_and_whistles/.keep
···11+# This directory holds SQLx migrations for the bells_and_whistles.sqlite database.
22+# It is intentionally empty for now; running `sqlx::migrate!` will still ensure the
33+# migrations table exists and succeed with zero migrations.
+177-51
src/main.rs
···11+use crate::xrpc::com_atproto_server::{create_session, get_session, update_email};
22+use axum::middleware as ax_middleware;
33+mod middleware;
44+use axum::body::Body;
55+use axum::handler::Handler;
66+use axum::http::{Method, header};
77+use axum::routing::post;
88+use axum::{Router, routing::get};
99+use axum_template::engine::Engine;
1010+use handlebars::Handlebars;
1111+use hyper_util::client::legacy::connect::HttpConnector;
1212+use hyper_util::rt::TokioExecutor;
1313+use lettre::{AsyncSmtpTransport, Tokio1Executor};
1414+use rust_embed::RustEmbed;
1515+use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode};
1616+use sqlx::{SqlitePool, sqlite::SqlitePoolOptions};
1717+use std::path::Path;
1818+use std::time::Duration;
119use std::{env, net::SocketAddr};
22-use axum::{extract::State, routing::get, Json, Router};
33-// use dotenvy::dotenv;
44-use serde::Serialize;
55-use sqlx::{sqlite::SqlitePoolOptions, SqlitePool};
66-use tracing::{error, info, log};
77-use tracing_subscriber::{fmt, prelude::*, EnvFilter};
2020+use tower_governor::GovernorLayer;
2121+use tower_governor::governor::GovernorConfigBuilder;
2222+use tower_http::compression::CompressionLayer;
2323+use tower_http::cors::{Any, CorsLayer};
2424+use tracing::{error, log};
2525+use tracing_subscriber::{EnvFilter, fmt, prelude::*};
826927mod xrpc;
10282929+type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
3030+3131+#[derive(RustEmbed)]
3232+#[folder = "email_templates"]
3333+#[include = "*.hbs"]
3434+struct EmailTemplates;
3535+1136#[derive(Clone)]
1237struct AppState {
1313- pool: SqlitePool,
3838+ account_pool: SqlitePool,
3939+ pds_gatekeeper_pool: SqlitePool,
4040+ reverse_proxy_client: HyperUtilClient,
4141+ pds_base_url: String,
4242+ mailer: AsyncSmtpTransport<Tokio1Executor>,
4343+ mailer_from: String,
4444+ template_engine: Engine<Handlebars<'static>>,
1445}
15461616-#[derive(Serialize)]
1717-struct HealthResponse {
1818- status: &'static str,
1919-}
4747+async fn root_handler() -> impl axum::response::IntoResponse {
4848+ let body = r"
20492121-#[derive(Serialize)]
2222-struct DbPingResponse {
2323- db: &'static str,
2424- value: i64,
5050+ ...oO _.--X~~OO~~X--._ ...oOO
5151+ _.-~ / \ II / \ ~-._
5252+ [].-~ \ / \||/ \ / ~-.[] ...o
5353+ ...o _ ||/ \ / || \ / \|| _
5454+ (_) |X X || X X| (_)
5555+ _-~-_ ||\ / \ || / \ /|| _-~-_
5656+ ||||| || \ / \ /||\ / \ / || |||||
5757+ | |_|| \ / \ / || \ / \ / ||_| |
5858+ | |~|| X X || X X ||~| |
5959+==============| | || / \ / \ || / \ / \ || | |==============
6060+______________| | || / \ / \||/ \ / \ || | |______________
6161+ . . | | ||/ \ / || \ / \|| | | . .
6262+ / | | |X X || X X| | | / /
6363+ / . | | ||\ / \ || / \ /|| | | . / .
6464+. / | | || \ / \ /||\ / \ / || | | . .
6565+ . . | | || \ / \ / || \ / \ / || | | .
6666+ / | | || X X || X X || | | . / . /
6767+ / . | | || / \ / \ || / \ / \ || | | /
6868+ / | | || / \ / \||/ \ / \ || | | . /
6969+. . . | | ||/ \ / /||\ \ / \|| | | /. .
7070+ | |_|X X / II \ X X|_| | . . /
7171+==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |==============
7272+ ";
7373+7474+ let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
7575+7676+ let banner = format!(" {}\n{}", body, intro);
7777+7878+ (
7979+ [(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
8080+ banner,
8181+ )
2582}
26832784#[tokio::main]
2885async fn main() -> Result<(), Box<dyn std::error::Error>> {
2986 setup_tracing();
3087 //TODO prod
3131- // dotenvy::from_path(Path::new("/pds.env"))?;
3232- // let pds_root = env::var("PDS_DATA_DIRECTORY")?;
3333- let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data";
8888+ dotenvy::from_path(Path::new("./pds.env"))?;
8989+ let pds_root = env::var("PDS_DATA_DIRECTORY")?;
9090+ // let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data";
3491 let account_db_url = format!("{}/account.sqlite", pds_root);
3592 log::info!("accounts_db_url: {}", account_db_url);
3636- let max_connections: u32 = env::var("DATABASE_MAX_CONNECTIONS")
3737- .ok()
3838- .and_then(|s| s.parse().ok())
3939- .unwrap_or(5);
40934141- //TODO may need to add journal_mode=WAL ?
4242- let pool = SqlitePoolOptions::new()
4343- .max_connections(max_connections)
4444- .connect(&account_db_url)
9494+ let account_options = SqliteConnectOptions::new()
9595+ .journal_mode(SqliteJournalMode::Wal)
9696+ .filename(account_db_url);
9797+9898+ let account_pool = SqlitePoolOptions::new()
9999+ .max_connections(5)
100100+ .connect_with(account_options)
45101 .await?;
461024747- let state = AppState { pool };
103103+ let bells_db_url = format!("{}/pds_gatekeeper.sqlite", pds_root);
104104+ let options = SqliteConnectOptions::new()
105105+ .journal_mode(SqliteJournalMode::Wal)
106106+ .filename(bells_db_url)
107107+ .create_if_missing(true);
108108+ let pds_gatekeeper_pool = SqlitePoolOptions::new()
109109+ .max_connections(5)
110110+ .connect_with(options)
111111+ .await?;
112112+113113+ // Run migrations for the bells_and_whistles database
114114+ // Note: the migrations are embedded at compile time from the given directory
115115+ // sqlx
116116+ sqlx::migrate!("./migrations")
117117+ .run(&pds_gatekeeper_pool)
118118+ .await?;
119119+120120+ let client: HyperUtilClient =
121121+ hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
122122+ .build(HttpConnector::new());
123123+124124+ //Emailer set up
125125+ let smtp_url =
126126+ env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file");
127127+ let sent_from = env::var("PDS_EMAIL_FROM_ADDRESS")
128128+ .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file");
129129+ let mailer: AsyncSmtpTransport<Tokio1Executor> =
130130+ AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build();
131131+ //Email templates setup
132132+ let mut hbs = Handlebars::new();
133133+ let _ = hbs.register_embed_templates::<EmailTemplates>();
134134+135135+ let state = AppState {
136136+ account_pool,
137137+ pds_gatekeeper_pool,
138138+ reverse_proxy_client: client,
139139+ //TODO should be env prob
140140+ pds_base_url: "http://localhost:3000".to_string(),
141141+ mailer,
142142+ mailer_from: sent_from,
143143+ template_engine: Engine::from(hbs),
144144+ };
145145+146146+ // Rate limiting
147147+ //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
148148+ let governor_conf = GovernorConfigBuilder::default()
149149+ .per_second(60)
150150+ .burst_size(5)
151151+ .finish()
152152+ .unwrap();
153153+ let governor_limiter = governor_conf.limiter().clone();
154154+ let interval = Duration::from_secs(60);
155155+ // a separate background task to clean up
156156+ std::thread::spawn(move || {
157157+ loop {
158158+ std::thread::sleep(interval);
159159+ tracing::info!("rate limiting storage size: {}", governor_limiter.len());
160160+ governor_limiter.retain_recent();
161161+ }
162162+ });
163163+164164+ let cors = CorsLayer::new()
165165+ .allow_origin(Any)
166166+ .allow_methods([Method::GET, Method::OPTIONS, Method::POST])
167167+ .allow_headers(Any);
4816849169 let app = Router::new()
5050- .route("/health", get(health))
5151- .route("/db/ping", get(db_ping))
170170+ .route("/", get(root_handler))
171171+ .route(
172172+ "/xrpc/com.atproto.server.getSession",
173173+ get(get_session).layer(ax_middleware::from_fn(middleware::extract_did)),
174174+ )
175175+ .route(
176176+ "/xrpc/com.atproto.server.updateEmail",
177177+ post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
178178+ )
179179+ .route(
180180+ "/xrpc/com.atproto.server.createSession",
181181+ post(create_session.layer(GovernorLayer::new(governor_conf))),
182182+ )
183183+ .layer(CompressionLayer::new())
184184+ .layer(cors)
52185 .with_state(state);
5318654187 let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
5555- let port: u16 = env::var("PORT").ok().and_then(|s| s.parse().ok()).unwrap_or(8080);
5656- let addr: SocketAddr = format!("{host}:{port}").parse().expect("valid socket address");
5757-5858- info!(%addr, %account_db_url, "starting server");
188188+ let port: u16 = env::var("PORT")
189189+ .ok()
190190+ .and_then(|s| s.parse().ok())
191191+ .unwrap_or(8080);
192192+ let addr: SocketAddr = format!("{host}:{port}")
193193+ .parse()
194194+ .expect("valid socket address");
5919560196 let listener = tokio::net::TcpListener::bind(addr).await?;
611976262- let server = axum::serve(listener, app).with_graceful_shutdown(shutdown_signal());
198198+ let server = axum::serve(
199199+ listener,
200200+ app.into_make_service_with_connect_info::<SocketAddr>(),
201201+ )
202202+ .with_graceful_shutdown(shutdown_signal());
6320364204 if let Err(err) = server.await {
65205 error!(error = %err, "server error");
···68208 Ok(())
69209}
702107171-async fn health() -> Json<HealthResponse> {
7272- Json(HealthResponse { status: "ok" })
7373-}
7474-7575-async fn db_ping(State(state): State<AppState>) -> Result<Json<DbPingResponse>, axum::http::StatusCode> {
7676- // Run a DB-agnostic ping that doesn't depend on user tables.
7777- // In SQLite, SELECT 1 returns a single row with value 1.
7878- let v: i64 = sqlx::query_scalar("SELECT 1")
7979- .fetch_one(&state.pool)
8080- .await
8181- .map_err(|_| axum::http::StatusCode::SERVICE_UNAVAILABLE)?;
8282-8383- Ok(Json(DbPingResponse { db: "ok", value: v }))
8484-}
8585-86211fn setup_tracing() {
87212 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
88213 tracing_subscriber::registry()
···101226102227 #[cfg(unix)]
103228 let terminate = async {
104104- use tokio::signal::unix::{signal, SignalKind};
229229+ use tokio::signal::unix::{SignalKind, signal};
105230106106- let mut sigterm = signal(SignalKind::terminate()).expect("failed to install signal handler");
231231+ let mut sigterm =
232232+ signal(SignalKind::terminate()).expect("failed to install signal handler");
107233 sigterm.recv().await;
108234 };
109235
+97
src/middleware.rs
···11+use crate::xrpc::helpers::json_error_response;
22+use axum::extract::Request;
33+use axum::http::{HeaderMap, StatusCode};
44+use axum::middleware::Next;
55+use axum::response::IntoResponse;
66+use jwt_compact::alg::{Hs256, Hs256Key};
77+use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError};
88+use serde::{Deserialize, Serialize};
99+use std::env;
1010+1111+#[derive(Clone, Debug)]
1212+pub struct Did(pub Option<String>);
1313+1414+#[derive(Serialize, Deserialize)]
1515+pub struct TokenClaims {
1616+ pub sub: String,
1717+}
1818+1919+pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse {
2020+ let token = extract_bearer(req.headers());
2121+2222+ match token {
2323+ Ok(token) => {
2424+ match token {
2525+ None => {
2626+ return json_error_response(
2727+ StatusCode::BAD_REQUEST,
2828+ "TokenRequired",
2929+ "",
3030+ ).unwrap();
3131+ }
3232+ Some(token) => {
3333+ let token = UntrustedToken::new(&token);
3434+ //Doing weird unwraps cause I can't do Result for middleware?
3535+ if token.is_err() {
3636+ return json_error_response(
3737+ StatusCode::BAD_REQUEST,
3838+ "TokenRequired",
3939+ "",
4040+ ).unwrap();
4141+ }
4242+ let parsed_token = token.unwrap();
4343+ let claims: Result<Claims<TokenClaims>, ValidationError> =
4444+ parsed_token.deserialize_claims_unchecked();
4545+ if claims.is_err() {
4646+ return json_error_response(
4747+ StatusCode::BAD_REQUEST,
4848+ "TokenRequired",
4949+ "",
5050+ ).unwrap();
5151+ }
5252+5353+ let key = Hs256Key::new(env::var("PDS_JWT_SECRET").unwrap());
5454+ let token: Result<Token<TokenClaims>, ValidationError> =
5555+ Hs256.validator(&key).validate(&parsed_token);
5656+ if token.is_err() {
5757+ return json_error_response(
5858+ StatusCode::BAD_REQUEST,
5959+ "InvalidToken",
6060+ "",
6161+ ).unwrap();
6262+ }
6363+ let token = token.unwrap();
6464+ //Not going to worry about expiration since it still goes to the PDS
6565+6666+ req.extensions_mut()
6767+ .insert(Did(Some(token.claims().custom.sub.clone())));
6868+ next.run(req).await
6969+ }
7070+ }
7171+ }
7272+ Err(_) => {
7373+ return json_error_response(
7474+ StatusCode::BAD_REQUEST,
7575+ "InvalidToken",
7676+ "",
7777+ ).unwrap();
7878+ }
7979+ }
8080+}
8181+8282+fn extract_bearer(headers: &HeaderMap) -> Result<Option<String>, String> {
8383+ match headers.get(axum::http::header::AUTHORIZATION) {
8484+ None => Ok(None),
8585+ Some(hv) => match hv.to_str() {
8686+ Err(_) => Err("Authorization header is not valid".into()),
8787+ Ok(s) => {
8888+ // Accept forms like: "Bearer <token>" (case-sensitive for the scheme here)
8989+ let mut parts = s.splitn(2, ' ');
9090+ match (parts.next(), parts.next()) {
9191+ (Some("Bearer"), Some(tok)) if !tok.is_empty() => Ok(Some(tok.to_string())),
9292+ _ => Err("Authorization header must be in format 'Bearer <token>'".into()),
9393+ }
9494+ }
9595+ },
9696+ }
9797+}
+396-15
src/xrpc/com_atproto_server.rs
···11+use crate::AppState;
22+use crate::middleware::Did;
33+use crate::xrpc::helpers::{ProxiedResult, json_error_response, proxy_get_json};
44+use axum::body::Body;
15use axum::extract::State;
22-use axum::{extract, Json};
33-use serde::Deserialize;
44-use crate::{AppState, DbPingResponse};
66+use axum::http::{HeaderMap, StatusCode};
77+use axum::response::{IntoResponse, Response};
88+use axum::{Extension, Json, debug_handler, extract, extract::Request};
99+use axum_template::TemplateEngine;
1010+use lettre::message::{MultiPart, SinglePart, header};
1111+use lettre::{AsyncTransport, Message};
1212+use serde::{Deserialize, Serialize};
1313+use serde_json;
1414+use serde_json::Value;
1515+use serde_json::value::Map;
1616+use tracing::log;
5171818+#[derive(Serialize, Deserialize, Debug, Clone)]
1919+#[serde(rename_all = "camelCase")]
2020+enum AccountStatus {
2121+ Takendown,
2222+ Suspended,
2323+ Deactivated,
2424+}
62577-#[derive(Deserialize)]
88-struct CreateSessionRequest {
2626+#[derive(Serialize, Deserialize, Debug, Clone)]
2727+#[serde(rename_all = "camelCase")]
2828+struct GetSessionResponse {
2929+ handle: String,
3030+ did: String,
3131+ #[serde(skip_serializing_if = "Option::is_none")]
3232+ email: Option<String>,
3333+ #[serde(skip_serializing_if = "Option::is_none")]
3434+ email_confirmed: Option<bool>,
3535+ #[serde(skip_serializing_if = "Option::is_none")]
3636+ email_auth_factor: Option<bool>,
3737+ #[serde(skip_serializing_if = "Option::is_none")]
3838+ did_doc: Option<String>,
3939+ #[serde(skip_serializing_if = "Option::is_none")]
4040+ active: Option<bool>,
4141+ #[serde(skip_serializing_if = "Option::is_none")]
4242+ status: Option<AccountStatus>,
4343+}
4444+4545+#[derive(Serialize, Deserialize, Debug, Clone)]
4646+#[serde(rename_all = "camelCase")]
4747+pub struct UpdateEmailResponse {
4848+ email: String,
4949+ #[serde(skip_serializing_if = "Option::is_none")]
5050+ email_auth_factor: Option<bool>,
5151+ #[serde(skip_serializing_if = "Option::is_none")]
5252+ token: Option<String>,
5353+}
5454+5555+#[allow(dead_code)]
5656+#[derive(Deserialize, Serialize)]
5757+#[serde(rename_all = "camelCase")]
5858+pub struct CreateSessionRequest {
959 identifier: String,
1060 password: String,
1111- #[serde(rename = "authFactorToken")]
1261 auth_factor_token: String,
1313- #[serde(rename = "allowTakendown")]
1462 allow_takendown: bool,
1563}
16641717-async fn create_session(State(state): State<AppState>, extract::Json(payload): extract::Json<CreateSessionRequest>) -> Result<Json<DbPingResponse>, axum::http::StatusCode> {
1818- // Run a DB-agnostic ping that doesn't depend on user tables.
1919- // In SQLite, SELECT 1 returns a single row with value 1.
2020- let v: i64 = sqlx::query_scalar("SELECT 1")
2121- .fetch_one(&state.pool)
6565+pub enum AuthResult {
6666+ WrongIdentityOrPassword,
6767+ TwoFactorRequired,
6868+ TwoFactorFailed,
6969+ /// User does not have 2FA enabled, or passes it
7070+ ProxyThrough,
7171+}
7272+7373+pub enum IdentifierType {
7474+ Email,
7575+ DID,
7676+ Handle,
7777+}
7878+7979+impl IdentifierType {
8080+ fn what_is_it(identifier: String) -> Self {
8181+ if identifier.contains("@") {
8282+ IdentifierType::Email
8383+ } else if identifier.contains("did:") {
8484+ IdentifierType::DID
8585+ } else {
8686+ IdentifierType::Handle
8787+ }
8888+ }
8989+}
9090+9191+async fn verify_password(password: &str, password_scrypt: &str) -> Result<bool, StatusCode> {
9292+ // Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes)
9393+ let mut parts = password_scrypt.splitn(2, ':');
9494+ let salt = match parts.next() {
9595+ Some(s) if !s.is_empty() => s,
9696+ _ => return Ok(false),
9797+ };
9898+ let stored_hash_hex = match parts.next() {
9999+ Some(h) if !h.is_empty() => h,
100100+ _ => return Ok(false),
101101+ };
102102+103103+ //Sets up scrypt to mimic node's scrypt
104104+ let params = match scrypt::Params::new(14, 8, 1, 64) {
105105+ Ok(p) => p,
106106+ Err(_) => return Ok(false),
107107+ };
108108+ let mut derived = [0u8; 64];
109109+ if scrypt::scrypt(password.as_bytes(), salt.as_bytes(), ¶ms, &mut derived).is_err() {
110110+ return Ok(false);
111111+ }
112112+113113+ let stored_bytes = match hex::decode(stored_hash_hex) {
114114+ Ok(b) => b,
115115+ Err(e) => {
116116+ log::error!("Error decoding stored hash: {}", e);
117117+ return Ok(false);
118118+ }
119119+ };
120120+121121+ Ok(derived.as_slice() == stored_bytes.as_slice())
122122+}
123123+124124+async fn preauth_check(
125125+ state: &AppState,
126126+ identifier: &str,
127127+ password: &str,
128128+) -> Result<AuthResult, StatusCode> {
129129+ // Determine identifier type
130130+ let id_type = IdentifierType::what_is_it(identifier.to_string());
131131+132132+ // Query account DB for did and passwordScrypt based on identifier type
133133+ let account_row: Option<(String, String, String)> = match id_type {
134134+ IdentifierType::Email => sqlx::query_as::<_, (String, String, String)>(
135135+ "SELECT did, passwordScrypt, account.email FROM account WHERE email = ? LIMIT 1",
136136+ )
137137+ .bind(identifier)
138138+ .fetch_optional(&state.account_pool)
139139+ .await
140140+ .map_err(|_| StatusCode::BAD_REQUEST)?,
141141+ IdentifierType::Handle => sqlx::query_as::<_, (String, String, String)>(
142142+ "SELECT account.did, account.passwordScrypt, account.email
143143+ FROM actor
144144+ LEFT JOIN account ON actor.did = account.did
145145+ where actor.handle =? LIMIT 1",
146146+ )
147147+ .bind(identifier)
148148+ .fetch_optional(&state.account_pool)
149149+ .await
150150+ .map_err(|_| StatusCode::BAD_REQUEST)?,
151151+ IdentifierType::DID => sqlx::query_as::<_, (String, String, String)>(
152152+ "SELECT did, passwordScrypt, account.email FROM account WHERE did = ? LIMIT 1",
153153+ )
154154+ .bind(identifier)
155155+ .fetch_optional(&state.account_pool)
156156+ .await
157157+ .map_err(|_| StatusCode::BAD_REQUEST)?,
158158+ };
159159+160160+ if let Some((did, password_scrypt, email)) = account_row {
161161+ // Check two-factor requirement for this DID in the gatekeeper DB
162162+ let required_opt = sqlx::query_as::<_, (u8,)>(
163163+ "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1",
164164+ )
165165+ .bind(&did)
166166+ .fetch_optional(&state.pds_gatekeeper_pool)
167167+ .await
168168+ .map_err(|_| StatusCode::BAD_REQUEST)?;
169169+170170+ let two_factor_required = match required_opt {
171171+ Some(row) => row.0 != 0,
172172+ None => false,
173173+ };
174174+175175+ if two_factor_required {
176176+ // Verify password before proceeding to 2FA email step
177177+ let verified = verify_password(password, &password_scrypt).await?;
178178+ if !verified {
179179+ return Ok(AuthResult::WrongIdentityOrPassword);
180180+ }
181181+ let mut email_data = Map::new();
182182+ //TODO these need real values
183183+ let token = "test".to_string();
184184+ let handle = "baileytownsend.dev".to_string();
185185+ email_data.insert("token".to_string(), Value::from(token.clone()));
186186+ email_data.insert("handle".to_string(), Value::from(handle.clone()));
187187+ //TODO bad unwrap
188188+ let email_body = state
189189+ .template_engine
190190+ .render("two_factor_code.hbs", email_data)
191191+ .unwrap();
192192+193193+ let email = Message::builder()
194194+ //TODO prob get the proper type in the state
195195+ .from(state.mailer_from.parse().unwrap())
196196+ .to(email.parse().unwrap())
197197+ .subject("Sign in to Bluesky")
198198+ .multipart(
199199+ MultiPart::alternative() // This is composed of two parts.
200200+ .singlepart(
201201+ SinglePart::builder()
202202+ .header(header::ContentType::TEXT_PLAIN)
203203+ .body(format!("We received a sign-in request for the account @{}. Use the code: {} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.", handle, token)), // Every message should have a plain text fallback.
204204+ )
205205+ .singlepart(
206206+ SinglePart::builder()
207207+ .header(header::ContentType::TEXT_HTML)
208208+ .body(email_body),
209209+ ),
210210+ )
211211+ //TODO bad
212212+ .unwrap();
213213+ return match state.mailer.send(email).await {
214214+ Ok(_) => Ok(AuthResult::TwoFactorRequired),
215215+ Err(err) => {
216216+ log::error!("Error sending the 2FA email: {}", err);
217217+ Err(StatusCode::BAD_REQUEST)
218218+ }
219219+ };
220220+ }
221221+ }
222222+223223+ // No local 2FA requirement (or account not found)
224224+ Ok(AuthResult::ProxyThrough)
225225+}
226226+227227+pub async fn create_session(
228228+ State(state): State<AppState>,
229229+ headers: HeaderMap,
230230+ Json(payload): extract::Json<CreateSessionRequest>,
231231+) -> Result<Response<Body>, StatusCode> {
232232+ let identifier = payload.identifier.clone();
233233+ let password = payload.password.clone();
234234+235235+ // Run the shared pre-auth logic to validate and check 2FA requirement
236236+ match preauth_check(&state, &identifier, &password).await? {
237237+ AuthResult::WrongIdentityOrPassword => json_error_response(
238238+ StatusCode::UNAUTHORIZED,
239239+ "AuthenticationRequired",
240240+ "Invalid identifier or password",
241241+ ),
242242+ AuthResult::TwoFactorRequired => {
243243+ // Email sending step can be handled here if needed in the future.
244244+ json_error_response(
245245+ StatusCode::UNAUTHORIZED,
246246+ "AuthFactorTokenRequired",
247247+ "A sign in code has been sent to your email address",
248248+ )
249249+ }
250250+ AuthResult::TwoFactorFailed => {
251251+ //Not sure what the errors are for this response is yet
252252+ json_error_response(StatusCode::UNAUTHORIZED, "PLACEHOLDER", "PLACEHOLDER")
253253+ }
254254+ AuthResult::ProxyThrough => {
255255+ //No 2FA or already passed
256256+ let uri = format!(
257257+ "{}{}",
258258+ state.pds_base_url, "/xrpc/com.atproto.server.createSession"
259259+ );
260260+261261+ let mut req = axum::http::Request::post(uri);
262262+ if let Some(req_headers) = req.headers_mut() {
263263+ req_headers.extend(headers.clone());
264264+ }
265265+266266+ let payload_bytes =
267267+ serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
268268+ let req = req
269269+ .body(Body::from(payload_bytes))
270270+ .map_err(|_| StatusCode::BAD_REQUEST)?;
271271+272272+ let proxied = state
273273+ .reverse_proxy_client
274274+ .request(req)
275275+ .await
276276+ .map_err(|_| StatusCode::BAD_REQUEST)?
277277+ .into_response();
278278+279279+ Ok(proxied)
280280+ }
281281+ }
282282+}
283283+284284+#[debug_handler]
285285+pub async fn update_email(
286286+ State(state): State<AppState>,
287287+ Extension(did): Extension<Did>,
288288+ headers: HeaderMap,
289289+ Json(payload): extract::Json<UpdateEmailResponse>,
290290+) -> Result<Response<Body>, StatusCode> {
291291+ //If email auth is not set at all it is a update email address
292292+ let email_auth_not_set = payload.email_auth_factor.is_none();
293293+ //If email aurth is set it is to either turn on or off 2fa
294294+ let email_auth_update = payload.email_auth_factor.unwrap_or(false);
295295+296296+ // Email update asked for
297297+ if email_auth_update {
298298+ let email = payload.email.clone();
299299+ let email_confirmed = sqlx::query_as::<_, (String,)>(
300300+ "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?",
301301+ )
302302+ .bind(&email)
303303+ .fetch_optional(&state.account_pool)
304304+ .await
305305+ .map_err(|_| StatusCode::BAD_REQUEST)?;
306306+307307+ //Since the email is already confirmed we can enable 2fa
308308+ return match email_confirmed {
309309+ None => Err(StatusCode::BAD_REQUEST),
310310+ Some(did_row) => {
311311+ let _ = sqlx::query(
312312+ "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1",
313313+ )
314314+ .bind(&did_row.0)
315315+ .execute(&state.pds_gatekeeper_pool)
316316+ .await
317317+ .map_err(|_| StatusCode::BAD_REQUEST)?;
318318+319319+ Ok(StatusCode::OK.into_response())
320320+ }
321321+ };
322322+ }
323323+324324+ // User wants auth turned off
325325+ if !email_auth_update && !email_auth_not_set {
326326+ //User wants auth turned off and has a token
327327+ if let Some(token) = &payload.token {
328328+ let token_found = sqlx::query_as::<_, (String,)>(
329329+ "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'",
330330+ )
331331+ .bind(token)
332332+ .bind(&did.0)
333333+ .fetch_optional(&state.account_pool)
334334+ .await
335335+ .map_err(|_| StatusCode::BAD_REQUEST)?;
336336+337337+ if token_found.is_some() {
338338+ let _ = sqlx::query(
339339+ "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0",
340340+ )
341341+ .bind(&did.0)
342342+ .execute(&state.pds_gatekeeper_pool)
343343+ .await
344344+ .map_err(|_| StatusCode::BAD_REQUEST)?;
345345+346346+ return Ok(StatusCode::OK.into_response());
347347+ } else {
348348+ return Err(StatusCode::BAD_REQUEST);
349349+ }
350350+ }
351351+ }
352352+353353+ // Updating the acutal email address
354354+ let uri = format!(
355355+ "{}{}",
356356+ state.pds_base_url, "/xrpc/com.atproto.server.updateEmail"
357357+ );
358358+ let mut req = axum::http::Request::post(uri);
359359+ if let Some(req_headers) = req.headers_mut() {
360360+ req_headers.extend(headers.clone());
361361+ }
362362+363363+ let payload_bytes = serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
364364+ let req = req
365365+ .body(Body::from(payload_bytes))
366366+ .map_err(|_| StatusCode::BAD_REQUEST)?;
367367+368368+ let proxied = state
369369+ .reverse_proxy_client
370370+ .request(req)
22371 .await
2323- .map_err(|_| axum::http::StatusCode::SERVICE_UNAVAILABLE)?;
372372+ .map_err(|_| StatusCode::BAD_REQUEST)?
373373+ .into_response();
374374+375375+ Ok(proxied)
376376+}
377377+378378+pub async fn get_session(
379379+ State(state): State<AppState>,
380380+ req: Request,
381381+) -> Result<Response<Body>, StatusCode> {
382382+ match proxy_get_json::<GetSessionResponse>(&state, req, "/xrpc/com.atproto.server.getSession")
383383+ .await?
384384+ {
385385+ ProxiedResult::Parsed {
386386+ value: mut session, ..
387387+ } => {
388388+ let did = session.did.clone();
389389+ let required_opt = sqlx::query_as::<_, (u8,)>(
390390+ "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1",
391391+ )
392392+ .bind(&did)
393393+ .fetch_optional(&state.pds_gatekeeper_pool)
394394+ .await
395395+ .map_err(|_| StatusCode::BAD_REQUEST)?;
396396+397397+ let email_auth_factor = match required_opt {
398398+ Some(row) => row.0 != 0,
399399+ None => false,
400400+ };
244012525- Ok(Json(DbPingResponse { db: "ok", value: v }))
2626-}402402+ session.email_auth_factor = Some(email_auth_factor);
403403+ Ok(Json(session).into_response())
404404+ }
405405+ ProxiedResult::Passthrough(resp) => Ok(resp),
406406+ }
407407+}
+150
src/xrpc/helpers.rs
···11+use axum::body::{Body, to_bytes};
22+use axum::extract::Request;
33+use axum::http::{HeaderMap, Method, StatusCode, Uri};
44+use axum::http::header::CONTENT_TYPE;
55+use axum::response::{IntoResponse, Response};
66+use serde::de::DeserializeOwned;
77+use tracing::error;
88+99+use crate::AppState;
1010+1111+/// The result of a proxied call that attempts to parse JSON.
1212+pub enum ProxiedResult<T> {
1313+ /// Successfully parsed JSON body along with original response headers.
1414+ Parsed { value: T, _headers: HeaderMap },
1515+ /// Could not or should not parse: return the original (or rebuilt) response as-is.
1616+ Passthrough(Response<Body>),
1717+}
1818+1919+/// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse
2020+/// the successful response body as JSON into `T`.
2121+///
2222+/// Behavior:
2323+/// - If the proxied response is non-200, returns Passthrough with the original response.
2424+/// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers.
2525+/// - If parsing succeeds, returns Parsed { value, headers }.
2626+pub async fn proxy_get_json<T>(
2727+ state: &AppState,
2828+ mut req: Request,
2929+ path: &str,
3030+) -> Result<ProxiedResult<T>, StatusCode>
3131+where
3232+ T: DeserializeOwned,
3333+{
3434+ let uri = format!("{}{}", state.pds_base_url, path);
3535+ *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?;
3636+3737+ let result = state
3838+ .reverse_proxy_client
3939+ .request(req)
4040+ .await
4141+ .map_err(|_| StatusCode::BAD_REQUEST)?
4242+ .into_response();
4343+4444+ if result.status() != StatusCode::OK {
4545+ return Ok(ProxiedResult::Passthrough(result));
4646+ }
4747+4848+ let response_headers = result.headers().clone();
4949+ let body = result.into_body();
5050+ let body_bytes = to_bytes(body, usize::MAX)
5151+ .await
5252+ .map_err(|_| StatusCode::BAD_REQUEST)?;
5353+5454+ match serde_json::from_slice::<T>(&body_bytes) {
5555+ Ok(value) => Ok(ProxiedResult::Parsed {
5656+ value,
5757+ _headers: response_headers,
5858+ }),
5959+ Err(err) => {
6060+ error!(%err, "failed to parse proxied JSON response; returning original body");
6161+ let mut builder = Response::builder().status(StatusCode::OK);
6262+ if let Some(headers) = builder.headers_mut() {
6363+ *headers = response_headers;
6464+ }
6565+ let resp = builder
6666+ .body(Body::from(body_bytes))
6767+ .map_err(|_| StatusCode::BAD_REQUEST)?;
6868+ Ok(ProxiedResult::Passthrough(resp))
6969+ }
7070+ }
7171+}
7272+7373+/// Proxy the incoming request as a POST to the PDS base URL plus the provided path and attempt to parse
7474+/// the successful response body as JSON into `T`.
7575+///
7676+/// Behavior mirrors `proxy_get_json`:
7777+/// - If the proxied response is non-200, returns Passthrough with the original response.
7878+/// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers.
7979+/// - If parsing succeeds, returns Parsed { value, headers }.
8080+pub async fn _proxy_post_json<T>(
8181+ state: &AppState,
8282+ mut req: Request,
8383+ path: &str,
8484+) -> Result<ProxiedResult<T>, StatusCode>
8585+where
8686+ T: DeserializeOwned,
8787+{
8888+ let uri = format!("{}{}", state.pds_base_url, path);
8989+ *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?;
9090+ *req.method_mut() = Method::POST;
9191+9292+ let result = state
9393+ .reverse_proxy_client
9494+ .request(req)
9595+ .await
9696+ .map_err(|_| StatusCode::BAD_REQUEST)?
9797+ .into_response();
9898+9999+ if result.status() != StatusCode::OK {
100100+ return Ok(ProxiedResult::Passthrough(result));
101101+ }
102102+103103+ let response_headers = result.headers().clone();
104104+ let body = result.into_body();
105105+ let body_bytes = to_bytes(body, usize::MAX)
106106+ .await
107107+ .map_err(|_| StatusCode::BAD_REQUEST)?;
108108+109109+ match serde_json::from_slice::<T>(&body_bytes) {
110110+ Ok(value) => Ok(ProxiedResult::Parsed {
111111+ value,
112112+ _headers: response_headers,
113113+ }),
114114+ Err(err) => {
115115+ error!(%err, "failed to parse proxied JSON response (POST); returning original body");
116116+ let mut builder = Response::builder().status(StatusCode::OK);
117117+ if let Some(headers) = builder.headers_mut() {
118118+ *headers = response_headers;
119119+ }
120120+ let resp = builder
121121+ .body(Body::from(body_bytes))
122122+ .map_err(|_| StatusCode::BAD_REQUEST)?;
123123+ Ok(ProxiedResult::Passthrough(resp))
124124+ }
125125+ }
126126+}
127127+128128+129129+/// Build a JSON error response with the required Content-Type header
130130+/// Content-Type: application/json;charset=utf-8
131131+/// Body shape: { "error": string, "message": string }
132132+pub fn json_error_response(
133133+ status: StatusCode,
134134+ error: impl Into<String>,
135135+ message: impl Into<String>,
136136+) -> Result<Response<Body>, StatusCode> {
137137+ let body_str = match serde_json::to_string(&serde_json::json!({
138138+ "error": error.into(),
139139+ "message": message.into(),
140140+ })) {
141141+ Ok(s) => s,
142142+ Err(_) => return Err(StatusCode::BAD_REQUEST),
143143+ };
144144+145145+ Response::builder()
146146+ .status(status)
147147+ .header(CONTENT_TYPE, "application/json;charset=utf-8")
148148+ .body(Body::from(body_str))
149149+ .map_err(|_| StatusCode::BAD_REQUEST)
150150+}
+2-1
src/xrpc/mod.rs
···11-mod com_atproto_server;11+pub mod com_atproto_server;
22+pub mod helpers;