From 2954fc0b2c49eee6e93d5ac82ae5cae2c6bff440 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20Junior?= Date: Wed, 29 Nov 2023 08:49:29 -0400 Subject: [PATCH] fix(dot/sync): verify justification before importing blocks (#3576) --- dot/interfaces.go | 2 +- dot/network/message.go | 2 +- dot/sync/chain_sync.go | 114 ++++++------------ dot/sync/chain_sync_test.go | 92 +++++++++++++- dot/sync/interfaces.go | 4 +- dot/sync/mocks_test.go | 37 +++--- dot/sync/syncer_integration_test.go | 7 +- lib/grandpa/message_handler.go | 64 +++++----- .../message_handler_integration_test.go | 29 ++--- lib/grandpa/votes_tracker_test.go | 6 + 10 files changed, 204 insertions(+), 153 deletions(-) diff --git a/dot/interfaces.go b/dot/interfaces.go index 33beecb61c..880b258f26 100644 --- a/dot/interfaces.go +++ b/dot/interfaces.go @@ -27,7 +27,7 @@ type ServiceRegisterer interface { // BlockJustificationVerifier has a verification method for block justifications. type BlockJustificationVerifier interface { - VerifyBlockJustification(common.Hash, []byte) error + VerifyBlockJustification(common.Hash, []byte) (round uint64, setID uint64, err error) } // Telemetry is the telemetry client to send telemetry messages. diff --git a/dot/network/message.go b/dot/network/message.go index 14768f6514..66f689db74 100644 --- a/dot/network/message.go +++ b/dot/network/message.go @@ -388,7 +388,7 @@ func NewAscendingBlockRequests(startNumber, targetNumber uint, requestedData byt numRequests := diff / MaxBlocksInResponse // we should check if the diff is in the maxResponseSize bounds // otherwise we should increase the numRequests by one, take this - // example, we want to sync from 0 to 259, the diff is 259 + // example, we want to sync from 1 to 259, the diff is 259 // then the num of requests is 2 (uint(259)/uint(128)) however two requests will // retrieve only 256 blocks (each request can retrieve a max of 128 blocks), so we should // create one more request to retrieve those missing blocks, 3 in this example. diff --git a/dot/sync/chain_sync.go b/dot/sync/chain_sync.go index 68da2fa83c..6f104c48c8 100644 --- a/dot/sync/chain_sync.go +++ b/dot/sync/chain_sync.go @@ -23,7 +23,6 @@ import ( "github.com/ChainSafe/gossamer/dot/telemetry" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/internal/database" - "github.com/ChainSafe/gossamer/lib/blocktree" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/common/variadic" "github.com/ChainSafe/gossamer/lib/trie" @@ -795,28 +794,19 @@ func (cs *chainSync) handleReadyBlock(bd *types.BlockData) error { // returns the index of the last BlockData it handled on success, // or the index of the block data that errored on failure. func (cs *chainSync) processBlockData(blockData types.BlockData) error { - headerInState, err := cs.blockState.HasHeader(blockData.Hash) - if err != nil { - return fmt.Errorf("checking if block state has header: %w", err) - } - - bodyInState, err := cs.blockState.HasBlockBody(blockData.Hash) - if err != nil { - return fmt.Errorf("checking if block state has body: %w", err) - } - // while in bootstrap mode we don't need to broadcast block announcements announceImportedBlock := cs.getSyncMode() == tip - if headerInState && bodyInState { - err = cs.processBlockDataWithStateHeaderAndBody(blockData, announceImportedBlock) - if err != nil { - return fmt.Errorf("processing block data with header and "+ - "body in block state: %w", err) - } - return nil + var blockDataJustification []byte + if blockData.Justification != nil { + blockDataJustification = *blockData.Justification } if blockData.Header != nil { + round, setID, err := cs.verifyJustification(blockData.Header.Hash(), blockDataJustification) + if err != nil { + return err + } + if blockData.Body != nil { err = cs.processBlockDataWithHeaderAndBody(blockData, announceImportedBlock) if err != nil { @@ -824,16 +814,16 @@ func (cs *chainSync) processBlockData(blockData types.BlockData) error { } } - if blockData.Justification != nil && len(*blockData.Justification) > 0 { - logger.Infof("handling justification for block %s (#%d)", blockData.Hash.Short(), blockData.Number()) - err = cs.handleJustification(blockData.Header, *blockData.Justification) - if err != nil { - return fmt.Errorf("handling justification: %w", err) - } + err = cs.finalizeAndSetJustification( + blockData.Header, + round, setID, + blockDataJustification) + if err != nil { + return fmt.Errorf("while setting justification: %w", err) } } - err = cs.blockState.CompareAndSetBlockData(&blockData) + err := cs.blockState.CompareAndSetBlockData(&blockData) if err != nil { return fmt.Errorf("comparing and setting block data: %w", err) } @@ -841,48 +831,14 @@ func (cs *chainSync) processBlockData(blockData types.BlockData) error { return nil } -func (cs *chainSync) processBlockDataWithStateHeaderAndBody(blockData types.BlockData, - announceImportedBlock bool) (err error) { - // TODO: fix this; sometimes when the node shuts down the "best block" isn't stored properly, - // so when the node restarts it has blocks higher than what it thinks is the best, causing it not to sync - // if we update the node to only store finalised blocks in the database, this should be fixed and the entire - // code block can be removed (#1784) - block, err := cs.blockState.GetBlockByHash(blockData.Hash) - if err != nil { - return fmt.Errorf("getting block by hash: %w", err) - } - - err = cs.blockState.AddBlockToBlockTree(block) - if errors.Is(err, blocktree.ErrBlockExists) { - logger.Debugf( - "block number %d with hash %s already exists in block tree, skipping it.", - block.Header.Number, blockData.Hash) - return nil - } else if err != nil { - return fmt.Errorf("adding block to blocktree: %w", err) - } - - if blockData.Justification != nil && len(*blockData.Justification) > 0 { - err = cs.handleJustification(&block.Header, *blockData.Justification) - if err != nil { - return fmt.Errorf("handling justification: %w", err) - } - } - - // TODO: this is probably unnecessary, since the state is already in the database - // however, this case shouldn't be hit often, since it's only hit if the node state - // is rewinded or if the node shuts down unexpectedly (#1784) - state, err := cs.storageState.TrieState(&block.Header.StateRoot) - if err != nil { - return fmt.Errorf("loading trie state: %w", err) - } - - err = cs.blockImportHandler.HandleBlockImport(block, state, announceImportedBlock) - if err != nil { - return fmt.Errorf("handling block import: %w", err) +func (cs *chainSync) verifyJustification(headerHash common.Hash, justification []byte) ( + round uint64, setID uint64, err error) { + if len(justification) > 0 { + round, setID, err = cs.finalityGadget.VerifyBlockJustification(headerHash, justification) + return round, setID, err } - return nil + return 0, 0, nil } func (cs *chainSync) processBlockDataWithHeaderAndBody(blockData types.BlockData, @@ -918,21 +874,27 @@ func (cs *chainSync) handleBody(body *types.Body) { blockSizeGauge.Set(float64(acc)) } -func (cs *chainSync) handleJustification(header *types.Header, justification []byte) (err error) { - logger.Debugf("handling justification for block %d...", header.Number) +func (cs *chainSync) finalizeAndSetJustification(header *types.Header, + round, setID uint64, justification []byte) (err error) { + if len(justification) > 0 { + err = cs.blockState.SetFinalisedHash(header.Hash(), round, setID) + if err != nil { + return fmt.Errorf("setting finalised hash: %w", err) + } - headerHash := header.Hash() - err = cs.finalityGadget.VerifyBlockJustification(headerHash, justification) - if err != nil { - return fmt.Errorf("verifying block number %d justification: %w", header.Number, err) - } + logger.Debugf( + "finalised block with hash #%d (%s), round %d and set id %d", + header.Number, header.Hash(), round, setID) - err = cs.blockState.SetJustification(headerHash, justification) - if err != nil { - return fmt.Errorf("setting justification for block number %d: %w", header.Number, err) + err = cs.blockState.SetJustification(header.Hash(), justification) + if err != nil { + return fmt.Errorf("setting justification for block number %d: %w", + header.Number, err) + } + + logger.Infof("🔨 finalised block number #%d (%s)", header.Number, header.Hash()) } - logger.Infof("🔨 finalised block number %d with hash %s", header.Number, headerHash) return nil } diff --git a/dot/sync/chain_sync_test.go b/dot/sync/chain_sync_test.go index fb932a8c60..27f8bb5845 100644 --- a/dot/sync/chain_sync_test.go +++ b/dot/sync/chain_sync_test.go @@ -1288,8 +1288,6 @@ func ensureSuccessfulBlockImportFlow(t *testing.T, parentHeader *types.Header, t.Helper() for idx, blockData := range blocksReceived { - mockBlockState.EXPECT().HasHeader(blockData.Header.Hash()).Return(false, nil) - mockBlockState.EXPECT().HasBlockBody(blockData.Header.Hash()).Return(false, nil) mockBabeVerifier.EXPECT().VerifyBlock(blockData.Header).Return(nil) var previousHeader *types.Header @@ -1676,3 +1674,93 @@ func TestChainSync_getHighestBlock(t *testing.T) { }) } } + +func TestChainSync_BootstrapSync_SuccessfulSync_WithInvalidJusticationBlock(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockBlockState := NewMockBlockState(ctrl) + mockBlockState.EXPECT().GetFinalisedNotifierChannel().Return(make(chan *types.FinalisationInfo)) + mockedGenesisHeader := types.NewHeader(common.NewHash([]byte{0}), trie.EmptyHash, + trie.EmptyHash, 0, types.NewDigest()) + + mockNetwork := NewMockNetwork(ctrl) + mockRequestMaker := NewMockRequestMaker(ctrl) + + mockBabeVerifier := NewMockBabeVerifier(ctrl) + mockStorageState := NewMockStorageState(ctrl) + mockImportHandler := NewMockBlockImportHandler(ctrl) + mockTelemetry := NewMockTelemetry(ctrl) + mockFinalityGadget := NewMockFinalityGadget(ctrl) + + // this test expects two workers responding each request with 128 blocks which means + // we should import 256 blocks in total + blockResponse := createSuccesfullBlockResponse(t, mockedGenesisHeader.Hash(), 1, 129) + const announceBlock = false + + invalidJustificationBlock := blockResponse.BlockData[90] + invalidJustification := &[]byte{0x01, 0x01, 0x01, 0x02} + invalidJustificationBlock.Justification = invalidJustification + + // here we split the whole set in two parts each one will be the "response" for each peer + worker1Response := &network.BlockResponseMessage{ + BlockData: blockResponse.BlockData[:128], + } + + // the first peer will respond the from the block 1 to 128 so the ensureBlockImportFlow + // will setup the expectations starting from the genesis header until block 128 + ensureSuccessfulBlockImportFlow(t, mockedGenesisHeader, worker1Response.BlockData[:90], mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, announceBlock) + + errVerifyBlockJustification := errors.New("VerifyBlockJustification mock error") + mockFinalityGadget.EXPECT(). + VerifyBlockJustification( + invalidJustificationBlock.Header.Hash(), + *invalidJustification). + Return(uint64(0), uint64(0), errVerifyBlockJustification) + + // we use gomock.Any since I cannot guarantee which peer picks which request + // but the first call to DoBlockRequest will return the first set and the second + // call will return the second set + mockRequestMaker.EXPECT(). + Do(gomock.Any(), gomock.Any(), &network.BlockResponseMessage{}). + DoAndReturn(func(peerID, _, response any) any { + responsePtr := response.(*network.BlockResponseMessage) + *responsePtr = *worker1Response + + fmt.Println("mocked request maker") + return nil + }) + + // setup a chain sync which holds in its peer view map + // 3 peers, each one announce block 129 as its best block number. + // We start this test with genesis block being our best block, so + // we're far behind by 128 blocks, we should execute a bootstrap + // sync request those blocks + const blocksAhead = 128 + cs := setupChainSyncToBootstrapMode(t, blocksAhead, + mockBlockState, mockNetwork, mockRequestMaker, mockBabeVerifier, + mockStorageState, mockImportHandler, mockTelemetry) + + cs.finalityGadget = mockFinalityGadget + + target, err := cs.getTarget() + require.NoError(t, err) + require.Equal(t, uint(blocksAhead), target) + + // include a new worker in the worker pool set, this worker + // should be an available peer that will receive a block request + // the worker pool executes the workers management + cs.workerPool.fromBlockAnnounce(peer.ID("alice")) + //cs.workerPool.fromBlockAnnounce(peer.ID("bob")) + + err = cs.requestMaxBlocksFrom(mockedGenesisHeader) + require.ErrorIs(t, err, errVerifyBlockJustification) + + err = cs.workerPool.stop() + require.NoError(t, err) + + // peer should be not in the worker pool + // peer should be in the ignore list + require.Len(t, cs.workerPool.workers, 1) +} diff --git a/dot/sync/interfaces.go b/dot/sync/interfaces.go index 806eedb659..29ac858ee6 100644 --- a/dot/sync/interfaces.go +++ b/dot/sync/interfaces.go @@ -20,7 +20,6 @@ type BlockState interface { BestBlockHeader() (*types.Header, error) BestBlockNumber() (number uint, err error) CompareAndSetBlockData(bd *types.BlockData) error - HasBlockBody(hash common.Hash) (bool, error) GetBlockBody(common.Hash) (*types.Body, error) GetHeader(common.Hash) (*types.Header, error) HasHeader(hash common.Hash) (bool, error) @@ -40,6 +39,7 @@ type BlockState interface { GetHeaderByNumber(num uint) (*types.Header, error) GetAllBlocksAtNumber(num uint) ([]common.Hash, error) IsDescendantOf(parent, child common.Hash) (bool, error) + SetFinalisedHash(common.Hash, uint64, uint64) error } // StorageState is the interface for the storage state @@ -60,7 +60,7 @@ type BabeVerifier interface { // FinalityGadget implements justification verification functionality type FinalityGadget interface { - VerifyBlockJustification(common.Hash, []byte) error + VerifyBlockJustification(common.Hash, []byte) (round uint64, setID uint64, err error) } // BlockImportHandler is the interface for the handler of newly imported blocks diff --git a/dot/sync/mocks_test.go b/dot/sync/mocks_test.go index 53350f11bc..8335588e01 100644 --- a/dot/sync/mocks_test.go +++ b/dot/sync/mocks_test.go @@ -276,21 +276,6 @@ func (mr *MockBlockStateMockRecorder) GetRuntime(arg0 interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRuntime", reflect.TypeOf((*MockBlockState)(nil).GetRuntime), arg0) } -// HasBlockBody mocks base method. -func (m *MockBlockState) HasBlockBody(arg0 common.Hash) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HasBlockBody", arg0) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// HasBlockBody indicates an expected call of HasBlockBody. -func (mr *MockBlockStateMockRecorder) HasBlockBody(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasBlockBody", reflect.TypeOf((*MockBlockState)(nil).HasBlockBody), arg0) -} - // HasHeader mocks base method. func (m *MockBlockState) HasHeader(arg0 common.Hash) (bool, error) { m.ctrl.T.Helper() @@ -351,6 +336,20 @@ func (mr *MockBlockStateMockRecorder) RangeInMemory(arg0, arg1 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RangeInMemory", reflect.TypeOf((*MockBlockState)(nil).RangeInMemory), arg0, arg1) } +// SetFinalisedHash mocks base method. +func (m *MockBlockState) SetFinalisedHash(arg0 common.Hash, arg1, arg2 uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetFinalisedHash", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetFinalisedHash indicates an expected call of SetFinalisedHash. +func (mr *MockBlockStateMockRecorder) SetFinalisedHash(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFinalisedHash", reflect.TypeOf((*MockBlockState)(nil).SetFinalisedHash), arg0, arg1, arg2) +} + // SetJustification mocks base method. func (m *MockBlockState) SetJustification(arg0 common.Hash, arg1 []byte) error { m.ctrl.T.Helper() @@ -535,11 +534,13 @@ func (m *MockFinalityGadget) EXPECT() *MockFinalityGadgetMockRecorder { } // VerifyBlockJustification mocks base method. -func (m *MockFinalityGadget) VerifyBlockJustification(arg0 common.Hash, arg1 []byte) error { +func (m *MockFinalityGadget) VerifyBlockJustification(arg0 common.Hash, arg1 []byte) (uint64, uint64, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "VerifyBlockJustification", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(uint64) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // VerifyBlockJustification indicates an expected call of VerifyBlockJustification. diff --git a/dot/sync/syncer_integration_test.go b/dot/sync/syncer_integration_test.go index 9333486bce..68560ef4c3 100644 --- a/dot/sync/syncer_integration_test.go +++ b/dot/sync/syncer_integration_test.go @@ -111,9 +111,10 @@ func newTestSyncer(t *testing.T) *Service { cfg.LogLvl = log.Trace mockFinalityGadget := NewMockFinalityGadget(ctrl) mockFinalityGadget.EXPECT().VerifyBlockJustification(gomock.AssignableToTypeOf(common.Hash{}), - gomock.AssignableToTypeOf([]byte{})).DoAndReturn(func(hash common.Hash, justification []byte) error { - return nil - }).AnyTimes() + gomock.AssignableToTypeOf([]byte{})).DoAndReturn( + func(hash common.Hash, justification []byte) (uint64, uint64, error) { + return 1, 1, nil + }).AnyTimes() cfg.FinalityGadget = mockFinalityGadget cfg.Network = NewMockNetwork(ctrl) diff --git a/lib/grandpa/message_handler.go b/lib/grandpa/message_handler.go index cb5105324f..6d7cf472fd 100644 --- a/lib/grandpa/message_handler.go +++ b/lib/grandpa/message_handler.go @@ -402,54 +402,55 @@ func (h *MessageHandler) verifyPreCommitJustification(msg *CatchUpResponse) erro return nil } -// VerifyBlockJustification verifies the finality justification for a block, returns scale encoded justification with -// any extra bytes removed. -func (s *Service) VerifyBlockJustification(hash common.Hash, justification []byte) error { +// VerifyBlockJustification verifies the finality justification for a block, +// if the justification is valid the return the round and set id otherwise error +func (s *Service) VerifyBlockJustification(hash common.Hash, justification []byte) ( + round uint64, setID uint64, err error) { fj := Justification{} - err := scale.Unmarshal(justification, &fj) + err = scale.Unmarshal(justification, &fj) if err != nil { - return err + return 0, 0, err } if hash != fj.Commit.Hash { - return fmt.Errorf("%w: justification %s and block hash %s", + return 0, 0, fmt.Errorf("%w: justification %s and block hash %s", ErrJustificationMismatch, fj.Commit.Hash.Short(), hash.Short()) } - setID, err := s.grandpaState.GetSetIDByBlockNumber(uint(fj.Commit.Number)) + setID, err = s.grandpaState.GetSetIDByBlockNumber(uint(fj.Commit.Number)) if err != nil { - return fmt.Errorf("cannot get set ID from block number: %w", err) + return 0, 0, fmt.Errorf("cannot get set ID from block number: %w", err) } has, err := s.blockState.HasFinalisedBlock(fj.Round, setID) if err != nil { - return fmt.Errorf("checking if round and set id has finalised block: %w", err) + return 0, 0, fmt.Errorf("checking if round and set id has finalised block: %w", err) } if has { storedFinalisedHash, err := s.blockState.GetFinalisedHash(fj.Round, setID) if err != nil { - return fmt.Errorf("getting finalised hash: %w", err) + return 0, 0, fmt.Errorf("getting finalised hash: %w", err) } if storedFinalisedHash != hash { - return fmt.Errorf("%w, setID=%d and round=%d", errFinalisedBlocksMismatch, setID, fj.Round) + return 0, 0, fmt.Errorf("%w, setID=%d and round=%d", errFinalisedBlocksMismatch, setID, fj.Round) } - return nil + return fj.Round, setID, nil } isDescendant, err := isDescendantOfHighestFinalisedBlock(s.blockState, fj.Commit.Hash) if err != nil { - return fmt.Errorf("checking if descendant of highest block: %w", err) + return 0, 0, fmt.Errorf("checking if descendant of highest block: %w", err) } if !isDescendant { - return errVoteBlockMismatch + return 0, 0, errVoteBlockMismatch } auths, err := s.grandpaState.GetAuthorities(setID) if err != nil { - return fmt.Errorf("cannot get authorities for set ID: %w", err) + return 0, 0, fmt.Errorf("cannot get authorities for set ID: %w", err) } // threshold is two-thirds the number of authorities, @@ -457,7 +458,7 @@ func (s *Service) VerifyBlockJustification(hash common.Hash, justification []byt threshold := (2 * len(auths) / 3) if len(fj.Commit.Precommits) < threshold { - return ErrMinVotesNotMet + return 0, 0, ErrMinVotesNotMet } authPubKeys := make([]AuthData, len(fj.Commit.Precommits)) @@ -477,20 +478,20 @@ func (s *Service) VerifyBlockJustification(hash common.Hash, justification []byt // check if vote was for descendant of committed block isDescendant, err := s.blockState.IsDescendantOf(hash, just.Vote.Hash) if err != nil { - return err + return 0, 0, err } if !isDescendant { - return ErrPrecommitBlockMismatch + return 0, 0, ErrPrecommitBlockMismatch } publicKey, err := ed25519.NewPublicKey(just.AuthorityID[:]) if err != nil { - return err + return 0, 0, err } if !isInAuthSet(publicKey, auths) { - return ErrAuthorityNotInSet + return 0, 0, ErrAuthorityNotInSet } // verify signature for each precommit @@ -501,16 +502,16 @@ func (s *Service) VerifyBlockJustification(hash common.Hash, justification []byt SetID: setID, }) if err != nil { - return err + return 0, 0, err } ok, err := publicKey.Verify(msg, just.Signature[:]) if err != nil { - return err + return 0, 0, err } if !ok { - return ErrInvalidSignature + return 0, 0, ErrInvalidSignature } if _, ok := equivocatoryVoters[just.AuthorityID]; ok { @@ -521,30 +522,21 @@ func (s *Service) VerifyBlockJustification(hash common.Hash, justification []byt } if count+len(equivocatoryVoters) < threshold { - return ErrMinVotesNotMet + return 0, 0, ErrMinVotesNotMet } err = verifyBlockHashAgainstBlockNumber(s.blockState, fj.Commit.Hash, uint(fj.Commit.Number)) if err != nil { - return fmt.Errorf("verifying block hash against block number: %w", err) + return 0, 0, fmt.Errorf("verifying block hash against block number: %w", err) } for _, preCommit := range fj.Commit.Precommits { err := verifyBlockHashAgainstBlockNumber(s.blockState, preCommit.Vote.Hash, uint(preCommit.Vote.Number)) if err != nil { - return fmt.Errorf("verifying block hash against block number: %w", err) + return 0, 0, fmt.Errorf("verifying block hash against block number: %w", err) } } - - err = s.blockState.SetFinalisedHash(hash, fj.Round, setID) - if err != nil { - return fmt.Errorf("setting finalised hash: %w", err) - } - - logger.Debugf( - "set finalised block with hash %s, round %d and set id %d", - hash, fj.Round, setID) - return nil + return fj.Round, setID, nil } func verifyBlockHashAgainstBlockNumber(bs BlockState, hash common.Hash, number uint) error { diff --git a/lib/grandpa/message_handler_integration_test.go b/lib/grandpa/message_handler_integration_test.go index 831d863e64..3e137dc380 100644 --- a/lib/grandpa/message_handler_integration_test.go +++ b/lib/grandpa/message_handler_integration_test.go @@ -772,8 +772,11 @@ func TestMessageHandler_VerifyBlockJustification_WithEquivocatoryVotes(t *testin just := newJustification(round, testHash, number, precommits) data, err := scale.Marshal(*just) require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) + + actualRound, actualSetID, err := gs.VerifyBlockJustification(testHash, data) require.NoError(t, err) + require.Equal(t, round, actualRound) + require.Equal(t, setID, actualSetID) } func TestMessageHandler_VerifyBlockJustification(t *testing.T) { @@ -840,8 +843,10 @@ func TestMessageHandler_VerifyBlockJustification(t *testing.T) { just := newJustification(round, testHash, number, precommits) data, err := scale.Marshal(*just) require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) + actualRound, actualSetID, err := gs.VerifyBlockJustification(testHash, data) require.NoError(t, err) + require.Equal(t, round, actualRound) + require.Equal(t, setID, actualSetID) // use wrong hash, shouldn't verify precommits = buildTestJustification(t, 2, round+1, setID, kr, precommit) @@ -849,7 +854,7 @@ func TestMessageHandler_VerifyBlockJustification(t *testing.T) { just.Commit.Precommits[0].Vote.Hash = testHeader2.Hash() data, err = scale.Marshal(*just) require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) + _, _, err = gs.VerifyBlockJustification(testHash, data) require.Equal(t, ErrPrecommitBlockMismatch, err) } @@ -899,7 +904,7 @@ func TestMessageHandler_VerifyBlockJustification_invalid(t *testing.T) { just.Commit.Precommits[0].Vote.Hash = genhash data, err := scale.Marshal(*just) require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) + _, _, err = gs.VerifyBlockJustification(testHash, data) require.Equal(t, ErrPrecommitBlockMismatch, err) // use wrong round, shouldn't verify @@ -907,7 +912,7 @@ func TestMessageHandler_VerifyBlockJustification_invalid(t *testing.T) { just = newJustification(round+2, testHash, number, precommits) data, err = scale.Marshal(*just) require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) + _, _, err = gs.VerifyBlockJustification(testHash, data) require.Equal(t, ErrInvalidSignature, err) // add authority not in set, shouldn't verify @@ -915,7 +920,7 @@ func TestMessageHandler_VerifyBlockJustification_invalid(t *testing.T) { just = newJustification(round+1, testHash, number, precommits) data, err = scale.Marshal(*just) require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) + _, _, err = gs.VerifyBlockJustification(testHash, data) require.Equal(t, ErrAuthorityNotInSet, err) // not enough signatures, shouldn't verify @@ -923,7 +928,7 @@ func TestMessageHandler_VerifyBlockJustification_invalid(t *testing.T) { just = newJustification(round+1, testHash, number, precommits) data, err = scale.Marshal(*just) require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) + _, _, err = gs.VerifyBlockJustification(testHash, data) require.Equal(t, ErrMinVotesNotMet, err) // mismatch justification header and block header @@ -932,7 +937,7 @@ func TestMessageHandler_VerifyBlockJustification_invalid(t *testing.T) { data, err = scale.Marshal(*just) require.NoError(t, err) otherHeader := types.NewEmptyHeader() - err = gs.VerifyBlockJustification(otherHeader.Hash(), data) + _, _, err = gs.VerifyBlockJustification(otherHeader.Hash(), data) require.ErrorIs(t, err, ErrJustificationMismatch) expectedErr := fmt.Sprintf("%s: justification %s and block hash %s", ErrJustificationMismatch, @@ -1001,7 +1006,7 @@ func TestMessageHandler_VerifyBlockJustification_ErrFinalisedBlockMismatch(t *te just := newJustification(round, testHash, number, precommits) data, err := scale.Marshal(*just) require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) + _, _, err = gs.VerifyBlockJustification(testHash, data) require.ErrorIs(t, err, errFinalisedBlocksMismatch) } @@ -1517,8 +1522,6 @@ func TestService_VerifyBlockJustification(t *testing.T) { //nolint mockBlockState.EXPECT().IsDescendantOf(testHash, testHash). Return(true, nil).Times(3) mockBlockState.EXPECT().GetHeader(testHash).Return(testHeader, nil).Times(3) - mockBlockState.EXPECT().SetFinalisedHash(testHash, uint64(1), - uint64(0)).Return(nil) return mockBlockState }, grandpaStateBuilder: func(ctrl *gomock.Controller) GrandpaState { @@ -1547,8 +1550,6 @@ func TestService_VerifyBlockJustification(t *testing.T) { //nolint mockBlockState.EXPECT().IsDescendantOf(testHash, testHash). Return(true, nil).Times(3) mockBlockState.EXPECT().GetHeader(testHash).Return(testHeader, nil).Times(3) - mockBlockState.EXPECT().SetFinalisedHash(testHash, uint64(1), - uint64(0)).Return(nil) return mockBlockState }, grandpaStateBuilder: func(ctrl *gomock.Controller) GrandpaState { @@ -1578,7 +1579,7 @@ func TestService_VerifyBlockJustification(t *testing.T) { //nolint blockState: tt.fields.blockStateBuilder(ctrl), grandpaState: tt.fields.grandpaStateBuilder(ctrl), } - err := s.VerifyBlockJustification(tt.args.hash, tt.args.justification) + _, _, err := s.VerifyBlockJustification(tt.args.hash, tt.args.justification) if tt.wantErr != nil { assert.ErrorContains(t, err, tt.wantErr.Error()) } else { diff --git a/lib/grandpa/votes_tracker_test.go b/lib/grandpa/votes_tracker_test.go index 91a9b220b8..54857b3678 100644 --- a/lib/grandpa/votes_tracker_test.go +++ b/lib/grandpa/votes_tracker_test.go @@ -5,16 +5,22 @@ package grandpa import ( "container/list" + "fmt" "sort" "testing" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/crypto/ed25519" + "github.com/ChainSafe/gossamer/pkg/scale" "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestA(t *testing.T) { + fmt.Println(scale.MustMarshal([]byte{})) +} + // buildVoteMessage creates a test vote message using the // given block hash and authority ID only. func buildVoteMessage(blockHash common.Hash,