From 3e2393f3afe1f3f530c34f4dbbece47bde6e736b Mon Sep 17 00:00:00 2001 From: Vlad Tokarev Date: Fri, 22 Sep 2023 23:23:14 +0300 Subject: [PATCH] add close flag into wsConnection to avoid duplicate calls of CloseFunc (#2803) * add close flag into wsConnection to avoid duplicate calls of CloseFunc * add test * Fix linter error --- graphql/handler/transport/websocket.go | 6 ++++ graphql/handler/transport/websocket_test.go | 33 +++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index ed1d9588c9c..8b375465c0c 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -40,6 +40,7 @@ type ( keepAliveTicker *time.Ticker pingPongTicker *time.Ticker exec graphql.GraphExecutor + closed bool initPayload InitPayload } @@ -441,10 +442,15 @@ func (c *wsConnection) sendConnectionError(format string, args ...interface{}) { func (c *wsConnection) close(closeCode int, message string) { c.mu.Lock() + if c.closed { + c.mu.Unlock() + return + } _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message)) for _, closer := range c.active { closer() } + c.closed = true c.mu.Unlock() _ = c.conn.Close() diff --git a/graphql/handler/transport/websocket_test.go b/graphql/handler/transport/websocket_test.go index 678bfeab25c..d7c0a4c7439 100644 --- a/graphql/handler/transport/websocket_test.go +++ b/graphql/handler/transport/websocket_test.go @@ -500,6 +500,39 @@ func TestWebSocketCloseFunc(t *testing.T) { } }) + t.Run("the on close handler gets called only once when the websocket is closed", func(t *testing.T) { + closeFuncCalled := make(chan bool, 1) + h := testserver.New() + h.AddTransport(transport.Websocket{ + CloseFunc: func(_ context.Context, _closeCode int) { + closeFuncCalled <- true + }, + }) + + srv := httptest.NewServer(h) + defer srv.Close() + + c := wsConnect(srv.URL) + require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) + assert.Equal(t, connectionAckMsg, readOp(c).Type) + assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) + require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg})) + + select { + case res := <-closeFuncCalled: + assert.True(t, res) + case <-time.NewTimer(time.Millisecond * 20).C: + assert.Fail(t, "The close handler was not called in time") + } + + select { + case <-closeFuncCalled: + assert.Fail(t, "The close handler was called more than once") + case <-time.NewTimer(time.Millisecond * 20).C: + // ok + } + }) + t.Run("init func errors call the close handler", func(t *testing.T) { h := testserver.New() closeFuncCalled := make(chan bool, 1)