diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index a0429b7481d..51b1104cccd 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -49,7 +49,24 @@ type ( var errReadTimeout = errors.New("read timeout") -var _ graphql.Transport = Websocket{} +type WebsocketError struct { + Err error + + // IsReadError flags whether the error occurred on read or write to the websocket + IsReadError bool +} + +func (e WebsocketError) Error() string { + if e.IsReadError { + return fmt.Sprintf("websocket read: %v", e.Err) + } + return fmt.Sprintf("websocket write: %v", e.Err) +} + +var ( + _ graphql.Transport = Websocket{} + _ error = WebsocketError{} +) func (t Websocket) Supports(r *http.Request) bool { return r.Header.Get("Upgrade") != "" @@ -94,9 +111,12 @@ func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.Graph conn.run() } -func (c *wsConnection) handlePossibleError(err error) { +func (c *wsConnection) handlePossibleError(err error, isReadError bool) { if c.ErrorFunc != nil && err != nil { - c.ErrorFunc(c.ctx, err) + c.ErrorFunc(c.ctx, WebsocketError{ + Err: err, + IsReadError: isReadError, + }) } } @@ -181,7 +201,7 @@ func (c *wsConnection) init() bool { func (c *wsConnection) write(msg *message) { c.mu.Lock() - c.handlePossibleError(c.me.Send(msg)) + c.handlePossibleError(c.me.Send(msg), false) c.mu.Unlock() } @@ -227,7 +247,7 @@ func (c *wsConnection) run() { if err != nil { // If the connection got closed by us, don't report the error if !errors.Is(err, net.ErrClosed) { - c.handlePossibleError(err) + c.handlePossibleError(err, true) } return } diff --git a/graphql/handler/transport/websocket_test.go b/graphql/handler/transport/websocket_test.go index 7c84f352627..fb2a07bf8c2 100644 --- a/graphql/handler/transport/websocket_test.go +++ b/graphql/handler/transport/websocket_test.go @@ -356,7 +356,9 @@ func TestWebSocketErrorFunc(t *testing.T) { h.AddTransport(transport.Websocket{ ErrorFunc: func(_ context.Context, err error) { require.Error(t, err) - assert.Equal(t, err.Error(), "invalid message received") + assert.Equal(t, err.Error(), "websocket read: invalid message received") + assert.IsType(t, transport.WebsocketError{}, err) + assert.True(t, err.(transport.WebsocketError).IsReadError) errFuncCalled <- true }, })