diff --git a/lib/grandpa/commits_tracker.go b/lib/grandpa/commits_tracker.go index 14addd5a86..8e5bd4c0c1 100644 --- a/lib/grandpa/commits_tracker.go +++ b/lib/grandpa/commits_tracker.go @@ -5,6 +5,7 @@ package grandpa import ( "container/list" + "sync" "github.com/ChainSafe/gossamer/lib/common" ) @@ -14,6 +15,7 @@ import ( // its maximum capacity is reached. // It is NOT THREAD SAFE to use. type commitsTracker struct { + sync.Mutex // map of commit block hash to linked list commit message. mapping map[common.Hash]*list.Element // double linked list of commit messages @@ -36,6 +38,9 @@ func newCommitsTracker(capacity int) commitsTracker { // If the commit message tracker capacity is reached, // the oldest commit message is removed. func (ct *commitsTracker) add(commitMessage *CommitMessage) { + ct.Lock() + defer ct.Unlock() + blockHash := commitMessage.Vote.Hash listElement, has := ct.mapping[blockHash] @@ -75,6 +80,9 @@ func (ct *commitsTracker) cleanup() { // delete deletes all the vote messages for a particular // block hash from the vote messages tracker. func (ct *commitsTracker) delete(blockHash common.Hash) { + ct.Lock() + defer ct.Unlock() + listElement, has := ct.mapping[blockHash] if !has { return @@ -90,6 +98,9 @@ func (ct *commitsTracker) delete(blockHash common.Hash) { // does not exist in the tracker func (ct *commitsTracker) message(blockHash common.Hash) ( message *CommitMessage) { + ct.Lock() + defer ct.Unlock() + listElement, ok := ct.mapping[blockHash] if !ok { return nil diff --git a/lib/grandpa/commits_tracker_test.go b/lib/grandpa/commits_tracker_test.go index d5d28bb588..7adde210ff 100644 --- a/lib/grandpa/commits_tracker_test.go +++ b/lib/grandpa/commits_tracker_test.go @@ -8,8 +8,11 @@ import ( "container/list" "crypto/rand" "sort" + "sync" "testing" + "time" + "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -51,7 +54,9 @@ func Test_newCommitsTracker(t *testing.T) { } vt := newCommitsTracker(capacity) - assert.Equal(t, expected, vt) + assert.Equal(t, expected.mapping, vt.mapping) + assert.Equal(t, expected.linkedList, vt.linkedList) + assert.Equal(t, expected.capacity, vt.capacity) } // We cannot really unit test each method independently @@ -319,3 +324,48 @@ func Benchmark_ForEachVsSlice(b *testing.B) { } }) } + +func Test_commitsTracker_threadSafety(t *testing.T) { + // This test is meant to be run with the `-race` flag + // to detect any data race. + t.Parallel() + + const capacity = 2 + commitsTracker := newCommitsTracker(capacity) + + const parallelism = 10 + + var endWg sync.WaitGroup + defer endWg.Wait() + + for i := 1; i < parallelism; i++ { + endWg.Add(1) + go func(i int) { + defer endWg.Done() + + blockHash := common.Hash{byte(i)} + + commitMessage := &CommitMessage{ + Round: 1, + SetID: 1, + Vote: types.GrandpaVote{ + Hash: blockHash, + Number: uint32(i), + }, + } + + timer := time.NewTimer(50 * time.Millisecond) + for { + select { + case <-timer.C: + return + default: + } + + commitsTracker.add(commitMessage) + commitsTracker.delete(blockHash) + _ = commitsTracker.message(blockHash) + } + }(i) + } +} diff --git a/lib/grandpa/message_tracker.go b/lib/grandpa/message_tracker.go index 1c380bb175..051d681f96 100644 --- a/lib/grandpa/message_tracker.go +++ b/lib/grandpa/message_tracker.go @@ -19,7 +19,6 @@ type tracker struct { handler *MessageHandler votes votesTracker commits commitsTracker - mapLock sync.Mutex in chan *types.Block // receive imported block from BlockState stopped chan struct{} @@ -38,7 +37,6 @@ func newTracker(bs BlockState, handler *MessageHandler) *tracker { handler: handler, votes: newVotesTracker(votesCapacity), commits: newCommitsTracker(commitsCapacity), - mapLock: sync.Mutex{}, in: bs.GetImportedBlockNotifierChannel(), stopped: make(chan struct{}), catchUpResponseMessages: make(map[uint64]*CatchUpResponse), @@ -59,15 +57,10 @@ func (t *tracker) addVote(peerID peer.ID, message *VoteMessage) { return } - t.mapLock.Lock() - defer t.mapLock.Unlock() - t.votes.add(peerID, message) } func (t *tracker) addCommit(cm *CommitMessage) { - t.mapLock.Lock() - defer t.mapLock.Unlock() t.commits.add(cm) } @@ -100,9 +93,6 @@ func (t *tracker) handleBlocks() { } func (t *tracker) handleBlock(b *types.Block) { - t.mapLock.Lock() - defer t.mapLock.Unlock() - h := b.Header.Hash() vms := t.votes.messages(h) for _, v := range vms { @@ -128,9 +118,6 @@ func (t *tracker) handleBlock(b *types.Block) { } func (t *tracker) handleTick() { - t.mapLock.Lock() - defer t.mapLock.Unlock() - for _, networkVoteMessage := range t.votes.networkVoteMessages() { peerID := networkVoteMessage.from message := networkVoteMessage.msg diff --git a/lib/grandpa/message_tracker_test.go b/lib/grandpa/message_tracker_test.go index 56a43adeef..7de9ce4309 100644 --- a/lib/grandpa/message_tracker_test.go +++ b/lib/grandpa/message_tracker_test.go @@ -4,6 +4,7 @@ package grandpa import ( + "container/list" "testing" "time" @@ -16,12 +17,12 @@ import ( "github.com/stretchr/testify/require" ) -// getMessageFromVotesTracker returns the vote message +// getMessageFromVotesMapping returns the vote message // from the votes tracker for the given block hash and authority ID. -func getMessageFromVotesTracker(votes votesTracker, +func getMessageFromVotesMapping(votesMapping map[common.Hash]map[ed25519.PublicKeyBytes]*list.Element, blockHash common.Hash, authorityID ed25519.PublicKeyBytes) ( message *VoteMessage) { - authorityIDToElement, has := votes.mapping[blockHash] + authorityIDToElement, has := votesMapping[blockHash] if !has { return nil } @@ -54,7 +55,7 @@ func TestMessageTracker_ValidateMessage(t *testing.T) { _, err = gs.validateVoteMessage("", msg) require.Equal(t, err, ErrBlockDoesNotExist) authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes() - voteMessage := getMessageFromVotesTracker(gs.tracker.votes, fake.Hash(), authorityID) + voteMessage := getMessageFromVotesMapping(gs.tracker.votes.mapping, fake.Hash(), authorityID) require.Equal(t, msg, voteMessage) } @@ -91,7 +92,7 @@ func TestMessageTracker_SendMessage(t *testing.T) { _, err = gs.validateVoteMessage("", msg) require.Equal(t, err, ErrBlockDoesNotExist) authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes() - voteMessage := getMessageFromVotesTracker(gs.tracker.votes, next.Hash(), authorityID) + voteMessage := getMessageFromVotesMapping(gs.tracker.votes.mapping, next.Hash(), authorityID) require.Equal(t, msg, voteMessage) err = gs.blockState.(*state.BlockState).AddBlock(&types.Block{ @@ -143,7 +144,7 @@ func TestMessageTracker_ProcessMessage(t *testing.T) { _, err = gs.validateVoteMessage("", msg) require.Equal(t, ErrBlockDoesNotExist, err) authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes() - voteMessage := getMessageFromVotesTracker(gs.tracker.votes, next.Hash(), authorityID) + voteMessage := getMessageFromVotesMapping(gs.tracker.votes.mapping, next.Hash(), authorityID) require.Equal(t, msg, voteMessage) err = gs.blockState.(*state.BlockState).AddBlock(&types.Block{ @@ -159,7 +160,7 @@ func TestMessageTracker_ProcessMessage(t *testing.T) { } pv, has := gs.prevotes.Load(kr.Alice().Public().(*ed25519.PublicKey).AsBytes()) require.True(t, has) - require.Equal(t, expectedVote, &pv.(*SignedVote).Vote, gs.tracker.votes) + require.Equal(t, expectedVote, &pv.(*SignedVote).Vote) } func TestMessageTracker_MapInsideMap(t *testing.T) { @@ -186,7 +187,7 @@ func TestMessageTracker_MapInsideMap(t *testing.T) { gs.tracker.addVote("", msg) - voteMessage := getMessageFromVotesTracker(gs.tracker.votes, hash, authorityID) + voteMessage := getMessageFromVotesMapping(gs.tracker.votes.mapping, hash, authorityID) require.NotEmpty(t, voteMessage) } @@ -227,6 +228,15 @@ func TestMessageTracker_handleTick(t *testing.T) { }, } gs.tracker.addVote("", msg) + commitMessage := &CommitMessage{ + Round: 100, + SetID: 1, + Vote: types.GrandpaVote{ + Hash: testHash, + Number: 1, + }, + } + gs.tracker.addCommit(commitMessage) gs.tracker.handleTick() @@ -239,4 +249,5 @@ func TestMessageTracker_handleTick(t *testing.T) { // should be deleted as round in message < grandpa round require.Empty(t, gs.tracker.votes.messages(testHash)) + require.Empty(t, gs.tracker.commits.message(testHash)) } diff --git a/lib/grandpa/votes_tracker.go b/lib/grandpa/votes_tracker.go index ed69088e5c..0a4bf1c273 100644 --- a/lib/grandpa/votes_tracker.go +++ b/lib/grandpa/votes_tracker.go @@ -5,6 +5,7 @@ package grandpa import ( "container/list" + "sync" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/crypto/ed25519" @@ -16,6 +17,7 @@ import ( // its maximum capacity is reached. // It is NOT THREAD SAFE to use. type votesTracker struct { + sync.Mutex // map of vote block hash to authority ID (ed25519 public Key) // to linked list element pointer mapping map[common.Hash]map[ed25519.PublicKeyBytes]*list.Element @@ -38,6 +40,9 @@ func newVotesTracker(capacity int) votesTracker { // If the vote message tracker capacity is reached, // the oldest vote message is removed. func (vt *votesTracker) add(peerID peer.ID, voteMessage *VoteMessage) { + vt.Lock() + defer vt.Unlock() + signedMessage := voteMessage.Message blockHash := signedMessage.BlockHash authorityID := signedMessage.AuthorityID @@ -101,6 +106,9 @@ func (vt *votesTracker) cleanup() { // delete deletes all the vote messages for a particular // block hash from the vote messages tracker. func (vt *votesTracker) delete(blockHash common.Hash) { + vt.Lock() + defer vt.Unlock() + authIDToElement, has := vt.mapping[blockHash] if !has { return @@ -119,6 +127,9 @@ func (vt *votesTracker) delete(blockHash common.Hash) { // It returns nil if the block hash does not exist. func (vt *votesTracker) messages(blockHash common.Hash) ( messages []networkVoteMessage) { + vt.Lock() + defer vt.Unlock() + authIDToElement, ok := vt.mapping[blockHash] if !ok { // Note authIDToElement cannot be empty @@ -138,6 +149,9 @@ func (vt *votesTracker) messages(blockHash common.Hash) ( // as a slice of networkVoteMessages. func (vt *votesTracker) networkVoteMessages() ( messages []networkVoteMessage) { + vt.Lock() + defer vt.Unlock() + messages = make([]networkVoteMessage, 0, vt.linkedList.Len()) for _, authorityIDToElement := range vt.mapping { for _, element := range authorityIDToElement { diff --git a/lib/grandpa/votes_tracker_test.go b/lib/grandpa/votes_tracker_test.go index 7735cc7615..2ea483295b 100644 --- a/lib/grandpa/votes_tracker_test.go +++ b/lib/grandpa/votes_tracker_test.go @@ -70,7 +70,9 @@ func Test_newVotesTracker(t *testing.T) { } vt := newVotesTracker(capacity) - assert.Equal(t, expected, vt) + assert.Equal(t, expected.mapping, vt.mapping) + assert.Equal(t, expected.linkedList, vt.linkedList) + assert.Equal(t, expected.capacity, vt.capacity) } // We cannot really unit test each method independently