Coffee journaling on ATProto (alpha)
alpha.arabica.social
coffee
1package middleware
2
3import (
4 "context"
5 "net/http"
6 "net/http/httptest"
7 "strings"
8 "testing"
9 "time"
10
11 "github.com/stretchr/testify/assert"
12 "github.com/stretchr/testify/require"
13)
14
15func TestSecurityHeadersMiddleware(t *testing.T) {
16 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
17 // Verify nonce is available in context
18 nonce := CSPNonceFromContext(r.Context())
19 assert.NotEmpty(t, nonce, "nonce should be set in context")
20 w.WriteHeader(http.StatusOK)
21 })
22
23 wrapped := SecurityHeadersMiddleware(handler)
24 req := httptest.NewRequest(http.MethodGet, "/", nil)
25 rec := httptest.NewRecorder()
26
27 wrapped.ServeHTTP(rec, req)
28
29 assert.Equal(t, http.StatusOK, rec.Code)
30 assert.Equal(t, "DENY", rec.Header().Get("X-Frame-Options"))
31 assert.Equal(t, "nosniff", rec.Header().Get("X-Content-Type-Options"))
32 assert.Equal(t, "1; mode=block", rec.Header().Get("X-XSS-Protection"))
33 assert.Equal(t, "strict-origin-when-cross-origin", rec.Header().Get("Referrer-Policy"))
34 assert.Equal(t, "geolocation=(), microphone=(), camera=()", rec.Header().Get("Permissions-Policy"))
35
36 csp := rec.Header().Get("Content-Security-Policy")
37 assert.Contains(t, csp, "default-src 'self'")
38 assert.Contains(t, csp, "script-src 'self' 'unsafe-eval' 'nonce-")
39 assert.Contains(t, csp, "frame-ancestors 'none'")
40}
41
42func TestCSPNonceFromContext(t *testing.T) {
43 t.Run("returns nonce when set", func(t *testing.T) {
44 ctx := context.WithValue(context.Background(), cspNonceKey, "test-nonce-123")
45 assert.Equal(t, "test-nonce-123", CSPNonceFromContext(ctx))
46 })
47
48 t.Run("returns empty string when not set", func(t *testing.T) {
49 assert.Equal(t, "", CSPNonceFromContext(context.Background()))
50 })
51
52 t.Run("returns empty string for wrong type", func(t *testing.T) {
53 ctx := context.WithValue(context.Background(), cspNonceKey, 12345)
54 assert.Equal(t, "", CSPNonceFromContext(ctx))
55 })
56}
57
58func TestCSPNonceUniqueness(t *testing.T) {
59 nonces := make(map[string]bool)
60 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
61 nonce := CSPNonceFromContext(r.Context())
62 nonces[nonce] = true
63 w.WriteHeader(http.StatusOK)
64 })
65
66 wrapped := SecurityHeadersMiddleware(handler)
67
68 for i := 0; i < 10; i++ {
69 req := httptest.NewRequest(http.MethodGet, "/", nil)
70 rec := httptest.NewRecorder()
71 wrapped.ServeHTTP(rec, req)
72 }
73
74 assert.Len(t, nonces, 10, "each request should get a unique nonce")
75}
76
77func TestRateLimiter_Allow(t *testing.T) {
78 t.Run("allows requests within limit", func(t *testing.T) {
79 rl := &RateLimiter{
80 visitors: make(map[string]*visitor),
81 rate: 3,
82 window: time.Minute,
83 cleanup: 2 * time.Minute,
84 }
85
86 assert.True(t, rl.Allow("192.168.1.1"))
87 assert.True(t, rl.Allow("192.168.1.1"))
88 assert.True(t, rl.Allow("192.168.1.1"))
89 })
90
91 t.Run("blocks after exceeding limit", func(t *testing.T) {
92 rl := &RateLimiter{
93 visitors: make(map[string]*visitor),
94 rate: 2,
95 window: time.Minute,
96 cleanup: 2 * time.Minute,
97 }
98
99 assert.True(t, rl.Allow("10.0.0.1"))
100 assert.True(t, rl.Allow("10.0.0.1"))
101 assert.False(t, rl.Allow("10.0.0.1"))
102 })
103
104 t.Run("different IPs are independent", func(t *testing.T) {
105 rl := &RateLimiter{
106 visitors: make(map[string]*visitor),
107 rate: 1,
108 window: time.Minute,
109 cleanup: 2 * time.Minute,
110 }
111
112 assert.True(t, rl.Allow("10.0.0.1"))
113 assert.False(t, rl.Allow("10.0.0.1"))
114 assert.True(t, rl.Allow("10.0.0.2"))
115 })
116
117 t.Run("resets after window expires", func(t *testing.T) {
118 rl := &RateLimiter{
119 visitors: make(map[string]*visitor),
120 rate: 1,
121 window: 50 * time.Millisecond,
122 cleanup: 100 * time.Millisecond,
123 }
124
125 assert.True(t, rl.Allow("10.0.0.1"))
126 assert.False(t, rl.Allow("10.0.0.1"))
127
128 time.Sleep(60 * time.Millisecond)
129 assert.True(t, rl.Allow("10.0.0.1"))
130 })
131}
132
133func TestRateLimitMiddleware(t *testing.T) {
134 config := &RateLimitConfig{
135 AuthLimiter: &RateLimiter{visitors: make(map[string]*visitor), rate: 2, window: time.Minute, cleanup: 2 * time.Minute},
136 APILimiter: &RateLimiter{visitors: make(map[string]*visitor), rate: 3, window: time.Minute, cleanup: 2 * time.Minute},
137 GlobalLimiter: &RateLimiter{visitors: make(map[string]*visitor), rate: 5, window: time.Minute, cleanup: 2 * time.Minute},
138 }
139
140 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
141 w.WriteHeader(http.StatusOK)
142 })
143
144 middleware := RateLimitMiddleware(config)
145 wrapped := middleware(handler)
146
147 t.Run("auth endpoints use auth limiter", func(t *testing.T) {
148 for i := 0; i < 2; i++ {
149 req := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
150 req.RemoteAddr = "1.1.1.1:1234"
151 rec := httptest.NewRecorder()
152 wrapped.ServeHTTP(rec, req)
153 assert.Equal(t, http.StatusOK, rec.Code)
154 }
155
156 req := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
157 req.RemoteAddr = "1.1.1.1:1234"
158 rec := httptest.NewRecorder()
159 wrapped.ServeHTTP(rec, req)
160 assert.Equal(t, http.StatusTooManyRequests, rec.Code)
161 assert.Equal(t, "60", rec.Header().Get("Retry-After"))
162 })
163
164 t.Run("api endpoints use api limiter", func(t *testing.T) {
165 for i := 0; i < 3; i++ {
166 req := httptest.NewRequest(http.MethodGet, "/api/brews", nil)
167 req.RemoteAddr = "2.2.2.2:1234"
168 rec := httptest.NewRecorder()
169 wrapped.ServeHTTP(rec, req)
170 assert.Equal(t, http.StatusOK, rec.Code)
171 }
172
173 req := httptest.NewRequest(http.MethodGet, "/api/brews", nil)
174 req.RemoteAddr = "2.2.2.2:1234"
175 rec := httptest.NewRecorder()
176 wrapped.ServeHTTP(rec, req)
177 assert.Equal(t, http.StatusTooManyRequests, rec.Code)
178 })
179
180 t.Run("other endpoints use global limiter", func(t *testing.T) {
181 for i := 0; i < 5; i++ {
182 req := httptest.NewRequest(http.MethodGet, "/brews", nil)
183 req.RemoteAddr = "3.3.3.3:1234"
184 rec := httptest.NewRecorder()
185 wrapped.ServeHTTP(rec, req)
186 assert.Equal(t, http.StatusOK, rec.Code)
187 }
188
189 req := httptest.NewRequest(http.MethodGet, "/brews", nil)
190 req.RemoteAddr = "3.3.3.3:1234"
191 rec := httptest.NewRecorder()
192 wrapped.ServeHTTP(rec, req)
193 assert.Equal(t, http.StatusTooManyRequests, rec.Code)
194 })
195
196 t.Run("login path uses auth limiter", func(t *testing.T) {
197 for i := 0; i < 2; i++ {
198 req := httptest.NewRequest(http.MethodPost, "/login", nil)
199 req.RemoteAddr = "4.4.4.4:1234"
200 rec := httptest.NewRecorder()
201 wrapped.ServeHTTP(rec, req)
202 assert.Equal(t, http.StatusOK, rec.Code)
203 }
204
205 req := httptest.NewRequest(http.MethodPost, "/login", nil)
206 req.RemoteAddr = "4.4.4.4:1234"
207 rec := httptest.NewRecorder()
208 wrapped.ServeHTTP(rec, req)
209 assert.Equal(t, http.StatusTooManyRequests, rec.Code)
210 })
211}
212
213func TestRequireHTMXMiddleware(t *testing.T) {
214 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
215 w.WriteHeader(http.StatusOK)
216 w.Write([]byte("OK"))
217 })
218
219 wrapped := RequireHTMXMiddleware(handler)
220
221 t.Run("allows HTMX requests", func(t *testing.T) {
222 req := httptest.NewRequest(http.MethodGet, "/api/partial", nil)
223 req.Header.Set("HX-Request", "true")
224 rec := httptest.NewRecorder()
225
226 wrapped.ServeHTTP(rec, req)
227 assert.Equal(t, http.StatusOK, rec.Code)
228 assert.Equal(t, "OK", rec.Body.String())
229 })
230
231 t.Run("blocks non-HTMX requests", func(t *testing.T) {
232 req := httptest.NewRequest(http.MethodGet, "/api/partial", nil)
233 rec := httptest.NewRecorder()
234
235 wrapped.ServeHTTP(rec, req)
236 assert.Equal(t, http.StatusNotFound, rec.Code)
237 })
238
239 t.Run("blocks wrong HX-Request value", func(t *testing.T) {
240 req := httptest.NewRequest(http.MethodGet, "/api/partial", nil)
241 req.Header.Set("HX-Request", "false")
242 rec := httptest.NewRecorder()
243
244 wrapped.ServeHTTP(rec, req)
245 assert.Equal(t, http.StatusNotFound, rec.Code)
246 })
247}
248
249func TestLimitBodyMiddleware(t *testing.T) {
250 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
251 // Try to read the body
252 buf := make([]byte, 2<<20) // 2MB buffer
253 _, err := r.Body.Read(buf)
254 if err != nil && err.Error() != "EOF" {
255 http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
256 return
257 }
258 w.WriteHeader(http.StatusOK)
259 })
260
261 wrapped := LimitBodyMiddleware(handler)
262
263 t.Run("allows small JSON body", func(t *testing.T) {
264 body := strings.NewReader(`{"name": "test"}`)
265 req := httptest.NewRequest(http.MethodPost, "/api/test", body)
266 req.Header.Set("Content-Type", "application/json")
267 rec := httptest.NewRecorder()
268
269 wrapped.ServeHTTP(rec, req)
270 assert.Equal(t, http.StatusOK, rec.Code)
271 })
272
273 t.Run("allows small form body", func(t *testing.T) {
274 body := strings.NewReader("name=test&value=123")
275 req := httptest.NewRequest(http.MethodPost, "/api/test", body)
276 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
277 rec := httptest.NewRecorder()
278
279 wrapped.ServeHTTP(rec, req)
280 assert.Equal(t, http.StatusOK, rec.Code)
281 })
282
283 t.Run("handles nil body", func(t *testing.T) {
284 req := httptest.NewRequest(http.MethodGet, "/test", nil)
285 rec := httptest.NewRecorder()
286
287 wrapped.ServeHTTP(rec, req)
288 assert.Equal(t, http.StatusOK, rec.Code)
289 })
290}
291
292func TestGetClientIP(t *testing.T) {
293 tests := []struct {
294 name string
295 xff string
296 xri string
297 remoteAddr string
298 expected string
299 }{
300 {
301 name: "X-Forwarded-For single IP",
302 xff: "203.0.113.50",
303 remoteAddr: "127.0.0.1:1234",
304 expected: "203.0.113.50",
305 },
306 {
307 name: "X-Forwarded-For multiple IPs",
308 xff: "203.0.113.50, 70.41.3.18, 150.172.238.178",
309 remoteAddr: "127.0.0.1:1234",
310 expected: "203.0.113.50",
311 },
312 {
313 name: "X-Forwarded-For with whitespace",
314 xff: " 203.0.113.50 ",
315 remoteAddr: "127.0.0.1:1234",
316 expected: "203.0.113.50",
317 },
318 {
319 name: "X-Real-IP",
320 xri: "198.51.100.178",
321 remoteAddr: "127.0.0.1:1234",
322 expected: "198.51.100.178",
323 },
324 {
325 name: "X-Real-IP with whitespace",
326 xri: " 198.51.100.178 ",
327 remoteAddr: "127.0.0.1:1234",
328 expected: "198.51.100.178",
329 },
330 {
331 name: "X-Forwarded-For takes precedence over X-Real-IP",
332 xff: "203.0.113.50",
333 xri: "198.51.100.178",
334 remoteAddr: "127.0.0.1:1234",
335 expected: "203.0.113.50",
336 },
337 {
338 name: "fallback to RemoteAddr with port",
339 remoteAddr: "192.168.1.1:8080",
340 expected: "192.168.1.1",
341 },
342 {
343 name: "fallback to RemoteAddr without port",
344 remoteAddr: "192.168.1.1",
345 expected: "192.168.1.1",
346 },
347 }
348
349 for _, tt := range tests {
350 t.Run(tt.name, func(t *testing.T) {
351 req := httptest.NewRequest(http.MethodGet, "/", nil)
352 req.RemoteAddr = tt.remoteAddr
353 if tt.xff != "" {
354 req.Header.Set("X-Forwarded-For", tt.xff)
355 }
356 if tt.xri != "" {
357 req.Header.Set("X-Real-IP", tt.xri)
358 }
359
360 got := GetClientIP(req)
361 assert.Equal(t, tt.expected, got)
362 })
363 }
364}
365
366func TestGenerateNonce(t *testing.T) {
367 t.Run("generates base64 string", func(t *testing.T) {
368 nonce, err := generateNonce()
369 require.NoError(t, err)
370 assert.NotEmpty(t, nonce)
371 // Base64 of 16 bytes = 24 chars
372 assert.Len(t, nonce, 24)
373 })
374
375 t.Run("generates unique values", func(t *testing.T) {
376 n1, err := generateNonce()
377 require.NoError(t, err)
378 n2, err := generateNonce()
379 require.NoError(t, err)
380 assert.NotEqual(t, n1, n2)
381 })
382}