Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(lib/grandpa): capped number of tracked vote messages #2485

Merged
merged 12 commits into from
May 30, 2022
62 changes: 29 additions & 33 deletions lib/grandpa/message_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

"github.com/ChainSafe/gossamer/dot/types"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/crypto/ed25519"
"github.com/libp2p/go-libp2p-core/peer"
)

// tracker keeps track of messages that have been received, but have failed to
Expand All @@ -18,8 +18,8 @@ import (
type tracker struct {
blockState BlockState
handler *MessageHandler
// map of vote block hash -> array of VoteMessages for that hash
voteMessages map[common.Hash]map[ed25519.PublicKeyBytes]*networkVoteMessage
votes votesTracker

// map of commit block hash to commit message
commitMessages map[common.Hash]*CommitMessage
mapLock sync.Mutex
Expand All @@ -32,10 +32,11 @@ type tracker struct {
}

func newTracker(bs BlockState, handler *MessageHandler) *tracker {
const votesCapacity = 1000
return &tracker{
blockState: bs,
handler: handler,
voteMessages: make(map[common.Hash]map[ed25519.PublicKeyBytes]*networkVoteMessage),
votes: newVotesTracker(votesCapacity),
commitMessages: make(map[common.Hash]*CommitMessage),
mapLock: sync.Mutex{},
in: bs.GetImportedBlockNotifierChannel(),
Expand All @@ -53,21 +54,15 @@ func (t *tracker) stop() {
t.blockState.FreeImportedBlockNotifierChannel(t.in)
}

func (t *tracker) addVote(v *networkVoteMessage) {
if v.msg == nil {
func (t *tracker) addVote(peerID peer.ID, message *VoteMessage) {
if message == nil {
return
}

t.mapLock.Lock()
defer t.mapLock.Unlock()

msgs, has := t.voteMessages[v.msg.Message.BlockHash]
if !has {
msgs = make(map[ed25519.PublicKeyBytes]*networkVoteMessage)
t.voteMessages[v.msg.Message.BlockHash] = msgs
}

msgs[v.msg.Message.AuthorityID] = v
t.votes.add(peerID, message)
}

func (t *tracker) addCommit(cm *CommitMessage) {
Expand All @@ -76,10 +71,11 @@ func (t *tracker) addCommit(cm *CommitMessage) {
t.commitMessages[cm.Vote.Hash] = cm
}

func (t *tracker) addCatchUpResponse(cr *CatchUpResponse) {
func (t *tracker) addCatchUpResponse(_ *CatchUpResponse) {
t.catchUpResponseMessageMutex.Lock()
defer t.catchUpResponseMessageMutex.Unlock()
t.catchUpResponseMessages[cr.Round] = cr
// uncomment when usage is setup properly, see #1531
// t.catchUpResponseMessages[cr.Round] = cr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this commented out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it's a memory leak as it is now, we never clean these up.
It's commented until it's properly implemented.

}

func (t *tracker) handleBlocks() {
Expand Down Expand Up @@ -108,18 +104,18 @@ func (t *tracker) handleBlock(b *types.Block) {
defer t.mapLock.Unlock()

h := b.Header.Hash()
if vms, has := t.voteMessages[h]; has {
for _, v := range vms {
// handleMessage would never error for vote message
_, err := t.handler.handleMessage(v.from, v.msg)
if err != nil {
logger.Warnf("failed to handle vote message %v: %s", v, err)
}
vms := t.votes.messages(h)
for _, v := range vms {
// handleMessage would never error for vote message
_, err := t.handler.handleMessage(v.from, v.msg)
if err != nil {
logger.Warnf("failed to handle vote message %v: %s", v, err)
}

delete(t.voteMessages, h)
}

// delete block hash that may or may not be in the tracker.
t.votes.delete(h)
qdm12 marked this conversation as resolved.
Show resolved Hide resolved

qdm12 marked this conversation as resolved.
Show resolved Hide resolved
if cm, has := t.commitMessages[h]; has {
_, err := t.handler.handleMessage("", cm)
if err != nil {
Expand All @@ -134,17 +130,17 @@ func (t *tracker) handleTick() {
t.mapLock.Lock()
defer t.mapLock.Unlock()

for _, vms := range t.voteMessages {
for _, v := range vms {
for _, networkVoteMessage := range t.votes.networkVoteMessages() {
peerID := networkVoteMessage.from
message := networkVoteMessage.msg
_, err := t.handler.handleMessage(peerID, message)
if err != nil {
// handleMessage would never error for vote message
_, err := t.handler.handleMessage(v.from, v.msg)
if err != nil {
logger.Debugf("failed to handle vote message %v: %s", v, err)
}
logger.Debugf("failed to handle vote message %v from peer id %s: %s", message, peerID, err)
}

if v.msg.Round < t.handler.grandpa.state.round && v.msg.SetID == t.handler.grandpa.state.setID {
delete(t.voteMessages, v.msg.Message.BlockHash)
}
if message.Round < t.handler.grandpa.state.round && message.SetID == t.handler.grandpa.state.setID {
t.votes.delete(message.Message.BlockHash)
}
}

Expand Down
71 changes: 37 additions & 34 deletions lib/grandpa/message_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ import (
"github.com/stretchr/testify/require"
)

// getMessageFromVotesTracker returns the vote message
// from the votes tracker for the given block hash and authority ID.
func getMessageFromVotesTracker(votes votesTracker,
blockHash common.Hash, authorityID ed25519.PublicKeyBytes) (
message *VoteMessage) {
authorityIDToElement, has := votes.mapping[blockHash]
if !has {
return nil
}

element, ok := authorityIDToElement[authorityID]
if !ok {
return nil
}

return element.Value.(networkVoteMessage).msg
}

func TestMessageTracker_ValidateMessage(t *testing.T) {
kr, err := keystore.NewEd25519Keyring()
require.NoError(t, err)
Expand All @@ -33,13 +51,11 @@ func TestMessageTracker_ValidateMessage(t *testing.T) {
require.NoError(t, err)
gs.keypair = kr.Bob().(*ed25519.Keypair)

expected := &networkVoteMessage{
msg: msg,
}

_, err = gs.validateVoteMessage("", msg)
require.Equal(t, err, ErrBlockDoesNotExist)
require.Equal(t, expected, gs.tracker.voteMessages[fake.Hash()][kr.Alice().Public().(*ed25519.PublicKey).AsBytes()])
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, fake.Hash(), authorityID)
require.Equal(t, msg, voteMessage)
}

func TestMessageTracker_SendMessage(t *testing.T) {
Expand Down Expand Up @@ -72,13 +88,11 @@ func TestMessageTracker_SendMessage(t *testing.T) {
require.NoError(t, err)
gs.keypair = kr.Bob().(*ed25519.Keypair)

expected := &networkVoteMessage{
msg: msg,
}

_, err = gs.validateVoteMessage("", msg)
require.Equal(t, err, ErrBlockDoesNotExist)
require.Equal(t, expected, gs.tracker.voteMessages[next.Hash()][kr.Alice().Public().(*ed25519.PublicKey).AsBytes()])
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, next.Hash(), authorityID)
require.Equal(t, msg, voteMessage)

err = gs.blockState.(*state.BlockState).AddBlock(&types.Block{
Header: *next,
Expand Down Expand Up @@ -126,13 +140,11 @@ func TestMessageTracker_ProcessMessage(t *testing.T) {
require.NoError(t, err)
gs.keypair = kr.Bob().(*ed25519.Keypair)

expected := &networkVoteMessage{
msg: msg,
}

_, err = gs.validateVoteMessage("", msg)
require.Equal(t, ErrBlockDoesNotExist, err)
require.Equal(t, expected, gs.tracker.voteMessages[next.Hash()][kr.Alice().Public().(*ed25519.PublicKey).AsBytes()])
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, next.Hash(), authorityID)
require.Equal(t, msg, voteMessage)

err = gs.blockState.(*state.BlockState).AddBlock(&types.Block{
Header: *next,
Expand All @@ -147,7 +159,7 @@ func TestMessageTracker_ProcessMessage(t *testing.T) {
}
pv, has := gs.prevotes.Load(kr.Alice().Public().(*ed25519.PublicKey).AsBytes())
require.True(t, has)
require.Equal(t, expectedVote, &pv.(*SignedVote).Vote, gs.tracker.voteMessages)
require.Equal(t, expectedVote, &pv.(*SignedVote).Vote, gs.tracker.votes)
}

func TestMessageTracker_MapInsideMap(t *testing.T) {
Expand All @@ -163,24 +175,19 @@ func TestMessageTracker_MapInsideMap(t *testing.T) {
}

hash := header.Hash()
_, ok := gs.tracker.voteMessages[hash]
require.False(t, ok)
messages := gs.tracker.votes.messages(hash)
require.Empty(t, messages)

gs.keypair = kr.Alice().(*ed25519.Keypair)
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
_, msg, err := gs.createSignedVoteAndVoteMessage(NewVoteFromHeader(header), prevote)
require.NoError(t, err)
gs.keypair = kr.Bob().(*ed25519.Keypair)

gs.tracker.addVote(&networkVoteMessage{
msg: msg,
})

voteMsgs, ok := gs.tracker.voteMessages[hash]
require.True(t, ok)
gs.tracker.addVote("", msg)

_, ok = voteMsgs[authorityID]
require.True(t, ok)
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, hash, authorityID)
require.NotEmpty(t, voteMessage)
}

func TestMessageTracker_handleTick(t *testing.T) {
Expand All @@ -197,9 +204,7 @@ func TestMessageTracker_handleTick(t *testing.T) {
BlockHash: testHash,
},
}
gs.tracker.addVote(&networkVoteMessage{
msg: msg,
})
gs.tracker.addVote("", msg)

gs.tracker.handleTick()

Expand All @@ -212,7 +217,7 @@ func TestMessageTracker_handleTick(t *testing.T) {
}

// shouldn't be deleted as round in message >= grandpa round
require.Equal(t, 1, len(gs.tracker.voteMessages[testHash]))
require.Len(t, gs.tracker.votes.messages(testHash), 1)

gs.state.round = 1
msg = &VoteMessage{
Expand All @@ -221,9 +226,7 @@ func TestMessageTracker_handleTick(t *testing.T) {
BlockHash: testHash,
},
}
gs.tracker.addVote(&networkVoteMessage{
msg: msg,
})
gs.tracker.addVote("", msg)

gs.tracker.handleTick()

Expand All @@ -235,5 +238,5 @@ func TestMessageTracker_handleTick(t *testing.T) {
}

// should be deleted as round in message < grandpa round
require.Empty(t, len(gs.tracker.voteMessages[testHash]))
require.Empty(t, gs.tracker.votes.messages(testHash))
}
76 changes: 46 additions & 30 deletions lib/grandpa/vote_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,51 +126,70 @@ func (s *Service) validateVoteMessage(from peer.ID, m *VoteMessage) (*Vote, erro
// check for message signature
pk, err := ed25519.NewPublicKey(m.Message.AuthorityID[:])
if err != nil {
// TODO Affect peer reputation
// https://github.com/ChainSafe/gossamer/issues/2505
return nil, err
}

err = validateMessageSignature(pk, m)
if err != nil {
// TODO Affect peer reputation
// https://github.com/ChainSafe/gossamer/issues/2505
return nil, err
}

if m.SetID != s.state.setID {
return nil, ErrSetIDMismatch
}

// check that vote is for current round
if m.Round != s.state.round {
if m.Round < s.state.round {
// peer doesn't know round was finalised, send out another commit message
header, err := s.blockState.GetFinalisedHeader(m.Round, m.SetID)
if err != nil {
return nil, err
}
const maxRoundsLag = 1
minRoundAccepted := s.state.round - maxRoundsLag
if minRoundAccepted > s.state.round {
// we overflowed below 0 so set the minimum to 0.
minRoundAccepted = 0
}

cm, err := s.newCommitMessage(header, m.Round)
if err != nil {
return nil, err
}
const maxRoundsAhead = 1
maxRoundAccepted := s.state.round + maxRoundsAhead
qdm12 marked this conversation as resolved.
Show resolved Hide resolved

// send finalised block from previous round to network
msg, err := cm.ToConsensusMessage()
if err != nil {
return nil, err
}
if m.Round < minRoundAccepted || m.Round > maxRoundAccepted {
// Discard message
// TODO: affect peer reputation, this is shameful impolite behaviour
// https://github.com/ChainSafe/gossamer/issues/2505
return nil, nil //nolint:nilnil
}

if err = s.network.SendMessage(from, msg); err != nil {
logger.Warnf("failed to send CommitMessage: %s", err)
}
} else {
// round is higher than ours, perhaps we are behind. store vote in tracker for now
s.tracker.addVote(&networkVoteMessage{
from: from,
msg: m,
})
if m.Round < s.state.round {
// message round is lagging by 1
// peer doesn't know round was finalised, send out another commit message
header, err := s.blockState.GetFinalisedHeader(m.Round, m.SetID)
if err != nil {
return nil, err
}

cm, err := s.newCommitMessage(header, m.Round)
if err != nil {
return nil, err
}

// send finalised block from previous round to network
msg, err := cm.ToConsensusMessage()
if err != nil {
return nil, err
}

if err = s.network.SendMessage(from, msg); err != nil {
logger.Warnf("failed to send CommitMessage: %s", err)
}

// TODO: get justification if your round is lower, or just do catch-up? (#1815)
return nil, errRoundMismatch(m.Round, s.state.round)
} else if m.Round > s.state.round {
// Message round is higher by 1 than the round of our state,
// we may be lagging behind, so store the message in the tracker
// for processing later in the coming few milliseconds.
s.tracker.addVote(from, m)
return nil, errRoundMismatch(m.Round, s.state.round)
}

// check for equivocation ie. multiple votes within one subround
Expand All @@ -192,10 +211,7 @@ func (s *Service) validateVoteMessage(from peer.ID, m *VoteMessage) (*Vote, erro
errors.Is(err, blocktree.ErrDescendantNotFound) ||
errors.Is(err, blocktree.ErrEndNodeNotFound) ||
errors.Is(err, blocktree.ErrStartNodeNotFound) {
s.tracker.addVote(&networkVoteMessage{
from: from,
msg: m,
})
s.tracker.addVote(from, m)
}
if err != nil {
return nil, err
Expand Down
Loading