use axum::{ extract::{ConnectInfo, Request}, http::{HeaderMap, StatusCode}, middleware::Next, response::{Response, Json}, }; use axum_extra::extract::PrivateCookieJar; use reqwest; use serde::Deserialize; use std::{ collections::HashMap, net::{IpAddr, SocketAddr}, sync::{Arc, Mutex}, time::{Duration, SystemTime}, }; use tokio::time::timeout; use crate::{ http::context::WebContext, http::middleware_auth::{WebSession, AUTH_COOKIE_NAME}, storage::oauth::web_session_lookup, }; // Extension to store the detected timezone #[derive(Clone, Debug)] pub struct DetectedTimezone { pub timezone: String, pub source: String, pub coordinates: Option<(f64, f64)>, } // Simple cache to avoid too many API calls type TimezoneCache = Arc>>; // Middleware configuration #[derive(Clone)] pub struct TimezoneConfig { pub cache_duration: Duration, pub request_timeout: Duration, pub fallback_timezone: String, } impl Default for TimezoneConfig { fn default() -> Self { Self { cache_duration: Duration::from_secs(3600), // 1 hour request_timeout: Duration::from_secs(3), fallback_timezone: "UTC".to_string(), } } } // Main middleware function pub async fn timezone_middleware( mut request: Request, next: Next, ) -> Result { // Extract what we need from the request let headers = request.headers().clone(); let web_context = request.extensions().get::().cloned(); let addr = request.extensions().get::>().map(|info| info.0); // Try to get the real client IP from forwarded headers (for tunnels/proxies) let client_ip = if let Some(socket_addr) = addr { get_client_ip(&headers, socket_addr.ip()) } else { // Fallback to localhost if no connection info available std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)) }; tracing::debug!("Timezone middleware executed for IP: {} (original: {:?})", client_ip, addr.map(|a| a.ip())); let config = TimezoneConfig::default(); let cache: TimezoneCache = Arc::new(Mutex::new(HashMap::new())); let detected_timezone = detect_timezone_with_cache( client_ip, &headers, &config, &cache, web_context.as_ref(), ).await; tracing::debug!("Detected timezone: {} from source: {}", detected_timezone.timezone, detected_timezone.source); // Add the detected timezone to request extensions request.extensions_mut().insert(detected_timezone); Ok(next.run(request).await) } // Detection with cache async fn detect_timezone_with_cache( ip: IpAddr, headers: &HeaderMap, config: &TimezoneConfig, cache: &TimezoneCache, web_context: Option<&WebContext>, ) -> DetectedTimezone { // Priority 0: User preference from authenticated user's settings if let Some(user_tz) = get_user_timezone_preference(headers, web_context).await { return DetectedTimezone { timezone: user_tz, source: "user_preference".to_string(), coordinates: None, }; } // Check cache first for IP-based detection if let Ok(cache_lock) = cache.lock() { if let Some((cached_tz, timestamp)) = cache_lock.get(&ip) { if timestamp.elapsed().unwrap_or(Duration::MAX) < config.cache_duration { // Only return cached result if it wasn't from user preference if cached_tz.source != "user_preference" { return cached_tz.clone(); } } } } // Priority 1: Custom header if let Some(tz_header) = headers.get("x-timezone") { if let Ok(timezone) = tz_header.to_str() { let detected = DetectedTimezone { timezone: timezone.to_string(), source: "header".to_string(), coordinates: None, }; cache_timezone(cache, ip, detected.clone()); return detected; } } // Priority 2: IP detection (only for public IPs) if !is_private_ip(ip) { if let Some(detected) = detect_timezone_by_ip(ip, config).await { cache_timezone(cache, ip, detected.clone()); return detected; } } // Fallback let fallback = DetectedTimezone { timezone: config.fallback_timezone.clone(), source: "fallback".to_string(), coordinates: None, }; cache_timezone(cache, ip, fallback.clone()); fallback } // Get authenticated user's timezone preference async fn get_user_timezone_preference(headers: &HeaderMap, web_context: Option<&WebContext>) -> Option { let web_context = web_context?; // Extract session from cookies similar to how Auth middleware does it let cookie_jar = PrivateCookieJar::from_headers( headers, web_context.config.http_cookie_key.as_ref().clone(), ); let session = cookie_jar .get(AUTH_COOKIE_NAME) .map(|user_cookie| user_cookie.value().to_owned()) .and_then(|inner_value| WebSession::try_from(inner_value).ok())?; // Look up the user's handle to get their timezone preference match web_session_lookup( &web_context.pool, &session.session_group, Some(&session.did), ).await { Ok((handle, _)) => { // Check if user has a non-default timezone set if !handle.tz.is_empty() && handle.tz != "UTC" { Some(handle.tz) } else { None } } Err(_) => None, } } // IP detection with multiple services async fn detect_timezone_by_ip(ip: IpAddr, config: &TimezoneConfig) -> Option { // Try WorldTimeAPI first if let Some(tz) = get_worldtime_timezone_ip(ip, config.request_timeout).await { return Some(DetectedTimezone { timezone: tz, source: "worldtimeapi".to_string(), coordinates: None, }); } // Fallback to ipapi.co if let Some((tz, coords)) = get_ipapi_detailed(ip, config.request_timeout).await { return Some(DetectedTimezone { timezone: tz, source: "ipapi".to_string(), coordinates: Some(coords), }); } // Last fallback to ip-api.com get_ip_api_detailed(ip, config.request_timeout).await } // Individual API services async fn get_worldtime_timezone_ip(_ip: IpAddr, timeout_duration: Duration) -> Option { #[derive(Deserialize)] struct WorldTimeResponse { timezone: String, } let client = reqwest::Client::new(); let response: WorldTimeResponse = timeout( timeout_duration, client.get("http://worldtimeapi.org/api/ip").send() ) .await .ok()? .ok()? .json() .await .ok()?; Some(response.timezone) } async fn get_ipapi_detailed(ip: IpAddr, timeout_duration: Duration) -> Option<(String, (f64, f64))> { #[derive(Deserialize)] struct IpapiResponse { timezone: Option, latitude: Option, longitude: Option, } let url = format!("https://ipapi.co/{}/json/", ip); let client = reqwest::Client::new(); let response: IpapiResponse = timeout( timeout_duration, client.get(&url).send() ) .await .ok()? .ok()? .json() .await .ok()?; match (response.timezone, response.latitude, response.longitude) { (Some(tz), Some(lat), Some(lon)) => Some((tz, (lat, lon))), _ => None, } } async fn get_ip_api_detailed(ip: IpAddr, timeout_duration: Duration) -> Option { #[derive(Deserialize)] struct IpApiResponse { timezone: Option, lat: Option, lon: Option, status: String, } let url = format!("http://ip-api.com/json/{}?fields=timezone,lat,lon,status", ip); let client = reqwest::Client::new(); let response: IpApiResponse = timeout( timeout_duration, client.get(&url).send() ) .await .ok()? .ok()? .json() .await .ok()?; if response.status == "success" { Some(DetectedTimezone { timezone: response.timezone.unwrap_or_else(|| "UTC".to_string()), source: "ip-api".to_string(), coordinates: match (response.lat, response.lon) { (Some(lat), Some(lon)) => Some((lat, lon)), _ => None, }, }) } else { None } } // Utilities fn get_client_ip(headers: &HeaderMap, fallback_ip: IpAddr) -> IpAddr { // Check common forwarded headers in order of preference let forwarded_headers = [ "cf-connecting-ip", // Cloudflare "x-forwarded-for", // Standard proxy header "x-real-ip", // Nginx proxy "x-client-ip", // Some proxies "x-forwarded", // Less common "forwarded-for", // Older standard "forwarded", // RFC 7239 ]; for header_name in &forwarded_headers { if let Some(header_value) = headers.get(*header_name) { if let Ok(header_str) = header_value.to_str() { // x-forwarded-for can have multiple IPs: "client, proxy1, proxy2" // We want the first (leftmost) IP which is the original client let first_ip = header_str.split(',').next().unwrap_or("").trim(); if let Ok(ip) = first_ip.parse::() { // Make sure it's not a private/loopback IP (unless fallback is also private) if !is_private_ip(ip) || is_private_ip(fallback_ip) { tracing::debug!("Found real client IP {} in header {}", ip, header_name); return ip; } } } } } // No valid forwarded IP found, use the direct connection IP fallback_ip } fn is_private_ip(ip: IpAddr) -> bool { match ip { IpAddr::V4(ipv4) => { ipv4.is_private() || ipv4.is_loopback() || ipv4.is_link_local() } IpAddr::V6(ipv6) => { ipv6.is_loopback() || ipv6.is_multicast() } } } fn cache_timezone(cache: &TimezoneCache, ip: IpAddr, timezone: DetectedTimezone) { if let Ok(mut cache_lock) = cache.lock() { cache_lock.insert(ip, (timezone, SystemTime::now())); } } // Trait to easily access the detected timezone from a request pub trait RequestTimezoneExt { fn timezone(&self) -> Option<&DetectedTimezone>; } impl RequestTimezoneExt for Request { fn timezone(&self) -> Option<&DetectedTimezone> { self.extensions().get::() } } // Example usage in a handler pub async fn example_handler(request: Request) -> Json { let timezone_info = RequestTimezoneExt::timezone(&request).cloned().unwrap_or_else(|| DetectedTimezone { timezone: "Unknown".to_string(), source: "error".to_string(), coordinates: None, }); Json(serde_json::json!({ "timezone": timezone_info.timezone, "source": timezone_info.source, "coordinates": timezone_info.coordinates, "message": format!("Your detected timezone: {}", timezone_info.timezone) })) } #[cfg(test)] mod tests { use super::*; use axum::{ middleware, routing::get, Router, }; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; #[tokio::test] async fn test_timezone_detection() { let _app = Router::new() .route("/test", get(example_handler)) .layer(middleware::from_fn(timezone_middleware)) .into_make_service_with_connect_info::(); // Test with a simulated public IP let _addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 80); // Here you could add more detailed tests assert!(true); // Placeholder test } }