Coffee journaling on ATProto (alpha)
alpha.arabica.social
coffee
1package middleware
2
3import (
4 "context"
5 "crypto/rand"
6 "encoding/base64"
7 "net/http"
8 "strings"
9 "sync"
10 "time"
11)
12
13type cspNonceKeyType struct{}
14
15var cspNonceKey = cspNonceKeyType{}
16
17func generateNonce() (string, error) {
18 b := make([]byte, 16)
19 if _, err := rand.Read(b); err != nil {
20 return "", err
21 }
22 return base64.StdEncoding.EncodeToString(b), nil
23}
24
25func CSPNonceFromContext(ctx context.Context) string {
26 if v, ok := ctx.Value(cspNonceKey).(string); ok {
27 return v
28 }
29 return ""
30}
31
32// SecurityHeadersMiddleware adds security headers to all responses
33func SecurityHeadersMiddleware(next http.Handler) http.Handler {
34 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
35 nonce, err := generateNonce()
36 if err != nil {
37 http.Error(w, "Internal Server Error", http.StatusInternalServerError)
38 return
39 }
40
41 r = r.WithContext(context.WithValue(r.Context(), cspNonceKey, nonce))
42
43 // Prevent clickjacking
44 w.Header().Set("X-Frame-Options", "DENY")
45
46 // Prevent MIME type sniffing
47 w.Header().Set("X-Content-Type-Options", "nosniff")
48
49 // XSS protection (legacy but still useful for older browsers)
50 w.Header().Set("X-XSS-Protection", "1; mode=block")
51
52 // Control referrer information
53 w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
54
55 // Permissions policy - disable unnecessary features
56 w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
57
58 // Content Security Policy
59 // Allows: self for scripts/styles, inline styles (for Tailwind), inline HTMX/Alpine
60 // Note: unsafe-eval required for Alpine.js standard build (CSP build has CDN MIME type issues)
61 // Note: form-action allows https: for OAuth redirects to external authorization servers
62 // TODO: set nonce/hash on unsafe tags -- needs to be set in elements as well
63 csp := strings.Join([]string{
64 "default-src 'self'",
65 "script-src 'self' 'unsafe-eval' 'nonce-" + nonce + "'",
66 "style-src 'self' 'unsafe-inline'", // unsafe-inline needed for Tailwind
67 "img-src 'self' https: data:", // Allow external images (avatars) and data URIs
68 "font-src 'self'",
69 "connect-src 'self' https:", // Allow connections to external APIs (OAuth, PDS)
70 "frame-ancestors 'none'",
71 "base-uri 'self'",
72 "form-action 'self' https:", // Allow form submissions to external OAuth servers
73 }, "; ")
74 w.Header().Set("Content-Security-Policy", csp)
75
76 next.ServeHTTP(w, r)
77 })
78}
79
80// RateLimiter implements a simple per-IP rate limiter using token bucket algorithm
81type RateLimiter struct {
82 mu sync.Mutex
83 visitors map[string]*visitor
84 rate int // requests per window
85 window time.Duration // time window
86 cleanup time.Duration // cleanup interval for old entries
87}
88
89type visitor struct {
90 tokens int
91 lastReset time.Time
92}
93
94// NewRateLimiter creates a new rate limiter
95// rate: number of requests allowed per window
96// window: time window for rate limiting
97func NewRateLimiter(rate int, window time.Duration) *RateLimiter {
98 rl := &RateLimiter{
99 visitors: make(map[string]*visitor),
100 rate: rate,
101 window: window,
102 cleanup: window * 2,
103 }
104
105 // Start cleanup goroutine
106 go rl.cleanupLoop()
107
108 return rl
109}
110
111func (rl *RateLimiter) cleanupLoop() {
112 ticker := time.NewTicker(rl.cleanup)
113 defer ticker.Stop()
114
115 for range ticker.C {
116 rl.mu.Lock()
117 now := time.Now()
118 for ip, v := range rl.visitors {
119 if now.Sub(v.lastReset) > rl.cleanup {
120 delete(rl.visitors, ip)
121 }
122 }
123 rl.mu.Unlock()
124 }
125}
126
127// Allow checks if a request from the given IP is allowed
128func (rl *RateLimiter) Allow(ip string) bool {
129 rl.mu.Lock()
130 defer rl.mu.Unlock()
131
132 now := time.Now()
133 v, exists := rl.visitors[ip]
134
135 if !exists {
136 rl.visitors[ip] = &visitor{
137 tokens: rl.rate - 1, // Use one token
138 lastReset: now,
139 }
140 return true
141 }
142
143 // Reset tokens if window has passed
144 if now.Sub(v.lastReset) >= rl.window {
145 v.tokens = rl.rate - 1
146 v.lastReset = now
147 return true
148 }
149
150 // Check if tokens available
151 if v.tokens > 0 {
152 v.tokens--
153 return true
154 }
155
156 return false
157}
158
159// RateLimitConfig holds configuration for rate limiting different endpoint types
160type RateLimitConfig struct {
161 // AuthLimiter for login/auth endpoints (stricter)
162 AuthLimiter *RateLimiter
163 // APILimiter for general API endpoints
164 APILimiter *RateLimiter
165 // GlobalLimiter for all other requests
166 GlobalLimiter *RateLimiter
167}
168
169// NewDefaultRateLimitConfig creates rate limiters with sensible defaults
170func NewDefaultRateLimitConfig() *RateLimitConfig {
171 return &RateLimitConfig{
172 AuthLimiter: NewRateLimiter(5, time.Minute), // 5 auth attempts per minute
173 APILimiter: NewRateLimiter(60, time.Minute), // 60 API calls per minute
174 GlobalLimiter: NewRateLimiter(120, time.Minute), // 120 requests per minute
175 }
176}
177
178// RateLimitMiddleware creates a rate limiting middleware
179func RateLimitMiddleware(config *RateLimitConfig) func(http.Handler) http.Handler {
180 return func(next http.Handler) http.Handler {
181 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
182 ip := GetClientIP(r)
183 path := r.URL.Path
184
185 var limiter *RateLimiter
186
187 // Select appropriate limiter based on path
188 switch {
189 case strings.HasPrefix(path, "/auth/") || path == "/login" || path == "/oauth/callback":
190 limiter = config.AuthLimiter
191 case strings.HasPrefix(path, "/api/"):
192 limiter = config.APILimiter
193 default:
194 limiter = config.GlobalLimiter
195 }
196
197 if !limiter.Allow(ip) {
198 w.Header().Set("Retry-After", "60")
199 http.Error(w, "Too many requests", http.StatusTooManyRequests)
200 return
201 }
202
203 next.ServeHTTP(w, r)
204 })
205 }
206}
207
208
209// RequireHTMXMiddleware ensures that certain API routes are only accessible via HTMX requests.
210// This prevents direct browser access to internal API endpoints that return fragments or JSON.
211// Routes that need to be publicly accessible (like /api/resolve-handle) should not use this middleware.
212func RequireHTMXMiddleware(next http.Handler) http.Handler {
213 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
214 // Check for HTMX request header
215 if r.Header.Get("HX-Request") != "true" {
216 http.NotFound(w, r)
217 return
218 }
219 next.ServeHTTP(w, r)
220 })
221}
222
223// MaxBodySize limits the size of request bodies
224const (
225 MaxJSONBodySize = 1 << 20 // 1 MB for JSON requests
226 MaxFormBodySize = 1 << 20 // 1 MB for form submissions
227)
228
229// LimitBodyMiddleware limits request body size to prevent DoS
230func LimitBodyMiddleware(next http.Handler) http.Handler {
231 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
232 if r.Body != nil {
233 contentType := r.Header.Get("Content-Type")
234 var maxSize int64
235
236 switch {
237 case strings.HasPrefix(contentType, "application/json"):
238 maxSize = MaxJSONBodySize
239 case strings.HasPrefix(contentType, "application/x-www-form-urlencoded"),
240 strings.HasPrefix(contentType, "multipart/form-data"):
241 maxSize = MaxFormBodySize
242 default:
243 maxSize = MaxJSONBodySize // Default limit
244 }
245
246 r.Body = http.MaxBytesReader(w, r.Body, maxSize)
247 }
248
249 next.ServeHTTP(w, r)
250 })
251}