1package server
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "errors"
7 "fmt"
8 "log/slog"
9 "net/http"
10 "net/smtp"
11 "os"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/Azure/go-autorest/autorest/to"
17 "github.com/bluesky-social/indigo/api/atproto"
18 "github.com/bluesky-social/indigo/atproto/syntax"
19 "github.com/bluesky-social/indigo/events"
20 "github.com/bluesky-social/indigo/util"
21 "github.com/bluesky-social/indigo/xrpc"
22 "github.com/domodwyer/mailyak/v3"
23 "github.com/go-playground/validator"
24 "github.com/golang-jwt/jwt/v4"
25 "github.com/haileyok/cocoon/identity"
26 "github.com/haileyok/cocoon/internal/helpers"
27 "github.com/haileyok/cocoon/models"
28 "github.com/haileyok/cocoon/plc"
29 "github.com/labstack/echo/v4"
30 "github.com/labstack/echo/v4/middleware"
31 "github.com/lestrrat-go/jwx/v2/jwk"
32 slogecho "github.com/samber/slog-echo"
33 "gorm.io/driver/sqlite"
34 "gorm.io/gorm"
35)
36
37type Server struct {
38 http *http.Client
39 httpd *http.Server
40 mail *mailyak.MailYak
41 mailLk *sync.Mutex
42 echo *echo.Echo
43 db *gorm.DB
44 plcClient *plc.Client
45 logger *slog.Logger
46 config *config
47 privateKey *ecdsa.PrivateKey
48 repoman *RepoMan
49 evtman *events.EventManager
50 passport *identity.Passport
51}
52
53type Args struct {
54 Addr string
55 DbName string
56 Logger *slog.Logger
57 Version string
58 Did string
59 Hostname string
60 RotationKeyPath string
61 JwkPath string
62 ContactEmail string
63 Relays []string
64
65 SmtpUser string
66 SmtpPass string
67 SmtpHost string
68 SmtpPort string
69 SmtpEmail string
70 SmtpName string
71}
72
73type config struct {
74 Version string
75 Did string
76 Hostname string
77 ContactEmail string
78 EnforcePeering bool
79 Relays []string
80 SmtpEmail string
81 SmtpName string
82}
83
84type CustomValidator struct {
85 validator *validator.Validate
86}
87
88type ValidationError struct {
89 error
90 Field string
91 Tag string
92}
93
94func (cv *CustomValidator) Validate(i any) error {
95 if err := cv.validator.Struct(i); err != nil {
96 var validateErrors validator.ValidationErrors
97 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
98 first := validateErrors[0]
99 return ValidationError{
100 error: err,
101 Field: first.Field(),
102 Tag: first.Tag(),
103 }
104 }
105
106 return err
107 }
108
109 return nil
110}
111
112func (s *Server) handleSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
113 return func(e echo.Context) error {
114 authheader := e.Request().Header.Get("authorization")
115 if authheader == "" {
116 return e.JSON(401, map[string]string{"error": "Unauthorized"})
117 }
118
119 pts := strings.Split(authheader, " ")
120 if len(pts) != 2 {
121 return helpers.ServerError(e, nil)
122 }
123
124 tokenstr := pts[1]
125
126 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
127 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
128 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
129 }
130
131 return s.privateKey.Public(), nil
132 })
133 if err != nil {
134 s.logger.Error("error parsing jwt", "error", err)
135 // NOTE: https://github.com/bluesky-social/atproto/discussions/3319
136 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
137 }
138
139 claims, ok := token.Claims.(jwt.MapClaims)
140 if !ok || !token.Valid {
141 return helpers.InputError(e, to.StringPtr("InvalidToken"))
142 }
143
144 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
145 scope := claims["scope"].(string)
146
147 if isRefresh && scope != "com.atproto.refresh" {
148 return helpers.InputError(e, to.StringPtr("InvalidToken"))
149 } else if !isRefresh && scope != "com.atproto.access" {
150 return helpers.InputError(e, to.StringPtr("InvalidToken"))
151 }
152
153 table := "tokens"
154 if isRefresh {
155 table = "refresh_tokens"
156 }
157
158 type Result struct {
159 Found bool
160 }
161 var result Result
162 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", tokenstr).Scan(&result).Error; err != nil {
163 if err == gorm.ErrRecordNotFound {
164 return helpers.InputError(e, to.StringPtr("InvalidToken"))
165 }
166
167 s.logger.Error("error getting token from db", "error", err)
168 return helpers.ServerError(e, nil)
169 }
170
171 if !result.Found {
172 return helpers.InputError(e, to.StringPtr("InvalidToken"))
173 }
174
175 exp, ok := claims["exp"].(float64)
176 if !ok {
177 s.logger.Error("error getting iat from token")
178 return helpers.ServerError(e, nil)
179 }
180
181 if exp < float64(time.Now().UTC().Unix()) {
182 return helpers.InputError(e, to.StringPtr("ExpiredToken"))
183 }
184
185 repo, err := s.getRepoActorByDid(claims["sub"].(string))
186 if err != nil {
187 s.logger.Error("error fetching repo", "error", err)
188 return helpers.ServerError(e, nil)
189 }
190
191 e.Set("repo", repo)
192 e.Set("did", claims["sub"])
193 e.Set("token", tokenstr)
194
195 if err := next(e); err != nil {
196 e.Error(err)
197 }
198
199 return nil
200 }
201}
202
203func New(args *Args) (*Server, error) {
204 if args.Addr == "" {
205 return nil, fmt.Errorf("addr must be set")
206 }
207
208 if args.DbName == "" {
209 return nil, fmt.Errorf("db name must be set")
210 }
211
212 if args.Did == "" {
213 return nil, fmt.Errorf("cocoon did must be set")
214 }
215
216 if args.ContactEmail == "" {
217 return nil, fmt.Errorf("cocoon contact email is required")
218 }
219
220 if _, err := syntax.ParseDID(args.Did); err != nil {
221 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
222 }
223
224 if args.Hostname == "" {
225 return nil, fmt.Errorf("cocoon hostname must be set")
226 }
227
228 if args.Logger == nil {
229 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
230 }
231
232 e := echo.New()
233
234 e.Pre(middleware.RemoveTrailingSlash())
235 e.Pre(slogecho.New(args.Logger))
236 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
237 AllowOrigins: []string{"*"},
238 AllowHeaders: []string{"*"},
239 AllowMethods: []string{"*"},
240 AllowCredentials: true,
241 MaxAge: 100_000_000,
242 }))
243
244 vdtor := validator.New()
245 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
246 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
247 return false
248 }
249 return true
250 })
251 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
252 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
253 return false
254 }
255 return true
256 })
257 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
258 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
259 return false
260 }
261 return true
262 })
263 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
264 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
265 return false
266 }
267 return true
268 })
269
270 e.Validator = &CustomValidator{validator: vdtor}
271
272 httpd := &http.Server{
273 Addr: args.Addr,
274 Handler: e,
275 }
276
277 db, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
278 if err != nil {
279 return nil, err
280 }
281
282 rkbytes, err := os.ReadFile(args.RotationKeyPath)
283 if err != nil {
284 return nil, err
285 }
286
287 h := util.RobustHTTPClient()
288
289 plcClient, err := plc.NewClient(&plc.ClientArgs{
290 H: h,
291 Service: "https://plc.directory",
292 PdsHostname: args.Hostname,
293 RotationKey: rkbytes,
294 })
295 if err != nil {
296 return nil, err
297 }
298
299 jwkbytes, err := os.ReadFile(args.JwkPath)
300 if err != nil {
301 return nil, err
302 }
303
304 key, err := jwk.ParseKey(jwkbytes)
305 if err != nil {
306 return nil, err
307 }
308
309 var pkey ecdsa.PrivateKey
310 if err := key.Raw(&pkey); err != nil {
311 return nil, err
312 }
313
314 s := &Server{
315 http: h,
316 httpd: httpd,
317 echo: e,
318 logger: args.Logger,
319 db: db,
320 plcClient: plcClient,
321 privateKey: &pkey,
322 config: &config{
323 Version: args.Version,
324 Did: args.Did,
325 Hostname: args.Hostname,
326 ContactEmail: args.ContactEmail,
327 EnforcePeering: false,
328 Relays: args.Relays,
329 SmtpName: args.SmtpName,
330 SmtpEmail: args.SmtpEmail,
331 },
332 evtman: events.NewEventManager(events.NewMemPersister()),
333 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
334 }
335
336 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
337
338 // TODO: should validate these args
339 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
340 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.")
341 } else {
342 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
343 mail.From(s.config.SmtpEmail)
344 mail.FromName(s.config.SmtpName)
345
346 s.mail = mail
347 s.mailLk = &sync.Mutex{}
348 }
349
350 return s, nil
351}
352
353func (s *Server) addRoutes() {
354 // random stuff
355 s.echo.GET("/", s.handleRoot)
356 s.echo.GET("/xrpc/_health", s.handleHealth)
357 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
358 s.echo.GET("/robots.txt", s.handleRobots)
359
360 // public
361 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
362 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
363 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
364 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
365 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
366
367 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
368 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
369 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
370 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
371 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
372 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
373 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
374 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
375 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
376 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
377 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
378 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
379
380 // authed
381 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleSessionMiddleware)
382 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleSessionMiddleware)
383 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleSessionMiddleware)
384 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleSessionMiddleware)
385 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleSessionMiddleware)
386 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleSessionMiddleware)
387 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
388 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleSessionMiddleware)
389 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleSessionMiddleware)
390 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleSessionMiddleware)
391
392 // repo
393 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleSessionMiddleware)
394 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleSessionMiddleware)
395 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleSessionMiddleware)
396 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleSessionMiddleware)
397 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleSessionMiddleware)
398
399 // stupid silly endpoints
400 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleSessionMiddleware)
401 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleSessionMiddleware)
402
403 // are there any routes that we should be allowing without auth? i dont think so but idk
404 s.echo.GET("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
405 s.echo.POST("/xrpc/*", s.handleProxy, s.handleSessionMiddleware)
406}
407
408func (s *Server) Serve(ctx context.Context) error {
409 s.addRoutes()
410
411 s.logger.Info("migrating...")
412
413 s.db.AutoMigrate(
414 &models.Actor{},
415 &models.Repo{},
416 &models.InviteCode{},
417 &models.Token{},
418 &models.RefreshToken{},
419 &models.Block{},
420 &models.Record{},
421 &models.Blob{},
422 &models.BlobPart{},
423 )
424
425 s.logger.Info("starting cocoon")
426
427 go func() {
428 if err := s.httpd.ListenAndServe(); err != nil {
429 panic(err)
430 }
431 }()
432
433 for _, relay := range s.config.Relays {
434 cli := xrpc.Client{Host: relay}
435 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
436 Hostname: s.config.Hostname,
437 })
438 }
439
440 <-ctx.Done()
441
442 fmt.Println("shut down")
443
444 return nil
445}