A very experimental PLC implementation which uses BFT consensus for decentralization
1package store
2
3import (
4 "bufio"
5 "bytes"
6 "crypto/sha256"
7 "encoding/binary"
8 "errors"
9 "fmt"
10 "io"
11 "math"
12 "os"
13 "path/filepath"
14 "slices"
15 "strconv"
16 "strings"
17 "sync"
18 "time"
19
20 cmttypes "github.com/cometbft/cometbft/types"
21 "github.com/cosmos/iavl"
22 "github.com/gbl08ma/stacktrace"
23 "github.com/klauspost/compress/zstd"
24 "github.com/samber/lo"
25 "tangled.org/gbl08ma.com/didplcbft/transaction"
26)
27
28// Snapshot format constants
29const (
30 SnapshotChunkSize = 10 * 1024 * 1024 // 10 MB
31 SnapshotChunkHashSize = 32
32 SnapshotFormatVersion = 2
33 SnapshotFileMagic = "didplcbft-snapshot"
34)
35
36// BlockHeaderGetter is a function type for retrieving block headers
37type BlockHeaderGetter func(height int64) (cmttypes.Header, error)
38
39// ErrMalformedChunk is returned when a snapshot chunk is invalid
40var ErrMalformedChunk = errors.New("malformed chunk")
41
42// ErrInvalidMetadata is returned when snapshot metadata is invalid
43var ErrInvalidMetadata = errors.New("invalid metadata")
44
45// ErrTreeHashMismatch is returned when the imported tree hash doesn't match expected
46var ErrTreeHashMismatch = errors.New("tree hash mismatch")
47
48// ErrIndexEntryCountMismatch is returned when the imported index entry count doesn't match expected
49var ErrIndexEntryCountMismatch = errors.New("index entry count mismatch")
50
51// SnapshotStore provides snapshot creation and import functionality
52type SnapshotStore struct{}
53
54var Snapshot = &SnapshotStore{}
55
56// Export creates a snapshot of the current state and writes it to the provided writer
57func (s *SnapshotStore) Export(writerSeeker io.WriteSeeker, tx transaction.Read, blockHeaderGetter BlockHeaderGetter, updateValidatorsBlockInterval uint64, numRecentBlockHeaders int64) error {
58 return writeSnapshot(writerSeeker, tx, blockHeaderGetter, updateValidatorsBlockInterval, numRecentBlockHeaders)
59}
60
61// WriteChunkHashes calculates and writes chunk hashes for a snapshot file
62func (s *SnapshotStore) WriteChunkHashes(snapshotFile io.ReadSeeker, w io.Writer) error {
63 return writeChunkHashes(snapshotFile, w)
64}
65
66// CreateSnapshotApplier creates a new SnapshotApplier for applying a snapshot
67func (s *SnapshotStore) CreateSnapshotApplier(tree *iavl.MutableTree, txFactory *transaction.Factory, treeVersion int64, expectedFinalHash []byte, expectedChunkHashes [][]byte) (*SnapshotApplier, error) {
68 writeTx, err := txFactory.ReadWorking(time.Now()).UpgradeForIndexOnly()
69 if err != nil {
70 return nil, stacktrace.Propagate(err)
71 }
72
73 sa := &SnapshotApplier{
74 writeIndex: writeTx,
75 tree: tree,
76 treeVersion: treeVersion,
77 expectedFinalHash: expectedFinalHash,
78 expectedChunkHashes: expectedChunkHashes,
79 }
80
81 sa.importer, err = sa.tree.Import(treeVersion)
82 if err != nil {
83 return nil, stacktrace.Propagate(err)
84 }
85
86 sa.compressImporter = iavl.NewCompressImporter(sa.importer)
87
88 return sa, nil
89}
90
91func writeSnapshot(writerSeeker io.WriteSeeker, tx transaction.Read, blockHeaderGetter BlockHeaderGetter, updateValidatorsBlockInterval uint64, numRecentBlockHeaders int64) error {
92 it, ok := transaction.UnderlyingImmutableTree(tx.Tree())
93 if !ok {
94 return stacktrace.NewError("expected immutable tree")
95 }
96
97 writtenUntilReservedFields := 0
98
99 bw := bufio.NewWriter(writerSeeker)
100
101 // file magic and version
102 c, err := bw.Write([]byte(SnapshotFileMagic))
103 if err != nil {
104 return stacktrace.Propagate(err)
105 }
106 writtenUntilReservedFields += c
107
108 c, err = bw.Write([]byte{0, 0, 0, 0, 0, byte(SnapshotFormatVersion)})
109 if err != nil {
110 return stacktrace.Propagate(err)
111 }
112 writtenUntilReservedFields += c
113
114 b := make([]byte, 8)
115 binary.BigEndian.PutUint64(b, uint64(it.Version()))
116 c, err = bw.Write(b)
117 if err != nil {
118 return stacktrace.Propagate(err)
119 }
120 writtenUntilReservedFields += c
121
122 c, err = bw.Write(it.Hash())
123 if err != nil {
124 return stacktrace.Propagate(err)
125 }
126 writtenUntilReservedFields += c
127
128 // reserve space for writing:
129 // - 8 bytes for compressed section size in bytes
130 // - 8 bytes for number of index entries
131 // - 8 bytes for number of nodes
132 sizeOfReservedFields := 8 * 3
133 b = make([]byte, sizeOfReservedFields)
134 _, err = bw.Write(b)
135 if err != nil {
136 return stacktrace.Propagate(err)
137 }
138
139 zstdw, err := zstd.NewWriter(bw, zstd.WithEncoderLevel(zstd.SpeedBetterCompression))
140 if err != nil {
141 return stacktrace.Propagate(err)
142 }
143
144 numIndexEntries, err := exportIndexEntries(tx, blockHeaderGetter, updateValidatorsBlockInterval, numRecentBlockHeaders, it.Version(), zstdw)
145 if err != nil {
146 return stacktrace.Propagate(err)
147 }
148
149 numNodes, err := exportNodes(it, zstdw)
150 if err != nil {
151 return stacktrace.Propagate(err)
152 }
153
154 err = zstdw.Close()
155 if err != nil {
156 return stacktrace.Propagate(err)
157 }
158
159 err = bw.Flush()
160 if err != nil {
161 return stacktrace.Propagate(err)
162 }
163
164 // find total compressed section size
165 offset, err := writerSeeker.Seek(0, io.SeekCurrent)
166 if err != nil {
167 return stacktrace.Propagate(err)
168 }
169 compressedSectionSize := offset - int64(writtenUntilReservedFields) - int64(sizeOfReservedFields)
170
171 // seek back and write empty header fields
172
173 offset, err = writerSeeker.Seek(int64(writtenUntilReservedFields), io.SeekStart)
174 if err != nil {
175 return stacktrace.Propagate(err)
176 }
177 if offset != int64(writtenUntilReservedFields) {
178 return stacktrace.NewError("unexpected seek result")
179 }
180
181 b = make([]byte, sizeOfReservedFields)
182 binary.BigEndian.PutUint64(b, uint64(compressedSectionSize))
183 binary.BigEndian.PutUint64(b[8:], uint64(numIndexEntries))
184 binary.BigEndian.PutUint64(b[16:], uint64(numNodes))
185 _, err = writerSeeker.Write(b)
186 if err != nil {
187 return stacktrace.Propagate(err)
188 }
189
190 return nil
191}
192
193func exportIndexEntries(tx transaction.Read, blockHeaderGetter BlockHeaderGetter, updateValidatorsBlockInterval uint64, numRecentBlockHeaders int64, treeVersion int64, w io.Writer) (int64, error) {
194 indexDB := tx.IndexDB()
195 numDIDEntries, err := exportIndexDIDEntries(indexDB, treeVersion, w)
196 if err != nil {
197 return 0, stacktrace.Propagate(err)
198 }
199
200 numValidatorParticipationEntries, err := exportIndexValidatorParticipation(indexDB, updateValidatorsBlockInterval, treeVersion, w)
201 if err != nil {
202 return 0, stacktrace.Propagate(err)
203 }
204
205 numRecentBlockHeadersExported, err := exportRecentBlockHeaders(blockHeaderGetter, numRecentBlockHeaders, treeVersion, w)
206 if err != nil {
207 return 0, stacktrace.Propagate(err)
208 }
209
210 return numDIDEntries + numValidatorParticipationEntries + numRecentBlockHeadersExported, nil
211}
212
213func exportIndexDIDEntries(indexDB transaction.IndexReader, treeVersion int64, w io.Writer) (int64, error) {
214 didLogKeyStart := make([]byte, IndexDIDLogKeyLength)
215 didLogKeyStart[0] = IndexDIDLogKeyPrefix
216 didLogKeyEnd := slices.Repeat([]byte{0xff}, IndexDIDLogKeyLength)
217 didLogKeyEnd[0] = IndexDIDLogKeyPrefix
218
219 iterator, err := indexDB.Iterator(didLogKeyStart, didLogKeyEnd)
220 if err != nil {
221 return 0, stacktrace.Propagate(err)
222 }
223 defer iterator.Close()
224
225 numEntries := int64(0)
226 for iterator.Valid() {
227 key := iterator.Key()
228 value := iterator.Value()
229
230 validFromHeight, validToHeight := UnmarshalDIDLogValue(value)
231 if uint64(treeVersion) >= validFromHeight && uint64(treeVersion) <= validToHeight {
232 header := make([]byte, 4+4)
233 binary.BigEndian.PutUint32(header, uint32(len(key)))
234 binary.BigEndian.PutUint32(header[4:], uint32(len(value)))
235
236 _, err = w.Write(header)
237 if err != nil {
238 return 0, stacktrace.Propagate(err)
239 }
240
241 _, err = w.Write(key)
242 if err != nil {
243 return 0, stacktrace.Propagate(err)
244 }
245
246 _, err = w.Write(value)
247 if err != nil {
248 return 0, stacktrace.Propagate(err)
249 }
250
251 numEntries++
252 }
253
254 iterator.Next()
255 }
256 return numEntries, nil
257}
258
259func exportIndexValidatorParticipation(indexDB transaction.IndexReader, updateValidatorsBlockInterval uint64, treeVersion int64, w io.Writer) (int64, error) {
260 epochHeight := uint64(treeVersion) - uint64(treeVersion)%updateValidatorsBlockInterval
261 startKey := MarshalValidatorVotingActivityKey(uint64(epochHeight), make([]byte, AddressLength))
262 endKey := MarshalValidatorVotingActivityKey(uint64(epochHeight), slices.Repeat([]byte{0xff}, AddressLength))
263
264 iterator, err := indexDB.Iterator(startKey, endKey)
265 if err != nil {
266 return 0, stacktrace.Propagate(err)
267 }
268 defer iterator.Close()
269
270 numEntries := int64(0)
271 for iterator.Valid() {
272 key := iterator.Key()
273 value := iterator.Value()
274
275 header := make([]byte, 4+4)
276 binary.BigEndian.PutUint32(header, uint32(len(key)))
277 binary.BigEndian.PutUint32(header[4:], uint32(len(value)))
278
279 _, err = w.Write(header)
280 if err != nil {
281 return 0, stacktrace.Propagate(err)
282 }
283
284 _, err = w.Write(key)
285 if err != nil {
286 return 0, stacktrace.Propagate(err)
287 }
288
289 _, err = w.Write(value)
290 if err != nil {
291 return 0, stacktrace.Propagate(err)
292 }
293
294 numEntries++
295
296 iterator.Next()
297 }
298
299 if numEntries == 0 {
300 // there should always be at least one active validator
301 return 0, stacktrace.NewError("unexpectedly missing index entries for validator voting participation - treeVersion may be too old to export")
302 }
303
304 return numEntries, nil
305}
306
307func exportRecentBlockHeaders(blockHeaderGetter BlockHeaderGetter, numRecentBlockHeaders int64, treeVersion int64, w io.Writer) (int64, error) {
308 startHeight := treeVersion - numRecentBlockHeaders + 1 // plus one because we want to include the block at treeVersion
309 if startHeight < 1 {
310 startHeight = 1
311 }
312
313 numExportedBlockHeaders := int64(0)
314 for height := startHeight; height <= treeVersion; height++ {
315 blockHeader, err := blockHeaderGetter(height)
316 if err != nil {
317 return 0, stacktrace.Propagate(err)
318 }
319
320 blockHeaderProto := blockHeader.ToProto()
321 blockHeaderBytes, err := blockHeaderProto.Marshal()
322 if err != nil {
323 return 0, stacktrace.Propagate(err)
324 }
325
326 key := make([]byte, 1+8)
327 key[0] = IndexBlockHeaderKeyPrefix
328 binary.BigEndian.PutUint64(key[1:], uint64(height))
329
330 header := make([]byte, 4+4)
331 binary.BigEndian.PutUint32(header, uint32(len(key)))
332 binary.BigEndian.PutUint32(header[4:], uint32(len(blockHeaderBytes)))
333
334 _, err = w.Write(header)
335 if err != nil {
336 return 0, stacktrace.Propagate(err)
337 }
338
339 _, err = w.Write(key)
340 if err != nil {
341 return 0, stacktrace.Propagate(err)
342 }
343
344 _, err = w.Write(blockHeaderBytes)
345 if err != nil {
346 return 0, stacktrace.Propagate(err)
347 }
348
349 numExportedBlockHeaders++
350 }
351
352 return numExportedBlockHeaders, nil
353}
354
355func exportNodes(it *iavl.ImmutableTree, w io.Writer) (int64, error) {
356 exporter, err := it.Export()
357 if err != nil {
358 return 0, stacktrace.Propagate(err)
359 }
360 defer exporter.Close()
361 cexporter := iavl.NewCompressExporter(exporter)
362
363 // 1 byte for node height
364 // 8 bytes for node version
365 // 4 bytes for node key length (0xffffffff if node key is nil)
366 // 4 bytes for node value length (0xffffffff if node value is nil)
367 nodeHeaderBuffer := make([]byte, 1+8+4+4)
368
369 numNodes := int64(0)
370 for {
371 node, err := cexporter.Next()
372 if errors.Is(err, iavl.ErrorExportDone) {
373 break
374 }
375 if err != nil {
376 return 0, stacktrace.Propagate(err)
377 }
378
379 nodeHeaderBuffer[0] = byte(node.Height)
380
381 binary.BigEndian.PutUint64(nodeHeaderBuffer[1:], uint64(node.Version))
382
383 // nil node values are different from 0-byte values
384 binary.BigEndian.PutUint32(nodeHeaderBuffer[9:13], lo.Ternary(node.Key != nil, uint32(len(node.Key)), math.MaxUint32))
385 binary.BigEndian.PutUint32(nodeHeaderBuffer[13:17], lo.Ternary(node.Value != nil, uint32(len(node.Value)), math.MaxUint32))
386
387 _, err = w.Write(nodeHeaderBuffer)
388 if err != nil {
389 return 0, stacktrace.Propagate(err)
390 }
391
392 _, err = w.Write(node.Key)
393 if err != nil {
394 return 0, stacktrace.Propagate(err)
395 }
396
397 _, err = w.Write(node.Value)
398 if err != nil {
399 return 0, stacktrace.Propagate(err)
400 }
401 numNodes++
402 }
403
404 return numNodes, nil
405}
406
407func writeChunkHashes(snapshotFile io.ReadSeeker, w io.Writer) error {
408 bw := bufio.NewWriter(w)
409 defer bw.Flush()
410
411 _, err := snapshotFile.Seek(0, io.SeekStart)
412 if err != nil {
413 return stacktrace.Propagate(err)
414 }
415
416 buf := make([]byte, SnapshotChunkSize)
417 for {
418 n, err := io.ReadFull(snapshotFile, buf)
419 if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
420 return stacktrace.Propagate(err)
421 }
422 if n == 0 {
423 break
424 }
425
426 hash := sha256.Sum256(buf[:n])
427 c, err := w.Write(hash[:])
428 if err != nil {
429 return stacktrace.Propagate(err)
430 }
431
432 if c != SnapshotChunkHashSize {
433 return stacktrace.NewError("unexpected chunk hash size")
434 }
435
436 if n < SnapshotChunkSize {
437 break
438 }
439 }
440
441 return nil
442}
443
444// SnapshotApplier handles applying snapshot chunks
445type SnapshotApplier struct {
446 writeIndex transaction.WriteIndex
447 tree *iavl.MutableTree
448 treeVersion int64
449 expectedFinalHash []byte
450 expectedChunkHashes [][]byte
451
452 pipeWriter *io.PipeWriter
453 pipeReader *io.PipeReader
454 zstdReader io.ReadCloser
455
456 importer *iavl.Importer
457 compressImporter iavl.NodeImporter
458 importerWg sync.WaitGroup
459
460 numImportedNodes int
461 claimedNodeCount int
462 numImportedIndexEntries int
463 claimedIndexEntryCount int
464 done bool
465}
466
467func (a *SnapshotApplier) initApplier() {
468 pipeReader, pipeWriter := io.Pipe()
469
470 zstdReader, err := zstd.NewReader(pipeReader)
471 if err != nil {
472 _ = pipeReader.Close()
473 _ = pipeWriter.Close()
474 return
475 }
476
477 a.pipeWriter = pipeWriter
478 a.pipeReader = pipeReader
479 a.zstdReader = zstdReader.IOReadCloser()
480
481 a.importerWg.Go(a.streamingImporter)
482}
483
484func (a *SnapshotApplier) Apply(chunkIndex int, chunkBytes []byte) error {
485 if len(chunkBytes) > SnapshotChunkSize {
486 return stacktrace.Propagate(ErrMalformedChunk, "chunk too large")
487 }
488 hash := sha256.Sum256(chunkBytes)
489 if !bytes.Equal(a.expectedChunkHashes[chunkIndex], hash[:]) {
490 return stacktrace.Propagate(ErrMalformedChunk, "hash mismatch")
491 }
492
493 if chunkIndex == 0 {
494 if len(chunkBytes) < 88 {
495 return stacktrace.Propagate(ErrMalformedChunk, "chunk too small")
496 }
497
498 if string(chunkBytes[0:18]) != SnapshotFileMagic {
499 return stacktrace.Propagate(ErrMalformedChunk, "invalid file magic")
500 }
501
502 if binary.BigEndian.Uint32(chunkBytes[20:]) != SnapshotFormatVersion {
503 return stacktrace.Propagate(ErrMalformedChunk, "invalid snapshot format")
504 }
505
506 if binary.BigEndian.Uint64(chunkBytes[24:]) != uint64(a.treeVersion) {
507 return stacktrace.Propagate(ErrMalformedChunk, "mismatched tree version")
508 }
509
510 if !bytes.Equal(chunkBytes[32:64], a.expectedFinalHash) {
511 return stacktrace.Propagate(ErrMalformedChunk, "mismatched declared tree hash")
512 }
513
514 declaredFileSize := 88 + binary.BigEndian.Uint64(chunkBytes[64:])
515 minExpectedSize := uint64((len(a.expectedChunkHashes) - 1) * SnapshotChunkSize)
516 maxExpectedSize := uint64(len(a.expectedChunkHashes) * SnapshotChunkSize)
517 if declaredFileSize < minExpectedSize ||
518 declaredFileSize > maxExpectedSize {
519 return stacktrace.Propagate(ErrMalformedChunk, "unexpected compressed section length")
520 }
521
522 a.claimedIndexEntryCount = int(binary.BigEndian.Uint64(chunkBytes[72:]))
523 a.claimedNodeCount = int(binary.BigEndian.Uint64(chunkBytes[80:]))
524
525 // move to the start of the compressed portion
526 chunkBytes = chunkBytes[88:]
527
528 a.initApplier()
529 }
530
531 _, err := a.pipeWriter.Write(chunkBytes)
532 if err != nil {
533 return stacktrace.Propagate(err)
534 }
535
536 isLastChunk := chunkIndex == len(a.expectedChunkHashes)-1
537 if isLastChunk {
538 _ = a.pipeWriter.Close()
539 // wait for importer to finish reading and importing everything
540 a.importerWg.Wait()
541
542 if a.numImportedIndexEntries != a.claimedIndexEntryCount {
543 return stacktrace.Propagate(ErrIndexEntryCountMismatch, "imported index entry count mismatch")
544 }
545
546 if a.numImportedNodes != a.claimedNodeCount {
547 return stacktrace.Propagate(ErrTreeHashMismatch, "imported node count mismatch")
548 }
549
550 err := a.writeIndex.Commit()
551 if err != nil {
552 return stacktrace.Propagate(err)
553 }
554
555 err = a.importer.Commit()
556 if err != nil {
557 if strings.Contains(err.Error(), "invalid node structure") {
558 return stacktrace.Propagate(errors.Join(ErrMalformedChunk, err))
559 }
560 return stacktrace.Propagate(err)
561 }
562
563 err = a.closeCommons()
564 if err != nil {
565 return stacktrace.Propagate(err)
566 }
567 a.done = true
568
569 if !bytes.Equal(a.tree.Hash(), a.expectedFinalHash) {
570 return stacktrace.Propagate(ErrTreeHashMismatch)
571 }
572 }
573
574 return nil
575}
576
577func (a *SnapshotApplier) streamingImporter() {
578 for {
579 if a.numImportedIndexEntries < a.claimedIndexEntryCount {
580 entryHeader := make([]byte, 4+4)
581 n, err := io.ReadFull(a.zstdReader, entryHeader)
582 if err != nil || n != 8 {
583 // err may be EOF here, which is expected
584 return
585 }
586
587 // validate lengths against sensible limits to prevent OOM DoS by malicious third parties
588 keyLength := binary.BigEndian.Uint32(entryHeader[0:4])
589 valueLength := binary.BigEndian.Uint32(entryHeader[4:8])
590 if keyLength > 1024*1024 || valueLength > 1024*1024 {
591 return
592 }
593
594 key := make([]byte, keyLength)
595
596 n, err = io.ReadFull(a.zstdReader, key)
597 if err != nil || n != len(key) {
598 // this shouldn't happen unless the data is corrupt
599 // we can return silently here because since we didn't import all nodes, the tree hash won't match anyway
600 return
601 }
602
603 value := make([]byte, valueLength)
604 n, err = io.ReadFull(a.zstdReader, value)
605 if err != nil || n != len(value) {
606 return
607 }
608
609 err = a.writeIndex.IndexDB().Set(key, value)
610 if err != nil {
611 // we can return silently here because since we didn't import all nodes, the tree hash won't match anyway
612 return
613 }
614
615 a.numImportedIndexEntries++
616 } else {
617 nodeHeader := make([]byte, 9+4+4)
618 n, err := io.ReadFull(a.zstdReader, nodeHeader)
619 if err != nil || n != 9+4+4 {
620 // err may be EOF here, which is expected
621 return
622 }
623
624 // validate lengths against sensible limits to prevent OOM DoS by malicious third parties
625 keyLength := binary.BigEndian.Uint32(nodeHeader[9:13])
626 var key []byte
627 if keyLength != 0xffffffff {
628 if keyLength > 1024*1024 {
629 return
630 }
631 key = make([]byte, keyLength)
632
633 n, err = io.ReadFull(a.zstdReader, key)
634 if err != nil || n != len(key) {
635 // this shouldn't happen unless the data is corrupt
636 // we can return silently here because since we didn't import all nodes, the tree hash won't match anyway
637 return
638 }
639 }
640
641 valueLength := binary.BigEndian.Uint32(nodeHeader[13:17])
642 var value []byte
643 if valueLength != 0xffffffff {
644 if valueLength > 1024*1024 {
645 return
646 }
647 value = make([]byte, valueLength)
648 n, err = io.ReadFull(a.zstdReader, value)
649 if err != nil || n != len(value) {
650 return
651 }
652 }
653
654 err = a.compressImporter.Add(&iavl.ExportNode{
655 Height: int8(nodeHeader[0]),
656 Version: int64(binary.BigEndian.Uint64(nodeHeader[1:9])),
657 Key: key,
658 Value: value,
659 })
660 if err != nil {
661 // this shouldn't happen unless the data is corrupt
662 // we can return silently here because since we didn't import all nodes, the tree hash won't match anyway
663 return
664 }
665 a.numImportedNodes++
666 }
667 }
668}
669
670// Abort cancels the snapshot application
671func (a *SnapshotApplier) Abort() error {
672 err := a.closeCommons()
673 if err != nil {
674 return stacktrace.Propagate(err)
675 }
676
677 if a.tree != nil {
678 err = a.tree.DeleteVersionsFrom(0)
679 if err != nil {
680 return stacktrace.Propagate(err)
681 }
682 }
683
684 return nil
685}
686
687func (a *SnapshotApplier) closeCommons() error {
688 if a.zstdReader != nil {
689 err := a.zstdReader.Close()
690 if err != nil {
691 return stacktrace.Propagate(err)
692 }
693 }
694
695 if a.pipeReader != nil {
696 err := a.pipeReader.Close()
697 if err != nil {
698 return stacktrace.Propagate(err)
699 }
700 }
701
702 if a.pipeWriter != nil {
703 err := a.pipeWriter.Close()
704 if err != nil {
705 return stacktrace.Propagate(err)
706 }
707 }
708
709 if a.writeIndex != nil {
710 err := a.writeIndex.Rollback()
711 if err != nil {
712 return stacktrace.Propagate(err)
713 }
714 }
715
716 if a.importerWg != (sync.WaitGroup{}) {
717 a.importerWg.Wait()
718 }
719
720 a.importer.Close()
721
722 return nil
723}
724
725// Done returns true if the snapshot has been fully applied
726func (a *SnapshotApplier) Done() bool {
727 return a.done
728}
729
730// ReadSnapshotMetadata reads metadata from a snapshot file
731func (s *SnapshotStore) ReadSnapshotMetadata(filename string) (height uint64, format uint32, chunks uint32, hash []byte, chunksumsData []byte, err error) {
732 // Extract height from filename pattern: %020d.snapshot
733 base := filepath.Base(filename)
734 if !strings.HasSuffix(base, ".snapshot") {
735 return 0, 0, 0, nil, nil, stacktrace.NewError("invalid snapshot filename format: %s", filename)
736 }
737 heightStr := strings.TrimSuffix(base, ".snapshot")
738 height, err = strconv.ParseUint(heightStr, 10, 64)
739 if err != nil {
740 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to parse height from filename: %s", filename)
741 }
742
743 // Open and read snapshot file header
744 f, err := os.Open(filename)
745 if err != nil {
746 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to open snapshot file")
747 }
748 defer f.Close()
749
750 // Read file magic (18 bytes)
751 magic := make([]byte, 18)
752 _, err = io.ReadFull(f, magic)
753 if err != nil {
754 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to read file magic")
755 }
756 if string(magic) != SnapshotFileMagic {
757 return 0, 0, 0, nil, nil, stacktrace.NewError("invalid file magic")
758 }
759
760 // Read version bytes (6 bytes)
761 versionBytes := make([]byte, 6)
762 _, err = io.ReadFull(f, versionBytes)
763 if err != nil {
764 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to read version bytes")
765 }
766 format = binary.BigEndian.Uint32(versionBytes[2:])
767
768 // Read height (8 bytes, big-endian)
769 heightBytes := make([]byte, 8)
770 _, err = io.ReadFull(f, heightBytes)
771 if err != nil {
772 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to read height")
773 }
774 fileHeight := int64(binary.BigEndian.Uint64(heightBytes))
775 if fileHeight != int64(height) {
776 return 0, 0, 0, nil, nil, stacktrace.NewError("height mismatch")
777 }
778
779 // Read tree hash (32 bytes)
780 hash = make([]byte, 32)
781 _, err = io.ReadFull(f, hash)
782 if err != nil {
783 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to read tree hash")
784 }
785
786 // Read corresponding chunksums file
787 chunksumsFilename := strings.TrimSuffix(filename, ".snapshot") + ".chunksums"
788 chunksumsData, err = os.ReadFile(chunksumsFilename)
789 if err != nil {
790 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to read chunksums file")
791 }
792
793 // Calculate number of chunks
794 chunks = uint32(len(chunksumsData) / SnapshotChunkHashSize)
795
796 return height, format, chunks, hash, chunksumsData, nil
797}
798
799// LoadSnapshotChunk reads a chunk from a snapshot file at the given height and chunk index
800func (s *SnapshotStore) LoadSnapshotChunk(snapshotDirectory string, height uint64, chunkIndex int) ([]byte, error) {
801 snapshotFilename := filepath.Join(snapshotDirectory, fmt.Sprintf("%020d.snapshot", height))
802
803 // Open the snapshot file
804 f, err := os.Open(snapshotFilename)
805 if err != nil {
806 return nil, stacktrace.Propagate(err, "failed to open snapshot file: %s", snapshotFilename)
807 }
808 defer f.Close()
809
810 // Calculate the offset for the requested chunk (start from beginning of file, including header)
811 offset := int64(chunkIndex) * int64(SnapshotChunkSize)
812 _, err = f.Seek(offset, io.SeekStart)
813 if err != nil {
814 return nil, stacktrace.Propagate(err, "failed to seek to chunk offset")
815 }
816
817 // Read up to SnapshotChunkSize bytes
818 chunkData := make([]byte, SnapshotChunkSize)
819 n, err := f.Read(chunkData)
820 if err != nil && err != io.EOF {
821 return nil, stacktrace.Propagate(err, "failed to read chunk data")
822 }
823
824 // If we read less than SnapshotChunkSize, trim the slice
825 if n < SnapshotChunkSize {
826 chunkData = chunkData[:n]
827 }
828
829 return chunkData, nil
830}
831
832func (s *SnapshotStore) PruneOldSnapshots(snapshotDirectory string, retentionCount int) (int, error) {
833 // Get list of all snapshots sorted by height (newest first)
834 files, err := filepath.Glob(filepath.Join(snapshotDirectory, "*.snapshot"))
835 if err != nil {
836 return 0, stacktrace.Propagate(err, "failed to list snapshots")
837 }
838
839 if len(files) <= retentionCount {
840 return 0, nil
841 }
842
843 // Extract heights from filenames and sort
844 heights := make([]uint64, 0, len(files))
845 for _, f := range files {
846 base := filepath.Base(f)
847 heightStr := strings.TrimSuffix(base, ".snapshot")
848 h, err := strconv.ParseUint(heightStr, 10, 64)
849 if err != nil {
850 continue
851 }
852 heights = append(heights, h)
853 }
854
855 slices.SortFunc(heights, func(a, b uint64) int {
856 return int(int64(b) - int64(a)) // Sort descending (newest first)
857 })
858
859 // Delete snapshots beyond retention count
860 toDelete := heights[retentionCount:]
861 for _, h := range toDelete {
862 snapshotFile := filepath.Join(snapshotDirectory, fmt.Sprintf("%020d.snapshot", h))
863 chunksumsFile := filepath.Join(snapshotDirectory, fmt.Sprintf("%020d.chunksums", h))
864
865 if err := os.Remove(snapshotFile); err != nil && !errors.Is(err, os.ErrNotExist) {
866 return 0, stacktrace.Propagate(err, "failed to delete old snapshot file: %s", snapshotFile)
867 }
868
869 if err := os.Remove(chunksumsFile); err != nil && !errors.Is(err, os.ErrNotExist) {
870 return 0, stacktrace.Propagate(err, "failed to delete old chunksums file: %s", chunksumsFile)
871 }
872 }
873
874 return len(toDelete), nil
875}
876
877func (s *SnapshotStore) MostRecentSnapshotHeight(snapshotDirectory string) (uint64, error) {
878 files, err := filepath.Glob(filepath.Join(snapshotDirectory, "*.snapshot"))
879 if err != nil {
880 return 0, stacktrace.Propagate(err, "failed to list snapshots")
881 }
882
883 var maxHeight uint64
884 for _, f := range files {
885 base := filepath.Base(f)
886 heightStr := strings.TrimSuffix(base, ".snapshot")
887 h, err := strconv.ParseUint(heightStr, 10, 64)
888 if err != nil {
889 continue
890 }
891 if h > maxHeight {
892 maxHeight = h
893 }
894 }
895
896 return maxHeight, nil
897}