diff --git a/bitswap.go b/bitswap.go index 1ff55517..0cd6b497 100644 --- a/bitswap.go +++ b/bitswap.go @@ -139,7 +139,11 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, pm := bspm.New(ctx, peerQueueFactory, network.Self()) pqm := bspqm.New(ctx, network) - sessionFactory := func(sessctx context.Context, onShutdown bssession.OnShutdown, id uint64, spm bssession.SessionPeerManager, + sessionFactory := func( + sessctx context.Context, + sessmgr bssession.SessionManager, + id uint64, + spm bssession.SessionPeerManager, sim *bssim.SessionInterestManager, pm bssession.PeerManager, bpm *bsbpm.BlockPresenceManager, @@ -147,7 +151,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, provSearchDelay time.Duration, rebroadcastDelay delay.D, self peer.ID) bssm.Session { - return bssession.New(ctx, sessctx, onShutdown, id, spm, pqm, sim, pm, bpm, notif, provSearchDelay, rebroadcastDelay, self) + return bssession.New(sessctx, sessmgr, id, spm, pqm, sim, pm, bpm, notif, provSearchDelay, rebroadcastDelay, self) } sessionPeerManagerFactory := func(ctx context.Context, id uint64) bssession.SessionPeerManager { return bsspm.New(id, network.ConnectionManager()) diff --git a/internal/blockpresencemanager/blockpresencemanager.go b/internal/blockpresencemanager/blockpresencemanager.go index 87821f2f..1d3acb0e 100644 --- a/internal/blockpresencemanager/blockpresencemanager.go +++ b/internal/blockpresencemanager/blockpresencemanager.go @@ -109,3 +109,13 @@ func (bpm *BlockPresenceManager) RemoveKeys(ks []cid.Cid) { delete(bpm.presence, c) } } + +// HasKey indicates whether the BlockPresenceManager is tracking the given key +// (used by the tests) +func (bpm *BlockPresenceManager) HasKey(c cid.Cid) bool { + bpm.Lock() + defer bpm.Unlock() + + _, ok := bpm.presence[c] + return ok +} diff --git a/internal/session/session.go b/internal/session/session.go index 47cbb548..7a0d23b3 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -43,6 +43,14 @@ type PeerManager interface { SendCancels(context.Context, []cid.Cid) } +// SessionManager manages all the sessions +type SessionManager interface { + // Remove a session (called when the session shuts down) + RemoveSession(sesid uint64) + // Cancel wants (called when a call to GetBlocks() is cancelled) + CancelSessionWants(sid uint64, wants []cid.Cid) +} + // SessionPeerManager keeps track of peers in the session type SessionPeerManager interface { // PeersDiscovered indicates if any peers have been discovered yet @@ -86,19 +94,15 @@ type op struct { keys []cid.Cid } -type OnShutdown func(uint64) - // Session holds state for an individual bitswap transfer operation. // This allows bitswap to make smarter decisions about who to send wantlist // info to, and who to request blocks from. type Session struct { // dependencies - bsctx context.Context // context for bitswap - ctx context.Context // context for session + ctx context.Context shutdown func() - onShutdown OnShutdown + sm SessionManager pm PeerManager - bpm *bsbpm.BlockPresenceManager sprm SessionPeerManager providerFinder ProviderFinder sim *bssim.SessionInterestManager @@ -130,9 +134,8 @@ type Session struct { // New creates a new bitswap session whose lifetime is bounded by the // given context. func New( - bsctx context.Context, // context for bitswap - ctx context.Context, // context for this session - onShutdown OnShutdown, + ctx context.Context, + sm SessionManager, id uint64, sprm SessionPeerManager, providerFinder ProviderFinder, @@ -148,12 +151,10 @@ func New( s := &Session{ sw: newSessionWants(broadcastLiveWantsLimit), tickDelayReqs: make(chan time.Duration), - bsctx: bsctx, ctx: ctx, shutdown: cancel, - onShutdown: onShutdown, + sm: sm, pm: pm, - bpm: bpm, sprm: sprm, providerFinder: providerFinder, sim: sim, @@ -167,7 +168,7 @@ func New( periodicSearchDelay: periodicSearchDelay, self: self, } - s.sws = newSessionWantSender(id, pm, sprm, bpm, s.onWantsSent, s.onPeersExhausted) + s.sws = newSessionWantSender(id, pm, sprm, sm, bpm, s.onWantsSent, s.onPeersExhausted) go s.run(ctx) @@ -308,6 +309,7 @@ func (s *Session) run(ctx context.Context) { case opCancel: // Wants were cancelled s.sw.CancelPending(oper.keys) + s.sws.Cancel(oper.keys) case opWantsSent: // Wants were sent to a peer s.sw.WantsSent(oper.keys) @@ -402,23 +404,9 @@ func (s *Session) handleShutdown() { // Shut down the sessionWantSender (blocks until sessionWantSender stops // sending) s.sws.Shutdown() - - // Remove session's interest in the given blocks. - cancelKs := s.sim.RemoveSessionInterest(s.id) - - // Free up block presence tracking for keys that no session is interested - // in anymore - s.bpm.RemoveKeys(cancelKs) - - // Send CANCEL to all peers for blocks that no session is interested in - // anymore. - // Note: use bitswap context because session context has already been - // cancelled. - s.pm.SendCancels(s.bsctx, cancelKs) - // Signal to the SessionManager that the session has been shutdown // and can be cleaned up - s.onShutdown(s.id) + s.sm.RemoveSession(s.id) } // handleReceive is called when the session receives blocks from a peer diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 58e11172..028ee46e 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -18,6 +18,40 @@ import ( peer "github.com/libp2p/go-libp2p-core/peer" ) +type mockSessionMgr struct { + lk sync.Mutex + removeSession bool + cancels []cid.Cid +} + +func newMockSessionMgr() *mockSessionMgr { + return &mockSessionMgr{} +} + +func (msm *mockSessionMgr) removeSessionCalled() bool { + msm.lk.Lock() + defer msm.lk.Unlock() + return msm.removeSession +} + +func (msm *mockSessionMgr) cancelled() []cid.Cid { + msm.lk.Lock() + defer msm.lk.Unlock() + return msm.cancels +} + +func (msm *mockSessionMgr) RemoveSession(sesid uint64) { + msm.lk.Lock() + defer msm.lk.Unlock() + msm.removeSession = true +} + +func (msm *mockSessionMgr) CancelSessionWants(sid uint64, wants []cid.Cid) { + msm.lk.Lock() + defer msm.lk.Unlock() + msm.cancels = append(msm.cancels, wants...) +} + func newFakeSessionPeerManager() *bsspm.SessionPeerManager { return bsspm.New(1, newFakePeerTagger()) } @@ -61,8 +95,6 @@ type wantReq struct { type fakePeerManager struct { wantReqs chan wantReq - lk sync.Mutex - cancels []cid.Cid } func newFakePeerManager() *fakePeerManager { @@ -82,35 +114,7 @@ func (pm *fakePeerManager) BroadcastWantHaves(ctx context.Context, cids []cid.Ci case <-ctx.Done(): } } -func (pm *fakePeerManager) SendCancels(ctx context.Context, cancels []cid.Cid) { - pm.lk.Lock() - defer pm.lk.Unlock() - pm.cancels = append(pm.cancels, cancels...) -} -func (pm *fakePeerManager) allCancels() []cid.Cid { - pm.lk.Lock() - defer pm.lk.Unlock() - return append([]cid.Cid{}, pm.cancels...) -} - -type onShutdownMonitor struct { - lk sync.Mutex - shutdown bool -} - -func (sm *onShutdownMonitor) onShutdown(uint64) { - sm.lk.Lock() - defer sm.lk.Unlock() - - sm.shutdown = true -} - -func (sm *onShutdownMonitor) shutdownCalled() bool { - sm.lk.Lock() - defer sm.lk.Unlock() - - return sm.shutdown -} +func (pm *fakePeerManager) SendCancels(ctx context.Context, cancels []cid.Cid) {} func TestSessionGetBlocks(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -122,8 +126,8 @@ func TestSessionGetBlocks(t *testing.T) { notif := notifications.New() defer notif.Shutdown() id := testutil.GenerateSessionID() - onShutdown := func(uint64) {} - session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "") + sm := newMockSessionMgr() + session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "") blockGenerator := blocksutil.NewBlockGenerator() blks := blockGenerator.Blocks(broadcastLiveWantsLimit * 2) var cids []cid.Cid @@ -201,9 +205,9 @@ func TestSessionGetBlocks(t *testing.T) { time.Sleep(10 * time.Millisecond) - // Verify wants were cancelled - if len(fpm.allCancels()) != len(blks) { - t.Fatal("expected cancels to be sent for all wants") + // Verify session was removed + if !sm.removeSessionCalled() { + t.Fatal("expected session to be removed") } } @@ -218,8 +222,8 @@ func TestSessionFindMorePeers(t *testing.T) { notif := notifications.New() defer notif.Shutdown() id := testutil.GenerateSessionID() - onShutdown := func(uint64) {} - session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "") + sm := newMockSessionMgr() + session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "") session.SetBaseTickDelay(200 * time.Microsecond) blockGenerator := blocksutil.NewBlockGenerator() blks := blockGenerator.Blocks(broadcastLiveWantsLimit * 2) @@ -293,8 +297,8 @@ func TestSessionOnPeersExhausted(t *testing.T) { notif := notifications.New() defer notif.Shutdown() id := testutil.GenerateSessionID() - onShutdown := func(uint64) {} - session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "") + sm := newMockSessionMgr() + session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "") blockGenerator := blocksutil.NewBlockGenerator() blks := blockGenerator.Blocks(broadcastLiveWantsLimit + 5) var cids []cid.Cid @@ -338,8 +342,8 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) { notif := notifications.New() defer notif.Shutdown() id := testutil.GenerateSessionID() - onShutdown := func(uint64) {} - session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, 10*time.Millisecond, delay.Fixed(100*time.Millisecond), "") + sm := newMockSessionMgr() + session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, 10*time.Millisecond, delay.Fixed(100*time.Millisecond), "") blockGenerator := blocksutil.NewBlockGenerator() blks := blockGenerator.Blocks(4) var cids []cid.Cid @@ -451,12 +455,11 @@ func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) { notif := notifications.New() defer notif.Shutdown() id := testutil.GenerateSessionID() - - osm := &onShutdownMonitor{} + sm := newMockSessionMgr() // Create a new session with its own context sessctx, sesscancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - session := New(context.Background(), sessctx, osm.onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "") + session := New(sessctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "") timerCtx, timerCancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer timerCancel() @@ -487,8 +490,8 @@ func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) { time.Sleep(10 * time.Millisecond) - // Expect onShutdown to be called - if !osm.shutdownCalled() { + // Expect RemoveSession to be called + if !sm.removeSessionCalled() { t.Fatal("expected onShutdown to be called") } } @@ -502,27 +505,26 @@ func TestSessionOnShutdownCalled(t *testing.T) { notif := notifications.New() defer notif.Shutdown() id := testutil.GenerateSessionID() - - osm := &onShutdownMonitor{} + sm := newMockSessionMgr() // Create a new session with its own context sessctx, sesscancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer sesscancel() - session := New(context.Background(), sessctx, osm.onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "") + session := New(sessctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "") // Shutdown the session session.Shutdown() time.Sleep(10 * time.Millisecond) - // Expect onShutdown to be called - if !osm.shutdownCalled() { + // Expect RemoveSession to be called + if !sm.removeSessionCalled() { t.Fatal("expected onShutdown to be called") } } func TestSessionReceiveMessageAfterCtxCancel(t *testing.T) { - ctx, cancelCtx := context.WithTimeout(context.Background(), 10*time.Millisecond) + ctx, cancelCtx := context.WithTimeout(context.Background(), 20*time.Millisecond) fpm := newFakePeerManager() fspm := newFakeSessionPeerManager() fpf := newFakeProviderFinder() @@ -532,8 +534,8 @@ func TestSessionReceiveMessageAfterCtxCancel(t *testing.T) { notif := notifications.New() defer notif.Shutdown() id := testutil.GenerateSessionID() - onShutdown := func(uint64) {} - session := New(ctx, ctx, onShutdown, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "") + sm := newMockSessionMgr() + session := New(ctx, sm, id, fspm, fpf, sim, fpm, bpm, notif, time.Second, delay.Fixed(time.Minute), "") blockGenerator := blocksutil.NewBlockGenerator() blks := blockGenerator.Blocks(2) cids := []cid.Cid{blks[0].Cid(), blks[1].Cid()} diff --git a/internal/session/sessionwantsender.go b/internal/session/sessionwantsender.go index 8ccba8f8..094d9096 100644 --- a/internal/session/sessionwantsender.go +++ b/internal/session/sessionwantsender.go @@ -30,6 +30,12 @@ const ( BPHave ) +// SessionWantsCanceller provides a method to cancel wants +type SessionWantsCanceller interface { + // Cancel wants for this session + CancelSessionWants(sid uint64, wants []cid.Cid) +} + // update encapsulates a message received by the session type update struct { // Which peer sent the update @@ -53,6 +59,8 @@ type peerAvailability struct { type change struct { // new wants requested add []cid.Cid + // wants cancelled + cancel []cid.Cid // new message received by session (blocks / HAVEs / DONT_HAVEs) update update // peer has connected / disconnected @@ -94,6 +102,8 @@ type sessionWantSender struct { pm PeerManager // Keeps track of peers in the session spm SessionPeerManager + // Cancels wants + canceller SessionWantsCanceller // Keeps track of which peer has / doesn't have a block bpm *bsbpm.BlockPresenceManager // Called when wants are sent @@ -102,7 +112,7 @@ type sessionWantSender struct { onPeersExhausted onPeersExhaustedFn } -func newSessionWantSender(sid uint64, pm PeerManager, spm SessionPeerManager, +func newSessionWantSender(sid uint64, pm PeerManager, spm SessionPeerManager, canceller SessionWantsCanceller, bpm *bsbpm.BlockPresenceManager, onSend onSendFn, onPeersExhausted onPeersExhaustedFn) sessionWantSender { ctx, cancel := context.WithCancel(context.Background()) @@ -119,6 +129,7 @@ func newSessionWantSender(sid uint64, pm PeerManager, spm SessionPeerManager, pm: pm, spm: spm, + canceller: canceller, bpm: bpm, onSend: onSend, onPeersExhausted: onPeersExhausted, @@ -139,6 +150,14 @@ func (sws *sessionWantSender) Add(ks []cid.Cid) { sws.addChange(change{add: ks}) } +// Cancel is called when a request is cancelled +func (sws *sessionWantSender) Cancel(ks []cid.Cid) { + if len(ks) == 0 { + return + } + sws.addChange(change{cancel: ks}) +} + // Update is called when the session receives a message with incoming blocks // or HAVE / DONT_HAVE func (sws *sessionWantSender) Update(from peer.ID, ks []cid.Cid, haves []cid.Cid, dontHaves []cid.Cid) { @@ -156,7 +175,9 @@ func (sws *sessionWantSender) Update(from peer.ID, ks []cid.Cid, haves []cid.Cid // connected / disconnected func (sws *sessionWantSender) SignalAvailability(p peer.ID, isAvailable bool) { availability := peerAvailability{p, isAvailable} - sws.addChange(change{availability: availability}) + // Add the change in a non-blocking manner to avoid the possibility of a + // deadlock + sws.addChangeNonBlocking(change{availability: availability}) } // Run is the main loop for processing incoming changes @@ -193,6 +214,22 @@ func (sws *sessionWantSender) addChange(c change) { } } +// addChangeNonBlocking adds a new change to the queue, using a go-routine +// if the change blocks, so as to avoid potential deadlocks +func (sws *sessionWantSender) addChangeNonBlocking(c change) { + select { + case sws.changes <- c: + default: + // changes channel is full, so add change in a go routine instead + go func() { + select { + case sws.changes <- c: + case <-sws.ctx.Done(): + } + }() + } +} + // collectChanges collects all the changes that have occurred since the last // invocation of onChange func (sws *sessionWantSender) collectChanges(changes []change) []change { @@ -215,6 +252,7 @@ func (sws *sessionWantSender) onChange(changes []change) { // Apply each change availability := make(map[peer.ID]bool, len(changes)) + cancels := make([]cid.Cid, 0) var updates []update for _, chng := range changes { // Initialize info for new wants @@ -222,6 +260,12 @@ func (sws *sessionWantSender) onChange(changes []change) { sws.trackWant(c) } + // Remove cancelled wants + for _, c := range chng.cancel { + sws.untrackWant(c) + cancels = append(cancels, c) + } + // Consolidate updates and changes to availability if chng.update.from != "" { // If the update includes blocks or haves, treat it as signaling that @@ -247,6 +291,11 @@ func (sws *sessionWantSender) onChange(changes []change) { // don't have the want sws.checkForExhaustedWants(dontHaves, newlyUnavailable) + // If there are any cancels, send them + if len(cancels) > 0 { + sws.canceller.CancelSessionWants(sws.sessionID, cancels) + } + // If there are some connected peers, send any pending wants if sws.spm.HasPeers() { sws.sendNextWants(newlyAvailable) @@ -306,6 +355,11 @@ func (sws *sessionWantSender) trackWant(c cid.Cid) { } } +// untrackWant removes an entry from the map of CID -> want info +func (sws *sessionWantSender) untrackWant(c cid.Cid) { + delete(sws.wants, c) +} + // processUpdates processes incoming blocks and HAVE / DONT_HAVEs. // It returns all DONT_HAVEs. func (sws *sessionWantSender) processUpdates(updates []update) []cid.Cid { diff --git a/internal/session/sessionwantsender_test.go b/internal/session/sessionwantsender_test.go index 3593009a..6c3059c1 100644 --- a/internal/session/sessionwantsender_test.go +++ b/internal/session/sessionwantsender_test.go @@ -136,10 +136,12 @@ func TestSendWants(t *testing.T) { sid := uint64(1) pm := newMockPeerManager() fpm := newFakeSessionPeerManager() + swc := newMockSessionMgr() bpm := bsbpm.New() onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} onPeersExhausted := func([]cid.Cid) {} - spm := newSessionWantSender(sid, pm, fpm, bpm, onSend, onPeersExhausted) + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted) + defer spm.Shutdown() go spm.Run() @@ -174,10 +176,12 @@ func TestSendsWantBlockToOnePeerOnly(t *testing.T) { sid := uint64(1) pm := newMockPeerManager() fpm := newFakeSessionPeerManager() + swc := newMockSessionMgr() bpm := bsbpm.New() onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} onPeersExhausted := func([]cid.Cid) {} - spm := newSessionWantSender(sid, pm, fpm, bpm, onSend, onPeersExhausted) + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted) + defer spm.Shutdown() go spm.Run() @@ -232,10 +236,12 @@ func TestReceiveBlock(t *testing.T) { sid := uint64(1) pm := newMockPeerManager() fpm := newFakeSessionPeerManager() + swc := newMockSessionMgr() bpm := bsbpm.New() onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} onPeersExhausted := func([]cid.Cid) {} - spm := newSessionWantSender(sid, pm, fpm, bpm, onSend, onPeersExhausted) + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted) + defer spm.Shutdown() go spm.Run() @@ -284,6 +290,40 @@ func TestReceiveBlock(t *testing.T) { } } +func TestCancelWants(t *testing.T) { + cids := testutil.GenerateCids(4) + sid := uint64(1) + pm := newMockPeerManager() + fpm := newFakeSessionPeerManager() + swc := newMockSessionMgr() + bpm := bsbpm.New() + onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} + onPeersExhausted := func([]cid.Cid) {} + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted) + defer spm.Shutdown() + + go spm.Run() + + // add cid0, cid1, cid2 + blkCids := cids[0:3] + spm.Add(blkCids) + + time.Sleep(5 * time.Millisecond) + + // cancel cid0, cid2 + cancelCids := []cid.Cid{cids[0], cids[2]} + spm.Cancel(cancelCids) + + // Wait for processing to complete + time.Sleep(5 * time.Millisecond) + + // Should have sent cancels for cid0, cid2 + sent := swc.cancelled() + if !testutil.MatchKeysIgnoreOrder(sent, cancelCids) { + t.Fatal("Wrong keys") + } +} + func TestPeerUnavailable(t *testing.T) { cids := testutil.GenerateCids(2) peers := testutil.GeneratePeers(2) @@ -292,10 +332,12 @@ func TestPeerUnavailable(t *testing.T) { sid := uint64(1) pm := newMockPeerManager() fpm := newFakeSessionPeerManager() + swc := newMockSessionMgr() bpm := bsbpm.New() onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} onPeersExhausted := func([]cid.Cid) {} - spm := newSessionWantSender(sid, pm, fpm, bpm, onSend, onPeersExhausted) + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted) + defer spm.Shutdown() go spm.Run() @@ -357,11 +399,12 @@ func TestPeersExhausted(t *testing.T) { sid := uint64(1) pm := newMockPeerManager() fpm := newFakeSessionPeerManager() + swc := newMockSessionMgr() bpm := bsbpm.New() onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} ep := exhaustedPeers{} - spm := newSessionWantSender(sid, pm, fpm, bpm, onSend, ep.onPeersExhausted) + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, ep.onPeersExhausted) go spm.Run() @@ -433,11 +476,12 @@ func TestPeersExhaustedLastWaitingPeerUnavailable(t *testing.T) { sid := uint64(1) pm := newMockPeerManager() fpm := newFakeSessionPeerManager() + swc := newMockSessionMgr() bpm := bsbpm.New() onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} ep := exhaustedPeers{} - spm := newSessionWantSender(sid, pm, fpm, bpm, onSend, ep.onPeersExhausted) + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, ep.onPeersExhausted) go spm.Run() @@ -481,11 +525,12 @@ func TestPeersExhaustedAllPeersUnavailable(t *testing.T) { sid := uint64(1) pm := newMockPeerManager() fpm := newFakeSessionPeerManager() + swc := newMockSessionMgr() bpm := bsbpm.New() onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} ep := exhaustedPeers{} - spm := newSessionWantSender(sid, pm, fpm, bpm, onSend, ep.onPeersExhausted) + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, ep.onPeersExhausted) go spm.Run() @@ -520,10 +565,12 @@ func TestConsecutiveDontHaveLimit(t *testing.T) { sid := uint64(1) pm := newMockPeerManager() fpm := newFakeSessionPeerManager() + swc := newMockSessionMgr() bpm := bsbpm.New() onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} onPeersExhausted := func([]cid.Cid) {} - spm := newSessionWantSender(sid, pm, fpm, bpm, onSend, onPeersExhausted) + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted) + defer spm.Shutdown() go spm.Run() @@ -576,10 +623,12 @@ func TestConsecutiveDontHaveLimitInterrupted(t *testing.T) { sid := uint64(1) pm := newMockPeerManager() fpm := newFakeSessionPeerManager() + swc := newMockSessionMgr() bpm := bsbpm.New() onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} onPeersExhausted := func([]cid.Cid) {} - spm := newSessionWantSender(sid, pm, fpm, bpm, onSend, onPeersExhausted) + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted) + defer spm.Shutdown() go spm.Run() @@ -631,10 +680,12 @@ func TestConsecutiveDontHaveReinstateAfterRemoval(t *testing.T) { sid := uint64(1) pm := newMockPeerManager() fpm := newFakeSessionPeerManager() + swc := newMockSessionMgr() bpm := bsbpm.New() onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} onPeersExhausted := func([]cid.Cid) {} - spm := newSessionWantSender(sid, pm, fpm, bpm, onSend, onPeersExhausted) + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted) + defer spm.Shutdown() go spm.Run() @@ -715,10 +766,12 @@ func TestConsecutiveDontHaveDontRemoveIfHasWantedBlock(t *testing.T) { sid := uint64(1) pm := newMockPeerManager() fpm := newFakeSessionPeerManager() + swc := newMockSessionMgr() bpm := bsbpm.New() onSend := func(peer.ID, []cid.Cid, []cid.Cid) {} onPeersExhausted := func([]cid.Cid) {} - spm := newSessionWantSender(sid, pm, fpm, bpm, onSend, onPeersExhausted) + spm := newSessionWantSender(sid, pm, fpm, swc, bpm, onSend, onPeersExhausted) + defer spm.Shutdown() go spm.Run() diff --git a/internal/sessioninterestmanager/sessioninterestmanager.go b/internal/sessioninterestmanager/sessioninterestmanager.go index 6e345b55..0ab32ed1 100644 --- a/internal/sessioninterestmanager/sessioninterestmanager.go +++ b/internal/sessioninterestmanager/sessioninterestmanager.go @@ -3,7 +3,6 @@ package sessioninterestmanager import ( "sync" - bsswl "github.com/ipfs/go-bitswap/internal/sessionwantlist" blocks "github.com/ipfs/go-block-format" cid "github.com/ipfs/go-cid" @@ -11,16 +10,22 @@ import ( // SessionInterestManager records the CIDs that each session is interested in. type SessionInterestManager struct { - lk sync.RWMutex - interested *bsswl.SessionWantlist - wanted *bsswl.SessionWantlist + lk sync.RWMutex + wants map[cid.Cid]map[uint64]bool } // New initializes a new SessionInterestManager. func New() *SessionInterestManager { return &SessionInterestManager{ - interested: bsswl.NewSessionWantlist(), - wanted: bsswl.NewSessionWantlist(), + // Map of cids -> sessions -> bool + // + // The boolean indicates whether the session still wants the block + // or is just interested in receiving messages about it. + // + // Note that once the block is received the session no longer wants + // the block, but still wants to receive messages from peers who have + // the block as they may have other blocks the session is interested in. + wants: make(map[cid.Cid]map[uint64]bool), } } @@ -30,25 +35,85 @@ func (sim *SessionInterestManager) RecordSessionInterest(ses uint64, ks []cid.Ci sim.lk.Lock() defer sim.lk.Unlock() - sim.interested.Add(ks, ses) - sim.wanted.Add(ks, ses) + // For each key + for _, c := range ks { + // Record that the session wants the blocks + if want, ok := sim.wants[c]; ok { + want[ses] = true + } else { + sim.wants[c] = map[uint64]bool{ses: true} + } + } } // When the session shuts down it calls RemoveSessionInterest(). -func (sim *SessionInterestManager) RemoveSessionInterest(ses uint64) []cid.Cid { +// Returns the keys that no session is interested in any more. +func (sim *SessionInterestManager) RemoveSession(ses uint64) []cid.Cid { sim.lk.Lock() defer sim.lk.Unlock() - sim.wanted.RemoveSession(ses) - return sim.interested.RemoveSession(ses) + // The keys that no session is interested in + deletedKs := make([]cid.Cid, 0) + + // For each known key + for c := range sim.wants { + // Remove the session from the list of sessions that want the key + delete(sim.wants[c], ses) + + // If there are no more sessions that want the key + if len(sim.wants[c]) == 0 { + // Clean up the list memory + delete(sim.wants, c) + // Add the key to the list of keys that no session is interested in + deletedKs = append(deletedKs, c) + } + } + + return deletedKs } // When the session receives blocks, it calls RemoveSessionWants(). -func (sim *SessionInterestManager) RemoveSessionWants(ses uint64, wants []cid.Cid) { +func (sim *SessionInterestManager) RemoveSessionWants(ses uint64, ks []cid.Cid) { + sim.lk.Lock() + defer sim.lk.Unlock() + + // For each key + for _, c := range ks { + // If the session wanted the block + if wanted, ok := sim.wants[c][ses]; ok && wanted { + // Mark the block as unwanted + sim.wants[c][ses] = false + } + } +} + +// When a request is cancelled, the session calls RemoveSessionInterested(). +// Returns the keys that no session is interested in any more. +func (sim *SessionInterestManager) RemoveSessionInterested(ses uint64, ks []cid.Cid) []cid.Cid { sim.lk.Lock() defer sim.lk.Unlock() - sim.wanted.RemoveSessionKeys(ses, wants) + // The keys that no session is interested in + deletedKs := make([]cid.Cid, 0, len(ks)) + + // For each key + for _, c := range ks { + // If there is a list of sessions that want the key + if _, ok := sim.wants[c]; ok { + // Remove the session from the list of sessions that want the key + delete(sim.wants[c], ses) + + // If there are no more sessions that want the key + if len(sim.wants[c]) == 0 { + // Clean up the list memory + delete(sim.wants, c) + // Add the key to the list of keys that no session is interested in + deletedKs = append(deletedKs, c) + } + } + } + + return deletedKs } // The session calls FilterSessionInterested() to filter the sets of keys for @@ -57,9 +122,20 @@ func (sim *SessionInterestManager) FilterSessionInterested(ses uint64, ksets ... sim.lk.RLock() defer sim.lk.RUnlock() + // For each set of keys kres := make([][]cid.Cid, len(ksets)) for i, ks := range ksets { - kres[i] = sim.interested.SessionHas(ses, ks).Keys() + // The set of keys that at least one session is interested in + has := make([]cid.Cid, 0, len(ks)) + + // For each key in the list + for _, c := range ks { + // If there is a session that's interested, add the key to the set + if _, ok := sim.wants[c][ses]; ok { + has = append(has, c) + } + } + kres[i] = has } return kres } @@ -70,12 +146,19 @@ func (sim *SessionInterestManager) SplitWantedUnwanted(blks []blocks.Block) ([]b sim.lk.RLock() defer sim.lk.RUnlock() - // Get the wanted block keys - ks := make([]cid.Cid, len(blks)) + // Get the wanted block keys as a set + wantedKs := cid.NewSet() for _, b := range blks { - ks = append(ks, b.Cid()) + c := b.Cid() + // For each session that is interested in the key + for ses := range sim.wants[c] { + // If the session wants the key (rather than just being interested) + if wanted, ok := sim.wants[c][ses]; ok && wanted { + // Add the key to the set + wantedKs.Add(c) + } + } } - wantedKs := sim.wanted.Has(ks) // Separate the blocks into wanted and unwanted wantedBlks := make([]blocks.Block, 0, len(blks)) @@ -101,5 +184,18 @@ func (sim *SessionInterestManager) InterestedSessions(blks []cid.Cid, haves []ci ks = append(ks, haves...) ks = append(ks, dontHaves...) - return sim.interested.SessionsFor(ks) + // Create a set of sessions that are interested in the keys + sesSet := make(map[uint64]struct{}) + for _, c := range ks { + for s := range sim.wants[c] { + sesSet[s] = struct{}{} + } + } + + // Convert the set into a list + ses := make([]uint64, 0, len(sesSet)) + for s := range sesSet { + ses = append(ses, s) + } + return ses } diff --git a/internal/sessioninterestmanager/sessioninterestmanager_test.go b/internal/sessioninterestmanager/sessioninterestmanager_test.go index ead92023..0bba6638 100644 --- a/internal/sessioninterestmanager/sessioninterestmanager_test.go +++ b/internal/sessioninterestmanager/sessioninterestmanager_test.go @@ -83,7 +83,7 @@ func TestInterestedSessions(t *testing.T) { } } -func TestRemoveSessionInterest(t *testing.T) { +func TestRemoveSession(t *testing.T) { sim := New() ses1 := uint64(1) @@ -92,7 +92,7 @@ func TestRemoveSessionInterest(t *testing.T) { cids2 := append(testutil.GenerateCids(1), cids1[1]) sim.RecordSessionInterest(ses1, cids1) sim.RecordSessionInterest(ses2, cids2) - sim.RemoveSessionInterest(ses1) + sim.RemoveSession(ses1) res := sim.FilterSessionInterested(ses1, cids1) if len(res) != 1 || len(res[0]) != 0 { @@ -111,6 +111,42 @@ func TestRemoveSessionInterest(t *testing.T) { } } +func TestRemoveSessionInterested(t *testing.T) { + sim := New() + + ses1 := uint64(1) + ses2 := uint64(2) + cids1 := testutil.GenerateCids(2) + cids2 := append(testutil.GenerateCids(1), cids1[1]) + sim.RecordSessionInterest(ses1, cids1) + sim.RecordSessionInterest(ses2, cids2) + + res := sim.RemoveSessionInterested(ses1, []cid.Cid{cids1[0]}) + if len(res) != 1 { + t.Fatal("Expected no interested sessions left") + } + + interested := sim.FilterSessionInterested(ses1, cids1) + if len(interested) != 1 || len(interested[0]) != 1 { + t.Fatal("Expected ses1 still interested in one cid") + } + + res = sim.RemoveSessionInterested(ses1, cids1) + if len(res) != 0 { + t.Fatal("Expected ses2 to be interested in one cid") + } + + interested = sim.FilterSessionInterested(ses1, cids1) + if len(interested) != 1 || len(interested[0]) != 0 { + t.Fatal("Expected ses1 to have no remaining interest") + } + + interested = sim.FilterSessionInterested(ses2, cids1) + if len(interested) != 1 || len(interested[0]) != 1 { + t.Fatal("Expected ses2 to still be interested in one key") + } +} + func TestSplitWantedUnwanted(t *testing.T) { blks := testutil.GenerateBlocksOfSize(3, 1024) sim := New() diff --git a/internal/sessionmanager/sessionmanager.go b/internal/sessionmanager/sessionmanager.go index 0f79a7aa..42b20938 100644 --- a/internal/sessionmanager/sessionmanager.go +++ b/internal/sessionmanager/sessionmanager.go @@ -25,7 +25,18 @@ type Session interface { } // SessionFactory generates a new session for the SessionManager to track. -type SessionFactory func(ctx context.Context, onShutdown bssession.OnShutdown, id uint64, sprm bssession.SessionPeerManager, sim *bssim.SessionInterestManager, pm bssession.PeerManager, bpm *bsbpm.BlockPresenceManager, notif notifications.PubSub, provSearchDelay time.Duration, rebroadcastDelay delay.D, self peer.ID) Session +type SessionFactory func( + ctx context.Context, + sm bssession.SessionManager, + id uint64, + sprm bssession.SessionPeerManager, + sim *bssim.SessionInterestManager, + pm bssession.PeerManager, + bpm *bsbpm.BlockPresenceManager, + notif notifications.PubSub, + provSearchDelay time.Duration, + rebroadcastDelay delay.D, + self peer.ID) Session // PeerManagerFactory generates a new peer manager for a session. type PeerManagerFactory func(ctx context.Context, id uint64) bssession.SessionPeerManager @@ -77,10 +88,12 @@ func (sm *SessionManager) NewSession(ctx context.Context, id := sm.GetNextSessionID() pm := sm.peerManagerFactory(ctx, id) - session := sm.sessionFactory(ctx, sm.removeSession, id, pm, sm.sessionInterestManager, sm.peerManager, sm.blockPresenceManager, sm.notif, provSearchDelay, rebroadcastDelay, sm.self) + session := sm.sessionFactory(ctx, sm, id, pm, sm.sessionInterestManager, sm.peerManager, sm.blockPresenceManager, sm.notif, provSearchDelay, rebroadcastDelay, sm.self) sm.sessLk.Lock() - sm.sessions[id] = session + if sm.sessions != nil { // check if SessionManager was shutdown + sm.sessions[id] = session + } sm.sessLk.Unlock() return session @@ -88,18 +101,38 @@ func (sm *SessionManager) NewSession(ctx context.Context, func (sm *SessionManager) Shutdown() { sm.sessLk.Lock() - defer sm.sessLk.Unlock() + sessions := make([]Session, 0, len(sm.sessions)) for _, ses := range sm.sessions { + sessions = append(sessions, ses) + } + + // Ensure that if Shutdown() is called twice we only shut down + // the sessions once + sm.sessions = nil + + sm.sessLk.Unlock() + + for _, ses := range sessions { ses.Shutdown() } } -func (sm *SessionManager) removeSession(sesid uint64) { +func (sm *SessionManager) RemoveSession(sesid uint64) { + // Remove session from SessionInterestManager - returns the keys that no + // session is interested in anymore. + cancelKs := sm.sessionInterestManager.RemoveSession(sesid) + + // Cancel keys that no session is interested in anymore + sm.cancelWants(cancelKs) + sm.sessLk.Lock() defer sm.sessLk.Unlock() - delete(sm.sessions, sesid) + // Clean up session + if sm.sessions != nil { // check if SessionManager was shutdown + delete(sm.sessions, sesid) + } } // GetNextSessionID returns the next sequential identifier for a session. @@ -119,6 +152,10 @@ func (sm *SessionManager) ReceiveFrom(ctx context.Context, p peer.ID, blks []cid // Notify each session that is interested in the blocks / HAVEs / DONT_HAVEs for _, id := range sm.sessionInterestManager.InterestedSessions(blks, haves, dontHaves) { sm.sessLk.RLock() + if sm.sessions == nil { // check if SessionManager was shutdown + sm.sessLk.RUnlock() + return + } sess, ok := sm.sessions[id] sm.sessLk.RUnlock() @@ -130,3 +167,23 @@ func (sm *SessionManager) ReceiveFrom(ctx context.Context, p peer.ID, blks []cid // Send CANCEL to all peers with want-have / want-block sm.peerManager.SendCancels(ctx, blks) } + +// CancelSessionWants is called when a session cancels wants because a call to +// GetBlocks() is cancelled +func (sm *SessionManager) CancelSessionWants(sesid uint64, wants []cid.Cid) { + // Remove session's interest in the given blocks - returns the keys that no + // session is interested in anymore. + cancelKs := sm.sessionInterestManager.RemoveSessionInterested(sesid, wants) + sm.cancelWants(cancelKs) +} + +func (sm *SessionManager) cancelWants(wants []cid.Cid) { + // Free up block presence tracking for keys that no session is interested + // in anymore + sm.blockPresenceManager.RemoveKeys(wants) + + // Send CANCEL to all peers for blocks that no session is interested in + // anymore. + // Note: use bitswap context because session context may already be Done. + sm.peerManager.SendCancels(sm.ctx, wants) +} diff --git a/internal/sessionmanager/sessionmanager_test.go b/internal/sessionmanager/sessionmanager_test.go index 19dbc981..3be1f9b5 100644 --- a/internal/sessionmanager/sessionmanager_test.go +++ b/internal/sessionmanager/sessionmanager_test.go @@ -2,6 +2,7 @@ package sessionmanager import ( "context" + "sync" "testing" "time" @@ -12,6 +13,7 @@ import ( bspm "github.com/ipfs/go-bitswap/internal/peermanager" bssession "github.com/ipfs/go-bitswap/internal/session" bssim "github.com/ipfs/go-bitswap/internal/sessioninterestmanager" + "github.com/ipfs/go-bitswap/internal/testutil" blocks "github.com/ipfs/go-block-format" cid "github.com/ipfs/go-cid" @@ -23,8 +25,8 @@ type fakeSession struct { wantBlocks []cid.Cid wantHaves []cid.Cid id uint64 - onShutdown bssession.OnShutdown pm *fakeSesPeerManager + sm bssession.SessionManager notif notifications.PubSub } @@ -43,7 +45,7 @@ func (fs *fakeSession) ReceiveFrom(p peer.ID, ks []cid.Cid, wantBlocks []cid.Cid fs.wantHaves = append(fs.wantHaves, wantHaves...) } func (fs *fakeSession) Shutdown() { - go fs.onShutdown(fs.id) + fs.sm.RemoveSession(fs.id) } type fakeSesPeerManager struct { @@ -57,6 +59,7 @@ func (*fakeSesPeerManager) RemovePeer(peer.ID) bool { return false } func (*fakeSesPeerManager) HasPeers() bool { return false } type fakePeerManager struct { + lk sync.Mutex cancels []cid.Cid } @@ -65,11 +68,18 @@ func (*fakePeerManager) UnregisterSession(uint64) func (*fakePeerManager) SendWants(context.Context, peer.ID, []cid.Cid, []cid.Cid) {} func (*fakePeerManager) BroadcastWantHaves(context.Context, []cid.Cid) {} func (fpm *fakePeerManager) SendCancels(ctx context.Context, cancels []cid.Cid) { + fpm.lk.Lock() + defer fpm.lk.Unlock() fpm.cancels = append(fpm.cancels, cancels...) } +func (fpm *fakePeerManager) cancelled() []cid.Cid { + fpm.lk.Lock() + defer fpm.lk.Unlock() + return fpm.cancels +} func sessionFactory(ctx context.Context, - onShutdown bssession.OnShutdown, + sm bssession.SessionManager, id uint64, sprm bssession.SessionPeerManager, sim *bssim.SessionInterestManager, @@ -80,14 +90,14 @@ func sessionFactory(ctx context.Context, rebroadcastDelay delay.D, self peer.ID) Session { fs := &fakeSession{ - id: id, - onShutdown: onShutdown, - pm: sprm.(*fakeSesPeerManager), - notif: notif, + id: id, + pm: sprm.(*fakeSesPeerManager), + sm: sm, + notif: notif, } go func() { <-ctx.Done() - fs.onShutdown(fs.id) + sm.RemoveSession(fs.id) }() return fs } @@ -138,7 +148,7 @@ func TestReceiveFrom(t *testing.T) { t.Fatal("should have received want-haves but didn't") } - if len(pm.cancels) != 1 { + if len(pm.cancelled()) != 1 { t.Fatal("should have sent cancel for received blocks") } } @@ -179,8 +189,7 @@ func TestReceiveBlocksWhenManagerShutdown(t *testing.T) { } func TestReceiveBlocksWhenSessionContextCancelled(t *testing.T) { - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() notif := notifications.New() defer notif.Shutdown() @@ -213,3 +222,38 @@ func TestReceiveBlocksWhenSessionContextCancelled(t *testing.T) { t.Fatal("received blocks for sessions that are canceled") } } + +func TestShutdown(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + notif := notifications.New() + defer notif.Shutdown() + sim := bssim.New() + bpm := bsbpm.New() + pm := &fakePeerManager{} + sm := New(ctx, sessionFactory, sim, peerManagerFactory, bpm, pm, notif, "") + + p := peer.ID(123) + block := blocks.NewBlock([]byte("block")) + cids := []cid.Cid{block.Cid()} + firstSession := sm.NewSession(ctx, time.Second, delay.Fixed(time.Minute)).(*fakeSession) + sim.RecordSessionInterest(firstSession.ID(), cids) + sm.ReceiveFrom(ctx, p, []cid.Cid{}, []cid.Cid{}, cids) + + if !bpm.HasKey(block.Cid()) { + t.Fatal("expected cid to be added to block presence manager") + } + + sm.Shutdown() + + // wait for cleanup + time.Sleep(10 * time.Millisecond) + + if bpm.HasKey(block.Cid()) { + t.Fatal("expected cid to be removed from block presence manager") + } + if !testutil.MatchKeysIgnoreOrder(pm.cancelled(), cids) { + t.Fatal("expected cancels to be sent") + } +} diff --git a/internal/sessionwantlist/sessionwantlist.go b/internal/sessionwantlist/sessionwantlist.go deleted file mode 100644 index 05c14336..00000000 --- a/internal/sessionwantlist/sessionwantlist.go +++ /dev/null @@ -1,137 +0,0 @@ -package sessionwantlist - -import ( - "sync" - - cid "github.com/ipfs/go-cid" -) - -// The SessionWantList keeps track of which sessions want a CID -type SessionWantlist struct { - sync.RWMutex - wants map[cid.Cid]map[uint64]struct{} -} - -func NewSessionWantlist() *SessionWantlist { - return &SessionWantlist{ - wants: make(map[cid.Cid]map[uint64]struct{}), - } -} - -// The given session wants the keys -func (swl *SessionWantlist) Add(ks []cid.Cid, ses uint64) { - swl.Lock() - defer swl.Unlock() - - for _, c := range ks { - if _, ok := swl.wants[c]; !ok { - swl.wants[c] = make(map[uint64]struct{}) - } - swl.wants[c][ses] = struct{}{} - } -} - -// Remove the keys for all sessions. -// Called when blocks are received. -func (swl *SessionWantlist) RemoveKeys(ks []cid.Cid) { - swl.Lock() - defer swl.Unlock() - - for _, c := range ks { - delete(swl.wants, c) - } -} - -// Remove the session's wants, and return wants that are no longer wanted by -// any session. -func (swl *SessionWantlist) RemoveSession(ses uint64) []cid.Cid { - swl.Lock() - defer swl.Unlock() - - deletedKs := make([]cid.Cid, 0) - for c := range swl.wants { - delete(swl.wants[c], ses) - if len(swl.wants[c]) == 0 { - delete(swl.wants, c) - deletedKs = append(deletedKs, c) - } - } - - return deletedKs -} - -// Remove the session's wants -func (swl *SessionWantlist) RemoveSessionKeys(ses uint64, ks []cid.Cid) { - swl.Lock() - defer swl.Unlock() - - for _, c := range ks { - if _, ok := swl.wants[c]; ok { - delete(swl.wants[c], ses) - if len(swl.wants[c]) == 0 { - delete(swl.wants, c) - } - } - } -} - -// All keys wanted by all sessions -func (swl *SessionWantlist) Keys() []cid.Cid { - swl.RLock() - defer swl.RUnlock() - - ks := make([]cid.Cid, 0, len(swl.wants)) - for c := range swl.wants { - ks = append(ks, c) - } - return ks -} - -// All sessions that want the given keys -func (swl *SessionWantlist) SessionsFor(ks []cid.Cid) []uint64 { - swl.RLock() - defer swl.RUnlock() - - sesMap := make(map[uint64]struct{}) - for _, c := range ks { - for s := range swl.wants[c] { - sesMap[s] = struct{}{} - } - } - - ses := make([]uint64, 0, len(sesMap)) - for s := range sesMap { - ses = append(ses, s) - } - return ses -} - -// Filter for keys that at least one session wants -func (swl *SessionWantlist) Has(ks []cid.Cid) *cid.Set { - swl.RLock() - defer swl.RUnlock() - - has := cid.NewSet() - for _, c := range ks { - if _, ok := swl.wants[c]; ok { - has.Add(c) - } - } - return has -} - -// Filter for keys that the given session wants -func (swl *SessionWantlist) SessionHas(ses uint64, ks []cid.Cid) *cid.Set { - swl.RLock() - defer swl.RUnlock() - - has := cid.NewSet() - for _, c := range ks { - if sesMap, cok := swl.wants[c]; cok { - if _, sok := sesMap[ses]; sok { - has.Add(c) - } - } - } - return has -} diff --git a/internal/sessionwantlist/sessionwantlist_test.go b/internal/sessionwantlist/sessionwantlist_test.go deleted file mode 100644 index d57f9395..00000000 --- a/internal/sessionwantlist/sessionwantlist_test.go +++ /dev/null @@ -1,258 +0,0 @@ -package sessionwantlist - -import ( - "os" - "testing" - - "github.com/ipfs/go-bitswap/internal/testutil" - - cid "github.com/ipfs/go-cid" -) - -var c0 cid.Cid -var c1 cid.Cid -var c2 cid.Cid - -const s0 = uint64(0) -const s1 = uint64(1) - -func setup() { - cids := testutil.GenerateCids(3) - c0 = cids[0] - c1 = cids[1] - c2 = cids[2] -} - -func TestMain(m *testing.M) { - setup() - os.Exit(m.Run()) -} - -func TestEmpty(t *testing.T) { - swl := NewSessionWantlist() - - if len(swl.Keys()) != 0 { - t.Fatal("Expected Keys() to be empty") - } - if len(swl.SessionsFor([]cid.Cid{c0})) != 0 { - t.Fatal("Expected SessionsFor() to be empty") - } -} - -func TestSimpleAdd(t *testing.T) { - swl := NewSessionWantlist() - - // s0: c0 - swl.Add([]cid.Cid{c0}, s0) - if len(swl.Keys()) != 1 { - t.Fatal("Expected Keys() to have length 1") - } - if !swl.Keys()[0].Equals(c0) { - t.Fatal("Expected Keys() to be [cid0]") - } - if len(swl.SessionsFor([]cid.Cid{c0})) != 1 { - t.Fatal("Expected SessionsFor() to have length 1") - } - if swl.SessionsFor([]cid.Cid{c0})[0] != s0 { - t.Fatal("Expected SessionsFor() to be [s0]") - } - - // s0: c0, c1 - swl.Add([]cid.Cid{c1}, s0) - if len(swl.Keys()) != 2 { - t.Fatal("Expected Keys() to have length 2") - } - if !testutil.MatchKeysIgnoreOrder(swl.Keys(), []cid.Cid{c0, c1}) { - t.Fatal("Expected Keys() to contain [cid0, cid1]") - } - if len(swl.SessionsFor([]cid.Cid{c0})) != 1 { - t.Fatal("Expected SessionsFor() to have length 1") - } - if swl.SessionsFor([]cid.Cid{c0})[0] != s0 { - t.Fatal("Expected SessionsFor() to be [s0]") - } - - // s0: c0, c1 - // s1: c0 - swl.Add([]cid.Cid{c0}, s1) - if len(swl.Keys()) != 2 { - t.Fatal("Expected Keys() to have length 2") - } - if !testutil.MatchKeysIgnoreOrder(swl.Keys(), []cid.Cid{c0, c1}) { - t.Fatal("Expected Keys() to contain [cid0, cid1]") - } - if len(swl.SessionsFor([]cid.Cid{c0})) != 2 { - t.Fatal("Expected SessionsFor() to have length 2") - } -} - -func TestMultiKeyAdd(t *testing.T) { - swl := NewSessionWantlist() - - // s0: c0, c1 - swl.Add([]cid.Cid{c0, c1}, s0) - if len(swl.Keys()) != 2 { - t.Fatal("Expected Keys() to have length 2") - } - if !testutil.MatchKeysIgnoreOrder(swl.Keys(), []cid.Cid{c0, c1}) { - t.Fatal("Expected Keys() to contain [cid0, cid1]") - } - if len(swl.SessionsFor([]cid.Cid{c0})) != 1 { - t.Fatal("Expected SessionsFor() to have length 1") - } - if swl.SessionsFor([]cid.Cid{c0})[0] != s0 { - t.Fatal("Expected SessionsFor() to be [s0]") - } -} - -func TestSessionHas(t *testing.T) { - swl := NewSessionWantlist() - - if swl.Has([]cid.Cid{c0, c1}).Len() > 0 { - t.Fatal("Expected Has([c0, c1]) to be []") - } - if swl.SessionHas(s0, []cid.Cid{c0, c1}).Len() > 0 { - t.Fatal("Expected SessionHas(s0, [c0, c1]) to be []") - } - - // s0: c0 - swl.Add([]cid.Cid{c0}, s0) - if !matchSet(swl.Has([]cid.Cid{c0, c1}), []cid.Cid{c0}) { - t.Fatal("Expected Has([c0, c1]) to be [c0]") - } - if !matchSet(swl.SessionHas(s0, []cid.Cid{c0, c1}), []cid.Cid{c0}) { - t.Fatal("Expected SessionHas(s0, [c0, c1]) to be [c0]") - } - if swl.SessionHas(s1, []cid.Cid{c0, c1}).Len() > 0 { - t.Fatal("Expected SessionHas(s1, [c0, c1]) to be []") - } - - // s0: c0, c1 - swl.Add([]cid.Cid{c1}, s0) - if !matchSet(swl.Has([]cid.Cid{c0, c1}), []cid.Cid{c0, c1}) { - t.Fatal("Expected Has([c0, c1]) to be [c0, c1]") - } - if !matchSet(swl.SessionHas(s0, []cid.Cid{c0, c1}), []cid.Cid{c0, c1}) { - t.Fatal("Expected SessionHas(s0, [c0, c1]) to be [c0, c1]") - } - - // s0: c0, c1 - // s1: c0 - swl.Add([]cid.Cid{c0}, s1) - if len(swl.Keys()) != 2 { - t.Fatal("Expected Keys() to have length 2") - } - if !matchSet(swl.Has([]cid.Cid{c0, c1}), []cid.Cid{c0, c1}) { - t.Fatal("Expected Has([c0, c1]) to be [c0, c1]") - } - if !matchSet(swl.SessionHas(s0, []cid.Cid{c0, c1}), []cid.Cid{c0, c1}) { - t.Fatal("Expected SessionHas(s0, [c0, c1]) to be [c0, c1]") - } - if !matchSet(swl.SessionHas(s1, []cid.Cid{c0, c1}), []cid.Cid{c0}) { - t.Fatal("Expected SessionHas(s1, [c0, c1]) to be [c0]") - } -} - -func TestSimpleRemoveKeys(t *testing.T) { - swl := NewSessionWantlist() - - // s0: c0, c1 - // s1: c0 - swl.Add([]cid.Cid{c0, c1}, s0) - swl.Add([]cid.Cid{c0}, s1) - - // s0: c1 - swl.RemoveKeys([]cid.Cid{c0}) - if len(swl.Keys()) != 1 { - t.Fatal("Expected Keys() to have length 1") - } - if !swl.Keys()[0].Equals(c1) { - t.Fatal("Expected Keys() to be [cid1]") - } - if len(swl.SessionsFor([]cid.Cid{c0})) != 0 { - t.Fatal("Expected SessionsFor(c0) to be empty") - } - if len(swl.SessionsFor([]cid.Cid{c1})) != 1 { - t.Fatal("Expected SessionsFor(c1) to have length 1") - } - if swl.SessionsFor([]cid.Cid{c1})[0] != s0 { - t.Fatal("Expected SessionsFor(c1) to be [s0]") - } -} - -func TestMultiRemoveKeys(t *testing.T) { - swl := NewSessionWantlist() - - // s0: c0, c1 - // s1: c0 - swl.Add([]cid.Cid{c0, c1}, s0) - swl.Add([]cid.Cid{c0}, s1) - - // - swl.RemoveKeys([]cid.Cid{c0, c1}) - if len(swl.Keys()) != 0 { - t.Fatal("Expected Keys() to be empty") - } - if len(swl.SessionsFor([]cid.Cid{c0})) != 0 { - t.Fatal("Expected SessionsFor() to be empty") - } -} - -func TestRemoveSession(t *testing.T) { - swl := NewSessionWantlist() - - // s0: c0, c1 - // s1: c0 - swl.Add([]cid.Cid{c0, c1}, s0) - swl.Add([]cid.Cid{c0}, s1) - - // s1: c0 - swl.RemoveSession(s0) - if len(swl.Keys()) != 1 { - t.Fatal("Expected Keys() to have length 1") - } - if !swl.Keys()[0].Equals(c0) { - t.Fatal("Expected Keys() to be [cid0]") - } - if len(swl.SessionsFor([]cid.Cid{c1})) != 0 { - t.Fatal("Expected SessionsFor(c1) to be empty") - } - if len(swl.SessionsFor([]cid.Cid{c0})) != 1 { - t.Fatal("Expected SessionsFor(c0) to have length 1") - } - if swl.SessionsFor([]cid.Cid{c0})[0] != s1 { - t.Fatal("Expected SessionsFor(c0) to be [s1]") - } -} - -func TestRemoveSessionKeys(t *testing.T) { - swl := NewSessionWantlist() - - // s0: c0, c1, c2 - // s1: c0 - swl.Add([]cid.Cid{c0, c1, c2}, s0) - swl.Add([]cid.Cid{c0}, s1) - - // s0: c2 - // s1: c0 - swl.RemoveSessionKeys(s0, []cid.Cid{c0, c1}) - if !matchSet(swl.SessionHas(s0, []cid.Cid{c0, c1, c2}), []cid.Cid{c2}) { - t.Fatal("Expected SessionHas(s0, [c0, c1, c2]) to be [c2]") - } - if !matchSet(swl.SessionHas(s1, []cid.Cid{c0, c1, c2}), []cid.Cid{c0}) { - t.Fatal("Expected SessionHas(s1, [c0, c1, c2]) to be [c0]") - } -} - -func matchSet(ks1 *cid.Set, ks2 []cid.Cid) bool { - if ks1.Len() != len(ks2) { - return false - } - - for _, k := range ks2 { - if !ks1.Has(k) { - return false - } - } - return true -}