···1+# PDS gatekeeper
2+3+A microservice that sits on the same server as the PDS to add some of the security that the entryway does.
4+5+
6+7+PDS gatekeeper works by overriding some of the PDS endpoints inside your Caddyfile to provide gatekeeping to certain
8+endpoints. Mainly, the ability to have 2FA on a self hosted PDS like it does on a Bluesky mushroom(PDS). Most of the
9+logic of these endpoints still happens on the PDS via a proxied request, just some are gatekept.
10+11+# Features
12+13+## 2FA
14+15+- [x] Ability to turn on/off 2FA
16+- [x] getSession overwrite to set the `emailAuthFactor` flag if the user has 2FA turned on
17+- [x] send an email using the `PDS_EMAIL_SMTP_URL` with a handlebar email template like Bluesky's 2FA sign in email.
18+- [ ] generate a 2FA code
19+- [ ] createSession gatekeeping (It does stop logins, just eh, doesn't actually send a real code or check it yet)
20+- [ ] oauth endpoint gatekeeping
21+22+## Captcha on Create Account
23+24+Future feature?
25+26+# Setup
27+28+Nothing here yet! If you are brave enough to try before full release, let me know and I'll help you set it up.
29+But I want to run it locally on my own PDS first to test run it a bit.
30+31+Example Caddyfile (mostly so I don't lose it for now. Will have a better one in the future)
32+33+```caddyfile
34+http://localhost {
35+36+ @gatekeeper {
37+ path /xrpc/com.atproto.server.getSession
38+ path /xrpc/com.atproto.server.updateEmail
39+ path /xrpc/com.atproto.server.createSession
40+ }
41+42+ handle @gatekeeper {
43+ reverse_proxy http://localhost:8080
44+ }
45+46+ reverse_proxy /* http://localhost:3000
47+}
48+49+```
+5
build.rs
···00000
···1+// generated by `sqlx migrate build-script`
2+fn main() {
3+ // trigger recompilation when a new migration is added
4+ println!("cargo:rerun-if-changed=migrations");
5+}
···1+-- Add migration script here
2+CREATE TABLE IF NOT EXISTS two_factor_accounts
3+(
4+ did VARCHAR PRIMARY KEY,
5+ required INT2 NOT NULL
6+);
+3
migrations_bells_and_whistles/.keep
···000
···1+# This directory holds SQLx migrations for the bells_and_whistles.sqlite database.
2+# It is intentionally empty for now; running `sqlx::migrate!` will still ensure the
3+# migrations table exists and succeed with zero migrations.
+177-51
src/main.rs
···0000000000000000001use std::{env, net::SocketAddr};
2-use axum::{extract::State, routing::get, Json, Router};
3-// use dotenvy::dotenv;
4-use serde::Serialize;
5-use sqlx::{sqlite::SqlitePoolOptions, SqlitePool};
6-use tracing::{error, info, log};
7-use tracing_subscriber::{fmt, prelude::*, EnvFilter};
89mod xrpc;
10000000011#[derive(Clone)]
12struct AppState {
13- pool: SqlitePool,
00000014}
1516-#[derive(Serialize)]
17-struct HealthResponse {
18- status: &'static str,
19-}
2021-#[derive(Serialize)]
22-struct DbPingResponse {
23- db: &'static str,
24- value: i64,
000000000000000000000000000025}
2627#[tokio::main]
28async fn main() -> Result<(), Box<dyn std::error::Error>> {
29 setup_tracing();
30 //TODO prod
31- // dotenvy::from_path(Path::new("/pds.env"))?;
32- // let pds_root = env::var("PDS_DATA_DIRECTORY")?;
33- let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data";
34 let account_db_url = format!("{}/account.sqlite", pds_root);
35 log::info!("accounts_db_url: {}", account_db_url);
36- let max_connections: u32 = env::var("DATABASE_MAX_CONNECTIONS")
37- .ok()
38- .and_then(|s| s.parse().ok())
39- .unwrap_or(5);
4041- //TODO may need to add journal_mode=WAL ?
42- let pool = SqlitePoolOptions::new()
43- .max_connections(max_connections)
44- .connect(&account_db_url)
00045 .await?;
4647- let state = AppState { pool };
00000000000000000000000000000000000000000000000000000000000000004849 let app = Router::new()
50- .route("/health", get(health))
51- .route("/db/ping", get(db_ping))
000000000000052 .with_state(state);
5354 let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
55- let port: u16 = env::var("PORT").ok().and_then(|s| s.parse().ok()).unwrap_or(8080);
56- let addr: SocketAddr = format!("{host}:{port}").parse().expect("valid socket address");
57-58- info!(%addr, %account_db_url, "starting server");
0005960 let listener = tokio::net::TcpListener::bind(addr).await?;
6162- let server = axum::serve(listener, app).with_graceful_shutdown(shutdown_signal());
00006364 if let Err(err) = server.await {
65 error!(error = %err, "server error");
···68 Ok(())
69}
7071-async fn health() -> Json<HealthResponse> {
72- Json(HealthResponse { status: "ok" })
73-}
74-75-async fn db_ping(State(state): State<AppState>) -> Result<Json<DbPingResponse>, axum::http::StatusCode> {
76- // Run a DB-agnostic ping that doesn't depend on user tables.
77- // In SQLite, SELECT 1 returns a single row with value 1.
78- let v: i64 = sqlx::query_scalar("SELECT 1")
79- .fetch_one(&state.pool)
80- .await
81- .map_err(|_| axum::http::StatusCode::SERVICE_UNAVAILABLE)?;
82-83- Ok(Json(DbPingResponse { db: "ok", value: v }))
84-}
85-86fn setup_tracing() {
87 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
88 tracing_subscriber::registry()
···101102 #[cfg(unix)]
103 let terminate = async {
104- use tokio::signal::unix::{signal, SignalKind};
105106- let mut sigterm = signal(SignalKind::terminate()).expect("failed to install signal handler");
0107 sigterm.recv().await;
108 };
109
···1+use crate::xrpc::com_atproto_server::{create_session, get_session, update_email};
2+use axum::middleware as ax_middleware;
3+mod middleware;
4+use axum::body::Body;
5+use axum::handler::Handler;
6+use axum::http::{Method, header};
7+use axum::routing::post;
8+use axum::{Router, routing::get};
9+use axum_template::engine::Engine;
10+use handlebars::Handlebars;
11+use hyper_util::client::legacy::connect::HttpConnector;
12+use hyper_util::rt::TokioExecutor;
13+use lettre::{AsyncSmtpTransport, Tokio1Executor};
14+use rust_embed::RustEmbed;
15+use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode};
16+use sqlx::{SqlitePool, sqlite::SqlitePoolOptions};
17+use std::path::Path;
18+use std::time::Duration;
19use std::{env, net::SocketAddr};
20+use tower_governor::GovernorLayer;
21+use tower_governor::governor::GovernorConfigBuilder;
22+use tower_http::compression::CompressionLayer;
23+use tower_http::cors::{Any, CorsLayer};
24+use tracing::{error, log};
25+use tracing_subscriber::{EnvFilter, fmt, prelude::*};
2627mod xrpc;
2829+type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
30+31+#[derive(RustEmbed)]
32+#[folder = "email_templates"]
33+#[include = "*.hbs"]
34+struct EmailTemplates;
35+36#[derive(Clone)]
37struct AppState {
38+ account_pool: SqlitePool,
39+ pds_gatekeeper_pool: SqlitePool,
40+ reverse_proxy_client: HyperUtilClient,
41+ pds_base_url: String,
42+ mailer: AsyncSmtpTransport<Tokio1Executor>,
43+ mailer_from: String,
44+ template_engine: Engine<Handlebars<'static>>,
45}
4647+async fn root_handler() -> impl axum::response::IntoResponse {
48+ let body = r"
004950+ ...oO _.--X~~OO~~X--._ ...oOO
51+ _.-~ / \ II / \ ~-._
52+ [].-~ \ / \||/ \ / ~-.[] ...o
53+ ...o _ ||/ \ / || \ / \|| _
54+ (_) |X X || X X| (_)
55+ _-~-_ ||\ / \ || / \ /|| _-~-_
56+ ||||| || \ / \ /||\ / \ / || |||||
57+ | |_|| \ / \ / || \ / \ / ||_| |
58+ | |~|| X X || X X ||~| |
59+==============| | || / \ / \ || / \ / \ || | |==============
60+______________| | || / \ / \||/ \ / \ || | |______________
61+ . . | | ||/ \ / || \ / \|| | | . .
62+ / | | |X X || X X| | | / /
63+ / . | | ||\ / \ || / \ /|| | | . / .
64+. / | | || \ / \ /||\ / \ / || | | . .
65+ . . | | || \ / \ / || \ / \ / || | | .
66+ / | | || X X || X X || | | . / . /
67+ / . | | || / \ / \ || / \ / \ || | | /
68+ / | | || / \ / \||/ \ / \ || | | . /
69+. . . | | ||/ \ / /||\ \ / \|| | | /. .
70+ | |_|X X / II \ X X|_| | . . /
71+==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |==============
72+ ";
73+74+ let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
75+76+ let banner = format!(" {}\n{}", body, intro);
77+78+ (
79+ [(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
80+ banner,
81+ )
82}
8384#[tokio::main]
85async fn main() -> Result<(), Box<dyn std::error::Error>> {
86 setup_tracing();
87 //TODO prod
88+ dotenvy::from_path(Path::new("./pds.env"))?;
89+ let pds_root = env::var("PDS_DATA_DIRECTORY")?;
90+ // let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data";
91 let account_db_url = format!("{}/account.sqlite", pds_root);
92 log::info!("accounts_db_url: {}", account_db_url);
00009394+ let account_options = SqliteConnectOptions::new()
95+ .journal_mode(SqliteJournalMode::Wal)
96+ .filename(account_db_url);
97+98+ let account_pool = SqlitePoolOptions::new()
99+ .max_connections(5)
100+ .connect_with(account_options)
101 .await?;
102103+ let bells_db_url = format!("{}/pds_gatekeeper.sqlite", pds_root);
104+ let options = SqliteConnectOptions::new()
105+ .journal_mode(SqliteJournalMode::Wal)
106+ .filename(bells_db_url)
107+ .create_if_missing(true);
108+ let pds_gatekeeper_pool = SqlitePoolOptions::new()
109+ .max_connections(5)
110+ .connect_with(options)
111+ .await?;
112+113+ // Run migrations for the bells_and_whistles database
114+ // Note: the migrations are embedded at compile time from the given directory
115+ // sqlx
116+ sqlx::migrate!("./migrations")
117+ .run(&pds_gatekeeper_pool)
118+ .await?;
119+120+ let client: HyperUtilClient =
121+ hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
122+ .build(HttpConnector::new());
123+124+ //Emailer set up
125+ let smtp_url =
126+ env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file");
127+ let sent_from = env::var("PDS_EMAIL_FROM_ADDRESS")
128+ .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file");
129+ let mailer: AsyncSmtpTransport<Tokio1Executor> =
130+ AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build();
131+ //Email templates setup
132+ let mut hbs = Handlebars::new();
133+ let _ = hbs.register_embed_templates::<EmailTemplates>();
134+135+ let state = AppState {
136+ account_pool,
137+ pds_gatekeeper_pool,
138+ reverse_proxy_client: client,
139+ //TODO should be env prob
140+ pds_base_url: "http://localhost:3000".to_string(),
141+ mailer,
142+ mailer_from: sent_from,
143+ template_engine: Engine::from(hbs),
144+ };
145+146+ // Rate limiting
147+ //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
148+ let governor_conf = GovernorConfigBuilder::default()
149+ .per_second(60)
150+ .burst_size(5)
151+ .finish()
152+ .unwrap();
153+ let governor_limiter = governor_conf.limiter().clone();
154+ let interval = Duration::from_secs(60);
155+ // a separate background task to clean up
156+ std::thread::spawn(move || {
157+ loop {
158+ std::thread::sleep(interval);
159+ tracing::info!("rate limiting storage size: {}", governor_limiter.len());
160+ governor_limiter.retain_recent();
161+ }
162+ });
163+164+ let cors = CorsLayer::new()
165+ .allow_origin(Any)
166+ .allow_methods([Method::GET, Method::OPTIONS, Method::POST])
167+ .allow_headers(Any);
168169 let app = Router::new()
170+ .route("/", get(root_handler))
171+ .route(
172+ "/xrpc/com.atproto.server.getSession",
173+ get(get_session).layer(ax_middleware::from_fn(middleware::extract_did)),
174+ )
175+ .route(
176+ "/xrpc/com.atproto.server.updateEmail",
177+ post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
178+ )
179+ .route(
180+ "/xrpc/com.atproto.server.createSession",
181+ post(create_session.layer(GovernorLayer::new(governor_conf))),
182+ )
183+ .layer(CompressionLayer::new())
184+ .layer(cors)
185 .with_state(state);
186187 let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
188+ let port: u16 = env::var("PORT")
189+ .ok()
190+ .and_then(|s| s.parse().ok())
191+ .unwrap_or(8080);
192+ let addr: SocketAddr = format!("{host}:{port}")
193+ .parse()
194+ .expect("valid socket address");
195196 let listener = tokio::net::TcpListener::bind(addr).await?;
197198+ let server = axum::serve(
199+ listener,
200+ app.into_make_service_with_connect_info::<SocketAddr>(),
201+ )
202+ .with_graceful_shutdown(shutdown_signal());
203204 if let Err(err) = server.await {
205 error!(error = %err, "server error");
···208 Ok(())
209}
210000000000000000211fn setup_tracing() {
212 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
213 tracing_subscriber::registry()
···226227 #[cfg(unix)]
228 let terminate = async {
229+ use tokio::signal::unix::{SignalKind, signal};
230231+ let mut sigterm =
232+ signal(SignalKind::terminate()).expect("failed to install signal handler");
233 sigterm.recv().await;
234 };
235