From 8359f9749e6fd54be20325ff6aafb05503124238 Mon Sep 17 00:00:00 2001 From: foreverest Date: Wed, 13 Oct 2021 08:41:24 -0700 Subject: [PATCH] Allow custom websocket upgrader (#1595) --- docs/content/recipes/cors.md | 2 +- example/chat/server/server.go | 2 +- graphql/handler/transport/websocket.go | 10 ++++++++-- handler/handler.go | 5 ++--- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/docs/content/recipes/cors.md b/docs/content/recipes/cors.md index 0b0ac114bfa..129768410d3 100644 --- a/docs/content/recipes/cors.md +++ b/docs/content/recipes/cors.md @@ -40,7 +40,7 @@ func main() { srv := handler.NewDefaultServer(starwars.NewExecutableSchema(starwars.NewResolver())) srv.AddTransport(&transport.Websocket{ - Upgrader: websocket.Upgrader{ + Upgrader: &websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { // Check against your desired domains here return r.Host == "example.org" diff --git a/example/chat/server/server.go b/example/chat/server/server.go index 1592f6b0ab5..e07595319ca 100644 --- a/example/chat/server/server.go +++ b/example/chat/server/server.go @@ -34,7 +34,7 @@ func main() { srv.AddTransport(transport.POST{}) srv.AddTransport(transport.Websocket{ KeepAlivePingInterval: 10 * time.Second, - Upgrader: websocket.Upgrader{ + Upgrader: &websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index 787a879f50d..94ac293c827 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -32,7 +32,7 @@ const ( type ( Websocket struct { - Upgrader websocket.Upgrader + Upgrader WebsocketUpgrader InitFunc WebsocketInitFunc KeepAlivePingInterval time.Duration } @@ -52,10 +52,16 @@ type ( ID string `json:"id,omitempty"` Type string `json:"type"` } + WebsocketUpgrader interface { + Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, error) + } WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) ) -var _ graphql.Transport = Websocket{} +var ( + _ graphql.Transport = Websocket{} + _ WebsocketUpgrader = &websocket.Upgrader{} +) func (t Websocket) Supports(r *http.Request) bool { return r.Header.Get("Upgrade") != "" diff --git a/handler/handler.go b/handler/handler.go index 892df53986a..8eb2680a655 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -11,7 +11,6 @@ import ( "github.com/99designs/gqlgen/graphql/handler/lru" "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/99designs/gqlgen/graphql/playground" - "github.com/gorilla/websocket" ) // Deprecated: switch to graphql/handler.New @@ -74,7 +73,7 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc // Deprecated: switch to graphql/handler.New type Config struct { cacheSize int - upgrader websocket.Upgrader + upgrader transport.WebsocketUpgrader websocketInitFunc transport.WebsocketInitFunc connectionKeepAlivePingInterval time.Duration recover graphql.RecoverFunc @@ -93,7 +92,7 @@ type Config struct { type Option func(cfg *Config) // Deprecated: switch to graphql/handler.New -func WebsocketUpgrader(upgrader websocket.Upgrader) Option { +func WebsocketUpgrader(upgrader transport.WebsocketUpgrader) Option { return func(cfg *Config) { cfg.upgrader = upgrader }