Microservice to bring 2FA to self hosted PDSes
at feature/admin-rbac 615 lines 22 kB view raw
1#![warn(clippy::unwrap_used)] 2use crate::gate::{get_gate, post_gate}; 3use crate::mailer::{Mailer, build_mailer_from_env}; 4use crate::oauth_provider::sign_in; 5use crate::xrpc::com_atproto_server::{ 6 create_account, create_session, describe_server, get_session, update_email, 7}; 8use anyhow::Result; 9use axum::{ 10 Router, 11 body::Body, 12 handler::Handler, 13 http::{Method, header}, 14 middleware as ax_middleware, 15 routing::get, 16 routing::post, 17}; 18use axum_template::engine::Engine; 19use handlebars::Handlebars; 20use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; 21use jacquard_common::types::did::Did; 22use jacquard_identity::{PublicResolver, resolver::PlcSource}; 23use rand::Rng; 24use rust_embed::RustEmbed; 25use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; 26use sqlx::{SqlitePool, sqlite::SqlitePoolOptions}; 27use std::path::Path; 28use std::sync::Arc; 29use std::time::Duration; 30use std::{env, net::SocketAddr}; 31use tower_governor::{ 32 GovernorLayer, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, 33}; 34use tower_http::cors::AllowHeaders; 35use tower_http::trace::{DefaultOnRequest, HttpMakeClassifier}; 36use tower_http::{ 37 compression::CompressionLayer, 38 cors::{Any, CorsLayer}, 39 trace::TraceLayer, 40}; 41use tracing::{Span, log}; 42use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 43 44mod admin; 45mod auth; 46mod gate; 47pub mod helpers; 48pub mod mailer; 49mod middleware; 50mod oauth_provider; 51mod xrpc; 52 53type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>; 54 55#[derive(RustEmbed)] 56#[folder = "email_templates"] 57#[include = "*.hbs"] 58struct EmailTemplates; 59 60#[derive(RustEmbed)] 61#[folder = "html_templates"] 62#[include = "*.hbs"] 63struct HtmlTemplates; 64 65#[derive(RustEmbed)] 66#[folder = "static"] 67pub struct StaticFiles; 68 69/// Mostly the env variables that are used in the app 70#[derive(Clone, Debug)] 71pub struct AppConfig { 72 pds_base_url: String, 73 mailer_from: String, 74 email_subject: String, 75 allow_only_migrations: bool, 76 use_captcha: bool, 77 //The url to redirect to after a successful captcha. Defaults to https://bsky.app, but you may have another social-app fork you rather your users use 78 //that need to capture this redirect url for creating an account 79 default_successful_redirect_url: String, 80 pds_service_did: Did<'static>, 81 gate_jwe_key: Vec<u8>, 82 captcha_success_redirects: Vec<String>, 83 // Admin portal config 84 pub pds_admin_password: Option<String>, 85 pub pds_hostname: String, 86 pub admin_session_ttl_hours: u64, 87} 88 89impl AppConfig { 90 pub fn new() -> Self { 91 let pds_base_url = 92 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); 93 let mailer_from = env::var("PDS_EMAIL_FROM_ADDRESS") 94 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file"); 95 //Hack not my favorite, but it does work 96 let allow_only_migrations = env::var("GATEKEEPER_ALLOW_ONLY_MIGRATIONS") 97 .map(|val| val.parse::<bool>().unwrap_or(false)) 98 .unwrap_or(false); 99 100 let use_captcha = env::var("GATEKEEPER_CREATE_ACCOUNT_CAPTCHA") 101 .map(|val| val.parse::<bool>().unwrap_or(false)) 102 .unwrap_or(false); 103 104 // PDS_SERVICE_DID is the did:web if set, if not it's PDS_HOSTNAME 105 let pds_service_did = 106 env::var("PDS_SERVICE_DID").unwrap_or_else(|_| match env::var("PDS_HOSTNAME") { 107 Ok(pds_hostname) => format!("did:web:{}", pds_hostname), 108 Err(_) => { 109 panic!("PDS_HOSTNAME or PDS_SERVICE_DID must be set in your pds.env file") 110 } 111 }); 112 113 let email_subject = env::var("GATEKEEPER_TWO_FACTOR_EMAIL_SUBJECT") 114 .unwrap_or("Sign in to Bluesky".to_string()); 115 116 // Load or generate JWE encryption key (32 bytes for AES-256) 117 let gate_jwe_key = env::var("GATEKEEPER_JWE_KEY") 118 .ok() 119 .and_then(|key_hex| hex::decode(key_hex).ok()) 120 .unwrap_or_else(|| { 121 // Generate a random 32-byte key if not provided 122 let key: Vec<u8> = (0..32).map(|_| rand::rng().random()).collect(); 123 log::warn!("WARNING: No GATEKEEPER_JWE_KEY found in the environment. Generated random key (hex): {}", hex::encode(&key)); 124 log::warn!("This is not strictly needed unless you scale PDS Gatekeeper. Will not also be able to verify tokens between reboots, but they are short lived (5mins)."); 125 key 126 }); 127 128 if gate_jwe_key.len() != 32 { 129 panic!( 130 "GATEKEEPER_JWE_KEY must be 32 bytes (64 hex characters) for AES-256 encryption" 131 ); 132 } 133 134 let captcha_success_redirects = match env::var("GATEKEEPER_CAPTCHA_SUCCESS_REDIRECTS") { 135 Ok(from_env) => from_env.split(",").map(|s| s.trim().to_string()).collect(), 136 Err(_) => { 137 vec![ 138 String::from("https://bsky.app"), 139 String::from("https://pdsmoover.com"), 140 String::from("https://blacksky.community"), 141 String::from("https://tektite.cc"), 142 ] 143 } 144 }; 145 146 let pds_hostname = env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 147 148 let pds_admin_password = env::var("PDS_ADMIN_PASSWORD").ok(); 149 150 let admin_session_ttl_hours = env::var("GATEKEEPER_ADMIN_SESSION_TTL_HOURS") 151 .ok() 152 .and_then(|v| v.parse().ok()) 153 .unwrap_or(24u64); 154 155 AppConfig { 156 pds_base_url, 157 mailer_from, 158 email_subject, 159 allow_only_migrations, 160 use_captcha, 161 default_successful_redirect_url: env::var("GATEKEEPER_DEFAULT_CAPTCHA_REDIRECT") 162 .unwrap_or("https://bsky.app".to_string()), 163 pds_service_did: pds_service_did 164 .parse() 165 .expect("PDS_SERVICE_DID is not a valid did or could not infer from PDS_HOSTNAME"), 166 gate_jwe_key, 167 captcha_success_redirects, 168 pds_admin_password, 169 pds_hostname, 170 admin_session_ttl_hours, 171 } 172 } 173} 174 175#[derive(Clone)] 176pub struct AppState { 177 account_pool: SqlitePool, 178 pds_gatekeeper_pool: SqlitePool, 179 reverse_proxy_client: HyperUtilClient, 180 mailer: Arc<Mailer>, 181 template_engine: Engine<Handlebars<'static>>, 182 resolver: Arc<PublicResolver>, 183 handle_cache: auth::HandleCache, 184 app_config: AppConfig, 185 // Admin portal 186 admin_rbac_config: Option<Arc<admin::rbac::RbacConfig>>, 187 admin_oauth_client: Option<Arc<admin::oauth::AdminOAuthClient>>, 188 cookie_key: axum_extra::extract::cookie::Key, 189} 190 191impl axum::extract::FromRef<AppState> for axum_extra::extract::cookie::Key { 192 fn from_ref(state: &AppState) -> Self { 193 state.cookie_key.clone() 194 } 195} 196 197async fn root_handler() -> impl axum::response::IntoResponse { 198 let body = r" 199 200 ...oO _.--X~~OO~~X--._ ...oOO 201 _.-~ / \ II / \ ~-._ 202 [].-~ \ / \||/ \ / ~-.[] ...o 203 ...o _ ||/ \ / || \ / \|| _ 204 (_) |X X || X X| (_) 205 _-~-_ ||\ / \ || / \ /|| _-~-_ 206 ||||| || \ / \ /||\ / \ / || ||||| 207 | |_|| \ / \ / || \ / \ / ||_| | 208 | |~|| X X || X X ||~| | 209==============| | || / \ / \ || / \ / \ || | |============== 210______________| | || / \ / \||/ \ / \ || | |______________ 211 . . | | ||/ \ / || \ / \|| | | . . 212 / | | |X X || X X| | | / / 213 / . | | ||\ / \ || / \ /|| | | . / . 214. / | | || \ / \ /||\ / \ / || | | . . 215 . . | | || \ / \ / || \ / \ / || | | . 216 / | | || X X || X X || | | . / . / 217 / . | | || / \ / \ || / \ / \ || | | / 218 / | | || / \ / \||/ \ / \ || | | . / 219. . . | | ||/ \ / /||\ \ / \|| | | /. . 220 | |_|X X / II \ X X|_| | . . / 221==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |============== 222 "; 223 224 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; 225 226 let banner = format!(" {body}\n{intro}"); 227 228 ( 229 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], 230 banner, 231 ) 232} 233 234#[tokio::main] 235async fn main() -> Result<(), Box<dyn std::error::Error>> { 236 let pds_env_location = 237 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string()); 238 239 let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location)); 240 if let Err(e) = result_of_finding_pds_env { 241 log::error!( 242 "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}" 243 ); 244 } 245 // Sets up after the pds.env file is loaded 246 setup_tracing(); 247 248 let pds_root = 249 env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file"); 250 let account_db_url = format!("{pds_root}/account.sqlite"); 251 252 let account_options = SqliteConnectOptions::new() 253 .journal_mode(SqliteJournalMode::Wal) 254 .filename(account_db_url) 255 .busy_timeout(Duration::from_secs(5)); 256 257 let account_pool = SqlitePoolOptions::new() 258 .max_connections(5) 259 .connect_with(account_options) 260 .await?; 261 262 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); 263 let options = SqliteConnectOptions::new() 264 .journal_mode(SqliteJournalMode::Wal) 265 .filename(bells_db_url) 266 .create_if_missing(true) 267 .busy_timeout(Duration::from_secs(5)); 268 let pds_gatekeeper_pool = SqlitePoolOptions::new() 269 .max_connections(5) 270 .connect_with(options) 271 .await?; 272 273 // Run migrations for the extra database 274 // Note: the migrations are embedded at compile time from the given directory 275 // sqlx 276 sqlx::migrate!("./migrations") 277 .run(&pds_gatekeeper_pool) 278 .await?; 279 280 let client: HyperUtilClient = 281 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) 282 .build(HttpConnector::new()); 283 284 //Emailer set up 285 let mailer = Arc::new(build_mailer_from_env()?); 286 287 //Email templates setup 288 let mut hbs = Handlebars::new(); 289 290 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); 291 if let Ok(users_email_directory) = users_email_directory { 292 hbs.register_template_file( 293 "two_factor_code.hbs", 294 format!("{users_email_directory}/two_factor_code.hbs"), 295 )?; 296 } else { 297 let _ = hbs.register_embed_templates::<EmailTemplates>(); 298 } 299 300 let _ = hbs.register_embed_templates::<HtmlTemplates>(); 301 302 //Reads the PLC source from the pds env's or defaults to ol faithful 303 let plc_source_url = 304 env::var("PDS_DID_PLC_URL").unwrap_or_else(|_| "https://plc.directory".to_string()); 305 let plc_source = PlcSource::PlcDirectory { 306 base: plc_source_url.parse().unwrap(), 307 }; 308 let mut resolver = PublicResolver::default(); 309 resolver = resolver.with_plc_source(plc_source.clone()); 310 311 let app_config = AppConfig::new(); 312 313 // Admin portal setup (opt-in via GATEKEEPER_ADMIN_RBAC_CONFIG) 314 let admin_rbac_config = env::var("GATEKEEPER_ADMIN_RBAC_CONFIG").ok().map(|path| { 315 let config = admin::rbac::RbacConfig::load_from_file(&path) 316 .unwrap_or_else(|e| panic!("Failed to load RBAC config from {}: {}", path, e)); 317 log::info!( 318 "Loaded admin RBAC config from {} ({} members)", 319 path, 320 config.members.len() 321 ); 322 Arc::new(config) 323 }); 324 325 let admin_oauth_client = if admin_rbac_config.is_some() { 326 match admin::oauth::init_oauth_client(&app_config.pds_hostname, pds_gatekeeper_pool.clone()) 327 { 328 Ok(client) => { 329 log::info!( 330 "Admin OAuth client initialized for {}", 331 app_config.pds_hostname 332 ); 333 Some(Arc::new(client)) 334 } 335 Err(e) => { 336 log::error!( 337 "Failed to initialize admin OAuth client: {}. Admin portal will be disabled.", 338 e 339 ); 340 None 341 } 342 } 343 } else { 344 None 345 }; 346 347 // Cookie signing key for admin sessions 348 let cookie_key = env::var("GATEKEEPER_ADMIN_COOKIE_SECRET") 349 .ok() 350 .and_then(|hex_str| hex::decode(hex_str).ok()) 351 .unwrap_or_else(|| app_config.gate_jwe_key.clone()); 352 let cookie_key = { 353 // Key::from requires at least 64 bytes; derive by repeating if needed 354 let mut key_bytes = cookie_key.clone(); 355 while key_bytes.len() < 64 { 356 key_bytes.extend_from_slice(&cookie_key); 357 } 358 axum_extra::extract::cookie::Key::from(&key_bytes[..64]) 359 }; 360 361 let state = AppState { 362 account_pool, 363 pds_gatekeeper_pool, 364 reverse_proxy_client: client, 365 mailer, 366 template_engine: Engine::from(hbs), 367 resolver: Arc::new(resolver), 368 handle_cache: auth::HandleCache::new(), 369 app_config, 370 admin_rbac_config, 371 admin_oauth_client, 372 cookie_key, 373 }; 374 375 // Rate limiting 376 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 377 let captcha_governor_conf = GovernorConfigBuilder::default() 378 .per_second(60) 379 .burst_size(5) 380 .key_extractor(SmartIpKeyExtractor) 381 .finish() 382 .expect("failed to create governor config for create session. this should not happen and is a bug"); 383 384 // Create a second config with the same settings for the other endpoint 385 let sign_in_governor_conf = GovernorConfigBuilder::default() 386 .per_second(60) 387 .burst_size(5) 388 .key_extractor(SmartIpKeyExtractor) 389 .finish() 390 .expect( 391 "failed to create governor config for sign in. this should not happen and is a bug", 392 ); 393 394 let create_account_limiter_time: Option<String> = 395 env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok(); 396 let create_account_limiter_burst: Option<String> = 397 env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok(); 398 399 //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally 400 let mut create_account_governor_conf = GovernorConfigBuilder::default(); 401 if create_account_limiter_time.is_some() { 402 let time = create_account_limiter_time 403 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set") 404 .parse::<u64>() 405 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer"); 406 create_account_governor_conf.per_second(time); 407 } 408 409 if create_account_limiter_burst.is_some() { 410 let burst = create_account_limiter_burst 411 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set") 412 .parse::<u32>() 413 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer"); 414 create_account_governor_conf.burst_size(burst); 415 } 416 417 let create_account_governor_conf = create_account_governor_conf 418 .key_extractor(SmartIpKeyExtractor) 419 .finish().expect( 420 "failed to create governor config for create account. this should not happen and is a bug", 421 ); 422 423 let captcha_governor_limiter = captcha_governor_conf.limiter().clone(); 424 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 425 let create_account_governor_limiter = create_account_governor_conf.limiter().clone(); 426 427 let sign_in_governor_layer = GovernorLayer::new(sign_in_governor_conf); 428 429 let interval = Duration::from_secs(60); 430 // a separate background task to clean up 431 std::thread::spawn(move || { 432 loop { 433 std::thread::sleep(interval); 434 captcha_governor_limiter.retain_recent(); 435 sign_in_governor_limiter.retain_recent(); 436 create_account_governor_limiter.retain_recent(); 437 } 438 }); 439 440 let cors = CorsLayer::new() 441 .allow_origin(Any) 442 .allow_methods([Method::GET, Method::OPTIONS, Method::POST]) 443 .allow_headers(AllowHeaders::mirror_request()); 444 445 let mut app = Router::new() 446 .route("/", get(root_handler)) 447 .route("/xrpc/com.atproto.server.getSession", get(get_session)) 448 .route( 449 "/xrpc/com.atproto.server.describeServer", 450 get(describe_server), 451 ) 452 .route( 453 "/xrpc/com.atproto.server.updateEmail", 454 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), 455 ) 456 .route( 457 "/@atproto/oauth-provider/~api/sign-in", 458 post(sign_in).layer(sign_in_governor_layer.clone()), 459 ) 460 .route( 461 "/xrpc/com.atproto.server.createSession", 462 post(create_session.layer(sign_in_governor_layer)), 463 ) 464 .route( 465 "/xrpc/com.atproto.server.createAccount", 466 post(create_account).layer(GovernorLayer::new(create_account_governor_conf)), 467 ); 468 469 if state.app_config.use_captcha { 470 app = app.route( 471 "/gate/signup", 472 get(get_gate).post(post_gate.layer(GovernorLayer::new(captcha_governor_conf))), 473 ); 474 } 475 476 // Mount admin portal if RBAC config is loaded 477 if state.admin_rbac_config.is_some() { 478 let admin_router = admin::router(state.clone()); 479 app = app.nest("/admin", admin_router); 480 log::info!("Admin portal mounted at /admin/"); 481 } 482 483 // Background cleanup for admin sessions 484 let admin_enabled = state.admin_rbac_config.is_some(); 485 if admin_enabled { 486 let admin_session_ttl_in_mins = state.app_config.admin_session_ttl_hours * 60; 487 let cleanup_pool = state.pds_gatekeeper_pool.clone(); 488 tokio::spawn(async move { 489 let mut interval = tokio::time::interval(Duration::from_secs(300)); 490 loop { 491 interval.tick().await; 492 if admin_enabled { 493 if let Err(e) = admin::session::cleanup_expired_sessions(&cleanup_pool).await { 494 tracing::error!("Failed to cleanup expired admin sessions: {}", e); 495 } 496 if let Err(e) = admin::store::cleanup_stale_auth_requests( 497 &cleanup_pool, 498 admin_session_ttl_in_mins as i64, 499 ) 500 .await 501 { 502 tracing::error!("Failed to cleanup stale OAuth auth requests: {}", e); 503 } 504 } 505 } 506 }); 507 } 508 509 let request_logging = env::var("GATEKEEPER_REQUEST_LOGGING") 510 .map(|v| v.eq_ignore_ascii_case("true") || v == "1") 511 .unwrap_or(false); 512 513 if request_logging { 514 app = app.layer(request_trace_layer()); 515 } 516 517 let app = app 518 .layer(CompressionLayer::new()) 519 .layer(cors) 520 .with_state(state); 521 522 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()); 523 let port: u16 = env::var("GATEKEEPER_PORT") 524 .ok() 525 .and_then(|s| s.parse().ok()) 526 .unwrap_or(8080); 527 let addr: SocketAddr = format!("{host}:{port}") 528 .parse() 529 .expect("valid socket address"); 530 531 let listener = tokio::net::TcpListener::bind(addr).await?; 532 533 let server = axum::serve( 534 listener, 535 app.into_make_service_with_connect_info::<SocketAddr>(), 536 ) 537 .with_graceful_shutdown(shutdown_signal()); 538 539 if let Err(err) = server.await { 540 log::error!("server error:{err}"); 541 } 542 543 Ok(()) 544} 545 546fn setup_tracing() { 547 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); 548 let json = env::var("GATEKEEPER_LOG_FORMAT") 549 .map(|v| v.eq_ignore_ascii_case("json")) 550 .unwrap_or(false); 551 552 if json { 553 tracing_subscriber::registry() 554 .with(env_filter) 555 .with(fmt::layer().json()) 556 .init(); 557 } else { 558 tracing_subscriber::registry() 559 .with(env_filter) 560 .with(fmt::layer()) 561 .init(); 562 } 563} 564 565async fn shutdown_signal() { 566 // Wait for Ctrl+C 567 let ctrl_c = async { 568 tokio::signal::ctrl_c() 569 .await 570 .expect("failed to install Ctrl+C handler"); 571 }; 572 573 #[cfg(unix)] 574 let terminate = async { 575 use tokio::signal::unix::{SignalKind, signal}; 576 577 let mut sigterm = 578 signal(SignalKind::terminate()).expect("failed to install signal handler"); 579 sigterm.recv().await; 580 }; 581 582 #[cfg(not(unix))] 583 let terminate = std::future::pending::<()>(); 584 585 tokio::select! { 586 _ = ctrl_c => {}, 587 _ = terminate => {}, 588 } 589} 590 591fn request_trace_layer() -> TraceLayer< 592 HttpMakeClassifier, 593 impl Fn(&axum::http::Request<Body>) -> Span + Clone, 594 DefaultOnRequest, 595 impl Fn(&axum::http::Response<Body>, Duration, &Span) + Clone, 596> { 597 TraceLayer::new_for_http() 598 .make_span_with(|req: &axum::http::Request<Body>| { 599 let headers = req.headers(); 600 tracing::info_span!("request", 601 method = %req.method(), 602 path = %req.uri().path(), 603 headers = %format!("{:?}", headers), 604 ) 605 }) 606 .on_response( 607 |resp: &axum::http::Response<Body>, latency: Duration, _span: &tracing::Span| { 608 tracing::info!( 609 status = resp.status().as_u16(), 610 latency_ms = latency.as_millis() as u64, 611 "response" 612 ); 613 }, 614 ) 615}