An atproto PDS written in Go
at efb83a84054d7539b552d08eee82db9ecd9ddb4e 772 lines 24 kB view raw
1package server 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "embed" 8 "errors" 9 "fmt" 10 "io" 11 "log/slog" 12 "net/http" 13 "net/smtp" 14 "os" 15 "path/filepath" 16 "sync" 17 "text/template" 18 "time" 19 20 "github.com/aws/aws-sdk-go/aws" 21 "github.com/aws/aws-sdk-go/aws/credentials" 22 "github.com/aws/aws-sdk-go/aws/session" 23 "github.com/aws/aws-sdk-go/service/s3" 24 "github.com/bluesky-social/indigo/api/atproto" 25 "github.com/bluesky-social/indigo/atproto/syntax" 26 "github.com/bluesky-social/indigo/events" 27 "github.com/bluesky-social/indigo/util" 28 "github.com/bluesky-social/indigo/xrpc" 29 "github.com/domodwyer/mailyak/v3" 30 "github.com/go-playground/validator" 31 "github.com/gorilla/sessions" 32 "github.com/haileyok/cocoon/identity" 33 "github.com/haileyok/cocoon/internal/db" 34 "github.com/haileyok/cocoon/internal/helpers" 35 "github.com/haileyok/cocoon/models" 36 "github.com/haileyok/cocoon/oauth/client" 37 "github.com/haileyok/cocoon/oauth/constants" 38 "github.com/haileyok/cocoon/oauth/dpop" 39 "github.com/haileyok/cocoon/oauth/provider" 40 "github.com/haileyok/cocoon/plc" 41 "github.com/ipfs/go-cid" 42 "github.com/labstack/echo-contrib/echoprometheus" 43 echo_session "github.com/labstack/echo-contrib/session" 44 "github.com/labstack/echo/v4" 45 "github.com/labstack/echo/v4/middleware" 46 slogecho "github.com/samber/slog-echo" 47 "gorm.io/driver/postgres" 48 "gorm.io/driver/sqlite" 49 "gorm.io/gorm" 50) 51 52const ( 53 AccountSessionMaxAge = 30 * 24 * time.Hour // one week 54) 55 56type S3Config struct { 57 BackupsEnabled bool 58 BlobstoreEnabled bool 59 Endpoint string 60 Region string 61 Bucket string 62 AccessKey string 63 SecretKey string 64 CDNUrl string 65} 66 67type Server struct { 68 http *http.Client 69 httpd *http.Server 70 mail *mailyak.MailYak 71 mailLk *sync.Mutex 72 echo *echo.Echo 73 db *db.DB 74 plcClient *plc.Client 75 logger *slog.Logger 76 config *config 77 privateKey *ecdsa.PrivateKey 78 repoman *RepoMan 79 oauthProvider *provider.Provider 80 evtman *events.EventManager 81 passport *identity.Passport 82 fallbackProxy string 83 84 lastRequestCrawl time.Time 85 requestCrawlMu sync.Mutex 86 87 dbName string 88 dbType string 89 s3Config *S3Config 90} 91 92type Args struct { 93 Logger *slog.Logger 94 95 LogLevel slog.Level 96 Addr string 97 DbName string 98 DbType string 99 DatabaseURL string 100 Version string 101 Did string 102 Hostname string 103 RotationKeyPath string 104 JwkPath string 105 ContactEmail string 106 Relays []string 107 AdminPassword string 108 RequireInvite bool 109 110 SmtpUser string 111 SmtpPass string 112 SmtpHost string 113 SmtpPort string 114 SmtpEmail string 115 SmtpName string 116 117 S3Config *S3Config 118 119 SessionSecret string 120 SessionCookieKey string 121 122 BlockstoreVariant BlockstoreVariant 123 FallbackProxy string 124} 125 126type config struct { 127 LogLevel slog.Level 128 Version string 129 Did string 130 Hostname string 131 ContactEmail string 132 EnforcePeering bool 133 Relays []string 134 AdminPassword string 135 RequireInvite bool 136 SmtpEmail string 137 SmtpName string 138 SessionCookieKey string 139 BlockstoreVariant BlockstoreVariant 140 FallbackProxy string 141} 142 143type CustomValidator struct { 144 validator *validator.Validate 145} 146 147type ValidationError struct { 148 error 149 Field string 150 Tag string 151} 152 153func (cv *CustomValidator) Validate(i any) error { 154 if err := cv.validator.Struct(i); err != nil { 155 var validateErrors validator.ValidationErrors 156 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 157 first := validateErrors[0] 158 return ValidationError{ 159 error: err, 160 Field: first.Field(), 161 Tag: first.Tag(), 162 } 163 } 164 165 return err 166 } 167 168 return nil 169} 170 171//go:embed templates/* 172var templateFS embed.FS 173 174//go:embed static/* 175var staticFS embed.FS 176 177type TemplateRenderer struct { 178 templates *template.Template 179 isDev bool 180 templatePath string 181} 182 183func (s *Server) loadTemplates() { 184 absPath, _ := filepath.Abs("server/templates/*.html") 185 if s.config.Version == "dev" { 186 tmpl := template.Must(template.ParseGlob(absPath)) 187 s.echo.Renderer = &TemplateRenderer{ 188 templates: tmpl, 189 isDev: true, 190 templatePath: absPath, 191 } 192 } else { 193 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html")) 194 s.echo.Renderer = &TemplateRenderer{ 195 templates: tmpl, 196 isDev: false, 197 } 198 } 199} 200 201func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { 202 if t.isDev { 203 tmpl, err := template.ParseGlob(t.templatePath) 204 if err != nil { 205 return err 206 } 207 t.templates = tmpl 208 } 209 210 if viewContext, isMap := data.(map[string]any); isMap { 211 viewContext["reverse"] = c.Echo().Reverse 212 } 213 214 return t.templates.ExecuteTemplate(w, name, data) 215} 216 217type filteredHandler struct { 218 level slog.Level 219 handler slog.Handler 220} 221 222func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool { 223 return level >= h.level && h.handler.Enabled(ctx, level) 224} 225 226func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error { 227 return h.handler.Handle(ctx, r) 228} 229 230func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler { 231 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)} 232} 233 234func (h *filteredHandler) WithGroup(name string) slog.Handler { 235 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)} 236} 237 238func New(args *Args) (*Server, error) { 239 if args.Logger == nil { 240 args.Logger = slog.Default() 241 } 242 243 if args.LogLevel != 0 { 244 args.Logger = slog.New(&filteredHandler{ 245 level: args.LogLevel, 246 handler: args.Logger.Handler(), 247 }) 248 } 249 250 logger := args.Logger.With("name", "New") 251 252 if args.Addr == "" { 253 return nil, fmt.Errorf("addr must be set") 254 } 255 256 if args.DbName == "" { 257 return nil, fmt.Errorf("db name must be set") 258 } 259 260 if args.Did == "" { 261 return nil, fmt.Errorf("cocoon did must be set") 262 } 263 264 if args.ContactEmail == "" { 265 return nil, fmt.Errorf("cocoon contact email is required") 266 } 267 268 if _, err := syntax.ParseDID(args.Did); err != nil { 269 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 270 } 271 272 if args.Hostname == "" { 273 return nil, fmt.Errorf("cocoon hostname must be set") 274 } 275 276 if args.AdminPassword == "" { 277 return nil, fmt.Errorf("admin password must be set") 278 } 279 280 if args.SessionSecret == "" { 281 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 282 } 283 284 e := echo.New() 285 286 e.Pre(middleware.RemoveTrailingSlash()) 287 e.Pre(slogecho.New(args.Logger.With("component", "slogecho"))) 288 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 289 e.Use(echoprometheus.NewMiddleware("cocoon")) 290 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 291 AllowOrigins: []string{"*"}, 292 AllowHeaders: []string{"*"}, 293 AllowMethods: []string{"*"}, 294 AllowCredentials: true, 295 MaxAge: 100_000_000, 296 })) 297 298 vdtor := validator.New() 299 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 300 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 301 return false 302 } 303 return true 304 }) 305 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 306 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 307 return false 308 } 309 return true 310 }) 311 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 312 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 313 return false 314 } 315 return true 316 }) 317 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 318 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 319 return false 320 } 321 return true 322 }) 323 324 e.Validator = &CustomValidator{validator: vdtor} 325 326 httpd := &http.Server{ 327 Addr: args.Addr, 328 Handler: e, 329 // shitty defaults but okay for now, needed for import repo 330 ReadTimeout: 5 * time.Minute, 331 WriteTimeout: 5 * time.Minute, 332 IdleTimeout: 5 * time.Minute, 333 } 334 335 dbType := args.DbType 336 if dbType == "" { 337 dbType = "sqlite" 338 } 339 340 var gdb *gorm.DB 341 var err error 342 switch dbType { 343 case "postgres": 344 if args.DatabaseURL == "" { 345 return nil, fmt.Errorf("database-url must be set when using postgres") 346 } 347 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{}) 348 if err != nil { 349 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 350 } 351 logger.Info("connected to PostgreSQL database") 352 default: 353 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{}) 354 if err != nil { 355 return nil, fmt.Errorf("failed to open sqlite database: %w", err) 356 } 357 gdb.Exec("PRAGMA journal_mode=WAL") 358 gdb.Exec("PRAGMA synchronous=NORMAL") 359 360 logger.Info("connected to SQLite database", "path", args.DbName) 361 } 362 dbw := db.NewDB(gdb) 363 364 rkbytes, err := os.ReadFile(args.RotationKeyPath) 365 if err != nil { 366 return nil, err 367 } 368 369 h := util.RobustHTTPClient() 370 371 plcClient, err := plc.NewClient(&plc.ClientArgs{ 372 H: h, 373 Service: "https://plc.directory", 374 PdsHostname: args.Hostname, 375 RotationKey: rkbytes, 376 }) 377 if err != nil { 378 return nil, err 379 } 380 381 jwkbytes, err := os.ReadFile(args.JwkPath) 382 if err != nil { 383 return nil, err 384 } 385 386 key, err := helpers.ParseJWKFromBytes(jwkbytes) 387 if err != nil { 388 return nil, err 389 } 390 391 var pkey ecdsa.PrivateKey 392 if err := key.Raw(&pkey); err != nil { 393 return nil, err 394 } 395 396 oauthCli := &http.Client{ 397 Timeout: 10 * time.Second, 398 } 399 400 var nonceSecret []byte 401 maybeSecret, err := os.ReadFile("nonce.secret") 402 if err != nil && !os.IsNotExist(err) { 403 logger.Error("error attempting to read nonce secret", "error", err) 404 } else { 405 nonceSecret = maybeSecret 406 } 407 408 s := &Server{ 409 http: h, 410 httpd: httpd, 411 echo: e, 412 logger: args.Logger, 413 db: dbw, 414 plcClient: plcClient, 415 privateKey: &pkey, 416 config: &config{ 417 LogLevel: args.LogLevel, 418 Version: args.Version, 419 Did: args.Did, 420 Hostname: args.Hostname, 421 ContactEmail: args.ContactEmail, 422 EnforcePeering: false, 423 Relays: args.Relays, 424 AdminPassword: args.AdminPassword, 425 RequireInvite: args.RequireInvite, 426 SmtpName: args.SmtpName, 427 SmtpEmail: args.SmtpEmail, 428 SessionCookieKey: args.SessionCookieKey, 429 BlockstoreVariant: args.BlockstoreVariant, 430 FallbackProxy: args.FallbackProxy, 431 }, 432 evtman: events.NewEventManager(events.NewMemPersister()), 433 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 434 435 dbName: args.DbName, 436 dbType: dbType, 437 s3Config: args.S3Config, 438 439 oauthProvider: provider.NewProvider(provider.Args{ 440 Hostname: args.Hostname, 441 ClientManagerArgs: client.ManagerArgs{ 442 Cli: oauthCli, 443 Logger: args.Logger.With("component", "oauth-client-manager"), 444 }, 445 DpopManagerArgs: dpop.ManagerArgs{ 446 NonceSecret: nonceSecret, 447 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 448 OnNonceSecretCreated: func(newNonce []byte) { 449 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil { 450 logger.Error("error writing new nonce secret", "error", err) 451 } 452 }, 453 Logger: args.Logger.With("component", "dpop-manager"), 454 Hostname: args.Hostname, 455 }, 456 }), 457 } 458 459 s.loadTemplates() 460 461 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 462 463 // TODO: should validate these args 464 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 465 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.") 466 } else { 467 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 468 mail.From(s.config.SmtpEmail) 469 mail.FromName(s.config.SmtpName) 470 471 s.mail = mail 472 s.mailLk = &sync.Mutex{} 473 } 474 475 return s, nil 476} 477 478func (s *Server) addRoutes() { 479 // static 480 if s.config.Version == "dev" { 481 s.echo.Static("/static", "server/static") 482 } else { 483 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS)))) 484 } 485 486 // random stuff 487 s.echo.GET("/", s.handleRoot) 488 s.echo.GET("/xrpc/_health", s.handleHealth) 489 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 490 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid) 491 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 492 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 493 s.echo.GET("/robots.txt", s.handleRobots) 494 495 // public 496 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 497 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 498 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 499 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 500 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey) 501 502 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 503 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 504 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 505 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 506 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 507 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 508 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 509 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 510 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 511 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 512 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 513 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 514 515 // labels 516 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels) 517 518 // account 519 s.echo.GET("/account", s.handleAccount) 520 s.echo.POST("/account/revoke", s.handleAccountRevoke) 521 s.echo.GET("/account/signin", s.handleAccountSigninGet) 522 s.echo.POST("/account/signin", s.handleAccountSigninPost) 523 s.echo.GET("/account/signout", s.handleAccountSignout) 524 525 // oauth account 526 s.echo.GET("/oauth/jwks", s.handleOauthJwks) 527 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet) 528 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost) 529 530 // oauth authorization 531 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware) 532 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware) 533 534 // authed 535 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 536 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 537 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 538 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 539 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 540 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 541 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 542 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 543 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 544 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 545 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 546 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 547 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 548 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 549 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 550 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 551 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 552 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 553 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 554 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount) 555 556 // repo 557 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 558 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 559 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 560 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 561 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 562 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 563 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 564 565 // stupid silly endpoints 566 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 567 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 568 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 569 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 570 // admin routes 571 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 572 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 573 574 // are there any routes that we should be allowing without auth? i dont think so but idk 575 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 576 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 577} 578 579func (s *Server) Serve(ctx context.Context) error { 580 logger := s.logger.With("name", "Serve") 581 582 s.addRoutes() 583 584 logger.Info("migrating...") 585 586 s.db.AutoMigrate( 587 &models.Actor{}, 588 &models.Repo{}, 589 &models.InviteCode{}, 590 &models.Token{}, 591 &models.RefreshToken{}, 592 &models.Block{}, 593 &models.Record{}, 594 &models.Blob{}, 595 &models.BlobPart{}, 596 &models.ReservedKey{}, 597 &provider.OauthToken{}, 598 &provider.OauthAuthorizationRequest{}, 599 ) 600 601 logger.Info("starting cocoon") 602 603 go func() { 604 if err := s.httpd.ListenAndServe(); err != nil { 605 panic(err) 606 } 607 }() 608 609 go s.backupRoutine() 610 611 go func() { 612 if err := s.requestCrawl(ctx); err != nil { 613 logger.Error("error requesting crawls", "err", err) 614 } 615 }() 616 617 <-ctx.Done() 618 619 fmt.Println("shut down") 620 621 return nil 622} 623 624func (s *Server) requestCrawl(ctx context.Context) error { 625 logger := s.logger.With("component", "request-crawl") 626 s.requestCrawlMu.Lock() 627 defer s.requestCrawlMu.Unlock() 628 629 logger.Info("requesting crawl with configured relays") 630 631 if time.Since(s.lastRequestCrawl) <= 1*time.Minute { 632 return fmt.Errorf("a crawl request has already been made within the last minute") 633 } 634 635 for _, relay := range s.config.Relays { 636 logger := logger.With("relay", relay) 637 logger.Info("requesting crawl from relay") 638 cli := xrpc.Client{Host: relay} 639 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 640 Hostname: s.config.Hostname, 641 }); err != nil { 642 logger.Error("error requesting crawl", "err", err) 643 } else { 644 logger.Info("crawl requested successfully") 645 } 646 } 647 648 s.lastRequestCrawl = time.Now() 649 650 return nil 651} 652 653func (s *Server) doBackup() { 654 logger := s.logger.With("name", "doBackup") 655 656 if s.dbType == "postgres" { 657 logger.Info("skipping S3 backup - PostgreSQL backups should be handled externally (pg_dump, managed database backups, etc.)") 658 return 659 } 660 661 start := time.Now() 662 663 logger.Info("beginning backup to s3...") 664 665 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano)) 666 defer os.Remove(tmpFile) 667 668 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil { 669 logger.Error("error creating tmp backup file", "err", err) 670 return 671 } 672 673 backupData, err := os.ReadFile(tmpFile) 674 if err != nil { 675 logger.Error("error reading tmp backup file", "err", err) 676 return 677 } 678 679 logger.Info("sending to s3...") 680 681 currTime := time.Now().Format("2006-01-02_15-04-05") 682 key := "cocoon-backup-" + currTime + ".db" 683 684 config := &aws.Config{ 685 Region: aws.String(s.s3Config.Region), 686 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 687 } 688 689 if s.s3Config.Endpoint != "" { 690 config.Endpoint = aws.String(s.s3Config.Endpoint) 691 config.S3ForcePathStyle = aws.Bool(true) 692 } 693 694 sess, err := session.NewSession(config) 695 if err != nil { 696 logger.Error("error creating s3 session", "err", err) 697 return 698 } 699 700 svc := s3.New(sess) 701 702 if _, err := svc.PutObject(&s3.PutObjectInput{ 703 Bucket: aws.String(s.s3Config.Bucket), 704 Key: aws.String(key), 705 Body: bytes.NewReader(backupData), 706 }); err != nil { 707 logger.Error("error uploading file to s3", "err", err) 708 return 709 } 710 711 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds()) 712 713 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644) 714} 715 716func (s *Server) backupRoutine() { 717 logger := s.logger.With("name", "backupRoutine") 718 719 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 720 return 721 } 722 723 if s.s3Config.Region == "" { 724 logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 725 return 726 } 727 728 if s.s3Config.Bucket == "" { 729 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 730 return 731 } 732 733 if s.s3Config.AccessKey == "" { 734 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 735 return 736 } 737 738 if s.s3Config.SecretKey == "" { 739 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 740 return 741 } 742 743 shouldBackupNow := false 744 lastBackupStr, err := os.ReadFile("last-backup.txt") 745 if err != nil { 746 shouldBackupNow = true 747 } else { 748 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr)) 749 if err != nil { 750 shouldBackupNow = true 751 } else if time.Since(lastBackup).Seconds() > 3600 { 752 shouldBackupNow = true 753 } 754 } 755 756 if shouldBackupNow { 757 go s.doBackup() 758 } 759 760 ticker := time.NewTicker(time.Hour) 761 for range ticker.C { 762 go s.doBackup() 763 } 764} 765 766func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 767 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 768 return err 769 } 770 771 return nil 772}