diff --git a/graphql.go b/graphql.go index 9ccc3592..b3b7048a 100644 --- a/graphql.go +++ b/graphql.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "reflect" + "time" "github.com/graph-gophers/graphql-go/errors" "github.com/graph-gophers/graphql-go/internal/common" @@ -64,13 +65,14 @@ type Schema struct { schema *schema.Schema res *resolvable.Schema - maxDepth int - maxParallelism int - tracer trace.Tracer - validationTracer trace.ValidationTracer - logger log.Logger - useStringDescriptions bool - disableIntrospection bool + maxDepth int + maxParallelism int + tracer trace.Tracer + validationTracer trace.ValidationTracer + logger log.Logger + useStringDescriptions bool + disableIntrospection bool + subscribeResolverTimeout time.Duration } // SchemaOpt is an option to pass to ParseSchema or MustParseSchema. @@ -135,6 +137,15 @@ func DisableIntrospection() SchemaOpt { } } +// SubscribeResolverTimeout is an option to control the amount of time +// we allow for a single subscribe message resolver to complete it's job +// before it times out and returns an error to the subscriber. +func SubscribeResolverTimeout(timeout time.Duration) SchemaOpt { + return func(s *Schema) { + s.subscribeResolverTimeout = timeout + } +} + // Response represents a typical response of a GraphQL server. It may be encoded to JSON directly or // it may be further processed to a custom response type, for example to include custom error data. // Errors are intentionally serialized first based on the advice in https://github.com/facebook/graphql/commit/7b40390d48680b15cb93e02d46ac5eb249689876#diff-757cea6edf0288677a9eea4cfc801d87R107 @@ -190,7 +201,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str // Subscriptions are not valid in Exec. Use schema.Subscribe() instead. if op.Type == query.Subscription { - return &Response{Errors: []*errors.QueryError{&errors.QueryError{Message: "graphql-ws protocol header is missing"}}} + return &Response{Errors: []*errors.QueryError{{Message: "graphql-ws protocol header is missing"}}} } if op.Type == query.Mutation { if _, ok := s.schema.EntryPoints["mutation"]; !ok { diff --git a/internal/exec/exec.go b/internal/exec/exec.go index bdb8bed1..1db65464 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -7,6 +7,7 @@ import ( "fmt" "reflect" "sync" + "time" "github.com/graph-gophers/graphql-go/errors" "github.com/graph-gophers/graphql-go/internal/common" @@ -20,9 +21,10 @@ import ( type Request struct { selected.Request - Limiter chan struct{} - Tracer trace.Tracer - Logger log.Logger + Limiter chan struct{} + Tracer trace.Tracer + Logger log.Logger + SubscribeResolverTimeout time.Duration } func (r *Request) handlePanic(ctx context.Context) { diff --git a/internal/exec/subscribe.go b/internal/exec/subscribe.go index 16cae69c..246f7e1f 100644 --- a/internal/exec/subscribe.go +++ b/internal/exec/subscribe.go @@ -115,8 +115,12 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query } var out bytes.Buffer func() { - // TODO: configurable timeout - subCtx, cancel := context.WithTimeout(ctx, time.Second) + timeout := r.SubscribeResolverTimeout + if timeout == 0 { + timeout = time.Second + } + + subCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() // resolve response diff --git a/subscription_test.go b/subscription_test.go index 96649d94..a8162eca 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "testing" + "time" graphql "github.com/graph-gophers/graphql-go" qerrors "github.com/graph-gophers/graphql-go/errors" @@ -473,3 +474,50 @@ const schema = ` hello: String! } ` + +type subscriptionsCustomTimeout struct{} + +type messageResolver struct{} + +func (r messageResolver) Msg() string { + time.Sleep(5 * time.Millisecond) + return "failed!" +} + +func (r *subscriptionsCustomTimeout) OnTimeout() <-chan *messageResolver { + c := make(chan *messageResolver) + go func() { + c <- &messageResolver{} + close(c) + }() + + return c +} + +func TestSchemaSubscribe_CustomResolverTimeout(t *testing.T) { + r := &struct { + *subscriptionsCustomTimeout + }{ + subscriptionsCustomTimeout: &subscriptionsCustomTimeout{}, + } + gqltesting.RunSubscribe(t, &gqltesting.TestSubscription{ + Schema: graphql.MustParseSchema(` + type Query {} + type Subscription { + onTimeout : Message! + } + + type Message { + msg: String! + } + `, r, graphql.SubscribeResolverTimeout(1*time.Millisecond)), + Query: ` + subscription { + onTimeout { msg } + } + `, + ExpectedResults: []gqltesting.TestResponse{ + {Errors: []*qerrors.QueryError{{Message: "context deadline exceeded"}}}, + }, + }) +} diff --git a/subscriptions.go b/subscriptions.go index 4199c06d..0709796a 100644 --- a/subscriptions.go +++ b/subscriptions.go @@ -54,9 +54,10 @@ func (s *Schema) subscribe(ctx context.Context, queryString string, operationNam Vars: variables, Schema: s.schema, }, - Limiter: make(chan struct{}, s.maxParallelism), - Tracer: s.tracer, - Logger: s.logger, + Limiter: make(chan struct{}, s.maxParallelism), + Tracer: s.tracer, + Logger: s.logger, + SubscribeResolverTimeout: s.subscribeResolverTimeout, } varTypes := make(map[string]*introspection.Type) for _, v := range op.Vars {