diff --git a/example/chat/package.json b/example/chat/package.json index 2c1efd0106a..38aaba9f933 100644 --- a/example/chat/package.json +++ b/example/chat/package.json @@ -4,9 +4,7 @@ "private": true, "dependencies": { "@apollo/client": "^3.2.3", - "@apollo/react-hooks": "^4.0.0", "apollo-cache-inmemory": "^1.3.11", - "apollo-link-ws": "^1.0.10", "apollo-utilities": "^1.0.26", "graphql": "^14.0.2", "graphql-tag": "^2.10.0", @@ -18,6 +16,7 @@ }, "scripts": { "start": "react-scripts start", + "start:graphql-transport-ws": "REACT_APP_WS_PROTOCOL=graphql-transport-ws npm run start", "build": "react-scripts build", "test": "react-scripts test --env=jsdom", "eject": "react-scripts eject" diff --git a/example/chat/readme.md b/example/chat/readme.md index 95b944fd826..416a9fb083a 100644 --- a/example/chat/readme.md +++ b/example/chat/readme.md @@ -1,14 +1,28 @@ -### chat app +# Chat App Example app using subscriptions to build a chat room. -to run this server +### Server ```bash go run ./server/server.go ``` -to run the react app +### Client +The react app uses two different implementation for the websocket link +- [apollo-link-ws](https://www.apollographql.com/docs/react/api/link/apollo-link-ws) which uses the deprecated [subscriptions-transport-ws](https://github.com/apollographql/subscriptions-transport-ws) library +- [graphql-ws](https://github.com/enisdenjo/graphql-ws) + +First you need to install the dependencies ```bash npm install +``` + +Then to run the app with the `apollo-link-ws` implementation do +```bash npm run start ``` + +or to run the app with the `graphql-ws` implementation (and the newer `graphql-transport-ws` protocol) do +```bash +npm run start:graphql-transport-ws +``` diff --git a/example/chat/src/Room.js b/example/chat/src/Room.js index 5fe0af5e988..ac018d24c70 100644 --- a/example/chat/src/Room.js +++ b/example/chat/src/Room.js @@ -1,6 +1,6 @@ import React, { useState, useEffect, useRef } from 'react'; import gql from 'graphql-tag'; -import { useQuery, useMutation } from '@apollo/react-hooks'; +import { useQuery, useMutation } from '@apollo/client'; import { Chat, ChatContainer, Message, MessageReceived } from './components/room'; export const Room = ({ channel, name }) => { diff --git a/example/chat/src/graphql-ws.js b/example/chat/src/graphql-ws.js new file mode 100644 index 00000000000..8b9f49834b1 --- /dev/null +++ b/example/chat/src/graphql-ws.js @@ -0,0 +1,46 @@ +import { createClient } from 'graphql-ws'; +import { print } from 'graphql'; +import { ApolloLink, Observable } from '@apollo/client'; + +export class WebSocketLink extends ApolloLink { + client; + + constructor(options) { + super(); + this.client = createClient(options); + } + + request(operation) { + return new Observable((sink) => { + return this.client.subscribe( + { ...operation, query: print(operation.query) }, + { + next: sink.next.bind(sink), + complete: sink.complete.bind(sink), + error: (err) => { + if (err instanceof Error) { + return sink.error(err); + } + + if (err instanceof CloseEvent) { + return sink.error( + // reason will be available on clean closes + new Error( + `Socket closed with event ${err.code} ${err.reason || ''}`, + ), + ); + } + + return sink.error( + new Error( + err + .map(({ message }) => message) + .join(', '), + ), + ); + }, + }, + ); + }); + } +} \ No newline at end of file diff --git a/example/chat/src/index.js b/example/chat/src/index.js index 7d1fa9b3033..66815636589 100644 --- a/example/chat/src/index.js +++ b/example/chat/src/index.js @@ -7,16 +7,24 @@ import { split, } from '@apollo/client'; import { InMemoryCache } from 'apollo-cache-inmemory'; -import { WebSocketLink } from 'apollo-link-ws'; +import { WebSocketLink as ApolloWebSocketLink} from '@apollo/client/link/ws'; import { getMainDefinition } from 'apollo-utilities'; import { App } from './App'; - -const wsLink = new WebSocketLink({ - uri: `ws://localhost:8085/query`, - options: { - reconnect: true - } -}); +import { WebSocketLink as GraphQLWSWebSocketLink } from './graphql-ws' + +let wsLink; +if (process.env.REACT_APP_WS_PROTOCOL === 'graphql-transport-ws') { + wsLink = new GraphQLWSWebSocketLink({ + url: `ws://localhost:8085/query` + }); +} else { + wsLink = new ApolloWebSocketLink({ + uri: `ws://localhost:8085/query`, + options: { + reconnect: true + } + }); +} const httpLink = new HttpLink({ uri: 'http://localhost:8085/query' }); diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index 787a879f50d..bb148aec2c4 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -17,29 +17,19 @@ import ( "github.com/vektah/gqlparser/v2/gqlerror" ) -const ( - connectionInitMsg = "connection_init" // Client -> Server - connectionTerminateMsg = "connection_terminate" // Client -> Server - startMsg = "start" // Client -> Server - stopMsg = "stop" // Client -> Server - connectionAckMsg = "connection_ack" // Server -> Client - connectionErrorMsg = "connection_error" // Server -> Client - dataMsg = "data" // Server -> Client - errorMsg = "error" // Server -> Client - completeMsg = "complete" // Server -> Client - connectionKeepAliveMsg = "ka" // Server -> Client -) - type ( Websocket struct { Upgrader websocket.Upgrader InitFunc WebsocketInitFunc KeepAlivePingInterval time.Duration + + didInjectSubprotocols bool } wsConnection struct { Websocket ctx context.Context conn *websocket.Conn + me messageExchanger active map[string]context.CancelFunc mu sync.Mutex keepAliveTicker *time.Ticker @@ -47,11 +37,7 @@ type ( initPayload InitPayload } - operationMessage struct { - Payload json.RawMessage `json:"payload,omitempty"` - ID string `json:"id,omitempty"` - Type string `json:"type"` - } + WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) ) @@ -62,20 +48,34 @@ func (t Websocket) Supports(r *http.Request) bool { } func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { - ws, err := t.Upgrader.Upgrade(w, r, http.Header{ - "Sec-Websocket-Protocol": []string{"graphql-ws"}, - }) + t.injectGraphQLWSSubprotocols() + ws, err := t.Upgrader.Upgrade(w, r, http.Header{}) if err != nil { log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error()) SendErrorf(w, http.StatusBadRequest, "unable to upgrade") return } + var me messageExchanger + switch ws.Subprotocol() { + default: + msg := websocket.FormatCloseMessage(websocket.CloseProtocolError, fmt.Sprintf("unsupported negotiated subprotocol %s", ws.Subprotocol())) + ws.WriteMessage(websocket.CloseMessage, msg) + return + case graphqlwsSubprotocol, "": + // clients are required to send a subprotocol, to be backward compatible with the previous implementation we select + // "graphql-ws" by default + me = graphqlwsMessageExchanger{c: ws} + case graphqltransportwsSubprotocol: + me = graphqltransportwsMessageExchanger{c: ws} + } + conn := wsConnection{ active: map[string]context.CancelFunc{}, conn: ws, ctx: r.Context(), exec: exec, + me: me, Websocket: t, } @@ -87,17 +87,21 @@ func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.Graph } func (c *wsConnection) init() bool { - message := c.readOp() - if message == nil { + m, err := c.me.NextMessage() + if err != nil { + if err == errInvalidMsg { + c.sendConnectionError("invalid json") + } + c.close(websocket.CloseProtocolError, "decoding error") return false } - switch message.Type { - case connectionInitMsg: - if len(message.Payload) > 0 { + switch m.t { + case initMessageType: + if len(m.payload) > 0 { c.initPayload = make(InitPayload) - err := json.Unmarshal(message.Payload, &c.initPayload) + err := json.Unmarshal(m.payload, &c.initPayload) if err != nil { return false } @@ -113,13 +117,13 @@ func (c *wsConnection) init() bool { c.ctx = ctx } - c.write(&operationMessage{Type: connectionAckMsg}) - c.write(&operationMessage{Type: connectionKeepAliveMsg}) - case connectionTerminateMsg: + c.write(&message{t: connectionAckMessageType}) + c.write(&message{t: keepAliveMessageType}) + case connectionCloseMessageType: c.close(websocket.CloseNormalClosure, "terminated") return false default: - c.sendConnectionError("unexpected message %s", message.Type) + c.sendConnectionError("unexpected message %s", m.t) c.close(websocket.CloseProtocolError, "unexpected message") return false } @@ -127,9 +131,11 @@ func (c *wsConnection) init() bool { return true } -func (c *wsConnection) write(msg *operationMessage) { +func (c *wsConnection) write(msg *message) { c.mu.Lock() - c.conn.WriteJSON(msg) + // TODO: missing error handling here, err from previous implementation + // was ignored + _ = c.me.Send(msg) c.mu.Unlock() } @@ -153,26 +159,27 @@ func (c *wsConnection) run() { for { start := graphql.Now() - message := c.readOp() - if message == nil { + m, err := c.me.NextMessage() + if err != nil { + // TODO: better error handling here return } - switch message.Type { - case startMsg: - c.subscribe(start, message) - case stopMsg: + switch m.t { + case startMessageType: + c.subscribe(start, &m) + case stopMessageType: c.mu.Lock() - closer := c.active[message.ID] + closer := c.active[m.id] c.mu.Unlock() if closer != nil { closer() } - case connectionTerminateMsg: + case connectionCloseMessageType: c.close(websocket.CloseNormalClosure, "terminated") return default: - c.sendConnectionError("unexpected message %s", message.Type) + c.sendConnectionError("unexpected message %s", m.t) c.close(websocket.CloseProtocolError, "unexpected message") return } @@ -186,17 +193,17 @@ func (c *wsConnection) keepAlive(ctx context.Context) { c.keepAliveTicker.Stop() return case <-c.keepAliveTicker.C: - c.write(&operationMessage{Type: connectionKeepAliveMsg}) + c.write(&message{t: keepAliveMessageType}) } } } -func (c *wsConnection) subscribe(start time.Time, message *operationMessage) { +func (c *wsConnection) subscribe(start time.Time, msg *message) { ctx := graphql.StartOperationTrace(c.ctx) var params *graphql.RawParams - if err := jsonDecode(bytes.NewReader(message.Payload), ¶ms); err != nil { - c.sendError(message.ID, &gqlerror.Error{Message: "invalid json"}) - c.complete(message.ID) + if err := jsonDecode(bytes.NewReader(msg.payload), ¶ms); err != nil { + c.sendError(msg.id, &gqlerror.Error{Message: "invalid json"}) + c.complete(msg.id) return } @@ -210,12 +217,12 @@ func (c *wsConnection) subscribe(start time.Time, message *operationMessage) { resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err) switch errcode.GetErrorKind(err) { case errcode.KindProtocol: - c.sendError(message.ID, resp.Errors...) + c.sendError(msg.id, resp.Errors...) default: - c.sendResponse(message.ID, &graphql.Response{Errors: err}) + c.sendResponse(msg.id, &graphql.Response{Errors: err}) } - c.complete(message.ID) + c.complete(msg.id) return } @@ -227,7 +234,7 @@ func (c *wsConnection) subscribe(start time.Time, message *operationMessage) { ctx, cancel := context.WithCancel(ctx) c.mu.Lock() - c.active[message.ID] = cancel + c.active[msg.id] = cancel c.mu.Unlock() go func() { @@ -241,11 +248,11 @@ func (c *wsConnection) subscribe(start time.Time, message *operationMessage) { gqlerr.Message = err.Error() } } - c.sendError(message.ID, gqlerr) + c.sendError(msg.id, gqlerr) } - c.complete(message.ID) + c.complete(msg.id) c.mu.Lock() - delete(c.active, message.ID) + delete(c.active, msg.id) c.mu.Unlock() cancel() }() @@ -256,8 +263,15 @@ func (c *wsConnection) subscribe(start time.Time, message *operationMessage) { if response == nil { break } - c.sendResponse(message.ID, response) + + c.sendResponse(msg.id, response) } + c.complete(msg.id) + + c.mu.Lock() + delete(c.active, msg.id) + c.mu.Unlock() + cancel() }() } @@ -266,15 +280,15 @@ func (c *wsConnection) sendResponse(id string, response *graphql.Response) { if err != nil { panic(err) } - c.write(&operationMessage{ - Payload: b, - ID: id, - Type: dataMsg, + c.write(&message{ + payload: b, + id: id, + t: dataMessageType, }) } func (c *wsConnection) complete(id string) { - c.write(&operationMessage{ID: id, Type: completeMsg}) + c.write(&message{id: id, t: completeMessageType}) } func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) { @@ -286,7 +300,7 @@ func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) { if err != nil { panic(err) } - c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b}) + c.write(&message{t: errorMessageType, id: id, payload: b}) } func (c *wsConnection) sendConnectionError(format string, args ...interface{}) { @@ -295,24 +309,7 @@ func (c *wsConnection) sendConnectionError(format string, args ...interface{}) { panic(err) } - c.write(&operationMessage{Type: connectionErrorMsg, Payload: b}) -} - -func (c *wsConnection) readOp() *operationMessage { - _, r, err := c.conn.NextReader() - if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { - return nil - } else if err != nil { - c.sendConnectionError("invalid json: %T %s", err, err.Error()) - return nil - } - message := operationMessage{} - if err := jsonDecode(r, &message); err != nil { - c.sendConnectionError("invalid json") - return nil - } - - return &message + c.write(&message{t: connectionErrorMessageType, payload: b}) } func (c *wsConnection) close(closeCode int, message string) { diff --git a/graphql/handler/transport/websocket_graphql_transport_ws.go b/graphql/handler/transport/websocket_graphql_transport_ws.go new file mode 100644 index 00000000000..c998fc9f477 --- /dev/null +++ b/graphql/handler/transport/websocket_graphql_transport_ws.go @@ -0,0 +1,139 @@ +package transport + +import ( + "encoding/json" + "fmt" + + "github.com/gorilla/websocket" +) + +// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md +const ( + graphqltransportwsSubprotocol = "graphql-transport-ws" + + graphqltransportwsConnectionInitMsg = graphqltransportwsMessageType("connection_init") + graphqltransportwsConnectionAckMsg = graphqltransportwsMessageType("connection_ack") + graphqltransportwsSubscribeMsg = graphqltransportwsMessageType("subscribe") + graphqltransportwsNextMsg = graphqltransportwsMessageType("next") + graphqltransportwsErrorMsg = graphqltransportwsMessageType("error") + graphqltransportwsCompleteMsg = graphqltransportwsMessageType("complete") +) + +var ( + allGraphqltransportwsMessageTypes = []graphqltransportwsMessageType{ + graphqltransportwsConnectionInitMsg, + graphqltransportwsConnectionAckMsg, + graphqltransportwsSubscribeMsg, + graphqltransportwsNextMsg, + graphqltransportwsErrorMsg, + graphqltransportwsCompleteMsg, + } +) + +type ( + graphqltransportwsMessageExchanger struct { + c *websocket.Conn + } + + graphqltransportwsMessage struct { + Payload json.RawMessage `json:"payload,omitempty"` + ID string `json:"id,omitempty"` + Type graphqltransportwsMessageType `json:"type"` + noOp bool + } + + graphqltransportwsMessageType string +) + +func (me graphqltransportwsMessageExchanger) NextMessage() (message, error) { + _, r, err := me.c.NextReader() + if err != nil { + return message{}, handleNextReaderError(err) + } + + var graphqltransportwsMessage graphqltransportwsMessage + if err := jsonDecode(r, &graphqltransportwsMessage); err != nil { + return message{}, errInvalidMsg + } + + return graphqltransportwsMessage.toMessage() +} + +func (me graphqltransportwsMessageExchanger) Send(m *message) error { + msg := &graphqltransportwsMessage{} + if err := msg.fromMessage(m); err != nil { + return err + } + + if msg.noOp { + return nil + } + + return me.c.WriteJSON(msg) +} + +func (t *graphqltransportwsMessageType) UnmarshalText(text []byte) (err error) { + var found bool + for _, candidate := range allGraphqltransportwsMessageTypes { + if string(candidate) == string(text) { + *t = candidate + found = true + break + } + } + + if !found { + err = fmt.Errorf("invalid message type %s", string(text)) + } + + return err +} + +func (t graphqltransportwsMessageType) MarshalText() ([]byte, error) { + return []byte(string(t)), nil +} + +func (m graphqltransportwsMessage) toMessage() (message, error) { + var t messageType + var err error + switch m.Type { + default: + err = fmt.Errorf("invalid client->server message type %s", m.Type) + case graphqltransportwsConnectionInitMsg: + t = initMessageType + case graphqltransportwsSubscribeMsg: + t = startMessageType + case graphqltransportwsCompleteMsg: + t = stopMessageType + } + + return message{ + payload: m.Payload, + id: m.ID, + t: t, + }, err +} + +func (m *graphqltransportwsMessage) fromMessage(msg *message) (err error) { + m.ID = msg.id + m.Payload = msg.payload + + switch msg.t { + default: + err = fmt.Errorf("invalid server->client message type %s", msg.t) + case connectionAckMessageType: + m.Type = graphqltransportwsConnectionAckMsg + case keepAliveMessageType: + m.noOp = true + case connectionErrorMessageType: + m.noOp = true + case dataMessageType: + m.Type = graphqltransportwsNextMsg + case completeMessageType: + m.Type = graphqltransportwsCompleteMsg + case errorMessageType: + m.Type = graphqltransportwsErrorMsg + } + + return err +} diff --git a/graphql/handler/transport/websocket_graphqlws.go b/graphql/handler/transport/websocket_graphqlws.go new file mode 100644 index 00000000000..f8e05caf4a7 --- /dev/null +++ b/graphql/handler/transport/websocket_graphqlws.go @@ -0,0 +1,171 @@ +package transport + +import ( + "encoding/json" + "fmt" + + "github.com/gorilla/websocket" +) + +// https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +const ( + graphqlwsSubprotocol = "graphql-ws" + + graphqlwsConnectionInitMsg = graphqlwsMessageType("connection_init") + graphqlwsConnectionTerminateMsg = graphqlwsMessageType("connection_terminate") + graphqlwsStartMsg = graphqlwsMessageType("start") + graphqlwsStopMsg = graphqlwsMessageType("stop") + graphqlwsConnectionAckMsg = graphqlwsMessageType("connection_ack") + graphqlwsConnectionErrorMsg = graphqlwsMessageType("connection_error") + graphqlwsDataMsg = graphqlwsMessageType("data") + graphqlwsErrorMsg = graphqlwsMessageType("error") + graphqlwsCompleteMsg = graphqlwsMessageType("complete") + graphqlwsConnectionKeepAliveMsg = graphqlwsMessageType("ka") +) + +var ( + allGraphqlwsMessageTypes = []graphqlwsMessageType{ + graphqlwsConnectionInitMsg, + graphqlwsConnectionTerminateMsg, + graphqlwsStartMsg, + graphqlwsStopMsg, + graphqlwsConnectionAckMsg, + graphqlwsConnectionErrorMsg, + graphqlwsDataMsg, + graphqlwsErrorMsg, + graphqlwsCompleteMsg, + graphqlwsConnectionKeepAliveMsg, + } +) + +type ( + graphqlwsMessageExchanger struct { + c *websocket.Conn + } + + graphqlwsMessage struct { + Payload json.RawMessage `json:"payload,omitempty"` + ID string `json:"id,omitempty"` + Type graphqlwsMessageType `json:"type"` + } + + graphqlwsMessageType string +) + +func (me graphqlwsMessageExchanger) NextMessage() (message, error) { + _, r, err := me.c.NextReader() + if err != nil { + return message{}, handleNextReaderError(err) + } + + var graphqlwsMessage graphqlwsMessage + if err := jsonDecode(r, &graphqlwsMessage); err != nil { + return message{}, errInvalidMsg + } + + return graphqlwsMessage.toMessage() +} + +func (me graphqlwsMessageExchanger) Send(m *message) error { + msg := &graphqlwsMessage{} + if err := msg.fromMessage(m); err != nil { + return err + } + + return me.c.WriteJSON(msg) +} + +func (t *graphqlwsMessageType) UnmarshalText(text []byte) (err error) { + var found bool + for _, candidate := range allGraphqlwsMessageTypes { + if string(candidate) == string(text) { + *t = candidate + found = true + break + } + } + + if !found { + err = fmt.Errorf("invalid message type %s", string(text)) + } + + return err +} + +func (t graphqlwsMessageType) MarshalText() ([]byte, error) { + return []byte(string(t)), nil +} + +func (t graphqlwsMessageType) toMessageType() (mt messageType, err error) { + switch t { + default: + err = fmt.Errorf("unknown message type mapping for %s", t) + case graphqlwsConnectionInitMsg: + mt = initMessageType + case graphqlwsConnectionTerminateMsg: + mt = connectionCloseMessageType + case graphqlwsStartMsg: + mt = startMessageType + case graphqlwsStopMsg: + mt = stopMessageType + case graphqlwsConnectionAckMsg: + mt = connectionAckMessageType + case graphqlwsConnectionErrorMsg: + mt = connectionErrorMessageType + case graphqlwsDataMsg: + mt = dataMessageType + case graphqlwsErrorMsg: + mt = errorMessageType + case graphqlwsCompleteMsg: + mt = completeMessageType + case graphqlwsConnectionKeepAliveMsg: + mt = keepAliveMessageType + } + + return mt, err +} + +func (t *graphqlwsMessageType) fromMessageType(mt messageType) (err error) { + switch mt { + default: + err = fmt.Errorf("failed to convert message %s to %s subprotocol", mt, graphqlwsSubprotocol) + case initMessageType: + *t = graphqlwsConnectionInitMsg + case connectionAckMessageType: + *t = graphqlwsConnectionAckMsg + case keepAliveMessageType: + *t = graphqlwsConnectionKeepAliveMsg + case connectionErrorMessageType: + *t = graphqlwsConnectionErrorMsg + case connectionCloseMessageType: + *t = graphqlwsConnectionTerminateMsg + case startMessageType: + *t = graphqlwsStartMsg + case stopMessageType: + *t = graphqlwsStopMsg + case dataMessageType: + *t = graphqlwsDataMsg + case completeMessageType: + *t = graphqlwsCompleteMsg + case errorMessageType: + *t = graphqlwsErrorMsg + } + + return err +} + +func (m graphqlwsMessage) toMessage() (message, error) { + mt, err := m.Type.toMessageType() + return message{ + payload: m.Payload, + id: m.ID, + t: mt, + }, err +} + +func (m *graphqlwsMessage) fromMessage(msg *message) (err error) { + err = m.Type.fromMessageType(msg.t) + m.ID = msg.id + m.Payload = msg.payload + return err +} diff --git a/graphql/handler/transport/websocket_subprotocol.go b/graphql/handler/transport/websocket_subprotocol.go new file mode 100644 index 00000000000..f47bfee8092 --- /dev/null +++ b/graphql/handler/transport/websocket_subprotocol.go @@ -0,0 +1,110 @@ +package transport + +import ( + "encoding/json" + "errors" + + "github.com/gorilla/websocket" +) + +const ( + initMessageType messageType = iota + connectionAckMessageType + keepAliveMessageType + connectionErrorMessageType + connectionCloseMessageType + startMessageType + stopMessageType + dataMessageType + completeMessageType + errorMessageType +) + +var ( + supportedSubprotocols = []string{ + graphqlwsSubprotocol, + graphqltransportwsSubprotocol, + } + + errWsConnClosed = errors.New("websocket connection closed") + errInvalidMsg = errors.New("invalid message received") +) + +type ( + messageType int + message struct { + payload json.RawMessage + id string + t messageType + } + messageExchanger interface { + NextMessage() (message, error) + Send(m *message) error + } +) + +func (t messageType) String() string { + var text string + switch t { + default: + text = "unknown" + case initMessageType: + text = "init" + case connectionAckMessageType: + text = "connection ack" + case keepAliveMessageType: + text = "keep alive" + case connectionErrorMessageType: + text = "connection error" + case connectionCloseMessageType: + text = "connection close" + case startMessageType: + text = "start" + case stopMessageType: + text = "stop subscription" + case dataMessageType: + text = "data" + case completeMessageType: + text = "complete" + case errorMessageType: + text = "error" + } + return text +} + +func contains(list []string, elem string) bool { + for _, e := range list { + if e == elem { + return true + } + } + + return false +} + +func (t *Websocket) injectGraphQLWSSubprotocols() { + // the list of subprotocols is specified by the consumer of the Websocket struct, + // in order to preserve backward compatibility, we inject the graphql specific subprotocols + // at runtime + if !t.didInjectSubprotocols { + defer func() { + t.didInjectSubprotocols = true + }() + + for _, subprotocol := range supportedSubprotocols { + if !contains(t.Upgrader.Subprotocols, subprotocol) { + t.Upgrader.Subprotocols = append(t.Upgrader.Subprotocols, subprotocol) + } + } + } +} + +func handleNextReaderError(err error) error { + // TODO: should we consider all closure scenarios here for the ws connection? + // for now we only list the error codes from the previous implementation + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { + return errWsConnClosed + } + + return err +} diff --git a/graphql/handler/transport/websocket_test.go b/graphql/handler/transport/websocket_test.go index 566777b621a..4280b1205f9 100644 --- a/graphql/handler/transport/websocket_test.go +++ b/graphql/handler/transport/websocket_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "net/http" "net/http/httptest" "strings" "testing" @@ -267,8 +268,67 @@ func TestWebsocketInitFunc(t *testing.T) { }) } +func TestWebsocketGraphqltransportwsSubprotocol(t *testing.T) { + handler := testserver.New() + handler.AddTransport(transport.Websocket{}) + + srv := httptest.NewServer(handler) + defer srv.Close() + + t.Run("server acks init", func(t *testing.T) { + c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) + defer c.Close() + + require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) + + assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) + }) + + t.Run("client can receive data", func(t *testing.T) { + c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol) + defer c.Close() + + require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg})) + assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type) + + require.NoError(t, c.WriteJSON(&operationMessage{ + Type: graphqltransportwsSubscribeMsg, + ID: "test_1", + Payload: json.RawMessage(`{"query": "subscription { name }"}`), + })) + + handler.SendNextSubscriptionMessage() + msg := readOp(c) + require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload)) + require.Equal(t, "test_1", msg.ID, string(msg.Payload)) + require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) + + handler.SendNextSubscriptionMessage() + msg = readOp(c) + require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload)) + require.Equal(t, "test_1", msg.ID, string(msg.Payload)) + require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload)) + + require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsCompleteMsg, ID: "test_1"})) + + msg = readOp(c) + require.Equal(t, graphqltransportwsCompleteMsg, msg.Type) + require.Equal(t, "test_1", msg.ID) + }) +} + func wsConnect(url string) *websocket.Conn { - c, resp, err := websocket.DefaultDialer.Dial(strings.ReplaceAll(url, "http://", "ws://"), nil) + + return wsConnectWithSubprocotol(url, "") +} + +func wsConnectWithSubprocotol(url, subprocotol string) *websocket.Conn { + h := make(http.Header) + if subprocotol != "" { + h.Add("Sec-WebSocket-Protocol", subprocotol) + } + + c, resp, err := websocket.DefaultDialer.Dial(strings.ReplaceAll(url, "http://", "ws://"), h) if err != nil { panic(err) } @@ -291,7 +351,7 @@ func readOp(conn *websocket.Conn) operationMessage { return msg } -// copied out from weboscket.go to keep these private +// copied out from websocket_graphqlws.go to keep these private const ( connectionInitMsg = "connection_init" // Client -> Server @@ -306,6 +366,18 @@ const ( connectionKeepAliveMsg = "ka" // Server -> Client ) +// copied out from websocket_graphql_transport_ws.go to keep these private + +const ( + graphqltransportwsSubprotocol = "graphql-transport-ws" + + graphqltransportwsConnectionInitMsg = "connection_init" + graphqltransportwsConnectionAckMsg = "connection_ack" + graphqltransportwsSubscribeMsg = "subscribe" + graphqltransportwsNextMsg = "next" + graphqltransportwsCompleteMsg = "complete" +) + type operationMessage struct { Payload json.RawMessage `json:"payload,omitempty"` ID string `json:"id,omitempty"`