diff --git a/l1infotreesync/downloader.go b/l1infotreesync/downloader.go index 2051f7b58..83d343291 100644 --- a/l1infotreesync/downloader.go +++ b/l1infotreesync/downloader.go @@ -86,7 +86,7 @@ func buildAppender(client EthClienter, globalExitRoot, rollupManager common.Addr l, err, ) } - log.Infof("updateL1InfoTreeSignatureV2: expected root: %s", common.Bytes2Hex(l1InfoTreeUpdate.CurrentL1InfoRoot[:])) + log.Infof("updateL1InfoTreeSignatureV2: expected root: %s", common.BytesToHash(l1InfoTreeUpdate.CurrentL1InfoRoot[:]).String()) return nil } diff --git a/l1infotreesync/e2e_test.go b/l1infotreesync/e2e_test.go index 10ef34e7f..ac36f8d20 100644 --- a/l1infotreesync/e2e_test.go +++ b/l1infotreesync/e2e_test.go @@ -173,7 +173,7 @@ func TestWithReorgs(t *testing.T) { rd, err := reorgdetector.New(client.Client(), reorgdetector.Config{DBPath: dbPathReorg, CheckReorgsInterval: cdktypes.NewDuration(time.Millisecond * 30)}) require.NoError(t, err) require.NoError(t, rd.Start(ctx)) - syncer, err := l1infotreesync.New(ctx, dbPathSyncer, gerAddr, verifyAddr, 10, etherman.LatestBlock, rd, client.Client(), time.Millisecond, 0, time.Second, 5) + syncer, err := l1infotreesync.New(ctx, dbPathSyncer, gerAddr, verifyAddr, 10, etherman.LatestBlock, rd, client.Client(), time.Millisecond, 0, time.Second, 25) require.NoError(t, err) go syncer.Start(ctx) @@ -272,12 +272,12 @@ func TestWithReorgs(t *testing.T) { func TestStressAndReorgs(t *testing.T) { const ( - totalIterations = 200 // Have tested with much larger number (+10k) - enableReorgs = true // test fails when set to true - reorgEveryXIterations = 53 - maxReorgDepth = 12 - maxEventsPerBlock = 7 - maxRollups = 31 + totalIterations = 2 // Have tested with much larger number (+10k) + blocksInIteration = 140 + reorgEveryXIterations = 70 + reorgSizeInBlocks = 2 + maxRollupID = 31 + extraBlocksToMine = 10 ) ctx := context.Background() @@ -292,51 +292,49 @@ func TestStressAndReorgs(t *testing.T) { rd, err := reorgdetector.New(client.Client(), reorgdetector.Config{DBPath: dbPathReorg, CheckReorgsInterval: cdktypes.NewDuration(time.Millisecond * 100)}) require.NoError(t, err) require.NoError(t, rd.Start(ctx)) - syncer, err := l1infotreesync.New(ctx, dbPathSyncer, gerAddr, verifyAddr, 10, etherman.LatestBlock, rd, client.Client(), time.Millisecond, 0, time.Second, 5) + syncer, err := l1infotreesync.New(ctx, dbPathSyncer, gerAddr, verifyAddr, 10, etherman.LatestBlock, rd, client.Client(), time.Millisecond, 0, time.Second, 100) require.NoError(t, err) go syncer.Start(ctx) - var extraBlocksToMine int - for i := 0; i < totalIterations; i++ { - for j := 0; j < i%maxEventsPerBlock; j++ { - switch j % 3 { - case 0: // Update L1 Info Tree - _, err := gerSc.UpdateExitRoot(auth, common.HexToHash(strconv.Itoa(i))) - require.NoError(t, err) - case 1: // Update L1 Info Tree + Rollup Exit Tree - newLocalExitRoot := common.HexToHash(strconv.Itoa(i) + "ffff" + strconv.Itoa(j)) - _, err := verifySC.VerifyBatches(auth, 1+uint32(i%maxRollups), 0, newLocalExitRoot, common.Hash{}, true) - require.NoError(t, err) - case 2: // Update Rollup Exit Tree - newLocalExitRoot := common.HexToHash(strconv.Itoa(i) + "ffff" + strconv.Itoa(j)) - _, err := verifySC.VerifyBatches(auth, 1+uint32(i%maxRollups), 0, newLocalExitRoot, common.Hash{}, false) - require.NoError(t, err) - } - } + updateL1InfoTreeAndRollupExitTree := func(i, j int, rollupID uint32) { + // Update L1 Info Tree + _, err := gerSc.UpdateExitRoot(auth, common.HexToHash(strconv.Itoa(i))) + require.NoError(t, err) - time.Sleep(time.Microsecond * 30) // Sleep just enough for goroutine to switch + // Update L1 Info Tree + Rollup Exit Tree + newLocalExitRoot := common.HexToHash(strconv.Itoa(i) + "ffff" + strconv.Itoa(j)) + _, err = verifySC.VerifyBatches(auth, rollupID, 0, newLocalExitRoot, common.Hash{}, true) + require.NoError(t, err) - if enableReorgs && i > 0 && i%reorgEveryXIterations == 0 { - reorgDepth := i%maxReorgDepth + 1 - extraBlocksToMine += reorgDepth + 1 - currentBlockNum, err := client.Client().BlockNumber(ctx) - require.NoError(t, err) + // Update Rollup Exit Tree + newLocalExitRoot = common.HexToHash(strconv.Itoa(i) + "fffa" + strconv.Itoa(j)) + _, err = verifySC.VerifyBatches(auth, rollupID, 0, newLocalExitRoot, common.Hash{}, false) + require.NoError(t, err) + } - targetReorgBlockNum := currentBlockNum - if uint64(reorgDepth) <= currentBlockNum { - targetReorgBlockNum -= uint64(reorgDepth) - } + for i := 1; i <= totalIterations; i++ { + for j := 1; j <= blocksInIteration; j++ { + commitBlocks(t, client, 1, time.Millisecond*10) - if targetReorgBlockNum < currentBlockNum { // we are dealing with uints... - reorgBlock, err := client.Client().BlockByNumber(ctx, big.NewInt(int64(targetReorgBlockNum))) + if j%reorgEveryXIterations == 0 { + currentBlockNum, err := client.Client().BlockNumber(ctx) require.NoError(t, err) - err = client.Fork(reorgBlock.Hash()) + + block, err := client.Client().BlockByNumber(ctx, big.NewInt(int64(currentBlockNum-reorgSizeInBlocks))) require.NoError(t, err) + reorgFrom := block.Hash() + err = client.Fork(reorgFrom) + require.NoError(t, err) + + fmt.Println("reorg from block:", block.Number().Uint64()+1) + } else { + updateL1InfoTreeAndRollupExitTree(i, j, uint32(j%maxRollupID)+1) } } } - commitBlocks(t, client, extraBlocksToMine, time.Millisecond*100) + commitBlocks(t, client, 1, time.Millisecond*10) + waitForSyncerToCatchUp(ctx, t, syncer, client) // Assert rollup exit root @@ -356,8 +354,9 @@ func TestStressAndReorgs(t *testing.T) { info, err := syncer.GetInfoByIndex(ctx, lastRoot.Index) require.NoError(t, err, fmt.Sprintf("index: %d", lastRoot.Index)) - require.Equal(t, common.Hash(expectedL1InfoRoot), lastRoot.Hash) + t.Logf("lastRoot: %+v", lastRoot) require.Equal(t, common.Hash(expectedGER), info.GlobalExitRoot, fmt.Sprintf("%+v", info)) + require.Equal(t, common.Hash(expectedL1InfoRoot), lastRoot.Hash) } func waitForSyncerToCatchUp(ctx context.Context, t *testing.T, syncer *l1infotreesync.L1InfoTreeSync, client *simulated.Backend) { @@ -366,7 +365,7 @@ func waitForSyncerToCatchUp(ctx context.Context, t *testing.T, syncer *l1infotre syncerUpToDate := false var errMsg string - for i := 0; i < 50; i++ { + for i := 0; i < 200; i++ { lpb, err := syncer.GetLastProcessedBlock(ctx) require.NoError(t, err) lb, err := client.Client().BlockNumber(ctx) diff --git a/l1infotreesync/processor.go b/l1infotreesync/processor.go index a672c5eff..a61b9e867 100644 --- a/l1infotreesync/processor.go +++ b/l1infotreesync/processor.go @@ -211,6 +211,11 @@ func (p *processor) Reorg(ctx context.Context, firstReorgedBlock uint64) error { return err } + _, err = tx.Exec(`DELETE FROM l1info_leaf WHERE block_num >= $1;`, firstReorgedBlock) + if err != nil { + return err + } + if err = p.l1InfoTree.Reorg(tx, firstReorgedBlock); err != nil { return err } @@ -240,7 +245,7 @@ func (p *processor) ProcessBlock(ctx context.Context, b sync.Block) error { } }() - if _, err := tx.Exec(`INSERT INTO block (num) VALUES ($1)`, b.Num); err != nil { + if _, err = tx.Exec(`INSERT INTO block (num) VALUES ($1)`, b.Num); err != nil { return fmt.Errorf("err: %w", err) } diff --git a/lastgersync/evmdownloader.go b/lastgersync/evmdownloader.go index 97235c289..e76bb5788 100644 --- a/lastgersync/evmdownloader.go +++ b/lastgersync/evmdownloader.go @@ -105,7 +105,11 @@ func (d *downloader) Download(ctx context.Context, fromBlock uint64, downloadedC break } - blockHeader := d.GetBlockHeader(ctx, lastBlock) + blockHeader, isCanceled := d.GetBlockHeader(ctx, lastBlock) + if isCanceled { + return + } + block := &sync.EVMBlock{ EVMBlockHeader: sync.EVMBlockHeader{ Num: blockHeader.Num, diff --git a/reorgdetector/reorgdetector.go b/reorgdetector/reorgdetector.go index 8efde6b9a..5e0d17383 100644 --- a/reorgdetector/reorgdetector.go +++ b/reorgdetector/reorgdetector.go @@ -120,12 +120,21 @@ func (rd *ReorgDetector) detectReorgInTrackedList(ctx context.Context) error { errGroup errgroup.Group ) - rd.trackedBlocksLock.Lock() - defer rd.trackedBlocksLock.Unlock() + subscriberIDs := rd.getSubscriberIDs() - for id, hdrs := range rd.trackedBlocks { + for _, id := range subscriberIDs { id := id - hdrs := hdrs + + // This is done like this because of a possible deadlock between AddBlocksToTrack and detectReorgInTrackedList + // because detectReorgInTrackedList would take the trackedBlocksLock and try to notify the subscriber in case of a reorg + // but the subscriber would be trying to add a block to track and save it to trackedBlocks, resulting in a deadlock + rd.trackedBlocksLock.RLock() + hdrs, ok := rd.trackedBlocks[id] + rd.trackedBlocksLock.RUnlock() + + if !ok { + continue + } errGroup.Go(func() error { headers := hdrs.getSorted() @@ -153,6 +162,8 @@ func (rd *ReorgDetector) detectReorgInTrackedList(ctx context.Context) error { continue } + log.Info("[ReorgDetector] Reorg detected", "blockNum", hdr.Num) + // Notify the subscriber about the reorg rd.notifySubscriber(id, hdr) diff --git a/reorgdetector/reorgdetector_db.go b/reorgdetector/reorgdetector_db.go index 3174cbc02..552f9d695 100644 --- a/reorgdetector/reorgdetector_db.go +++ b/reorgdetector/reorgdetector_db.go @@ -53,6 +53,11 @@ func (rd *ReorgDetector) getTrackedBlocks(ctx context.Context) (map[string]*head // saveTrackedBlock saves the tracked block for a subscriber in db and in memory func (rd *ReorgDetector) saveTrackedBlock(ctx context.Context, id string, b header) error { + rd.trackedBlocksLock.Lock() + + // this has to go after the lock, because of a possible deadlock between AddBlocksToTrack and detectReorgInTrackedList + // because AddBlocksToTrack would start a transaction on db, but detectReorgInTrackedList would lock the trackedBlocksLock + // and then try to start a transaction on db, resulting in a deadlock tx, err := rd.db.BeginRw(ctx) if err != nil { return err @@ -60,7 +65,6 @@ func (rd *ReorgDetector) saveTrackedBlock(ctx context.Context, id string, b head defer tx.Rollback() - rd.trackedBlocksLock.Lock() hdrs, ok := rd.trackedBlocks[id] if !ok || hdrs.isEmpty() { hdrs = newHeadersList(b) diff --git a/reorgdetector/reorgdetector_sub.go b/reorgdetector/reorgdetector_sub.go index 675a81c5f..ebdb97e95 100644 --- a/reorgdetector/reorgdetector_sub.go +++ b/reorgdetector/reorgdetector_sub.go @@ -34,9 +34,24 @@ func (rd *ReorgDetector) Subscribe(id string) (*Subscription, error) { func (rd *ReorgDetector) notifySubscriber(id string, startingBlock header) { // Notify subscriber about this particular reorg rd.subscriptionsLock.RLock() - if sub, ok := rd.subscriptions[id]; ok { + sub, ok := rd.subscriptions[id] + rd.subscriptionsLock.RUnlock() + + if ok { sub.ReorgedBlock <- startingBlock.Num <-sub.ReorgProcessed } - rd.subscriptionsLock.RUnlock() +} + +// getSubscriberIDs returns a list of subscriber IDs +func (rd *ReorgDetector) getSubscriberIDs() []string { + rd.subscriptionsLock.RLock() + defer rd.subscriptionsLock.RUnlock() + + var ids []string + for id := range rd.subscriptions { + ids = append(ids, id) + } + + return ids } diff --git a/sync/evmdownloader.go b/sync/evmdownloader.go index c9c4e6610..13539f2f3 100644 --- a/sync/evmdownloader.go +++ b/sync/evmdownloader.go @@ -2,6 +2,7 @@ package sync import ( "context" + "errors" "math/big" "time" @@ -24,7 +25,7 @@ type EVMDownloaderInterface interface { WaitForNewBlocks(ctx context.Context, lastBlockSeen uint64) (newLastBlock uint64) GetEventsByBlockRange(ctx context.Context, fromBlock, toBlock uint64) []EVMBlock GetLogs(ctx context.Context, fromBlock, toBlock uint64) []types.Log - GetBlockHeader(ctx context.Context, blockNum uint64) EVMBlockHeader + GetBlockHeader(ctx context.Context, blockNum uint64) (EVMBlockHeader, bool) } type LogAppenderMap map[common.Hash]func(b *EVMBlock, l types.Log) error @@ -101,8 +102,13 @@ func (d *EVMDownloader) Download(ctx context.Context, fromBlock uint64, download if len(blocks) == 0 || blocks[len(blocks)-1].Num < toBlock { // Indicate the last downloaded block if there are not events on it d.log.Debugf("sending block %d to the driver (without events)", toBlock) + header, isCanceled := d.GetBlockHeader(ctx, toBlock) + if isCanceled { + return + } + downloadedCh <- EVMBlock{ - EVMBlockHeader: d.GetBlockHeader(ctx, toBlock), + EVMBlockHeader: header, } } fromBlock = toBlock + 1 @@ -170,44 +176,53 @@ func (d *EVMDownloaderImplementation) WaitForNewBlocks( } func (d *EVMDownloaderImplementation) GetEventsByBlockRange(ctx context.Context, fromBlock, toBlock uint64) []EVMBlock { - blocks := []EVMBlock{} - logs := d.GetLogs(ctx, fromBlock, toBlock) - for _, l := range logs { - if len(blocks) == 0 || blocks[len(blocks)-1].Num < l.BlockNumber { - b := d.GetBlockHeader(ctx, l.BlockNumber) - if b.Hash != l.BlockHash { - d.log.Infof( - "there has been a block hash change between the event query and the block query "+ - "for block %d: %s vs %s. Retrying.", - l.BlockNumber, b.Hash, l.BlockHash, - ) - return d.GetEventsByBlockRange(ctx, fromBlock, toBlock) + select { + case <-ctx.Done(): + return nil + default: + blocks := []EVMBlock{} + logs := d.GetLogs(ctx, fromBlock, toBlock) + for _, l := range logs { + if len(blocks) == 0 || blocks[len(blocks)-1].Num < l.BlockNumber { + b, canceled := d.GetBlockHeader(ctx, l.BlockNumber) + if canceled { + return nil + } + + if b.Hash != l.BlockHash { + d.log.Infof( + "there has been a block hash change between the event query and the block query "+ + "for block %d: %s vs %s. Retrying.", + l.BlockNumber, b.Hash, l.BlockHash, + ) + return d.GetEventsByBlockRange(ctx, fromBlock, toBlock) + } + blocks = append(blocks, EVMBlock{ + EVMBlockHeader: EVMBlockHeader{ + Num: l.BlockNumber, + Hash: l.BlockHash, + Timestamp: b.Timestamp, + ParentHash: b.ParentHash, + }, + Events: []interface{}{}, + }) } - blocks = append(blocks, EVMBlock{ - EVMBlockHeader: EVMBlockHeader{ - Num: l.BlockNumber, - Hash: l.BlockHash, - Timestamp: b.Timestamp, - ParentHash: b.ParentHash, - }, - Events: []interface{}{}, - }) - } - for { - attempts := 0 - err := d.appender[l.Topics[0]](&blocks[len(blocks)-1], l) - if err != nil { - attempts++ - d.log.Error("error trying to append log: ", err) - d.rh.Handle("getLogs", attempts) - continue + for { + attempts := 0 + err := d.appender[l.Topics[0]](&blocks[len(blocks)-1], l) + if err != nil { + attempts++ + d.log.Error("error trying to append log: ", err) + d.rh.Handle("getLogs", attempts) + continue + } + break } - break } - } - return blocks + return blocks + } } func (d *EVMDownloaderImplementation) GetLogs(ctx context.Context, fromBlock, toBlock uint64) []types.Log { @@ -224,6 +239,11 @@ func (d *EVMDownloaderImplementation) GetLogs(ctx context.Context, fromBlock, to for { unfilteredLogs, err = d.ethClient.FilterLogs(ctx, query) if err != nil { + if errors.Is(err, context.Canceled) { + // context is canceled, we don't want to fatal on max attempts in this case + return nil + } + attempts++ d.log.Error("error calling FilterLogs to eth client: ", err) d.rh.Handle("getLogs", attempts) @@ -243,11 +263,16 @@ func (d *EVMDownloaderImplementation) GetLogs(ctx context.Context, fromBlock, to return logs } -func (d *EVMDownloaderImplementation) GetBlockHeader(ctx context.Context, blockNum uint64) EVMBlockHeader { +func (d *EVMDownloaderImplementation) GetBlockHeader(ctx context.Context, blockNum uint64) (EVMBlockHeader, bool) { attempts := 0 for { header, err := d.ethClient.HeaderByNumber(ctx, new(big.Int).SetUint64(blockNum)) if err != nil { + if errors.Is(err, context.Canceled) { + // context is canceled, we don't want to fatal on max attempts in this case + return EVMBlockHeader{}, true + } + attempts++ d.log.Errorf("error getting block header for block %d, err: %v", blockNum, err) d.rh.Handle("getBlockHeader", attempts) @@ -258,6 +283,6 @@ func (d *EVMDownloaderImplementation) GetBlockHeader(ctx context.Context, blockN Hash: header.Hash(), ParentHash: header.ParentHash, Timestamp: header.Time, - } + }, false } } diff --git a/sync/evmdownloader_test.go b/sync/evmdownloader_test.go index 59c43b8ff..6fd7e1a2f 100644 --- a/sync/evmdownloader_test.go +++ b/sync/evmdownloader_test.go @@ -369,14 +369,16 @@ func TestGetBlockHeader(t *testing.T) { // at first attempt clientMock.On("HeaderByNumber", ctx, blockNumBig).Return(returnedBlock, nil).Once() - actualBlock := d.GetBlockHeader(ctx, blockNum) + actualBlock, isCanceled := d.GetBlockHeader(ctx, blockNum) assert.Equal(t, expectedBlock, actualBlock) + assert.False(t, isCanceled) // after error from client clientMock.On("HeaderByNumber", ctx, blockNumBig).Return(nil, errors.New("foo")).Once() clientMock.On("HeaderByNumber", ctx, blockNumBig).Return(returnedBlock, nil).Once() - actualBlock = d.GetBlockHeader(ctx, blockNum) + actualBlock, isCanceled = d.GetBlockHeader(ctx, blockNum) assert.Equal(t, expectedBlock, actualBlock) + assert.False(t, isCanceled) } func buildAppender() LogAppenderMap { diff --git a/sync/evmdriver.go b/sync/evmdriver.go index ae2f3a952..af8bcb7cf 100644 --- a/sync/evmdriver.go +++ b/sync/evmdriver.go @@ -65,12 +65,15 @@ func NewEVMDriver( } func (d *EVMDriver) Sync(ctx context.Context) { + syncID := 0 reset: var ( lastProcessedBlock uint64 attempts int err error ) + + syncID++ for { lastProcessedBlock, err = d.processor.GetLastProcessedBlock(ctx) if err != nil { @@ -84,17 +87,19 @@ reset: cancellableCtx, cancel := context.WithCancel(ctx) defer cancel() + log.Info("Starting sync. syncID: ", syncID, " lastProcessedBlock", lastProcessedBlock) // start downloading downloadCh := make(chan EVMBlock, d.downloadBufferSize) - go d.downloader.Download(cancellableCtx, lastProcessedBlock, downloadCh) + go d.downloader.Download(cancellableCtx, lastProcessedBlock+1, downloadCh) for { select { case b := <-downloadCh: - d.log.Debug("handleNewBlock: ", b.Num, b.Hash) + d.log.Debug("handleNewBlock: ", b.Num, b.Hash, " syncID ", syncID) d.handleNewBlock(ctx, b) + d.log.Debug("handleNewBlock done: ", b.Num, b.Hash, " syncID ", syncID) case firstReorgedBlock := <-d.reorgSub.ReorgedBlock: - d.log.Debug("handleReorg: ", firstReorgedBlock) + d.log.Debug("handleReorg: ", firstReorgedBlock, " syncID ", syncID) d.handleReorg(ctx, cancel, downloadCh, firstReorgedBlock) goto reset } @@ -135,10 +140,7 @@ func (d *EVMDriver) handleReorg( ) { // stop downloader cancel() - _, ok := <-downloadCh - for ok { - _, ok = <-downloadCh - } + // handle reorg attempts := 0 for { diff --git a/sync/mock_downloader_test.go b/sync/mock_downloader_test.go index c965efb6a..f28045b59 100644 --- a/sync/mock_downloader_test.go +++ b/sync/mock_downloader_test.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.40.1. DO NOT EDIT. +// Code generated by mockery v2.45.0. DO NOT EDIT. package sync @@ -20,7 +20,7 @@ func (_m *EVMDownloaderMock) Download(ctx context.Context, fromBlock uint64, dow } // GetBlockHeader provides a mock function with given fields: ctx, blockNum -func (_m *EVMDownloaderMock) GetBlockHeader(ctx context.Context, blockNum uint64) EVMBlockHeader { +func (_m *EVMDownloaderMock) GetBlockHeader(ctx context.Context, blockNum uint64) (EVMBlockHeader, bool) { ret := _m.Called(ctx, blockNum) if len(ret) == 0 { @@ -28,13 +28,23 @@ func (_m *EVMDownloaderMock) GetBlockHeader(ctx context.Context, blockNum uint64 } var r0 EVMBlockHeader + var r1 bool + if rf, ok := ret.Get(0).(func(context.Context, uint64) (EVMBlockHeader, bool)); ok { + return rf(ctx, blockNum) + } if rf, ok := ret.Get(0).(func(context.Context, uint64) EVMBlockHeader); ok { r0 = rf(ctx, blockNum) } else { r0 = ret.Get(0).(EVMBlockHeader) } - return r0 + if rf, ok := ret.Get(1).(func(context.Context, uint64) bool); ok { + r1 = rf(ctx, blockNum) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 } // GetEventsByBlockRange provides a mock function with given fields: ctx, fromBlock, toBlock diff --git a/tree/tree_test.go b/tree/tree_test.go index b5278723b..97c6c80e5 100644 --- a/tree/tree_test.go +++ b/tree/tree_test.go @@ -2,6 +2,7 @@ package tree_test import ( "context" + "database/sql" "encoding/json" "fmt" "os" @@ -18,6 +19,95 @@ import ( "github.com/stretchr/testify/require" ) +func TestCheckExpectedRoot(t *testing.T) { + t.Parallel() + + createTreeDb := func(t *testing.T) *sql.DB { + dbPath := path.Join(t.TempDir(), "file::memory:?cache=shared") + log.Debug("DB created at: ", dbPath) + require.NoError(t, migrations.RunMigrations(dbPath)) + treeDB, err := db.NewSQLiteDB(dbPath) + require.NoError(t, err) + + return treeDB + } + + addLeaves := func(t *testing.T, + merkletree *tree.AppendOnlyTree, + treeDB *sql.DB, + numOfLeavesToAdd, from int) { + tx, err := db.NewTx(context.Background(), treeDB) + require.NoError(t, err) + + for i := from; i < from+numOfLeavesToAdd; i++ { + require.NoError(t, merkletree.AddLeaf(tx, uint64(i), 0, types.Leaf{ + Index: uint32(i), + Hash: common.HexToHash(fmt.Sprintf("%x", i)), + })) + } + + require.NoError(t, tx.Commit()) + } + + t.Run("Check when no reorg", func(t *testing.T) { + t.Parallel() + + numOfLeavesToAdd := 10 + indexToCheck := uint32(numOfLeavesToAdd - 1) + + treeDB := createTreeDb(t) + merkletree := tree.NewAppendOnlyTree(treeDB, "") + + addLeaves(t, merkletree, treeDB, numOfLeavesToAdd, 0) + + expectedRoot, err := merkletree.GetLastRoot(context.Background()) + require.NoError(t, err) + + addLeaves(t, merkletree, treeDB, numOfLeavesToAdd, numOfLeavesToAdd) + + root2, err := merkletree.GetRootByIndex(context.Background(), indexToCheck) + require.NoError(t, err) + require.Equal(t, expectedRoot.Hash, root2.Hash) + require.Equal(t, expectedRoot.Index, root2.Index) + }) + + t.Run("Check after rebuild tree when reorg", func(t *testing.T) { + t.Parallel() + + numOfLeavesToAdd := 10 + indexToCheck := uint32(numOfLeavesToAdd - 1) + treeDB := createTreeDb(t) + merkletree := tree.NewAppendOnlyTree(treeDB, "") + + addLeaves(t, merkletree, treeDB, numOfLeavesToAdd, 0) + + expectedRoot, err := merkletree.GetLastRoot(context.Background()) + require.NoError(t, err) + + addLeaves(t, merkletree, treeDB, numOfLeavesToAdd, numOfLeavesToAdd) + + // reorg tree + tx, err := db.NewTx(context.Background(), treeDB) + require.NoError(t, err) + require.NoError(t, merkletree.Reorg(tx, uint64(indexToCheck+1))) + require.NoError(t, tx.Commit()) + + // rebuild cache on adding new leaf + tx, err = db.NewTx(context.Background(), treeDB) + require.NoError(t, err) + require.NoError(t, merkletree.AddLeaf(tx, uint64(indexToCheck+1), 0, types.Leaf{ + Index: indexToCheck + 1, + Hash: common.HexToHash(fmt.Sprintf("%x", indexToCheck+1)), + })) + tx.Commit() + + root2, err := merkletree.GetRootByIndex(context.Background(), indexToCheck) + require.NoError(t, err) + require.Equal(t, expectedRoot.Hash, root2.Hash) + require.Equal(t, expectedRoot.Index, root2.Index) + }) +} + func TestMTAddLeaf(t *testing.T) { data, err := os.ReadFile("testvectors/root-vectors.json") require.NoError(t, err)