Coffee journaling on ATProto (alpha) alpha.arabica.social
coffee

test: add new tests

pdewey.com 39d01397 d33958e0

verified
+1127
+139
internal/database/boltstore/feed_store_test.go
··· 1 + package boltstore 2 + 3 + import ( 4 + "path/filepath" 5 + "testing" 6 + 7 + "github.com/stretchr/testify/assert" 8 + "github.com/stretchr/testify/require" 9 + ) 10 + 11 + func setupTestFeedStore(t *testing.T) *FeedStore { 12 + tmpDir := t.TempDir() 13 + dbPath := filepath.Join(tmpDir, "test.db") 14 + 15 + store, err := Open(Options{Path: dbPath}) 16 + require.NoError(t, err) 17 + 18 + t.Cleanup(func() { 19 + store.Close() 20 + }) 21 + 22 + return store.FeedStore() 23 + } 24 + 25 + func TestFeedStore_Register(t *testing.T) { 26 + store := setupTestFeedStore(t) 27 + 28 + t.Run("register new DID", func(t *testing.T) { 29 + err := store.Register("did:plc:user1") 30 + require.NoError(t, err) 31 + assert.True(t, store.IsRegistered("did:plc:user1")) 32 + }) 33 + 34 + t.Run("register is idempotent", func(t *testing.T) { 35 + err := store.Register("did:plc:user2") 36 + require.NoError(t, err) 37 + 38 + err = store.Register("did:plc:user2") 39 + require.NoError(t, err) 40 + 41 + assert.Equal(t, 1, countDID(store, "did:plc:user2")) 42 + }) 43 + } 44 + 45 + // countDID counts how many times a DID appears in the list (should be 0 or 1). 46 + func countDID(store *FeedStore, did string) int { 47 + count := 0 48 + for _, d := range store.List() { 49 + if d == did { 50 + count++ 51 + } 52 + } 53 + return count 54 + } 55 + 56 + func TestFeedStore_Unregister(t *testing.T) { 57 + store := setupTestFeedStore(t) 58 + 59 + err := store.Register("did:plc:unreg") 60 + require.NoError(t, err) 61 + assert.True(t, store.IsRegistered("did:plc:unreg")) 62 + 63 + err = store.Unregister("did:plc:unreg") 64 + require.NoError(t, err) 65 + assert.False(t, store.IsRegistered("did:plc:unreg")) 66 + } 67 + 68 + func TestFeedStore_IsRegistered(t *testing.T) { 69 + store := setupTestFeedStore(t) 70 + 71 + assert.False(t, store.IsRegistered("did:plc:nobody")) 72 + 73 + store.Register("did:plc:somebody") 74 + assert.True(t, store.IsRegistered("did:plc:somebody")) 75 + } 76 + 77 + func TestFeedStore_List(t *testing.T) { 78 + store := setupTestFeedStore(t) 79 + 80 + t.Run("empty store", func(t *testing.T) { 81 + dids := store.List() 82 + assert.Empty(t, dids) 83 + }) 84 + 85 + t.Run("multiple registrations", func(t *testing.T) { 86 + store.Register("did:plc:a") 87 + store.Register("did:plc:b") 88 + store.Register("did:plc:c") 89 + 90 + dids := store.List() 91 + assert.Len(t, dids, 3) 92 + assert.Contains(t, dids, "did:plc:a") 93 + assert.Contains(t, dids, "did:plc:b") 94 + assert.Contains(t, dids, "did:plc:c") 95 + }) 96 + } 97 + 98 + func TestFeedStore_ListWithMetadata(t *testing.T) { 99 + store := setupTestFeedStore(t) 100 + 101 + store.Register("did:plc:meta1") 102 + store.Register("did:plc:meta2") 103 + 104 + users := store.ListWithMetadata() 105 + assert.Len(t, users, 2) 106 + 107 + for _, u := range users { 108 + assert.NotEmpty(t, u.DID) 109 + assert.False(t, u.RegisteredAt.IsZero()) 110 + } 111 + } 112 + 113 + func TestFeedStore_Count(t *testing.T) { 114 + store := setupTestFeedStore(t) 115 + 116 + assert.Equal(t, 0, store.Count()) 117 + 118 + store.Register("did:plc:c1") 119 + assert.Equal(t, 1, store.Count()) 120 + 121 + store.Register("did:plc:c2") 122 + assert.Equal(t, 2, store.Count()) 123 + 124 + store.Unregister("did:plc:c1") 125 + assert.Equal(t, 1, store.Count()) 126 + } 127 + 128 + func TestFeedStore_Clear(t *testing.T) { 129 + store := setupTestFeedStore(t) 130 + 131 + store.Register("did:plc:clear1") 132 + store.Register("did:plc:clear2") 133 + assert.Equal(t, 2, store.Count()) 134 + 135 + err := store.Clear() 136 + require.NoError(t, err) 137 + assert.Equal(t, 0, store.Count()) 138 + assert.False(t, store.IsRegistered("did:plc:clear1")) 139 + }
+136
internal/database/boltstore/join_store_test.go
··· 1 + package boltstore 2 + 3 + import ( 4 + "path/filepath" 5 + "testing" 6 + "time" 7 + 8 + "github.com/stretchr/testify/assert" 9 + "github.com/stretchr/testify/require" 10 + ) 11 + 12 + func setupTestJoinStore(t *testing.T) *JoinStore { 13 + tmpDir := t.TempDir() 14 + dbPath := filepath.Join(tmpDir, "test.db") 15 + 16 + store, err := Open(Options{Path: dbPath}) 17 + require.NoError(t, err) 18 + 19 + t.Cleanup(func() { 20 + store.Close() 21 + }) 22 + 23 + return store.JoinStore() 24 + } 25 + 26 + func TestJoinStore_SaveAndGet(t *testing.T) { 27 + store := setupTestJoinStore(t) 28 + 29 + req := &JoinRequest{ 30 + ID: "join-001", 31 + Email: "user@example.com", 32 + Message: "I love coffee!", 33 + CreatedAt: time.Now().Truncate(time.Millisecond), 34 + IP: "203.0.113.50", 35 + } 36 + 37 + err := store.SaveRequest(req) 38 + require.NoError(t, err) 39 + 40 + retrieved, err := store.GetRequest("join-001") 41 + require.NoError(t, err) 42 + require.NotNil(t, retrieved) 43 + 44 + assert.Equal(t, req.ID, retrieved.ID) 45 + assert.Equal(t, req.Email, retrieved.Email) 46 + assert.Equal(t, req.Message, retrieved.Message) 47 + assert.Equal(t, req.IP, retrieved.IP) 48 + assert.True(t, req.CreatedAt.Equal(retrieved.CreatedAt)) 49 + } 50 + 51 + func TestJoinStore_GetNotFound(t *testing.T) { 52 + store := setupTestJoinStore(t) 53 + 54 + retrieved, err := store.GetRequest("nonexistent") 55 + assert.Error(t, err) 56 + assert.Nil(t, retrieved) 57 + assert.Contains(t, err.Error(), "not found") 58 + } 59 + 60 + func TestJoinStore_Delete(t *testing.T) { 61 + store := setupTestJoinStore(t) 62 + 63 + req := &JoinRequest{ 64 + ID: "join-del", 65 + Email: "delete@example.com", 66 + CreatedAt: time.Now(), 67 + IP: "10.0.0.1", 68 + } 69 + 70 + err := store.SaveRequest(req) 71 + require.NoError(t, err) 72 + 73 + err = store.DeleteRequest("join-del") 74 + require.NoError(t, err) 75 + 76 + retrieved, err := store.GetRequest("join-del") 77 + assert.Error(t, err) 78 + assert.Nil(t, retrieved) 79 + } 80 + 81 + func TestJoinStore_DeleteNonexistent(t *testing.T) { 82 + store := setupTestJoinStore(t) 83 + 84 + // Deleting a non-existent request should not error 85 + err := store.DeleteRequest("nonexistent") 86 + assert.NoError(t, err) 87 + } 88 + 89 + func TestJoinStore_ListRequests(t *testing.T) { 90 + store := setupTestJoinStore(t) 91 + 92 + t.Run("empty store", func(t *testing.T) { 93 + requests, err := store.ListRequests() 94 + require.NoError(t, err) 95 + assert.Empty(t, requests) 96 + }) 97 + 98 + t.Run("multiple requests", func(t *testing.T) { 99 + for i, email := range []string{"a@test.com", "b@test.com", "c@test.com"} { 100 + req := &JoinRequest{ 101 + ID: "list-" + string(rune('0'+i)), 102 + Email: email, 103 + CreatedAt: time.Now(), 104 + IP: "10.0.0.1", 105 + } 106 + require.NoError(t, store.SaveRequest(req)) 107 + } 108 + 109 + requests, err := store.ListRequests() 110 + require.NoError(t, err) 111 + assert.Len(t, requests, 3) 112 + }) 113 + } 114 + 115 + func TestJoinStore_SaveOverwrites(t *testing.T) { 116 + store := setupTestJoinStore(t) 117 + 118 + req := &JoinRequest{ 119 + ID: "join-overwrite", 120 + Email: "original@example.com", 121 + CreatedAt: time.Now(), 122 + IP: "10.0.0.1", 123 + } 124 + 125 + err := store.SaveRequest(req) 126 + require.NoError(t, err) 127 + 128 + // Save again with updated email 129 + req.Email = "updated@example.com" 130 + err = store.SaveRequest(req) 131 + require.NoError(t, err) 132 + 133 + retrieved, err := store.GetRequest("join-overwrite") 134 + require.NoError(t, err) 135 + assert.Equal(t, "updated@example.com", retrieved.Email) 136 + }
+382
internal/middleware/security_test.go
··· 1 + package middleware 2 + 3 + import ( 4 + "context" 5 + "net/http" 6 + "net/http/httptest" 7 + "strings" 8 + "testing" 9 + "time" 10 + 11 + "github.com/stretchr/testify/assert" 12 + "github.com/stretchr/testify/require" 13 + ) 14 + 15 + func TestSecurityHeadersMiddleware(t *testing.T) { 16 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 17 + // Verify nonce is available in context 18 + nonce := CSPNonceFromContext(r.Context()) 19 + assert.NotEmpty(t, nonce, "nonce should be set in context") 20 + w.WriteHeader(http.StatusOK) 21 + }) 22 + 23 + wrapped := SecurityHeadersMiddleware(handler) 24 + req := httptest.NewRequest(http.MethodGet, "/", nil) 25 + rec := httptest.NewRecorder() 26 + 27 + wrapped.ServeHTTP(rec, req) 28 + 29 + assert.Equal(t, http.StatusOK, rec.Code) 30 + assert.Equal(t, "DENY", rec.Header().Get("X-Frame-Options")) 31 + assert.Equal(t, "nosniff", rec.Header().Get("X-Content-Type-Options")) 32 + assert.Equal(t, "1; mode=block", rec.Header().Get("X-XSS-Protection")) 33 + assert.Equal(t, "strict-origin-when-cross-origin", rec.Header().Get("Referrer-Policy")) 34 + assert.Equal(t, "geolocation=(), microphone=(), camera=()", rec.Header().Get("Permissions-Policy")) 35 + 36 + csp := rec.Header().Get("Content-Security-Policy") 37 + assert.Contains(t, csp, "default-src 'self'") 38 + assert.Contains(t, csp, "script-src 'self' 'unsafe-eval' 'nonce-") 39 + assert.Contains(t, csp, "frame-ancestors 'none'") 40 + } 41 + 42 + func TestCSPNonceFromContext(t *testing.T) { 43 + t.Run("returns nonce when set", func(t *testing.T) { 44 + ctx := context.WithValue(context.Background(), cspNonceKey, "test-nonce-123") 45 + assert.Equal(t, "test-nonce-123", CSPNonceFromContext(ctx)) 46 + }) 47 + 48 + t.Run("returns empty string when not set", func(t *testing.T) { 49 + assert.Equal(t, "", CSPNonceFromContext(context.Background())) 50 + }) 51 + 52 + t.Run("returns empty string for wrong type", func(t *testing.T) { 53 + ctx := context.WithValue(context.Background(), cspNonceKey, 12345) 54 + assert.Equal(t, "", CSPNonceFromContext(ctx)) 55 + }) 56 + } 57 + 58 + func TestCSPNonceUniqueness(t *testing.T) { 59 + nonces := make(map[string]bool) 60 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 61 + nonce := CSPNonceFromContext(r.Context()) 62 + nonces[nonce] = true 63 + w.WriteHeader(http.StatusOK) 64 + }) 65 + 66 + wrapped := SecurityHeadersMiddleware(handler) 67 + 68 + for i := 0; i < 10; i++ { 69 + req := httptest.NewRequest(http.MethodGet, "/", nil) 70 + rec := httptest.NewRecorder() 71 + wrapped.ServeHTTP(rec, req) 72 + } 73 + 74 + assert.Len(t, nonces, 10, "each request should get a unique nonce") 75 + } 76 + 77 + func TestRateLimiter_Allow(t *testing.T) { 78 + t.Run("allows requests within limit", func(t *testing.T) { 79 + rl := &RateLimiter{ 80 + visitors: make(map[string]*visitor), 81 + rate: 3, 82 + window: time.Minute, 83 + cleanup: 2 * time.Minute, 84 + } 85 + 86 + assert.True(t, rl.Allow("192.168.1.1")) 87 + assert.True(t, rl.Allow("192.168.1.1")) 88 + assert.True(t, rl.Allow("192.168.1.1")) 89 + }) 90 + 91 + t.Run("blocks after exceeding limit", func(t *testing.T) { 92 + rl := &RateLimiter{ 93 + visitors: make(map[string]*visitor), 94 + rate: 2, 95 + window: time.Minute, 96 + cleanup: 2 * time.Minute, 97 + } 98 + 99 + assert.True(t, rl.Allow("10.0.0.1")) 100 + assert.True(t, rl.Allow("10.0.0.1")) 101 + assert.False(t, rl.Allow("10.0.0.1")) 102 + }) 103 + 104 + t.Run("different IPs are independent", func(t *testing.T) { 105 + rl := &RateLimiter{ 106 + visitors: make(map[string]*visitor), 107 + rate: 1, 108 + window: time.Minute, 109 + cleanup: 2 * time.Minute, 110 + } 111 + 112 + assert.True(t, rl.Allow("10.0.0.1")) 113 + assert.False(t, rl.Allow("10.0.0.1")) 114 + assert.True(t, rl.Allow("10.0.0.2")) 115 + }) 116 + 117 + t.Run("resets after window expires", func(t *testing.T) { 118 + rl := &RateLimiter{ 119 + visitors: make(map[string]*visitor), 120 + rate: 1, 121 + window: 50 * time.Millisecond, 122 + cleanup: 100 * time.Millisecond, 123 + } 124 + 125 + assert.True(t, rl.Allow("10.0.0.1")) 126 + assert.False(t, rl.Allow("10.0.0.1")) 127 + 128 + time.Sleep(60 * time.Millisecond) 129 + assert.True(t, rl.Allow("10.0.0.1")) 130 + }) 131 + } 132 + 133 + func TestRateLimitMiddleware(t *testing.T) { 134 + config := &RateLimitConfig{ 135 + AuthLimiter: &RateLimiter{visitors: make(map[string]*visitor), rate: 2, window: time.Minute, cleanup: 2 * time.Minute}, 136 + APILimiter: &RateLimiter{visitors: make(map[string]*visitor), rate: 3, window: time.Minute, cleanup: 2 * time.Minute}, 137 + GlobalLimiter: &RateLimiter{visitors: make(map[string]*visitor), rate: 5, window: time.Minute, cleanup: 2 * time.Minute}, 138 + } 139 + 140 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 141 + w.WriteHeader(http.StatusOK) 142 + }) 143 + 144 + middleware := RateLimitMiddleware(config) 145 + wrapped := middleware(handler) 146 + 147 + t.Run("auth endpoints use auth limiter", func(t *testing.T) { 148 + for i := 0; i < 2; i++ { 149 + req := httptest.NewRequest(http.MethodPost, "/auth/login", nil) 150 + req.RemoteAddr = "1.1.1.1:1234" 151 + rec := httptest.NewRecorder() 152 + wrapped.ServeHTTP(rec, req) 153 + assert.Equal(t, http.StatusOK, rec.Code) 154 + } 155 + 156 + req := httptest.NewRequest(http.MethodPost, "/auth/login", nil) 157 + req.RemoteAddr = "1.1.1.1:1234" 158 + rec := httptest.NewRecorder() 159 + wrapped.ServeHTTP(rec, req) 160 + assert.Equal(t, http.StatusTooManyRequests, rec.Code) 161 + assert.Equal(t, "60", rec.Header().Get("Retry-After")) 162 + }) 163 + 164 + t.Run("api endpoints use api limiter", func(t *testing.T) { 165 + for i := 0; i < 3; i++ { 166 + req := httptest.NewRequest(http.MethodGet, "/api/brews", nil) 167 + req.RemoteAddr = "2.2.2.2:1234" 168 + rec := httptest.NewRecorder() 169 + wrapped.ServeHTTP(rec, req) 170 + assert.Equal(t, http.StatusOK, rec.Code) 171 + } 172 + 173 + req := httptest.NewRequest(http.MethodGet, "/api/brews", nil) 174 + req.RemoteAddr = "2.2.2.2:1234" 175 + rec := httptest.NewRecorder() 176 + wrapped.ServeHTTP(rec, req) 177 + assert.Equal(t, http.StatusTooManyRequests, rec.Code) 178 + }) 179 + 180 + t.Run("other endpoints use global limiter", func(t *testing.T) { 181 + for i := 0; i < 5; i++ { 182 + req := httptest.NewRequest(http.MethodGet, "/brews", nil) 183 + req.RemoteAddr = "3.3.3.3:1234" 184 + rec := httptest.NewRecorder() 185 + wrapped.ServeHTTP(rec, req) 186 + assert.Equal(t, http.StatusOK, rec.Code) 187 + } 188 + 189 + req := httptest.NewRequest(http.MethodGet, "/brews", nil) 190 + req.RemoteAddr = "3.3.3.3:1234" 191 + rec := httptest.NewRecorder() 192 + wrapped.ServeHTTP(rec, req) 193 + assert.Equal(t, http.StatusTooManyRequests, rec.Code) 194 + }) 195 + 196 + t.Run("login path uses auth limiter", func(t *testing.T) { 197 + for i := 0; i < 2; i++ { 198 + req := httptest.NewRequest(http.MethodPost, "/login", nil) 199 + req.RemoteAddr = "4.4.4.4:1234" 200 + rec := httptest.NewRecorder() 201 + wrapped.ServeHTTP(rec, req) 202 + assert.Equal(t, http.StatusOK, rec.Code) 203 + } 204 + 205 + req := httptest.NewRequest(http.MethodPost, "/login", nil) 206 + req.RemoteAddr = "4.4.4.4:1234" 207 + rec := httptest.NewRecorder() 208 + wrapped.ServeHTTP(rec, req) 209 + assert.Equal(t, http.StatusTooManyRequests, rec.Code) 210 + }) 211 + } 212 + 213 + func TestRequireHTMXMiddleware(t *testing.T) { 214 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 215 + w.WriteHeader(http.StatusOK) 216 + w.Write([]byte("OK")) 217 + }) 218 + 219 + wrapped := RequireHTMXMiddleware(handler) 220 + 221 + t.Run("allows HTMX requests", func(t *testing.T) { 222 + req := httptest.NewRequest(http.MethodGet, "/api/partial", nil) 223 + req.Header.Set("HX-Request", "true") 224 + rec := httptest.NewRecorder() 225 + 226 + wrapped.ServeHTTP(rec, req) 227 + assert.Equal(t, http.StatusOK, rec.Code) 228 + assert.Equal(t, "OK", rec.Body.String()) 229 + }) 230 + 231 + t.Run("blocks non-HTMX requests", func(t *testing.T) { 232 + req := httptest.NewRequest(http.MethodGet, "/api/partial", nil) 233 + rec := httptest.NewRecorder() 234 + 235 + wrapped.ServeHTTP(rec, req) 236 + assert.Equal(t, http.StatusNotFound, rec.Code) 237 + }) 238 + 239 + t.Run("blocks wrong HX-Request value", func(t *testing.T) { 240 + req := httptest.NewRequest(http.MethodGet, "/api/partial", nil) 241 + req.Header.Set("HX-Request", "false") 242 + rec := httptest.NewRecorder() 243 + 244 + wrapped.ServeHTTP(rec, req) 245 + assert.Equal(t, http.StatusNotFound, rec.Code) 246 + }) 247 + } 248 + 249 + func TestLimitBodyMiddleware(t *testing.T) { 250 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 251 + // Try to read the body 252 + buf := make([]byte, 2<<20) // 2MB buffer 253 + _, err := r.Body.Read(buf) 254 + if err != nil && err.Error() != "EOF" { 255 + http.Error(w, "body too large", http.StatusRequestEntityTooLarge) 256 + return 257 + } 258 + w.WriteHeader(http.StatusOK) 259 + }) 260 + 261 + wrapped := LimitBodyMiddleware(handler) 262 + 263 + t.Run("allows small JSON body", func(t *testing.T) { 264 + body := strings.NewReader(`{"name": "test"}`) 265 + req := httptest.NewRequest(http.MethodPost, "/api/test", body) 266 + req.Header.Set("Content-Type", "application/json") 267 + rec := httptest.NewRecorder() 268 + 269 + wrapped.ServeHTTP(rec, req) 270 + assert.Equal(t, http.StatusOK, rec.Code) 271 + }) 272 + 273 + t.Run("allows small form body", func(t *testing.T) { 274 + body := strings.NewReader("name=test&value=123") 275 + req := httptest.NewRequest(http.MethodPost, "/api/test", body) 276 + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 277 + rec := httptest.NewRecorder() 278 + 279 + wrapped.ServeHTTP(rec, req) 280 + assert.Equal(t, http.StatusOK, rec.Code) 281 + }) 282 + 283 + t.Run("handles nil body", func(t *testing.T) { 284 + req := httptest.NewRequest(http.MethodGet, "/test", nil) 285 + rec := httptest.NewRecorder() 286 + 287 + wrapped.ServeHTTP(rec, req) 288 + assert.Equal(t, http.StatusOK, rec.Code) 289 + }) 290 + } 291 + 292 + func TestGetClientIP(t *testing.T) { 293 + tests := []struct { 294 + name string 295 + xff string 296 + xri string 297 + remoteAddr string 298 + expected string 299 + }{ 300 + { 301 + name: "X-Forwarded-For single IP", 302 + xff: "203.0.113.50", 303 + remoteAddr: "127.0.0.1:1234", 304 + expected: "203.0.113.50", 305 + }, 306 + { 307 + name: "X-Forwarded-For multiple IPs", 308 + xff: "203.0.113.50, 70.41.3.18, 150.172.238.178", 309 + remoteAddr: "127.0.0.1:1234", 310 + expected: "203.0.113.50", 311 + }, 312 + { 313 + name: "X-Forwarded-For with whitespace", 314 + xff: " 203.0.113.50 ", 315 + remoteAddr: "127.0.0.1:1234", 316 + expected: "203.0.113.50", 317 + }, 318 + { 319 + name: "X-Real-IP", 320 + xri: "198.51.100.178", 321 + remoteAddr: "127.0.0.1:1234", 322 + expected: "198.51.100.178", 323 + }, 324 + { 325 + name: "X-Real-IP with whitespace", 326 + xri: " 198.51.100.178 ", 327 + remoteAddr: "127.0.0.1:1234", 328 + expected: "198.51.100.178", 329 + }, 330 + { 331 + name: "X-Forwarded-For takes precedence over X-Real-IP", 332 + xff: "203.0.113.50", 333 + xri: "198.51.100.178", 334 + remoteAddr: "127.0.0.1:1234", 335 + expected: "203.0.113.50", 336 + }, 337 + { 338 + name: "fallback to RemoteAddr with port", 339 + remoteAddr: "192.168.1.1:8080", 340 + expected: "192.168.1.1", 341 + }, 342 + { 343 + name: "fallback to RemoteAddr without port", 344 + remoteAddr: "192.168.1.1", 345 + expected: "192.168.1.1", 346 + }, 347 + } 348 + 349 + for _, tt := range tests { 350 + t.Run(tt.name, func(t *testing.T) { 351 + req := httptest.NewRequest(http.MethodGet, "/", nil) 352 + req.RemoteAddr = tt.remoteAddr 353 + if tt.xff != "" { 354 + req.Header.Set("X-Forwarded-For", tt.xff) 355 + } 356 + if tt.xri != "" { 357 + req.Header.Set("X-Real-IP", tt.xri) 358 + } 359 + 360 + got := GetClientIP(req) 361 + assert.Equal(t, tt.expected, got) 362 + }) 363 + } 364 + } 365 + 366 + func TestGenerateNonce(t *testing.T) { 367 + t.Run("generates base64 string", func(t *testing.T) { 368 + nonce, err := generateNonce() 369 + require.NoError(t, err) 370 + assert.NotEmpty(t, nonce) 371 + // Base64 of 16 bytes = 24 chars 372 + assert.Len(t, nonce, 24) 373 + }) 374 + 375 + t.Run("generates unique values", func(t *testing.T) { 376 + n1, err := generateNonce() 377 + require.NoError(t, err) 378 + n2, err := generateNonce() 379 + require.NoError(t, err) 380 + assert.NotEqual(t, n1, n2) 381 + }) 382 + }
+352
internal/models/models_test.go
··· 1 + package models 2 + 3 + import ( 4 + "strings" 5 + "testing" 6 + 7 + "github.com/stretchr/testify/assert" 8 + ) 9 + 10 + func TestCreateBeanRequest_Validate(t *testing.T) { 11 + t.Run("valid request", func(t *testing.T) { 12 + req := &CreateBeanRequest{Name: "Ethiopian Yirgacheffe"} 13 + assert.NoError(t, req.Validate()) 14 + }) 15 + 16 + t.Run("empty name", func(t *testing.T) { 17 + req := &CreateBeanRequest{Name: ""} 18 + assert.ErrorIs(t, req.Validate(), ErrNameRequired) 19 + }) 20 + 21 + t.Run("name too long", func(t *testing.T) { 22 + req := &CreateBeanRequest{Name: strings.Repeat("a", MaxNameLength+1)} 23 + assert.ErrorIs(t, req.Validate(), ErrNameTooLong) 24 + }) 25 + 26 + t.Run("name at max length", func(t *testing.T) { 27 + req := &CreateBeanRequest{Name: strings.Repeat("a", MaxNameLength)} 28 + assert.NoError(t, req.Validate()) 29 + }) 30 + 31 + t.Run("origin too long", func(t *testing.T) { 32 + req := &CreateBeanRequest{ 33 + Name: "Bean", 34 + Origin: strings.Repeat("a", MaxOriginLength+1), 35 + } 36 + assert.ErrorIs(t, req.Validate(), ErrOriginTooLong) 37 + }) 38 + 39 + t.Run("roast level too long", func(t *testing.T) { 40 + req := &CreateBeanRequest{ 41 + Name: "Bean", 42 + RoastLevel: strings.Repeat("a", MaxRoastLevelLength+1), 43 + } 44 + assert.ErrorIs(t, req.Validate(), ErrFieldTooLong) 45 + }) 46 + 47 + t.Run("process too long", func(t *testing.T) { 48 + req := &CreateBeanRequest{ 49 + Name: "Bean", 50 + Process: strings.Repeat("a", MaxProcessLength+1), 51 + } 52 + assert.ErrorIs(t, req.Validate(), ErrFieldTooLong) 53 + }) 54 + 55 + t.Run("description too long", func(t *testing.T) { 56 + req := &CreateBeanRequest{ 57 + Name: "Bean", 58 + Description: strings.Repeat("a", MaxDescriptionLength+1), 59 + } 60 + assert.ErrorIs(t, req.Validate(), ErrDescTooLong) 61 + }) 62 + 63 + t.Run("all optional fields populated", func(t *testing.T) { 64 + req := &CreateBeanRequest{ 65 + Name: "Ethiopian Yirgacheffe", 66 + Origin: "Ethiopia", 67 + RoastLevel: "Light", 68 + Process: "Washed", 69 + Description: "Fruity and floral notes", 70 + RoasterRKey: "abc123", 71 + } 72 + assert.NoError(t, req.Validate()) 73 + }) 74 + } 75 + 76 + func TestUpdateBeanRequest_Validate(t *testing.T) { 77 + t.Run("valid request", func(t *testing.T) { 78 + req := &UpdateBeanRequest{Name: "Updated Bean"} 79 + assert.NoError(t, req.Validate()) 80 + }) 81 + 82 + t.Run("empty name", func(t *testing.T) { 83 + req := &UpdateBeanRequest{Name: ""} 84 + assert.ErrorIs(t, req.Validate(), ErrNameRequired) 85 + }) 86 + 87 + t.Run("name too long", func(t *testing.T) { 88 + req := &UpdateBeanRequest{Name: strings.Repeat("a", MaxNameLength+1)} 89 + assert.ErrorIs(t, req.Validate(), ErrNameTooLong) 90 + }) 91 + 92 + t.Run("origin too long", func(t *testing.T) { 93 + req := &UpdateBeanRequest{ 94 + Name: "Bean", 95 + Origin: strings.Repeat("a", MaxOriginLength+1), 96 + } 97 + assert.ErrorIs(t, req.Validate(), ErrOriginTooLong) 98 + }) 99 + 100 + t.Run("description too long", func(t *testing.T) { 101 + req := &UpdateBeanRequest{ 102 + Name: "Bean", 103 + Description: strings.Repeat("a", MaxDescriptionLength+1), 104 + } 105 + assert.ErrorIs(t, req.Validate(), ErrDescTooLong) 106 + }) 107 + } 108 + 109 + func TestCreateRoasterRequest_Validate(t *testing.T) { 110 + t.Run("valid request", func(t *testing.T) { 111 + req := &CreateRoasterRequest{Name: "Blue Bottle"} 112 + assert.NoError(t, req.Validate()) 113 + }) 114 + 115 + t.Run("empty name", func(t *testing.T) { 116 + req := &CreateRoasterRequest{Name: ""} 117 + assert.ErrorIs(t, req.Validate(), ErrNameRequired) 118 + }) 119 + 120 + t.Run("name too long", func(t *testing.T) { 121 + req := &CreateRoasterRequest{Name: strings.Repeat("a", MaxNameLength+1)} 122 + assert.ErrorIs(t, req.Validate(), ErrNameTooLong) 123 + }) 124 + 125 + t.Run("location too long", func(t *testing.T) { 126 + req := &CreateRoasterRequest{ 127 + Name: "Roaster", 128 + Location: strings.Repeat("a", MaxLocationLength+1), 129 + } 130 + assert.ErrorIs(t, req.Validate(), ErrLocationTooLong) 131 + }) 132 + 133 + t.Run("website too long", func(t *testing.T) { 134 + req := &CreateRoasterRequest{ 135 + Name: "Roaster", 136 + Website: strings.Repeat("a", MaxWebsiteLength+1), 137 + } 138 + assert.ErrorIs(t, req.Validate(), ErrWebsiteTooLong) 139 + }) 140 + 141 + t.Run("all fields at max", func(t *testing.T) { 142 + req := &CreateRoasterRequest{ 143 + Name: strings.Repeat("a", MaxNameLength), 144 + Location: strings.Repeat("a", MaxLocationLength), 145 + Website: strings.Repeat("a", MaxWebsiteLength), 146 + } 147 + assert.NoError(t, req.Validate()) 148 + }) 149 + } 150 + 151 + func TestUpdateRoasterRequest_Validate(t *testing.T) { 152 + t.Run("valid request", func(t *testing.T) { 153 + req := &UpdateRoasterRequest{Name: "Updated Roaster"} 154 + assert.NoError(t, req.Validate()) 155 + }) 156 + 157 + t.Run("empty name", func(t *testing.T) { 158 + req := &UpdateRoasterRequest{Name: ""} 159 + assert.ErrorIs(t, req.Validate(), ErrNameRequired) 160 + }) 161 + 162 + t.Run("location too long", func(t *testing.T) { 163 + req := &UpdateRoasterRequest{ 164 + Name: "Roaster", 165 + Location: strings.Repeat("a", MaxLocationLength+1), 166 + } 167 + assert.ErrorIs(t, req.Validate(), ErrLocationTooLong) 168 + }) 169 + 170 + t.Run("website too long", func(t *testing.T) { 171 + req := &UpdateRoasterRequest{ 172 + Name: "Roaster", 173 + Website: strings.Repeat("a", MaxWebsiteLength+1), 174 + } 175 + assert.ErrorIs(t, req.Validate(), ErrWebsiteTooLong) 176 + }) 177 + } 178 + 179 + func TestCreateGrinderRequest_Validate(t *testing.T) { 180 + t.Run("valid request", func(t *testing.T) { 181 + req := &CreateGrinderRequest{Name: "Comandante C40"} 182 + assert.NoError(t, req.Validate()) 183 + }) 184 + 185 + t.Run("empty name", func(t *testing.T) { 186 + req := &CreateGrinderRequest{Name: ""} 187 + assert.ErrorIs(t, req.Validate(), ErrNameRequired) 188 + }) 189 + 190 + t.Run("name too long", func(t *testing.T) { 191 + req := &CreateGrinderRequest{Name: strings.Repeat("a", MaxNameLength+1)} 192 + assert.ErrorIs(t, req.Validate(), ErrNameTooLong) 193 + }) 194 + 195 + t.Run("grinder type too long", func(t *testing.T) { 196 + req := &CreateGrinderRequest{ 197 + Name: "Grinder", 198 + GrinderType: strings.Repeat("a", MaxGrinderTypeLength+1), 199 + } 200 + assert.ErrorIs(t, req.Validate(), ErrFieldTooLong) 201 + }) 202 + 203 + t.Run("burr type too long", func(t *testing.T) { 204 + req := &CreateGrinderRequest{ 205 + Name: "Grinder", 206 + BurrType: strings.Repeat("a", MaxBurrTypeLength+1), 207 + } 208 + assert.ErrorIs(t, req.Validate(), ErrFieldTooLong) 209 + }) 210 + 211 + t.Run("notes too long", func(t *testing.T) { 212 + req := &CreateGrinderRequest{ 213 + Name: "Grinder", 214 + Notes: strings.Repeat("a", MaxNotesLength+1), 215 + } 216 + assert.ErrorIs(t, req.Validate(), ErrNotesTooLong) 217 + }) 218 + } 219 + 220 + func TestUpdateGrinderRequest_Validate(t *testing.T) { 221 + t.Run("valid request", func(t *testing.T) { 222 + req := &UpdateGrinderRequest{Name: "Updated Grinder"} 223 + assert.NoError(t, req.Validate()) 224 + }) 225 + 226 + t.Run("empty name", func(t *testing.T) { 227 + req := &UpdateGrinderRequest{Name: ""} 228 + assert.ErrorIs(t, req.Validate(), ErrNameRequired) 229 + }) 230 + 231 + t.Run("grinder type too long", func(t *testing.T) { 232 + req := &UpdateGrinderRequest{ 233 + Name: "Grinder", 234 + GrinderType: strings.Repeat("a", MaxGrinderTypeLength+1), 235 + } 236 + assert.ErrorIs(t, req.Validate(), ErrFieldTooLong) 237 + }) 238 + 239 + t.Run("notes too long", func(t *testing.T) { 240 + req := &UpdateGrinderRequest{ 241 + Name: "Grinder", 242 + Notes: strings.Repeat("a", MaxNotesLength+1), 243 + } 244 + assert.ErrorIs(t, req.Validate(), ErrNotesTooLong) 245 + }) 246 + } 247 + 248 + func TestCreateBrewerRequest_Validate(t *testing.T) { 249 + t.Run("valid request", func(t *testing.T) { 250 + req := &CreateBrewerRequest{Name: "V60"} 251 + assert.NoError(t, req.Validate()) 252 + }) 253 + 254 + t.Run("empty name", func(t *testing.T) { 255 + req := &CreateBrewerRequest{Name: ""} 256 + assert.ErrorIs(t, req.Validate(), ErrNameRequired) 257 + }) 258 + 259 + t.Run("name too long", func(t *testing.T) { 260 + req := &CreateBrewerRequest{Name: strings.Repeat("a", MaxNameLength+1)} 261 + assert.ErrorIs(t, req.Validate(), ErrNameTooLong) 262 + }) 263 + 264 + t.Run("brewer type too long", func(t *testing.T) { 265 + req := &CreateBrewerRequest{ 266 + Name: "Brewer", 267 + BrewerType: strings.Repeat("a", MaxBrewerTypeLength+1), 268 + } 269 + assert.ErrorIs(t, req.Validate(), ErrFieldTooLong) 270 + }) 271 + 272 + t.Run("description too long", func(t *testing.T) { 273 + req := &CreateBrewerRequest{ 274 + Name: "Brewer", 275 + Description: strings.Repeat("a", MaxDescriptionLength+1), 276 + } 277 + assert.ErrorIs(t, req.Validate(), ErrDescTooLong) 278 + }) 279 + } 280 + 281 + func TestUpdateBrewerRequest_Validate(t *testing.T) { 282 + t.Run("valid request", func(t *testing.T) { 283 + req := &UpdateBrewerRequest{Name: "Updated V60"} 284 + assert.NoError(t, req.Validate()) 285 + }) 286 + 287 + t.Run("empty name", func(t *testing.T) { 288 + req := &UpdateBrewerRequest{Name: ""} 289 + assert.ErrorIs(t, req.Validate(), ErrNameRequired) 290 + }) 291 + 292 + t.Run("brewer type too long", func(t *testing.T) { 293 + req := &UpdateBrewerRequest{ 294 + Name: "Brewer", 295 + BrewerType: strings.Repeat("a", MaxBrewerTypeLength+1), 296 + } 297 + assert.ErrorIs(t, req.Validate(), ErrFieldTooLong) 298 + }) 299 + 300 + t.Run("description too long", func(t *testing.T) { 301 + req := &UpdateBrewerRequest{ 302 + Name: "Brewer", 303 + Description: strings.Repeat("a", MaxDescriptionLength+1), 304 + } 305 + assert.ErrorIs(t, req.Validate(), ErrDescTooLong) 306 + }) 307 + } 308 + 309 + func TestCreateBrewRequest_Validate(t *testing.T) { 310 + t.Run("valid minimal request", func(t *testing.T) { 311 + req := &CreateBrewRequest{} 312 + assert.NoError(t, req.Validate()) 313 + }) 314 + 315 + t.Run("valid full request", func(t *testing.T) { 316 + req := &CreateBrewRequest{ 317 + BeanRKey: "abc123", 318 + Method: "Pour Over", 319 + Temperature: 93.5, 320 + WaterAmount: 250, 321 + CoffeeAmount: 15, 322 + TimeSeconds: 210, 323 + GrindSize: "Medium-Fine", 324 + GrinderRKey: "grinder1", 325 + BrewerRKey: "brewer1", 326 + TastingNotes: "Fruity, bright acidity", 327 + Rating: 8, 328 + } 329 + assert.NoError(t, req.Validate()) 330 + }) 331 + 332 + t.Run("method too long", func(t *testing.T) { 333 + req := &CreateBrewRequest{ 334 + Method: strings.Repeat("a", MaxMethodLength+1), 335 + } 336 + assert.ErrorIs(t, req.Validate(), ErrFieldTooLong) 337 + }) 338 + 339 + t.Run("grind size too long", func(t *testing.T) { 340 + req := &CreateBrewRequest{ 341 + GrindSize: strings.Repeat("a", MaxGrindSizeLength+1), 342 + } 343 + assert.ErrorIs(t, req.Validate(), ErrFieldTooLong) 344 + }) 345 + 346 + t.Run("tasting notes too long", func(t *testing.T) { 347 + req := &CreateBrewRequest{ 348 + TastingNotes: strings.Repeat("a", MaxTastingNotesLength+1), 349 + } 350 + assert.ErrorIs(t, req.Validate(), ErrFieldTooLong) 351 + }) 352 + }
+118
internal/web/bff/helpers_test.go
··· 2 2 3 3 import ( 4 4 "testing" 5 + "time" 5 6 6 7 "arabica/internal/models" 7 8 "github.com/stretchr/testify/assert" ··· 122 123 }) 123 124 } 124 125 } 126 + 127 + func TestHasTemp(t *testing.T) { 128 + assert.False(t, HasTemp(0)) 129 + assert.False(t, HasTemp(-1)) 130 + assert.True(t, HasTemp(0.1)) 131 + assert.True(t, HasTemp(93.5)) 132 + } 133 + 134 + func TestHasValue(t *testing.T) { 135 + assert.False(t, HasValue(0)) 136 + assert.False(t, HasValue(-1)) 137 + assert.True(t, HasValue(1)) 138 + assert.True(t, HasValue(250)) 139 + } 140 + 141 + func TestSafeAvatarURL(t *testing.T) { 142 + tests := []struct { 143 + name string 144 + input string 145 + expected string 146 + }{ 147 + {"empty string", "", ""}, 148 + {"trusted bsky CDN", "https://cdn.bsky.app/img/avatar/did:plc:abc/cid@jpeg", "https://cdn.bsky.app/img/avatar/did:plc:abc/cid@jpeg"}, 149 + {"trusted av-cdn", "https://av-cdn.bsky.app/img/avatar/abc", "https://av-cdn.bsky.app/img/avatar/abc"}, 150 + {"static path", "/static/icon-placeholder.svg", "/static/icon-placeholder.svg"}, 151 + {"non-static relative path", "/evil/path", ""}, 152 + {"http scheme rejected", "http://cdn.bsky.app/img/avatar/abc", ""}, 153 + {"untrusted domain", "https://evil.com/avatar.jpg", ""}, 154 + {"javascript scheme", "javascript:alert(1)", ""}, 155 + {"data URI rejected", "data:image/svg+xml,<svg></svg>", ""}, 156 + {"invalid URL", "://invalid", ""}, 157 + {"subdomain of trusted", "https://sub.cdn.bsky.app/avatar.jpg", "https://sub.cdn.bsky.app/avatar.jpg"}, 158 + } 159 + 160 + for _, tt := range tests { 161 + t.Run(tt.name, func(t *testing.T) { 162 + assert.Equal(t, tt.expected, SafeAvatarURL(tt.input)) 163 + }) 164 + } 165 + } 166 + 167 + func TestSafeWebsiteURL(t *testing.T) { 168 + tests := []struct { 169 + name string 170 + input string 171 + expected string 172 + }{ 173 + {"empty string", "", ""}, 174 + {"valid https", "https://example.com", "https://example.com"}, 175 + {"valid http", "http://example.com", "http://example.com"}, 176 + {"javascript scheme", "javascript:alert(1)", ""}, 177 + {"ftp scheme", "ftp://files.example.com", ""}, 178 + {"no dot in host", "https://localhost", ""}, 179 + {"invalid URL", "://invalid", ""}, 180 + {"https with path", "https://roaster.coffee/about", "https://roaster.coffee/about"}, 181 + } 182 + 183 + for _, tt := range tests { 184 + t.Run(tt.name, func(t *testing.T) { 185 + assert.Equal(t, tt.expected, SafeWebsiteURL(tt.input)) 186 + }) 187 + } 188 + } 189 + 190 + func TestEscapeJS(t *testing.T) { 191 + tests := []struct { 192 + name string 193 + input string 194 + expected string 195 + }{ 196 + {"empty string", "", ""}, 197 + {"no special chars", "hello world", "hello world"}, 198 + {"single quotes", "it's a test", "it\\'s a test"}, 199 + {"double quotes", `say "hello"`, `say \"hello\"`}, 200 + {"newlines", "line1\nline2", "line1\\nline2"}, 201 + {"carriage return", "line1\rline2", "line1\\rline2"}, 202 + {"tabs", "col1\tcol2", "col1\\tcol2"}, 203 + {"backslash", `path\to\file`, `path\\to\\file`}, 204 + {"mixed", "it's a \"test\"\nwith\\stuff", "it\\'s a \\\"test\\\"\\nwith\\\\stuff"}, 205 + } 206 + 207 + for _, tt := range tests { 208 + t.Run(tt.name, func(t *testing.T) { 209 + assert.Equal(t, tt.expected, EscapeJS(tt.input)) 210 + }) 211 + } 212 + } 213 + 214 + func TestFormatTimeAgo(t *testing.T) { 215 + now := time.Now() 216 + 217 + tests := []struct { 218 + name string 219 + input time.Time 220 + expected string 221 + }{ 222 + {"just now", now.Add(-30 * time.Second), "just now"}, 223 + {"1 minute ago", now.Add(-1 * time.Minute), "1 minute ago"}, 224 + {"5 minutes ago", now.Add(-5 * time.Minute), "5 minutes ago"}, 225 + {"1 hour ago", now.Add(-1 * time.Hour), "1 hour ago"}, 226 + {"3 hours ago", now.Add(-3 * time.Hour), "3 hours ago"}, 227 + {"yesterday", now.Add(-36 * time.Hour), "yesterday"}, 228 + {"3 days ago", now.Add(-3 * 24 * time.Hour), "3 days ago"}, 229 + {"1 week ago", now.Add(-8 * 24 * time.Hour), "1 week ago"}, 230 + {"3 weeks ago", now.Add(-22 * 24 * time.Hour), "3 weeks ago"}, 231 + {"1 month ago", now.Add(-35 * 24 * time.Hour), "1 month ago"}, 232 + {"6 months ago", now.Add(-180 * 24 * time.Hour), "6 months ago"}, 233 + {"1 year ago", now.Add(-400 * 24 * time.Hour), "1 year ago"}, 234 + {"2 years ago", now.Add(-800 * 24 * time.Hour), "2 years ago"}, 235 + } 236 + 237 + for _, tt := range tests { 238 + t.Run(tt.name, func(t *testing.T) { 239 + assert.Equal(t, tt.expected, FormatTimeAgo(tt.input)) 240 + }) 241 + } 242 + }