diff --git a/consensus/ibft/ibft.go b/consensus/ibft/ibft.go index 687ad52c11..7113b15efa 100644 --- a/consensus/ibft/ibft.go +++ b/consensus/ibft/ibft.go @@ -916,12 +916,12 @@ func (i *Ibft) runValidateState() { panic(fmt.Sprintf("BUG: %s", reflect.TypeOf(msg.Type))) } - if i.state.numPrepared() > i.state.NumValid() { + if i.state.numPrepared() >= i.state.validators.QuorumSize() { // we have received enough pre-prepare messages sendCommit() } - if i.state.numCommitted() > i.state.NumValid() { + if i.state.numCommitted() >= i.state.validators.QuorumSize() { // we have received enough commit messages sendCommit() @@ -1102,17 +1102,15 @@ func (i *Ibft) runRoundChangeState() { // we only expect RoundChange messages right now num := i.state.AddRoundMessage(msg) - if num == i.state.NumValid() { + if num == i.state.validators.MaxFaultyNodes()+1 && i.state.view.Round < msg.View.Round { + // weak certificate, try to catch up if our round number is smaller + // update timer + timeout = exponentialTimeout(i.state.view.Round) + sendRoundChange(msg.View.Round) + } else if num == i.state.validators.QuorumSize() { // start a new round immediately i.state.view.Round = msg.View.Round i.setState(AcceptState) - } else if num == i.state.validators.MaxFaultyNodes()+1 { - // weak certificate, try to catch up if our round number is smaller - if i.state.view.Round < msg.View.Round { - // update timer - timeout = exponentialTimeout(i.state.view.Round) - sendRoundChange(msg.View.Round) - } } } } diff --git a/consensus/ibft/ibft_test.go b/consensus/ibft/ibft_test.go index e4fa4ac8ae..f2e953d75d 100644 --- a/consensus/ibft/ibft_test.go +++ b/consensus/ibft/ibft_test.go @@ -342,7 +342,7 @@ func TestTransition_RoundChangeState_CatchupRound(t *testing.T) { m.expect(expectResult{ sequence: 1, round: 2, - outgoing: 1, // our new round change + outgoing: 2, // our new round change state: AcceptState, }) } diff --git a/consensus/ibft/sign.go b/consensus/ibft/sign.go index 27b786270b..cb2e922b67 100644 --- a/consensus/ibft/sign.go +++ b/consensus/ibft/sign.go @@ -207,7 +207,7 @@ func verifyCommitedFields(snap *Snapshot, header *types.Header) error { // Valid committed seals must be at least 2F+1 // 2F is the required number of honest validators who provided the committed seals // +1 is the proposer - if validSeals := len(visited); validSeals <= 2*snap.Set.MaxFaultyNodes() { + if validSeals := len(visited); validSeals < snap.Set.QuorumSize() { return fmt.Errorf("not enough seals to seal block") } diff --git a/consensus/ibft/sign_test.go b/consensus/ibft/sign_test.go index d0add5a9b4..059b19a437 100644 --- a/consensus/ibft/sign_test.go +++ b/consensus/ibft/sign_test.go @@ -63,7 +63,7 @@ func TestSign_CommittedSeals(t *testing.T) { } // Correct - assert.NoError(t, buildCommittedSeal([]string{"A", "B", "C"})) + assert.NoError(t, buildCommittedSeal([]string{"A", "B", "C", "D"})) // Failed - Repeated signature assert.Error(t, buildCommittedSeal([]string{"A", "A"})) diff --git a/consensus/ibft/state.go b/consensus/ibft/state.go index 64b6173cc3..ba38ad52e2 100644 --- a/consensus/ibft/state.go +++ b/consensus/ibft/state.go @@ -2,6 +2,7 @@ package ibft import ( "fmt" + "math" "sync/atomic" "github.com/0xPolygon/polygon-edge/consensus/ibft/proto" @@ -96,15 +97,6 @@ func (c *currentState) setState(s IbftState) { atomic.StoreUint64(stateAddr, uint64(s)) } -// NumValid returns the number of required messages -func (c *currentState) NumValid() int { - // According to the IBFT spec, the number of valid messages - // needs to be 2F + 1 - // The 1 missing from this equation is accounted for elsewhere - // (the current node tallying the messages will include its own message) - return 2 * c.validators.MaxFaultyNodes() -} - // getErr returns the current error, if any, and consumes it func (c *currentState) getErr() error { err := c.err @@ -153,7 +145,12 @@ func (c *currentState) unlock() { // cleanRound deletes the specific round messages func (c *currentState) cleanRound(round uint64) { - delete(c.roundMessages, round) + // clear messages from previous round + for r := range c.roundMessages { + if r < round { + delete(c.roundMessages, r) + } + } } // AddRoundMessage adds a message to the round, and returns the round message size @@ -306,3 +303,20 @@ func (v *ValidatorSet) MaxFaultyNodes() int { // It should always take the floor of the result return (len(*v) - 1) / 3 } + +// QuorumSize returns the number of required messages for consensus +func (v ValidatorSet) QuorumSize() int { + // if the number of validators is less than 4, + // then the entire set is required + if v.MaxFaultyNodes() == 0 { + /* + N: 1 -> Q: 1 + N: 2 -> Q: 2 + N: 3 -> Q: 3 + */ + return v.Len() + } + + // (quorum optimal) Q = ceil(2/3 * N) + return int(math.Ceil(2 * float64(v.Len()) / 3)) +} diff --git a/consensus/ibft/state_test.go b/consensus/ibft/state_test.go index c9a46e68e8..414bc1efbc 100644 --- a/consensus/ibft/state_test.go +++ b/consensus/ibft/state_test.go @@ -1,6 +1,7 @@ package ibft import ( + "strconv" "testing" "github.com/0xPolygon/polygon-edge/consensus/ibft/proto" @@ -28,6 +29,44 @@ func TestState_FaultyNodes(t *testing.T) { } } +// TestNumValid checks if the quorum size is calculated +// correctly based on number of validators (network size). +func TestNumValid(t *testing.T) { + cases := []struct { + Network, Quorum uint64 + }{ + {1, 1}, + {2, 2}, + {3, 3}, + {4, 3}, + {5, 4}, + {6, 4}, + {7, 5}, + {8, 6}, + {9, 6}, + } + + addAccounts := func( + pool *testerAccountPool, + numAccounts int, + ) { + // add accounts + for i := 0; i < numAccounts; i++ { + pool.add(strconv.Itoa(i)) + } + } + + for _, c := range cases { + pool := newTesterAccountPool(int(c.Network)) + addAccounts(pool, int(c.Network)) + + assert.Equal(t, + int(c.Quorum), + pool.ValidatorSet().QuorumSize(), + ) + } +} + func TestState_AddMessages(t *testing.T) { pool := newTesterAccountPool() pool.add("A", "B", "C", "D")