Microservice to bring 2FA to self hosted PDSes
1#![warn(clippy::unwrap_used)]
2use crate::gate::{get_gate, post_gate};
3use crate::mailer::{Mailer, build_mailer_from_env};
4use crate::oauth_provider::sign_in;
5use crate::xrpc::com_atproto_server::{
6 create_account, create_session, describe_server, get_session, update_email,
7};
8use anyhow::Result;
9use axum::{
10 Router,
11 body::Body,
12 handler::Handler,
13 http::{Method, header},
14 middleware as ax_middleware,
15 routing::get,
16 routing::post,
17};
18use axum_template::engine::Engine;
19use handlebars::Handlebars;
20use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
21use jacquard_common::types::did::Did;
22use jacquard_identity::{PublicResolver, resolver::PlcSource};
23use rand::Rng;
24use rust_embed::RustEmbed;
25use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode};
26use sqlx::{SqlitePool, sqlite::SqlitePoolOptions};
27use std::path::Path;
28use std::sync::Arc;
29use std::time::Duration;
30use std::{env, net::SocketAddr};
31use tower_governor::{
32 GovernorLayer, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor,
33};
34use tower_http::cors::AllowHeaders;
35use tower_http::trace::{DefaultOnRequest, HttpMakeClassifier};
36use tower_http::{
37 compression::CompressionLayer,
38 cors::{Any, CorsLayer},
39 trace::TraceLayer,
40};
41use tracing::{Span, log};
42use tracing_subscriber::{EnvFilter, fmt, prelude::*};
43
44mod admin;
45mod auth;
46mod gate;
47pub mod helpers;
48pub mod mailer;
49mod middleware;
50mod oauth_provider;
51mod xrpc;
52
53type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
54
55#[derive(RustEmbed)]
56#[folder = "email_templates"]
57#[include = "*.hbs"]
58struct EmailTemplates;
59
60#[derive(RustEmbed)]
61#[folder = "html_templates"]
62#[include = "*.hbs"]
63struct HtmlTemplates;
64
65#[derive(RustEmbed)]
66#[folder = "static"]
67pub struct StaticFiles;
68
69/// Mostly the env variables that are used in the app
70#[derive(Clone, Debug)]
71pub struct AppConfig {
72 pds_base_url: String,
73 mailer_from: String,
74 email_subject: String,
75 allow_only_migrations: bool,
76 use_captcha: bool,
77 //The url to redirect to after a successful captcha. Defaults to https://bsky.app, but you may have another social-app fork you rather your users use
78 //that need to capture this redirect url for creating an account
79 default_successful_redirect_url: String,
80 pds_service_did: Did<'static>,
81 gate_jwe_key: Vec<u8>,
82 captcha_success_redirects: Vec<String>,
83 // Admin portal config
84 pub pds_admin_password: Option<String>,
85 pub pds_hostname: String,
86 pub admin_session_ttl_hours: u64,
87}
88
89impl AppConfig {
90 pub fn new() -> Self {
91 let pds_base_url =
92 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string());
93 let mailer_from = env::var("PDS_EMAIL_FROM_ADDRESS")
94 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file");
95 //Hack not my favorite, but it does work
96 let allow_only_migrations = env::var("GATEKEEPER_ALLOW_ONLY_MIGRATIONS")
97 .map(|val| val.parse::<bool>().unwrap_or(false))
98 .unwrap_or(false);
99
100 let use_captcha = env::var("GATEKEEPER_CREATE_ACCOUNT_CAPTCHA")
101 .map(|val| val.parse::<bool>().unwrap_or(false))
102 .unwrap_or(false);
103
104 // PDS_SERVICE_DID is the did:web if set, if not it's PDS_HOSTNAME
105 let pds_service_did =
106 env::var("PDS_SERVICE_DID").unwrap_or_else(|_| match env::var("PDS_HOSTNAME") {
107 Ok(pds_hostname) => format!("did:web:{}", pds_hostname),
108 Err(_) => {
109 panic!("PDS_HOSTNAME or PDS_SERVICE_DID must be set in your pds.env file")
110 }
111 });
112
113 let email_subject = env::var("GATEKEEPER_TWO_FACTOR_EMAIL_SUBJECT")
114 .unwrap_or("Sign in to Bluesky".to_string());
115
116 // Load or generate JWE encryption key (32 bytes for AES-256)
117 let gate_jwe_key = env::var("GATEKEEPER_JWE_KEY")
118 .ok()
119 .and_then(|key_hex| hex::decode(key_hex).ok())
120 .unwrap_or_else(|| {
121 // Generate a random 32-byte key if not provided
122 let key: Vec<u8> = (0..32).map(|_| rand::rng().random()).collect();
123 log::warn!("WARNING: No GATEKEEPER_JWE_KEY found in the environment. Generated random key (hex): {}", hex::encode(&key));
124 log::warn!("This is not strictly needed unless you scale PDS Gatekeeper. Will not also be able to verify tokens between reboots, but they are short lived (5mins).");
125 key
126 });
127
128 if gate_jwe_key.len() != 32 {
129 panic!(
130 "GATEKEEPER_JWE_KEY must be 32 bytes (64 hex characters) for AES-256 encryption"
131 );
132 }
133
134 let captcha_success_redirects = match env::var("GATEKEEPER_CAPTCHA_SUCCESS_REDIRECTS") {
135 Ok(from_env) => from_env.split(",").map(|s| s.trim().to_string()).collect(),
136 Err(_) => {
137 vec![
138 String::from("https://bsky.app"),
139 String::from("https://pdsmoover.com"),
140 String::from("https://blacksky.community"),
141 String::from("https://tektite.cc"),
142 ]
143 }
144 };
145
146 let pds_hostname = env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
147
148 let pds_admin_password = env::var("PDS_ADMIN_PASSWORD").ok();
149
150 let admin_session_ttl_hours = env::var("GATEKEEPER_ADMIN_SESSION_TTL_HOURS")
151 .ok()
152 .and_then(|v| v.parse().ok())
153 .unwrap_or(24u64);
154
155 AppConfig {
156 pds_base_url,
157 mailer_from,
158 email_subject,
159 allow_only_migrations,
160 use_captcha,
161 default_successful_redirect_url: env::var("GATEKEEPER_DEFAULT_CAPTCHA_REDIRECT")
162 .unwrap_or("https://bsky.app".to_string()),
163 pds_service_did: pds_service_did
164 .parse()
165 .expect("PDS_SERVICE_DID is not a valid did or could not infer from PDS_HOSTNAME"),
166 gate_jwe_key,
167 captcha_success_redirects,
168 pds_admin_password,
169 pds_hostname,
170 admin_session_ttl_hours,
171 }
172 }
173}
174
175#[derive(Clone)]
176pub struct AppState {
177 account_pool: SqlitePool,
178 pds_gatekeeper_pool: SqlitePool,
179 reverse_proxy_client: HyperUtilClient,
180 mailer: Arc<Mailer>,
181 template_engine: Engine<Handlebars<'static>>,
182 resolver: Arc<PublicResolver>,
183 handle_cache: auth::HandleCache,
184 app_config: AppConfig,
185 // Admin portal
186 admin_rbac_config: Option<Arc<admin::rbac::RbacConfig>>,
187 admin_oauth_client: Option<Arc<admin::oauth::AdminOAuthClient>>,
188 cookie_key: axum_extra::extract::cookie::Key,
189}
190
191impl axum::extract::FromRef<AppState> for axum_extra::extract::cookie::Key {
192 fn from_ref(state: &AppState) -> Self {
193 state.cookie_key.clone()
194 }
195}
196
197async fn root_handler() -> impl axum::response::IntoResponse {
198 let body = r"
199
200 ...oO _.--X~~OO~~X--._ ...oOO
201 _.-~ / \ II / \ ~-._
202 [].-~ \ / \||/ \ / ~-.[] ...o
203 ...o _ ||/ \ / || \ / \|| _
204 (_) |X X || X X| (_)
205 _-~-_ ||\ / \ || / \ /|| _-~-_
206 ||||| || \ / \ /||\ / \ / || |||||
207 | |_|| \ / \ / || \ / \ / ||_| |
208 | |~|| X X || X X ||~| |
209==============| | || / \ / \ || / \ / \ || | |==============
210______________| | || / \ / \||/ \ / \ || | |______________
211 . . | | ||/ \ / || \ / \|| | | . .
212 / | | |X X || X X| | | / /
213 / . | | ||\ / \ || / \ /|| | | . / .
214. / | | || \ / \ /||\ / \ / || | | . .
215 . . | | || \ / \ / || \ / \ / || | | .
216 / | | || X X || X X || | | . / . /
217 / . | | || / \ / \ || / \ / \ || | | /
218 / | | || / \ / \||/ \ / \ || | | . /
219. . . | | ||/ \ / /||\ \ / \|| | | /. .
220 | |_|X X / II \ X X|_| | . . /
221==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |==============
222 ";
223
224 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
225
226 let banner = format!(" {body}\n{intro}");
227
228 (
229 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
230 banner,
231 )
232}
233
234#[tokio::main]
235async fn main() -> Result<(), Box<dyn std::error::Error>> {
236 let pds_env_location =
237 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string());
238
239 let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location));
240 if let Err(e) = result_of_finding_pds_env {
241 log::error!(
242 "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}"
243 );
244 }
245 // Sets up after the pds.env file is loaded
246 setup_tracing();
247
248 let pds_root =
249 env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file");
250 let account_db_url = format!("{pds_root}/account.sqlite");
251
252 let account_options = SqliteConnectOptions::new()
253 .journal_mode(SqliteJournalMode::Wal)
254 .filename(account_db_url)
255 .busy_timeout(Duration::from_secs(5));
256
257 let account_pool = SqlitePoolOptions::new()
258 .max_connections(5)
259 .connect_with(account_options)
260 .await?;
261
262 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite");
263 let options = SqliteConnectOptions::new()
264 .journal_mode(SqliteJournalMode::Wal)
265 .filename(bells_db_url)
266 .create_if_missing(true)
267 .busy_timeout(Duration::from_secs(5));
268 let pds_gatekeeper_pool = SqlitePoolOptions::new()
269 .max_connections(5)
270 .connect_with(options)
271 .await?;
272
273 // Run migrations for the extra database
274 // Note: the migrations are embedded at compile time from the given directory
275 // sqlx
276 sqlx::migrate!("./migrations")
277 .run(&pds_gatekeeper_pool)
278 .await?;
279
280 let client: HyperUtilClient =
281 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
282 .build(HttpConnector::new());
283
284 //Emailer set up
285 let mailer = Arc::new(build_mailer_from_env()?);
286
287 //Email templates setup
288 let mut hbs = Handlebars::new();
289
290 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY");
291 if let Ok(users_email_directory) = users_email_directory {
292 hbs.register_template_file(
293 "two_factor_code.hbs",
294 format!("{users_email_directory}/two_factor_code.hbs"),
295 )?;
296 } else {
297 let _ = hbs.register_embed_templates::<EmailTemplates>();
298 }
299
300 let _ = hbs.register_embed_templates::<HtmlTemplates>();
301
302 //Reads the PLC source from the pds env's or defaults to ol faithful
303 let plc_source_url =
304 env::var("PDS_DID_PLC_URL").unwrap_or_else(|_| "https://plc.directory".to_string());
305 let plc_source = PlcSource::PlcDirectory {
306 base: plc_source_url.parse().unwrap(),
307 };
308 let mut resolver = PublicResolver::default();
309 resolver = resolver.with_plc_source(plc_source.clone());
310
311 let app_config = AppConfig::new();
312
313 // Admin portal setup (opt-in via GATEKEEPER_ADMIN_RBAC_CONFIG)
314 let admin_rbac_config = env::var("GATEKEEPER_ADMIN_RBAC_CONFIG").ok().map(|path| {
315 let config = admin::rbac::RbacConfig::load_from_file(&path)
316 .unwrap_or_else(|e| panic!("Failed to load RBAC config from {}: {}", path, e));
317 log::info!(
318 "Loaded admin RBAC config from {} ({} members)",
319 path,
320 config.members.len()
321 );
322 Arc::new(config)
323 });
324
325 let admin_oauth_client = if admin_rbac_config.is_some() {
326 match admin::oauth::init_oauth_client(&app_config.pds_hostname, pds_gatekeeper_pool.clone())
327 {
328 Ok(client) => {
329 log::info!(
330 "Admin OAuth client initialized for {}",
331 app_config.pds_hostname
332 );
333 Some(Arc::new(client))
334 }
335 Err(e) => {
336 log::error!(
337 "Failed to initialize admin OAuth client: {}. Admin portal will be disabled.",
338 e
339 );
340 None
341 }
342 }
343 } else {
344 None
345 };
346
347 // Cookie signing key for admin sessions
348 let cookie_key = env::var("GATEKEEPER_ADMIN_COOKIE_SECRET")
349 .ok()
350 .and_then(|hex_str| hex::decode(hex_str).ok())
351 .unwrap_or_else(|| app_config.gate_jwe_key.clone());
352 let cookie_key = {
353 // Key::from requires at least 64 bytes; derive by repeating if needed
354 let mut key_bytes = cookie_key.clone();
355 while key_bytes.len() < 64 {
356 key_bytes.extend_from_slice(&cookie_key);
357 }
358 axum_extra::extract::cookie::Key::from(&key_bytes[..64])
359 };
360
361 let state = AppState {
362 account_pool,
363 pds_gatekeeper_pool,
364 reverse_proxy_client: client,
365 mailer,
366 template_engine: Engine::from(hbs),
367 resolver: Arc::new(resolver),
368 handle_cache: auth::HandleCache::new(),
369 app_config,
370 admin_rbac_config,
371 admin_oauth_client,
372 cookie_key,
373 };
374
375 // Rate limiting
376 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
377 let captcha_governor_conf = GovernorConfigBuilder::default()
378 .per_second(60)
379 .burst_size(5)
380 .key_extractor(SmartIpKeyExtractor)
381 .finish()
382 .expect("failed to create governor config for create session. this should not happen and is a bug");
383
384 // Create a second config with the same settings for the other endpoint
385 let sign_in_governor_conf = GovernorConfigBuilder::default()
386 .per_second(60)
387 .burst_size(5)
388 .key_extractor(SmartIpKeyExtractor)
389 .finish()
390 .expect(
391 "failed to create governor config for sign in. this should not happen and is a bug",
392 );
393
394 let create_account_limiter_time: Option<String> =
395 env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok();
396 let create_account_limiter_burst: Option<String> =
397 env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok();
398
399 //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally
400 let mut create_account_governor_conf = GovernorConfigBuilder::default();
401 if create_account_limiter_time.is_some() {
402 let time = create_account_limiter_time
403 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set")
404 .parse::<u64>()
405 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer");
406 create_account_governor_conf.per_second(time);
407 }
408
409 if create_account_limiter_burst.is_some() {
410 let burst = create_account_limiter_burst
411 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set")
412 .parse::<u32>()
413 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer");
414 create_account_governor_conf.burst_size(burst);
415 }
416
417 let create_account_governor_conf = create_account_governor_conf
418 .key_extractor(SmartIpKeyExtractor)
419 .finish().expect(
420 "failed to create governor config for create account. this should not happen and is a bug",
421 );
422
423 let captcha_governor_limiter = captcha_governor_conf.limiter().clone();
424 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
425 let create_account_governor_limiter = create_account_governor_conf.limiter().clone();
426
427 let sign_in_governor_layer = GovernorLayer::new(sign_in_governor_conf);
428
429 let interval = Duration::from_secs(60);
430 // a separate background task to clean up
431 std::thread::spawn(move || {
432 loop {
433 std::thread::sleep(interval);
434 captcha_governor_limiter.retain_recent();
435 sign_in_governor_limiter.retain_recent();
436 create_account_governor_limiter.retain_recent();
437 }
438 });
439
440 let cors = CorsLayer::new()
441 .allow_origin(Any)
442 .allow_methods([Method::GET, Method::OPTIONS, Method::POST])
443 .allow_headers(AllowHeaders::mirror_request());
444
445 let mut app = Router::new()
446 .route("/", get(root_handler))
447 .route("/xrpc/com.atproto.server.getSession", get(get_session))
448 .route(
449 "/xrpc/com.atproto.server.describeServer",
450 get(describe_server),
451 )
452 .route(
453 "/xrpc/com.atproto.server.updateEmail",
454 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
455 )
456 .route(
457 "/@atproto/oauth-provider/~api/sign-in",
458 post(sign_in).layer(sign_in_governor_layer.clone()),
459 )
460 .route(
461 "/xrpc/com.atproto.server.createSession",
462 post(create_session.layer(sign_in_governor_layer)),
463 )
464 .route(
465 "/xrpc/com.atproto.server.createAccount",
466 post(create_account).layer(GovernorLayer::new(create_account_governor_conf)),
467 );
468
469 if state.app_config.use_captcha {
470 app = app.route(
471 "/gate/signup",
472 get(get_gate).post(post_gate.layer(GovernorLayer::new(captcha_governor_conf))),
473 );
474 }
475
476 // Mount admin portal if RBAC config is loaded
477 if state.admin_rbac_config.is_some() {
478 let admin_router = admin::router(state.clone());
479 app = app.nest("/admin", admin_router);
480 log::info!("Admin portal mounted at /admin/");
481 }
482
483 // Background cleanup for admin sessions
484 let admin_enabled = state.admin_rbac_config.is_some();
485 if admin_enabled {
486 let admin_session_ttl_in_mins = state.app_config.admin_session_ttl_hours * 60;
487 let cleanup_pool = state.pds_gatekeeper_pool.clone();
488 tokio::spawn(async move {
489 let mut interval = tokio::time::interval(Duration::from_secs(300));
490 loop {
491 interval.tick().await;
492 if admin_enabled {
493 if let Err(e) = admin::session::cleanup_expired_sessions(&cleanup_pool).await {
494 tracing::error!("Failed to cleanup expired admin sessions: {}", e);
495 }
496 if let Err(e) = admin::store::cleanup_stale_auth_requests(
497 &cleanup_pool,
498 admin_session_ttl_in_mins as i64,
499 )
500 .await
501 {
502 tracing::error!("Failed to cleanup stale OAuth auth requests: {}", e);
503 }
504 }
505 }
506 });
507 }
508
509 let request_logging = env::var("GATEKEEPER_REQUEST_LOGGING")
510 .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
511 .unwrap_or(false);
512
513 if request_logging {
514 app = app.layer(request_trace_layer());
515 }
516
517 let app = app
518 .layer(CompressionLayer::new())
519 .layer(cors)
520 .with_state(state);
521
522 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "0.0.0.0".to_string());
523 let port: u16 = env::var("GATEKEEPER_PORT")
524 .ok()
525 .and_then(|s| s.parse().ok())
526 .unwrap_or(8080);
527 let addr: SocketAddr = format!("{host}:{port}")
528 .parse()
529 .expect("valid socket address");
530
531 let listener = tokio::net::TcpListener::bind(addr).await?;
532
533 let server = axum::serve(
534 listener,
535 app.into_make_service_with_connect_info::<SocketAddr>(),
536 )
537 .with_graceful_shutdown(shutdown_signal());
538
539 if let Err(err) = server.await {
540 log::error!("server error:{err}");
541 }
542
543 Ok(())
544}
545
546fn setup_tracing() {
547 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
548 let json = env::var("GATEKEEPER_LOG_FORMAT")
549 .map(|v| v.eq_ignore_ascii_case("json"))
550 .unwrap_or(false);
551
552 if json {
553 tracing_subscriber::registry()
554 .with(env_filter)
555 .with(fmt::layer().json())
556 .init();
557 } else {
558 tracing_subscriber::registry()
559 .with(env_filter)
560 .with(fmt::layer())
561 .init();
562 }
563}
564
565async fn shutdown_signal() {
566 // Wait for Ctrl+C
567 let ctrl_c = async {
568 tokio::signal::ctrl_c()
569 .await
570 .expect("failed to install Ctrl+C handler");
571 };
572
573 #[cfg(unix)]
574 let terminate = async {
575 use tokio::signal::unix::{SignalKind, signal};
576
577 let mut sigterm =
578 signal(SignalKind::terminate()).expect("failed to install signal handler");
579 sigterm.recv().await;
580 };
581
582 #[cfg(not(unix))]
583 let terminate = std::future::pending::<()>();
584
585 tokio::select! {
586 _ = ctrl_c => {},
587 _ = terminate => {},
588 }
589}
590
591fn request_trace_layer() -> TraceLayer<
592 HttpMakeClassifier,
593 impl Fn(&axum::http::Request<Body>) -> Span + Clone,
594 DefaultOnRequest,
595 impl Fn(&axum::http::Response<Body>, Duration, &Span) + Clone,
596> {
597 TraceLayer::new_for_http()
598 .make_span_with(|req: &axum::http::Request<Body>| {
599 let headers = req.headers();
600 tracing::info_span!("request",
601 method = %req.method(),
602 path = %req.uri().path(),
603 headers = %format!("{:?}", headers),
604 )
605 })
606 .on_response(
607 |resp: &axum::http::Response<Body>, latency: Duration, _span: &tracing::Span| {
608 tracing::info!(
609 status = resp.status().as_u16(),
610 latency_ms = latency.as_millis() as u64,
611 "response"
612 );
613 },
614 )
615}