tangled
alpha
login
or
join now
arabica.social
/
arabica
7
fork
atom
Coffee journaling on ATProto (alpha)
alpha.arabica.social
coffee
7
fork
atom
overview
issues
pulls
pipelines
test: add new tests
pdewey.com
3 weeks ago
39d01397
d33958e0
verified
This commit was signed with the committer's
known signature
.
pdewey.com
SSH Key Fingerprint:
SHA256:ePOVkJstqVLchGK8m9/OGQG+aFNHD5XN3xjvW9wKCA4=
+1127
5 changed files
expand all
collapse all
unified
split
internal
database
boltstore
feed_store_test.go
join_store_test.go
middleware
security_test.go
models
models_test.go
web
bff
helpers_test.go
+139
internal/database/boltstore/feed_store_test.go
···
1
1
+
package boltstore
2
2
+
3
3
+
import (
4
4
+
"path/filepath"
5
5
+
"testing"
6
6
+
7
7
+
"github.com/stretchr/testify/assert"
8
8
+
"github.com/stretchr/testify/require"
9
9
+
)
10
10
+
11
11
+
func setupTestFeedStore(t *testing.T) *FeedStore {
12
12
+
tmpDir := t.TempDir()
13
13
+
dbPath := filepath.Join(tmpDir, "test.db")
14
14
+
15
15
+
store, err := Open(Options{Path: dbPath})
16
16
+
require.NoError(t, err)
17
17
+
18
18
+
t.Cleanup(func() {
19
19
+
store.Close()
20
20
+
})
21
21
+
22
22
+
return store.FeedStore()
23
23
+
}
24
24
+
25
25
+
func TestFeedStore_Register(t *testing.T) {
26
26
+
store := setupTestFeedStore(t)
27
27
+
28
28
+
t.Run("register new DID", func(t *testing.T) {
29
29
+
err := store.Register("did:plc:user1")
30
30
+
require.NoError(t, err)
31
31
+
assert.True(t, store.IsRegistered("did:plc:user1"))
32
32
+
})
33
33
+
34
34
+
t.Run("register is idempotent", func(t *testing.T) {
35
35
+
err := store.Register("did:plc:user2")
36
36
+
require.NoError(t, err)
37
37
+
38
38
+
err = store.Register("did:plc:user2")
39
39
+
require.NoError(t, err)
40
40
+
41
41
+
assert.Equal(t, 1, countDID(store, "did:plc:user2"))
42
42
+
})
43
43
+
}
44
44
+
45
45
+
// countDID counts how many times a DID appears in the list (should be 0 or 1).
46
46
+
func countDID(store *FeedStore, did string) int {
47
47
+
count := 0
48
48
+
for _, d := range store.List() {
49
49
+
if d == did {
50
50
+
count++
51
51
+
}
52
52
+
}
53
53
+
return count
54
54
+
}
55
55
+
56
56
+
func TestFeedStore_Unregister(t *testing.T) {
57
57
+
store := setupTestFeedStore(t)
58
58
+
59
59
+
err := store.Register("did:plc:unreg")
60
60
+
require.NoError(t, err)
61
61
+
assert.True(t, store.IsRegistered("did:plc:unreg"))
62
62
+
63
63
+
err = store.Unregister("did:plc:unreg")
64
64
+
require.NoError(t, err)
65
65
+
assert.False(t, store.IsRegistered("did:plc:unreg"))
66
66
+
}
67
67
+
68
68
+
func TestFeedStore_IsRegistered(t *testing.T) {
69
69
+
store := setupTestFeedStore(t)
70
70
+
71
71
+
assert.False(t, store.IsRegistered("did:plc:nobody"))
72
72
+
73
73
+
store.Register("did:plc:somebody")
74
74
+
assert.True(t, store.IsRegistered("did:plc:somebody"))
75
75
+
}
76
76
+
77
77
+
func TestFeedStore_List(t *testing.T) {
78
78
+
store := setupTestFeedStore(t)
79
79
+
80
80
+
t.Run("empty store", func(t *testing.T) {
81
81
+
dids := store.List()
82
82
+
assert.Empty(t, dids)
83
83
+
})
84
84
+
85
85
+
t.Run("multiple registrations", func(t *testing.T) {
86
86
+
store.Register("did:plc:a")
87
87
+
store.Register("did:plc:b")
88
88
+
store.Register("did:plc:c")
89
89
+
90
90
+
dids := store.List()
91
91
+
assert.Len(t, dids, 3)
92
92
+
assert.Contains(t, dids, "did:plc:a")
93
93
+
assert.Contains(t, dids, "did:plc:b")
94
94
+
assert.Contains(t, dids, "did:plc:c")
95
95
+
})
96
96
+
}
97
97
+
98
98
+
func TestFeedStore_ListWithMetadata(t *testing.T) {
99
99
+
store := setupTestFeedStore(t)
100
100
+
101
101
+
store.Register("did:plc:meta1")
102
102
+
store.Register("did:plc:meta2")
103
103
+
104
104
+
users := store.ListWithMetadata()
105
105
+
assert.Len(t, users, 2)
106
106
+
107
107
+
for _, u := range users {
108
108
+
assert.NotEmpty(t, u.DID)
109
109
+
assert.False(t, u.RegisteredAt.IsZero())
110
110
+
}
111
111
+
}
112
112
+
113
113
+
func TestFeedStore_Count(t *testing.T) {
114
114
+
store := setupTestFeedStore(t)
115
115
+
116
116
+
assert.Equal(t, 0, store.Count())
117
117
+
118
118
+
store.Register("did:plc:c1")
119
119
+
assert.Equal(t, 1, store.Count())
120
120
+
121
121
+
store.Register("did:plc:c2")
122
122
+
assert.Equal(t, 2, store.Count())
123
123
+
124
124
+
store.Unregister("did:plc:c1")
125
125
+
assert.Equal(t, 1, store.Count())
126
126
+
}
127
127
+
128
128
+
func TestFeedStore_Clear(t *testing.T) {
129
129
+
store := setupTestFeedStore(t)
130
130
+
131
131
+
store.Register("did:plc:clear1")
132
132
+
store.Register("did:plc:clear2")
133
133
+
assert.Equal(t, 2, store.Count())
134
134
+
135
135
+
err := store.Clear()
136
136
+
require.NoError(t, err)
137
137
+
assert.Equal(t, 0, store.Count())
138
138
+
assert.False(t, store.IsRegistered("did:plc:clear1"))
139
139
+
}
+136
internal/database/boltstore/join_store_test.go
···
1
1
+
package boltstore
2
2
+
3
3
+
import (
4
4
+
"path/filepath"
5
5
+
"testing"
6
6
+
"time"
7
7
+
8
8
+
"github.com/stretchr/testify/assert"
9
9
+
"github.com/stretchr/testify/require"
10
10
+
)
11
11
+
12
12
+
func setupTestJoinStore(t *testing.T) *JoinStore {
13
13
+
tmpDir := t.TempDir()
14
14
+
dbPath := filepath.Join(tmpDir, "test.db")
15
15
+
16
16
+
store, err := Open(Options{Path: dbPath})
17
17
+
require.NoError(t, err)
18
18
+
19
19
+
t.Cleanup(func() {
20
20
+
store.Close()
21
21
+
})
22
22
+
23
23
+
return store.JoinStore()
24
24
+
}
25
25
+
26
26
+
func TestJoinStore_SaveAndGet(t *testing.T) {
27
27
+
store := setupTestJoinStore(t)
28
28
+
29
29
+
req := &JoinRequest{
30
30
+
ID: "join-001",
31
31
+
Email: "user@example.com",
32
32
+
Message: "I love coffee!",
33
33
+
CreatedAt: time.Now().Truncate(time.Millisecond),
34
34
+
IP: "203.0.113.50",
35
35
+
}
36
36
+
37
37
+
err := store.SaveRequest(req)
38
38
+
require.NoError(t, err)
39
39
+
40
40
+
retrieved, err := store.GetRequest("join-001")
41
41
+
require.NoError(t, err)
42
42
+
require.NotNil(t, retrieved)
43
43
+
44
44
+
assert.Equal(t, req.ID, retrieved.ID)
45
45
+
assert.Equal(t, req.Email, retrieved.Email)
46
46
+
assert.Equal(t, req.Message, retrieved.Message)
47
47
+
assert.Equal(t, req.IP, retrieved.IP)
48
48
+
assert.True(t, req.CreatedAt.Equal(retrieved.CreatedAt))
49
49
+
}
50
50
+
51
51
+
func TestJoinStore_GetNotFound(t *testing.T) {
52
52
+
store := setupTestJoinStore(t)
53
53
+
54
54
+
retrieved, err := store.GetRequest("nonexistent")
55
55
+
assert.Error(t, err)
56
56
+
assert.Nil(t, retrieved)
57
57
+
assert.Contains(t, err.Error(), "not found")
58
58
+
}
59
59
+
60
60
+
func TestJoinStore_Delete(t *testing.T) {
61
61
+
store := setupTestJoinStore(t)
62
62
+
63
63
+
req := &JoinRequest{
64
64
+
ID: "join-del",
65
65
+
Email: "delete@example.com",
66
66
+
CreatedAt: time.Now(),
67
67
+
IP: "10.0.0.1",
68
68
+
}
69
69
+
70
70
+
err := store.SaveRequest(req)
71
71
+
require.NoError(t, err)
72
72
+
73
73
+
err = store.DeleteRequest("join-del")
74
74
+
require.NoError(t, err)
75
75
+
76
76
+
retrieved, err := store.GetRequest("join-del")
77
77
+
assert.Error(t, err)
78
78
+
assert.Nil(t, retrieved)
79
79
+
}
80
80
+
81
81
+
func TestJoinStore_DeleteNonexistent(t *testing.T) {
82
82
+
store := setupTestJoinStore(t)
83
83
+
84
84
+
// Deleting a non-existent request should not error
85
85
+
err := store.DeleteRequest("nonexistent")
86
86
+
assert.NoError(t, err)
87
87
+
}
88
88
+
89
89
+
func TestJoinStore_ListRequests(t *testing.T) {
90
90
+
store := setupTestJoinStore(t)
91
91
+
92
92
+
t.Run("empty store", func(t *testing.T) {
93
93
+
requests, err := store.ListRequests()
94
94
+
require.NoError(t, err)
95
95
+
assert.Empty(t, requests)
96
96
+
})
97
97
+
98
98
+
t.Run("multiple requests", func(t *testing.T) {
99
99
+
for i, email := range []string{"a@test.com", "b@test.com", "c@test.com"} {
100
100
+
req := &JoinRequest{
101
101
+
ID: "list-" + string(rune('0'+i)),
102
102
+
Email: email,
103
103
+
CreatedAt: time.Now(),
104
104
+
IP: "10.0.0.1",
105
105
+
}
106
106
+
require.NoError(t, store.SaveRequest(req))
107
107
+
}
108
108
+
109
109
+
requests, err := store.ListRequests()
110
110
+
require.NoError(t, err)
111
111
+
assert.Len(t, requests, 3)
112
112
+
})
113
113
+
}
114
114
+
115
115
+
func TestJoinStore_SaveOverwrites(t *testing.T) {
116
116
+
store := setupTestJoinStore(t)
117
117
+
118
118
+
req := &JoinRequest{
119
119
+
ID: "join-overwrite",
120
120
+
Email: "original@example.com",
121
121
+
CreatedAt: time.Now(),
122
122
+
IP: "10.0.0.1",
123
123
+
}
124
124
+
125
125
+
err := store.SaveRequest(req)
126
126
+
require.NoError(t, err)
127
127
+
128
128
+
// Save again with updated email
129
129
+
req.Email = "updated@example.com"
130
130
+
err = store.SaveRequest(req)
131
131
+
require.NoError(t, err)
132
132
+
133
133
+
retrieved, err := store.GetRequest("join-overwrite")
134
134
+
require.NoError(t, err)
135
135
+
assert.Equal(t, "updated@example.com", retrieved.Email)
136
136
+
}
+382
internal/middleware/security_test.go
···
1
1
+
package middleware
2
2
+
3
3
+
import (
4
4
+
"context"
5
5
+
"net/http"
6
6
+
"net/http/httptest"
7
7
+
"strings"
8
8
+
"testing"
9
9
+
"time"
10
10
+
11
11
+
"github.com/stretchr/testify/assert"
12
12
+
"github.com/stretchr/testify/require"
13
13
+
)
14
14
+
15
15
+
func TestSecurityHeadersMiddleware(t *testing.T) {
16
16
+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
17
17
+
// Verify nonce is available in context
18
18
+
nonce := CSPNonceFromContext(r.Context())
19
19
+
assert.NotEmpty(t, nonce, "nonce should be set in context")
20
20
+
w.WriteHeader(http.StatusOK)
21
21
+
})
22
22
+
23
23
+
wrapped := SecurityHeadersMiddleware(handler)
24
24
+
req := httptest.NewRequest(http.MethodGet, "/", nil)
25
25
+
rec := httptest.NewRecorder()
26
26
+
27
27
+
wrapped.ServeHTTP(rec, req)
28
28
+
29
29
+
assert.Equal(t, http.StatusOK, rec.Code)
30
30
+
assert.Equal(t, "DENY", rec.Header().Get("X-Frame-Options"))
31
31
+
assert.Equal(t, "nosniff", rec.Header().Get("X-Content-Type-Options"))
32
32
+
assert.Equal(t, "1; mode=block", rec.Header().Get("X-XSS-Protection"))
33
33
+
assert.Equal(t, "strict-origin-when-cross-origin", rec.Header().Get("Referrer-Policy"))
34
34
+
assert.Equal(t, "geolocation=(), microphone=(), camera=()", rec.Header().Get("Permissions-Policy"))
35
35
+
36
36
+
csp := rec.Header().Get("Content-Security-Policy")
37
37
+
assert.Contains(t, csp, "default-src 'self'")
38
38
+
assert.Contains(t, csp, "script-src 'self' 'unsafe-eval' 'nonce-")
39
39
+
assert.Contains(t, csp, "frame-ancestors 'none'")
40
40
+
}
41
41
+
42
42
+
func TestCSPNonceFromContext(t *testing.T) {
43
43
+
t.Run("returns nonce when set", func(t *testing.T) {
44
44
+
ctx := context.WithValue(context.Background(), cspNonceKey, "test-nonce-123")
45
45
+
assert.Equal(t, "test-nonce-123", CSPNonceFromContext(ctx))
46
46
+
})
47
47
+
48
48
+
t.Run("returns empty string when not set", func(t *testing.T) {
49
49
+
assert.Equal(t, "", CSPNonceFromContext(context.Background()))
50
50
+
})
51
51
+
52
52
+
t.Run("returns empty string for wrong type", func(t *testing.T) {
53
53
+
ctx := context.WithValue(context.Background(), cspNonceKey, 12345)
54
54
+
assert.Equal(t, "", CSPNonceFromContext(ctx))
55
55
+
})
56
56
+
}
57
57
+
58
58
+
func TestCSPNonceUniqueness(t *testing.T) {
59
59
+
nonces := make(map[string]bool)
60
60
+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
61
61
+
nonce := CSPNonceFromContext(r.Context())
62
62
+
nonces[nonce] = true
63
63
+
w.WriteHeader(http.StatusOK)
64
64
+
})
65
65
+
66
66
+
wrapped := SecurityHeadersMiddleware(handler)
67
67
+
68
68
+
for i := 0; i < 10; i++ {
69
69
+
req := httptest.NewRequest(http.MethodGet, "/", nil)
70
70
+
rec := httptest.NewRecorder()
71
71
+
wrapped.ServeHTTP(rec, req)
72
72
+
}
73
73
+
74
74
+
assert.Len(t, nonces, 10, "each request should get a unique nonce")
75
75
+
}
76
76
+
77
77
+
func TestRateLimiter_Allow(t *testing.T) {
78
78
+
t.Run("allows requests within limit", func(t *testing.T) {
79
79
+
rl := &RateLimiter{
80
80
+
visitors: make(map[string]*visitor),
81
81
+
rate: 3,
82
82
+
window: time.Minute,
83
83
+
cleanup: 2 * time.Minute,
84
84
+
}
85
85
+
86
86
+
assert.True(t, rl.Allow("192.168.1.1"))
87
87
+
assert.True(t, rl.Allow("192.168.1.1"))
88
88
+
assert.True(t, rl.Allow("192.168.1.1"))
89
89
+
})
90
90
+
91
91
+
t.Run("blocks after exceeding limit", func(t *testing.T) {
92
92
+
rl := &RateLimiter{
93
93
+
visitors: make(map[string]*visitor),
94
94
+
rate: 2,
95
95
+
window: time.Minute,
96
96
+
cleanup: 2 * time.Minute,
97
97
+
}
98
98
+
99
99
+
assert.True(t, rl.Allow("10.0.0.1"))
100
100
+
assert.True(t, rl.Allow("10.0.0.1"))
101
101
+
assert.False(t, rl.Allow("10.0.0.1"))
102
102
+
})
103
103
+
104
104
+
t.Run("different IPs are independent", func(t *testing.T) {
105
105
+
rl := &RateLimiter{
106
106
+
visitors: make(map[string]*visitor),
107
107
+
rate: 1,
108
108
+
window: time.Minute,
109
109
+
cleanup: 2 * time.Minute,
110
110
+
}
111
111
+
112
112
+
assert.True(t, rl.Allow("10.0.0.1"))
113
113
+
assert.False(t, rl.Allow("10.0.0.1"))
114
114
+
assert.True(t, rl.Allow("10.0.0.2"))
115
115
+
})
116
116
+
117
117
+
t.Run("resets after window expires", func(t *testing.T) {
118
118
+
rl := &RateLimiter{
119
119
+
visitors: make(map[string]*visitor),
120
120
+
rate: 1,
121
121
+
window: 50 * time.Millisecond,
122
122
+
cleanup: 100 * time.Millisecond,
123
123
+
}
124
124
+
125
125
+
assert.True(t, rl.Allow("10.0.0.1"))
126
126
+
assert.False(t, rl.Allow("10.0.0.1"))
127
127
+
128
128
+
time.Sleep(60 * time.Millisecond)
129
129
+
assert.True(t, rl.Allow("10.0.0.1"))
130
130
+
})
131
131
+
}
132
132
+
133
133
+
func TestRateLimitMiddleware(t *testing.T) {
134
134
+
config := &RateLimitConfig{
135
135
+
AuthLimiter: &RateLimiter{visitors: make(map[string]*visitor), rate: 2, window: time.Minute, cleanup: 2 * time.Minute},
136
136
+
APILimiter: &RateLimiter{visitors: make(map[string]*visitor), rate: 3, window: time.Minute, cleanup: 2 * time.Minute},
137
137
+
GlobalLimiter: &RateLimiter{visitors: make(map[string]*visitor), rate: 5, window: time.Minute, cleanup: 2 * time.Minute},
138
138
+
}
139
139
+
140
140
+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
141
141
+
w.WriteHeader(http.StatusOK)
142
142
+
})
143
143
+
144
144
+
middleware := RateLimitMiddleware(config)
145
145
+
wrapped := middleware(handler)
146
146
+
147
147
+
t.Run("auth endpoints use auth limiter", func(t *testing.T) {
148
148
+
for i := 0; i < 2; i++ {
149
149
+
req := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
150
150
+
req.RemoteAddr = "1.1.1.1:1234"
151
151
+
rec := httptest.NewRecorder()
152
152
+
wrapped.ServeHTTP(rec, req)
153
153
+
assert.Equal(t, http.StatusOK, rec.Code)
154
154
+
}
155
155
+
156
156
+
req := httptest.NewRequest(http.MethodPost, "/auth/login", nil)
157
157
+
req.RemoteAddr = "1.1.1.1:1234"
158
158
+
rec := httptest.NewRecorder()
159
159
+
wrapped.ServeHTTP(rec, req)
160
160
+
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
161
161
+
assert.Equal(t, "60", rec.Header().Get("Retry-After"))
162
162
+
})
163
163
+
164
164
+
t.Run("api endpoints use api limiter", func(t *testing.T) {
165
165
+
for i := 0; i < 3; i++ {
166
166
+
req := httptest.NewRequest(http.MethodGet, "/api/brews", nil)
167
167
+
req.RemoteAddr = "2.2.2.2:1234"
168
168
+
rec := httptest.NewRecorder()
169
169
+
wrapped.ServeHTTP(rec, req)
170
170
+
assert.Equal(t, http.StatusOK, rec.Code)
171
171
+
}
172
172
+
173
173
+
req := httptest.NewRequest(http.MethodGet, "/api/brews", nil)
174
174
+
req.RemoteAddr = "2.2.2.2:1234"
175
175
+
rec := httptest.NewRecorder()
176
176
+
wrapped.ServeHTTP(rec, req)
177
177
+
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
178
178
+
})
179
179
+
180
180
+
t.Run("other endpoints use global limiter", func(t *testing.T) {
181
181
+
for i := 0; i < 5; i++ {
182
182
+
req := httptest.NewRequest(http.MethodGet, "/brews", nil)
183
183
+
req.RemoteAddr = "3.3.3.3:1234"
184
184
+
rec := httptest.NewRecorder()
185
185
+
wrapped.ServeHTTP(rec, req)
186
186
+
assert.Equal(t, http.StatusOK, rec.Code)
187
187
+
}
188
188
+
189
189
+
req := httptest.NewRequest(http.MethodGet, "/brews", nil)
190
190
+
req.RemoteAddr = "3.3.3.3:1234"
191
191
+
rec := httptest.NewRecorder()
192
192
+
wrapped.ServeHTTP(rec, req)
193
193
+
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
194
194
+
})
195
195
+
196
196
+
t.Run("login path uses auth limiter", func(t *testing.T) {
197
197
+
for i := 0; i < 2; i++ {
198
198
+
req := httptest.NewRequest(http.MethodPost, "/login", nil)
199
199
+
req.RemoteAddr = "4.4.4.4:1234"
200
200
+
rec := httptest.NewRecorder()
201
201
+
wrapped.ServeHTTP(rec, req)
202
202
+
assert.Equal(t, http.StatusOK, rec.Code)
203
203
+
}
204
204
+
205
205
+
req := httptest.NewRequest(http.MethodPost, "/login", nil)
206
206
+
req.RemoteAddr = "4.4.4.4:1234"
207
207
+
rec := httptest.NewRecorder()
208
208
+
wrapped.ServeHTTP(rec, req)
209
209
+
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
210
210
+
})
211
211
+
}
212
212
+
213
213
+
func TestRequireHTMXMiddleware(t *testing.T) {
214
214
+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
215
215
+
w.WriteHeader(http.StatusOK)
216
216
+
w.Write([]byte("OK"))
217
217
+
})
218
218
+
219
219
+
wrapped := RequireHTMXMiddleware(handler)
220
220
+
221
221
+
t.Run("allows HTMX requests", func(t *testing.T) {
222
222
+
req := httptest.NewRequest(http.MethodGet, "/api/partial", nil)
223
223
+
req.Header.Set("HX-Request", "true")
224
224
+
rec := httptest.NewRecorder()
225
225
+
226
226
+
wrapped.ServeHTTP(rec, req)
227
227
+
assert.Equal(t, http.StatusOK, rec.Code)
228
228
+
assert.Equal(t, "OK", rec.Body.String())
229
229
+
})
230
230
+
231
231
+
t.Run("blocks non-HTMX requests", func(t *testing.T) {
232
232
+
req := httptest.NewRequest(http.MethodGet, "/api/partial", nil)
233
233
+
rec := httptest.NewRecorder()
234
234
+
235
235
+
wrapped.ServeHTTP(rec, req)
236
236
+
assert.Equal(t, http.StatusNotFound, rec.Code)
237
237
+
})
238
238
+
239
239
+
t.Run("blocks wrong HX-Request value", func(t *testing.T) {
240
240
+
req := httptest.NewRequest(http.MethodGet, "/api/partial", nil)
241
241
+
req.Header.Set("HX-Request", "false")
242
242
+
rec := httptest.NewRecorder()
243
243
+
244
244
+
wrapped.ServeHTTP(rec, req)
245
245
+
assert.Equal(t, http.StatusNotFound, rec.Code)
246
246
+
})
247
247
+
}
248
248
+
249
249
+
func TestLimitBodyMiddleware(t *testing.T) {
250
250
+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
251
251
+
// Try to read the body
252
252
+
buf := make([]byte, 2<<20) // 2MB buffer
253
253
+
_, err := r.Body.Read(buf)
254
254
+
if err != nil && err.Error() != "EOF" {
255
255
+
http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
256
256
+
return
257
257
+
}
258
258
+
w.WriteHeader(http.StatusOK)
259
259
+
})
260
260
+
261
261
+
wrapped := LimitBodyMiddleware(handler)
262
262
+
263
263
+
t.Run("allows small JSON body", func(t *testing.T) {
264
264
+
body := strings.NewReader(`{"name": "test"}`)
265
265
+
req := httptest.NewRequest(http.MethodPost, "/api/test", body)
266
266
+
req.Header.Set("Content-Type", "application/json")
267
267
+
rec := httptest.NewRecorder()
268
268
+
269
269
+
wrapped.ServeHTTP(rec, req)
270
270
+
assert.Equal(t, http.StatusOK, rec.Code)
271
271
+
})
272
272
+
273
273
+
t.Run("allows small form body", func(t *testing.T) {
274
274
+
body := strings.NewReader("name=test&value=123")
275
275
+
req := httptest.NewRequest(http.MethodPost, "/api/test", body)
276
276
+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
277
277
+
rec := httptest.NewRecorder()
278
278
+
279
279
+
wrapped.ServeHTTP(rec, req)
280
280
+
assert.Equal(t, http.StatusOK, rec.Code)
281
281
+
})
282
282
+
283
283
+
t.Run("handles nil body", func(t *testing.T) {
284
284
+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
285
285
+
rec := httptest.NewRecorder()
286
286
+
287
287
+
wrapped.ServeHTTP(rec, req)
288
288
+
assert.Equal(t, http.StatusOK, rec.Code)
289
289
+
})
290
290
+
}
291
291
+
292
292
+
func TestGetClientIP(t *testing.T) {
293
293
+
tests := []struct {
294
294
+
name string
295
295
+
xff string
296
296
+
xri string
297
297
+
remoteAddr string
298
298
+
expected string
299
299
+
}{
300
300
+
{
301
301
+
name: "X-Forwarded-For single IP",
302
302
+
xff: "203.0.113.50",
303
303
+
remoteAddr: "127.0.0.1:1234",
304
304
+
expected: "203.0.113.50",
305
305
+
},
306
306
+
{
307
307
+
name: "X-Forwarded-For multiple IPs",
308
308
+
xff: "203.0.113.50, 70.41.3.18, 150.172.238.178",
309
309
+
remoteAddr: "127.0.0.1:1234",
310
310
+
expected: "203.0.113.50",
311
311
+
},
312
312
+
{
313
313
+
name: "X-Forwarded-For with whitespace",
314
314
+
xff: " 203.0.113.50 ",
315
315
+
remoteAddr: "127.0.0.1:1234",
316
316
+
expected: "203.0.113.50",
317
317
+
},
318
318
+
{
319
319
+
name: "X-Real-IP",
320
320
+
xri: "198.51.100.178",
321
321
+
remoteAddr: "127.0.0.1:1234",
322
322
+
expected: "198.51.100.178",
323
323
+
},
324
324
+
{
325
325
+
name: "X-Real-IP with whitespace",
326
326
+
xri: " 198.51.100.178 ",
327
327
+
remoteAddr: "127.0.0.1:1234",
328
328
+
expected: "198.51.100.178",
329
329
+
},
330
330
+
{
331
331
+
name: "X-Forwarded-For takes precedence over X-Real-IP",
332
332
+
xff: "203.0.113.50",
333
333
+
xri: "198.51.100.178",
334
334
+
remoteAddr: "127.0.0.1:1234",
335
335
+
expected: "203.0.113.50",
336
336
+
},
337
337
+
{
338
338
+
name: "fallback to RemoteAddr with port",
339
339
+
remoteAddr: "192.168.1.1:8080",
340
340
+
expected: "192.168.1.1",
341
341
+
},
342
342
+
{
343
343
+
name: "fallback to RemoteAddr without port",
344
344
+
remoteAddr: "192.168.1.1",
345
345
+
expected: "192.168.1.1",
346
346
+
},
347
347
+
}
348
348
+
349
349
+
for _, tt := range tests {
350
350
+
t.Run(tt.name, func(t *testing.T) {
351
351
+
req := httptest.NewRequest(http.MethodGet, "/", nil)
352
352
+
req.RemoteAddr = tt.remoteAddr
353
353
+
if tt.xff != "" {
354
354
+
req.Header.Set("X-Forwarded-For", tt.xff)
355
355
+
}
356
356
+
if tt.xri != "" {
357
357
+
req.Header.Set("X-Real-IP", tt.xri)
358
358
+
}
359
359
+
360
360
+
got := GetClientIP(req)
361
361
+
assert.Equal(t, tt.expected, got)
362
362
+
})
363
363
+
}
364
364
+
}
365
365
+
366
366
+
func TestGenerateNonce(t *testing.T) {
367
367
+
t.Run("generates base64 string", func(t *testing.T) {
368
368
+
nonce, err := generateNonce()
369
369
+
require.NoError(t, err)
370
370
+
assert.NotEmpty(t, nonce)
371
371
+
// Base64 of 16 bytes = 24 chars
372
372
+
assert.Len(t, nonce, 24)
373
373
+
})
374
374
+
375
375
+
t.Run("generates unique values", func(t *testing.T) {
376
376
+
n1, err := generateNonce()
377
377
+
require.NoError(t, err)
378
378
+
n2, err := generateNonce()
379
379
+
require.NoError(t, err)
380
380
+
assert.NotEqual(t, n1, n2)
381
381
+
})
382
382
+
}
+352
internal/models/models_test.go
···
1
1
+
package models
2
2
+
3
3
+
import (
4
4
+
"strings"
5
5
+
"testing"
6
6
+
7
7
+
"github.com/stretchr/testify/assert"
8
8
+
)
9
9
+
10
10
+
func TestCreateBeanRequest_Validate(t *testing.T) {
11
11
+
t.Run("valid request", func(t *testing.T) {
12
12
+
req := &CreateBeanRequest{Name: "Ethiopian Yirgacheffe"}
13
13
+
assert.NoError(t, req.Validate())
14
14
+
})
15
15
+
16
16
+
t.Run("empty name", func(t *testing.T) {
17
17
+
req := &CreateBeanRequest{Name: ""}
18
18
+
assert.ErrorIs(t, req.Validate(), ErrNameRequired)
19
19
+
})
20
20
+
21
21
+
t.Run("name too long", func(t *testing.T) {
22
22
+
req := &CreateBeanRequest{Name: strings.Repeat("a", MaxNameLength+1)}
23
23
+
assert.ErrorIs(t, req.Validate(), ErrNameTooLong)
24
24
+
})
25
25
+
26
26
+
t.Run("name at max length", func(t *testing.T) {
27
27
+
req := &CreateBeanRequest{Name: strings.Repeat("a", MaxNameLength)}
28
28
+
assert.NoError(t, req.Validate())
29
29
+
})
30
30
+
31
31
+
t.Run("origin too long", func(t *testing.T) {
32
32
+
req := &CreateBeanRequest{
33
33
+
Name: "Bean",
34
34
+
Origin: strings.Repeat("a", MaxOriginLength+1),
35
35
+
}
36
36
+
assert.ErrorIs(t, req.Validate(), ErrOriginTooLong)
37
37
+
})
38
38
+
39
39
+
t.Run("roast level too long", func(t *testing.T) {
40
40
+
req := &CreateBeanRequest{
41
41
+
Name: "Bean",
42
42
+
RoastLevel: strings.Repeat("a", MaxRoastLevelLength+1),
43
43
+
}
44
44
+
assert.ErrorIs(t, req.Validate(), ErrFieldTooLong)
45
45
+
})
46
46
+
47
47
+
t.Run("process too long", func(t *testing.T) {
48
48
+
req := &CreateBeanRequest{
49
49
+
Name: "Bean",
50
50
+
Process: strings.Repeat("a", MaxProcessLength+1),
51
51
+
}
52
52
+
assert.ErrorIs(t, req.Validate(), ErrFieldTooLong)
53
53
+
})
54
54
+
55
55
+
t.Run("description too long", func(t *testing.T) {
56
56
+
req := &CreateBeanRequest{
57
57
+
Name: "Bean",
58
58
+
Description: strings.Repeat("a", MaxDescriptionLength+1),
59
59
+
}
60
60
+
assert.ErrorIs(t, req.Validate(), ErrDescTooLong)
61
61
+
})
62
62
+
63
63
+
t.Run("all optional fields populated", func(t *testing.T) {
64
64
+
req := &CreateBeanRequest{
65
65
+
Name: "Ethiopian Yirgacheffe",
66
66
+
Origin: "Ethiopia",
67
67
+
RoastLevel: "Light",
68
68
+
Process: "Washed",
69
69
+
Description: "Fruity and floral notes",
70
70
+
RoasterRKey: "abc123",
71
71
+
}
72
72
+
assert.NoError(t, req.Validate())
73
73
+
})
74
74
+
}
75
75
+
76
76
+
func TestUpdateBeanRequest_Validate(t *testing.T) {
77
77
+
t.Run("valid request", func(t *testing.T) {
78
78
+
req := &UpdateBeanRequest{Name: "Updated Bean"}
79
79
+
assert.NoError(t, req.Validate())
80
80
+
})
81
81
+
82
82
+
t.Run("empty name", func(t *testing.T) {
83
83
+
req := &UpdateBeanRequest{Name: ""}
84
84
+
assert.ErrorIs(t, req.Validate(), ErrNameRequired)
85
85
+
})
86
86
+
87
87
+
t.Run("name too long", func(t *testing.T) {
88
88
+
req := &UpdateBeanRequest{Name: strings.Repeat("a", MaxNameLength+1)}
89
89
+
assert.ErrorIs(t, req.Validate(), ErrNameTooLong)
90
90
+
})
91
91
+
92
92
+
t.Run("origin too long", func(t *testing.T) {
93
93
+
req := &UpdateBeanRequest{
94
94
+
Name: "Bean",
95
95
+
Origin: strings.Repeat("a", MaxOriginLength+1),
96
96
+
}
97
97
+
assert.ErrorIs(t, req.Validate(), ErrOriginTooLong)
98
98
+
})
99
99
+
100
100
+
t.Run("description too long", func(t *testing.T) {
101
101
+
req := &UpdateBeanRequest{
102
102
+
Name: "Bean",
103
103
+
Description: strings.Repeat("a", MaxDescriptionLength+1),
104
104
+
}
105
105
+
assert.ErrorIs(t, req.Validate(), ErrDescTooLong)
106
106
+
})
107
107
+
}
108
108
+
109
109
+
func TestCreateRoasterRequest_Validate(t *testing.T) {
110
110
+
t.Run("valid request", func(t *testing.T) {
111
111
+
req := &CreateRoasterRequest{Name: "Blue Bottle"}
112
112
+
assert.NoError(t, req.Validate())
113
113
+
})
114
114
+
115
115
+
t.Run("empty name", func(t *testing.T) {
116
116
+
req := &CreateRoasterRequest{Name: ""}
117
117
+
assert.ErrorIs(t, req.Validate(), ErrNameRequired)
118
118
+
})
119
119
+
120
120
+
t.Run("name too long", func(t *testing.T) {
121
121
+
req := &CreateRoasterRequest{Name: strings.Repeat("a", MaxNameLength+1)}
122
122
+
assert.ErrorIs(t, req.Validate(), ErrNameTooLong)
123
123
+
})
124
124
+
125
125
+
t.Run("location too long", func(t *testing.T) {
126
126
+
req := &CreateRoasterRequest{
127
127
+
Name: "Roaster",
128
128
+
Location: strings.Repeat("a", MaxLocationLength+1),
129
129
+
}
130
130
+
assert.ErrorIs(t, req.Validate(), ErrLocationTooLong)
131
131
+
})
132
132
+
133
133
+
t.Run("website too long", func(t *testing.T) {
134
134
+
req := &CreateRoasterRequest{
135
135
+
Name: "Roaster",
136
136
+
Website: strings.Repeat("a", MaxWebsiteLength+1),
137
137
+
}
138
138
+
assert.ErrorIs(t, req.Validate(), ErrWebsiteTooLong)
139
139
+
})
140
140
+
141
141
+
t.Run("all fields at max", func(t *testing.T) {
142
142
+
req := &CreateRoasterRequest{
143
143
+
Name: strings.Repeat("a", MaxNameLength),
144
144
+
Location: strings.Repeat("a", MaxLocationLength),
145
145
+
Website: strings.Repeat("a", MaxWebsiteLength),
146
146
+
}
147
147
+
assert.NoError(t, req.Validate())
148
148
+
})
149
149
+
}
150
150
+
151
151
+
func TestUpdateRoasterRequest_Validate(t *testing.T) {
152
152
+
t.Run("valid request", func(t *testing.T) {
153
153
+
req := &UpdateRoasterRequest{Name: "Updated Roaster"}
154
154
+
assert.NoError(t, req.Validate())
155
155
+
})
156
156
+
157
157
+
t.Run("empty name", func(t *testing.T) {
158
158
+
req := &UpdateRoasterRequest{Name: ""}
159
159
+
assert.ErrorIs(t, req.Validate(), ErrNameRequired)
160
160
+
})
161
161
+
162
162
+
t.Run("location too long", func(t *testing.T) {
163
163
+
req := &UpdateRoasterRequest{
164
164
+
Name: "Roaster",
165
165
+
Location: strings.Repeat("a", MaxLocationLength+1),
166
166
+
}
167
167
+
assert.ErrorIs(t, req.Validate(), ErrLocationTooLong)
168
168
+
})
169
169
+
170
170
+
t.Run("website too long", func(t *testing.T) {
171
171
+
req := &UpdateRoasterRequest{
172
172
+
Name: "Roaster",
173
173
+
Website: strings.Repeat("a", MaxWebsiteLength+1),
174
174
+
}
175
175
+
assert.ErrorIs(t, req.Validate(), ErrWebsiteTooLong)
176
176
+
})
177
177
+
}
178
178
+
179
179
+
func TestCreateGrinderRequest_Validate(t *testing.T) {
180
180
+
t.Run("valid request", func(t *testing.T) {
181
181
+
req := &CreateGrinderRequest{Name: "Comandante C40"}
182
182
+
assert.NoError(t, req.Validate())
183
183
+
})
184
184
+
185
185
+
t.Run("empty name", func(t *testing.T) {
186
186
+
req := &CreateGrinderRequest{Name: ""}
187
187
+
assert.ErrorIs(t, req.Validate(), ErrNameRequired)
188
188
+
})
189
189
+
190
190
+
t.Run("name too long", func(t *testing.T) {
191
191
+
req := &CreateGrinderRequest{Name: strings.Repeat("a", MaxNameLength+1)}
192
192
+
assert.ErrorIs(t, req.Validate(), ErrNameTooLong)
193
193
+
})
194
194
+
195
195
+
t.Run("grinder type too long", func(t *testing.T) {
196
196
+
req := &CreateGrinderRequest{
197
197
+
Name: "Grinder",
198
198
+
GrinderType: strings.Repeat("a", MaxGrinderTypeLength+1),
199
199
+
}
200
200
+
assert.ErrorIs(t, req.Validate(), ErrFieldTooLong)
201
201
+
})
202
202
+
203
203
+
t.Run("burr type too long", func(t *testing.T) {
204
204
+
req := &CreateGrinderRequest{
205
205
+
Name: "Grinder",
206
206
+
BurrType: strings.Repeat("a", MaxBurrTypeLength+1),
207
207
+
}
208
208
+
assert.ErrorIs(t, req.Validate(), ErrFieldTooLong)
209
209
+
})
210
210
+
211
211
+
t.Run("notes too long", func(t *testing.T) {
212
212
+
req := &CreateGrinderRequest{
213
213
+
Name: "Grinder",
214
214
+
Notes: strings.Repeat("a", MaxNotesLength+1),
215
215
+
}
216
216
+
assert.ErrorIs(t, req.Validate(), ErrNotesTooLong)
217
217
+
})
218
218
+
}
219
219
+
220
220
+
func TestUpdateGrinderRequest_Validate(t *testing.T) {
221
221
+
t.Run("valid request", func(t *testing.T) {
222
222
+
req := &UpdateGrinderRequest{Name: "Updated Grinder"}
223
223
+
assert.NoError(t, req.Validate())
224
224
+
})
225
225
+
226
226
+
t.Run("empty name", func(t *testing.T) {
227
227
+
req := &UpdateGrinderRequest{Name: ""}
228
228
+
assert.ErrorIs(t, req.Validate(), ErrNameRequired)
229
229
+
})
230
230
+
231
231
+
t.Run("grinder type too long", func(t *testing.T) {
232
232
+
req := &UpdateGrinderRequest{
233
233
+
Name: "Grinder",
234
234
+
GrinderType: strings.Repeat("a", MaxGrinderTypeLength+1),
235
235
+
}
236
236
+
assert.ErrorIs(t, req.Validate(), ErrFieldTooLong)
237
237
+
})
238
238
+
239
239
+
t.Run("notes too long", func(t *testing.T) {
240
240
+
req := &UpdateGrinderRequest{
241
241
+
Name: "Grinder",
242
242
+
Notes: strings.Repeat("a", MaxNotesLength+1),
243
243
+
}
244
244
+
assert.ErrorIs(t, req.Validate(), ErrNotesTooLong)
245
245
+
})
246
246
+
}
247
247
+
248
248
+
func TestCreateBrewerRequest_Validate(t *testing.T) {
249
249
+
t.Run("valid request", func(t *testing.T) {
250
250
+
req := &CreateBrewerRequest{Name: "V60"}
251
251
+
assert.NoError(t, req.Validate())
252
252
+
})
253
253
+
254
254
+
t.Run("empty name", func(t *testing.T) {
255
255
+
req := &CreateBrewerRequest{Name: ""}
256
256
+
assert.ErrorIs(t, req.Validate(), ErrNameRequired)
257
257
+
})
258
258
+
259
259
+
t.Run("name too long", func(t *testing.T) {
260
260
+
req := &CreateBrewerRequest{Name: strings.Repeat("a", MaxNameLength+1)}
261
261
+
assert.ErrorIs(t, req.Validate(), ErrNameTooLong)
262
262
+
})
263
263
+
264
264
+
t.Run("brewer type too long", func(t *testing.T) {
265
265
+
req := &CreateBrewerRequest{
266
266
+
Name: "Brewer",
267
267
+
BrewerType: strings.Repeat("a", MaxBrewerTypeLength+1),
268
268
+
}
269
269
+
assert.ErrorIs(t, req.Validate(), ErrFieldTooLong)
270
270
+
})
271
271
+
272
272
+
t.Run("description too long", func(t *testing.T) {
273
273
+
req := &CreateBrewerRequest{
274
274
+
Name: "Brewer",
275
275
+
Description: strings.Repeat("a", MaxDescriptionLength+1),
276
276
+
}
277
277
+
assert.ErrorIs(t, req.Validate(), ErrDescTooLong)
278
278
+
})
279
279
+
}
280
280
+
281
281
+
func TestUpdateBrewerRequest_Validate(t *testing.T) {
282
282
+
t.Run("valid request", func(t *testing.T) {
283
283
+
req := &UpdateBrewerRequest{Name: "Updated V60"}
284
284
+
assert.NoError(t, req.Validate())
285
285
+
})
286
286
+
287
287
+
t.Run("empty name", func(t *testing.T) {
288
288
+
req := &UpdateBrewerRequest{Name: ""}
289
289
+
assert.ErrorIs(t, req.Validate(), ErrNameRequired)
290
290
+
})
291
291
+
292
292
+
t.Run("brewer type too long", func(t *testing.T) {
293
293
+
req := &UpdateBrewerRequest{
294
294
+
Name: "Brewer",
295
295
+
BrewerType: strings.Repeat("a", MaxBrewerTypeLength+1),
296
296
+
}
297
297
+
assert.ErrorIs(t, req.Validate(), ErrFieldTooLong)
298
298
+
})
299
299
+
300
300
+
t.Run("description too long", func(t *testing.T) {
301
301
+
req := &UpdateBrewerRequest{
302
302
+
Name: "Brewer",
303
303
+
Description: strings.Repeat("a", MaxDescriptionLength+1),
304
304
+
}
305
305
+
assert.ErrorIs(t, req.Validate(), ErrDescTooLong)
306
306
+
})
307
307
+
}
308
308
+
309
309
+
func TestCreateBrewRequest_Validate(t *testing.T) {
310
310
+
t.Run("valid minimal request", func(t *testing.T) {
311
311
+
req := &CreateBrewRequest{}
312
312
+
assert.NoError(t, req.Validate())
313
313
+
})
314
314
+
315
315
+
t.Run("valid full request", func(t *testing.T) {
316
316
+
req := &CreateBrewRequest{
317
317
+
BeanRKey: "abc123",
318
318
+
Method: "Pour Over",
319
319
+
Temperature: 93.5,
320
320
+
WaterAmount: 250,
321
321
+
CoffeeAmount: 15,
322
322
+
TimeSeconds: 210,
323
323
+
GrindSize: "Medium-Fine",
324
324
+
GrinderRKey: "grinder1",
325
325
+
BrewerRKey: "brewer1",
326
326
+
TastingNotes: "Fruity, bright acidity",
327
327
+
Rating: 8,
328
328
+
}
329
329
+
assert.NoError(t, req.Validate())
330
330
+
})
331
331
+
332
332
+
t.Run("method too long", func(t *testing.T) {
333
333
+
req := &CreateBrewRequest{
334
334
+
Method: strings.Repeat("a", MaxMethodLength+1),
335
335
+
}
336
336
+
assert.ErrorIs(t, req.Validate(), ErrFieldTooLong)
337
337
+
})
338
338
+
339
339
+
t.Run("grind size too long", func(t *testing.T) {
340
340
+
req := &CreateBrewRequest{
341
341
+
GrindSize: strings.Repeat("a", MaxGrindSizeLength+1),
342
342
+
}
343
343
+
assert.ErrorIs(t, req.Validate(), ErrFieldTooLong)
344
344
+
})
345
345
+
346
346
+
t.Run("tasting notes too long", func(t *testing.T) {
347
347
+
req := &CreateBrewRequest{
348
348
+
TastingNotes: strings.Repeat("a", MaxTastingNotesLength+1),
349
349
+
}
350
350
+
assert.ErrorIs(t, req.Validate(), ErrFieldTooLong)
351
351
+
})
352
352
+
}
+118
internal/web/bff/helpers_test.go
···
2
2
3
3
import (
4
4
"testing"
5
5
+
"time"
5
6
6
7
"arabica/internal/models"
7
8
"github.com/stretchr/testify/assert"
···
122
123
})
123
124
}
124
125
}
126
126
+
127
127
+
func TestHasTemp(t *testing.T) {
128
128
+
assert.False(t, HasTemp(0))
129
129
+
assert.False(t, HasTemp(-1))
130
130
+
assert.True(t, HasTemp(0.1))
131
131
+
assert.True(t, HasTemp(93.5))
132
132
+
}
133
133
+
134
134
+
func TestHasValue(t *testing.T) {
135
135
+
assert.False(t, HasValue(0))
136
136
+
assert.False(t, HasValue(-1))
137
137
+
assert.True(t, HasValue(1))
138
138
+
assert.True(t, HasValue(250))
139
139
+
}
140
140
+
141
141
+
func TestSafeAvatarURL(t *testing.T) {
142
142
+
tests := []struct {
143
143
+
name string
144
144
+
input string
145
145
+
expected string
146
146
+
}{
147
147
+
{"empty string", "", ""},
148
148
+
{"trusted bsky CDN", "https://cdn.bsky.app/img/avatar/did:plc:abc/cid@jpeg", "https://cdn.bsky.app/img/avatar/did:plc:abc/cid@jpeg"},
149
149
+
{"trusted av-cdn", "https://av-cdn.bsky.app/img/avatar/abc", "https://av-cdn.bsky.app/img/avatar/abc"},
150
150
+
{"static path", "/static/icon-placeholder.svg", "/static/icon-placeholder.svg"},
151
151
+
{"non-static relative path", "/evil/path", ""},
152
152
+
{"http scheme rejected", "http://cdn.bsky.app/img/avatar/abc", ""},
153
153
+
{"untrusted domain", "https://evil.com/avatar.jpg", ""},
154
154
+
{"javascript scheme", "javascript:alert(1)", ""},
155
155
+
{"data URI rejected", "data:image/svg+xml,<svg></svg>", ""},
156
156
+
{"invalid URL", "://invalid", ""},
157
157
+
{"subdomain of trusted", "https://sub.cdn.bsky.app/avatar.jpg", "https://sub.cdn.bsky.app/avatar.jpg"},
158
158
+
}
159
159
+
160
160
+
for _, tt := range tests {
161
161
+
t.Run(tt.name, func(t *testing.T) {
162
162
+
assert.Equal(t, tt.expected, SafeAvatarURL(tt.input))
163
163
+
})
164
164
+
}
165
165
+
}
166
166
+
167
167
+
func TestSafeWebsiteURL(t *testing.T) {
168
168
+
tests := []struct {
169
169
+
name string
170
170
+
input string
171
171
+
expected string
172
172
+
}{
173
173
+
{"empty string", "", ""},
174
174
+
{"valid https", "https://example.com", "https://example.com"},
175
175
+
{"valid http", "http://example.com", "http://example.com"},
176
176
+
{"javascript scheme", "javascript:alert(1)", ""},
177
177
+
{"ftp scheme", "ftp://files.example.com", ""},
178
178
+
{"no dot in host", "https://localhost", ""},
179
179
+
{"invalid URL", "://invalid", ""},
180
180
+
{"https with path", "https://roaster.coffee/about", "https://roaster.coffee/about"},
181
181
+
}
182
182
+
183
183
+
for _, tt := range tests {
184
184
+
t.Run(tt.name, func(t *testing.T) {
185
185
+
assert.Equal(t, tt.expected, SafeWebsiteURL(tt.input))
186
186
+
})
187
187
+
}
188
188
+
}
189
189
+
190
190
+
func TestEscapeJS(t *testing.T) {
191
191
+
tests := []struct {
192
192
+
name string
193
193
+
input string
194
194
+
expected string
195
195
+
}{
196
196
+
{"empty string", "", ""},
197
197
+
{"no special chars", "hello world", "hello world"},
198
198
+
{"single quotes", "it's a test", "it\\'s a test"},
199
199
+
{"double quotes", `say "hello"`, `say \"hello\"`},
200
200
+
{"newlines", "line1\nline2", "line1\\nline2"},
201
201
+
{"carriage return", "line1\rline2", "line1\\rline2"},
202
202
+
{"tabs", "col1\tcol2", "col1\\tcol2"},
203
203
+
{"backslash", `path\to\file`, `path\\to\\file`},
204
204
+
{"mixed", "it's a \"test\"\nwith\\stuff", "it\\'s a \\\"test\\\"\\nwith\\\\stuff"},
205
205
+
}
206
206
+
207
207
+
for _, tt := range tests {
208
208
+
t.Run(tt.name, func(t *testing.T) {
209
209
+
assert.Equal(t, tt.expected, EscapeJS(tt.input))
210
210
+
})
211
211
+
}
212
212
+
}
213
213
+
214
214
+
func TestFormatTimeAgo(t *testing.T) {
215
215
+
now := time.Now()
216
216
+
217
217
+
tests := []struct {
218
218
+
name string
219
219
+
input time.Time
220
220
+
expected string
221
221
+
}{
222
222
+
{"just now", now.Add(-30 * time.Second), "just now"},
223
223
+
{"1 minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
224
224
+
{"5 minutes ago", now.Add(-5 * time.Minute), "5 minutes ago"},
225
225
+
{"1 hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
226
226
+
{"3 hours ago", now.Add(-3 * time.Hour), "3 hours ago"},
227
227
+
{"yesterday", now.Add(-36 * time.Hour), "yesterday"},
228
228
+
{"3 days ago", now.Add(-3 * 24 * time.Hour), "3 days ago"},
229
229
+
{"1 week ago", now.Add(-8 * 24 * time.Hour), "1 week ago"},
230
230
+
{"3 weeks ago", now.Add(-22 * 24 * time.Hour), "3 weeks ago"},
231
231
+
{"1 month ago", now.Add(-35 * 24 * time.Hour), "1 month ago"},
232
232
+
{"6 months ago", now.Add(-180 * 24 * time.Hour), "6 months ago"},
233
233
+
{"1 year ago", now.Add(-400 * 24 * time.Hour), "1 year ago"},
234
234
+
{"2 years ago", now.Add(-800 * 24 * time.Hour), "2 years ago"},
235
235
+
}
236
236
+
237
237
+
for _, tt := range tests {
238
238
+
t.Run(tt.name, func(t *testing.T) {
239
239
+
assert.Equal(t, tt.expected, FormatTimeAgo(tt.input))
240
240
+
})
241
241
+
}
242
242
+
}