Microservice to bring 2FA to self hosted PDSes
at sendmail 449 lines 16 kB view raw
1#![warn(clippy::unwrap_used)] 2use crate::gate::{get_gate, post_gate}; 3use crate::mailer::{Mailer, build_mailer_from_env}; 4use crate::oauth_provider::sign_in; 5use crate::xrpc::com_atproto_server::{ 6 create_account, create_session, describe_server, get_session, update_email, 7}; 8use anyhow::Result; 9use axum::{ 10 Router, 11 body::Body, 12 handler::Handler, 13 http::{Method, header}, 14 middleware as ax_middleware, 15 routing::get, 16 routing::post, 17}; 18use axum_template::engine::Engine; 19use handlebars::Handlebars; 20use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; 21use jacquard_common::types::did::Did; 22use jacquard_identity::{PublicResolver, resolver::PlcSource}; 23use rand::Rng; 24use rust_embed::RustEmbed; 25use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; 26use sqlx::{SqlitePool, sqlite::SqlitePoolOptions}; 27use std::path::Path; 28use std::sync::Arc; 29use std::time::Duration; 30use std::{env, net::SocketAddr}; 31use tower_governor::{ 32 GovernorLayer, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, 33}; 34use tower_http::{ 35 compression::CompressionLayer, 36 cors::{Any, CorsLayer}, 37}; 38use tracing::log; 39use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 40 41mod auth; 42mod gate; 43pub mod helpers; 44pub mod mailer; 45mod middleware; 46mod oauth_provider; 47mod xrpc; 48 49type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>; 50 51#[derive(RustEmbed)] 52#[folder = "email_templates"] 53#[include = "*.hbs"] 54struct EmailTemplates; 55 56#[derive(RustEmbed)] 57#[folder = "html_templates"] 58#[include = "*.hbs"] 59struct HtmlTemplates; 60 61/// Mostly the env variables that are used in the app 62#[derive(Clone, Debug)] 63pub struct AppConfig { 64 pds_base_url: String, 65 mailer_from: String, 66 email_subject: String, 67 allow_only_migrations: bool, 68 use_captcha: bool, 69 //The url to redirect to after a successful captcha. Defaults to https://bsky.app, but you may have another social-app fork you rather your users use 70 //that need to capture this redirect url for creating an account 71 default_successful_redirect_url: String, 72 pds_service_did: Did<'static>, 73 gate_jwe_key: Vec<u8>, 74 captcha_success_redirects: Vec<String>, 75} 76 77impl AppConfig { 78 pub fn new() -> Self { 79 let pds_base_url = 80 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); 81 let mailer_from = env::var("PDS_EMAIL_FROM_ADDRESS") 82 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file"); 83 //Hack not my favorite, but it does work 84 let allow_only_migrations = env::var("GATEKEEPER_ALLOW_ONLY_MIGRATIONS") 85 .map(|val| val.parse::<bool>().unwrap_or(false)) 86 .unwrap_or(false); 87 88 let use_captcha = env::var("GATEKEEPER_CREATE_ACCOUNT_CAPTCHA") 89 .map(|val| val.parse::<bool>().unwrap_or(false)) 90 .unwrap_or(false); 91 92 // PDS_SERVICE_DID is the did:web if set, if not it's PDS_HOSTNAME 93 let pds_service_did = 94 env::var("PDS_SERVICE_DID").unwrap_or_else(|_| match env::var("PDS_HOSTNAME") { 95 Ok(pds_hostname) => format!("did:web:{}", pds_hostname), 96 Err(_) => { 97 panic!("PDS_HOSTNAME or PDS_SERVICE_DID must be set in your pds.env file") 98 } 99 }); 100 101 let email_subject = env::var("GATEKEEPER_TWO_FACTOR_EMAIL_SUBJECT") 102 .unwrap_or("Sign in to Bluesky".to_string()); 103 104 // Load or generate JWE encryption key (32 bytes for AES-256) 105 let gate_jwe_key = env::var("GATEKEEPER_JWE_KEY") 106 .ok() 107 .and_then(|key_hex| hex::decode(key_hex).ok()) 108 .unwrap_or_else(|| { 109 // Generate a random 32-byte key if not provided 110 let key: Vec<u8> = (0..32).map(|_| rand::rng().random()).collect(); 111 log::warn!("WARNING: No GATEKEEPER_JWE_KEY found in the environment. Generated random key (hex): {}", hex::encode(&key)); 112 log::warn!("This is not strictly needed unless you scale PDS Gatekeeper. Will not also be able to verify tokens between reboots, but they are short lived (5mins)."); 113 key 114 }); 115 116 if gate_jwe_key.len() != 32 { 117 panic!( 118 "GATEKEEPER_JWE_KEY must be 32 bytes (64 hex characters) for AES-256 encryption" 119 ); 120 } 121 122 let captcha_success_redirects = match env::var("GATEKEEPER_CAPTCHA_SUCCESS_REDIRECTS") { 123 Ok(from_env) => from_env.split(",").map(|s| s.trim().to_string()).collect(), 124 Err(_) => { 125 vec![ 126 String::from("https://bsky.app"), 127 String::from("https://pdsmoover.com"), 128 String::from("https://blacksky.community"), 129 String::from("https://tektite.cc"), 130 ] 131 } 132 }; 133 134 AppConfig { 135 pds_base_url, 136 mailer_from, 137 email_subject, 138 allow_only_migrations, 139 use_captcha, 140 default_successful_redirect_url: env::var("GATEKEEPER_DEFAULT_CAPTCHA_REDIRECT") 141 .unwrap_or("https://bsky.app".to_string()), 142 pds_service_did: pds_service_did 143 .parse() 144 .expect("PDS_SERVICE_DID is not a valid did or could not infer from PDS_HOSTNAME"), 145 gate_jwe_key, 146 captcha_success_redirects, 147 } 148 } 149} 150 151#[derive(Clone)] 152pub struct AppState { 153 account_pool: SqlitePool, 154 pds_gatekeeper_pool: SqlitePool, 155 reverse_proxy_client: HyperUtilClient, 156 mailer: Arc<Mailer>, 157 template_engine: Engine<Handlebars<'static>>, 158 resolver: Arc<PublicResolver>, 159 handle_cache: auth::HandleCache, 160 app_config: AppConfig, 161} 162 163async fn root_handler() -> impl axum::response::IntoResponse { 164 let body = r" 165 166 ...oO _.--X~~OO~~X--._ ...oOO 167 _.-~ / \ II / \ ~-._ 168 [].-~ \ / \||/ \ / ~-.[] ...o 169 ...o _ ||/ \ / || \ / \|| _ 170 (_) |X X || X X| (_) 171 _-~-_ ||\ / \ || / \ /|| _-~-_ 172 ||||| || \ / \ /||\ / \ / || ||||| 173 | |_|| \ / \ / || \ / \ / ||_| | 174 | |~|| X X || X X ||~| | 175==============| | || / \ / \ || / \ / \ || | |============== 176______________| | || / \ / \||/ \ / \ || | |______________ 177 . . | | ||/ \ / || \ / \|| | | . . 178 / | | |X X || X X| | | / / 179 / . | | ||\ / \ || / \ /|| | | . / . 180. / | | || \ / \ /||\ / \ / || | | . . 181 . . | | || \ / \ / || \ / \ / || | | . 182 / | | || X X || X X || | | . / . / 183 / . | | || / \ / \ || / \ / \ || | | / 184 / | | || / \ / \||/ \ / \ || | | . / 185. . . | | ||/ \ / /||\ \ / \|| | | /. . 186 | |_|X X / II \ X X|_| | . . / 187==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |============== 188 "; 189 190 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; 191 192 let banner = format!(" {body}\n{intro}"); 193 194 ( 195 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], 196 banner, 197 ) 198} 199 200#[tokio::main] 201async fn main() -> Result<(), Box<dyn std::error::Error>> { 202 setup_tracing(); 203 let pds_env_location = 204 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string()); 205 206 let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location)); 207 if let Err(e) = result_of_finding_pds_env { 208 log::error!( 209 "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}" 210 ); 211 } 212 213 let pds_root = 214 env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file"); 215 let account_db_url = format!("{pds_root}/account.sqlite"); 216 217 let account_options = SqliteConnectOptions::new() 218 .journal_mode(SqliteJournalMode::Wal) 219 .filename(account_db_url) 220 .busy_timeout(Duration::from_secs(5)); 221 222 let account_pool = SqlitePoolOptions::new() 223 .max_connections(5) 224 .connect_with(account_options) 225 .await?; 226 227 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); 228 let options = SqliteConnectOptions::new() 229 .journal_mode(SqliteJournalMode::Wal) 230 .filename(bells_db_url) 231 .create_if_missing(true) 232 .busy_timeout(Duration::from_secs(5)); 233 let pds_gatekeeper_pool = SqlitePoolOptions::new() 234 .max_connections(5) 235 .connect_with(options) 236 .await?; 237 238 // Run migrations for the extra database 239 // Note: the migrations are embedded at compile time from the given directory 240 // sqlx 241 sqlx::migrate!("./migrations") 242 .run(&pds_gatekeeper_pool) 243 .await?; 244 245 let client: HyperUtilClient = 246 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) 247 .build(HttpConnector::new()); 248 249 //Emailer set up 250 let mailer = Arc::new(build_mailer_from_env()?); 251 252 //Email templates setup 253 let mut hbs = Handlebars::new(); 254 255 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); 256 if let Ok(users_email_directory) = users_email_directory { 257 hbs.register_template_file( 258 "two_factor_code.hbs", 259 format!("{users_email_directory}/two_factor_code.hbs"), 260 )?; 261 } else { 262 let _ = hbs.register_embed_templates::<EmailTemplates>(); 263 } 264 265 let _ = hbs.register_embed_templates::<HtmlTemplates>(); 266 267 //Reads the PLC source from the pds env's or defaults to ol faithful 268 let plc_source_url = 269 env::var("PDS_DID_PLC_URL").unwrap_or_else(|_| "https://plc.directory".to_string()); 270 let plc_source = PlcSource::PlcDirectory { 271 base: plc_source_url.parse().unwrap(), 272 }; 273 let mut resolver = PublicResolver::default(); 274 resolver = resolver.with_plc_source(plc_source.clone()); 275 276 let state = AppState { 277 account_pool, 278 pds_gatekeeper_pool, 279 reverse_proxy_client: client, 280 mailer, 281 template_engine: Engine::from(hbs), 282 resolver: Arc::new(resolver), 283 handle_cache: auth::HandleCache::new(), 284 app_config: AppConfig::new(), 285 }; 286 287 // Rate limiting 288 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 289 let captcha_governor_conf = GovernorConfigBuilder::default() 290 .per_second(60) 291 .burst_size(5) 292 .key_extractor(SmartIpKeyExtractor) 293 .finish() 294 .expect("failed to create governor config for create session. this should not happen and is a bug"); 295 296 // Create a second config with the same settings for the other endpoint 297 let sign_in_governor_conf = GovernorConfigBuilder::default() 298 .per_second(60) 299 .burst_size(5) 300 .key_extractor(SmartIpKeyExtractor) 301 .finish() 302 .expect( 303 "failed to create governor config for sign in. this should not happen and is a bug", 304 ); 305 306 let create_account_limiter_time: Option<String> = 307 env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok(); 308 let create_account_limiter_burst: Option<String> = 309 env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok(); 310 311 //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally 312 let mut create_account_governor_conf = GovernorConfigBuilder::default(); 313 if create_account_limiter_time.is_some() { 314 let time = create_account_limiter_time 315 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set") 316 .parse::<u64>() 317 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer"); 318 create_account_governor_conf.per_second(time); 319 } 320 321 if create_account_limiter_burst.is_some() { 322 let burst = create_account_limiter_burst 323 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set") 324 .parse::<u32>() 325 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer"); 326 create_account_governor_conf.burst_size(burst); 327 } 328 329 let create_account_governor_conf = create_account_governor_conf 330 .key_extractor(SmartIpKeyExtractor) 331 .finish().expect( 332 "failed to create governor config for create account. this should not happen and is a bug", 333 ); 334 335 let captcha_governor_limiter = captcha_governor_conf.limiter().clone(); 336 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 337 let create_account_governor_limiter = create_account_governor_conf.limiter().clone(); 338 339 let sign_in_governor_layer = GovernorLayer::new(sign_in_governor_conf); 340 341 let interval = Duration::from_secs(60); 342 // a separate background task to clean up 343 std::thread::spawn(move || { 344 loop { 345 std::thread::sleep(interval); 346 captcha_governor_limiter.retain_recent(); 347 sign_in_governor_limiter.retain_recent(); 348 create_account_governor_limiter.retain_recent(); 349 } 350 }); 351 352 let cors = CorsLayer::new() 353 .allow_origin(Any) 354 .allow_methods([Method::GET, Method::OPTIONS, Method::POST]) 355 .allow_headers(Any); 356 357 let mut app = Router::new() 358 .route("/", get(root_handler)) 359 .route("/xrpc/com.atproto.server.getSession", get(get_session)) 360 .route( 361 "/xrpc/com.atproto.server.describeServer", 362 get(describe_server), 363 ) 364 .route( 365 "/xrpc/com.atproto.server.updateEmail", 366 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), 367 ) 368 .route( 369 "/@atproto/oauth-provider/~api/sign-in", 370 post(sign_in).layer(sign_in_governor_layer.clone()), 371 ) 372 .route( 373 "/xrpc/com.atproto.server.createSession", 374 post(create_session.layer(sign_in_governor_layer)), 375 ) 376 .route( 377 "/xrpc/com.atproto.server.createAccount", 378 post(create_account).layer(GovernorLayer::new(create_account_governor_conf)), 379 ); 380 381 if state.app_config.use_captcha { 382 app = app.route( 383 "/gate/signup", 384 get(get_gate).post(post_gate.layer(GovernorLayer::new(captcha_governor_conf))), 385 ); 386 } 387 388 let app = app 389 .layer(CompressionLayer::new()) 390 .layer(cors) 391 .with_state(state); 392 393 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()); 394 let port: u16 = env::var("GATEKEEPER_PORT") 395 .ok() 396 .and_then(|s| s.parse().ok()) 397 .unwrap_or(8080); 398 let addr: SocketAddr = format!("{host}:{port}") 399 .parse() 400 .expect("valid socket address"); 401 402 let listener = tokio::net::TcpListener::bind(addr).await?; 403 404 let server = axum::serve( 405 listener, 406 app.into_make_service_with_connect_info::<SocketAddr>(), 407 ) 408 .with_graceful_shutdown(shutdown_signal()); 409 410 if let Err(err) = server.await { 411 log::error!("server error:{err}"); 412 } 413 414 Ok(()) 415} 416 417fn setup_tracing() { 418 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); 419 tracing_subscriber::registry() 420 .with(env_filter) 421 .with(fmt::layer()) 422 .init(); 423} 424 425async fn shutdown_signal() { 426 // Wait for Ctrl+C 427 let ctrl_c = async { 428 tokio::signal::ctrl_c() 429 .await 430 .expect("failed to install Ctrl+C handler"); 431 }; 432 433 #[cfg(unix)] 434 let terminate = async { 435 use tokio::signal::unix::{SignalKind, signal}; 436 437 let mut sigterm = 438 signal(SignalKind::terminate()).expect("failed to install signal handler"); 439 sigterm.recv().await; 440 }; 441 442 #[cfg(not(unix))] 443 let terminate = std::future::pending::<()>(); 444 445 tokio::select! { 446 _ = ctrl_c => {}, 447 _ = terminate => {}, 448 } 449}