1package server
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "time"
10
11 "github.com/Azure/go-autorest/autorest/to"
12 "github.com/bluesky-social/indigo/api/atproto"
13 "github.com/bluesky-social/indigo/atproto/atdata"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "github.com/bluesky-social/indigo/carstore"
16 "github.com/bluesky-social/indigo/events"
17 lexutil "github.com/bluesky-social/indigo/lex/util"
18 "github.com/bluesky-social/indigo/repo"
19 "github.com/haileyok/cocoon/internal/db"
20 "github.com/haileyok/cocoon/models"
21 "github.com/haileyok/cocoon/recording_blockstore"
22 blocks "github.com/ipfs/go-block-format"
23 "github.com/ipfs/go-cid"
24 cbor "github.com/ipfs/go-ipld-cbor"
25 "github.com/ipld/go-car"
26 "gorm.io/gorm/clause"
27)
28
29type RepoMan struct {
30 db *db.DB
31 s *Server
32 clock *syntax.TIDClock
33}
34
35func NewRepoMan(s *Server) *RepoMan {
36 clock := syntax.NewTIDClock(0)
37
38 return &RepoMan{
39 s: s,
40 db: s.db,
41 clock: &clock,
42 }
43}
44
45type OpType string
46
47var (
48 OpTypeCreate = OpType("com.atproto.repo.applyWrites#create")
49 OpTypeUpdate = OpType("com.atproto.repo.applyWrites#update")
50 OpTypeDelete = OpType("com.atproto.repo.applyWrites#delete")
51)
52
53func (ot OpType) String() string {
54 return string(ot)
55}
56
57type Op struct {
58 Type OpType `json:"$type"`
59 Collection string `json:"collection"`
60 Rkey *string `json:"rkey,omitempty"`
61 Validate *bool `json:"validate,omitempty"`
62 SwapRecord *string `json:"swapRecord,omitempty"`
63 Record *MarshalableMap `json:"record,omitempty"`
64}
65
66type MarshalableMap map[string]any
67
68type FirehoseOp struct {
69 Cid cid.Cid
70 Path string
71 Action string
72}
73
74func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error {
75 data, err := atdata.MarshalCBOR(*mm)
76 if err != nil {
77 return err
78 }
79
80 w.Write(data)
81
82 return nil
83}
84
85type ApplyWriteResult struct {
86 Type *string `json:"$type,omitempty"`
87 Uri *string `json:"uri,omitempty"`
88 Cid *string `json:"cid,omitempty"`
89 Commit *RepoCommit `json:"commit,omitempty"`
90 ValidationStatus *string `json:"validationStatus,omitempty"`
91}
92
93type RepoCommit struct {
94 Cid string `json:"cid"`
95 Rev string `json:"rev"`
96}
97
98// TODO make use of swap commit
99func (rm *RepoMan) applyWrites(ctx context.Context, urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) {
100 rootcid, err := cid.Cast(urepo.Root)
101 if err != nil {
102 return nil, err
103 }
104
105 dbs := rm.s.getBlockstore(urepo.Did)
106 bs := recording_blockstore.New(dbs)
107 r, err := repo.OpenRepo(ctx, bs, rootcid)
108
109 var results []ApplyWriteResult
110
111 entries := make([]models.Record, 0, len(writes))
112 for i, op := range writes {
113 // updates or deletes must supply an rkey
114 if op.Type != OpTypeCreate && op.Rkey == nil {
115 return nil, fmt.Errorf("invalid rkey")
116 } else if op.Type == OpTypeCreate && op.Rkey != nil {
117 // we should conver this op to an update if the rkey already exists
118 _, _, err := r.GetRecord(ctx, fmt.Sprintf("%s/%s", op.Collection, *op.Rkey))
119 if err == nil {
120 op.Type = OpTypeUpdate
121 }
122 } else if op.Rkey == nil {
123 // creates that don't supply an rkey will have one generated for them
124 op.Rkey = to.StringPtr(rm.clock.Next().String())
125 writes[i].Rkey = op.Rkey
126 }
127
128 // validate the record key is actually valid
129 _, err := syntax.ParseRecordKey(*op.Rkey)
130 if err != nil {
131 return nil, err
132 }
133
134 switch op.Type {
135 case OpTypeCreate:
136 // HACK: this fixes some type conversions, mainly around integers
137 // first we convert to json bytes
138 b, err := json.Marshal(*op.Record)
139 if err != nil {
140 return nil, err
141 }
142 // then we use atdata.UnmarshalJSON to convert it back to a map
143 out, err := atdata.UnmarshalJSON(b)
144 if err != nil {
145 return nil, err
146 }
147 // finally we can cast to a MarshalableMap
148 mm := MarshalableMap(out)
149
150 // HACK: if a record doesn't contain a $type, we can manually set it here based on the op's collection
151 // i forget why this is actually necessary?
152 if mm["$type"] == "" {
153 mm["$type"] = op.Collection
154 }
155
156 nc, err := r.PutRecord(ctx, fmt.Sprintf("%s/%s", op.Collection, *op.Rkey), &mm)
157 if err != nil {
158 return nil, err
159 }
160
161 d, err := atdata.MarshalCBOR(mm)
162 if err != nil {
163 return nil, err
164 }
165
166 entries = append(entries, models.Record{
167 Did: urepo.Did,
168 CreatedAt: rm.clock.Next().String(),
169 Nsid: op.Collection,
170 Rkey: *op.Rkey,
171 Cid: nc.String(),
172 Value: d,
173 })
174
175 results = append(results, ApplyWriteResult{
176 Type: to.StringPtr(OpTypeCreate.String()),
177 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
178 Cid: to.StringPtr(nc.String()),
179 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
180 })
181 case OpTypeDelete:
182 // try to find the old record in the database
183 var old models.Record
184 if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
185 return nil, err
186 }
187
188 // TODO: this is really confusing, and looking at it i have no idea why i did this. below when we are doing deletes, we
189 // check if `cid` here is nil to indicate if we should delete. that really doesn't make much sense and its super illogical
190 // when reading this code. i dont feel like fixing right now though so
191 entries = append(entries, models.Record{
192 Did: urepo.Did,
193 Nsid: op.Collection,
194 Rkey: *op.Rkey,
195 Value: old.Value,
196 })
197
198 // delete the record from the repo
199 err := r.DeleteRecord(ctx, fmt.Sprintf("%s/%s", op.Collection, *op.Rkey))
200 if err != nil {
201 return nil, err
202 }
203
204 // add a result for the delete
205 results = append(results, ApplyWriteResult{
206 Type: to.StringPtr(OpTypeDelete.String()),
207 })
208 case OpTypeUpdate:
209 // HACK: same hack as above for type fixes
210 b, err := json.Marshal(*op.Record)
211 if err != nil {
212 return nil, err
213 }
214 out, err := atdata.UnmarshalJSON(b)
215 if err != nil {
216 return nil, err
217 }
218 mm := MarshalableMap(out)
219
220 nc, err := r.UpdateRecord(ctx, fmt.Sprintf("%s/%s", op.Collection, *op.Rkey), &mm)
221 if err != nil {
222 return nil, err
223 }
224
225 d, err := atdata.MarshalCBOR(mm)
226 if err != nil {
227 return nil, err
228 }
229
230 entries = append(entries, models.Record{
231 Did: urepo.Did,
232 CreatedAt: rm.clock.Next().String(),
233 Nsid: op.Collection,
234 Rkey: *op.Rkey,
235 Cid: nc.String(),
236 Value: d,
237 })
238
239 results = append(results, ApplyWriteResult{
240 Type: to.StringPtr(OpTypeUpdate.String()),
241 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
242 Cid: to.StringPtr(nc.String()),
243 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
244 })
245 }
246 }
247
248 // commit and get the new root
249 newroot, rev, err := r.Commit(ctx, urepo.SignFor)
250 if err != nil {
251 return nil, err
252 }
253
254 // create a buffer for dumping our new cbor into
255 buf := new(bytes.Buffer)
256
257 // first write the car header to the buffer
258 hb, err := cbor.DumpObject(&car.CarHeader{
259 Roots: []cid.Cid{newroot},
260 Version: 1,
261 })
262 if _, err := carstore.LdWrite(buf, hb); err != nil {
263 return nil, err
264 }
265
266 // get a diff of the changes to the repo
267 diffops, err := r.DiffSince(ctx, rootcid)
268 if err != nil {
269 return nil, err
270 }
271
272 // create the repo ops for the given diff
273 ops := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(diffops))
274 for _, op := range diffops {
275 var c cid.Cid
276 switch op.Op {
277 case "add", "mut":
278 kind := "create"
279 if op.Op == "mut" {
280 kind = "update"
281 }
282
283 c = op.NewCid
284 ll := lexutil.LexLink(op.NewCid)
285 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
286 Action: kind,
287 Path: op.Rpath,
288 Cid: &ll,
289 })
290
291 case "del":
292 c = op.OldCid
293 ll := lexutil.LexLink(op.OldCid)
294 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{
295 Action: "delete",
296 Path: op.Rpath,
297 Cid: nil,
298 Prev: &ll,
299 })
300 }
301
302 blk, err := dbs.Get(ctx, c)
303 if err != nil {
304 return nil, err
305 }
306
307 // write the block to the buffer
308 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil {
309 return nil, err
310 }
311 }
312
313 // write the writelog to the buffer
314 for _, op := range bs.GetWriteLog() {
315 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil {
316 return nil, err
317 }
318 }
319
320 // blob blob blob blob blob :3
321 var blobs []lexutil.LexLink
322 for _, entry := range entries {
323 var cids []cid.Cid
324 // whenever there is cid present, we know it's a create (dumb)
325 if entry.Cid != "" {
326 if err := rm.s.db.Create(&entry, []clause.Expression{clause.OnConflict{
327 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}},
328 UpdateAll: true,
329 }}).Error; err != nil {
330 return nil, err
331 }
332
333 // increment the given blob refs, yay
334 cids, err = rm.incrementBlobRefs(urepo, entry.Value)
335 if err != nil {
336 return nil, err
337 }
338 } else {
339 // as i noted above this is dumb. but we delete whenever the cid is nil. it works solely becaue the pkey
340 // is did + collection + rkey. i still really want to separate that out, or use a different type to make
341 // this less confusing/easy to read. alas, its 2 am and yea no
342 if err := rm.s.db.Delete(&entry, nil).Error; err != nil {
343 return nil, err
344 }
345
346 // TODO:
347 cids, err = rm.decrementBlobRefs(urepo, entry.Value)
348 if err != nil {
349 return nil, err
350 }
351 }
352
353 // add all the relevant blobs to the blobs list of blobs. blob ^.^
354 for _, c := range cids {
355 blobs = append(blobs, lexutil.LexLink(c))
356 }
357 }
358
359 // NOTE: using the request ctx seems a bit suss here, so using a background context. i'm not sure if this
360 // runs sync or not
361 rm.s.evtman.AddEvent(context.Background(), &events.XRPCStreamEvent{
362 RepoCommit: &atproto.SyncSubscribeRepos_Commit{
363 Repo: urepo.Did,
364 Blocks: buf.Bytes(),
365 Blobs: blobs,
366 Rev: rev,
367 Since: &urepo.Rev,
368 Commit: lexutil.LexLink(newroot),
369 Time: time.Now().Format(time.RFC3339Nano),
370 Ops: ops,
371 TooBig: false,
372 },
373 })
374
375 if err := rm.s.UpdateRepo(ctx, urepo.Did, newroot, rev); err != nil {
376 return nil, err
377 }
378
379 for i := range results {
380 results[i].Type = to.StringPtr(*results[i].Type + "Result")
381 results[i].Commit = &RepoCommit{
382 Cid: newroot.String(),
383 Rev: rev,
384 }
385 }
386
387 return results, nil
388}
389
390// this is a fun little guy. to get a proof, we need to read the record out of the blockstore and record how we actually
391// got to the guy. we'll wrap a new blockstore in a recording blockstore, then return the log for proof
392func (rm *RepoMan) getRecordProof(ctx context.Context, urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) {
393 c, err := cid.Cast(urepo.Root)
394 if err != nil {
395 return cid.Undef, nil, err
396 }
397
398 dbs := rm.s.getBlockstore(urepo.Did)
399 bs := recording_blockstore.New(dbs)
400
401 r, err := repo.OpenRepo(ctx, bs, c)
402 if err != nil {
403 return cid.Undef, nil, err
404 }
405
406 _, _, err = r.GetRecordBytes(ctx, fmt.Sprintf("%s/%s", collection, rkey))
407 if err != nil {
408 return cid.Undef, nil, err
409 }
410
411 return c, bs.GetReadLog(), nil
412}
413
414func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
415 cids, err := getBlobCidsFromCbor(cbor)
416 if err != nil {
417 return nil, err
418 }
419
420 for _, c := range cids {
421 if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil {
422 return nil, err
423 }
424 }
425
426 return cids, nil
427}
428
429func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
430 cids, err := getBlobCidsFromCbor(cbor)
431 if err != nil {
432 return nil, err
433 }
434
435 for _, c := range cids {
436 var res struct {
437 ID uint
438 Count int
439 }
440 if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
441 return nil, err
442 }
443
444 // TODO: this does _not_ handle deletions of blobs that are on s3 storage!!!! we need to get the blob, see what
445 // storage it is in, and clean up s3!!!!
446 if res.Count == 0 {
447 if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil {
448 return nil, err
449 }
450 if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil {
451 return nil, err
452 }
453 }
454 }
455
456 return cids, nil
457}
458
459// to be honest, we could just store both the cbor and non-cbor in []entries above to avoid an additional
460// unmarshal here. this will work for now though
461func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) {
462 var cids []cid.Cid
463
464 decoded, err := atdata.UnmarshalCBOR(cbor)
465 if err != nil {
466 return nil, fmt.Errorf("error unmarshaling cbor: %w", err)
467 }
468
469 var deepiter func(any) error
470 deepiter = func(item any) error {
471 switch val := item.(type) {
472 case map[string]any:
473 if val["$type"] == "blob" {
474 if ref, ok := val["ref"].(string); ok {
475 c, err := cid.Parse(ref)
476 if err != nil {
477 return err
478 }
479 cids = append(cids, c)
480 }
481 for _, v := range val {
482 return deepiter(v)
483 }
484 }
485 case []any:
486 for _, v := range val {
487 deepiter(v)
488 }
489 }
490
491 return nil
492 }
493
494 if err := deepiter(decoded); err != nil {
495 return nil, err
496 }
497
498 return cids, nil
499}