this repo has no description
1package server 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "errors" 7 "fmt" 8 "log/slog" 9 "net/http" 10 "net/smtp" 11 "os" 12 "strings" 13 "sync" 14 "time" 15 16 "github.com/Azure/go-autorest/autorest/to" 17 "github.com/bluesky-social/indigo/api/atproto" 18 "github.com/bluesky-social/indigo/atproto/syntax" 19 "github.com/bluesky-social/indigo/events" 20 "github.com/bluesky-social/indigo/util" 21 "github.com/bluesky-social/indigo/xrpc" 22 "github.com/domodwyer/mailyak/v3" 23 "github.com/go-playground/validator" 24 "github.com/golang-jwt/jwt/v4" 25 "github.com/haileyok/cocoon/identity" 26 "github.com/haileyok/cocoon/internal/helpers" 27 "github.com/haileyok/cocoon/models" 28 "github.com/haileyok/cocoon/plc" 29 "github.com/labstack/echo/v4" 30 "github.com/labstack/echo/v4/middleware" 31 "github.com/lestrrat-go/jwx/v2/jwk" 32 slogecho "github.com/samber/slog-echo" 33 "gorm.io/driver/sqlite" 34 "gorm.io/gorm" 35) 36 37type Server struct { 38 http *http.Client 39 httpd *http.Server 40 mail *mailyak.MailYak 41 mailLk *sync.Mutex 42 echo *echo.Echo 43 db *gorm.DB 44 plcClient *plc.Client 45 logger *slog.Logger 46 config *config 47 privateKey *ecdsa.PrivateKey 48 repoman *RepoMan 49 evtman *events.EventManager 50 passport *identity.Passport 51} 52 53type Args struct { 54 Addr string 55 DbName string 56 Logger *slog.Logger 57 Version string 58 Did string 59 Hostname string 60 RotationKeyPath string 61 JwkPath string 62 ContactEmail string 63 Relays []string 64 65 SmtpUser string 66 SmtpPass string 67 SmtpHost string 68 SmtpPort string 69 SmtpEmail string 70 SmtpName string 71} 72 73type config struct { 74 Version string 75 Did string 76 Hostname string 77 ContactEmail string 78 EnforcePeering bool 79 Relays []string 80 SmtpEmail string 81 SmtpName string 82} 83 84type CustomValidator struct { 85 validator *validator.Validate 86} 87 88type ValidationError struct { 89 error 90 Field string 91 Tag string 92} 93 94func (cv *CustomValidator) Validate(i any) error { 95 if err := cv.validator.Struct(i); err != nil { 96 var validateErrors validator.ValidationErrors 97 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 98 first := validateErrors[0] 99 return ValidationError{ 100 error: err, 101 Field: first.Field(), 102 Tag: first.Tag(), 103 } 104 } 105 106 return err 107 } 108 109 return nil 110} 111 112func (s *Server) handleSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 113 return func(e echo.Context) error { 114 authheader := e.Request().Header.Get("authorization") 115 if authheader == "" { 116 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 117 } 118 119 pts := strings.Split(authheader, " ") 120 if len(pts) != 2 { 121 return helpers.ServerError(e, nil) 122 } 123 124 tokenstr := pts[1] 125 126 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 127 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 128 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 129 } 130 131 return s.privateKey.Public(), nil 132 }) 133 if err != nil { 134 s.logger.Error("error parsing jwt", "error", err) 135 // NOTE: https://github.com/bluesky-social/atproto/discussions/3319 136 return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 137 } 138 139 claims, ok := token.Claims.(jwt.MapClaims) 140 if !ok || !token.Valid { 141 return helpers.InputError(e, to.StringPtr("InvalidToken")) 142 } 143 144 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 145 scope := claims["scope"].(string) 146 147 if isRefresh && scope != "com.atproto.refresh" { 148 return helpers.InputError(e, to.StringPtr("InvalidToken")) 149 } else if !isRefresh && scope != "com.atproto.access" { 150 return helpers.InputError(e, to.StringPtr("InvalidToken")) 151 } 152 153 table := "tokens" 154 if isRefresh { 155 table = "refresh_tokens" 156 } 157 158 type Result struct { 159 Found bool 160 } 161 var result Result 162 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", tokenstr).Scan(&result).Error; err != nil { 163 if err == gorm.ErrRecordNotFound { 164 return helpers.InputError(e, to.StringPtr("InvalidToken")) 165 } 166 167 s.logger.Error("error getting token from db", "error", err) 168 return helpers.ServerError(e, nil) 169 } 170 171 if !result.Found { 172 return helpers.InputError(e, to.StringPtr("InvalidToken")) 173 } 174 175 exp, ok := claims["exp"].(float64) 176 if !ok { 177 s.logger.Error("error getting iat from token") 178 return helpers.ServerError(e, nil) 179 } 180 181 if exp < float64(time.Now().UTC().Unix()) { 182 return helpers.InputError(e, to.StringPtr("ExpiredToken")) 183 } 184 185 repo, err := s.getRepoActorByDid(claims["sub"].(string)) 186 if err != nil { 187 s.logger.Error("error fetching repo", "error", err) 188 return helpers.ServerError(e, nil) 189 } 190 191 e.Set("repo", repo) 192 e.Set("did", claims["sub"]) 193 e.Set("token", tokenstr) 194 195 if err := next(e); err != nil { 196 e.Error(err) 197 } 198 199 return nil 200 } 201} 202 203func New(args *Args) (*Server, error) { 204 if args.Addr == "" { 205 return nil, fmt.Errorf("addr must be set") 206 } 207 208 if args.DbName == "" { 209 return nil, fmt.Errorf("db name must be set") 210 } 211 212 if args.Did == "" { 213 return nil, fmt.Errorf("cocoon did must be set") 214 } 215 216 if args.ContactEmail == "" { 217 return nil, fmt.Errorf("cocoon contact email is required") 218 } 219 220 if _, err := syntax.ParseDID(args.Did); err != nil { 221 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 222 } 223 224 if args.Hostname == "" { 225 return nil, fmt.Errorf("cocoon hostname must be set") 226 } 227 228 if args.Logger == nil { 229 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{})) 230 } 231 232 e := echo.New() 233 234 e.Pre(middleware.RemoveTrailingSlash()) 235 e.Pre(slogecho.New(args.Logger)) 236 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 237 AllowOrigins: []string{"*"}, 238 AllowHeaders: []string{"*"}, 239 AllowMethods: []string{"*"}, 240 AllowCredentials: true, 241 MaxAge: 100_000_000, 242 })) 243 244 vdtor := validator.New() 245 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 246 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 247 return false 248 } 249 return true 250 }) 251 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 252 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 253 return false 254 } 255 return true 256 }) 257 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 258 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 259 return false 260 } 261 return true 262 }) 263 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 264 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 265 return false 266 } 267 return true 268 }) 269 270 e.Validator = &CustomValidator{validator: vdtor} 271 272 httpd := &http.Server{ 273 Addr: args.Addr, 274 Handler: e, 275 } 276 277 db, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{}) 278 if err != nil { 279 return nil, err 280 } 281 282 rkbytes, err := os.ReadFile(args.RotationKeyPath) 283 if err != nil { 284 return nil, err 285 } 286 287 h := util.RobustHTTPClient() 288 289 plcClient, err := plc.NewClient(&plc.ClientArgs{ 290 H: h, 291 Service: "https://plc.directory", 292 PdsHostname: args.Hostname, 293 RotationKey: rkbytes, 294 }) 295 if err != nil { 296 return nil, err 297 } 298 299 jwkbytes, err := os.ReadFile(args.JwkPath) 300 if err != nil { 301 return nil, err 302 } 303 304 key, err := jwk.ParseKey(jwkbytes) 305 if err != nil { 306 return nil, err 307 } 308 309 var pkey ecdsa.PrivateKey 310 if err := key.Raw(&pkey); err != nil { 311 return nil, err 312 } 313 314 s := &Server{ 315 http: h, 316 httpd: httpd, 317 echo: e, 318 logger: args.Logger, 319 db: db, 320 plcClient: plcClient, 321 privateKey: &pkey, 322 config: &config{ 323 Version: args.Version, 324 Did: args.Did, 325 Hostname: args.Hostname, 326 ContactEmail: args.ContactEmail, 327 EnforcePeering: false, 328 Relays: args.Relays, 329 SmtpName: args.SmtpName, 330 SmtpEmail: args.SmtpEmail, 331 }, 332 evtman: events.NewEventManager(events.NewMemPersister()), 333 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 334 } 335 336 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 337 338 // TODO: should validate these args 339 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 340 args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.") 341 } else { 342 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 343 mail.From(s.config.SmtpEmail) 344 mail.FromName(s.config.SmtpName) 345 346 s.mail = mail 347 s.mailLk = &sync.Mutex{} 348 } 349 350 return s, nil 351} 352 353func (s *Server) addRoutes() { 354 // random stuff 355 s.echo.GET("/", s.handleRoot) 356 s.echo.GET("/xrpc/_health", s.handleHealth) 357 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 358 s.echo.GET("/robots.txt", s.handleRobots) 359 360 // public 361 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 362 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 363 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 364 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 365 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 366 367 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 368 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 369 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 370 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 371 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 372 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 373 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 374 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 375 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 376 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 377 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 378 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 379 380 // authed 381 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleSessionMiddleware) 382 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleSessionMiddleware) 383 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleSessionMiddleware) 384 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleSessionMiddleware) 385 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleSessionMiddleware) 386 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleSessionMiddleware) 387 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 388 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleSessionMiddleware) 389 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleSessionMiddleware) 390 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleSessionMiddleware) 391 392 // repo 393 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleSessionMiddleware) 394 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleSessionMiddleware) 395 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleSessionMiddleware) 396 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleSessionMiddleware) 397 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleSessionMiddleware) 398 399 // stupid silly endpoints 400 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleSessionMiddleware) 401 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleSessionMiddleware) 402 403 // are there any routes that we should be allowing without auth? i dont think so but idk 404 s.echo.GET("/xrpc/*", s.handleProxy, s.handleSessionMiddleware) 405 s.echo.POST("/xrpc/*", s.handleProxy, s.handleSessionMiddleware) 406} 407 408func (s *Server) Serve(ctx context.Context) error { 409 s.addRoutes() 410 411 s.logger.Info("migrating...") 412 413 s.db.AutoMigrate( 414 &models.Actor{}, 415 &models.Repo{}, 416 &models.InviteCode{}, 417 &models.Token{}, 418 &models.RefreshToken{}, 419 &models.Block{}, 420 &models.Record{}, 421 &models.Blob{}, 422 &models.BlobPart{}, 423 ) 424 425 s.logger.Info("starting cocoon") 426 427 go func() { 428 if err := s.httpd.ListenAndServe(); err != nil { 429 panic(err) 430 } 431 }() 432 433 for _, relay := range s.config.Relays { 434 cli := xrpc.Client{Host: relay} 435 atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 436 Hostname: s.config.Hostname, 437 }) 438 } 439 440 <-ctx.Done() 441 442 fmt.Println("shut down") 443 444 return nil 445}