forked from
hailey.at/cocoon
An atproto PDS written in Go
1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "embed"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "net/smtp"
14 "os"
15 "path/filepath"
16 "sync"
17 "text/template"
18 "time"
19
20 "github.com/aws/aws-sdk-go/aws"
21 "github.com/aws/aws-sdk-go/aws/credentials"
22 "github.com/aws/aws-sdk-go/aws/session"
23 "github.com/aws/aws-sdk-go/service/s3"
24 "github.com/bluesky-social/indigo/api/atproto"
25 "github.com/bluesky-social/indigo/atproto/syntax"
26 "github.com/bluesky-social/indigo/events"
27 "github.com/bluesky-social/indigo/util"
28 "github.com/bluesky-social/indigo/xrpc"
29 "github.com/domodwyer/mailyak/v3"
30 "github.com/go-playground/validator"
31 "github.com/gorilla/sessions"
32 "github.com/haileyok/cocoon/identity"
33 "github.com/haileyok/cocoon/internal/db"
34 "github.com/haileyok/cocoon/internal/helpers"
35 "github.com/haileyok/cocoon/models"
36 "github.com/haileyok/cocoon/oauth/client"
37 "github.com/haileyok/cocoon/oauth/constants"
38 "github.com/haileyok/cocoon/oauth/dpop"
39 "github.com/haileyok/cocoon/oauth/provider"
40 "github.com/haileyok/cocoon/plc"
41 "github.com/ipfs/go-cid"
42 "github.com/labstack/echo-contrib/echoprometheus"
43 echo_session "github.com/labstack/echo-contrib/session"
44 "github.com/labstack/echo/v4"
45 "github.com/labstack/echo/v4/middleware"
46 slogecho "github.com/samber/slog-echo"
47 "gorm.io/driver/postgres"
48 "gorm.io/driver/sqlite"
49 "gorm.io/gorm"
50)
51
52const (
53 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
54)
55
56type S3Config struct {
57 BackupsEnabled bool
58 BlobstoreEnabled bool
59 Endpoint string
60 Region string
61 Bucket string
62 AccessKey string
63 SecretKey string
64 CDNUrl string
65}
66
67type Server struct {
68 http *http.Client
69 httpd *http.Server
70 mail *mailyak.MailYak
71 mailLk *sync.Mutex
72 echo *echo.Echo
73 db *db.DB
74 plcClient *plc.Client
75 logger *slog.Logger
76 config *config
77 privateKey *ecdsa.PrivateKey
78 repoman *RepoMan
79 oauthProvider *provider.Provider
80 evtman *events.EventManager
81 passport *identity.Passport
82 fallbackProxy string
83
84 lastRequestCrawl time.Time
85 requestCrawlMu sync.Mutex
86
87 dbName string
88 dbType string
89 s3Config *S3Config
90}
91
92type Args struct {
93 Logger *slog.Logger
94
95 LogLevel slog.Level
96 Addr string
97 DbName string
98 DbType string
99 DatabaseURL string
100 Version string
101 Did string
102 Hostname string
103 RotationKeyPath string
104 JwkPath string
105 ContactEmail string
106 Relays []string
107 AdminPassword string
108 RequireInvite bool
109
110 SmtpUser string
111 SmtpPass string
112 SmtpHost string
113 SmtpPort string
114 SmtpEmail string
115 SmtpName string
116
117 S3Config *S3Config
118
119 SessionSecret string
120 SessionCookieKey string
121
122 BlockstoreVariant BlockstoreVariant
123 FallbackProxy string
124}
125
126type config struct {
127 LogLevel slog.Level
128 Version string
129 Did string
130 Hostname string
131 ContactEmail string
132 EnforcePeering bool
133 Relays []string
134 AdminPassword string
135 RequireInvite bool
136 SmtpEmail string
137 SmtpName string
138 SessionCookieKey string
139 BlockstoreVariant BlockstoreVariant
140 FallbackProxy string
141}
142
143type CustomValidator struct {
144 validator *validator.Validate
145}
146
147type ValidationError struct {
148 error
149 Field string
150 Tag string
151}
152
153func (cv *CustomValidator) Validate(i any) error {
154 if err := cv.validator.Struct(i); err != nil {
155 var validateErrors validator.ValidationErrors
156 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
157 first := validateErrors[0]
158 return ValidationError{
159 error: err,
160 Field: first.Field(),
161 Tag: first.Tag(),
162 }
163 }
164
165 return err
166 }
167
168 return nil
169}
170
171//go:embed templates/*
172var templateFS embed.FS
173
174//go:embed static/*
175var staticFS embed.FS
176
177type TemplateRenderer struct {
178 templates *template.Template
179 isDev bool
180 templatePath string
181}
182
183func (s *Server) loadTemplates() {
184 absPath, _ := filepath.Abs("server/templates/*.html")
185 if s.config.Version == "dev" {
186 tmpl := template.Must(template.ParseGlob(absPath))
187 s.echo.Renderer = &TemplateRenderer{
188 templates: tmpl,
189 isDev: true,
190 templatePath: absPath,
191 }
192 } else {
193 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
194 s.echo.Renderer = &TemplateRenderer{
195 templates: tmpl,
196 isDev: false,
197 }
198 }
199}
200
201func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
202 if t.isDev {
203 tmpl, err := template.ParseGlob(t.templatePath)
204 if err != nil {
205 return err
206 }
207 t.templates = tmpl
208 }
209
210 if viewContext, isMap := data.(map[string]any); isMap {
211 viewContext["reverse"] = c.Echo().Reverse
212 }
213
214 return t.templates.ExecuteTemplate(w, name, data)
215}
216
217type filteredHandler struct {
218 level slog.Level
219 handler slog.Handler
220}
221
222func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool {
223 return level >= h.level && h.handler.Enabled(ctx, level)
224}
225
226func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error {
227 return h.handler.Handle(ctx, r)
228}
229
230func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
231 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)}
232}
233
234func (h *filteredHandler) WithGroup(name string) slog.Handler {
235 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)}
236}
237
238func New(args *Args) (*Server, error) {
239 if args.Logger == nil {
240 args.Logger = slog.Default()
241 }
242
243 if args.LogLevel != 0 {
244 args.Logger = slog.New(&filteredHandler{
245 level: args.LogLevel,
246 handler: args.Logger.Handler(),
247 })
248 }
249
250 logger := args.Logger.With("name", "New")
251
252 if args.Addr == "" {
253 return nil, fmt.Errorf("addr must be set")
254 }
255
256 if args.DbName == "" {
257 return nil, fmt.Errorf("db name must be set")
258 }
259
260 if args.Did == "" {
261 return nil, fmt.Errorf("cocoon did must be set")
262 }
263
264 if args.ContactEmail == "" {
265 return nil, fmt.Errorf("cocoon contact email is required")
266 }
267
268 if _, err := syntax.ParseDID(args.Did); err != nil {
269 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
270 }
271
272 if args.Hostname == "" {
273 return nil, fmt.Errorf("cocoon hostname must be set")
274 }
275
276 if args.AdminPassword == "" {
277 return nil, fmt.Errorf("admin password must be set")
278 }
279
280 if args.SessionSecret == "" {
281 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
282 }
283
284 e := echo.New()
285
286 e.Pre(middleware.RemoveTrailingSlash())
287 e.Pre(slogecho.New(args.Logger.With("component", "slogecho")))
288 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
289 e.Use(echoprometheus.NewMiddleware("cocoon"))
290 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
291 AllowOrigins: []string{"*"},
292 AllowHeaders: []string{"*"},
293 AllowMethods: []string{"*"},
294 AllowCredentials: true,
295 MaxAge: 100_000_000,
296 }))
297
298 vdtor := validator.New()
299 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
300 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
301 return false
302 }
303 return true
304 })
305 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
306 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
307 return false
308 }
309 return true
310 })
311 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
312 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
313 return false
314 }
315 return true
316 })
317 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
318 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
319 return false
320 }
321 return true
322 })
323
324 e.Validator = &CustomValidator{validator: vdtor}
325
326 httpd := &http.Server{
327 Addr: args.Addr,
328 Handler: e,
329 // shitty defaults but okay for now, needed for import repo
330 ReadTimeout: 5 * time.Minute,
331 WriteTimeout: 5 * time.Minute,
332 IdleTimeout: 5 * time.Minute,
333 }
334
335 dbType := args.DbType
336 if dbType == "" {
337 dbType = "sqlite"
338 }
339
340 var gdb *gorm.DB
341 var err error
342 switch dbType {
343 case "postgres":
344 if args.DatabaseURL == "" {
345 return nil, fmt.Errorf("database-url must be set when using postgres")
346 }
347 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{})
348 if err != nil {
349 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
350 }
351 logger.Info("connected to PostgreSQL database")
352 default:
353 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{})
354 if err != nil {
355 return nil, fmt.Errorf("failed to open sqlite database: %w", err)
356 }
357 gdb.Exec("PRAGMA journal_mode=WAL")
358 gdb.Exec("PRAGMA synchronous=NORMAL")
359
360 logger.Info("connected to SQLite database", "path", args.DbName)
361 }
362 dbw := db.NewDB(gdb)
363
364 rkbytes, err := os.ReadFile(args.RotationKeyPath)
365 if err != nil {
366 return nil, err
367 }
368
369 h := util.RobustHTTPClient()
370
371 plcClient, err := plc.NewClient(&plc.ClientArgs{
372 H: h,
373 Service: "https://plc.directory",
374 PdsHostname: args.Hostname,
375 RotationKey: rkbytes,
376 })
377 if err != nil {
378 return nil, err
379 }
380
381 jwkbytes, err := os.ReadFile(args.JwkPath)
382 if err != nil {
383 return nil, err
384 }
385
386 key, err := helpers.ParseJWKFromBytes(jwkbytes)
387 if err != nil {
388 return nil, err
389 }
390
391 var pkey ecdsa.PrivateKey
392 if err := key.Raw(&pkey); err != nil {
393 return nil, err
394 }
395
396 oauthCli := &http.Client{
397 Timeout: 10 * time.Second,
398 }
399
400 var nonceSecret []byte
401 maybeSecret, err := os.ReadFile("nonce.secret")
402 if err != nil && !os.IsNotExist(err) {
403 logger.Error("error attempting to read nonce secret", "error", err)
404 } else {
405 nonceSecret = maybeSecret
406 }
407
408 s := &Server{
409 http: h,
410 httpd: httpd,
411 echo: e,
412 logger: args.Logger,
413 db: dbw,
414 plcClient: plcClient,
415 privateKey: &pkey,
416 config: &config{
417 LogLevel: args.LogLevel,
418 Version: args.Version,
419 Did: args.Did,
420 Hostname: args.Hostname,
421 ContactEmail: args.ContactEmail,
422 EnforcePeering: false,
423 Relays: args.Relays,
424 AdminPassword: args.AdminPassword,
425 RequireInvite: args.RequireInvite,
426 SmtpName: args.SmtpName,
427 SmtpEmail: args.SmtpEmail,
428 SessionCookieKey: args.SessionCookieKey,
429 BlockstoreVariant: args.BlockstoreVariant,
430 FallbackProxy: args.FallbackProxy,
431 },
432 evtman: events.NewEventManager(events.NewMemPersister()),
433 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
434
435 dbName: args.DbName,
436 dbType: dbType,
437 s3Config: args.S3Config,
438
439 oauthProvider: provider.NewProvider(provider.Args{
440 Hostname: args.Hostname,
441 ClientManagerArgs: client.ManagerArgs{
442 Cli: oauthCli,
443 Logger: args.Logger.With("component", "oauth-client-manager"),
444 },
445 DpopManagerArgs: dpop.ManagerArgs{
446 NonceSecret: nonceSecret,
447 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
448 OnNonceSecretCreated: func(newNonce []byte) {
449 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
450 logger.Error("error writing new nonce secret", "error", err)
451 }
452 },
453 Logger: args.Logger.With("component", "dpop-manager"),
454 Hostname: args.Hostname,
455 },
456 }),
457 }
458
459 s.loadTemplates()
460
461 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
462
463 // TODO: should validate these args
464 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
465 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
466 } else {
467 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
468 mail.From(s.config.SmtpEmail)
469 mail.FromName(s.config.SmtpName)
470
471 s.mail = mail
472 s.mailLk = &sync.Mutex{}
473 }
474
475 return s, nil
476}
477
478func (s *Server) addRoutes() {
479 // static
480 if s.config.Version == "dev" {
481 s.echo.Static("/static", "server/static")
482 } else {
483 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
484 }
485
486 // random stuff
487 s.echo.GET("/", s.handleRoot)
488 s.echo.GET("/xrpc/_health", s.handleHealth)
489 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
490 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid)
491 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
492 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
493 s.echo.GET("/robots.txt", s.handleRobots)
494
495 // public
496 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
497 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
498 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
499 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
500 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey)
501
502 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
503 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
504 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
505 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
506 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
507 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
508 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
509 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
510 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
511 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
512 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
513 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
514
515 // labels
516 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels)
517
518 // account
519 s.echo.GET("/account", s.handleAccount)
520 s.echo.POST("/account/revoke", s.handleAccountRevoke)
521 s.echo.GET("/account/signin", s.handleAccountSigninGet)
522 s.echo.POST("/account/signin", s.handleAccountSigninPost)
523 s.echo.GET("/account/signout", s.handleAccountSignout)
524
525 // oauth account
526 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
527 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
528 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
529
530 // oauth authorization
531 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
532 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
533
534 // authed
535 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
536 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
537 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
538 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
539 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
540 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
541 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
542 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
543 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
544 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
545 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
546 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
547 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
548 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
549 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
550 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
551 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
552 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
553 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
554 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount)
555
556 // repo
557 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
558 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
559 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
560 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
561 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
562 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
563 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
564
565 // stupid silly endpoints
566 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
567 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
568 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
569 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
570 // admin routes
571 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
572 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
573
574 // are there any routes that we should be allowing without auth? i dont think so but idk
575 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
576 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
577}
578
579func (s *Server) Serve(ctx context.Context) error {
580 logger := s.logger.With("name", "Serve")
581
582 s.addRoutes()
583
584 logger.Info("migrating...")
585
586 s.db.AutoMigrate(
587 &models.Actor{},
588 &models.Repo{},
589 &models.InviteCode{},
590 &models.Token{},
591 &models.RefreshToken{},
592 &models.Block{},
593 &models.Record{},
594 &models.Blob{},
595 &models.BlobPart{},
596 &models.ReservedKey{},
597 &provider.OauthToken{},
598 &provider.OauthAuthorizationRequest{},
599 )
600
601 logger.Info("starting cocoon")
602
603 go func() {
604 if err := s.httpd.ListenAndServe(); err != nil {
605 panic(err)
606 }
607 }()
608
609 go s.backupRoutine()
610
611 go func() {
612 if err := s.requestCrawl(ctx); err != nil {
613 logger.Error("error requesting crawls", "err", err)
614 }
615 }()
616
617 <-ctx.Done()
618
619 fmt.Println("shut down")
620
621 return nil
622}
623
624func (s *Server) requestCrawl(ctx context.Context) error {
625 logger := s.logger.With("component", "request-crawl")
626 s.requestCrawlMu.Lock()
627 defer s.requestCrawlMu.Unlock()
628
629 logger.Info("requesting crawl with configured relays")
630
631 if time.Since(s.lastRequestCrawl) <= 1*time.Minute {
632 return fmt.Errorf("a crawl request has already been made within the last minute")
633 }
634
635 for _, relay := range s.config.Relays {
636 logger := logger.With("relay", relay)
637 logger.Info("requesting crawl from relay")
638 cli := xrpc.Client{Host: relay}
639 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
640 Hostname: s.config.Hostname,
641 }); err != nil {
642 logger.Error("error requesting crawl", "err", err)
643 } else {
644 logger.Info("crawl requested successfully")
645 }
646 }
647
648 s.lastRequestCrawl = time.Now()
649
650 return nil
651}
652
653func (s *Server) doBackup() {
654 logger := s.logger.With("name", "doBackup")
655
656 if s.dbType == "postgres" {
657 logger.Info("skipping S3 backup - PostgreSQL backups should be handled externally (pg_dump, managed database backups, etc.)")
658 return
659 }
660
661 start := time.Now()
662
663 logger.Info("beginning backup to s3...")
664
665 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano))
666 defer os.Remove(tmpFile)
667
668 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil {
669 logger.Error("error creating tmp backup file", "err", err)
670 return
671 }
672
673 backupData, err := os.ReadFile(tmpFile)
674 if err != nil {
675 logger.Error("error reading tmp backup file", "err", err)
676 return
677 }
678
679 logger.Info("sending to s3...")
680
681 currTime := time.Now().Format("2006-01-02_15-04-05")
682 key := "cocoon-backup-" + currTime + ".db"
683
684 config := &aws.Config{
685 Region: aws.String(s.s3Config.Region),
686 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
687 }
688
689 if s.s3Config.Endpoint != "" {
690 config.Endpoint = aws.String(s.s3Config.Endpoint)
691 config.S3ForcePathStyle = aws.Bool(true)
692 }
693
694 sess, err := session.NewSession(config)
695 if err != nil {
696 logger.Error("error creating s3 session", "err", err)
697 return
698 }
699
700 svc := s3.New(sess)
701
702 if _, err := svc.PutObject(&s3.PutObjectInput{
703 Bucket: aws.String(s.s3Config.Bucket),
704 Key: aws.String(key),
705 Body: bytes.NewReader(backupData),
706 }); err != nil {
707 logger.Error("error uploading file to s3", "err", err)
708 return
709 }
710
711 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds())
712
713 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644)
714}
715
716func (s *Server) backupRoutine() {
717 logger := s.logger.With("name", "backupRoutine")
718
719 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
720 return
721 }
722
723 if s.s3Config.Region == "" {
724 logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
725 return
726 }
727
728 if s.s3Config.Bucket == "" {
729 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
730 return
731 }
732
733 if s.s3Config.AccessKey == "" {
734 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
735 return
736 }
737
738 if s.s3Config.SecretKey == "" {
739 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
740 return
741 }
742
743 shouldBackupNow := false
744 lastBackupStr, err := os.ReadFile("last-backup.txt")
745 if err != nil {
746 shouldBackupNow = true
747 } else {
748 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr))
749 if err != nil {
750 shouldBackupNow = true
751 } else if time.Since(lastBackup).Seconds() > 3600 {
752 shouldBackupNow = true
753 }
754 }
755
756 if shouldBackupNow {
757 go s.doBackup()
758 }
759
760 ticker := time.NewTicker(time.Hour)
761 for range ticker.C {
762 go s.doBackup()
763 }
764}
765
766func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
767 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
768 return err
769 }
770
771 return nil
772}