Coffee journaling on ATProto (alpha) alpha.arabica.social
coffee
at main 382 lines 11 kB view raw
1package middleware 2 3import ( 4 "context" 5 "net/http" 6 "net/http/httptest" 7 "strings" 8 "testing" 9 "time" 10 11 "github.com/stretchr/testify/assert" 12 "github.com/stretchr/testify/require" 13) 14 15func TestSecurityHeadersMiddleware(t *testing.T) { 16 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 17 // Verify nonce is available in context 18 nonce := CSPNonceFromContext(r.Context()) 19 assert.NotEmpty(t, nonce, "nonce should be set in context") 20 w.WriteHeader(http.StatusOK) 21 }) 22 23 wrapped := SecurityHeadersMiddleware(handler) 24 req := httptest.NewRequest(http.MethodGet, "/", nil) 25 rec := httptest.NewRecorder() 26 27 wrapped.ServeHTTP(rec, req) 28 29 assert.Equal(t, http.StatusOK, rec.Code) 30 assert.Equal(t, "DENY", rec.Header().Get("X-Frame-Options")) 31 assert.Equal(t, "nosniff", rec.Header().Get("X-Content-Type-Options")) 32 assert.Equal(t, "1; mode=block", rec.Header().Get("X-XSS-Protection")) 33 assert.Equal(t, "strict-origin-when-cross-origin", rec.Header().Get("Referrer-Policy")) 34 assert.Equal(t, "geolocation=(), microphone=(), camera=()", rec.Header().Get("Permissions-Policy")) 35 36 csp := rec.Header().Get("Content-Security-Policy") 37 assert.Contains(t, csp, "default-src 'self'") 38 assert.Contains(t, csp, "script-src 'self' 'unsafe-eval' 'nonce-") 39 assert.Contains(t, csp, "frame-ancestors 'none'") 40} 41 42func TestCSPNonceFromContext(t *testing.T) { 43 t.Run("returns nonce when set", func(t *testing.T) { 44 ctx := context.WithValue(context.Background(), cspNonceKey, "test-nonce-123") 45 assert.Equal(t, "test-nonce-123", CSPNonceFromContext(ctx)) 46 }) 47 48 t.Run("returns empty string when not set", func(t *testing.T) { 49 assert.Equal(t, "", CSPNonceFromContext(context.Background())) 50 }) 51 52 t.Run("returns empty string for wrong type", func(t *testing.T) { 53 ctx := context.WithValue(context.Background(), cspNonceKey, 12345) 54 assert.Equal(t, "", CSPNonceFromContext(ctx)) 55 }) 56} 57 58func TestCSPNonceUniqueness(t *testing.T) { 59 nonces := make(map[string]bool) 60 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 61 nonce := CSPNonceFromContext(r.Context()) 62 nonces[nonce] = true 63 w.WriteHeader(http.StatusOK) 64 }) 65 66 wrapped := SecurityHeadersMiddleware(handler) 67 68 for i := 0; i < 10; i++ { 69 req := httptest.NewRequest(http.MethodGet, "/", nil) 70 rec := httptest.NewRecorder() 71 wrapped.ServeHTTP(rec, req) 72 } 73 74 assert.Len(t, nonces, 10, "each request should get a unique nonce") 75} 76 77func TestRateLimiter_Allow(t *testing.T) { 78 t.Run("allows requests within limit", func(t *testing.T) { 79 rl := &RateLimiter{ 80 visitors: make(map[string]*visitor), 81 rate: 3, 82 window: time.Minute, 83 cleanup: 2 * time.Minute, 84 } 85 86 assert.True(t, rl.Allow("192.168.1.1")) 87 assert.True(t, rl.Allow("192.168.1.1")) 88 assert.True(t, rl.Allow("192.168.1.1")) 89 }) 90 91 t.Run("blocks after exceeding limit", func(t *testing.T) { 92 rl := &RateLimiter{ 93 visitors: make(map[string]*visitor), 94 rate: 2, 95 window: time.Minute, 96 cleanup: 2 * time.Minute, 97 } 98 99 assert.True(t, rl.Allow("10.0.0.1")) 100 assert.True(t, rl.Allow("10.0.0.1")) 101 assert.False(t, rl.Allow("10.0.0.1")) 102 }) 103 104 t.Run("different IPs are independent", func(t *testing.T) { 105 rl := &RateLimiter{ 106 visitors: make(map[string]*visitor), 107 rate: 1, 108 window: time.Minute, 109 cleanup: 2 * time.Minute, 110 } 111 112 assert.True(t, rl.Allow("10.0.0.1")) 113 assert.False(t, rl.Allow("10.0.0.1")) 114 assert.True(t, rl.Allow("10.0.0.2")) 115 }) 116 117 t.Run("resets after window expires", func(t *testing.T) { 118 rl := &RateLimiter{ 119 visitors: make(map[string]*visitor), 120 rate: 1, 121 window: 50 * time.Millisecond, 122 cleanup: 100 * time.Millisecond, 123 } 124 125 assert.True(t, rl.Allow("10.0.0.1")) 126 assert.False(t, rl.Allow("10.0.0.1")) 127 128 time.Sleep(60 * time.Millisecond) 129 assert.True(t, rl.Allow("10.0.0.1")) 130 }) 131} 132 133func TestRateLimitMiddleware(t *testing.T) { 134 config := &RateLimitConfig{ 135 AuthLimiter: &RateLimiter{visitors: make(map[string]*visitor), rate: 2, window: time.Minute, cleanup: 2 * time.Minute}, 136 APILimiter: &RateLimiter{visitors: make(map[string]*visitor), rate: 3, window: time.Minute, cleanup: 2 * time.Minute}, 137 GlobalLimiter: &RateLimiter{visitors: make(map[string]*visitor), rate: 5, window: time.Minute, cleanup: 2 * time.Minute}, 138 } 139 140 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 141 w.WriteHeader(http.StatusOK) 142 }) 143 144 middleware := RateLimitMiddleware(config) 145 wrapped := middleware(handler) 146 147 t.Run("auth endpoints use auth limiter", func(t *testing.T) { 148 for i := 0; i < 2; i++ { 149 req := httptest.NewRequest(http.MethodPost, "/auth/login", nil) 150 req.RemoteAddr = "1.1.1.1:1234" 151 rec := httptest.NewRecorder() 152 wrapped.ServeHTTP(rec, req) 153 assert.Equal(t, http.StatusOK, rec.Code) 154 } 155 156 req := httptest.NewRequest(http.MethodPost, "/auth/login", nil) 157 req.RemoteAddr = "1.1.1.1:1234" 158 rec := httptest.NewRecorder() 159 wrapped.ServeHTTP(rec, req) 160 assert.Equal(t, http.StatusTooManyRequests, rec.Code) 161 assert.Equal(t, "60", rec.Header().Get("Retry-After")) 162 }) 163 164 t.Run("api endpoints use api limiter", func(t *testing.T) { 165 for i := 0; i < 3; i++ { 166 req := httptest.NewRequest(http.MethodGet, "/api/brews", nil) 167 req.RemoteAddr = "2.2.2.2:1234" 168 rec := httptest.NewRecorder() 169 wrapped.ServeHTTP(rec, req) 170 assert.Equal(t, http.StatusOK, rec.Code) 171 } 172 173 req := httptest.NewRequest(http.MethodGet, "/api/brews", nil) 174 req.RemoteAddr = "2.2.2.2:1234" 175 rec := httptest.NewRecorder() 176 wrapped.ServeHTTP(rec, req) 177 assert.Equal(t, http.StatusTooManyRequests, rec.Code) 178 }) 179 180 t.Run("other endpoints use global limiter", func(t *testing.T) { 181 for i := 0; i < 5; i++ { 182 req := httptest.NewRequest(http.MethodGet, "/brews", nil) 183 req.RemoteAddr = "3.3.3.3:1234" 184 rec := httptest.NewRecorder() 185 wrapped.ServeHTTP(rec, req) 186 assert.Equal(t, http.StatusOK, rec.Code) 187 } 188 189 req := httptest.NewRequest(http.MethodGet, "/brews", nil) 190 req.RemoteAddr = "3.3.3.3:1234" 191 rec := httptest.NewRecorder() 192 wrapped.ServeHTTP(rec, req) 193 assert.Equal(t, http.StatusTooManyRequests, rec.Code) 194 }) 195 196 t.Run("login path uses auth limiter", func(t *testing.T) { 197 for i := 0; i < 2; i++ { 198 req := httptest.NewRequest(http.MethodPost, "/login", nil) 199 req.RemoteAddr = "4.4.4.4:1234" 200 rec := httptest.NewRecorder() 201 wrapped.ServeHTTP(rec, req) 202 assert.Equal(t, http.StatusOK, rec.Code) 203 } 204 205 req := httptest.NewRequest(http.MethodPost, "/login", nil) 206 req.RemoteAddr = "4.4.4.4:1234" 207 rec := httptest.NewRecorder() 208 wrapped.ServeHTTP(rec, req) 209 assert.Equal(t, http.StatusTooManyRequests, rec.Code) 210 }) 211} 212 213func TestRequireHTMXMiddleware(t *testing.T) { 214 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 215 w.WriteHeader(http.StatusOK) 216 w.Write([]byte("OK")) 217 }) 218 219 wrapped := RequireHTMXMiddleware(handler) 220 221 t.Run("allows HTMX requests", func(t *testing.T) { 222 req := httptest.NewRequest(http.MethodGet, "/api/partial", nil) 223 req.Header.Set("HX-Request", "true") 224 rec := httptest.NewRecorder() 225 226 wrapped.ServeHTTP(rec, req) 227 assert.Equal(t, http.StatusOK, rec.Code) 228 assert.Equal(t, "OK", rec.Body.String()) 229 }) 230 231 t.Run("blocks non-HTMX requests", func(t *testing.T) { 232 req := httptest.NewRequest(http.MethodGet, "/api/partial", nil) 233 rec := httptest.NewRecorder() 234 235 wrapped.ServeHTTP(rec, req) 236 assert.Equal(t, http.StatusNotFound, rec.Code) 237 }) 238 239 t.Run("blocks wrong HX-Request value", func(t *testing.T) { 240 req := httptest.NewRequest(http.MethodGet, "/api/partial", nil) 241 req.Header.Set("HX-Request", "false") 242 rec := httptest.NewRecorder() 243 244 wrapped.ServeHTTP(rec, req) 245 assert.Equal(t, http.StatusNotFound, rec.Code) 246 }) 247} 248 249func TestLimitBodyMiddleware(t *testing.T) { 250 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 251 // Try to read the body 252 buf := make([]byte, 2<<20) // 2MB buffer 253 _, err := r.Body.Read(buf) 254 if err != nil && err.Error() != "EOF" { 255 http.Error(w, "body too large", http.StatusRequestEntityTooLarge) 256 return 257 } 258 w.WriteHeader(http.StatusOK) 259 }) 260 261 wrapped := LimitBodyMiddleware(handler) 262 263 t.Run("allows small JSON body", func(t *testing.T) { 264 body := strings.NewReader(`{"name": "test"}`) 265 req := httptest.NewRequest(http.MethodPost, "/api/test", body) 266 req.Header.Set("Content-Type", "application/json") 267 rec := httptest.NewRecorder() 268 269 wrapped.ServeHTTP(rec, req) 270 assert.Equal(t, http.StatusOK, rec.Code) 271 }) 272 273 t.Run("allows small form body", func(t *testing.T) { 274 body := strings.NewReader("name=test&value=123") 275 req := httptest.NewRequest(http.MethodPost, "/api/test", body) 276 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 277 rec := httptest.NewRecorder() 278 279 wrapped.ServeHTTP(rec, req) 280 assert.Equal(t, http.StatusOK, rec.Code) 281 }) 282 283 t.Run("handles nil body", func(t *testing.T) { 284 req := httptest.NewRequest(http.MethodGet, "/test", nil) 285 rec := httptest.NewRecorder() 286 287 wrapped.ServeHTTP(rec, req) 288 assert.Equal(t, http.StatusOK, rec.Code) 289 }) 290} 291 292func TestGetClientIP(t *testing.T) { 293 tests := []struct { 294 name string 295 xff string 296 xri string 297 remoteAddr string 298 expected string 299 }{ 300 { 301 name: "X-Forwarded-For single IP", 302 xff: "203.0.113.50", 303 remoteAddr: "127.0.0.1:1234", 304 expected: "203.0.113.50", 305 }, 306 { 307 name: "X-Forwarded-For multiple IPs", 308 xff: "203.0.113.50, 70.41.3.18, 150.172.238.178", 309 remoteAddr: "127.0.0.1:1234", 310 expected: "203.0.113.50", 311 }, 312 { 313 name: "X-Forwarded-For with whitespace", 314 xff: " 203.0.113.50 ", 315 remoteAddr: "127.0.0.1:1234", 316 expected: "203.0.113.50", 317 }, 318 { 319 name: "X-Real-IP", 320 xri: "198.51.100.178", 321 remoteAddr: "127.0.0.1:1234", 322 expected: "198.51.100.178", 323 }, 324 { 325 name: "X-Real-IP with whitespace", 326 xri: " 198.51.100.178 ", 327 remoteAddr: "127.0.0.1:1234", 328 expected: "198.51.100.178", 329 }, 330 { 331 name: "X-Forwarded-For takes precedence over X-Real-IP", 332 xff: "203.0.113.50", 333 xri: "198.51.100.178", 334 remoteAddr: "127.0.0.1:1234", 335 expected: "203.0.113.50", 336 }, 337 { 338 name: "fallback to RemoteAddr with port", 339 remoteAddr: "192.168.1.1:8080", 340 expected: "192.168.1.1", 341 }, 342 { 343 name: "fallback to RemoteAddr without port", 344 remoteAddr: "192.168.1.1", 345 expected: "192.168.1.1", 346 }, 347 } 348 349 for _, tt := range tests { 350 t.Run(tt.name, func(t *testing.T) { 351 req := httptest.NewRequest(http.MethodGet, "/", nil) 352 req.RemoteAddr = tt.remoteAddr 353 if tt.xff != "" { 354 req.Header.Set("X-Forwarded-For", tt.xff) 355 } 356 if tt.xri != "" { 357 req.Header.Set("X-Real-IP", tt.xri) 358 } 359 360 got := GetClientIP(req) 361 assert.Equal(t, tt.expected, got) 362 }) 363 } 364} 365 366func TestGenerateNonce(t *testing.T) { 367 t.Run("generates base64 string", func(t *testing.T) { 368 nonce, err := generateNonce() 369 require.NoError(t, err) 370 assert.NotEmpty(t, nonce) 371 // Base64 of 16 bytes = 24 chars 372 assert.Len(t, nonce, 24) 373 }) 374 375 t.Run("generates unique values", func(t *testing.T) { 376 n1, err := generateNonce() 377 require.NoError(t, err) 378 n2, err := generateNonce() 379 require.NoError(t, err) 380 assert.NotEqual(t, n1, n2) 381 }) 382}