diff --git a/server/clustering.go b/server/clustering.go index 6373f8f8..9ae5c449 100644 --- a/server/clustering.go +++ b/server/clustering.go @@ -101,23 +101,18 @@ func (r *raftNode) shutdown() error { } r.closed = true r.Unlock() - if r.Raft != nil { - if err := r.Raft.Shutdown().Error(); err != nil { - return err - } - } if r.transport != nil { if err := r.transport.Close(); err != nil { return err } } - if r.store != nil { - if err := r.store.Close(); err != nil { + if r.Raft != nil { + if err := r.Raft.Shutdown().Error(); err != nil { return err } } - if r.joinSub != nil { - if err := r.joinSub.Unsubscribe(); err != nil { + if r.store != nil { + if err := r.store.Close(); err != nil { return err } } diff --git a/server/raft_transport.go b/server/raft_transport.go index 31942fe2..64dda97e 100644 --- a/server/raft_transport.go +++ b/server/raft_transport.go @@ -330,6 +330,9 @@ func (n *natsStreamLayer) Accept() (net.Conn, error) { func (n *natsStreamLayer) Close() error { n.mu.Lock() + nc := n.conn + // Do not set nc.conn to nil since it is accessed in some functions + // without the stream layer lock conns := make(map[*natsConn]struct{}, len(n.conns)) for conn, s := range n.conns { conns[conn] = s @@ -338,7 +341,10 @@ func (n *natsStreamLayer) Close() error { for c := range conns { c.Close() } - return n.sub.Unsubscribe() + if nc != nil { + nc.Close() + } + return nil } func (n *natsStreamLayer) Addr() net.Addr { diff --git a/server/raft_transport_test.go b/server/raft_transport_test.go index 640aca8f..350940ea 100644 --- a/server/raft_transport_test.go +++ b/server/raft_transport_test.go @@ -124,7 +124,9 @@ func TestRAFTTransportHeartbeatFastPath(t *testing.T) { trans1.SetHeartbeatHandler(fastpath) // Transport 2 makes outbound request - trans2, err := newNATSTransportWithLogger("b", nc, time.Second, newTestLogger(t)) + nc2 := newNatsConnection(t) + defer nc2.Close() + trans2, err := newNATSTransportWithLogger("b", nc2, time.Second, newTestLogger(t)) if err != nil { t.Fatalf("err: %v", err) } @@ -204,7 +206,9 @@ func TestRAFTTransportAppendEntries(t *testing.T) { }() // Transport 2 makes outbound request - trans2, err := newNATSTransportWithLogger("b", nc, time.Second, newTestLogger(t)) + nc2 := newNatsConnection(t) + defer nc2.Close() + trans2, err := newNATSTransportWithLogger("b", nc2, time.Second, newTestLogger(t)) if err != nil { t.Fatalf("err: %v", err) } @@ -285,7 +289,9 @@ func TestRAFTTransportAppendEntriesPipeline(t *testing.T) { }() // Transport 2 makes outbound request - trans2, err := newNATSTransportWithLogger("b", nc, time.Second, newTestLogger(t)) + nc2 := newNatsConnection(t) + defer nc2.Close() + trans2, err := newNATSTransportWithLogger("b", nc2, time.Second, newTestLogger(t)) if err != nil { t.Fatalf("err: %v", err) } @@ -372,7 +378,9 @@ func TestRAFTTransportRequestVote(t *testing.T) { }() // Transport 2 makes outbound request - trans2, err := newNATSTransportWithLogger("b", nc, time.Second, newTestLogger(t)) + nc2 := newNatsConnection(t) + defer nc2.Close() + trans2, err := newNATSTransportWithLogger("b", nc2, time.Second, newTestLogger(t)) if err != nil { t.Fatalf("err: %v", err) } @@ -455,7 +463,9 @@ func TestRAFTTransportInstallSnapshot(t *testing.T) { }() // Transport 2 makes outbound request - trans2, err := newNATSTransportWithLogger("b", nc, time.Second, newTestLogger(t)) + nc2 := newNatsConnection(t) + defer nc2.Close() + trans2, err := newNATSTransportWithLogger("b", nc2, time.Second, newTestLogger(t)) if err != nil { t.Fatalf("err: %v", err) } @@ -583,7 +593,9 @@ func TestRAFTTransportPooledConn(t *testing.T) { }() // Transport 2 makes outbound request, 3 conn pool - trans2, err := newNATSTransportWithLogger("b", nc, time.Second, newTestLogger(t)) + nc2 := newNatsConnection(t) + defer nc2.Close() + trans2, err := newNATSTransportWithLogger("b", nc2, time.Second, newTestLogger(t)) if err != nil { t.Fatalf("err: %v", err) }