Skip to content
This repository has been archived by the owner on Feb 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #76 from ipfs/fix/pubsub-wait-race
Browse files Browse the repository at this point in the history
fix multiple data races
  • Loading branch information
Stebalien authored Feb 20, 2019
2 parents fc3f5b0 + 52f9630 commit 722239f
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 36 deletions.
5 changes: 3 additions & 2 deletions bitswap.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
"time"

bssrs "github.com/ipfs/go-bitswap/sessionrequestsplitter"
Expand Down Expand Up @@ -292,7 +291,9 @@ func (bs *Bitswap) receiveBlockFrom(blk blocks.Block, from peer.ID) error {
}

func (bs *Bitswap) ReceiveMessage(ctx context.Context, p peer.ID, incoming bsmsg.BitSwapMessage) {
atomic.AddUint64(&bs.counters.messagesRecvd, 1)
bs.counterLk.Lock()
bs.counters.messagesRecvd++
bs.counterLk.Unlock()

// This call records changes to wantlists, blocks received,
// and number of bytes transfered.
Expand Down
2 changes: 1 addition & 1 deletion bitswap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func TestLargeSwarm(t *testing.T) {
if detectrace.WithRace() {
// when running with the race detector, 500 instances launches
// well over 8k goroutines. This hits a race detector limit.
numInstances = 75
numInstances = 50
} else if travis.IsRunning() {
numInstances = 200
} else {
Expand Down
6 changes: 5 additions & 1 deletion bitswap_with_sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ func TestSessionBetweenPeers(t *testing.T) {
}
}
for _, is := range inst[2:] {
if is.Exchange.counters.messagesRecvd > 2 {
stat, err := is.Exchange.Stat()
if err != nil {
t.Fatal(err)
}
if stat.MessagesReceived > 2 {
t.Fatal("uninvolved nodes should only receive two messages", is.Exchange.counters.messagesRecvd)
}
}
Expand Down
59 changes: 30 additions & 29 deletions notifications/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,38 @@ type PubSub interface {
func New() PubSub {
return &impl{
wrapped: *pubsub.New(bufferSize),
cancel: make(chan struct{}),
closed: make(chan struct{}),
}
}

type impl struct {
lk sync.RWMutex
wrapped pubsub.PubSub

// These two fields make up a shutdown "lock".
// We need them as calling, e.g., `Unsubscribe` after calling `Shutdown`
// blocks forever and fixing this in pubsub would be rather invasive.
cancel chan struct{}
wg sync.WaitGroup
closed chan struct{}
}

func (ps *impl) Publish(block blocks.Block) {
ps.wg.Add(1)
defer ps.wg.Done()

ps.lk.RLock()
defer ps.lk.RUnlock()
select {
case <-ps.cancel:
// Already shutdown, bail.
case <-ps.closed:
return
default:
}

ps.wrapped.Pub(block, block.Cid().KeyString())
}

// Not safe to call more than once.
func (ps *impl) Shutdown() {
// Interrupt in-progress subscriptions.
close(ps.cancel)
// Wait for them to finish.
ps.wg.Wait()
// shutdown the pubsub.
ps.lk.Lock()
defer ps.lk.Unlock()
select {
case <-ps.closed:
return
default:
}
close(ps.closed)
ps.wrapped.Shutdown()
}

Expand All @@ -71,13 +68,11 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Bl
}

// prevent shutdown
ps.wg.Add(1)
ps.lk.RLock()
defer ps.lk.RUnlock()

// check if shutdown *after* preventing shutdowns.
select {
case <-ps.cancel:
// abort, allow shutdown to continue.
ps.wg.Done()
case <-ps.closed:
close(blocksCh)
return blocksCh
default:
Expand All @@ -86,19 +81,26 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Bl
ps.wrapped.AddSubOnceEach(valuesCh, toStrings(keys)...)
go func() {
defer func() {
ps.wrapped.Unsub(valuesCh)
close(blocksCh)

// Unblock shutdown.
ps.wg.Done()
ps.lk.RLock()
defer ps.lk.RUnlock()
// Don't touch the pubsub instance if we're
// already closed.
select {
case <-ps.closed:
return
default:
}

ps.wrapped.Unsub(valuesCh)
}()

for {
select {
case <-ps.cancel:
return
case <-ctx.Done():
return
case <-ps.closed:
case val, ok := <-valuesCh:
if !ok {
return
Expand All @@ -108,11 +110,10 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Bl
return
}
select {
case <-ps.cancel:
return
case <-ctx.Done():
return
case blocksCh <- block: // continue
case <-ps.closed:
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion notifications/notifications_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func TestShutdownBeforeUnsubscribe(t *testing.T) {
if ok {
t.Fatal("channel should have been closed")
}
default:
case <-time.After(5 * time.Second):
t.Fatal("channel should have been closed")
}
}
Expand Down
17 changes: 15 additions & 2 deletions sessionpeermanager/sessionpeermanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,24 @@ func (fppf *fakePeerProviderFinder) FindProvidersAsync(ctx context.Context, c ci
}

type fakePeerTagger struct {
lk sync.Mutex
taggedPeers []peer.ID
wait sync.WaitGroup
}

func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) {
fpt.wait.Add(1)

fpt.lk.Lock()
defer fpt.lk.Unlock()
fpt.taggedPeers = append(fpt.taggedPeers, p)
}

func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {
defer fpt.wait.Done()

fpt.lk.Lock()
defer fpt.lk.Unlock()
for i := 0; i < len(fpt.taggedPeers); i++ {
if fpt.taggedPeers[i] == p {
fpt.taggedPeers[i] = fpt.taggedPeers[len(fpt.taggedPeers)-1]
Expand All @@ -62,6 +68,12 @@ func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) {
}
}

func (fpt *fakePeerTagger) count() int {
fpt.lk.Lock()
defer fpt.lk.Unlock()
return len(fpt.taggedPeers)
}

func TestFindingMorePeers(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
Expand Down Expand Up @@ -195,6 +207,7 @@ func TestOrderingPeers(t *testing.T) {
t.Fatal("should not return the same random peers each time")
}
}

func TestUntaggingPeers(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
Expand All @@ -216,13 +229,13 @@ func TestUntaggingPeers(t *testing.T) {
}
time.Sleep(2 * time.Millisecond)

if len(fpt.taggedPeers) != len(peers) {
if fpt.count() != len(peers) {
t.Fatal("Peers were not tagged!")
}
<-ctx.Done()
fpt.wait.Wait()

if len(fpt.taggedPeers) != 0 {
if fpt.count() != 0 {
t.Fatal("Peers were not untagged!")
}
}

0 comments on commit 722239f

Please sign in to comment.