Skip to content

Commit

Permalink
[FIXED] Call ConnectedCB with RetryOnFailedConnect when initial conn …
Browse files Browse the repository at this point in the history
…failed (#1619)

Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Aug 15, 2024
1 parent 0110705 commit 10bd240
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 47 deletions.
14 changes: 9 additions & 5 deletions nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -2875,15 +2875,19 @@ func (nc *Conn) doReconnect(err error) {
// This is where we are truly connected.
nc.status = CONNECTED

// Queue up the correct callback. If we are in initial connect state
// (using retry on failed connect), we will call the ConnectedCB,
// otherwise the ReconnectedCB.
if nc.Opts.ReconnectedCB != nil && !nc.initc {
nc.ach.push(func() { nc.Opts.ReconnectedCB(nc) })
} else if nc.Opts.ConnectedCB != nil && nc.initc {
nc.ach.push(func() { nc.Opts.ConnectedCB(nc) })
}

// If we are here with a retry on failed connect, indicate that the
// initial connect is now complete.
nc.initc = false

// Queue up the reconnect callback.
if nc.Opts.ReconnectedCB != nil {
nc.ach.push(func() { nc.Opts.ReconnectedCB(nc) })
}

// Release lock here, we will return below.
nc.mu.Unlock()

Expand Down
169 changes: 127 additions & 42 deletions test/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1094,16 +1094,21 @@ func TestCallbacksOrder(t *testing.T) {
}

func TestConnectHandler(t *testing.T) {
handler := func(ch chan bool) func(*nats.Conn) {
return func(*nats.Conn) {
ch <- true
}
}
t.Run("with RetryOnFailedConnect, connection established", func(t *testing.T) {
s := RunDefaultServer()
defer s.Shutdown()

connected := make(chan bool)
connHandler := func(*nats.Conn) {
connected <- true
}
reconnected := make(chan bool)

nc, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(connHandler),
nats.ConnectHandler(handler(connected)),
nats.ReconnectHandler(handler(reconnected)),
nats.RetryOnFailedConnect(true))

if err != nil {
Expand All @@ -1113,59 +1118,135 @@ func TestConnectHandler(t *testing.T) {
if err = Wait(connected); err != nil {
t.Fatal("Timeout waiting for connect handler")
}
if err = WaitTime(reconnected, 100*time.Millisecond); err == nil {
t.Fatal("Reconnect handler should not have been invoked")
}
})
t.Run("with RetryOnFailedConnect, connection failed", func(t *testing.T) {
connected := make(chan bool)
connHandler := func(*nats.Conn) {
connected <- true
}
reconnected := make(chan bool)

nc, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(connHandler),
nats.ConnectHandler(handler(connected)),
nats.ReconnectHandler(handler(reconnected)),
nats.RetryOnFailedConnect(true))

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
select {
case <-connected:
t.Fatalf("ConnectedCB invoked when no connection established")
case <-time.After(100 * time.Millisecond):
if err = WaitTime(connected, 100*time.Millisecond); err == nil {
t.Fatal("Connected handler should not have been invoked")
}
if err = WaitTime(reconnected, 100*time.Millisecond); err == nil {
t.Fatal("Reconnect handler should not have been invoked")
}
})
t.Run("no RetryOnFailedConnect, connection established", func(t *testing.T) {
s := RunDefaultServer()
defer s.Shutdown()

connected := make(chan bool)
connHandler := func(*nats.Conn) {
connected <- true
}
reconnected := make(chan bool)
nc, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(connHandler))
nats.ConnectHandler(handler(connected)),
nats.ReconnectHandler(handler(reconnected)))

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
if err = Wait(connected); err != nil {
t.Fatal("Timeout waiting for connect handler")
}
if err = WaitTime(reconnected, 100*time.Millisecond); err == nil {
t.Fatal("Reconnect handler should not have been invoked")
}
})
t.Run("no RetryOnFailedConnect, connection failed", func(t *testing.T) {
connected := make(chan bool)
connHandler := func(*nats.Conn) {
connected <- true
}
reconnected := make(chan bool)
_, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(connHandler))
nats.ConnectHandler(handler(connected)),
nats.ReconnectHandler(handler(reconnected)))

if err == nil {
t.Fatalf("Expected error on connect, got nil")
}
select {
case <-connected:
t.Fatalf("ConnectedCB invoked when no connection established")
case <-time.After(100 * time.Millisecond):
if err = WaitTime(connected, 100*time.Millisecond); err == nil {
t.Fatal("Connected handler should not have been invoked")
}
if err = WaitTime(reconnected, 100*time.Millisecond); err == nil {
t.Fatal("Reconnect handler should not have been invoked")
}
})
t.Run("with RetryOnFailedConnect, initial connection failed, reconnect successful", func(t *testing.T) {
connected := make(chan bool)
reconnected := make(chan bool)

nc, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(handler(connected)),
nats.ReconnectHandler(handler(reconnected)),
nats.RetryOnFailedConnect(true),
nats.ReconnectWait(100*time.Millisecond))

if err != nil {
t.Fatalf("Expected error on connect, got nil")
}

defer nc.Close()

s := RunDefaultServer()
defer s.Shutdown()

if err != nil {
t.Fatalf("Expected error on connect, got nil")
}
if err = Wait(connected); err != nil {
t.Fatal("Timeout waiting for reconnect handler")
}
if err = WaitTime(reconnected, 100*time.Millisecond); err == nil {
t.Fatal("Reconnect handler should not have been invoked")
}
})
t.Run("with RetryOnFailedConnect, initial connection successful, server restart", func(t *testing.T) {
connected := make(chan bool)
reconnected := make(chan bool)

s := RunDefaultServer()
defer s.Shutdown()

nc, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(handler(connected)),
nats.ReconnectHandler(handler(reconnected)),
nats.RetryOnFailedConnect(true),
nats.ReconnectWait(100*time.Millisecond))

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if err = Wait(connected); err != nil {
t.Fatal("Timeout waiting for connect handler")
}
if err = WaitTime(reconnected, 100*time.Millisecond); err == nil {
t.Fatal("Reconnect handler should not have been invoked")
}

s.Shutdown()

s = RunDefaultServer()
defer s.Shutdown()

if err = Wait(reconnected); err != nil {
t.Fatal("Timeout waiting for reconnect handler")
}
if err = WaitTime(connected, 100*time.Millisecond); err == nil {
t.Fatal("Connected handler should not have been invoked")
}
})
}
Expand Down Expand Up @@ -2709,7 +2790,8 @@ func TestRetryOnFailedConnect(t *testing.T) {
nc.Close()
t.Fatal("Expected error, did not get one")
}
ch := make(chan bool, 1)
reconnectedCh := make(chan bool, 1)
connectedCh := make(chan bool, 1)
dch := make(chan bool, 1)
nc, err = nats.Connect(nats.DefaultURL,
nats.RetryOnFailedConnect(true),
Expand All @@ -2718,8 +2800,11 @@ func TestRetryOnFailedConnect(t *testing.T) {
nats.DisconnectErrHandler(func(_ *nats.Conn, _ error) {
dch <- true
}),
nats.ConnectHandler(func(_ *nats.Conn) {
connectedCh <- true
}),
nats.ReconnectHandler(func(_ *nats.Conn) {
ch <- true
reconnectedCh <- true
}),
nats.NoCallbacksAfterClientClose())
if err != nil {
Expand All @@ -2737,19 +2822,19 @@ func TestRetryOnFailedConnect(t *testing.T) {
s := RunDefaultServer()
defer s.Shutdown()

var action string
switch i {
case 0:
action = "connected"
select {
case <-connectedCh:
case <-time.After(2 * time.Second):
t.Fatal("Should have connected")
}
case 1:
action = "reconnected"
}

// Wait for the reconnect CB which in this context means that we connected ok
select {
case <-ch:
case <-time.After(2 * time.Second):
t.Fatalf("Should have %s", action)
select {
case <-reconnectedCh:
case <-time.After(2 * time.Second):
t.Fatal("Should have reconnected")
}
}

// Now make sure that the pub worked and sub worked.
Expand Down Expand Up @@ -2782,7 +2867,7 @@ func TestRetryOnFailedConnect(t *testing.T) {
nats.MaxReconnects(-1),
nats.ReconnectWait(15*time.Millisecond),
nats.ReconnectHandler(func(_ *nats.Conn) {
ch <- true
reconnectedCh <- true
}),
nats.ClosedHandler(func(_ *nats.Conn) {
closedCh <- true
Expand All @@ -2807,7 +2892,7 @@ func TestRetryOnFailedConnect(t *testing.T) {
}
// Make sure that we did not get the (re)connected CB
select {
case <-ch:
case <-reconnectedCh:
t.Fatal("(re)connected callback should not have been invoked")
default:
}
Expand All @@ -2830,14 +2915,14 @@ func TestRetryOnFailedConnectWithTLSError(t *testing.T) {
s := RunServerWithOptions(&opts)
defer s.Shutdown()

ch := make(chan bool, 1)
connectedCh := make(chan bool, 1)
nc, err := nats.Connect(nats.DefaultURL,
nats.Secure(&tls.Config{InsecureSkipVerify: true}),
nats.RetryOnFailedConnect(true),
nats.MaxReconnects(-1),
nats.ReconnectWait(15*time.Millisecond),
nats.ReconnectHandler(func(_ *nats.Conn) {
ch <- true
nats.ConnectHandler(func(_ *nats.Conn) {
connectedCh <- true
}),
nats.NoCallbacksAfterClientClose())
if err != nil {
Expand All @@ -2854,7 +2939,7 @@ func TestRetryOnFailedConnectWithTLSError(t *testing.T) {
defer s.Shutdown()

select {
case <-ch:
case <-connectedCh:
case <-time.After(time.Second):
t.Fatal("Should have connected")
}
Expand Down

0 comments on commit 10bd240

Please sign in to comment.