Skip to content

Commit

Permalink
remove Deadlock from lib/grandpa/message_tracker.go (#2923)
Browse files Browse the repository at this point in the history
- Also added tests for commits tracker
  • Loading branch information
kishansagathiya authored Nov 15, 2022
1 parent ac2af46 commit bb637ff
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 23 deletions.
11 changes: 11 additions & 0 deletions lib/grandpa/commits_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package grandpa

import (
"container/list"
"sync"

"github.com/ChainSafe/gossamer/lib/common"
)
Expand All @@ -14,6 +15,7 @@ import (
// its maximum capacity is reached.
// It is NOT THREAD SAFE to use.
type commitsTracker struct {
sync.Mutex
// map of commit block hash to linked list commit message.
mapping map[common.Hash]*list.Element
// double linked list of commit messages
Expand All @@ -36,6 +38,9 @@ func newCommitsTracker(capacity int) commitsTracker {
// If the commit message tracker capacity is reached,
// the oldest commit message is removed.
func (ct *commitsTracker) add(commitMessage *CommitMessage) {
ct.Lock()
defer ct.Unlock()

blockHash := commitMessage.Vote.Hash

listElement, has := ct.mapping[blockHash]
Expand Down Expand Up @@ -75,6 +80,9 @@ func (ct *commitsTracker) cleanup() {
// delete deletes all the vote messages for a particular
// block hash from the vote messages tracker.
func (ct *commitsTracker) delete(blockHash common.Hash) {
ct.Lock()
defer ct.Unlock()

listElement, has := ct.mapping[blockHash]
if !has {
return
Expand All @@ -90,6 +98,9 @@ func (ct *commitsTracker) delete(blockHash common.Hash) {
// does not exist in the tracker
func (ct *commitsTracker) message(blockHash common.Hash) (
message *CommitMessage) {
ct.Lock()
defer ct.Unlock()

listElement, ok := ct.mapping[blockHash]
if !ok {
return nil
Expand Down
52 changes: 51 additions & 1 deletion lib/grandpa/commits_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ import (
"container/list"
"crypto/rand"
"sort"
"sync"
"testing"
"time"

"github.com/ChainSafe/gossamer/dot/types"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -51,7 +54,9 @@ func Test_newCommitsTracker(t *testing.T) {
}
vt := newCommitsTracker(capacity)

assert.Equal(t, expected, vt)
assert.Equal(t, expected.mapping, vt.mapping)
assert.Equal(t, expected.linkedList, vt.linkedList)
assert.Equal(t, expected.capacity, vt.capacity)
}

// We cannot really unit test each method independently
Expand Down Expand Up @@ -319,3 +324,48 @@ func Benchmark_ForEachVsSlice(b *testing.B) {
}
})
}

func Test_commitsTracker_threadSafety(t *testing.T) {
// This test is meant to be run with the `-race` flag
// to detect any data race.
t.Parallel()

const capacity = 2
commitsTracker := newCommitsTracker(capacity)

const parallelism = 10

var endWg sync.WaitGroup
defer endWg.Wait()

for i := 1; i < parallelism; i++ {
endWg.Add(1)
go func(i int) {
defer endWg.Done()

blockHash := common.Hash{byte(i)}

commitMessage := &CommitMessage{
Round: 1,
SetID: 1,
Vote: types.GrandpaVote{
Hash: blockHash,
Number: uint32(i),
},
}

timer := time.NewTimer(50 * time.Millisecond)
for {
select {
case <-timer.C:
return
default:
}

commitsTracker.add(commitMessage)
commitsTracker.delete(blockHash)
_ = commitsTracker.message(blockHash)
}
}(i)
}
}
13 changes: 0 additions & 13 deletions lib/grandpa/message_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ type tracker struct {
handler *MessageHandler
votes votesTracker
commits commitsTracker
mapLock sync.Mutex
in chan *types.Block // receive imported block from BlockState
stopped chan struct{}

Expand All @@ -38,7 +37,6 @@ func newTracker(bs BlockState, handler *MessageHandler) *tracker {
handler: handler,
votes: newVotesTracker(votesCapacity),
commits: newCommitsTracker(commitsCapacity),
mapLock: sync.Mutex{},
in: bs.GetImportedBlockNotifierChannel(),
stopped: make(chan struct{}),
catchUpResponseMessages: make(map[uint64]*CatchUpResponse),
Expand All @@ -59,15 +57,10 @@ func (t *tracker) addVote(peerID peer.ID, message *VoteMessage) {
return
}

t.mapLock.Lock()
defer t.mapLock.Unlock()

t.votes.add(peerID, message)
}

func (t *tracker) addCommit(cm *CommitMessage) {
t.mapLock.Lock()
defer t.mapLock.Unlock()
t.commits.add(cm)
}

Expand Down Expand Up @@ -100,9 +93,6 @@ func (t *tracker) handleBlocks() {
}

func (t *tracker) handleBlock(b *types.Block) {
t.mapLock.Lock()
defer t.mapLock.Unlock()

h := b.Header.Hash()
vms := t.votes.messages(h)
for _, v := range vms {
Expand All @@ -128,9 +118,6 @@ func (t *tracker) handleBlock(b *types.Block) {
}

func (t *tracker) handleTick() {
t.mapLock.Lock()
defer t.mapLock.Unlock()

for _, networkVoteMessage := range t.votes.networkVoteMessages() {
peerID := networkVoteMessage.from
message := networkVoteMessage.msg
Expand Down
27 changes: 19 additions & 8 deletions lib/grandpa/message_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package grandpa

import (
"container/list"
"testing"
"time"

Expand All @@ -16,12 +17,12 @@ import (
"github.com/stretchr/testify/require"
)

// getMessageFromVotesTracker returns the vote message
// getMessageFromVotesMapping returns the vote message
// from the votes tracker for the given block hash and authority ID.
func getMessageFromVotesTracker(votes votesTracker,
func getMessageFromVotesMapping(votesMapping map[common.Hash]map[ed25519.PublicKeyBytes]*list.Element,
blockHash common.Hash, authorityID ed25519.PublicKeyBytes) (
message *VoteMessage) {
authorityIDToElement, has := votes.mapping[blockHash]
authorityIDToElement, has := votesMapping[blockHash]
if !has {
return nil
}
Expand Down Expand Up @@ -54,7 +55,7 @@ func TestMessageTracker_ValidateMessage(t *testing.T) {
_, err = gs.validateVoteMessage("", msg)
require.Equal(t, err, ErrBlockDoesNotExist)
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, fake.Hash(), authorityID)
voteMessage := getMessageFromVotesMapping(gs.tracker.votes.mapping, fake.Hash(), authorityID)
require.Equal(t, msg, voteMessage)
}

Expand Down Expand Up @@ -91,7 +92,7 @@ func TestMessageTracker_SendMessage(t *testing.T) {
_, err = gs.validateVoteMessage("", msg)
require.Equal(t, err, ErrBlockDoesNotExist)
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, next.Hash(), authorityID)
voteMessage := getMessageFromVotesMapping(gs.tracker.votes.mapping, next.Hash(), authorityID)
require.Equal(t, msg, voteMessage)

err = gs.blockState.(*state.BlockState).AddBlock(&types.Block{
Expand Down Expand Up @@ -143,7 +144,7 @@ func TestMessageTracker_ProcessMessage(t *testing.T) {
_, err = gs.validateVoteMessage("", msg)
require.Equal(t, ErrBlockDoesNotExist, err)
authorityID := kr.Alice().Public().(*ed25519.PublicKey).AsBytes()
voteMessage := getMessageFromVotesTracker(gs.tracker.votes, next.Hash(), authorityID)
voteMessage := getMessageFromVotesMapping(gs.tracker.votes.mapping, next.Hash(), authorityID)
require.Equal(t, msg, voteMessage)

err = gs.blockState.(*state.BlockState).AddBlock(&types.Block{
Expand All @@ -159,7 +160,7 @@ func TestMessageTracker_ProcessMessage(t *testing.T) {
}
pv, has := gs.prevotes.Load(kr.Alice().Public().(*ed25519.PublicKey).AsBytes())
require.True(t, has)
require.Equal(t, expectedVote, &pv.(*SignedVote).Vote, gs.tracker.votes)
require.Equal(t, expectedVote, &pv.(*SignedVote).Vote)
}

func TestMessageTracker_MapInsideMap(t *testing.T) {
Expand All @@ -186,7 +187,7 @@ func TestMessageTracker_MapInsideMap(t *testing.T) {

gs.tracker.addVote("", msg)

voteMessage := getMessageFromVotesTracker(gs.tracker.votes, hash, authorityID)
voteMessage := getMessageFromVotesMapping(gs.tracker.votes.mapping, hash, authorityID)
require.NotEmpty(t, voteMessage)
}

Expand Down Expand Up @@ -227,6 +228,15 @@ func TestMessageTracker_handleTick(t *testing.T) {
},
}
gs.tracker.addVote("", msg)
commitMessage := &CommitMessage{
Round: 100,
SetID: 1,
Vote: types.GrandpaVote{
Hash: testHash,
Number: 1,
},
}
gs.tracker.addCommit(commitMessage)

gs.tracker.handleTick()

Expand All @@ -239,4 +249,5 @@ func TestMessageTracker_handleTick(t *testing.T) {

// should be deleted as round in message < grandpa round
require.Empty(t, gs.tracker.votes.messages(testHash))
require.Empty(t, gs.tracker.commits.message(testHash))
}
14 changes: 14 additions & 0 deletions lib/grandpa/votes_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package grandpa

import (
"container/list"
"sync"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/crypto/ed25519"
Expand All @@ -16,6 +17,7 @@ import (
// its maximum capacity is reached.
// It is NOT THREAD SAFE to use.
type votesTracker struct {
sync.Mutex
// map of vote block hash to authority ID (ed25519 public Key)
// to linked list element pointer
mapping map[common.Hash]map[ed25519.PublicKeyBytes]*list.Element
Expand All @@ -38,6 +40,9 @@ func newVotesTracker(capacity int) votesTracker {
// If the vote message tracker capacity is reached,
// the oldest vote message is removed.
func (vt *votesTracker) add(peerID peer.ID, voteMessage *VoteMessage) {
vt.Lock()
defer vt.Unlock()

signedMessage := voteMessage.Message
blockHash := signedMessage.BlockHash
authorityID := signedMessage.AuthorityID
Expand Down Expand Up @@ -101,6 +106,9 @@ func (vt *votesTracker) cleanup() {
// delete deletes all the vote messages for a particular
// block hash from the vote messages tracker.
func (vt *votesTracker) delete(blockHash common.Hash) {
vt.Lock()
defer vt.Unlock()

authIDToElement, has := vt.mapping[blockHash]
if !has {
return
Expand All @@ -119,6 +127,9 @@ func (vt *votesTracker) delete(blockHash common.Hash) {
// It returns nil if the block hash does not exist.
func (vt *votesTracker) messages(blockHash common.Hash) (
messages []networkVoteMessage) {
vt.Lock()
defer vt.Unlock()

authIDToElement, ok := vt.mapping[blockHash]
if !ok {
// Note authIDToElement cannot be empty
Expand All @@ -138,6 +149,9 @@ func (vt *votesTracker) messages(blockHash common.Hash) (
// as a slice of networkVoteMessages.
func (vt *votesTracker) networkVoteMessages() (
messages []networkVoteMessage) {
vt.Lock()
defer vt.Unlock()

messages = make([]networkVoteMessage, 0, vt.linkedList.Len())
for _, authorityIDToElement := range vt.mapping {
for _, element := range authorityIDToElement {
Expand Down
4 changes: 3 additions & 1 deletion lib/grandpa/votes_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ func Test_newVotesTracker(t *testing.T) {
}
vt := newVotesTracker(capacity)

assert.Equal(t, expected, vt)
assert.Equal(t, expected.mapping, vt.mapping)
assert.Equal(t, expected.linkedList, vt.linkedList)
assert.Equal(t, expected.capacity, vt.capacity)
}

// We cannot really unit test each method independently
Expand Down

0 comments on commit bb637ff

Please sign in to comment.