forked from
hailey.at/cocoon
An atproto PDS written in Go
1package server
2
3import (
4 "bytes"
5 "context"
6 "fmt"
7 "sync"
8 "time"
9
10 "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/events"
12 indigomodels "github.com/bluesky-social/indigo/models"
13 cbg "github.com/whyrusleeping/cbor-gen"
14 "gorm.io/gorm"
15
16 "github.com/haileyok/cocoon/models"
17)
18
19type DbPersister struct {
20 Db *gorm.DB
21
22 Lk sync.Mutex
23 Seq int64
24
25 Broadcast func(*events.XRPCStreamEvent)
26
27 // how long do we actually want to keep these things around
28 Retention time.Duration
29}
30
31func NewDbPersister(db *gorm.DB, retention time.Duration) (*DbPersister, error) {
32 if err := db.AutoMigrate(&models.EventRecord{}); err != nil {
33 return nil, fmt.Errorf("failed to migrate EventRecord: %w", err)
34 }
35
36 if retention == 0 {
37 retention = 72 * time.Hour
38 }
39
40 p := &DbPersister{
41 Db: db,
42 Retention: retention,
43 }
44
45 // kind of hacky. we will try and get the latest one from the db, but if it doesn't exist...well we have a problem
46 // because the relay will already have _some_ value > 0 set as a cursor, we'll want to just set this to some high value
47 // we'll just grab a current unix timestamp and set that as the cursor
48 var lastEvent models.EventRecord
49 if err := db.Order("seq desc").Limit(1).First(&lastEvent).Error; err != nil {
50 if err != gorm.ErrRecordNotFound {
51 return nil, fmt.Errorf("failed to get last event seq: %w", err)
52 }
53 p.Seq = time.Now().Unix()
54 } else {
55 p.Seq = lastEvent.Seq
56 }
57
58 go p.cleanupRoutine()
59
60 return p, nil
61}
62
63func (p *DbPersister) SetEventBroadcaster(brc func(*events.XRPCStreamEvent)) {
64 p.Broadcast = brc
65}
66
67func (p *DbPersister) Persist(ctx context.Context, e *events.XRPCStreamEvent) error {
68 p.Lk.Lock()
69 defer p.Lk.Unlock()
70
71 p.Seq++
72 seq := p.Seq
73
74 var did string
75 var evtType string
76
77 switch {
78 case e.RepoCommit != nil:
79 e.RepoCommit.Seq = seq
80 did = e.RepoCommit.Repo
81 evtType = "commit"
82 case e.RepoSync != nil:
83 e.RepoSync.Seq = seq
84 did = e.RepoSync.Did
85 evtType = "sync"
86 case e.RepoIdentity != nil:
87 e.RepoIdentity.Seq = seq
88 did = e.RepoIdentity.Did
89 evtType = "identity"
90 case e.RepoAccount != nil:
91 e.RepoAccount.Seq = seq
92 did = e.RepoAccount.Did
93 evtType = "account"
94 default:
95 return fmt.Errorf("unknown event type")
96 }
97
98 data, err := serializeEvent(e)
99 if err != nil {
100 return fmt.Errorf("failed to serialize event: %w", err)
101 }
102
103 rec := &models.EventRecord{
104 Seq: seq,
105 CreatedAt: time.Now(),
106 Did: did,
107 Type: evtType,
108 Data: data,
109 }
110
111 if err := p.Db.Create(rec).Error; err != nil {
112 return fmt.Errorf("failed to persist event: %w", err)
113 }
114
115 if p.Broadcast != nil {
116 p.Broadcast(e)
117 }
118
119 return nil
120}
121
122func (p *DbPersister) Playback(ctx context.Context, since int64, cb func(*events.XRPCStreamEvent) error) error {
123 const pageSize = 500
124
125 cursor := since
126 for {
127 var records []models.EventRecord
128 if err := p.Db.WithContext(ctx).
129 Where("seq > ?", cursor).
130 Order("seq asc").
131 Limit(pageSize).
132 Find(&records).Error; err != nil {
133 return fmt.Errorf("failed to query events: %w", err)
134 }
135
136 if len(records) == 0 {
137 return nil
138 }
139
140 for _, rec := range records {
141 evt, err := deserializeEvent(rec.Type, rec.Data)
142 if err != nil {
143 return fmt.Errorf("failed to deserialize event %d: %w", rec.Seq, err)
144 }
145
146 if err := cb(evt); err != nil {
147 return err
148 }
149
150 cursor = rec.Seq
151 }
152
153 if len(records) < pageSize {
154 return nil
155 }
156 }
157}
158
159func (p *DbPersister) TakeDownRepo(ctx context.Context, uid indigomodels.Uid) error {
160 return nil
161}
162
163func (p *DbPersister) Flush(ctx context.Context) error {
164 return nil
165}
166
167func (p *DbPersister) Shutdown(ctx context.Context) error {
168 return nil
169}
170
171func (p *DbPersister) cleanupRoutine() {
172 ticker := time.NewTicker(time.Hour)
173 defer ticker.Stop()
174
175 for range ticker.C {
176 cutoff := time.Now().Add(-p.Retention)
177 if err := p.Db.Where("created_at < ?", cutoff).Delete(&models.EventRecord{}).Error; err != nil {
178 continue
179 }
180 }
181}
182
183func serializeEvent(e *events.XRPCStreamEvent) ([]byte, error) {
184 buf := new(bytes.Buffer)
185 cw := cbg.NewCborWriter(buf)
186
187 switch {
188 case e.RepoCommit != nil:
189 if err := e.RepoCommit.MarshalCBOR(cw); err != nil {
190 return nil, err
191 }
192 case e.RepoSync != nil:
193 if err := e.RepoSync.MarshalCBOR(cw); err != nil {
194 return nil, err
195 }
196 case e.RepoIdentity != nil:
197 if err := e.RepoIdentity.MarshalCBOR(cw); err != nil {
198 return nil, err
199 }
200 case e.RepoAccount != nil:
201 if err := e.RepoAccount.MarshalCBOR(cw); err != nil {
202 return nil, err
203 }
204 default:
205 return nil, fmt.Errorf("unknown event type")
206 }
207
208 return buf.Bytes(), nil
209}
210
211func deserializeEvent(evtType string, data []byte) (*events.XRPCStreamEvent, error) {
212 r := bytes.NewReader(data)
213 cr := cbg.NewCborReader(r)
214
215 switch evtType {
216 case "commit":
217 evt := &atproto.SyncSubscribeRepos_Commit{}
218 if err := evt.UnmarshalCBOR(cr); err != nil {
219 return nil, err
220 }
221 return &events.XRPCStreamEvent{RepoCommit: evt}, nil
222 case "sync":
223 evt := &atproto.SyncSubscribeRepos_Sync{}
224 if err := evt.UnmarshalCBOR(cr); err != nil {
225 return nil, err
226 }
227 return &events.XRPCStreamEvent{RepoSync: evt}, nil
228 case "identity":
229 evt := &atproto.SyncSubscribeRepos_Identity{}
230 if err := evt.UnmarshalCBOR(cr); err != nil {
231 return nil, err
232 }
233 return &events.XRPCStreamEvent{RepoIdentity: evt}, nil
234 case "account":
235 evt := &atproto.SyncSubscribeRepos_Account{}
236 if err := evt.UnmarshalCBOR(cr); err != nil {
237 return nil, err
238 }
239 return &events.XRPCStreamEvent{RepoAccount: evt}, nil
240 default:
241 return nil, fmt.Errorf("unknown event type: %s", evtType)
242 }
243}