Live video on the AT Protocol

statedb: migrate server jwks

authored by

Eli Mallon and committed by
Eli Mallon
e73b0984 89e58ff3

+141 -80
-45
pkg/atproto/jwks.go
··· 1 - package atproto 2 - 3 - import ( 4 - "context" 5 - "encoding/json" 6 - "os" 7 - 8 - "github.com/lestrrat-go/jwx/v2/jwk" 9 - oauth_helpers "github.com/streamplace/atproto-oauth-golang/helpers" 10 - "stream.place/streamplace/pkg/log" 11 - ) 12 - 13 - func EnsureJWK(ctx context.Context, fPath string) (jwk.Key, error) { 14 - var key jwk.Key 15 - _, err := os.Stat(fPath) 16 - if err == nil { 17 - b, err := os.ReadFile(fPath) 18 - if err != nil { 19 - return nil, err 20 - } 21 - key, err = jwk.ParseKey(b) 22 - if err != nil { 23 - return nil, err 24 - } 25 - } else if os.IsNotExist(err) { 26 - key, err = oauth_helpers.GenerateKey(nil) 27 - if err != nil { 28 - return nil, err 29 - } 30 - 31 - b, err := json.Marshal(key) 32 - if err != nil { 33 - return nil, err 34 - } 35 - 36 - if err := os.WriteFile(fPath, b, 0600); err != nil { 37 - return nil, err 38 - } 39 - log.Log(ctx, "generated JWK", "path", fPath) 40 - } else { 41 - return nil, err 42 - } 43 - 44 - return key, nil 45 - }
+2 -4
pkg/cmd/streamplace.go
··· 325 325 } 326 326 } 327 327 328 - jwkPath := cli.DataFilePath([]string{"jwk.json"}) 329 - jwk, err := atproto.EnsureJWK(ctx, jwkPath) 328 + jwk, err := statefulDB.EnsureJWK(ctx, "jwk") 330 329 if err != nil { 331 330 return err 332 331 } 333 332 cli.JWK = jwk 334 333 335 - accessJWKPath := cli.DataFilePath([]string{"access-jwk.json"}) 336 - accessJWK, err := atproto.EnsureJWK(ctx, accessJWKPath) 334 + accessJWK, err := statefulDB.EnsureJWK(ctx, "access-jwk") 337 335 if err != nil { 338 336 return err 339 337 }
+34
pkg/statedb/config.go
··· 1 + package statedb 2 + 3 + import ( 4 + "errors" 5 + "time" 6 + 7 + "gorm.io/gorm" 8 + ) 9 + 10 + type Config struct { 11 + Key string `gorm:"column:key;primarykey"` 12 + Value []byte `gorm:"column:value"` 13 + CreatedAt time.Time `gorm:"column:created_at"` 14 + UpdatedAt time.Time `gorm:"column:updated_at"` 15 + } 16 + 17 + func (state *StatefulDB) GetConfig(key string) (*Config, error) { 18 + var config Config 19 + if err := state.DB.Where("key = ?", key).First(&config).Error; err != nil { 20 + if errors.Is(err, gorm.ErrRecordNotFound) { 21 + return nil, nil 22 + } 23 + return nil, err 24 + } 25 + return &config, nil 26 + } 27 + 28 + func (state *StatefulDB) PutConfig(key string, value []byte) error { 29 + config := Config{ 30 + Key: key, 31 + Value: value, 32 + } 33 + return state.DB.Save(&config).Error 34 + }
+73
pkg/statedb/jwks.go
··· 1 + package statedb 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "fmt" 7 + "os" 8 + 9 + "github.com/lestrrat-go/jwx/v2/jwk" 10 + oauth_helpers "github.com/streamplace/atproto-oauth-golang/helpers" 11 + "stream.place/streamplace/pkg/log" 12 + ) 13 + 14 + func (state *StatefulDB) EnsureJWK(ctx context.Context, name string) (jwk.Key, error) { 15 + var key jwk.Key 16 + 17 + conf, err := state.GetConfig(name) 18 + if err != nil { 19 + return nil, err 20 + } 21 + 22 + // happy path: we found the jwk in the database, use that 23 + if conf != nil { 24 + key, err = jwk.ParseKey(conf.Value) 25 + if err != nil { 26 + return nil, err 27 + } 28 + return key, nil 29 + } 30 + 31 + // migration path: maybe we have an old one on disk. 32 + key, _ = state.getOldJWK(ctx, name) 33 + 34 + // new path: found neither, generate a new one 35 + if key == nil { 36 + log.Warn(ctx, "no JWK found, generating new one", "name", name) 37 + key, err = oauth_helpers.GenerateKey(nil) 38 + if err != nil { 39 + return nil, fmt.Errorf("failed to generate JWK: %w", err) 40 + } 41 + } 42 + 43 + b, err := json.Marshal(key) 44 + if err != nil { 45 + return nil, fmt.Errorf("failed to marshal JWK: %w", err) 46 + } 47 + err = state.PutConfig(name, b) 48 + if err != nil { 49 + return nil, fmt.Errorf("failed to save JWK: %w", err) 50 + } 51 + 52 + return key, nil 53 + } 54 + 55 + // migration for the old one we stored on disk 56 + func (state *StatefulDB) getOldJWK(ctx context.Context, name string) (jwk.Key, error) { 57 + var key jwk.Key 58 + jwkPath := state.CLI.DataFilePath([]string{name + ".json"}) 59 + _, err := os.Stat(jwkPath) 60 + if err == nil { 61 + b, err := os.ReadFile(jwkPath) 62 + if err != nil { 63 + return nil, err 64 + } 65 + key, err = jwk.ParseKey(b) 66 + if err != nil { 67 + return nil, err 68 + } 69 + log.Warn(ctx, "found old JWK on disk, migrating to stateful database", "path", jwkPath) 70 + return key, nil 71 + } 72 + return nil, nil 73 + }
+13 -16
pkg/statedb/notification.go
··· 3 3 import ( 4 4 "fmt" 5 5 "time" 6 - 7 - "gorm.io/gorm" 8 6 ) 9 7 10 8 type Notification struct { 11 - Token string `gorm:"primarykey"` 12 - RepoDID string `json:"repoDID,omitempty" gorm:"column:repo_did;index"` 13 - CreatedAt time.Time 14 - UpdatedAt time.Time 15 - DeletedAt gorm.DeletedAt `gorm:"index"` 9 + Token string `gorm:"column:token;primarykey"` 10 + RepoDID string `json:"repoDID,omitempty" gorm:"column:repo_did;index"` 11 + CreatedAt time.Time `gorm:"column:created_at"` 12 + UpdatedAt time.Time `gorm:"column:updated_at"` 16 13 } 17 14 18 - func (db *StatefulDB) CreateNotification(token string, repoDID string) error { 15 + func (state *StatefulDB) CreateNotification(token string, repoDID string) error { 19 16 not := Notification{ 20 17 Token: token, 21 18 } 22 19 if repoDID != "" { 23 20 not.RepoDID = repoDID 24 21 } 25 - err := db.DB.Save(&not).Error 22 + err := state.DB.Save(&not).Error 26 23 if err != nil { 27 24 return err 28 25 } 29 26 return nil 30 27 } 31 28 32 - func (db *StatefulDB) ListNotifications() ([]Notification, error) { 29 + func (state *StatefulDB) ListNotifications() ([]Notification, error) { 33 30 nots := []Notification{} 34 - err := db.DB.Find(&nots).Error 31 + err := state.DB.Find(&nots).Error 35 32 if err != nil { 36 33 return nil, fmt.Errorf("error retrieving notifications: %w", err) 37 34 } 38 35 return nots, nil 39 36 } 40 37 41 - func (db *StatefulDB) ListUserNotifications(userDID string) ([]Notification, error) { 38 + func (state *StatefulDB) ListUserNotifications(userDID string) ([]Notification, error) { 42 39 nots := []Notification{} 43 - err := db.DB.Where("repo_did = ?", userDID).Find(&nots).Error 40 + err := state.DB.Where("repo_did = ?", userDID).Find(&nots).Error 44 41 if err != nil { 45 42 return nil, fmt.Errorf("error retrieving notifications: %w", err) 46 43 } ··· 48 45 } 49 46 50 47 // todo fixme we don't have followers in this database 51 - func (db *StatefulDB) GetFollowersNotificationTokens(userDID string) ([]string, error) { 48 + func (state *StatefulDB) GetFollowersNotificationTokens(userDID string) ([]string, error) { 52 49 var tokens []string 53 50 54 - err := db.DB.Model(&Notification{}). 51 + err := state.DB.Model(&Notification{}). 55 52 Distinct("notifications.token"). 56 53 Joins("JOIN follows ON follows.user_did = notifications.repo_did"). 57 54 Where("follows.subject_did = ?", userDID). ··· 63 60 } 64 61 65 62 // also you prolly wanna get one for yourself 66 - nots, err := db.ListUserNotifications(userDID) 63 + nots, err := state.ListUserNotifications(userDID) 67 64 if err != nil { 68 65 return nil, fmt.Errorf("error retrieving user notifications: %w", err) 69 66 }
+10 -10
pkg/statedb/oauth_session.go
··· 7 7 "gorm.io/gorm" 8 8 ) 9 9 10 - func (db *StatefulDB) CreateOAuthSession(id string, session *oatproxy.OAuthSession) error { 11 - return db.DB.Create(session).Error 10 + func (state *StatefulDB) CreateOAuthSession(id string, session *oatproxy.OAuthSession) error { 11 + return state.DB.Create(session).Error 12 12 } 13 13 14 - func (db *StatefulDB) LoadOAuthSession(id string) (*oatproxy.OAuthSession, error) { 14 + func (state *StatefulDB) LoadOAuthSession(id string) (*oatproxy.OAuthSession, error) { 15 15 var session oatproxy.OAuthSession 16 - if err := db.DB.Where("downstream_dpop_jkt = ?", id).First(&session).Error; err != nil { 16 + if err := state.DB.Where("downstream_dpop_jkt = ?", id).First(&session).Error; err != nil { 17 17 if errors.Is(err, gorm.ErrRecordNotFound) { 18 18 return nil, nil 19 19 } ··· 22 22 return &session, nil 23 23 } 24 24 25 - func (db *StatefulDB) UpdateOAuthSession(id string, session *oatproxy.OAuthSession) error { 26 - res := db.DB.Model(&oatproxy.OAuthSession{}).Where("downstream_dpop_jkt = ?", id).Updates(session) 25 + func (state *StatefulDB) UpdateOAuthSession(id string, session *oatproxy.OAuthSession) error { 26 + res := state.DB.Model(&oatproxy.OAuthSession{}).Where("downstream_dpop_jkt = ?", id).Updates(session) 27 27 if res.Error != nil { 28 28 return res.Error 29 29 } ··· 33 33 return nil 34 34 } 35 35 36 - func (db *StatefulDB) ListOAuthSessions() ([]oatproxy.OAuthSession, error) { 36 + func (state *StatefulDB) ListOAuthSessions() ([]oatproxy.OAuthSession, error) { 37 37 var sessions []oatproxy.OAuthSession 38 - if err := db.DB.Find(&sessions).Error; err != nil { 38 + if err := state.DB.Find(&sessions).Error; err != nil { 39 39 return nil, err 40 40 } 41 41 return sessions, nil 42 42 } 43 43 44 - func (db *StatefulDB) GetSessionByDID(did string) (*oatproxy.OAuthSession, error) { 44 + func (state *StatefulDB) GetSessionByDID(did string) (*oatproxy.OAuthSession, error) { 45 45 var session oatproxy.OAuthSession 46 - if err := db.DB.Where("repo_did = ? AND revoked_at IS NULL", did).Order("updated_at DESC").First(&session).Error; err != nil { 46 + if err := state.DB.Where("repo_did = ? AND revoked_at IS NULL", did).Order("updated_at DESC").First(&session).Error; err != nil { 47 47 return nil, err 48 48 } 49 49 return &session, nil
+9 -5
pkg/statedb/statedb.go
··· 23 23 CLI *config.CLI 24 24 } 25 25 26 + // list tables here so we can migrate them 27 + var StatefulDBModels = []any{ 28 + oatproxy.OAuthSession{}, 29 + Notification{}, 30 + Config{}, 31 + } 32 + 26 33 var NoPostgresDatabaseCode = "3D000" 27 34 28 35 // Stateful database for storing private streamplace state ··· 68 75 } 69 76 sqlDB.SetMaxOpenConns(1) 70 77 } 71 - for _, model := range []any{ 72 - oatproxy.OAuthSession{}, 73 - Notification{}, 74 - } { 78 + for _, model := range StatefulDBModels { 75 79 err = db.AutoMigrate(model) 76 80 if err != nil { 77 81 return nil, err 78 82 } 79 83 } 80 - return &StatefulDB{DB: db}, nil 84 + return &StatefulDB{DB: db, CLI: cli}, nil 81 85 } 82 86 83 87 func openDB(dial gorm.Dialector) (*gorm.DB, error) {