An atproto PDS written in Go
at main 243 lines 5.6 kB view raw
1package server 2 3import ( 4 "bytes" 5 "context" 6 "fmt" 7 "sync" 8 "time" 9 10 "github.com/bluesky-social/indigo/api/atproto" 11 "github.com/bluesky-social/indigo/events" 12 indigomodels "github.com/bluesky-social/indigo/models" 13 cbg "github.com/whyrusleeping/cbor-gen" 14 "gorm.io/gorm" 15 16 "github.com/haileyok/cocoon/models" 17) 18 19type DbPersister struct { 20 Db *gorm.DB 21 22 Lk sync.Mutex 23 Seq int64 24 25 Broadcast func(*events.XRPCStreamEvent) 26 27 // how long do we actually want to keep these things around 28 Retention time.Duration 29} 30 31func NewDbPersister(db *gorm.DB, retention time.Duration) (*DbPersister, error) { 32 if err := db.AutoMigrate(&models.EventRecord{}); err != nil { 33 return nil, fmt.Errorf("failed to migrate EventRecord: %w", err) 34 } 35 36 if retention == 0 { 37 retention = 72 * time.Hour 38 } 39 40 p := &DbPersister{ 41 Db: db, 42 Retention: retention, 43 } 44 45 // kind of hacky. we will try and get the latest one from the db, but if it doesn't exist...well we have a problem 46 // because the relay will already have _some_ value > 0 set as a cursor, we'll want to just set this to some high value 47 // we'll just grab a current unix timestamp and set that as the cursor 48 var lastEvent models.EventRecord 49 if err := db.Order("seq desc").Limit(1).First(&lastEvent).Error; err != nil { 50 if err != gorm.ErrRecordNotFound { 51 return nil, fmt.Errorf("failed to get last event seq: %w", err) 52 } 53 p.Seq = time.Now().Unix() 54 } else { 55 p.Seq = lastEvent.Seq 56 } 57 58 go p.cleanupRoutine() 59 60 return p, nil 61} 62 63func (p *DbPersister) SetEventBroadcaster(brc func(*events.XRPCStreamEvent)) { 64 p.Broadcast = brc 65} 66 67func (p *DbPersister) Persist(ctx context.Context, e *events.XRPCStreamEvent) error { 68 p.Lk.Lock() 69 defer p.Lk.Unlock() 70 71 p.Seq++ 72 seq := p.Seq 73 74 var did string 75 var evtType string 76 77 switch { 78 case e.RepoCommit != nil: 79 e.RepoCommit.Seq = seq 80 did = e.RepoCommit.Repo 81 evtType = "commit" 82 case e.RepoSync != nil: 83 e.RepoSync.Seq = seq 84 did = e.RepoSync.Did 85 evtType = "sync" 86 case e.RepoIdentity != nil: 87 e.RepoIdentity.Seq = seq 88 did = e.RepoIdentity.Did 89 evtType = "identity" 90 case e.RepoAccount != nil: 91 e.RepoAccount.Seq = seq 92 did = e.RepoAccount.Did 93 evtType = "account" 94 default: 95 return fmt.Errorf("unknown event type") 96 } 97 98 data, err := serializeEvent(e) 99 if err != nil { 100 return fmt.Errorf("failed to serialize event: %w", err) 101 } 102 103 rec := &models.EventRecord{ 104 Seq: seq, 105 CreatedAt: time.Now(), 106 Did: did, 107 Type: evtType, 108 Data: data, 109 } 110 111 if err := p.Db.Create(rec).Error; err != nil { 112 return fmt.Errorf("failed to persist event: %w", err) 113 } 114 115 if p.Broadcast != nil { 116 p.Broadcast(e) 117 } 118 119 return nil 120} 121 122func (p *DbPersister) Playback(ctx context.Context, since int64, cb func(*events.XRPCStreamEvent) error) error { 123 const pageSize = 500 124 125 cursor := since 126 for { 127 var records []models.EventRecord 128 if err := p.Db.WithContext(ctx). 129 Where("seq > ?", cursor). 130 Order("seq asc"). 131 Limit(pageSize). 132 Find(&records).Error; err != nil { 133 return fmt.Errorf("failed to query events: %w", err) 134 } 135 136 if len(records) == 0 { 137 return nil 138 } 139 140 for _, rec := range records { 141 evt, err := deserializeEvent(rec.Type, rec.Data) 142 if err != nil { 143 return fmt.Errorf("failed to deserialize event %d: %w", rec.Seq, err) 144 } 145 146 if err := cb(evt); err != nil { 147 return err 148 } 149 150 cursor = rec.Seq 151 } 152 153 if len(records) < pageSize { 154 return nil 155 } 156 } 157} 158 159func (p *DbPersister) TakeDownRepo(ctx context.Context, uid indigomodels.Uid) error { 160 return nil 161} 162 163func (p *DbPersister) Flush(ctx context.Context) error { 164 return nil 165} 166 167func (p *DbPersister) Shutdown(ctx context.Context) error { 168 return nil 169} 170 171func (p *DbPersister) cleanupRoutine() { 172 ticker := time.NewTicker(time.Hour) 173 defer ticker.Stop() 174 175 for range ticker.C { 176 cutoff := time.Now().Add(-p.Retention) 177 if err := p.Db.Where("created_at < ?", cutoff).Delete(&models.EventRecord{}).Error; err != nil { 178 continue 179 } 180 } 181} 182 183func serializeEvent(e *events.XRPCStreamEvent) ([]byte, error) { 184 buf := new(bytes.Buffer) 185 cw := cbg.NewCborWriter(buf) 186 187 switch { 188 case e.RepoCommit != nil: 189 if err := e.RepoCommit.MarshalCBOR(cw); err != nil { 190 return nil, err 191 } 192 case e.RepoSync != nil: 193 if err := e.RepoSync.MarshalCBOR(cw); err != nil { 194 return nil, err 195 } 196 case e.RepoIdentity != nil: 197 if err := e.RepoIdentity.MarshalCBOR(cw); err != nil { 198 return nil, err 199 } 200 case e.RepoAccount != nil: 201 if err := e.RepoAccount.MarshalCBOR(cw); err != nil { 202 return nil, err 203 } 204 default: 205 return nil, fmt.Errorf("unknown event type") 206 } 207 208 return buf.Bytes(), nil 209} 210 211func deserializeEvent(evtType string, data []byte) (*events.XRPCStreamEvent, error) { 212 r := bytes.NewReader(data) 213 cr := cbg.NewCborReader(r) 214 215 switch evtType { 216 case "commit": 217 evt := &atproto.SyncSubscribeRepos_Commit{} 218 if err := evt.UnmarshalCBOR(cr); err != nil { 219 return nil, err 220 } 221 return &events.XRPCStreamEvent{RepoCommit: evt}, nil 222 case "sync": 223 evt := &atproto.SyncSubscribeRepos_Sync{} 224 if err := evt.UnmarshalCBOR(cr); err != nil { 225 return nil, err 226 } 227 return &events.XRPCStreamEvent{RepoSync: evt}, nil 228 case "identity": 229 evt := &atproto.SyncSubscribeRepos_Identity{} 230 if err := evt.UnmarshalCBOR(cr); err != nil { 231 return nil, err 232 } 233 return &events.XRPCStreamEvent{RepoIdentity: evt}, nil 234 case "account": 235 evt := &atproto.SyncSubscribeRepos_Account{} 236 if err := evt.UnmarshalCBOR(cr); err != nil { 237 return nil, err 238 } 239 return &events.XRPCStreamEvent{RepoAccount: evt}, nil 240 default: 241 return nil, fmt.Errorf("unknown event type: %s", evtType) 242 } 243}