this repo has no description
1package server 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "embed" 8 "errors" 9 "fmt" 10 "io" 11 "log/slog" 12 "net/http" 13 "net/smtp" 14 "os" 15 "path/filepath" 16 "strings" 17 "sync" 18 "text/template" 19 "time" 20 21 "github.com/Azure/go-autorest/autorest/to" 22 "github.com/aws/aws-sdk-go/aws" 23 "github.com/aws/aws-sdk-go/aws/credentials" 24 "github.com/aws/aws-sdk-go/aws/session" 25 "github.com/aws/aws-sdk-go/service/s3" 26 "github.com/bluesky-social/indigo/api/atproto" 27 "github.com/bluesky-social/indigo/atproto/syntax" 28 "github.com/bluesky-social/indigo/events" 29 "github.com/bluesky-social/indigo/util" 30 "github.com/bluesky-social/indigo/xrpc" 31 "github.com/domodwyer/mailyak/v3" 32 "github.com/go-playground/validator" 33 "github.com/golang-jwt/jwt/v4" 34 "github.com/gorilla/sessions" 35 "github.com/haileyok/cocoon/identity" 36 "github.com/haileyok/cocoon/internal/db" 37 "github.com/haileyok/cocoon/internal/helpers" 38 "github.com/haileyok/cocoon/models" 39 "github.com/haileyok/cocoon/oauth/client_manager" 40 "github.com/haileyok/cocoon/oauth/constants" 41 "github.com/haileyok/cocoon/oauth/dpop/dpop_manager" 42 "github.com/haileyok/cocoon/oauth/provider" 43 "github.com/haileyok/cocoon/plc" 44 echo_session "github.com/labstack/echo-contrib/session" 45 "github.com/labstack/echo/v4" 46 "github.com/labstack/echo/v4/middleware" 47 slogecho "github.com/samber/slog-echo" 48 "gorm.io/driver/sqlite" 49 "gorm.io/gorm" 50) 51 52const ( 53 AccountSessionMaxAge = 30 * 24 * time.Hour // one week 54) 55 56type S3Config struct { 57 BackupsEnabled bool 58 Endpoint string 59 Region string 60 Bucket string 61 AccessKey string 62 SecretKey string 63} 64 65type Server struct { 66 http *http.Client 67 httpd *http.Server 68 mail *mailyak.MailYak 69 mailLk *sync.Mutex 70 echo *echo.Echo 71 db *db.DB 72 plcClient *plc.Client 73 logger *slog.Logger 74 config *config 75 privateKey *ecdsa.PrivateKey 76 repoman *RepoMan 77 oauthProvider *provider.Provider 78 evtman *events.EventManager 79 passport *identity.Passport 80 81 dbName string 82 s3Config *S3Config 83} 84 85type Args struct { 86 Addr string 87 DbName string 88 Logger *slog.Logger 89 Version string 90 Did string 91 Hostname string 92 RotationKeyPath string 93 JwkPath string 94 ContactEmail string 95 Relays []string 96 AdminPassword string 97 98 SmtpUser string 99 SmtpPass string 100 SmtpHost string 101 SmtpPort string 102 SmtpEmail string 103 SmtpName string 104 105 S3Config *S3Config 106 107 SessionSecret string 108} 109 110type config struct { 111 Version string 112 Did string 113 Hostname string 114 ContactEmail string 115 EnforcePeering bool 116 Relays []string 117 AdminPassword string 118 SmtpEmail string 119 SmtpName string 120} 121 122type CustomValidator struct { 123 validator *validator.Validate 124} 125 126type ValidationError struct { 127 error 128 Field string 129 Tag string 130} 131 132func (cv *CustomValidator) Validate(i any) error { 133 if err := cv.validator.Struct(i); err != nil { 134 var validateErrors validator.ValidationErrors 135 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 136 first := validateErrors[0] 137 return ValidationError{ 138 error: err, 139 Field: first.Field(), 140 Tag: first.Tag(), 141 } 142 } 143 144 return err 145 } 146 147 return nil 148} 149 150//go:embed templates/* 151var templateFS embed.FS 152 153//go:embed static/* 154var staticFS embed.FS 155 156type TemplateRenderer struct { 157 templates *template.Template 158 isDev bool 159 templatePath string 160} 161 162func (s *Server) loadTemplates() { 163 absPath, _ := filepath.Abs("server/templates/*.html") 164 if s.config.Version == "dev" { 165 tmpl := template.Must(template.ParseGlob(absPath)) 166 s.echo.Renderer = &TemplateRenderer{ 167 templates: tmpl, 168 isDev: true, 169 templatePath: absPath, 170 } 171 } else { 172 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html")) 173 s.echo.Renderer = &TemplateRenderer{ 174 templates: tmpl, 175 isDev: false, 176 } 177 } 178} 179 180func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { 181 if t.isDev { 182 tmpl, err := template.ParseGlob(t.templatePath) 183 if err != nil { 184 return err 185 } 186 t.templates = tmpl 187 } 188 189 if viewContext, isMap := data.(map[string]any); isMap { 190 viewContext["reverse"] = c.Echo().Reverse 191 } 192 193 return t.templates.ExecuteTemplate(w, name, data) 194} 195 196func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 197 return func(e echo.Context) error { 198 username, password, ok := e.Request().BasicAuth() 199 if !ok || username != "admin" || password != s.config.AdminPassword { 200 return helpers.InputError(e, to.StringPtr("Unauthorized")) 201 } 202 203 if err := next(e); err != nil { 204 e.Error(err) 205 } 206 207 return nil 208 } 209} 210 211func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 212 return func(e echo.Context) error { 213 authheader := e.Request().Header.Get("authorization") 214 if authheader == "" { 215 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 216 } 217 218 pts := strings.Split(authheader, " ") 219 if len(pts) != 2 { 220 return helpers.ServerError(e, nil) 221 } 222 223 if pts[0] == "DPoP" { 224 return next(e) 225 } 226 227 tokenstr := pts[1] 228 229 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 230 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 231 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 232 } 233 234 return s.privateKey.Public(), nil 235 }) 236 if err != nil { 237 s.logger.Error("error parsing jwt", "error", err) 238 // NOTE: https://github.com/bluesky-social/atproto/discussions/3319 239 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 240 } 241 242 claims, ok := token.Claims.(jwt.MapClaims) 243 if !ok || !token.Valid { 244 return helpers.InputError(e, to.StringPtr("InvalidToken")) 245 } 246 247 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 248 scope := claims["scope"].(string) 249 250 if isRefresh && scope != "com.atproto.refresh" { 251 return helpers.InputError(e, to.StringPtr("InvalidToken")) 252 } else if !isRefresh && scope != "com.atproto.access" { 253 return helpers.InputError(e, to.StringPtr("InvalidToken")) 254 } 255 256 table := "tokens" 257 if isRefresh { 258 table = "refresh_tokens" 259 } 260 261 type Result struct { 262 Found bool 263 } 264 var result Result 265 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 266 if err == gorm.ErrRecordNotFound { 267 return helpers.InputError(e, to.StringPtr("InvalidToken")) 268 } 269 270 s.logger.Error("error getting token from db", "error", err) 271 return helpers.ServerError(e, nil) 272 } 273 274 if !result.Found { 275 return helpers.InputError(e, to.StringPtr("InvalidToken")) 276 } 277 278 exp, ok := claims["exp"].(float64) 279 if !ok { 280 s.logger.Error("error getting iat from token") 281 return helpers.ServerError(e, nil) 282 } 283 284 if exp < float64(time.Now().UTC().Unix()) { 285 return helpers.InputError(e, to.StringPtr("ExpiredToken")) 286 } 287 288 repo, err := s.getRepoActorByDid(claims["sub"].(string)) 289 if err != nil { 290 s.logger.Error("error fetching repo", "error", err) 291 return helpers.ServerError(e, nil) 292 } 293 294 e.Set("repo", repo) 295 e.Set("did", claims["sub"]) 296 e.Set("token", tokenstr) 297 298 if err := next(e); err != nil { 299 e.Error(err) 300 } 301 302 return nil 303 } 304} 305 306func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 307 return func(e echo.Context) error { 308 authheader := e.Request().Header.Get("authorization") 309 if authheader == "" { 310 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 311 } 312 313 pts := strings.Split(authheader, " ") 314 if len(pts) != 2 { 315 return helpers.ServerError(e, nil) 316 } 317 318 if pts[0] != "DPoP" { 319 return next(e) 320 } 321 322 accessToken := pts[1] 323 324 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 325 if err != nil { 326 s.logger.Error("invalid dpop proof", "error", err) 327 return helpers.InputError(e, to.StringPtr(err.Error())) 328 } 329 330 var oauthToken provider.OauthToken 331 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 332 s.logger.Error("error finding access token in db", "error", err) 333 return helpers.InputError(e, nil) 334 } 335 336 if oauthToken.Token == "" { 337 return helpers.InputError(e, to.StringPtr("InvalidToken")) 338 } 339 340 if *oauthToken.Parameters.DpopJkt != proof.JKT { 341 s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 342 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 343 } 344 345 if time.Now().After(oauthToken.ExpiresAt) { 346 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 347 } 348 349 repo, err := s.getRepoActorByDid(oauthToken.Sub) 350 if err != nil { 351 s.logger.Error("could not find actor in db", "error", err) 352 return helpers.ServerError(e, nil) 353 } 354 355 nonce := s.oauthProvider.NextNonce() 356 if nonce != "" { 357 e.Response().Header().Set("DPoP-Nonce", nonce) 358 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 359 } 360 361 e.Set("repo", repo) 362 e.Set("did", repo.Repo.Did) 363 e.Set("token", accessToken) 364 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 365 366 return next(e) 367 } 368} 369 370func New(args *Args) (*Server, error) { 371 if args.Addr == "" { 372 return nil, fmt.Errorf("addr must be set") 373 } 374 375 if args.DbName == "" { 376 return nil, fmt.Errorf("db name must be set") 377 } 378 379 if args.Did == "" { 380 return nil, fmt.Errorf("cocoon did must be set") 381 } 382 383 if args.ContactEmail == "" { 384 return nil, fmt.Errorf("cocoon contact email is required") 385 } 386 387 if _, err := syntax.ParseDID(args.Did); err != nil { 388 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 389 } 390 391 if args.Hostname == "" { 392 return nil, fmt.Errorf("cocoon hostname must be set") 393 } 394 395 if args.AdminPassword == "" { 396 return nil, fmt.Errorf("admin password must be set") 397 } 398 399 if args.Logger == nil { 400 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{})) 401 } 402 403 if args.SessionSecret == "" { 404 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 405 } 406 407 e := echo.New() 408 409 e.Pre(middleware.RemoveTrailingSlash()) 410 e.Pre(slogecho.New(args.Logger)) 411 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 412 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 413 AllowOrigins: []string{"*"}, 414 AllowHeaders: []string{"*"}, 415 AllowMethods: []string{"*"}, 416 AllowCredentials: true, 417 MaxAge: 100_000_000, 418 })) 419 420 vdtor := validator.New() 421 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 422 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 423 return false 424 } 425 return true 426 }) 427 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 428 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 429 return false 430 } 431 return true 432 }) 433 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 434 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 435 return false 436 } 437 return true 438 }) 439 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 440 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 441 return false 442 } 443 return true 444 }) 445 446 e.Validator = &CustomValidator{validator: vdtor} 447 448 httpd := &http.Server{ 449 Addr: args.Addr, 450 Handler: e, 451 // shitty defaults but okay for now, needed for import repo 452 ReadTimeout: 5 * time.Minute, 453 WriteTimeout: 5 * time.Minute, 454 IdleTimeout: 5 * time.Minute, 455 } 456 457 gdb, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{}) 458 if err != nil { 459 return nil, err 460 } 461 dbw := db.NewDB(gdb) 462 463 rkbytes, err := os.ReadFile(args.RotationKeyPath) 464 if err != nil { 465 return nil, err 466 } 467 468 h := util.RobustHTTPClient() 469 470 plcClient, err := plc.NewClient(&plc.ClientArgs{ 471 H: h, 472 Service: "https://plc.directory", 473 PdsHostname: args.Hostname, 474 RotationKey: rkbytes, 475 }) 476 if err != nil { 477 return nil, err 478 } 479 480 jwkbytes, err := os.ReadFile(args.JwkPath) 481 if err != nil { 482 return nil, err 483 } 484 485 key, err := helpers.ParseJWKFromBytes(jwkbytes) 486 if err != nil { 487 return nil, err 488 } 489 490 var pkey ecdsa.PrivateKey 491 if err := key.Raw(&pkey); err != nil { 492 return nil, err 493 } 494 495 oauthCli := &http.Client{ 496 Timeout: 10 * time.Second, 497 } 498 499 var nonceSecret []byte 500 maybeSecret, err := os.ReadFile("nonce.secret") 501 if err != nil && !os.IsNotExist(err) { 502 args.Logger.Error("error attempting to read nonce secret", "error", err) 503 } else { 504 nonceSecret = maybeSecret 505 } 506 507 s := &Server{ 508 http: h, 509 httpd: httpd, 510 echo: e, 511 logger: args.Logger, 512 db: dbw, 513 plcClient: plcClient, 514 privateKey: &pkey, 515 config: &config{ 516 Version: args.Version, 517 Did: args.Did, 518 Hostname: args.Hostname, 519 ContactEmail: args.ContactEmail, 520 EnforcePeering: false, 521 Relays: args.Relays, 522 AdminPassword: args.AdminPassword, 523 SmtpName: args.SmtpName, 524 SmtpEmail: args.SmtpEmail, 525 }, 526 evtman: events.NewEventManager(events.NewMemPersister()), 527 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 528 529 dbName: args.DbName, 530 s3Config: args.S3Config, 531 532 oauthProvider: provider.NewProvider(provider.Args{ 533 Hostname: args.Hostname, 534 ClientManagerArgs: client_manager.Args{ 535 Cli: oauthCli, 536 Logger: args.Logger, 537 }, 538 DpopManagerArgs: dpop_manager.Args{ 539 NonceSecret: nonceSecret, 540 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 541 OnNonceSecretCreated: func(newNonce []byte) { 542 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil { 543 args.Logger.Error("error writing new nonce secret", "error", err) 544 } 545 }, 546 Logger: args.Logger, 547 Hostname: args.Hostname, 548 }, 549 }), 550 } 551 552 s.loadTemplates() 553 554 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 555 556 // TODO: should validate these args 557 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 558 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.") 559 } else { 560 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 561 mail.From(s.config.SmtpEmail) 562 mail.FromName(s.config.SmtpName) 563 564 s.mail = mail 565 s.mailLk = &sync.Mutex{} 566 } 567 568 return s, nil 569} 570 571func (s *Server) addRoutes() { 572 // static 573 if s.config.Version == "dev" { 574 s.echo.Static("/static", "server/static") 575 } else { 576 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS)))) 577 } 578 579 // random stuff 580 s.echo.GET("/", s.handleRoot) 581 s.echo.GET("/xrpc/_health", s.handleHealth) 582 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 583 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 584 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 585 s.echo.GET("/robots.txt", s.handleRobots) 586 587 // public 588 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 589 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 590 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 591 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 592 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 593 594 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 595 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 596 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 597 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 598 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 599 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 600 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 601 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 602 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 603 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 604 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 605 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 606 607 // account 608 s.echo.GET("/account", s.handleAccount) 609 s.echo.POST("/account/revoke", s.handleAccountRevoke) 610 s.echo.GET("/account/signin", s.handleAccountSigninGet) 611 s.echo.POST("/account/signin", s.handleAccountSigninPost) 612 s.echo.GET("/account/signout", s.handleAccountSignout) 613 614 // oauth account 615 s.echo.GET("/oauth/jwks", s.handleOauthJwks) 616 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet) 617 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost) 618 619 // oauth authorization 620 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware) 621 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware) 622 623 // authed 624 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 625 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 626 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 627 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 628 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 629 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 630 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 631 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 632 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 633 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 634 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 635 636 // repo 637 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 638 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 639 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 640 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 641 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 642 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 643 644 // stupid silly endpoints 645 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 646 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 647 648 // are there any routes that we should be allowing without auth? i dont think so but idk 649 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 650 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 651 652 // admin routes 653 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 654 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 655} 656 657func (s *Server) Serve(ctx context.Context) error { 658 s.addRoutes() 659 660 s.logger.Info("migrating...") 661 662 s.db.AutoMigrate( 663 &models.Actor{}, 664 &models.Repo{}, 665 &models.InviteCode{}, 666 &models.Token{}, 667 &models.RefreshToken{}, 668 &models.Block{}, 669 &models.Record{}, 670 &models.Blob{}, 671 &models.BlobPart{}, 672 &provider.OauthToken{}, 673 &provider.OauthAuthorizationRequest{}, 674 ) 675 676 s.logger.Info("starting cocoon") 677 678 go func() { 679 if err := s.httpd.ListenAndServe(); err != nil { 680 panic(err) 681 } 682 }() 683 684 go s.backupRoutine() 685 686 for _, relay := range s.config.Relays { 687 cli := xrpc.Client{Host: relay} 688 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 689 Hostname: s.config.Hostname, 690 }) 691 } 692 693 <-ctx.Done() 694 695 fmt.Println("shut down") 696 697 return nil 698} 699 700func (s *Server) doBackup() { 701 start := time.Now() 702 703 s.logger.Info("beginning backup to s3...") 704 705 var buf bytes.Buffer 706 if err := func() error { 707 s.logger.Info("reading database bytes...") 708 s.db.Lock() 709 defer s.db.Unlock() 710 711 sf, err := os.Open(s.dbName) 712 if err != nil { 713 return fmt.Errorf("error opening database for backup: %w", err) 714 } 715 defer sf.Close() 716 717 if _, err := io.Copy(&buf, sf); err != nil { 718 return fmt.Errorf("error reading bytes of backup db: %w", err) 719 } 720 721 return nil 722 }(); err != nil { 723 s.logger.Error("error backing up database", "error", err) 724 return 725 } 726 727 if err := func() error { 728 s.logger.Info("sending to s3...") 729 730 currTime := time.Now().Format("2006-01-02_15-04-05") 731 key := "cocoon-backup-" + currTime + ".db" 732 733 config := &aws.Config{ 734 Region: aws.String(s.s3Config.Region), 735 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 736 } 737 738 if s.s3Config.Endpoint != "" { 739 config.Endpoint = aws.String(s.s3Config.Endpoint) 740 config.S3ForcePathStyle = aws.Bool(true) 741 } 742 743 sess, err := session.NewSession(config) 744 if err != nil { 745 return err 746 } 747 748 svc := s3.New(sess) 749 750 if _, err := svc.PutObject(&s3.PutObjectInput{ 751 Bucket: aws.String(s.s3Config.Bucket), 752 Key: aws.String(key), 753 Body: bytes.NewReader(buf.Bytes()), 754 }); err != nil { 755 return fmt.Errorf("error uploading file to s3: %w", err) 756 } 757 758 s.logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds()) 759 760 return nil 761 }(); err != nil { 762 s.logger.Error("error uploading database backup", "error", err) 763 return 764 } 765 766 os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644) 767} 768 769func (s *Server) backupRoutine() { 770 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 771 return 772 } 773 774 if s.s3Config.Region == "" { 775 s.logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 776 return 777 } 778 779 if s.s3Config.Bucket == "" { 780 s.logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 781 return 782 } 783 784 if s.s3Config.AccessKey == "" { 785 s.logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 786 return 787 } 788 789 if s.s3Config.SecretKey == "" { 790 s.logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 791 return 792 } 793 794 shouldBackupNow := false 795 lastBackupStr, err := os.ReadFile("last-backup.txt") 796 if err != nil { 797 shouldBackupNow = true 798 } else { 799 lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr)) 800 if err != nil { 801 shouldBackupNow = true 802 } else if time.Now().Sub(lastBackup).Seconds() > 3600 { 803 shouldBackupNow = true 804 } 805 } 806 807 if shouldBackupNow { 808 go s.doBackup() 809 } 810 811 ticker := time.NewTicker(time.Hour) 812 for range ticker.C { 813 go s.doBackup() 814 } 815}