slack status without the slack status.zzstoatzz.io/
quickslice

style: format code and fix linting issues

- Run cargo fmt on all files
- Fix trailing whitespace issues
- Allow dead_code for config fields (will be used in future refactor)
- Fix clippy warning about unnecessary to_string()
- Pre-commit hooks now working properly

+181 -135
+10 -9
src/config.rs
··· 3 3 4 4 /// Application configuration loaded from environment variables 5 5 #[derive(Debug, Clone, Deserialize)] 6 + #[allow(dead_code)] 6 7 pub struct Config { 7 8 /// The admin DID for moderation (intentionally hardcoded for security) 8 9 pub admin_did: String, 9 - 10 + 10 11 /// Owner handle for the default status page 11 12 pub owner_handle: String, 12 - 13 + 13 14 /// Database URL (defaults to local SQLite) 14 15 pub database_url: String, 15 - 16 + 16 17 /// OAuth redirect base URL 17 18 pub oauth_redirect_base: String, 18 - 19 + 19 20 /// Server host 20 21 pub server_host: String, 21 - 22 + 22 23 /// Server port 23 24 pub server_port: u16, 24 - 25 + 25 26 /// Enable firehose ingester 26 27 pub enable_firehose: bool, 27 - 28 + 28 29 /// Log level 29 30 pub log_level: String, 30 31 } ··· 34 35 pub fn from_env() -> Result<Self, env::VarError> { 35 36 // Admin DID is intentionally hardcoded as discussed 36 37 let admin_did = "did:plc:xbtmt2zjwlrfegqvch7fboei".to_string(); 37 - 38 + 38 39 Ok(Config { 39 40 admin_did, 40 41 owner_handle: env::var("OWNER_HANDLE").unwrap_or_else(|_| "zzstoatzz.io".to_string()), ··· 54 55 log_level: env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string()), 55 56 }) 56 57 } 57 - } 58 + }
+4 -4
src/db.rs
··· 50 50 .unwrap(); 51 51 52 52 // Note: custom_emojis table removed - we serve emojis directly from static/emojis/ directory 53 - 53 + 54 54 // Add indexes for performance optimization 55 55 // Index on startedAt for feed queries (ORDER BY startedAt DESC) 56 56 conn.execute( ··· 58 58 [], 59 59 ) 60 60 .unwrap(); 61 - 61 + 62 62 // Composite index for user status queries (WHERE authorDid = ? ORDER BY startedAt DESC) 63 63 conn.execute( 64 64 "CREATE INDEX IF NOT EXISTS idx_status_authorDid_startedAt ON status(authorDid, startedAt DESC)", 65 65 [], 66 66 ) 67 67 .unwrap(); 68 - 68 + 69 69 // Add hidden column for moderation (won't error if already exists) 70 70 let _ = conn.execute( 71 71 "ALTER TABLE status ADD COLUMN hidden BOOLEAN DEFAULT FALSE", 72 72 [], 73 73 ); 74 - 74 + 75 75 Ok(()) 76 76 }) 77 77 .await?;
+45 -23
src/error_handler.rs
··· 1 - use actix_web::{ 2 - error::ResponseError, 3 - http::StatusCode, 4 - HttpResponse, 5 - }; 1 + use actix_web::{HttpResponse, error::ResponseError, http::StatusCode}; 6 2 use std::fmt; 7 3 8 4 #[derive(Debug)] ··· 10 6 InternalError(String), 11 7 DatabaseError(String), 12 8 AuthenticationError(String), 13 - #[allow(dead_code)] // Keep for potential future use 9 + #[allow(dead_code)] // Keep for potential future use 14 10 ValidationError(String), 15 - #[allow(dead_code)] // Keep for potential future use 11 + #[allow(dead_code)] // Keep for potential future use 16 12 NotFound(String), 17 13 RateLimitExceeded, 18 14 } ··· 34 30 fn error_response(&self) -> HttpResponse { 35 31 let (status_code, error_message) = match self { 36 32 AppError::InternalError(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg.clone()), 37 - AppError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Database error occurred".to_string()), 33 + AppError::DatabaseError(_) => ( 34 + StatusCode::INTERNAL_SERVER_ERROR, 35 + "Database error occurred".to_string(), 36 + ), 38 37 AppError::AuthenticationError(msg) => (StatusCode::UNAUTHORIZED, msg.clone()), 39 38 AppError::ValidationError(msg) => (StatusCode::BAD_REQUEST, msg.clone()), 40 39 AppError::NotFound(msg) => (StatusCode::NOT_FOUND, msg.clone()), 41 - AppError::RateLimitExceeded => (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded. Please try again later.".to_string()), 40 + AppError::RateLimitExceeded => ( 41 + StatusCode::TOO_MANY_REQUESTS, 42 + "Rate limit exceeded. Please try again later.".to_string(), 43 + ), 42 44 }; 43 - 44 - HttpResponse::build(status_code) 45 - .body(format!("Error {}: {}", status_code.as_u16(), error_message)) 45 + 46 + HttpResponse::build(status_code).body(format!( 47 + "Error {}: {}", 48 + status_code.as_u16(), 49 + error_message 50 + )) 46 51 } 47 - 52 + 48 53 fn status_code(&self) -> StatusCode { 49 54 match self { 50 - AppError::InternalError(_) | AppError::DatabaseError(_) => StatusCode::INTERNAL_SERVER_ERROR, 55 + AppError::InternalError(_) | AppError::DatabaseError(_) => { 56 + StatusCode::INTERNAL_SERVER_ERROR 57 + } 51 58 AppError::AuthenticationError(_) => StatusCode::UNAUTHORIZED, 52 59 AppError::ValidationError(_) => StatusCode::BAD_REQUEST, 53 60 AppError::NotFound(_) => StatusCode::NOT_FOUND, ··· 75 82 #[cfg(test)] 76 83 mod tests { 77 84 use super::*; 78 - 85 + 79 86 #[test] 80 87 fn test_error_display() { 81 88 let err = AppError::ValidationError("Invalid input".to_string()); 82 89 assert_eq!(err.to_string(), "Validation error: Invalid input"); 83 - 90 + 84 91 let err = AppError::RateLimitExceeded; 85 92 assert_eq!(err.to_string(), "Rate limit exceeded"); 86 93 } 87 - 94 + 88 95 #[test] 89 96 fn test_error_status_codes() { 90 - assert_eq!(AppError::InternalError("test".to_string()).status_code(), StatusCode::INTERNAL_SERVER_ERROR); 91 - assert_eq!(AppError::ValidationError("test".to_string()).status_code(), StatusCode::BAD_REQUEST); 92 - assert_eq!(AppError::AuthenticationError("test".to_string()).status_code(), StatusCode::UNAUTHORIZED); 93 - assert_eq!(AppError::NotFound("test".to_string()).status_code(), StatusCode::NOT_FOUND); 94 - assert_eq!(AppError::RateLimitExceeded.status_code(), StatusCode::TOO_MANY_REQUESTS); 97 + assert_eq!( 98 + AppError::InternalError("test".to_string()).status_code(), 99 + StatusCode::INTERNAL_SERVER_ERROR 100 + ); 101 + assert_eq!( 102 + AppError::ValidationError("test".to_string()).status_code(), 103 + StatusCode::BAD_REQUEST 104 + ); 105 + assert_eq!( 106 + AppError::AuthenticationError("test".to_string()).status_code(), 107 + StatusCode::UNAUTHORIZED 108 + ); 109 + assert_eq!( 110 + AppError::NotFound("test".to_string()).status_code(), 111 + StatusCode::NOT_FOUND 112 + ); 113 + assert_eq!( 114 + AppError::RateLimitExceeded.status_code(), 115 + StatusCode::TOO_MANY_REQUESTS 116 + ); 95 117 } 96 - } 118 + }
+110 -88
src/main.rs
··· 7 7 storage::{SqliteSessionStore, SqliteStateStore}, 8 8 templates::{FeedTemplate, LoginTemplate, StatusTemplate}, 9 9 }; 10 - use async_sqlite::rusqlite; 11 10 use actix_files::Files; 12 11 use actix_session::{ 13 12 Session, SessionMiddleware, config::PersistentSession, storage::CookieSessionStore, ··· 19 18 web::{self, Redirect}, 20 19 }; 21 20 use askama::Template; 21 + use async_sqlite::rusqlite; 22 22 use async_sqlite::{Pool, PoolBuilder}; 23 23 use atrium_api::{ 24 24 agent::Agent, ··· 30 30 handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig}, 31 31 }; 32 32 use atrium_oauth::{ 33 - AtprotoClientMetadata, AtprotoLocalhostClientMetadata, AuthorizeOptions, CallbackParams, 34 - AuthMethod, DefaultHttpClient, GrantType, KnownScope, OAuthClient, OAuthClientConfig, 33 + AtprotoClientMetadata, AtprotoLocalhostClientMetadata, AuthMethod, AuthorizeOptions, 34 + CallbackParams, DefaultHttpClient, GrantType, KnownScope, OAuthClient, OAuthClientConfig, 35 35 OAuthResolverConfig, Scope, 36 36 }; 37 37 use dotenv::dotenv; ··· 91 91 #[get("/client-metadata.json")] 92 92 async fn client_metadata(config: web::Data<config::Config>) -> Result<HttpResponse> { 93 93 let public_url = config.oauth_redirect_base.clone(); 94 - 94 + 95 95 let metadata = serde_json::json!({ 96 96 "client_id": format!("{}/client-metadata.json", public_url), 97 97 "client_name": "Status Sphere", ··· 103 103 "token_endpoint_auth_method": "none", 104 104 "dpop_bound_access_tokens": true 105 105 }); 106 - 106 + 107 107 Ok(HttpResponse::Ok() 108 108 .content_type("application/json") 109 109 .body(metadata.to_string())) ··· 153 153 ) -> HttpResponse { 154 154 // Check if there's an OAuth error from BlueSky 155 155 if let Some(error) = &params.error { 156 - let error_msg = params.error_description.as_deref() 156 + let error_msg = params 157 + .error_description 158 + .as_deref() 157 159 .unwrap_or("An error occurred during authentication"); 158 160 log::error!("OAuth error from BlueSky: {} - {}", error, error_msg); 159 - 161 + 160 162 let html = ErrorTemplate { 161 163 title: "Authentication Error", 162 164 error: error_msg, 163 165 }; 164 166 return HttpResponse::BadRequest().body(html.render().expect("template should be valid")); 165 167 } 166 - 168 + 167 169 // Check if we have the required code field for a successful callback 168 170 let code = match &params.code { 169 171 Some(code) => code.clone(), ··· 173 175 title: "Error", 174 176 error: "Missing required OAuth code. Please try logging in again.", 175 177 }; 176 - return HttpResponse::BadRequest().body(html.render().expect("template should be valid")); 178 + return HttpResponse::BadRequest() 179 + .body(html.render().expect("template should be valid")); 177 180 } 178 181 }; 179 - 182 + 180 183 // Create CallbackParams for the OAuth client 181 - let callback_params = CallbackParams { 184 + let callback_params = CallbackParams { 182 185 code, 183 186 state: params.state.clone(), 184 187 iss: params.iss.clone(), 185 188 }; 186 - 189 + 187 190 //Processes the call back and parses out a session if found and valid 188 191 match oauth_client.callback(callback_params).await { 189 192 Ok((bsky_session, _)) => { ··· 511 514 ) -> Result<impl Responder> { 512 515 // Default owner of the domain 513 516 const OWNER_HANDLE: &str = "zzstoatzz.io"; 514 - 517 + 515 518 // Resolve handle to DID using ATProto handle resolution 516 519 let atproto_handle_resolver = AtprotoHandleResolver::new(AtprotoHandleResolverConfig { 517 520 dns_txt_resolver: HickoryDnsTxtResolver::default(), ··· 574 577 db_pool: web::Data<Arc<Pool>>, 575 578 handle_resolver: web::Data<HandleResolver>, 576 579 ) -> Result<impl Responder> { 577 - let offset = query.get("offset") 580 + let offset = query 581 + .get("offset") 578 582 .and_then(|s| s.parse::<i32>().ok()) 579 583 .unwrap_or(0); 580 - let limit = query.get("limit") 584 + let limit = query 585 + .get("limit") 581 586 .and_then(|s| s.parse::<i32>().ok()) 582 587 .unwrap_or(20) 583 588 .min(50); // Cap at 50 items per request 584 - 589 + 585 590 let mut statuses = StatusFromDb::load_statuses_paginated(&db_pool, offset, limit) 586 591 .await 587 592 .unwrap_or_else(|err| { 588 593 log::error!("Error loading statuses: {err}"); 589 594 vec![] 590 595 }); 591 - 596 + 592 597 // Resolve handles for each status 593 598 let mut quick_resolve_map: HashMap<Did, String> = HashMap::new(); 594 599 for db_status in &mut statuses { ··· 616 621 Err(_) => None, 617 622 }; 618 623 } 619 - 624 + 620 625 Ok(HttpResponse::Ok().json(statuses)) 621 626 } 622 627 ··· 624 629 #[get("/api/custom-emojis")] 625 630 async fn get_custom_emojis() -> Result<impl Responder> { 626 631 use std::fs; 627 - 632 + 628 633 #[derive(Serialize)] 629 634 struct SimpleEmoji { 630 635 name: String, 631 636 filename: String, 632 637 } 633 - 638 + 634 639 let emojis_dir = "static/emojis"; 635 640 let mut emojis = Vec::new(); 636 - 641 + 637 642 if let Ok(entries) = fs::read_dir(emojis_dir) { 638 643 for entry in entries.flatten() { 639 644 if let Some(filename) = entry.file_name().to_str() { 640 645 // Only include image files 641 - if filename.ends_with(".png") || filename.ends_with(".gif") || 642 - filename.ends_with(".jpg") || filename.ends_with(".webp") { 646 + if filename.ends_with(".png") 647 + || filename.ends_with(".gif") 648 + || filename.ends_with(".jpg") 649 + || filename.ends_with(".webp") 650 + { 643 651 // Remove file extension to get name 644 - let name = filename.rsplit_once('.').map(|(name, _)| name).unwrap_or(filename).to_string(); 652 + let name = filename 653 + .rsplit_once('.') 654 + .map(|(name, _)| name) 655 + .unwrap_or(filename) 656 + .to_string(); 645 657 emojis.push(SimpleEmoji { 646 658 name: name.clone(), 647 659 filename: filename.to_string(), ··· 650 662 } 651 663 } 652 664 } 653 - 665 + 654 666 // Sort by name 655 667 emojis.sort_by(|a, b| a.name.cmp(&b.name)); 656 - 668 + 657 669 Ok(HttpResponse::Ok().json(emojis)) 658 670 } 659 671 ··· 837 849 ) 838 850 .await; 839 851 840 - let is_admin = is_admin(&did.to_string()); 852 + let is_admin = is_admin(did.as_str()); 841 853 let html = FeedTemplate { 842 854 title: TITLE, 843 855 profile: match profile { ··· 1033 1045 match session.get::<String>("did").unwrap_or(None) { 1034 1046 Some(did_string) => { 1035 1047 let did = Did::new(did_string.clone()).expect("failed to parse did"); 1036 - 1048 + 1037 1049 // Parse the URI to verify it belongs to this user 1038 1050 // URI format: at://did:plc:xxx/io.zzstoatzz.status.record/rkey 1039 1051 let uri_parts: Vec<&str> = req.uri.split('/').collect(); ··· 1042 1054 "error": "Invalid status URI format" 1043 1055 })); 1044 1056 } 1045 - 1057 + 1046 1058 // Extract DID from URI (at://did:plc:xxx/...) 1047 1059 let uri_did_part = uri_parts[2]; 1048 1060 if uri_did_part != did_string { ··· 1050 1062 "error": "You can only delete your own statuses" 1051 1063 })); 1052 1064 } 1053 - 1065 + 1054 1066 // Extract record key 1055 1067 if let Some(rkey) = uri_parts.last() { 1056 1068 // Get OAuth session 1057 1069 match oauth_client.restore(&did).await { 1058 1070 Ok(session) => { 1059 1071 let agent = Agent::new(session); 1060 - 1072 + 1061 1073 // Delete the record from ATProto 1062 1074 let delete_request = 1063 1075 atrium_api::com::atproto::repo::delete_record::InputData { ··· 1066 1078 ) 1067 1079 .expect("valid nsid"), 1068 1080 repo: did.clone().into(), 1069 - rkey: atrium_api::types::string::RecordKey::new( 1070 - rkey.to_string(), 1071 - ) 1072 - .expect("valid rkey"), 1081 + rkey: atrium_api::types::string::RecordKey::new(rkey.to_string()) 1082 + .expect("valid rkey"), 1073 1083 swap_commit: None, 1074 1084 swap_record: None, 1075 1085 }; 1076 - 1086 + 1077 1087 match agent 1078 1088 .api 1079 1089 .com ··· 1084 1094 { 1085 1095 Ok(_) => { 1086 1096 // Also remove from local database 1087 - let _ = StatusFromDb::delete_by_uri(&db_pool, req.uri.clone()).await; 1088 - 1097 + let _ = 1098 + StatusFromDb::delete_by_uri(&db_pool, req.uri.clone()).await; 1099 + 1089 1100 HttpResponse::Ok().json(serde_json::json!({ 1090 1101 "success": true 1091 1102 })) ··· 1141 1152 "error": "Admin access required" 1142 1153 })); 1143 1154 } 1144 - 1155 + 1145 1156 // Update the hidden status in the database 1146 1157 let uri = req.uri.clone(); 1147 1158 let hidden = req.hidden; 1148 - 1159 + 1149 1160 let result = db_pool 1150 1161 .conn(move |conn| { 1151 1162 conn.execute( ··· 1154 1165 ) 1155 1166 }) 1156 1167 .await; 1157 - 1168 + 1158 1169 match result { 1159 1170 Ok(rows_affected) if rows_affected > 0 => { 1160 1171 HttpResponse::Ok().json(serde_json::json!({ ··· 1162 1173 "message": if hidden { "Status hidden" } else { "Status unhidden" } 1163 1174 })) 1164 1175 } 1165 - Ok(_) => { 1166 - HttpResponse::NotFound().json(serde_json::json!({ 1167 - "error": "Status not found" 1168 - })) 1169 - } 1176 + Ok(_) => HttpResponse::NotFound().json(serde_json::json!({ 1177 + "error": "Status not found" 1178 + })), 1170 1179 Err(err) => { 1171 1180 log::error!("Error updating hidden status: {}", err); 1172 1181 HttpResponse::InternalServerError().json(serde_json::json!({ ··· 1175 1184 } 1176 1185 } 1177 1186 } 1178 - None => { 1179 - HttpResponse::Unauthorized().json(serde_json::json!({ 1180 - "error": "Not authenticated" 1181 - })) 1182 - } 1187 + None => HttpResponse::Unauthorized().json(serde_json::json!({ 1188 + "error": "Not authenticated" 1189 + })), 1183 1190 } 1184 1191 } 1185 1192 ··· 1256 1263 did_string, 1257 1264 form.status.clone(), 1258 1265 ); 1259 - 1266 + 1260 1267 // Set the text field if provided 1261 1268 status.text = form.text.clone(); 1262 - 1269 + 1263 1270 // Set the expiration time if provided 1264 1271 if let Some(exp_str) = &form.expires_in { 1265 1272 if let Some(duration) = parse_duration(exp_str) { ··· 1295 1302 } 1296 1303 } 1297 1304 } 1298 - None => { 1299 - Err(AppError::AuthenticationError("You must be logged in to create a status.".to_string())) 1300 - } 1305 + None => Err(AppError::AuthenticationError( 1306 + "You must be logged in to create a status.".to_string(), 1307 + )), 1301 1308 } 1302 1309 } 1303 1310 1304 1311 #[actix_web::main] 1305 1312 async fn main() -> std::io::Result<()> { 1306 1313 dotenv().ok(); 1307 - 1314 + 1308 1315 // Load configuration 1309 1316 let config = config::Config::from_env().expect("Failed to load configuration"); 1310 1317 let app_config = config.clone(); 1311 - 1318 + 1312 1319 env_logger::init_from_env(env_logger::Env::new().default_filter_or(&config.log_level)); 1313 1320 let host = config.server_host.clone(); 1314 1321 let port = config.server_port; 1315 1322 1316 1323 // Use database URL from config 1317 1324 let db_connection_string = if config.database_url.starts_with("sqlite://") { 1318 - config.database_url.strip_prefix("sqlite://").unwrap_or(&config.database_url).to_string() 1325 + config 1326 + .database_url 1327 + .strip_prefix("sqlite://") 1328 + .unwrap_or(&config.database_url) 1329 + .to_string() 1319 1330 } else { 1320 1331 config.database_url.clone() 1321 1332 }; ··· 1348 1359 1349 1360 // Create a new OAuth client 1350 1361 let http_client = Arc::new(DefaultHttpClient::default()); 1351 - 1362 + 1352 1363 // Check if we're running in production (non-localhost) or locally 1353 - let is_production = !config.oauth_redirect_base.starts_with("http://localhost") 1364 + let is_production = !config.oauth_redirect_base.starts_with("http://localhost") 1354 1365 && !config.oauth_redirect_base.starts_with("http://127.0.0.1"); 1355 - 1366 + 1356 1367 let client: OAuthClientType = if is_production { 1357 1368 // Production configuration with AtprotoClientMetadata 1358 - log::info!("Configuring OAuth for production with URL: {}", config.oauth_redirect_base); 1359 - 1369 + log::info!( 1370 + "Configuring OAuth for production with URL: {}", 1371 + config.oauth_redirect_base 1372 + ); 1373 + 1360 1374 let oauth_config = OAuthClientConfig { 1361 1375 client_metadata: AtprotoClientMetadata { 1362 1376 client_id: format!("{}/client-metadata.json", config.oauth_redirect_base), ··· 1390 1404 Arc::new(OAuthClient::new(oauth_config).expect("failed to create OAuth client")) 1391 1405 } else { 1392 1406 // Local development configuration with AtprotoLocalhostClientMetadata 1393 - log::info!("Configuring OAuth for local development at {}:{}", host, port); 1394 - 1407 + log::info!( 1408 + "Configuring OAuth for local development at {}:{}", 1409 + host, 1410 + port 1411 + ); 1412 + 1395 1413 let oauth_config = OAuthClientConfig { 1396 1414 client_metadata: AtprotoLocalhostClientMetadata { 1397 1415 redirect_uris: Some(vec![format!( ··· 1433 1451 log::info!("Jetstream firehose disabled (set ENABLE_FIREHOSE=true to enable)"); 1434 1452 } 1435 1453 let arc_pool = Arc::new(pool.clone()); 1436 - 1454 + 1437 1455 // Create rate limiter - 30 requests per minute per IP 1438 1456 let rate_limiter = web::Data::new(RateLimiter::new(30, Duration::from_secs(60))); 1439 - 1457 + 1440 1458 log::info!("starting HTTP server at http://{host}:{port}"); 1441 1459 HttpServer::new(move || { 1442 1460 App::new() ··· 1484 1502 #[cfg(test)] 1485 1503 mod tests { 1486 1504 use super::*; 1487 - use actix_web::{test, App}; 1505 + use actix_web::{App, test}; 1488 1506 1489 1507 #[actix_web::test] 1490 1508 async fn test_health_check() { ··· 1495 1513 #[actix_web::test] 1496 1514 async fn test_custom_emojis_endpoint() { 1497 1515 // Test that the custom emojis endpoint returns JSON 1498 - let app = test::init_service( 1499 - App::new() 1500 - .service(get_custom_emojis) 1501 - ).await; 1516 + let app = test::init_service(App::new().service(get_custom_emojis)).await; 1502 1517 1503 1518 let req = test::TestRequest::get() 1504 1519 .uri("/api/custom-emojis") 1505 1520 .to_request(); 1506 - 1521 + 1507 1522 let resp = test::call_service(&app, req).await; 1508 1523 assert!(resp.status().is_success()); 1509 1524 } ··· 1512 1527 async fn test_rate_limiting() { 1513 1528 // Simple test of the rate limiter directly 1514 1529 let rate_limiter = RateLimiter::new(3, Duration::from_secs(60)); 1515 - 1530 + 1516 1531 // Should allow first 3 requests from same IP 1517 1532 for i in 0..3 { 1518 - assert!(rate_limiter.check_rate_limit("test_ip"), 1519 - "Request {} should be allowed", i + 1); 1533 + assert!( 1534 + rate_limiter.check_rate_limit("test_ip"), 1535 + "Request {} should be allowed", 1536 + i + 1 1537 + ); 1520 1538 } 1521 - 1539 + 1522 1540 // 4th request should be blocked 1523 - assert!(!rate_limiter.check_rate_limit("test_ip"), 1524 - "4th request should be blocked"); 1525 - 1541 + assert!( 1542 + !rate_limiter.check_rate_limit("test_ip"), 1543 + "4th request should be blocked" 1544 + ); 1545 + 1526 1546 // Different IP should have its own limit 1527 - assert!(rate_limiter.check_rate_limit("different_ip"), 1528 - "Different IP should have its own rate limit"); 1547 + assert!( 1548 + rate_limiter.check_rate_limit("different_ip"), 1549 + "Different IP should have its own rate limit" 1550 + ); 1529 1551 } 1530 - 1552 + 1531 1553 #[actix_web::test] 1532 1554 async fn test_error_handling() { 1533 1555 use crate::error_handler::AppError; 1534 - use actix_web::{http::StatusCode, ResponseError}; 1535 - 1556 + use actix_web::{ResponseError, http::StatusCode}; 1557 + 1536 1558 // Test that our error types return correct status codes 1537 1559 let err = AppError::ValidationError("test".to_string()); 1538 1560 assert_eq!(err.status_code(), StatusCode::BAD_REQUEST); 1539 - 1561 + 1540 1562 let err = AppError::RateLimitExceeded; 1541 1563 assert_eq!(err.status_code(), StatusCode::TOO_MANY_REQUESTS); 1542 - 1564 + 1543 1565 let err = AppError::AuthenticationError("test".to_string()); 1544 1566 assert_eq!(err.status_code(), StatusCode::UNAUTHORIZED); 1545 1567 }
+12 -11
src/rate_limiter.rs
··· 27 27 pub fn check_rate_limit(&self, key: &str) -> bool { 28 28 let mut buckets = self.buckets.lock().unwrap(); 29 29 let now = Instant::now(); 30 - 30 + 31 31 let bucket = buckets.entry(key.to_string()).or_insert(TokenBucket { 32 32 tokens: self.max_tokens, 33 33 last_refill: now, ··· 35 35 36 36 // Refill tokens based on elapsed time 37 37 let elapsed = now.duration_since(bucket.last_refill); 38 - let tokens_to_add = (elapsed.as_secs_f64() / self.refill_rate.as_secs_f64() * self.max_tokens as f64) as u32; 39 - 38 + let tokens_to_add = (elapsed.as_secs_f64() / self.refill_rate.as_secs_f64() 39 + * self.max_tokens as f64) as u32; 40 + 40 41 if tokens_to_add > 0 { 41 42 bucket.tokens = (bucket.tokens + tokens_to_add).min(self.max_tokens); 42 43 bucket.last_refill = now; ··· 68 69 #[test] 69 70 fn test_rate_limiter_basic() { 70 71 let limiter = RateLimiter::new(5, Duration::from_secs(1)); 71 - 72 + 72 73 // Should allow first 5 requests 73 74 for _ in 0..5 { 74 75 assert!(limiter.check_rate_limit("test_client")); 75 76 } 76 - 77 + 77 78 // 6th request should be blocked 78 79 assert!(!limiter.check_rate_limit("test_client")); 79 80 } ··· 81 82 #[test] 82 83 fn test_rate_limiter_refill() { 83 84 let limiter = RateLimiter::new(2, Duration::from_millis(100)); 84 - 85 + 85 86 // Use up tokens 86 87 assert!(limiter.check_rate_limit("test_client")); 87 88 assert!(limiter.check_rate_limit("test_client")); 88 89 assert!(!limiter.check_rate_limit("test_client")); 89 - 90 + 90 91 // Wait for refill 91 92 thread::sleep(Duration::from_millis(150)); 92 - 93 + 93 94 // Should have tokens again 94 95 assert!(limiter.check_rate_limit("test_client")); 95 96 } ··· 97 98 #[test] 98 99 fn test_rate_limiter_different_clients() { 99 100 let limiter = RateLimiter::new(1, Duration::from_secs(1)); 100 - 101 + 101 102 // Different clients should have separate buckets 102 103 assert!(limiter.check_rate_limit("client1")); 103 104 assert!(limiter.check_rate_limit("client2")); 104 - 105 + 105 106 // But same client should be limited 106 107 assert!(!limiter.check_rate_limit("client1")); 107 108 assert!(!limiter.check_rate_limit("client2")); 108 109 } 109 - } 110 + }