1package server
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "errors"
7 "fmt"
8 "log/slog"
9 "net/http"
10 "os"
11 "strings"
12 "time"
13
14 "github.com/Azure/go-autorest/autorest/to"
15 "github.com/bluesky-social/indigo/api/atproto"
16 "github.com/bluesky-social/indigo/atproto/syntax"
17 "github.com/bluesky-social/indigo/events"
18 "github.com/bluesky-social/indigo/util"
19 "github.com/bluesky-social/indigo/xrpc"
20 "github.com/go-playground/validator"
21 "github.com/golang-jwt/jwt/v4"
22 "github.com/haileyok/cocoon/identity"
23 "github.com/haileyok/cocoon/internal/helpers"
24 "github.com/haileyok/cocoon/models"
25 "github.com/haileyok/cocoon/plc"
26 "github.com/labstack/echo/v4"
27 "github.com/labstack/echo/v4/middleware"
28 "github.com/lestrrat-go/jwx/v2/jwk"
29 slogecho "github.com/samber/slog-echo"
30 "gorm.io/driver/sqlite"
31 "gorm.io/gorm"
32)
33
34type Server struct {
35 http *http.Client
36 httpd *http.Server
37 echo *echo.Echo
38 db *gorm.DB
39 plcClient *plc.Client
40 logger *slog.Logger
41 config *config
42 privateKey *ecdsa.PrivateKey
43 repoman *RepoMan
44 evtman *events.EventManager
45 passport *identity.Passport
46}
47
48type Args struct {
49 Addr string
50 DbName string
51 Logger *slog.Logger
52 Version string
53 Did string
54 Hostname string
55 RotationKeyPath string
56 JwkPath string
57 ContactEmail string
58 Relays []string
59}
60
61type config struct {
62 Version string
63 Did string
64 Hostname string
65 ContactEmail string
66 EnforcePeering bool
67 Relays []string
68}
69
70type CustomValidator struct {
71 validator *validator.Validate
72}
73
74type ValidationError struct {
75 error
76 Field string
77 Tag string
78}
79
80func (cv *CustomValidator) Validate(i any) error {
81 if err := cv.validator.Struct(i); err != nil {
82 var validateErrors validator.ValidationErrors
83 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
84 first := validateErrors[0]
85 return ValidationError{
86 error: err,
87 Field: first.Field(),
88 Tag: first.Tag(),
89 }
90 }
91
92 return err
93 }
94
95 return nil
96}
97
98func (s *Server) handleSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
99 return func(e echo.Context) error {
100 authheader := e.Request().Header.Get("authorization")
101 if authheader == "" {
102 return e.JSON(401, map[string]string{"error": "Unauthorized"})
103 }
104
105 pts := strings.Split(authheader, " ")
106 if len(pts) != 2 {
107 return helpers.ServerError(e, nil)
108 }
109
110 tokenstr := pts[1]
111
112 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
113 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
114 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
115 }
116
117 return s.privateKey.Public(), nil
118 })
119 if err != nil {
120 s.logger.Error("error parsing jwt", "error", err)
121 return helpers.InputError(e, to.StringPtr("InvalidToken"))
122 }
123
124 claims, ok := token.Claims.(jwt.MapClaims)
125 if !ok || !token.Valid {
126 return helpers.InputError(e, to.StringPtr("InvalidToken"))
127 }
128
129 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
130 scope := claims["scope"].(string)
131
132 if isRefresh && scope != "com.atproto.refresh" {
133 return helpers.InputError(e, to.StringPtr("InvalidToken"))
134 } else if !isRefresh && scope != "com.atproto.access" {
135 return helpers.InputError(e, to.StringPtr("InvalidToken"))
136 }
137
138 table := "tokens"
139 if isRefresh {
140 table = "refresh_tokens"
141 }
142
143 type Result struct {
144 Found bool
145 }
146 var result Result
147 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", tokenstr).Scan(&result).Error; err != nil {
148 if err == gorm.ErrRecordNotFound {
149 return helpers.InputError(e, to.StringPtr("InvalidToken"))
150 }
151
152 s.logger.Error("error getting token from db", "error", err)
153 return helpers.ServerError(e, nil)
154 }
155
156 if !result.Found {
157 return helpers.InputError(e, to.StringPtr("InvalidToken"))
158 }
159
160 exp, ok := claims["exp"].(float64)
161 if !ok {
162 s.logger.Error("error getting iat from token")
163 return helpers.ServerError(e, nil)
164 }
165
166 if exp < float64(time.Now().UTC().Unix()) {
167 return helpers.InputError(e, to.StringPtr("ExpiredToken"))
168 }
169
170 e.Set("did", claims["sub"])
171
172 repo, err := s.getRepoActorByDid(claims["sub"].(string))
173 if err != nil {
174 s.logger.Error("error fetching repo", "error", err)
175 return helpers.ServerError(e, nil)
176 }
177 e.Set("repo", repo)
178
179 e.Set("token", tokenstr)
180
181 if err := next(e); err != nil {
182 e.Error(err)
183 }
184
185 return nil
186 }
187}
188
189func New(args *Args) (*Server, error) {
190 if args.Addr == "" {
191 return nil, fmt.Errorf("addr must be set")
192 }
193
194 if args.DbName == "" {
195 return nil, fmt.Errorf("db name must be set")
196 }
197
198 if args.Did == "" {
199 return nil, fmt.Errorf("cocoon did must be set")
200 }
201
202 if args.ContactEmail == "" {
203 return nil, fmt.Errorf("cocoon contact email is required")
204 }
205
206 if _, err := syntax.ParseDID(args.Did); err != nil {
207 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
208 }
209
210 if args.Hostname == "" {
211 return nil, fmt.Errorf("cocoon hostname must be set")
212 }
213
214 if args.Logger == nil {
215 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
216 }
217
218 e := echo.New()
219
220 e.Pre(middleware.RemoveTrailingSlash())
221 e.Pre(slogecho.New(args.Logger))
222 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
223 AllowOrigins: []string{"*"},
224 AllowHeaders: []string{"*"},
225 AllowMethods: []string{"*"},
226 AllowCredentials: true,
227 MaxAge: 100_000_000,
228 }))
229
230 vdtor := validator.New()
231 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
232 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
233 return false
234 }
235 return true
236 })
237 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
238 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
239 return false
240 }
241 return true
242 })
243 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
244 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
245 return false
246 }
247 return true
248 })
249 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
250 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
251 return false
252 }
253 return true
254 })
255
256 e.Validator = &CustomValidator{validator: vdtor}
257
258 httpd := &http.Server{
259 Addr: args.Addr,
260 Handler: e,
261 }
262
263 db, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
264 if err != nil {
265 return nil, err
266 }
267
268 rkbytes, err := os.ReadFile(args.RotationKeyPath)
269 if err != nil {
270 return nil, err
271 }
272
273 h := util.RobustHTTPClient()
274
275 plcClient, err := plc.NewClient(&plc.ClientArgs{
276 H: h,
277 Service: "https://plc.directory",
278 PdsHostname: args.Hostname,
279 RotationKey: rkbytes,
280 })
281 if err != nil {
282 return nil, err
283 }
284
285 jwkbytes, err := os.ReadFile(args.JwkPath)
286 if err != nil {
287 return nil, err
288 }
289
290 key, err := jwk.ParseKey(jwkbytes)
291 if err != nil {
292 return nil, err
293 }
294
295 var pkey ecdsa.PrivateKey
296 if err := key.Raw(&pkey); err != nil {
297 return nil, err
298 }
299
300 s := &Server{
301 http: h,
302 httpd: httpd,
303 echo: e,
304 logger: args.Logger,
305 db: db,
306 plcClient: plcClient,
307 privateKey: &pkey,
308 config: &config{
309 Version: args.Version,
310 Did: args.Did,
311 Hostname: args.Hostname,
312 ContactEmail: args.ContactEmail,
313 EnforcePeering: false,
314 Relays: args.Relays,
315 },
316 evtman: events.NewEventManager(events.NewMemPersister()),
317 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
318 }
319
320 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
321
322 return s, nil
323}
324
325func (s *Server) addRoutes() {
326 s.echo.GET("/", s.handleRoot)
327 s.echo.GET("/xrpc/_health", s.handleHealth)
328 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
329 s.echo.GET("/robots.txt", s.handleRobots)
330
331 // public
332 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
333 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
334 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
335 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
336 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
337
338 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
339 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
340 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
341 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
342 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
343 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
344 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
345 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
346 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
347 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
348 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
349 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
350
351 // authed
352 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleSessionMiddleware)
353 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleSessionMiddleware)
354 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleSessionMiddleware)
355 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleSessionMiddleware)
356
357 // repo
358 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleSessionMiddleware)
359 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleSessionMiddleware)
360 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleSessionMiddleware)
361 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleSessionMiddleware)
362
363 // stupid silly endpoints
364 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleSessionMiddleware)
365 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleSessionMiddleware)
366
367 s.echo.GET("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
368 s.echo.POST("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
369}
370
371func (s *Server) Serve(ctx context.Context) error {
372 s.addRoutes()
373
374 s.logger.Info("migrating...")
375
376 s.db.AutoMigrate(
377 &models.Actor{},
378 &models.Repo{},
379 &models.InviteCode{},
380 &models.Token{},
381 &models.RefreshToken{},
382 &models.Block{},
383 &models.Record{},
384 &models.Blob{},
385 &models.BlobPart{},
386 )
387
388 s.logger.Info("starting cocoon")
389
390 go func() {
391 if err := s.httpd.ListenAndServe(); err != nil {
392 panic(err)
393 }
394 }()
395
396 for _, relay := range s.config.Relays {
397 cli := xrpc.Client{Host: relay}
398 atproto.SyncRequestCrawl(context.TODO(), &cli, &atproto.SyncRequestCrawl_Input{
399 Hostname: s.config.Hostname,
400 })
401 }
402
403 <-ctx.Done()
404
405 fmt.Println("shut down")
406
407 return nil
408}