Skip to content

Commit

Permalink
[EVM-701]: Validate ExitRootHash in CheckpointData (#1638)
Browse files Browse the repository at this point in the history
* Initial changes

* Comments fix
  • Loading branch information
goran-ethernal authored Jun 20, 2023
1 parent 2801876 commit de084a9
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 38 deletions.
12 changes: 2 additions & 10 deletions consensus/polybft/blockchain_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ type blockchainBackend interface {
txPool txPoolInterface, blockTime time.Duration, logger hclog.Logger) (blockBuilder, error)

// ProcessBlock builds a final block from given 'block' on top of 'parent'.
ProcessBlock(parent *types.Header, block *types.Block,
callback func(*state.Transition) error) (*types.FullBlock, error)
ProcessBlock(parent *types.Header, block *types.Block) (*types.FullBlock, error)

// GetStateProviderForBlock returns a reference to make queries to the state at 'block'.
GetStateProviderForBlock(block *types.Header) (contract.Provider, error)
Expand Down Expand Up @@ -83,8 +82,7 @@ func (p *blockchainWrapper) CommitBlock(block *types.FullBlock) error {
}

// ProcessBlock builds a final block from given 'block' on top of 'parent'
func (p *blockchainWrapper) ProcessBlock(parent *types.Header, block *types.Block,
callback func(*state.Transition) error) (*types.FullBlock, error) {
func (p *blockchainWrapper) ProcessBlock(parent *types.Header, block *types.Block) (*types.FullBlock, error) {
header := block.Header.Copy()
start := time.Now().UTC()

Expand All @@ -100,12 +98,6 @@ func (p *blockchainWrapper) ProcessBlock(parent *types.Header, block *types.Bloc
}
}

if callback != nil {
if err := callback(transition); err != nil {
return nil, err
}
}

_, root, err := transition.Commit()
if err != nil {
return nil, fmt.Errorf("failed to commit the state changes: %w", err)
Expand Down
9 changes: 8 additions & 1 deletion consensus/polybft/extra.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,8 @@ func (c *CheckpointData) ValidateBasic(parentCheckpoint *CheckpointData) error {
// Validate encapsulates validation logic for checkpoint data
// (with regards to current and next epoch validators)
func (c *CheckpointData) Validate(parentCheckpoint *CheckpointData,
currentValidators validator.AccountSet, nextValidators validator.AccountSet) error {
currentValidators validator.AccountSet, nextValidators validator.AccountSet,
exitRootHash types.Hash) error {
if err := c.ValidateBasic(parentCheckpoint); err != nil {
return err
}
Expand Down Expand Up @@ -459,6 +460,12 @@ func (c *CheckpointData) Validate(parentCheckpoint *CheckpointData,
return fmt.Errorf("epoch number should not change for epoch-ending block")
}

// exit root hash of proposer and
// validator that validates proposal have to match
if exitRootHash != c.EventRoot {
return fmt.Errorf("exit root hash not as expected")
}

return nil
}

Expand Down
15 changes: 14 additions & 1 deletion consensus/polybft/extra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ func TestCheckpointData_Validate(t *testing.T) {
nextValidators validator.AccountSet
currentValidatorsHash types.Hash
nextValidatorsHash types.Hash
exitRootHash types.Hash
errString string
}{
{
Expand Down Expand Up @@ -713,6 +714,17 @@ func TestCheckpointData_Validate(t *testing.T) {
nextValidatorsHash: nextValidatorsHash,
errString: "epoch number should not change for epoch-ending block",
},
{
name: "Invalid exit root hash",
parentEpochNumber: 2,
epochNumber: 2,
currentValidators: currentValidators,
nextValidators: currentValidators,
currentValidatorsHash: currentValidatorsHash,
nextValidatorsHash: currentValidatorsHash,
exitRootHash: types.BytesToHash([]byte{0, 1, 2, 3, 4, 5, 6, 7}),
errString: "exit root hash not as expected",
},
}

for _, c := range cases {
Expand All @@ -723,9 +735,10 @@ func TestCheckpointData_Validate(t *testing.T) {
EpochNumber: c.epochNumber,
CurrentValidatorsHash: c.currentValidatorsHash,
NextValidatorsHash: c.nextValidatorsHash,
EventRoot: c.exitRootHash,
}
parentCheckpoint := &CheckpointData{EpochNumber: c.parentEpochNumber}
err := checkpoint.Validate(parentCheckpoint, c.currentValidators, c.nextValidators)
err := checkpoint.Validate(parentCheckpoint, c.currentValidators, c.nextValidators, types.ZeroHash)

if c.errString != "" {
require.ErrorContains(t, err, c.errString)
Expand Down
32 changes: 17 additions & 15 deletions consensus/polybft/fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,24 +334,26 @@ func (f *fsm) Validate(proposal []byte) error {
}

currentValidators := f.validators.Accounts()
nextValidators := f.validators.Accounts()

validateExtraData := func(transition *state.Transition) error {
if f.isEndOfEpoch {
if !extra.Validators.Equals(f.newValidatorsDelta) {
return errValidatorSetDeltaMismatch
}
} else if !extra.Validators.IsEmpty() {
// delta should be empty in non epoch ending blocks
return errValidatorsUpdateInNonEpochEnding
// validate validators delta
if f.isEndOfEpoch {
if !extra.Validators.Equals(f.newValidatorsDelta) {
return errValidatorSetDeltaMismatch
}
} else if !extra.Validators.IsEmpty() {
// delta should be empty in non epoch ending blocks
return errValidatorsUpdateInNonEpochEnding
}

nextValidators, err = f.getValidatorsTransition(extra.Validators)
if err != nil {
return err
}
nextValidators, err := f.getValidatorsTransition(extra.Validators)
if err != nil {
return err
}

return extra.Checkpoint.Validate(parentExtra.Checkpoint, currentValidators, nextValidators)
// validate checkpoint data
if err := extra.Checkpoint.Validate(parentExtra.Checkpoint,
currentValidators, nextValidators, f.exitEventRootHash); err != nil {
return err
}

if f.logger.IsTrace() && block.Number() > 1 {
Expand All @@ -363,7 +365,7 @@ func (f *fsm) Validate(proposal []byte) error {
f.logger.Trace("[FSM Validate]", "Block", block.Number(), "parent validators", validators)
}

stateBlock, err := f.backend.ProcessBlock(f.parent, &block, validateExtraData)
stateBlock, err := f.backend.ProcessBlock(f.parent, &block)
if err != nil {
return err
}
Expand Down
69 changes: 66 additions & 3 deletions consensus/polybft/fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,69 @@ func TestFSM_ValidateCommit_Good(t *testing.T) {
require.NoError(t, err)
}

func TestFSM_Validate_ExitEventRootNotExpected(t *testing.T) {
t.Parallel()

const (
accountsCount = 5
parentBlockNumber = 25
signaturesCount = 3
)

validators := validator.NewTestValidators(t, accountsCount)
parentExtra := createTestExtraObject(validators.GetPublicIdentities(), validator.AccountSet{}, 4, signaturesCount, signaturesCount)
parentExtra.Validators = nil

parent := &types.Header{
Number: parentBlockNumber,
ExtraData: parentExtra.MarshalRLPTo(nil),
}
parent.ComputeHash()

polybftBackendMock := new(polybftBackendMock)
polybftBackendMock.On("GetValidators", mock.Anything, mock.Anything).Return(validators.GetPublicIdentities(), nil).Once()

extra := createTestExtraObject(validators.GetPublicIdentities(), validator.AccountSet{}, 4, signaturesCount, signaturesCount)
extra.Validators = nil
parentCheckpointHash, err := extra.Checkpoint.Hash(0, parentBlockNumber, parent.Hash)
require.NoError(t, err)

currentValSetHash, err := validators.GetPublicIdentities().Hash()
require.NoError(t, err)

extra.Parent = createSignature(t, validators.GetPrivateIdentities(), parentCheckpointHash, bls.DomainCheckpointManager)
extra.Checkpoint.EpochNumber = 1
extra.Checkpoint.CurrentValidatorsHash = currentValSetHash
extra.Checkpoint.NextValidatorsHash = currentValSetHash

stateBlock := createDummyStateBlock(parent.Number+1, types.Hash{100, 15}, extra.MarshalRLPTo(nil))

proposalHash, err := extra.Checkpoint.Hash(0, stateBlock.Block.Number(), stateBlock.Block.Hash())
require.NoError(t, err)

stateBlock.Block.Header.Hash = proposalHash
stateBlock.Block.Header.ParentHash = parent.Hash
stateBlock.Block.Header.Timestamp = uint64(time.Now().UTC().Unix())
stateBlock.Block.Transactions = []*types.Transaction{}

proposal := stateBlock.Block.MarshalRLP()

fsm := &fsm{
parent: parent,
backend: new(blockchainMock),
validators: validators.ToValidatorSet(),
logger: hclog.NewNullLogger(),
polybftBackend: polybftBackendMock,
config: &PolyBFTConfig{BlockTimeDrift: 1},
exitEventRootHash: types.BytesToHash([]byte{0, 1, 2, 3, 4}), // expect this to be in proposal extra
}

err = fsm.Validate(proposal)
require.ErrorContains(t, err, "exit root hash not as expected")

polybftBackendMock.AssertExpectations(t)
}

func TestFSM_Validate_EpochEndingBlock_MismatchInDeltas(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -763,7 +826,7 @@ func TestFSM_Validate_EpochEndingBlock_MismatchInDeltas(t *testing.T) {
proposal := stateBlock.Block.MarshalRLP()

blockchainMock := new(blockchainMock)
blockchainMock.On("ProcessBlock", mock.Anything, mock.Anything, mock.Anything).
blockchainMock.On("ProcessBlock", mock.Anything, mock.Anything).
Return(stateBlock, error(nil)).
Maybe()

Expand Down Expand Up @@ -854,7 +917,7 @@ func TestFSM_Validate_EpochEndingBlock_UpdatingValidatorSetInNonEpochEndingBlock
proposal := stateBlock.Block.MarshalRLP()

blockchainMock := new(blockchainMock)
blockchainMock.On("ProcessBlock", mock.Anything, mock.Anything, mock.Anything).
blockchainMock.On("ProcessBlock", mock.Anything, mock.Anything).
Return(stateBlock, error(nil)).
Maybe()

Expand Down Expand Up @@ -1062,7 +1125,7 @@ func TestFSM_Insert_Good(t *testing.T) {
builderMock := newBlockBuilderMock(builtBlock)
chainMock := &blockchainMock{}
chainMock.On("CommitBlock", mock.Anything).Return(error(nil)).Once()
chainMock.On("ProcessBlock", mock.Anything, mock.Anything, mock.Anything).
chainMock.On("ProcessBlock", mock.Anything, mock.Anything).
Return(builtBlock, error(nil)).
Maybe()

Expand Down
10 changes: 2 additions & 8 deletions consensus/polybft/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,8 @@ func (m *blockchainMock) NewBlockBuilder(parent *types.Header, coinbase types.Ad
return args.Get(0).(blockBuilder), args.Error(1) //nolint:forcetypeassert
}

func (m *blockchainMock) ProcessBlock(parent *types.Header, block *types.Block, callback func(*state.Transition) error) (*types.FullBlock, error) {
args := m.Called(parent, block, callback)

if callback != nil {
if err := callback(nil); err != nil {
return nil, err
}
}
func (m *blockchainMock) ProcessBlock(parent *types.Header, block *types.Block) (*types.FullBlock, error) {
args := m.Called(parent, block)

return args.Get(0).(*types.FullBlock), args.Error(1) //nolint:forcetypeassert
}
Expand Down

0 comments on commit de084a9

Please sign in to comment.