1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "embed"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "net/smtp"
14 "os"
15 "path/filepath"
16 "strings"
17 "sync"
18 "text/template"
19 "time"
20
21 "github.com/Azure/go-autorest/autorest/to"
22 "github.com/aws/aws-sdk-go/aws"
23 "github.com/aws/aws-sdk-go/aws/credentials"
24 "github.com/aws/aws-sdk-go/aws/session"
25 "github.com/aws/aws-sdk-go/service/s3"
26 "github.com/bluesky-social/indigo/api/atproto"
27 "github.com/bluesky-social/indigo/atproto/syntax"
28 "github.com/bluesky-social/indigo/events"
29 "github.com/bluesky-social/indigo/util"
30 "github.com/bluesky-social/indigo/xrpc"
31 "github.com/domodwyer/mailyak/v3"
32 "github.com/go-playground/validator"
33 "github.com/golang-jwt/jwt/v4"
34 "github.com/gorilla/sessions"
35 "github.com/haileyok/cocoon/identity"
36 "github.com/haileyok/cocoon/internal/db"
37 "github.com/haileyok/cocoon/internal/helpers"
38 "github.com/haileyok/cocoon/models"
39 "github.com/haileyok/cocoon/oauth/client_manager"
40 "github.com/haileyok/cocoon/oauth/constants"
41 "github.com/haileyok/cocoon/oauth/dpop/dpop_manager"
42 "github.com/haileyok/cocoon/oauth/provider"
43 "github.com/haileyok/cocoon/plc"
44 echo_session "github.com/labstack/echo-contrib/session"
45 "github.com/labstack/echo/v4"
46 "github.com/labstack/echo/v4/middleware"
47 slogecho "github.com/samber/slog-echo"
48 "gorm.io/driver/sqlite"
49 "gorm.io/gorm"
50)
51
52const (
53 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
54)
55
56type S3Config struct {
57 BackupsEnabled bool
58 Endpoint string
59 Region string
60 Bucket string
61 AccessKey string
62 SecretKey string
63}
64
65type Server struct {
66 http *http.Client
67 httpd *http.Server
68 mail *mailyak.MailYak
69 mailLk *sync.Mutex
70 echo *echo.Echo
71 db *db.DB
72 plcClient *plc.Client
73 logger *slog.Logger
74 config *config
75 privateKey *ecdsa.PrivateKey
76 repoman *RepoMan
77 oauthProvider *provider.Provider
78 evtman *events.EventManager
79 passport *identity.Passport
80
81 dbName string
82 s3Config *S3Config
83}
84
85type Args struct {
86 Addr string
87 DbName string
88 Logger *slog.Logger
89 Version string
90 Did string
91 Hostname string
92 RotationKeyPath string
93 JwkPath string
94 ContactEmail string
95 Relays []string
96 AdminPassword string
97
98 SmtpUser string
99 SmtpPass string
100 SmtpHost string
101 SmtpPort string
102 SmtpEmail string
103 SmtpName string
104
105 S3Config *S3Config
106
107 SessionSecret string
108}
109
110type config struct {
111 Version string
112 Did string
113 Hostname string
114 ContactEmail string
115 EnforcePeering bool
116 Relays []string
117 AdminPassword string
118 SmtpEmail string
119 SmtpName string
120}
121
122type CustomValidator struct {
123 validator *validator.Validate
124}
125
126type ValidationError struct {
127 error
128 Field string
129 Tag string
130}
131
132func (cv *CustomValidator) Validate(i any) error {
133 if err := cv.validator.Struct(i); err != nil {
134 var validateErrors validator.ValidationErrors
135 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
136 first := validateErrors[0]
137 return ValidationError{
138 error: err,
139 Field: first.Field(),
140 Tag: first.Tag(),
141 }
142 }
143
144 return err
145 }
146
147 return nil
148}
149
150//go:embed templates/*
151var templateFS embed.FS
152
153//go:embed static/*
154var staticFS embed.FS
155
156type TemplateRenderer struct {
157 templates *template.Template
158 isDev bool
159 templatePath string
160}
161
162func (s *Server) loadTemplates() {
163 absPath, _ := filepath.Abs("server/templates/*.html")
164 if s.config.Version == "dev" {
165 tmpl := template.Must(template.ParseGlob(absPath))
166 s.echo.Renderer = &TemplateRenderer{
167 templates: tmpl,
168 isDev: true,
169 templatePath: absPath,
170 }
171 } else {
172 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
173 s.echo.Renderer = &TemplateRenderer{
174 templates: tmpl,
175 isDev: false,
176 }
177 }
178}
179
180func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
181 if t.isDev {
182 tmpl, err := template.ParseGlob(t.templatePath)
183 if err != nil {
184 return err
185 }
186 t.templates = tmpl
187 }
188
189 if viewContext, isMap := data.(map[string]any); isMap {
190 viewContext["reverse"] = c.Echo().Reverse
191 }
192
193 return t.templates.ExecuteTemplate(w, name, data)
194}
195
196func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
197 return func(e echo.Context) error {
198 username, password, ok := e.Request().BasicAuth()
199 if !ok || username != "admin" || password != s.config.AdminPassword {
200 return helpers.InputError(e, to.StringPtr("Unauthorized"))
201 }
202
203 if err := next(e); err != nil {
204 e.Error(err)
205 }
206
207 return nil
208 }
209}
210
211func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
212 return func(e echo.Context) error {
213 authheader := e.Request().Header.Get("authorization")
214 if authheader == "" {
215 return e.JSON(401, map[string]string{"error": "Unauthorized"})
216 }
217
218 pts := strings.Split(authheader, " ")
219 if len(pts) != 2 {
220 return helpers.ServerError(e, nil)
221 }
222
223 if pts[0] == "DPoP" {
224 return next(e)
225 }
226
227 tokenstr := pts[1]
228
229 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
230 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
231 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
232 }
233
234 return s.privateKey.Public(), nil
235 })
236 if err != nil {
237 s.logger.Error("error parsing jwt", "error", err)
238 // NOTE: https://github.com/bluesky-social/atproto/discussions/3319
239 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
240 }
241
242 claims, ok := token.Claims.(jwt.MapClaims)
243 if !ok || !token.Valid {
244 return helpers.InputError(e, to.StringPtr("InvalidToken"))
245 }
246
247 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
248 scope := claims["scope"].(string)
249
250 if isRefresh && scope != "com.atproto.refresh" {
251 return helpers.InputError(e, to.StringPtr("InvalidToken"))
252 } else if !isRefresh && scope != "com.atproto.access" {
253 return helpers.InputError(e, to.StringPtr("InvalidToken"))
254 }
255
256 table := "tokens"
257 if isRefresh {
258 table = "refresh_tokens"
259 }
260
261 type Result struct {
262 Found bool
263 }
264 var result Result
265 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
266 if err == gorm.ErrRecordNotFound {
267 return helpers.InputError(e, to.StringPtr("InvalidToken"))
268 }
269
270 s.logger.Error("error getting token from db", "error", err)
271 return helpers.ServerError(e, nil)
272 }
273
274 if !result.Found {
275 return helpers.InputError(e, to.StringPtr("InvalidToken"))
276 }
277
278 exp, ok := claims["exp"].(float64)
279 if !ok {
280 s.logger.Error("error getting iat from token")
281 return helpers.ServerError(e, nil)
282 }
283
284 if exp < float64(time.Now().UTC().Unix()) {
285 return helpers.InputError(e, to.StringPtr("ExpiredToken"))
286 }
287
288 repo, err := s.getRepoActorByDid(claims["sub"].(string))
289 if err != nil {
290 s.logger.Error("error fetching repo", "error", err)
291 return helpers.ServerError(e, nil)
292 }
293
294 e.Set("repo", repo)
295 e.Set("did", claims["sub"])
296 e.Set("token", tokenstr)
297
298 if err := next(e); err != nil {
299 e.Error(err)
300 }
301
302 return nil
303 }
304}
305
306func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
307 return func(e echo.Context) error {
308 authheader := e.Request().Header.Get("authorization")
309 if authheader == "" {
310 return e.JSON(401, map[string]string{"error": "Unauthorized"})
311 }
312
313 pts := strings.Split(authheader, " ")
314 if len(pts) != 2 {
315 return helpers.ServerError(e, nil)
316 }
317
318 if pts[0] != "DPoP" {
319 return next(e)
320 }
321
322 accessToken := pts[1]
323
324 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken))
325 if err != nil {
326 s.logger.Error("invalid dpop proof", "error", err)
327 return helpers.InputError(e, to.StringPtr(err.Error()))
328 }
329
330 var oauthToken provider.OauthToken
331 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
332 s.logger.Error("error finding access token in db", "error", err)
333 return helpers.InputError(e, nil)
334 }
335
336 if oauthToken.Token == "" {
337 return helpers.InputError(e, to.StringPtr("InvalidToken"))
338 }
339
340 if *oauthToken.Parameters.DpopJkt != proof.JKT {
341 s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT)
342 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch"))
343 }
344
345 if time.Now().After(oauthToken.ExpiresAt) {
346 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
347 }
348
349 repo, err := s.getRepoActorByDid(oauthToken.Sub)
350 if err != nil {
351 s.logger.Error("could not find actor in db", "error", err)
352 return helpers.ServerError(e, nil)
353 }
354
355 nonce := s.oauthProvider.NextNonce()
356 if nonce != "" {
357 e.Response().Header().Set("DPoP-Nonce", nonce)
358 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
359 }
360
361 e.Set("repo", repo)
362 e.Set("did", repo.Repo.Did)
363 e.Set("token", accessToken)
364 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " "))
365
366 return next(e)
367 }
368}
369
370func New(args *Args) (*Server, error) {
371 if args.Addr == "" {
372 return nil, fmt.Errorf("addr must be set")
373 }
374
375 if args.DbName == "" {
376 return nil, fmt.Errorf("db name must be set")
377 }
378
379 if args.Did == "" {
380 return nil, fmt.Errorf("cocoon did must be set")
381 }
382
383 if args.ContactEmail == "" {
384 return nil, fmt.Errorf("cocoon contact email is required")
385 }
386
387 if _, err := syntax.ParseDID(args.Did); err != nil {
388 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
389 }
390
391 if args.Hostname == "" {
392 return nil, fmt.Errorf("cocoon hostname must be set")
393 }
394
395 if args.AdminPassword == "" {
396 return nil, fmt.Errorf("admin password must be set")
397 }
398
399 if args.Logger == nil {
400 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
401 }
402
403 if args.SessionSecret == "" {
404 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
405 }
406
407 e := echo.New()
408
409 e.Pre(middleware.RemoveTrailingSlash())
410 e.Pre(slogecho.New(args.Logger))
411 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
412 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
413 AllowOrigins: []string{"*"},
414 AllowHeaders: []string{"*"},
415 AllowMethods: []string{"*"},
416 AllowCredentials: true,
417 MaxAge: 100_000_000,
418 }))
419
420 vdtor := validator.New()
421 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
422 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
423 return false
424 }
425 return true
426 })
427 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
428 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
429 return false
430 }
431 return true
432 })
433 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
434 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
435 return false
436 }
437 return true
438 })
439 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
440 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
441 return false
442 }
443 return true
444 })
445
446 e.Validator = &CustomValidator{validator: vdtor}
447
448 httpd := &http.Server{
449 Addr: args.Addr,
450 Handler: e,
451 // shitty defaults but okay for now, needed for import repo
452 ReadTimeout: 5 * time.Minute,
453 WriteTimeout: 5 * time.Minute,
454 IdleTimeout: 5 * time.Minute,
455 }
456
457 gdb, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
458 if err != nil {
459 return nil, err
460 }
461 dbw := db.NewDB(gdb)
462
463 rkbytes, err := os.ReadFile(args.RotationKeyPath)
464 if err != nil {
465 return nil, err
466 }
467
468 h := util.RobustHTTPClient()
469
470 plcClient, err := plc.NewClient(&plc.ClientArgs{
471 H: h,
472 Service: "https://plc.directory",
473 PdsHostname: args.Hostname,
474 RotationKey: rkbytes,
475 })
476 if err != nil {
477 return nil, err
478 }
479
480 jwkbytes, err := os.ReadFile(args.JwkPath)
481 if err != nil {
482 return nil, err
483 }
484
485 key, err := helpers.ParseJWKFromBytes(jwkbytes)
486 if err != nil {
487 return nil, err
488 }
489
490 var pkey ecdsa.PrivateKey
491 if err := key.Raw(&pkey); err != nil {
492 return nil, err
493 }
494
495 oauthCli := &http.Client{
496 Timeout: 10 * time.Second,
497 }
498
499 var nonceSecret []byte
500 maybeSecret, err := os.ReadFile("nonce.secret")
501 if err != nil && !os.IsNotExist(err) {
502 args.Logger.Error("error attempting to read nonce secret", "error", err)
503 } else {
504 nonceSecret = maybeSecret
505 }
506
507 s := &Server{
508 http: h,
509 httpd: httpd,
510 echo: e,
511 logger: args.Logger,
512 db: dbw,
513 plcClient: plcClient,
514 privateKey: &pkey,
515 config: &config{
516 Version: args.Version,
517 Did: args.Did,
518 Hostname: args.Hostname,
519 ContactEmail: args.ContactEmail,
520 EnforcePeering: false,
521 Relays: args.Relays,
522 AdminPassword: args.AdminPassword,
523 SmtpName: args.SmtpName,
524 SmtpEmail: args.SmtpEmail,
525 },
526 evtman: events.NewEventManager(events.NewMemPersister()),
527 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
528
529 dbName: args.DbName,
530 s3Config: args.S3Config,
531
532 oauthProvider: provider.NewProvider(provider.Args{
533 Hostname: args.Hostname,
534 ClientManagerArgs: client_manager.Args{
535 Cli: oauthCli,
536 Logger: args.Logger,
537 },
538 DpopManagerArgs: dpop_manager.Args{
539 NonceSecret: nonceSecret,
540 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
541 OnNonceSecretCreated: func(newNonce []byte) {
542 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
543 args.Logger.Error("error writing new nonce secret", "error", err)
544 }
545 },
546 Logger: args.Logger,
547 Hostname: args.Hostname,
548 },
549 }),
550 }
551
552 s.loadTemplates()
553
554 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
555
556 // TODO: should validate these args
557 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
558 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.")
559 } else {
560 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
561 mail.From(s.config.SmtpEmail)
562 mail.FromName(s.config.SmtpName)
563
564 s.mail = mail
565 s.mailLk = &sync.Mutex{}
566 }
567
568 return s, nil
569}
570
571func (s *Server) addRoutes() {
572 // static
573 if s.config.Version == "dev" {
574 s.echo.Static("/static", "server/static")
575 } else {
576 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
577 }
578
579 // random stuff
580 s.echo.GET("/", s.handleRoot)
581 s.echo.GET("/xrpc/_health", s.handleHealth)
582 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
583 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
584 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
585 s.echo.GET("/robots.txt", s.handleRobots)
586
587 // public
588 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
589 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
590 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
591 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
592 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
593
594 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
595 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
596 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
597 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
598 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
599 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
600 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
601 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
602 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
603 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
604 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
605 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
606
607 // account
608 s.echo.GET("/account", s.handleAccount)
609 s.echo.POST("/account/revoke", s.handleAccountRevoke)
610 s.echo.GET("/account/signin", s.handleAccountSigninGet)
611 s.echo.POST("/account/signin", s.handleAccountSigninPost)
612 s.echo.GET("/account/signout", s.handleAccountSignout)
613
614 // oauth account
615 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
616 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
617 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
618
619 // oauth authorization
620 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
621 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
622
623 // authed
624 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
625 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
626 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
627 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
628 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
629 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
630 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
631 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
632 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
633 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
634 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
635
636 // repo
637 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
638 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
639 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
640 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
641 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
642 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
643
644 // stupid silly endpoints
645 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
646 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
647
648 // are there any routes that we should be allowing without auth? i dont think so but idk
649 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
650 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
651
652 // admin routes
653 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
654 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
655}
656
657func (s *Server) Serve(ctx context.Context) error {
658 s.addRoutes()
659
660 s.logger.Info("migrating...")
661
662 s.db.AutoMigrate(
663 &models.Actor{},
664 &models.Repo{},
665 &models.InviteCode{},
666 &models.Token{},
667 &models.RefreshToken{},
668 &models.Block{},
669 &models.Record{},
670 &models.Blob{},
671 &models.BlobPart{},
672 &provider.OauthToken{},
673 &provider.OauthAuthorizationRequest{},
674 )
675
676 s.logger.Info("starting cocoon")
677
678 go func() {
679 if err := s.httpd.ListenAndServe(); err != nil {
680 panic(err)
681 }
682 }()
683
684 go s.backupRoutine()
685
686 for _, relay := range s.config.Relays {
687 cli := xrpc.Client{Host: relay}
688 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
689 Hostname: s.config.Hostname,
690 })
691 }
692
693 <-ctx.Done()
694
695 fmt.Println("shut down")
696
697 return nil
698}
699
700func (s *Server) doBackup() {
701 start := time.Now()
702
703 s.logger.Info("beginning backup to s3...")
704
705 var buf bytes.Buffer
706 if err := func() error {
707 s.logger.Info("reading database bytes...")
708 s.db.Lock()
709 defer s.db.Unlock()
710
711 sf, err := os.Open(s.dbName)
712 if err != nil {
713 return fmt.Errorf("error opening database for backup: %w", err)
714 }
715 defer sf.Close()
716
717 if _, err := io.Copy(&buf, sf); err != nil {
718 return fmt.Errorf("error reading bytes of backup db: %w", err)
719 }
720
721 return nil
722 }(); err != nil {
723 s.logger.Error("error backing up database", "error", err)
724 return
725 }
726
727 if err := func() error {
728 s.logger.Info("sending to s3...")
729
730 currTime := time.Now().Format("2006-01-02_15-04-05")
731 key := "cocoon-backup-" + currTime + ".db"
732
733 config := &aws.Config{
734 Region: aws.String(s.s3Config.Region),
735 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
736 }
737
738 if s.s3Config.Endpoint != "" {
739 config.Endpoint = aws.String(s.s3Config.Endpoint)
740 config.S3ForcePathStyle = aws.Bool(true)
741 }
742
743 sess, err := session.NewSession(config)
744 if err != nil {
745 return err
746 }
747
748 svc := s3.New(sess)
749
750 if _, err := svc.PutObject(&s3.PutObjectInput{
751 Bucket: aws.String(s.s3Config.Bucket),
752 Key: aws.String(key),
753 Body: bytes.NewReader(buf.Bytes()),
754 }); err != nil {
755 return fmt.Errorf("error uploading file to s3: %w", err)
756 }
757
758 s.logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds())
759
760 return nil
761 }(); err != nil {
762 s.logger.Error("error uploading database backup", "error", err)
763 return
764 }
765
766 os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644)
767}
768
769func (s *Server) backupRoutine() {
770 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
771 return
772 }
773
774 if s.s3Config.Region == "" {
775 s.logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
776 return
777 }
778
779 if s.s3Config.Bucket == "" {
780 s.logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
781 return
782 }
783
784 if s.s3Config.AccessKey == "" {
785 s.logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
786 return
787 }
788
789 if s.s3Config.SecretKey == "" {
790 s.logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
791 return
792 }
793
794 shouldBackupNow := false
795 lastBackupStr, err := os.ReadFile("last-backup.txt")
796 if err != nil {
797 shouldBackupNow = true
798 } else {
799 lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr))
800 if err != nil {
801 shouldBackupNow = true
802 } else if time.Now().Sub(lastBackup).Seconds() > 3600 {
803 shouldBackupNow = true
804 }
805 }
806
807 if shouldBackupNow {
808 go s.doBackup()
809 }
810
811 ticker := time.NewTicker(time.Hour)
812 for range ticker.C {
813 go s.doBackup()
814 }
815}