From 472f245b3ae2731a168b512b4e375b65aebe1a44 Mon Sep 17 00:00:00 2001 From: Manav Aggarwal Date: Sat, 9 Sep 2023 03:54:39 -0400 Subject: [PATCH] Validate exact expected error in signed header verification tests (#1165) ## Overview Closes: #1049 Stacked on top of #1162 ## Checklist - [x] New and updated code has appropriate documentation - [x] New and updated code has new and/or updated testing - [x] Required CI checks are passing - [ ] Visual proof for any user facing features like CLI or documentation updates - [x] Linked issues closed with keywords --------- Co-authored-by: Matthew Sevey --- types/header.go | 21 ++++++++----- types/signed_header.go | 24 +++++++++++--- types/signed_header_test.go | 62 ++++++++++++++++++++++++++----------- 3 files changed, 77 insertions(+), 30 deletions(-) diff --git a/types/header.go b/types/header.go index 61ce131218a..e03aa1bb6cb 100644 --- a/types/header.go +++ b/types/header.go @@ -2,7 +2,6 @@ package types import ( "encoding" - "errors" "fmt" "time" @@ -85,7 +84,7 @@ func (h *Header) Verify(untrst header.Header) error { } // sanity check fields if err := verifyNewHeaderAndVals(h, untrstH); err != nil { - return &header.VerifyError{Reason: err} + return err } // Check the validator hashes are the same in the case headers are adjacent @@ -132,16 +131,22 @@ func verifyNewHeaderAndVals(trusted, untrusted *Header) error { } if !untrusted.Time().After(trusted.Time()) { - return fmt.Errorf("expected new header time %v to be after old header time %v", - untrusted.Time(), - trusted.Time()) + return fmt.Errorf("%w: %w", + ErrNewHeaderTimeBeforeOldHeaderTime, + fmt.Errorf("expected new header time %v to be after %v", + untrusted.Time(), + trusted.Time(), + ), + ) } if !untrusted.Time().Before(time.Now().Add(maxClockDrift)) { - return fmt.Errorf("new header has a time from the future %v (now: %v; max clock drift: %v)", + return fmt.Errorf("%w: new header time %v (now: %v; max clock drift: %v)", + ErrNewHeaderTimeFromFuture, untrusted.Time(), time.Now(), - maxClockDrift) + maxClockDrift, + ) } return nil @@ -150,7 +155,7 @@ func verifyNewHeaderAndVals(trusted, untrusted *Header) error { // ValidateBasic performs basic validation of a header. func (h *Header) ValidateBasic() error { if len(h.ProposerAddress) == 0 { - return errors.New("no proposer address") + return ErrNoProposerAddress } return nil diff --git a/types/signed_header.go b/types/signed_header.go index 575ed715715..142d4b2efbf 100644 --- a/types/signed_header.go +++ b/types/signed_header.go @@ -17,6 +17,16 @@ func (sH *SignedHeader) IsZero() bool { return sH == nil } +var ( + ErrAggregatorSetHashMismatch = errors.New("aggregator set hash in signed header and hash of validator set do not match") + ErrSignatureVerificationFailed = errors.New("signature verification failed") + ErrNoProposerAddress = errors.New("no proposer address") + ErrLastHeaderHashMismatch = errors.New("last header hash mismatch") + ErrLastCommitHashMismatch = errors.New("last commit hash mismatch") + ErrNewHeaderTimeBeforeOldHeaderTime = errors.New("new header has time before old header time") + ErrNewHeaderTimeFromFuture = errors.New("new header has time from future") +) + func (sH *SignedHeader) Verify(untrst header.Header) error { // Explicit type checks are required due to embedded Header which also does the explicit type check untrstH, ok := untrst.(*SignedHeader) @@ -44,13 +54,19 @@ func (sH *SignedHeader) Verify(untrst header.Header) error { sHHash := sH.Header.Hash() if !bytes.Equal(untrstH.LastHeaderHash[:], sHHash) { return &header.VerifyError{ - Reason: fmt.Errorf("last header hash %v does not match hash of previous header %v", untrstH.LastHeaderHash[:], sHHash), + Reason: fmt.Errorf("%w: expected %v, but got %v", + ErrLastHeaderHashMismatch, + untrstH.LastHeaderHash[:], sHHash, + ), } } sHLastCommitHash := sH.Commit.GetCommitHash(&untrstH.Header, sH.ProposerAddress) if !bytes.Equal(untrstH.LastCommitHash[:], sHLastCommitHash) { return &header.VerifyError{ - Reason: fmt.Errorf("last commit hash %v does not match hash of previous header %v", untrstH.LastCommitHash[:], sHHash), + Reason: fmt.Errorf("%w: expected %v, but got %v", + ErrLastCommitHashMismatch, + untrstH.LastCommitHash[:], sHHash, + ), } } return nil @@ -78,7 +94,7 @@ func (h *SignedHeader) ValidateBasic() error { } if !bytes.Equal(h.Validators.Hash(), h.AggregatorsHash[:]) { - return errors.New("aggregator set hash in signed header and hash of validator set do not match") + return ErrAggregatorSetHashMismatch } // Make sure there is exactly one signature @@ -94,7 +110,7 @@ func (h *SignedHeader) ValidateBasic() error { return errors.New("signature verification failed, unable to marshal header") } if !pubKey.VerifySignature(msg, signature) { - return errors.New("signature verification failed") + return ErrSignatureVerificationFailed } return nil diff --git a/types/signed_header_test.go b/types/signed_header_test.go index 1242a308968..c772b711b3c 100644 --- a/types/signed_header_test.go +++ b/types/signed_header_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/celestiaorg/go-header" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -15,37 +16,46 @@ func TestVerify(t *testing.T) { time.Sleep(time.Second) untrustedAdj, err := GetNextRandomHeader(trusted, privKey) require.NoError(t, err) + fakeAggregatorsHash := header.Hash(GetRandomBytes(32)) + fakeLastHeaderHash := header.Hash(GetRandomBytes(32)) + fakeLastCommitHash := header.Hash(GetRandomBytes(32)) tests := []struct { prepare func() (*SignedHeader, bool) - err bool + err error }{ { prepare: func() (*SignedHeader, bool) { return untrustedAdj, false }, - err: false, + err: nil, }, { prepare: func() (*SignedHeader, bool) { untrusted := *untrustedAdj - untrusted.AggregatorsHash = GetRandomBytes(32) - return &untrusted, true + untrusted.AggregatorsHash = fakeAggregatorsHash + return &untrusted, false + }, + err: &header.VerifyError{ + Reason: ErrAggregatorSetHashMismatch, }, - err: true, }, { prepare: func() (*SignedHeader, bool) { untrusted := *untrustedAdj - untrusted.LastHeaderHash = GetRandomBytes(32) + untrusted.LastHeaderHash = fakeLastHeaderHash return &untrusted, true }, - err: true, + err: &header.VerifyError{ + Reason: ErrLastHeaderHashMismatch, + }, }, { prepare: func() (*SignedHeader, bool) { untrusted := *untrustedAdj - untrusted.LastCommitHash = GetRandomBytes(32) + untrusted.LastCommitHash = fakeLastCommitHash return &untrusted, true }, - err: true, + err: &header.VerifyError{ + Reason: ErrLastCommitHashMismatch, + }, }, { prepare: func() (*SignedHeader, bool) { @@ -54,7 +64,7 @@ func TestVerify(t *testing.T) { untrusted.Header.BaseHeader.Height++ return &untrusted, true }, - err: false, // Accepts non-adjacent headers + err: nil, // Accepts non-adjacent headers }, { prepare: func() (*SignedHeader, bool) { @@ -62,7 +72,9 @@ func TestVerify(t *testing.T) { untrusted.Header.BaseHeader.Time = uint64(untrusted.Header.Time().Truncate(time.Hour).UnixNano()) return &untrusted, true }, - err: true, + err: &header.VerifyError{ + Reason: ErrNewHeaderTimeBeforeOldHeaderTime, + }, }, { prepare: func() (*SignedHeader, bool) { @@ -70,7 +82,9 @@ func TestVerify(t *testing.T) { untrusted.Header.BaseHeader.Time = uint64(untrusted.Header.Time().Add(time.Minute).UnixNano()) return &untrusted, true }, - err: true, + err: &header.VerifyError{ + Reason: ErrNewHeaderTimeFromFuture, + }, }, { prepare: func() (*SignedHeader, bool) { @@ -78,7 +92,9 @@ func TestVerify(t *testing.T) { untrusted.BaseHeader.ChainID = "toaster" return &untrusted, false // Signature verification should fail }, - err: true, + err: &header.VerifyError{ + Reason: ErrSignatureVerificationFailed, + }, }, { prepare: func() (*SignedHeader, bool) { @@ -86,7 +102,9 @@ func TestVerify(t *testing.T) { untrusted.Version.App = untrusted.Version.App + 1 return &untrusted, false // Signature verification should fail }, - err: true, + err: &header.VerifyError{ + Reason: ErrSignatureVerificationFailed, + }, }, { prepare: func() (*SignedHeader, bool) { @@ -94,7 +112,9 @@ func TestVerify(t *testing.T) { untrusted.ProposerAddress = nil return &untrusted, true }, - err: true, + err: &header.VerifyError{ + Reason: ErrNoProposerAddress, + }, }, } @@ -107,11 +127,17 @@ func TestVerify(t *testing.T) { preparedHeader.Commit = *commit } err = trusted.Verify(preparedHeader) - if test.err { - assert.Error(t, err) - } else { + if test.err == nil { assert.NoError(t, err) + return + } + if err == nil { + t.Errorf("expected err: %v, got nil", test.err) + return } + reason := err.(*header.VerifyError).Reason + testReason := test.err.(*header.VerifyError).Reason + assert.ErrorIs(t, reason, testReason) }) } }