Skip to content

Commit

Permalink
Merge pull request #838 from nats-io/raft_shutdown
Browse files Browse the repository at this point in the history
Changes to how raft node is shutdown
  • Loading branch information
kozlovic authored May 18, 2019
2 parents 7b9c7fe + 931300d commit 51d6217
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 16 deletions.
13 changes: 4 additions & 9 deletions server/clustering.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,18 @@ func (r *raftNode) shutdown() error {
}
r.closed = true
r.Unlock()
if r.Raft != nil {
if err := r.Raft.Shutdown().Error(); err != nil {
return err
}
}
if r.transport != nil {
if err := r.transport.Close(); err != nil {
return err
}
}
if r.store != nil {
if err := r.store.Close(); err != nil {
if r.Raft != nil {
if err := r.Raft.Shutdown().Error(); err != nil {
return err
}
}
if r.joinSub != nil {
if err := r.joinSub.Unsubscribe(); err != nil {
if r.store != nil {
if err := r.store.Close(); err != nil {
return err
}
}
Expand Down
8 changes: 7 additions & 1 deletion server/raft_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ func (n *natsStreamLayer) Accept() (net.Conn, error) {

func (n *natsStreamLayer) Close() error {
n.mu.Lock()
nc := n.conn
// Do not set nc.conn to nil since it is accessed in some functions
// without the stream layer lock
conns := make(map[*natsConn]struct{}, len(n.conns))
for conn, s := range n.conns {
conns[conn] = s
Expand All @@ -338,7 +341,10 @@ func (n *natsStreamLayer) Close() error {
for c := range conns {
c.Close()
}
return n.sub.Unsubscribe()
if nc != nil {
nc.Close()
}
return nil
}

func (n *natsStreamLayer) Addr() net.Addr {
Expand Down
24 changes: 18 additions & 6 deletions server/raft_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ func TestRAFTTransportHeartbeatFastPath(t *testing.T) {
trans1.SetHeartbeatHandler(fastpath)

// Transport 2 makes outbound request
trans2, err := newNATSTransportWithLogger("b", nc, time.Second, newTestLogger(t))
nc2 := newNatsConnection(t)
defer nc2.Close()
trans2, err := newNATSTransportWithLogger("b", nc2, time.Second, newTestLogger(t))
if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -204,7 +206,9 @@ func TestRAFTTransportAppendEntries(t *testing.T) {
}()

// Transport 2 makes outbound request
trans2, err := newNATSTransportWithLogger("b", nc, time.Second, newTestLogger(t))
nc2 := newNatsConnection(t)
defer nc2.Close()
trans2, err := newNATSTransportWithLogger("b", nc2, time.Second, newTestLogger(t))
if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -285,7 +289,9 @@ func TestRAFTTransportAppendEntriesPipeline(t *testing.T) {
}()

// Transport 2 makes outbound request
trans2, err := newNATSTransportWithLogger("b", nc, time.Second, newTestLogger(t))
nc2 := newNatsConnection(t)
defer nc2.Close()
trans2, err := newNATSTransportWithLogger("b", nc2, time.Second, newTestLogger(t))
if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -372,7 +378,9 @@ func TestRAFTTransportRequestVote(t *testing.T) {
}()

// Transport 2 makes outbound request
trans2, err := newNATSTransportWithLogger("b", nc, time.Second, newTestLogger(t))
nc2 := newNatsConnection(t)
defer nc2.Close()
trans2, err := newNATSTransportWithLogger("b", nc2, time.Second, newTestLogger(t))
if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -455,7 +463,9 @@ func TestRAFTTransportInstallSnapshot(t *testing.T) {
}()

// Transport 2 makes outbound request
trans2, err := newNATSTransportWithLogger("b", nc, time.Second, newTestLogger(t))
nc2 := newNatsConnection(t)
defer nc2.Close()
trans2, err := newNATSTransportWithLogger("b", nc2, time.Second, newTestLogger(t))
if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -583,7 +593,9 @@ func TestRAFTTransportPooledConn(t *testing.T) {
}()

// Transport 2 makes outbound request, 3 conn pool
trans2, err := newNATSTransportWithLogger("b", nc, time.Second, newTestLogger(t))
nc2 := newNatsConnection(t)
defer nc2.Close()
trans2, err := newNATSTransportWithLogger("b", nc2, time.Second, newTestLogger(t))
if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down

0 comments on commit 51d6217

Please sign in to comment.