diff --git a/handler/graphql.go b/handler/graphql.go index c918d649daf..7c5f70cfdfd 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -26,16 +26,16 @@ type params struct { } type Config struct { - cacheSize int - upgrader websocket.Upgrader - recover graphql.RecoverFunc - errorPresenter graphql.ErrorPresenterFunc - resolverHook graphql.FieldMiddleware - requestHook graphql.RequestMiddleware - tracer graphql.Tracer - complexityLimit int - disableIntrospection bool - connectionKeepAliveTimeout time.Duration + cacheSize int + upgrader websocket.Upgrader + recover graphql.RecoverFunc + errorPresenter graphql.ErrorPresenterFunc + resolverHook graphql.FieldMiddleware + requestHook graphql.RequestMiddleware + tracer graphql.Tracer + complexityLimit int + disableIntrospection bool + connectionKeepAlivePingInterval time.Duration } func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext { @@ -249,7 +249,7 @@ const DefaultCacheSize = 1000 // By default, keep-alive is disabled. func WebsocketKeepAliveDuration(duration time.Duration) Option { return func(cfg *Config) { - cfg.connectionKeepAliveTimeout = duration + cfg.connectionKeepAlivePingInterval = duration } } diff --git a/handler/websocket.go b/handler/websocket.go index 327182d5c89..09800c172b4 100644 --- a/handler/websocket.go +++ b/handler/websocket.go @@ -38,13 +38,13 @@ type operationMessage struct { } type wsConnection struct { - ctx context.Context - conn *websocket.Conn - exec graphql.ExecutableSchema - active map[string]context.CancelFunc - mu sync.Mutex - cfg *Config - keepAliveTimer *time.Timer + ctx context.Context + conn *websocket.Conn + exec graphql.ExecutableSchema + active map[string]context.CancelFunc + mu sync.Mutex + cfg *Config + keepAliveTicker *time.Ticker initPayload InitPayload } @@ -107,9 +107,6 @@ func (c *wsConnection) init() bool { func (c *wsConnection) write(msg *operationMessage) { c.mu.Lock() c.conn.WriteJSON(msg) - if c.cfg.connectionKeepAliveTimeout != 0 && c.keepAliveTimer != nil { - c.keepAliveTimer.Reset(c.cfg.connectionKeepAliveTimeout) - } c.mu.Unlock() } @@ -119,11 +116,10 @@ func (c *wsConnection) run() { ctx, cancel := context.WithCancel(c.ctx) defer cancel() - // Create a timer that will fire every interval if a write hasn't been made - // to keep the connection alive. - if c.cfg.connectionKeepAliveTimeout != 0 { + // Create a timer that will fire every interval to keep the connection alive. + if c.cfg.connectionKeepAlivePingInterval != 0 { c.mu.Lock() - c.keepAliveTimer = time.NewTimer(c.cfg.connectionKeepAliveTimeout) + c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval) c.mu.Unlock() go c.keepAlive(ctx) @@ -165,13 +161,9 @@ func (c *wsConnection) keepAlive(ctx context.Context) { for { select { case <-ctx.Done(): - if !c.keepAliveTimer.Stop() { - <-c.keepAliveTimer.C - } + c.keepAliveTicker.Stop() return - case <-c.keepAliveTimer.C: - // We don't reset the timer here, because the `c.write` command - // will reset the timer anyways. + case <-c.keepAliveTicker.C: c.write(&operationMessage{Type: connectionKeepAliveMsg}) } }