Skip to content

Commit

Permalink
Message is not hashed before signature verification (#1205)
Browse files Browse the repository at this point in the history
* Message not hashed before signature verification

* Refactor IsValidValidator unit tests

* Renames

* Move hashing to RecoverAddressFromSignature
  • Loading branch information
Stefan-Ethernal authored Feb 13, 2023
1 parent e3320d5 commit 327d115
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 76 deletions.
8 changes: 4 additions & 4 deletions consensus/polybft/consensus_runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ func (c *consensusRuntime) BuildPrePrepareMessage(
},
}

message, err := c.config.Key.SignEcdsaMessage(&msg)
message, err := c.config.Key.SignIBFTMessage(&msg)
if err != nil {
c.logger.Error("Cannot sign message", "error", err)

Expand All @@ -837,7 +837,7 @@ func (c *consensusRuntime) BuildPrepareMessage(proposalHash []byte, view *proto.
},
}

message, err := c.config.Key.SignEcdsaMessage(&msg)
message, err := c.config.Key.SignIBFTMessage(&msg)
if err != nil {
c.logger.Error("Cannot sign message.", "error", err)

Expand Down Expand Up @@ -868,7 +868,7 @@ func (c *consensusRuntime) BuildCommitMessage(proposalHash []byte, view *proto.V
},
}

message, err := c.config.Key.SignEcdsaMessage(&msg)
message, err := c.config.Key.SignIBFTMessage(&msg)
if err != nil {
c.logger.Error("Cannot sign message", "Error", err)

Expand All @@ -895,7 +895,7 @@ func (c *consensusRuntime) BuildRoundChangeMessage(
}},
}

signedMsg, err := c.config.Key.SignEcdsaMessage(&msg)
signedMsg, err := c.config.Key.SignIBFTMessage(&msg)
if err != nil {
c.logger.Error("Cannot sign message", "Error", err)

Expand Down
170 changes: 109 additions & 61 deletions consensus/polybft/consensus_runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,84 +608,132 @@ func TestConsensusRuntime_validateVote_VoteSentFromUnknownValidator(t *testing.T
fmt.Sprintf("message is received from sender %s, which is not in current validator set", vote.From))
}

func TestConsensusRuntime_IsValidSender(t *testing.T) {
func TestConsensusRuntime_IsValidValidator_BasicCases(t *testing.T) {
t.Parallel()

validatorAccounts := newTestValidatorsWithAliases([]string{"A", "B", "C", "D", "E", "F"})

extra := &Extra{}
lastBuildBlock := &types.Header{
Number: 0,
ExtraData: append(make([]byte, ExtraVanity), extra.MarshalRLPTo(nil)...),
}
setupFn := func(t *testing.T) (*consensusRuntime, *testValidators) {
t.Helper()

blockchainMock := new(blockchainMock)
blockchainMock.On("NewBlockBuilder", mock.Anything).Return(&BlockBuilder{}, nil).Once()
validatorAccounts := newTestValidatorsWithAliases([]string{"A", "B", "C", "D", "E", "F"})
epoch := &epochMetadata{
Validators: validatorAccounts.getPublicIdentities("A", "B", "C", "D"),
}
runtime := &consensusRuntime{
epoch: epoch,
logger: hclog.NewNullLogger(),
fsm: &fsm{validators: NewValidatorSet(epoch.Validators, hclog.NewNullLogger())},
}

state := newTestState(t)
snapshot := NewProposerSnapshot(0, nil)
config := &runtimeConfig{
Key: validatorAccounts.getValidator("B").Key(),
blockchain: blockchainMock,
PolyBFTConfig: &PolyBFTConfig{EpochSize: 10, SprintSize: 5},
return runtime, validatorAccounts
}
runtime := &consensusRuntime{
state: state,
config: config,
lastBuiltBlock: lastBuildBlock,
epoch: &epochMetadata{
Number: 1,
Validators: validatorAccounts.getPublicIdentities()[:len(validatorAccounts.validators)-1],

cases := []struct {
name string
signerAlias string
senderAlias string
isValidSender bool
}{
{
name: "Valid sender",
signerAlias: "A",
senderAlias: "A",
isValidSender: true,
},
{
name: "Sender not amongst current validators",
signerAlias: "F",
senderAlias: "F",
isValidSender: false,
},
{
name: "Sender and signer accounts mismatch",
signerAlias: "A",
senderAlias: "B",
isValidSender: false,
},
logger: hclog.NewNullLogger(),
proposerCalculator: NewProposerCalculatorFromSnapshot(snapshot, config, hclog.NewNullLogger()),
stateSyncManager: &dummyStateSyncManager{},
checkpointManager: &dummyCheckpointManager{},
}

require.NoError(t, runtime.FSM())

sender := validatorAccounts.getValidator("A")
msg, err := sender.Key().SignEcdsaMessage(&proto.Message{
From: sender.Address().Bytes(),
})
for _, c := range cases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()

require.NoError(t, err)
runtime, validatorAccounts := setupFn(t)
signer := validatorAccounts.getValidator(c.signerAlias)
sender := validatorAccounts.getValidator(c.senderAlias)
msg, err := signer.Key().SignIBFTMessage(&proto.Message{From: sender.Address().Bytes()})

assert.True(t, runtime.IsValidValidator(msg))
blockchainMock.AssertExpectations(t)
require.NoError(t, err)
require.Equal(t, c.isValidSender, runtime.IsValidValidator(msg))
})
}
}

// sender not in current epoch validators
sender = validatorAccounts.getValidator("F")
msg, err = sender.Key().SignEcdsaMessage(&proto.Message{
From: sender.Address().Bytes(),
})
func TestConsensusRuntime_IsValidValidator_TamperSignature(t *testing.T) {
t.Parallel()

require.NoError(t, err)
validatorAccounts := newTestValidatorsWithAliases([]string{"A", "B", "C", "D", "E", "F"})
epoch := &epochMetadata{
Validators: validatorAccounts.getPublicIdentities("A", "B", "C", "D"),
}
runtime := &consensusRuntime{
epoch: epoch,
logger: hclog.NewNullLogger(),
fsm: &fsm{validators: NewValidatorSet(epoch.Validators, hclog.NewNullLogger())},
}

assert.False(t, runtime.IsValidValidator(msg))
blockchainMock.AssertExpectations(t)
// provide invalid signature
sender := validatorAccounts.getValidator("A")
msg := &proto.Message{
From: sender.Address().Bytes(),
Signature: []byte{1, 2, 3, 4, 5},
}
require.False(t, runtime.IsValidValidator(msg))
}

// signature does not come from sender
sender = validatorAccounts.getValidator("A")
msg, err = sender.Key().SignEcdsaMessage(&proto.Message{
From: validatorAccounts.getValidator("B").Address().Bytes(),
})
func TestConsensusRuntime_TamperMessageContent(t *testing.T) {
t.Parallel()

validatorAccounts := newTestValidatorsWithAliases([]string{"A", "B", "C", "D", "E", "F"})
epoch := &epochMetadata{
Validators: validatorAccounts.getPublicIdentities("A", "B", "C", "D"),
}
runtime := &consensusRuntime{
epoch: epoch,
logger: hclog.NewNullLogger(),
fsm: &fsm{validators: NewValidatorSet(epoch.Validators, hclog.NewNullLogger())},
}
sender := validatorAccounts.getValidator("A")
proposalHash := []byte{2, 4, 6, 8, 10}
proposalSignature, err := sender.Key().Sign(proposalHash)
require.NoError(t, err)

assert.False(t, runtime.IsValidValidator(msg))
blockchainMock.AssertExpectations(t)

// invalid signature
sender = validatorAccounts.getValidator("A")
msg = &proto.Message{
From: sender.Address().Bytes(),
Signature: []byte{1, 2},
msg := &proto.Message{
View: &proto.View{},
From: sender.Address().Bytes(),
Type: proto.MessageType_COMMIT,
Payload: &proto.Message_CommitData{
CommitData: &proto.CommitMessage{
ProposalHash: proposalHash,
CommittedSeal: proposalSignature,
},
},
}
// sign the message itself
msg, err = sender.Key().SignIBFTMessage(msg)
assert.NoError(t, err)
// signature verification works
assert.True(t, runtime.IsValidValidator(msg))

// modify message without signing it again
msg.Payload = &proto.Message_CommitData{
CommitData: &proto.CommitMessage{
ProposalHash: []byte{1, 3, 5, 7, 9}, // modification
CommittedSeal: proposalSignature,
},
}
// signature isn't valid, because message was tampered
assert.False(t, runtime.IsValidValidator(msg))
blockchainMock.AssertExpectations(t)
}

func TestConsensusRuntime_IsValidProposalHash(t *testing.T) {
Expand Down Expand Up @@ -952,7 +1000,7 @@ func TestConsensusRuntime_BuildRoundChangeMessage(t *testing.T) {
}},
}

signedMsg, err := key.SignEcdsaMessage(&expected)
signedMsg, err := key.SignIBFTMessage(&expected)
require.NoError(t, err)

assert.Equal(t, signedMsg, runtime.BuildRoundChangeMessage(proposal, certificate, view))
Expand Down Expand Up @@ -985,7 +1033,7 @@ func TestConsensusRuntime_BuildCommitMessage(t *testing.T) {
},
}

signedMsg, err := key.SignEcdsaMessage(&expected)
signedMsg, err := key.SignIBFTMessage(&expected)
require.NoError(t, err)

assert.Equal(t, signedMsg, runtime.BuildCommitMessage(proposalHash, view))
Expand Down Expand Up @@ -1030,7 +1078,7 @@ func TestConsensusRuntime_BuildPrepareMessage(t *testing.T) {
},
}

signedMsg, err := key.SignEcdsaMessage(&expected)
signedMsg, err := key.SignIBFTMessage(&expected)
require.NoError(t, err)

assert.Equal(t, signedMsg, runtime.BuildPrepareMessage(proposalHash, view))
Expand Down
24 changes: 14 additions & 10 deletions consensus/polybft/wallet/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,44 @@ func NewKey(raw *Account) *Key {
}
}

// String returns hex encoded ECDSA address
func (k *Key) String() string {
return k.raw.Ecdsa.Address().String()
}

// Address returns ECDSA address
func (k *Key) Address() ethgo.Address {
return k.raw.Ecdsa.Address()
}

func (k *Key) Sign(b []byte) ([]byte, error) {
s, err := k.raw.Bls.Sign(b)
// Sign signs the provided digest with BLS key
func (k *Key) Sign(digest []byte) ([]byte, error) {
signature, err := k.raw.Bls.Sign(digest)
if err != nil {
return nil, err
}

return s.Marshal()
return signature.Marshal()
}

// SignEcdsaMessage signs the proto message with ecdsa
func (k *Key) SignEcdsaMessage(msg *proto.Message) (*proto.Message, error) {
raw, err := protobuf.Marshal(msg)
// SignIBFTMessage signs the IBFT consensus message with ECDSA key
func (k *Key) SignIBFTMessage(msg *proto.Message) (*proto.Message, error) {
msgRaw, err := protobuf.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("cannot marshal message: %w", err)
}

if msg.Signature, err = k.raw.Ecdsa.Sign(raw); err != nil {
if msg.Signature, err = k.raw.Ecdsa.Sign(crypto.Keccak256(msgRaw)); err != nil {
return nil, fmt.Errorf("cannot create message signature: %w", err)
}

return msg, nil
}

// RecoverAddressFromSignature recovers signer address from the given digest and signature
func RecoverAddressFromSignature(sig, msg []byte) (types.Address, error) {
pub, err := crypto.RecoverPubkey(sig, msg)
// RecoverAddressFromSignature calculates keccak256 hash of provided rawContent
// and recovers signer address from given signature and hash
func RecoverAddressFromSignature(sig, rawContent []byte) (types.Address, error) {
pub, err := crypto.RecoverPubkey(sig, crypto.Keccak256(rawContent))
if err != nil {
return types.Address{}, fmt.Errorf("cannot recover address from signature: %w", err)
}
Expand Down
8 changes: 7 additions & 1 deletion consensus/polybft/wallet/key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
)

func Test_RecoverAddressFromSignature(t *testing.T) {
t.Parallel()

for _, account := range []*Account{GenerateAccount(), GenerateAccount(), GenerateAccount()} {
key := NewKey(account)
msgNoSig := &proto.Message{
Expand All @@ -18,7 +20,7 @@ func Test_RecoverAddressFromSignature(t *testing.T) {
Payload: &proto.Message_CommitData{},
}

msg, err := key.SignEcdsaMessage(msgNoSig)
msg, err := key.SignIBFTMessage(msgNoSig)
require.NoError(t, err)

payload, err := msgNoSig.PayloadNoSig()
Expand All @@ -31,6 +33,8 @@ func Test_RecoverAddressFromSignature(t *testing.T) {
}

func Test_Sign(t *testing.T) {
t.Parallel()

msg := []byte("some message")

for _, account := range []*Account{GenerateAccount(), GenerateAccount()} {
Expand All @@ -47,6 +51,8 @@ func Test_Sign(t *testing.T) {
}

func Test_String(t *testing.T) {
t.Parallel()

for _, account := range []*Account{GenerateAccount(), GenerateAccount(), GenerateAccount()} {
key := NewKey(account)
assert.Equal(t, key.Address().String(), key.String())
Expand Down

0 comments on commit 327d115

Please sign in to comment.