Coffee journaling on ATProto (alpha) alpha.arabica.social
coffee
at main 251 lines 7.1 kB view raw
1package middleware 2 3import ( 4 "context" 5 "crypto/rand" 6 "encoding/base64" 7 "net/http" 8 "strings" 9 "sync" 10 "time" 11) 12 13type cspNonceKeyType struct{} 14 15var cspNonceKey = cspNonceKeyType{} 16 17func generateNonce() (string, error) { 18 b := make([]byte, 16) 19 if _, err := rand.Read(b); err != nil { 20 return "", err 21 } 22 return base64.StdEncoding.EncodeToString(b), nil 23} 24 25func CSPNonceFromContext(ctx context.Context) string { 26 if v, ok := ctx.Value(cspNonceKey).(string); ok { 27 return v 28 } 29 return "" 30} 31 32// SecurityHeadersMiddleware adds security headers to all responses 33func SecurityHeadersMiddleware(next http.Handler) http.Handler { 34 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 35 nonce, err := generateNonce() 36 if err != nil { 37 http.Error(w, "Internal Server Error", http.StatusInternalServerError) 38 return 39 } 40 41 r = r.WithContext(context.WithValue(r.Context(), cspNonceKey, nonce)) 42 43 // Prevent clickjacking 44 w.Header().Set("X-Frame-Options", "DENY") 45 46 // Prevent MIME type sniffing 47 w.Header().Set("X-Content-Type-Options", "nosniff") 48 49 // XSS protection (legacy but still useful for older browsers) 50 w.Header().Set("X-XSS-Protection", "1; mode=block") 51 52 // Control referrer information 53 w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") 54 55 // Permissions policy - disable unnecessary features 56 w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()") 57 58 // Content Security Policy 59 // Allows: self for scripts/styles, inline styles (for Tailwind), inline HTMX/Alpine 60 // Note: unsafe-eval required for Alpine.js standard build (CSP build has CDN MIME type issues) 61 // Note: form-action allows https: for OAuth redirects to external authorization servers 62 // TODO: set nonce/hash on unsafe tags -- needs to be set in elements as well 63 csp := strings.Join([]string{ 64 "default-src 'self'", 65 "script-src 'self' 'unsafe-eval' 'nonce-" + nonce + "'", 66 "style-src 'self' 'unsafe-inline'", // unsafe-inline needed for Tailwind 67 "img-src 'self' https: data:", // Allow external images (avatars) and data URIs 68 "font-src 'self'", 69 "connect-src 'self' https:", // Allow connections to external APIs (OAuth, PDS) 70 "frame-ancestors 'none'", 71 "base-uri 'self'", 72 "form-action 'self' https:", // Allow form submissions to external OAuth servers 73 }, "; ") 74 w.Header().Set("Content-Security-Policy", csp) 75 76 next.ServeHTTP(w, r) 77 }) 78} 79 80// RateLimiter implements a simple per-IP rate limiter using token bucket algorithm 81type RateLimiter struct { 82 mu sync.Mutex 83 visitors map[string]*visitor 84 rate int // requests per window 85 window time.Duration // time window 86 cleanup time.Duration // cleanup interval for old entries 87} 88 89type visitor struct { 90 tokens int 91 lastReset time.Time 92} 93 94// NewRateLimiter creates a new rate limiter 95// rate: number of requests allowed per window 96// window: time window for rate limiting 97func NewRateLimiter(rate int, window time.Duration) *RateLimiter { 98 rl := &RateLimiter{ 99 visitors: make(map[string]*visitor), 100 rate: rate, 101 window: window, 102 cleanup: window * 2, 103 } 104 105 // Start cleanup goroutine 106 go rl.cleanupLoop() 107 108 return rl 109} 110 111func (rl *RateLimiter) cleanupLoop() { 112 ticker := time.NewTicker(rl.cleanup) 113 defer ticker.Stop() 114 115 for range ticker.C { 116 rl.mu.Lock() 117 now := time.Now() 118 for ip, v := range rl.visitors { 119 if now.Sub(v.lastReset) > rl.cleanup { 120 delete(rl.visitors, ip) 121 } 122 } 123 rl.mu.Unlock() 124 } 125} 126 127// Allow checks if a request from the given IP is allowed 128func (rl *RateLimiter) Allow(ip string) bool { 129 rl.mu.Lock() 130 defer rl.mu.Unlock() 131 132 now := time.Now() 133 v, exists := rl.visitors[ip] 134 135 if !exists { 136 rl.visitors[ip] = &visitor{ 137 tokens: rl.rate - 1, // Use one token 138 lastReset: now, 139 } 140 return true 141 } 142 143 // Reset tokens if window has passed 144 if now.Sub(v.lastReset) >= rl.window { 145 v.tokens = rl.rate - 1 146 v.lastReset = now 147 return true 148 } 149 150 // Check if tokens available 151 if v.tokens > 0 { 152 v.tokens-- 153 return true 154 } 155 156 return false 157} 158 159// RateLimitConfig holds configuration for rate limiting different endpoint types 160type RateLimitConfig struct { 161 // AuthLimiter for login/auth endpoints (stricter) 162 AuthLimiter *RateLimiter 163 // APILimiter for general API endpoints 164 APILimiter *RateLimiter 165 // GlobalLimiter for all other requests 166 GlobalLimiter *RateLimiter 167} 168 169// NewDefaultRateLimitConfig creates rate limiters with sensible defaults 170func NewDefaultRateLimitConfig() *RateLimitConfig { 171 return &RateLimitConfig{ 172 AuthLimiter: NewRateLimiter(5, time.Minute), // 5 auth attempts per minute 173 APILimiter: NewRateLimiter(60, time.Minute), // 60 API calls per minute 174 GlobalLimiter: NewRateLimiter(120, time.Minute), // 120 requests per minute 175 } 176} 177 178// RateLimitMiddleware creates a rate limiting middleware 179func RateLimitMiddleware(config *RateLimitConfig) func(http.Handler) http.Handler { 180 return func(next http.Handler) http.Handler { 181 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 182 ip := GetClientIP(r) 183 path := r.URL.Path 184 185 var limiter *RateLimiter 186 187 // Select appropriate limiter based on path 188 switch { 189 case strings.HasPrefix(path, "/auth/") || path == "/login" || path == "/oauth/callback": 190 limiter = config.AuthLimiter 191 case strings.HasPrefix(path, "/api/"): 192 limiter = config.APILimiter 193 default: 194 limiter = config.GlobalLimiter 195 } 196 197 if !limiter.Allow(ip) { 198 w.Header().Set("Retry-After", "60") 199 http.Error(w, "Too many requests", http.StatusTooManyRequests) 200 return 201 } 202 203 next.ServeHTTP(w, r) 204 }) 205 } 206} 207 208 209// RequireHTMXMiddleware ensures that certain API routes are only accessible via HTMX requests. 210// This prevents direct browser access to internal API endpoints that return fragments or JSON. 211// Routes that need to be publicly accessible (like /api/resolve-handle) should not use this middleware. 212func RequireHTMXMiddleware(next http.Handler) http.Handler { 213 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 214 // Check for HTMX request header 215 if r.Header.Get("HX-Request") != "true" { 216 http.NotFound(w, r) 217 return 218 } 219 next.ServeHTTP(w, r) 220 }) 221} 222 223// MaxBodySize limits the size of request bodies 224const ( 225 MaxJSONBodySize = 1 << 20 // 1 MB for JSON requests 226 MaxFormBodySize = 1 << 20 // 1 MB for form submissions 227) 228 229// LimitBodyMiddleware limits request body size to prevent DoS 230func LimitBodyMiddleware(next http.Handler) http.Handler { 231 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 232 if r.Body != nil { 233 contentType := r.Header.Get("Content-Type") 234 var maxSize int64 235 236 switch { 237 case strings.HasPrefix(contentType, "application/json"): 238 maxSize = MaxJSONBodySize 239 case strings.HasPrefix(contentType, "application/x-www-form-urlencoded"), 240 strings.HasPrefix(contentType, "multipart/form-data"): 241 maxSize = MaxFormBodySize 242 default: 243 maxSize = MaxJSONBodySize // Default limit 244 } 245 246 r.Body = http.MaxBytesReader(w, r.Body, maxSize) 247 } 248 249 next.ServeHTTP(w, r) 250 }) 251}