slack status without the slack status.zzstoatzz.io/
quickslice

style: format code and fix linting issues

- Run cargo fmt on all files
- Fix trailing whitespace issues
- Allow dead_code for config fields (will be used in future refactor)
- Fix clippy warning about unnecessary to_string()
- Pre-commit hooks now working properly

+181 -135
+10 -9
src/config.rs
··· 3 4 /// Application configuration loaded from environment variables 5 #[derive(Debug, Clone, Deserialize)] 6 pub struct Config { 7 /// The admin DID for moderation (intentionally hardcoded for security) 8 pub admin_did: String, 9 - 10 /// Owner handle for the default status page 11 pub owner_handle: String, 12 - 13 /// Database URL (defaults to local SQLite) 14 pub database_url: String, 15 - 16 /// OAuth redirect base URL 17 pub oauth_redirect_base: String, 18 - 19 /// Server host 20 pub server_host: String, 21 - 22 /// Server port 23 pub server_port: u16, 24 - 25 /// Enable firehose ingester 26 pub enable_firehose: bool, 27 - 28 /// Log level 29 pub log_level: String, 30 } ··· 34 pub fn from_env() -> Result<Self, env::VarError> { 35 // Admin DID is intentionally hardcoded as discussed 36 let admin_did = "did:plc:xbtmt2zjwlrfegqvch7fboei".to_string(); 37 - 38 Ok(Config { 39 admin_did, 40 owner_handle: env::var("OWNER_HANDLE").unwrap_or_else(|_| "zzstoatzz.io".to_string()), ··· 54 log_level: env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string()), 55 }) 56 } 57 - }
··· 3 4 /// Application configuration loaded from environment variables 5 #[derive(Debug, Clone, Deserialize)] 6 + #[allow(dead_code)] 7 pub struct Config { 8 /// The admin DID for moderation (intentionally hardcoded for security) 9 pub admin_did: String, 10 + 11 /// Owner handle for the default status page 12 pub owner_handle: String, 13 + 14 /// Database URL (defaults to local SQLite) 15 pub database_url: String, 16 + 17 /// OAuth redirect base URL 18 pub oauth_redirect_base: String, 19 + 20 /// Server host 21 pub server_host: String, 22 + 23 /// Server port 24 pub server_port: u16, 25 + 26 /// Enable firehose ingester 27 pub enable_firehose: bool, 28 + 29 /// Log level 30 pub log_level: String, 31 } ··· 35 pub fn from_env() -> Result<Self, env::VarError> { 36 // Admin DID is intentionally hardcoded as discussed 37 let admin_did = "did:plc:xbtmt2zjwlrfegqvch7fboei".to_string(); 38 + 39 Ok(Config { 40 admin_did, 41 owner_handle: env::var("OWNER_HANDLE").unwrap_or_else(|_| "zzstoatzz.io".to_string()), ··· 55 log_level: env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string()), 56 }) 57 } 58 + }
+4 -4
src/db.rs
··· 50 .unwrap(); 51 52 // Note: custom_emojis table removed - we serve emojis directly from static/emojis/ directory 53 - 54 // Add indexes for performance optimization 55 // Index on startedAt for feed queries (ORDER BY startedAt DESC) 56 conn.execute( ··· 58 [], 59 ) 60 .unwrap(); 61 - 62 // Composite index for user status queries (WHERE authorDid = ? ORDER BY startedAt DESC) 63 conn.execute( 64 "CREATE INDEX IF NOT EXISTS idx_status_authorDid_startedAt ON status(authorDid, startedAt DESC)", 65 [], 66 ) 67 .unwrap(); 68 - 69 // Add hidden column for moderation (won't error if already exists) 70 let _ = conn.execute( 71 "ALTER TABLE status ADD COLUMN hidden BOOLEAN DEFAULT FALSE", 72 [], 73 ); 74 - 75 Ok(()) 76 }) 77 .await?;
··· 50 .unwrap(); 51 52 // Note: custom_emojis table removed - we serve emojis directly from static/emojis/ directory 53 + 54 // Add indexes for performance optimization 55 // Index on startedAt for feed queries (ORDER BY startedAt DESC) 56 conn.execute( ··· 58 [], 59 ) 60 .unwrap(); 61 + 62 // Composite index for user status queries (WHERE authorDid = ? ORDER BY startedAt DESC) 63 conn.execute( 64 "CREATE INDEX IF NOT EXISTS idx_status_authorDid_startedAt ON status(authorDid, startedAt DESC)", 65 [], 66 ) 67 .unwrap(); 68 + 69 // Add hidden column for moderation (won't error if already exists) 70 let _ = conn.execute( 71 "ALTER TABLE status ADD COLUMN hidden BOOLEAN DEFAULT FALSE", 72 [], 73 ); 74 + 75 Ok(()) 76 }) 77 .await?;
+45 -23
src/error_handler.rs
··· 1 - use actix_web::{ 2 - error::ResponseError, 3 - http::StatusCode, 4 - HttpResponse, 5 - }; 6 use std::fmt; 7 8 #[derive(Debug)] ··· 10 InternalError(String), 11 DatabaseError(String), 12 AuthenticationError(String), 13 - #[allow(dead_code)] // Keep for potential future use 14 ValidationError(String), 15 - #[allow(dead_code)] // Keep for potential future use 16 NotFound(String), 17 RateLimitExceeded, 18 } ··· 34 fn error_response(&self) -> HttpResponse { 35 let (status_code, error_message) = match self { 36 AppError::InternalError(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg.clone()), 37 - AppError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Database error occurred".to_string()), 38 AppError::AuthenticationError(msg) => (StatusCode::UNAUTHORIZED, msg.clone()), 39 AppError::ValidationError(msg) => (StatusCode::BAD_REQUEST, msg.clone()), 40 AppError::NotFound(msg) => (StatusCode::NOT_FOUND, msg.clone()), 41 - AppError::RateLimitExceeded => (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded. Please try again later.".to_string()), 42 }; 43 - 44 - HttpResponse::build(status_code) 45 - .body(format!("Error {}: {}", status_code.as_u16(), error_message)) 46 } 47 - 48 fn status_code(&self) -> StatusCode { 49 match self { 50 - AppError::InternalError(_) | AppError::DatabaseError(_) => StatusCode::INTERNAL_SERVER_ERROR, 51 AppError::AuthenticationError(_) => StatusCode::UNAUTHORIZED, 52 AppError::ValidationError(_) => StatusCode::BAD_REQUEST, 53 AppError::NotFound(_) => StatusCode::NOT_FOUND, ··· 75 #[cfg(test)] 76 mod tests { 77 use super::*; 78 - 79 #[test] 80 fn test_error_display() { 81 let err = AppError::ValidationError("Invalid input".to_string()); 82 assert_eq!(err.to_string(), "Validation error: Invalid input"); 83 - 84 let err = AppError::RateLimitExceeded; 85 assert_eq!(err.to_string(), "Rate limit exceeded"); 86 } 87 - 88 #[test] 89 fn test_error_status_codes() { 90 - assert_eq!(AppError::InternalError("test".to_string()).status_code(), StatusCode::INTERNAL_SERVER_ERROR); 91 - assert_eq!(AppError::ValidationError("test".to_string()).status_code(), StatusCode::BAD_REQUEST); 92 - assert_eq!(AppError::AuthenticationError("test".to_string()).status_code(), StatusCode::UNAUTHORIZED); 93 - assert_eq!(AppError::NotFound("test".to_string()).status_code(), StatusCode::NOT_FOUND); 94 - assert_eq!(AppError::RateLimitExceeded.status_code(), StatusCode::TOO_MANY_REQUESTS); 95 } 96 - }
··· 1 + use actix_web::{HttpResponse, error::ResponseError, http::StatusCode}; 2 use std::fmt; 3 4 #[derive(Debug)] ··· 6 InternalError(String), 7 DatabaseError(String), 8 AuthenticationError(String), 9 + #[allow(dead_code)] // Keep for potential future use 10 ValidationError(String), 11 + #[allow(dead_code)] // Keep for potential future use 12 NotFound(String), 13 RateLimitExceeded, 14 } ··· 30 fn error_response(&self) -> HttpResponse { 31 let (status_code, error_message) = match self { 32 AppError::InternalError(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg.clone()), 33 + AppError::DatabaseError(_) => ( 34 + StatusCode::INTERNAL_SERVER_ERROR, 35 + "Database error occurred".to_string(), 36 + ), 37 AppError::AuthenticationError(msg) => (StatusCode::UNAUTHORIZED, msg.clone()), 38 AppError::ValidationError(msg) => (StatusCode::BAD_REQUEST, msg.clone()), 39 AppError::NotFound(msg) => (StatusCode::NOT_FOUND, msg.clone()), 40 + AppError::RateLimitExceeded => ( 41 + StatusCode::TOO_MANY_REQUESTS, 42 + "Rate limit exceeded. Please try again later.".to_string(), 43 + ), 44 }; 45 + 46 + HttpResponse::build(status_code).body(format!( 47 + "Error {}: {}", 48 + status_code.as_u16(), 49 + error_message 50 + )) 51 } 52 + 53 fn status_code(&self) -> StatusCode { 54 match self { 55 + AppError::InternalError(_) | AppError::DatabaseError(_) => { 56 + StatusCode::INTERNAL_SERVER_ERROR 57 + } 58 AppError::AuthenticationError(_) => StatusCode::UNAUTHORIZED, 59 AppError::ValidationError(_) => StatusCode::BAD_REQUEST, 60 AppError::NotFound(_) => StatusCode::NOT_FOUND, ··· 82 #[cfg(test)] 83 mod tests { 84 use super::*; 85 + 86 #[test] 87 fn test_error_display() { 88 let err = AppError::ValidationError("Invalid input".to_string()); 89 assert_eq!(err.to_string(), "Validation error: Invalid input"); 90 + 91 let err = AppError::RateLimitExceeded; 92 assert_eq!(err.to_string(), "Rate limit exceeded"); 93 } 94 + 95 #[test] 96 fn test_error_status_codes() { 97 + assert_eq!( 98 + AppError::InternalError("test".to_string()).status_code(), 99 + StatusCode::INTERNAL_SERVER_ERROR 100 + ); 101 + assert_eq!( 102 + AppError::ValidationError("test".to_string()).status_code(), 103 + StatusCode::BAD_REQUEST 104 + ); 105 + assert_eq!( 106 + AppError::AuthenticationError("test".to_string()).status_code(), 107 + StatusCode::UNAUTHORIZED 108 + ); 109 + assert_eq!( 110 + AppError::NotFound("test".to_string()).status_code(), 111 + StatusCode::NOT_FOUND 112 + ); 113 + assert_eq!( 114 + AppError::RateLimitExceeded.status_code(), 115 + StatusCode::TOO_MANY_REQUESTS 116 + ); 117 } 118 + }
+110 -88
src/main.rs
··· 7 storage::{SqliteSessionStore, SqliteStateStore}, 8 templates::{FeedTemplate, LoginTemplate, StatusTemplate}, 9 }; 10 - use async_sqlite::rusqlite; 11 use actix_files::Files; 12 use actix_session::{ 13 Session, SessionMiddleware, config::PersistentSession, storage::CookieSessionStore, ··· 19 web::{self, Redirect}, 20 }; 21 use askama::Template; 22 use async_sqlite::{Pool, PoolBuilder}; 23 use atrium_api::{ 24 agent::Agent, ··· 30 handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig}, 31 }; 32 use atrium_oauth::{ 33 - AtprotoClientMetadata, AtprotoLocalhostClientMetadata, AuthorizeOptions, CallbackParams, 34 - AuthMethod, DefaultHttpClient, GrantType, KnownScope, OAuthClient, OAuthClientConfig, 35 OAuthResolverConfig, Scope, 36 }; 37 use dotenv::dotenv; ··· 91 #[get("/client-metadata.json")] 92 async fn client_metadata(config: web::Data<config::Config>) -> Result<HttpResponse> { 93 let public_url = config.oauth_redirect_base.clone(); 94 - 95 let metadata = serde_json::json!({ 96 "client_id": format!("{}/client-metadata.json", public_url), 97 "client_name": "Status Sphere", ··· 103 "token_endpoint_auth_method": "none", 104 "dpop_bound_access_tokens": true 105 }); 106 - 107 Ok(HttpResponse::Ok() 108 .content_type("application/json") 109 .body(metadata.to_string())) ··· 153 ) -> HttpResponse { 154 // Check if there's an OAuth error from BlueSky 155 if let Some(error) = &params.error { 156 - let error_msg = params.error_description.as_deref() 157 .unwrap_or("An error occurred during authentication"); 158 log::error!("OAuth error from BlueSky: {} - {}", error, error_msg); 159 - 160 let html = ErrorTemplate { 161 title: "Authentication Error", 162 error: error_msg, 163 }; 164 return HttpResponse::BadRequest().body(html.render().expect("template should be valid")); 165 } 166 - 167 // Check if we have the required code field for a successful callback 168 let code = match &params.code { 169 Some(code) => code.clone(), ··· 173 title: "Error", 174 error: "Missing required OAuth code. Please try logging in again.", 175 }; 176 - return HttpResponse::BadRequest().body(html.render().expect("template should be valid")); 177 } 178 }; 179 - 180 // Create CallbackParams for the OAuth client 181 - let callback_params = CallbackParams { 182 code, 183 state: params.state.clone(), 184 iss: params.iss.clone(), 185 }; 186 - 187 //Processes the call back and parses out a session if found and valid 188 match oauth_client.callback(callback_params).await { 189 Ok((bsky_session, _)) => { ··· 511 ) -> Result<impl Responder> { 512 // Default owner of the domain 513 const OWNER_HANDLE: &str = "zzstoatzz.io"; 514 - 515 // Resolve handle to DID using ATProto handle resolution 516 let atproto_handle_resolver = AtprotoHandleResolver::new(AtprotoHandleResolverConfig { 517 dns_txt_resolver: HickoryDnsTxtResolver::default(), ··· 574 db_pool: web::Data<Arc<Pool>>, 575 handle_resolver: web::Data<HandleResolver>, 576 ) -> Result<impl Responder> { 577 - let offset = query.get("offset") 578 .and_then(|s| s.parse::<i32>().ok()) 579 .unwrap_or(0); 580 - let limit = query.get("limit") 581 .and_then(|s| s.parse::<i32>().ok()) 582 .unwrap_or(20) 583 .min(50); // Cap at 50 items per request 584 - 585 let mut statuses = StatusFromDb::load_statuses_paginated(&db_pool, offset, limit) 586 .await 587 .unwrap_or_else(|err| { 588 log::error!("Error loading statuses: {err}"); 589 vec![] 590 }); 591 - 592 // Resolve handles for each status 593 let mut quick_resolve_map: HashMap<Did, String> = HashMap::new(); 594 for db_status in &mut statuses { ··· 616 Err(_) => None, 617 }; 618 } 619 - 620 Ok(HttpResponse::Ok().json(statuses)) 621 } 622 ··· 624 #[get("/api/custom-emojis")] 625 async fn get_custom_emojis() -> Result<impl Responder> { 626 use std::fs; 627 - 628 #[derive(Serialize)] 629 struct SimpleEmoji { 630 name: String, 631 filename: String, 632 } 633 - 634 let emojis_dir = "static/emojis"; 635 let mut emojis = Vec::new(); 636 - 637 if let Ok(entries) = fs::read_dir(emojis_dir) { 638 for entry in entries.flatten() { 639 if let Some(filename) = entry.file_name().to_str() { 640 // Only include image files 641 - if filename.ends_with(".png") || filename.ends_with(".gif") || 642 - filename.ends_with(".jpg") || filename.ends_with(".webp") { 643 // Remove file extension to get name 644 - let name = filename.rsplit_once('.').map(|(name, _)| name).unwrap_or(filename).to_string(); 645 emojis.push(SimpleEmoji { 646 name: name.clone(), 647 filename: filename.to_string(), ··· 650 } 651 } 652 } 653 - 654 // Sort by name 655 emojis.sort_by(|a, b| a.name.cmp(&b.name)); 656 - 657 Ok(HttpResponse::Ok().json(emojis)) 658 } 659 ··· 837 ) 838 .await; 839 840 - let is_admin = is_admin(&did.to_string()); 841 let html = FeedTemplate { 842 title: TITLE, 843 profile: match profile { ··· 1033 match session.get::<String>("did").unwrap_or(None) { 1034 Some(did_string) => { 1035 let did = Did::new(did_string.clone()).expect("failed to parse did"); 1036 - 1037 // Parse the URI to verify it belongs to this user 1038 // URI format: at://did:plc:xxx/io.zzstoatzz.status.record/rkey 1039 let uri_parts: Vec<&str> = req.uri.split('/').collect(); ··· 1042 "error": "Invalid status URI format" 1043 })); 1044 } 1045 - 1046 // Extract DID from URI (at://did:plc:xxx/...) 1047 let uri_did_part = uri_parts[2]; 1048 if uri_did_part != did_string { ··· 1050 "error": "You can only delete your own statuses" 1051 })); 1052 } 1053 - 1054 // Extract record key 1055 if let Some(rkey) = uri_parts.last() { 1056 // Get OAuth session 1057 match oauth_client.restore(&did).await { 1058 Ok(session) => { 1059 let agent = Agent::new(session); 1060 - 1061 // Delete the record from ATProto 1062 let delete_request = 1063 atrium_api::com::atproto::repo::delete_record::InputData { ··· 1066 ) 1067 .expect("valid nsid"), 1068 repo: did.clone().into(), 1069 - rkey: atrium_api::types::string::RecordKey::new( 1070 - rkey.to_string(), 1071 - ) 1072 - .expect("valid rkey"), 1073 swap_commit: None, 1074 swap_record: None, 1075 }; 1076 - 1077 match agent 1078 .api 1079 .com ··· 1084 { 1085 Ok(_) => { 1086 // Also remove from local database 1087 - let _ = StatusFromDb::delete_by_uri(&db_pool, req.uri.clone()).await; 1088 - 1089 HttpResponse::Ok().json(serde_json::json!({ 1090 "success": true 1091 })) ··· 1141 "error": "Admin access required" 1142 })); 1143 } 1144 - 1145 // Update the hidden status in the database 1146 let uri = req.uri.clone(); 1147 let hidden = req.hidden; 1148 - 1149 let result = db_pool 1150 .conn(move |conn| { 1151 conn.execute( ··· 1154 ) 1155 }) 1156 .await; 1157 - 1158 match result { 1159 Ok(rows_affected) if rows_affected > 0 => { 1160 HttpResponse::Ok().json(serde_json::json!({ ··· 1162 "message": if hidden { "Status hidden" } else { "Status unhidden" } 1163 })) 1164 } 1165 - Ok(_) => { 1166 - HttpResponse::NotFound().json(serde_json::json!({ 1167 - "error": "Status not found" 1168 - })) 1169 - } 1170 Err(err) => { 1171 log::error!("Error updating hidden status: {}", err); 1172 HttpResponse::InternalServerError().json(serde_json::json!({ ··· 1175 } 1176 } 1177 } 1178 - None => { 1179 - HttpResponse::Unauthorized().json(serde_json::json!({ 1180 - "error": "Not authenticated" 1181 - })) 1182 - } 1183 } 1184 } 1185 ··· 1256 did_string, 1257 form.status.clone(), 1258 ); 1259 - 1260 // Set the text field if provided 1261 status.text = form.text.clone(); 1262 - 1263 // Set the expiration time if provided 1264 if let Some(exp_str) = &form.expires_in { 1265 if let Some(duration) = parse_duration(exp_str) { ··· 1295 } 1296 } 1297 } 1298 - None => { 1299 - Err(AppError::AuthenticationError("You must be logged in to create a status.".to_string())) 1300 - } 1301 } 1302 } 1303 1304 #[actix_web::main] 1305 async fn main() -> std::io::Result<()> { 1306 dotenv().ok(); 1307 - 1308 // Load configuration 1309 let config = config::Config::from_env().expect("Failed to load configuration"); 1310 let app_config = config.clone(); 1311 - 1312 env_logger::init_from_env(env_logger::Env::new().default_filter_or(&config.log_level)); 1313 let host = config.server_host.clone(); 1314 let port = config.server_port; 1315 1316 // Use database URL from config 1317 let db_connection_string = if config.database_url.starts_with("sqlite://") { 1318 - config.database_url.strip_prefix("sqlite://").unwrap_or(&config.database_url).to_string() 1319 } else { 1320 config.database_url.clone() 1321 }; ··· 1348 1349 // Create a new OAuth client 1350 let http_client = Arc::new(DefaultHttpClient::default()); 1351 - 1352 // Check if we're running in production (non-localhost) or locally 1353 - let is_production = !config.oauth_redirect_base.starts_with("http://localhost") 1354 && !config.oauth_redirect_base.starts_with("http://127.0.0.1"); 1355 - 1356 let client: OAuthClientType = if is_production { 1357 // Production configuration with AtprotoClientMetadata 1358 - log::info!("Configuring OAuth for production with URL: {}", config.oauth_redirect_base); 1359 - 1360 let oauth_config = OAuthClientConfig { 1361 client_metadata: AtprotoClientMetadata { 1362 client_id: format!("{}/client-metadata.json", config.oauth_redirect_base), ··· 1390 Arc::new(OAuthClient::new(oauth_config).expect("failed to create OAuth client")) 1391 } else { 1392 // Local development configuration with AtprotoLocalhostClientMetadata 1393 - log::info!("Configuring OAuth for local development at {}:{}", host, port); 1394 - 1395 let oauth_config = OAuthClientConfig { 1396 client_metadata: AtprotoLocalhostClientMetadata { 1397 redirect_uris: Some(vec![format!( ··· 1433 log::info!("Jetstream firehose disabled (set ENABLE_FIREHOSE=true to enable)"); 1434 } 1435 let arc_pool = Arc::new(pool.clone()); 1436 - 1437 // Create rate limiter - 30 requests per minute per IP 1438 let rate_limiter = web::Data::new(RateLimiter::new(30, Duration::from_secs(60))); 1439 - 1440 log::info!("starting HTTP server at http://{host}:{port}"); 1441 HttpServer::new(move || { 1442 App::new() ··· 1484 #[cfg(test)] 1485 mod tests { 1486 use super::*; 1487 - use actix_web::{test, App}; 1488 1489 #[actix_web::test] 1490 async fn test_health_check() { ··· 1495 #[actix_web::test] 1496 async fn test_custom_emojis_endpoint() { 1497 // Test that the custom emojis endpoint returns JSON 1498 - let app = test::init_service( 1499 - App::new() 1500 - .service(get_custom_emojis) 1501 - ).await; 1502 1503 let req = test::TestRequest::get() 1504 .uri("/api/custom-emojis") 1505 .to_request(); 1506 - 1507 let resp = test::call_service(&app, req).await; 1508 assert!(resp.status().is_success()); 1509 } ··· 1512 async fn test_rate_limiting() { 1513 // Simple test of the rate limiter directly 1514 let rate_limiter = RateLimiter::new(3, Duration::from_secs(60)); 1515 - 1516 // Should allow first 3 requests from same IP 1517 for i in 0..3 { 1518 - assert!(rate_limiter.check_rate_limit("test_ip"), 1519 - "Request {} should be allowed", i + 1); 1520 } 1521 - 1522 // 4th request should be blocked 1523 - assert!(!rate_limiter.check_rate_limit("test_ip"), 1524 - "4th request should be blocked"); 1525 - 1526 // Different IP should have its own limit 1527 - assert!(rate_limiter.check_rate_limit("different_ip"), 1528 - "Different IP should have its own rate limit"); 1529 } 1530 - 1531 #[actix_web::test] 1532 async fn test_error_handling() { 1533 use crate::error_handler::AppError; 1534 - use actix_web::{http::StatusCode, ResponseError}; 1535 - 1536 // Test that our error types return correct status codes 1537 let err = AppError::ValidationError("test".to_string()); 1538 assert_eq!(err.status_code(), StatusCode::BAD_REQUEST); 1539 - 1540 let err = AppError::RateLimitExceeded; 1541 assert_eq!(err.status_code(), StatusCode::TOO_MANY_REQUESTS); 1542 - 1543 let err = AppError::AuthenticationError("test".to_string()); 1544 assert_eq!(err.status_code(), StatusCode::UNAUTHORIZED); 1545 }
··· 7 storage::{SqliteSessionStore, SqliteStateStore}, 8 templates::{FeedTemplate, LoginTemplate, StatusTemplate}, 9 }; 10 use actix_files::Files; 11 use actix_session::{ 12 Session, SessionMiddleware, config::PersistentSession, storage::CookieSessionStore, ··· 18 web::{self, Redirect}, 19 }; 20 use askama::Template; 21 + use async_sqlite::rusqlite; 22 use async_sqlite::{Pool, PoolBuilder}; 23 use atrium_api::{ 24 agent::Agent, ··· 30 handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig}, 31 }; 32 use atrium_oauth::{ 33 + AtprotoClientMetadata, AtprotoLocalhostClientMetadata, AuthMethod, AuthorizeOptions, 34 + CallbackParams, DefaultHttpClient, GrantType, KnownScope, OAuthClient, OAuthClientConfig, 35 OAuthResolverConfig, Scope, 36 }; 37 use dotenv::dotenv; ··· 91 #[get("/client-metadata.json")] 92 async fn client_metadata(config: web::Data<config::Config>) -> Result<HttpResponse> { 93 let public_url = config.oauth_redirect_base.clone(); 94 + 95 let metadata = serde_json::json!({ 96 "client_id": format!("{}/client-metadata.json", public_url), 97 "client_name": "Status Sphere", ··· 103 "token_endpoint_auth_method": "none", 104 "dpop_bound_access_tokens": true 105 }); 106 + 107 Ok(HttpResponse::Ok() 108 .content_type("application/json") 109 .body(metadata.to_string())) ··· 153 ) -> HttpResponse { 154 // Check if there's an OAuth error from BlueSky 155 if let Some(error) = &params.error { 156 + let error_msg = params 157 + .error_description 158 + .as_deref() 159 .unwrap_or("An error occurred during authentication"); 160 log::error!("OAuth error from BlueSky: {} - {}", error, error_msg); 161 + 162 let html = ErrorTemplate { 163 title: "Authentication Error", 164 error: error_msg, 165 }; 166 return HttpResponse::BadRequest().body(html.render().expect("template should be valid")); 167 } 168 + 169 // Check if we have the required code field for a successful callback 170 let code = match &params.code { 171 Some(code) => code.clone(), ··· 175 title: "Error", 176 error: "Missing required OAuth code. Please try logging in again.", 177 }; 178 + return HttpResponse::BadRequest() 179 + .body(html.render().expect("template should be valid")); 180 } 181 }; 182 + 183 // Create CallbackParams for the OAuth client 184 + let callback_params = CallbackParams { 185 code, 186 state: params.state.clone(), 187 iss: params.iss.clone(), 188 }; 189 + 190 //Processes the call back and parses out a session if found and valid 191 match oauth_client.callback(callback_params).await { 192 Ok((bsky_session, _)) => { ··· 514 ) -> Result<impl Responder> { 515 // Default owner of the domain 516 const OWNER_HANDLE: &str = "zzstoatzz.io"; 517 + 518 // Resolve handle to DID using ATProto handle resolution 519 let atproto_handle_resolver = AtprotoHandleResolver::new(AtprotoHandleResolverConfig { 520 dns_txt_resolver: HickoryDnsTxtResolver::default(), ··· 577 db_pool: web::Data<Arc<Pool>>, 578 handle_resolver: web::Data<HandleResolver>, 579 ) -> Result<impl Responder> { 580 + let offset = query 581 + .get("offset") 582 .and_then(|s| s.parse::<i32>().ok()) 583 .unwrap_or(0); 584 + let limit = query 585 + .get("limit") 586 .and_then(|s| s.parse::<i32>().ok()) 587 .unwrap_or(20) 588 .min(50); // Cap at 50 items per request 589 + 590 let mut statuses = StatusFromDb::load_statuses_paginated(&db_pool, offset, limit) 591 .await 592 .unwrap_or_else(|err| { 593 log::error!("Error loading statuses: {err}"); 594 vec![] 595 }); 596 + 597 // Resolve handles for each status 598 let mut quick_resolve_map: HashMap<Did, String> = HashMap::new(); 599 for db_status in &mut statuses { ··· 621 Err(_) => None, 622 }; 623 } 624 + 625 Ok(HttpResponse::Ok().json(statuses)) 626 } 627 ··· 629 #[get("/api/custom-emojis")] 630 async fn get_custom_emojis() -> Result<impl Responder> { 631 use std::fs; 632 + 633 #[derive(Serialize)] 634 struct SimpleEmoji { 635 name: String, 636 filename: String, 637 } 638 + 639 let emojis_dir = "static/emojis"; 640 let mut emojis = Vec::new(); 641 + 642 if let Ok(entries) = fs::read_dir(emojis_dir) { 643 for entry in entries.flatten() { 644 if let Some(filename) = entry.file_name().to_str() { 645 // Only include image files 646 + if filename.ends_with(".png") 647 + || filename.ends_with(".gif") 648 + || filename.ends_with(".jpg") 649 + || filename.ends_with(".webp") 650 + { 651 // Remove file extension to get name 652 + let name = filename 653 + .rsplit_once('.') 654 + .map(|(name, _)| name) 655 + .unwrap_or(filename) 656 + .to_string(); 657 emojis.push(SimpleEmoji { 658 name: name.clone(), 659 filename: filename.to_string(), ··· 662 } 663 } 664 } 665 + 666 // Sort by name 667 emojis.sort_by(|a, b| a.name.cmp(&b.name)); 668 + 669 Ok(HttpResponse::Ok().json(emojis)) 670 } 671 ··· 849 ) 850 .await; 851 852 + let is_admin = is_admin(did.as_str()); 853 let html = FeedTemplate { 854 title: TITLE, 855 profile: match profile { ··· 1045 match session.get::<String>("did").unwrap_or(None) { 1046 Some(did_string) => { 1047 let did = Did::new(did_string.clone()).expect("failed to parse did"); 1048 + 1049 // Parse the URI to verify it belongs to this user 1050 // URI format: at://did:plc:xxx/io.zzstoatzz.status.record/rkey 1051 let uri_parts: Vec<&str> = req.uri.split('/').collect(); ··· 1054 "error": "Invalid status URI format" 1055 })); 1056 } 1057 + 1058 // Extract DID from URI (at://did:plc:xxx/...) 1059 let uri_did_part = uri_parts[2]; 1060 if uri_did_part != did_string { ··· 1062 "error": "You can only delete your own statuses" 1063 })); 1064 } 1065 + 1066 // Extract record key 1067 if let Some(rkey) = uri_parts.last() { 1068 // Get OAuth session 1069 match oauth_client.restore(&did).await { 1070 Ok(session) => { 1071 let agent = Agent::new(session); 1072 + 1073 // Delete the record from ATProto 1074 let delete_request = 1075 atrium_api::com::atproto::repo::delete_record::InputData { ··· 1078 ) 1079 .expect("valid nsid"), 1080 repo: did.clone().into(), 1081 + rkey: atrium_api::types::string::RecordKey::new(rkey.to_string()) 1082 + .expect("valid rkey"), 1083 swap_commit: None, 1084 swap_record: None, 1085 }; 1086 + 1087 match agent 1088 .api 1089 .com ··· 1094 { 1095 Ok(_) => { 1096 // Also remove from local database 1097 + let _ = 1098 + StatusFromDb::delete_by_uri(&db_pool, req.uri.clone()).await; 1099 + 1100 HttpResponse::Ok().json(serde_json::json!({ 1101 "success": true 1102 })) ··· 1152 "error": "Admin access required" 1153 })); 1154 } 1155 + 1156 // Update the hidden status in the database 1157 let uri = req.uri.clone(); 1158 let hidden = req.hidden; 1159 + 1160 let result = db_pool 1161 .conn(move |conn| { 1162 conn.execute( ··· 1165 ) 1166 }) 1167 .await; 1168 + 1169 match result { 1170 Ok(rows_affected) if rows_affected > 0 => { 1171 HttpResponse::Ok().json(serde_json::json!({ ··· 1173 "message": if hidden { "Status hidden" } else { "Status unhidden" } 1174 })) 1175 } 1176 + Ok(_) => HttpResponse::NotFound().json(serde_json::json!({ 1177 + "error": "Status not found" 1178 + })), 1179 Err(err) => { 1180 log::error!("Error updating hidden status: {}", err); 1181 HttpResponse::InternalServerError().json(serde_json::json!({ ··· 1184 } 1185 } 1186 } 1187 + None => HttpResponse::Unauthorized().json(serde_json::json!({ 1188 + "error": "Not authenticated" 1189 + })), 1190 } 1191 } 1192 ··· 1263 did_string, 1264 form.status.clone(), 1265 ); 1266 + 1267 // Set the text field if provided 1268 status.text = form.text.clone(); 1269 + 1270 // Set the expiration time if provided 1271 if let Some(exp_str) = &form.expires_in { 1272 if let Some(duration) = parse_duration(exp_str) { ··· 1302 } 1303 } 1304 } 1305 + None => Err(AppError::AuthenticationError( 1306 + "You must be logged in to create a status.".to_string(), 1307 + )), 1308 } 1309 } 1310 1311 #[actix_web::main] 1312 async fn main() -> std::io::Result<()> { 1313 dotenv().ok(); 1314 + 1315 // Load configuration 1316 let config = config::Config::from_env().expect("Failed to load configuration"); 1317 let app_config = config.clone(); 1318 + 1319 env_logger::init_from_env(env_logger::Env::new().default_filter_or(&config.log_level)); 1320 let host = config.server_host.clone(); 1321 let port = config.server_port; 1322 1323 // Use database URL from config 1324 let db_connection_string = if config.database_url.starts_with("sqlite://") { 1325 + config 1326 + .database_url 1327 + .strip_prefix("sqlite://") 1328 + .unwrap_or(&config.database_url) 1329 + .to_string() 1330 } else { 1331 config.database_url.clone() 1332 }; ··· 1359 1360 // Create a new OAuth client 1361 let http_client = Arc::new(DefaultHttpClient::default()); 1362 + 1363 // Check if we're running in production (non-localhost) or locally 1364 + let is_production = !config.oauth_redirect_base.starts_with("http://localhost") 1365 && !config.oauth_redirect_base.starts_with("http://127.0.0.1"); 1366 + 1367 let client: OAuthClientType = if is_production { 1368 // Production configuration with AtprotoClientMetadata 1369 + log::info!( 1370 + "Configuring OAuth for production with URL: {}", 1371 + config.oauth_redirect_base 1372 + ); 1373 + 1374 let oauth_config = OAuthClientConfig { 1375 client_metadata: AtprotoClientMetadata { 1376 client_id: format!("{}/client-metadata.json", config.oauth_redirect_base), ··· 1404 Arc::new(OAuthClient::new(oauth_config).expect("failed to create OAuth client")) 1405 } else { 1406 // Local development configuration with AtprotoLocalhostClientMetadata 1407 + log::info!( 1408 + "Configuring OAuth for local development at {}:{}", 1409 + host, 1410 + port 1411 + ); 1412 + 1413 let oauth_config = OAuthClientConfig { 1414 client_metadata: AtprotoLocalhostClientMetadata { 1415 redirect_uris: Some(vec![format!( ··· 1451 log::info!("Jetstream firehose disabled (set ENABLE_FIREHOSE=true to enable)"); 1452 } 1453 let arc_pool = Arc::new(pool.clone()); 1454 + 1455 // Create rate limiter - 30 requests per minute per IP 1456 let rate_limiter = web::Data::new(RateLimiter::new(30, Duration::from_secs(60))); 1457 + 1458 log::info!("starting HTTP server at http://{host}:{port}"); 1459 HttpServer::new(move || { 1460 App::new() ··· 1502 #[cfg(test)] 1503 mod tests { 1504 use super::*; 1505 + use actix_web::{App, test}; 1506 1507 #[actix_web::test] 1508 async fn test_health_check() { ··· 1513 #[actix_web::test] 1514 async fn test_custom_emojis_endpoint() { 1515 // Test that the custom emojis endpoint returns JSON 1516 + let app = test::init_service(App::new().service(get_custom_emojis)).await; 1517 1518 let req = test::TestRequest::get() 1519 .uri("/api/custom-emojis") 1520 .to_request(); 1521 + 1522 let resp = test::call_service(&app, req).await; 1523 assert!(resp.status().is_success()); 1524 } ··· 1527 async fn test_rate_limiting() { 1528 // Simple test of the rate limiter directly 1529 let rate_limiter = RateLimiter::new(3, Duration::from_secs(60)); 1530 + 1531 // Should allow first 3 requests from same IP 1532 for i in 0..3 { 1533 + assert!( 1534 + rate_limiter.check_rate_limit("test_ip"), 1535 + "Request {} should be allowed", 1536 + i + 1 1537 + ); 1538 } 1539 + 1540 // 4th request should be blocked 1541 + assert!( 1542 + !rate_limiter.check_rate_limit("test_ip"), 1543 + "4th request should be blocked" 1544 + ); 1545 + 1546 // Different IP should have its own limit 1547 + assert!( 1548 + rate_limiter.check_rate_limit("different_ip"), 1549 + "Different IP should have its own rate limit" 1550 + ); 1551 } 1552 + 1553 #[actix_web::test] 1554 async fn test_error_handling() { 1555 use crate::error_handler::AppError; 1556 + use actix_web::{ResponseError, http::StatusCode}; 1557 + 1558 // Test that our error types return correct status codes 1559 let err = AppError::ValidationError("test".to_string()); 1560 assert_eq!(err.status_code(), StatusCode::BAD_REQUEST); 1561 + 1562 let err = AppError::RateLimitExceeded; 1563 assert_eq!(err.status_code(), StatusCode::TOO_MANY_REQUESTS); 1564 + 1565 let err = AppError::AuthenticationError("test".to_string()); 1566 assert_eq!(err.status_code(), StatusCode::UNAUTHORIZED); 1567 }
+12 -11
src/rate_limiter.rs
··· 27 pub fn check_rate_limit(&self, key: &str) -> bool { 28 let mut buckets = self.buckets.lock().unwrap(); 29 let now = Instant::now(); 30 - 31 let bucket = buckets.entry(key.to_string()).or_insert(TokenBucket { 32 tokens: self.max_tokens, 33 last_refill: now, ··· 35 36 // Refill tokens based on elapsed time 37 let elapsed = now.duration_since(bucket.last_refill); 38 - let tokens_to_add = (elapsed.as_secs_f64() / self.refill_rate.as_secs_f64() * self.max_tokens as f64) as u32; 39 - 40 if tokens_to_add > 0 { 41 bucket.tokens = (bucket.tokens + tokens_to_add).min(self.max_tokens); 42 bucket.last_refill = now; ··· 68 #[test] 69 fn test_rate_limiter_basic() { 70 let limiter = RateLimiter::new(5, Duration::from_secs(1)); 71 - 72 // Should allow first 5 requests 73 for _ in 0..5 { 74 assert!(limiter.check_rate_limit("test_client")); 75 } 76 - 77 // 6th request should be blocked 78 assert!(!limiter.check_rate_limit("test_client")); 79 } ··· 81 #[test] 82 fn test_rate_limiter_refill() { 83 let limiter = RateLimiter::new(2, Duration::from_millis(100)); 84 - 85 // Use up tokens 86 assert!(limiter.check_rate_limit("test_client")); 87 assert!(limiter.check_rate_limit("test_client")); 88 assert!(!limiter.check_rate_limit("test_client")); 89 - 90 // Wait for refill 91 thread::sleep(Duration::from_millis(150)); 92 - 93 // Should have tokens again 94 assert!(limiter.check_rate_limit("test_client")); 95 } ··· 97 #[test] 98 fn test_rate_limiter_different_clients() { 99 let limiter = RateLimiter::new(1, Duration::from_secs(1)); 100 - 101 // Different clients should have separate buckets 102 assert!(limiter.check_rate_limit("client1")); 103 assert!(limiter.check_rate_limit("client2")); 104 - 105 // But same client should be limited 106 assert!(!limiter.check_rate_limit("client1")); 107 assert!(!limiter.check_rate_limit("client2")); 108 } 109 - }
··· 27 pub fn check_rate_limit(&self, key: &str) -> bool { 28 let mut buckets = self.buckets.lock().unwrap(); 29 let now = Instant::now(); 30 + 31 let bucket = buckets.entry(key.to_string()).or_insert(TokenBucket { 32 tokens: self.max_tokens, 33 last_refill: now, ··· 35 36 // Refill tokens based on elapsed time 37 let elapsed = now.duration_since(bucket.last_refill); 38 + let tokens_to_add = (elapsed.as_secs_f64() / self.refill_rate.as_secs_f64() 39 + * self.max_tokens as f64) as u32; 40 + 41 if tokens_to_add > 0 { 42 bucket.tokens = (bucket.tokens + tokens_to_add).min(self.max_tokens); 43 bucket.last_refill = now; ··· 69 #[test] 70 fn test_rate_limiter_basic() { 71 let limiter = RateLimiter::new(5, Duration::from_secs(1)); 72 + 73 // Should allow first 5 requests 74 for _ in 0..5 { 75 assert!(limiter.check_rate_limit("test_client")); 76 } 77 + 78 // 6th request should be blocked 79 assert!(!limiter.check_rate_limit("test_client")); 80 } ··· 82 #[test] 83 fn test_rate_limiter_refill() { 84 let limiter = RateLimiter::new(2, Duration::from_millis(100)); 85 + 86 // Use up tokens 87 assert!(limiter.check_rate_limit("test_client")); 88 assert!(limiter.check_rate_limit("test_client")); 89 assert!(!limiter.check_rate_limit("test_client")); 90 + 91 // Wait for refill 92 thread::sleep(Duration::from_millis(150)); 93 + 94 // Should have tokens again 95 assert!(limiter.check_rate_limit("test_client")); 96 } ··· 98 #[test] 99 fn test_rate_limiter_different_clients() { 100 let limiter = RateLimiter::new(1, Duration::from_secs(1)); 101 + 102 // Different clients should have separate buckets 103 assert!(limiter.check_rate_limit("client1")); 104 assert!(limiter.check_rate_limit("client2")); 105 + 106 // But same client should be limited 107 assert!(!limiter.check_rate_limit("client1")); 108 assert!(!limiter.check_rate_limit("client2")); 109 } 110 + }