tangled
alpha
login
or
join now
julien.rbrt.fr
/
vow
forked from
hailey.at/cocoon
0
fork
atom
Vow, uncensorable PDS written in Go
0
fork
atom
overview
issues
pulls
pipelines
refactor: use chi
julien.rbrt.fr
3 days ago
704386bc
5ce7fa86
+1732
-1159
70 changed files
expand all
collapse all
unified
split
go.mod
go.sum
internal
helpers
helpers.go
oauth
provider
middleware.go
server
handle_account.go
handle_account_revoke.go
handle_account_signin.go
handle_account_signout.go
handle_actor_get_preferences.go
handle_actor_put_preferences.go
handle_age_assurance.go
handle_health.go
handle_identity_get_recommended_did_credentials.go
handle_identity_request_plc_operation.go
handle_identity_sign_plc_operation.go
handle_identity_submit_plc_operation.go
handle_identity_update_handle.go
handle_import_repo.go
handle_label_query_labels.go
handle_oauth_authorize.go
handle_oauth_jwks.go
handle_oauth_par.go
handle_oauth_token.go
handle_proxy.go
handle_proxy_get_feed.go
handle_repo_apply_writes.go
handle_repo_create_record.go
handle_repo_delete_record.go
handle_repo_describe_repo.go
handle_repo_get_record.go
handle_repo_list_missing_blobs.go
handle_repo_list_records.go
handle_repo_list_repos.go
handle_repo_put_record.go
handle_repo_upload_blob.go
handle_robots.go
handle_root.go
handle_server_activate_account.go
handle_server_check_account_status.go
handle_server_confirm_email.go
handle_server_create_account.go
handle_server_create_invite_code.go
handle_server_create_invite_codes.go
handle_server_create_session.go
handle_server_deactivate_account.go
handle_server_delete_account.go
handle_server_delete_session.go
handle_server_describe_server.go
handle_server_get_service_auth.go
handle_server_get_session.go
handle_server_refresh_session.go
handle_server_request_account_delete.go
handle_server_request_email_confirmation.go
handle_server_request_email_update.go
handle_server_request_password_reset.go
handle_server_reserve_signing_key.go
handle_server_reset_password.go
handle_server_resolve_handle.go
handle_server_update_email.go
handle_sync_get_blob.go
handle_sync_get_blocks.go
handle_sync_get_latest_commit.go
handle_sync_get_record.go
handle_sync_get_repo.go
handle_sync_get_repo_status.go
handle_sync_list_blobs.go
handle_sync_subscribe_repos.go
handle_well_known.go
middleware.go
server.go
+1
-8
go.mod
···
8
8
github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792
9
9
github.com/domodwyer/mailyak/v3 v3.6.2
10
10
github.com/glebarez/sqlite v1.11.0
11
11
+
github.com/go-chi/chi/v5 v5.2.5
11
12
github.com/go-pkgz/expirable-cache/v3 v3.0.0
12
13
github.com/go-playground/validator v9.31.0+incompatible
13
14
github.com/golang-jwt/jwt/v4 v4.5.2
···
22
23
github.com/ipfs/go-ipld-cbor v0.1.0
23
24
github.com/ipld/go-car v0.6.1-0.20230509095817-92d28eb23ba4
24
25
github.com/joho/godotenv v1.5.1
25
25
-
github.com/labstack/echo-contrib v0.17.4
26
26
-
github.com/labstack/echo/v4 v4.13.3
27
26
github.com/lestrrat-go/jwx/v2 v2.0.21
28
27
github.com/multiformats/go-multihash v0.2.3
29
28
github.com/prometheus/client_golang v1.23.2
30
30
-
github.com/samber/slog-echo v1.16.1
31
29
github.com/spf13/cobra v1.10.2
32
30
github.com/spf13/viper v1.21.0
33
31
github.com/whyrusleeping/cbor-gen v0.2.1-0.20241030202151-b7a6831be65e
···
56
54
github.com/gocql/gocql v1.7.0 // indirect
57
55
github.com/gogo/protobuf v1.3.2 // indirect
58
56
github.com/golang/snappy v0.0.4 // indirect
59
59
-
github.com/gorilla/context v1.1.2 // indirect
60
57
github.com/gorilla/securecookie v1.1.2 // indirect
61
58
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
62
59
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
···
87
84
github.com/jinzhu/inflection v1.0.0 // indirect
88
85
github.com/jinzhu/now v1.1.5 // indirect
89
86
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
90
90
-
github.com/labstack/gommon v0.4.2 // indirect
91
87
github.com/leodido/go-urn v1.4.0 // indirect
92
88
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
93
89
github.com/lestrrat-go/httpcc v1.0.1 // indirect
···
112
108
github.com/prometheus/procfs v0.16.1 // indirect
113
109
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
114
110
github.com/sagikazarmark/locafero v0.11.0 // indirect
115
115
-
github.com/samber/lo v1.49.1 // indirect
116
111
github.com/segmentio/asm v1.2.0 // indirect
117
112
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
118
113
github.com/spaolacci/murmur3 v1.1.0 // indirect
···
120
115
github.com/spf13/cast v1.10.0 // indirect
121
116
github.com/spf13/pflag v1.0.10 // indirect
122
117
github.com/subosito/gotenv v1.6.0 // indirect
123
123
-
github.com/valyala/bytebufferpool v1.0.0 // indirect
124
124
-
github.com/valyala/fasttemplate v1.2.2 // indirect
125
118
gitlab.com/yawning/tuplehash v0.0.0-20230713102510-df83abbf9a02 // indirect
126
119
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
127
120
go.opentelemetry.io/otel v1.29.0 // indirect
+2
-16
go.sum
···
49
49
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
50
50
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
51
51
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
52
52
+
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
53
53
+
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
52
54
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
53
55
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
54
56
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
···
91
93
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
92
94
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
93
95
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
94
94
-
github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o=
95
95
-
github.com/gorilla/context v1.1.2/go.mod h1:KDPwT9i/MeWHiLl90fuTgrt4/wPcv75vFAZLaOOcbxM=
96
96
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
97
97
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
98
98
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
···
218
218
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
219
219
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
220
220
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
221
221
-
github.com/labstack/echo-contrib v0.17.4 h1:g5mfsrJfJTKv+F5uNKCyrjLK7js+ZW6HTjg4FnDxxgk=
222
222
-
github.com/labstack/echo-contrib v0.17.4/go.mod h1:9O7ZPAHUeMGTOAfg80YqQduHzt0CzLak36PZRldYrZ0=
223
223
-
github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaafY=
224
224
-
github.com/labstack/echo/v4 v4.13.3/go.mod h1:o90YNEeQWjDozo584l7AwhJMHN0bOC4tAfg+Xox9q5g=
225
225
-
github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
226
226
-
github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU=
227
221
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
228
222
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
229
223
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
···
320
314
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
321
315
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
322
316
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
323
323
-
github.com/samber/lo v1.49.1 h1:4BIFyVfuQSEpluc7Fua+j1NolZHiEHEpaSEKdsH0tew=
324
324
-
github.com/samber/lo v1.49.1/go.mod h1:dO6KHFzUKXgP8LDhU0oI8d2hekjXnGOu0DB8Jecxd6o=
325
325
-
github.com/samber/slog-echo v1.16.1 h1:5Q5IUROkFqKcu/qJM/13AP1d3gd1RS+Q/4EvKQU1fuo=
326
326
-
github.com/samber/slog-echo v1.16.1/go.mod h1:f+B3WR06saRXcaGRZ/I/UPCECDPqTUqadRIf7TmyRhI=
327
317
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
328
318
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
329
319
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
···
357
347
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
358
348
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
359
349
github.com/urfave/cli v1.22.10/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
360
360
-
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
361
361
-
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
362
362
-
github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
363
363
-
github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
364
350
github.com/warpfork/go-testmark v0.12.1 h1:rMgCpJfwy1sJ50x0M0NgyphxYYPMOODIJHhsXyEHU0s=
365
351
github.com/warpfork/go-testmark v0.12.1/go.mod h1:kHwy7wfvGSPh1rQJYKayD4AbtNaeyZdcGi9tNJTaa5Y=
366
352
github.com/warpfork/go-wish v0.0.0-20220906213052-39a1cc7a02d0 h1:GDDkbFiaK8jsSDJfjId/PEGEShv6ugrt4kYsC5UIDaQ=
+23
-16
internal/helpers/helpers.go
···
3
3
import (
4
4
crand "crypto/rand"
5
5
"encoding/hex"
6
6
+
"encoding/json"
6
7
"errors"
7
8
"math/rand"
9
9
+
"net/http"
8
10
"net/url"
9
11
10
10
-
"github.com/Azure/go-autorest/autorest/to"
11
11
-
"github.com/labstack/echo/v4"
12
12
"github.com/lestrrat-go/jwx/v2/jwk"
13
13
)
14
14
···
16
16
// /^[A-Z2-7]{5}-[A-Z2-7]{5}$/
17
17
var letters = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567")
18
18
19
19
-
func InputError(e echo.Context, custom *string) error {
19
19
+
func writeJSON(w http.ResponseWriter, status int, v any) error {
20
20
+
w.Header().Set("Content-Type", "application/json")
21
21
+
w.WriteHeader(status)
22
22
+
return json.NewEncoder(w).Encode(v)
23
23
+
}
24
24
+
25
25
+
func InputError(w http.ResponseWriter, custom *string) error {
20
26
msg := "InvalidRequest"
21
27
if custom != nil {
22
28
msg = *custom
23
29
}
24
24
-
return genericError(e, 400, msg)
30
30
+
return genericError(w, http.StatusBadRequest, msg)
25
31
}
26
32
27
27
-
func ServerError(e echo.Context, suffix *string) error {
33
33
+
func ServerError(w http.ResponseWriter, suffix *string) error {
28
34
msg := "Internal server error"
29
35
if suffix != nil {
30
36
msg += ". " + *suffix
31
37
}
32
32
-
return genericError(e, 500, msg)
38
38
+
return genericError(w, http.StatusInternalServerError, msg)
33
39
}
34
40
35
35
-
func UnauthorizedError(e echo.Context, suffix *string) error {
41
41
+
func UnauthorizedError(w http.ResponseWriter, suffix *string) error {
36
42
msg := "Unauthorized"
37
43
if suffix != nil {
38
44
msg += ". " + *suffix
39
45
}
40
40
-
return genericError(e, 401, msg)
46
46
+
return genericError(w, http.StatusUnauthorized, msg)
41
47
}
42
48
43
43
-
func ForbiddenError(e echo.Context, suffix *string) error {
49
49
+
func ForbiddenError(w http.ResponseWriter, suffix *string) error {
44
50
msg := "Forbidden"
45
51
if suffix != nil {
46
52
msg += ". " + *suffix
47
53
}
48
48
-
return genericError(e, 403, msg)
54
54
+
return genericError(w, http.StatusForbidden, msg)
49
55
}
50
56
51
51
-
func InvalidTokenError(e echo.Context) error {
52
52
-
return InputError(e, to.StringPtr("InvalidToken"))
57
57
+
func InvalidTokenError(w http.ResponseWriter) error {
58
58
+
s := "InvalidToken"
59
59
+
return InputError(w, &s)
53
60
}
54
61
55
55
-
func ExpiredTokenError(e echo.Context) error {
62
62
+
func ExpiredTokenError(w http.ResponseWriter) error {
56
63
// WARN: See https://github.com/bluesky-social/atproto/discussions/3319
57
57
-
return e.JSON(400, map[string]string{
64
64
+
return writeJSON(w, http.StatusBadRequest, map[string]string{
58
65
"error": "ExpiredToken",
59
66
"message": "*",
60
67
})
61
68
}
62
69
63
63
-
func genericError(e echo.Context, code int, msg string) error {
64
64
-
return e.JSON(code, map[string]string{
70
70
+
func genericError(w http.ResponseWriter, code int, msg string) error {
71
71
+
return writeJSON(w, code, map[string]string{
65
72
"error": msg,
66
73
})
67
74
}
+9
-9
oauth/provider/middleware.go
···
1
1
package provider
2
2
3
3
import (
4
4
-
"github.com/labstack/echo/v4"
4
4
+
"net/http"
5
5
)
6
6
7
7
-
func (p *Provider) BaseMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
8
8
-
return func(e echo.Context) error {
9
9
-
e.Response().Header().Set("cache-control", "no-store")
10
10
-
e.Response().Header().Set("pragma", "no-cache")
7
7
+
func (p *Provider) BaseMiddleware(next http.Handler) http.Handler {
8
8
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
9
9
+
w.Header().Set("cache-control", "no-store")
10
10
+
w.Header().Set("pragma", "no-cache")
11
11
12
12
nonce := p.NextNonce()
13
13
if nonce != "" {
14
14
-
e.Response().Header().Set("DPoP-Nonce", nonce)
15
15
-
e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
14
14
+
w.Header().Set("DPoP-Nonce", nonce)
15
15
+
w.Header().Add("access-control-expose-headers", "DPoP-Nonce")
16
16
}
17
17
18
18
-
return next(e)
19
19
-
}
18
18
+
next.ServeHTTP(w, r)
19
19
+
})
20
20
}
+12
-10
server/handle_account.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
4
5
"time"
5
6
6
7
"github.com/haileyok/cocoon/oauth"
7
8
"github.com/haileyok/cocoon/oauth/constants"
8
9
"github.com/haileyok/cocoon/oauth/provider"
9
10
"github.com/hako/durafmt"
10
10
-
"github.com/labstack/echo/v4"
11
11
)
12
12
13
13
-
func (s *Server) handleAccount(e echo.Context) error {
14
14
-
ctx := e.Request().Context()
13
13
+
func (s *Server) handleAccount(w http.ResponseWriter, r *http.Request) {
14
14
+
ctx := r.Context()
15
15
logger := s.logger.With("name", "handleAuth")
16
16
17
17
-
repo, sess, err := s.getSessionRepoOrErr(e)
17
17
+
repo, sess, err := s.getSessionRepoOrErr(r)
18
18
if err != nil {
19
19
-
return e.Redirect(303, "/account/signin")
19
19
+
http.Redirect(w, r, "/account/signin", 303)
20
20
+
return
20
21
}
21
22
22
23
oldestPossibleSession := time.Now().Add(constants.ConfidentialClientSessionLifetime)
···
25
26
if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil {
26
27
logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err)
27
28
sess.AddFlash("Unable to fetch sessions. See server logs for more details.", "error")
28
28
-
sess.Save(e.Request(), e.Response())
29
29
-
return e.Render(200, "account.html", map[string]any{
30
30
-
"flashes": getFlashesFromSession(e, sess),
29
29
+
sess.Save(r, w)
30
30
+
s.renderTemplate(w, "account.html", map[string]any{
31
31
+
"flashes": getFlashesFromSession(w, r, sess),
31
32
})
33
33
+
return
32
34
}
33
35
34
36
var filtered []provider.OauthToken
···
68
70
})
69
71
}
70
72
71
71
-
return e.Render(200, "account.html", map[string]any{
73
73
+
s.renderTemplate(w, "account.html", map[string]any{
72
74
"Repo": repo,
73
75
"Tokens": tokenInfo,
74
74
-
"flashes": getFlashesFromSession(e, sess),
76
76
+
"flashes": getFlashesFromSession(w, r, sess),
75
77
})
76
78
}
+21
-14
server/handle_account_revoke.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
5
5
+
4
6
"github.com/haileyok/cocoon/internal/helpers"
5
5
-
"github.com/labstack/echo/v4"
6
7
)
7
8
8
9
type AccountRevokeInput struct {
9
10
Token string `form:"token"`
10
11
}
11
12
12
12
-
func (s *Server) handleAccountRevoke(e echo.Context) error {
13
13
-
ctx := e.Request().Context()
14
14
-
logger := s.logger.With("name", "handleAcocuntRevoke")
13
13
+
func (s *Server) handleAccountRevoke(w http.ResponseWriter, r *http.Request) {
14
14
+
ctx := r.Context()
15
15
+
logger := s.logger.With("name", "handleAccountRevoke")
15
16
16
16
-
var req AccountRevokeInput
17
17
-
if err := e.Bind(&req); err != nil {
18
18
-
logger.Error("could not bind account revoke request", "error", err)
19
19
-
return helpers.ServerError(e, nil)
17
17
+
if err := r.ParseForm(); err != nil {
18
18
+
logger.Error("could not parse account revoke form", "error", err)
19
19
+
helpers.ServerError(w, nil)
20
20
+
return
21
21
+
}
22
22
+
23
23
+
req := AccountRevokeInput{
24
24
+
Token: r.FormValue("token"),
20
25
}
21
26
22
22
-
repo, sess, err := s.getSessionRepoOrErr(e)
27
27
+
repo, sess, err := s.getSessionRepoOrErr(r)
23
28
if err != nil {
24
24
-
return e.Redirect(303, "/account/signin")
29
29
+
http.Redirect(w, r, "/account/signin", 303)
30
30
+
return
25
31
}
26
32
27
33
if err := s.db.Exec(ctx, "DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil {
28
34
logger.Error("couldnt delete oauth session for account", "did", repo.Repo.Did, "token", req.Token, "error", err)
29
35
sess.AddFlash("Unable to revoke session. See server logs for more details.", "error")
30
30
-
sess.Save(e.Request(), e.Response())
31
31
-
return e.Redirect(303, "/account")
36
36
+
sess.Save(r, w)
37
37
+
http.Redirect(w, r, "/account", 303)
38
38
+
return
32
39
}
33
40
34
41
sess.AddFlash("Session successfully revoked!", "success")
35
35
-
sess.Save(e.Request(), e.Response())
36
36
-
return e.Redirect(303, "/account")
42
42
+
sess.Save(r, w)
43
43
+
http.Redirect(w, r, "/account", 303)
37
44
}
+55
-39
server/handle_account_signin.go
···
3
3
import (
4
4
"errors"
5
5
"fmt"
6
6
+
"net/http"
6
7
"strings"
7
8
"time"
8
9
···
10
11
"github.com/gorilla/sessions"
11
12
"github.com/haileyok/cocoon/internal/helpers"
12
13
"github.com/haileyok/cocoon/models"
13
13
-
"github.com/labstack/echo-contrib/session"
14
14
-
"github.com/labstack/echo/v4"
15
14
"golang.org/x/crypto/bcrypt"
16
15
"gorm.io/gorm"
17
16
)
···
23
22
QueryParams string `form:"query_params"`
24
23
}
25
24
26
26
-
func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) {
27
27
-
ctx := e.Request().Context()
25
25
+
func (s *Server) getSessionRepoOrErr(r *http.Request) (*models.RepoActor, *sessions.Session, error) {
26
26
+
ctx := r.Context()
28
27
29
29
-
sess, err := session.Get(s.config.SessionCookieKey, e)
28
28
+
sess, err := s.sessions.Get(r, s.config.SessionCookieKey)
30
29
if err != nil {
31
30
return nil, nil, err
32
31
}
···
44
43
return repo, sess, nil
45
44
}
46
45
47
47
-
func getFlashesFromSession(e echo.Context, sess *sessions.Session) map[string]any {
48
48
-
defer sess.Save(e.Request(), e.Response())
46
46
+
func getFlashesFromSession(w http.ResponseWriter, r *http.Request, sess *sessions.Session) map[string]any {
47
47
+
defer sess.Save(r, w)
49
48
return map[string]any{
50
49
"errors": sess.Flashes("error"),
51
50
"successes": sess.Flashes("success"),
···
53
52
}
54
53
}
55
54
56
56
-
func (s *Server) handleAccountSigninGet(e echo.Context) error {
57
57
-
_, sess, err := s.getSessionRepoOrErr(e)
55
55
+
func (s *Server) handleAccountSigninGet(w http.ResponseWriter, r *http.Request) {
56
56
+
_, sess, err := s.getSessionRepoOrErr(r)
58
57
if err == nil {
59
59
-
return e.Redirect(303, "/account")
58
58
+
http.Redirect(w, r, "/account", 303)
59
59
+
return
60
60
}
61
61
62
62
-
return e.Render(200, "signin.html", map[string]any{
63
63
-
"flashes": getFlashesFromSession(e, sess),
64
64
-
"QueryParams": e.QueryParams().Encode(),
62
62
+
s.renderTemplate(w, "signin.html", map[string]any{
63
63
+
"flashes": getFlashesFromSession(w, r, sess),
64
64
+
"QueryParams": r.URL.Query().Encode(),
65
65
})
66
66
}
67
67
68
68
-
func (s *Server) handleAccountSigninPost(e echo.Context) error {
69
69
-
ctx := e.Request().Context()
68
68
+
func (s *Server) handleAccountSigninPost(w http.ResponseWriter, r *http.Request) {
69
69
+
ctx := r.Context()
70
70
logger := s.logger.With("name", "handleAccountSigninPost")
71
71
72
72
-
var req OauthSigninInput
73
73
-
if err := e.Bind(&req); err != nil {
74
74
-
logger.Error("error binding sign in req", "error", err)
75
75
-
return helpers.ServerError(e, nil)
72
72
+
if err := r.ParseForm(); err != nil {
73
73
+
logger.Error("error parsing sign in form", "error", err)
74
74
+
helpers.ServerError(w, nil)
75
75
+
return
76
76
+
}
77
77
+
78
78
+
req := OauthSigninInput{
79
79
+
Username: r.FormValue("username"),
80
80
+
Password: r.FormValue("password"),
81
81
+
AuthFactorToken: r.FormValue("token"),
82
82
+
QueryParams: r.FormValue("query_params"),
76
83
}
77
84
78
78
-
sess, _ := session.Get(s.config.SessionCookieKey, e)
85
85
+
sess, _ := s.sessions.Get(r, s.config.SessionCookieKey)
79
86
80
87
req.Username = strings.ToLower(req.Username)
81
88
var idtype string
···
109
116
} else {
110
117
sess.AddFlash("Something went wrong!", "error")
111
118
}
112
112
-
sess.Save(e.Request(), e.Response())
113
113
-
return e.Redirect(303, "/account/signin"+queryParams)
119
119
+
sess.Save(r, w)
120
120
+
http.Redirect(w, r, "/account/signin"+queryParams, 303)
121
121
+
return
114
122
}
115
123
116
124
if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
···
119
127
} else {
120
128
sess.AddFlash("Something went wrong!", "error")
121
129
}
122
122
-
sess.Save(e.Request(), e.Response())
123
123
-
return e.Redirect(303, "/account/signin"+queryParams)
130
130
+
sess.Save(r, w)
131
131
+
http.Redirect(w, r, "/account/signin"+queryParams, 303)
132
132
+
return
124
133
}
125
134
126
135
// if repo requires 2FA token and one hasn't been provided, return error prompting for one
···
128
137
err = s.createAndSendTwoFactorCode(ctx, repo)
129
138
if err != nil {
130
139
sess.AddFlash("Something went wrong!", "error")
131
131
-
sess.Save(e.Request(), e.Response())
132
132
-
return e.Redirect(303, "/account/signin"+queryParams)
140
140
+
sess.Save(r, w)
141
141
+
http.Redirect(w, r, "/account/signin"+queryParams, 303)
142
142
+
return
133
143
}
134
144
135
145
sess.AddFlash("requires 2FA token", "tokenrequired")
136
136
-
sess.Save(e.Request(), e.Response())
137
137
-
return e.Redirect(303, "/account/signin"+queryParams)
146
146
+
sess.Save(r, w)
147
147
+
http.Redirect(w, r, "/account/signin"+queryParams, 303)
148
148
+
return
138
149
}
139
150
140
140
-
// if 2FAis required, now check that the one provided is valid
151
151
+
// if 2FA is required, now check that the one provided is valid
141
152
if repo.TwoFactorType != models.TwoFactorTypeNone {
142
153
if repo.TwoFactorCode == nil || repo.TwoFactorCodeExpiresAt == nil {
143
154
err = s.createAndSendTwoFactorCode(ctx, repo)
144
155
if err != nil {
145
156
sess.AddFlash("Something went wrong!", "error")
146
146
-
sess.Save(e.Request(), e.Response())
147
147
-
return e.Redirect(303, "/account/signin"+queryParams)
157
157
+
sess.Save(r, w)
158
158
+
http.Redirect(w, r, "/account/signin"+queryParams, 303)
159
159
+
return
148
160
}
149
161
150
162
sess.AddFlash("requires 2FA token", "tokenrequired")
151
151
-
sess.Save(e.Request(), e.Response())
152
152
-
return e.Redirect(303, "/account/signin"+queryParams)
163
163
+
sess.Save(r, w)
164
164
+
http.Redirect(w, r, "/account/signin"+queryParams, 303)
165
165
+
return
153
166
}
154
167
155
168
if *repo.TwoFactorCode != req.AuthFactorToken {
156
156
-
return helpers.InvalidTokenError(e)
169
169
+
helpers.InvalidTokenError(w)
170
170
+
return
157
171
}
158
172
159
173
if time.Now().UTC().After(*repo.TwoFactorCodeExpiresAt) {
160
160
-
return helpers.ExpiredTokenError(e)
174
174
+
helpers.ExpiredTokenError(w)
175
175
+
return
161
176
}
162
177
}
163
178
···
170
185
sess.Values = map[any]any{}
171
186
sess.Values["did"] = repo.Repo.Did
172
187
173
173
-
if err := sess.Save(e.Request(), e.Response()); err != nil {
174
174
-
return err
188
188
+
if err := sess.Save(r, w); err != nil {
189
189
+
helpers.ServerError(w, nil)
190
190
+
return
175
191
}
176
192
177
193
if queryParams != "" {
178
178
-
return e.Redirect(303, "/oauth/authorize"+queryParams)
194
194
+
http.Redirect(w, r, "/oauth/authorize"+queryParams, 303)
179
195
} else {
180
180
-
return e.Redirect(303, "/account")
196
196
+
http.Redirect(w, r, "/account", 303)
181
197
}
182
198
}
+12
-10
server/handle_account_signout.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
5
5
+
4
6
"github.com/gorilla/sessions"
5
5
-
"github.com/labstack/echo-contrib/session"
6
6
-
"github.com/labstack/echo/v4"
7
7
)
8
8
9
9
-
func (s *Server) handleAccountSignout(e echo.Context) error {
10
10
-
sess, err := session.Get(s.config.SessionCookieKey, e)
9
9
+
func (s *Server) handleAccountSignout(w http.ResponseWriter, r *http.Request) {
10
10
+
sess, err := s.sessions.Get(r, s.config.SessionCookieKey)
11
11
if err != nil {
12
12
-
return err
12
12
+
http.Error(w, "session error", http.StatusInternalServerError)
13
13
+
return
13
14
}
14
15
15
16
sess.Options = &sessions.Options{
···
20
21
21
22
sess.Values = map[any]any{}
22
23
23
23
-
if err := sess.Save(e.Request(), e.Response()); err != nil {
24
24
-
return err
24
24
+
if err := sess.Save(r, w); err != nil {
25
25
+
http.Error(w, "session save error", http.StatusInternalServerError)
26
26
+
return
25
27
}
26
28
27
27
-
reqUri := e.QueryParam("request_uri")
29
29
+
reqUri := r.URL.Query().Get("request_uri")
28
30
29
31
redirect := "/account/signin"
30
32
if reqUri != "" {
31
31
-
redirect += "?" + e.QueryParams().Encode()
33
33
+
redirect += "?" + r.URL.Query().Encode()
32
34
}
33
35
34
34
-
return e.Redirect(303, redirect)
36
36
+
http.Redirect(w, r, redirect, 303)
35
37
}
+4
-4
server/handle_actor_get_preferences.go
···
2
2
3
3
import (
4
4
"encoding/json"
5
5
+
"net/http"
5
6
6
7
"github.com/haileyok/cocoon/models"
7
7
-
"github.com/labstack/echo/v4"
8
8
)
9
9
10
10
// This is kinda lame. Not great to implement app.bsky in the pds, but alas
11
11
12
12
-
func (s *Server) handleActorGetPreferences(e echo.Context) error {
13
13
-
repo := e.Get("repo").(*models.RepoActor)
12
12
+
func (s *Server) handleActorGetPreferences(w http.ResponseWriter, r *http.Request) {
13
13
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
14
14
15
15
var prefs map[string]any
16
16
err := json.Unmarshal(repo.Preferences, &prefs)
···
20
20
}
21
21
}
22
22
23
23
-
return e.JSON(200, prefs)
23
23
+
s.writeJSON(w, 200, prefs)
24
24
}
+12
-9
server/handle_actor_put_preferences.go
···
2
2
3
3
import (
4
4
"encoding/json"
5
5
+
"net/http"
5
6
6
7
"github.com/haileyok/cocoon/models"
7
7
-
"github.com/labstack/echo/v4"
8
8
)
9
9
10
10
// This is kinda lame. Not great to implement app.bsky in the pds, but alas
11
11
12
12
-
func (s *Server) handleActorPutPreferences(e echo.Context) error {
13
13
-
ctx := e.Request().Context()
12
12
+
func (s *Server) handleActorPutPreferences(w http.ResponseWriter, r *http.Request) {
13
13
+
ctx := r.Context()
14
14
15
15
-
repo := e.Get("repo").(*models.RepoActor)
15
15
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
16
16
17
17
var prefs map[string]any
18
18
-
if err := json.NewDecoder(e.Request().Body).Decode(&prefs); err != nil {
19
19
-
return err
18
18
+
if err := json.NewDecoder(r.Body).Decode(&prefs); err != nil {
19
19
+
http.Error(w, err.Error(), http.StatusBadRequest)
20
20
+
return
20
21
}
21
22
22
23
b, err := json.Marshal(prefs)
23
24
if err != nil {
24
24
-
return err
25
25
+
http.Error(w, err.Error(), http.StatusInternalServerError)
26
26
+
return
25
27
}
26
28
27
29
if err := s.db.Exec(ctx, "UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil {
28
28
-
return err
30
30
+
http.Error(w, err.Error(), http.StatusInternalServerError)
31
31
+
return
29
32
}
30
33
31
31
-
return nil
34
34
+
w.WriteHeader(http.StatusOK)
32
35
}
+4
-4
server/handle_age_assurance.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
4
5
"time"
5
6
6
7
"github.com/bluesky-social/indigo/util"
7
8
"github.com/haileyok/cocoon/models"
8
8
-
"github.com/labstack/echo/v4"
9
9
)
10
10
11
11
-
func (s *Server) handleAgeAssurance(e echo.Context) error {
12
12
-
repo := e.Get("repo").(*models.RepoActor)
11
11
+
func (s *Server) handleAgeAssurance(w http.ResponseWriter, r *http.Request) {
12
12
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
13
13
14
14
resp := map[string]any{
15
15
"state": map[string]any{
···
22
22
},
23
23
}
24
24
25
25
-
return e.JSON(200, resp)
25
25
+
s.writeJSON(w, 200, resp)
26
26
}
+3
-3
server/handle_health.go
···
1
1
package server
2
2
3
3
-
import "github.com/labstack/echo/v4"
3
3
+
import "net/http"
4
4
5
5
-
func (s *Server) handleHealth(e echo.Context) error {
6
6
-
return e.JSON(200, map[string]string{
5
5
+
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
6
6
+
s.writeJSON(w, 200, map[string]string{
7
7
"version": "cocoon " + s.config.Version,
8
8
})
9
9
}
+10
-7
server/handle_identity_get_recommended_did_credentials.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
5
5
+
4
6
"github.com/bluesky-social/indigo/atproto/atcrypto"
5
7
"github.com/haileyok/cocoon/internal/helpers"
6
8
"github.com/haileyok/cocoon/models"
7
7
-
"github.com/labstack/echo/v4"
8
9
)
9
10
10
10
-
func (s *Server) handleGetRecommendedDidCredentials(e echo.Context) error {
11
11
+
func (s *Server) handleGetRecommendedDidCredentials(w http.ResponseWriter, r *http.Request) {
11
12
logger := s.logger.With("name", "handleIdentityGetRecommendedDidCredentials")
12
13
13
13
-
repo := e.Get("repo").(*models.RepoActor)
14
14
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
14
15
k, err := atcrypto.ParsePrivateBytesK256(repo.SigningKey)
15
16
if err != nil {
16
17
logger.Error("error parsing key", "error", err)
17
17
-
return helpers.ServerError(e, nil)
18
18
+
helpers.ServerError(w, nil)
19
19
+
return
18
20
}
19
21
creds, err := s.plcClient.CreateDidCredentials(k, "", repo.Actor.Handle)
20
22
if err != nil {
21
21
-
logger.Error("error crating did credentials", "error", err)
22
22
-
return helpers.ServerError(e, nil)
23
23
+
logger.Error("error creating did credentials", "error", err)
24
24
+
helpers.ServerError(w, nil)
25
25
+
return
23
26
}
24
27
25
25
-
return e.JSON(200, creds)
28
28
+
s.writeJSON(w, 200, creds)
26
29
}
+9
-7
server/handle_identity_request_plc_operation.go
···
2
2
3
3
import (
4
4
"fmt"
5
5
+
"net/http"
5
6
"time"
6
7
7
8
"github.com/haileyok/cocoon/internal/helpers"
8
9
"github.com/haileyok/cocoon/models"
9
9
-
"github.com/labstack/echo/v4"
10
10
)
11
11
12
12
-
func (s *Server) handleIdentityRequestPlcOperationSignature(e echo.Context) error {
13
13
-
ctx := e.Request().Context()
12
12
+
func (s *Server) handleIdentityRequestPlcOperationSignature(w http.ResponseWriter, r *http.Request) {
13
13
+
ctx := r.Context()
14
14
logger := s.logger.With("name", "handleIdentityRequestPlcOperationSignature")
15
15
16
16
-
urepo := e.Get("repo").(*models.RepoActor)
16
16
+
urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
17
17
18
18
code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
19
19
eat := time.Now().Add(10 * time.Minute).UTC()
20
20
21
21
if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = ?, plc_operation_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
22
22
logger.Error("error updating user", "error", err)
23
23
-
return helpers.ServerError(e, nil)
23
23
+
helpers.ServerError(w, nil)
24
24
+
return
24
25
}
25
26
26
27
if err := s.sendPlcTokenReset(urepo.Email, urepo.Handle, code); err != nil {
27
28
logger.Error("error sending mail", "error", err)
28
28
-
return helpers.ServerError(e, nil)
29
29
+
helpers.ServerError(w, nil)
30
30
+
return
29
31
}
30
32
31
31
-
return e.NoContent(200)
33
33
+
w.WriteHeader(http.StatusOK)
32
34
}
+26
-16
server/handle_identity_sign_plc_operation.go
···
2
2
3
3
import (
4
4
"context"
5
5
+
"encoding/json"
6
6
+
"net/http"
5
7
"strings"
6
8
"time"
7
9
···
11
13
"github.com/haileyok/cocoon/internal/helpers"
12
14
"github.com/haileyok/cocoon/models"
13
15
"github.com/haileyok/cocoon/plc"
14
14
-
"github.com/labstack/echo/v4"
15
16
)
16
17
17
18
type ComAtprotoSignPlcOperationRequest struct {
···
26
27
Operation plc.Operation `json:"operation"`
27
28
}
28
29
29
29
-
func (s *Server) handleSignPlcOperation(e echo.Context) error {
30
30
+
func (s *Server) handleSignPlcOperation(w http.ResponseWriter, r *http.Request) {
30
31
logger := s.logger.With("name", "handleSignPlcOperation")
31
32
32
32
-
repo := e.Get("repo").(*models.RepoActor)
33
33
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
33
34
34
35
var req ComAtprotoSignPlcOperationRequest
35
35
-
if err := e.Bind(&req); err != nil {
36
36
-
logger.Error("error binding", "error", err)
37
37
-
return helpers.ServerError(e, nil)
36
36
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
37
37
+
logger.Error("error decoding", "error", err)
38
38
+
helpers.ServerError(w, nil)
39
39
+
return
38
40
}
39
41
40
42
if !strings.HasPrefix(repo.Repo.Did, "did:plc:") {
41
41
-
return helpers.InputError(e, nil)
43
43
+
helpers.InputError(w, nil)
44
44
+
return
42
45
}
43
46
44
47
if repo.PlcOperationCode == nil || repo.PlcOperationCodeExpiresAt == nil {
45
45
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
48
48
+
helpers.InputError(w, to.StringPtr("InvalidToken"))
49
49
+
return
46
50
}
47
51
48
52
if *repo.PlcOperationCode != req.Token {
49
49
-
return helpers.InvalidTokenError(e)
53
53
+
helpers.InvalidTokenError(w)
54
54
+
return
50
55
}
51
56
52
57
if time.Now().UTC().After(*repo.PlcOperationCodeExpiresAt) {
53
53
-
return helpers.ExpiredTokenError(e)
58
58
+
helpers.ExpiredTokenError(w)
59
59
+
return
54
60
}
55
61
56
56
-
ctx := context.WithValue(e.Request().Context(), "skip-cache", true)
62
62
+
ctx := context.WithValue(r.Context(), "skip-cache", true)
57
63
log, err := identity.FetchDidAuditLog(ctx, nil, repo.Repo.Did)
58
64
if err != nil {
59
65
logger.Error("error fetching doc", "error", err)
60
60
-
return helpers.ServerError(e, nil)
66
66
+
helpers.ServerError(w, nil)
67
67
+
return
61
68
}
62
69
63
70
latest := log[len(log)-1]
···
86
93
k, err := atcrypto.ParsePrivateBytesK256(repo.SigningKey)
87
94
if err != nil {
88
95
logger.Error("error parsing signing key", "error", err)
89
89
-
return helpers.ServerError(e, nil)
96
96
+
helpers.ServerError(w, nil)
97
97
+
return
90
98
}
91
99
92
100
if err := s.plcClient.SignOp(k, &op); err != nil {
93
101
logger.Error("error signing plc operation", "error", err)
94
94
-
return helpers.ServerError(e, nil)
102
102
+
helpers.ServerError(w, nil)
103
103
+
return
95
104
}
96
105
97
106
if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = NULL, plc_operation_code_expires_at = NULL WHERE did = ?", nil, repo.Repo.Did).Error; err != nil {
98
107
logger.Error("error updating repo", "error", err)
99
99
-
return helpers.ServerError(e, nil)
108
108
+
helpers.ServerError(w, nil)
109
109
+
return
100
110
}
101
111
102
102
-
return e.JSON(200, ComAtprotoSignPlcOperationResponse{
112
112
+
s.writeJSON(w, 200, ComAtprotoSignPlcOperationResponse{
103
113
Operation: op,
104
114
})
105
115
}
+32
-21
server/handle_identity_submit_plc_operation.go
···
2
2
3
3
import (
4
4
"context"
5
5
+
"encoding/json"
6
6
+
"net/http"
5
7
"slices"
6
8
"strings"
7
9
"time"
···
13
15
"github.com/haileyok/cocoon/internal/helpers"
14
16
"github.com/haileyok/cocoon/models"
15
17
"github.com/haileyok/cocoon/plc"
16
16
-
"github.com/labstack/echo/v4"
17
18
)
18
19
19
20
type ComAtprotoSubmitPlcOperationRequest struct {
20
21
Operation plc.Operation `json:"operation"`
21
22
}
22
23
23
23
-
func (s *Server) handleSubmitPlcOperation(e echo.Context) error {
24
24
+
func (s *Server) handleSubmitPlcOperation(w http.ResponseWriter, r *http.Request) {
24
25
logger := s.logger.With("name", "handleIdentitySubmitPlcOperation")
25
26
26
26
-
repo := e.Get("repo").(*models.RepoActor)
27
27
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
27
28
28
29
var req ComAtprotoSubmitPlcOperationRequest
29
29
-
if err := e.Bind(&req); err != nil {
30
30
-
logger.Error("error binding", "error", err)
31
31
-
return helpers.ServerError(e, nil)
30
30
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
31
31
+
logger.Error("error decoding", "error", err)
32
32
+
helpers.ServerError(w, nil)
33
33
+
return
32
34
}
33
35
34
34
-
if err := e.Validate(req); err != nil {
35
35
-
return helpers.InputError(e, nil)
36
36
+
if err := s.validator.Struct(req); err != nil {
37
37
+
helpers.InputError(w, nil)
38
38
+
return
36
39
}
40
40
+
37
41
if !strings.HasPrefix(repo.Repo.Did, "did:plc:") {
38
38
-
return helpers.InputError(e, nil)
42
42
+
helpers.InputError(w, nil)
43
43
+
return
39
44
}
40
45
41
46
op := req.Operation
···
43
48
k, err := atcrypto.ParsePrivateBytesK256(repo.SigningKey)
44
49
if err != nil {
45
50
logger.Error("error parsing key", "error", err)
46
46
-
return helpers.ServerError(e, nil)
51
51
+
helpers.ServerError(w, nil)
52
52
+
return
47
53
}
48
54
required, err := s.plcClient.CreateDidCredentials(k, "", repo.Actor.Handle)
49
55
if err != nil {
50
50
-
logger.Error("error crating did credentials", "error", err)
51
51
-
return helpers.ServerError(e, nil)
56
56
+
logger.Error("error creating did credentials", "error", err)
57
57
+
helpers.ServerError(w, nil)
58
58
+
return
52
59
}
53
60
54
61
for _, expectedKey := range required.RotationKeys {
55
62
if !slices.Contains(op.RotationKeys, expectedKey) {
56
56
-
return helpers.InputError(e, nil)
63
63
+
helpers.InputError(w, nil)
64
64
+
return
57
65
}
58
66
}
59
67
if op.Services["atproto_pds"].Type != "AtprotoPersonalDataServer" {
60
60
-
return helpers.InputError(e, nil)
68
68
+
helpers.InputError(w, nil)
69
69
+
return
61
70
}
62
71
if op.Services["atproto_pds"].Endpoint != required.Services["atproto_pds"].Endpoint {
63
63
-
return helpers.InputError(e, nil)
72
72
+
helpers.InputError(w, nil)
73
73
+
return
64
74
}
65
75
if op.VerificationMethods["atproto"] != required.VerificationMethods["atproto"] {
66
66
-
return helpers.InputError(e, nil)
76
76
+
helpers.InputError(w, nil)
77
77
+
return
67
78
}
68
79
if op.AlsoKnownAs[0] != required.AlsoKnownAs[0] {
69
69
-
return helpers.InputError(e, nil)
80
80
+
helpers.InputError(w, nil)
81
81
+
return
70
82
}
71
83
72
72
-
if err := s.plcClient.SendOperation(e.Request().Context(), repo.Repo.Did, &op); err != nil {
73
73
-
return err
84
84
+
if err := s.plcClient.SendOperation(r.Context(), repo.Repo.Did, &op); err != nil {
85
85
+
helpers.ServerError(w, nil)
86
86
+
return
74
87
}
75
88
76
89
if err := s.passport.BustDoc(context.TODO(), repo.Repo.Did); err != nil {
···
84
97
Time: time.Now().Format(util.ISO8601),
85
98
},
86
99
})
87
87
-
88
88
-
return nil
89
100
}
+23
-17
server/handle_identity_update_handle.go
···
2
2
3
3
import (
4
4
"context"
5
5
+
"encoding/json"
6
6
+
"net/http"
5
7
"strings"
6
8
"time"
7
9
···
14
16
"github.com/haileyok/cocoon/internal/helpers"
15
17
"github.com/haileyok/cocoon/models"
16
18
"github.com/haileyok/cocoon/plc"
17
17
-
"github.com/labstack/echo/v4"
18
19
)
19
20
20
21
type ComAtprotoIdentityUpdateHandleRequest struct {
21
22
Handle string `json:"handle" validate:"atproto-handle"`
22
23
}
23
24
24
24
-
func (s *Server) handleIdentityUpdateHandle(e echo.Context) error {
25
25
+
func (s *Server) handleIdentityUpdateHandle(w http.ResponseWriter, r *http.Request) {
25
26
logger := s.logger.With("name", "handleIdentityUpdateHandle")
26
27
27
27
-
repo := e.Get("repo").(*models.RepoActor)
28
28
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
28
29
29
30
var req ComAtprotoIdentityUpdateHandleRequest
30
30
-
if err := e.Bind(&req); err != nil {
31
31
-
logger.Error("error binding", "error", err)
32
32
-
return helpers.ServerError(e, nil)
31
31
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
32
32
+
logger.Error("error decoding", "error", err)
33
33
+
helpers.ServerError(w, nil)
34
34
+
return
33
35
}
34
36
35
37
req.Handle = strings.ToLower(req.Handle)
36
38
37
37
-
if err := e.Validate(req); err != nil {
38
38
-
return helpers.InputError(e, nil)
39
39
+
if err := s.validator.Struct(req); err != nil {
40
40
+
helpers.InputError(w, nil)
41
41
+
return
39
42
}
40
43
41
41
-
ctx := context.WithValue(e.Request().Context(), "skip-cache", true)
44
44
+
ctx := context.WithValue(r.Context(), "skip-cache", true)
42
45
43
46
if strings.HasPrefix(repo.Repo.Did, "did:plc:") {
44
47
log, err := identity.FetchDidAuditLog(ctx, nil, repo.Repo.Did)
45
48
if err != nil {
46
49
logger.Error("error fetching doc", "error", err)
47
47
-
return helpers.ServerError(e, nil)
50
50
+
helpers.ServerError(w, nil)
51
51
+
return
48
52
}
49
53
50
54
latest := log[len(log)-1]
···
71
75
k, err := atcrypto.ParsePrivateBytesK256(repo.SigningKey)
72
76
if err != nil {
73
77
logger.Error("error parsing signing key", "error", err)
74
74
-
return helpers.ServerError(e, nil)
78
78
+
helpers.ServerError(w, nil)
79
79
+
return
75
80
}
76
81
77
82
if err := s.plcClient.SignOp(k, &op); err != nil {
78
78
-
return err
83
83
+
helpers.ServerError(w, nil)
84
84
+
return
79
85
}
80
86
81
81
-
if err := s.plcClient.SendOperation(e.Request().Context(), repo.Repo.Did, &op); err != nil {
82
82
-
return err
87
87
+
if err := s.plcClient.SendOperation(r.Context(), repo.Repo.Did, &op); err != nil {
88
88
+
helpers.ServerError(w, nil)
89
89
+
return
83
90
}
84
91
}
85
92
···
98
105
99
106
if err := s.db.Exec(ctx, "UPDATE actors SET handle = ? WHERE did = ?", nil, req.Handle, repo.Repo.Did).Error; err != nil {
100
107
logger.Error("error updating handle in db", "error", err)
101
101
-
return helpers.ServerError(e, nil)
108
108
+
helpers.ServerError(w, nil)
109
109
+
return
102
110
}
103
103
-
104
104
-
return nil
105
111
}
+27
-19
server/handle_import_repo.go
···
4
4
"bytes"
5
5
"context"
6
6
"io"
7
7
+
"net/http"
7
8
"slices"
8
9
"strings"
9
10
···
13
14
blocks "github.com/ipfs/go-block-format"
14
15
"github.com/ipfs/go-cid"
15
16
"github.com/ipld/go-car"
16
16
-
"github.com/labstack/echo/v4"
17
17
)
18
18
19
19
-
func (s *Server) handleRepoImportRepo(e echo.Context) error {
20
20
-
ctx := e.Request().Context()
19
19
+
func (s *Server) handleRepoImportRepo(w http.ResponseWriter, r *http.Request) {
20
20
+
ctx := r.Context()
21
21
logger := s.logger.With("name", "handleImportRepo")
22
22
23
23
-
urepo := e.Get("repo").(*models.RepoActor)
23
23
+
urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
24
24
25
25
-
b, err := io.ReadAll(e.Request().Body)
25
25
+
b, err := io.ReadAll(r.Body)
26
26
if err != nil {
27
27
logger.Error("could not read bytes in import request", "error", err)
28
28
-
return helpers.ServerError(e, nil)
28
28
+
helpers.ServerError(w, nil)
29
29
+
return
29
30
}
30
31
31
32
bs := s.getBlockstore(urepo.Repo.Did)
···
33
34
cs, err := car.NewCarReader(bytes.NewReader(b))
34
35
if err != nil {
35
36
logger.Error("could not read car in import request", "error", err)
36
36
-
return helpers.ServerError(e, nil)
37
37
+
helpers.ServerError(w, nil)
38
38
+
return
37
39
}
38
40
39
41
orderedBlocks := []blocks.Block{}
40
42
currBlock, err := cs.Next()
41
43
if err != nil {
42
44
logger.Error("could not get first block from car", "error", err)
43
43
-
return helpers.ServerError(e, nil)
45
45
+
helpers.ServerError(w, nil)
46
46
+
return
44
47
}
45
48
currBlockCt := 1
46
49
···
56
59
57
60
if err := bs.PutMany(context.TODO(), orderedBlocks); err != nil {
58
61
logger.Error("could not insert blocks", "error", err)
59
59
-
return helpers.ServerError(e, nil)
62
62
+
helpers.ServerError(w, nil)
63
63
+
return
60
64
}
61
65
62
66
r, err := openRepo(context.TODO(), bs, cs.Header.Roots[0], urepo.Repo.Did)
63
67
if err != nil {
64
68
logger.Error("could not open repo", "error", err)
65
65
-
return helpers.ServerError(e, nil)
69
69
+
helpers.ServerError(w, nil)
70
70
+
return
66
71
}
67
72
68
73
tx := s.db.Begin(ctx)
···
73
78
pts := strings.Split(string(key), "/")
74
79
nsid := pts[0]
75
80
rkey := pts[1]
76
76
-
cidStr := cid.String()
77
77
-
b, err := bs.Get(context.TODO(), cid)
81
81
+
cidStr := c.String()
82
82
+
blkData, err := bs.Get(context.TODO(), c)
78
83
if err != nil {
79
84
logger.Error("record bytes don't exist in blockstore", "error", err)
80
80
-
return helpers.ServerError(e, nil)
85
85
+
return err
81
86
}
82
87
83
88
rec := models.Record{
···
86
91
Nsid: nsid,
87
92
Rkey: rkey,
88
93
Cid: cidStr,
89
89
-
Value: b.RawData(),
94
94
+
Value: blkData.RawData(),
90
95
}
91
96
92
97
if err := tx.Save(rec).Error; err != nil {
···
96
101
return nil
97
102
}); err != nil {
98
103
tx.Rollback()
99
99
-
logger.Error("record bytes don't exist in blockstore", "error", err)
100
100
-
return helpers.ServerError(e, nil)
104
104
+
logger.Error("error iterating repo blocks", "error", err)
105
105
+
helpers.ServerError(w, nil)
106
106
+
return
101
107
}
102
108
103
109
tx.Commit()
···
105
111
root, rev, err := commitRepo(context.TODO(), bs, r, urepo.Repo.SigningKey)
106
112
if err != nil {
107
113
logger.Error("error committing", "error", err)
108
108
-
return helpers.ServerError(e, nil)
114
114
+
helpers.ServerError(w, nil)
115
115
+
return
109
116
}
110
117
111
118
if err := s.UpdateRepo(context.TODO(), urepo.Repo.Did, root, rev); err != nil {
112
119
logger.Error("error updating repo after commit", "error", err)
113
113
-
return helpers.ServerError(e, nil)
120
120
+
helpers.ServerError(w, nil)
121
121
+
return
114
122
}
115
123
116
116
-
return nil
124
124
+
w.WriteHeader(http.StatusOK)
117
125
}
+6
-7
server/handle_label_query_labels.go
···
1
1
package server
2
2
3
3
-
import (
4
4
-
"github.com/labstack/echo/v4"
5
5
-
)
3
3
+
import "net/http"
6
4
7
5
type Label struct {
8
6
Ver *int `json:"ver,omitempty"`
···
21
19
Labels []Label `json:"labels"`
22
20
}
23
21
24
24
-
func (s *Server) handleLabelQueryLabels(e echo.Context) error {
25
25
-
svc := e.Request().Header.Get("atproto-proxy")
22
22
+
func (s *Server) handleLabelQueryLabels(w http.ResponseWriter, r *http.Request) {
23
23
+
svc := r.Header.Get("atproto-proxy")
26
24
if svc != "" || s.config.FallbackProxy != "" {
27
27
-
return s.handleProxy(e)
25
25
+
s.handleProxy(w, r)
26
26
+
return
28
27
}
29
28
30
30
-
return e.JSON(200, ComAtprotoLabelQueryLabelsResponse{
29
29
+
s.writeJSON(w, 200, ComAtprotoLabelQueryLabelsResponse{
31
30
Cursor: nil,
32
31
Labels: []Label{},
33
32
})
+95
-53
server/handle_oauth_authorize.go
···
2
2
3
3
import (
4
4
"fmt"
5
5
+
"net/http"
5
6
"net/url"
6
7
"strings"
7
8
"time"
···
11
12
"github.com/haileyok/cocoon/oauth"
12
13
"github.com/haileyok/cocoon/oauth/constants"
13
14
"github.com/haileyok/cocoon/oauth/provider"
14
14
-
"github.com/labstack/echo/v4"
15
15
)
16
16
17
17
type HandleOauthAuthorizeGetInput struct {
18
18
RequestUri string `query:"request_uri"`
19
19
}
20
20
21
21
-
func (s *Server) handleOauthAuthorizeGet(e echo.Context) error {
22
22
-
ctx := e.Request().Context()
21
21
+
func (s *Server) handleOauthAuthorizeGet(w http.ResponseWriter, r *http.Request) {
22
22
+
ctx := r.Context()
23
23
24
24
logger := s.logger.With("name", "handleOauthAuthorizeGet")
25
25
26
26
-
var input HandleOauthAuthorizeGetInput
27
27
-
if err := e.Bind(&input); err != nil {
28
28
-
logger.Error("error binding request", "err", err)
29
29
-
return fmt.Errorf("error binding request")
30
30
-
}
26
26
+
requestUri := r.URL.Query().Get("request_uri")
31
27
32
28
var reqId string
33
33
-
if input.RequestUri != "" {
34
34
-
id, err := oauth.DecodeRequestUri(input.RequestUri)
29
29
+
if requestUri != "" {
30
30
+
id, err := oauth.DecodeRequestUri(requestUri)
35
31
if err != nil {
36
36
-
logger.Error("no request uri found in input", "url", e.Request().URL.String())
37
37
-
return helpers.InputError(e, to.StringPtr("no request uri"))
32
32
+
logger.Error("no request uri found in input", "url", r.URL.String())
33
33
+
helpers.InputError(w, to.StringPtr("no request uri"))
34
34
+
return
38
35
}
39
36
reqId = id
40
37
} else {
41
41
-
var parRequest provider.ParRequest
42
42
-
if err := e.Bind(&parRequest); err != nil {
43
43
-
s.logger.Error("error binding for standard auth request", "error", err)
44
44
-
return helpers.InputError(e, to.StringPtr("InvalidRequest"))
38
38
+
parRequest := provider.ParRequest{
39
39
+
AuthenticateClientRequestBase: provider.AuthenticateClientRequestBase{
40
40
+
ClientID: r.URL.Query().Get("client_id"),
41
41
+
},
42
42
+
ResponseType: r.URL.Query().Get("response_type"),
43
43
+
State: r.URL.Query().Get("state"),
44
44
+
RedirectURI: r.URL.Query().Get("redirect_uri"),
45
45
+
Scope: r.URL.Query().Get("scope"),
46
46
+
CodeChallengeMethod: r.URL.Query().Get("code_challenge_method"),
47
47
+
}
48
48
+
if v := r.URL.Query().Get("code_challenge"); v != "" {
49
49
+
parRequest.CodeChallenge = to.StringPtr(v)
50
50
+
}
51
51
+
if v := r.URL.Query().Get("login_hint"); v != "" {
52
52
+
parRequest.LoginHint = to.StringPtr(v)
53
53
+
}
54
54
+
if v := r.URL.Query().Get("dpop_jkt"); v != "" {
55
55
+
parRequest.DpopJkt = to.StringPtr(v)
56
56
+
}
57
57
+
if v := r.URL.Query().Get("response_mode"); v != "" {
58
58
+
parRequest.ResponseMode = to.StringPtr(v)
45
59
}
46
60
47
47
-
if err := e.Validate(parRequest); err != nil {
61
61
+
if err := s.validator.Struct(parRequest); err != nil {
48
62
// render page for logged out dev
49
63
if s.config.Version == "dev" && parRequest.ClientID == "" {
50
50
-
return e.Render(200, "authorize.html", map[string]any{
64
64
+
s.renderTemplate(w, "authorize.html", map[string]any{
51
65
"Scopes": []string{"atproto", "transition:generic"},
52
66
"AppName": "DEV MODE AUTHORIZATION PAGE",
53
67
"Handle": "paula.cocoon.social",
54
68
"RequestUri": "",
55
69
})
70
70
+
return
56
71
}
57
57
-
return helpers.InputError(e, to.StringPtr("no request uri and invalid parameters"))
72
72
+
helpers.InputError(w, to.StringPtr("no request uri and invalid parameters"))
73
73
+
return
58
74
}
59
75
60
76
client, clientAuth, err := s.oauthProvider.AuthenticateClient(ctx, parRequest.AuthenticateClientRequestBase, nil, &provider.AuthenticateClientOptions{
···
62
78
})
63
79
if err != nil {
64
80
s.logger.Error("error authenticating client in standard request", "client_id", parRequest.ClientID, "error", err)
65
65
-
return helpers.ServerError(e, to.StringPtr(err.Error()))
81
81
+
helpers.ServerError(w, to.StringPtr(err.Error()))
82
82
+
return
66
83
}
67
84
68
85
if parRequest.DpopJkt == nil {
69
86
if client.Metadata.DpopBoundAccessTokens {
87
87
+
// nothing to do
70
88
}
71
89
} else {
72
90
if !client.Metadata.DpopBoundAccessTokens {
73
91
msg := "dpop bound access tokens are not enabled for this client"
74
74
-
return helpers.InputError(e, &msg)
92
92
+
helpers.InputError(w, &msg)
93
93
+
return
75
94
}
76
95
}
77
96
···
88
107
89
108
if err := s.db.Create(ctx, authRequest, nil).Error; err != nil {
90
109
s.logger.Error("error creating auth request in db", "error", err)
91
91
-
return helpers.ServerError(e, nil)
110
110
+
helpers.ServerError(w, nil)
111
111
+
return
92
112
}
93
113
94
94
-
input.RequestUri = oauth.EncodeRequestUri(id)
114
114
+
requestUri = oauth.EncodeRequestUri(id)
95
115
reqId = id
96
96
-
97
116
}
98
117
99
99
-
repo, _, err := s.getSessionRepoOrErr(e)
118
118
+
repo, _, err := s.getSessionRepoOrErr(r)
100
119
if err != nil {
101
101
-
return e.Redirect(303, "/account/signin?"+e.QueryParams().Encode())
120
120
+
http.Redirect(w, r, "/account/signin?"+r.URL.Query().Encode(), 303)
121
121
+
return
102
122
}
103
123
104
124
var req provider.OauthAuthorizationRequest
105
125
if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil {
106
106
-
return helpers.ServerError(e, to.StringPtr(err.Error()))
126
126
+
helpers.ServerError(w, to.StringPtr(err.Error()))
127
127
+
return
107
128
}
108
129
109
109
-
clientId := e.QueryParam("client_id")
130
130
+
clientId := r.URL.Query().Get("client_id")
110
131
if clientId != req.ClientId {
111
111
-
return helpers.InputError(e, to.StringPtr("client id does not match the client id for the supplied request"))
132
132
+
helpers.InputError(w, to.StringPtr("client id does not match the client id for the supplied request"))
133
133
+
return
112
134
}
113
135
114
114
-
client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), req.ClientId)
136
136
+
client, err := s.oauthProvider.ClientManager.GetClient(r.Context(), req.ClientId)
115
137
if err != nil {
116
116
-
return helpers.ServerError(e, to.StringPtr(err.Error()))
138
138
+
helpers.ServerError(w, to.StringPtr(err.Error()))
139
139
+
return
117
140
}
118
141
119
142
scopes := strings.Split(req.Parameters.Scope, " ")
···
122
145
data := map[string]any{
123
146
"Scopes": scopes,
124
147
"AppName": appName,
125
125
-
"RequestUri": input.RequestUri,
126
126
-
"QueryParams": e.QueryParams().Encode(),
148
148
+
"RequestUri": requestUri,
149
149
+
"QueryParams": r.URL.Query().Encode(),
127
150
"Handle": repo.Actor.Handle,
128
151
}
129
152
130
130
-
return e.Render(200, "authorize.html", data)
153
153
+
s.renderTemplate(w, "authorize.html", data)
131
154
}
132
155
133
156
type OauthAuthorizePostRequest struct {
···
135
158
AcceptOrRejct string `form:"accept_or_reject"`
136
159
}
137
160
138
138
-
func (s *Server) handleOauthAuthorizePost(e echo.Context) error {
139
139
-
ctx := e.Request().Context()
161
161
+
func (s *Server) handleOauthAuthorizePost(w http.ResponseWriter, r *http.Request) {
162
162
+
ctx := r.Context()
140
163
logger := s.logger.With("name", "handleOauthAuthorizePost")
141
164
142
142
-
repo, _, err := s.getSessionRepoOrErr(e)
165
165
+
repo, _, err := s.getSessionRepoOrErr(r)
143
166
if err != nil {
144
144
-
return e.Redirect(303, "/account/signin")
167
167
+
http.Redirect(w, r, "/account/signin", 303)
168
168
+
return
145
169
}
146
170
147
147
-
var req OauthAuthorizePostRequest
148
148
-
if err := e.Bind(&req); err != nil {
149
149
-
logger.Error("error binding authorize post request", "error", err)
150
150
-
return helpers.InputError(e, nil)
171
171
+
if err := r.ParseForm(); err != nil {
172
172
+
logger.Error("error parsing authorize post form", "error", err)
173
173
+
helpers.InputError(w, nil)
174
174
+
return
175
175
+
}
176
176
+
177
177
+
req := OauthAuthorizePostRequest{
178
178
+
RequestUri: r.FormValue("request_uri"),
179
179
+
AcceptOrRejct: r.FormValue("accept_or_reject"),
151
180
}
152
181
153
182
reqId, err := oauth.DecodeRequestUri(req.RequestUri)
154
183
if err != nil {
155
155
-
return helpers.InputError(e, to.StringPtr(err.Error()))
184
184
+
helpers.InputError(w, to.StringPtr(err.Error()))
185
185
+
return
156
186
}
157
187
158
188
var authReq provider.OauthAuthorizationRequest
159
189
if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil {
160
160
-
return helpers.ServerError(e, to.StringPtr(err.Error()))
190
190
+
helpers.ServerError(w, to.StringPtr(err.Error()))
191
191
+
return
161
192
}
162
193
163
163
-
client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), authReq.ClientId)
194
194
+
client, err := s.oauthProvider.ClientManager.GetClient(r.Context(), authReq.ClientId)
164
195
if err != nil {
165
165
-
return helpers.ServerError(e, to.StringPtr(err.Error()))
196
196
+
helpers.ServerError(w, to.StringPtr(err.Error()))
197
197
+
return
166
198
}
167
199
168
200
// TODO: figure out how im supposed to actually redirect
169
201
if req.AcceptOrRejct == "reject" {
170
170
-
return e.Redirect(303, client.Metadata.ClientURI)
202
202
+
http.Redirect(w, r, client.Metadata.ClientURI, 303)
203
203
+
return
171
204
}
172
205
173
206
if time.Now().After(authReq.ExpiresAt) {
174
174
-
return helpers.InputError(e, to.StringPtr("the request has expired"))
207
207
+
helpers.InputError(w, to.StringPtr("the request has expired"))
208
208
+
return
175
209
}
176
210
177
211
if authReq.Sub != nil || authReq.Code != nil {
178
178
-
return helpers.InputError(e, to.StringPtr("this request was already authorized"))
212
212
+
helpers.InputError(w, to.StringPtr("this request was already authorized"))
213
213
+
return
179
214
}
180
215
181
216
code := oauth.GenerateCode()
182
217
183
183
-
if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil {
218
218
+
// Use the first non-loopback remote address as the IP
219
219
+
ip := r.RemoteAddr
220
220
+
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
221
221
+
ip = strings.Split(forwarded, ",")[0]
222
222
+
}
223
223
+
224
224
+
if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, ip, reqId).Error; err != nil {
184
225
logger.Error("error updating authorization request", "error", err)
185
185
-
return helpers.ServerError(e, nil)
226
226
+
helpers.ServerError(w, nil)
227
227
+
return
186
228
}
187
229
188
230
q := url.Values{}
···
197
239
hashOrQuestion = "#"
198
240
case "query":
199
241
// do nothing
200
200
-
break
201
242
default:
202
243
if authReq.Parameters.ResponseType != "code" {
203
244
hashOrQuestion = "#"
···
209
250
}
210
251
}
211
252
212
212
-
return e.Redirect(303, authReq.Parameters.RedirectURI+hashOrQuestion+q.Encode())
253
253
+
_ = fmt.Sprintf // avoid unused import if fmt ends up unused
254
254
+
http.Redirect(w, r, authReq.Parameters.RedirectURI+hashOrQuestion+q.Encode(), 303)
213
255
}
+3
-3
server/handle_oauth_jwks.go
···
1
1
package server
2
2
3
3
-
import "github.com/labstack/echo/v4"
3
3
+
import "net/http"
4
4
5
5
type OauthJwksResponse struct {
6
6
Keys []any `json:"keys"`
7
7
}
8
8
9
9
// TODO: ?
10
10
-
func (s *Server) handleOauthJwks(e echo.Context) error {
11
11
-
return e.JSON(200, OauthJwksResponse{Keys: []any{}})
10
10
+
func (s *Server) handleOauthJwks(w http.ResponseWriter, r *http.Request) {
11
11
+
s.writeJSON(w, 200, OauthJwksResponse{Keys: []any{}})
12
12
}
+57
-21
server/handle_oauth_par.go
···
2
2
3
3
import (
4
4
"errors"
5
5
+
"net/http"
5
6
"time"
6
7
7
8
"github.com/Azure/go-autorest/autorest/to"
···
10
11
"github.com/haileyok/cocoon/oauth/constants"
11
12
"github.com/haileyok/cocoon/oauth/dpop"
12
13
"github.com/haileyok/cocoon/oauth/provider"
13
13
-
"github.com/labstack/echo/v4"
14
14
)
15
15
16
16
type OauthParResponse struct {
···
18
18
RequestURI string `json:"request_uri"`
19
19
}
20
20
21
21
-
func (s *Server) handleOauthPar(e echo.Context) error {
22
22
-
ctx := e.Request().Context()
21
21
+
func (s *Server) handleOauthPar(w http.ResponseWriter, r *http.Request) {
22
22
+
ctx := r.Context()
23
23
logger := s.logger.With("name", "handleOauthPar")
24
24
25
25
-
var parRequest provider.ParRequest
26
26
-
if err := e.Bind(&parRequest); err != nil {
27
27
-
logger.Error("error binding for par request", "error", err)
28
28
-
return helpers.ServerError(e, nil)
25
25
+
if err := r.ParseForm(); err != nil {
26
26
+
logger.Error("error parsing par request form", "error", err)
27
27
+
helpers.ServerError(w, nil)
28
28
+
return
29
29
}
30
30
31
31
-
if err := e.Validate(parRequest); err != nil {
31
31
+
parRequest := provider.ParRequest{
32
32
+
AuthenticateClientRequestBase: provider.AuthenticateClientRequestBase{
33
33
+
ClientID: r.FormValue("client_id"),
34
34
+
},
35
35
+
ResponseType: r.FormValue("response_type"),
36
36
+
State: r.FormValue("state"),
37
37
+
RedirectURI: r.FormValue("redirect_uri"),
38
38
+
Scope: r.FormValue("scope"),
39
39
+
CodeChallengeMethod: r.FormValue("code_challenge_method"),
40
40
+
}
41
41
+
if v := r.FormValue("code_challenge"); v != "" {
42
42
+
parRequest.CodeChallenge = to.StringPtr(v)
43
43
+
}
44
44
+
if v := r.FormValue("login_hint"); v != "" {
45
45
+
parRequest.LoginHint = to.StringPtr(v)
46
46
+
}
47
47
+
if v := r.FormValue("dpop_jkt"); v != "" {
48
48
+
parRequest.DpopJkt = to.StringPtr(v)
49
49
+
}
50
50
+
if v := r.FormValue("response_mode"); v != "" {
51
51
+
parRequest.ResponseMode = to.StringPtr(v)
52
52
+
}
53
53
+
if v := r.FormValue("client_assertion_type"); v != "" {
54
54
+
parRequest.ClientAssertionType = to.StringPtr(v)
55
55
+
}
56
56
+
if v := r.FormValue("client_assertion"); v != "" {
57
57
+
parRequest.ClientAssertion = to.StringPtr(v)
58
58
+
}
59
59
+
60
60
+
if err := s.validator.Struct(parRequest); err != nil {
32
61
logger.Error("missing parameters for par request", "error", err)
33
33
-
return helpers.InputError(e, nil)
62
62
+
helpers.InputError(w, nil)
63
63
+
return
34
64
}
35
65
36
66
// TODO: this seems wrong. should be a way to get the entire request url i believe, but this will work for now
37
37
-
dpopProof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, nil)
67
67
+
dpopProof, err := s.oauthProvider.DpopManager.CheckProof(r.Method, "https://"+s.config.Hostname+r.URL.String(), r.Header, nil)
38
68
if err != nil {
39
69
if errors.Is(err, dpop.ErrUseDpopNonce) {
40
70
nonce := s.oauthProvider.NextNonce()
41
71
if nonce != "" {
42
42
-
e.Response().Header().Set("DPoP-Nonce", nonce)
43
43
-
e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
72
72
+
w.Header().Set("DPoP-Nonce", nonce)
73
73
+
w.Header().Add("access-control-expose-headers", "DPoP-Nonce")
44
74
}
45
45
-
logger.Error("nonce error: use_dpop_nonce", "headers", e.Request().Header)
46
46
-
return e.JSON(400, map[string]string{
75
75
+
logger.Error("nonce error: use_dpop_nonce", "headers", r.Header)
76
76
+
s.writeJSON(w, 400, map[string]string{
47
77
"error": "use_dpop_nonce",
48
78
})
79
79
+
return
49
80
}
50
81
logger.Error("error getting dpop proof", "error", err)
51
51
-
return helpers.InputError(e, nil)
82
82
+
helpers.InputError(w, nil)
83
83
+
return
52
84
}
53
85
54
54
-
client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), parRequest.AuthenticateClientRequestBase, dpopProof, &provider.AuthenticateClientOptions{
86
86
+
client, clientAuth, err := s.oauthProvider.AuthenticateClient(ctx, parRequest.AuthenticateClientRequestBase, dpopProof, &provider.AuthenticateClientOptions{
55
87
// rfc9449
56
88
// https://github.com/bluesky-social/atproto/blob/main/packages/oauth/oauth-provider/src/oauth-provider.ts#L473
57
89
AllowMissingDpopProof: true,
58
90
})
59
91
if err != nil {
60
92
logger.Error("error authenticating client", "client_id", parRequest.ClientID, "error", err)
61
61
-
return helpers.InputError(e, to.StringPtr(err.Error()))
93
93
+
helpers.InputError(w, to.StringPtr(err.Error()))
94
94
+
return
62
95
}
63
96
64
97
if parRequest.DpopJkt == nil {
···
69
102
if !client.Metadata.DpopBoundAccessTokens {
70
103
msg := "dpop bound access tokens are not enabled for this client"
71
104
logger.Error(msg)
72
72
-
return helpers.InputError(e, &msg)
105
105
+
helpers.InputError(w, &msg)
106
106
+
return
73
107
}
74
108
75
109
if dpopProof.JKT != *parRequest.DpopJkt {
76
110
msg := "supplied dpop jkt does not match header dpop jkt"
77
111
logger.Error(msg)
78
78
-
return helpers.InputError(e, &msg)
112
112
+
helpers.InputError(w, &msg)
113
113
+
return
79
114
}
80
115
}
81
116
···
92
127
93
128
if err := s.db.Create(ctx, authRequest, nil).Error; err != nil {
94
129
logger.Error("error creating auth request in db", "error", err)
95
95
-
return helpers.ServerError(e, nil)
130
130
+
helpers.ServerError(w, nil)
131
131
+
return
96
132
}
97
133
98
134
uri := oauth.EncodeRequestUri(id)
99
135
100
100
-
return e.JSON(201, OauthParResponse{
136
136
+
s.writeJSON(w, 201, OauthParResponse{
101
137
ExpiresIn: int64(constants.ParExpiresIn.Seconds()),
102
138
RequestURI: uri,
103
139
})
+102
-47
server/handle_oauth_token.go
···
6
6
"encoding/base64"
7
7
"errors"
8
8
"fmt"
9
9
+
"net/http"
9
10
"slices"
10
11
"time"
11
12
···
16
17
"github.com/haileyok/cocoon/oauth/constants"
17
18
"github.com/haileyok/cocoon/oauth/dpop"
18
19
"github.com/haileyok/cocoon/oauth/provider"
19
19
-
"github.com/labstack/echo/v4"
20
20
)
21
21
22
22
type OauthTokenRequest struct {
···
37
37
Sub string `json:"sub"`
38
38
}
39
39
40
40
-
func (s *Server) handleOauthToken(e echo.Context) error {
41
41
-
ctx := e.Request().Context()
40
40
+
func (s *Server) handleOauthToken(w http.ResponseWriter, r *http.Request) {
41
41
+
ctx := r.Context()
42
42
logger := s.logger.With("name", "handleOauthToken")
43
43
44
44
-
var req OauthTokenRequest
45
45
-
if err := e.Bind(&req); err != nil {
46
46
-
logger.Error("error binding token request", "error", err)
47
47
-
return helpers.ServerError(e, nil)
44
44
+
if err := r.ParseForm(); err != nil {
45
45
+
logger.Error("error parsing token request form", "error", err)
46
46
+
helpers.ServerError(w, nil)
47
47
+
return
48
48
+
}
49
49
+
50
50
+
req := OauthTokenRequest{
51
51
+
GrantType: r.FormValue("grant_type"),
52
52
+
}
53
53
+
if v := r.FormValue("code"); v != "" {
54
54
+
req.Code = to.StringPtr(v)
55
55
+
}
56
56
+
if v := r.FormValue("code_verifier"); v != "" {
57
57
+
req.CodeVerifier = to.StringPtr(v)
58
58
+
}
59
59
+
if v := r.FormValue("redirect_uri"); v != "" {
60
60
+
req.RedirectURI = to.StringPtr(v)
61
61
+
}
62
62
+
if v := r.FormValue("refresh_token"); v != "" {
63
63
+
req.RefreshToken = to.StringPtr(v)
64
64
+
}
65
65
+
if v := r.FormValue("client_assertion_type"); v != "" {
66
66
+
req.ClientAssertionType = to.StringPtr(v)
67
67
+
}
68
68
+
if v := r.FormValue("client_assertion"); v != "" {
69
69
+
req.ClientAssertion = to.StringPtr(v)
70
70
+
}
71
71
+
req.AuthenticateClientRequestBase = provider.AuthenticateClientRequestBase{
72
72
+
ClientID: r.FormValue("client_id"),
73
73
+
ClientAssertionType: req.ClientAssertionType,
74
74
+
ClientAssertion: req.ClientAssertion,
48
75
}
49
76
50
50
-
proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, e.Request().URL.String(), e.Request().Header, nil)
77
77
+
proof, err := s.oauthProvider.DpopManager.CheckProof(r.Method, r.URL.String(), r.Header, nil)
51
78
if err != nil {
52
79
if errors.Is(err, dpop.ErrUseDpopNonce) {
53
80
nonce := s.oauthProvider.NextNonce()
54
81
if nonce != "" {
55
55
-
e.Response().Header().Set("DPoP-Nonce", nonce)
56
56
-
e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
82
82
+
w.Header().Set("DPoP-Nonce", nonce)
83
83
+
w.Header().Add("access-control-expose-headers", "DPoP-Nonce")
57
84
}
58
58
-
return e.JSON(400, map[string]string{
85
85
+
s.writeJSON(w, 400, map[string]string{
59
86
"error": "use_dpop_nonce",
60
87
})
88
88
+
return
61
89
}
62
90
logger.Error("error getting dpop proof", "error", err)
63
63
-
return helpers.InputError(e, nil)
91
91
+
helpers.InputError(w, nil)
92
92
+
return
64
93
}
65
94
66
66
-
client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{
95
95
+
client, clientAuth, err := s.oauthProvider.AuthenticateClient(ctx, req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{
67
96
AllowMissingDpopProof: true,
68
97
})
69
98
if err != nil {
70
99
logger.Error("error authenticating client", "client_id", req.ClientID, "error", err)
71
71
-
return helpers.InputError(e, to.StringPtr(err.Error()))
100
100
+
helpers.InputError(w, to.StringPtr(err.Error()))
101
101
+
return
72
102
}
73
103
74
74
-
// TODO: this should come from an oauth provier config
104
104
+
// TODO: this should come from an oauth provider config
75
105
if !slices.Contains([]string{"authorization_code", "refresh_token"}, req.GrantType) {
76
76
-
return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the server`, req.GrantType)))
106
106
+
helpers.InputError(w, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the server`, req.GrantType)))
107
107
+
return
77
108
}
78
109
79
110
if !slices.Contains(client.Metadata.GrantTypes, req.GrantType) {
80
80
-
return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the client`, req.GrantType)))
111
111
+
helpers.InputError(w, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the client`, req.GrantType)))
112
112
+
return
81
113
}
82
114
83
115
if req.GrantType == "authorization_code" {
84
116
if req.Code == nil {
85
85
-
return helpers.InputError(e, to.StringPtr(`"code" is required"`))
117
117
+
helpers.InputError(w, to.StringPtr(`"code" is required"`))
118
118
+
return
86
119
}
87
120
88
121
var authReq provider.OauthAuthorizationRequest
89
122
// get the lil guy and delete him
90
123
if err := s.db.Raw(ctx, "DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil {
91
124
logger.Error("error finding authorization request", "error", err)
92
92
-
return helpers.ServerError(e, nil)
125
125
+
helpers.ServerError(w, nil)
126
126
+
return
93
127
}
94
128
95
129
if req.RedirectURI == nil || *req.RedirectURI != authReq.Parameters.RedirectURI {
96
96
-
return helpers.InputError(e, to.StringPtr(`"redirect_uri" mismatch`))
130
130
+
helpers.InputError(w, to.StringPtr(`"redirect_uri" mismatch`))
131
131
+
return
97
132
}
98
133
99
134
if authReq.Parameters.CodeChallenge != nil {
100
135
if req.CodeVerifier == nil {
101
101
-
return helpers.InputError(e, to.StringPtr(`"code_verifier" is required`))
136
136
+
helpers.InputError(w, to.StringPtr(`"code_verifier" is required`))
137
137
+
return
102
138
}
103
139
104
140
if len(*req.CodeVerifier) < 43 {
105
105
-
return helpers.InputError(e, to.StringPtr(`"code_verifier" is too short`))
141
141
+
helpers.InputError(w, to.StringPtr(`"code_verifier" is too short`))
142
142
+
return
106
143
}
107
144
108
108
-
switch *&authReq.Parameters.CodeChallengeMethod {
145
145
+
switch authReq.Parameters.CodeChallengeMethod {
109
146
case "", "plain":
110
147
if authReq.Parameters.CodeChallenge != req.CodeVerifier {
111
111
-
return helpers.InputError(e, to.StringPtr("invalid code_verifier"))
148
148
+
helpers.InputError(w, to.StringPtr("invalid code_verifier"))
149
149
+
return
112
150
}
113
151
case "S256":
114
152
inputChal, err := base64.RawURLEncoding.DecodeString(*authReq.Parameters.CodeChallenge)
115
153
if err != nil {
116
154
logger.Error("error decoding code challenge", "error", err)
117
117
-
return helpers.ServerError(e, nil)
155
155
+
helpers.ServerError(w, nil)
156
156
+
return
118
157
}
119
158
120
159
h := sha256.New()
···
122
161
compdChal := h.Sum(nil)
123
162
124
163
if !bytes.Equal(inputChal, compdChal) {
125
125
-
return helpers.InputError(e, to.StringPtr("invalid code_verifier"))
164
164
+
helpers.InputError(w, to.StringPtr("invalid code_verifier"))
165
165
+
return
126
166
}
127
167
default:
128
128
-
return helpers.InputError(e, to.StringPtr("unsupported code_challenge_method "+*&authReq.Parameters.CodeChallengeMethod))
168
168
+
helpers.InputError(w, to.StringPtr("unsupported code_challenge_method "+authReq.Parameters.CodeChallengeMethod))
169
169
+
return
129
170
}
130
171
} else if req.CodeVerifier != nil {
131
131
-
return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided"))
172
172
+
helpers.InputError(w, to.StringPtr("code_challenge parameter wasn't provided"))
173
173
+
return
132
174
}
133
175
134
176
repo, err := s.getRepoActorByDid(ctx, *authReq.Sub)
135
177
if err != nil {
136
136
-
helpers.InputError(e, to.StringPtr("unable to find actor"))
178
178
+
helpers.InputError(w, to.StringPtr("unable to find actor"))
179
179
+
return
137
180
}
138
181
139
182
now := time.Now()
···
159
202
accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims)
160
203
accessString, err := accessToken.SignedString(s.privateKey)
161
204
if err != nil {
162
162
-
return err
205
205
+
helpers.ServerError(w, nil)
206
206
+
return
163
207
}
164
208
165
209
if err := s.db.Create(ctx, &provider.OauthToken{
···
175
219
Ip: authReq.Ip,
176
220
}, nil).Error; err != nil {
177
221
logger.Error("error creating token in db", "error", err)
178
178
-
return helpers.ServerError(e, nil)
222
222
+
helpers.ServerError(w, nil)
223
223
+
return
179
224
}
180
225
181
226
// prob not needed
···
184
229
tokenType = "DPoP"
185
230
}
186
231
187
187
-
e.Response().Header().Set("content-type", "application/json")
188
188
-
189
189
-
return e.JSON(200, OauthTokenResponse{
232
232
+
s.writeJSON(w, 200, OauthTokenResponse{
190
233
AccessToken: accessString,
191
234
RefreshToken: refreshToken,
192
235
TokenType: tokenType,
···
194
237
ExpiresIn: int64(eat.Sub(time.Now()).Seconds()),
195
238
Sub: repo.Repo.Did,
196
239
})
240
240
+
return
197
241
}
198
242
199
243
if req.GrantType == "refresh_token" {
200
244
if req.RefreshToken == nil {
201
201
-
return helpers.InputError(e, to.StringPtr(`"refresh_token" is required`))
245
245
+
helpers.InputError(w, to.StringPtr(`"refresh_token" is required`))
246
246
+
return
202
247
}
203
248
204
249
var oauthToken provider.OauthToken
205
250
if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil {
206
251
logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken)
207
207
-
return helpers.ServerError(e, nil)
252
252
+
helpers.ServerError(w, nil)
253
253
+
return
208
254
}
209
255
210
256
if client.Metadata.ClientID != oauthToken.ClientId {
211
211
-
return helpers.InputError(e, to.StringPtr(`"client_id" mismatch`))
257
257
+
helpers.InputError(w, to.StringPtr(`"client_id" mismatch`))
258
258
+
return
212
259
}
213
260
214
261
if clientAuth.Method != oauthToken.ClientAuth.Method {
215
215
-
return helpers.InputError(e, to.StringPtr(`"client authentication method mismatch`))
262
262
+
helpers.InputError(w, to.StringPtr(`"client authentication method mismatch`))
263
263
+
return
216
264
}
217
265
218
266
if *oauthToken.Parameters.DpopJkt != proof.JKT {
219
219
-
return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt"))
267
267
+
helpers.InputError(w, to.StringPtr("dpop proof does not match expected jkt"))
268
268
+
return
220
269
}
221
270
222
271
ageRes := oauth.GetSessionAgeFromToken(oauthToken)
223
272
224
273
if ageRes.SessionExpired {
225
225
-
return helpers.InputError(e, to.StringPtr("Session expired"))
274
274
+
helpers.InputError(w, to.StringPtr("Session expired"))
275
275
+
return
226
276
}
227
277
228
278
if ageRes.RefreshExpired {
229
229
-
return helpers.InputError(e, to.StringPtr("Refresh token expired"))
279
279
+
helpers.InputError(w, to.StringPtr("Refresh token expired"))
280
280
+
return
230
281
}
231
282
232
283
if client.Metadata.DpopBoundAccessTokens && oauthToken.Parameters.DpopJkt == nil {
233
284
// why? ref impl
234
234
-
return helpers.InputError(e, to.StringPtr("dpop jkt is required for dpop bound access tokens"))
285
285
+
helpers.InputError(w, to.StringPtr("dpop jkt is required for dpop bound access tokens"))
286
286
+
return
235
287
}
236
288
237
289
nextTokenId := oauth.GenerateTokenId()
···
251
303
}
252
304
253
305
if oauthToken.Parameters.DpopJkt != nil {
254
254
-
accessClaims["cnf"] = *&oauthToken.Parameters.DpopJkt
306
306
+
accessClaims["cnf"] = oauthToken.Parameters.DpopJkt
255
307
}
256
308
257
309
accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims)
258
310
accessString, err := accessToken.SignedString(s.privateKey)
259
311
if err != nil {
260
260
-
return err
312
312
+
helpers.ServerError(w, nil)
313
313
+
return
261
314
}
262
315
263
316
if err := s.db.Exec(ctx, "UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil {
264
317
logger.Error("error updating token", "error", err)
265
265
-
return helpers.ServerError(e, nil)
318
318
+
helpers.ServerError(w, nil)
319
319
+
return
266
320
}
267
321
268
322
// prob not needed
···
271
325
tokenType = "DPoP"
272
326
}
273
327
274
274
-
return e.JSON(200, OauthTokenResponse{
328
328
+
s.writeJSON(w, 200, OauthTokenResponse{
275
329
AccessToken: accessString,
276
330
RefreshToken: nextRefreshToken,
277
331
TokenType: tokenType,
···
279
333
ExpiresIn: int64(eat.Sub(time.Now()).Seconds()),
280
334
Sub: oauthToken.Sub,
281
335
})
336
336
+
return
282
337
}
283
338
284
284
-
return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`grant type "%s" is not supported`, req.GrantType)))
339
339
+
helpers.InputError(w, to.StringPtr(fmt.Sprintf(`grant type "%s" is not supported`, req.GrantType)))
285
340
}
+36
-27
server/handle_proxy.go
···
6
6
"encoding/base64"
7
7
"encoding/json"
8
8
"fmt"
9
9
+
"io"
9
10
"net/http"
10
11
"strings"
11
12
"time"
···
13
14
"github.com/google/uuid"
14
15
"github.com/haileyok/cocoon/internal/helpers"
15
16
"github.com/haileyok/cocoon/models"
16
16
-
"github.com/labstack/echo/v4"
17
17
secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18
18
)
19
19
20
20
-
func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) {
21
21
-
svc := e.Request().Header.Get("atproto-proxy")
20
20
+
func (s *Server) getAtprotoProxyEndpointFromRequest(r *http.Request) (string, string, error) {
21
21
+
svc := r.Header.Get("atproto-proxy")
22
22
if svc == "" && s.config.FallbackProxy != "" {
23
23
svc = s.config.FallbackProxy
24
24
}
···
31
31
svcDid := svcPts[0]
32
32
svcId := "#" + svcPts[1]
33
33
34
34
-
doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid)
34
34
+
doc, err := s.passport.FetchDoc(r.Context(), svcDid)
35
35
if err != nil {
36
36
return "", "", err
37
37
}
···
46
46
return endpoint, svcDid, nil
47
47
}
48
48
49
49
-
func (s *Server) handleProxy(e echo.Context) error {
49
49
+
func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) {
50
50
logger := s.logger.With("handler", "handleProxy")
51
51
52
52
-
repo, isAuthed := e.Get("repo").(*models.RepoActor)
52
52
+
repo, isAuthed := getContextValue[*models.RepoActor](r, contextKeyRepo)
53
53
54
54
-
pts := strings.Split(e.Request().URL.Path, "/")
54
54
+
pts := strings.Split(r.URL.Path, "/")
55
55
if len(pts) != 3 {
56
56
-
return fmt.Errorf("incorrect number of parts")
56
56
+
helpers.ServerError(w, nil)
57
57
+
return
57
58
}
58
59
59
59
-
endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e)
60
60
+
endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(r)
60
61
if err != nil {
61
62
logger.Error("could not get atproto proxy", "error", err)
62
62
-
return helpers.ServerError(e, nil)
63
63
+
helpers.ServerError(w, nil)
64
64
+
return
63
65
}
64
66
65
65
-
requrl := e.Request().URL
67
67
+
requrl := *r.URL
66
68
requrl.Host = strings.TrimPrefix(endpoint, "https://")
67
69
requrl.Scheme = "https"
68
70
69
69
-
body := e.Request().Body
70
70
-
if e.Request().Method == "GET" {
71
71
-
body = nil
71
71
+
var body io.Reader
72
72
+
if r.Method != http.MethodGet {
73
73
+
body = r.Body
72
74
}
73
75
74
74
-
req, err := http.NewRequest(e.Request().Method, requrl.String(), body)
76
76
+
req, err := http.NewRequest(r.Method, requrl.String(), body)
75
77
if err != nil {
76
76
-
return err
78
78
+
helpers.ServerError(w, nil)
79
79
+
return
77
80
}
78
81
79
79
-
req.Header = e.Request().Header.Clone()
82
82
+
req.Header = r.Header.Clone()
80
83
81
84
if isAuthed {
82
85
// this is a little dumb. i should probably figure out a better way to do this, and use
···
91
94
hj, err := json.Marshal(header)
92
95
if err != nil {
93
96
logger.Error("error marshaling header", "error", err)
94
94
-
return helpers.ServerError(e, nil)
97
97
+
helpers.ServerError(w, nil)
98
98
+
return
95
99
}
96
100
97
101
encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=")
···
100
104
// underlying feed generator and the app view passes it on. This allows the
101
105
// getFeed implementation to pass in the desired lxm and aud for the token
102
106
// and then just delegate to the general proxying logic
103
103
-
lxm, proxyTokenLxmExists := e.Get("proxyTokenLxm").(string)
107
107
+
lxm, proxyTokenLxmExists := getContextValue[string](r, contextKeyProxyTokenLxm)
104
108
if !proxyTokenLxmExists || lxm == "" {
105
109
lxm = pts[2]
106
110
}
107
107
-
aud, proxyTokenAudExists := e.Get("proxyTokenAud").(string)
111
111
+
aud, proxyTokenAudExists := getContextValue[string](r, contextKeyProxyTokenAud)
108
112
if !proxyTokenAudExists || aud == "" {
109
113
aud = svcDid
110
114
}
···
118
122
}
119
123
pj, err := json.Marshal(payload)
120
124
if err != nil {
121
121
-
logger.Error("error marashaling payload", "error", err)
122
122
-
return helpers.ServerError(e, nil)
125
125
+
logger.Error("error marshaling payload", "error", err)
126
126
+
helpers.ServerError(w, nil)
127
127
+
return
123
128
}
124
129
125
130
encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=")
···
130
135
sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
131
136
if err != nil {
132
137
logger.Error("can't load private key", "error", err)
133
133
-
return err
138
138
+
helpers.ServerError(w, nil)
139
139
+
return
134
140
}
135
141
136
142
R, S, _, err := sk.SignRaw(rand.Reader, hash[:])
137
143
if err != nil {
138
144
logger.Error("error signing", "error", err)
145
145
+
helpers.ServerError(w, nil)
146
146
+
return
139
147
}
140
148
141
149
rBytes := R.Bytes()
···
157
165
158
166
resp, err := http.DefaultClient.Do(req)
159
167
if err != nil {
160
160
-
return err
168
168
+
helpers.ServerError(w, nil)
169
169
+
return
161
170
}
162
171
defer resp.Body.Close()
163
172
164
173
for k, v := range resp.Header {
165
165
-
e.Response().Header().Set(k, strings.Join(v, ","))
174
174
+
w.Header().Set(k, strings.Join(v, ","))
166
175
}
167
167
-
168
168
-
return e.Stream(resp.StatusCode, e.Response().Header().Get("content-type"), resp.Body)
176
176
+
w.WriteHeader(resp.StatusCode)
177
177
+
io.Copy(w, resp.Body)
169
178
}
+21
-11
server/handle_proxy_get_feed.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"context"
5
5
+
"net/http"
6
6
+
4
7
"github.com/Azure/go-autorest/autorest/to"
5
8
"github.com/bluesky-social/indigo/api/atproto"
6
9
"github.com/bluesky-social/indigo/api/bsky"
7
10
"github.com/bluesky-social/indigo/atproto/syntax"
8
11
"github.com/bluesky-social/indigo/xrpc"
9
12
"github.com/haileyok/cocoon/internal/helpers"
10
10
-
"github.com/labstack/echo/v4"
11
13
)
12
14
13
13
-
func (s *Server) handleProxyBskyFeedGetFeed(e echo.Context) error {
14
14
-
feedUri, err := syntax.ParseATURI(e.QueryParam("feed"))
15
15
+
func (s *Server) handleProxyBskyFeedGetFeed(w http.ResponseWriter, r *http.Request) {
16
16
+
feedUri, err := syntax.ParseATURI(r.URL.Query().Get("feed"))
15
17
if err != nil {
16
16
-
return helpers.InputError(e, to.StringPtr("invalid feed uri"))
18
18
+
helpers.InputError(w, to.StringPtr("invalid feed uri"))
19
19
+
return
17
20
}
18
21
19
19
-
appViewEndpoint, _, err := s.getAtprotoProxyEndpointFromRequest(e)
22
22
+
appViewEndpoint, _, err := s.getAtprotoProxyEndpointFromRequest(r)
20
23
if err != nil {
21
21
-
e.Logger().Error("could not get atproto proxy", "error", err)
22
22
-
return helpers.ServerError(e, nil)
24
24
+
s.logger.Error("could not get atproto proxy", "error", err)
25
25
+
helpers.ServerError(w, nil)
26
26
+
return
23
27
}
24
28
25
29
appViewClient := xrpc.Client{
26
30
Host: appViewEndpoint,
27
31
}
28
28
-
feedRecord, err := atproto.RepoGetRecord(e.Request().Context(), &appViewClient, "", feedUri.Collection().String(), feedUri.Authority().String(), feedUri.RecordKey().String())
32
32
+
feedRecord, err := atproto.RepoGetRecord(r.Context(), &appViewClient, "", feedUri.Collection().String(), feedUri.Authority().String(), feedUri.RecordKey().String())
33
33
+
if err != nil {
34
34
+
s.logger.Error("could not get feed record", "error", err)
35
35
+
helpers.ServerError(w, nil)
36
36
+
return
37
37
+
}
29
38
feedGeneratorDid := feedRecord.Value.Val.(*bsky.FeedGenerator).Did
30
39
31
31
-
e.Set("proxyTokenLxm", "app.bsky.feed.getFeedSkeleton")
32
32
-
e.Set("proxyTokenAud", feedGeneratorDid)
40
40
+
// Inject proxy token overrides into the request context so handleProxy can read them.
41
41
+
ctx := context.WithValue(r.Context(), contextKeyProxyTokenLxm, "app.bsky.feed.getFeedSkeleton")
42
42
+
ctx = context.WithValue(ctx, contextKeyProxyTokenAud, feedGeneratorDid)
33
43
34
34
-
return s.handleProxy(e)
44
44
+
s.handleProxy(w, r.WithContext(ctx))
35
45
}
+17
-11
server/handle_repo_apply_writes.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"encoding/json"
5
5
+
"net/http"
6
6
+
4
7
"github.com/haileyok/cocoon/internal/helpers"
5
8
"github.com/haileyok/cocoon/models"
6
6
-
"github.com/labstack/echo/v4"
7
9
)
8
10
9
11
type ComAtprotoRepoApplyWritesInput struct {
···
25
27
Results []ApplyWriteResult `json:"results"`
26
28
}
27
29
28
28
-
func (s *Server) handleApplyWrites(e echo.Context) error {
29
29
-
ctx := e.Request().Context()
30
30
+
func (s *Server) handleApplyWrites(w http.ResponseWriter, r *http.Request) {
31
31
+
ctx := r.Context()
30
32
logger := s.logger.With("name", "handleRepoApplyWrites")
31
33
32
34
var req ComAtprotoRepoApplyWritesInput
33
33
-
if err := e.Bind(&req); err != nil {
35
35
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
34
36
logger.Error("error binding", "error", err)
35
35
-
return helpers.ServerError(e, nil)
37
37
+
helpers.ServerError(w, nil)
38
38
+
return
36
39
}
37
40
38
38
-
if err := e.Validate(req); err != nil {
41
41
+
if err := s.validator.Struct(req); err != nil {
39
42
logger.Error("error validating", "error", err)
40
40
-
return helpers.InputError(e, nil)
43
43
+
helpers.InputError(w, nil)
44
44
+
return
41
45
}
42
46
43
43
-
repo := e.Get("repo").(*models.RepoActor)
47
47
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
44
48
45
49
if repo.Repo.Did != req.Repo {
46
50
logger.Warn("mismatched repo/auth")
47
47
-
return helpers.InputError(e, nil)
51
51
+
helpers.InputError(w, nil)
52
52
+
return
48
53
}
49
54
50
55
ops := make([]Op, 0, len(req.Writes))
···
60
65
results, err := s.repoman.applyWrites(ctx, repo.Repo, ops, req.SwapCommit)
61
66
if err != nil {
62
67
logger.Error("error applying writes", "error", err)
63
63
-
return helpers.ServerError(e, nil)
68
68
+
helpers.ServerError(w, nil)
69
69
+
return
64
70
}
65
71
66
72
commit := *results[0].Commit
···
69
75
results[i].Commit = nil
70
76
}
71
77
72
72
-
return e.JSON(200, ComAtprotoRepoApplyWritesOutput{
78
78
+
s.writeJSON(w, http.StatusOK, ComAtprotoRepoApplyWritesOutput{
73
79
Commit: commit,
74
80
Results: results,
75
81
})
+18
-12
server/handle_repo_create_record.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"encoding/json"
5
5
+
"net/http"
6
6
+
4
7
"github.com/haileyok/cocoon/internal/helpers"
5
8
"github.com/haileyok/cocoon/models"
6
6
-
"github.com/labstack/echo/v4"
7
9
)
8
10
9
11
type ComAtprotoRepoCreateRecordInput struct {
···
16
18
SwapCommit *string `json:"swapCommit"`
17
19
}
18
20
19
19
-
func (s *Server) handleCreateRecord(e echo.Context) error {
20
20
-
ctx := e.Request().Context()
21
21
+
func (s *Server) handleCreateRecord(w http.ResponseWriter, r *http.Request) {
22
22
+
ctx := r.Context()
21
23
logger := s.logger.With("name", "handleCreateRecord")
22
24
23
23
-
repo := e.Get("repo").(*models.RepoActor)
25
25
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
24
26
25
27
var req ComAtprotoRepoCreateRecordInput
26
26
-
if err := e.Bind(&req); err != nil {
27
27
-
logger.Error("error binding", "error", err)
28
28
-
return helpers.ServerError(e, nil)
28
28
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
29
29
+
logger.Error("error decoding", "error", err)
30
30
+
helpers.ServerError(w, nil)
31
31
+
return
29
32
}
30
33
31
31
-
if err := e.Validate(req); err != nil {
34
34
+
if err := s.validator.Struct(req); err != nil {
32
35
logger.Error("error validating", "error", err)
33
33
-
return helpers.InputError(e, nil)
36
36
+
helpers.InputError(w, nil)
37
37
+
return
34
38
}
35
39
36
40
if repo.Repo.Did != req.Repo {
37
41
logger.Warn("mismatched repo/auth")
38
38
-
return helpers.InputError(e, nil)
42
42
+
helpers.InputError(w, nil)
43
43
+
return
39
44
}
40
45
41
46
optype := OpTypeCreate
···
55
60
}, req.SwapCommit)
56
61
if err != nil {
57
62
logger.Error("error applying writes", "error", err)
58
58
-
return helpers.ServerError(e, nil)
63
63
+
helpers.ServerError(w, nil)
64
64
+
return
59
65
}
60
66
61
67
results[0].Type = nil
62
68
63
63
-
return e.JSON(200, results[0])
69
69
+
s.writeJSON(w, 200, results[0])
64
70
}
+18
-12
server/handle_repo_delete_record.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"encoding/json"
5
5
+
"net/http"
6
6
+
4
7
"github.com/haileyok/cocoon/internal/helpers"
5
8
"github.com/haileyok/cocoon/models"
6
6
-
"github.com/labstack/echo/v4"
7
9
)
8
10
9
11
type ComAtprotoRepoDeleteRecordInput struct {
···
14
16
SwapCommit *string `json:"swapCommit"`
15
17
}
16
18
17
17
-
func (s *Server) handleDeleteRecord(e echo.Context) error {
18
18
-
ctx := e.Request().Context()
19
19
+
func (s *Server) handleDeleteRecord(w http.ResponseWriter, r *http.Request) {
20
20
+
ctx := r.Context()
19
21
logger := s.logger.With("name", "handleDeleteRecord")
20
22
21
21
-
repo := e.Get("repo").(*models.RepoActor)
23
23
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
22
24
23
25
var req ComAtprotoRepoDeleteRecordInput
24
24
-
if err := e.Bind(&req); err != nil {
25
25
-
logger.Error("error binding", "error", err)
26
26
-
return helpers.ServerError(e, nil)
26
26
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
27
27
+
logger.Error("error decoding", "error", err)
28
28
+
helpers.ServerError(w, nil)
29
29
+
return
27
30
}
28
31
29
29
-
if err := e.Validate(req); err != nil {
32
32
+
if err := s.validator.Struct(req); err != nil {
30
33
logger.Error("error validating", "error", err)
31
31
-
return helpers.InputError(e, nil)
34
34
+
helpers.InputError(w, nil)
35
35
+
return
32
36
}
33
37
34
38
if repo.Repo.Did != req.Repo {
35
39
logger.Warn("mismatched repo/auth")
36
36
-
return helpers.InputError(e, nil)
40
40
+
helpers.InputError(w, nil)
41
41
+
return
37
42
}
38
43
39
44
results, err := s.repoman.applyWrites(ctx, repo.Repo, []Op{
···
46
51
}, req.SwapCommit)
47
52
if err != nil {
48
53
logger.Error("error applying writes", "error", err)
49
49
-
return helpers.ServerError(e, nil)
54
54
+
helpers.ServerError(w, nil)
55
55
+
return
50
56
}
51
57
52
58
results[0].Type = nil
···
54
60
results[0].Cid = nil
55
61
results[0].ValidationStatus = nil
56
62
57
57
-
return e.JSON(200, results[0])
63
63
+
s.writeJSON(w, 200, results[0])
58
64
}
+21
-16
server/handle_repo_describe_repo.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
4
5
"strings"
5
6
6
7
"github.com/Azure/go-autorest/autorest/to"
7
8
"github.com/haileyok/cocoon/identity"
8
9
"github.com/haileyok/cocoon/internal/helpers"
9
10
"github.com/haileyok/cocoon/models"
10
10
-
"github.com/labstack/echo/v4"
11
11
"gorm.io/gorm"
12
12
)
13
13
···
19
19
HandleIsCorrect bool `json:"handleIsCorrect"`
20
20
}
21
21
22
22
-
func (s *Server) handleDescribeRepo(e echo.Context) error {
23
23
-
ctx := e.Request().Context()
22
22
+
func (s *Server) handleDescribeRepo(w http.ResponseWriter, r *http.Request) {
23
23
+
ctx := r.Context()
24
24
logger := s.logger.With("name", "handleDescribeRepo")
25
25
26
26
-
did := e.QueryParam("repo")
26
26
+
did := r.URL.Query().Get("repo")
27
27
repo, err := s.getRepoActorByDid(ctx, did)
28
28
if err != nil {
29
29
if err == gorm.ErrRecordNotFound {
30
30
-
return helpers.InputError(e, to.StringPtr("RepoNotFound"))
30
30
+
helpers.InputError(w, to.StringPtr("RepoNotFound"))
31
31
+
return
31
32
}
32
33
33
34
logger.Error("error looking up repo", "error", err)
34
34
-
return helpers.ServerError(e, nil)
35
35
+
helpers.ServerError(w, nil)
36
36
+
return
35
37
}
36
38
37
39
handleIsCorrect := true
38
40
39
39
-
diddoc, err := s.passport.FetchDoc(e.Request().Context(), repo.Repo.Did)
41
41
+
diddoc, err := s.passport.FetchDoc(r.Context(), repo.Repo.Did)
40
42
if err != nil {
41
43
logger.Error("error fetching diddoc", "error", err)
42
42
-
return helpers.ServerError(e, nil)
44
44
+
helpers.ServerError(w, nil)
45
45
+
return
43
46
}
44
47
45
48
dochandle := ""
···
55
58
}
56
59
57
60
if handleIsCorrect {
58
58
-
resolvedDid, err := s.passport.ResolveHandle(e.Request().Context(), repo.Handle)
61
61
+
resolvedDid, err := s.passport.ResolveHandle(r.Context(), repo.Handle)
59
62
if err != nil {
60
60
-
e.Logger().Error("error resolving handle", "error", err)
61
61
-
return helpers.ServerError(e, nil)
63
63
+
logger.Error("error resolving handle", "error", err)
64
64
+
helpers.ServerError(w, nil)
65
65
+
return
62
66
}
63
67
64
68
if resolvedDid != repo.Repo.Did {
···
69
73
var records []models.Record
70
74
if err := s.db.Raw(ctx, "SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil {
71
75
logger.Error("error getting collections", "error", err)
72
72
-
return helpers.ServerError(e, nil)
76
76
+
helpers.ServerError(w, nil)
77
77
+
return
73
78
}
74
79
75
75
-
var collections []string = make([]string, 0, len(records))
76
76
-
for _, r := range records {
77
77
-
collections = append(collections, r.Nsid)
80
80
+
collections := make([]string, 0, len(records))
81
81
+
for _, rec := range records {
82
82
+
collections = append(collections, rec.Nsid)
78
83
}
79
84
80
80
-
return e.JSON(200, ComAtprotoRepoDescribeRepoResponse{
85
85
+
s.writeJSON(w, 200, ComAtprotoRepoDescribeRepoResponse{
81
86
Did: repo.Repo.Did,
82
87
Handle: repo.Handle,
83
88
DidDoc: *diddoc,
+16
-12
server/handle_repo_get_record.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
5
5
+
4
6
"github.com/bluesky-social/indigo/atproto/atdata"
5
7
"github.com/bluesky-social/indigo/atproto/syntax"
6
8
"github.com/haileyok/cocoon/models"
7
7
-
"github.com/labstack/echo/v4"
8
9
)
9
10
10
11
type ComAtprotoRepoGetRecordResponse struct {
···
13
14
Value map[string]any `json:"value"`
14
15
}
15
16
16
16
-
func (s *Server) handleRepoGetRecord(e echo.Context) error {
17
17
-
ctx := e.Request().Context()
17
17
+
func (s *Server) handleRepoGetRecord(w http.ResponseWriter, r *http.Request) {
18
18
+
ctx := r.Context()
18
19
19
19
-
repo := e.QueryParam("repo")
20
20
-
collection := e.QueryParam("collection")
21
21
-
rkey := e.QueryParam("rkey")
22
22
-
cidstr := e.QueryParam("cid")
20
20
+
repo := r.URL.Query().Get("repo")
21
21
+
collection := r.URL.Query().Get("collection")
22
22
+
rkey := r.URL.Query().Get("rkey")
23
23
+
cidstr := r.URL.Query().Get("cid")
23
24
24
25
params := []any{repo, collection, rkey}
25
26
cidquery := ""
···
27
28
if cidstr != "" {
28
29
c, err := syntax.ParseCID(cidstr)
29
30
if err != nil {
30
30
-
return err
31
31
+
http.Error(w, err.Error(), http.StatusBadRequest)
32
32
+
return
31
33
}
32
34
params = append(params, c.String())
33
35
cidquery = " AND cid = ?"
···
35
37
36
38
var record models.Record
37
39
if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil {
38
38
-
// TODO: handle error nicely
39
39
-
return err
40
40
+
http.Error(w, err.Error(), http.StatusInternalServerError)
41
41
+
return
40
42
}
41
43
42
44
val, err := atdata.UnmarshalCBOR(record.Value)
43
45
if err != nil {
44
44
-
return s.handleProxy(e) // TODO: this should be getting handled like...if we don't find it in the db. why doesn't it throw error up there?
46
46
+
// Fall back to proxy if we can't find/decode the record locally
47
47
+
s.handleProxy(w, r)
48
48
+
return
45
49
}
46
50
47
47
-
return e.JSON(200, ComAtprotoRepoGetRecordResponse{
51
51
+
s.writeJSON(w, 200, ComAtprotoRepoGetRecordResponse{
48
52
Uri: "at://" + record.Did + "/" + record.Nsid + "/" + record.Rkey,
49
53
Cid: record.Cid,
50
54
Value: val,
+10
-9
server/handle_repo_list_missing_blobs.go
···
2
2
3
3
import (
4
4
"fmt"
5
5
+
"net/http"
5
6
"strconv"
6
7
7
8
"github.com/bluesky-social/indigo/atproto/atdata"
8
9
"github.com/haileyok/cocoon/internal/helpers"
9
10
"github.com/haileyok/cocoon/models"
10
11
"github.com/ipfs/go-cid"
11
11
-
"github.com/labstack/echo/v4"
12
12
)
13
13
14
14
type ComAtprotoRepoListMissingBlobsResponse struct {
···
21
21
RecordUri string `json:"recordUri"`
22
22
}
23
23
24
24
-
func (s *Server) handleListMissingBlobs(e echo.Context) error {
25
25
-
ctx := e.Request().Context()
26
26
-
logger := s.logger.With("name", "handleListMissingBlos")
24
24
+
func (s *Server) handleListMissingBlobs(w http.ResponseWriter, r *http.Request) {
25
25
+
ctx := r.Context()
26
26
+
logger := s.logger.With("name", "handleListMissingBlobs")
27
27
28
28
-
urepo := e.Get("repo").(*models.RepoActor)
28
28
+
urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
29
29
30
30
-
limitStr := e.QueryParam("limit")
31
31
-
cursor := e.QueryParam("cursor")
30
30
+
limitStr := r.URL.Query().Get("limit")
31
31
+
cursor := r.URL.Query().Get("cursor")
32
32
33
33
limit := 500
34
34
if limitStr != "" {
···
40
40
var records []models.Record
41
41
if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil {
42
42
logger.Error("failed to get records for listMissingBlobs", "error", err)
43
43
-
return helpers.ServerError(e, nil)
43
43
+
helpers.ServerError(w, nil)
44
44
+
return
44
45
}
45
46
46
47
type blobRef struct {
···
95
96
nextCursor = &lastCid
96
97
}
97
98
98
98
-
return e.JSON(200, ComAtprotoRepoListMissingBlobsResponse{
99
99
+
s.writeJSON(w, http.StatusOK, ComAtprotoRepoListMissingBlobsResponse{
99
100
Cursor: nextCursor,
100
101
Blobs: missingBlobs,
101
102
})
+33
-26
server/handle_repo_list_records.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
4
5
"strconv"
5
6
6
7
"github.com/Azure/go-autorest/autorest/to"
···
8
9
"github.com/bluesky-social/indigo/atproto/syntax"
9
10
"github.com/haileyok/cocoon/internal/helpers"
10
11
"github.com/haileyok/cocoon/models"
11
11
-
"github.com/labstack/echo/v4"
12
12
)
13
13
14
14
type ComAtprotoRepoListRecordsRequest struct {
···
30
30
Value map[string]any `json:"value"`
31
31
}
32
32
33
33
-
func getLimitFromContext(e echo.Context, def int) (int, error) {
33
33
+
func getLimitFromRequest(r *http.Request, def int) (int, error) {
34
34
limit := def
35
35
-
limitstr := e.QueryParam("limit")
35
35
+
limitstr := r.URL.Query().Get("limit")
36
36
37
37
if limitstr != "" {
38
38
l64, err := strconv.ParseInt(limitstr, 10, 32)
···
45
45
return limit, nil
46
46
}
47
47
48
48
-
func (s *Server) handleListRecords(e echo.Context) error {
49
49
-
ctx := e.Request().Context()
48
48
+
func (s *Server) handleListRecords(w http.ResponseWriter, r *http.Request) {
49
49
+
ctx := r.Context()
50
50
logger := s.logger.With("name", "handleListRecords")
51
51
52
52
-
var req ComAtprotoRepoListRecordsRequest
53
53
-
if err := e.Bind(&req); err != nil {
54
54
-
logger.Error("could not bind list records request", "error", err)
55
55
-
return helpers.ServerError(e, nil)
52
52
+
req := ComAtprotoRepoListRecordsRequest{
53
53
+
Repo: r.URL.Query().Get("repo"),
54
54
+
Collection: r.URL.Query().Get("collection"),
55
55
+
Cursor: r.URL.Query().Get("cursor"),
56
56
}
57
57
-
58
58
-
if err := e.Validate(req); err != nil {
59
59
-
return helpers.InputError(e, nil)
57
57
+
if v := r.URL.Query().Get("reverse"); v == "true" {
58
58
+
req.Reverse = true
60
59
}
61
60
62
62
-
if req.Limit <= 0 {
63
63
-
req.Limit = 50
64
64
-
} else if req.Limit > 100 {
65
65
-
req.Limit = 100
61
61
+
if err := s.validator.Struct(req); err != nil {
62
62
+
helpers.InputError(w, nil)
63
63
+
return
66
64
}
67
65
68
68
-
limit, err := getLimitFromContext(e, 50)
66
66
+
limit, err := getLimitFromRequest(r, 50)
69
67
if err != nil {
70
70
-
return helpers.InputError(e, nil)
68
68
+
helpers.InputError(w, nil)
69
69
+
return
70
70
+
}
71
71
+
if limit <= 0 {
72
72
+
limit = 50
73
73
+
} else if limit > 100 {
74
74
+
limit = 100
71
75
}
72
76
73
77
sort := "DESC"
···
83
87
if _, err := syntax.ParseDID(did); err != nil {
84
88
actor, err := s.getActorByHandle(ctx, req.Repo)
85
89
if err != nil {
86
86
-
return helpers.InputError(e, to.StringPtr("RepoNotFound"))
90
90
+
helpers.InputError(w, to.StringPtr("RepoNotFound"))
91
91
+
return
87
92
}
88
93
did = actor.Did
89
94
}
···
98
103
var records []models.Record
99
104
if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil {
100
105
logger.Error("error getting records", "error", err)
101
101
-
return helpers.ServerError(e, nil)
106
106
+
helpers.ServerError(w, nil)
107
107
+
return
102
108
}
103
109
104
110
items := []ComAtprotoRepoListRecordsRecordItem{}
105
105
-
for _, r := range records {
106
106
-
val, err := atdata.UnmarshalCBOR(r.Value)
111
111
+
for _, rec := range records {
112
112
+
val, err := atdata.UnmarshalCBOR(rec.Value)
107
113
if err != nil {
108
108
-
return err
114
114
+
helpers.ServerError(w, nil)
115
115
+
return
109
116
}
110
117
111
118
items = append(items, ComAtprotoRepoListRecordsRecordItem{
112
112
-
Uri: "at://" + r.Did + "/" + r.Nsid + "/" + r.Rkey,
113
113
-
Cid: r.Cid,
119
119
+
Uri: "at://" + rec.Did + "/" + rec.Nsid + "/" + rec.Rkey,
120
120
+
Cid: rec.Cid,
114
121
Value: val,
115
122
})
116
123
}
···
120
127
newcursor = to.StringPtr(records[len(records)-1].CreatedAt)
121
128
}
122
129
123
123
-
return e.JSON(200, ComAtprotoRepoListRecordsResponse{
130
130
+
s.writeJSON(w, 200, ComAtprotoRepoListRecordsResponse{
124
131
Cursor: newcursor,
125
132
Records: items,
126
133
})
+15
-12
server/handle_repo_list_repos.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
5
5
+
4
6
"github.com/haileyok/cocoon/models"
5
7
"github.com/ipfs/go-cid"
6
6
-
"github.com/labstack/echo/v4"
7
8
)
8
9
9
10
type ComAtprotoSyncListReposResponse struct {
···
20
21
}
21
22
22
23
// TODO: paginate this bitch
23
23
-
func (s *Server) handleListRepos(e echo.Context) error {
24
24
-
ctx := e.Request().Context()
24
24
+
func (s *Server) handleListRepos(w http.ResponseWriter, r *http.Request) {
25
25
+
ctx := r.Context()
25
26
26
27
var repos []models.Repo
27
28
if err := s.db.Raw(ctx, "SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil {
28
28
-
return err
29
29
+
http.Error(w, err.Error(), http.StatusInternalServerError)
30
30
+
return
29
31
}
30
32
31
33
items := make([]ComAtprotoSyncListReposRepoItem, 0, len(repos))
32
32
-
for _, r := range repos {
33
33
-
c, err := cid.Cast(r.Root)
34
34
+
for _, repo := range repos {
35
35
+
c, err := cid.Cast(repo.Root)
34
36
if err != nil {
35
35
-
return err
37
37
+
http.Error(w, err.Error(), http.StatusInternalServerError)
38
38
+
return
36
39
}
37
40
38
41
items = append(items, ComAtprotoSyncListReposRepoItem{
39
39
-
Did: r.Did,
42
42
+
Did: repo.Did,
40
43
Head: c.String(),
41
41
-
Rev: r.Rev,
42
42
-
Active: r.Active(),
43
43
-
Status: r.Status(),
44
44
+
Rev: repo.Rev,
45
45
+
Active: repo.Active(),
46
46
+
Status: repo.Status(),
44
47
})
45
48
}
46
49
47
47
-
return e.JSON(200, ComAtprotoSyncListReposResponse{
50
50
+
s.writeJSON(w, 200, ComAtprotoSyncListReposResponse{
48
51
Cursor: nil,
49
52
Repos: items,
50
53
})
+18
-12
server/handle_repo_put_record.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"encoding/json"
5
5
+
"net/http"
6
6
+
4
7
"github.com/haileyok/cocoon/internal/helpers"
5
8
"github.com/haileyok/cocoon/models"
6
6
-
"github.com/labstack/echo/v4"
7
9
)
8
10
9
11
type ComAtprotoRepoPutRecordInput struct {
···
16
18
SwapCommit *string `json:"swapCommit"`
17
19
}
18
20
19
19
-
func (s *Server) handlePutRecord(e echo.Context) error {
20
20
-
ctx := e.Request().Context()
21
21
+
func (s *Server) handlePutRecord(w http.ResponseWriter, r *http.Request) {
22
22
+
ctx := r.Context()
21
23
logger := s.logger.With("name", "handlePutRecord")
22
24
23
23
-
repo := e.Get("repo").(*models.RepoActor)
25
25
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
24
26
25
27
var req ComAtprotoRepoPutRecordInput
26
26
-
if err := e.Bind(&req); err != nil {
27
27
-
logger.Error("error binding", "error", err)
28
28
-
return helpers.ServerError(e, nil)
28
28
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
29
29
+
logger.Error("error decoding", "error", err)
30
30
+
helpers.ServerError(w, nil)
31
31
+
return
29
32
}
30
33
31
31
-
if err := e.Validate(req); err != nil {
34
34
+
if err := s.validator.Struct(req); err != nil {
32
35
logger.Error("error validating", "error", err)
33
33
-
return helpers.InputError(e, nil)
36
36
+
helpers.InputError(w, nil)
37
37
+
return
34
38
}
35
39
36
40
if repo.Repo.Did != req.Repo {
37
41
logger.Warn("mismatched repo/auth")
38
38
-
return helpers.InputError(e, nil)
42
42
+
helpers.InputError(w, nil)
43
43
+
return
39
44
}
40
45
41
46
optype := OpTypeCreate
···
55
60
}, req.SwapCommit)
56
61
if err != nil {
57
62
logger.Error("error applying writes", "error", err)
58
58
-
return helpers.ServerError(e, nil)
63
63
+
helpers.ServerError(w, nil)
64
64
+
return
59
65
}
60
66
61
67
results[0].Type = nil
62
68
63
63
-
return e.JSON(200, results[0])
69
69
+
s.writeJSON(w, 200, results[0])
64
70
}
+18
-13
server/handle_repo_upload_blob.go
···
10
10
"github.com/haileyok/cocoon/internal/helpers"
11
11
"github.com/haileyok/cocoon/models"
12
12
"github.com/ipfs/go-cid"
13
13
-
"github.com/labstack/echo/v4"
14
13
"github.com/multiformats/go-multihash"
15
14
)
16
15
···
29
28
} `json:"blob"`
30
29
}
31
30
32
32
-
func (s *Server) handleRepoUploadBlob(e echo.Context) error {
33
33
-
ctx := e.Request().Context()
31
31
+
func (s *Server) handleRepoUploadBlob(w http.ResponseWriter, r *http.Request) {
32
32
+
ctx := r.Context()
34
33
logger := s.logger.With("name", "handleRepoUploadBlob")
35
34
36
36
-
urepo := e.Get("repo").(*models.RepoActor)
35
35
+
urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
37
36
38
38
-
mime := e.Request().Header.Get("content-type")
37
37
+
mime := r.Header.Get("content-type")
39
38
if mime == "" {
40
39
mime = "application/octet-stream"
41
40
}
···
55
54
56
55
if err := s.db.Create(ctx, &blob, nil).Error; err != nil {
57
56
logger.Error("error creating new blob in db", "error", err)
58
58
-
return helpers.ServerError(e, nil)
57
57
+
helpers.ServerError(w, nil)
58
58
+
return
59
59
}
60
60
61
61
read := 0
···
65
65
fulldata := new(bytes.Buffer)
66
66
67
67
for {
68
68
-
n, err := io.ReadFull(e.Request().Body, buf)
68
68
+
n, err := io.ReadFull(r.Body, buf)
69
69
if err == io.ErrUnexpectedEOF || err == io.EOF {
70
70
if n == 0 {
71
71
break
72
72
}
73
73
} else if err != nil && err != io.ErrUnexpectedEOF {
74
74
logger.Error("error reading blob", "error", err)
75
75
-
return helpers.ServerError(e, nil)
75
75
+
helpers.ServerError(w, nil)
76
76
+
return
76
77
}
77
78
78
79
data := buf[:n]
···
88
89
89
90
if err := s.db.Create(ctx, &blobPart, nil).Error; err != nil {
90
91
logger.Error("error adding blob part to db", "error", err)
91
91
-
return helpers.ServerError(e, nil)
92
92
+
helpers.ServerError(w, nil)
93
93
+
return
92
94
}
93
95
}
94
96
part++
···
101
103
c, err := cid.NewPrefixV1(cid.Raw, multihash.SHA2_256).Sum(fulldata.Bytes())
102
104
if err != nil {
103
105
logger.Error("error creating cid prefix", "error", err)
104
104
-
return helpers.ServerError(e, nil)
106
106
+
helpers.ServerError(w, nil)
107
107
+
return
105
108
}
106
109
107
110
if ipfsUpload {
108
111
ipfsCid, err := s.addBlobToIPFS(fulldata.Bytes(), mime)
109
112
if err != nil {
110
113
logger.Error("error adding blob to ipfs", "error", err)
111
111
-
return helpers.ServerError(e, nil)
114
114
+
helpers.ServerError(w, nil)
115
115
+
return
112
116
}
113
117
114
118
// Overwrite the locally computed CID with the one returned by the IPFS
···
126
130
127
131
if err := s.db.Exec(ctx, "UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil {
128
132
logger.Error("error updating blob", "error", err)
129
129
-
return helpers.ServerError(e, nil)
133
133
+
helpers.ServerError(w, nil)
134
134
+
return
130
135
}
131
136
132
137
resp := ComAtprotoRepoUploadBlobResponse{}
···
135
140
resp.Blob.MimeType = mime
136
141
resp.Blob.Size = read
137
142
138
138
-
return e.JSON(200, resp)
143
143
+
s.writeJSON(w, 200, resp)
139
144
}
140
145
141
146
// addBlobToIPFS adds raw blob data to the configured IPFS node via the Kubo
+7
-3
server/handle_robots.go
···
1
1
package server
2
2
3
3
-
import "github.com/labstack/echo/v4"
3
3
+
import (
4
4
+
"fmt"
5
5
+
"net/http"
6
6
+
)
4
7
5
5
-
func (s *Server) handleRobots(e echo.Context) error {
6
6
-
return e.String(200, "# Beep boop beep boop\n\n# Crawl me 🥺\nUser-agent: *\nAllow: /")
8
8
+
func (s *Server) handleRobots(w http.ResponseWriter, r *http.Request) {
9
9
+
w.Header().Set("Content-Type", "text/plain")
10
10
+
fmt.Fprint(w, "# Beep boop beep boop\n\n# Crawl me 🥺\nUser-agent: *\nAllow: /")
7
11
}
+32
-28
server/handle_root.go
···
1
1
package server
2
2
3
3
-
import "github.com/labstack/echo/v4"
3
3
+
import (
4
4
+
"fmt"
5
5
+
"net/http"
6
6
+
)
4
7
5
5
-
func (s *Server) handleRoot(e echo.Context) error {
6
6
-
return e.String(200, `
8
8
+
func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) {
9
9
+
w.Header().Set("Content-Type", "text/plain")
10
10
+
fmt.Fprint(w, `
7
11
8
8
-
....-*%%%#####
12
12
+
....-*%%%#####
9
13
.%#+++****#%%%%%%%%%#+:....
10
10
-
.%+++**++++*%%%%.....
11
11
-
.%+++*****#%%%%#.. %#%...
12
12
-
***+*****%%%%%... =..
13
13
-
*****%%%%.. +=++..
14
14
-
%%%%%... .+----==++.
15
15
-
.-::----===++
16
16
-
.=-:.------==+++
17
17
-
+-:::-:----===++..
18
18
-
=-::-----:-==+++-.
19
19
-
.==*=------==++++.
20
20
-
+-:--=++===*=--++.
21
21
-
+:::--:=++=----=+..
22
22
-
*::::---=+#----=+.
23
23
-
=::::----=+#---=+..
24
24
-
.::::----==+=--=+..
25
25
-
.-::-----==++=-=+..
26
26
-
-::-----==++===+..
27
27
-
=::-----==++==++
28
28
-
+::----:==++=+++
29
29
-
:-:----:==+++++.
30
30
-
.=:=----=+++++.
31
31
-
+=-=====+++..
32
32
-
=====++.
33
33
-
=++...
14
14
+
.%+++**++++*%%%%.....
15
15
+
.%+++*****#%%%%#.. %#%...
16
16
+
***+*****%%%%%... =..
17
17
+
*****%%%%.. +=++..
18
18
+
%%%%%... .+----==++.
19
19
+
.-::----===++
20
20
+
.=-:.------==+++
21
21
+
+-:::-:----===++..
22
22
+
=-::-----:-==+++-.
23
23
+
.==*=------==++++.
24
24
+
+-:--=++===*=--++.
25
25
+
+:::--:=++=----=+..
26
26
+
*::::---=+#----=+.
27
27
+
=::::----=+#---=+..
28
28
+
.::::----==+=--=+..
29
29
+
.-::-----==++=-=+..
30
30
+
-::-----==++===+..
31
31
+
=::-----==++==++
32
32
+
+::----:==++=+++
33
33
+
:-:----:==+++++.
34
34
+
.=:=----=+++++.
35
35
+
+=-=====+++..
36
36
+
=====++.
37
37
+
=++...
34
38
35
39
36
40
This is an AT Protocol Personal Data Server (aka, an atproto PDS)
+8
-13
server/handle_server_activate_account.go
···
2
2
3
3
import (
4
4
"context"
5
5
+
"net/http"
5
6
"time"
6
7
7
8
"github.com/bluesky-social/indigo/api/atproto"
···
9
10
"github.com/bluesky-social/indigo/util"
10
11
"github.com/haileyok/cocoon/internal/helpers"
11
12
"github.com/haileyok/cocoon/models"
12
12
-
"github.com/labstack/echo/v4"
13
13
)
14
14
15
15
type ComAtprotoServerActivateAccountRequest struct {
···
17
17
DeleteAfter time.Time `json:"deleteAfter"`
18
18
}
19
19
20
20
-
func (s *Server) handleServerActivateAccount(e echo.Context) error {
21
21
-
ctx := e.Request().Context()
20
20
+
func (s *Server) handleServerActivateAccount(w http.ResponseWriter, r *http.Request) {
21
21
+
ctx := r.Context()
22
22
logger := s.logger.With("name", "handleServerActivateAccount")
23
23
24
24
-
var req ComAtprotoServerDeactivateAccountRequest
25
25
-
if err := e.Bind(&req); err != nil {
26
26
-
logger.Error("error binding", "error", err)
27
27
-
return helpers.ServerError(e, nil)
28
28
-
}
29
29
-
30
30
-
urepo := e.Get("repo").(*models.RepoActor)
24
24
+
urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
31
25
32
26
if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil {
33
33
-
logger.Error("error updating account status to deactivated", "error", err)
34
34
-
return helpers.ServerError(e, nil)
27
27
+
logger.Error("error updating account status to activated", "error", err)
28
28
+
helpers.ServerError(w, nil)
29
29
+
return
35
30
}
36
31
37
32
s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
···
44
39
},
45
40
})
46
41
47
47
-
return e.NoContent(200)
42
42
+
w.WriteHeader(http.StatusOK)
48
43
}
+14
-9
server/handle_server_check_account_status.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
5
5
+
4
6
"github.com/haileyok/cocoon/internal/helpers"
5
7
"github.com/haileyok/cocoon/models"
6
8
"github.com/ipfs/go-cid"
7
7
-
"github.com/labstack/echo/v4"
8
9
)
9
10
10
11
type ComAtprotoServerCheckAccountStatusResponse struct {
···
19
20
ImportedBlobs int64 `json:"importedBlobs"`
20
21
}
21
22
22
22
-
func (s *Server) handleServerCheckAccountStatus(e echo.Context) error {
23
23
-
ctx := e.Request().Context()
23
23
+
func (s *Server) handleServerCheckAccountStatus(w http.ResponseWriter, r *http.Request) {
24
24
+
ctx := r.Context()
24
25
logger := s.logger.With("name", "handleServerCheckAccountStatus")
25
26
26
26
-
urepo := e.Get("repo").(*models.RepoActor)
27
27
+
urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
27
28
28
29
resp := ComAtprotoServerCheckAccountStatusResponse{
29
30
Activated: true, // TODO: should allow for deactivation etc.
···
35
36
rootcid, err := cid.Cast(urepo.Root)
36
37
if err != nil {
37
38
logger.Error("error casting cid", "error", err)
38
38
-
return helpers.ServerError(e, nil)
39
39
+
helpers.ServerError(w, nil)
40
40
+
return
39
41
}
40
42
resp.RepoCommit = rootcid.String()
41
43
···
46
48
var blockCtResp CountResp
47
49
if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil {
48
50
logger.Error("error getting block count", "error", err)
49
49
-
return helpers.ServerError(e, nil)
51
51
+
helpers.ServerError(w, nil)
52
52
+
return
50
53
}
51
54
resp.RepoBlocks = blockCtResp.Ct
52
55
53
56
var recCtResp CountResp
54
57
if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil {
55
58
logger.Error("error getting record count", "error", err)
56
56
-
return helpers.ServerError(e, nil)
59
59
+
helpers.ServerError(w, nil)
60
60
+
return
57
61
}
58
62
resp.IndexedRecords = recCtResp.Ct
59
63
60
64
var blobCtResp CountResp
61
65
if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil {
62
66
logger.Error("error getting record count", "error", err)
63
63
-
return helpers.ServerError(e, nil)
67
67
+
helpers.ServerError(w, nil)
68
68
+
return
64
69
}
65
70
resp.ExpectedBlobs = blobCtResp.Ct
66
71
67
67
-
return e.JSON(200, resp)
72
72
+
s.writeJSON(w, 200, resp)
68
73
}
+21
-14
server/handle_server_confirm_email.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"encoding/json"
5
5
+
"net/http"
4
6
"time"
5
7
6
8
"github.com/Azure/go-autorest/autorest/to"
7
9
"github.com/haileyok/cocoon/internal/helpers"
8
10
"github.com/haileyok/cocoon/models"
9
9
-
"github.com/labstack/echo/v4"
10
11
)
11
12
12
13
type ComAtprotoServerConfirmEmailRequest struct {
···
14
15
Token string `json:"token" validate:"required"`
15
16
}
16
17
17
17
-
func (s *Server) handleServerConfirmEmail(e echo.Context) error {
18
18
-
ctx := e.Request().Context()
18
18
+
func (s *Server) handleServerConfirmEmail(w http.ResponseWriter, r *http.Request) {
19
19
+
ctx := r.Context()
19
20
logger := s.logger.With("name", "handleServerConfirmEmail")
20
21
21
21
-
urepo := e.Get("repo").(*models.RepoActor)
22
22
+
urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
22
23
23
24
var req ComAtprotoServerConfirmEmailRequest
24
24
-
if err := e.Bind(&req); err != nil {
25
25
-
logger.Error("error binding", "error", err)
26
26
-
return helpers.ServerError(e, nil)
25
25
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
26
26
+
logger.Error("error decoding", "error", err)
27
27
+
helpers.ServerError(w, nil)
28
28
+
return
27
29
}
28
30
29
29
-
if err := e.Validate(req); err != nil {
30
30
-
return helpers.InputError(e, nil)
31
31
+
if err := s.validator.Struct(req); err != nil {
32
32
+
helpers.InputError(w, nil)
33
33
+
return
31
34
}
32
35
33
36
if urepo.EmailVerificationCode == nil || urepo.EmailVerificationCodeExpiresAt == nil {
34
34
-
return helpers.ExpiredTokenError(e)
37
37
+
helpers.ExpiredTokenError(w)
38
38
+
return
35
39
}
36
40
37
41
if *urepo.EmailVerificationCode != req.Token {
38
38
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
42
42
+
helpers.InputError(w, to.StringPtr("InvalidToken"))
43
43
+
return
39
44
}
40
45
41
46
if time.Now().UTC().After(*urepo.EmailVerificationCodeExpiresAt) {
42
42
-
return helpers.ExpiredTokenError(e)
47
47
+
helpers.ExpiredTokenError(w)
48
48
+
return
43
49
}
44
50
45
51
now := time.Now().UTC()
46
52
47
53
if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil {
48
54
logger.Error("error updating user", "error", err)
49
49
-
return helpers.ServerError(e, nil)
55
55
+
helpers.ServerError(w, nil)
56
56
+
return
50
57
}
51
58
52
52
-
return e.NoContent(200)
59
59
+
w.WriteHeader(http.StatusOK)
53
60
}
+67
-39
server/handle_server_create_account.go
···
2
2
3
3
import (
4
4
"context"
5
5
+
"encoding/json"
5
6
"errors"
6
7
"fmt"
8
8
+
"net/http"
7
9
"strings"
8
10
"time"
9
11
···
17
19
"github.com/bluesky-social/indigo/util"
18
20
"github.com/haileyok/cocoon/internal/helpers"
19
21
"github.com/haileyok/cocoon/models"
20
20
-
"github.com/labstack/echo/v4"
21
22
"golang.org/x/crypto/bcrypt"
22
23
"gorm.io/gorm"
23
24
)
···
37
38
Did string `json:"did"`
38
39
}
39
40
40
40
-
func (s *Server) handleCreateAccount(e echo.Context) error {
41
41
-
ctx := e.Request().Context()
41
41
+
func (s *Server) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
42
42
+
ctx := r.Context()
42
43
logger := s.logger.With("name", "handleServerCreateAccount")
43
44
44
45
var request ComAtprotoServerCreateAccountRequest
45
46
46
46
-
if err := e.Bind(&request); err != nil {
47
47
+
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
47
48
logger.Error("error receiving request", "endpoint", "com.atproto.server.createAccount", "error", err)
48
48
-
return helpers.ServerError(e, nil)
49
49
+
helpers.ServerError(w, nil)
50
50
+
return
49
51
}
50
52
51
53
request.Handle = strings.ToLower(request.Handle)
52
54
53
53
-
if err := e.Validate(request); err != nil {
55
55
+
if err := s.validator.Struct(request); err != nil {
54
56
logger.Error("error validating request", "endpoint", "com.atproto.server.createAccount", "error", err)
55
57
56
58
var verr ValidationError
57
59
if errors.As(err, &verr) {
58
60
if verr.Field == "Email" {
59
59
-
// TODO: what is this supposed to be? `InvalidEmail` isn't listed in doc
60
60
-
return helpers.InputError(e, to.StringPtr("InvalidEmail"))
61
61
+
helpers.InputError(w, to.StringPtr("InvalidEmail"))
62
62
+
return
61
63
}
62
64
63
65
if verr.Field == "Handle" {
64
64
-
return helpers.InputError(e, to.StringPtr("InvalidHandle"))
66
66
+
helpers.InputError(w, to.StringPtr("InvalidHandle"))
67
67
+
return
65
68
}
66
69
67
70
if verr.Field == "Password" {
68
68
-
return helpers.InputError(e, to.StringPtr("InvalidPassword"))
71
71
+
helpers.InputError(w, to.StringPtr("InvalidPassword"))
72
72
+
return
69
73
}
70
74
71
75
if verr.Field == "InviteCode" {
72
72
-
return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
76
76
+
helpers.InputError(w, to.StringPtr("InvalidInviteCode"))
77
77
+
return
73
78
}
74
79
}
75
80
}
···
78
83
if request.Did != nil {
79
84
signupDid = *request.Did
80
85
81
81
-
token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1))
86
86
+
token := strings.TrimSpace(strings.Replace(r.Header.Get("authorization"), "Bearer ", "", 1))
82
87
if token == "" {
83
83
-
return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did"))
88
88
+
helpers.UnauthorizedError(w, to.StringPtr("must authenticate to use an existing did"))
89
89
+
return
84
90
}
85
85
-
authDid, err := s.validateServiceAuth(e.Request().Context(), token, "com.atproto.server.createAccount")
91
91
+
authDid, err := s.validateServiceAuth(r.Context(), token, "com.atproto.server.createAccount")
86
92
87
93
if err != nil {
88
94
logger.Warn("error validating authorization token", "endpoint", "com.atproto.server.createAccount", "error", err)
89
89
-
return helpers.UnauthorizedError(e, to.StringPtr("invalid authorization token"))
95
95
+
helpers.UnauthorizedError(w, to.StringPtr("invalid authorization token"))
96
96
+
return
90
97
}
91
98
92
99
if authDid != signupDid {
93
93
-
return helpers.ForbiddenError(e, to.StringPtr("auth did did not match signup did"))
100
100
+
helpers.ForbiddenError(w, to.StringPtr("auth did did not match signup did"))
101
101
+
return
94
102
}
95
103
}
96
104
···
98
106
actor, err := s.getActorByHandle(ctx, request.Handle)
99
107
if err != nil && err != gorm.ErrRecordNotFound {
100
108
logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err)
101
101
-
return helpers.ServerError(e, nil)
109
109
+
helpers.ServerError(w, nil)
110
110
+
return
102
111
}
103
112
if err == nil && actor.Did != signupDid {
104
104
-
return helpers.InputError(e, to.StringPtr("HandleNotAvailable"))
113
113
+
helpers.InputError(w, to.StringPtr("HandleNotAvailable"))
114
114
+
return
105
115
}
106
116
107
107
-
if did, err := s.passport.ResolveHandle(e.Request().Context(), request.Handle); err == nil && did != signupDid {
108
108
-
return helpers.InputError(e, to.StringPtr("HandleNotAvailable"))
117
117
+
if did, err := s.passport.ResolveHandle(r.Context(), request.Handle); err == nil && did != signupDid {
118
118
+
helpers.InputError(w, to.StringPtr("HandleNotAvailable"))
119
119
+
return
109
120
}
110
121
111
122
var ic models.InviteCode
112
123
if s.config.RequireInvite {
113
124
if strings.TrimSpace(request.InviteCode) == "" {
114
114
-
return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
125
125
+
helpers.InputError(w, to.StringPtr("InvalidInviteCode"))
126
126
+
return
115
127
}
116
128
117
129
if err := s.db.Raw(ctx, "SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
118
130
if err == gorm.ErrRecordNotFound {
119
119
-
return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
131
131
+
helpers.InputError(w, to.StringPtr("InvalidInviteCode"))
132
132
+
return
120
133
}
121
134
logger.Error("error getting invite code from db", "error", err)
122
122
-
return helpers.ServerError(e, nil)
135
135
+
helpers.ServerError(w, nil)
136
136
+
return
123
137
}
124
138
125
139
if ic.RemainingUseCount < 1 {
126
126
-
return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
140
140
+
helpers.InputError(w, to.StringPtr("InvalidInviteCode"))
141
141
+
return
127
142
}
128
143
}
129
144
···
131
146
existingRepo, err := s.getRepoByEmail(ctx, request.Email)
132
147
if err != nil && err != gorm.ErrRecordNotFound {
133
148
logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err)
134
134
-
return helpers.ServerError(e, nil)
149
149
+
helpers.ServerError(w, nil)
150
150
+
return
135
151
}
136
152
if err == nil && existingRepo.Did != signupDid {
137
137
-
return helpers.InputError(e, to.StringPtr("EmailNotAvailable"))
153
153
+
helpers.InputError(w, to.StringPtr("EmailNotAvailable"))
154
154
+
return
138
155
}
139
156
140
157
// TODO: unsupported domains
···
165
182
k, err = atcrypto.GeneratePrivateKeyK256()
166
183
if err != nil {
167
184
logger.Error("error creating signing key", "endpoint", "com.atproto.server.createAccount", "error", err)
168
168
-
return helpers.ServerError(e, nil)
185
185
+
helpers.ServerError(w, nil)
186
186
+
return
169
187
}
170
188
}
171
189
···
173
191
did, op, err := s.plcClient.CreateDID(k, "", request.Handle)
174
192
if err != nil {
175
193
logger.Error("error creating operation", "endpoint", "com.atproto.server.createAccount", "error", err)
176
176
-
return helpers.ServerError(e, nil)
194
194
+
helpers.ServerError(w, nil)
195
195
+
return
177
196
}
178
197
179
179
-
if err := s.plcClient.SendOperation(e.Request().Context(), did, op); err != nil {
198
198
+
if err := s.plcClient.SendOperation(r.Context(), did, op); err != nil {
180
199
logger.Error("error sending plc op", "endpoint", "com.atproto.server.createAccount", "error", err)
181
181
-
return helpers.ServerError(e, nil)
200
200
+
helpers.ServerError(w, nil)
201
201
+
return
182
202
}
183
203
signupDid = did
184
204
}
···
186
206
hashed, err := bcrypt.GenerateFromPassword([]byte(request.Password), 10)
187
207
if err != nil {
188
208
logger.Error("error hashing password", "error", err)
189
189
-
return helpers.ServerError(e, nil)
209
209
+
helpers.ServerError(w, nil)
210
210
+
return
190
211
}
191
212
192
213
urepo := models.Repo{
···
206
227
207
228
if err := s.db.Create(ctx, &urepo, nil).Error; err != nil {
208
229
logger.Error("error inserting new repo", "error", err)
209
209
-
return helpers.ServerError(e, nil)
230
230
+
helpers.ServerError(w, nil)
231
231
+
return
210
232
}
211
233
212
234
if err := s.db.Create(ctx, &actor, nil).Error; err != nil {
213
235
logger.Error("error inserting new actor", "error", err)
214
214
-
return helpers.ServerError(e, nil)
236
236
+
helpers.ServerError(w, nil)
237
237
+
return
215
238
}
216
239
} else {
217
240
if err := s.db.Save(ctx, &actor, nil).Error; err != nil {
218
241
logger.Error("error inserting new actor", "error", err)
219
219
-
return helpers.ServerError(e, nil)
242
242
+
helpers.ServerError(w, nil)
243
243
+
return
220
244
}
221
245
}
222
246
···
234
258
root, rev, err := commitRepo(context.TODO(), bs, r, urepo.SigningKey)
235
259
if err != nil {
236
260
logger.Error("error committing", "error", err)
237
237
-
return helpers.ServerError(e, nil)
261
261
+
helpers.ServerError(w, nil)
262
262
+
return
238
263
}
239
264
240
265
if err := s.UpdateRepo(context.TODO(), urepo.Did, root, rev); err != nil {
241
266
logger.Error("error updating repo after commit", "error", err)
242
242
-
return helpers.ServerError(e, nil)
267
267
+
helpers.ServerError(w, nil)
268
268
+
return
243
269
}
244
270
245
271
s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
···
255
281
if s.config.RequireInvite {
256
282
if err := s.db.Raw(ctx, "UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
257
283
logger.Error("error decrementing use count", "error", err)
258
258
-
return helpers.ServerError(e, nil)
284
284
+
helpers.ServerError(w, nil)
285
285
+
return
259
286
}
260
287
}
261
288
262
289
sess, err := s.createSession(ctx, &urepo)
263
290
if err != nil {
264
291
logger.Error("error creating new session", "error", err)
265
265
-
return helpers.ServerError(e, nil)
292
292
+
helpers.ServerError(w, nil)
293
293
+
return
266
294
}
267
295
268
296
go func() {
···
274
302
}
275
303
}()
276
304
277
277
-
return e.JSON(200, ComAtprotoServerCreateAccountResponse{
305
305
+
s.writeJSON(w, 200, ComAtprotoServerCreateAccountResponse{
278
306
AccessJwt: sess.AccessToken,
279
307
RefreshJwt: sess.RefreshToken,
280
308
Handle: request.Handle,
+15
-10
server/handle_server_create_invite_code.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"encoding/json"
5
5
+
"net/http"
6
6
+
4
7
"github.com/google/uuid"
5
8
"github.com/haileyok/cocoon/internal/helpers"
6
9
"github.com/haileyok/cocoon/models"
7
7
-
"github.com/labstack/echo/v4"
8
10
)
9
11
10
12
type ComAtprotoServerCreateInviteCodeRequest struct {
···
16
18
Code string `json:"code"`
17
19
}
18
20
19
19
-
func (s *Server) handleCreateInviteCode(e echo.Context) error {
20
20
-
ctx := e.Request().Context()
21
21
+
func (s *Server) handleCreateInviteCode(w http.ResponseWriter, r *http.Request) {
22
22
+
ctx := r.Context()
21
23
logger := s.logger.With("name", "handleServerCreateInviteCode")
22
24
23
25
var req ComAtprotoServerCreateInviteCodeRequest
24
24
-
if err := e.Bind(&req); err != nil {
25
25
-
logger.Error("error binding", "error", err)
26
26
-
return helpers.ServerError(e, nil)
26
26
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
27
27
+
logger.Error("error decoding", "error", err)
28
28
+
helpers.ServerError(w, nil)
29
29
+
return
27
30
}
28
31
29
29
-
if err := e.Validate(req); err != nil {
32
32
+
if err := s.validator.Struct(req); err != nil {
30
33
logger.Error("error validating", "error", err)
31
31
-
return helpers.InputError(e, nil)
34
34
+
helpers.InputError(w, nil)
35
35
+
return
32
36
}
33
37
34
38
ic := uuid.NewString()
···
46
50
RemainingUseCount: req.UseCount,
47
51
}, nil).Error; err != nil {
48
52
logger.Error("error creating invite code", "error", err)
49
49
-
return helpers.ServerError(e, nil)
53
53
+
helpers.ServerError(w, nil)
54
54
+
return
50
55
}
51
56
52
52
-
return e.JSON(200, ComAtprotoServerCreateInviteCodeResponse{
57
57
+
s.writeJSON(w, 200, ComAtprotoServerCreateInviteCodeResponse{
53
58
Code: ic,
54
59
})
55
60
}
+15
-10
server/handle_server_create_invite_codes.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"encoding/json"
5
5
+
"net/http"
6
6
+
4
7
"github.com/Azure/go-autorest/autorest/to"
5
8
"github.com/google/uuid"
6
9
"github.com/haileyok/cocoon/internal/helpers"
7
10
"github.com/haileyok/cocoon/models"
8
8
-
"github.com/labstack/echo/v4"
9
11
)
10
12
11
13
type ComAtprotoServerCreateInviteCodesRequest struct {
···
21
23
Codes []string `json:"codes"`
22
24
}
23
25
24
24
-
func (s *Server) handleCreateInviteCodes(e echo.Context) error {
25
25
-
ctx := e.Request().Context()
26
26
+
func (s *Server) handleCreateInviteCodes(w http.ResponseWriter, r *http.Request) {
27
27
+
ctx := r.Context()
26
28
logger := s.logger.With("name", "handleServerCreateInviteCodes")
27
29
28
30
var req ComAtprotoServerCreateInviteCodesRequest
29
29
-
if err := e.Bind(&req); err != nil {
30
30
-
logger.Error("error binding", "error", err)
31
31
-
return helpers.ServerError(e, nil)
31
31
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
32
32
+
logger.Error("error decoding", "error", err)
33
33
+
helpers.ServerError(w, nil)
34
34
+
return
32
35
}
33
36
34
34
-
if err := e.Validate(req); err != nil {
37
37
+
if err := s.validator.Struct(req); err != nil {
35
38
logger.Error("error validating", "error", err)
36
36
-
return helpers.InputError(e, nil)
39
39
+
helpers.InputError(w, nil)
40
40
+
return
37
41
}
38
42
39
43
if req.CodeCount == nil {
···
59
63
RemainingUseCount: req.UseCount,
60
64
}, nil).Error; err != nil {
61
65
logger.Error("error creating invite code", "error", err)
62
62
-
return helpers.ServerError(e, nil)
66
66
+
helpers.ServerError(w, nil)
67
67
+
return
63
68
}
64
69
}
65
70
···
69
74
})
70
75
}
71
76
72
72
-
return e.JSON(200, codes)
77
77
+
s.writeJSON(w, 200, codes)
73
78
}
+36
-22
server/handle_server_create_session.go
···
2
2
3
3
import (
4
4
"context"
5
5
+
"encoding/json"
5
6
"errors"
6
7
"fmt"
8
8
+
"net/http"
7
9
"strings"
8
10
"time"
9
11
···
11
13
"github.com/bluesky-social/indigo/atproto/syntax"
12
14
"github.com/haileyok/cocoon/internal/helpers"
13
15
"github.com/haileyok/cocoon/models"
14
14
-
"github.com/labstack/echo/v4"
15
16
"golang.org/x/crypto/bcrypt"
16
17
"gorm.io/gorm"
17
18
)
···
34
35
Status *string `json:"status,omitempty"`
35
36
}
36
37
37
37
-
func (s *Server) handleCreateSession(e echo.Context) error {
38
38
-
ctx := e.Request().Context()
38
38
+
func (s *Server) handleCreateSession(w http.ResponseWriter, r *http.Request) {
39
39
+
ctx := r.Context()
39
40
logger := s.logger.With("name", "handleServerCreateSession")
40
41
41
42
var req ComAtprotoServerCreateSessionRequest
42
42
-
if err := e.Bind(&req); err != nil {
43
43
-
logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err)
44
44
-
return helpers.ServerError(e, nil)
43
43
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
44
44
+
logger.Error("error decoding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err)
45
45
+
helpers.ServerError(w, nil)
46
46
+
return
45
47
}
46
48
47
47
-
if err := e.Validate(req); err != nil {
49
49
+
if err := s.validator.Struct(req); err != nil {
48
50
var verr ValidationError
49
51
if errors.As(err, &verr) {
50
52
if verr.Field == "Identifier" {
51
51
-
return helpers.InputError(e, to.StringPtr("InvalidRequest"))
53
53
+
helpers.InputError(w, to.StringPtr("InvalidRequest"))
54
54
+
return
52
55
}
53
56
54
57
if verr.Field == "Password" {
55
55
-
return helpers.InputError(e, to.StringPtr("InvalidRequest"))
58
58
+
helpers.InputError(w, to.StringPtr("InvalidRequest"))
59
59
+
return
56
60
}
57
61
}
58
62
}
···
80
84
81
85
if err != nil {
82
86
if err == gorm.ErrRecordNotFound {
83
83
-
return helpers.InputError(e, to.StringPtr("InvalidRequest"))
87
87
+
helpers.InputError(w, to.StringPtr("InvalidRequest"))
88
88
+
return
84
89
}
85
90
86
86
-
logger.Error("erorr looking up repo", "endpoint", "com.atproto.server.createSession", "error", err)
87
87
-
return helpers.ServerError(e, nil)
91
91
+
logger.Error("error looking up repo", "endpoint", "com.atproto.server.createSession", "error", err)
92
92
+
helpers.ServerError(w, nil)
93
93
+
return
88
94
}
89
95
90
96
if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
91
97
if err != bcrypt.ErrMismatchedHashAndPassword {
92
92
-
logger.Error("erorr comparing hash and password", "error", err)
98
98
+
logger.Error("error comparing hash and password", "error", err)
93
99
}
94
94
-
return helpers.InputError(e, to.StringPtr("InvalidRequest"))
100
100
+
helpers.InputError(w, to.StringPtr("InvalidRequest"))
101
101
+
return
95
102
}
96
103
97
104
// if repo requires 2FA token and one hasn't been provided, return error prompting for one
···
99
106
err = s.createAndSendTwoFactorCode(ctx, repo)
100
107
if err != nil {
101
108
logger.Error("sending 2FA code", "error", err)
102
102
-
return helpers.ServerError(e, nil)
109
109
+
helpers.ServerError(w, nil)
110
110
+
return
103
111
}
104
112
105
105
-
return helpers.InputError(e, to.StringPtr("AuthFactorTokenRequired"))
113
113
+
helpers.InputError(w, to.StringPtr("AuthFactorTokenRequired"))
114
114
+
return
106
115
}
107
116
108
117
// if 2FA is required, now check that the one provided is valid
···
111
120
err = s.createAndSendTwoFactorCode(ctx, repo)
112
121
if err != nil {
113
122
logger.Error("sending 2FA code", "error", err)
114
114
-
return helpers.ServerError(e, nil)
123
123
+
helpers.ServerError(w, nil)
124
124
+
return
115
125
}
116
126
117
117
-
return helpers.InputError(e, to.StringPtr("AuthFactorTokenRequired"))
127
127
+
helpers.InputError(w, to.StringPtr("AuthFactorTokenRequired"))
128
128
+
return
118
129
}
119
130
120
131
if *repo.TwoFactorCode != *req.AuthFactorToken {
121
121
-
return helpers.InvalidTokenError(e)
132
132
+
helpers.InvalidTokenError(w)
133
133
+
return
122
134
}
123
135
124
136
if time.Now().UTC().After(*repo.TwoFactorCodeExpiresAt) {
125
125
-
return helpers.ExpiredTokenError(e)
137
137
+
helpers.ExpiredTokenError(w)
138
138
+
return
126
139
}
127
140
}
128
141
129
142
sess, err := s.createSession(ctx, &repo.Repo)
130
143
if err != nil {
131
144
logger.Error("error creating session", "error", err)
132
132
-
return helpers.ServerError(e, nil)
145
145
+
helpers.ServerError(w, nil)
146
146
+
return
133
147
}
134
148
135
135
-
return e.JSON(200, ComAtprotoServerCreateSessionResponse{
149
149
+
s.writeJSON(w, 200, ComAtprotoServerCreateSessionResponse{
136
150
AccessJwt: sess.AccessToken,
137
151
RefreshJwt: sess.RefreshToken,
138
152
Handle: repo.Handle,
+7
-12
server/handle_server_deactivate_account.go
···
2
2
3
3
import (
4
4
"context"
5
5
+
"net/http"
5
6
"time"
6
7
7
8
"github.com/Azure/go-autorest/autorest/to"
···
10
11
"github.com/bluesky-social/indigo/util"
11
12
"github.com/haileyok/cocoon/internal/helpers"
12
13
"github.com/haileyok/cocoon/models"
13
13
-
"github.com/labstack/echo/v4"
14
14
)
15
15
16
16
type ComAtprotoServerDeactivateAccountRequest struct {
···
18
18
DeleteAfter time.Time `json:"deleteAfter"`
19
19
}
20
20
21
21
-
func (s *Server) handleServerDeactivateAccount(e echo.Context) error {
22
22
-
ctx := e.Request().Context()
21
21
+
func (s *Server) handleServerDeactivateAccount(w http.ResponseWriter, r *http.Request) {
22
22
+
ctx := r.Context()
23
23
logger := s.logger.With("name", "handleServerDeactivateAccount")
24
24
25
25
-
var req ComAtprotoServerDeactivateAccountRequest
26
26
-
if err := e.Bind(&req); err != nil {
27
27
-
logger.Error("error binding", "error", err)
28
28
-
return helpers.ServerError(e, nil)
29
29
-
}
30
30
-
31
31
-
urepo := e.Get("repo").(*models.RepoActor)
25
25
+
urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
32
26
33
27
if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, true, urepo.Repo.Did).Error; err != nil {
34
28
logger.Error("error updating account status to deactivated", "error", err)
35
35
-
return helpers.ServerError(e, nil)
29
29
+
helpers.ServerError(w, nil)
30
30
+
return
36
31
}
37
32
38
33
s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
···
45
40
},
46
41
})
47
42
48
48
-
return e.NoContent(200)
43
43
+
w.WriteHeader(http.StatusOK)
49
44
}
+44
-25
server/handle_server_delete_account.go
···
2
2
3
3
import (
4
4
"context"
5
5
+
"encoding/json"
6
6
+
"net/http"
5
7
"time"
6
8
7
9
"github.com/Azure/go-autorest/autorest/to"
···
9
11
"github.com/bluesky-social/indigo/events"
10
12
"github.com/bluesky-social/indigo/util"
11
13
"github.com/haileyok/cocoon/internal/helpers"
12
12
-
"github.com/labstack/echo/v4"
13
14
"golang.org/x/crypto/bcrypt"
14
15
)
15
16
···
19
20
Token string `json:"token" validate:"required"`
20
21
}
21
22
22
22
-
func (s *Server) handleServerDeleteAccount(e echo.Context) error {
23
23
-
ctx := e.Request().Context()
23
23
+
func (s *Server) handleServerDeleteAccount(w http.ResponseWriter, r *http.Request) {
24
24
+
ctx := r.Context()
24
25
logger := s.logger.With("name", "handleServerDeleteAccount")
25
26
26
27
var req ComAtprotoServerDeleteAccountRequest
27
27
-
if err := e.Bind(&req); err != nil {
28
28
-
logger.Error("error binding", "error", err)
29
29
-
return helpers.ServerError(e, nil)
28
28
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
29
29
+
logger.Error("error decoding", "error", err)
30
30
+
helpers.ServerError(w, nil)
31
31
+
return
30
32
}
31
33
32
32
-
if err := e.Validate(&req); err != nil {
34
34
+
if err := s.validator.Struct(&req); err != nil {
33
35
logger.Error("error validating", "error", err)
34
34
-
return helpers.ServerError(e, nil)
36
36
+
helpers.ServerError(w, nil)
37
37
+
return
35
38
}
36
39
37
40
urepo, err := s.getRepoActorByDid(ctx, req.Did)
38
41
if err != nil {
39
42
logger.Error("error getting repo", "error", err)
40
40
-
return echo.NewHTTPError(400, "account not found")
43
43
+
s.writeJSON(w, 400, map[string]string{"error": "account not found"})
44
44
+
return
41
45
}
42
46
43
47
if err := bcrypt.CompareHashAndPassword([]byte(urepo.Repo.Password), []byte(req.Password)); err != nil {
44
48
logger.Error("password mismatch", "error", err)
45
45
-
return echo.NewHTTPError(401, "Invalid did or password")
49
49
+
s.writeJSON(w, 401, map[string]string{"error": "Invalid did or password"})
50
50
+
return
46
51
}
47
52
48
53
if urepo.Repo.AccountDeleteCode == nil || urepo.Repo.AccountDeleteCodeExpiresAt == nil {
49
54
logger.Error("no deletion token found for account")
50
50
-
return echo.NewHTTPError(400, map[string]interface{}{
55
55
+
s.writeJSON(w, 400, map[string]any{
51
56
"error": "InvalidToken",
52
57
"message": "Token is invalid",
53
58
})
59
59
+
return
54
60
}
55
61
56
62
if *urepo.Repo.AccountDeleteCode != req.Token {
57
63
logger.Error("deletion token mismatch")
58
58
-
return echo.NewHTTPError(400, map[string]interface{}{
64
64
+
s.writeJSON(w, 400, map[string]any{
59
65
"error": "InvalidToken",
60
66
"message": "Token is invalid",
61
67
})
68
68
+
return
62
69
}
63
70
64
71
if time.Now().UTC().After(*urepo.Repo.AccountDeleteCodeExpiresAt) {
65
72
logger.Error("deletion token expired")
66
66
-
return echo.NewHTTPError(400, map[string]interface{}{
73
73
+
s.writeJSON(w, 400, map[string]any{
67
74
"error": "ExpiredToken",
68
75
"message": "Token is expired",
69
76
})
77
77
+
return
70
78
}
71
79
72
80
tx := s.db.Begin(ctx)
73
81
if tx.Error != nil {
74
82
logger.Error("error starting transaction", "error", tx.Error)
75
75
-
return helpers.ServerError(e, nil)
83
83
+
helpers.ServerError(w, nil)
84
84
+
return
76
85
}
77
86
78
87
status := "error"
···
86
95
87
96
if err := tx.Exec("DELETE FROM blocks WHERE did = ?", req.Did).Error; err != nil {
88
97
logger.Error("error deleting blocks", "error", err)
89
89
-
return helpers.ServerError(e, nil)
98
98
+
helpers.ServerError(w, nil)
99
99
+
return
90
100
}
91
101
92
102
if err := tx.Exec("DELETE FROM records WHERE did = ?", req.Did).Error; err != nil {
93
103
logger.Error("error deleting records", "error", err)
94
94
-
return helpers.ServerError(e, nil)
104
104
+
helpers.ServerError(w, nil)
105
105
+
return
95
106
}
96
107
97
108
if err := tx.Exec("DELETE FROM blobs WHERE did = ?", req.Did).Error; err != nil {
98
109
logger.Error("error deleting blobs", "error", err)
99
99
-
return helpers.ServerError(e, nil)
110
110
+
helpers.ServerError(w, nil)
111
111
+
return
100
112
}
101
113
102
114
if err := tx.Exec("DELETE FROM tokens WHERE did = ?", req.Did).Error; err != nil {
103
115
logger.Error("error deleting tokens", "error", err)
104
104
-
return helpers.ServerError(e, nil)
116
116
+
helpers.ServerError(w, nil)
117
117
+
return
105
118
}
106
119
107
120
if err := tx.Exec("DELETE FROM refresh_tokens WHERE did = ?", req.Did).Error; err != nil {
108
121
logger.Error("error deleting refresh tokens", "error", err)
109
109
-
return helpers.ServerError(e, nil)
122
122
+
helpers.ServerError(w, nil)
123
123
+
return
110
124
}
111
125
112
126
if err := tx.Exec("DELETE FROM reserved_keys WHERE did = ?", req.Did).Error; err != nil {
113
127
logger.Error("error deleting reserved keys", "error", err)
114
114
-
return helpers.ServerError(e, nil)
128
128
+
helpers.ServerError(w, nil)
129
129
+
return
115
130
}
116
131
117
132
if err := tx.Exec("DELETE FROM invite_codes WHERE did = ?", req.Did).Error; err != nil {
118
133
logger.Error("error deleting invite codes", "error", err)
119
119
-
return helpers.ServerError(e, nil)
134
134
+
helpers.ServerError(w, nil)
135
135
+
return
120
136
}
121
137
122
138
if err := tx.Exec("DELETE FROM actors WHERE did = ?", req.Did).Error; err != nil {
123
139
logger.Error("error deleting actor", "error", err)
124
124
-
return helpers.ServerError(e, nil)
140
140
+
helpers.ServerError(w, nil)
141
141
+
return
125
142
}
126
143
127
144
if err := tx.Exec("DELETE FROM repos WHERE did = ?", req.Did).Error; err != nil {
128
145
logger.Error("error deleting repo", "error", err)
129
129
-
return helpers.ServerError(e, nil)
146
146
+
helpers.ServerError(w, nil)
147
147
+
return
130
148
}
131
149
132
150
status = "ok"
133
151
134
152
if err := tx.Commit().Error; err != nil {
135
153
logger.Error("error committing transaction", "error", err)
136
136
-
return helpers.ServerError(e, nil)
154
154
+
helpers.ServerError(w, nil)
155
155
+
return
137
156
}
138
157
139
158
s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
···
146
165
},
147
166
})
148
167
149
149
-
return e.NoContent(200)
168
168
+
w.WriteHeader(http.StatusOK)
150
169
}
+10
-7
server/handle_server_delete_session.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
5
5
+
4
6
"github.com/haileyok/cocoon/internal/helpers"
5
7
"github.com/haileyok/cocoon/models"
6
6
-
"github.com/labstack/echo/v4"
7
8
)
8
9
9
9
-
func (s *Server) handleDeleteSession(e echo.Context) error {
10
10
-
ctx := e.Request().Context()
10
10
+
func (s *Server) handleDeleteSession(w http.ResponseWriter, r *http.Request) {
11
11
+
ctx := r.Context()
11
12
12
12
-
token := e.Get("token").(string)
13
13
+
token, _ := getContextValue[string](r, contextKeyToken)
13
14
14
15
var acctok models.Token
15
16
if err := s.db.Raw(ctx, "DELETE FROM tokens WHERE token = ? RETURNING *", nil, token).Scan(&acctok).Error; err != nil {
16
17
s.logger.Error("error deleting access token from db", "error", err)
17
17
-
return helpers.ServerError(e, nil)
18
18
+
helpers.ServerError(w, nil)
19
19
+
return
18
20
}
19
21
20
22
if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, acctok.RefreshToken).Error; err != nil {
21
23
s.logger.Error("error deleting refresh token from db", "error", err)
22
22
-
return helpers.ServerError(e, nil)
24
24
+
helpers.ServerError(w, nil)
25
25
+
return
23
26
}
24
27
25
25
-
return e.NoContent(200)
28
28
+
w.WriteHeader(http.StatusOK)
26
29
}
+3
-3
server/handle_server_describe_server.go
···
1
1
package server
2
2
3
3
-
import "github.com/labstack/echo/v4"
3
3
+
import "net/http"
4
4
5
5
type ComAtprotoServerDescribeServerResponseLinks struct {
6
6
PrivacyPolicy *string `json:"privacyPolicy,omitempty"`
···
20
20
Did string `json:"did"`
21
21
}
22
22
23
23
-
func (s *Server) handleDescribeServer(e echo.Context) error {
24
24
-
return e.JSON(200, ComAtprotoServerDescribeServerResponse{
23
23
+
func (s *Server) handleDescribeServer(w http.ResponseWriter, r *http.Request) {
24
24
+
s.writeJSON(w, 200, ComAtprotoServerDescribeServerResponse{
25
25
InviteCodeRequired: s.config.RequireInvite,
26
26
PhoneVerificationRequired: false,
27
27
AvailableUserDomains: []string{"." + s.config.Hostname}, // TODO: more
+30
-19
server/handle_server_get_service_auth.go
···
6
6
"encoding/base64"
7
7
"encoding/json"
8
8
"fmt"
9
9
+
"net/http"
9
10
"strings"
10
11
"time"
11
12
···
13
14
"github.com/google/uuid"
14
15
"github.com/haileyok/cocoon/internal/helpers"
15
16
"github.com/haileyok/cocoon/models"
16
16
-
"github.com/labstack/echo/v4"
17
17
secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18
18
)
19
19
20
20
type ServerGetServiceAuthRequest struct {
21
21
-
Aud string `query:"aud" validate:"required,atproto-did"`
22
22
-
// exp should be a float, as some clients will send a non-integer expiration
21
21
+
Aud string `query:"aud" validate:"required,atproto-did"`
23
22
Exp float64 `query:"exp"`
24
23
Lxm string `query:"lxm"`
25
24
}
26
25
27
27
-
func (s *Server) handleServerGetServiceAuth(e echo.Context) error {
26
26
+
func (s *Server) handleServerGetServiceAuth(w http.ResponseWriter, r *http.Request) {
28
27
logger := s.logger.With("name", "handleServerGetServiceAuth")
29
28
30
30
-
var req ServerGetServiceAuthRequest
31
31
-
if err := e.Bind(&req); err != nil {
32
32
-
logger.Error("could not bind service auth request", "error", err)
33
33
-
return helpers.ServerError(e, nil)
29
29
+
req := ServerGetServiceAuthRequest{
30
30
+
Aud: r.URL.Query().Get("aud"),
31
31
+
Lxm: r.URL.Query().Get("lxm"),
32
32
+
}
33
33
+
if v := r.URL.Query().Get("exp"); v != "" {
34
34
+
var exp float64
35
35
+
if _, err := fmt.Sscanf(v, "%f", &exp); err == nil {
36
36
+
req.Exp = exp
37
37
+
}
34
38
}
35
39
36
36
-
if err := e.Validate(req); err != nil {
37
37
-
return helpers.InputError(e, nil)
40
40
+
if err := s.validator.Struct(req); err != nil {
41
41
+
helpers.InputError(w, nil)
42
42
+
return
38
43
}
39
44
40
45
exp := int64(req.Exp)
···
44
49
}
45
50
46
51
if req.Lxm == "com.atproto.server.getServiceAuth" {
47
47
-
return helpers.InputError(e, to.StringPtr("may not generate auth tokens recursively"))
52
52
+
helpers.InputError(w, to.StringPtr("may not generate auth tokens recursively"))
53
53
+
return
48
54
}
49
55
50
56
var maxExp int64
···
54
60
maxExp = now + 60
55
61
}
56
62
if exp > maxExp {
57
57
-
return helpers.InputError(e, to.StringPtr("expiration too big. smoller please"))
63
63
+
helpers.InputError(w, to.StringPtr("expiration too big. smoller please"))
64
64
+
return
58
65
}
59
66
60
60
-
repo := e.Get("repo").(*models.RepoActor)
67
67
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
61
68
62
69
header := map[string]string{
63
70
"alg": "ES256K",
···
67
74
hj, err := json.Marshal(header)
68
75
if err != nil {
69
76
logger.Error("error marshaling header", "error", err)
70
70
-
return helpers.ServerError(e, nil)
77
77
+
helpers.ServerError(w, nil)
78
78
+
return
71
79
}
72
80
73
81
encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=")
···
84
92
}
85
93
pj, err := json.Marshal(payload)
86
94
if err != nil {
87
87
-
logger.Error("error marashaling payload", "error", err)
88
88
-
return helpers.ServerError(e, nil)
95
95
+
logger.Error("error marshaling payload", "error", err)
96
96
+
helpers.ServerError(w, nil)
97
97
+
return
89
98
}
90
99
91
100
encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=")
···
96
105
sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
97
106
if err != nil {
98
107
logger.Error("can't load private key", "error", err)
99
99
-
return err
108
108
+
helpers.ServerError(w, nil)
109
109
+
return
100
110
}
101
111
102
112
R, S, _, err := sk.SignRaw(rand.Reader, hash[:])
103
113
if err != nil {
104
114
logger.Error("error signing", "error", err)
105
105
-
return helpers.ServerError(e, nil)
115
115
+
helpers.ServerError(w, nil)
116
116
+
return
106
117
}
107
118
108
119
rBytes := R.Bytes()
···
117
128
encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=")
118
129
token := fmt.Sprintf("%s.%s", input, encsig)
119
130
120
120
-
return e.JSON(200, map[string]string{
131
131
+
s.writeJSON(w, 200, map[string]string{
121
132
"token": token,
122
133
})
123
134
}
+5
-4
server/handle_server_get_session.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
5
5
+
4
6
"github.com/haileyok/cocoon/models"
5
5
-
"github.com/labstack/echo/v4"
6
7
)
7
8
8
9
type ComAtprotoServerGetSessionResponse struct {
···
15
16
Status *string `json:"status,omitempty"`
16
17
}
17
18
18
18
-
func (s *Server) handleGetSession(e echo.Context) error {
19
19
-
repo := e.Get("repo").(*models.RepoActor)
19
19
+
func (s *Server) handleGetSession(w http.ResponseWriter, r *http.Request) {
20
20
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
20
21
21
21
-
return e.JSON(200, ComAtprotoServerGetSessionResponse{
22
22
+
s.writeJSON(w, 200, ComAtprotoServerGetSessionResponse{
22
23
Handle: repo.Handle,
23
24
Did: repo.Repo.Did,
24
25
Email: repo.Email,
+13
-9
server/handle_server_refresh_session.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
5
5
+
4
6
"github.com/haileyok/cocoon/internal/helpers"
5
7
"github.com/haileyok/cocoon/models"
6
6
-
"github.com/labstack/echo/v4"
7
8
)
8
9
9
10
type ComAtprotoServerRefreshSessionResponse struct {
···
15
16
Status *string `json:"status,omitempty"`
16
17
}
17
18
18
18
-
func (s *Server) handleRefreshSession(e echo.Context) error {
19
19
-
ctx := e.Request().Context()
19
19
+
func (s *Server) handleRefreshSession(w http.ResponseWriter, r *http.Request) {
20
20
+
ctx := r.Context()
20
21
logger := s.logger.With("name", "handleServerRefreshSession")
21
22
22
22
-
token := e.Get("token").(string)
23
23
-
repo := e.Get("repo").(*models.RepoActor)
23
23
+
token, _ := getContextValue[string](r, contextKeyToken)
24
24
+
repo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
24
25
25
26
if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, token).Error; err != nil {
26
27
logger.Error("error getting refresh token from db", "error", err)
27
27
-
return helpers.ServerError(e, nil)
28
28
+
helpers.ServerError(w, nil)
29
29
+
return
28
30
}
29
31
30
32
if err := s.db.Exec(ctx, "DELETE FROM tokens WHERE refresh_token = ?", nil, token).Error; err != nil {
31
33
logger.Error("error deleting access token from db", "error", err)
32
32
-
return helpers.ServerError(e, nil)
34
34
+
helpers.ServerError(w, nil)
35
35
+
return
33
36
}
34
37
35
38
sess, err := s.createSession(ctx, &repo.Repo)
36
39
if err != nil {
37
40
logger.Error("error creating new session for refresh", "error", err)
38
38
-
return helpers.ServerError(e, nil)
41
41
+
helpers.ServerError(w, nil)
42
42
+
return
39
43
}
40
44
41
41
-
return e.JSON(200, ComAtprotoServerRefreshSessionResponse{
45
45
+
s.writeJSON(w, 200, ComAtprotoServerRefreshSessionResponse{
42
46
AccessJwt: sess.AccessToken,
43
47
RefreshJwt: sess.RefreshToken,
44
48
Handle: repo.Handle,
+7
-6
server/handle_server_request_account_delete.go
···
2
2
3
3
import (
4
4
"fmt"
5
5
+
"net/http"
5
6
"time"
6
7
7
8
"github.com/haileyok/cocoon/internal/helpers"
8
9
"github.com/haileyok/cocoon/models"
9
9
-
"github.com/labstack/echo/v4"
10
10
)
11
11
12
12
-
func (s *Server) handleServerRequestAccountDelete(e echo.Context) error {
13
13
-
ctx := e.Request().Context()
12
12
+
func (s *Server) handleServerRequestAccountDelete(w http.ResponseWriter, r *http.Request) {
13
13
+
ctx := r.Context()
14
14
logger := s.logger.With("name", "handleServerRequestAccountDelete")
15
15
16
16
-
urepo := e.Get("repo").(*models.RepoActor)
16
16
+
urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
17
17
18
18
token := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
19
19
expiresAt := time.Now().UTC().Add(15 * time.Minute)
20
20
21
21
if err := s.db.Exec(ctx, "UPDATE repos SET account_delete_code = ?, account_delete_code_expires_at = ? WHERE did = ?", nil, token, expiresAt, urepo.Repo.Did).Error; err != nil {
22
22
logger.Error("error setting deletion token", "error", err)
23
23
-
return helpers.ServerError(e, nil)
23
23
+
helpers.ServerError(w, nil)
24
24
+
return
24
25
}
25
26
26
27
if urepo.Email != "" {
···
29
30
}
30
31
}
31
32
32
32
-
return e.NoContent(200)
33
33
+
w.WriteHeader(http.StatusOK)
33
34
}
34
35
35
36
func (s *Server) sendAccountDeleteEmail(email, handle, token string) error {
+11
-8
server/handle_server_request_email_confirmation.go
···
2
2
3
3
import (
4
4
"fmt"
5
5
+
"net/http"
5
6
"time"
6
7
7
8
"github.com/Azure/go-autorest/autorest/to"
8
9
"github.com/haileyok/cocoon/internal/helpers"
9
10
"github.com/haileyok/cocoon/models"
10
10
-
"github.com/labstack/echo/v4"
11
11
)
12
12
13
13
-
func (s *Server) handleServerRequestEmailConfirmation(e echo.Context) error {
14
14
-
ctx := e.Request().Context()
13
13
+
func (s *Server) handleServerRequestEmailConfirmation(w http.ResponseWriter, r *http.Request) {
14
14
+
ctx := r.Context()
15
15
logger := s.logger.With("name", "handleServerRequestEmailConfirm")
16
16
17
17
-
urepo := e.Get("repo").(*models.RepoActor)
17
17
+
urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
18
18
19
19
if urepo.EmailConfirmedAt != nil {
20
20
-
return helpers.InputError(e, to.StringPtr("InvalidRequest"))
20
20
+
helpers.InputError(w, to.StringPtr("InvalidRequest"))
21
21
+
return
21
22
}
22
23
23
24
code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
···
25
26
26
27
if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
27
28
logger.Error("error updating user", "error", err)
28
28
-
return helpers.ServerError(e, nil)
29
29
+
helpers.ServerError(w, nil)
30
30
+
return
29
31
}
30
32
31
33
if err := s.sendEmailVerification(urepo.Email, urepo.Handle, code); err != nil {
32
34
logger.Error("error sending mail", "error", err)
33
33
-
return helpers.ServerError(e, nil)
35
35
+
helpers.ServerError(w, nil)
36
36
+
return
34
37
}
35
38
36
36
-
return e.NoContent(200)
39
39
+
w.WriteHeader(http.StatusOK)
37
40
}
+9
-7
server/handle_server_request_email_update.go
···
2
2
3
3
import (
4
4
"fmt"
5
5
+
"net/http"
5
6
"time"
6
7
7
8
"github.com/haileyok/cocoon/internal/helpers"
8
9
"github.com/haileyok/cocoon/models"
9
9
-
"github.com/labstack/echo/v4"
10
10
)
11
11
12
12
type ComAtprotoRequestEmailUpdateResponse struct {
13
13
TokenRequired bool `json:"tokenRequired"`
14
14
}
15
15
16
16
-
func (s *Server) handleServerRequestEmailUpdate(e echo.Context) error {
17
17
-
ctx := e.Request().Context()
16
16
+
func (s *Server) handleServerRequestEmailUpdate(w http.ResponseWriter, r *http.Request) {
17
17
+
ctx := r.Context()
18
18
logger := s.logger.With("name", "handleServerRequestEmailUpdate")
19
19
20
20
-
urepo := e.Get("repo").(*models.RepoActor)
20
20
+
urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
21
21
22
22
if urepo.EmailConfirmedAt != nil {
23
23
code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
···
25
25
26
26
if err := s.db.Exec(ctx, "UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
27
27
logger.Error("error updating repo", "error", err)
28
28
-
return helpers.ServerError(e, nil)
28
28
+
helpers.ServerError(w, nil)
29
29
+
return
29
30
}
30
31
31
32
if err := s.sendEmailUpdate(urepo.Email, urepo.Handle, code); err != nil {
32
33
logger.Error("error sending email", "error", err)
33
33
-
return helpers.ServerError(e, nil)
34
34
+
helpers.ServerError(w, nil)
35
35
+
return
34
36
}
35
37
}
36
38
37
37
-
return e.JSON(200, ComAtprotoRequestEmailUpdateResponse{
39
39
+
s.writeJSON(w, 200, ComAtprotoRequestEmailUpdateResponse{
38
40
TokenRequired: urepo.EmailConfirmedAt != nil,
39
41
})
40
42
}
+21
-13
server/handle_server_request_password_reset.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"encoding/json"
4
5
"fmt"
6
6
+
"net/http"
5
7
"time"
6
8
7
9
"github.com/haileyok/cocoon/internal/helpers"
8
10
"github.com/haileyok/cocoon/models"
9
9
-
"github.com/labstack/echo/v4"
10
11
)
11
12
12
13
type ComAtprotoServerRequestPasswordResetRequest struct {
13
14
Email string `json:"email" validate:"required"`
14
15
}
15
16
16
16
-
func (s *Server) handleServerRequestPasswordReset(e echo.Context) error {
17
17
-
ctx := e.Request().Context()
17
17
+
func (s *Server) handleServerRequestPasswordReset(w http.ResponseWriter, r *http.Request) {
18
18
+
ctx := r.Context()
18
19
logger := s.logger.With("name", "handleServerRequestPasswordReset")
19
20
20
20
-
urepo, ok := e.Get("repo").(*models.RepoActor)
21
21
-
if !ok {
21
21
+
var urepo *models.RepoActor
22
22
+
if repo, ok := getContextValue[*models.RepoActor](r, contextKeyRepo); ok {
23
23
+
urepo = repo
24
24
+
} else {
22
25
var req ComAtprotoServerRequestPasswordResetRequest
23
23
-
if err := e.Bind(&req); err != nil {
24
24
-
return err
26
26
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
27
27
+
helpers.ServerError(w, nil)
28
28
+
return
25
29
}
26
30
27
27
-
if err := e.Validate(req); err != nil {
28
28
-
return err
31
31
+
if err := s.validator.Struct(req); err != nil {
32
32
+
helpers.InputError(w, nil)
33
33
+
return
29
34
}
30
35
31
36
murepo, err := s.getRepoActorByEmail(ctx, req.Email)
32
37
if err != nil {
33
33
-
return err
38
38
+
helpers.ServerError(w, nil)
39
39
+
return
34
40
}
35
41
36
42
urepo = murepo
···
41
47
42
48
if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
43
49
logger.Error("error updating repo", "error", err)
44
44
-
return helpers.ServerError(e, nil)
50
50
+
helpers.ServerError(w, nil)
51
51
+
return
45
52
}
46
53
47
54
if err := s.sendPasswordReset(urepo.Email, urepo.Handle, code); err != nil {
48
55
logger.Error("error sending email", "error", err)
49
49
-
return helpers.ServerError(e, nil)
56
56
+
helpers.ServerError(w, nil)
57
57
+
return
50
58
}
51
59
52
52
-
return e.NoContent(200)
60
60
+
w.WriteHeader(http.StatusOK)
53
61
}
+17
-11
server/handle_server_reserve_signing_key.go
···
2
2
3
3
import (
4
4
"context"
5
5
+
"encoding/json"
6
6
+
"net/http"
5
7
"time"
6
8
7
9
"github.com/bluesky-social/indigo/atproto/atcrypto"
8
10
"github.com/haileyok/cocoon/internal/helpers"
9
11
"github.com/haileyok/cocoon/models"
10
10
-
"github.com/labstack/echo/v4"
11
12
)
12
13
13
14
type ServerReserveSigningKeyRequest struct {
···
18
19
SigningKey string `json:"signingKey"`
19
20
}
20
21
21
21
-
func (s *Server) handleServerReserveSigningKey(e echo.Context) error {
22
22
-
ctx := e.Request().Context()
22
22
+
func (s *Server) handleServerReserveSigningKey(w http.ResponseWriter, r *http.Request) {
23
23
+
ctx := r.Context()
23
24
logger := s.logger.With("name", "handleServerReserveSigningKey")
24
25
25
26
var req ServerReserveSigningKeyRequest
26
26
-
if err := e.Bind(&req); err != nil {
27
27
-
logger.Error("could not bind reserve signing key request", "error", err)
28
28
-
return helpers.ServerError(e, nil)
27
27
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
28
28
+
logger.Error("could not decode reserve signing key request", "error", err)
29
29
+
helpers.ServerError(w, nil)
30
30
+
return
29
31
}
30
32
31
33
if req.Did != nil && *req.Did != "" {
32
34
var existing models.ReservedKey
33
35
if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE did = ?", nil, *req.Did).Scan(&existing).Error; err == nil && existing.KeyDid != "" {
34
34
-
return e.JSON(200, ServerReserveSigningKeyResponse{
36
36
+
s.writeJSON(w, 200, ServerReserveSigningKeyResponse{
35
37
SigningKey: existing.KeyDid,
36
38
})
39
39
+
return
37
40
}
38
41
}
39
42
40
43
k, err := atcrypto.GeneratePrivateKeyK256()
41
44
if err != nil {
42
45
logger.Error("error creating signing key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err)
43
43
-
return helpers.ServerError(e, nil)
46
46
+
helpers.ServerError(w, nil)
47
47
+
return
44
48
}
45
49
46
50
pubKey, err := k.PublicKey()
47
51
if err != nil {
48
52
logger.Error("error getting public key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err)
49
49
-
return helpers.ServerError(e, nil)
53
53
+
helpers.ServerError(w, nil)
54
54
+
return
50
55
}
51
56
52
57
keyDid := pubKey.DIDKey()
···
60
65
61
66
if err := s.db.Create(ctx, &reservedKey, nil).Error; err != nil {
62
67
logger.Error("error storing reserved key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err)
63
63
-
return helpers.ServerError(e, nil)
68
68
+
helpers.ServerError(w, nil)
69
69
+
return
64
70
}
65
71
66
72
logger.Info("reserved signing key", "keyDid", keyDid, "forDid", req.Did)
67
73
68
68
-
return e.JSON(200, ServerReserveSigningKeyResponse{
74
74
+
s.writeJSON(w, 200, ServerReserveSigningKeyResponse{
69
75
SigningKey: keyDid,
70
76
})
71
77
}
+23
-15
server/handle_server_reset_password.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"encoding/json"
5
5
+
"net/http"
4
6
"time"
5
7
6
8
"github.com/Azure/go-autorest/autorest/to"
7
9
"github.com/haileyok/cocoon/internal/helpers"
8
10
"github.com/haileyok/cocoon/models"
9
9
-
"github.com/labstack/echo/v4"
10
11
"golang.org/x/crypto/bcrypt"
11
12
)
12
13
···
15
16
Password string `json:"password" validate:"required"`
16
17
}
17
18
18
18
-
func (s *Server) handleServerResetPassword(e echo.Context) error {
19
19
-
ctx := e.Request().Context()
19
19
+
func (s *Server) handleServerResetPassword(w http.ResponseWriter, r *http.Request) {
20
20
+
ctx := r.Context()
20
21
logger := s.logger.With("name", "handleServerResetPassword")
21
22
22
22
-
urepo := e.Get("repo").(*models.RepoActor)
23
23
+
urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
23
24
24
25
var req ComAtprotoServerResetPasswordRequest
25
25
-
if err := e.Bind(&req); err != nil {
26
26
-
logger.Error("error binding", "error", err)
27
27
-
return helpers.ServerError(e, nil)
26
26
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
27
27
+
logger.Error("error decoding", "error", err)
28
28
+
helpers.ServerError(w, nil)
29
29
+
return
28
30
}
29
31
30
30
-
if err := e.Validate(req); err != nil {
31
31
-
return helpers.InputError(e, nil)
32
32
+
if err := s.validator.Struct(req); err != nil {
33
33
+
helpers.InputError(w, nil)
34
34
+
return
32
35
}
33
36
34
37
if urepo.PasswordResetCode == nil || urepo.PasswordResetCodeExpiresAt == nil {
35
35
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
38
38
+
helpers.InputError(w, to.StringPtr("InvalidToken"))
39
39
+
return
36
40
}
37
41
38
42
if *urepo.PasswordResetCode != req.Token {
39
39
-
return helpers.InvalidTokenError(e)
43
43
+
helpers.InvalidTokenError(w)
44
44
+
return
40
45
}
41
46
42
47
if time.Now().UTC().After(*urepo.PasswordResetCodeExpiresAt) {
43
43
-
return helpers.ExpiredTokenError(e)
48
48
+
helpers.ExpiredTokenError(w)
49
49
+
return
44
50
}
45
51
46
52
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 10)
47
53
if err != nil {
48
54
logger.Error("error creating hash", "error", err)
49
49
-
return helpers.ServerError(e, nil)
55
55
+
helpers.ServerError(w, nil)
56
56
+
return
50
57
}
51
58
52
59
if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", nil, hash, urepo.Repo.Did).Error; err != nil {
53
60
logger.Error("error updating repo", "error", err)
54
54
-
return helpers.ServerError(e, nil)
61
61
+
helpers.ServerError(w, nil)
62
62
+
return
55
63
}
56
64
57
57
-
return e.NoContent(200)
65
65
+
w.WriteHeader(http.StatusOK)
58
66
}
+11
-8
server/handle_server_resolve_handle.go
···
2
2
3
3
import (
4
4
"context"
5
5
+
"net/http"
5
6
6
7
"github.com/Azure/go-autorest/autorest/to"
7
8
"github.com/bluesky-social/indigo/atproto/syntax"
8
9
"github.com/haileyok/cocoon/internal/helpers"
9
9
-
"github.com/labstack/echo/v4"
10
10
)
11
11
12
12
-
func (s *Server) handleResolveHandle(e echo.Context) error {
12
12
+
func (s *Server) handleResolveHandle(w http.ResponseWriter, r *http.Request) {
13
13
logger := s.logger.With("name", "handleServerResolveHandle")
14
14
15
15
type Resp struct {
16
16
Did string `json:"did"`
17
17
}
18
18
19
19
-
handle := e.QueryParam("handle")
19
19
+
handle := r.URL.Query().Get("handle")
20
20
21
21
if handle == "" {
22
22
-
return helpers.InputError(e, to.StringPtr("Handle must be supplied in request."))
22
22
+
helpers.InputError(w, to.StringPtr("Handle must be supplied in request."))
23
23
+
return
23
24
}
24
25
25
26
parsed, err := syntax.ParseHandle(handle)
26
27
if err != nil {
27
27
-
return helpers.InputError(e, to.StringPtr("Invalid handle."))
28
28
+
helpers.InputError(w, to.StringPtr("Invalid handle."))
29
29
+
return
28
30
}
29
31
30
30
-
ctx := context.WithValue(e.Request().Context(), "skip-cache", true)
32
32
+
ctx := context.WithValue(r.Context(), "skip-cache", true)
31
33
did, err := s.passport.ResolveHandle(ctx, parsed.String())
32
34
if err != nil {
33
35
logger.Error("error resolving handle", "error", err)
34
34
-
return helpers.ServerError(e, nil)
36
36
+
helpers.ServerError(w, nil)
37
37
+
return
35
38
}
36
39
37
37
-
return e.JSON(200, Resp{
40
40
+
s.writeJSON(w, 200, Resp{
38
41
Did: did,
39
42
})
40
43
}
+24
-16
server/handle_server_update_email.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"encoding/json"
5
5
+
"net/http"
4
6
"time"
5
7
6
8
"github.com/haileyok/cocoon/internal/helpers"
7
9
"github.com/haileyok/cocoon/models"
8
8
-
"github.com/labstack/echo/v4"
9
10
)
10
11
11
12
type ComAtprotoServerUpdateEmailRequest struct {
···
14
15
Token string `json:"token"`
15
16
}
16
17
17
17
-
func (s *Server) handleServerUpdateEmail(e echo.Context) error {
18
18
-
ctx := e.Request().Context()
18
18
+
func (s *Server) handleServerUpdateEmail(w http.ResponseWriter, r *http.Request) {
19
19
+
ctx := r.Context()
19
20
logger := s.logger.With("name", "handleServerUpdateEmail")
20
21
21
21
-
urepo := e.Get("repo").(*models.RepoActor)
22
22
+
urepo, _ := getContextValue[*models.RepoActor](r, contextKeyRepo)
22
23
23
24
var req ComAtprotoServerUpdateEmailRequest
24
24
-
if err := e.Bind(&req); err != nil {
25
25
-
logger.Error("error binding", "error", err)
26
26
-
return helpers.ServerError(e, nil)
25
25
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
26
26
+
logger.Error("error decoding", "error", err)
27
27
+
helpers.ServerError(w, nil)
28
28
+
return
27
29
}
28
30
29
29
-
if err := e.Validate(req); err != nil {
30
30
-
return helpers.InputError(e, nil)
31
31
+
if err := s.validator.Struct(req); err != nil {
32
32
+
helpers.InputError(w, nil)
33
33
+
return
31
34
}
32
35
33
36
// To disable email auth factor a token is required.
34
37
// To enable email auth factor a token is not required.
35
38
// If updating an email address, a token will be sent anyway
36
36
-
if urepo.TwoFactorType != models.TwoFactorTypeNone && req.EmailAuthFactor == false && req.Token == "" {
37
37
-
return helpers.InvalidTokenError(e)
39
39
+
if urepo.TwoFactorType != models.TwoFactorTypeNone && !req.EmailAuthFactor && req.Token == "" {
40
40
+
helpers.InvalidTokenError(w)
41
41
+
return
38
42
}
39
43
40
44
if req.Token != "" {
41
45
if urepo.EmailUpdateCode == nil || urepo.EmailUpdateCodeExpiresAt == nil {
42
42
-
return helpers.InvalidTokenError(e)
46
46
+
helpers.InvalidTokenError(w)
47
47
+
return
43
48
}
44
49
45
50
if *urepo.EmailUpdateCode != req.Token {
46
46
-
return helpers.InvalidTokenError(e)
51
51
+
helpers.InvalidTokenError(w)
52
52
+
return
47
53
}
48
54
49
55
if time.Now().UTC().After(*urepo.EmailUpdateCodeExpiresAt) {
50
50
-
return helpers.ExpiredTokenError(e)
56
56
+
helpers.ExpiredTokenError(w)
57
57
+
return
51
58
}
52
59
}
53
60
···
66
73
67
74
if err := s.db.Exec(ctx, query, nil, twoFactorType, req.Email, urepo.Repo.Did).Error; err != nil {
68
75
logger.Error("error updating repo", "error", err)
69
69
-
return helpers.ServerError(e, nil)
76
76
+
helpers.ServerError(w, nil)
77
77
+
return
70
78
}
71
79
72
72
-
return e.NoContent(200)
80
80
+
w.WriteHeader(http.StatusOK)
73
81
}
+30
-19
server/handle_sync_get_blob.go
···
10
10
"github.com/haileyok/cocoon/internal/helpers"
11
11
"github.com/haileyok/cocoon/models"
12
12
"github.com/ipfs/go-cid"
13
13
-
"github.com/labstack/echo/v4"
14
13
)
15
14
16
16
-
func (s *Server) handleSyncGetBlob(e echo.Context) error {
17
17
-
ctx := e.Request().Context()
15
15
+
func (s *Server) handleSyncGetBlob(w http.ResponseWriter, r *http.Request) {
16
16
+
ctx := r.Context()
18
17
logger := s.logger.With("name", "handleSyncGetBlob")
19
18
20
20
-
did := e.QueryParam("did")
19
19
+
did := r.URL.Query().Get("did")
21
20
if did == "" {
22
22
-
return helpers.InputError(e, nil)
21
21
+
helpers.InputError(w, nil)
22
22
+
return
23
23
}
24
24
25
25
-
cstr := e.QueryParam("cid")
25
25
+
cstr := r.URL.Query().Get("cid")
26
26
if cstr == "" {
27
27
-
return helpers.InputError(e, nil)
27
27
+
helpers.InputError(w, nil)
28
28
+
return
28
29
}
29
30
30
31
c, err := cid.Parse(cstr)
31
32
if err != nil {
32
32
-
return helpers.InputError(e, nil)
33
33
+
helpers.InputError(w, nil)
34
34
+
return
33
35
}
34
36
35
37
urepo, err := s.getRepoActorByDid(ctx, did)
36
38
if err != nil {
37
39
logger.Error("could not find user for requested blob", "error", err)
38
38
-
return helpers.InputError(e, nil)
40
40
+
helpers.InputError(w, nil)
41
41
+
return
39
42
}
40
43
41
44
status := urepo.Status()
42
45
if status != nil {
43
46
if *status == "deactivated" {
44
44
-
return helpers.InputError(e, to.StringPtr("RepoDeactivated"))
47
47
+
helpers.InputError(w, to.StringPtr("RepoDeactivated"))
48
48
+
return
45
49
}
46
50
}
47
51
48
52
var blob models.Blob
49
53
if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil {
50
54
logger.Error("error looking up blob", "error", err)
51
51
-
return helpers.ServerError(e, nil)
55
55
+
helpers.ServerError(w, nil)
56
56
+
return
52
57
}
53
58
54
59
buf := new(bytes.Buffer)
···
58
63
var parts []models.BlobPart
59
64
if err := s.db.Raw(ctx, "SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil {
60
65
logger.Error("error getting blob parts", "error", err)
61
61
-
return helpers.ServerError(e, nil)
66
66
+
helpers.ServerError(w, nil)
67
67
+
return
62
68
}
63
69
64
70
for _, p := range parts {
···
68
74
case "ipfs":
69
75
if s.ipfsConfig == nil || !s.ipfsConfig.BlobstoreEnabled {
70
76
logger.Error("ipfs storage disabled")
71
71
-
return helpers.ServerError(e, nil)
77
77
+
helpers.ServerError(w, nil)
78
78
+
return
72
79
}
73
80
74
81
// If a public gateway is configured, redirect the client directly to it
75
82
// instead of proxying the content through this server.
76
83
if s.ipfsConfig.GatewayURL != "" {
77
84
redirectURL := fmt.Sprintf("%s/ipfs/%s", s.ipfsConfig.GatewayURL, c.String())
78
78
-
return e.Redirect(302, redirectURL)
85
85
+
http.Redirect(w, r, redirectURL, http.StatusFound)
86
86
+
return
79
87
}
80
88
81
89
// Otherwise fetch from the local Kubo node via /api/v0/cat and stream
···
83
91
data, err := s.fetchBlobFromIPFS(c.String())
84
92
if err != nil {
85
93
logger.Error("error fetching blob from ipfs node", "cid", c.String(), "error", err)
86
86
-
return helpers.ServerError(e, nil)
94
94
+
helpers.ServerError(w, nil)
95
95
+
return
87
96
}
88
97
buf.Write(data)
89
98
90
99
default:
91
100
logger.Error("unknown storage", "storage", blob.Storage)
92
92
-
return helpers.ServerError(e, nil)
101
101
+
helpers.ServerError(w, nil)
102
102
+
return
93
103
}
94
104
95
95
-
e.Response().Header().Set(echo.HeaderContentDisposition, "attachment; filename="+c.String())
96
96
-
97
97
-
return e.Stream(200, "application/octet-stream", buf)
105
105
+
w.Header().Set("Content-Disposition", "attachment; filename="+c.String())
106
106
+
w.Header().Set("Content-Type", "application/octet-stream")
107
107
+
w.WriteHeader(http.StatusOK)
108
108
+
io.Copy(w, buf)
98
109
}
99
110
100
111
// fetchBlobFromIPFS retrieves blob data for the given CID from the local Kubo
+38
-18
server/handle_sync_get_blocks.go
···
2
2
3
3
import (
4
4
"bytes"
5
5
+
"net/http"
5
6
6
7
"github.com/bluesky-social/indigo/carstore"
7
8
"github.com/haileyok/cocoon/internal/helpers"
8
9
"github.com/ipfs/go-cid"
9
10
cbor "github.com/ipfs/go-ipld-cbor"
10
11
"github.com/ipld/go-car"
11
11
-
"github.com/labstack/echo/v4"
12
12
)
13
13
14
14
type ComAtprotoSyncGetBlocksRequest struct {
···
16
16
Cids []string `query:"cids"`
17
17
}
18
18
19
19
-
func (s *Server) handleGetBlocks(e echo.Context) error {
20
20
-
ctx := e.Request().Context()
19
19
+
func (s *Server) handleGetBlocks(w http.ResponseWriter, r *http.Request) {
20
20
+
ctx := r.Context()
21
21
logger := s.logger.With("name", "handleSyncGetBlocks")
22
22
23
23
-
var req ComAtprotoSyncGetBlocksRequest
24
24
-
if err := e.Bind(&req); err != nil {
25
25
-
return helpers.InputError(e, nil)
23
23
+
did := r.URL.Query().Get("did")
24
24
+
if did == "" {
25
25
+
helpers.InputError(w, nil)
26
26
+
return
26
27
}
27
28
29
29
+
cidsParam := r.URL.Query()["cids"]
28
30
var cids []cid.Cid
29
31
30
30
-
for _, cs := range req.Cids {
32
32
+
for _, cs := range cidsParam {
31
33
c, err := cid.Cast([]byte(cs))
32
34
if err != nil {
33
33
-
return err
35
35
+
logger.Error("error parsing cid", "cid", cs, "error", err)
36
36
+
helpers.InputError(w, nil)
37
37
+
return
34
38
}
35
35
-
36
39
cids = append(cids, c)
37
40
}
38
41
39
39
-
urepo, err := s.getRepoActorByDid(ctx, req.Did)
42
42
+
urepo, err := s.getRepoActorByDid(ctx, did)
40
43
if err != nil {
41
41
-
return helpers.ServerError(e, nil)
44
44
+
logger.Error("could not find repo", "did", did, "error", err)
45
45
+
helpers.ServerError(w, nil)
46
46
+
return
42
47
}
43
48
44
44
-
buf := new(bytes.Buffer)
45
49
rc, err := cid.Cast(urepo.Root)
46
50
if err != nil {
47
47
-
return err
51
51
+
logger.Error("error casting root cid", "error", err)
52
52
+
helpers.ServerError(w, nil)
53
53
+
return
48
54
}
49
55
50
56
hb, err := cbor.DumpObject(&car.CarHeader{
51
57
Roots: []cid.Cid{rc},
52
58
Version: 1,
53
59
})
60
60
+
if err != nil {
61
61
+
logger.Error("error dumping car header", "error", err)
62
62
+
helpers.ServerError(w, nil)
63
63
+
return
64
64
+
}
65
65
+
66
66
+
buf := new(bytes.Buffer)
54
67
55
68
if _, err := carstore.LdWrite(buf, hb); err != nil {
56
56
-
logger.Error("error writing to car", "error", err)
57
57
-
return helpers.ServerError(e, nil)
69
69
+
logger.Error("error writing car header", "error", err)
70
70
+
helpers.ServerError(w, nil)
71
71
+
return
58
72
}
59
73
60
74
bs := s.getBlockstore(urepo.Repo.Did)
···
62
76
for _, c := range cids {
63
77
b, err := bs.Get(ctx, c)
64
78
if err != nil {
65
65
-
return err
79
79
+
logger.Error("error getting block", "cid", c.String(), "error", err)
80
80
+
helpers.ServerError(w, nil)
81
81
+
return
66
82
}
67
83
68
84
if _, err := carstore.LdWrite(buf, b.Cid().Bytes(), b.RawData()); err != nil {
69
69
-
return err
85
85
+
logger.Error("error writing block to car", "error", err)
86
86
+
helpers.ServerError(w, nil)
87
87
+
return
70
88
}
71
89
}
72
90
73
73
-
return e.Stream(200, "application/vnd.ipld.car", bytes.NewReader(buf.Bytes()))
91
91
+
w.Header().Set("Content-Type", "application/vnd.ipld.car")
92
92
+
w.WriteHeader(http.StatusOK)
93
93
+
w.Write(buf.Bytes())
74
94
}
+15
-8
server/handle_sync_get_latest_commit.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
5
5
+
4
6
"github.com/haileyok/cocoon/internal/helpers"
5
7
"github.com/ipfs/go-cid"
6
6
-
"github.com/labstack/echo/v4"
7
8
)
8
9
9
10
type ComAtprotoSyncGetLatestCommitResponse struct {
···
11
12
Rev string `json:"rev"`
12
13
}
13
14
14
14
-
func (s *Server) handleSyncGetLatestCommit(e echo.Context) error {
15
15
-
ctx := e.Request().Context()
15
15
+
func (s *Server) handleSyncGetLatestCommit(w http.ResponseWriter, r *http.Request) {
16
16
+
ctx := r.Context()
17
17
+
logger := s.logger.With("name", "handleSyncGetLatestCommit")
16
18
17
17
-
did := e.QueryParam("did")
19
19
+
did := r.URL.Query().Get("did")
18
20
if did == "" {
19
19
-
return helpers.InputError(e, nil)
21
21
+
helpers.InputError(w, nil)
22
22
+
return
20
23
}
21
24
22
25
urepo, err := s.getRepoActorByDid(ctx, did)
23
26
if err != nil {
24
24
-
return err
27
27
+
logger.Error("could not find repo", "error", err)
28
28
+
helpers.ServerError(w, nil)
29
29
+
return
25
30
}
26
31
27
32
c, err := cid.Cast(urepo.Root)
28
33
if err != nil {
29
29
-
return err
34
34
+
logger.Error("could not cast root cid", "error", err)
35
35
+
helpers.ServerError(w, nil)
36
36
+
return
30
37
}
31
38
32
32
-
return e.JSON(200, ComAtprotoSyncGetLatestCommitResponse{
39
39
+
s.writeJSON(w, http.StatusOK, ComAtprotoSyncGetLatestCommitResponse{
33
40
Cid: c.String(),
34
41
Rev: urepo.Rev,
35
42
})
+24
-12
server/handle_sync_get_record.go
···
2
2
3
3
import (
4
4
"bytes"
5
5
+
"net/http"
5
6
6
7
"github.com/bluesky-social/indigo/carstore"
7
8
"github.com/haileyok/cocoon/internal/helpers"
···
9
10
"github.com/ipfs/go-cid"
10
11
cbor "github.com/ipfs/go-ipld-cbor"
11
12
"github.com/ipld/go-car"
12
12
-
"github.com/labstack/echo/v4"
13
13
)
14
14
15
15
-
func (s *Server) handleSyncGetRecord(e echo.Context) error {
16
16
-
ctx := e.Request().Context()
15
15
+
func (s *Server) handleSyncGetRecord(w http.ResponseWriter, r *http.Request) {
16
16
+
ctx := r.Context()
17
17
logger := s.logger.With("name", "handleSyncGetRecord")
18
18
19
19
-
did := e.QueryParam("did")
20
20
-
collection := e.QueryParam("collection")
21
21
-
rkey := e.QueryParam("rkey")
19
19
+
did := r.URL.Query().Get("did")
20
20
+
collection := r.URL.Query().Get("collection")
21
21
+
rkey := r.URL.Query().Get("rkey")
22
22
23
23
var urepo models.Repo
24
24
if err := s.db.Raw(ctx, "SELECT * FROM repos WHERE did = ?", nil, did).Scan(&urepo).Error; err != nil {
25
25
logger.Error("error getting repo", "error", err)
26
26
-
return helpers.ServerError(e, nil)
26
26
+
helpers.ServerError(w, nil)
27
27
+
return
27
28
}
28
29
29
30
root, blocks, err := s.repoman.getRecordProof(ctx, urepo, collection, rkey)
30
31
if err != nil {
31
31
-
return err
32
32
+
logger.Error("error getting record proof", "error", err)
33
33
+
helpers.ServerError(w, nil)
34
34
+
return
32
35
}
33
36
34
37
buf := new(bytes.Buffer)
···
37
40
Roots: []cid.Cid{root},
38
41
Version: 1,
39
42
})
43
43
+
if err != nil {
44
44
+
logger.Error("error dumping car header", "error", err)
45
45
+
helpers.ServerError(w, nil)
46
46
+
return
47
47
+
}
40
48
41
49
if _, err := carstore.LdWrite(buf, hb); err != nil {
42
50
logger.Error("error writing to car", "error", err)
43
43
-
return helpers.ServerError(e, nil)
51
51
+
helpers.ServerError(w, nil)
52
52
+
return
44
53
}
45
54
46
55
for _, blk := range blocks {
47
56
if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil {
48
48
-
logger.Error("error writing to car", "error", err)
49
49
-
return helpers.ServerError(e, nil)
57
57
+
logger.Error("error writing block to car", "error", err)
58
58
+
helpers.ServerError(w, nil)
59
59
+
return
50
60
}
51
61
}
52
62
53
53
-
return e.Stream(200, "application/vnd.ipld.car", bytes.NewReader(buf.Bytes()))
63
63
+
w.Header().Set("Content-Type", "application/vnd.ipld.car")
64
64
+
w.WriteHeader(http.StatusOK)
65
65
+
w.Write(buf.Bytes())
54
66
}
+29
-12
server/handle_sync_get_repo.go
···
2
2
3
3
import (
4
4
"bytes"
5
5
+
"net/http"
5
6
6
7
"github.com/bluesky-social/indigo/carstore"
7
8
"github.com/haileyok/cocoon/internal/helpers"
···
9
10
"github.com/ipfs/go-cid"
10
11
cbor "github.com/ipfs/go-ipld-cbor"
11
12
"github.com/ipld/go-car"
12
12
-
"github.com/labstack/echo/v4"
13
13
)
14
14
15
15
-
func (s *Server) handleSyncGetRepo(e echo.Context) error {
16
16
-
ctx := e.Request().Context()
15
15
+
func (s *Server) handleSyncGetRepo(w http.ResponseWriter, r *http.Request) {
16
16
+
ctx := r.Context()
17
17
logger := s.logger.With("name", "handleSyncGetRepo")
18
18
19
19
-
did := e.QueryParam("did")
19
19
+
did := r.URL.Query().Get("did")
20
20
if did == "" {
21
21
-
return helpers.InputError(e, nil)
21
21
+
helpers.InputError(w, nil)
22
22
+
return
22
23
}
23
24
24
25
urepo, err := s.getRepoActorByDid(ctx, did)
25
26
if err != nil {
26
26
-
return err
27
27
+
logger.Error("could not find repo", "did", did, "error", err)
28
28
+
helpers.ServerError(w, nil)
29
29
+
return
27
30
}
28
31
29
32
rc, err := cid.Cast(urepo.Root)
30
33
if err != nil {
31
31
-
return err
34
34
+
logger.Error("error casting root cid", "error", err)
35
35
+
helpers.ServerError(w, nil)
36
36
+
return
32
37
}
33
38
34
39
hb, err := cbor.DumpObject(&car.CarHeader{
35
40
Roots: []cid.Cid{rc},
36
41
Version: 1,
37
42
})
43
43
+
if err != nil {
44
44
+
logger.Error("error dumping car header", "error", err)
45
45
+
helpers.ServerError(w, nil)
46
46
+
return
47
47
+
}
38
48
39
49
buf := new(bytes.Buffer)
40
50
41
51
if _, err := carstore.LdWrite(buf, hb); err != nil {
42
42
-
logger.Error("error writing to car", "error", err)
43
43
-
return helpers.ServerError(e, nil)
52
52
+
logger.Error("error writing car header", "error", err)
53
53
+
helpers.ServerError(w, nil)
54
54
+
return
44
55
}
45
56
46
57
var blocks []models.Block
47
58
if err := s.db.Raw(ctx, "SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", nil, urepo.Repo.Did).Scan(&blocks).Error; err != nil {
48
48
-
return err
59
59
+
logger.Error("error getting blocks", "error", err)
60
60
+
helpers.ServerError(w, nil)
61
61
+
return
49
62
}
50
63
51
64
for _, block := range blocks {
52
65
if _, err := carstore.LdWrite(buf, block.Cid, block.Value); err != nil {
53
53
-
return err
66
66
+
logger.Error("error writing block to car", "error", err)
67
67
+
helpers.ServerError(w, nil)
68
68
+
return
54
69
}
55
70
}
56
71
57
57
-
return e.Stream(200, "application/vnd.ipld.car", bytes.NewReader(buf.Bytes()))
72
72
+
w.Header().Set("Content-Type", "application/vnd.ipld.car")
73
73
+
w.WriteHeader(http.StatusOK)
74
74
+
w.Write(buf.Bytes())
58
75
}
+12
-7
server/handle_sync_get_repo_status.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
5
5
+
4
6
"github.com/haileyok/cocoon/internal/helpers"
5
5
-
"github.com/labstack/echo/v4"
6
7
)
7
8
8
9
type ComAtprotoSyncGetRepoStatusResponse struct {
···
13
14
}
14
15
15
16
// TODO: make this actually do the right thing
16
16
-
func (s *Server) handleSyncGetRepoStatus(e echo.Context) error {
17
17
-
ctx := e.Request().Context()
17
17
+
func (s *Server) handleSyncGetRepoStatus(w http.ResponseWriter, r *http.Request) {
18
18
+
ctx := r.Context()
19
19
+
logger := s.logger.With("name", "handleSyncGetRepoStatus")
18
20
19
19
-
did := e.QueryParam("did")
21
21
+
did := r.URL.Query().Get("did")
20
22
if did == "" {
21
21
-
return helpers.InputError(e, nil)
23
23
+
helpers.InputError(w, nil)
24
24
+
return
22
25
}
23
26
24
27
urepo, err := s.getRepoActorByDid(ctx, did)
25
28
if err != nil {
26
26
-
return err
29
29
+
logger.Error("could not find repo", "did", did, "error", err)
30
30
+
helpers.ServerError(w, nil)
31
31
+
return
27
32
}
28
33
29
29
-
return e.JSON(200, ComAtprotoSyncGetRepoStatusResponse{
34
34
+
s.writeJSON(w, http.StatusOK, ComAtprotoSyncGetRepoStatusResponse{
30
35
Did: urepo.Repo.Did,
31
36
Active: urepo.Active(),
32
37
Status: urepo.Status(),
+23
-14
server/handle_sync_list_blobs.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"net/http"
5
5
+
"strconv"
6
6
+
4
7
"github.com/Azure/go-autorest/autorest/to"
5
8
"github.com/haileyok/cocoon/internal/helpers"
6
9
"github.com/haileyok/cocoon/models"
7
10
"github.com/ipfs/go-cid"
8
8
-
"github.com/labstack/echo/v4"
9
11
)
10
12
11
13
type ComAtprotoSyncListBlobsResponse struct {
···
13
15
Cids []string `json:"cids"`
14
16
}
15
17
16
16
-
func (s *Server) handleSyncListBlobs(e echo.Context) error {
17
17
-
ctx := e.Request().Context()
18
18
+
func (s *Server) handleSyncListBlobs(w http.ResponseWriter, r *http.Request) {
19
19
+
ctx := r.Context()
18
20
logger := s.logger.With("name", "handleSyncListBlobs")
19
21
20
20
-
did := e.QueryParam("did")
22
22
+
did := r.URL.Query().Get("did")
21
23
if did == "" {
22
22
-
return helpers.InputError(e, nil)
24
24
+
helpers.InputError(w, nil)
25
25
+
return
23
26
}
24
27
25
28
// TODO: add tid param
26
26
-
cursor := e.QueryParam("cursor")
27
27
-
limit, err := getLimitFromContext(e, 50)
28
28
-
if err != nil {
29
29
-
return helpers.InputError(e, nil)
29
29
+
cursor := r.URL.Query().Get("cursor")
30
30
+
31
31
+
limit := 50
32
32
+
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
33
33
+
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 1000 {
34
34
+
limit = l
35
35
+
}
30
36
}
31
37
32
38
cursorquery := ""
···
41
47
urepo, err := s.getRepoActorByDid(ctx, did)
42
48
if err != nil {
43
49
logger.Error("could not find user for requested blobs", "error", err)
44
44
-
return helpers.InputError(e, nil)
50
50
+
helpers.InputError(w, nil)
51
51
+
return
45
52
}
46
53
47
54
status := urepo.Status()
48
55
if status != nil {
49
56
if *status == "deactivated" {
50
50
-
return helpers.InputError(e, to.StringPtr("RepoDeactivated"))
57
57
+
helpers.InputError(w, to.StringPtr("RepoDeactivated"))
58
58
+
return
51
59
}
52
60
}
53
61
54
62
var blobs []models.Blob
55
63
if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil {
56
64
logger.Error("error getting records", "error", err)
57
57
-
return helpers.ServerError(e, nil)
65
65
+
helpers.ServerError(w, nil)
66
66
+
return
58
67
}
59
68
60
69
cstrs := make([]string, 0, len(blobs))
···
72
81
}
73
82
74
83
var newcursor *string
75
75
-
if len(blobs) == 50 {
84
84
+
if len(blobs) == limit {
76
85
newcursor = &blobs[len(blobs)-1].CreatedAt
77
86
}
78
87
79
79
-
return e.JSON(200, ComAtprotoSyncListBlobsResponse{
88
88
+
s.writeJSON(w, http.StatusOK, ComAtprotoSyncListBlobsResponse{
80
89
Cursor: newcursor,
81
90
Cids: cstrs,
82
91
})
+9
-10
server/handle_sync_subscribe_repos.go
···
2
2
3
3
import (
4
4
"context"
5
5
+
"net/http"
5
6
"strconv"
6
7
"time"
7
8
···
9
10
"github.com/bluesky-social/indigo/lex/util"
10
11
"github.com/btcsuite/websocket"
11
12
"github.com/haileyok/cocoon/metrics"
12
12
-
"github.com/labstack/echo/v4"
13
13
)
14
14
15
15
-
func (s *Server) handleSyncSubscribeRepos(e echo.Context) error {
16
16
-
ctx, cancel := context.WithCancel(e.Request().Context())
15
15
+
func (s *Server) handleSyncSubscribeRepos(w http.ResponseWriter, r *http.Request) {
16
16
+
ctx, cancel := context.WithCancel(r.Context())
17
17
defer cancel()
18
18
19
19
logger := s.logger.With("component", "subscribe-repos-websocket")
20
20
21
21
-
conn, err := websocket.Upgrade(e.Response().Writer, e.Request(), e.Response().Header(), 1<<10, 1<<10)
21
21
+
conn, err := websocket.Upgrade(w, r, w.Header(), 1<<10, 1<<10)
22
22
if err != nil {
23
23
logger.Error("unable to establish websocket with relay", "err", err)
24
24
-
return err
24
24
+
return
25
25
}
26
26
27
27
-
ident := e.RealIP() + "-" + e.Request().UserAgent()
27
27
+
ident := r.RemoteAddr + "-" + r.UserAgent()
28
28
logger = logger.With("ident", ident)
29
29
logger.Info("new connection established")
30
30
31
31
var since *int64
32
32
-
if cursorStr := e.QueryParam("cursor"); cursorStr != "" {
32
32
+
if cursorStr := r.URL.Query().Get("cursor"); cursorStr != "" {
33
33
cursor, err := strconv.ParseInt(cursorStr, 10, 64)
34
34
if err != nil {
35
35
logger.Warn("invalid cursor parameter", "cursor", cursorStr, "err", err)
···
48
48
return true
49
49
}, since)
50
50
if err != nil {
51
51
-
return err
51
51
+
logger.Error("error subscribing to event manager", "err", err)
52
52
+
return
52
53
}
53
54
defer evtManCancel()
54
55
···
134
135
logger.Error("error requesting crawls", "err", err)
135
136
}
136
137
}()
137
137
-
138
138
-
return nil
139
138
}
+26
-19
server/handle_well_known.go
···
2
2
3
3
import (
4
4
"fmt"
5
5
+
"net/http"
5
6
"strings"
6
7
7
8
"github.com/Azure/go-autorest/autorest/to"
8
9
"github.com/haileyok/cocoon/internal/helpers"
9
9
-
"github.com/labstack/echo/v4"
10
10
"gorm.io/gorm"
11
11
)
12
12
···
50
50
ClientIDMetadataDocumentSupported bool `json:"client_id_metadata_document_supported"`
51
51
}
52
52
53
53
-
func (s *Server) handleWellKnown(e echo.Context) error {
54
54
-
return e.JSON(200, map[string]any{
53
53
+
func (s *Server) handleWellKnown(w http.ResponseWriter, r *http.Request) {
54
54
+
s.writeJSON(w, 200, map[string]any{
55
55
"@context": []string{
56
56
"https://www.w3.org/ns/did/v1",
57
57
},
···
66
66
})
67
67
}
68
68
69
69
-
func (s *Server) handleAtprotoDid(e echo.Context) error {
70
70
-
ctx := e.Request().Context()
69
69
+
func (s *Server) handleAtprotoDid(w http.ResponseWriter, r *http.Request) {
70
70
+
ctx := r.Context()
71
71
logger := s.logger.With("name", "handleAtprotoDid")
72
72
73
73
-
host := e.Request().Host
73
73
+
host := r.Host
74
74
if host == "" {
75
75
-
return helpers.InputError(e, to.StringPtr("Invalid handle."))
75
75
+
helpers.InputError(w, to.StringPtr("Invalid handle."))
76
76
+
return
76
77
}
77
78
78
79
host = strings.Split(host, ":")[0]
79
80
host = strings.ToLower(strings.TrimSpace(host))
80
81
81
82
if host == s.config.Hostname {
82
82
-
return e.String(200, s.config.Did)
83
83
+
w.Header().Set("Content-Type", "text/plain")
84
84
+
fmt.Fprint(w, s.config.Did)
85
85
+
return
83
86
}
84
87
85
88
suffix := "." + s.config.Hostname
86
89
if !strings.HasSuffix(host, suffix) {
87
87
-
return e.NoContent(404)
90
90
+
w.WriteHeader(http.StatusNotFound)
91
91
+
return
88
92
}
89
93
90
94
actor, err := s.getActorByHandle(ctx, host)
91
95
if err != nil {
92
96
if err == gorm.ErrRecordNotFound {
93
93
-
return e.NoContent(404)
97
97
+
w.WriteHeader(http.StatusNotFound)
98
98
+
return
94
99
}
95
100
logger.Error("error looking up actor by handle", "error", err)
96
96
-
return helpers.ServerError(e, nil)
101
101
+
helpers.ServerError(w, nil)
102
102
+
return
97
103
}
98
104
99
99
-
return e.String(200, actor.Did)
105
105
+
w.Header().Set("Content-Type", "text/plain")
106
106
+
fmt.Fprint(w, actor.Did)
100
107
}
101
108
102
102
-
func (s *Server) handleOauthProtectedResource(e echo.Context) error {
103
103
-
return e.JSON(200, map[string]any{
109
109
+
func (s *Server) handleOauthProtectedResource(w http.ResponseWriter, r *http.Request) {
110
110
+
s.writeJSON(w, 200, map[string]any{
104
111
"resource": "https://" + s.config.Hostname,
105
112
"authorization_servers": []string{
106
113
"https://" + s.config.Hostname,
···
111
118
})
112
119
}
113
120
114
114
-
func (s *Server) handleOauthAuthorizationServer(e echo.Context) error {
115
115
-
return e.JSON(200, OauthAuthorizationMetadata{
121
121
+
func (s *Server) handleOauthAuthorizationServer(w http.ResponseWriter, r *http.Request) {
122
122
+
s.writeJSON(w, 200, OauthAuthorizationMetadata{
116
123
Issuer: "https://" + s.config.Hostname,
117
124
RequestParameterSupported: true,
118
125
RequestUriParameterSupported: true,
···
125
132
CodeChallengeMethodsSupported: []string{"S256"},
126
133
UILocalesSupported: []string{"en-US"},
127
134
DisplayValuesSupported: []string{"page", "popup", "touch"},
128
128
-
RequestObjectSigningAlgValuesSupported: []string{"ES256"}, // only es256 for now...
135
135
+
RequestObjectSigningAlgValuesSupported: []string{"ES256"},
129
136
AuthorizationResponseISSParameterSupported: true,
130
137
RequestObjectEncryptionAlgValuesSupported: []string{},
131
138
RequestObjectEncryptionEncValuesSupported: []string{},
···
133
140
AuthorizationEndpoint: fmt.Sprintf("https://%s/oauth/authorize", s.config.Hostname),
134
141
TokenEndpoint: fmt.Sprintf("https://%s/oauth/token", s.config.Hostname),
135
142
TokenEndpointAuthMethodsSupported: []string{"none", "private_key_jwt"},
136
136
-
TokenEndpointAuthSigningAlgValuesSupported: []string{"ES256"}, // Same as above, just es256
143
143
+
TokenEndpointAuthSigningAlgValuesSupported: []string{"ES256"},
137
144
RevocationEndpoint: fmt.Sprintf("https://%s/oauth/revoke", s.config.Hostname),
138
145
IntrospectionEndpoint: fmt.Sprintf("https://%s/oauth/introspect", s.config.Hostname),
139
146
PushedAuthorizationRequestEndpoint: fmt.Sprintf("https://%s/oauth/par", s.config.Hostname),
140
147
RequirePushedAuthorizationRequests: true,
141
141
-
DpopSigningAlgValuesSupported: []string{"ES256"}, // again same as above
148
148
+
DpopSigningAlgValuesSupported: []string{"ES256"},
142
149
ProtectedResources: []string{"https://" + s.config.Hostname},
143
150
ClientIDMetadataDocumentSupported: true,
144
151
})
+128
-78
server/middleware.go
···
1
1
package server
2
2
3
3
import (
4
4
+
"context"
4
5
"crypto/sha256"
5
6
"encoding/base64"
6
7
"errors"
7
8
"fmt"
9
9
+
"net/http"
8
10
"strings"
9
11
"time"
10
12
···
14
16
"github.com/haileyok/cocoon/models"
15
17
"github.com/haileyok/cocoon/oauth/dpop"
16
18
"github.com/haileyok/cocoon/oauth/provider"
17
17
-
"github.com/labstack/echo/v4"
18
19
"gitlab.com/yawning/secp256k1-voi"
19
20
secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
20
21
"gorm.io/gorm"
21
22
)
22
23
23
23
-
func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
24
24
-
return func(e echo.Context) error {
25
25
-
username, password, ok := e.Request().BasicAuth()
26
26
-
if !ok || username != "admin" || password != s.config.AdminPassword {
27
27
-
return helpers.InputError(e, to.StringPtr("Unauthorized"))
28
28
-
}
24
24
+
// context keys for values set by middleware
25
25
+
type contextKey string
29
26
30
30
-
if err := next(e); err != nil {
31
31
-
e.Error(err)
32
32
-
}
27
27
+
const (
28
28
+
contextKeyRepo contextKey = "repo"
29
29
+
contextKeyDid contextKey = "did"
30
30
+
contextKeyToken contextKey = "token"
31
31
+
contextKeyScopes contextKey = "scopes"
32
32
+
33
33
+
// used by proxy handler to override token fields
34
34
+
contextKeyProxyTokenLxm contextKey = "proxyTokenLxm"
35
35
+
contextKeyProxyTokenAud contextKey = "proxyTokenAud"
36
36
+
)
33
37
34
34
-
return nil
35
35
-
}
38
38
+
func setContextValue(r *http.Request, key contextKey, value any) *http.Request {
39
39
+
return r.WithContext(context.WithValue(r.Context(), key, value))
40
40
+
}
41
41
+
42
42
+
func getContextValue[T any](r *http.Request, key contextKey) (T, bool) {
43
43
+
v, ok := r.Context().Value(key).(T)
44
44
+
return v, ok
36
45
}
37
46
38
38
-
func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
39
39
-
return func(e echo.Context) error {
40
40
-
ctx := e.Request().Context()
47
47
+
func (s *Server) handleAdminMiddleware(next http.Handler) http.Handler {
48
48
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49
49
+
username, password, ok := r.BasicAuth()
50
50
+
if !ok || username != "admin" || password != s.config.AdminPassword {
51
51
+
helpers.InputError(w, to.StringPtr("Unauthorized"))
52
52
+
return
53
53
+
}
54
54
+
next.ServeHTTP(w, r)
55
55
+
})
56
56
+
}
57
57
+
58
58
+
func (s *Server) handleLegacySessionMiddleware(next http.Handler) http.Handler {
59
59
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
60
60
+
ctx := r.Context()
41
61
logger := s.logger.With("name", "handleLegacySessionMiddleware")
42
62
43
43
-
authheader := e.Request().Header.Get("authorization")
63
63
+
authheader := r.Header.Get("authorization")
44
64
if authheader == "" {
45
45
-
return e.JSON(401, map[string]string{"error": "Unauthorized"})
65
65
+
s.writeJSON(w, 401, map[string]string{"error": "Unauthorized"})
66
66
+
return
46
67
}
47
68
48
69
pts := strings.Split(authheader, " ")
49
70
if len(pts) != 2 {
50
50
-
return helpers.ServerError(e, nil)
71
71
+
helpers.ServerError(w, nil)
72
72
+
return
51
73
}
52
74
53
75
// move on to oauth session middleware if this is a dpop token
54
76
if pts[0] == "DPoP" {
55
55
-
return next(e)
77
77
+
next.ServeHTTP(w, r)
78
78
+
return
56
79
}
57
80
58
81
tokenstr := pts[1]
59
82
token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{})
60
83
claims, ok := token.Claims.(jwt.MapClaims)
61
84
if !ok {
62
62
-
return helpers.InvalidTokenError(e)
85
85
+
helpers.InvalidTokenError(w)
86
86
+
return
63
87
}
64
88
65
89
var did string
···
68
92
// service auth tokens
69
93
lxm, hasLxm := claims["lxm"]
70
94
if hasLxm {
71
71
-
pts := strings.Split(e.Request().URL.String(), "/")
95
95
+
pts := strings.Split(r.URL.String(), "/")
72
96
if lxm != pts[len(pts)-1] {
73
97
logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err)
74
74
-
return helpers.InputError(e, nil)
98
98
+
helpers.InputError(w, nil)
99
99
+
return
75
100
}
76
101
77
102
maybeDid, ok := claims["iss"].(string)
78
103
if !ok {
79
104
logger.Error("no iss in service auth token", "error", err)
80
80
-
return helpers.InputError(e, nil)
105
105
+
helpers.InputError(w, nil)
106
106
+
return
81
107
}
82
108
did = maybeDid
83
109
84
110
maybeRepo, err := s.getRepoActorByDid(ctx, did)
85
111
if err != nil {
86
112
logger.Error("error fetching repo", "error", err)
87
87
-
return helpers.ServerError(e, nil)
113
113
+
helpers.ServerError(w, nil)
114
114
+
return
88
115
}
89
116
repo = maybeRepo
90
117
}
···
98
125
})
99
126
if err != nil {
100
127
logger.Error("error parsing jwt", "error", err)
101
101
-
return helpers.ExpiredTokenError(e)
128
128
+
helpers.ExpiredTokenError(w)
129
129
+
return
102
130
}
103
131
104
132
if !token.Valid {
105
105
-
return helpers.InvalidTokenError(e)
133
133
+
helpers.InvalidTokenError(w)
134
134
+
return
106
135
}
107
136
} else {
108
137
kpts := strings.Split(tokenstr, ".")
···
111
140
sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2])
112
141
if err != nil {
113
142
logger.Error("error decoding signature bytes", "error", err)
114
114
-
return helpers.ServerError(e, nil)
143
143
+
helpers.ServerError(w, nil)
144
144
+
return
115
145
}
116
146
117
147
if len(sigBytes) != 64 {
118
148
logger.Error("incorrect sigbytes length", "length", len(sigBytes))
119
119
-
return helpers.ServerError(e, nil)
149
149
+
helpers.ServerError(w, nil)
150
150
+
return
120
151
}
121
152
122
153
rBytes := sigBytes[:32]
···
128
159
sub, ok := claims["sub"].(string)
129
160
if !ok {
130
161
s.logger.Error("no sub claim in ES256K token and repo not set")
131
131
-
return helpers.InvalidTokenError(e)
162
162
+
helpers.InvalidTokenError(w)
163
163
+
return
132
164
}
133
165
maybeRepo, err := s.getRepoActorByDid(ctx, sub)
134
166
if err != nil {
135
167
s.logger.Error("error fetching repo for ES256K verification", "error", err)
136
136
-
return helpers.ServerError(e, nil)
168
168
+
helpers.ServerError(w, nil)
169
169
+
return
137
170
}
138
171
repo = maybeRepo
139
172
did = sub
···
142
175
sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
143
176
if err != nil {
144
177
logger.Error("can't load private key", "error", err)
145
145
-
return err
178
178
+
helpers.ServerError(w, nil)
179
179
+
return
146
180
}
147
181
148
182
pubKey, ok := sk.Public().(*secp256k1secec.PublicKey)
149
183
if !ok {
150
184
logger.Error("error getting public key from sk")
151
151
-
return helpers.ServerError(e, nil)
185
185
+
helpers.ServerError(w, nil)
186
186
+
return
152
187
}
153
188
154
189
verified := pubKey.VerifyRaw(hash[:], rr, ss)
155
190
if !verified {
156
191
logger.Error("error verifying", "error", err)
157
157
-
return helpers.ServerError(e, nil)
192
192
+
helpers.ServerError(w, nil)
193
193
+
return
158
194
}
159
195
}
160
196
161
161
-
isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
197
197
+
isRefresh := r.URL.Path == "/xrpc/com.atproto.server.refreshSession"
162
198
scope, _ := claims["scope"].(string)
163
199
164
200
if isRefresh && scope != "com.atproto.refresh" {
165
165
-
return helpers.InvalidTokenError(e)
201
201
+
helpers.InvalidTokenError(w)
202
202
+
return
166
203
} else if !hasLxm && !isRefresh && scope != "com.atproto.access" {
167
167
-
return helpers.InvalidTokenError(e)
204
204
+
helpers.InvalidTokenError(w)
205
205
+
return
168
206
}
169
207
170
208
table := "tokens"
···
179
217
var result Result
180
218
if err := s.db.Raw(ctx, "SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
181
219
if err == gorm.ErrRecordNotFound {
182
182
-
return helpers.InvalidTokenError(e)
220
220
+
helpers.InvalidTokenError(w)
221
221
+
return
183
222
}
184
223
185
224
logger.Error("error getting token from db", "error", err)
186
186
-
return helpers.ServerError(e, nil)
225
225
+
helpers.ServerError(w, nil)
226
226
+
return
187
227
}
188
228
189
229
if !result.Found {
190
190
-
return helpers.InvalidTokenError(e)
230
230
+
helpers.InvalidTokenError(w)
231
231
+
return
191
232
}
192
233
}
193
234
194
235
exp, ok := claims["exp"].(float64)
195
236
if !ok {
196
237
logger.Error("error getting iat from token")
197
197
-
return helpers.ServerError(e, nil)
238
238
+
helpers.ServerError(w, nil)
239
239
+
return
198
240
}
199
241
200
242
if exp < float64(time.Now().UTC().Unix()) {
201
201
-
return helpers.ExpiredTokenError(e)
243
243
+
helpers.ExpiredTokenError(w)
244
244
+
return
202
245
}
203
246
204
247
if repo == nil {
205
248
maybeRepo, err := s.getRepoActorByDid(ctx, claims["sub"].(string))
206
249
if err != nil {
207
250
logger.Error("error fetching repo", "error", err)
208
208
-
return helpers.ServerError(e, nil)
251
251
+
helpers.ServerError(w, nil)
252
252
+
return
209
253
}
210
254
repo = maybeRepo
211
255
did = repo.Repo.Did
212
256
}
213
257
214
214
-
e.Set("repo", repo)
215
215
-
e.Set("did", did)
216
216
-
e.Set("token", tokenstr)
258
258
+
r = setContextValue(r, contextKeyRepo, repo)
259
259
+
r = setContextValue(r, contextKeyDid, did)
260
260
+
r = setContextValue(r, contextKeyToken, tokenstr)
217
261
218
218
-
if err := next(e); err != nil {
219
219
-
return helpers.InvalidTokenError(e)
220
220
-
}
221
221
-
222
222
-
return nil
223
223
-
}
262
262
+
next.ServeHTTP(w, r)
263
263
+
})
224
264
}
225
265
226
226
-
func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
227
227
-
return func(e echo.Context) error {
228
228
-
ctx := e.Request().Context()
266
266
+
func (s *Server) handleOauthSessionMiddleware(next http.Handler) http.Handler {
267
267
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
268
268
+
ctx := r.Context()
229
269
logger := s.logger.With("name", "handleOauthSessionMiddleware")
230
270
231
231
-
authheader := e.Request().Header.Get("authorization")
271
271
+
authheader := r.Header.Get("authorization")
232
272
if authheader == "" {
233
233
-
return e.JSON(401, map[string]string{"error": "Unauthorized"})
273
273
+
s.writeJSON(w, 401, map[string]string{"error": "Unauthorized"})
274
274
+
return
234
275
}
235
276
236
277
pts := strings.Split(authheader, " ")
237
278
if len(pts) != 2 {
238
238
-
return helpers.ServerError(e, nil)
279
279
+
helpers.ServerError(w, nil)
280
280
+
return
239
281
}
240
282
241
283
if pts[0] != "DPoP" {
242
242
-
return next(e)
284
284
+
next.ServeHTTP(w, r)
285
285
+
return
243
286
}
244
287
245
288
accessToken := pts[1]
246
289
247
290
nonce := s.oauthProvider.NextNonce()
248
291
if nonce != "" {
249
249
-
e.Response().Header().Set("DPoP-Nonce", nonce)
250
250
-
e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
292
292
+
w.Header().Set("DPoP-Nonce", nonce)
293
293
+
w.Header().Add("access-control-expose-headers", "DPoP-Nonce")
251
294
}
252
295
253
253
-
proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken))
296
296
+
proof, err := s.oauthProvider.DpopManager.CheckProof(r.Method, "https://"+s.config.Hostname+r.URL.String(), r.Header, to.StringPtr(accessToken))
254
297
if err != nil {
255
298
if errors.Is(err, dpop.ErrUseDpopNonce) {
256
256
-
e.Response().Header().Set("WWW-Authenticate", `DPoP error="use_dpop_nonce"`)
257
257
-
e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate")
258
258
-
return e.JSON(401, map[string]string{
299
299
+
w.Header().Set("WWW-Authenticate", `DPoP error="use_dpop_nonce"`)
300
300
+
w.Header().Add("access-control-expose-headers", "WWW-Authenticate")
301
301
+
s.writeJSON(w, 401, map[string]string{
259
302
"error": "use_dpop_nonce",
260
303
})
304
304
+
return
261
305
}
262
306
logger.Error("invalid dpop proof", "error", err)
263
263
-
return helpers.InputError(e, nil)
307
307
+
helpers.InputError(w, nil)
308
308
+
return
264
309
}
265
310
266
311
var oauthToken provider.OauthToken
267
312
if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
268
313
logger.Error("error finding access token in db", "error", err)
269
269
-
return helpers.InputError(e, nil)
314
314
+
helpers.InputError(w, nil)
315
315
+
return
270
316
}
271
317
272
318
if oauthToken.Token == "" {
273
273
-
return helpers.InvalidTokenError(e)
319
319
+
helpers.InvalidTokenError(w)
320
320
+
return
274
321
}
275
322
276
323
if *oauthToken.Parameters.DpopJkt != proof.JKT {
277
324
logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT)
278
278
-
return helpers.InputError(e, to.StringPtr("dpop jkt mismatch"))
325
325
+
helpers.InputError(w, to.StringPtr("dpop jkt mismatch"))
326
326
+
return
279
327
}
280
328
281
329
if time.Now().After(oauthToken.ExpiresAt) {
282
282
-
e.Response().Header().Set("WWW-Authenticate", `DPoP error="invalid_token", error_description="Token expired"`)
283
283
-
e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate")
284
284
-
return e.JSON(401, map[string]string{
330
330
+
w.Header().Set("WWW-Authenticate", `DPoP error="invalid_token", error_description="Token expired"`)
331
331
+
w.Header().Add("access-control-expose-headers", "WWW-Authenticate")
332
332
+
s.writeJSON(w, 401, map[string]string{
285
333
"error": "invalid_token",
286
334
"error_description": "Token expired",
287
335
})
336
336
+
return
288
337
}
289
338
290
339
repo, err := s.getRepoActorByDid(ctx, oauthToken.Sub)
291
340
if err != nil {
292
341
logger.Error("could not find actor in db", "error", err)
293
293
-
return helpers.ServerError(e, nil)
342
342
+
helpers.ServerError(w, nil)
343
343
+
return
294
344
}
295
345
296
296
-
e.Set("repo", repo)
297
297
-
e.Set("did", repo.Repo.Did)
298
298
-
e.Set("token", accessToken)
299
299
-
e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " "))
346
346
+
r = setContextValue(r, contextKeyRepo, repo)
347
347
+
r = setContextValue(r, contextKeyDid, repo.Repo.Did)
348
348
+
r = setContextValue(r, contextKeyToken, accessToken)
349
349
+
r = setContextValue(r, contextKeyScopes, strings.Split(oauthToken.Parameters.Scope, " "))
300
350
301
301
-
return next(e)
302
302
-
}
351
351
+
next.ServeHTTP(w, r)
352
352
+
})
303
353
}
+159
-118
server/server.go
···
4
4
"context"
5
5
"crypto/ecdsa"
6
6
"embed"
7
7
+
"encoding/json"
7
8
"errors"
8
9
"fmt"
10
10
+
"html/template"
9
11
"io"
10
12
"log/slog"
11
13
"net/http"
···
13
15
"os"
14
16
"path/filepath"
15
17
"sync"
16
16
-
"text/template"
17
18
"time"
18
19
19
20
"github.com/bluesky-social/indigo/api/atproto"
···
23
24
"github.com/bluesky-social/indigo/xrpc"
24
25
"github.com/domodwyer/mailyak/v3"
25
26
"github.com/glebarez/sqlite"
27
27
+
"github.com/go-chi/chi/v5"
28
28
+
chimiddleware "github.com/go-chi/chi/v5/middleware"
26
29
"github.com/go-playground/validator"
27
30
"github.com/gorilla/sessions"
28
31
"github.com/haileyok/cocoon/identity"
···
35
38
"github.com/haileyok/cocoon/oauth/provider"
36
39
"github.com/haileyok/cocoon/plc"
37
40
"github.com/ipfs/go-cid"
38
38
-
"github.com/labstack/echo-contrib/echoprometheus"
39
39
-
echo_session "github.com/labstack/echo-contrib/session"
40
40
-
"github.com/labstack/echo/v4"
41
41
-
"github.com/labstack/echo/v4/middleware"
42
42
-
slogecho "github.com/samber/slog-echo"
41
41
+
"github.com/prometheus/client_golang/prometheus/promhttp"
42
42
+
43
43
"gorm.io/gorm"
44
44
)
45
45
···
76
76
}
77
77
78
78
type Server struct {
79
79
-
http *http.Client
80
80
-
httpd *http.Server
81
81
-
mail *mailyak.MailYak
82
82
-
mailLk *sync.Mutex
83
83
-
echo *echo.Echo
84
84
-
db *db.DB
85
85
-
plcClient *plc.Client
86
86
-
logger *slog.Logger
87
87
-
config *config
88
88
-
privateKey *ecdsa.PrivateKey
89
89
-
repoman *RepoMan
90
90
-
oauthProvider *provider.Provider
91
91
-
evtman *events.EventManager
92
92
-
passport *identity.Passport
93
93
-
fallbackProxy string
79
79
+
http *http.Client
80
80
+
httpd *http.Server
81
81
+
mail *mailyak.MailYak
82
82
+
mailLk *sync.Mutex
83
83
+
router *chi.Mux
84
84
+
db *db.DB
85
85
+
plcClient *plc.Client
86
86
+
logger *slog.Logger
87
87
+
config *config
88
88
+
privateKey *ecdsa.PrivateKey
89
89
+
repoman *RepoMan
90
90
+
oauthProvider *provider.Provider
91
91
+
evtman *events.EventManager
92
92
+
passport *identity.Passport
93
93
+
fallbackProxy string
94
94
+
sessions *sessions.CookieStore
95
95
+
validator *validator.Validate
96
96
+
templateRenderer *TemplateRenderer
94
97
95
98
lastRequestCrawl time.Time
96
99
requestCrawlMu sync.Mutex
···
190
193
absPath, _ := filepath.Abs("server/templates/*.html")
191
194
if s.config.Version == "dev" {
192
195
tmpl := template.Must(template.ParseGlob(absPath))
193
193
-
s.echo.Renderer = &TemplateRenderer{
196
196
+
s.templateRenderer = &TemplateRenderer{
194
197
templates: tmpl,
195
198
isDev: true,
196
199
templatePath: absPath,
197
200
}
198
201
} else {
199
202
tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
200
200
-
s.echo.Renderer = &TemplateRenderer{
203
203
+
s.templateRenderer = &TemplateRenderer{
201
204
templates: tmpl,
202
205
isDev: false,
203
206
}
204
207
}
205
208
}
206
209
207
207
-
func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
210
210
+
func (t *TemplateRenderer) Render(w io.Writer, name string, data any) error {
208
211
if t.isDev {
209
212
tmpl, err := template.ParseGlob(t.templatePath)
210
213
if err != nil {
···
213
216
t.templates = tmpl
214
217
}
215
218
216
216
-
if viewContext, isMap := data.(map[string]any); isMap {
217
217
-
viewContext["reverse"] = c.Echo().Reverse
218
218
-
}
219
219
+
return t.templates.ExecuteTemplate(w, name, data)
220
220
+
}
221
221
+
222
222
+
// renderTemplate is a convenience method on the server that renders a named
223
223
+
// HTML template to the given ResponseWriter with a 200 status.
224
224
+
func (s *Server) renderTemplate(w http.ResponseWriter, name string, data any) error {
225
225
+
w.Header().Set("Content-Type", "text/html; charset=utf-8")
226
226
+
return s.templateRenderer.Render(w, name, data)
227
227
+
}
219
228
220
220
-
return t.templates.ExecuteTemplate(w, name, data)
229
229
+
// writeJSON writes a JSON-encoded value with the given status code.
230
230
+
func (s *Server) writeJSON(w http.ResponseWriter, status int, v any) {
231
231
+
w.Header().Set("Content-Type", "application/json")
232
232
+
w.WriteHeader(status)
233
233
+
if err := json.NewEncoder(w).Encode(v); err != nil {
234
234
+
s.logger.Error("failed to encode JSON response", "error", err)
235
235
+
}
221
236
}
222
237
223
238
func New(args *Args) (*Server, error) {
···
259
274
panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
260
275
}
261
276
262
262
-
e := echo.New()
277
277
+
r := chi.NewRouter()
263
278
264
264
-
e.Pre(middleware.RemoveTrailingSlash())
265
265
-
e.Pre(slogecho.New(args.Logger.With("component", "slogecho")))
266
266
-
e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
267
267
-
e.Use(echoprometheus.NewMiddleware("cocoon"))
268
268
-
e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
269
269
-
AllowOrigins: []string{"*"},
270
270
-
AllowHeaders: []string{"*"},
271
271
-
AllowMethods: []string{"*"},
272
272
-
AllowCredentials: true,
273
273
-
MaxAge: 100_000_000,
274
274
-
}))
279
279
+
r.Use(chimiddleware.StripSlashes)
280
280
+
r.Use(func(next http.Handler) http.Handler {
281
281
+
logger := args.Logger.With("component", "http")
282
282
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
283
283
+
next.ServeHTTP(w, r)
284
284
+
logger.Info("request", "method", r.Method, "path", r.URL.Path, "remote_addr", r.RemoteAddr)
285
285
+
})
286
286
+
})
287
287
+
r.Use(corsMiddleware)
275
288
276
289
vdtor := validator.New()
277
290
vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
···
299
312
return true
300
313
})
301
314
302
302
-
e.Validator = &CustomValidator{validator: vdtor}
303
303
-
304
315
httpd := &http.Server{
305
316
Addr: args.Addr,
306
306
-
Handler: e,
317
317
+
Handler: r,
307
318
// shitty defaults but okay for now, needed for import repo
308
319
ReadTimeout: 5 * time.Minute,
309
320
WriteTimeout: 5 * time.Minute,
···
370
381
return nil, fmt.Errorf("failed to create event persister: %w", err)
371
382
}
372
383
384
384
+
cookieStore := sessions.NewCookieStore([]byte(args.SessionSecret))
385
385
+
373
386
s := &Server{
374
387
http: h,
375
388
httpd: httpd,
376
376
-
echo: e,
389
389
+
router: r,
377
390
logger: args.Logger,
378
391
db: dbw,
379
392
plcClient: plcClient,
380
393
privateKey: &pkey,
394
394
+
sessions: cookieStore,
395
395
+
validator: vdtor,
381
396
config: &config{
382
397
Version: args.Version,
383
398
Did: args.Did,
···
438
453
return s, nil
439
454
}
440
455
456
456
+
// corsMiddleware adds permissive CORS headers to every response.
457
457
+
func corsMiddleware(next http.Handler) http.Handler {
458
458
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
459
459
+
w.Header().Set("Access-Control-Allow-Origin", "*")
460
460
+
w.Header().Set("Access-Control-Allow-Headers", "*")
461
461
+
w.Header().Set("Access-Control-Allow-Methods", "*")
462
462
+
w.Header().Set("Access-Control-Allow-Credentials", "true")
463
463
+
w.Header().Set("Access-Control-Max-Age", "100000000")
464
464
+
if r.Method == http.MethodOptions {
465
465
+
w.WriteHeader(http.StatusNoContent)
466
466
+
return
467
467
+
}
468
468
+
next.ServeHTTP(w, r)
469
469
+
})
470
470
+
}
471
471
+
441
472
func (s *Server) addRoutes() {
473
473
+
r := s.router
474
474
+
442
475
// static
443
476
if s.config.Version == "dev" {
444
444
-
s.echo.Static("/static", "server/static")
477
477
+
r.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.Dir("server/static"))))
445
478
} else {
446
446
-
s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
479
479
+
r.Handle("/static/*", http.FileServer(http.FS(staticFS)))
447
480
}
448
481
482
482
+
// metrics
483
483
+
r.Handle("/metrics", promhttp.Handler())
484
484
+
449
485
// random stuff
450
450
-
s.echo.GET("/", s.handleRoot)
451
451
-
s.echo.GET("/xrpc/_health", s.handleHealth)
452
452
-
s.echo.GET("/.well-known/did.json", s.handleWellKnown)
453
453
-
s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid)
454
454
-
s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
455
455
-
s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
456
456
-
s.echo.GET("/robots.txt", s.handleRobots)
486
486
+
r.Get("/", s.handleRoot)
487
487
+
r.Get("/xrpc/_health", s.handleHealth)
488
488
+
r.Get("/.well-known/did.json", s.handleWellKnown)
489
489
+
r.Get("/.well-known/atproto-did", s.handleAtprotoDid)
490
490
+
r.Get("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
491
491
+
r.Get("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
492
492
+
r.Get("/robots.txt", s.handleRobots)
457
493
458
494
// public
459
459
-
s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
460
460
-
s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
461
461
-
s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
462
462
-
s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
463
463
-
s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey)
495
495
+
r.Get("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
496
496
+
r.Post("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
497
497
+
r.Post("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
498
498
+
r.Get("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
499
499
+
r.Post("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey)
464
500
465
465
-
s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
466
466
-
s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
467
467
-
s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
468
468
-
s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
469
469
-
s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
470
470
-
s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
471
471
-
s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
472
472
-
s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
473
473
-
s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
474
474
-
s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
475
475
-
s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
476
476
-
s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
501
501
+
r.Get("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
502
502
+
r.Get("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
503
503
+
r.Get("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
504
504
+
r.Get("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
505
505
+
r.Get("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
506
506
+
r.Get("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
507
507
+
r.Get("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
508
508
+
r.Get("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
509
509
+
r.Get("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
510
510
+
r.Get("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
511
511
+
r.Get("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
512
512
+
r.Get("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
477
513
478
514
// labels
479
479
-
s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels)
515
515
+
r.Get("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels)
480
516
481
517
// account
482
482
-
s.echo.GET("/account", s.handleAccount)
483
483
-
s.echo.POST("/account/revoke", s.handleAccountRevoke)
484
484
-
s.echo.GET("/account/signin", s.handleAccountSigninGet)
485
485
-
s.echo.POST("/account/signin", s.handleAccountSigninPost)
486
486
-
s.echo.GET("/account/signout", s.handleAccountSignout)
518
518
+
r.Get("/account", s.handleAccount)
519
519
+
r.Post("/account/revoke", s.handleAccountRevoke)
520
520
+
r.Get("/account/signin", s.handleAccountSigninGet)
521
521
+
r.Post("/account/signin", s.handleAccountSigninPost)
522
522
+
r.Get("/account/signout", s.handleAccountSignout)
487
523
488
524
// oauth account
489
489
-
s.echo.GET("/oauth/jwks", s.handleOauthJwks)
490
490
-
s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
491
491
-
s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
525
525
+
r.Get("/oauth/jwks", s.handleOauthJwks)
526
526
+
r.Get("/oauth/authorize", s.handleOauthAuthorizeGet)
527
527
+
r.Post("/oauth/authorize", s.handleOauthAuthorizePost)
492
528
493
493
-
// oauth authorization
494
494
-
s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
495
495
-
s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
529
529
+
// oauth authorization (with BaseMiddleware)
530
530
+
r.With(s.oauthProvider.BaseMiddleware).Post("/oauth/par", s.handleOauthPar)
531
531
+
r.With(s.oauthProvider.BaseMiddleware).Post("/oauth/token", s.handleOauthToken)
496
532
497
533
// authed
498
498
-
s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
499
499
-
s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
500
500
-
s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
501
501
-
s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
502
502
-
s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
503
503
-
s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
504
504
-
s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
505
505
-
s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
506
506
-
s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
507
507
-
s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
508
508
-
s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
509
509
-
s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
510
510
-
s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
511
511
-
s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
512
512
-
s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
513
513
-
s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
514
514
-
s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
515
515
-
s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
516
516
-
s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
517
517
-
s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount)
534
534
+
authed := func(h http.HandlerFunc) http.Handler {
535
535
+
return s.handleLegacySessionMiddleware(s.handleOauthSessionMiddleware(h))
536
536
+
}
537
537
+
538
538
+
r.Get("/xrpc/com.atproto.server.getSession", authed(s.handleGetSession).ServeHTTP)
539
539
+
r.Post("/xrpc/com.atproto.server.refreshSession", authed(s.handleRefreshSession).ServeHTTP)
540
540
+
r.Post("/xrpc/com.atproto.server.deleteSession", authed(s.handleDeleteSession).ServeHTTP)
541
541
+
r.Get("/xrpc/com.atproto.identity.getRecommendedDidCredentials", authed(s.handleGetRecommendedDidCredentials).ServeHTTP)
542
542
+
r.Post("/xrpc/com.atproto.identity.updateHandle", authed(s.handleIdentityUpdateHandle).ServeHTTP)
543
543
+
r.Post("/xrpc/com.atproto.identity.requestPlcOperationSignature", authed(s.handleIdentityRequestPlcOperationSignature).ServeHTTP)
544
544
+
r.Post("/xrpc/com.atproto.identity.signPlcOperation", authed(s.handleSignPlcOperation).ServeHTTP)
545
545
+
r.Post("/xrpc/com.atproto.identity.submitPlcOperation", authed(s.handleSubmitPlcOperation).ServeHTTP)
546
546
+
r.Post("/xrpc/com.atproto.server.confirmEmail", authed(s.handleServerConfirmEmail).ServeHTTP)
547
547
+
r.Post("/xrpc/com.atproto.server.requestEmailConfirmation", authed(s.handleServerRequestEmailConfirmation).ServeHTTP)
548
548
+
r.Post("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
549
549
+
r.Post("/xrpc/com.atproto.server.requestEmailUpdate", authed(s.handleServerRequestEmailUpdate).ServeHTTP)
550
550
+
r.Post("/xrpc/com.atproto.server.resetPassword", authed(s.handleServerResetPassword).ServeHTTP)
551
551
+
r.Post("/xrpc/com.atproto.server.updateEmail", authed(s.handleServerUpdateEmail).ServeHTTP)
552
552
+
r.Get("/xrpc/com.atproto.server.getServiceAuth", authed(s.handleServerGetServiceAuth).ServeHTTP)
553
553
+
r.Get("/xrpc/com.atproto.server.checkAccountStatus", authed(s.handleServerCheckAccountStatus).ServeHTTP)
554
554
+
r.Post("/xrpc/com.atproto.server.deactivateAccount", authed(s.handleServerDeactivateAccount).ServeHTTP)
555
555
+
r.Post("/xrpc/com.atproto.server.activateAccount", authed(s.handleServerActivateAccount).ServeHTTP)
556
556
+
r.Post("/xrpc/com.atproto.server.requestAccountDelete", authed(s.handleServerRequestAccountDelete).ServeHTTP)
557
557
+
r.Post("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount)
518
558
519
559
// repo
520
520
-
s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
521
521
-
s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
522
522
-
s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
523
523
-
s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
524
524
-
s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
525
525
-
s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
526
526
-
s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
560
560
+
r.Get("/xrpc/com.atproto.repo.listMissingBlobs", authed(s.handleListMissingBlobs).ServeHTTP)
561
561
+
r.Post("/xrpc/com.atproto.repo.createRecord", authed(s.handleCreateRecord).ServeHTTP)
562
562
+
r.Post("/xrpc/com.atproto.repo.putRecord", authed(s.handlePutRecord).ServeHTTP)
563
563
+
r.Post("/xrpc/com.atproto.repo.deleteRecord", authed(s.handleDeleteRecord).ServeHTTP)
564
564
+
r.Post("/xrpc/com.atproto.repo.applyWrites", authed(s.handleApplyWrites).ServeHTTP)
565
565
+
r.Post("/xrpc/com.atproto.repo.uploadBlob", authed(s.handleRepoUploadBlob).ServeHTTP)
566
566
+
r.Post("/xrpc/com.atproto.repo.importRepo", authed(s.handleRepoImportRepo).ServeHTTP)
527
567
528
568
// stupid silly endpoints
529
529
-
s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
530
530
-
s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
531
531
-
s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
532
532
-
s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
569
569
+
r.Get("/xrpc/app.bsky.actor.getPreferences", authed(s.handleActorGetPreferences).ServeHTTP)
570
570
+
r.Post("/xrpc/app.bsky.actor.putPreferences", authed(s.handleActorPutPreferences).ServeHTTP)
571
571
+
r.Get("/xrpc/app.bsky.feed.getFeed", authed(s.handleProxyBskyFeedGetFeed).ServeHTTP)
572
572
+
r.Get("/xrpc/app.bsky.ageassurance.getState", authed(s.handleAgeAssurance).ServeHTTP)
573
573
+
533
574
// admin routes
534
534
-
s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
535
535
-
s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
575
575
+
r.With(s.handleAdminMiddleware).Post("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode)
576
576
+
r.With(s.handleAdminMiddleware).Post("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes)
536
577
537
537
-
// are there any routes that we should be allowing without auth? i dont think so but idk
538
538
-
s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
539
539
-
s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
578
578
+
// catch-all proxy (authed)
579
579
+
r.Get("/xrpc/*", authed(s.handleProxy).ServeHTTP)
580
580
+
r.Post("/xrpc/*", authed(s.handleProxy).ServeHTTP)
540
581
}
541
582
542
583
func (s *Server) Serve(ctx context.Context) error {