Coffee journaling on ATProto (alpha) alpha.arabica.social
coffee

feat: security improvements

pdewey.com a41637c9 8d528ce9

verified
+704 -46
+26 -1
internal/atproto/nsid.go
··· 1 1 package atproto 2 2 3 - import "fmt" 3 + import ( 4 + "fmt" 5 + "regexp" 6 + ) 4 7 5 8 // NSID (Namespaced Identifier) constants for Arabica lexicons. 6 9 // The domain is reversed following ATProto conventions: arabica.social -> social.arabica ··· 15 18 NSIDBrewer = NSIDBase + ".brewer" 16 19 NSIDGrinder = NSIDBase + ".grinder" 17 20 NSIDRoaster = NSIDBase + ".roaster" 21 + 22 + // MaxRKeyLength is the maximum allowed length for a record key 23 + MaxRKeyLength = 512 18 24 ) 25 + 26 + // rkeyRegex validates AT Protocol record keys (rkeys). 27 + // Valid rkeys contain only alphanumeric characters, hyphens, underscores, colons, and periods. 28 + // They must start with an alphanumeric character and be 1-512 characters long. 29 + // TIDs are the most common format: 13 lowercase base32 characters (e.g., "3kfk4slgu6s2h"). 30 + var rkeyRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._:-]{0,511}$`) 31 + 32 + // ValidateRKey checks if an rkey is valid according to AT Protocol spec. 33 + // Returns true if valid, false otherwise. 34 + func ValidateRKey(rkey string) bool { 35 + if rkey == "" || len(rkey) > MaxRKeyLength { 36 + return false 37 + } 38 + // Reserved rkeys that should not be used 39 + if rkey == "." || rkey == ".." { 40 + return false 41 + } 42 + return rkeyRegex.MatchString(rkey) 43 + } 19 44 20 45 // BuildATURI constructs an AT-URI from a DID, collection NSID, and record key 21 46 func BuildATURI(did, collection, rkey string) string {
+46
internal/atproto/nsid_test.go
··· 1 1 package atproto 2 2 3 3 import ( 4 + "strings" 4 5 "testing" 5 6 ) 6 7 ··· 27 28 t.Run(tt.name, func(t *testing.T) { 28 29 if tt.got != tt.expected { 29 30 t.Errorf("%s = %q, want %q", tt.name, tt.got, tt.expected) 31 + } 32 + }) 33 + } 34 + } 35 + 36 + func TestValidateRKey(t *testing.T) { 37 + tests := []struct { 38 + name string 39 + rkey string 40 + valid bool 41 + }{ 42 + // Valid rkeys 43 + {"TID format", "3kfk4slgu6s2h", true}, 44 + {"short alphanumeric", "abc123", true}, 45 + {"single char", "a", true}, 46 + {"with hyphen", "my-record", true}, 47 + {"with underscore", "my_record", true}, 48 + {"with period", "my.record", true}, 49 + {"with colon", "my:record", true}, 50 + {"mixed valid chars", "a1-b2_c3.d4:e5", true}, 51 + {"uppercase", "ABC123", true}, 52 + {"mixed case", "AbC123xYz", true}, 53 + 54 + // Invalid rkeys 55 + {"empty string", "", false}, 56 + {"starts with hyphen", "-abc", false}, 57 + {"starts with underscore", "_abc", false}, 58 + {"starts with period", ".abc", false}, 59 + {"starts with colon", ":abc", false}, 60 + {"reserved dot", ".", false}, 61 + {"reserved dotdot", "..", false}, 62 + {"contains slash", "abc/def", false}, 63 + {"contains space", "abc def", false}, 64 + {"contains at", "abc@def", false}, 65 + {"contains hash", "abc#def", false}, 66 + {"contains question", "abc?def", false}, 67 + {"too long", strings.Repeat("a", 513), false}, 68 + {"max length valid", strings.Repeat("a", 512), true}, 69 + } 70 + 71 + for _, tt := range tests { 72 + t.Run(tt.name, func(t *testing.T) { 73 + got := ValidateRKey(tt.rkey) 74 + if got != tt.valid { 75 + t.Errorf("ValidateRKey(%q) = %v, want %v", tt.rkey, got, tt.valid) 30 76 } 31 77 }) 32 78 }
+87
internal/atproto/public_client.go
··· 3 3 import ( 4 4 "context" 5 5 "encoding/json" 6 + "errors" 6 7 "fmt" 8 + "net" 7 9 "net/http" 8 10 "net/url" 9 11 "strings" ··· 18 20 PLCDirectoryURL = "https://plc.directory" 19 21 ) 20 22 23 + // ErrSSRFBlocked is returned when a potential SSRF attack is blocked 24 + var ErrSSRFBlocked = errors.New("request blocked: potential SSRF detected") 25 + 26 + // isPrivateIP checks if an IP address is in a private/internal range 27 + func isPrivateIP(ip net.IP) bool { 28 + if ip == nil { 29 + return false 30 + } 31 + 32 + // Check for loopback 33 + if ip.IsLoopback() { 34 + return true 35 + } 36 + 37 + // Check for link-local 38 + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { 39 + return true 40 + } 41 + 42 + // Check for private ranges 43 + if ip.IsPrivate() { 44 + return true 45 + } 46 + 47 + // Check for unspecified (0.0.0.0) 48 + if ip.IsUnspecified() { 49 + return true 50 + } 51 + 52 + // Check for cloud metadata endpoint (169.254.169.254) 53 + if ip.Equal(net.ParseIP("169.254.169.254")) { 54 + return true 55 + } 56 + 57 + return false 58 + } 59 + 60 + // validateDomain checks if a domain is safe to connect to (not internal/private) 61 + func validateDomain(domain string) error { 62 + // Block obviously dangerous patterns 63 + if domain == "localhost" || strings.HasSuffix(domain, ".local") { 64 + return ErrSSRFBlocked 65 + } 66 + 67 + // Check for IP addresses embedded in the domain 68 + if ip := net.ParseIP(domain); ip != nil { 69 + if isPrivateIP(ip) { 70 + return ErrSSRFBlocked 71 + } 72 + } 73 + 74 + // Resolve the domain and check all IPs 75 + ips, err := net.LookupIP(domain) 76 + if err != nil { 77 + // If we can't resolve it, let the HTTP request fail later 78 + return nil 79 + } 80 + 81 + for _, ip := range ips { 82 + if isPrivateIP(ip) { 83 + return ErrSSRFBlocked 84 + } 85 + } 86 + 87 + return nil 88 + } 89 + 21 90 // PublicClient provides unauthenticated access to public ATProto APIs 22 91 type PublicClient struct { 23 92 baseURL string ··· 89 158 } 90 159 } else if strings.HasPrefix(did, "did:web:") { 91 160 // Web DID - the domain is the PDS 161 + // Validate domain to prevent SSRF attacks 92 162 domain := strings.TrimPrefix(did, "did:web:") 163 + // Handle percent-encoded colons for ports (e.g., did:web:example.com%3A8080) 164 + domain = strings.ReplaceAll(domain, "%3A", ":") 165 + 166 + // Extract just the host part (without path) 167 + if idx := strings.Index(domain, "/"); idx != -1 { 168 + domain = domain[:idx] 169 + } 170 + 171 + // Validate the domain is safe 172 + host := domain 173 + if hostPart, _, err := net.SplitHostPort(domain); err == nil { 174 + host = hostPart 175 + } 176 + if err := validateDomain(host); err != nil { 177 + return "", err 178 + } 179 + 93 180 pdsEndpoint = "https://" + domain 94 181 } 95 182
+2 -1
internal/handlers/auth.go
··· 160 160 161 161 w.Header().Set("Content-Type", "application/json") 162 162 if err := json.NewEncoder(w).Encode(metadata); err != nil { 163 - http.Error(w, err.Error(), http.StatusInternalServerError) 163 + log.Error().Err(err).Msg("Failed to encode client metadata") 164 + http.Error(w, "Failed to encode response", http.StatusInternalServerError) 164 165 return 165 166 } 166 167 }
+142 -39
internal/handlers/handlers.go
··· 55 55 } 56 56 } 57 57 58 + // validateRKey validates and returns an rkey from a path parameter. 59 + // Returns the rkey if valid, or writes an error response and returns empty string if invalid. 60 + func validateRKey(w http.ResponseWriter, rkey string) string { 61 + if rkey == "" { 62 + http.Error(w, "Record key is required", http.StatusBadRequest) 63 + return "" 64 + } 65 + if !atproto.ValidateRKey(rkey) { 66 + http.Error(w, "Invalid record key format", http.StatusBadRequest) 67 + return "" 68 + } 69 + return rkey 70 + } 71 + 72 + // validateOptionalRKey validates an optional rkey from form data. 73 + // Returns an error message if invalid, empty string if valid or empty. 74 + func validateOptionalRKey(rkey, fieldName string) string { 75 + if rkey == "" { 76 + return "" 77 + } 78 + if !atproto.ValidateRKey(rkey) { 79 + return fieldName + " has invalid format" 80 + } 81 + return "" 82 + } 83 + 58 84 // getAtprotoStore creates a user-scoped atproto store from the request context. 59 85 // Returns the store and true if authenticated, or nil and false if not authenticated. 60 86 func (h *Handler) getAtprotoStore(r *http.Request) (database.Store, bool) { ··· 224 250 225 251 // Show edit brew form 226 252 func (h *Handler) HandleBrewEdit(w http.ResponseWriter, r *http.Request) { 227 - rkey := r.PathValue("id") // URL still uses "id" path param but value is now rkey 253 + rkey := validateRKey(w, r.PathValue("id")) 254 + if rkey == "" { 255 + return 256 + } 228 257 229 258 // Require authentication 230 259 store, authenticated := h.getAtprotoStore(r) ··· 377 406 beanRKey := r.FormValue("bean_rkey") 378 407 if beanRKey == "" { 379 408 http.Error(w, "Bean selection is required", http.StatusBadRequest) 409 + return 410 + } 411 + if !atproto.ValidateRKey(beanRKey) { 412 + http.Error(w, "Invalid bean selection", http.StatusBadRequest) 413 + return 414 + } 415 + 416 + // Validate optional rkeys 417 + grinderRKey := r.FormValue("grinder_rkey") 418 + if errMsg := validateOptionalRKey(grinderRKey, "Grinder selection"); errMsg != "" { 419 + http.Error(w, errMsg, http.StatusBadRequest) 420 + return 421 + } 422 + brewerRKey := r.FormValue("brewer_rkey") 423 + if errMsg := validateOptionalRKey(brewerRKey, "Brewer selection"); errMsg != "" { 424 + http.Error(w, errMsg, http.StatusBadRequest) 380 425 return 381 426 } 382 427 ··· 388 433 CoffeeAmount: coffeeAmount, 389 434 TimeSeconds: timeSeconds, 390 435 GrindSize: r.FormValue("grind_size"), 391 - GrinderRKey: r.FormValue("grinder_rkey"), 392 - BrewerRKey: r.FormValue("brewer_rkey"), 436 + GrinderRKey: grinderRKey, 437 + BrewerRKey: brewerRKey, 393 438 TastingNotes: r.FormValue("tasting_notes"), 394 439 Rating: rating, 395 440 Pours: pours, ··· 409 454 410 455 // Update existing brew 411 456 func (h *Handler) HandleBrewUpdate(w http.ResponseWriter, r *http.Request) { 412 - rkey := r.PathValue("id") // URL still uses "id" path param but value is now rkey 457 + rkey := validateRKey(w, r.PathValue("id")) 458 + if rkey == "" { 459 + return 460 + } 413 461 414 462 // Require authentication 415 463 store, authenticated := h.getAtprotoStore(r) ··· 436 484 http.Error(w, "Bean selection is required", http.StatusBadRequest) 437 485 return 438 486 } 487 + if !atproto.ValidateRKey(beanRKey) { 488 + http.Error(w, "Invalid bean selection", http.StatusBadRequest) 489 + return 490 + } 491 + 492 + // Validate optional rkeys 493 + grinderRKey := r.FormValue("grinder_rkey") 494 + if errMsg := validateOptionalRKey(grinderRKey, "Grinder selection"); errMsg != "" { 495 + http.Error(w, errMsg, http.StatusBadRequest) 496 + return 497 + } 498 + brewerRKey := r.FormValue("brewer_rkey") 499 + if errMsg := validateOptionalRKey(brewerRKey, "Brewer selection"); errMsg != "" { 500 + http.Error(w, errMsg, http.StatusBadRequest) 501 + return 502 + } 439 503 440 504 req := &models.CreateBrewRequest{ 441 505 BeanRKey: beanRKey, ··· 445 509 CoffeeAmount: coffeeAmount, 446 510 TimeSeconds: timeSeconds, 447 511 GrindSize: r.FormValue("grind_size"), 448 - GrinderRKey: r.FormValue("grinder_rkey"), 449 - BrewerRKey: r.FormValue("brewer_rkey"), 512 + GrinderRKey: grinderRKey, 513 + BrewerRKey: brewerRKey, 450 514 TastingNotes: r.FormValue("tasting_notes"), 451 515 Rating: rating, 452 516 Pours: pours, ··· 466 530 467 531 // Delete brew 468 532 func (h *Handler) HandleBrewDelete(w http.ResponseWriter, r *http.Request) { 469 - rkey := r.PathValue("id") // URL still uses "id" path param but value is now rkey 533 + rkey := validateRKey(w, r.PathValue("id")) 534 + if rkey == "" { 535 + return 536 + } 470 537 471 538 // Require authentication 472 539 store, authenticated := h.getAtprotoStore(r) ··· 594 661 return 595 662 } 596 663 597 - // Validate required fields 598 - if req.Name == "" { 599 - http.Error(w, "Bean name is required", http.StatusBadRequest) 664 + // Validate request 665 + if err := req.Validate(); err != nil { 666 + http.Error(w, err.Error(), http.StatusBadRequest) 667 + return 668 + } 669 + 670 + // Validate optional roaster rkey 671 + if errMsg := validateOptionalRKey(req.RoasterRKey, "Roaster selection"); errMsg != "" { 672 + http.Error(w, errMsg, http.StatusBadRequest) 600 673 return 601 674 } 602 675 ··· 628 701 return 629 702 } 630 703 631 - // Validate required fields 632 - if req.Name == "" { 633 - http.Error(w, "Roaster name is required", http.StatusBadRequest) 704 + // Validate request 705 + if err := req.Validate(); err != nil { 706 + http.Error(w, err.Error(), http.StatusBadRequest) 634 707 return 635 708 } 636 709 ··· 667 740 668 741 // Bean update/delete handlers 669 742 func (h *Handler) HandleBeanUpdate(w http.ResponseWriter, r *http.Request) { 670 - rkey := r.PathValue("id") // URL still uses "id" path param but value is now rkey 743 + rkey := validateRKey(w, r.PathValue("id")) 744 + if rkey == "" { 745 + return 746 + } 671 747 672 748 // Require authentication 673 749 store, authenticated := h.getAtprotoStore(r) ··· 682 758 return 683 759 } 684 760 685 - // Validate required fields 686 - if req.Name == "" { 687 - http.Error(w, "Bean name is required", http.StatusBadRequest) 761 + // Validate request 762 + if err := req.Validate(); err != nil { 763 + http.Error(w, err.Error(), http.StatusBadRequest) 764 + return 765 + } 766 + 767 + // Validate optional roaster rkey 768 + if errMsg := validateOptionalRKey(req.RoasterRKey, "Roaster selection"); errMsg != "" { 769 + http.Error(w, errMsg, http.StatusBadRequest) 688 770 return 689 771 } 690 772 ··· 708 790 } 709 791 710 792 func (h *Handler) HandleBeanDelete(w http.ResponseWriter, r *http.Request) { 711 - rkey := r.PathValue("id") // URL still uses "id" path param but value is now rkey 793 + rkey := validateRKey(w, r.PathValue("id")) 794 + if rkey == "" { 795 + return 796 + } 712 797 713 798 // Require authentication 714 799 store, authenticated := h.getAtprotoStore(r) ··· 728 813 729 814 // Roaster update/delete handlers 730 815 func (h *Handler) HandleRoasterUpdate(w http.ResponseWriter, r *http.Request) { 731 - rkey := r.PathValue("id") // URL still uses "id" path param but value is now rkey 816 + rkey := validateRKey(w, r.PathValue("id")) 817 + if rkey == "" { 818 + return 819 + } 732 820 733 821 // Require authentication 734 822 store, authenticated := h.getAtprotoStore(r) ··· 743 831 return 744 832 } 745 833 746 - // Validate required fields 747 - if req.Name == "" { 748 - http.Error(w, "Roaster name is required", http.StatusBadRequest) 834 + // Validate request 835 + if err := req.Validate(); err != nil { 836 + http.Error(w, err.Error(), http.StatusBadRequest) 749 837 return 750 838 } 751 839 ··· 769 857 } 770 858 771 859 func (h *Handler) HandleRoasterDelete(w http.ResponseWriter, r *http.Request) { 772 - rkey := r.PathValue("id") // URL still uses "id" path param but value is now rkey 860 + rkey := validateRKey(w, r.PathValue("id")) 861 + if rkey == "" { 862 + return 863 + } 773 864 774 865 // Require authentication 775 866 store, authenticated := h.getAtprotoStore(r) ··· 802 893 return 803 894 } 804 895 805 - // Validate required fields 806 - if req.Name == "" { 807 - http.Error(w, "Grinder name is required", http.StatusBadRequest) 896 + // Validate request 897 + if err := req.Validate(); err != nil { 898 + http.Error(w, err.Error(), http.StatusBadRequest) 808 899 return 809 900 } 810 901 ··· 822 913 } 823 914 824 915 func (h *Handler) HandleGrinderUpdate(w http.ResponseWriter, r *http.Request) { 825 - rkey := r.PathValue("id") // URL still uses "id" path param but value is now rkey 916 + rkey := validateRKey(w, r.PathValue("id")) 917 + if rkey == "" { 918 + return 919 + } 826 920 827 921 // Require authentication 828 922 store, authenticated := h.getAtprotoStore(r) ··· 837 931 return 838 932 } 839 933 840 - // Validate required fields 841 - if req.Name == "" { 842 - http.Error(w, "Grinder name is required", http.StatusBadRequest) 934 + // Validate request 935 + if err := req.Validate(); err != nil { 936 + http.Error(w, err.Error(), http.StatusBadRequest) 843 937 return 844 938 } 845 939 ··· 863 957 } 864 958 865 959 func (h *Handler) HandleGrinderDelete(w http.ResponseWriter, r *http.Request) { 866 - rkey := r.PathValue("id") // URL still uses "id" path param but value is now rkey 960 + rkey := validateRKey(w, r.PathValue("id")) 961 + if rkey == "" { 962 + return 963 + } 867 964 868 965 // Require authentication 869 966 store, authenticated := h.getAtprotoStore(r) ··· 896 993 return 897 994 } 898 995 899 - // Validate required fields 900 - if req.Name == "" { 901 - http.Error(w, "Brewer name is required", http.StatusBadRequest) 996 + // Validate request 997 + if err := req.Validate(); err != nil { 998 + http.Error(w, err.Error(), http.StatusBadRequest) 902 999 return 903 1000 } 904 1001 ··· 916 1013 } 917 1014 918 1015 func (h *Handler) HandleBrewerUpdate(w http.ResponseWriter, r *http.Request) { 919 - rkey := r.PathValue("id") // URL still uses "id" path param but value is now rkey 1016 + rkey := validateRKey(w, r.PathValue("id")) 1017 + if rkey == "" { 1018 + return 1019 + } 920 1020 921 1021 // Require authentication 922 1022 store, authenticated := h.getAtprotoStore(r) ··· 931 1031 return 932 1032 } 933 1033 934 - // Validate required fields 935 - if req.Name == "" { 936 - http.Error(w, "Brewer name is required", http.StatusBadRequest) 1034 + // Validate request 1035 + if err := req.Validate(); err != nil { 1036 + http.Error(w, err.Error(), http.StatusBadRequest) 937 1037 return 938 1038 } 939 1039 ··· 957 1057 } 958 1058 959 1059 func (h *Handler) HandleBrewerDelete(w http.ResponseWriter, r *http.Request) { 960 - rkey := r.PathValue("id") // URL still uses "id" path param but value is now rkey 1060 + rkey := validateRKey(w, r.PathValue("id")) 1061 + if rkey == "" { 1062 + return 1063 + } 961 1064 962 1065 // Require authentication 963 1066 store, authenticated := h.getAtprotoStore(r)
+205
internal/middleware/security.go
··· 1 + package middleware 2 + 3 + import ( 4 + "net/http" 5 + "strings" 6 + "sync" 7 + "time" 8 + ) 9 + 10 + // SecurityHeadersMiddleware adds security headers to all responses 11 + func SecurityHeadersMiddleware(next http.Handler) http.Handler { 12 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 13 + // Prevent clickjacking 14 + w.Header().Set("X-Frame-Options", "DENY") 15 + 16 + // Prevent MIME type sniffing 17 + w.Header().Set("X-Content-Type-Options", "nosniff") 18 + 19 + // XSS protection (legacy but still useful for older browsers) 20 + w.Header().Set("X-XSS-Protection", "1; mode=block") 21 + 22 + // Control referrer information 23 + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") 24 + 25 + // Permissions policy - disable unnecessary features 26 + w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()") 27 + 28 + // Content Security Policy 29 + // Allows: self for scripts/styles, inline styles (for Tailwind), unpkg for HTMX/Alpine 30 + csp := strings.Join([]string{ 31 + "default-src 'self'", 32 + "script-src 'self' https://unpkg.com", 33 + "style-src 'self' 'unsafe-inline'", // unsafe-inline needed for Tailwind 34 + "img-src 'self' https: data:", // Allow external images (avatars) and data URIs 35 + "font-src 'self'", 36 + "connect-src 'self'", 37 + "frame-ancestors 'none'", 38 + "base-uri 'self'", 39 + "form-action 'self'", 40 + }, "; ") 41 + w.Header().Set("Content-Security-Policy", csp) 42 + 43 + next.ServeHTTP(w, r) 44 + }) 45 + } 46 + 47 + // RateLimiter implements a simple per-IP rate limiter using token bucket algorithm 48 + type RateLimiter struct { 49 + mu sync.Mutex 50 + visitors map[string]*visitor 51 + rate int // requests per window 52 + window time.Duration // time window 53 + cleanup time.Duration // cleanup interval for old entries 54 + } 55 + 56 + type visitor struct { 57 + tokens int 58 + lastReset time.Time 59 + } 60 + 61 + // NewRateLimiter creates a new rate limiter 62 + // rate: number of requests allowed per window 63 + // window: time window for rate limiting 64 + func NewRateLimiter(rate int, window time.Duration) *RateLimiter { 65 + rl := &RateLimiter{ 66 + visitors: make(map[string]*visitor), 67 + rate: rate, 68 + window: window, 69 + cleanup: window * 2, 70 + } 71 + 72 + // Start cleanup goroutine 73 + go rl.cleanupLoop() 74 + 75 + return rl 76 + } 77 + 78 + func (rl *RateLimiter) cleanupLoop() { 79 + ticker := time.NewTicker(rl.cleanup) 80 + defer ticker.Stop() 81 + 82 + for range ticker.C { 83 + rl.mu.Lock() 84 + now := time.Now() 85 + for ip, v := range rl.visitors { 86 + if now.Sub(v.lastReset) > rl.cleanup { 87 + delete(rl.visitors, ip) 88 + } 89 + } 90 + rl.mu.Unlock() 91 + } 92 + } 93 + 94 + // Allow checks if a request from the given IP is allowed 95 + func (rl *RateLimiter) Allow(ip string) bool { 96 + rl.mu.Lock() 97 + defer rl.mu.Unlock() 98 + 99 + now := time.Now() 100 + v, exists := rl.visitors[ip] 101 + 102 + if !exists { 103 + rl.visitors[ip] = &visitor{ 104 + tokens: rl.rate - 1, // Use one token 105 + lastReset: now, 106 + } 107 + return true 108 + } 109 + 110 + // Reset tokens if window has passed 111 + if now.Sub(v.lastReset) >= rl.window { 112 + v.tokens = rl.rate - 1 113 + v.lastReset = now 114 + return true 115 + } 116 + 117 + // Check if tokens available 118 + if v.tokens > 0 { 119 + v.tokens-- 120 + return true 121 + } 122 + 123 + return false 124 + } 125 + 126 + // RateLimitConfig holds configuration for rate limiting different endpoint types 127 + type RateLimitConfig struct { 128 + // AuthLimiter for login/auth endpoints (stricter) 129 + AuthLimiter *RateLimiter 130 + // APILimiter for general API endpoints 131 + APILimiter *RateLimiter 132 + // GlobalLimiter for all other requests 133 + GlobalLimiter *RateLimiter 134 + } 135 + 136 + // NewDefaultRateLimitConfig creates rate limiters with sensible defaults 137 + func NewDefaultRateLimitConfig() *RateLimitConfig { 138 + return &RateLimitConfig{ 139 + AuthLimiter: NewRateLimiter(5, time.Minute), // 5 auth attempts per minute 140 + APILimiter: NewRateLimiter(60, time.Minute), // 60 API calls per minute 141 + GlobalLimiter: NewRateLimiter(120, time.Minute), // 120 requests per minute 142 + } 143 + } 144 + 145 + // RateLimitMiddleware creates a rate limiting middleware 146 + func RateLimitMiddleware(config *RateLimitConfig) func(http.Handler) http.Handler { 147 + return func(next http.Handler) http.Handler { 148 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 149 + ip := getClientIP(r) 150 + path := r.URL.Path 151 + 152 + var limiter *RateLimiter 153 + 154 + // Select appropriate limiter based on path 155 + switch { 156 + case strings.HasPrefix(path, "/auth/") || path == "/login" || path == "/oauth/callback": 157 + limiter = config.AuthLimiter 158 + case strings.HasPrefix(path, "/api/"): 159 + limiter = config.APILimiter 160 + default: 161 + limiter = config.GlobalLimiter 162 + } 163 + 164 + if !limiter.Allow(ip) { 165 + w.Header().Set("Retry-After", "60") 166 + http.Error(w, "Too many requests", http.StatusTooManyRequests) 167 + return 168 + } 169 + 170 + next.ServeHTTP(w, r) 171 + }) 172 + } 173 + } 174 + 175 + // getClientIP is defined in logging.go 176 + 177 + // MaxBodySize limits the size of request bodies 178 + const ( 179 + MaxJSONBodySize = 1 << 20 // 1 MB for JSON requests 180 + MaxFormBodySize = 1 << 20 // 1 MB for form submissions 181 + ) 182 + 183 + // LimitBodyMiddleware limits request body size to prevent DoS 184 + func LimitBodyMiddleware(next http.Handler) http.Handler { 185 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 186 + if r.Body != nil { 187 + contentType := r.Header.Get("Content-Type") 188 + var maxSize int64 189 + 190 + switch { 191 + case strings.HasPrefix(contentType, "application/json"): 192 + maxSize = MaxJSONBodySize 193 + case strings.HasPrefix(contentType, "application/x-www-form-urlencoded"), 194 + strings.HasPrefix(contentType, "multipart/form-data"): 195 + maxSize = MaxFormBodySize 196 + default: 197 + maxSize = MaxJSONBodySize // Default limit 198 + } 199 + 200 + r.Body = http.MaxBytesReader(w, r.Body, maxSize) 201 + } 202 + 203 + next.ServeHTTP(w, r) 204 + }) 205 + }
+180 -1
internal/models/models.go
··· 1 1 package models 2 2 3 - import "time" 3 + import ( 4 + "errors" 5 + "time" 6 + ) 7 + 8 + // Field length limits for validation 9 + const ( 10 + MaxNameLength = 200 11 + MaxLocationLength = 200 12 + MaxWebsiteLength = 500 13 + MaxDescriptionLength = 2000 14 + MaxNotesLength = 2000 15 + MaxOriginLength = 200 16 + MaxRoastLevelLength = 100 17 + MaxProcessLength = 100 18 + MaxMethodLength = 100 19 + MaxGrindSizeLength = 100 20 + MaxGrinderTypeLength = 50 21 + MaxBurrTypeLength = 50 22 + ) 23 + 24 + // Validation errors 25 + var ( 26 + ErrNameRequired = errors.New("name is required") 27 + ErrNameTooLong = errors.New("name is too long") 28 + ErrLocationTooLong = errors.New("location is too long") 29 + ErrWebsiteTooLong = errors.New("website is too long") 30 + ErrDescTooLong = errors.New("description is too long") 31 + ErrNotesTooLong = errors.New("notes is too long") 32 + ErrOriginTooLong = errors.New("origin is too long") 33 + ErrFieldTooLong = errors.New("field value is too long") 34 + ) 4 35 5 36 type Bean struct { 6 37 RKey string `json:"rkey"` // Record key (AT Protocol or stringified ID for SQLite) ··· 142 173 Name string `json:"name"` 143 174 Description string `json:"description"` 144 175 } 176 + 177 + // Validate checks that all fields are within acceptable limits 178 + func (r *CreateBeanRequest) Validate() error { 179 + if r.Name == "" { 180 + return ErrNameRequired 181 + } 182 + if len(r.Name) > MaxNameLength { 183 + return ErrNameTooLong 184 + } 185 + if len(r.Origin) > MaxOriginLength { 186 + return ErrOriginTooLong 187 + } 188 + if len(r.RoastLevel) > MaxRoastLevelLength { 189 + return ErrFieldTooLong 190 + } 191 + if len(r.Process) > MaxProcessLength { 192 + return ErrFieldTooLong 193 + } 194 + if len(r.Description) > MaxDescriptionLength { 195 + return ErrDescTooLong 196 + } 197 + return nil 198 + } 199 + 200 + // Validate checks that all fields are within acceptable limits 201 + func (r *UpdateBeanRequest) Validate() error { 202 + if r.Name == "" { 203 + return ErrNameRequired 204 + } 205 + if len(r.Name) > MaxNameLength { 206 + return ErrNameTooLong 207 + } 208 + if len(r.Origin) > MaxOriginLength { 209 + return ErrOriginTooLong 210 + } 211 + if len(r.RoastLevel) > MaxRoastLevelLength { 212 + return ErrFieldTooLong 213 + } 214 + if len(r.Process) > MaxProcessLength { 215 + return ErrFieldTooLong 216 + } 217 + if len(r.Description) > MaxDescriptionLength { 218 + return ErrDescTooLong 219 + } 220 + return nil 221 + } 222 + 223 + // Validate checks that all fields are within acceptable limits 224 + func (r *CreateRoasterRequest) Validate() error { 225 + if r.Name == "" { 226 + return ErrNameRequired 227 + } 228 + if len(r.Name) > MaxNameLength { 229 + return ErrNameTooLong 230 + } 231 + if len(r.Location) > MaxLocationLength { 232 + return ErrLocationTooLong 233 + } 234 + if len(r.Website) > MaxWebsiteLength { 235 + return ErrWebsiteTooLong 236 + } 237 + return nil 238 + } 239 + 240 + // Validate checks that all fields are within acceptable limits 241 + func (r *UpdateRoasterRequest) Validate() error { 242 + if r.Name == "" { 243 + return ErrNameRequired 244 + } 245 + if len(r.Name) > MaxNameLength { 246 + return ErrNameTooLong 247 + } 248 + if len(r.Location) > MaxLocationLength { 249 + return ErrLocationTooLong 250 + } 251 + if len(r.Website) > MaxWebsiteLength { 252 + return ErrWebsiteTooLong 253 + } 254 + return nil 255 + } 256 + 257 + // Validate checks that all fields are within acceptable limits 258 + func (r *CreateGrinderRequest) Validate() error { 259 + if r.Name == "" { 260 + return ErrNameRequired 261 + } 262 + if len(r.Name) > MaxNameLength { 263 + return ErrNameTooLong 264 + } 265 + if len(r.GrinderType) > MaxGrinderTypeLength { 266 + return ErrFieldTooLong 267 + } 268 + if len(r.BurrType) > MaxBurrTypeLength { 269 + return ErrFieldTooLong 270 + } 271 + if len(r.Notes) > MaxNotesLength { 272 + return ErrNotesTooLong 273 + } 274 + return nil 275 + } 276 + 277 + // Validate checks that all fields are within acceptable limits 278 + func (r *UpdateGrinderRequest) Validate() error { 279 + if r.Name == "" { 280 + return ErrNameRequired 281 + } 282 + if len(r.Name) > MaxNameLength { 283 + return ErrNameTooLong 284 + } 285 + if len(r.GrinderType) > MaxGrinderTypeLength { 286 + return ErrFieldTooLong 287 + } 288 + if len(r.BurrType) > MaxBurrTypeLength { 289 + return ErrFieldTooLong 290 + } 291 + if len(r.Notes) > MaxNotesLength { 292 + return ErrNotesTooLong 293 + } 294 + return nil 295 + } 296 + 297 + // Validate checks that all fields are within acceptable limits 298 + func (r *CreateBrewerRequest) Validate() error { 299 + if r.Name == "" { 300 + return ErrNameRequired 301 + } 302 + if len(r.Name) > MaxNameLength { 303 + return ErrNameTooLong 304 + } 305 + if len(r.Description) > MaxDescriptionLength { 306 + return ErrDescTooLong 307 + } 308 + return nil 309 + } 310 + 311 + // Validate checks that all fields are within acceptable limits 312 + func (r *UpdateBrewerRequest) Validate() error { 313 + if r.Name == "" { 314 + return ErrNameRequired 315 + } 316 + if len(r.Name) > MaxNameLength { 317 + return ErrNameTooLong 318 + } 319 + if len(r.Description) > MaxDescriptionLength { 320 + return ErrDescTooLong 321 + } 322 + return nil 323 + }
+16 -4
internal/routing/routing.go
··· 86 86 // Catch-all 404 handler - must be last, catches any unmatched routes 87 87 mux.HandleFunc("/", h.HandleNotFound) 88 88 89 - // Apply middleware in order (last added is executed first) 90 - // 1. Apply OAuth middleware to add auth context to all requests 91 - handler := cfg.OAuthManager.AuthMiddleware(mux) 89 + // Apply middleware in order (outermost first, innermost last) 90 + var handler http.Handler = mux 92 91 93 - // 2. Apply logging middleware (wraps everything) 92 + // 1. Limit request body size (innermost - runs first on request) 93 + handler = middleware.LimitBodyMiddleware(handler) 94 + 95 + // 2. Apply OAuth middleware to add auth context 96 + handler = cfg.OAuthManager.AuthMiddleware(handler) 97 + 98 + // 3. Apply rate limiting 99 + rateLimitConfig := middleware.NewDefaultRateLimitConfig() 100 + handler = middleware.RateLimitMiddleware(rateLimitConfig)(handler) 101 + 102 + // 4. Apply security headers 103 + handler = middleware.SecurityHeadersMiddleware(handler) 104 + 105 + // 5. Apply logging middleware (outermost - wraps everything) 94 106 handler = middleware.LoggingMiddleware(cfg.Logger)(handler) 95 107 96 108 return handler