Skip to content

Commit

Permalink
add close flag into wsConnection to avoid duplicate calls of CloseFunc (
Browse files Browse the repository at this point in the history
#2803)

* add close flag into wsConnection to avoid duplicate calls of CloseFunc

* add test

* Fix linter error
  • Loading branch information
vlad-tokarev authored Sep 22, 2023
1 parent af4d394 commit 3e2393f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
6 changes: 6 additions & 0 deletions graphql/handler/transport/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type (
keepAliveTicker *time.Ticker
pingPongTicker *time.Ticker
exec graphql.GraphExecutor
closed bool

initPayload InitPayload
}
Expand Down Expand Up @@ -441,10 +442,15 @@ func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {

func (c *wsConnection) close(closeCode int, message string) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
for _, closer := range c.active {
closer()
}
c.closed = true
c.mu.Unlock()
_ = c.conn.Close()

Expand Down
33 changes: 33 additions & 0 deletions graphql/handler/transport/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,39 @@ func TestWebSocketCloseFunc(t *testing.T) {
}
})

t.Run("the on close handler gets called only once when the websocket is closed", func(t *testing.T) {
closeFuncCalled := make(chan bool, 1)
h := testserver.New()
h.AddTransport(transport.Websocket{
CloseFunc: func(_ context.Context, _closeCode int) {
closeFuncCalled <- true
},
})

srv := httptest.NewServer(h)
defer srv.Close()

c := wsConnect(srv.URL)
require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
assert.Equal(t, connectionAckMsg, readOp(c).Type)
assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))

select {
case res := <-closeFuncCalled:
assert.True(t, res)
case <-time.NewTimer(time.Millisecond * 20).C:
assert.Fail(t, "The close handler was not called in time")
}

select {
case <-closeFuncCalled:
assert.Fail(t, "The close handler was called more than once")
case <-time.NewTimer(time.Millisecond * 20).C:
// ok
}
})

t.Run("init func errors call the close handler", func(t *testing.T) {
h := testserver.New()
closeFuncCalled := make(chan bool, 1)
Expand Down

0 comments on commit 3e2393f

Please sign in to comment.