···1919use std::time::Duration;
2020use std::{env, net::SocketAddr};
2121use tower_governor::GovernorLayer;
2222-use tower_governor::governor::GovernorConfigBuilder;
2222+use tower_governor::governor::{GovernorConfig, GovernorConfigBuilder};
2323+use tower_governor::key_extractor::PeerIpKeyExtractor;
2324use tower_http::compression::CompressionLayer;
2425use tower_http::cors::{Any, CorsLayer};
2526use tracing::log;
···166167 .per_second(60)
167168 .burst_size(5)
168169 .finish()
169169- .expect("failed to create governor config. this should not happen and is a bug");
170170+ .expect("failed to create governor config for create session. this should not happen and is a bug");
170171171172 // Create a second config with the same settings for the other endpoint
172173 let sign_in_governor_conf = GovernorConfigBuilder::default()
173174 .per_second(60)
174175 .burst_size(5)
175176 .finish()
176176- .expect("failed to create governor config. this should not happen and is a bug");
177177+ .expect(
178178+ "failed to create governor config for sign in. this should not happen and is a bug",
179179+ );
180180+181181+ let create_account_limiter_time: Option<String> =
182182+ env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok();
183183+ let create_account_limiter_burst: Option<String> =
184184+ env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok();
185185+ let mut create_account_governor_conf = None;
177186178178- // let create_account_limiter_time: Option<String> =
179179- // env::var("GATEKEEPER_CREATE_ACCOUNT_LIMITER_WINDOW").unwrap_or_else(|_| None);
187187+ if create_account_governor_conf.is_some() && create_account_limiter_time.is_some() {
188188+ let time = create_account_limiter_time
189189+ .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set")
190190+ .parse::<u64>()
191191+ .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer");
192192+ let burst = create_account_limiter_burst
193193+ .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set")
194194+ .parse::<u32>()
195195+ .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer");
196196+197197+ create_account_governor_conf = Some(
198198+ GovernorConfigBuilder::default()
199199+ .per_second(time)
200200+ .burst_size(burst)
201201+ .finish()
202202+ .expect("failed to create governor config for create account. this should not happen and is a bug"),
203203+ )
204204+ }
180205181206 let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
182207 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
208208+ let create_account_governor_limiter = match create_account_governor_conf {
209209+ None => None,
210210+ Some(conf) => Some(conf.limiter().clone()),
211211+ };
212212+183213 let interval = Duration::from_secs(60);
184214 // a separate background task to clean up
185215 std::thread::spawn(move || {
···187217 std::thread::sleep(interval);
188218 create_session_governor_limiter.retain_recent();
189219 sign_in_governor_limiter.retain_recent();
220220+ if let Some(ref limiter) = create_account_governor_limiter {
221221+ limiter.retain_recent();
222222+ }
190223 }
191224 });
192225···197230198231 let app = Router::new()
199232 .route("/", get(root_handler))
200200- .route(
201201- "/xrpc/com.atproto.server.getSession",
202202- get(get_session).layer(ax_middleware::from_fn(middleware::extract_did)),
203203- )
233233+ .route("/xrpc/com.atproto.server.getSession", get(get_session))
204234 .route(
205235 "/xrpc/com.atproto.server.updateEmail",
206236 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
···213243 "/xrpc/com.atproto.server.createSession",
214244 post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
215245 )
246246+ .route("/xrpc/com.atproto.server.createAccount")
216247 .layer(CompressionLayer::new())
217248 .layer(cors)
218249 .with_state(state);
+45-32
src/middleware.rs
···3535 Some((scheme, token_str)) => {
3636 // For Bearer, validate JWT and extract DID from `sub`.
3737 // For DPoP, we currently only pass through and do not validate here; insert None DID.
3838- // match scheme {
3939- // AuthScheme::Bearer => {
4040- let token = UntrustedToken::new(&token_str);
4141- if token.is_err() {
4242- return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
4343- .expect("Error creating an error response");
4444- }
4545- let parsed_token = token.expect("Already checked for error");
4646- let claims: Result<Claims<TokenClaims>, ValidationError> =
4747- parsed_token.deserialize_claims_unchecked();
4848- if claims.is_err() {
4949- return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
5050- .expect("Error creating an error response");
5151- }
3838+ match scheme {
3939+ AuthScheme::Bearer => {
4040+ let token = UntrustedToken::new(&token_str);
4141+ if token.is_err() {
4242+ return json_error_response(
4343+ StatusCode::BAD_REQUEST,
4444+ "TokenRequired",
4545+ "",
4646+ )
4747+ .expect("Error creating an error response");
4848+ }
4949+ let parsed_token = token.expect("Already checked for error");
5050+ let claims: Result<Claims<TokenClaims>, ValidationError> =
5151+ parsed_token.deserialize_claims_unchecked();
5252+ if claims.is_err() {
5353+ return json_error_response(
5454+ StatusCode::BAD_REQUEST,
5555+ "TokenRequired",
5656+ "",
5757+ )
5858+ .expect("Error creating an error response");
5959+ }
52605353- let key = Hs256Key::new(
5454- env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"),
5555- );
5656- let token: Result<Token<TokenClaims>, ValidationError> =
5757- Hs256.validator(&key).validate(&parsed_token);
5858- if token.is_err() {
5959- return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
6060- .expect("Error creating an error response");
6161+ let key = Hs256Key::new(
6262+ env::var("PDS_JWT_SECRET")
6363+ .expect("PDS_JWT_SECRET not set in the pds.env"),
6464+ );
6565+ let token: Result<Token<TokenClaims>, ValidationError> =
6666+ Hs256.validator(&key).validate(&parsed_token);
6767+ if token.is_err() {
6868+ return json_error_response(
6969+ StatusCode::BAD_REQUEST,
7070+ "InvalidToken",
7171+ "",
7272+ )
7373+ .expect("Error creating an error response");
7474+ }
7575+ let token = token.expect("Already checked for error,");
7676+ // Not going to worry about expiration since it still goes to the PDS
7777+ req.extensions_mut()
7878+ .insert(Did(Some(token.claims().custom.sub.clone())));
7979+ }
8080+ AuthScheme::DPoP => {
8181+ //Not going to worry about oauth email update for now, just always forward to the PDS
8282+ req.extensions_mut().insert(Did(None));
8383+ }
6184 }
6262- let token = token.expect("Already checked for error,");
6363- // Not going to worry about expiration since it still goes to the PDS
6464- req.extensions_mut()
6565- .insert(Did(Some(token.claims().custom.sub.clone())));
6666- // }
6767- // AuthScheme::DPoP => {
6868- // // No DID extraction from DPoP here; leave None
6969- // req.extensions_mut().insert(Did(None));
7070- // }
7171- // }
72857386 next.run(req).await
7487 }
+51-46
src/xrpc/com_atproto_server.rs
···147147 //If email auth is set it is to either turn on or off 2fa
148148 let email_auth_update = payload.email_auth_factor.unwrap_or(false);
149149150150- // Email update asked for
151151- if email_auth_update {
152152- let email = payload.email.clone();
153153- let email_confirmed = sqlx::query_as::<_, (String,)>(
154154- "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?",
155155- )
156156- .bind(&email)
157157- .fetch_optional(&state.account_pool)
158158- .await
159159- .map_err(|_| StatusCode::BAD_REQUEST)?;
150150+ //This means the middleware successfully extracted a did from the request, if not it just needs to be forward to the PDS
151151+ //This is also empty if it is an oauth request, which is not supported by gatekeeper turning on 2fa since the dpop stuff needs to be implemented
152152+ let did_is_not_empty = did.0.is_some();
160153161161- //Since the email is already confirmed we can enable 2fa
162162- return match email_confirmed {
163163- None => Err(StatusCode::BAD_REQUEST),
164164- Some(did_row) => {
165165- let _ = sqlx::query(
166166- "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1",
167167- )
168168- .bind(&did_row.0)
169169- .execute(&state.pds_gatekeeper_pool)
170170- .await
171171- .map_err(|_| StatusCode::BAD_REQUEST)?;
172172-173173- Ok(StatusCode::OK.into_response())
174174- }
175175- };
176176- }
177177-178178- // User wants auth turned off
179179- if !email_auth_update && !email_auth_not_set {
180180- //User wants auth turned off and has a token
181181- if let Some(token) = &payload.token {
182182- let token_found = sqlx::query_as::<_, (String,)>(
183183- "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'",
154154+ if did_is_not_empty {
155155+ // Email update asked for
156156+ if email_auth_update {
157157+ let email = payload.email.clone();
158158+ let email_confirmed = sqlx::query_as::<_, (String,)>(
159159+ "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?",
184160 )
185185- .bind(token)
186186- .bind(&did.0)
161161+ .bind(&email)
187162 .fetch_optional(&state.account_pool)
188163 .await
189164 .map_err(|_| StatusCode::BAD_REQUEST)?;
190165191191- if token_found.is_some() {
192192- let _ = sqlx::query(
193193- "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0",
166166+ //Since the email is already confirmed we can enable 2fa
167167+ return match email_confirmed {
168168+ None => Err(StatusCode::BAD_REQUEST),
169169+ Some(did_row) => {
170170+ let _ = sqlx::query(
171171+ "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1",
172172+ )
173173+ .bind(&did_row.0)
174174+ .execute(&state.pds_gatekeeper_pool)
175175+ .await
176176+ .map_err(|_| StatusCode::BAD_REQUEST)?;
177177+178178+ Ok(StatusCode::OK.into_response())
179179+ }
180180+ };
181181+ }
182182+183183+ // User wants auth turned off
184184+ if !email_auth_update && !email_auth_not_set {
185185+ //User wants auth turned off and has a token
186186+ if let Some(token) = &payload.token {
187187+ let token_found = sqlx::query_as::<_, (String,)>(
188188+ "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'",
194189 )
195195- .bind(&did.0)
196196- .execute(&state.pds_gatekeeper_pool)
197197- .await
198198- .map_err(|_| StatusCode::BAD_REQUEST)?;
190190+ .bind(token)
191191+ .bind(&did.0)
192192+ .fetch_optional(&state.account_pool)
193193+ .await
194194+ .map_err(|_| StatusCode::BAD_REQUEST)?;
199195200200- return Ok(StatusCode::OK.into_response());
201201- } else {
202202- return Err(StatusCode::BAD_REQUEST);
196196+ return if token_found.is_some() {
197197+ let _ = sqlx::query(
198198+ "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0",
199199+ )
200200+ .bind(&did.0)
201201+ .execute(&state.pds_gatekeeper_pool)
202202+ .await
203203+ .map_err(|_| StatusCode::BAD_REQUEST)?;
204204+205205+ Ok(StatusCode::OK.into_response())
206206+ } else {
207207+ Err(StatusCode::BAD_REQUEST)
208208+ };
203209 }
204210 }
205211 }
206206-207212 // Updating the actual email address by sending it on to the PDS
208213 let uri = format!(
209214 "{}{}",