Microservice to bring 2FA to self hosted PDSes
1#![warn(clippy::unwrap_used)]
2use crate::gate::{get_gate, post_gate};
3use crate::mailer::{Mailer, build_mailer_from_env};
4use crate::oauth_provider::sign_in;
5use crate::xrpc::com_atproto_server::{
6 create_account, create_session, describe_server, get_session, update_email,
7};
8use anyhow::Result;
9use axum::{
10 Router,
11 body::Body,
12 handler::Handler,
13 http::{Method, header},
14 middleware as ax_middleware,
15 routing::get,
16 routing::post,
17};
18use axum_template::engine::Engine;
19use handlebars::Handlebars;
20use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
21use jacquard_common::types::did::Did;
22use jacquard_identity::{PublicResolver, resolver::PlcSource};
23use rand::Rng;
24use rust_embed::RustEmbed;
25use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode};
26use sqlx::{SqlitePool, sqlite::SqlitePoolOptions};
27use std::path::Path;
28use std::sync::Arc;
29use std::time::Duration;
30use std::{env, net::SocketAddr};
31use tower_governor::{
32 GovernorLayer, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor,
33};
34use tower_http::{
35 compression::CompressionLayer,
36 cors::{Any, CorsLayer},
37};
38use tracing::log;
39use tracing_subscriber::{EnvFilter, fmt, prelude::*};
40
41mod auth;
42mod gate;
43pub mod helpers;
44pub mod mailer;
45mod middleware;
46mod oauth_provider;
47mod xrpc;
48
49type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
50
51#[derive(RustEmbed)]
52#[folder = "email_templates"]
53#[include = "*.hbs"]
54struct EmailTemplates;
55
56#[derive(RustEmbed)]
57#[folder = "html_templates"]
58#[include = "*.hbs"]
59struct HtmlTemplates;
60
61/// Mostly the env variables that are used in the app
62#[derive(Clone, Debug)]
63pub struct AppConfig {
64 pds_base_url: String,
65 mailer_from: String,
66 email_subject: String,
67 allow_only_migrations: bool,
68 use_captcha: bool,
69 //The url to redirect to after a successful captcha. Defaults to https://bsky.app, but you may have another social-app fork you rather your users use
70 //that need to capture this redirect url for creating an account
71 default_successful_redirect_url: String,
72 pds_service_did: Did<'static>,
73 gate_jwe_key: Vec<u8>,
74 captcha_success_redirects: Vec<String>,
75}
76
77impl AppConfig {
78 pub fn new() -> Self {
79 let pds_base_url =
80 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string());
81 let mailer_from = env::var("PDS_EMAIL_FROM_ADDRESS")
82 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file");
83 //Hack not my favorite, but it does work
84 let allow_only_migrations = env::var("GATEKEEPER_ALLOW_ONLY_MIGRATIONS")
85 .map(|val| val.parse::<bool>().unwrap_or(false))
86 .unwrap_or(false);
87
88 let use_captcha = env::var("GATEKEEPER_CREATE_ACCOUNT_CAPTCHA")
89 .map(|val| val.parse::<bool>().unwrap_or(false))
90 .unwrap_or(false);
91
92 // PDS_SERVICE_DID is the did:web if set, if not it's PDS_HOSTNAME
93 let pds_service_did =
94 env::var("PDS_SERVICE_DID").unwrap_or_else(|_| match env::var("PDS_HOSTNAME") {
95 Ok(pds_hostname) => format!("did:web:{}", pds_hostname),
96 Err(_) => {
97 panic!("PDS_HOSTNAME or PDS_SERVICE_DID must be set in your pds.env file")
98 }
99 });
100
101 let email_subject = env::var("GATEKEEPER_TWO_FACTOR_EMAIL_SUBJECT")
102 .unwrap_or("Sign in to Bluesky".to_string());
103
104 // Load or generate JWE encryption key (32 bytes for AES-256)
105 let gate_jwe_key = env::var("GATEKEEPER_JWE_KEY")
106 .ok()
107 .and_then(|key_hex| hex::decode(key_hex).ok())
108 .unwrap_or_else(|| {
109 // Generate a random 32-byte key if not provided
110 let key: Vec<u8> = (0..32).map(|_| rand::rng().random()).collect();
111 log::warn!("WARNING: No GATEKEEPER_JWE_KEY found in the environment. Generated random key (hex): {}", hex::encode(&key));
112 log::warn!("This is not strictly needed unless you scale PDS Gatekeeper. Will not also be able to verify tokens between reboots, but they are short lived (5mins).");
113 key
114 });
115
116 if gate_jwe_key.len() != 32 {
117 panic!(
118 "GATEKEEPER_JWE_KEY must be 32 bytes (64 hex characters) for AES-256 encryption"
119 );
120 }
121
122 let captcha_success_redirects = match env::var("GATEKEEPER_CAPTCHA_SUCCESS_REDIRECTS") {
123 Ok(from_env) => from_env.split(",").map(|s| s.trim().to_string()).collect(),
124 Err(_) => {
125 vec![
126 String::from("https://bsky.app"),
127 String::from("https://pdsmoover.com"),
128 String::from("https://blacksky.community"),
129 String::from("https://tektite.cc"),
130 ]
131 }
132 };
133
134 AppConfig {
135 pds_base_url,
136 mailer_from,
137 email_subject,
138 allow_only_migrations,
139 use_captcha,
140 default_successful_redirect_url: env::var("GATEKEEPER_DEFAULT_CAPTCHA_REDIRECT")
141 .unwrap_or("https://bsky.app".to_string()),
142 pds_service_did: pds_service_did
143 .parse()
144 .expect("PDS_SERVICE_DID is not a valid did or could not infer from PDS_HOSTNAME"),
145 gate_jwe_key,
146 captcha_success_redirects,
147 }
148 }
149}
150
151#[derive(Clone)]
152pub struct AppState {
153 account_pool: SqlitePool,
154 pds_gatekeeper_pool: SqlitePool,
155 reverse_proxy_client: HyperUtilClient,
156 mailer: Arc<Mailer>,
157 template_engine: Engine<Handlebars<'static>>,
158 resolver: Arc<PublicResolver>,
159 handle_cache: auth::HandleCache,
160 app_config: AppConfig,
161}
162
163async fn root_handler() -> impl axum::response::IntoResponse {
164 let body = r"
165
166 ...oO _.--X~~OO~~X--._ ...oOO
167 _.-~ / \ II / \ ~-._
168 [].-~ \ / \||/ \ / ~-.[] ...o
169 ...o _ ||/ \ / || \ / \|| _
170 (_) |X X || X X| (_)
171 _-~-_ ||\ / \ || / \ /|| _-~-_
172 ||||| || \ / \ /||\ / \ / || |||||
173 | |_|| \ / \ / || \ / \ / ||_| |
174 | |~|| X X || X X ||~| |
175==============| | || / \ / \ || / \ / \ || | |==============
176______________| | || / \ / \||/ \ / \ || | |______________
177 . . | | ||/ \ / || \ / \|| | | . .
178 / | | |X X || X X| | | / /
179 / . | | ||\ / \ || / \ /|| | | . / .
180. / | | || \ / \ /||\ / \ / || | | . .
181 . . | | || \ / \ / || \ / \ / || | | .
182 / | | || X X || X X || | | . / . /
183 / . | | || / \ / \ || / \ / \ || | | /
184 / | | || / \ / \||/ \ / \ || | | . /
185. . . | | ||/ \ / /||\ \ / \|| | | /. .
186 | |_|X X / II \ X X|_| | . . /
187==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |==============
188 ";
189
190 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
191
192 let banner = format!(" {body}\n{intro}");
193
194 (
195 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
196 banner,
197 )
198}
199
200#[tokio::main]
201async fn main() -> Result<(), Box<dyn std::error::Error>> {
202 setup_tracing();
203 let pds_env_location =
204 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string());
205
206 let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location));
207 if let Err(e) = result_of_finding_pds_env {
208 log::error!(
209 "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}"
210 );
211 }
212
213 let pds_root =
214 env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file");
215 let account_db_url = format!("{pds_root}/account.sqlite");
216
217 let account_options = SqliteConnectOptions::new()
218 .journal_mode(SqliteJournalMode::Wal)
219 .filename(account_db_url)
220 .busy_timeout(Duration::from_secs(5));
221
222 let account_pool = SqlitePoolOptions::new()
223 .max_connections(5)
224 .connect_with(account_options)
225 .await?;
226
227 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite");
228 let options = SqliteConnectOptions::new()
229 .journal_mode(SqliteJournalMode::Wal)
230 .filename(bells_db_url)
231 .create_if_missing(true)
232 .busy_timeout(Duration::from_secs(5));
233 let pds_gatekeeper_pool = SqlitePoolOptions::new()
234 .max_connections(5)
235 .connect_with(options)
236 .await?;
237
238 // Run migrations for the extra database
239 // Note: the migrations are embedded at compile time from the given directory
240 // sqlx
241 sqlx::migrate!("./migrations")
242 .run(&pds_gatekeeper_pool)
243 .await?;
244
245 let client: HyperUtilClient =
246 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
247 .build(HttpConnector::new());
248
249 //Emailer set up
250 let mailer = Arc::new(build_mailer_from_env()?);
251
252 //Email templates setup
253 let mut hbs = Handlebars::new();
254
255 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY");
256 if let Ok(users_email_directory) = users_email_directory {
257 hbs.register_template_file(
258 "two_factor_code.hbs",
259 format!("{users_email_directory}/two_factor_code.hbs"),
260 )?;
261 } else {
262 let _ = hbs.register_embed_templates::<EmailTemplates>();
263 }
264
265 let _ = hbs.register_embed_templates::<HtmlTemplates>();
266
267 //Reads the PLC source from the pds env's or defaults to ol faithful
268 let plc_source_url =
269 env::var("PDS_DID_PLC_URL").unwrap_or_else(|_| "https://plc.directory".to_string());
270 let plc_source = PlcSource::PlcDirectory {
271 base: plc_source_url.parse().unwrap(),
272 };
273 let mut resolver = PublicResolver::default();
274 resolver = resolver.with_plc_source(plc_source.clone());
275
276 let state = AppState {
277 account_pool,
278 pds_gatekeeper_pool,
279 reverse_proxy_client: client,
280 mailer,
281 template_engine: Engine::from(hbs),
282 resolver: Arc::new(resolver),
283 handle_cache: auth::HandleCache::new(),
284 app_config: AppConfig::new(),
285 };
286
287 // Rate limiting
288 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
289 let captcha_governor_conf = GovernorConfigBuilder::default()
290 .per_second(60)
291 .burst_size(5)
292 .key_extractor(SmartIpKeyExtractor)
293 .finish()
294 .expect("failed to create governor config for create session. this should not happen and is a bug");
295
296 // Create a second config with the same settings for the other endpoint
297 let sign_in_governor_conf = GovernorConfigBuilder::default()
298 .per_second(60)
299 .burst_size(5)
300 .key_extractor(SmartIpKeyExtractor)
301 .finish()
302 .expect(
303 "failed to create governor config for sign in. this should not happen and is a bug",
304 );
305
306 let create_account_limiter_time: Option<String> =
307 env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok();
308 let create_account_limiter_burst: Option<String> =
309 env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok();
310
311 //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally
312 let mut create_account_governor_conf = GovernorConfigBuilder::default();
313 if create_account_limiter_time.is_some() {
314 let time = create_account_limiter_time
315 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set")
316 .parse::<u64>()
317 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer");
318 create_account_governor_conf.per_second(time);
319 }
320
321 if create_account_limiter_burst.is_some() {
322 let burst = create_account_limiter_burst
323 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set")
324 .parse::<u32>()
325 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer");
326 create_account_governor_conf.burst_size(burst);
327 }
328
329 let create_account_governor_conf = create_account_governor_conf
330 .key_extractor(SmartIpKeyExtractor)
331 .finish().expect(
332 "failed to create governor config for create account. this should not happen and is a bug",
333 );
334
335 let captcha_governor_limiter = captcha_governor_conf.limiter().clone();
336 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
337 let create_account_governor_limiter = create_account_governor_conf.limiter().clone();
338
339 let sign_in_governor_layer = GovernorLayer::new(sign_in_governor_conf);
340
341 let interval = Duration::from_secs(60);
342 // a separate background task to clean up
343 std::thread::spawn(move || {
344 loop {
345 std::thread::sleep(interval);
346 captcha_governor_limiter.retain_recent();
347 sign_in_governor_limiter.retain_recent();
348 create_account_governor_limiter.retain_recent();
349 }
350 });
351
352 let cors = CorsLayer::new()
353 .allow_origin(Any)
354 .allow_methods([Method::GET, Method::OPTIONS, Method::POST])
355 .allow_headers(Any);
356
357 let mut app = Router::new()
358 .route("/", get(root_handler))
359 .route("/xrpc/com.atproto.server.getSession", get(get_session))
360 .route(
361 "/xrpc/com.atproto.server.describeServer",
362 get(describe_server),
363 )
364 .route(
365 "/xrpc/com.atproto.server.updateEmail",
366 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
367 )
368 .route(
369 "/@atproto/oauth-provider/~api/sign-in",
370 post(sign_in).layer(sign_in_governor_layer.clone()),
371 )
372 .route(
373 "/xrpc/com.atproto.server.createSession",
374 post(create_session.layer(sign_in_governor_layer)),
375 )
376 .route(
377 "/xrpc/com.atproto.server.createAccount",
378 post(create_account).layer(GovernorLayer::new(create_account_governor_conf)),
379 );
380
381 if state.app_config.use_captcha {
382 app = app.route(
383 "/gate/signup",
384 get(get_gate).post(post_gate.layer(GovernorLayer::new(captcha_governor_conf))),
385 );
386 }
387
388 let app = app
389 .layer(CompressionLayer::new())
390 .layer(cors)
391 .with_state(state);
392
393 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "0.0.0.0".to_string());
394 let port: u16 = env::var("GATEKEEPER_PORT")
395 .ok()
396 .and_then(|s| s.parse().ok())
397 .unwrap_or(8080);
398 let addr: SocketAddr = format!("{host}:{port}")
399 .parse()
400 .expect("valid socket address");
401
402 let listener = tokio::net::TcpListener::bind(addr).await?;
403
404 let server = axum::serve(
405 listener,
406 app.into_make_service_with_connect_info::<SocketAddr>(),
407 )
408 .with_graceful_shutdown(shutdown_signal());
409
410 if let Err(err) = server.await {
411 log::error!("server error:{err}");
412 }
413
414 Ok(())
415}
416
417fn setup_tracing() {
418 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
419 tracing_subscriber::registry()
420 .with(env_filter)
421 .with(fmt::layer())
422 .init();
423}
424
425async fn shutdown_signal() {
426 // Wait for Ctrl+C
427 let ctrl_c = async {
428 tokio::signal::ctrl_c()
429 .await
430 .expect("failed to install Ctrl+C handler");
431 };
432
433 #[cfg(unix)]
434 let terminate = async {
435 use tokio::signal::unix::{SignalKind, signal};
436
437 let mut sigterm =
438 signal(SignalKind::terminate()).expect("failed to install signal handler");
439 sigterm.recv().await;
440 };
441
442 #[cfg(not(unix))]
443 let terminate = std::future::pending::<()>();
444
445 tokio::select! {
446 _ = ctrl_c => {},
447 _ = terminate => {},
448 }
449}