forked from
smokesignal.events/smokesignal
i18n+filtering fork - fluent-templates v2
1use axum::{
2 extract::{ConnectInfo, Request},
3 http::{HeaderMap, StatusCode},
4 middleware::Next,
5 response::{Response, Json},
6};
7use axum_extra::extract::PrivateCookieJar;
8use reqwest;
9use serde::Deserialize;
10use std::{
11 collections::HashMap,
12 net::{IpAddr, SocketAddr},
13 sync::{Arc, Mutex},
14 time::{Duration, SystemTime},
15};
16use tokio::time::timeout;
17
18use crate::{
19 http::context::WebContext,
20 http::middleware_auth::{WebSession, AUTH_COOKIE_NAME},
21 storage::oauth::web_session_lookup,
22};
23
24// Extension to store the detected timezone
25#[derive(Clone, Debug)]
26pub struct DetectedTimezone {
27 pub timezone: String,
28 pub source: String,
29 pub coordinates: Option<(f64, f64)>,
30}
31
32// Simple cache to avoid too many API calls
33type TimezoneCache = Arc<Mutex<HashMap<IpAddr, (DetectedTimezone, SystemTime)>>>;
34
35// Middleware configuration
36#[derive(Clone)]
37pub struct TimezoneConfig {
38 pub cache_duration: Duration,
39 pub request_timeout: Duration,
40 pub fallback_timezone: String,
41}
42
43impl Default for TimezoneConfig {
44 fn default() -> Self {
45 Self {
46 cache_duration: Duration::from_secs(3600), // 1 hour
47 request_timeout: Duration::from_secs(3),
48 fallback_timezone: "UTC".to_string(),
49 }
50 }
51}
52
53// Main middleware function
54pub async fn timezone_middleware(
55 mut request: Request<axum::body::Body>,
56 next: Next,
57) -> Result<Response, StatusCode> {
58 // Extract what we need from the request
59 let headers = request.headers().clone();
60 let web_context = request.extensions().get::<WebContext>().cloned();
61 let addr = request.extensions().get::<ConnectInfo<SocketAddr>>().map(|info| info.0);
62
63 // Try to get the real client IP from forwarded headers (for tunnels/proxies)
64 let client_ip = if let Some(socket_addr) = addr {
65 get_client_ip(&headers, socket_addr.ip())
66 } else {
67 // Fallback to localhost if no connection info available
68 std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))
69 };
70
71 tracing::debug!("Timezone middleware executed for IP: {} (original: {:?})",
72 client_ip, addr.map(|a| a.ip()));
73
74 let config = TimezoneConfig::default();
75 let cache: TimezoneCache = Arc::new(Mutex::new(HashMap::new()));
76
77 let detected_timezone = detect_timezone_with_cache(
78 client_ip,
79 &headers,
80 &config,
81 &cache,
82 web_context.as_ref(),
83 ).await;
84
85 tracing::debug!("Detected timezone: {} from source: {}",
86 detected_timezone.timezone, detected_timezone.source);
87
88 // Add the detected timezone to request extensions
89 request.extensions_mut().insert(detected_timezone);
90
91 Ok(next.run(request).await)
92}
93
94// Detection with cache
95async fn detect_timezone_with_cache(
96 ip: IpAddr,
97 headers: &HeaderMap,
98 config: &TimezoneConfig,
99 cache: &TimezoneCache,
100 web_context: Option<&WebContext>,
101) -> DetectedTimezone {
102 // Priority 0: User preference from authenticated user's settings
103 if let Some(user_tz) = get_user_timezone_preference(headers, web_context).await {
104 return DetectedTimezone {
105 timezone: user_tz,
106 source: "user_preference".to_string(),
107 coordinates: None,
108 };
109 }
110
111 // Check cache first for IP-based detection
112 if let Ok(cache_lock) = cache.lock() {
113 if let Some((cached_tz, timestamp)) = cache_lock.get(&ip) {
114 if timestamp.elapsed().unwrap_or(Duration::MAX) < config.cache_duration {
115 // Only return cached result if it wasn't from user preference
116 if cached_tz.source != "user_preference" {
117 return cached_tz.clone();
118 }
119 }
120 }
121 }
122
123 // Priority 1: Custom header
124 if let Some(tz_header) = headers.get("x-timezone") {
125 if let Ok(timezone) = tz_header.to_str() {
126 let detected = DetectedTimezone {
127 timezone: timezone.to_string(),
128 source: "header".to_string(),
129 coordinates: None,
130 };
131 cache_timezone(cache, ip, detected.clone());
132 return detected;
133 }
134 }
135
136 // Priority 2: IP detection (only for public IPs)
137 if !is_private_ip(ip) {
138 if let Some(detected) = detect_timezone_by_ip(ip, config).await {
139 cache_timezone(cache, ip, detected.clone());
140 return detected;
141 }
142 }
143
144 // Fallback
145 let fallback = DetectedTimezone {
146 timezone: config.fallback_timezone.clone(),
147 source: "fallback".to_string(),
148 coordinates: None,
149 };
150
151 cache_timezone(cache, ip, fallback.clone());
152 fallback
153}
154
155// Get authenticated user's timezone preference
156async fn get_user_timezone_preference(headers: &HeaderMap, web_context: Option<&WebContext>) -> Option<String> {
157 let web_context = web_context?;
158
159 // Extract session from cookies similar to how Auth middleware does it
160 let cookie_jar = PrivateCookieJar::from_headers(
161 headers,
162 web_context.config.http_cookie_key.as_ref().clone(),
163 );
164
165 let session = cookie_jar
166 .get(AUTH_COOKIE_NAME)
167 .map(|user_cookie| user_cookie.value().to_owned())
168 .and_then(|inner_value| WebSession::try_from(inner_value).ok())?;
169
170 // Look up the user's handle to get their timezone preference
171 match web_session_lookup(
172 &web_context.pool,
173 &session.session_group,
174 Some(&session.did),
175 ).await {
176 Ok((handle, _)) => {
177 // Check if user has a non-default timezone set
178 if !handle.tz.is_empty() && handle.tz != "UTC" {
179 Some(handle.tz)
180 } else {
181 None
182 }
183 }
184 Err(_) => None,
185 }
186}
187
188// IP detection with multiple services
189async fn detect_timezone_by_ip(ip: IpAddr, config: &TimezoneConfig) -> Option<DetectedTimezone> {
190 // Try WorldTimeAPI first
191 if let Some(tz) = get_worldtime_timezone_ip(ip, config.request_timeout).await {
192 return Some(DetectedTimezone {
193 timezone: tz,
194 source: "worldtimeapi".to_string(),
195 coordinates: None,
196 });
197 }
198
199 // Fallback to ipapi.co
200 if let Some((tz, coords)) = get_ipapi_detailed(ip, config.request_timeout).await {
201 return Some(DetectedTimezone {
202 timezone: tz,
203 source: "ipapi".to_string(),
204 coordinates: Some(coords),
205 });
206 }
207
208 // Last fallback to ip-api.com
209 get_ip_api_detailed(ip, config.request_timeout).await
210}
211
212// Individual API services
213async fn get_worldtime_timezone_ip(_ip: IpAddr, timeout_duration: Duration) -> Option<String> {
214 #[derive(Deserialize)]
215 struct WorldTimeResponse {
216 timezone: String,
217 }
218
219 let client = reqwest::Client::new();
220 let response: WorldTimeResponse = timeout(
221 timeout_duration,
222 client.get("http://worldtimeapi.org/api/ip").send()
223 )
224 .await
225 .ok()?
226 .ok()?
227 .json()
228 .await
229 .ok()?;
230
231 Some(response.timezone)
232}
233
234async fn get_ipapi_detailed(ip: IpAddr, timeout_duration: Duration) -> Option<(String, (f64, f64))> {
235 #[derive(Deserialize)]
236 struct IpapiResponse {
237 timezone: Option<String>,
238 latitude: Option<f64>,
239 longitude: Option<f64>,
240 }
241
242 let url = format!("https://ipapi.co/{}/json/", ip);
243 let client = reqwest::Client::new();
244
245 let response: IpapiResponse = timeout(
246 timeout_duration,
247 client.get(&url).send()
248 )
249 .await
250 .ok()?
251 .ok()?
252 .json()
253 .await
254 .ok()?;
255
256 match (response.timezone, response.latitude, response.longitude) {
257 (Some(tz), Some(lat), Some(lon)) => Some((tz, (lat, lon))),
258 _ => None,
259 }
260}
261
262async fn get_ip_api_detailed(ip: IpAddr, timeout_duration: Duration) -> Option<DetectedTimezone> {
263 #[derive(Deserialize)]
264 struct IpApiResponse {
265 timezone: Option<String>,
266 lat: Option<f64>,
267 lon: Option<f64>,
268 status: String,
269 }
270
271 let url = format!("http://ip-api.com/json/{}?fields=timezone,lat,lon,status", ip);
272 let client = reqwest::Client::new();
273
274 let response: IpApiResponse = timeout(
275 timeout_duration,
276 client.get(&url).send()
277 )
278 .await
279 .ok()?
280 .ok()?
281 .json()
282 .await
283 .ok()?;
284
285 if response.status == "success" {
286 Some(DetectedTimezone {
287 timezone: response.timezone.unwrap_or_else(|| "UTC".to_string()),
288 source: "ip-api".to_string(),
289 coordinates: match (response.lat, response.lon) {
290 (Some(lat), Some(lon)) => Some((lat, lon)),
291 _ => None,
292 },
293 })
294 } else {
295 None
296 }
297}
298
299// Utilities
300fn get_client_ip(headers: &HeaderMap, fallback_ip: IpAddr) -> IpAddr {
301 // Check common forwarded headers in order of preference
302 let forwarded_headers = [
303 "cf-connecting-ip", // Cloudflare
304 "x-forwarded-for", // Standard proxy header
305 "x-real-ip", // Nginx proxy
306 "x-client-ip", // Some proxies
307 "x-forwarded", // Less common
308 "forwarded-for", // Older standard
309 "forwarded", // RFC 7239
310 ];
311
312 for header_name in &forwarded_headers {
313 if let Some(header_value) = headers.get(*header_name) {
314 if let Ok(header_str) = header_value.to_str() {
315 // x-forwarded-for can have multiple IPs: "client, proxy1, proxy2"
316 // We want the first (leftmost) IP which is the original client
317 let first_ip = header_str.split(',').next().unwrap_or("").trim();
318
319 if let Ok(ip) = first_ip.parse::<IpAddr>() {
320 // Make sure it's not a private/loopback IP (unless fallback is also private)
321 if !is_private_ip(ip) || is_private_ip(fallback_ip) {
322 tracing::debug!("Found real client IP {} in header {}", ip, header_name);
323 return ip;
324 }
325 }
326 }
327 }
328 }
329
330 // No valid forwarded IP found, use the direct connection IP
331 fallback_ip
332}
333
334fn is_private_ip(ip: IpAddr) -> bool {
335 match ip {
336 IpAddr::V4(ipv4) => {
337 ipv4.is_private() || ipv4.is_loopback() || ipv4.is_link_local()
338 }
339 IpAddr::V6(ipv6) => {
340 ipv6.is_loopback() || ipv6.is_multicast()
341 }
342 }
343}
344
345fn cache_timezone(cache: &TimezoneCache, ip: IpAddr, timezone: DetectedTimezone) {
346 if let Ok(mut cache_lock) = cache.lock() {
347 cache_lock.insert(ip, (timezone, SystemTime::now()));
348 }
349}
350
351// Trait to easily access the detected timezone from a request
352pub trait RequestTimezoneExt {
353 fn timezone(&self) -> Option<&DetectedTimezone>;
354}
355
356impl RequestTimezoneExt for Request {
357 fn timezone(&self) -> Option<&DetectedTimezone> {
358 self.extensions().get::<DetectedTimezone>()
359 }
360}
361
362// Example usage in a handler
363pub async fn example_handler(request: Request) -> Json<serde_json::Value> {
364 let timezone_info = RequestTimezoneExt::timezone(&request).cloned().unwrap_or_else(|| DetectedTimezone {
365 timezone: "Unknown".to_string(),
366 source: "error".to_string(),
367 coordinates: None,
368 });
369
370 Json(serde_json::json!({
371 "timezone": timezone_info.timezone,
372 "source": timezone_info.source,
373 "coordinates": timezone_info.coordinates,
374 "message": format!("Your detected timezone: {}", timezone_info.timezone)
375 }))
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use axum::{
382 middleware,
383 routing::get,
384 Router,
385 };
386 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
387
388 #[tokio::test]
389 async fn test_timezone_detection() {
390 let _app = Router::new()
391 .route("/test", get(example_handler))
392 .layer(middleware::from_fn(timezone_middleware))
393 .into_make_service_with_connect_info::<SocketAddr>();
394
395 // Test with a simulated public IP
396 let _addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 80);
397
398 // Here you could add more detailed tests
399 assert!(true); // Placeholder test
400 }
401}