diff --git a/connmgr.go b/connmgr.go index 0ee8901..804730d 100644 --- a/connmgr.go +++ b/connmgr.go @@ -212,7 +212,7 @@ func (cm *BasicConnMgr) getConnsToClose(ctx context.Context) []network.Conn { // disabled return nil } - now := time.Now() + nconns := int(atomic.LoadInt32(&cm.connCount)) if nconns <= cm.lowWater { log.Info("open connection count below limit") @@ -221,6 +221,9 @@ func (cm *BasicConnMgr) getConnsToClose(ctx context.Context) []network.Conn { npeers := cm.segments.countPeers() candidates := make([]*peerInfo, 0, npeers) + ncandidates := 0 + gracePeriodStart := time.Now().Add(-cm.gracePeriod) + cm.plk.RLock() for _, s := range cm.segments { s.Lock() @@ -229,12 +232,26 @@ func (cm *BasicConnMgr) getConnsToClose(ctx context.Context) []network.Conn { // skip over protected peer. continue } + if inf.firstSeen.After(gracePeriodStart) { + // skip peers in the grace period. + continue + } candidates = append(candidates, inf) + ncandidates += len(inf.conns) } s.Unlock() } cm.plk.RUnlock() + if ncandidates < cm.lowWater { + log.Info("open connection count above limit but too many are in the grace period") + // We have too many connections but fewer than lowWater + // connections out of the grace period. + // + // If we trimmed now, we'd kill potentially useful connections. + return nil + } + // Sort peers according to their value. sort.Slice(candidates, func(i, j int) bool { left, right := candidates[i], candidates[j] @@ -246,7 +263,7 @@ func (cm *BasicConnMgr) getConnsToClose(ctx context.Context) []network.Conn { return left.value < right.value }) - target := nconns - cm.lowWater + target := ncandidates - cm.lowWater // slightly overallocate because we may have more than one conns per peer selected := make([]network.Conn, 0, target+10) @@ -255,10 +272,6 @@ func (cm *BasicConnMgr) getConnsToClose(ctx context.Context) []network.Conn { if target <= 0 { break } - // TODO: should we be using firstSeen or the time associated with the connection itself? - if inf.firstSeen.Add(cm.gracePeriod).After(now) { - continue - } // lock this to protect from concurrent modifications from connect/disconnect events s := cm.segments.get(inf.id) diff --git a/connmgr_test.go b/connmgr_test.go index edaff7b..ad78786 100644 --- a/connmgr_test.go +++ b/connmgr_test.go @@ -303,6 +303,63 @@ func TestDisconnected(t *testing.T) { } } +func TestGracePeriod(t *testing.T) { + if detectrace.WithRace() { + t.Skip("race detector is unhappy with this test") + } + + SilencePeriod = 0 + cm := NewConnManager(10, 20, 100*time.Millisecond) + SilencePeriod = 10 * time.Second + + not := cm.Notifee() + + var conns []network.Conn + + // Add a connection and wait the grace period. + { + rc := randConn(t, not.Disconnected) + conns = append(conns, rc) + not.Connected(nil, rc) + + time.Sleep(200 * time.Millisecond) + + if rc.(*tconn).closed { + t.Fatal("expected conn to remain open") + } + } + + // quickly add 30 connections (sending us above the high watermark) + for i := 0; i < 30; i++ { + rc := randConn(t, not.Disconnected) + conns = append(conns, rc) + not.Connected(nil, rc) + } + + cm.TrimOpenConns(context.Background()) + + for _, c := range conns { + if c.(*tconn).closed { + t.Fatal("expected no conns to be closed") + } + } + + time.Sleep(200 * time.Millisecond) + + cm.TrimOpenConns(context.Background()) + + closed := 0 + for _, c := range conns { + if c.(*tconn).closed { + closed++ + } + } + + if closed != 21 { + t.Fatal("expected to have closed 21 connections") + } +} + // see https://github.com/libp2p/go-libp2p-connmgr/issues/23 func TestQuickBurstRespectsSilencePeriod(t *testing.T) { if detectrace.WithRace() { @@ -350,13 +407,17 @@ func TestPeerProtectionSingleTag(t *testing.T) { not := cm.Notifee() - // produce 20 connections with unique peers. var conns []network.Conn - for i := 0; i < 20; i++ { + addConn := func(value int) { rc := randConn(t, not.Disconnected) conns = append(conns, rc) not.Connected(nil, rc) - cm.TagPeer(rc.RemotePeer(), "test", 20) + cm.TagPeer(rc.RemotePeer(), "test", value) + } + + // produce 20 connections with unique peers. + for i := 0; i < 20; i++ { + addConn(20) } // protect the first 5 peers. @@ -368,8 +429,21 @@ func TestPeerProtectionSingleTag(t *testing.T) { cm.TagPeer(c.RemotePeer(), "test", -100) } - // add one more connection, sending the connection manager overboard. - not.Connected(nil, randConn(t, not.Disconnected)) + // add 1 more conn, this shouldn't send us over the limit as protected conns don't count + addConn(20) + + cm.TrimOpenConns(context.Background()) + + for _, c := range conns { + if c.(*tconn).closed { + t.Error("connection was closed by connection manager") + } + } + + // add 5 more connection, sending the connection manager overboard. + for i := 0; i < 5; i++ { + addConn(20) + } cm.TrimOpenConns(context.Background()) @@ -379,15 +453,22 @@ func TestPeerProtectionSingleTag(t *testing.T) { } } + closed := 0 + for _, c := range conns { + if c.(*tconn).closed { + closed++ + } + } + if closed != 2 { + t.Errorf("expected 2 connection to be closed, found %d", closed) + } + // unprotect the first peer. cm.Unprotect(protected[0].RemotePeer(), "global") // add 2 more connections, sending the connection manager overboard again. for i := 0; i < 2; i++ { - rc := randConn(t, not.Disconnected) - conns = append(conns, rc) - not.Connected(nil, rc) - cm.TagPeer(rc.RemotePeer(), "test", 20) + addConn(20) } cm.TrimOpenConns(context.Background())