this repo has no description
1package server 2 3import ( 4 "context" 5 "crypto/ecdsa" 6 "errors" 7 "fmt" 8 "log/slog" 9 "net/http" 10 "os" 11 "strings" 12 "time" 13 14 "github.com/Azure/go-autorest/autorest/to" 15 "github.com/bluesky-social/indigo/api/atproto" 16 "github.com/bluesky-social/indigo/atproto/syntax" 17 "github.com/bluesky-social/indigo/events" 18 "github.com/bluesky-social/indigo/util" 19 "github.com/bluesky-social/indigo/xrpc" 20 "github.com/go-playground/validator" 21 "github.com/golang-jwt/jwt/v4" 22 "github.com/haileyok/cocoon/identity" 23 "github.com/haileyok/cocoon/internal/helpers" 24 "github.com/haileyok/cocoon/models" 25 "github.com/haileyok/cocoon/plc" 26 "github.com/labstack/echo/v4" 27 "github.com/labstack/echo/v4/middleware" 28 "github.com/lestrrat-go/jwx/v2/jwk" 29 slogecho "github.com/samber/slog-echo" 30 "gorm.io/driver/sqlite" 31 "gorm.io/gorm" 32) 33 34type Server struct { 35 http *http.Client 36 httpd *http.Server 37 echo *echo.Echo 38 db *gorm.DB 39 plcClient *plc.Client 40 logger *slog.Logger 41 config *config 42 privateKey *ecdsa.PrivateKey 43 repoman *RepoMan 44 evtman *events.EventManager 45 passport *identity.Passport 46} 47 48type Args struct { 49 Addr string 50 DbName string 51 Logger *slog.Logger 52 Version string 53 Did string 54 Hostname string 55 RotationKeyPath string 56 JwkPath string 57 ContactEmail string 58 Relays []string 59} 60 61type config struct { 62 Version string 63 Did string 64 Hostname string 65 ContactEmail string 66 EnforcePeering bool 67 Relays []string 68} 69 70type CustomValidator struct { 71 validator *validator.Validate 72} 73 74type ValidationError struct { 75 error 76 Field string 77 Tag string 78} 79 80func (cv *CustomValidator) Validate(i any) error { 81 if err := cv.validator.Struct(i); err != nil { 82 var validateErrors validator.ValidationErrors 83 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 84 first := validateErrors[0] 85 return ValidationError{ 86 error: err, 87 Field: first.Field(), 88 Tag: first.Tag(), 89 } 90 } 91 92 return err 93 } 94 95 return nil 96} 97 98func (s *Server) handleSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 99 return func(e echo.Context) error { 100 authheader := e.Request().Header.Get("authorization") 101 if authheader == "" { 102 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 103 } 104 105 pts := strings.Split(authheader, " ") 106 if len(pts) != 2 { 107 return helpers.ServerError(e, nil) 108 } 109 110 tokenstr := pts[1] 111 112 token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 113 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 114 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 115 } 116 117 return s.privateKey.Public(), nil 118 }) 119 if err != nil { 120 s.logger.Error("error parsing jwt", "error", err) 121 return helpers.InputError(e, to.StringPtr("InvalidToken")) 122 } 123 124 claims, ok := token.Claims.(jwt.MapClaims) 125 if !ok || !token.Valid { 126 return helpers.InputError(e, to.StringPtr("InvalidToken")) 127 } 128 129 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 130 scope := claims["scope"].(string) 131 132 if isRefresh && scope != "com.atproto.refresh" { 133 return helpers.InputError(e, to.StringPtr("InvalidToken")) 134 } else if !isRefresh && scope != "com.atproto.access" { 135 return helpers.InputError(e, to.StringPtr("InvalidToken")) 136 } 137 138 table := "tokens" 139 if isRefresh { 140 table = "refresh_tokens" 141 } 142 143 type Result struct { 144 Found bool 145 } 146 var result Result 147 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", tokenstr).Scan(&result).Error; err != nil { 148 if err == gorm.ErrRecordNotFound { 149 return helpers.InputError(e, to.StringPtr("InvalidToken")) 150 } 151 152 s.logger.Error("error getting token from db", "error", err) 153 return helpers.ServerError(e, nil) 154 } 155 156 if !result.Found { 157 return helpers.InputError(e, to.StringPtr("InvalidToken")) 158 } 159 160 exp, ok := claims["exp"].(float64) 161 if !ok { 162 s.logger.Error("error getting iat from token") 163 return helpers.ServerError(e, nil) 164 } 165 166 if exp < float64(time.Now().UTC().Unix()) { 167 return helpers.InputError(e, to.StringPtr("ExpiredToken")) 168 } 169 170 e.Set("did", claims["sub"]) 171 172 repo, err := s.getRepoActorByDid(claims["sub"].(string)) 173 if err != nil { 174 s.logger.Error("error fetching repo", "error", err) 175 return helpers.ServerError(e, nil) 176 } 177 e.Set("repo", repo) 178 179 e.Set("token", tokenstr) 180 181 if err := next(e); err != nil { 182 e.Error(err) 183 } 184 185 return nil 186 } 187} 188 189func New(args *Args) (*Server, error) { 190 if args.Addr == "" { 191 return nil, fmt.Errorf("addr must be set") 192 } 193 194 if args.DbName == "" { 195 return nil, fmt.Errorf("db name must be set") 196 } 197 198 if args.Did == "" { 199 return nil, fmt.Errorf("cocoon did must be set") 200 } 201 202 if args.ContactEmail == "" { 203 return nil, fmt.Errorf("cocoon contact email is required") 204 } 205 206 if _, err := syntax.ParseDID(args.Did); err != nil { 207 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 208 } 209 210 if args.Hostname == "" { 211 return nil, fmt.Errorf("cocoon hostname must be set") 212 } 213 214 if args.Logger == nil { 215 args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{})) 216 } 217 218 e := echo.New() 219 220 e.Pre(middleware.RemoveTrailingSlash()) 221 e.Pre(slogecho.New(args.Logger)) 222 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 223 AllowOrigins: []string{"*"}, 224 AllowHeaders: []string{"*"}, 225 AllowMethods: []string{"*"}, 226 AllowCredentials: true, 227 MaxAge: 100_000_000, 228 })) 229 230 vdtor := validator.New() 231 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 232 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 233 return false 234 } 235 return true 236 }) 237 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 238 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 239 return false 240 } 241 return true 242 }) 243 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 244 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 245 return false 246 } 247 return true 248 }) 249 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 250 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 251 return false 252 } 253 return true 254 }) 255 256 e.Validator = &CustomValidator{validator: vdtor} 257 258 httpd := &http.Server{ 259 Addr: args.Addr, 260 Handler: e, 261 } 262 263 db, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{}) 264 if err != nil { 265 return nil, err 266 } 267 268 rkbytes, err := os.ReadFile(args.RotationKeyPath) 269 if err != nil { 270 return nil, err 271 } 272 273 h := util.RobustHTTPClient() 274 275 plcClient, err := plc.NewClient(&plc.ClientArgs{ 276 H: h, 277 Service: "https://plc.directory", 278 PdsHostname: args.Hostname, 279 RotationKey: rkbytes, 280 }) 281 if err != nil { 282 return nil, err 283 } 284 285 jwkbytes, err := os.ReadFile(args.JwkPath) 286 if err != nil { 287 return nil, err 288 } 289 290 key, err := jwk.ParseKey(jwkbytes) 291 if err != nil { 292 return nil, err 293 } 294 295 var pkey ecdsa.PrivateKey 296 if err := key.Raw(&pkey); err != nil { 297 return nil, err 298 } 299 300 s := &Server{ 301 http: h, 302 httpd: httpd, 303 echo: e, 304 logger: args.Logger, 305 db: db, 306 plcClient: plcClient, 307 privateKey: &pkey, 308 config: &config{ 309 Version: args.Version, 310 Did: args.Did, 311 Hostname: args.Hostname, 312 ContactEmail: args.ContactEmail, 313 EnforcePeering: false, 314 Relays: args.Relays, 315 }, 316 evtman: events.NewEventManager(events.NewMemPersister()), 317 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 318 } 319 320 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 321 322 return s, nil 323} 324 325func (s *Server) addRoutes() { 326 s.echo.GET("/", s.handleRoot) 327 s.echo.GET("/xrpc/_health", s.handleHealth) 328 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 329 s.echo.GET("/robots.txt", s.handleRobots) 330 331 // public 332 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 333 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 334 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 335 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 336 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 337 338 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 339 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 340 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 341 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 342 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 343 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 344 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 345 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 346 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 347 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 348 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 349 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 350 351 // authed 352 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleSessionMiddleware) 353 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleSessionMiddleware) 354 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleSessionMiddleware) 355 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleSessionMiddleware) 356 357 // repo 358 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleSessionMiddleware) 359 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleSessionMiddleware) 360 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleSessionMiddleware) 361 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleSessionMiddleware) 362 363 // stupid silly endpoints 364 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleSessionMiddleware) 365 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleSessionMiddleware) 366 367 s.echo.GET("/xrpc/*", s.handleProxy, s.handleSessionMiddleware) 368 s.echo.POST("/xrpc/*", s.handleProxy, s.handleSessionMiddleware) 369} 370 371func (s *Server) Serve(ctx context.Context) error { 372 s.addRoutes() 373 374 s.logger.Info("migrating...") 375 376 s.db.AutoMigrate( 377 &models.Actor{}, 378 &models.Repo{}, 379 &models.InviteCode{}, 380 &models.Token{}, 381 &models.RefreshToken{}, 382 &models.Block{}, 383 &models.Record{}, 384 &models.Blob{}, 385 &models.BlobPart{}, 386 ) 387 388 s.logger.Info("starting cocoon") 389 390 go func() { 391 if err := s.httpd.ListenAndServe(); err != nil { 392 panic(err) 393 } 394 }() 395 396 for _, relay := range s.config.Relays { 397 cli := xrpc.Client{Host: relay} 398 atproto.SyncRequestCrawl(context.TODO(), &cli, &atproto.SyncRequestCrawl_Input{ 399 Hostname: s.config.Hostname, 400 }) 401 } 402 403 <-ctx.Done() 404 405 fmt.Println("shut down") 406 407 return nil 408}