A very experimental PLC implementation which uses BFT consensus for decentralization
at main 897 lines 26 kB view raw
1package store 2 3import ( 4 "bufio" 5 "bytes" 6 "crypto/sha256" 7 "encoding/binary" 8 "errors" 9 "fmt" 10 "io" 11 "math" 12 "os" 13 "path/filepath" 14 "slices" 15 "strconv" 16 "strings" 17 "sync" 18 "time" 19 20 cmttypes "github.com/cometbft/cometbft/types" 21 "github.com/cosmos/iavl" 22 "github.com/gbl08ma/stacktrace" 23 "github.com/klauspost/compress/zstd" 24 "github.com/samber/lo" 25 "tangled.org/gbl08ma.com/didplcbft/transaction" 26) 27 28// Snapshot format constants 29const ( 30 SnapshotChunkSize = 10 * 1024 * 1024 // 10 MB 31 SnapshotChunkHashSize = 32 32 SnapshotFormatVersion = 2 33 SnapshotFileMagic = "didplcbft-snapshot" 34) 35 36// BlockHeaderGetter is a function type for retrieving block headers 37type BlockHeaderGetter func(height int64) (cmttypes.Header, error) 38 39// ErrMalformedChunk is returned when a snapshot chunk is invalid 40var ErrMalformedChunk = errors.New("malformed chunk") 41 42// ErrInvalidMetadata is returned when snapshot metadata is invalid 43var ErrInvalidMetadata = errors.New("invalid metadata") 44 45// ErrTreeHashMismatch is returned when the imported tree hash doesn't match expected 46var ErrTreeHashMismatch = errors.New("tree hash mismatch") 47 48// ErrIndexEntryCountMismatch is returned when the imported index entry count doesn't match expected 49var ErrIndexEntryCountMismatch = errors.New("index entry count mismatch") 50 51// SnapshotStore provides snapshot creation and import functionality 52type SnapshotStore struct{} 53 54var Snapshot = &SnapshotStore{} 55 56// Export creates a snapshot of the current state and writes it to the provided writer 57func (s *SnapshotStore) Export(writerSeeker io.WriteSeeker, tx transaction.Read, blockHeaderGetter BlockHeaderGetter, updateValidatorsBlockInterval uint64, numRecentBlockHeaders int64) error { 58 return writeSnapshot(writerSeeker, tx, blockHeaderGetter, updateValidatorsBlockInterval, numRecentBlockHeaders) 59} 60 61// WriteChunkHashes calculates and writes chunk hashes for a snapshot file 62func (s *SnapshotStore) WriteChunkHashes(snapshotFile io.ReadSeeker, w io.Writer) error { 63 return writeChunkHashes(snapshotFile, w) 64} 65 66// CreateSnapshotApplier creates a new SnapshotApplier for applying a snapshot 67func (s *SnapshotStore) CreateSnapshotApplier(tree *iavl.MutableTree, txFactory *transaction.Factory, treeVersion int64, expectedFinalHash []byte, expectedChunkHashes [][]byte) (*SnapshotApplier, error) { 68 writeTx, err := txFactory.ReadWorking(time.Now()).UpgradeForIndexOnly() 69 if err != nil { 70 return nil, stacktrace.Propagate(err) 71 } 72 73 sa := &SnapshotApplier{ 74 writeIndex: writeTx, 75 tree: tree, 76 treeVersion: treeVersion, 77 expectedFinalHash: expectedFinalHash, 78 expectedChunkHashes: expectedChunkHashes, 79 } 80 81 sa.importer, err = sa.tree.Import(treeVersion) 82 if err != nil { 83 return nil, stacktrace.Propagate(err) 84 } 85 86 sa.compressImporter = iavl.NewCompressImporter(sa.importer) 87 88 return sa, nil 89} 90 91func writeSnapshot(writerSeeker io.WriteSeeker, tx transaction.Read, blockHeaderGetter BlockHeaderGetter, updateValidatorsBlockInterval uint64, numRecentBlockHeaders int64) error { 92 it, ok := transaction.UnderlyingImmutableTree(tx.Tree()) 93 if !ok { 94 return stacktrace.NewError("expected immutable tree") 95 } 96 97 writtenUntilReservedFields := 0 98 99 bw := bufio.NewWriter(writerSeeker) 100 101 // file magic and version 102 c, err := bw.Write([]byte(SnapshotFileMagic)) 103 if err != nil { 104 return stacktrace.Propagate(err) 105 } 106 writtenUntilReservedFields += c 107 108 c, err = bw.Write([]byte{0, 0, 0, 0, 0, byte(SnapshotFormatVersion)}) 109 if err != nil { 110 return stacktrace.Propagate(err) 111 } 112 writtenUntilReservedFields += c 113 114 b := make([]byte, 8) 115 binary.BigEndian.PutUint64(b, uint64(it.Version())) 116 c, err = bw.Write(b) 117 if err != nil { 118 return stacktrace.Propagate(err) 119 } 120 writtenUntilReservedFields += c 121 122 c, err = bw.Write(it.Hash()) 123 if err != nil { 124 return stacktrace.Propagate(err) 125 } 126 writtenUntilReservedFields += c 127 128 // reserve space for writing: 129 // - 8 bytes for compressed section size in bytes 130 // - 8 bytes for number of index entries 131 // - 8 bytes for number of nodes 132 sizeOfReservedFields := 8 * 3 133 b = make([]byte, sizeOfReservedFields) 134 _, err = bw.Write(b) 135 if err != nil { 136 return stacktrace.Propagate(err) 137 } 138 139 zstdw, err := zstd.NewWriter(bw, zstd.WithEncoderLevel(zstd.SpeedBetterCompression)) 140 if err != nil { 141 return stacktrace.Propagate(err) 142 } 143 144 numIndexEntries, err := exportIndexEntries(tx, blockHeaderGetter, updateValidatorsBlockInterval, numRecentBlockHeaders, it.Version(), zstdw) 145 if err != nil { 146 return stacktrace.Propagate(err) 147 } 148 149 numNodes, err := exportNodes(it, zstdw) 150 if err != nil { 151 return stacktrace.Propagate(err) 152 } 153 154 err = zstdw.Close() 155 if err != nil { 156 return stacktrace.Propagate(err) 157 } 158 159 err = bw.Flush() 160 if err != nil { 161 return stacktrace.Propagate(err) 162 } 163 164 // find total compressed section size 165 offset, err := writerSeeker.Seek(0, io.SeekCurrent) 166 if err != nil { 167 return stacktrace.Propagate(err) 168 } 169 compressedSectionSize := offset - int64(writtenUntilReservedFields) - int64(sizeOfReservedFields) 170 171 // seek back and write empty header fields 172 173 offset, err = writerSeeker.Seek(int64(writtenUntilReservedFields), io.SeekStart) 174 if err != nil { 175 return stacktrace.Propagate(err) 176 } 177 if offset != int64(writtenUntilReservedFields) { 178 return stacktrace.NewError("unexpected seek result") 179 } 180 181 b = make([]byte, sizeOfReservedFields) 182 binary.BigEndian.PutUint64(b, uint64(compressedSectionSize)) 183 binary.BigEndian.PutUint64(b[8:], uint64(numIndexEntries)) 184 binary.BigEndian.PutUint64(b[16:], uint64(numNodes)) 185 _, err = writerSeeker.Write(b) 186 if err != nil { 187 return stacktrace.Propagate(err) 188 } 189 190 return nil 191} 192 193func exportIndexEntries(tx transaction.Read, blockHeaderGetter BlockHeaderGetter, updateValidatorsBlockInterval uint64, numRecentBlockHeaders int64, treeVersion int64, w io.Writer) (int64, error) { 194 indexDB := tx.IndexDB() 195 numDIDEntries, err := exportIndexDIDEntries(indexDB, treeVersion, w) 196 if err != nil { 197 return 0, stacktrace.Propagate(err) 198 } 199 200 numValidatorParticipationEntries, err := exportIndexValidatorParticipation(indexDB, updateValidatorsBlockInterval, treeVersion, w) 201 if err != nil { 202 return 0, stacktrace.Propagate(err) 203 } 204 205 numRecentBlockHeadersExported, err := exportRecentBlockHeaders(blockHeaderGetter, numRecentBlockHeaders, treeVersion, w) 206 if err != nil { 207 return 0, stacktrace.Propagate(err) 208 } 209 210 return numDIDEntries + numValidatorParticipationEntries + numRecentBlockHeadersExported, nil 211} 212 213func exportIndexDIDEntries(indexDB transaction.IndexReader, treeVersion int64, w io.Writer) (int64, error) { 214 didLogKeyStart := make([]byte, IndexDIDLogKeyLength) 215 didLogKeyStart[0] = IndexDIDLogKeyPrefix 216 didLogKeyEnd := slices.Repeat([]byte{0xff}, IndexDIDLogKeyLength) 217 didLogKeyEnd[0] = IndexDIDLogKeyPrefix 218 219 iterator, err := indexDB.Iterator(didLogKeyStart, didLogKeyEnd) 220 if err != nil { 221 return 0, stacktrace.Propagate(err) 222 } 223 defer iterator.Close() 224 225 numEntries := int64(0) 226 for iterator.Valid() { 227 key := iterator.Key() 228 value := iterator.Value() 229 230 validFromHeight, validToHeight := UnmarshalDIDLogValue(value) 231 if uint64(treeVersion) >= validFromHeight && uint64(treeVersion) <= validToHeight { 232 header := make([]byte, 4+4) 233 binary.BigEndian.PutUint32(header, uint32(len(key))) 234 binary.BigEndian.PutUint32(header[4:], uint32(len(value))) 235 236 _, err = w.Write(header) 237 if err != nil { 238 return 0, stacktrace.Propagate(err) 239 } 240 241 _, err = w.Write(key) 242 if err != nil { 243 return 0, stacktrace.Propagate(err) 244 } 245 246 _, err = w.Write(value) 247 if err != nil { 248 return 0, stacktrace.Propagate(err) 249 } 250 251 numEntries++ 252 } 253 254 iterator.Next() 255 } 256 return numEntries, nil 257} 258 259func exportIndexValidatorParticipation(indexDB transaction.IndexReader, updateValidatorsBlockInterval uint64, treeVersion int64, w io.Writer) (int64, error) { 260 epochHeight := uint64(treeVersion) - uint64(treeVersion)%updateValidatorsBlockInterval 261 startKey := MarshalValidatorVotingActivityKey(uint64(epochHeight), make([]byte, AddressLength)) 262 endKey := MarshalValidatorVotingActivityKey(uint64(epochHeight), slices.Repeat([]byte{0xff}, AddressLength)) 263 264 iterator, err := indexDB.Iterator(startKey, endKey) 265 if err != nil { 266 return 0, stacktrace.Propagate(err) 267 } 268 defer iterator.Close() 269 270 numEntries := int64(0) 271 for iterator.Valid() { 272 key := iterator.Key() 273 value := iterator.Value() 274 275 header := make([]byte, 4+4) 276 binary.BigEndian.PutUint32(header, uint32(len(key))) 277 binary.BigEndian.PutUint32(header[4:], uint32(len(value))) 278 279 _, err = w.Write(header) 280 if err != nil { 281 return 0, stacktrace.Propagate(err) 282 } 283 284 _, err = w.Write(key) 285 if err != nil { 286 return 0, stacktrace.Propagate(err) 287 } 288 289 _, err = w.Write(value) 290 if err != nil { 291 return 0, stacktrace.Propagate(err) 292 } 293 294 numEntries++ 295 296 iterator.Next() 297 } 298 299 if numEntries == 0 { 300 // there should always be at least one active validator 301 return 0, stacktrace.NewError("unexpectedly missing index entries for validator voting participation - treeVersion may be too old to export") 302 } 303 304 return numEntries, nil 305} 306 307func exportRecentBlockHeaders(blockHeaderGetter BlockHeaderGetter, numRecentBlockHeaders int64, treeVersion int64, w io.Writer) (int64, error) { 308 startHeight := treeVersion - numRecentBlockHeaders + 1 // plus one because we want to include the block at treeVersion 309 if startHeight < 1 { 310 startHeight = 1 311 } 312 313 numExportedBlockHeaders := int64(0) 314 for height := startHeight; height <= treeVersion; height++ { 315 blockHeader, err := blockHeaderGetter(height) 316 if err != nil { 317 return 0, stacktrace.Propagate(err) 318 } 319 320 blockHeaderProto := blockHeader.ToProto() 321 blockHeaderBytes, err := blockHeaderProto.Marshal() 322 if err != nil { 323 return 0, stacktrace.Propagate(err) 324 } 325 326 key := make([]byte, 1+8) 327 key[0] = IndexBlockHeaderKeyPrefix 328 binary.BigEndian.PutUint64(key[1:], uint64(height)) 329 330 header := make([]byte, 4+4) 331 binary.BigEndian.PutUint32(header, uint32(len(key))) 332 binary.BigEndian.PutUint32(header[4:], uint32(len(blockHeaderBytes))) 333 334 _, err = w.Write(header) 335 if err != nil { 336 return 0, stacktrace.Propagate(err) 337 } 338 339 _, err = w.Write(key) 340 if err != nil { 341 return 0, stacktrace.Propagate(err) 342 } 343 344 _, err = w.Write(blockHeaderBytes) 345 if err != nil { 346 return 0, stacktrace.Propagate(err) 347 } 348 349 numExportedBlockHeaders++ 350 } 351 352 return numExportedBlockHeaders, nil 353} 354 355func exportNodes(it *iavl.ImmutableTree, w io.Writer) (int64, error) { 356 exporter, err := it.Export() 357 if err != nil { 358 return 0, stacktrace.Propagate(err) 359 } 360 defer exporter.Close() 361 cexporter := iavl.NewCompressExporter(exporter) 362 363 // 1 byte for node height 364 // 8 bytes for node version 365 // 4 bytes for node key length (0xffffffff if node key is nil) 366 // 4 bytes for node value length (0xffffffff if node value is nil) 367 nodeHeaderBuffer := make([]byte, 1+8+4+4) 368 369 numNodes := int64(0) 370 for { 371 node, err := cexporter.Next() 372 if errors.Is(err, iavl.ErrorExportDone) { 373 break 374 } 375 if err != nil { 376 return 0, stacktrace.Propagate(err) 377 } 378 379 nodeHeaderBuffer[0] = byte(node.Height) 380 381 binary.BigEndian.PutUint64(nodeHeaderBuffer[1:], uint64(node.Version)) 382 383 // nil node values are different from 0-byte values 384 binary.BigEndian.PutUint32(nodeHeaderBuffer[9:13], lo.Ternary(node.Key != nil, uint32(len(node.Key)), math.MaxUint32)) 385 binary.BigEndian.PutUint32(nodeHeaderBuffer[13:17], lo.Ternary(node.Value != nil, uint32(len(node.Value)), math.MaxUint32)) 386 387 _, err = w.Write(nodeHeaderBuffer) 388 if err != nil { 389 return 0, stacktrace.Propagate(err) 390 } 391 392 _, err = w.Write(node.Key) 393 if err != nil { 394 return 0, stacktrace.Propagate(err) 395 } 396 397 _, err = w.Write(node.Value) 398 if err != nil { 399 return 0, stacktrace.Propagate(err) 400 } 401 numNodes++ 402 } 403 404 return numNodes, nil 405} 406 407func writeChunkHashes(snapshotFile io.ReadSeeker, w io.Writer) error { 408 bw := bufio.NewWriter(w) 409 defer bw.Flush() 410 411 _, err := snapshotFile.Seek(0, io.SeekStart) 412 if err != nil { 413 return stacktrace.Propagate(err) 414 } 415 416 buf := make([]byte, SnapshotChunkSize) 417 for { 418 n, err := io.ReadFull(snapshotFile, buf) 419 if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { 420 return stacktrace.Propagate(err) 421 } 422 if n == 0 { 423 break 424 } 425 426 hash := sha256.Sum256(buf[:n]) 427 c, err := w.Write(hash[:]) 428 if err != nil { 429 return stacktrace.Propagate(err) 430 } 431 432 if c != SnapshotChunkHashSize { 433 return stacktrace.NewError("unexpected chunk hash size") 434 } 435 436 if n < SnapshotChunkSize { 437 break 438 } 439 } 440 441 return nil 442} 443 444// SnapshotApplier handles applying snapshot chunks 445type SnapshotApplier struct { 446 writeIndex transaction.WriteIndex 447 tree *iavl.MutableTree 448 treeVersion int64 449 expectedFinalHash []byte 450 expectedChunkHashes [][]byte 451 452 pipeWriter *io.PipeWriter 453 pipeReader *io.PipeReader 454 zstdReader io.ReadCloser 455 456 importer *iavl.Importer 457 compressImporter iavl.NodeImporter 458 importerWg sync.WaitGroup 459 460 numImportedNodes int 461 claimedNodeCount int 462 numImportedIndexEntries int 463 claimedIndexEntryCount int 464 done bool 465} 466 467func (a *SnapshotApplier) initApplier() { 468 pipeReader, pipeWriter := io.Pipe() 469 470 zstdReader, err := zstd.NewReader(pipeReader) 471 if err != nil { 472 _ = pipeReader.Close() 473 _ = pipeWriter.Close() 474 return 475 } 476 477 a.pipeWriter = pipeWriter 478 a.pipeReader = pipeReader 479 a.zstdReader = zstdReader.IOReadCloser() 480 481 a.importerWg.Go(a.streamingImporter) 482} 483 484func (a *SnapshotApplier) Apply(chunkIndex int, chunkBytes []byte) error { 485 if len(chunkBytes) > SnapshotChunkSize { 486 return stacktrace.Propagate(ErrMalformedChunk, "chunk too large") 487 } 488 hash := sha256.Sum256(chunkBytes) 489 if !bytes.Equal(a.expectedChunkHashes[chunkIndex], hash[:]) { 490 return stacktrace.Propagate(ErrMalformedChunk, "hash mismatch") 491 } 492 493 if chunkIndex == 0 { 494 if len(chunkBytes) < 88 { 495 return stacktrace.Propagate(ErrMalformedChunk, "chunk too small") 496 } 497 498 if string(chunkBytes[0:18]) != SnapshotFileMagic { 499 return stacktrace.Propagate(ErrMalformedChunk, "invalid file magic") 500 } 501 502 if binary.BigEndian.Uint32(chunkBytes[20:]) != SnapshotFormatVersion { 503 return stacktrace.Propagate(ErrMalformedChunk, "invalid snapshot format") 504 } 505 506 if binary.BigEndian.Uint64(chunkBytes[24:]) != uint64(a.treeVersion) { 507 return stacktrace.Propagate(ErrMalformedChunk, "mismatched tree version") 508 } 509 510 if !bytes.Equal(chunkBytes[32:64], a.expectedFinalHash) { 511 return stacktrace.Propagate(ErrMalformedChunk, "mismatched declared tree hash") 512 } 513 514 declaredFileSize := 88 + binary.BigEndian.Uint64(chunkBytes[64:]) 515 minExpectedSize := uint64((len(a.expectedChunkHashes) - 1) * SnapshotChunkSize) 516 maxExpectedSize := uint64(len(a.expectedChunkHashes) * SnapshotChunkSize) 517 if declaredFileSize < minExpectedSize || 518 declaredFileSize > maxExpectedSize { 519 return stacktrace.Propagate(ErrMalformedChunk, "unexpected compressed section length") 520 } 521 522 a.claimedIndexEntryCount = int(binary.BigEndian.Uint64(chunkBytes[72:])) 523 a.claimedNodeCount = int(binary.BigEndian.Uint64(chunkBytes[80:])) 524 525 // move to the start of the compressed portion 526 chunkBytes = chunkBytes[88:] 527 528 a.initApplier() 529 } 530 531 _, err := a.pipeWriter.Write(chunkBytes) 532 if err != nil { 533 return stacktrace.Propagate(err) 534 } 535 536 isLastChunk := chunkIndex == len(a.expectedChunkHashes)-1 537 if isLastChunk { 538 _ = a.pipeWriter.Close() 539 // wait for importer to finish reading and importing everything 540 a.importerWg.Wait() 541 542 if a.numImportedIndexEntries != a.claimedIndexEntryCount { 543 return stacktrace.Propagate(ErrIndexEntryCountMismatch, "imported index entry count mismatch") 544 } 545 546 if a.numImportedNodes != a.claimedNodeCount { 547 return stacktrace.Propagate(ErrTreeHashMismatch, "imported node count mismatch") 548 } 549 550 err := a.writeIndex.Commit() 551 if err != nil { 552 return stacktrace.Propagate(err) 553 } 554 555 err = a.importer.Commit() 556 if err != nil { 557 if strings.Contains(err.Error(), "invalid node structure") { 558 return stacktrace.Propagate(errors.Join(ErrMalformedChunk, err)) 559 } 560 return stacktrace.Propagate(err) 561 } 562 563 err = a.closeCommons() 564 if err != nil { 565 return stacktrace.Propagate(err) 566 } 567 a.done = true 568 569 if !bytes.Equal(a.tree.Hash(), a.expectedFinalHash) { 570 return stacktrace.Propagate(ErrTreeHashMismatch) 571 } 572 } 573 574 return nil 575} 576 577func (a *SnapshotApplier) streamingImporter() { 578 for { 579 if a.numImportedIndexEntries < a.claimedIndexEntryCount { 580 entryHeader := make([]byte, 4+4) 581 n, err := io.ReadFull(a.zstdReader, entryHeader) 582 if err != nil || n != 8 { 583 // err may be EOF here, which is expected 584 return 585 } 586 587 // validate lengths against sensible limits to prevent OOM DoS by malicious third parties 588 keyLength := binary.BigEndian.Uint32(entryHeader[0:4]) 589 valueLength := binary.BigEndian.Uint32(entryHeader[4:8]) 590 if keyLength > 1024*1024 || valueLength > 1024*1024 { 591 return 592 } 593 594 key := make([]byte, keyLength) 595 596 n, err = io.ReadFull(a.zstdReader, key) 597 if err != nil || n != len(key) { 598 // this shouldn't happen unless the data is corrupt 599 // we can return silently here because since we didn't import all nodes, the tree hash won't match anyway 600 return 601 } 602 603 value := make([]byte, valueLength) 604 n, err = io.ReadFull(a.zstdReader, value) 605 if err != nil || n != len(value) { 606 return 607 } 608 609 err = a.writeIndex.IndexDB().Set(key, value) 610 if err != nil { 611 // we can return silently here because since we didn't import all nodes, the tree hash won't match anyway 612 return 613 } 614 615 a.numImportedIndexEntries++ 616 } else { 617 nodeHeader := make([]byte, 9+4+4) 618 n, err := io.ReadFull(a.zstdReader, nodeHeader) 619 if err != nil || n != 9+4+4 { 620 // err may be EOF here, which is expected 621 return 622 } 623 624 // validate lengths against sensible limits to prevent OOM DoS by malicious third parties 625 keyLength := binary.BigEndian.Uint32(nodeHeader[9:13]) 626 var key []byte 627 if keyLength != 0xffffffff { 628 if keyLength > 1024*1024 { 629 return 630 } 631 key = make([]byte, keyLength) 632 633 n, err = io.ReadFull(a.zstdReader, key) 634 if err != nil || n != len(key) { 635 // this shouldn't happen unless the data is corrupt 636 // we can return silently here because since we didn't import all nodes, the tree hash won't match anyway 637 return 638 } 639 } 640 641 valueLength := binary.BigEndian.Uint32(nodeHeader[13:17]) 642 var value []byte 643 if valueLength != 0xffffffff { 644 if valueLength > 1024*1024 { 645 return 646 } 647 value = make([]byte, valueLength) 648 n, err = io.ReadFull(a.zstdReader, value) 649 if err != nil || n != len(value) { 650 return 651 } 652 } 653 654 err = a.compressImporter.Add(&iavl.ExportNode{ 655 Height: int8(nodeHeader[0]), 656 Version: int64(binary.BigEndian.Uint64(nodeHeader[1:9])), 657 Key: key, 658 Value: value, 659 }) 660 if err != nil { 661 // this shouldn't happen unless the data is corrupt 662 // we can return silently here because since we didn't import all nodes, the tree hash won't match anyway 663 return 664 } 665 a.numImportedNodes++ 666 } 667 } 668} 669 670// Abort cancels the snapshot application 671func (a *SnapshotApplier) Abort() error { 672 err := a.closeCommons() 673 if err != nil { 674 return stacktrace.Propagate(err) 675 } 676 677 if a.tree != nil { 678 err = a.tree.DeleteVersionsFrom(0) 679 if err != nil { 680 return stacktrace.Propagate(err) 681 } 682 } 683 684 return nil 685} 686 687func (a *SnapshotApplier) closeCommons() error { 688 if a.zstdReader != nil { 689 err := a.zstdReader.Close() 690 if err != nil { 691 return stacktrace.Propagate(err) 692 } 693 } 694 695 if a.pipeReader != nil { 696 err := a.pipeReader.Close() 697 if err != nil { 698 return stacktrace.Propagate(err) 699 } 700 } 701 702 if a.pipeWriter != nil { 703 err := a.pipeWriter.Close() 704 if err != nil { 705 return stacktrace.Propagate(err) 706 } 707 } 708 709 if a.writeIndex != nil { 710 err := a.writeIndex.Rollback() 711 if err != nil { 712 return stacktrace.Propagate(err) 713 } 714 } 715 716 if a.importerWg != (sync.WaitGroup{}) { 717 a.importerWg.Wait() 718 } 719 720 a.importer.Close() 721 722 return nil 723} 724 725// Done returns true if the snapshot has been fully applied 726func (a *SnapshotApplier) Done() bool { 727 return a.done 728} 729 730// ReadSnapshotMetadata reads metadata from a snapshot file 731func (s *SnapshotStore) ReadSnapshotMetadata(filename string) (height uint64, format uint32, chunks uint32, hash []byte, chunksumsData []byte, err error) { 732 // Extract height from filename pattern: %020d.snapshot 733 base := filepath.Base(filename) 734 if !strings.HasSuffix(base, ".snapshot") { 735 return 0, 0, 0, nil, nil, stacktrace.NewError("invalid snapshot filename format: %s", filename) 736 } 737 heightStr := strings.TrimSuffix(base, ".snapshot") 738 height, err = strconv.ParseUint(heightStr, 10, 64) 739 if err != nil { 740 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to parse height from filename: %s", filename) 741 } 742 743 // Open and read snapshot file header 744 f, err := os.Open(filename) 745 if err != nil { 746 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to open snapshot file") 747 } 748 defer f.Close() 749 750 // Read file magic (18 bytes) 751 magic := make([]byte, 18) 752 _, err = io.ReadFull(f, magic) 753 if err != nil { 754 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to read file magic") 755 } 756 if string(magic) != SnapshotFileMagic { 757 return 0, 0, 0, nil, nil, stacktrace.NewError("invalid file magic") 758 } 759 760 // Read version bytes (6 bytes) 761 versionBytes := make([]byte, 6) 762 _, err = io.ReadFull(f, versionBytes) 763 if err != nil { 764 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to read version bytes") 765 } 766 format = binary.BigEndian.Uint32(versionBytes[2:]) 767 768 // Read height (8 bytes, big-endian) 769 heightBytes := make([]byte, 8) 770 _, err = io.ReadFull(f, heightBytes) 771 if err != nil { 772 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to read height") 773 } 774 fileHeight := int64(binary.BigEndian.Uint64(heightBytes)) 775 if fileHeight != int64(height) { 776 return 0, 0, 0, nil, nil, stacktrace.NewError("height mismatch") 777 } 778 779 // Read tree hash (32 bytes) 780 hash = make([]byte, 32) 781 _, err = io.ReadFull(f, hash) 782 if err != nil { 783 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to read tree hash") 784 } 785 786 // Read corresponding chunksums file 787 chunksumsFilename := strings.TrimSuffix(filename, ".snapshot") + ".chunksums" 788 chunksumsData, err = os.ReadFile(chunksumsFilename) 789 if err != nil { 790 return 0, 0, 0, nil, nil, stacktrace.Propagate(err, "failed to read chunksums file") 791 } 792 793 // Calculate number of chunks 794 chunks = uint32(len(chunksumsData) / SnapshotChunkHashSize) 795 796 return height, format, chunks, hash, chunksumsData, nil 797} 798 799// LoadSnapshotChunk reads a chunk from a snapshot file at the given height and chunk index 800func (s *SnapshotStore) LoadSnapshotChunk(snapshotDirectory string, height uint64, chunkIndex int) ([]byte, error) { 801 snapshotFilename := filepath.Join(snapshotDirectory, fmt.Sprintf("%020d.snapshot", height)) 802 803 // Open the snapshot file 804 f, err := os.Open(snapshotFilename) 805 if err != nil { 806 return nil, stacktrace.Propagate(err, "failed to open snapshot file: %s", snapshotFilename) 807 } 808 defer f.Close() 809 810 // Calculate the offset for the requested chunk (start from beginning of file, including header) 811 offset := int64(chunkIndex) * int64(SnapshotChunkSize) 812 _, err = f.Seek(offset, io.SeekStart) 813 if err != nil { 814 return nil, stacktrace.Propagate(err, "failed to seek to chunk offset") 815 } 816 817 // Read up to SnapshotChunkSize bytes 818 chunkData := make([]byte, SnapshotChunkSize) 819 n, err := f.Read(chunkData) 820 if err != nil && err != io.EOF { 821 return nil, stacktrace.Propagate(err, "failed to read chunk data") 822 } 823 824 // If we read less than SnapshotChunkSize, trim the slice 825 if n < SnapshotChunkSize { 826 chunkData = chunkData[:n] 827 } 828 829 return chunkData, nil 830} 831 832func (s *SnapshotStore) PruneOldSnapshots(snapshotDirectory string, retentionCount int) (int, error) { 833 // Get list of all snapshots sorted by height (newest first) 834 files, err := filepath.Glob(filepath.Join(snapshotDirectory, "*.snapshot")) 835 if err != nil { 836 return 0, stacktrace.Propagate(err, "failed to list snapshots") 837 } 838 839 if len(files) <= retentionCount { 840 return 0, nil 841 } 842 843 // Extract heights from filenames and sort 844 heights := make([]uint64, 0, len(files)) 845 for _, f := range files { 846 base := filepath.Base(f) 847 heightStr := strings.TrimSuffix(base, ".snapshot") 848 h, err := strconv.ParseUint(heightStr, 10, 64) 849 if err != nil { 850 continue 851 } 852 heights = append(heights, h) 853 } 854 855 slices.SortFunc(heights, func(a, b uint64) int { 856 return int(int64(b) - int64(a)) // Sort descending (newest first) 857 }) 858 859 // Delete snapshots beyond retention count 860 toDelete := heights[retentionCount:] 861 for _, h := range toDelete { 862 snapshotFile := filepath.Join(snapshotDirectory, fmt.Sprintf("%020d.snapshot", h)) 863 chunksumsFile := filepath.Join(snapshotDirectory, fmt.Sprintf("%020d.chunksums", h)) 864 865 if err := os.Remove(snapshotFile); err != nil && !errors.Is(err, os.ErrNotExist) { 866 return 0, stacktrace.Propagate(err, "failed to delete old snapshot file: %s", snapshotFile) 867 } 868 869 if err := os.Remove(chunksumsFile); err != nil && !errors.Is(err, os.ErrNotExist) { 870 return 0, stacktrace.Propagate(err, "failed to delete old chunksums file: %s", chunksumsFile) 871 } 872 } 873 874 return len(toDelete), nil 875} 876 877func (s *SnapshotStore) MostRecentSnapshotHeight(snapshotDirectory string) (uint64, error) { 878 files, err := filepath.Glob(filepath.Join(snapshotDirectory, "*.snapshot")) 879 if err != nil { 880 return 0, stacktrace.Propagate(err, "failed to list snapshots") 881 } 882 883 var maxHeight uint64 884 for _, f := range files { 885 base := filepath.Base(f) 886 heightStr := strings.TrimSuffix(base, ".snapshot") 887 h, err := strconv.ParseUint(heightStr, 10, 64) 888 if err != nil { 889 continue 890 } 891 if h > maxHeight { 892 maxHeight = h 893 } 894 } 895 896 return maxHeight, nil 897}