Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add websocket keepalive support #530

Merged
merged 3 commits into from
Feb 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gometalinter.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
}
},
"Skip": ["internal/imports/testdata"],
"Disable": ["gas","golint","gocyclo","goconst", "gotype", "maligned", "gosec"]
"Disable": ["gas","golint","gocyclo","goconst", "gotype", "maligned", "gosec", "staticcheck"]
}
28 changes: 19 additions & 9 deletions handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"strings"
"time"

"github.com/99designs/gqlgen/complexity"
"github.com/99designs/gqlgen/graphql"
Expand All @@ -25,15 +26,16 @@ type params struct {
}

type Config struct {
cacheSize int
upgrader websocket.Upgrader
recover graphql.RecoverFunc
errorPresenter graphql.ErrorPresenterFunc
resolverHook graphql.FieldMiddleware
requestHook graphql.RequestMiddleware
tracer graphql.Tracer
complexityLimit int
disableIntrospection bool
cacheSize int
upgrader websocket.Upgrader
recover graphql.RecoverFunc
errorPresenter graphql.ErrorPresenterFunc
resolverHook graphql.FieldMiddleware
requestHook graphql.RequestMiddleware
tracer graphql.Tracer
complexityLimit int
disableIntrospection bool
connectionKeepAlivePingInterval time.Duration
}

func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
Expand Down Expand Up @@ -243,6 +245,14 @@ func CacheSize(size int) Option {

const DefaultCacheSize = 1000

// WebsocketKeepAliveDuration allows you to reconfigure the keepAlive behavior.
// By default, keep-alive is disabled.
func WebsocketKeepAliveDuration(duration time.Duration) Option {
return func(cfg *Config) {
cfg.connectionKeepAlivePingInterval = duration
}
}

func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
cfg := &Config{
cacheSize: DefaultCacheSize,
Expand Down
42 changes: 35 additions & 7 deletions handler/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"log"
"net/http"
"sync"
"time"

"github.com/99designs/gqlgen/graphql"
"github.com/gorilla/websocket"
Expand All @@ -27,7 +28,7 @@ const (
dataMsg = "data" // Server -> Client
errorMsg = "error" // Server -> Client
completeMsg = "complete" // Server -> Client
//connectionKeepAliveMsg = "ka" // Server -> Client TODO: keepalives
connectionKeepAliveMsg = "ka" // Server -> Client
)

type operationMessage struct {
Expand All @@ -37,12 +38,13 @@ type operationMessage struct {
}

type wsConnection struct {
ctx context.Context
conn *websocket.Conn
exec graphql.ExecutableSchema
active map[string]context.CancelFunc
mu sync.Mutex
cfg *Config
ctx context.Context
conn *websocket.Conn
exec graphql.ExecutableSchema
active map[string]context.CancelFunc
mu sync.Mutex
cfg *Config
keepAliveTicker *time.Ticker

initPayload InitPayload
}
Expand Down Expand Up @@ -109,6 +111,20 @@ func (c *wsConnection) write(msg *operationMessage) {
}

func (c *wsConnection) run() {
// We create a cancellation that will shutdown the keep-alive when we leave
// this function.
ctx, cancel := context.WithCancel(c.ctx)
defer cancel()

// Create a timer that will fire every interval to keep the connection alive.
if c.cfg.connectionKeepAlivePingInterval != 0 {
c.mu.Lock()
c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval)
c.mu.Unlock()

go c.keepAlive(ctx)
}

for {
message := c.readOp()
if message == nil {
Expand Down Expand Up @@ -141,6 +157,18 @@ func (c *wsConnection) run() {
}
}

func (c *wsConnection) keepAlive(ctx context.Context) {
for {
select {
case <-ctx.Done():
c.keepAliveTicker.Stop()
return
case <-c.keepAliveTicker.C:
c.write(&operationMessage{Type: connectionKeepAliveMsg})
}
}
}

func (c *wsConnection) subscribe(message *operationMessage) bool {
var reqParams params
if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
Expand Down
36 changes: 36 additions & 0 deletions handler/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -122,6 +123,41 @@ func TestWebsocket(t *testing.T) {
})
}

func TestWebsocketWithKeepAlive(t *testing.T) {
next := make(chan struct{})
h := GraphQL(&executableSchemaStub{next}, WebsocketKeepAliveDuration(10*time.Millisecond))

srv := httptest.NewServer(h)
defer srv.Close()

t.Run("client must receive keepalive", func(t *testing.T) {
c := wsConnect(srv.URL)
defer c.Close()

require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
require.Equal(t, connectionAckMsg, readOp(c).Type)

require.NoError(t, c.WriteJSON(&operationMessage{
Type: startMsg,
ID: "test_1",
Payload: json.RawMessage(`{"query": "subscription { user { title } }"}`),
}))

// keepalive
msg := readOp(c)
require.Equal(t, connectionKeepAliveMsg, msg.Type)

// server message
next <- struct{}{}
msg = readOp(c)
require.Equal(t, dataMsg, msg.Type)

// keepalive
msg = readOp(c)
require.Equal(t, connectionKeepAliveMsg, msg.Type)
})
}

func wsConnect(url string) *websocket.Conn {
c, _, err := websocket.DefaultDialer.Dial(strings.Replace(url, "http://", "ws://", -1), nil)
if err != nil {
Expand Down