Vow, uncensorable PDS written in Go

refactor: use chi

+1732 -1159
+1 -8
go.mod
··· 8 8 github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792 9 9 github.com/domodwyer/mailyak/v3 v3.6.2 10 10 github.com/glebarez/sqlite v1.11.0 11 + github.com/go-chi/chi/v5 v5.2.5 11 12 github.com/go-pkgz/expirable-cache/v3 v3.0.0 12 13 github.com/go-playground/validator v9.31.0+incompatible 13 14 github.com/golang-jwt/jwt/v4 v4.5.2 ··· 22 23 github.com/ipfs/go-ipld-cbor v0.1.0 23 24 github.com/ipld/go-car v0.6.1-0.20230509095817-92d28eb23ba4 24 25 github.com/joho/godotenv v1.5.1 25 - github.com/labstack/echo-contrib v0.17.4 26 - github.com/labstack/echo/v4 v4.13.3 27 26 github.com/lestrrat-go/jwx/v2 v2.0.21 28 27 github.com/multiformats/go-multihash v0.2.3 29 28 github.com/prometheus/client_golang v1.23.2 30 - github.com/samber/slog-echo v1.16.1 31 29 github.com/spf13/cobra v1.10.2 32 30 github.com/spf13/viper v1.21.0 33 31 github.com/whyrusleeping/cbor-gen v0.2.1-0.20241030202151-b7a6831be65e ··· 56 54 github.com/gocql/gocql v1.7.0 // indirect 57 55 github.com/gogo/protobuf v1.3.2 // indirect 58 56 github.com/golang/snappy v0.0.4 // indirect 59 - github.com/gorilla/context v1.1.2 // indirect 60 57 github.com/gorilla/securecookie v1.1.2 // indirect 61 58 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect 62 59 github.com/hashicorp/go-cleanhttp v0.5.2 // indirect ··· 87 84 github.com/jinzhu/inflection v1.0.0 // indirect 88 85 github.com/jinzhu/now v1.1.5 // indirect 89 86 github.com/klauspost/cpuid/v2 v2.2.7 // indirect 90 - github.com/labstack/gommon v0.4.2 // indirect 91 87 github.com/leodido/go-urn v1.4.0 // indirect 92 88 github.com/lestrrat-go/blackmagic v1.0.2 // indirect 93 89 github.com/lestrrat-go/httpcc v1.0.1 // indirect ··· 112 108 github.com/prometheus/procfs v0.16.1 // indirect 113 109 github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect 114 110 github.com/sagikazarmark/locafero v0.11.0 // indirect 115 - github.com/samber/lo v1.49.1 // indirect 116 111 github.com/segmentio/asm v1.2.0 // indirect 117 112 github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect 118 113 github.com/spaolacci/murmur3 v1.1.0 // indirect ··· 120 115 github.com/spf13/cast v1.10.0 // indirect 121 116 github.com/spf13/pflag v1.0.10 // indirect 122 117 github.com/subosito/gotenv v1.6.0 // indirect 123 - github.com/valyala/bytebufferpool v1.0.0 // indirect 124 - github.com/valyala/fasttemplate v1.2.2 // indirect 125 118 gitlab.com/yawning/tuplehash v0.0.0-20230713102510-df83abbf9a02 // indirect 126 119 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect 127 120 go.opentelemetry.io/otel v1.29.0 // indirect
+2 -16
go.sum
··· 49 49 github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= 50 50 github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= 51 51 github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= 52 + github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= 53 + github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= 52 54 github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 53 55 github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= 54 56 github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= ··· 91 93 github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 92 94 github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= 93 95 github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= 94 - github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o= 95 - github.com/gorilla/context v1.1.2/go.mod h1:KDPwT9i/MeWHiLl90fuTgrt4/wPcv75vFAZLaOOcbxM= 96 96 github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= 97 97 github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= 98 98 github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= ··· 218 218 github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 219 219 github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= 220 220 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= 221 - github.com/labstack/echo-contrib v0.17.4 h1:g5mfsrJfJTKv+F5uNKCyrjLK7js+ZW6HTjg4FnDxxgk= 222 - github.com/labstack/echo-contrib v0.17.4/go.mod h1:9O7ZPAHUeMGTOAfg80YqQduHzt0CzLak36PZRldYrZ0= 223 - github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaafY= 224 - github.com/labstack/echo/v4 v4.13.3/go.mod h1:o90YNEeQWjDozo584l7AwhJMHN0bOC4tAfg+Xox9q5g= 225 - github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= 226 - github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= 227 221 github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= 228 222 github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= 229 223 github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= ··· 320 314 github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 321 315 github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= 322 316 github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= 323 - github.com/samber/lo v1.49.1 h1:4BIFyVfuQSEpluc7Fua+j1NolZHiEHEpaSEKdsH0tew= 324 - github.com/samber/lo v1.49.1/go.mod h1:dO6KHFzUKXgP8LDhU0oI8d2hekjXnGOu0DB8Jecxd6o= 325 - github.com/samber/slog-echo v1.16.1 h1:5Q5IUROkFqKcu/qJM/13AP1d3gd1RS+Q/4EvKQU1fuo= 326 - github.com/samber/slog-echo v1.16.1/go.mod h1:f+B3WR06saRXcaGRZ/I/UPCECDPqTUqadRIf7TmyRhI= 327 317 github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= 328 318 github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= 329 319 github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= ··· 357 347 github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= 358 348 github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= 359 349 github.com/urfave/cli v1.22.10/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= 360 - github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= 361 - github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= 362 - github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= 363 - github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= 364 350 github.com/warpfork/go-testmark v0.12.1 h1:rMgCpJfwy1sJ50x0M0NgyphxYYPMOODIJHhsXyEHU0s= 365 351 github.com/warpfork/go-testmark v0.12.1/go.mod h1:kHwy7wfvGSPh1rQJYKayD4AbtNaeyZdcGi9tNJTaa5Y= 366 352 github.com/warpfork/go-wish v0.0.0-20220906213052-39a1cc7a02d0 h1:GDDkbFiaK8jsSDJfjId/PEGEShv6ugrt4kYsC5UIDaQ=
+23 -16
internal/helpers/helpers.go
··· 3 3 import ( 4 4 crand "crypto/rand" 5 5 "encoding/hex" 6 + "encoding/json" 6 7 "errors" 7 8 "math/rand" 9 + "net/http" 8 10 "net/url" 9 11 10 - "github.com/Azure/go-autorest/autorest/to" 11 - "github.com/labstack/echo/v4" 12 12 "github.com/lestrrat-go/jwx/v2/jwk" 13 13 ) 14 14 ··· 16 16 // /^[A-Z2-7]{5}-[A-Z2-7]{5}$/ 17 17 var letters = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") 18 18 19 - func InputError(e echo.Context, custom *string) error { 19 + func writeJSON(w http.ResponseWriter, status int, v any) error { 20 + w.Header().Set("Content-Type", "application/json") 21 + w.WriteHeader(status) 22 + return json.NewEncoder(w).Encode(v) 23 + } 24 + 25 + func InputError(w http.ResponseWriter, custom *string) error { 20 26 msg := "InvalidRequest" 21 27 if custom != nil { 22 28 msg = *custom 23 29 } 24 - return genericError(e, 400, msg) 30 + return genericError(w, http.StatusBadRequest, msg) 25 31 } 26 32 27 - func ServerError(e echo.Context, suffix *string) error { 33 + func ServerError(w http.ResponseWriter, suffix *string) error { 28 34 msg := "Internal server error" 29 35 if suffix != nil { 30 36 msg += ". " + *suffix 31 37 } 32 - return genericError(e, 500, msg) 38 + return genericError(w, http.StatusInternalServerError, msg) 33 39 } 34 40 35 - func UnauthorizedError(e echo.Context, suffix *string) error { 41 + func UnauthorizedError(w http.ResponseWriter, suffix *string) error { 36 42 msg := "Unauthorized" 37 43 if suffix != nil { 38 44 msg += ". " + *suffix 39 45 } 40 - return genericError(e, 401, msg) 46 + return genericError(w, http.StatusUnauthorized, msg) 41 47 } 42 48 43 - func ForbiddenError(e echo.Context, suffix *string) error { 49 + func ForbiddenError(w http.ResponseWriter, suffix *string) error { 44 50 msg := "Forbidden" 45 51 if suffix != nil { 46 52 msg += ". " + *suffix 47 53 } 48 - return genericError(e, 403, msg) 54 + return genericError(w, http.StatusForbidden, msg) 49 55 } 50 56 51 - func InvalidTokenError(e echo.Context) error { 52 - return InputError(e, to.StringPtr("InvalidToken")) 57 + func InvalidTokenError(w http.ResponseWriter) error { 58 + s := "InvalidToken" 59 + return InputError(w, &s) 53 60 } 54 61 55 - func ExpiredTokenError(e echo.Context) error { 62 + func ExpiredTokenError(w http.ResponseWriter) error { 56 63 // WARN: See https://github.com/bluesky-social/atproto/discussions/3319 57 - return e.JSON(400, map[string]string{ 64 + return writeJSON(w, http.StatusBadRequest, map[string]string{ 58 65 "error": "ExpiredToken", 59 66 "message": "*", 60 67 }) 61 68 } 62 69 63 - func genericError(e echo.Context, code int, msg string) error { 64 - return e.JSON(code, map[string]string{ 70 + func genericError(w http.ResponseWriter, code int, msg string) error { 71 + return writeJSON(w, code, map[string]string{ 65 72 "error": msg, 66 73 }) 67 74 }
+9 -9
oauth/provider/middleware.go
··· 1 1 package provider 2 2 3 3 import ( 4 - "github.com/labstack/echo/v4" 4 + "net/http" 5 5 ) 6 6 7 - func (p *Provider) BaseMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 8 - return func(e echo.Context) error { 9 - e.Response().Header().Set("cache-control", "no-store") 10 - e.Response().Header().Set("pragma", "no-cache") 7 + func (p *Provider) BaseMiddleware(next http.Handler) http.Handler { 8 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 9 + w.Header().Set("cache-control", "no-store") 10 + w.Header().Set("pragma", "no-cache") 11 11 12 12 nonce := p.NextNonce() 13 13 if nonce != "" { 14 - e.Response().Header().Set("DPoP-Nonce", nonce) 15 - e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 14 + w.Header().Set("DPoP-Nonce", nonce) 15 + w.Header().Add("access-control-expose-headers", "DPoP-Nonce") 16 16 } 17 17 18 - return next(e) 19 - } 18 + next.ServeHTTP(w, r) 19 + }) 20 20 }
+12 -10
server/handle_account.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 4 5 "time" 5 6 6 7 "github.com/haileyok/cocoon/oauth" 7 8 "github.com/haileyok/cocoon/oauth/constants" 8 9 "github.com/haileyok/cocoon/oauth/provider" 9 10 "github.com/hako/durafmt" 10 - "github.com/labstack/echo/v4" 11 11 ) 12 12 13 - func (s *Server) handleAccount(e echo.Context) error { 14 - ctx := e.Request().Context() 13 + func (s *Server) handleAccount(w http.ResponseWriter, r *http.Request) { 14 + ctx := r.Context() 15 15 logger := s.logger.With("name", "handleAuth") 16 16 17 - repo, sess, err := s.getSessionRepoOrErr(e) 17 + repo, sess, err := s.getSessionRepoOrErr(r) 18 18 if err != nil { 19 - return e.Redirect(303, "/account/signin") 19 + http.Redirect(w, r, "/account/signin", 303) 20 + return 20 21 } 21 22 22 23 oldestPossibleSession := time.Now().Add(constants.ConfidentialClientSessionLifetime) ··· 25 26 if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil { 26 27 logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err) 27 28 sess.AddFlash("Unable to fetch sessions. See server logs for more details.", "error") 28 - sess.Save(e.Request(), e.Response()) 29 - return e.Render(200, "account.html", map[string]any{ 30 - "flashes": getFlashesFromSession(e, sess), 29 + sess.Save(r, w) 30 + s.renderTemplate(w, "account.html", map[string]any{ 31 + "flashes": getFlashesFromSession(w, r, sess), 31 32 }) 33 + return 32 34 } 33 35 34 36 var filtered []provider.OauthToken ··· 68 70 }) 69 71 } 70 72 71 - return e.Render(200, "account.html", map[string]any{ 73 + s.renderTemplate(w, "account.html", map[string]any{ 72 74 "Repo": repo, 73 75 "Tokens": tokenInfo, 74 - "flashes": getFlashesFromSession(e, sess), 76 + "flashes": getFlashesFromSession(w, r, sess), 75 77 }) 76 78 }
+21 -14
server/handle_account_revoke.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 5 + 4 6 "github.com/haileyok/cocoon/internal/helpers" 5 - "github.com/labstack/echo/v4" 6 7 ) 7 8 8 9 type AccountRevokeInput struct { 9 10 Token string `form:"token"` 10 11 } 11 12 12 - func (s *Server) handleAccountRevoke(e echo.Context) error { 13 - ctx := e.Request().Context() 14 - logger := s.logger.With("name", "handleAcocuntRevoke") 13 + func (s *Server) handleAccountRevoke(w http.ResponseWriter, r *http.Request) { 14 + ctx := r.Context() 15 + logger := s.logger.With("name", "handleAccountRevoke") 15 16 16 - var req AccountRevokeInput 17 - if err := e.Bind(&req); err != nil { 18 - logger.Error("could not bind account revoke request", "error", err) 19 - return helpers.ServerError(e, nil) 17 + if err := r.ParseForm(); err != nil { 18 + logger.Error("could not parse account revoke form", "error", err) 19 + helpers.ServerError(w, nil) 20 + return 21 + } 22 + 23 + req := AccountRevokeInput{ 24 + Token: r.FormValue("token"), 20 25 } 21 26 22 - repo, sess, err := s.getSessionRepoOrErr(e) 27 + repo, sess, err := s.getSessionRepoOrErr(r) 23 28 if err != nil { 24 - return e.Redirect(303, "/account/signin") 29 + http.Redirect(w, r, "/account/signin", 303) 30 + return 25 31 } 26 32 27 33 if err := s.db.Exec(ctx, "DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil { 28 34 logger.Error("couldnt delete oauth session for account", "did", repo.Repo.Did, "token", req.Token, "error", err) 29 35 sess.AddFlash("Unable to revoke session. See server logs for more details.", "error") 30 - sess.Save(e.Request(), e.Response()) 31 - return e.Redirect(303, "/account") 36 + sess.Save(r, w) 37 + http.Redirect(w, r, "/account", 303) 38 + return 32 39 } 33 40 34 41 sess.AddFlash("Session successfully revoked!", "success") 35 - sess.Save(e.Request(), e.Response()) 36 - return e.Redirect(303, "/account") 42 + sess.Save(r, w) 43 + http.Redirect(w, r, "/account", 303) 37 44 }
+55 -39
server/handle_account_signin.go
··· 3 3 import ( 4 4 "errors" 5 5 "fmt" 6 + "net/http" 6 7 "strings" 7 8 "time" 8 9 ··· 10 11 "github.com/gorilla/sessions" 11 12 "github.com/haileyok/cocoon/internal/helpers" 12 13 "github.com/haileyok/cocoon/models" 13 - "github.com/labstack/echo-contrib/session" 14 - "github.com/labstack/echo/v4" 15 14 "golang.org/x/crypto/bcrypt" 16 15 "gorm.io/gorm" 17 16 ) ··· 23 22 QueryParams string `form:"query_params"` 24 23 } 25 24 26 - func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) { 27 - ctx := e.Request().Context() 25 + func (s *Server) getSessionRepoOrErr(r *http.Request) (*models.RepoActor, *sessions.Session, error) { 26 + ctx := r.Context() 28 27 29 - sess, err := session.Get(s.config.SessionCookieKey, e) 28 + sess, err := s.sessions.Get(r, s.config.SessionCookieKey) 30 29 if err != nil { 31 30 return nil, nil, err 32 31 } ··· 44 43 return repo, sess, nil 45 44 } 46 45 47 - func getFlashesFromSession(e echo.Context, sess *sessions.Session) map[string]any { 48 - defer sess.Save(e.Request(), e.Response()) 46 + func getFlashesFromSession(w http.ResponseWriter, r *http.Request, sess *sessions.Session) map[string]any { 47 + defer sess.Save(r, w) 49 48 return map[string]any{ 50 49 "errors": sess.Flashes("error"), 51 50 "successes": sess.Flashes("success"), ··· 53 52 } 54 53 } 55 54 56 - func (s *Server) handleAccountSigninGet(e echo.Context) error { 57 - _, sess, err := s.getSessionRepoOrErr(e) 55 + func (s *Server) handleAccountSigninGet(w http.ResponseWriter, r *http.Request) { 56 + _, sess, err := s.getSessionRepoOrErr(r) 58 57 if err == nil { 59 - return e.Redirect(303, "/account") 58 + http.Redirect(w, r, "/account", 303) 59 + return 60 60 } 61 61 62 - return e.Render(200, "signin.html", map[string]any{ 63 - "flashes": getFlashesFromSession(e, sess), 64 - "QueryParams": e.QueryParams().Encode(), 62 + s.renderTemplate(w, "signin.html", map[string]any{ 63 + "flashes": getFlashesFromSession(w, r, sess), 64 + "QueryParams": r.URL.Query().Encode(), 65 65 }) 66 66 } 67 67 68 - func (s *Server) handleAccountSigninPost(e echo.Context) error { 69 - ctx := e.Request().Context() 68 + func (s *Server) handleAccountSigninPost(w http.ResponseWriter, r *http.Request) { 69 + ctx := r.Context() 70 70 logger := s.logger.With("name", "handleAccountSigninPost") 71 71 72 - var req OauthSigninInput 73 - if err := e.Bind(&req); err != nil { 74 - logger.Error("error binding sign in req", "error", err) 75 - return helpers.ServerError(e, nil) 72 + if err := r.ParseForm(); err != nil { 73 + logger.Error("error parsing sign in form", "error", err) 74 + helpers.ServerError(w, nil) 75 + return 76 + } 77 + 78 + req := OauthSigninInput{ 79 + Username: r.FormValue("username"), 80 + Password: r.FormValue("password"), 81 + AuthFactorToken: r.FormValue("token"), 82 + QueryParams: r.FormValue("query_params"), 76 83 } 77 84 78 - sess, _ := session.Get(s.config.SessionCookieKey, e) 85 + sess, _ := s.sessions.Get(r, s.config.SessionCookieKey) 79 86 80 87 req.Username = strings.ToLower(req.Username) 81 88 var idtype string ··· 109 116 } else { 110 117 sess.AddFlash("Something went wrong!", "error") 111 118 } 112 - sess.Save(e.Request(), e.Response()) 113 - return e.Redirect(303, "/account/signin"+queryParams) 119 + sess.Save(r, w) 120 + http.Redirect(w, r, "/account/signin"+queryParams, 303) 121 + return 114 122 } 115 123 116 124 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil { ··· 119 127 } else { 120 128 sess.AddFlash("Something went wrong!", "error") 121 129 } 122 - sess.Save(e.Request(), e.Response()) 123 - return e.Redirect(303, "/account/signin"+queryParams) 130 + sess.Save(r, w) 131 + http.Redirect(w, r, "/account/signin"+queryParams, 303) 132 + return 124 133 } 125 134 126 135 // if repo requires 2FA token and one hasn't been provided, return error prompting for one ··· 128 137 err = s.createAndSendTwoFactorCode(ctx, repo) 129 138 if err != nil { 130 139 sess.AddFlash("Something went wrong!", "error") 131 - sess.Save(e.Request(), e.Response()) 132 - return e.Redirect(303, "/account/signin"+queryParams) 140 + sess.Save(r, w) 141 + http.Redirect(w, r, "/account/signin"+queryParams, 303) 142 + return 133 143 } 134 144 135 145 sess.AddFlash("requires 2FA token", "tokenrequired") 136 - sess.Save(e.Request(), e.Response()) 137 - return e.Redirect(303, "/account/signin"+queryParams) 146 + sess.Save(r, w) 147 + http.Redirect(w, r, "/account/signin"+queryParams, 303) 148 + return 138 149 } 139 150 140 - // if 2FAis required, now check that the one provided is valid 151 + // if 2FA is required, now check that the one provided is valid 141 152 if repo.TwoFactorType != models.TwoFactorTypeNone { 142 153 if repo.TwoFactorCode == nil || repo.TwoFactorCodeExpiresAt == nil { 143 154 err = s.createAndSendTwoFactorCode(ctx, repo) 144 155 if err != nil { 145 156 sess.AddFlash("Something went wrong!", "error") 146 - sess.Save(e.Request(), e.Response()) 147 - return e.Redirect(303, "/account/signin"+queryParams) 157 + sess.Save(r, w) 158 + http.Redirect(w, r, "/account/signin"+queryParams, 303) 159 + return 148 160 } 149 161 150 162 sess.AddFlash("requires 2FA token", "tokenrequired") 151 - sess.Save(e.Request(), e.Response()) 152 - return e.Redirect(303, "/account/signin"+queryParams) 163 + sess.Save(r, w) 164 + http.Redirect(w, r, "/account/signin"+queryParams, 303) 165 + return 153 166 } 154 167 155 168 if *repo.TwoFactorCode != req.AuthFactorToken { 156 - return helpers.InvalidTokenError(e) 169 + helpers.InvalidTokenError(w) 170 + return 157 171 } 158 172 159 173 if time.Now().UTC().After(*repo.TwoFactorCodeExpiresAt) { 160 - return helpers.ExpiredTokenError(e) 174 + helpers.ExpiredTokenError(w) 175 + return 161 176 } 162 177 } 163 178 ··· 170 185 sess.Values = map[any]any{} 171 186 sess.Values["did"] = repo.Repo.Did 172 187 173 - if err := sess.Save(e.Request(), e.Response()); err != nil { 174 - return err 188 + if err := sess.Save(r, w); err != nil { 189 + helpers.ServerError(w, nil) 190 + return 175 191 } 176 192 177 193 if queryParams != "" { 178 - return e.Redirect(303, "/oauth/authorize"+queryParams) 194 + http.Redirect(w, r, "/oauth/authorize"+queryParams, 303) 179 195 } else { 180 - return e.Redirect(303, "/account") 196 + http.Redirect(w, r, "/account", 303) 181 197 } 182 198 }
+12 -10
server/handle_account_signout.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 5 + 4 6 "github.com/gorilla/sessions" 5 - "github.com/labstack/echo-contrib/session" 6 - "github.com/labstack/echo/v4" 7 7 ) 8 8 9 - func (s *Server) handleAccountSignout(e echo.Context) error { 10 - sess, err := session.Get(s.config.SessionCookieKey, e) 9 + func (s *Server) handleAccountSignout(w http.ResponseWriter, r *http.Request) { 10 + sess, err := s.sessions.Get(r, s.config.SessionCookieKey) 11 11 if err != nil { 12 - return err 12 + http.Error(w, "session error", http.StatusInternalServerError) 13 + return 13 14 } 14 15 15 16 sess.Options = &sessions.Options{ ··· 20 21 21 22 sess.Values = map[any]any{} 22 23 23 - if err := sess.Save(e.Request(), e.Response()); err != nil { 24 - return err 24 + if err := sess.Save(r, w); err != nil { 25 + http.Error(w, "session save error", http.StatusInternalServerError) 26 + return 25 27 } 26 28 27 - reqUri := e.QueryParam("request_uri") 29 + reqUri := r.URL.Query().Get("request_uri") 28 30 29 31 redirect := "/account/signin" 30 32 if reqUri != "" { 31 - redirect += "?" + e.QueryParams().Encode() 33 + redirect += "?" + r.URL.Query().Encode() 32 34 } 33 35 34 - return e.Redirect(303, redirect) 36 + http.Redirect(w, r, redirect, 303) 35 37 }
+4 -4
server/handle_actor_get_preferences.go
··· 2 2 3 3 import ( 4 4 "encoding/json" 5 + "net/http" 5 6 6 7 "github.com/haileyok/cocoon/models" 7 - "github.com/labstack/echo/v4" 8 8 ) 9 9 10 10 // This is kinda lame. Not great to implement app.bsky in the pds, but alas 11 11 12 - func (s *Server) handleActorGetPreferences(e echo.Context) error { 13 - repo := e.Get("repo").(*models.RepoActor) 12 + func (s *Server) handleActorGetPreferences(w http.ResponseWriter, r *http.Request) { 13 + repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 14 14 15 15 var prefs map[string]any 16 16 err := json.Unmarshal(repo.Preferences, &prefs) ··· 20 20 } 21 21 } 22 22 23 - return e.JSON(200, prefs) 23 + s.writeJSON(w, 200, prefs) 24 24 }
+12 -9
server/handle_actor_put_preferences.go
··· 2 2 3 3 import ( 4 4 "encoding/json" 5 + "net/http" 5 6 6 7 "github.com/haileyok/cocoon/models" 7 - "github.com/labstack/echo/v4" 8 8 ) 9 9 10 10 // This is kinda lame. Not great to implement app.bsky in the pds, but alas 11 11 12 - func (s *Server) handleActorPutPreferences(e echo.Context) error { 13 - ctx := e.Request().Context() 12 + func (s *Server) handleActorPutPreferences(w http.ResponseWriter, r *http.Request) { 13 + ctx := r.Context() 14 14 15 - repo := e.Get("repo").(*models.RepoActor) 15 + repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 16 16 17 17 var prefs map[string]any 18 - if err := json.NewDecoder(e.Request().Body).Decode(&prefs); err != nil { 19 - return err 18 + if err := json.NewDecoder(r.Body).Decode(&prefs); err != nil { 19 + http.Error(w, err.Error(), http.StatusBadRequest) 20 + return 20 21 } 21 22 22 23 b, err := json.Marshal(prefs) 23 24 if err != nil { 24 - return err 25 + http.Error(w, err.Error(), http.StatusInternalServerError) 26 + return 25 27 } 26 28 27 29 if err := s.db.Exec(ctx, "UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil { 28 - return err 30 + http.Error(w, err.Error(), http.StatusInternalServerError) 31 + return 29 32 } 30 33 31 - return nil 34 + w.WriteHeader(http.StatusOK) 32 35 }
+4 -4
server/handle_age_assurance.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 4 5 "time" 5 6 6 7 "github.com/bluesky-social/indigo/util" 7 8 "github.com/haileyok/cocoon/models" 8 - "github.com/labstack/echo/v4" 9 9 ) 10 10 11 - func (s *Server) handleAgeAssurance(e echo.Context) error { 12 - repo := e.Get("repo").(*models.RepoActor) 11 + func (s *Server) handleAgeAssurance(w http.ResponseWriter, r *http.Request) { 12 + repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 13 13 14 14 resp := map[string]any{ 15 15 "state": map[string]any{ ··· 22 22 }, 23 23 } 24 24 25 - return e.JSON(200, resp) 25 + s.writeJSON(w, 200, resp) 26 26 }
+3 -3
server/handle_health.go
··· 1 1 package server 2 2 3 - import "github.com/labstack/echo/v4" 3 + import "net/http" 4 4 5 - func (s *Server) handleHealth(e echo.Context) error { 6 - return e.JSON(200, map[string]string{ 5 + func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { 6 + s.writeJSON(w, 200, map[string]string{ 7 7 "version": "cocoon " + s.config.Version, 8 8 }) 9 9 }
+9 -7
server/handle_identity_request_plc_operation.go
··· 2 2 3 3 import ( 4 4 "fmt" 5 + "net/http" 5 6 "time" 6 7 7 8 "github.com/haileyok/cocoon/internal/helpers" 8 9 "github.com/haileyok/cocoon/models" 9 - "github.com/labstack/echo/v4" 10 10 ) 11 11 12 - func (s *Server) handleIdentityRequestPlcOperationSignature(e echo.Context) error { 13 - ctx := e.Request().Context() 12 + func (s *Server) handleIdentityRequestPlcOperationSignature(w http.ResponseWriter, r *http.Request) { 13 + ctx := r.Context() 14 14 logger := s.logger.With("name", "handleIdentityRequestPlcOperationSignature") 15 15 16 - urepo := e.Get("repo").(*models.RepoActor) 16 + urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 17 17 18 18 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 19 19 eat := time.Now().Add(10 * time.Minute).UTC() 20 20 21 21 if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = ?, plc_operation_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 22 22 logger.Error("error updating user", "error", err) 23 - return helpers.ServerError(e, nil) 23 + helpers.ServerError(w, nil) 24 + return 24 25 } 25 26 26 27 if err := s.sendPlcTokenReset(urepo.Email, urepo.Handle, code); err != nil { 27 28 logger.Error("error sending mail", "error", err) 28 - return helpers.ServerError(e, nil) 29 + helpers.ServerError(w, nil) 30 + return 29 31 } 30 32 31 - return e.NoContent(200) 33 + w.WriteHeader(http.StatusOK) 32 34 }
+26 -16
server/handle_identity_sign_plc_operation.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "encoding/json" 6 + "net/http" 5 7 "strings" 6 8 "time" 7 9 ··· 11 13 "github.com/haileyok/cocoon/internal/helpers" 12 14 "github.com/haileyok/cocoon/models" 13 15 "github.com/haileyok/cocoon/plc" 14 - "github.com/labstack/echo/v4" 15 16 ) 16 17 17 18 type ComAtprotoSignPlcOperationRequest struct { ··· 26 27 Operation plc.Operation `json:"operation"` 27 28 } 28 29 29 - func (s *Server) handleSignPlcOperation(e echo.Context) error { 30 + func (s *Server) handleSignPlcOperation(w http.ResponseWriter, r *http.Request) { 30 31 logger := s.logger.With("name", "handleSignPlcOperation") 31 32 32 - repo := e.Get("repo").(*models.RepoActor) 33 + repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 33 34 34 35 var req ComAtprotoSignPlcOperationRequest 35 - if err := e.Bind(&req); err != nil { 36 - logger.Error("error binding", "error", err) 37 - return helpers.ServerError(e, nil) 36 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 37 + logger.Error("error decoding", "error", err) 38 + helpers.ServerError(w, nil) 39 + return 38 40 } 39 41 40 42 if !strings.HasPrefix(repo.Repo.Did, "did:plc:") { 41 - return helpers.InputError(e, nil) 43 + helpers.InputError(w, nil) 44 + return 42 45 } 43 46 44 47 if repo.PlcOperationCode == nil || repo.PlcOperationCodeExpiresAt == nil { 45 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 48 + helpers.InputError(w, to.StringPtr("InvalidToken")) 49 + return 46 50 } 47 51 48 52 if *repo.PlcOperationCode != req.Token { 49 - return helpers.InvalidTokenError(e) 53 + helpers.InvalidTokenError(w) 54 + return 50 55 } 51 56 52 57 if time.Now().UTC().After(*repo.PlcOperationCodeExpiresAt) { 53 - return helpers.ExpiredTokenError(e) 58 + helpers.ExpiredTokenError(w) 59 + return 54 60 } 55 61 56 - ctx := context.WithValue(e.Request().Context(), "skip-cache", true) 62 + ctx := context.WithValue(r.Context(), "skip-cache", true) 57 63 log, err := identity.FetchDidAuditLog(ctx, nil, repo.Repo.Did) 58 64 if err != nil { 59 65 logger.Error("error fetching doc", "error", err) 60 - return helpers.ServerError(e, nil) 66 + helpers.ServerError(w, nil) 67 + return 61 68 } 62 69 63 70 latest := log[len(log)-1] ··· 86 93 k, err := atcrypto.ParsePrivateBytesK256(repo.SigningKey) 87 94 if err != nil { 88 95 logger.Error("error parsing signing key", "error", err) 89 - return helpers.ServerError(e, nil) 96 + helpers.ServerError(w, nil) 97 + return 90 98 } 91 99 92 100 if err := s.plcClient.SignOp(k, &op); err != nil { 93 101 logger.Error("error signing plc operation", "error", err) 94 - return helpers.ServerError(e, nil) 102 + helpers.ServerError(w, nil) 103 + return 95 104 } 96 105 97 106 if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = NULL, plc_operation_code_expires_at = NULL WHERE did = ?", nil, repo.Repo.Did).Error; err != nil { 98 107 logger.Error("error updating repo", "error", err) 99 - return helpers.ServerError(e, nil) 108 + helpers.ServerError(w, nil) 109 + return 100 110 } 101 111 102 - return e.JSON(200, ComAtprotoSignPlcOperationResponse{ 112 + s.writeJSON(w, 200, ComAtprotoSignPlcOperationResponse{ 103 113 Operation: op, 104 114 }) 105 115 }
+32 -21
server/handle_identity_submit_plc_operation.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "encoding/json" 6 + "net/http" 5 7 "slices" 6 8 "strings" 7 9 "time" ··· 13 15 "github.com/haileyok/cocoon/internal/helpers" 14 16 "github.com/haileyok/cocoon/models" 15 17 "github.com/haileyok/cocoon/plc" 16 - "github.com/labstack/echo/v4" 17 18 ) 18 19 19 20 type ComAtprotoSubmitPlcOperationRequest struct { 20 21 Operation plc.Operation `json:"operation"` 21 22 } 22 23 23 - func (s *Server) handleSubmitPlcOperation(e echo.Context) error { 24 + func (s *Server) handleSubmitPlcOperation(w http.ResponseWriter, r *http.Request) { 24 25 logger := s.logger.With("name", "handleIdentitySubmitPlcOperation") 25 26 26 - repo := e.Get("repo").(*models.RepoActor) 27 + repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 27 28 28 29 var req ComAtprotoSubmitPlcOperationRequest 29 - if err := e.Bind(&req); err != nil { 30 - logger.Error("error binding", "error", err) 31 - return helpers.ServerError(e, nil) 30 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 31 + logger.Error("error decoding", "error", err) 32 + helpers.ServerError(w, nil) 33 + return 32 34 } 33 35 34 - if err := e.Validate(req); err != nil { 35 - return helpers.InputError(e, nil) 36 + if err := s.validator.Struct(req); err != nil { 37 + helpers.InputError(w, nil) 38 + return 36 39 } 40 + 37 41 if !strings.HasPrefix(repo.Repo.Did, "did:plc:") { 38 - return helpers.InputError(e, nil) 42 + helpers.InputError(w, nil) 43 + return 39 44 } 40 45 41 46 op := req.Operation ··· 43 48 k, err := atcrypto.ParsePrivateBytesK256(repo.SigningKey) 44 49 if err != nil { 45 50 logger.Error("error parsing key", "error", err) 46 - return helpers.ServerError(e, nil) 51 + helpers.ServerError(w, nil) 52 + return 47 53 } 48 54 required, err := s.plcClient.CreateDidCredentials(k, "", repo.Actor.Handle) 49 55 if err != nil { 50 - logger.Error("error crating did credentials", "error", err) 51 - return helpers.ServerError(e, nil) 56 + logger.Error("error creating did credentials", "error", err) 57 + helpers.ServerError(w, nil) 58 + return 52 59 } 53 60 54 61 for _, expectedKey := range required.RotationKeys { 55 62 if !slices.Contains(op.RotationKeys, expectedKey) { 56 - return helpers.InputError(e, nil) 63 + helpers.InputError(w, nil) 64 + return 57 65 } 58 66 } 59 67 if op.Services["atproto_pds"].Type != "AtprotoPersonalDataServer" { 60 - return helpers.InputError(e, nil) 68 + helpers.InputError(w, nil) 69 + return 61 70 } 62 71 if op.Services["atproto_pds"].Endpoint != required.Services["atproto_pds"].Endpoint { 63 - return helpers.InputError(e, nil) 72 + helpers.InputError(w, nil) 73 + return 64 74 } 65 75 if op.VerificationMethods["atproto"] != required.VerificationMethods["atproto"] { 66 - return helpers.InputError(e, nil) 76 + helpers.InputError(w, nil) 77 + return 67 78 } 68 79 if op.AlsoKnownAs[0] != required.AlsoKnownAs[0] { 69 - return helpers.InputError(e, nil) 80 + helpers.InputError(w, nil) 81 + return 70 82 } 71 83 72 - if err := s.plcClient.SendOperation(e.Request().Context(), repo.Repo.Did, &op); err != nil { 73 - return err 84 + if err := s.plcClient.SendOperation(r.Context(), repo.Repo.Did, &op); err != nil { 85 + helpers.ServerError(w, nil) 86 + return 74 87 } 75 88 76 89 if err := s.passport.BustDoc(context.TODO(), repo.Repo.Did); err != nil { ··· 84 97 Time: time.Now().Format(util.ISO8601), 85 98 }, 86 99 }) 87 - 88 - return nil 89 100 }
+23 -17
server/handle_identity_update_handle.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "encoding/json" 6 + "net/http" 5 7 "strings" 6 8 "time" 7 9 ··· 14 16 "github.com/haileyok/cocoon/internal/helpers" 15 17 "github.com/haileyok/cocoon/models" 16 18 "github.com/haileyok/cocoon/plc" 17 - "github.com/labstack/echo/v4" 18 19 ) 19 20 20 21 type ComAtprotoIdentityUpdateHandleRequest struct { 21 22 Handle string `json:"handle" validate:"atproto-handle"` 22 23 } 23 24 24 - func (s *Server) handleIdentityUpdateHandle(e echo.Context) error { 25 + func (s *Server) handleIdentityUpdateHandle(w http.ResponseWriter, r *http.Request) { 25 26 logger := s.logger.With("name", "handleIdentityUpdateHandle") 26 27 27 - repo := e.Get("repo").(*models.RepoActor) 28 + repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 28 29 29 30 var req ComAtprotoIdentityUpdateHandleRequest 30 - if err := e.Bind(&req); err != nil { 31 - logger.Error("error binding", "error", err) 32 - return helpers.ServerError(e, nil) 31 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 32 + logger.Error("error decoding", "error", err) 33 + helpers.ServerError(w, nil) 34 + return 33 35 } 34 36 35 37 req.Handle = strings.ToLower(req.Handle) 36 38 37 - if err := e.Validate(req); err != nil { 38 - return helpers.InputError(e, nil) 39 + if err := s.validator.Struct(req); err != nil { 40 + helpers.InputError(w, nil) 41 + return 39 42 } 40 43 41 - ctx := context.WithValue(e.Request().Context(), "skip-cache", true) 44 + ctx := context.WithValue(r.Context(), "skip-cache", true) 42 45 43 46 if strings.HasPrefix(repo.Repo.Did, "did:plc:") { 44 47 log, err := identity.FetchDidAuditLog(ctx, nil, repo.Repo.Did) 45 48 if err != nil { 46 49 logger.Error("error fetching doc", "error", err) 47 - return helpers.ServerError(e, nil) 50 + helpers.ServerError(w, nil) 51 + return 48 52 } 49 53 50 54 latest := log[len(log)-1] ··· 71 75 k, err := atcrypto.ParsePrivateBytesK256(repo.SigningKey) 72 76 if err != nil { 73 77 logger.Error("error parsing signing key", "error", err) 74 - return helpers.ServerError(e, nil) 78 + helpers.ServerError(w, nil) 79 + return 75 80 } 76 81 77 82 if err := s.plcClient.SignOp(k, &op); err != nil { 78 - return err 83 + helpers.ServerError(w, nil) 84 + return 79 85 } 80 86 81 - if err := s.plcClient.SendOperation(e.Request().Context(), repo.Repo.Did, &op); err != nil { 82 - return err 87 + if err := s.plcClient.SendOperation(r.Context(), repo.Repo.Did, &op); err != nil { 88 + helpers.ServerError(w, nil) 89 + return 83 90 } 84 91 } 85 92 ··· 98 105 99 106 if err := s.db.Exec(ctx, "UPDATE actors SET handle = ? WHERE did = ?", nil, req.Handle, repo.Repo.Did).Error; err != nil { 100 107 logger.Error("error updating handle in db", "error", err) 101 - return helpers.ServerError(e, nil) 108 + helpers.ServerError(w, nil) 109 + return 102 110 } 103 - 104 - return nil 105 111 }
+27 -19
server/handle_import_repo.go
··· 4 4 "bytes" 5 5 "context" 6 6 "io" 7 + "net/http" 7 8 "slices" 8 9 "strings" 9 10 ··· 13 14 blocks "github.com/ipfs/go-block-format" 14 15 "github.com/ipfs/go-cid" 15 16 "github.com/ipld/go-car" 16 - "github.com/labstack/echo/v4" 17 17 ) 18 18 19 - func (s *Server) handleRepoImportRepo(e echo.Context) error { 20 - ctx := e.Request().Context() 19 + func (s *Server) handleRepoImportRepo(w http.ResponseWriter, r *http.Request) { 20 + ctx := r.Context() 21 21 logger := s.logger.With("name", "handleImportRepo") 22 22 23 - urepo := e.Get("repo").(*models.RepoActor) 23 + urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 24 24 25 - b, err := io.ReadAll(e.Request().Body) 25 + b, err := io.ReadAll(r.Body) 26 26 if err != nil { 27 27 logger.Error("could not read bytes in import request", "error", err) 28 - return helpers.ServerError(e, nil) 28 + helpers.ServerError(w, nil) 29 + return 29 30 } 30 31 31 32 bs := s.getBlockstore(urepo.Repo.Did) ··· 33 34 cs, err := car.NewCarReader(bytes.NewReader(b)) 34 35 if err != nil { 35 36 logger.Error("could not read car in import request", "error", err) 36 - return helpers.ServerError(e, nil) 37 + helpers.ServerError(w, nil) 38 + return 37 39 } 38 40 39 41 orderedBlocks := []blocks.Block{} 40 42 currBlock, err := cs.Next() 41 43 if err != nil { 42 44 logger.Error("could not get first block from car", "error", err) 43 - return helpers.ServerError(e, nil) 45 + helpers.ServerError(w, nil) 46 + return 44 47 } 45 48 currBlockCt := 1 46 49 ··· 56 59 57 60 if err := bs.PutMany(context.TODO(), orderedBlocks); err != nil { 58 61 logger.Error("could not insert blocks", "error", err) 59 - return helpers.ServerError(e, nil) 62 + helpers.ServerError(w, nil) 63 + return 60 64 } 61 65 62 66 r, err := openRepo(context.TODO(), bs, cs.Header.Roots[0], urepo.Repo.Did) 63 67 if err != nil { 64 68 logger.Error("could not open repo", "error", err) 65 - return helpers.ServerError(e, nil) 69 + helpers.ServerError(w, nil) 70 + return 66 71 } 67 72 68 73 tx := s.db.Begin(ctx) ··· 73 78 pts := strings.Split(string(key), "/") 74 79 nsid := pts[0] 75 80 rkey := pts[1] 76 - cidStr := cid.String() 77 - b, err := bs.Get(context.TODO(), cid) 81 + cidStr := c.String() 82 + blkData, err := bs.Get(context.TODO(), c) 78 83 if err != nil { 79 84 logger.Error("record bytes don't exist in blockstore", "error", err) 80 - return helpers.ServerError(e, nil) 85 + return err 81 86 } 82 87 83 88 rec := models.Record{ ··· 86 91 Nsid: nsid, 87 92 Rkey: rkey, 88 93 Cid: cidStr, 89 - Value: b.RawData(), 94 + Value: blkData.RawData(), 90 95 } 91 96 92 97 if err := tx.Save(rec).Error; err != nil { ··· 96 101 return nil 97 102 }); err != nil { 98 103 tx.Rollback() 99 - logger.Error("record bytes don't exist in blockstore", "error", err) 100 - return helpers.ServerError(e, nil) 104 + logger.Error("error iterating repo blocks", "error", err) 105 + helpers.ServerError(w, nil) 106 + return 101 107 } 102 108 103 109 tx.Commit() ··· 105 111 root, rev, err := commitRepo(context.TODO(), bs, r, urepo.Repo.SigningKey) 106 112 if err != nil { 107 113 logger.Error("error committing", "error", err) 108 - return helpers.ServerError(e, nil) 114 + helpers.ServerError(w, nil) 115 + return 109 116 } 110 117 111 118 if err := s.UpdateRepo(context.TODO(), urepo.Repo.Did, root, rev); err != nil { 112 119 logger.Error("error updating repo after commit", "error", err) 113 - return helpers.ServerError(e, nil) 120 + helpers.ServerError(w, nil) 121 + return 114 122 } 115 123 116 - return nil 124 + w.WriteHeader(http.StatusOK) 117 125 }
+6 -7
server/handle_label_query_labels.go
··· 1 1 package server 2 2 3 - import ( 4 - "github.com/labstack/echo/v4" 5 - ) 3 + import "net/http" 6 4 7 5 type Label struct { 8 6 Ver *int `json:"ver,omitempty"` ··· 21 19 Labels []Label `json:"labels"` 22 20 } 23 21 24 - func (s *Server) handleLabelQueryLabels(e echo.Context) error { 25 - svc := e.Request().Header.Get("atproto-proxy") 22 + func (s *Server) handleLabelQueryLabels(w http.ResponseWriter, r *http.Request) { 23 + svc := r.Header.Get("atproto-proxy") 26 24 if svc != "" || s.config.FallbackProxy != "" { 27 - return s.handleProxy(e) 25 + s.handleProxy(w, r) 26 + return 28 27 } 29 28 30 - return e.JSON(200, ComAtprotoLabelQueryLabelsResponse{ 29 + s.writeJSON(w, 200, ComAtprotoLabelQueryLabelsResponse{ 31 30 Cursor: nil, 32 31 Labels: []Label{}, 33 32 })
+95 -53
server/handle_oauth_authorize.go
··· 2 2 3 3 import ( 4 4 "fmt" 5 + "net/http" 5 6 "net/url" 6 7 "strings" 7 8 "time" ··· 11 12 "github.com/haileyok/cocoon/oauth" 12 13 "github.com/haileyok/cocoon/oauth/constants" 13 14 "github.com/haileyok/cocoon/oauth/provider" 14 - "github.com/labstack/echo/v4" 15 15 ) 16 16 17 17 type HandleOauthAuthorizeGetInput struct { 18 18 RequestUri string `query:"request_uri"` 19 19 } 20 20 21 - func (s *Server) handleOauthAuthorizeGet(e echo.Context) error { 22 - ctx := e.Request().Context() 21 + func (s *Server) handleOauthAuthorizeGet(w http.ResponseWriter, r *http.Request) { 22 + ctx := r.Context() 23 23 24 24 logger := s.logger.With("name", "handleOauthAuthorizeGet") 25 25 26 - var input HandleOauthAuthorizeGetInput 27 - if err := e.Bind(&input); err != nil { 28 - logger.Error("error binding request", "err", err) 29 - return fmt.Errorf("error binding request") 30 - } 26 + requestUri := r.URL.Query().Get("request_uri") 31 27 32 28 var reqId string 33 - if input.RequestUri != "" { 34 - id, err := oauth.DecodeRequestUri(input.RequestUri) 29 + if requestUri != "" { 30 + id, err := oauth.DecodeRequestUri(requestUri) 35 31 if err != nil { 36 - logger.Error("no request uri found in input", "url", e.Request().URL.String()) 37 - return helpers.InputError(e, to.StringPtr("no request uri")) 32 + logger.Error("no request uri found in input", "url", r.URL.String()) 33 + helpers.InputError(w, to.StringPtr("no request uri")) 34 + return 38 35 } 39 36 reqId = id 40 37 } else { 41 - var parRequest provider.ParRequest 42 - if err := e.Bind(&parRequest); err != nil { 43 - s.logger.Error("error binding for standard auth request", "error", err) 44 - return helpers.InputError(e, to.StringPtr("InvalidRequest")) 38 + parRequest := provider.ParRequest{ 39 + AuthenticateClientRequestBase: provider.AuthenticateClientRequestBase{ 40 + ClientID: r.URL.Query().Get("client_id"), 41 + }, 42 + ResponseType: r.URL.Query().Get("response_type"), 43 + State: r.URL.Query().Get("state"), 44 + RedirectURI: r.URL.Query().Get("redirect_uri"), 45 + Scope: r.URL.Query().Get("scope"), 46 + CodeChallengeMethod: r.URL.Query().Get("code_challenge_method"), 47 + } 48 + if v := r.URL.Query().Get("code_challenge"); v != "" { 49 + parRequest.CodeChallenge = to.StringPtr(v) 50 + } 51 + if v := r.URL.Query().Get("login_hint"); v != "" { 52 + parRequest.LoginHint = to.StringPtr(v) 53 + } 54 + if v := r.URL.Query().Get("dpop_jkt"); v != "" { 55 + parRequest.DpopJkt = to.StringPtr(v) 56 + } 57 + if v := r.URL.Query().Get("response_mode"); v != "" { 58 + parRequest.ResponseMode = to.StringPtr(v) 45 59 } 46 60 47 - if err := e.Validate(parRequest); err != nil { 61 + if err := s.validator.Struct(parRequest); err != nil { 48 62 // render page for logged out dev 49 63 if s.config.Version == "dev" && parRequest.ClientID == "" { 50 - return e.Render(200, "authorize.html", map[string]any{ 64 + s.renderTemplate(w, "authorize.html", map[string]any{ 51 65 "Scopes": []string{"atproto", "transition:generic"}, 52 66 "AppName": "DEV MODE AUTHORIZATION PAGE", 53 67 "Handle": "paula.cocoon.social", 54 68 "RequestUri": "", 55 69 }) 70 + return 56 71 } 57 - return helpers.InputError(e, to.StringPtr("no request uri and invalid parameters")) 72 + helpers.InputError(w, to.StringPtr("no request uri and invalid parameters")) 73 + return 58 74 } 59 75 60 76 client, clientAuth, err := s.oauthProvider.AuthenticateClient(ctx, parRequest.AuthenticateClientRequestBase, nil, &provider.AuthenticateClientOptions{ ··· 62 78 }) 63 79 if err != nil { 64 80 s.logger.Error("error authenticating client in standard request", "client_id", parRequest.ClientID, "error", err) 65 - return helpers.ServerError(e, to.StringPtr(err.Error())) 81 + helpers.ServerError(w, to.StringPtr(err.Error())) 82 + return 66 83 } 67 84 68 85 if parRequest.DpopJkt == nil { 69 86 if client.Metadata.DpopBoundAccessTokens { 87 + // nothing to do 70 88 } 71 89 } else { 72 90 if !client.Metadata.DpopBoundAccessTokens { 73 91 msg := "dpop bound access tokens are not enabled for this client" 74 - return helpers.InputError(e, &msg) 92 + helpers.InputError(w, &msg) 93 + return 75 94 } 76 95 } 77 96 ··· 88 107 89 108 if err := s.db.Create(ctx, authRequest, nil).Error; err != nil { 90 109 s.logger.Error("error creating auth request in db", "error", err) 91 - return helpers.ServerError(e, nil) 110 + helpers.ServerError(w, nil) 111 + return 92 112 } 93 113 94 - input.RequestUri = oauth.EncodeRequestUri(id) 114 + requestUri = oauth.EncodeRequestUri(id) 95 115 reqId = id 96 - 97 116 } 98 117 99 - repo, _, err := s.getSessionRepoOrErr(e) 118 + repo, _, err := s.getSessionRepoOrErr(r) 100 119 if err != nil { 101 - return e.Redirect(303, "/account/signin?"+e.QueryParams().Encode()) 120 + http.Redirect(w, r, "/account/signin?"+r.URL.Query().Encode(), 303) 121 + return 102 122 } 103 123 104 124 var req provider.OauthAuthorizationRequest 105 125 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil { 106 - return helpers.ServerError(e, to.StringPtr(err.Error())) 126 + helpers.ServerError(w, to.StringPtr(err.Error())) 127 + return 107 128 } 108 129 109 - clientId := e.QueryParam("client_id") 130 + clientId := r.URL.Query().Get("client_id") 110 131 if clientId != req.ClientId { 111 - return helpers.InputError(e, to.StringPtr("client id does not match the client id for the supplied request")) 132 + helpers.InputError(w, to.StringPtr("client id does not match the client id for the supplied request")) 133 + return 112 134 } 113 135 114 - client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), req.ClientId) 136 + client, err := s.oauthProvider.ClientManager.GetClient(r.Context(), req.ClientId) 115 137 if err != nil { 116 - return helpers.ServerError(e, to.StringPtr(err.Error())) 138 + helpers.ServerError(w, to.StringPtr(err.Error())) 139 + return 117 140 } 118 141 119 142 scopes := strings.Split(req.Parameters.Scope, " ") ··· 122 145 data := map[string]any{ 123 146 "Scopes": scopes, 124 147 "AppName": appName, 125 - "RequestUri": input.RequestUri, 126 - "QueryParams": e.QueryParams().Encode(), 148 + "RequestUri": requestUri, 149 + "QueryParams": r.URL.Query().Encode(), 127 150 "Handle": repo.Actor.Handle, 128 151 } 129 152 130 - return e.Render(200, "authorize.html", data) 153 + s.renderTemplate(w, "authorize.html", data) 131 154 } 132 155 133 156 type OauthAuthorizePostRequest struct { ··· 135 158 AcceptOrRejct string `form:"accept_or_reject"` 136 159 } 137 160 138 - func (s *Server) handleOauthAuthorizePost(e echo.Context) error { 139 - ctx := e.Request().Context() 161 + func (s *Server) handleOauthAuthorizePost(w http.ResponseWriter, r *http.Request) { 162 + ctx := r.Context() 140 163 logger := s.logger.With("name", "handleOauthAuthorizePost") 141 164 142 - repo, _, err := s.getSessionRepoOrErr(e) 165 + repo, _, err := s.getSessionRepoOrErr(r) 143 166 if err != nil { 144 - return e.Redirect(303, "/account/signin") 167 + http.Redirect(w, r, "/account/signin", 303) 168 + return 145 169 } 146 170 147 - var req OauthAuthorizePostRequest 148 - if err := e.Bind(&req); err != nil { 149 - logger.Error("error binding authorize post request", "error", err) 150 - return helpers.InputError(e, nil) 171 + if err := r.ParseForm(); err != nil { 172 + logger.Error("error parsing authorize post form", "error", err) 173 + helpers.InputError(w, nil) 174 + return 175 + } 176 + 177 + req := OauthAuthorizePostRequest{ 178 + RequestUri: r.FormValue("request_uri"), 179 + AcceptOrRejct: r.FormValue("accept_or_reject"), 151 180 } 152 181 153 182 reqId, err := oauth.DecodeRequestUri(req.RequestUri) 154 183 if err != nil { 155 - return helpers.InputError(e, to.StringPtr(err.Error())) 184 + helpers.InputError(w, to.StringPtr(err.Error())) 185 + return 156 186 } 157 187 158 188 var authReq provider.OauthAuthorizationRequest 159 189 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil { 160 - return helpers.ServerError(e, to.StringPtr(err.Error())) 190 + helpers.ServerError(w, to.StringPtr(err.Error())) 191 + return 161 192 } 162 193 163 - client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), authReq.ClientId) 194 + client, err := s.oauthProvider.ClientManager.GetClient(r.Context(), authReq.ClientId) 164 195 if err != nil { 165 - return helpers.ServerError(e, to.StringPtr(err.Error())) 196 + helpers.ServerError(w, to.StringPtr(err.Error())) 197 + return 166 198 } 167 199 168 200 // TODO: figure out how im supposed to actually redirect 169 201 if req.AcceptOrRejct == "reject" { 170 - return e.Redirect(303, client.Metadata.ClientURI) 202 + http.Redirect(w, r, client.Metadata.ClientURI, 303) 203 + return 171 204 } 172 205 173 206 if time.Now().After(authReq.ExpiresAt) { 174 - return helpers.InputError(e, to.StringPtr("the request has expired")) 207 + helpers.InputError(w, to.StringPtr("the request has expired")) 208 + return 175 209 } 176 210 177 211 if authReq.Sub != nil || authReq.Code != nil { 178 - return helpers.InputError(e, to.StringPtr("this request was already authorized")) 212 + helpers.InputError(w, to.StringPtr("this request was already authorized")) 213 + return 179 214 } 180 215 181 216 code := oauth.GenerateCode() 182 217 183 - if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil { 218 + // Use the first non-loopback remote address as the IP 219 + ip := r.RemoteAddr 220 + if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" { 221 + ip = strings.Split(forwarded, ",")[0] 222 + } 223 + 224 + if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, ip, reqId).Error; err != nil { 184 225 logger.Error("error updating authorization request", "error", err) 185 - return helpers.ServerError(e, nil) 226 + helpers.ServerError(w, nil) 227 + return 186 228 } 187 229 188 230 q := url.Values{} ··· 197 239 hashOrQuestion = "#" 198 240 case "query": 199 241 // do nothing 200 - break 201 242 default: 202 243 if authReq.Parameters.ResponseType != "code" { 203 244 hashOrQuestion = "#" ··· 209 250 } 210 251 } 211 252 212 - return e.Redirect(303, authReq.Parameters.RedirectURI+hashOrQuestion+q.Encode()) 253 + _ = fmt.Sprintf // avoid unused import if fmt ends up unused 254 + http.Redirect(w, r, authReq.Parameters.RedirectURI+hashOrQuestion+q.Encode(), 303) 213 255 }
+3 -3
server/handle_oauth_jwks.go
··· 1 1 package server 2 2 3 - import "github.com/labstack/echo/v4" 3 + import "net/http" 4 4 5 5 type OauthJwksResponse struct { 6 6 Keys []any `json:"keys"` 7 7 } 8 8 9 9 // TODO: ? 10 - func (s *Server) handleOauthJwks(e echo.Context) error { 11 - return e.JSON(200, OauthJwksResponse{Keys: []any{}}) 10 + func (s *Server) handleOauthJwks(w http.ResponseWriter, r *http.Request) { 11 + s.writeJSON(w, 200, OauthJwksResponse{Keys: []any{}}) 12 12 }
+57 -21
server/handle_oauth_par.go
··· 2 2 3 3 import ( 4 4 "errors" 5 + "net/http" 5 6 "time" 6 7 7 8 "github.com/Azure/go-autorest/autorest/to" ··· 10 11 "github.com/haileyok/cocoon/oauth/constants" 11 12 "github.com/haileyok/cocoon/oauth/dpop" 12 13 "github.com/haileyok/cocoon/oauth/provider" 13 - "github.com/labstack/echo/v4" 14 14 ) 15 15 16 16 type OauthParResponse struct { ··· 18 18 RequestURI string `json:"request_uri"` 19 19 } 20 20 21 - func (s *Server) handleOauthPar(e echo.Context) error { 22 - ctx := e.Request().Context() 21 + func (s *Server) handleOauthPar(w http.ResponseWriter, r *http.Request) { 22 + ctx := r.Context() 23 23 logger := s.logger.With("name", "handleOauthPar") 24 24 25 - var parRequest provider.ParRequest 26 - if err := e.Bind(&parRequest); err != nil { 27 - logger.Error("error binding for par request", "error", err) 28 - return helpers.ServerError(e, nil) 25 + if err := r.ParseForm(); err != nil { 26 + logger.Error("error parsing par request form", "error", err) 27 + helpers.ServerError(w, nil) 28 + return 29 29 } 30 30 31 - if err := e.Validate(parRequest); err != nil { 31 + parRequest := provider.ParRequest{ 32 + AuthenticateClientRequestBase: provider.AuthenticateClientRequestBase{ 33 + ClientID: r.FormValue("client_id"), 34 + }, 35 + ResponseType: r.FormValue("response_type"), 36 + State: r.FormValue("state"), 37 + RedirectURI: r.FormValue("redirect_uri"), 38 + Scope: r.FormValue("scope"), 39 + CodeChallengeMethod: r.FormValue("code_challenge_method"), 40 + } 41 + if v := r.FormValue("code_challenge"); v != "" { 42 + parRequest.CodeChallenge = to.StringPtr(v) 43 + } 44 + if v := r.FormValue("login_hint"); v != "" { 45 + parRequest.LoginHint = to.StringPtr(v) 46 + } 47 + if v := r.FormValue("dpop_jkt"); v != "" { 48 + parRequest.DpopJkt = to.StringPtr(v) 49 + } 50 + if v := r.FormValue("response_mode"); v != "" { 51 + parRequest.ResponseMode = to.StringPtr(v) 52 + } 53 + if v := r.FormValue("client_assertion_type"); v != "" { 54 + parRequest.ClientAssertionType = to.StringPtr(v) 55 + } 56 + if v := r.FormValue("client_assertion"); v != "" { 57 + parRequest.ClientAssertion = to.StringPtr(v) 58 + } 59 + 60 + if err := s.validator.Struct(parRequest); err != nil { 32 61 logger.Error("missing parameters for par request", "error", err) 33 - return helpers.InputError(e, nil) 62 + helpers.InputError(w, nil) 63 + return 34 64 } 35 65 36 66 // TODO: this seems wrong. should be a way to get the entire request url i believe, but this will work for now 37 - dpopProof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, nil) 67 + dpopProof, err := s.oauthProvider.DpopManager.CheckProof(r.Method, "https://"+s.config.Hostname+r.URL.String(), r.Header, nil) 38 68 if err != nil { 39 69 if errors.Is(err, dpop.ErrUseDpopNonce) { 40 70 nonce := s.oauthProvider.NextNonce() 41 71 if nonce != "" { 42 - e.Response().Header().Set("DPoP-Nonce", nonce) 43 - e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 72 + w.Header().Set("DPoP-Nonce", nonce) 73 + w.Header().Add("access-control-expose-headers", "DPoP-Nonce") 44 74 } 45 - logger.Error("nonce error: use_dpop_nonce", "headers", e.Request().Header) 46 - return e.JSON(400, map[string]string{ 75 + logger.Error("nonce error: use_dpop_nonce", "headers", r.Header) 76 + s.writeJSON(w, 400, map[string]string{ 47 77 "error": "use_dpop_nonce", 48 78 }) 79 + return 49 80 } 50 81 logger.Error("error getting dpop proof", "error", err) 51 - return helpers.InputError(e, nil) 82 + helpers.InputError(w, nil) 83 + return 52 84 } 53 85 54 - client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), parRequest.AuthenticateClientRequestBase, dpopProof, &provider.AuthenticateClientOptions{ 86 + client, clientAuth, err := s.oauthProvider.AuthenticateClient(ctx, parRequest.AuthenticateClientRequestBase, dpopProof, &provider.AuthenticateClientOptions{ 55 87 // rfc9449 56 88 // https://github.com/bluesky-social/atproto/blob/main/packages/oauth/oauth-provider/src/oauth-provider.ts#L473 57 89 AllowMissingDpopProof: true, 58 90 }) 59 91 if err != nil { 60 92 logger.Error("error authenticating client", "client_id", parRequest.ClientID, "error", err) 61 - return helpers.InputError(e, to.StringPtr(err.Error())) 93 + helpers.InputError(w, to.StringPtr(err.Error())) 94 + return 62 95 } 63 96 64 97 if parRequest.DpopJkt == nil { ··· 69 102 if !client.Metadata.DpopBoundAccessTokens { 70 103 msg := "dpop bound access tokens are not enabled for this client" 71 104 logger.Error(msg) 72 - return helpers.InputError(e, &msg) 105 + helpers.InputError(w, &msg) 106 + return 73 107 } 74 108 75 109 if dpopProof.JKT != *parRequest.DpopJkt { 76 110 msg := "supplied dpop jkt does not match header dpop jkt" 77 111 logger.Error(msg) 78 - return helpers.InputError(e, &msg) 112 + helpers.InputError(w, &msg) 113 + return 79 114 } 80 115 } 81 116 ··· 92 127 93 128 if err := s.db.Create(ctx, authRequest, nil).Error; err != nil { 94 129 logger.Error("error creating auth request in db", "error", err) 95 - return helpers.ServerError(e, nil) 130 + helpers.ServerError(w, nil) 131 + return 96 132 } 97 133 98 134 uri := oauth.EncodeRequestUri(id) 99 135 100 - return e.JSON(201, OauthParResponse{ 136 + s.writeJSON(w, 201, OauthParResponse{ 101 137 ExpiresIn: int64(constants.ParExpiresIn.Seconds()), 102 138 RequestURI: uri, 103 139 })
+102 -47
server/handle_oauth_token.go
··· 6 6 "encoding/base64" 7 7 "errors" 8 8 "fmt" 9 + "net/http" 9 10 "slices" 10 11 "time" 11 12 ··· 16 17 "github.com/haileyok/cocoon/oauth/constants" 17 18 "github.com/haileyok/cocoon/oauth/dpop" 18 19 "github.com/haileyok/cocoon/oauth/provider" 19 - "github.com/labstack/echo/v4" 20 20 ) 21 21 22 22 type OauthTokenRequest struct { ··· 37 37 Sub string `json:"sub"` 38 38 } 39 39 40 - func (s *Server) handleOauthToken(e echo.Context) error { 41 - ctx := e.Request().Context() 40 + func (s *Server) handleOauthToken(w http.ResponseWriter, r *http.Request) { 41 + ctx := r.Context() 42 42 logger := s.logger.With("name", "handleOauthToken") 43 43 44 - var req OauthTokenRequest 45 - if err := e.Bind(&req); err != nil { 46 - logger.Error("error binding token request", "error", err) 47 - return helpers.ServerError(e, nil) 44 + if err := r.ParseForm(); err != nil { 45 + logger.Error("error parsing token request form", "error", err) 46 + helpers.ServerError(w, nil) 47 + return 48 + } 49 + 50 + req := OauthTokenRequest{ 51 + GrantType: r.FormValue("grant_type"), 52 + } 53 + if v := r.FormValue("code"); v != "" { 54 + req.Code = to.StringPtr(v) 55 + } 56 + if v := r.FormValue("code_verifier"); v != "" { 57 + req.CodeVerifier = to.StringPtr(v) 58 + } 59 + if v := r.FormValue("redirect_uri"); v != "" { 60 + req.RedirectURI = to.StringPtr(v) 61 + } 62 + if v := r.FormValue("refresh_token"); v != "" { 63 + req.RefreshToken = to.StringPtr(v) 64 + } 65 + if v := r.FormValue("client_assertion_type"); v != "" { 66 + req.ClientAssertionType = to.StringPtr(v) 67 + } 68 + if v := r.FormValue("client_assertion"); v != "" { 69 + req.ClientAssertion = to.StringPtr(v) 70 + } 71 + req.AuthenticateClientRequestBase = provider.AuthenticateClientRequestBase{ 72 + ClientID: r.FormValue("client_id"), 73 + ClientAssertionType: req.ClientAssertionType, 74 + ClientAssertion: req.ClientAssertion, 48 75 } 49 76 50 - proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, e.Request().URL.String(), e.Request().Header, nil) 77 + proof, err := s.oauthProvider.DpopManager.CheckProof(r.Method, r.URL.String(), r.Header, nil) 51 78 if err != nil { 52 79 if errors.Is(err, dpop.ErrUseDpopNonce) { 53 80 nonce := s.oauthProvider.NextNonce() 54 81 if nonce != "" { 55 - e.Response().Header().Set("DPoP-Nonce", nonce) 56 - e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 82 + w.Header().Set("DPoP-Nonce", nonce) 83 + w.Header().Add("access-control-expose-headers", "DPoP-Nonce") 57 84 } 58 - return e.JSON(400, map[string]string{ 85 + s.writeJSON(w, 400, map[string]string{ 59 86 "error": "use_dpop_nonce", 60 87 }) 88 + return 61 89 } 62 90 logger.Error("error getting dpop proof", "error", err) 63 - return helpers.InputError(e, nil) 91 + helpers.InputError(w, nil) 92 + return 64 93 } 65 94 66 - client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{ 95 + client, clientAuth, err := s.oauthProvider.AuthenticateClient(ctx, req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{ 67 96 AllowMissingDpopProof: true, 68 97 }) 69 98 if err != nil { 70 99 logger.Error("error authenticating client", "client_id", req.ClientID, "error", err) 71 - return helpers.InputError(e, to.StringPtr(err.Error())) 100 + helpers.InputError(w, to.StringPtr(err.Error())) 101 + return 72 102 } 73 103 74 - // TODO: this should come from an oauth provier config 104 + // TODO: this should come from an oauth provider config 75 105 if !slices.Contains([]string{"authorization_code", "refresh_token"}, req.GrantType) { 76 - return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the server`, req.GrantType))) 106 + helpers.InputError(w, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the server`, req.GrantType))) 107 + return 77 108 } 78 109 79 110 if !slices.Contains(client.Metadata.GrantTypes, req.GrantType) { 80 - return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the client`, req.GrantType))) 111 + helpers.InputError(w, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the client`, req.GrantType))) 112 + return 81 113 } 82 114 83 115 if req.GrantType == "authorization_code" { 84 116 if req.Code == nil { 85 - return helpers.InputError(e, to.StringPtr(`"code" is required"`)) 117 + helpers.InputError(w, to.StringPtr(`"code" is required"`)) 118 + return 86 119 } 87 120 88 121 var authReq provider.OauthAuthorizationRequest 89 122 // get the lil guy and delete him 90 123 if err := s.db.Raw(ctx, "DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil { 91 124 logger.Error("error finding authorization request", "error", err) 92 - return helpers.ServerError(e, nil) 125 + helpers.ServerError(w, nil) 126 + return 93 127 } 94 128 95 129 if req.RedirectURI == nil || *req.RedirectURI != authReq.Parameters.RedirectURI { 96 - return helpers.InputError(e, to.StringPtr(`"redirect_uri" mismatch`)) 130 + helpers.InputError(w, to.StringPtr(`"redirect_uri" mismatch`)) 131 + return 97 132 } 98 133 99 134 if authReq.Parameters.CodeChallenge != nil { 100 135 if req.CodeVerifier == nil { 101 - return helpers.InputError(e, to.StringPtr(`"code_verifier" is required`)) 136 + helpers.InputError(w, to.StringPtr(`"code_verifier" is required`)) 137 + return 102 138 } 103 139 104 140 if len(*req.CodeVerifier) < 43 { 105 - return helpers.InputError(e, to.StringPtr(`"code_verifier" is too short`)) 141 + helpers.InputError(w, to.StringPtr(`"code_verifier" is too short`)) 142 + return 106 143 } 107 144 108 - switch *&authReq.Parameters.CodeChallengeMethod { 145 + switch authReq.Parameters.CodeChallengeMethod { 109 146 case "", "plain": 110 147 if authReq.Parameters.CodeChallenge != req.CodeVerifier { 111 - return helpers.InputError(e, to.StringPtr("invalid code_verifier")) 148 + helpers.InputError(w, to.StringPtr("invalid code_verifier")) 149 + return 112 150 } 113 151 case "S256": 114 152 inputChal, err := base64.RawURLEncoding.DecodeString(*authReq.Parameters.CodeChallenge) 115 153 if err != nil { 116 154 logger.Error("error decoding code challenge", "error", err) 117 - return helpers.ServerError(e, nil) 155 + helpers.ServerError(w, nil) 156 + return 118 157 } 119 158 120 159 h := sha256.New() ··· 122 161 compdChal := h.Sum(nil) 123 162 124 163 if !bytes.Equal(inputChal, compdChal) { 125 - return helpers.InputError(e, to.StringPtr("invalid code_verifier")) 164 + helpers.InputError(w, to.StringPtr("invalid code_verifier")) 165 + return 126 166 } 127 167 default: 128 - return helpers.InputError(e, to.StringPtr("unsupported code_challenge_method "+*&authReq.Parameters.CodeChallengeMethod)) 168 + helpers.InputError(w, to.StringPtr("unsupported code_challenge_method "+authReq.Parameters.CodeChallengeMethod)) 169 + return 129 170 } 130 171 } else if req.CodeVerifier != nil { 131 - return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided")) 172 + helpers.InputError(w, to.StringPtr("code_challenge parameter wasn't provided")) 173 + return 132 174 } 133 175 134 176 repo, err := s.getRepoActorByDid(ctx, *authReq.Sub) 135 177 if err != nil { 136 - helpers.InputError(e, to.StringPtr("unable to find actor")) 178 + helpers.InputError(w, to.StringPtr("unable to find actor")) 179 + return 137 180 } 138 181 139 182 now := time.Now() ··· 159 202 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims) 160 203 accessString, err := accessToken.SignedString(s.privateKey) 161 204 if err != nil { 162 - return err 205 + helpers.ServerError(w, nil) 206 + return 163 207 } 164 208 165 209 if err := s.db.Create(ctx, &provider.OauthToken{ ··· 175 219 Ip: authReq.Ip, 176 220 }, nil).Error; err != nil { 177 221 logger.Error("error creating token in db", "error", err) 178 - return helpers.ServerError(e, nil) 222 + helpers.ServerError(w, nil) 223 + return 179 224 } 180 225 181 226 // prob not needed ··· 184 229 tokenType = "DPoP" 185 230 } 186 231 187 - e.Response().Header().Set("content-type", "application/json") 188 - 189 - return e.JSON(200, OauthTokenResponse{ 232 + s.writeJSON(w, 200, OauthTokenResponse{ 190 233 AccessToken: accessString, 191 234 RefreshToken: refreshToken, 192 235 TokenType: tokenType, ··· 194 237 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()), 195 238 Sub: repo.Repo.Did, 196 239 }) 240 + return 197 241 } 198 242 199 243 if req.GrantType == "refresh_token" { 200 244 if req.RefreshToken == nil { 201 - return helpers.InputError(e, to.StringPtr(`"refresh_token" is required`)) 245 + helpers.InputError(w, to.StringPtr(`"refresh_token" is required`)) 246 + return 202 247 } 203 248 204 249 var oauthToken provider.OauthToken 205 250 if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil { 206 251 logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken) 207 - return helpers.ServerError(e, nil) 252 + helpers.ServerError(w, nil) 253 + return 208 254 } 209 255 210 256 if client.Metadata.ClientID != oauthToken.ClientId { 211 - return helpers.InputError(e, to.StringPtr(`"client_id" mismatch`)) 257 + helpers.InputError(w, to.StringPtr(`"client_id" mismatch`)) 258 + return 212 259 } 213 260 214 261 if clientAuth.Method != oauthToken.ClientAuth.Method { 215 - return helpers.InputError(e, to.StringPtr(`"client authentication method mismatch`)) 262 + helpers.InputError(w, to.StringPtr(`"client authentication method mismatch`)) 263 + return 216 264 } 217 265 218 266 if *oauthToken.Parameters.DpopJkt != proof.JKT { 219 - return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt")) 267 + helpers.InputError(w, to.StringPtr("dpop proof does not match expected jkt")) 268 + return 220 269 } 221 270 222 271 ageRes := oauth.GetSessionAgeFromToken(oauthToken) 223 272 224 273 if ageRes.SessionExpired { 225 - return helpers.InputError(e, to.StringPtr("Session expired")) 274 + helpers.InputError(w, to.StringPtr("Session expired")) 275 + return 226 276 } 227 277 228 278 if ageRes.RefreshExpired { 229 - return helpers.InputError(e, to.StringPtr("Refresh token expired")) 279 + helpers.InputError(w, to.StringPtr("Refresh token expired")) 280 + return 230 281 } 231 282 232 283 if client.Metadata.DpopBoundAccessTokens && oauthToken.Parameters.DpopJkt == nil { 233 284 // why? ref impl 234 - return helpers.InputError(e, to.StringPtr("dpop jkt is required for dpop bound access tokens")) 285 + helpers.InputError(w, to.StringPtr("dpop jkt is required for dpop bound access tokens")) 286 + return 235 287 } 236 288 237 289 nextTokenId := oauth.GenerateTokenId() ··· 251 303 } 252 304 253 305 if oauthToken.Parameters.DpopJkt != nil { 254 - accessClaims["cnf"] = *&oauthToken.Parameters.DpopJkt 306 + accessClaims["cnf"] = oauthToken.Parameters.DpopJkt 255 307 } 256 308 257 309 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims) 258 310 accessString, err := accessToken.SignedString(s.privateKey) 259 311 if err != nil { 260 - return err 312 + helpers.ServerError(w, nil) 313 + return 261 314 } 262 315 263 316 if err := s.db.Exec(ctx, "UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil { 264 317 logger.Error("error updating token", "error", err) 265 - return helpers.ServerError(e, nil) 318 + helpers.ServerError(w, nil) 319 + return 266 320 } 267 321 268 322 // prob not needed ··· 271 325 tokenType = "DPoP" 272 326 } 273 327 274 - return e.JSON(200, OauthTokenResponse{ 328 + s.writeJSON(w, 200, OauthTokenResponse{ 275 329 AccessToken: accessString, 276 330 RefreshToken: nextRefreshToken, 277 331 TokenType: tokenType, ··· 279 333 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()), 280 334 Sub: oauthToken.Sub, 281 335 }) 336 + return 282 337 } 283 338 284 - return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`grant type "%s" is not supported`, req.GrantType))) 339 + helpers.InputError(w, to.StringPtr(fmt.Sprintf(`grant type "%s" is not supported`, req.GrantType))) 285 340 }
+36 -27
server/handle_proxy.go
··· 6 6 "encoding/base64" 7 7 "encoding/json" 8 8 "fmt" 9 + "io" 9 10 "net/http" 10 11 "strings" 11 12 "time" ··· 13 14 "github.com/google/uuid" 14 15 "github.com/haileyok/cocoon/internal/helpers" 15 16 "github.com/haileyok/cocoon/models" 16 - "github.com/labstack/echo/v4" 17 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18 18 ) 19 19 20 - func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) { 21 - svc := e.Request().Header.Get("atproto-proxy") 20 + func (s *Server) getAtprotoProxyEndpointFromRequest(r *http.Request) (string, string, error) { 21 + svc := r.Header.Get("atproto-proxy") 22 22 if svc == "" && s.config.FallbackProxy != "" { 23 23 svc = s.config.FallbackProxy 24 24 } ··· 31 31 svcDid := svcPts[0] 32 32 svcId := "#" + svcPts[1] 33 33 34 - doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid) 34 + doc, err := s.passport.FetchDoc(r.Context(), svcDid) 35 35 if err != nil { 36 36 return "", "", err 37 37 } ··· 46 46 return endpoint, svcDid, nil 47 47 } 48 48 49 - func (s *Server) handleProxy(e echo.Context) error { 49 + func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) { 50 50 logger := s.logger.With("handler", "handleProxy") 51 51 52 - repo, isAuthed := e.Get("repo").(*models.RepoActor) 52 + repo, isAuthed := getContextValue[*models.RepoActor](r, contextKeyRepo) 53 53 54 - pts := strings.Split(e.Request().URL.Path, "/") 54 + pts := strings.Split(r.URL.Path, "/") 55 55 if len(pts) != 3 { 56 - return fmt.Errorf("incorrect number of parts") 56 + helpers.ServerError(w, nil) 57 + return 57 58 } 58 59 59 - endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e) 60 + endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(r) 60 61 if err != nil { 61 62 logger.Error("could not get atproto proxy", "error", err) 62 - return helpers.ServerError(e, nil) 63 + helpers.ServerError(w, nil) 64 + return 63 65 } 64 66 65 - requrl := e.Request().URL 67 + requrl := *r.URL 66 68 requrl.Host = strings.TrimPrefix(endpoint, "https://") 67 69 requrl.Scheme = "https" 68 70 69 - body := e.Request().Body 70 - if e.Request().Method == "GET" { 71 - body = nil 71 + var body io.Reader 72 + if r.Method != http.MethodGet { 73 + body = r.Body 72 74 } 73 75 74 - req, err := http.NewRequest(e.Request().Method, requrl.String(), body) 76 + req, err := http.NewRequest(r.Method, requrl.String(), body) 75 77 if err != nil { 76 - return err 78 + helpers.ServerError(w, nil) 79 + return 77 80 } 78 81 79 - req.Header = e.Request().Header.Clone() 82 + req.Header = r.Header.Clone() 80 83 81 84 if isAuthed { 82 85 // this is a little dumb. i should probably figure out a better way to do this, and use ··· 91 94 hj, err := json.Marshal(header) 92 95 if err != nil { 93 96 logger.Error("error marshaling header", "error", err) 94 - return helpers.ServerError(e, nil) 97 + helpers.ServerError(w, nil) 98 + return 95 99 } 96 100 97 101 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") ··· 100 104 // underlying feed generator and the app view passes it on. This allows the 101 105 // getFeed implementation to pass in the desired lxm and aud for the token 102 106 // and then just delegate to the general proxying logic 103 - lxm, proxyTokenLxmExists := e.Get("proxyTokenLxm").(string) 107 + lxm, proxyTokenLxmExists := getContextValue[string](r, contextKeyProxyTokenLxm) 104 108 if !proxyTokenLxmExists || lxm == "" { 105 109 lxm = pts[2] 106 110 } 107 - aud, proxyTokenAudExists := e.Get("proxyTokenAud").(string) 111 + aud, proxyTokenAudExists := getContextValue[string](r, contextKeyProxyTokenAud) 108 112 if !proxyTokenAudExists || aud == "" { 109 113 aud = svcDid 110 114 } ··· 118 122 } 119 123 pj, err := json.Marshal(payload) 120 124 if err != nil { 121 - logger.Error("error marashaling payload", "error", err) 122 - return helpers.ServerError(e, nil) 125 + logger.Error("error marshaling payload", "error", err) 126 + helpers.ServerError(w, nil) 127 + return 123 128 } 124 129 125 130 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=") ··· 130 135 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 131 136 if err != nil { 132 137 logger.Error("can't load private key", "error", err) 133 - return err 138 + helpers.ServerError(w, nil) 139 + return 134 140 } 135 141 136 142 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 137 143 if err != nil { 138 144 logger.Error("error signing", "error", err) 145 + helpers.ServerError(w, nil) 146 + return 139 147 } 140 148 141 149 rBytes := R.Bytes() ··· 157 165 158 166 resp, err := http.DefaultClient.Do(req) 159 167 if err != nil { 160 - return err 168 + helpers.ServerError(w, nil) 169 + return 161 170 } 162 171 defer resp.Body.Close() 163 172 164 173 for k, v := range resp.Header { 165 - e.Response().Header().Set(k, strings.Join(v, ",")) 174 + w.Header().Set(k, strings.Join(v, ",")) 166 175 } 167 - 168 - return e.Stream(resp.StatusCode, e.Response().Header().Get("content-type"), resp.Body) 176 + w.WriteHeader(resp.StatusCode) 177 + io.Copy(w, resp.Body) 169 178 }
+21 -11
server/handle_proxy_get_feed.go
··· 1 1 package server 2 2 3 3 import ( 4 + "context" 5 + "net/http" 6 + 4 7 "github.com/Azure/go-autorest/autorest/to" 5 8 "github.com/bluesky-social/indigo/api/atproto" 6 9 "github.com/bluesky-social/indigo/api/bsky" 7 10 "github.com/bluesky-social/indigo/atproto/syntax" 8 11 "github.com/bluesky-social/indigo/xrpc" 9 12 "github.com/haileyok/cocoon/internal/helpers" 10 - "github.com/labstack/echo/v4" 11 13 ) 12 14 13 - func (s *Server) handleProxyBskyFeedGetFeed(e echo.Context) error { 14 - feedUri, err := syntax.ParseATURI(e.QueryParam("feed")) 15 + func (s *Server) handleProxyBskyFeedGetFeed(w http.ResponseWriter, r *http.Request) { 16 + feedUri, err := syntax.ParseATURI(r.URL.Query().Get("feed")) 15 17 if err != nil { 16 - return helpers.InputError(e, to.StringPtr("invalid feed uri")) 18 + helpers.InputError(w, to.StringPtr("invalid feed uri")) 19 + return 17 20 } 18 21 19 - appViewEndpoint, _, err := s.getAtprotoProxyEndpointFromRequest(e) 22 + appViewEndpoint, _, err := s.getAtprotoProxyEndpointFromRequest(r) 20 23 if err != nil { 21 - e.Logger().Error("could not get atproto proxy", "error", err) 22 - return helpers.ServerError(e, nil) 24 + s.logger.Error("could not get atproto proxy", "error", err) 25 + helpers.ServerError(w, nil) 26 + return 23 27 } 24 28 25 29 appViewClient := xrpc.Client{ 26 30 Host: appViewEndpoint, 27 31 } 28 - feedRecord, err := atproto.RepoGetRecord(e.Request().Context(), &appViewClient, "", feedUri.Collection().String(), feedUri.Authority().String(), feedUri.RecordKey().String()) 32 + feedRecord, err := atproto.RepoGetRecord(r.Context(), &appViewClient, "", feedUri.Collection().String(), feedUri.Authority().String(), feedUri.RecordKey().String()) 33 + if err != nil { 34 + s.logger.Error("could not get feed record", "error", err) 35 + helpers.ServerError(w, nil) 36 + return 37 + } 29 38 feedGeneratorDid := feedRecord.Value.Val.(*bsky.FeedGenerator).Did 30 39 31 - e.Set("proxyTokenLxm", "app.bsky.feed.getFeedSkeleton") 32 - e.Set("proxyTokenAud", feedGeneratorDid) 40 + // Inject proxy token overrides into the request context so handleProxy can read them. 41 + ctx := context.WithValue(r.Context(), contextKeyProxyTokenLxm, "app.bsky.feed.getFeedSkeleton") 42 + ctx = context.WithValue(ctx, contextKeyProxyTokenAud, feedGeneratorDid) 33 43 34 - return s.handleProxy(e) 44 + s.handleProxy(w, r.WithContext(ctx)) 35 45 }
+17 -11
server/handle_repo_apply_writes.go
··· 1 1 package server 2 2 3 3 import ( 4 + "encoding/json" 5 + "net/http" 6 + 4 7 "github.com/haileyok/cocoon/internal/helpers" 5 8 "github.com/haileyok/cocoon/models" 6 - "github.com/labstack/echo/v4" 7 9 ) 8 10 9 11 type ComAtprotoRepoApplyWritesInput struct { ··· 25 27 Results []ApplyWriteResult `json:"results"` 26 28 } 27 29 28 - func (s *Server) handleApplyWrites(e echo.Context) error { 29 - ctx := e.Request().Context() 30 + func (s *Server) handleApplyWrites(w http.ResponseWriter, r *http.Request) { 31 + ctx := r.Context() 30 32 logger := s.logger.With("name", "handleRepoApplyWrites") 31 33 32 34 var req ComAtprotoRepoApplyWritesInput 33 - if err := e.Bind(&req); err != nil { 35 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 34 36 logger.Error("error binding", "error", err) 35 - return helpers.ServerError(e, nil) 37 + helpers.ServerError(w, nil) 38 + return 36 39 } 37 40 38 - if err := e.Validate(req); err != nil { 41 + if err := s.validator.Struct(req); err != nil { 39 42 logger.Error("error validating", "error", err) 40 - return helpers.InputError(e, nil) 43 + helpers.InputError(w, nil) 44 + return 41 45 } 42 46 43 - repo := e.Get("repo").(*models.RepoActor) 47 + repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 44 48 45 49 if repo.Repo.Did != req.Repo { 46 50 logger.Warn("mismatched repo/auth") 47 - return helpers.InputError(e, nil) 51 + helpers.InputError(w, nil) 52 + return 48 53 } 49 54 50 55 ops := make([]Op, 0, len(req.Writes)) ··· 60 65 results, err := s.repoman.applyWrites(ctx, repo.Repo, ops, req.SwapCommit) 61 66 if err != nil { 62 67 logger.Error("error applying writes", "error", err) 63 - return helpers.ServerError(e, nil) 68 + helpers.ServerError(w, nil) 69 + return 64 70 } 65 71 66 72 commit := *results[0].Commit ··· 69 75 results[i].Commit = nil 70 76 } 71 77 72 - return e.JSON(200, ComAtprotoRepoApplyWritesOutput{ 78 + s.writeJSON(w, http.StatusOK, ComAtprotoRepoApplyWritesOutput{ 73 79 Commit: commit, 74 80 Results: results, 75 81 })
+18 -12
server/handle_repo_create_record.go
··· 1 1 package server 2 2 3 3 import ( 4 + "encoding/json" 5 + "net/http" 6 + 4 7 "github.com/haileyok/cocoon/internal/helpers" 5 8 "github.com/haileyok/cocoon/models" 6 - "github.com/labstack/echo/v4" 7 9 ) 8 10 9 11 type ComAtprotoRepoCreateRecordInput struct { ··· 16 18 SwapCommit *string `json:"swapCommit"` 17 19 } 18 20 19 - func (s *Server) handleCreateRecord(e echo.Context) error { 20 - ctx := e.Request().Context() 21 + func (s *Server) handleCreateRecord(w http.ResponseWriter, r *http.Request) { 22 + ctx := r.Context() 21 23 logger := s.logger.With("name", "handleCreateRecord") 22 24 23 - repo := e.Get("repo").(*models.RepoActor) 25 + repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 24 26 25 27 var req ComAtprotoRepoCreateRecordInput 26 - if err := e.Bind(&req); err != nil { 27 - logger.Error("error binding", "error", err) 28 - return helpers.ServerError(e, nil) 28 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 29 + logger.Error("error decoding", "error", err) 30 + helpers.ServerError(w, nil) 31 + return 29 32 } 30 33 31 - if err := e.Validate(req); err != nil { 34 + if err := s.validator.Struct(req); err != nil { 32 35 logger.Error("error validating", "error", err) 33 - return helpers.InputError(e, nil) 36 + helpers.InputError(w, nil) 37 + return 34 38 } 35 39 36 40 if repo.Repo.Did != req.Repo { 37 41 logger.Warn("mismatched repo/auth") 38 - return helpers.InputError(e, nil) 42 + helpers.InputError(w, nil) 43 + return 39 44 } 40 45 41 46 optype := OpTypeCreate ··· 55 60 }, req.SwapCommit) 56 61 if err != nil { 57 62 logger.Error("error applying writes", "error", err) 58 - return helpers.ServerError(e, nil) 63 + helpers.ServerError(w, nil) 64 + return 59 65 } 60 66 61 67 results[0].Type = nil 62 68 63 - return e.JSON(200, results[0]) 69 + s.writeJSON(w, 200, results[0]) 64 70 }
+18 -12
server/handle_repo_delete_record.go
··· 1 1 package server 2 2 3 3 import ( 4 + "encoding/json" 5 + "net/http" 6 + 4 7 "github.com/haileyok/cocoon/internal/helpers" 5 8 "github.com/haileyok/cocoon/models" 6 - "github.com/labstack/echo/v4" 7 9 ) 8 10 9 11 type ComAtprotoRepoDeleteRecordInput struct { ··· 14 16 SwapCommit *string `json:"swapCommit"` 15 17 } 16 18 17 - func (s *Server) handleDeleteRecord(e echo.Context) error { 18 - ctx := e.Request().Context() 19 + func (s *Server) handleDeleteRecord(w http.ResponseWriter, r *http.Request) { 20 + ctx := r.Context() 19 21 logger := s.logger.With("name", "handleDeleteRecord") 20 22 21 - repo := e.Get("repo").(*models.RepoActor) 23 + repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 22 24 23 25 var req ComAtprotoRepoDeleteRecordInput 24 - if err := e.Bind(&req); err != nil { 25 - logger.Error("error binding", "error", err) 26 - return helpers.ServerError(e, nil) 26 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 27 + logger.Error("error decoding", "error", err) 28 + helpers.ServerError(w, nil) 29 + return 27 30 } 28 31 29 - if err := e.Validate(req); err != nil { 32 + if err := s.validator.Struct(req); err != nil { 30 33 logger.Error("error validating", "error", err) 31 - return helpers.InputError(e, nil) 34 + helpers.InputError(w, nil) 35 + return 32 36 } 33 37 34 38 if repo.Repo.Did != req.Repo { 35 39 logger.Warn("mismatched repo/auth") 36 - return helpers.InputError(e, nil) 40 + helpers.InputError(w, nil) 41 + return 37 42 } 38 43 39 44 results, err := s.repoman.applyWrites(ctx, repo.Repo, []Op{ ··· 46 51 }, req.SwapCommit) 47 52 if err != nil { 48 53 logger.Error("error applying writes", "error", err) 49 - return helpers.ServerError(e, nil) 54 + helpers.ServerError(w, nil) 55 + return 50 56 } 51 57 52 58 results[0].Type = nil ··· 54 60 results[0].Cid = nil 55 61 results[0].ValidationStatus = nil 56 62 57 - return e.JSON(200, results[0]) 63 + s.writeJSON(w, 200, results[0]) 58 64 }
+21 -16
server/handle_repo_describe_repo.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 4 5 "strings" 5 6 6 7 "github.com/Azure/go-autorest/autorest/to" 7 8 "github.com/haileyok/cocoon/identity" 8 9 "github.com/haileyok/cocoon/internal/helpers" 9 10 "github.com/haileyok/cocoon/models" 10 - "github.com/labstack/echo/v4" 11 11 "gorm.io/gorm" 12 12 ) 13 13 ··· 19 19 HandleIsCorrect bool `json:"handleIsCorrect"` 20 20 } 21 21 22 - func (s *Server) handleDescribeRepo(e echo.Context) error { 23 - ctx := e.Request().Context() 22 + func (s *Server) handleDescribeRepo(w http.ResponseWriter, r *http.Request) { 23 + ctx := r.Context() 24 24 logger := s.logger.With("name", "handleDescribeRepo") 25 25 26 - did := e.QueryParam("repo") 26 + did := r.URL.Query().Get("repo") 27 27 repo, err := s.getRepoActorByDid(ctx, did) 28 28 if err != nil { 29 29 if err == gorm.ErrRecordNotFound { 30 - return helpers.InputError(e, to.StringPtr("RepoNotFound")) 30 + helpers.InputError(w, to.StringPtr("RepoNotFound")) 31 + return 31 32 } 32 33 33 34 logger.Error("error looking up repo", "error", err) 34 - return helpers.ServerError(e, nil) 35 + helpers.ServerError(w, nil) 36 + return 35 37 } 36 38 37 39 handleIsCorrect := true 38 40 39 - diddoc, err := s.passport.FetchDoc(e.Request().Context(), repo.Repo.Did) 41 + diddoc, err := s.passport.FetchDoc(r.Context(), repo.Repo.Did) 40 42 if err != nil { 41 43 logger.Error("error fetching diddoc", "error", err) 42 - return helpers.ServerError(e, nil) 44 + helpers.ServerError(w, nil) 45 + return 43 46 } 44 47 45 48 dochandle := "" ··· 55 58 } 56 59 57 60 if handleIsCorrect { 58 - resolvedDid, err := s.passport.ResolveHandle(e.Request().Context(), repo.Handle) 61 + resolvedDid, err := s.passport.ResolveHandle(r.Context(), repo.Handle) 59 62 if err != nil { 60 - e.Logger().Error("error resolving handle", "error", err) 61 - return helpers.ServerError(e, nil) 63 + logger.Error("error resolving handle", "error", err) 64 + helpers.ServerError(w, nil) 65 + return 62 66 } 63 67 64 68 if resolvedDid != repo.Repo.Did { ··· 69 73 var records []models.Record 70 74 if err := s.db.Raw(ctx, "SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil { 71 75 logger.Error("error getting collections", "error", err) 72 - return helpers.ServerError(e, nil) 76 + helpers.ServerError(w, nil) 77 + return 73 78 } 74 79 75 - var collections []string = make([]string, 0, len(records)) 76 - for _, r := range records { 77 - collections = append(collections, r.Nsid) 80 + collections := make([]string, 0, len(records)) 81 + for _, rec := range records { 82 + collections = append(collections, rec.Nsid) 78 83 } 79 84 80 - return e.JSON(200, ComAtprotoRepoDescribeRepoResponse{ 85 + s.writeJSON(w, 200, ComAtprotoRepoDescribeRepoResponse{ 81 86 Did: repo.Repo.Did, 82 87 Handle: repo.Handle, 83 88 DidDoc: *diddoc,
+16 -12
server/handle_repo_get_record.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 5 + 4 6 "github.com/bluesky-social/indigo/atproto/atdata" 5 7 "github.com/bluesky-social/indigo/atproto/syntax" 6 8 "github.com/haileyok/cocoon/models" 7 - "github.com/labstack/echo/v4" 8 9 ) 9 10 10 11 type ComAtprotoRepoGetRecordResponse struct { ··· 13 14 Value map[string]any `json:"value"` 14 15 } 15 16 16 - func (s *Server) handleRepoGetRecord(e echo.Context) error { 17 - ctx := e.Request().Context() 17 + func (s *Server) handleRepoGetRecord(w http.ResponseWriter, r *http.Request) { 18 + ctx := r.Context() 18 19 19 - repo := e.QueryParam("repo") 20 - collection := e.QueryParam("collection") 21 - rkey := e.QueryParam("rkey") 22 - cidstr := e.QueryParam("cid") 20 + repo := r.URL.Query().Get("repo") 21 + collection := r.URL.Query().Get("collection") 22 + rkey := r.URL.Query().Get("rkey") 23 + cidstr := r.URL.Query().Get("cid") 23 24 24 25 params := []any{repo, collection, rkey} 25 26 cidquery := "" ··· 27 28 if cidstr != "" { 28 29 c, err := syntax.ParseCID(cidstr) 29 30 if err != nil { 30 - return err 31 + http.Error(w, err.Error(), http.StatusBadRequest) 32 + return 31 33 } 32 34 params = append(params, c.String()) 33 35 cidquery = " AND cid = ?" ··· 35 37 36 38 var record models.Record 37 39 if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil { 38 - // TODO: handle error nicely 39 - return err 40 + http.Error(w, err.Error(), http.StatusInternalServerError) 41 + return 40 42 } 41 43 42 44 val, err := atdata.UnmarshalCBOR(record.Value) 43 45 if err != nil { 44 - return s.handleProxy(e) // TODO: this should be getting handled like...if we don't find it in the db. why doesn't it throw error up there? 46 + // Fall back to proxy if we can't find/decode the record locally 47 + s.handleProxy(w, r) 48 + return 45 49 } 46 50 47 - return e.JSON(200, ComAtprotoRepoGetRecordResponse{ 51 + s.writeJSON(w, 200, ComAtprotoRepoGetRecordResponse{ 48 52 Uri: "at://" + record.Did + "/" + record.Nsid + "/" + record.Rkey, 49 53 Cid: record.Cid, 50 54 Value: val,
+10 -9
server/handle_repo_list_missing_blobs.go
··· 2 2 3 3 import ( 4 4 "fmt" 5 + "net/http" 5 6 "strconv" 6 7 7 8 "github.com/bluesky-social/indigo/atproto/atdata" 8 9 "github.com/haileyok/cocoon/internal/helpers" 9 10 "github.com/haileyok/cocoon/models" 10 11 "github.com/ipfs/go-cid" 11 - "github.com/labstack/echo/v4" 12 12 ) 13 13 14 14 type ComAtprotoRepoListMissingBlobsResponse struct { ··· 21 21 RecordUri string `json:"recordUri"` 22 22 } 23 23 24 - func (s *Server) handleListMissingBlobs(e echo.Context) error { 25 - ctx := e.Request().Context() 26 - logger := s.logger.With("name", "handleListMissingBlos") 24 + func (s *Server) handleListMissingBlobs(w http.ResponseWriter, r *http.Request) { 25 + ctx := r.Context() 26 + logger := s.logger.With("name", "handleListMissingBlobs") 27 27 28 - urepo := e.Get("repo").(*models.RepoActor) 28 + urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 29 29 30 - limitStr := e.QueryParam("limit") 31 - cursor := e.QueryParam("cursor") 30 + limitStr := r.URL.Query().Get("limit") 31 + cursor := r.URL.Query().Get("cursor") 32 32 33 33 limit := 500 34 34 if limitStr != "" { ··· 40 40 var records []models.Record 41 41 if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil { 42 42 logger.Error("failed to get records for listMissingBlobs", "error", err) 43 - return helpers.ServerError(e, nil) 43 + helpers.ServerError(w, nil) 44 + return 44 45 } 45 46 46 47 type blobRef struct { ··· 95 96 nextCursor = &lastCid 96 97 } 97 98 98 - return e.JSON(200, ComAtprotoRepoListMissingBlobsResponse{ 99 + s.writeJSON(w, http.StatusOK, ComAtprotoRepoListMissingBlobsResponse{ 99 100 Cursor: nextCursor, 100 101 Blobs: missingBlobs, 101 102 })
+33 -26
server/handle_repo_list_records.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 4 5 "strconv" 5 6 6 7 "github.com/Azure/go-autorest/autorest/to" ··· 8 9 "github.com/bluesky-social/indigo/atproto/syntax" 9 10 "github.com/haileyok/cocoon/internal/helpers" 10 11 "github.com/haileyok/cocoon/models" 11 - "github.com/labstack/echo/v4" 12 12 ) 13 13 14 14 type ComAtprotoRepoListRecordsRequest struct { ··· 30 30 Value map[string]any `json:"value"` 31 31 } 32 32 33 - func getLimitFromContext(e echo.Context, def int) (int, error) { 33 + func getLimitFromRequest(r *http.Request, def int) (int, error) { 34 34 limit := def 35 - limitstr := e.QueryParam("limit") 35 + limitstr := r.URL.Query().Get("limit") 36 36 37 37 if limitstr != "" { 38 38 l64, err := strconv.ParseInt(limitstr, 10, 32) ··· 45 45 return limit, nil 46 46 } 47 47 48 - func (s *Server) handleListRecords(e echo.Context) error { 49 - ctx := e.Request().Context() 48 + func (s *Server) handleListRecords(w http.ResponseWriter, r *http.Request) { 49 + ctx := r.Context() 50 50 logger := s.logger.With("name", "handleListRecords") 51 51 52 - var req ComAtprotoRepoListRecordsRequest 53 - if err := e.Bind(&req); err != nil { 54 - logger.Error("could not bind list records request", "error", err) 55 - return helpers.ServerError(e, nil) 52 + req := ComAtprotoRepoListRecordsRequest{ 53 + Repo: r.URL.Query().Get("repo"), 54 + Collection: r.URL.Query().Get("collection"), 55 + Cursor: r.URL.Query().Get("cursor"), 56 56 } 57 - 58 - if err := e.Validate(req); err != nil { 59 - return helpers.InputError(e, nil) 57 + if v := r.URL.Query().Get("reverse"); v == "true" { 58 + req.Reverse = true 60 59 } 61 60 62 - if req.Limit <= 0 { 63 - req.Limit = 50 64 - } else if req.Limit > 100 { 65 - req.Limit = 100 61 + if err := s.validator.Struct(req); err != nil { 62 + helpers.InputError(w, nil) 63 + return 66 64 } 67 65 68 - limit, err := getLimitFromContext(e, 50) 66 + limit, err := getLimitFromRequest(r, 50) 69 67 if err != nil { 70 - return helpers.InputError(e, nil) 68 + helpers.InputError(w, nil) 69 + return 70 + } 71 + if limit <= 0 { 72 + limit = 50 73 + } else if limit > 100 { 74 + limit = 100 71 75 } 72 76 73 77 sort := "DESC" ··· 83 87 if _, err := syntax.ParseDID(did); err != nil { 84 88 actor, err := s.getActorByHandle(ctx, req.Repo) 85 89 if err != nil { 86 - return helpers.InputError(e, to.StringPtr("RepoNotFound")) 90 + helpers.InputError(w, to.StringPtr("RepoNotFound")) 91 + return 87 92 } 88 93 did = actor.Did 89 94 } ··· 98 103 var records []models.Record 99 104 if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil { 100 105 logger.Error("error getting records", "error", err) 101 - return helpers.ServerError(e, nil) 106 + helpers.ServerError(w, nil) 107 + return 102 108 } 103 109 104 110 items := []ComAtprotoRepoListRecordsRecordItem{} 105 - for _, r := range records { 106 - val, err := atdata.UnmarshalCBOR(r.Value) 111 + for _, rec := range records { 112 + val, err := atdata.UnmarshalCBOR(rec.Value) 107 113 if err != nil { 108 - return err 114 + helpers.ServerError(w, nil) 115 + return 109 116 } 110 117 111 118 items = append(items, ComAtprotoRepoListRecordsRecordItem{ 112 - Uri: "at://" + r.Did + "/" + r.Nsid + "/" + r.Rkey, 113 - Cid: r.Cid, 119 + Uri: "at://" + rec.Did + "/" + rec.Nsid + "/" + rec.Rkey, 120 + Cid: rec.Cid, 114 121 Value: val, 115 122 }) 116 123 } ··· 120 127 newcursor = to.StringPtr(records[len(records)-1].CreatedAt) 121 128 } 122 129 123 - return e.JSON(200, ComAtprotoRepoListRecordsResponse{ 130 + s.writeJSON(w, 200, ComAtprotoRepoListRecordsResponse{ 124 131 Cursor: newcursor, 125 132 Records: items, 126 133 })
+15 -12
server/handle_repo_list_repos.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 5 + 4 6 "github.com/haileyok/cocoon/models" 5 7 "github.com/ipfs/go-cid" 6 - "github.com/labstack/echo/v4" 7 8 ) 8 9 9 10 type ComAtprotoSyncListReposResponse struct { ··· 20 21 } 21 22 22 23 // TODO: paginate this bitch 23 - func (s *Server) handleListRepos(e echo.Context) error { 24 - ctx := e.Request().Context() 24 + func (s *Server) handleListRepos(w http.ResponseWriter, r *http.Request) { 25 + ctx := r.Context() 25 26 26 27 var repos []models.Repo 27 28 if err := s.db.Raw(ctx, "SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil { 28 - return err 29 + http.Error(w, err.Error(), http.StatusInternalServerError) 30 + return 29 31 } 30 32 31 33 items := make([]ComAtprotoSyncListReposRepoItem, 0, len(repos)) 32 - for _, r := range repos { 33 - c, err := cid.Cast(r.Root) 34 + for _, repo := range repos { 35 + c, err := cid.Cast(repo.Root) 34 36 if err != nil { 35 - return err 37 + http.Error(w, err.Error(), http.StatusInternalServerError) 38 + return 36 39 } 37 40 38 41 items = append(items, ComAtprotoSyncListReposRepoItem{ 39 - Did: r.Did, 42 + Did: repo.Did, 40 43 Head: c.String(), 41 - Rev: r.Rev, 42 - Active: r.Active(), 43 - Status: r.Status(), 44 + Rev: repo.Rev, 45 + Active: repo.Active(), 46 + Status: repo.Status(), 44 47 }) 45 48 } 46 49 47 - return e.JSON(200, ComAtprotoSyncListReposResponse{ 50 + s.writeJSON(w, 200, ComAtprotoSyncListReposResponse{ 48 51 Cursor: nil, 49 52 Repos: items, 50 53 })
+18 -12
server/handle_repo_put_record.go
··· 1 1 package server 2 2 3 3 import ( 4 + "encoding/json" 5 + "net/http" 6 + 4 7 "github.com/haileyok/cocoon/internal/helpers" 5 8 "github.com/haileyok/cocoon/models" 6 - "github.com/labstack/echo/v4" 7 9 ) 8 10 9 11 type ComAtprotoRepoPutRecordInput struct { ··· 16 18 SwapCommit *string `json:"swapCommit"` 17 19 } 18 20 19 - func (s *Server) handlePutRecord(e echo.Context) error { 20 - ctx := e.Request().Context() 21 + func (s *Server) handlePutRecord(w http.ResponseWriter, r *http.Request) { 22 + ctx := r.Context() 21 23 logger := s.logger.With("name", "handlePutRecord") 22 24 23 - repo := e.Get("repo").(*models.RepoActor) 25 + repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 24 26 25 27 var req ComAtprotoRepoPutRecordInput 26 - if err := e.Bind(&req); err != nil { 27 - logger.Error("error binding", "error", err) 28 - return helpers.ServerError(e, nil) 28 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 29 + logger.Error("error decoding", "error", err) 30 + helpers.ServerError(w, nil) 31 + return 29 32 } 30 33 31 - if err := e.Validate(req); err != nil { 34 + if err := s.validator.Struct(req); err != nil { 32 35 logger.Error("error validating", "error", err) 33 - return helpers.InputError(e, nil) 36 + helpers.InputError(w, nil) 37 + return 34 38 } 35 39 36 40 if repo.Repo.Did != req.Repo { 37 41 logger.Warn("mismatched repo/auth") 38 - return helpers.InputError(e, nil) 42 + helpers.InputError(w, nil) 43 + return 39 44 } 40 45 41 46 optype := OpTypeCreate ··· 55 60 }, req.SwapCommit) 56 61 if err != nil { 57 62 logger.Error("error applying writes", "error", err) 58 - return helpers.ServerError(e, nil) 63 + helpers.ServerError(w, nil) 64 + return 59 65 } 60 66 61 67 results[0].Type = nil 62 68 63 - return e.JSON(200, results[0]) 69 + s.writeJSON(w, 200, results[0]) 64 70 }
+18 -13
server/handle_repo_upload_blob.go
··· 10 10 "github.com/haileyok/cocoon/internal/helpers" 11 11 "github.com/haileyok/cocoon/models" 12 12 "github.com/ipfs/go-cid" 13 - "github.com/labstack/echo/v4" 14 13 "github.com/multiformats/go-multihash" 15 14 ) 16 15 ··· 29 28 } `json:"blob"` 30 29 } 31 30 32 - func (s *Server) handleRepoUploadBlob(e echo.Context) error { 33 - ctx := e.Request().Context() 31 + func (s *Server) handleRepoUploadBlob(w http.ResponseWriter, r *http.Request) { 32 + ctx := r.Context() 34 33 logger := s.logger.With("name", "handleRepoUploadBlob") 35 34 36 - urepo := e.Get("repo").(*models.RepoActor) 35 + urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 37 36 38 - mime := e.Request().Header.Get("content-type") 37 + mime := r.Header.Get("content-type") 39 38 if mime == "" { 40 39 mime = "application/octet-stream" 41 40 } ··· 55 54 56 55 if err := s.db.Create(ctx, &blob, nil).Error; err != nil { 57 56 logger.Error("error creating new blob in db", "error", err) 58 - return helpers.ServerError(e, nil) 57 + helpers.ServerError(w, nil) 58 + return 59 59 } 60 60 61 61 read := 0 ··· 65 65 fulldata := new(bytes.Buffer) 66 66 67 67 for { 68 - n, err := io.ReadFull(e.Request().Body, buf) 68 + n, err := io.ReadFull(r.Body, buf) 69 69 if err == io.ErrUnexpectedEOF || err == io.EOF { 70 70 if n == 0 { 71 71 break 72 72 } 73 73 } else if err != nil && err != io.ErrUnexpectedEOF { 74 74 logger.Error("error reading blob", "error", err) 75 - return helpers.ServerError(e, nil) 75 + helpers.ServerError(w, nil) 76 + return 76 77 } 77 78 78 79 data := buf[:n] ··· 88 89 89 90 if err := s.db.Create(ctx, &blobPart, nil).Error; err != nil { 90 91 logger.Error("error adding blob part to db", "error", err) 91 - return helpers.ServerError(e, nil) 92 + helpers.ServerError(w, nil) 93 + return 92 94 } 93 95 } 94 96 part++ ··· 101 103 c, err := cid.NewPrefixV1(cid.Raw, multihash.SHA2_256).Sum(fulldata.Bytes()) 102 104 if err != nil { 103 105 logger.Error("error creating cid prefix", "error", err) 104 - return helpers.ServerError(e, nil) 106 + helpers.ServerError(w, nil) 107 + return 105 108 } 106 109 107 110 if ipfsUpload { 108 111 ipfsCid, err := s.addBlobToIPFS(fulldata.Bytes(), mime) 109 112 if err != nil { 110 113 logger.Error("error adding blob to ipfs", "error", err) 111 - return helpers.ServerError(e, nil) 114 + helpers.ServerError(w, nil) 115 + return 112 116 } 113 117 114 118 // Overwrite the locally computed CID with the one returned by the IPFS ··· 126 130 127 131 if err := s.db.Exec(ctx, "UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil { 128 132 logger.Error("error updating blob", "error", err) 129 - return helpers.ServerError(e, nil) 133 + helpers.ServerError(w, nil) 134 + return 130 135 } 131 136 132 137 resp := ComAtprotoRepoUploadBlobResponse{} ··· 135 140 resp.Blob.MimeType = mime 136 141 resp.Blob.Size = read 137 142 138 - return e.JSON(200, resp) 143 + s.writeJSON(w, 200, resp) 139 144 } 140 145 141 146 // addBlobToIPFS adds raw blob data to the configured IPFS node via the Kubo
+7 -3
server/handle_robots.go
··· 1 1 package server 2 2 3 - import "github.com/labstack/echo/v4" 3 + import ( 4 + "fmt" 5 + "net/http" 6 + ) 4 7 5 - func (s *Server) handleRobots(e echo.Context) error { 6 - return e.String(200, "# Beep boop beep boop\n\n# Crawl me 🥺\nUser-agent: *\nAllow: /") 8 + func (s *Server) handleRobots(w http.ResponseWriter, r *http.Request) { 9 + w.Header().Set("Content-Type", "text/plain") 10 + fmt.Fprint(w, "# Beep boop beep boop\n\n# Crawl me 🥺\nUser-agent: *\nAllow: /") 7 11 }
+32 -28
server/handle_root.go
··· 1 1 package server 2 2 3 - import "github.com/labstack/echo/v4" 3 + import ( 4 + "fmt" 5 + "net/http" 6 + ) 4 7 5 - func (s *Server) handleRoot(e echo.Context) error { 6 - return e.String(200, ` 8 + func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) { 9 + w.Header().Set("Content-Type", "text/plain") 10 + fmt.Fprint(w, ` 7 11 8 - ....-*%%%##### 12 + ....-*%%%##### 9 13 .%#+++****#%%%%%%%%%#+:.... 10 - .%+++**++++*%%%%..... 11 - .%+++*****#%%%%#.. %#%... 12 - ***+*****%%%%%... =.. 13 - *****%%%%.. +=++.. 14 - %%%%%... .+----==++. 15 - .-::----===++ 16 - .=-:.------==+++ 17 - +-:::-:----===++.. 18 - =-::-----:-==+++-. 19 - .==*=------==++++. 20 - +-:--=++===*=--++. 21 - +:::--:=++=----=+.. 22 - *::::---=+#----=+. 23 - =::::----=+#---=+.. 24 - .::::----==+=--=+.. 25 - .-::-----==++=-=+.. 26 - -::-----==++===+.. 27 - =::-----==++==++ 28 - +::----:==++=+++ 29 - :-:----:==+++++. 30 - .=:=----=+++++. 31 - +=-=====+++.. 32 - =====++. 33 - =++... 14 + .%+++**++++*%%%%..... 15 + .%+++*****#%%%%#.. %#%... 16 + ***+*****%%%%%... =.. 17 + *****%%%%.. +=++.. 18 + %%%%%... .+----==++. 19 + .-::----===++ 20 + .=-:.------==+++ 21 + +-:::-:----===++.. 22 + =-::-----:-==+++-. 23 + .==*=------==++++. 24 + +-:--=++===*=--++. 25 + +:::--:=++=----=+.. 26 + *::::---=+#----=+. 27 + =::::----=+#---=+.. 28 + .::::----==+=--=+.. 29 + .-::-----==++=-=+.. 30 + -::-----==++===+.. 31 + =::-----==++==++ 32 + +::----:==++=+++ 33 + :-:----:==+++++. 34 + .=:=----=+++++. 35 + +=-=====+++.. 36 + =====++. 37 + =++... 34 38 35 39 36 40 This is an AT Protocol Personal Data Server (aka, an atproto PDS)
+8 -13
server/handle_server_activate_account.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "net/http" 5 6 "time" 6 7 7 8 "github.com/bluesky-social/indigo/api/atproto" ··· 9 10 "github.com/bluesky-social/indigo/util" 10 11 "github.com/haileyok/cocoon/internal/helpers" 11 12 "github.com/haileyok/cocoon/models" 12 - "github.com/labstack/echo/v4" 13 13 ) 14 14 15 15 type ComAtprotoServerActivateAccountRequest struct { ··· 17 17 DeleteAfter time.Time `json:"deleteAfter"` 18 18 } 19 19 20 - func (s *Server) handleServerActivateAccount(e echo.Context) error { 21 - ctx := e.Request().Context() 20 + func (s *Server) handleServerActivateAccount(w http.ResponseWriter, r *http.Request) { 21 + ctx := r.Context() 22 22 logger := s.logger.With("name", "handleServerActivateAccount") 23 23 24 - var req ComAtprotoServerDeactivateAccountRequest 25 - if err := e.Bind(&req); err != nil { 26 - logger.Error("error binding", "error", err) 27 - return helpers.ServerError(e, nil) 28 - } 29 - 30 - urepo := e.Get("repo").(*models.RepoActor) 24 + urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 31 25 32 26 if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil { 33 - logger.Error("error updating account status to deactivated", "error", err) 34 - return helpers.ServerError(e, nil) 27 + logger.Error("error updating account status to activated", "error", err) 28 + helpers.ServerError(w, nil) 29 + return 35 30 } 36 31 37 32 s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ ··· 44 39 }, 45 40 }) 46 41 47 - return e.NoContent(200) 42 + w.WriteHeader(http.StatusOK) 48 43 }
+14 -9
server/handle_server_check_account_status.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 5 + 4 6 "github.com/haileyok/cocoon/internal/helpers" 5 7 "github.com/haileyok/cocoon/models" 6 8 "github.com/ipfs/go-cid" 7 - "github.com/labstack/echo/v4" 8 9 ) 9 10 10 11 type ComAtprotoServerCheckAccountStatusResponse struct { ··· 19 20 ImportedBlobs int64 `json:"importedBlobs"` 20 21 } 21 22 22 - func (s *Server) handleServerCheckAccountStatus(e echo.Context) error { 23 - ctx := e.Request().Context() 23 + func (s *Server) handleServerCheckAccountStatus(w http.ResponseWriter, r *http.Request) { 24 + ctx := r.Context() 24 25 logger := s.logger.With("name", "handleServerCheckAccountStatus") 25 26 26 - urepo := e.Get("repo").(*models.RepoActor) 27 + urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 27 28 28 29 resp := ComAtprotoServerCheckAccountStatusResponse{ 29 30 Activated: true, // TODO: should allow for deactivation etc. ··· 35 36 rootcid, err := cid.Cast(urepo.Root) 36 37 if err != nil { 37 38 logger.Error("error casting cid", "error", err) 38 - return helpers.ServerError(e, nil) 39 + helpers.ServerError(w, nil) 40 + return 39 41 } 40 42 resp.RepoCommit = rootcid.String() 41 43 ··· 46 48 var blockCtResp CountResp 47 49 if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil { 48 50 logger.Error("error getting block count", "error", err) 49 - return helpers.ServerError(e, nil) 51 + helpers.ServerError(w, nil) 52 + return 50 53 } 51 54 resp.RepoBlocks = blockCtResp.Ct 52 55 53 56 var recCtResp CountResp 54 57 if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil { 55 58 logger.Error("error getting record count", "error", err) 56 - return helpers.ServerError(e, nil) 59 + helpers.ServerError(w, nil) 60 + return 57 61 } 58 62 resp.IndexedRecords = recCtResp.Ct 59 63 60 64 var blobCtResp CountResp 61 65 if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil { 62 66 logger.Error("error getting record count", "error", err) 63 - return helpers.ServerError(e, nil) 67 + helpers.ServerError(w, nil) 68 + return 64 69 } 65 70 resp.ExpectedBlobs = blobCtResp.Ct 66 71 67 - return e.JSON(200, resp) 72 + s.writeJSON(w, 200, resp) 68 73 }
+21 -14
server/handle_server_confirm_email.go
··· 1 1 package server 2 2 3 3 import ( 4 + "encoding/json" 5 + "net/http" 4 6 "time" 5 7 6 8 "github.com/Azure/go-autorest/autorest/to" 7 9 "github.com/haileyok/cocoon/internal/helpers" 8 10 "github.com/haileyok/cocoon/models" 9 - "github.com/labstack/echo/v4" 10 11 ) 11 12 12 13 type ComAtprotoServerConfirmEmailRequest struct { ··· 14 15 Token string `json:"token" validate:"required"` 15 16 } 16 17 17 - func (s *Server) handleServerConfirmEmail(e echo.Context) error { 18 - ctx := e.Request().Context() 18 + func (s *Server) handleServerConfirmEmail(w http.ResponseWriter, r *http.Request) { 19 + ctx := r.Context() 19 20 logger := s.logger.With("name", "handleServerConfirmEmail") 20 21 21 - urepo := e.Get("repo").(*models.RepoActor) 22 + urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 22 23 23 24 var req ComAtprotoServerConfirmEmailRequest 24 - if err := e.Bind(&req); err != nil { 25 - logger.Error("error binding", "error", err) 26 - return helpers.ServerError(e, nil) 25 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 26 + logger.Error("error decoding", "error", err) 27 + helpers.ServerError(w, nil) 28 + return 27 29 } 28 30 29 - if err := e.Validate(req); err != nil { 30 - return helpers.InputError(e, nil) 31 + if err := s.validator.Struct(req); err != nil { 32 + helpers.InputError(w, nil) 33 + return 31 34 } 32 35 33 36 if urepo.EmailVerificationCode == nil || urepo.EmailVerificationCodeExpiresAt == nil { 34 - return helpers.ExpiredTokenError(e) 37 + helpers.ExpiredTokenError(w) 38 + return 35 39 } 36 40 37 41 if *urepo.EmailVerificationCode != req.Token { 38 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 42 + helpers.InputError(w, to.StringPtr("InvalidToken")) 43 + return 39 44 } 40 45 41 46 if time.Now().UTC().After(*urepo.EmailVerificationCodeExpiresAt) { 42 - return helpers.ExpiredTokenError(e) 47 + helpers.ExpiredTokenError(w) 48 + return 43 49 } 44 50 45 51 now := time.Now().UTC() 46 52 47 53 if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil { 48 54 logger.Error("error updating user", "error", err) 49 - return helpers.ServerError(e, nil) 55 + helpers.ServerError(w, nil) 56 + return 50 57 } 51 58 52 - return e.NoContent(200) 59 + w.WriteHeader(http.StatusOK) 53 60 }
+67 -39
server/handle_server_create_account.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "encoding/json" 5 6 "errors" 6 7 "fmt" 8 + "net/http" 7 9 "strings" 8 10 "time" 9 11 ··· 17 19 "github.com/bluesky-social/indigo/util" 18 20 "github.com/haileyok/cocoon/internal/helpers" 19 21 "github.com/haileyok/cocoon/models" 20 - "github.com/labstack/echo/v4" 21 22 "golang.org/x/crypto/bcrypt" 22 23 "gorm.io/gorm" 23 24 ) ··· 37 38 Did string `json:"did"` 38 39 } 39 40 40 - func (s *Server) handleCreateAccount(e echo.Context) error { 41 - ctx := e.Request().Context() 41 + func (s *Server) handleCreateAccount(w http.ResponseWriter, r *http.Request) { 42 + ctx := r.Context() 42 43 logger := s.logger.With("name", "handleServerCreateAccount") 43 44 44 45 var request ComAtprotoServerCreateAccountRequest 45 46 46 - if err := e.Bind(&request); err != nil { 47 + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { 47 48 logger.Error("error receiving request", "endpoint", "com.atproto.server.createAccount", "error", err) 48 - return helpers.ServerError(e, nil) 49 + helpers.ServerError(w, nil) 50 + return 49 51 } 50 52 51 53 request.Handle = strings.ToLower(request.Handle) 52 54 53 - if err := e.Validate(request); err != nil { 55 + if err := s.validator.Struct(request); err != nil { 54 56 logger.Error("error validating request", "endpoint", "com.atproto.server.createAccount", "error", err) 55 57 56 58 var verr ValidationError 57 59 if errors.As(err, &verr) { 58 60 if verr.Field == "Email" { 59 - // TODO: what is this supposed to be? `InvalidEmail` isn't listed in doc 60 - return helpers.InputError(e, to.StringPtr("InvalidEmail")) 61 + helpers.InputError(w, to.StringPtr("InvalidEmail")) 62 + return 61 63 } 62 64 63 65 if verr.Field == "Handle" { 64 - return helpers.InputError(e, to.StringPtr("InvalidHandle")) 66 + helpers.InputError(w, to.StringPtr("InvalidHandle")) 67 + return 65 68 } 66 69 67 70 if verr.Field == "Password" { 68 - return helpers.InputError(e, to.StringPtr("InvalidPassword")) 71 + helpers.InputError(w, to.StringPtr("InvalidPassword")) 72 + return 69 73 } 70 74 71 75 if verr.Field == "InviteCode" { 72 - return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 76 + helpers.InputError(w, to.StringPtr("InvalidInviteCode")) 77 + return 73 78 } 74 79 } 75 80 } ··· 78 83 if request.Did != nil { 79 84 signupDid = *request.Did 80 85 81 - token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1)) 86 + token := strings.TrimSpace(strings.Replace(r.Header.Get("authorization"), "Bearer ", "", 1)) 82 87 if token == "" { 83 - return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did")) 88 + helpers.UnauthorizedError(w, to.StringPtr("must authenticate to use an existing did")) 89 + return 84 90 } 85 - authDid, err := s.validateServiceAuth(e.Request().Context(), token, "com.atproto.server.createAccount") 91 + authDid, err := s.validateServiceAuth(r.Context(), token, "com.atproto.server.createAccount") 86 92 87 93 if err != nil { 88 94 logger.Warn("error validating authorization token", "endpoint", "com.atproto.server.createAccount", "error", err) 89 - return helpers.UnauthorizedError(e, to.StringPtr("invalid authorization token")) 95 + helpers.UnauthorizedError(w, to.StringPtr("invalid authorization token")) 96 + return 90 97 } 91 98 92 99 if authDid != signupDid { 93 - return helpers.ForbiddenError(e, to.StringPtr("auth did did not match signup did")) 100 + helpers.ForbiddenError(w, to.StringPtr("auth did did not match signup did")) 101 + return 94 102 } 95 103 } 96 104 ··· 98 106 actor, err := s.getActorByHandle(ctx, request.Handle) 99 107 if err != nil && err != gorm.ErrRecordNotFound { 100 108 logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err) 101 - return helpers.ServerError(e, nil) 109 + helpers.ServerError(w, nil) 110 + return 102 111 } 103 112 if err == nil && actor.Did != signupDid { 104 - return helpers.InputError(e, to.StringPtr("HandleNotAvailable")) 113 + helpers.InputError(w, to.StringPtr("HandleNotAvailable")) 114 + return 105 115 } 106 116 107 - if did, err := s.passport.ResolveHandle(e.Request().Context(), request.Handle); err == nil && did != signupDid { 108 - return helpers.InputError(e, to.StringPtr("HandleNotAvailable")) 117 + if did, err := s.passport.ResolveHandle(r.Context(), request.Handle); err == nil && did != signupDid { 118 + helpers.InputError(w, to.StringPtr("HandleNotAvailable")) 119 + return 109 120 } 110 121 111 122 var ic models.InviteCode 112 123 if s.config.RequireInvite { 113 124 if strings.TrimSpace(request.InviteCode) == "" { 114 - return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 125 + helpers.InputError(w, to.StringPtr("InvalidInviteCode")) 126 + return 115 127 } 116 128 117 129 if err := s.db.Raw(ctx, "SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 118 130 if err == gorm.ErrRecordNotFound { 119 - return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 131 + helpers.InputError(w, to.StringPtr("InvalidInviteCode")) 132 + return 120 133 } 121 134 logger.Error("error getting invite code from db", "error", err) 122 - return helpers.ServerError(e, nil) 135 + helpers.ServerError(w, nil) 136 + return 123 137 } 124 138 125 139 if ic.RemainingUseCount < 1 { 126 - return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 140 + helpers.InputError(w, to.StringPtr("InvalidInviteCode")) 141 + return 127 142 } 128 143 } 129 144 ··· 131 146 existingRepo, err := s.getRepoByEmail(ctx, request.Email) 132 147 if err != nil && err != gorm.ErrRecordNotFound { 133 148 logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err) 134 - return helpers.ServerError(e, nil) 149 + helpers.ServerError(w, nil) 150 + return 135 151 } 136 152 if err == nil && existingRepo.Did != signupDid { 137 - return helpers.InputError(e, to.StringPtr("EmailNotAvailable")) 153 + helpers.InputError(w, to.StringPtr("EmailNotAvailable")) 154 + return 138 155 } 139 156 140 157 // TODO: unsupported domains ··· 165 182 k, err = atcrypto.GeneratePrivateKeyK256() 166 183 if err != nil { 167 184 logger.Error("error creating signing key", "endpoint", "com.atproto.server.createAccount", "error", err) 168 - return helpers.ServerError(e, nil) 185 + helpers.ServerError(w, nil) 186 + return 169 187 } 170 188 } 171 189 ··· 173 191 did, op, err := s.plcClient.CreateDID(k, "", request.Handle) 174 192 if err != nil { 175 193 logger.Error("error creating operation", "endpoint", "com.atproto.server.createAccount", "error", err) 176 - return helpers.ServerError(e, nil) 194 + helpers.ServerError(w, nil) 195 + return 177 196 } 178 197 179 - if err := s.plcClient.SendOperation(e.Request().Context(), did, op); err != nil { 198 + if err := s.plcClient.SendOperation(r.Context(), did, op); err != nil { 180 199 logger.Error("error sending plc op", "endpoint", "com.atproto.server.createAccount", "error", err) 181 - return helpers.ServerError(e, nil) 200 + helpers.ServerError(w, nil) 201 + return 182 202 } 183 203 signupDid = did 184 204 } ··· 186 206 hashed, err := bcrypt.GenerateFromPassword([]byte(request.Password), 10) 187 207 if err != nil { 188 208 logger.Error("error hashing password", "error", err) 189 - return helpers.ServerError(e, nil) 209 + helpers.ServerError(w, nil) 210 + return 190 211 } 191 212 192 213 urepo := models.Repo{ ··· 206 227 207 228 if err := s.db.Create(ctx, &urepo, nil).Error; err != nil { 208 229 logger.Error("error inserting new repo", "error", err) 209 - return helpers.ServerError(e, nil) 230 + helpers.ServerError(w, nil) 231 + return 210 232 } 211 233 212 234 if err := s.db.Create(ctx, &actor, nil).Error; err != nil { 213 235 logger.Error("error inserting new actor", "error", err) 214 - return helpers.ServerError(e, nil) 236 + helpers.ServerError(w, nil) 237 + return 215 238 } 216 239 } else { 217 240 if err := s.db.Save(ctx, &actor, nil).Error; err != nil { 218 241 logger.Error("error inserting new actor", "error", err) 219 - return helpers.ServerError(e, nil) 242 + helpers.ServerError(w, nil) 243 + return 220 244 } 221 245 } 222 246 ··· 234 258 root, rev, err := commitRepo(context.TODO(), bs, r, urepo.SigningKey) 235 259 if err != nil { 236 260 logger.Error("error committing", "error", err) 237 - return helpers.ServerError(e, nil) 261 + helpers.ServerError(w, nil) 262 + return 238 263 } 239 264 240 265 if err := s.UpdateRepo(context.TODO(), urepo.Did, root, rev); err != nil { 241 266 logger.Error("error updating repo after commit", "error", err) 242 - return helpers.ServerError(e, nil) 267 + helpers.ServerError(w, nil) 268 + return 243 269 } 244 270 245 271 s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ ··· 255 281 if s.config.RequireInvite { 256 282 if err := s.db.Raw(ctx, "UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 257 283 logger.Error("error decrementing use count", "error", err) 258 - return helpers.ServerError(e, nil) 284 + helpers.ServerError(w, nil) 285 + return 259 286 } 260 287 } 261 288 262 289 sess, err := s.createSession(ctx, &urepo) 263 290 if err != nil { 264 291 logger.Error("error creating new session", "error", err) 265 - return helpers.ServerError(e, nil) 292 + helpers.ServerError(w, nil) 293 + return 266 294 } 267 295 268 296 go func() { ··· 274 302 } 275 303 }() 276 304 277 - return e.JSON(200, ComAtprotoServerCreateAccountResponse{ 305 + s.writeJSON(w, 200, ComAtprotoServerCreateAccountResponse{ 278 306 AccessJwt: sess.AccessToken, 279 307 RefreshJwt: sess.RefreshToken, 280 308 Handle: request.Handle,
+15 -10
server/handle_server_create_invite_code.go
··· 1 1 package server 2 2 3 3 import ( 4 + "encoding/json" 5 + "net/http" 6 + 4 7 "github.com/google/uuid" 5 8 "github.com/haileyok/cocoon/internal/helpers" 6 9 "github.com/haileyok/cocoon/models" 7 - "github.com/labstack/echo/v4" 8 10 ) 9 11 10 12 type ComAtprotoServerCreateInviteCodeRequest struct { ··· 16 18 Code string `json:"code"` 17 19 } 18 20 19 - func (s *Server) handleCreateInviteCode(e echo.Context) error { 20 - ctx := e.Request().Context() 21 + func (s *Server) handleCreateInviteCode(w http.ResponseWriter, r *http.Request) { 22 + ctx := r.Context() 21 23 logger := s.logger.With("name", "handleServerCreateInviteCode") 22 24 23 25 var req ComAtprotoServerCreateInviteCodeRequest 24 - if err := e.Bind(&req); err != nil { 25 - logger.Error("error binding", "error", err) 26 - return helpers.ServerError(e, nil) 26 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 27 + logger.Error("error decoding", "error", err) 28 + helpers.ServerError(w, nil) 29 + return 27 30 } 28 31 29 - if err := e.Validate(req); err != nil { 32 + if err := s.validator.Struct(req); err != nil { 30 33 logger.Error("error validating", "error", err) 31 - return helpers.InputError(e, nil) 34 + helpers.InputError(w, nil) 35 + return 32 36 } 33 37 34 38 ic := uuid.NewString() ··· 46 50 RemainingUseCount: req.UseCount, 47 51 }, nil).Error; err != nil { 48 52 logger.Error("error creating invite code", "error", err) 49 - return helpers.ServerError(e, nil) 53 + helpers.ServerError(w, nil) 54 + return 50 55 } 51 56 52 - return e.JSON(200, ComAtprotoServerCreateInviteCodeResponse{ 57 + s.writeJSON(w, 200, ComAtprotoServerCreateInviteCodeResponse{ 53 58 Code: ic, 54 59 }) 55 60 }
+15 -10
server/handle_server_create_invite_codes.go
··· 1 1 package server 2 2 3 3 import ( 4 + "encoding/json" 5 + "net/http" 6 + 4 7 "github.com/Azure/go-autorest/autorest/to" 5 8 "github.com/google/uuid" 6 9 "github.com/haileyok/cocoon/internal/helpers" 7 10 "github.com/haileyok/cocoon/models" 8 - "github.com/labstack/echo/v4" 9 11 ) 10 12 11 13 type ComAtprotoServerCreateInviteCodesRequest struct { ··· 21 23 Codes []string `json:"codes"` 22 24 } 23 25 24 - func (s *Server) handleCreateInviteCodes(e echo.Context) error { 25 - ctx := e.Request().Context() 26 + func (s *Server) handleCreateInviteCodes(w http.ResponseWriter, r *http.Request) { 27 + ctx := r.Context() 26 28 logger := s.logger.With("name", "handleServerCreateInviteCodes") 27 29 28 30 var req ComAtprotoServerCreateInviteCodesRequest 29 - if err := e.Bind(&req); err != nil { 30 - logger.Error("error binding", "error", err) 31 - return helpers.ServerError(e, nil) 31 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 32 + logger.Error("error decoding", "error", err) 33 + helpers.ServerError(w, nil) 34 + return 32 35 } 33 36 34 - if err := e.Validate(req); err != nil { 37 + if err := s.validator.Struct(req); err != nil { 35 38 logger.Error("error validating", "error", err) 36 - return helpers.InputError(e, nil) 39 + helpers.InputError(w, nil) 40 + return 37 41 } 38 42 39 43 if req.CodeCount == nil { ··· 59 63 RemainingUseCount: req.UseCount, 60 64 }, nil).Error; err != nil { 61 65 logger.Error("error creating invite code", "error", err) 62 - return helpers.ServerError(e, nil) 66 + helpers.ServerError(w, nil) 67 + return 63 68 } 64 69 } 65 70 ··· 69 74 }) 70 75 } 71 76 72 - return e.JSON(200, codes) 77 + s.writeJSON(w, 200, codes) 73 78 }
+36 -22
server/handle_server_create_session.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "encoding/json" 5 6 "errors" 6 7 "fmt" 8 + "net/http" 7 9 "strings" 8 10 "time" 9 11 ··· 11 13 "github.com/bluesky-social/indigo/atproto/syntax" 12 14 "github.com/haileyok/cocoon/internal/helpers" 13 15 "github.com/haileyok/cocoon/models" 14 - "github.com/labstack/echo/v4" 15 16 "golang.org/x/crypto/bcrypt" 16 17 "gorm.io/gorm" 17 18 ) ··· 34 35 Status *string `json:"status,omitempty"` 35 36 } 36 37 37 - func (s *Server) handleCreateSession(e echo.Context) error { 38 - ctx := e.Request().Context() 38 + func (s *Server) handleCreateSession(w http.ResponseWriter, r *http.Request) { 39 + ctx := r.Context() 39 40 logger := s.logger.With("name", "handleServerCreateSession") 40 41 41 42 var req ComAtprotoServerCreateSessionRequest 42 - if err := e.Bind(&req); err != nil { 43 - logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err) 44 - return helpers.ServerError(e, nil) 43 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 44 + logger.Error("error decoding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err) 45 + helpers.ServerError(w, nil) 46 + return 45 47 } 46 48 47 - if err := e.Validate(req); err != nil { 49 + if err := s.validator.Struct(req); err != nil { 48 50 var verr ValidationError 49 51 if errors.As(err, &verr) { 50 52 if verr.Field == "Identifier" { 51 - return helpers.InputError(e, to.StringPtr("InvalidRequest")) 53 + helpers.InputError(w, to.StringPtr("InvalidRequest")) 54 + return 52 55 } 53 56 54 57 if verr.Field == "Password" { 55 - return helpers.InputError(e, to.StringPtr("InvalidRequest")) 58 + helpers.InputError(w, to.StringPtr("InvalidRequest")) 59 + return 56 60 } 57 61 } 58 62 } ··· 80 84 81 85 if err != nil { 82 86 if err == gorm.ErrRecordNotFound { 83 - return helpers.InputError(e, to.StringPtr("InvalidRequest")) 87 + helpers.InputError(w, to.StringPtr("InvalidRequest")) 88 + return 84 89 } 85 90 86 - logger.Error("erorr looking up repo", "endpoint", "com.atproto.server.createSession", "error", err) 87 - return helpers.ServerError(e, nil) 91 + logger.Error("error looking up repo", "endpoint", "com.atproto.server.createSession", "error", err) 92 + helpers.ServerError(w, nil) 93 + return 88 94 } 89 95 90 96 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil { 91 97 if err != bcrypt.ErrMismatchedHashAndPassword { 92 - logger.Error("erorr comparing hash and password", "error", err) 98 + logger.Error("error comparing hash and password", "error", err) 93 99 } 94 - return helpers.InputError(e, to.StringPtr("InvalidRequest")) 100 + helpers.InputError(w, to.StringPtr("InvalidRequest")) 101 + return 95 102 } 96 103 97 104 // if repo requires 2FA token and one hasn't been provided, return error prompting for one ··· 99 106 err = s.createAndSendTwoFactorCode(ctx, repo) 100 107 if err != nil { 101 108 logger.Error("sending 2FA code", "error", err) 102 - return helpers.ServerError(e, nil) 109 + helpers.ServerError(w, nil) 110 + return 103 111 } 104 112 105 - return helpers.InputError(e, to.StringPtr("AuthFactorTokenRequired")) 113 + helpers.InputError(w, to.StringPtr("AuthFactorTokenRequired")) 114 + return 106 115 } 107 116 108 117 // if 2FA is required, now check that the one provided is valid ··· 111 120 err = s.createAndSendTwoFactorCode(ctx, repo) 112 121 if err != nil { 113 122 logger.Error("sending 2FA code", "error", err) 114 - return helpers.ServerError(e, nil) 123 + helpers.ServerError(w, nil) 124 + return 115 125 } 116 126 117 - return helpers.InputError(e, to.StringPtr("AuthFactorTokenRequired")) 127 + helpers.InputError(w, to.StringPtr("AuthFactorTokenRequired")) 128 + return 118 129 } 119 130 120 131 if *repo.TwoFactorCode != *req.AuthFactorToken { 121 - return helpers.InvalidTokenError(e) 132 + helpers.InvalidTokenError(w) 133 + return 122 134 } 123 135 124 136 if time.Now().UTC().After(*repo.TwoFactorCodeExpiresAt) { 125 - return helpers.ExpiredTokenError(e) 137 + helpers.ExpiredTokenError(w) 138 + return 126 139 } 127 140 } 128 141 129 142 sess, err := s.createSession(ctx, &repo.Repo) 130 143 if err != nil { 131 144 logger.Error("error creating session", "error", err) 132 - return helpers.ServerError(e, nil) 145 + helpers.ServerError(w, nil) 146 + return 133 147 } 134 148 135 - return e.JSON(200, ComAtprotoServerCreateSessionResponse{ 149 + s.writeJSON(w, 200, ComAtprotoServerCreateSessionResponse{ 136 150 AccessJwt: sess.AccessToken, 137 151 RefreshJwt: sess.RefreshToken, 138 152 Handle: repo.Handle,
+7 -12
server/handle_server_deactivate_account.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "net/http" 5 6 "time" 6 7 7 8 "github.com/Azure/go-autorest/autorest/to" ··· 10 11 "github.com/bluesky-social/indigo/util" 11 12 "github.com/haileyok/cocoon/internal/helpers" 12 13 "github.com/haileyok/cocoon/models" 13 - "github.com/labstack/echo/v4" 14 14 ) 15 15 16 16 type ComAtprotoServerDeactivateAccountRequest struct { ··· 18 18 DeleteAfter time.Time `json:"deleteAfter"` 19 19 } 20 20 21 - func (s *Server) handleServerDeactivateAccount(e echo.Context) error { 22 - ctx := e.Request().Context() 21 + func (s *Server) handleServerDeactivateAccount(w http.ResponseWriter, r *http.Request) { 22 + ctx := r.Context() 23 23 logger := s.logger.With("name", "handleServerDeactivateAccount") 24 24 25 - var req ComAtprotoServerDeactivateAccountRequest 26 - if err := e.Bind(&req); err != nil { 27 - logger.Error("error binding", "error", err) 28 - return helpers.ServerError(e, nil) 29 - } 30 - 31 - urepo := e.Get("repo").(*models.RepoActor) 25 + urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 32 26 33 27 if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, true, urepo.Repo.Did).Error; err != nil { 34 28 logger.Error("error updating account status to deactivated", "error", err) 35 - return helpers.ServerError(e, nil) 29 + helpers.ServerError(w, nil) 30 + return 36 31 } 37 32 38 33 s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ ··· 45 40 }, 46 41 }) 47 42 48 - return e.NoContent(200) 43 + w.WriteHeader(http.StatusOK) 49 44 }
+44 -25
server/handle_server_delete_account.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "encoding/json" 6 + "net/http" 5 7 "time" 6 8 7 9 "github.com/Azure/go-autorest/autorest/to" ··· 9 11 "github.com/bluesky-social/indigo/events" 10 12 "github.com/bluesky-social/indigo/util" 11 13 "github.com/haileyok/cocoon/internal/helpers" 12 - "github.com/labstack/echo/v4" 13 14 "golang.org/x/crypto/bcrypt" 14 15 ) 15 16 ··· 19 20 Token string `json:"token" validate:"required"` 20 21 } 21 22 22 - func (s *Server) handleServerDeleteAccount(e echo.Context) error { 23 - ctx := e.Request().Context() 23 + func (s *Server) handleServerDeleteAccount(w http.ResponseWriter, r *http.Request) { 24 + ctx := r.Context() 24 25 logger := s.logger.With("name", "handleServerDeleteAccount") 25 26 26 27 var req ComAtprotoServerDeleteAccountRequest 27 - if err := e.Bind(&req); err != nil { 28 - logger.Error("error binding", "error", err) 29 - return helpers.ServerError(e, nil) 28 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 29 + logger.Error("error decoding", "error", err) 30 + helpers.ServerError(w, nil) 31 + return 30 32 } 31 33 32 - if err := e.Validate(&req); err != nil { 34 + if err := s.validator.Struct(&req); err != nil { 33 35 logger.Error("error validating", "error", err) 34 - return helpers.ServerError(e, nil) 36 + helpers.ServerError(w, nil) 37 + return 35 38 } 36 39 37 40 urepo, err := s.getRepoActorByDid(ctx, req.Did) 38 41 if err != nil { 39 42 logger.Error("error getting repo", "error", err) 40 - return echo.NewHTTPError(400, "account not found") 43 + s.writeJSON(w, 400, map[string]string{"error": "account not found"}) 44 + return 41 45 } 42 46 43 47 if err := bcrypt.CompareHashAndPassword([]byte(urepo.Repo.Password), []byte(req.Password)); err != nil { 44 48 logger.Error("password mismatch", "error", err) 45 - return echo.NewHTTPError(401, "Invalid did or password") 49 + s.writeJSON(w, 401, map[string]string{"error": "Invalid did or password"}) 50 + return 46 51 } 47 52 48 53 if urepo.Repo.AccountDeleteCode == nil || urepo.Repo.AccountDeleteCodeExpiresAt == nil { 49 54 logger.Error("no deletion token found for account") 50 - return echo.NewHTTPError(400, map[string]interface{}{ 55 + s.writeJSON(w, 400, map[string]any{ 51 56 "error": "InvalidToken", 52 57 "message": "Token is invalid", 53 58 }) 59 + return 54 60 } 55 61 56 62 if *urepo.Repo.AccountDeleteCode != req.Token { 57 63 logger.Error("deletion token mismatch") 58 - return echo.NewHTTPError(400, map[string]interface{}{ 64 + s.writeJSON(w, 400, map[string]any{ 59 65 "error": "InvalidToken", 60 66 "message": "Token is invalid", 61 67 }) 68 + return 62 69 } 63 70 64 71 if time.Now().UTC().After(*urepo.Repo.AccountDeleteCodeExpiresAt) { 65 72 logger.Error("deletion token expired") 66 - return echo.NewHTTPError(400, map[string]interface{}{ 73 + s.writeJSON(w, 400, map[string]any{ 67 74 "error": "ExpiredToken", 68 75 "message": "Token is expired", 69 76 }) 77 + return 70 78 } 71 79 72 80 tx := s.db.Begin(ctx) 73 81 if tx.Error != nil { 74 82 logger.Error("error starting transaction", "error", tx.Error) 75 - return helpers.ServerError(e, nil) 83 + helpers.ServerError(w, nil) 84 + return 76 85 } 77 86 78 87 status := "error" ··· 86 95 87 96 if err := tx.Exec("DELETE FROM blocks WHERE did = ?", req.Did).Error; err != nil { 88 97 logger.Error("error deleting blocks", "error", err) 89 - return helpers.ServerError(e, nil) 98 + helpers.ServerError(w, nil) 99 + return 90 100 } 91 101 92 102 if err := tx.Exec("DELETE FROM records WHERE did = ?", req.Did).Error; err != nil { 93 103 logger.Error("error deleting records", "error", err) 94 - return helpers.ServerError(e, nil) 104 + helpers.ServerError(w, nil) 105 + return 95 106 } 96 107 97 108 if err := tx.Exec("DELETE FROM blobs WHERE did = ?", req.Did).Error; err != nil { 98 109 logger.Error("error deleting blobs", "error", err) 99 - return helpers.ServerError(e, nil) 110 + helpers.ServerError(w, nil) 111 + return 100 112 } 101 113 102 114 if err := tx.Exec("DELETE FROM tokens WHERE did = ?", req.Did).Error; err != nil { 103 115 logger.Error("error deleting tokens", "error", err) 104 - return helpers.ServerError(e, nil) 116 + helpers.ServerError(w, nil) 117 + return 105 118 } 106 119 107 120 if err := tx.Exec("DELETE FROM refresh_tokens WHERE did = ?", req.Did).Error; err != nil { 108 121 logger.Error("error deleting refresh tokens", "error", err) 109 - return helpers.ServerError(e, nil) 122 + helpers.ServerError(w, nil) 123 + return 110 124 } 111 125 112 126 if err := tx.Exec("DELETE FROM reserved_keys WHERE did = ?", req.Did).Error; err != nil { 113 127 logger.Error("error deleting reserved keys", "error", err) 114 - return helpers.ServerError(e, nil) 128 + helpers.ServerError(w, nil) 129 + return 115 130 } 116 131 117 132 if err := tx.Exec("DELETE FROM invite_codes WHERE did = ?", req.Did).Error; err != nil { 118 133 logger.Error("error deleting invite codes", "error", err) 119 - return helpers.ServerError(e, nil) 134 + helpers.ServerError(w, nil) 135 + return 120 136 } 121 137 122 138 if err := tx.Exec("DELETE FROM actors WHERE did = ?", req.Did).Error; err != nil { 123 139 logger.Error("error deleting actor", "error", err) 124 - return helpers.ServerError(e, nil) 140 + helpers.ServerError(w, nil) 141 + return 125 142 } 126 143 127 144 if err := tx.Exec("DELETE FROM repos WHERE did = ?", req.Did).Error; err != nil { 128 145 logger.Error("error deleting repo", "error", err) 129 - return helpers.ServerError(e, nil) 146 + helpers.ServerError(w, nil) 147 + return 130 148 } 131 149 132 150 status = "ok" 133 151 134 152 if err := tx.Commit().Error; err != nil { 135 153 logger.Error("error committing transaction", "error", err) 136 - return helpers.ServerError(e, nil) 154 + helpers.ServerError(w, nil) 155 + return 137 156 } 138 157 139 158 s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ ··· 146 165 }, 147 166 }) 148 167 149 - return e.NoContent(200) 168 + w.WriteHeader(http.StatusOK) 150 169 }
+10 -7
server/handle_server_delete_session.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 5 + 4 6 "github.com/haileyok/cocoon/internal/helpers" 5 7 "github.com/haileyok/cocoon/models" 6 - "github.com/labstack/echo/v4" 7 8 ) 8 9 9 - func (s *Server) handleDeleteSession(e echo.Context) error { 10 - ctx := e.Request().Context() 10 + func (s *Server) handleDeleteSession(w http.ResponseWriter, r *http.Request) { 11 + ctx := r.Context() 11 12 12 - token := e.Get("token").(string) 13 + token, _ := getContextValue[string](r, contextKeyToken) 13 14 14 15 var acctok models.Token 15 16 if err := s.db.Raw(ctx, "DELETE FROM tokens WHERE token = ? RETURNING *", nil, token).Scan(&acctok).Error; err != nil { 16 17 s.logger.Error("error deleting access token from db", "error", err) 17 - return helpers.ServerError(e, nil) 18 + helpers.ServerError(w, nil) 19 + return 18 20 } 19 21 20 22 if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, acctok.RefreshToken).Error; err != nil { 21 23 s.logger.Error("error deleting refresh token from db", "error", err) 22 - return helpers.ServerError(e, nil) 24 + helpers.ServerError(w, nil) 25 + return 23 26 } 24 27 25 - return e.NoContent(200) 28 + w.WriteHeader(http.StatusOK) 26 29 }
+3 -3
server/handle_server_describe_server.go
··· 1 1 package server 2 2 3 - import "github.com/labstack/echo/v4" 3 + import "net/http" 4 4 5 5 type ComAtprotoServerDescribeServerResponseLinks struct { 6 6 PrivacyPolicy *string `json:"privacyPolicy,omitempty"` ··· 20 20 Did string `json:"did"` 21 21 } 22 22 23 - func (s *Server) handleDescribeServer(e echo.Context) error { 24 - return e.JSON(200, ComAtprotoServerDescribeServerResponse{ 23 + func (s *Server) handleDescribeServer(w http.ResponseWriter, r *http.Request) { 24 + s.writeJSON(w, 200, ComAtprotoServerDescribeServerResponse{ 25 25 InviteCodeRequired: s.config.RequireInvite, 26 26 PhoneVerificationRequired: false, 27 27 AvailableUserDomains: []string{"." + s.config.Hostname}, // TODO: more
+30 -19
server/handle_server_get_service_auth.go
··· 6 6 "encoding/base64" 7 7 "encoding/json" 8 8 "fmt" 9 + "net/http" 9 10 "strings" 10 11 "time" 11 12 ··· 13 14 "github.com/google/uuid" 14 15 "github.com/haileyok/cocoon/internal/helpers" 15 16 "github.com/haileyok/cocoon/models" 16 - "github.com/labstack/echo/v4" 17 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18 18 ) 19 19 20 20 type ServerGetServiceAuthRequest struct { 21 - Aud string `query:"aud" validate:"required,atproto-did"` 22 - // exp should be a float, as some clients will send a non-integer expiration 21 + Aud string `query:"aud" validate:"required,atproto-did"` 23 22 Exp float64 `query:"exp"` 24 23 Lxm string `query:"lxm"` 25 24 } 26 25 27 - func (s *Server) handleServerGetServiceAuth(e echo.Context) error { 26 + func (s *Server) handleServerGetServiceAuth(w http.ResponseWriter, r *http.Request) { 28 27 logger := s.logger.With("name", "handleServerGetServiceAuth") 29 28 30 - var req ServerGetServiceAuthRequest 31 - if err := e.Bind(&req); err != nil { 32 - logger.Error("could not bind service auth request", "error", err) 33 - return helpers.ServerError(e, nil) 29 + req := ServerGetServiceAuthRequest{ 30 + Aud: r.URL.Query().Get("aud"), 31 + Lxm: r.URL.Query().Get("lxm"), 32 + } 33 + if v := r.URL.Query().Get("exp"); v != "" { 34 + var exp float64 35 + if _, err := fmt.Sscanf(v, "%f", &exp); err == nil { 36 + req.Exp = exp 37 + } 34 38 } 35 39 36 - if err := e.Validate(req); err != nil { 37 - return helpers.InputError(e, nil) 40 + if err := s.validator.Struct(req); err != nil { 41 + helpers.InputError(w, nil) 42 + return 38 43 } 39 44 40 45 exp := int64(req.Exp) ··· 44 49 } 45 50 46 51 if req.Lxm == "com.atproto.server.getServiceAuth" { 47 - return helpers.InputError(e, to.StringPtr("may not generate auth tokens recursively")) 52 + helpers.InputError(w, to.StringPtr("may not generate auth tokens recursively")) 53 + return 48 54 } 49 55 50 56 var maxExp int64 ··· 54 60 maxExp = now + 60 55 61 } 56 62 if exp > maxExp { 57 - return helpers.InputError(e, to.StringPtr("expiration too big. smoller please")) 63 + helpers.InputError(w, to.StringPtr("expiration too big. smoller please")) 64 + return 58 65 } 59 66 60 - repo := e.Get("repo").(*models.RepoActor) 67 + repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 61 68 62 69 header := map[string]string{ 63 70 "alg": "ES256K", ··· 67 74 hj, err := json.Marshal(header) 68 75 if err != nil { 69 76 logger.Error("error marshaling header", "error", err) 70 - return helpers.ServerError(e, nil) 77 + helpers.ServerError(w, nil) 78 + return 71 79 } 72 80 73 81 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") ··· 84 92 } 85 93 pj, err := json.Marshal(payload) 86 94 if err != nil { 87 - logger.Error("error marashaling payload", "error", err) 88 - return helpers.ServerError(e, nil) 95 + logger.Error("error marshaling payload", "error", err) 96 + helpers.ServerError(w, nil) 97 + return 89 98 } 90 99 91 100 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=") ··· 96 105 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 97 106 if err != nil { 98 107 logger.Error("can't load private key", "error", err) 99 - return err 108 + helpers.ServerError(w, nil) 109 + return 100 110 } 101 111 102 112 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 103 113 if err != nil { 104 114 logger.Error("error signing", "error", err) 105 - return helpers.ServerError(e, nil) 115 + helpers.ServerError(w, nil) 116 + return 106 117 } 107 118 108 119 rBytes := R.Bytes() ··· 117 128 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=") 118 129 token := fmt.Sprintf("%s.%s", input, encsig) 119 130 120 - return e.JSON(200, map[string]string{ 131 + s.writeJSON(w, 200, map[string]string{ 121 132 "token": token, 122 133 }) 123 134 }
+5 -4
server/handle_server_get_session.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 5 + 4 6 "github.com/haileyok/cocoon/models" 5 - "github.com/labstack/echo/v4" 6 7 ) 7 8 8 9 type ComAtprotoServerGetSessionResponse struct { ··· 15 16 Status *string `json:"status,omitempty"` 16 17 } 17 18 18 - func (s *Server) handleGetSession(e echo.Context) error { 19 - repo := e.Get("repo").(*models.RepoActor) 19 + func (s *Server) handleGetSession(w http.ResponseWriter, r *http.Request) { 20 + repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 20 21 21 - return e.JSON(200, ComAtprotoServerGetSessionResponse{ 22 + s.writeJSON(w, 200, ComAtprotoServerGetSessionResponse{ 22 23 Handle: repo.Handle, 23 24 Did: repo.Repo.Did, 24 25 Email: repo.Email,
+13 -9
server/handle_server_refresh_session.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 5 + 4 6 "github.com/haileyok/cocoon/internal/helpers" 5 7 "github.com/haileyok/cocoon/models" 6 - "github.com/labstack/echo/v4" 7 8 ) 8 9 9 10 type ComAtprotoServerRefreshSessionResponse struct { ··· 15 16 Status *string `json:"status,omitempty"` 16 17 } 17 18 18 - func (s *Server) handleRefreshSession(e echo.Context) error { 19 - ctx := e.Request().Context() 19 + func (s *Server) handleRefreshSession(w http.ResponseWriter, r *http.Request) { 20 + ctx := r.Context() 20 21 logger := s.logger.With("name", "handleServerRefreshSession") 21 22 22 - token := e.Get("token").(string) 23 - repo := e.Get("repo").(*models.RepoActor) 23 + token, _ := getContextValue[string](r, contextKeyToken) 24 + repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 24 25 25 26 if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, token).Error; err != nil { 26 27 logger.Error("error getting refresh token from db", "error", err) 27 - return helpers.ServerError(e, nil) 28 + helpers.ServerError(w, nil) 29 + return 28 30 } 29 31 30 32 if err := s.db.Exec(ctx, "DELETE FROM tokens WHERE refresh_token = ?", nil, token).Error; err != nil { 31 33 logger.Error("error deleting access token from db", "error", err) 32 - return helpers.ServerError(e, nil) 34 + helpers.ServerError(w, nil) 35 + return 33 36 } 34 37 35 38 sess, err := s.createSession(ctx, &repo.Repo) 36 39 if err != nil { 37 40 logger.Error("error creating new session for refresh", "error", err) 38 - return helpers.ServerError(e, nil) 41 + helpers.ServerError(w, nil) 42 + return 39 43 } 40 44 41 - return e.JSON(200, ComAtprotoServerRefreshSessionResponse{ 45 + s.writeJSON(w, 200, ComAtprotoServerRefreshSessionResponse{ 42 46 AccessJwt: sess.AccessToken, 43 47 RefreshJwt: sess.RefreshToken, 44 48 Handle: repo.Handle,
+7 -6
server/handle_server_request_account_delete.go
··· 2 2 3 3 import ( 4 4 "fmt" 5 + "net/http" 5 6 "time" 6 7 7 8 "github.com/haileyok/cocoon/internal/helpers" 8 9 "github.com/haileyok/cocoon/models" 9 - "github.com/labstack/echo/v4" 10 10 ) 11 11 12 - func (s *Server) handleServerRequestAccountDelete(e echo.Context) error { 13 - ctx := e.Request().Context() 12 + func (s *Server) handleServerRequestAccountDelete(w http.ResponseWriter, r *http.Request) { 13 + ctx := r.Context() 14 14 logger := s.logger.With("name", "handleServerRequestAccountDelete") 15 15 16 - urepo := e.Get("repo").(*models.RepoActor) 16 + urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 17 17 18 18 token := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 19 19 expiresAt := time.Now().UTC().Add(15 * time.Minute) 20 20 21 21 if err := s.db.Exec(ctx, "UPDATE repos SET account_delete_code = ?, account_delete_code_expires_at = ? WHERE did = ?", nil, token, expiresAt, urepo.Repo.Did).Error; err != nil { 22 22 logger.Error("error setting deletion token", "error", err) 23 - return helpers.ServerError(e, nil) 23 + helpers.ServerError(w, nil) 24 + return 24 25 } 25 26 26 27 if urepo.Email != "" { ··· 29 30 } 30 31 } 31 32 32 - return e.NoContent(200) 33 + w.WriteHeader(http.StatusOK) 33 34 } 34 35 35 36 func (s *Server) sendAccountDeleteEmail(email, handle, token string) error {
+11 -8
server/handle_server_request_email_confirmation.go
··· 2 2 3 3 import ( 4 4 "fmt" 5 + "net/http" 5 6 "time" 6 7 7 8 "github.com/Azure/go-autorest/autorest/to" 8 9 "github.com/haileyok/cocoon/internal/helpers" 9 10 "github.com/haileyok/cocoon/models" 10 - "github.com/labstack/echo/v4" 11 11 ) 12 12 13 - func (s *Server) handleServerRequestEmailConfirmation(e echo.Context) error { 14 - ctx := e.Request().Context() 13 + func (s *Server) handleServerRequestEmailConfirmation(w http.ResponseWriter, r *http.Request) { 14 + ctx := r.Context() 15 15 logger := s.logger.With("name", "handleServerRequestEmailConfirm") 16 16 17 - urepo := e.Get("repo").(*models.RepoActor) 17 + urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 18 18 19 19 if urepo.EmailConfirmedAt != nil { 20 - return helpers.InputError(e, to.StringPtr("InvalidRequest")) 20 + helpers.InputError(w, to.StringPtr("InvalidRequest")) 21 + return 21 22 } 22 23 23 24 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) ··· 25 26 26 27 if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 27 28 logger.Error("error updating user", "error", err) 28 - return helpers.ServerError(e, nil) 29 + helpers.ServerError(w, nil) 30 + return 29 31 } 30 32 31 33 if err := s.sendEmailVerification(urepo.Email, urepo.Handle, code); err != nil { 32 34 logger.Error("error sending mail", "error", err) 33 - return helpers.ServerError(e, nil) 35 + helpers.ServerError(w, nil) 36 + return 34 37 } 35 38 36 - return e.NoContent(200) 39 + w.WriteHeader(http.StatusOK) 37 40 }
+9 -7
server/handle_server_request_email_update.go
··· 2 2 3 3 import ( 4 4 "fmt" 5 + "net/http" 5 6 "time" 6 7 7 8 "github.com/haileyok/cocoon/internal/helpers" 8 9 "github.com/haileyok/cocoon/models" 9 - "github.com/labstack/echo/v4" 10 10 ) 11 11 12 12 type ComAtprotoRequestEmailUpdateResponse struct { 13 13 TokenRequired bool `json:"tokenRequired"` 14 14 } 15 15 16 - func (s *Server) handleServerRequestEmailUpdate(e echo.Context) error { 17 - ctx := e.Request().Context() 16 + func (s *Server) handleServerRequestEmailUpdate(w http.ResponseWriter, r *http.Request) { 17 + ctx := r.Context() 18 18 logger := s.logger.With("name", "handleServerRequestEmailUpdate") 19 19 20 - urepo := e.Get("repo").(*models.RepoActor) 20 + urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 21 21 22 22 if urepo.EmailConfirmedAt != nil { 23 23 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) ··· 25 25 26 26 if err := s.db.Exec(ctx, "UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 27 27 logger.Error("error updating repo", "error", err) 28 - return helpers.ServerError(e, nil) 28 + helpers.ServerError(w, nil) 29 + return 29 30 } 30 31 31 32 if err := s.sendEmailUpdate(urepo.Email, urepo.Handle, code); err != nil { 32 33 logger.Error("error sending email", "error", err) 33 - return helpers.ServerError(e, nil) 34 + helpers.ServerError(w, nil) 35 + return 34 36 } 35 37 } 36 38 37 - return e.JSON(200, ComAtprotoRequestEmailUpdateResponse{ 39 + s.writeJSON(w, 200, ComAtprotoRequestEmailUpdateResponse{ 38 40 TokenRequired: urepo.EmailConfirmedAt != nil, 39 41 }) 40 42 }
+21 -13
server/handle_server_request_password_reset.go
··· 1 1 package server 2 2 3 3 import ( 4 + "encoding/json" 4 5 "fmt" 6 + "net/http" 5 7 "time" 6 8 7 9 "github.com/haileyok/cocoon/internal/helpers" 8 10 "github.com/haileyok/cocoon/models" 9 - "github.com/labstack/echo/v4" 10 11 ) 11 12 12 13 type ComAtprotoServerRequestPasswordResetRequest struct { 13 14 Email string `json:"email" validate:"required"` 14 15 } 15 16 16 - func (s *Server) handleServerRequestPasswordReset(e echo.Context) error { 17 - ctx := e.Request().Context() 17 + func (s *Server) handleServerRequestPasswordReset(w http.ResponseWriter, r *http.Request) { 18 + ctx := r.Context() 18 19 logger := s.logger.With("name", "handleServerRequestPasswordReset") 19 20 20 - urepo, ok := e.Get("repo").(*models.RepoActor) 21 - if !ok { 21 + var urepo *models.RepoActor 22 + if repo, ok := getContextValue[*models.RepoActor](r, contextKeyRepo); ok { 23 + urepo = repo 24 + } else { 22 25 var req ComAtprotoServerRequestPasswordResetRequest 23 - if err := e.Bind(&req); err != nil { 24 - return err 26 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 27 + helpers.ServerError(w, nil) 28 + return 25 29 } 26 30 27 - if err := e.Validate(req); err != nil { 28 - return err 31 + if err := s.validator.Struct(req); err != nil { 32 + helpers.InputError(w, nil) 33 + return 29 34 } 30 35 31 36 murepo, err := s.getRepoActorByEmail(ctx, req.Email) 32 37 if err != nil { 33 - return err 38 + helpers.ServerError(w, nil) 39 + return 34 40 } 35 41 36 42 urepo = murepo ··· 41 47 42 48 if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 43 49 logger.Error("error updating repo", "error", err) 44 - return helpers.ServerError(e, nil) 50 + helpers.ServerError(w, nil) 51 + return 45 52 } 46 53 47 54 if err := s.sendPasswordReset(urepo.Email, urepo.Handle, code); err != nil { 48 55 logger.Error("error sending email", "error", err) 49 - return helpers.ServerError(e, nil) 56 + helpers.ServerError(w, nil) 57 + return 50 58 } 51 59 52 - return e.NoContent(200) 60 + w.WriteHeader(http.StatusOK) 53 61 }
+17 -11
server/handle_server_reserve_signing_key.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "encoding/json" 6 + "net/http" 5 7 "time" 6 8 7 9 "github.com/bluesky-social/indigo/atproto/atcrypto" 8 10 "github.com/haileyok/cocoon/internal/helpers" 9 11 "github.com/haileyok/cocoon/models" 10 - "github.com/labstack/echo/v4" 11 12 ) 12 13 13 14 type ServerReserveSigningKeyRequest struct { ··· 18 19 SigningKey string `json:"signingKey"` 19 20 } 20 21 21 - func (s *Server) handleServerReserveSigningKey(e echo.Context) error { 22 - ctx := e.Request().Context() 22 + func (s *Server) handleServerReserveSigningKey(w http.ResponseWriter, r *http.Request) { 23 + ctx := r.Context() 23 24 logger := s.logger.With("name", "handleServerReserveSigningKey") 24 25 25 26 var req ServerReserveSigningKeyRequest 26 - if err := e.Bind(&req); err != nil { 27 - logger.Error("could not bind reserve signing key request", "error", err) 28 - return helpers.ServerError(e, nil) 27 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 28 + logger.Error("could not decode reserve signing key request", "error", err) 29 + helpers.ServerError(w, nil) 30 + return 29 31 } 30 32 31 33 if req.Did != nil && *req.Did != "" { 32 34 var existing models.ReservedKey 33 35 if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE did = ?", nil, *req.Did).Scan(&existing).Error; err == nil && existing.KeyDid != "" { 34 - return e.JSON(200, ServerReserveSigningKeyResponse{ 36 + s.writeJSON(w, 200, ServerReserveSigningKeyResponse{ 35 37 SigningKey: existing.KeyDid, 36 38 }) 39 + return 37 40 } 38 41 } 39 42 40 43 k, err := atcrypto.GeneratePrivateKeyK256() 41 44 if err != nil { 42 45 logger.Error("error creating signing key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err) 43 - return helpers.ServerError(e, nil) 46 + helpers.ServerError(w, nil) 47 + return 44 48 } 45 49 46 50 pubKey, err := k.PublicKey() 47 51 if err != nil { 48 52 logger.Error("error getting public key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err) 49 - return helpers.ServerError(e, nil) 53 + helpers.ServerError(w, nil) 54 + return 50 55 } 51 56 52 57 keyDid := pubKey.DIDKey() ··· 60 65 61 66 if err := s.db.Create(ctx, &reservedKey, nil).Error; err != nil { 62 67 logger.Error("error storing reserved key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err) 63 - return helpers.ServerError(e, nil) 68 + helpers.ServerError(w, nil) 69 + return 64 70 } 65 71 66 72 logger.Info("reserved signing key", "keyDid", keyDid, "forDid", req.Did) 67 73 68 - return e.JSON(200, ServerReserveSigningKeyResponse{ 74 + s.writeJSON(w, 200, ServerReserveSigningKeyResponse{ 69 75 SigningKey: keyDid, 70 76 }) 71 77 }
+23 -15
server/handle_server_reset_password.go
··· 1 1 package server 2 2 3 3 import ( 4 + "encoding/json" 5 + "net/http" 4 6 "time" 5 7 6 8 "github.com/Azure/go-autorest/autorest/to" 7 9 "github.com/haileyok/cocoon/internal/helpers" 8 10 "github.com/haileyok/cocoon/models" 9 - "github.com/labstack/echo/v4" 10 11 "golang.org/x/crypto/bcrypt" 11 12 ) 12 13 ··· 15 16 Password string `json:"password" validate:"required"` 16 17 } 17 18 18 - func (s *Server) handleServerResetPassword(e echo.Context) error { 19 - ctx := e.Request().Context() 19 + func (s *Server) handleServerResetPassword(w http.ResponseWriter, r *http.Request) { 20 + ctx := r.Context() 20 21 logger := s.logger.With("name", "handleServerResetPassword") 21 22 22 - urepo := e.Get("repo").(*models.RepoActor) 23 + urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 23 24 24 25 var req ComAtprotoServerResetPasswordRequest 25 - if err := e.Bind(&req); err != nil { 26 - logger.Error("error binding", "error", err) 27 - return helpers.ServerError(e, nil) 26 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 27 + logger.Error("error decoding", "error", err) 28 + helpers.ServerError(w, nil) 29 + return 28 30 } 29 31 30 - if err := e.Validate(req); err != nil { 31 - return helpers.InputError(e, nil) 32 + if err := s.validator.Struct(req); err != nil { 33 + helpers.InputError(w, nil) 34 + return 32 35 } 33 36 34 37 if urepo.PasswordResetCode == nil || urepo.PasswordResetCodeExpiresAt == nil { 35 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 38 + helpers.InputError(w, to.StringPtr("InvalidToken")) 39 + return 36 40 } 37 41 38 42 if *urepo.PasswordResetCode != req.Token { 39 - return helpers.InvalidTokenError(e) 43 + helpers.InvalidTokenError(w) 44 + return 40 45 } 41 46 42 47 if time.Now().UTC().After(*urepo.PasswordResetCodeExpiresAt) { 43 - return helpers.ExpiredTokenError(e) 48 + helpers.ExpiredTokenError(w) 49 + return 44 50 } 45 51 46 52 hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 10) 47 53 if err != nil { 48 54 logger.Error("error creating hash", "error", err) 49 - return helpers.ServerError(e, nil) 55 + helpers.ServerError(w, nil) 56 + return 50 57 } 51 58 52 59 if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", nil, hash, urepo.Repo.Did).Error; err != nil { 53 60 logger.Error("error updating repo", "error", err) 54 - return helpers.ServerError(e, nil) 61 + helpers.ServerError(w, nil) 62 + return 55 63 } 56 64 57 - return e.NoContent(200) 65 + w.WriteHeader(http.StatusOK) 58 66 }
+11 -8
server/handle_server_resolve_handle.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "net/http" 5 6 6 7 "github.com/Azure/go-autorest/autorest/to" 7 8 "github.com/bluesky-social/indigo/atproto/syntax" 8 9 "github.com/haileyok/cocoon/internal/helpers" 9 - "github.com/labstack/echo/v4" 10 10 ) 11 11 12 - func (s *Server) handleResolveHandle(e echo.Context) error { 12 + func (s *Server) handleResolveHandle(w http.ResponseWriter, r *http.Request) { 13 13 logger := s.logger.With("name", "handleServerResolveHandle") 14 14 15 15 type Resp struct { 16 16 Did string `json:"did"` 17 17 } 18 18 19 - handle := e.QueryParam("handle") 19 + handle := r.URL.Query().Get("handle") 20 20 21 21 if handle == "" { 22 - return helpers.InputError(e, to.StringPtr("Handle must be supplied in request.")) 22 + helpers.InputError(w, to.StringPtr("Handle must be supplied in request.")) 23 + return 23 24 } 24 25 25 26 parsed, err := syntax.ParseHandle(handle) 26 27 if err != nil { 27 - return helpers.InputError(e, to.StringPtr("Invalid handle.")) 28 + helpers.InputError(w, to.StringPtr("Invalid handle.")) 29 + return 28 30 } 29 31 30 - ctx := context.WithValue(e.Request().Context(), "skip-cache", true) 32 + ctx := context.WithValue(r.Context(), "skip-cache", true) 31 33 did, err := s.passport.ResolveHandle(ctx, parsed.String()) 32 34 if err != nil { 33 35 logger.Error("error resolving handle", "error", err) 34 - return helpers.ServerError(e, nil) 36 + helpers.ServerError(w, nil) 37 + return 35 38 } 36 39 37 - return e.JSON(200, Resp{ 40 + s.writeJSON(w, 200, Resp{ 38 41 Did: did, 39 42 }) 40 43 }
+24 -16
server/handle_server_update_email.go
··· 1 1 package server 2 2 3 3 import ( 4 + "encoding/json" 5 + "net/http" 4 6 "time" 5 7 6 8 "github.com/haileyok/cocoon/internal/helpers" 7 9 "github.com/haileyok/cocoon/models" 8 - "github.com/labstack/echo/v4" 9 10 ) 10 11 11 12 type ComAtprotoServerUpdateEmailRequest struct { ··· 14 15 Token string `json:"token"` 15 16 } 16 17 17 - func (s *Server) handleServerUpdateEmail(e echo.Context) error { 18 - ctx := e.Request().Context() 18 + func (s *Server) handleServerUpdateEmail(w http.ResponseWriter, r *http.Request) { 19 + ctx := r.Context() 19 20 logger := s.logger.With("name", "handleServerUpdateEmail") 20 21 21 - urepo := e.Get("repo").(*models.RepoActor) 22 + urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo) 22 23 23 24 var req ComAtprotoServerUpdateEmailRequest 24 - if err := e.Bind(&req); err != nil { 25 - logger.Error("error binding", "error", err) 26 - return helpers.ServerError(e, nil) 25 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 26 + logger.Error("error decoding", "error", err) 27 + helpers.ServerError(w, nil) 28 + return 27 29 } 28 30 29 - if err := e.Validate(req); err != nil { 30 - return helpers.InputError(e, nil) 31 + if err := s.validator.Struct(req); err != nil { 32 + helpers.InputError(w, nil) 33 + return 31 34 } 32 35 33 36 // To disable email auth factor a token is required. 34 37 // To enable email auth factor a token is not required. 35 38 // If updating an email address, a token will be sent anyway 36 - if urepo.TwoFactorType != models.TwoFactorTypeNone && req.EmailAuthFactor == false && req.Token == "" { 37 - return helpers.InvalidTokenError(e) 39 + if urepo.TwoFactorType != models.TwoFactorTypeNone && !req.EmailAuthFactor && req.Token == "" { 40 + helpers.InvalidTokenError(w) 41 + return 38 42 } 39 43 40 44 if req.Token != "" { 41 45 if urepo.EmailUpdateCode == nil || urepo.EmailUpdateCodeExpiresAt == nil { 42 - return helpers.InvalidTokenError(e) 46 + helpers.InvalidTokenError(w) 47 + return 43 48 } 44 49 45 50 if *urepo.EmailUpdateCode != req.Token { 46 - return helpers.InvalidTokenError(e) 51 + helpers.InvalidTokenError(w) 52 + return 47 53 } 48 54 49 55 if time.Now().UTC().After(*urepo.EmailUpdateCodeExpiresAt) { 50 - return helpers.ExpiredTokenError(e) 56 + helpers.ExpiredTokenError(w) 57 + return 51 58 } 52 59 } 53 60 ··· 66 73 67 74 if err := s.db.Exec(ctx, query, nil, twoFactorType, req.Email, urepo.Repo.Did).Error; err != nil { 68 75 logger.Error("error updating repo", "error", err) 69 - return helpers.ServerError(e, nil) 76 + helpers.ServerError(w, nil) 77 + return 70 78 } 71 79 72 - return e.NoContent(200) 80 + w.WriteHeader(http.StatusOK) 73 81 }
+30 -19
server/handle_sync_get_blob.go
··· 10 10 "github.com/haileyok/cocoon/internal/helpers" 11 11 "github.com/haileyok/cocoon/models" 12 12 "github.com/ipfs/go-cid" 13 - "github.com/labstack/echo/v4" 14 13 ) 15 14 16 - func (s *Server) handleSyncGetBlob(e echo.Context) error { 17 - ctx := e.Request().Context() 15 + func (s *Server) handleSyncGetBlob(w http.ResponseWriter, r *http.Request) { 16 + ctx := r.Context() 18 17 logger := s.logger.With("name", "handleSyncGetBlob") 19 18 20 - did := e.QueryParam("did") 19 + did := r.URL.Query().Get("did") 21 20 if did == "" { 22 - return helpers.InputError(e, nil) 21 + helpers.InputError(w, nil) 22 + return 23 23 } 24 24 25 - cstr := e.QueryParam("cid") 25 + cstr := r.URL.Query().Get("cid") 26 26 if cstr == "" { 27 - return helpers.InputError(e, nil) 27 + helpers.InputError(w, nil) 28 + return 28 29 } 29 30 30 31 c, err := cid.Parse(cstr) 31 32 if err != nil { 32 - return helpers.InputError(e, nil) 33 + helpers.InputError(w, nil) 34 + return 33 35 } 34 36 35 37 urepo, err := s.getRepoActorByDid(ctx, did) 36 38 if err != nil { 37 39 logger.Error("could not find user for requested blob", "error", err) 38 - return helpers.InputError(e, nil) 40 + helpers.InputError(w, nil) 41 + return 39 42 } 40 43 41 44 status := urepo.Status() 42 45 if status != nil { 43 46 if *status == "deactivated" { 44 - return helpers.InputError(e, to.StringPtr("RepoDeactivated")) 47 + helpers.InputError(w, to.StringPtr("RepoDeactivated")) 48 + return 45 49 } 46 50 } 47 51 48 52 var blob models.Blob 49 53 if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil { 50 54 logger.Error("error looking up blob", "error", err) 51 - return helpers.ServerError(e, nil) 55 + helpers.ServerError(w, nil) 56 + return 52 57 } 53 58 54 59 buf := new(bytes.Buffer) ··· 58 63 var parts []models.BlobPart 59 64 if err := s.db.Raw(ctx, "SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil { 60 65 logger.Error("error getting blob parts", "error", err) 61 - return helpers.ServerError(e, nil) 66 + helpers.ServerError(w, nil) 67 + return 62 68 } 63 69 64 70 for _, p := range parts { ··· 68 74 case "ipfs": 69 75 if s.ipfsConfig == nil || !s.ipfsConfig.BlobstoreEnabled { 70 76 logger.Error("ipfs storage disabled") 71 - return helpers.ServerError(e, nil) 77 + helpers.ServerError(w, nil) 78 + return 72 79 } 73 80 74 81 // If a public gateway is configured, redirect the client directly to it 75 82 // instead of proxying the content through this server. 76 83 if s.ipfsConfig.GatewayURL != "" { 77 84 redirectURL := fmt.Sprintf("%s/ipfs/%s", s.ipfsConfig.GatewayURL, c.String()) 78 - return e.Redirect(302, redirectURL) 85 + http.Redirect(w, r, redirectURL, http.StatusFound) 86 + return 79 87 } 80 88 81 89 // Otherwise fetch from the local Kubo node via /api/v0/cat and stream ··· 83 91 data, err := s.fetchBlobFromIPFS(c.String()) 84 92 if err != nil { 85 93 logger.Error("error fetching blob from ipfs node", "cid", c.String(), "error", err) 86 - return helpers.ServerError(e, nil) 94 + helpers.ServerError(w, nil) 95 + return 87 96 } 88 97 buf.Write(data) 89 98 90 99 default: 91 100 logger.Error("unknown storage", "storage", blob.Storage) 92 - return helpers.ServerError(e, nil) 101 + helpers.ServerError(w, nil) 102 + return 93 103 } 94 104 95 - e.Response().Header().Set(echo.HeaderContentDisposition, "attachment; filename="+c.String()) 96 - 97 - return e.Stream(200, "application/octet-stream", buf) 105 + w.Header().Set("Content-Disposition", "attachment; filename="+c.String()) 106 + w.Header().Set("Content-Type", "application/octet-stream") 107 + w.WriteHeader(http.StatusOK) 108 + io.Copy(w, buf) 98 109 } 99 110 100 111 // fetchBlobFromIPFS retrieves blob data for the given CID from the local Kubo
+38 -18
server/handle_sync_get_blocks.go
··· 2 2 3 3 import ( 4 4 "bytes" 5 + "net/http" 5 6 6 7 "github.com/bluesky-social/indigo/carstore" 7 8 "github.com/haileyok/cocoon/internal/helpers" 8 9 "github.com/ipfs/go-cid" 9 10 cbor "github.com/ipfs/go-ipld-cbor" 10 11 "github.com/ipld/go-car" 11 - "github.com/labstack/echo/v4" 12 12 ) 13 13 14 14 type ComAtprotoSyncGetBlocksRequest struct { ··· 16 16 Cids []string `query:"cids"` 17 17 } 18 18 19 - func (s *Server) handleGetBlocks(e echo.Context) error { 20 - ctx := e.Request().Context() 19 + func (s *Server) handleGetBlocks(w http.ResponseWriter, r *http.Request) { 20 + ctx := r.Context() 21 21 logger := s.logger.With("name", "handleSyncGetBlocks") 22 22 23 - var req ComAtprotoSyncGetBlocksRequest 24 - if err := e.Bind(&req); err != nil { 25 - return helpers.InputError(e, nil) 23 + did := r.URL.Query().Get("did") 24 + if did == "" { 25 + helpers.InputError(w, nil) 26 + return 26 27 } 27 28 29 + cidsParam := r.URL.Query()["cids"] 28 30 var cids []cid.Cid 29 31 30 - for _, cs := range req.Cids { 32 + for _, cs := range cidsParam { 31 33 c, err := cid.Cast([]byte(cs)) 32 34 if err != nil { 33 - return err 35 + logger.Error("error parsing cid", "cid", cs, "error", err) 36 + helpers.InputError(w, nil) 37 + return 34 38 } 35 - 36 39 cids = append(cids, c) 37 40 } 38 41 39 - urepo, err := s.getRepoActorByDid(ctx, req.Did) 42 + urepo, err := s.getRepoActorByDid(ctx, did) 40 43 if err != nil { 41 - return helpers.ServerError(e, nil) 44 + logger.Error("could not find repo", "did", did, "error", err) 45 + helpers.ServerError(w, nil) 46 + return 42 47 } 43 48 44 - buf := new(bytes.Buffer) 45 49 rc, err := cid.Cast(urepo.Root) 46 50 if err != nil { 47 - return err 51 + logger.Error("error casting root cid", "error", err) 52 + helpers.ServerError(w, nil) 53 + return 48 54 } 49 55 50 56 hb, err := cbor.DumpObject(&car.CarHeader{ 51 57 Roots: []cid.Cid{rc}, 52 58 Version: 1, 53 59 }) 60 + if err != nil { 61 + logger.Error("error dumping car header", "error", err) 62 + helpers.ServerError(w, nil) 63 + return 64 + } 65 + 66 + buf := new(bytes.Buffer) 54 67 55 68 if _, err := carstore.LdWrite(buf, hb); err != nil { 56 - logger.Error("error writing to car", "error", err) 57 - return helpers.ServerError(e, nil) 69 + logger.Error("error writing car header", "error", err) 70 + helpers.ServerError(w, nil) 71 + return 58 72 } 59 73 60 74 bs := s.getBlockstore(urepo.Repo.Did) ··· 62 76 for _, c := range cids { 63 77 b, err := bs.Get(ctx, c) 64 78 if err != nil { 65 - return err 79 + logger.Error("error getting block", "cid", c.String(), "error", err) 80 + helpers.ServerError(w, nil) 81 + return 66 82 } 67 83 68 84 if _, err := carstore.LdWrite(buf, b.Cid().Bytes(), b.RawData()); err != nil { 69 - return err 85 + logger.Error("error writing block to car", "error", err) 86 + helpers.ServerError(w, nil) 87 + return 70 88 } 71 89 } 72 90 73 - return e.Stream(200, "application/vnd.ipld.car", bytes.NewReader(buf.Bytes())) 91 + w.Header().Set("Content-Type", "application/vnd.ipld.car") 92 + w.WriteHeader(http.StatusOK) 93 + w.Write(buf.Bytes()) 74 94 }
+15 -8
server/handle_sync_get_latest_commit.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 5 + 4 6 "github.com/haileyok/cocoon/internal/helpers" 5 7 "github.com/ipfs/go-cid" 6 - "github.com/labstack/echo/v4" 7 8 ) 8 9 9 10 type ComAtprotoSyncGetLatestCommitResponse struct { ··· 11 12 Rev string `json:"rev"` 12 13 } 13 14 14 - func (s *Server) handleSyncGetLatestCommit(e echo.Context) error { 15 - ctx := e.Request().Context() 15 + func (s *Server) handleSyncGetLatestCommit(w http.ResponseWriter, r *http.Request) { 16 + ctx := r.Context() 17 + logger := s.logger.With("name", "handleSyncGetLatestCommit") 16 18 17 - did := e.QueryParam("did") 19 + did := r.URL.Query().Get("did") 18 20 if did == "" { 19 - return helpers.InputError(e, nil) 21 + helpers.InputError(w, nil) 22 + return 20 23 } 21 24 22 25 urepo, err := s.getRepoActorByDid(ctx, did) 23 26 if err != nil { 24 - return err 27 + logger.Error("could not find repo", "error", err) 28 + helpers.ServerError(w, nil) 29 + return 25 30 } 26 31 27 32 c, err := cid.Cast(urepo.Root) 28 33 if err != nil { 29 - return err 34 + logger.Error("could not cast root cid", "error", err) 35 + helpers.ServerError(w, nil) 36 + return 30 37 } 31 38 32 - return e.JSON(200, ComAtprotoSyncGetLatestCommitResponse{ 39 + s.writeJSON(w, http.StatusOK, ComAtprotoSyncGetLatestCommitResponse{ 33 40 Cid: c.String(), 34 41 Rev: urepo.Rev, 35 42 })
+24 -12
server/handle_sync_get_record.go
··· 2 2 3 3 import ( 4 4 "bytes" 5 + "net/http" 5 6 6 7 "github.com/bluesky-social/indigo/carstore" 7 8 "github.com/haileyok/cocoon/internal/helpers" ··· 9 10 "github.com/ipfs/go-cid" 10 11 cbor "github.com/ipfs/go-ipld-cbor" 11 12 "github.com/ipld/go-car" 12 - "github.com/labstack/echo/v4" 13 13 ) 14 14 15 - func (s *Server) handleSyncGetRecord(e echo.Context) error { 16 - ctx := e.Request().Context() 15 + func (s *Server) handleSyncGetRecord(w http.ResponseWriter, r *http.Request) { 16 + ctx := r.Context() 17 17 logger := s.logger.With("name", "handleSyncGetRecord") 18 18 19 - did := e.QueryParam("did") 20 - collection := e.QueryParam("collection") 21 - rkey := e.QueryParam("rkey") 19 + did := r.URL.Query().Get("did") 20 + collection := r.URL.Query().Get("collection") 21 + rkey := r.URL.Query().Get("rkey") 22 22 23 23 var urepo models.Repo 24 24 if err := s.db.Raw(ctx, "SELECT * FROM repos WHERE did = ?", nil, did).Scan(&urepo).Error; err != nil { 25 25 logger.Error("error getting repo", "error", err) 26 - return helpers.ServerError(e, nil) 26 + helpers.ServerError(w, nil) 27 + return 27 28 } 28 29 29 30 root, blocks, err := s.repoman.getRecordProof(ctx, urepo, collection, rkey) 30 31 if err != nil { 31 - return err 32 + logger.Error("error getting record proof", "error", err) 33 + helpers.ServerError(w, nil) 34 + return 32 35 } 33 36 34 37 buf := new(bytes.Buffer) ··· 37 40 Roots: []cid.Cid{root}, 38 41 Version: 1, 39 42 }) 43 + if err != nil { 44 + logger.Error("error dumping car header", "error", err) 45 + helpers.ServerError(w, nil) 46 + return 47 + } 40 48 41 49 if _, err := carstore.LdWrite(buf, hb); err != nil { 42 50 logger.Error("error writing to car", "error", err) 43 - return helpers.ServerError(e, nil) 51 + helpers.ServerError(w, nil) 52 + return 44 53 } 45 54 46 55 for _, blk := range blocks { 47 56 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil { 48 - logger.Error("error writing to car", "error", err) 49 - return helpers.ServerError(e, nil) 57 + logger.Error("error writing block to car", "error", err) 58 + helpers.ServerError(w, nil) 59 + return 50 60 } 51 61 } 52 62 53 - return e.Stream(200, "application/vnd.ipld.car", bytes.NewReader(buf.Bytes())) 63 + w.Header().Set("Content-Type", "application/vnd.ipld.car") 64 + w.WriteHeader(http.StatusOK) 65 + w.Write(buf.Bytes()) 54 66 }
+29 -12
server/handle_sync_get_repo.go
··· 2 2 3 3 import ( 4 4 "bytes" 5 + "net/http" 5 6 6 7 "github.com/bluesky-social/indigo/carstore" 7 8 "github.com/haileyok/cocoon/internal/helpers" ··· 9 10 "github.com/ipfs/go-cid" 10 11 cbor "github.com/ipfs/go-ipld-cbor" 11 12 "github.com/ipld/go-car" 12 - "github.com/labstack/echo/v4" 13 13 ) 14 14 15 - func (s *Server) handleSyncGetRepo(e echo.Context) error { 16 - ctx := e.Request().Context() 15 + func (s *Server) handleSyncGetRepo(w http.ResponseWriter, r *http.Request) { 16 + ctx := r.Context() 17 17 logger := s.logger.With("name", "handleSyncGetRepo") 18 18 19 - did := e.QueryParam("did") 19 + did := r.URL.Query().Get("did") 20 20 if did == "" { 21 - return helpers.InputError(e, nil) 21 + helpers.InputError(w, nil) 22 + return 22 23 } 23 24 24 25 urepo, err := s.getRepoActorByDid(ctx, did) 25 26 if err != nil { 26 - return err 27 + logger.Error("could not find repo", "did", did, "error", err) 28 + helpers.ServerError(w, nil) 29 + return 27 30 } 28 31 29 32 rc, err := cid.Cast(urepo.Root) 30 33 if err != nil { 31 - return err 34 + logger.Error("error casting root cid", "error", err) 35 + helpers.ServerError(w, nil) 36 + return 32 37 } 33 38 34 39 hb, err := cbor.DumpObject(&car.CarHeader{ 35 40 Roots: []cid.Cid{rc}, 36 41 Version: 1, 37 42 }) 43 + if err != nil { 44 + logger.Error("error dumping car header", "error", err) 45 + helpers.ServerError(w, nil) 46 + return 47 + } 38 48 39 49 buf := new(bytes.Buffer) 40 50 41 51 if _, err := carstore.LdWrite(buf, hb); err != nil { 42 - logger.Error("error writing to car", "error", err) 43 - return helpers.ServerError(e, nil) 52 + logger.Error("error writing car header", "error", err) 53 + helpers.ServerError(w, nil) 54 + return 44 55 } 45 56 46 57 var blocks []models.Block 47 58 if err := s.db.Raw(ctx, "SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", nil, urepo.Repo.Did).Scan(&blocks).Error; err != nil { 48 - return err 59 + logger.Error("error getting blocks", "error", err) 60 + helpers.ServerError(w, nil) 61 + return 49 62 } 50 63 51 64 for _, block := range blocks { 52 65 if _, err := carstore.LdWrite(buf, block.Cid, block.Value); err != nil { 53 - return err 66 + logger.Error("error writing block to car", "error", err) 67 + helpers.ServerError(w, nil) 68 + return 54 69 } 55 70 } 56 71 57 - return e.Stream(200, "application/vnd.ipld.car", bytes.NewReader(buf.Bytes())) 72 + w.Header().Set("Content-Type", "application/vnd.ipld.car") 73 + w.WriteHeader(http.StatusOK) 74 + w.Write(buf.Bytes()) 58 75 }
+12 -7
server/handle_sync_get_repo_status.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 5 + 4 6 "github.com/haileyok/cocoon/internal/helpers" 5 - "github.com/labstack/echo/v4" 6 7 ) 7 8 8 9 type ComAtprotoSyncGetRepoStatusResponse struct { ··· 13 14 } 14 15 15 16 // TODO: make this actually do the right thing 16 - func (s *Server) handleSyncGetRepoStatus(e echo.Context) error { 17 - ctx := e.Request().Context() 17 + func (s *Server) handleSyncGetRepoStatus(w http.ResponseWriter, r *http.Request) { 18 + ctx := r.Context() 19 + logger := s.logger.With("name", "handleSyncGetRepoStatus") 18 20 19 - did := e.QueryParam("did") 21 + did := r.URL.Query().Get("did") 20 22 if did == "" { 21 - return helpers.InputError(e, nil) 23 + helpers.InputError(w, nil) 24 + return 22 25 } 23 26 24 27 urepo, err := s.getRepoActorByDid(ctx, did) 25 28 if err != nil { 26 - return err 29 + logger.Error("could not find repo", "did", did, "error", err) 30 + helpers.ServerError(w, nil) 31 + return 27 32 } 28 33 29 - return e.JSON(200, ComAtprotoSyncGetRepoStatusResponse{ 34 + s.writeJSON(w, http.StatusOK, ComAtprotoSyncGetRepoStatusResponse{ 30 35 Did: urepo.Repo.Did, 31 36 Active: urepo.Active(), 32 37 Status: urepo.Status(),
+23 -14
server/handle_sync_list_blobs.go
··· 1 1 package server 2 2 3 3 import ( 4 + "net/http" 5 + "strconv" 6 + 4 7 "github.com/Azure/go-autorest/autorest/to" 5 8 "github.com/haileyok/cocoon/internal/helpers" 6 9 "github.com/haileyok/cocoon/models" 7 10 "github.com/ipfs/go-cid" 8 - "github.com/labstack/echo/v4" 9 11 ) 10 12 11 13 type ComAtprotoSyncListBlobsResponse struct { ··· 13 15 Cids []string `json:"cids"` 14 16 } 15 17 16 - func (s *Server) handleSyncListBlobs(e echo.Context) error { 17 - ctx := e.Request().Context() 18 + func (s *Server) handleSyncListBlobs(w http.ResponseWriter, r *http.Request) { 19 + ctx := r.Context() 18 20 logger := s.logger.With("name", "handleSyncListBlobs") 19 21 20 - did := e.QueryParam("did") 22 + did := r.URL.Query().Get("did") 21 23 if did == "" { 22 - return helpers.InputError(e, nil) 24 + helpers.InputError(w, nil) 25 + return 23 26 } 24 27 25 28 // TODO: add tid param 26 - cursor := e.QueryParam("cursor") 27 - limit, err := getLimitFromContext(e, 50) 28 - if err != nil { 29 - return helpers.InputError(e, nil) 29 + cursor := r.URL.Query().Get("cursor") 30 + 31 + limit := 50 32 + if limitStr := r.URL.Query().Get("limit"); limitStr != "" { 33 + if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 1000 { 34 + limit = l 35 + } 30 36 } 31 37 32 38 cursorquery := "" ··· 41 47 urepo, err := s.getRepoActorByDid(ctx, did) 42 48 if err != nil { 43 49 logger.Error("could not find user for requested blobs", "error", err) 44 - return helpers.InputError(e, nil) 50 + helpers.InputError(w, nil) 51 + return 45 52 } 46 53 47 54 status := urepo.Status() 48 55 if status != nil { 49 56 if *status == "deactivated" { 50 - return helpers.InputError(e, to.StringPtr("RepoDeactivated")) 57 + helpers.InputError(w, to.StringPtr("RepoDeactivated")) 58 + return 51 59 } 52 60 } 53 61 54 62 var blobs []models.Blob 55 63 if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil { 56 64 logger.Error("error getting records", "error", err) 57 - return helpers.ServerError(e, nil) 65 + helpers.ServerError(w, nil) 66 + return 58 67 } 59 68 60 69 cstrs := make([]string, 0, len(blobs)) ··· 72 81 } 73 82 74 83 var newcursor *string 75 - if len(blobs) == 50 { 84 + if len(blobs) == limit { 76 85 newcursor = &blobs[len(blobs)-1].CreatedAt 77 86 } 78 87 79 - return e.JSON(200, ComAtprotoSyncListBlobsResponse{ 88 + s.writeJSON(w, http.StatusOK, ComAtprotoSyncListBlobsResponse{ 80 89 Cursor: newcursor, 81 90 Cids: cstrs, 82 91 })
+9 -10
server/handle_sync_subscribe_repos.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "net/http" 5 6 "strconv" 6 7 "time" 7 8 ··· 9 10 "github.com/bluesky-social/indigo/lex/util" 10 11 "github.com/btcsuite/websocket" 11 12 "github.com/haileyok/cocoon/metrics" 12 - "github.com/labstack/echo/v4" 13 13 ) 14 14 15 - func (s *Server) handleSyncSubscribeRepos(e echo.Context) error { 16 - ctx, cancel := context.WithCancel(e.Request().Context()) 15 + func (s *Server) handleSyncSubscribeRepos(w http.ResponseWriter, r *http.Request) { 16 + ctx, cancel := context.WithCancel(r.Context()) 17 17 defer cancel() 18 18 19 19 logger := s.logger.With("component", "subscribe-repos-websocket") 20 20 21 - conn, err := websocket.Upgrade(e.Response().Writer, e.Request(), e.Response().Header(), 1<<10, 1<<10) 21 + conn, err := websocket.Upgrade(w, r, w.Header(), 1<<10, 1<<10) 22 22 if err != nil { 23 23 logger.Error("unable to establish websocket with relay", "err", err) 24 - return err 24 + return 25 25 } 26 26 27 - ident := e.RealIP() + "-" + e.Request().UserAgent() 27 + ident := r.RemoteAddr + "-" + r.UserAgent() 28 28 logger = logger.With("ident", ident) 29 29 logger.Info("new connection established") 30 30 31 31 var since *int64 32 - if cursorStr := e.QueryParam("cursor"); cursorStr != "" { 32 + if cursorStr := r.URL.Query().Get("cursor"); cursorStr != "" { 33 33 cursor, err := strconv.ParseInt(cursorStr, 10, 64) 34 34 if err != nil { 35 35 logger.Warn("invalid cursor parameter", "cursor", cursorStr, "err", err) ··· 48 48 return true 49 49 }, since) 50 50 if err != nil { 51 - return err 51 + logger.Error("error subscribing to event manager", "err", err) 52 + return 52 53 } 53 54 defer evtManCancel() 54 55 ··· 134 135 logger.Error("error requesting crawls", "err", err) 135 136 } 136 137 }() 137 - 138 - return nil 139 138 }
+26 -19
server/handle_well_known.go
··· 2 2 3 3 import ( 4 4 "fmt" 5 + "net/http" 5 6 "strings" 6 7 7 8 "github.com/Azure/go-autorest/autorest/to" 8 9 "github.com/haileyok/cocoon/internal/helpers" 9 - "github.com/labstack/echo/v4" 10 10 "gorm.io/gorm" 11 11 ) 12 12 ··· 50 50 ClientIDMetadataDocumentSupported bool `json:"client_id_metadata_document_supported"` 51 51 } 52 52 53 - func (s *Server) handleWellKnown(e echo.Context) error { 54 - return e.JSON(200, map[string]any{ 53 + func (s *Server) handleWellKnown(w http.ResponseWriter, r *http.Request) { 54 + s.writeJSON(w, 200, map[string]any{ 55 55 "@context": []string{ 56 56 "https://www.w3.org/ns/did/v1", 57 57 }, ··· 66 66 }) 67 67 } 68 68 69 - func (s *Server) handleAtprotoDid(e echo.Context) error { 70 - ctx := e.Request().Context() 69 + func (s *Server) handleAtprotoDid(w http.ResponseWriter, r *http.Request) { 70 + ctx := r.Context() 71 71 logger := s.logger.With("name", "handleAtprotoDid") 72 72 73 - host := e.Request().Host 73 + host := r.Host 74 74 if host == "" { 75 - return helpers.InputError(e, to.StringPtr("Invalid handle.")) 75 + helpers.InputError(w, to.StringPtr("Invalid handle.")) 76 + return 76 77 } 77 78 78 79 host = strings.Split(host, ":")[0] 79 80 host = strings.ToLower(strings.TrimSpace(host)) 80 81 81 82 if host == s.config.Hostname { 82 - return e.String(200, s.config.Did) 83 + w.Header().Set("Content-Type", "text/plain") 84 + fmt.Fprint(w, s.config.Did) 85 + return 83 86 } 84 87 85 88 suffix := "." + s.config.Hostname 86 89 if !strings.HasSuffix(host, suffix) { 87 - return e.NoContent(404) 90 + w.WriteHeader(http.StatusNotFound) 91 + return 88 92 } 89 93 90 94 actor, err := s.getActorByHandle(ctx, host) 91 95 if err != nil { 92 96 if err == gorm.ErrRecordNotFound { 93 - return e.NoContent(404) 97 + w.WriteHeader(http.StatusNotFound) 98 + return 94 99 } 95 100 logger.Error("error looking up actor by handle", "error", err) 96 - return helpers.ServerError(e, nil) 101 + helpers.ServerError(w, nil) 102 + return 97 103 } 98 104 99 - return e.String(200, actor.Did) 105 + w.Header().Set("Content-Type", "text/plain") 106 + fmt.Fprint(w, actor.Did) 100 107 } 101 108 102 - func (s *Server) handleOauthProtectedResource(e echo.Context) error { 103 - return e.JSON(200, map[string]any{ 109 + func (s *Server) handleOauthProtectedResource(w http.ResponseWriter, r *http.Request) { 110 + s.writeJSON(w, 200, map[string]any{ 104 111 "resource": "https://" + s.config.Hostname, 105 112 "authorization_servers": []string{ 106 113 "https://" + s.config.Hostname, ··· 111 118 }) 112 119 } 113 120 114 - func (s *Server) handleOauthAuthorizationServer(e echo.Context) error { 115 - return e.JSON(200, OauthAuthorizationMetadata{ 121 + func (s *Server) handleOauthAuthorizationServer(w http.ResponseWriter, r *http.Request) { 122 + s.writeJSON(w, 200, OauthAuthorizationMetadata{ 116 123 Issuer: "https://" + s.config.Hostname, 117 124 RequestParameterSupported: true, 118 125 RequestUriParameterSupported: true, ··· 125 132 CodeChallengeMethodsSupported: []string{"S256"}, 126 133 UILocalesSupported: []string{"en-US"}, 127 134 DisplayValuesSupported: []string{"page", "popup", "touch"}, 128 - RequestObjectSigningAlgValuesSupported: []string{"ES256"}, // only es256 for now... 135 + RequestObjectSigningAlgValuesSupported: []string{"ES256"}, 129 136 AuthorizationResponseISSParameterSupported: true, 130 137 RequestObjectEncryptionAlgValuesSupported: []string{}, 131 138 RequestObjectEncryptionEncValuesSupported: []string{}, ··· 133 140 AuthorizationEndpoint: fmt.Sprintf("https://%s/oauth/authorize", s.config.Hostname), 134 141 TokenEndpoint: fmt.Sprintf("https://%s/oauth/token", s.config.Hostname), 135 142 TokenEndpointAuthMethodsSupported: []string{"none", "private_key_jwt"}, 136 - TokenEndpointAuthSigningAlgValuesSupported: []string{"ES256"}, // Same as above, just es256 143 + TokenEndpointAuthSigningAlgValuesSupported: []string{"ES256"}, 137 144 RevocationEndpoint: fmt.Sprintf("https://%s/oauth/revoke", s.config.Hostname), 138 145 IntrospectionEndpoint: fmt.Sprintf("https://%s/oauth/introspect", s.config.Hostname), 139 146 PushedAuthorizationRequestEndpoint: fmt.Sprintf("https://%s/oauth/par", s.config.Hostname), 140 147 RequirePushedAuthorizationRequests: true, 141 - DpopSigningAlgValuesSupported: []string{"ES256"}, // again same as above 148 + DpopSigningAlgValuesSupported: []string{"ES256"}, 142 149 ProtectedResources: []string{"https://" + s.config.Hostname}, 143 150 ClientIDMetadataDocumentSupported: true, 144 151 })
+128 -78
server/middleware.go
··· 1 1 package server 2 2 3 3 import ( 4 + "context" 4 5 "crypto/sha256" 5 6 "encoding/base64" 6 7 "errors" 7 8 "fmt" 9 + "net/http" 8 10 "strings" 9 11 "time" 10 12 ··· 14 16 "github.com/haileyok/cocoon/models" 15 17 "github.com/haileyok/cocoon/oauth/dpop" 16 18 "github.com/haileyok/cocoon/oauth/provider" 17 - "github.com/labstack/echo/v4" 18 19 "gitlab.com/yawning/secp256k1-voi" 19 20 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 20 21 "gorm.io/gorm" 21 22 ) 22 23 23 - func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 24 - return func(e echo.Context) error { 25 - username, password, ok := e.Request().BasicAuth() 26 - if !ok || username != "admin" || password != s.config.AdminPassword { 27 - return helpers.InputError(e, to.StringPtr("Unauthorized")) 28 - } 24 + // context keys for values set by middleware 25 + type contextKey string 29 26 30 - if err := next(e); err != nil { 31 - e.Error(err) 32 - } 27 + const ( 28 + contextKeyRepo contextKey = "repo" 29 + contextKeyDid contextKey = "did" 30 + contextKeyToken contextKey = "token" 31 + contextKeyScopes contextKey = "scopes" 32 + 33 + // used by proxy handler to override token fields 34 + contextKeyProxyTokenLxm contextKey = "proxyTokenLxm" 35 + contextKeyProxyTokenAud contextKey = "proxyTokenAud" 36 + ) 33 37 34 - return nil 35 - } 38 + func setContextValue(r *http.Request, key contextKey, value any) *http.Request { 39 + return r.WithContext(context.WithValue(r.Context(), key, value)) 40 + } 41 + 42 + func getContextValue[T any](r *http.Request, key contextKey) (T, bool) { 43 + v, ok := r.Context().Value(key).(T) 44 + return v, ok 36 45 } 37 46 38 - func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 39 - return func(e echo.Context) error { 40 - ctx := e.Request().Context() 47 + func (s *Server) handleAdminMiddleware(next http.Handler) http.Handler { 48 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 49 + username, password, ok := r.BasicAuth() 50 + if !ok || username != "admin" || password != s.config.AdminPassword { 51 + helpers.InputError(w, to.StringPtr("Unauthorized")) 52 + return 53 + } 54 + next.ServeHTTP(w, r) 55 + }) 56 + } 57 + 58 + func (s *Server) handleLegacySessionMiddleware(next http.Handler) http.Handler { 59 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 60 + ctx := r.Context() 41 61 logger := s.logger.With("name", "handleLegacySessionMiddleware") 42 62 43 - authheader := e.Request().Header.Get("authorization") 63 + authheader := r.Header.Get("authorization") 44 64 if authheader == "" { 45 - return e.JSON(401, map[string]string{"error": "Unauthorized"}) 65 + s.writeJSON(w, 401, map[string]string{"error": "Unauthorized"}) 66 + return 46 67 } 47 68 48 69 pts := strings.Split(authheader, " ") 49 70 if len(pts) != 2 { 50 - return helpers.ServerError(e, nil) 71 + helpers.ServerError(w, nil) 72 + return 51 73 } 52 74 53 75 // move on to oauth session middleware if this is a dpop token 54 76 if pts[0] == "DPoP" { 55 - return next(e) 77 + next.ServeHTTP(w, r) 78 + return 56 79 } 57 80 58 81 tokenstr := pts[1] 59 82 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 60 83 claims, ok := token.Claims.(jwt.MapClaims) 61 84 if !ok { 62 - return helpers.InvalidTokenError(e) 85 + helpers.InvalidTokenError(w) 86 + return 63 87 } 64 88 65 89 var did string ··· 68 92 // service auth tokens 69 93 lxm, hasLxm := claims["lxm"] 70 94 if hasLxm { 71 - pts := strings.Split(e.Request().URL.String(), "/") 95 + pts := strings.Split(r.URL.String(), "/") 72 96 if lxm != pts[len(pts)-1] { 73 97 logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 74 - return helpers.InputError(e, nil) 98 + helpers.InputError(w, nil) 99 + return 75 100 } 76 101 77 102 maybeDid, ok := claims["iss"].(string) 78 103 if !ok { 79 104 logger.Error("no iss in service auth token", "error", err) 80 - return helpers.InputError(e, nil) 105 + helpers.InputError(w, nil) 106 + return 81 107 } 82 108 did = maybeDid 83 109 84 110 maybeRepo, err := s.getRepoActorByDid(ctx, did) 85 111 if err != nil { 86 112 logger.Error("error fetching repo", "error", err) 87 - return helpers.ServerError(e, nil) 113 + helpers.ServerError(w, nil) 114 + return 88 115 } 89 116 repo = maybeRepo 90 117 } ··· 98 125 }) 99 126 if err != nil { 100 127 logger.Error("error parsing jwt", "error", err) 101 - return helpers.ExpiredTokenError(e) 128 + helpers.ExpiredTokenError(w) 129 + return 102 130 } 103 131 104 132 if !token.Valid { 105 - return helpers.InvalidTokenError(e) 133 + helpers.InvalidTokenError(w) 134 + return 106 135 } 107 136 } else { 108 137 kpts := strings.Split(tokenstr, ".") ··· 111 140 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 112 141 if err != nil { 113 142 logger.Error("error decoding signature bytes", "error", err) 114 - return helpers.ServerError(e, nil) 143 + helpers.ServerError(w, nil) 144 + return 115 145 } 116 146 117 147 if len(sigBytes) != 64 { 118 148 logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 119 - return helpers.ServerError(e, nil) 149 + helpers.ServerError(w, nil) 150 + return 120 151 } 121 152 122 153 rBytes := sigBytes[:32] ··· 128 159 sub, ok := claims["sub"].(string) 129 160 if !ok { 130 161 s.logger.Error("no sub claim in ES256K token and repo not set") 131 - return helpers.InvalidTokenError(e) 162 + helpers.InvalidTokenError(w) 163 + return 132 164 } 133 165 maybeRepo, err := s.getRepoActorByDid(ctx, sub) 134 166 if err != nil { 135 167 s.logger.Error("error fetching repo for ES256K verification", "error", err) 136 - return helpers.ServerError(e, nil) 168 + helpers.ServerError(w, nil) 169 + return 137 170 } 138 171 repo = maybeRepo 139 172 did = sub ··· 142 175 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 143 176 if err != nil { 144 177 logger.Error("can't load private key", "error", err) 145 - return err 178 + helpers.ServerError(w, nil) 179 + return 146 180 } 147 181 148 182 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 149 183 if !ok { 150 184 logger.Error("error getting public key from sk") 151 - return helpers.ServerError(e, nil) 185 + helpers.ServerError(w, nil) 186 + return 152 187 } 153 188 154 189 verified := pubKey.VerifyRaw(hash[:], rr, ss) 155 190 if !verified { 156 191 logger.Error("error verifying", "error", err) 157 - return helpers.ServerError(e, nil) 192 + helpers.ServerError(w, nil) 193 + return 158 194 } 159 195 } 160 196 161 - isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 197 + isRefresh := r.URL.Path == "/xrpc/com.atproto.server.refreshSession" 162 198 scope, _ := claims["scope"].(string) 163 199 164 200 if isRefresh && scope != "com.atproto.refresh" { 165 - return helpers.InvalidTokenError(e) 201 + helpers.InvalidTokenError(w) 202 + return 166 203 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 167 - return helpers.InvalidTokenError(e) 204 + helpers.InvalidTokenError(w) 205 + return 168 206 } 169 207 170 208 table := "tokens" ··· 179 217 var result Result 180 218 if err := s.db.Raw(ctx, "SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 181 219 if err == gorm.ErrRecordNotFound { 182 - return helpers.InvalidTokenError(e) 220 + helpers.InvalidTokenError(w) 221 + return 183 222 } 184 223 185 224 logger.Error("error getting token from db", "error", err) 186 - return helpers.ServerError(e, nil) 225 + helpers.ServerError(w, nil) 226 + return 187 227 } 188 228 189 229 if !result.Found { 190 - return helpers.InvalidTokenError(e) 230 + helpers.InvalidTokenError(w) 231 + return 191 232 } 192 233 } 193 234 194 235 exp, ok := claims["exp"].(float64) 195 236 if !ok { 196 237 logger.Error("error getting iat from token") 197 - return helpers.ServerError(e, nil) 238 + helpers.ServerError(w, nil) 239 + return 198 240 } 199 241 200 242 if exp < float64(time.Now().UTC().Unix()) { 201 - return helpers.ExpiredTokenError(e) 243 + helpers.ExpiredTokenError(w) 244 + return 202 245 } 203 246 204 247 if repo == nil { 205 248 maybeRepo, err := s.getRepoActorByDid(ctx, claims["sub"].(string)) 206 249 if err != nil { 207 250 logger.Error("error fetching repo", "error", err) 208 - return helpers.ServerError(e, nil) 251 + helpers.ServerError(w, nil) 252 + return 209 253 } 210 254 repo = maybeRepo 211 255 did = repo.Repo.Did 212 256 } 213 257 214 - e.Set("repo", repo) 215 - e.Set("did", did) 216 - e.Set("token", tokenstr) 258 + r = setContextValue(r, contextKeyRepo, repo) 259 + r = setContextValue(r, contextKeyDid, did) 260 + r = setContextValue(r, contextKeyToken, tokenstr) 217 261 218 - if err := next(e); err != nil { 219 - return helpers.InvalidTokenError(e) 220 - } 221 - 222 - return nil 223 - } 262 + next.ServeHTTP(w, r) 263 + }) 224 264 } 225 265 226 - func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 227 - return func(e echo.Context) error { 228 - ctx := e.Request().Context() 266 + func (s *Server) handleOauthSessionMiddleware(next http.Handler) http.Handler { 267 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 268 + ctx := r.Context() 229 269 logger := s.logger.With("name", "handleOauthSessionMiddleware") 230 270 231 - authheader := e.Request().Header.Get("authorization") 271 + authheader := r.Header.Get("authorization") 232 272 if authheader == "" { 233 - return e.JSON(401, map[string]string{"error": "Unauthorized"}) 273 + s.writeJSON(w, 401, map[string]string{"error": "Unauthorized"}) 274 + return 234 275 } 235 276 236 277 pts := strings.Split(authheader, " ") 237 278 if len(pts) != 2 { 238 - return helpers.ServerError(e, nil) 279 + helpers.ServerError(w, nil) 280 + return 239 281 } 240 282 241 283 if pts[0] != "DPoP" { 242 - return next(e) 284 + next.ServeHTTP(w, r) 285 + return 243 286 } 244 287 245 288 accessToken := pts[1] 246 289 247 290 nonce := s.oauthProvider.NextNonce() 248 291 if nonce != "" { 249 - e.Response().Header().Set("DPoP-Nonce", nonce) 250 - e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 292 + w.Header().Set("DPoP-Nonce", nonce) 293 + w.Header().Add("access-control-expose-headers", "DPoP-Nonce") 251 294 } 252 295 253 - proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 296 + proof, err := s.oauthProvider.DpopManager.CheckProof(r.Method, "https://"+s.config.Hostname+r.URL.String(), r.Header, to.StringPtr(accessToken)) 254 297 if err != nil { 255 298 if errors.Is(err, dpop.ErrUseDpopNonce) { 256 - e.Response().Header().Set("WWW-Authenticate", `DPoP error="use_dpop_nonce"`) 257 - e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate") 258 - return e.JSON(401, map[string]string{ 299 + w.Header().Set("WWW-Authenticate", `DPoP error="use_dpop_nonce"`) 300 + w.Header().Add("access-control-expose-headers", "WWW-Authenticate") 301 + s.writeJSON(w, 401, map[string]string{ 259 302 "error": "use_dpop_nonce", 260 303 }) 304 + return 261 305 } 262 306 logger.Error("invalid dpop proof", "error", err) 263 - return helpers.InputError(e, nil) 307 + helpers.InputError(w, nil) 308 + return 264 309 } 265 310 266 311 var oauthToken provider.OauthToken 267 312 if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 268 313 logger.Error("error finding access token in db", "error", err) 269 - return helpers.InputError(e, nil) 314 + helpers.InputError(w, nil) 315 + return 270 316 } 271 317 272 318 if oauthToken.Token == "" { 273 - return helpers.InvalidTokenError(e) 319 + helpers.InvalidTokenError(w) 320 + return 274 321 } 275 322 276 323 if *oauthToken.Parameters.DpopJkt != proof.JKT { 277 324 logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 278 - return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 325 + helpers.InputError(w, to.StringPtr("dpop jkt mismatch")) 326 + return 279 327 } 280 328 281 329 if time.Now().After(oauthToken.ExpiresAt) { 282 - e.Response().Header().Set("WWW-Authenticate", `DPoP error="invalid_token", error_description="Token expired"`) 283 - e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate") 284 - return e.JSON(401, map[string]string{ 330 + w.Header().Set("WWW-Authenticate", `DPoP error="invalid_token", error_description="Token expired"`) 331 + w.Header().Add("access-control-expose-headers", "WWW-Authenticate") 332 + s.writeJSON(w, 401, map[string]string{ 285 333 "error": "invalid_token", 286 334 "error_description": "Token expired", 287 335 }) 336 + return 288 337 } 289 338 290 339 repo, err := s.getRepoActorByDid(ctx, oauthToken.Sub) 291 340 if err != nil { 292 341 logger.Error("could not find actor in db", "error", err) 293 - return helpers.ServerError(e, nil) 342 + helpers.ServerError(w, nil) 343 + return 294 344 } 295 345 296 - e.Set("repo", repo) 297 - e.Set("did", repo.Repo.Did) 298 - e.Set("token", accessToken) 299 - e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 346 + r = setContextValue(r, contextKeyRepo, repo) 347 + r = setContextValue(r, contextKeyDid, repo.Repo.Did) 348 + r = setContextValue(r, contextKeyToken, accessToken) 349 + r = setContextValue(r, contextKeyScopes, strings.Split(oauthToken.Parameters.Scope, " ")) 300 350 301 - return next(e) 302 - } 351 + next.ServeHTTP(w, r) 352 + }) 303 353 }
+159 -118
server/server.go
··· 4 4 "context" 5 5 "crypto/ecdsa" 6 6 "embed" 7 + "encoding/json" 7 8 "errors" 8 9 "fmt" 10 + "html/template" 9 11 "io" 10 12 "log/slog" 11 13 "net/http" ··· 13 15 "os" 14 16 "path/filepath" 15 17 "sync" 16 - "text/template" 17 18 "time" 18 19 19 20 "github.com/bluesky-social/indigo/api/atproto" ··· 23 24 "github.com/bluesky-social/indigo/xrpc" 24 25 "github.com/domodwyer/mailyak/v3" 25 26 "github.com/glebarez/sqlite" 27 + "github.com/go-chi/chi/v5" 28 + chimiddleware "github.com/go-chi/chi/v5/middleware" 26 29 "github.com/go-playground/validator" 27 30 "github.com/gorilla/sessions" 28 31 "github.com/haileyok/cocoon/identity" ··· 35 38 "github.com/haileyok/cocoon/oauth/provider" 36 39 "github.com/haileyok/cocoon/plc" 37 40 "github.com/ipfs/go-cid" 38 - "github.com/labstack/echo-contrib/echoprometheus" 39 - echo_session "github.com/labstack/echo-contrib/session" 40 - "github.com/labstack/echo/v4" 41 - "github.com/labstack/echo/v4/middleware" 42 - slogecho "github.com/samber/slog-echo" 41 + "github.com/prometheus/client_golang/prometheus/promhttp" 42 + 43 43 "gorm.io/gorm" 44 44 ) 45 45 ··· 76 76 } 77 77 78 78 type Server struct { 79 - http *http.Client 80 - httpd *http.Server 81 - mail *mailyak.MailYak 82 - mailLk *sync.Mutex 83 - echo *echo.Echo 84 - db *db.DB 85 - plcClient *plc.Client 86 - logger *slog.Logger 87 - config *config 88 - privateKey *ecdsa.PrivateKey 89 - repoman *RepoMan 90 - oauthProvider *provider.Provider 91 - evtman *events.EventManager 92 - passport *identity.Passport 93 - fallbackProxy string 79 + http *http.Client 80 + httpd *http.Server 81 + mail *mailyak.MailYak 82 + mailLk *sync.Mutex 83 + router *chi.Mux 84 + db *db.DB 85 + plcClient *plc.Client 86 + logger *slog.Logger 87 + config *config 88 + privateKey *ecdsa.PrivateKey 89 + repoman *RepoMan 90 + oauthProvider *provider.Provider 91 + evtman *events.EventManager 92 + passport *identity.Passport 93 + fallbackProxy string 94 + sessions *sessions.CookieStore 95 + validator *validator.Validate 96 + templateRenderer *TemplateRenderer 94 97 95 98 lastRequestCrawl time.Time 96 99 requestCrawlMu sync.Mutex ··· 190 193 absPath, _ := filepath.Abs("server/templates/*.html") 191 194 if s.config.Version == "dev" { 192 195 tmpl := template.Must(template.ParseGlob(absPath)) 193 - s.echo.Renderer = &TemplateRenderer{ 196 + s.templateRenderer = &TemplateRenderer{ 194 197 templates: tmpl, 195 198 isDev: true, 196 199 templatePath: absPath, 197 200 } 198 201 } else { 199 202 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html")) 200 - s.echo.Renderer = &TemplateRenderer{ 203 + s.templateRenderer = &TemplateRenderer{ 201 204 templates: tmpl, 202 205 isDev: false, 203 206 } 204 207 } 205 208 } 206 209 207 - func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { 210 + func (t *TemplateRenderer) Render(w io.Writer, name string, data any) error { 208 211 if t.isDev { 209 212 tmpl, err := template.ParseGlob(t.templatePath) 210 213 if err != nil { ··· 213 216 t.templates = tmpl 214 217 } 215 218 216 - if viewContext, isMap := data.(map[string]any); isMap { 217 - viewContext["reverse"] = c.Echo().Reverse 218 - } 219 + return t.templates.ExecuteTemplate(w, name, data) 220 + } 221 + 222 + // renderTemplate is a convenience method on the server that renders a named 223 + // HTML template to the given ResponseWriter with a 200 status. 224 + func (s *Server) renderTemplate(w http.ResponseWriter, name string, data any) error { 225 + w.Header().Set("Content-Type", "text/html; charset=utf-8") 226 + return s.templateRenderer.Render(w, name, data) 227 + } 219 228 220 - return t.templates.ExecuteTemplate(w, name, data) 229 + // writeJSON writes a JSON-encoded value with the given status code. 230 + func (s *Server) writeJSON(w http.ResponseWriter, status int, v any) { 231 + w.Header().Set("Content-Type", "application/json") 232 + w.WriteHeader(status) 233 + if err := json.NewEncoder(w).Encode(v); err != nil { 234 + s.logger.Error("failed to encode JSON response", "error", err) 235 + } 221 236 } 222 237 223 238 func New(args *Args) (*Server, error) { ··· 259 274 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 260 275 } 261 276 262 - e := echo.New() 277 + r := chi.NewRouter() 263 278 264 - e.Pre(middleware.RemoveTrailingSlash()) 265 - e.Pre(slogecho.New(args.Logger.With("component", "slogecho"))) 266 - e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 267 - e.Use(echoprometheus.NewMiddleware("cocoon")) 268 - e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 269 - AllowOrigins: []string{"*"}, 270 - AllowHeaders: []string{"*"}, 271 - AllowMethods: []string{"*"}, 272 - AllowCredentials: true, 273 - MaxAge: 100_000_000, 274 - })) 279 + r.Use(chimiddleware.StripSlashes) 280 + r.Use(func(next http.Handler) http.Handler { 281 + logger := args.Logger.With("component", "http") 282 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 283 + next.ServeHTTP(w, r) 284 + logger.Info("request", "method", r.Method, "path", r.URL.Path, "remote_addr", r.RemoteAddr) 285 + }) 286 + }) 287 + r.Use(corsMiddleware) 275 288 276 289 vdtor := validator.New() 277 290 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { ··· 299 312 return true 300 313 }) 301 314 302 - e.Validator = &CustomValidator{validator: vdtor} 303 - 304 315 httpd := &http.Server{ 305 316 Addr: args.Addr, 306 - Handler: e, 317 + Handler: r, 307 318 // shitty defaults but okay for now, needed for import repo 308 319 ReadTimeout: 5 * time.Minute, 309 320 WriteTimeout: 5 * time.Minute, ··· 370 381 return nil, fmt.Errorf("failed to create event persister: %w", err) 371 382 } 372 383 384 + cookieStore := sessions.NewCookieStore([]byte(args.SessionSecret)) 385 + 373 386 s := &Server{ 374 387 http: h, 375 388 httpd: httpd, 376 - echo: e, 389 + router: r, 377 390 logger: args.Logger, 378 391 db: dbw, 379 392 plcClient: plcClient, 380 393 privateKey: &pkey, 394 + sessions: cookieStore, 395 + validator: vdtor, 381 396 config: &config{ 382 397 Version: args.Version, 383 398 Did: args.Did, ··· 438 453 return s, nil 439 454 } 440 455 456 + // corsMiddleware adds permissive CORS headers to every response. 457 + func corsMiddleware(next http.Handler) http.Handler { 458 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 459 + w.Header().Set("Access-Control-Allow-Origin", "*") 460 + w.Header().Set("Access-Control-Allow-Headers", "*") 461 + w.Header().Set("Access-Control-Allow-Methods", "*") 462 + w.Header().Set("Access-Control-Allow-Credentials", "true") 463 + w.Header().Set("Access-Control-Max-Age", "100000000") 464 + if r.Method == http.MethodOptions { 465 + w.WriteHeader(http.StatusNoContent) 466 + return 467 + } 468 + next.ServeHTTP(w, r) 469 + }) 470 + } 471 + 441 472 func (s *Server) addRoutes() { 473 + r := s.router 474 + 442 475 // static 443 476 if s.config.Version == "dev" { 444 - s.echo.Static("/static", "server/static") 477 + r.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.Dir("server/static")))) 445 478 } else { 446 - s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS)))) 479 + r.Handle("/static/*", http.FileServer(http.FS(staticFS))) 447 480 } 448 481 482 + // metrics 483 + r.Handle("/metrics", promhttp.Handler()) 484 + 449 485 // random stuff 450 - s.echo.GET("/", s.handleRoot) 451 - s.echo.GET("/xrpc/_health", s.handleHealth) 452 - s.echo.GET("/.well-known/did.json", s.handleWellKnown) 453 - s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid) 454 - s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 455 - s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 456 - s.echo.GET("/robots.txt", s.handleRobots) 486 + r.Get("/", s.handleRoot) 487 + r.Get("/xrpc/_health", s.handleHealth) 488 + r.Get("/.well-known/did.json", s.handleWellKnown) 489 + r.Get("/.well-known/atproto-did", s.handleAtprotoDid) 490 + r.Get("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 491 + r.Get("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 492 + r.Get("/robots.txt", s.handleRobots) 457 493 458 494 // public 459 - s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 460 - s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 461 - s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 462 - s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 463 - s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey) 495 + r.Get("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 496 + r.Post("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 497 + r.Post("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 498 + r.Get("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 499 + r.Post("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey) 464 500 465 - s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 466 - s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 467 - s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 468 - s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 469 - s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 470 - s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 471 - s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 472 - s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 473 - s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 474 - s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 475 - s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 476 - s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 501 + r.Get("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 502 + r.Get("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 503 + r.Get("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 504 + r.Get("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 505 + r.Get("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 506 + r.Get("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 507 + r.Get("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 508 + r.Get("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 509 + r.Get("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 510 + r.Get("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 511 + r.Get("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 512 + r.Get("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 477 513 478 514 // labels 479 - s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels) 515 + r.Get("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels) 480 516 481 517 // account 482 - s.echo.GET("/account", s.handleAccount) 483 - s.echo.POST("/account/revoke", s.handleAccountRevoke) 484 - s.echo.GET("/account/signin", s.handleAccountSigninGet) 485 - s.echo.POST("/account/signin", s.handleAccountSigninPost) 486 - s.echo.GET("/account/signout", s.handleAccountSignout) 518 + r.Get("/account", s.handleAccount) 519 + r.Post("/account/revoke", s.handleAccountRevoke) 520 + r.Get("/account/signin", s.handleAccountSigninGet) 521 + r.Post("/account/signin", s.handleAccountSigninPost) 522 + r.Get("/account/signout", s.handleAccountSignout) 487 523 488 524 // oauth account 489 - s.echo.GET("/oauth/jwks", s.handleOauthJwks) 490 - s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet) 491 - s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost) 525 + r.Get("/oauth/jwks", s.handleOauthJwks) 526 + r.Get("/oauth/authorize", s.handleOauthAuthorizeGet) 527 + r.Post("/oauth/authorize", s.handleOauthAuthorizePost) 492 528 493 - // oauth authorization 494 - s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware) 495 - s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware) 529 + // oauth authorization (with BaseMiddleware) 530 + r.With(s.oauthProvider.BaseMiddleware).Post("/oauth/par", s.handleOauthPar) 531 + r.With(s.oauthProvider.BaseMiddleware).Post("/oauth/token", s.handleOauthToken) 496 532 497 533 // authed 498 - s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 499 - s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 500 - s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 501 - s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 502 - s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 503 - s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 504 - s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 505 - s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 506 - s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 507 - s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 508 - s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 509 - s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 510 - s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 511 - s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 512 - s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 513 - s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 514 - s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 515 - s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 516 - s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 517 - s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount) 534 + authed := func(h http.HandlerFunc) http.Handler { 535 + return s.handleLegacySessionMiddleware(s.handleOauthSessionMiddleware(h)) 536 + } 537 + 538 + r.Get("/xrpc/com.atproto.server.getSession", authed(s.handleGetSession).ServeHTTP) 539 + r.Post("/xrpc/com.atproto.server.refreshSession", authed(s.handleRefreshSession).ServeHTTP) 540 + r.Post("/xrpc/com.atproto.server.deleteSession", authed(s.handleDeleteSession).ServeHTTP) 541 + r.Get("/xrpc/com.atproto.identity.getRecommendedDidCredentials", authed(s.handleGetRecommendedDidCredentials).ServeHTTP) 542 + r.Post("/xrpc/com.atproto.identity.updateHandle", authed(s.handleIdentityUpdateHandle).ServeHTTP) 543 + r.Post("/xrpc/com.atproto.identity.requestPlcOperationSignature", authed(s.handleIdentityRequestPlcOperationSignature).ServeHTTP) 544 + r.Post("/xrpc/com.atproto.identity.signPlcOperation", authed(s.handleSignPlcOperation).ServeHTTP) 545 + r.Post("/xrpc/com.atproto.identity.submitPlcOperation", authed(s.handleSubmitPlcOperation).ServeHTTP) 546 + r.Post("/xrpc/com.atproto.server.confirmEmail", authed(s.handleServerConfirmEmail).ServeHTTP) 547 + r.Post("/xrpc/com.atproto.server.requestEmailConfirmation", authed(s.handleServerRequestEmailConfirmation).ServeHTTP) 548 + r.Post("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 549 + r.Post("/xrpc/com.atproto.server.requestEmailUpdate", authed(s.handleServerRequestEmailUpdate).ServeHTTP) 550 + r.Post("/xrpc/com.atproto.server.resetPassword", authed(s.handleServerResetPassword).ServeHTTP) 551 + r.Post("/xrpc/com.atproto.server.updateEmail", authed(s.handleServerUpdateEmail).ServeHTTP) 552 + r.Get("/xrpc/com.atproto.server.getServiceAuth", authed(s.handleServerGetServiceAuth).ServeHTTP) 553 + r.Get("/xrpc/com.atproto.server.checkAccountStatus", authed(s.handleServerCheckAccountStatus).ServeHTTP) 554 + r.Post("/xrpc/com.atproto.server.deactivateAccount", authed(s.handleServerDeactivateAccount).ServeHTTP) 555 + r.Post("/xrpc/com.atproto.server.activateAccount", authed(s.handleServerActivateAccount).ServeHTTP) 556 + r.Post("/xrpc/com.atproto.server.requestAccountDelete", authed(s.handleServerRequestAccountDelete).ServeHTTP) 557 + r.Post("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount) 518 558 519 559 // repo 520 - s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 521 - s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 522 - s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 523 - s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 524 - s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 525 - s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 526 - s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 560 + r.Get("/xrpc/com.atproto.repo.listMissingBlobs", authed(s.handleListMissingBlobs).ServeHTTP) 561 + r.Post("/xrpc/com.atproto.repo.createRecord", authed(s.handleCreateRecord).ServeHTTP) 562 + r.Post("/xrpc/com.atproto.repo.putRecord", authed(s.handlePutRecord).ServeHTTP) 563 + r.Post("/xrpc/com.atproto.repo.deleteRecord", authed(s.handleDeleteRecord).ServeHTTP) 564 + r.Post("/xrpc/com.atproto.repo.applyWrites", authed(s.handleApplyWrites).ServeHTTP) 565 + r.Post("/xrpc/com.atproto.repo.uploadBlob", authed(s.handleRepoUploadBlob).ServeHTTP) 566 + r.Post("/xrpc/com.atproto.repo.importRepo", authed(s.handleRepoImportRepo).ServeHTTP) 527 567 528 568 // stupid silly endpoints 529 - s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 530 - s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 531 - s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 532 - s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 569 + r.Get("/xrpc/app.bsky.actor.getPreferences", authed(s.handleActorGetPreferences).ServeHTTP) 570 + r.Post("/xrpc/app.bsky.actor.putPreferences", authed(s.handleActorPutPreferences).ServeHTTP) 571 + r.Get("/xrpc/app.bsky.feed.getFeed", authed(s.handleProxyBskyFeedGetFeed).ServeHTTP) 572 + r.Get("/xrpc/app.bsky.ageassurance.getState", authed(s.handleAgeAssurance).ServeHTTP) 573 + 533 574 // admin routes 534 - s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 535 - s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 575 + r.With(s.handleAdminMiddleware).Post("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode) 576 + r.With(s.handleAdminMiddleware).Post("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes) 536 577 537 - // are there any routes that we should be allowing without auth? i dont think so but idk 538 - s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 539 - s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 578 + // catch-all proxy (authed) 579 + r.Get("/xrpc/*", authed(s.handleProxy).ServeHTTP) 580 + r.Post("/xrpc/*", authed(s.handleProxy).ServeHTTP) 540 581 } 541 582 542 583 func (s *Server) Serve(ctx context.Context) error {