···19use std::time::Duration;
20use std::{env, net::SocketAddr};
21use tower_governor::GovernorLayer;
22-use tower_governor::governor::GovernorConfigBuilder;
023use tower_http::compression::CompressionLayer;
24use tower_http::cors::{Any, CorsLayer};
25use tracing::log;
···166 .per_second(60)
167 .burst_size(5)
168 .finish()
169- .expect("failed to create governor config. this should not happen and is a bug");
170171 // Create a second config with the same settings for the other endpoint
172 let sign_in_governor_conf = GovernorConfigBuilder::default()
173 .per_second(60)
174 .burst_size(5)
175 .finish()
176- .expect("failed to create governor config. this should not happen and is a bug");
00000000177178- // let create_account_limiter_time: Option<String> =
179- // env::var("GATEKEEPER_CREATE_ACCOUNT_LIMITER_WINDOW").unwrap_or_else(|_| None);
0000000000000000180181 let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
182 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
00000183 let interval = Duration::from_secs(60);
184 // a separate background task to clean up
185 std::thread::spawn(move || {
···187 std::thread::sleep(interval);
188 create_session_governor_limiter.retain_recent();
189 sign_in_governor_limiter.retain_recent();
000190 }
191 });
192···197198 let app = Router::new()
199 .route("/", get(root_handler))
200- .route(
201- "/xrpc/com.atproto.server.getSession",
202- get(get_session).layer(ax_middleware::from_fn(middleware::extract_did)),
203- )
204 .route(
205 "/xrpc/com.atproto.server.updateEmail",
206 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
···213 "/xrpc/com.atproto.server.createSession",
214 post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
215 )
0216 .layer(CompressionLayer::new())
217 .layer(cors)
218 .with_state(state);
···19use std::time::Duration;
20use std::{env, net::SocketAddr};
21use tower_governor::GovernorLayer;
22+use tower_governor::governor::{GovernorConfig, GovernorConfigBuilder};
23+use tower_governor::key_extractor::PeerIpKeyExtractor;
24use tower_http::compression::CompressionLayer;
25use tower_http::cors::{Any, CorsLayer};
26use tracing::log;
···167 .per_second(60)
168 .burst_size(5)
169 .finish()
170+ .expect("failed to create governor config for create session. this should not happen and is a bug");
171172 // Create a second config with the same settings for the other endpoint
173 let sign_in_governor_conf = GovernorConfigBuilder::default()
174 .per_second(60)
175 .burst_size(5)
176 .finish()
177+ .expect(
178+ "failed to create governor config for sign in. this should not happen and is a bug",
179+ );
180+181+ let create_account_limiter_time: Option<String> =
182+ env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok();
183+ let create_account_limiter_burst: Option<String> =
184+ env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok();
185+ let mut create_account_governor_conf = None;
186187+ if create_account_governor_conf.is_some() && create_account_limiter_time.is_some() {
188+ let time = create_account_limiter_time
189+ .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set")
190+ .parse::<u64>()
191+ .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer");
192+ let burst = create_account_limiter_burst
193+ .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set")
194+ .parse::<u32>()
195+ .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer");
196+197+ create_account_governor_conf = Some(
198+ GovernorConfigBuilder::default()
199+ .per_second(time)
200+ .burst_size(burst)
201+ .finish()
202+ .expect("failed to create governor config for create account. this should not happen and is a bug"),
203+ )
204+ }
205206 let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
207 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
208+ let create_account_governor_limiter = match create_account_governor_conf {
209+ None => None,
210+ Some(conf) => Some(conf.limiter().clone()),
211+ };
212+213 let interval = Duration::from_secs(60);
214 // a separate background task to clean up
215 std::thread::spawn(move || {
···217 std::thread::sleep(interval);
218 create_session_governor_limiter.retain_recent();
219 sign_in_governor_limiter.retain_recent();
220+ if let Some(ref limiter) = create_account_governor_limiter {
221+ limiter.retain_recent();
222+ }
223 }
224 });
225···230231 let app = Router::new()
232 .route("/", get(root_handler))
233+ .route("/xrpc/com.atproto.server.getSession", get(get_session))
000234 .route(
235 "/xrpc/com.atproto.server.updateEmail",
236 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
···243 "/xrpc/com.atproto.server.createSession",
244 post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
245 )
246+ .route("/xrpc/com.atproto.server.createAccount")
247 .layer(CompressionLayer::new())
248 .layer(cors)
249 .with_state(state);
+45-32
src/middleware.rs
···35 Some((scheme, token_str)) => {
36 // For Bearer, validate JWT and extract DID from `sub`.
37 // For DPoP, we currently only pass through and do not validate here; insert None DID.
38- // match scheme {
39- // AuthScheme::Bearer => {
40- let token = UntrustedToken::new(&token_str);
41- if token.is_err() {
42- return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
43- .expect("Error creating an error response");
44- }
45- let parsed_token = token.expect("Already checked for error");
46- let claims: Result<Claims<TokenClaims>, ValidationError> =
47- parsed_token.deserialize_claims_unchecked();
48- if claims.is_err() {
49- return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
50- .expect("Error creating an error response");
51- }
000000005253- let key = Hs256Key::new(
54- env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"),
55- );
56- let token: Result<Token<TokenClaims>, ValidationError> =
57- Hs256.validator(&key).validate(&parsed_token);
58- if token.is_err() {
59- return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
60- .expect("Error creating an error response");
00000000000000061 }
62- let token = token.expect("Already checked for error,");
63- // Not going to worry about expiration since it still goes to the PDS
64- req.extensions_mut()
65- .insert(Did(Some(token.claims().custom.sub.clone())));
66- // }
67- // AuthScheme::DPoP => {
68- // // No DID extraction from DPoP here; leave None
69- // req.extensions_mut().insert(Did(None));
70- // }
71- // }
7273 next.run(req).await
74 }
···35 Some((scheme, token_str)) => {
36 // For Bearer, validate JWT and extract DID from `sub`.
37 // For DPoP, we currently only pass through and do not validate here; insert None DID.
38+ match scheme {
39+ AuthScheme::Bearer => {
40+ let token = UntrustedToken::new(&token_str);
41+ if token.is_err() {
42+ return json_error_response(
43+ StatusCode::BAD_REQUEST,
44+ "TokenRequired",
45+ "",
46+ )
47+ .expect("Error creating an error response");
48+ }
49+ let parsed_token = token.expect("Already checked for error");
50+ let claims: Result<Claims<TokenClaims>, ValidationError> =
51+ parsed_token.deserialize_claims_unchecked();
52+ if claims.is_err() {
53+ return json_error_response(
54+ StatusCode::BAD_REQUEST,
55+ "TokenRequired",
56+ "",
57+ )
58+ .expect("Error creating an error response");
59+ }
6061+ let key = Hs256Key::new(
62+ env::var("PDS_JWT_SECRET")
63+ .expect("PDS_JWT_SECRET not set in the pds.env"),
64+ );
65+ let token: Result<Token<TokenClaims>, ValidationError> =
66+ Hs256.validator(&key).validate(&parsed_token);
67+ if token.is_err() {
68+ return json_error_response(
69+ StatusCode::BAD_REQUEST,
70+ "InvalidToken",
71+ "",
72+ )
73+ .expect("Error creating an error response");
74+ }
75+ let token = token.expect("Already checked for error,");
76+ // Not going to worry about expiration since it still goes to the PDS
77+ req.extensions_mut()
78+ .insert(Did(Some(token.claims().custom.sub.clone())));
79+ }
80+ AuthScheme::DPoP => {
81+ //Not going to worry about oauth email update for now, just always forward to the PDS
82+ req.extensions_mut().insert(Did(None));
83+ }
84 }
00000000008586 next.run(req).await
87 }
+51-46
src/xrpc/com_atproto_server.rs
···147 //If email auth is set it is to either turn on or off 2fa
148 let email_auth_update = payload.email_auth_factor.unwrap_or(false);
149150- // Email update asked for
151- if email_auth_update {
152- let email = payload.email.clone();
153- let email_confirmed = sqlx::query_as::<_, (String,)>(
154- "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?",
155- )
156- .bind(&email)
157- .fetch_optional(&state.account_pool)
158- .await
159- .map_err(|_| StatusCode::BAD_REQUEST)?;
160161- //Since the email is already confirmed we can enable 2fa
162- return match email_confirmed {
163- None => Err(StatusCode::BAD_REQUEST),
164- Some(did_row) => {
165- let _ = sqlx::query(
166- "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1",
167- )
168- .bind(&did_row.0)
169- .execute(&state.pds_gatekeeper_pool)
170- .await
171- .map_err(|_| StatusCode::BAD_REQUEST)?;
172-173- Ok(StatusCode::OK.into_response())
174- }
175- };
176- }
177-178- // User wants auth turned off
179- if !email_auth_update && !email_auth_not_set {
180- //User wants auth turned off and has a token
181- if let Some(token) = &payload.token {
182- let token_found = sqlx::query_as::<_, (String,)>(
183- "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'",
184 )
185- .bind(token)
186- .bind(&did.0)
187 .fetch_optional(&state.account_pool)
188 .await
189 .map_err(|_| StatusCode::BAD_REQUEST)?;
190191- if token_found.is_some() {
192- let _ = sqlx::query(
193- "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0",
00000000000000000000194 )
195- .bind(&did.0)
196- .execute(&state.pds_gatekeeper_pool)
197- .await
198- .map_err(|_| StatusCode::BAD_REQUEST)?;
0199200- return Ok(StatusCode::OK.into_response());
201- } else {
202- return Err(StatusCode::BAD_REQUEST);
0000000000203 }
204 }
205 }
206-207 // Updating the actual email address by sending it on to the PDS
208 let uri = format!(
209 "{}{}",
···147 //If email auth is set it is to either turn on or off 2fa
148 let email_auth_update = payload.email_auth_factor.unwrap_or(false);
149150+ //This means the middleware successfully extracted a did from the request, if not it just needs to be forward to the PDS
151+ //This is also empty if it is an oauth request, which is not supported by gatekeeper turning on 2fa since the dpop stuff needs to be implemented
152+ let did_is_not_empty = did.0.is_some();
0000000153154+ if did_is_not_empty {
155+ // Email update asked for
156+ if email_auth_update {
157+ let email = payload.email.clone();
158+ let email_confirmed = sqlx::query_as::<_, (String,)>(
159+ "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?",
00000000000000000160 )
161+ .bind(&email)
0162 .fetch_optional(&state.account_pool)
163 .await
164 .map_err(|_| StatusCode::BAD_REQUEST)?;
165166+ //Since the email is already confirmed we can enable 2fa
167+ return match email_confirmed {
168+ None => Err(StatusCode::BAD_REQUEST),
169+ Some(did_row) => {
170+ let _ = sqlx::query(
171+ "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1",
172+ )
173+ .bind(&did_row.0)
174+ .execute(&state.pds_gatekeeper_pool)
175+ .await
176+ .map_err(|_| StatusCode::BAD_REQUEST)?;
177+178+ Ok(StatusCode::OK.into_response())
179+ }
180+ };
181+ }
182+183+ // User wants auth turned off
184+ if !email_auth_update && !email_auth_not_set {
185+ //User wants auth turned off and has a token
186+ if let Some(token) = &payload.token {
187+ let token_found = sqlx::query_as::<_, (String,)>(
188+ "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'",
189 )
190+ .bind(token)
191+ .bind(&did.0)
192+ .fetch_optional(&state.account_pool)
193+ .await
194+ .map_err(|_| StatusCode::BAD_REQUEST)?;
195196+ return if token_found.is_some() {
197+ let _ = sqlx::query(
198+ "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0",
199+ )
200+ .bind(&did.0)
201+ .execute(&state.pds_gatekeeper_pool)
202+ .await
203+ .map_err(|_| StatusCode::BAD_REQUEST)?;
204+205+ Ok(StatusCode::OK.into_response())
206+ } else {
207+ Err(StatusCode::BAD_REQUEST)
208+ };
209 }
210 }
211 }
0212 // Updating the actual email address by sending it on to the PDS
213 let uri = format!(
214 "{}{}",