diff --git a/bitswap.go b/bitswap.go index 0bd53b3d..97e1daa1 100644 --- a/bitswap.go +++ b/bitswap.go @@ -6,7 +6,6 @@ import ( "context" "errors" "sync" - "sync/atomic" "time" bssrs "github.com/ipfs/go-bitswap/sessionrequestsplitter" @@ -292,7 +291,9 @@ func (bs *Bitswap) receiveBlockFrom(blk blocks.Block, from peer.ID) error { } func (bs *Bitswap) ReceiveMessage(ctx context.Context, p peer.ID, incoming bsmsg.BitSwapMessage) { - atomic.AddUint64(&bs.counters.messagesRecvd, 1) + bs.counterLk.Lock() + bs.counters.messagesRecvd++ + bs.counterLk.Unlock() // This call records changes to wantlists, blocks received, // and number of bytes transfered. diff --git a/bitswap_test.go b/bitswap_test.go index 7882147e..6b0f5c75 100644 --- a/bitswap_test.go +++ b/bitswap_test.go @@ -140,7 +140,7 @@ func TestLargeSwarm(t *testing.T) { if detectrace.WithRace() { // when running with the race detector, 500 instances launches // well over 8k goroutines. This hits a race detector limit. - numInstances = 75 + numInstances = 50 } else if travis.IsRunning() { numInstances = 200 } else { diff --git a/bitswap_with_sessions_test.go b/bitswap_with_sessions_test.go index 0be7bc97..d4d0cfee 100644 --- a/bitswap_with_sessions_test.go +++ b/bitswap_with_sessions_test.go @@ -104,7 +104,11 @@ func TestSessionBetweenPeers(t *testing.T) { } } for _, is := range inst[2:] { - if is.Exchange.counters.messagesRecvd > 2 { + stat, err := is.Exchange.Stat() + if err != nil { + t.Fatal(err) + } + if stat.MessagesReceived > 2 { t.Fatal("uninvolved nodes should only receive two messages", is.Exchange.counters.messagesRecvd) } } diff --git a/notifications/notifications.go b/notifications/notifications.go index 81ba3949..240379ae 100644 --- a/notifications/notifications.go +++ b/notifications/notifications.go @@ -20,27 +20,22 @@ type PubSub interface { func New() PubSub { return &impl{ wrapped: *pubsub.New(bufferSize), - cancel: make(chan struct{}), + closed: make(chan struct{}), } } type impl struct { + lk sync.RWMutex wrapped pubsub.PubSub - // These two fields make up a shutdown "lock". - // We need them as calling, e.g., `Unsubscribe` after calling `Shutdown` - // blocks forever and fixing this in pubsub would be rather invasive. - cancel chan struct{} - wg sync.WaitGroup + closed chan struct{} } func (ps *impl) Publish(block blocks.Block) { - ps.wg.Add(1) - defer ps.wg.Done() - + ps.lk.RLock() + defer ps.lk.RUnlock() select { - case <-ps.cancel: - // Already shutdown, bail. + case <-ps.closed: return default: } @@ -48,13 +43,15 @@ func (ps *impl) Publish(block blocks.Block) { ps.wrapped.Pub(block, block.Cid().KeyString()) } -// Not safe to call more than once. func (ps *impl) Shutdown() { - // Interrupt in-progress subscriptions. - close(ps.cancel) - // Wait for them to finish. - ps.wg.Wait() - // shutdown the pubsub. + ps.lk.Lock() + defer ps.lk.Unlock() + select { + case <-ps.closed: + return + default: + } + close(ps.closed) ps.wrapped.Shutdown() } @@ -71,13 +68,11 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Bl } // prevent shutdown - ps.wg.Add(1) + ps.lk.RLock() + defer ps.lk.RUnlock() - // check if shutdown *after* preventing shutdowns. select { - case <-ps.cancel: - // abort, allow shutdown to continue. - ps.wg.Done() + case <-ps.closed: close(blocksCh) return blocksCh default: @@ -86,19 +81,26 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Bl ps.wrapped.AddSubOnceEach(valuesCh, toStrings(keys)...) go func() { defer func() { - ps.wrapped.Unsub(valuesCh) close(blocksCh) - // Unblock shutdown. - ps.wg.Done() + ps.lk.RLock() + defer ps.lk.RUnlock() + // Don't touch the pubsub instance if we're + // already closed. + select { + case <-ps.closed: + return + default: + } + + ps.wrapped.Unsub(valuesCh) }() for { select { - case <-ps.cancel: - return case <-ctx.Done(): return + case <-ps.closed: case val, ok := <-valuesCh: if !ok { return @@ -108,11 +110,10 @@ func (ps *impl) Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Bl return } select { - case <-ps.cancel: - return case <-ctx.Done(): return case blocksCh <- block: // continue + case <-ps.closed: } } } diff --git a/notifications/notifications_test.go b/notifications/notifications_test.go index 38ab6f9a..4e59ae9b 100644 --- a/notifications/notifications_test.go +++ b/notifications/notifications_test.go @@ -114,7 +114,7 @@ func TestShutdownBeforeUnsubscribe(t *testing.T) { if ok { t.Fatal("channel should have been closed") } - default: + case <-time.After(5 * time.Second): t.Fatal("channel should have been closed") } } diff --git a/sessionpeermanager/sessionpeermanager_test.go b/sessionpeermanager/sessionpeermanager_test.go index d6d1440a..1cad238a 100644 --- a/sessionpeermanager/sessionpeermanager_test.go +++ b/sessionpeermanager/sessionpeermanager_test.go @@ -41,18 +41,24 @@ func (fppf *fakePeerProviderFinder) FindProvidersAsync(ctx context.Context, c ci } type fakePeerTagger struct { + lk sync.Mutex taggedPeers []peer.ID wait sync.WaitGroup } func (fpt *fakePeerTagger) TagPeer(p peer.ID, tag string, n int) { fpt.wait.Add(1) + + fpt.lk.Lock() + defer fpt.lk.Unlock() fpt.taggedPeers = append(fpt.taggedPeers, p) } func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) { defer fpt.wait.Done() + fpt.lk.Lock() + defer fpt.lk.Unlock() for i := 0; i < len(fpt.taggedPeers); i++ { if fpt.taggedPeers[i] == p { fpt.taggedPeers[i] = fpt.taggedPeers[len(fpt.taggedPeers)-1] @@ -62,6 +68,12 @@ func (fpt *fakePeerTagger) UntagPeer(p peer.ID, tag string) { } } +func (fpt *fakePeerTagger) count() int { + fpt.lk.Lock() + defer fpt.lk.Unlock() + return len(fpt.taggedPeers) +} + func TestFindingMorePeers(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) @@ -195,6 +207,7 @@ func TestOrderingPeers(t *testing.T) { t.Fatal("should not return the same random peers each time") } } + func TestUntaggingPeers(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) @@ -216,13 +229,13 @@ func TestUntaggingPeers(t *testing.T) { } time.Sleep(2 * time.Millisecond) - if len(fpt.taggedPeers) != len(peers) { + if fpt.count() != len(peers) { t.Fatal("Peers were not tagged!") } <-ctx.Done() fpt.wait.Wait() - if len(fpt.taggedPeers) != 0 { + if fpt.count() != 0 { t.Fatal("Peers were not untagged!") } }