From 0c1149a9099b8816c2c19091aeb8618be12ff0c0 Mon Sep 17 00:00:00 2001 From: Chris Pride Date: Fri, 15 Apr 2022 07:16:56 +0000 Subject: [PATCH] Fix the ability of websockets to get errors Because DispatchOperation creates tempResponseContext, which is passed into Exec, which is then used in _Subscription to generate the next function. Inside the various subscription functions when generating next the context was captured there. Which means later when the returned function from DispatchOperation is called. The responseContext which accumulates the errors is the tempResponseContext which we no longer have access to to read the errors out of it. Instead add a context to next() so that it can be passed through and accumulated the errors as expected. Added a unit test for this as well. --- _examples/chat/generated.go | 18 +-- codegen/directives.gotpl | 10 +- codegen/field.gotpl | 4 +- codegen/generated!.gotpl | 2 +- codegen/object.gotpl | 2 +- codegen/root_.gotpl | 2 +- codegen/testserver/followschema/nulls.graphql | 4 + codegen/testserver/followschema/resolver.go | 4 + .../followschema/root_.generated.go | 14 ++- .../followschema/schema.generated.go | 94 ++++++++++++--- codegen/testserver/followschema/stub.go | 4 + .../followschema/subscription_test.go | 53 +++++++++ codegen/testserver/singlefile/generated.go | 108 +++++++++++++++--- codegen/testserver/singlefile/nulls.graphql | 4 + codegen/testserver/singlefile/resolver.go | 4 + codegen/testserver/singlefile/stub.go | 4 + .../singlefile/subscription_test.go | 53 +++++++++ 17 files changed, 333 insertions(+), 51 deletions(-) diff --git a/_examples/chat/generated.go b/_examples/chat/generated.go index 5f01ae298a4..ddaf5ac6323 100644 --- a/_examples/chat/generated.go +++ b/_examples/chat/generated.go @@ -220,7 +220,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { var buf bytes.Buffer return func(ctx context.Context) *graphql.Response { buf.Reset() - data := next() + data := next(ctx) if data == nil { return nil @@ -419,7 +419,7 @@ func (ec *executionContext) field___Type_fields_args(ctx context.Context, rawArg // region ************************** directives.gotpl ************************** -func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *ast.OperationDefinition, next func(ctx context.Context) (interface{}, error)) func() graphql.Marshaler { +func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *ast.OperationDefinition, next func(ctx context.Context) (interface{}, error)) func(ctx context.Context) graphql.Marshaler { for _, d := range obj.Directives { switch d.Name { case "user": @@ -427,7 +427,7 @@ func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *as args, err := ec.dir_user_args(ctx, rawArgs) if err != nil { ec.Error(ctx, err) - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { return graphql.Null } } @@ -443,15 +443,15 @@ func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *as tmp, err := next(ctx) if err != nil { ec.Error(ctx, err) - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { return graphql.Null } } - if data, ok := tmp.(func() graphql.Marshaler); ok { + if data, ok := tmp.(func(ctx context.Context) graphql.Marshaler); ok { return data } ec.Errorf(ctx, `unexpected type %T from directive, should be graphql.Marshaler`, tmp) - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { return graphql.Null } } @@ -986,7 +986,7 @@ func (ec *executionContext) fieldContext_Query___schema(ctx context.Context, fie return fc, nil } -func (ec *executionContext) _Subscription_messageAdded(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_messageAdded(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_messageAdded(ctx, field) if err != nil { return nil @@ -1012,7 +1012,7 @@ func (ec *executionContext) _Subscription_messageAdded(ctx context.Context, fiel } return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan *Message) if !ok { return nil @@ -3050,7 +3050,7 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr var subscriptionImplementors = []string{"Subscription"} -func (ec *executionContext) _Subscription(ctx context.Context, sel ast.SelectionSet) func() graphql.Marshaler { +func (ec *executionContext) _Subscription(ctx context.Context, sel ast.SelectionSet) func(ctx context.Context) graphql.Marshaler { fields := graphql.CollectFields(ec.OperationContext, sel, subscriptionImplementors) ctx = graphql.WithFieldContext(ctx, &graphql.FieldContext{ Object: "Subscription", diff --git a/codegen/directives.gotpl b/codegen/directives.gotpl index e6d2455f6c8..23bcf0f879b 100644 --- a/codegen/directives.gotpl +++ b/codegen/directives.gotpl @@ -70,7 +70,7 @@ func (ec *executionContext) _mutationMiddleware(ctx context.Context, obj *ast.Op {{ end }} {{ if .Directives.LocationDirectives "SUBSCRIPTION" }} -func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *ast.OperationDefinition, next func(ctx context.Context) (interface{}, error)) func() graphql.Marshaler { +func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *ast.OperationDefinition, next func(ctx context.Context) (interface{}, error)) func(ctx context.Context) graphql.Marshaler { for _, d := range obj.Directives { switch d.Name { {{- range $directive := .Directives.LocationDirectives "SUBSCRIPTION" }} @@ -80,7 +80,7 @@ func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *as args, err := ec.{{ $directive.ArgsFunc }}(ctx,rawArgs) if err != nil { ec.Error(ctx, err) - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { return graphql.Null } } @@ -98,15 +98,15 @@ func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *as tmp, err := next(ctx) if err != nil { ec.Error(ctx, err) - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { return graphql.Null } } - if data, ok := tmp.(func() graphql.Marshaler); ok { + if data, ok := tmp.(func(ctx context.Context) graphql.Marshaler); ok { return data } ec.Errorf(ctx, `unexpected type %T from directive, should be graphql.Marshaler`, tmp) - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { return graphql.Null } } diff --git a/codegen/field.gotpl b/codegen/field.gotpl index 0a1042b6b4a..3629d92c7f9 100644 --- a/codegen/field.gotpl +++ b/codegen/field.gotpl @@ -1,6 +1,6 @@ {{- range $object := .Objects }}{{- range $field := $object.Fields }} -func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Context, field graphql.CollectedField{{ if not $object.Root }}, obj {{$object.Reference | ref}}{{end}}) (ret {{ if $object.Stream }}func(){{ end }}graphql.Marshaler) { +func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Context, field graphql.CollectedField{{ if not $object.Root }}, obj {{$object.Reference | ref}}{{end}}) (ret {{ if $object.Stream }}func(ctx context.Context){{ end }}graphql.Marshaler) { {{- $null := "graphql.Null" }} {{- if $object.Stream }} {{- $null = "nil" }} @@ -38,7 +38,7 @@ func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Contex return {{ $null }} } {{- if $object.Stream }} - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan {{$field.TypeReference.GO | ref}}) if !ok { return nil diff --git a/codegen/generated!.gotpl b/codegen/generated!.gotpl index bf59daccef9..07d6e2b4ce8 100644 --- a/codegen/generated!.gotpl +++ b/codegen/generated!.gotpl @@ -190,7 +190,7 @@ var buf bytes.Buffer return func(ctx context.Context) *graphql.Response { buf.Reset() - data := next() + data := next(ctx) if data == nil { return nil diff --git a/codegen/object.gotpl b/codegen/object.gotpl index 8cb9d28ced7..3dd2a5abf25 100644 --- a/codegen/object.gotpl +++ b/codegen/object.gotpl @@ -3,7 +3,7 @@ var {{ $object.Name|lcFirst}}Implementors = {{$object.Implementors}} {{- if .Stream }} -func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.SelectionSet) func() graphql.Marshaler { +func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.SelectionSet) func(ctx context.Context) graphql.Marshaler { fields := graphql.CollectFields(ec.OperationContext, sel, {{$object.Name|lcFirst}}Implementors) ctx = graphql.WithFieldContext(ctx, &graphql.FieldContext{ Object: {{$object.Name|quote}}, diff --git a/codegen/root_.gotpl b/codegen/root_.gotpl index 13d77961837..355a21bb1d0 100644 --- a/codegen/root_.gotpl +++ b/codegen/root_.gotpl @@ -157,7 +157,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { var buf bytes.Buffer return func(ctx context.Context) *graphql.Response { buf.Reset() - data := next() + data := next(ctx) if data == nil { return nil diff --git a/codegen/testserver/followschema/nulls.graphql b/codegen/testserver/followschema/nulls.graphql index a1fea680ce2..b213b17b0fb 100644 --- a/codegen/testserver/followschema/nulls.graphql +++ b/codegen/testserver/followschema/nulls.graphql @@ -6,6 +6,10 @@ extend type Query { valid: String! } +extend type Subscription { + errorRequired: Error! +} + type Errors { a: Error! b: Error! diff --git a/codegen/testserver/followschema/resolver.go b/codegen/testserver/followschema/resolver.go index c104c7c1040..04d4f581be7 100644 --- a/codegen/testserver/followschema/resolver.go +++ b/codegen/testserver/followschema/resolver.go @@ -372,6 +372,10 @@ func (r *subscriptionResolver) Issue896b(ctx context.Context) (<-chan []*CheckIs panic("not implemented") } +func (r *subscriptionResolver) ErrorRequired(ctx context.Context) (<-chan *Error, error) { + panic("not implemented") +} + func (r *userResolver) Friends(ctx context.Context, obj *User) ([]*User, error) { panic("not implemented") } diff --git a/codegen/testserver/followschema/root_.generated.go b/codegen/testserver/followschema/root_.generated.go index 65d486f2764..16ff85a5e12 100644 --- a/codegen/testserver/followschema/root_.generated.go +++ b/codegen/testserver/followschema/root_.generated.go @@ -367,6 +367,7 @@ type ComplexityRoot struct { DirectiveDouble func(childComplexity int) int DirectiveNullableArg func(childComplexity int, arg *int, arg2 *int, arg3 *string) int DirectiveUnimplemented func(childComplexity int) int + ErrorRequired func(childComplexity int) int InitPayload func(childComplexity int) int Issue896b func(childComplexity int) int Updated func(childComplexity int) int @@ -1713,6 +1714,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Subscription.DirectiveUnimplemented(childComplexity), true + case "Subscription.errorRequired": + if e.complexity.Subscription.ErrorRequired == nil { + break + } + + return e.complexity.Subscription.ErrorRequired(childComplexity), true + case "Subscription.initPayload": if e.complexity.Subscription.InitPayload == nil { break @@ -1948,7 +1956,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { var buf bytes.Buffer return func(ctx context.Context) *graphql.Response { buf.Reset() - data := next() + data := next(ctx) if data == nil { return nil @@ -2252,6 +2260,10 @@ input NestedInput { valid: String! } +extend type Subscription { + errorRequired: Error! +} + type Errors { a: Error! b: Error! diff --git a/codegen/testserver/followschema/schema.generated.go b/codegen/testserver/followschema/schema.generated.go index 6160ff997b9..142c2c8ed94 100644 --- a/codegen/testserver/followschema/schema.generated.go +++ b/codegen/testserver/followschema/schema.generated.go @@ -106,6 +106,7 @@ type SubscriptionResolver interface { DirectiveDouble(ctx context.Context) (<-chan *string, error) DirectiveUnimplemented(ctx context.Context) (<-chan *string, error) Issue896b(ctx context.Context) (<-chan []*CheckIssue896, error) + ErrorRequired(ctx context.Context) (<-chan *Error, error) } type UserResolver interface { Friends(ctx context.Context, obj *User) ([]*User, error) @@ -4698,7 +4699,7 @@ func (ec *executionContext) fieldContext_Query___schema(ctx context.Context, fie return fc, nil } -func (ec *executionContext) _Subscription_updated(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_updated(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_updated(ctx, field) if err != nil { return nil @@ -4721,7 +4722,7 @@ func (ec *executionContext) _Subscription_updated(ctx context.Context, field gra } return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan string) if !ok { return nil @@ -4749,7 +4750,7 @@ func (ec *executionContext) fieldContext_Subscription_updated(ctx context.Contex return fc, nil } -func (ec *executionContext) _Subscription_initPayload(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_initPayload(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_initPayload(ctx, field) if err != nil { return nil @@ -4772,7 +4773,7 @@ func (ec *executionContext) _Subscription_initPayload(ctx context.Context, field } return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan string) if !ok { return nil @@ -4800,7 +4801,7 @@ func (ec *executionContext) fieldContext_Subscription_initPayload(ctx context.Co return fc, nil } -func (ec *executionContext) _Subscription_directiveArg(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_directiveArg(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_directiveArg(ctx, field) if err != nil { return nil @@ -4820,7 +4821,7 @@ func (ec *executionContext) _Subscription_directiveArg(ctx context.Context, fiel if resTmp == nil { return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan *string) if !ok { return nil @@ -4859,7 +4860,7 @@ func (ec *executionContext) fieldContext_Subscription_directiveArg(ctx context.C return fc, nil } -func (ec *executionContext) _Subscription_directiveNullableArg(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_directiveNullableArg(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_directiveNullableArg(ctx, field) if err != nil { return nil @@ -4879,7 +4880,7 @@ func (ec *executionContext) _Subscription_directiveNullableArg(ctx context.Conte if resTmp == nil { return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan *string) if !ok { return nil @@ -4918,7 +4919,7 @@ func (ec *executionContext) fieldContext_Subscription_directiveNullableArg(ctx c return fc, nil } -func (ec *executionContext) _Subscription_directiveDouble(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_directiveDouble(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_directiveDouble(ctx, field) if err != nil { return nil @@ -4964,7 +4965,7 @@ func (ec *executionContext) _Subscription_directiveDouble(ctx context.Context, f if resTmp == nil { return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan *string) if !ok { return nil @@ -4992,7 +4993,7 @@ func (ec *executionContext) fieldContext_Subscription_directiveDouble(ctx contex return fc, nil } -func (ec *executionContext) _Subscription_directiveUnimplemented(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_directiveUnimplemented(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_directiveUnimplemented(ctx, field) if err != nil { return nil @@ -5032,7 +5033,7 @@ func (ec *executionContext) _Subscription_directiveUnimplemented(ctx context.Con if resTmp == nil { return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan *string) if !ok { return nil @@ -5060,7 +5061,7 @@ func (ec *executionContext) fieldContext_Subscription_directiveUnimplemented(ctx return fc, nil } -func (ec *executionContext) _Subscription_issue896b(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_issue896b(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_issue896b(ctx, field) if err != nil { return nil @@ -5080,7 +5081,7 @@ func (ec *executionContext) _Subscription_issue896b(ctx context.Context, field g if resTmp == nil { return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan []*CheckIssue896) if !ok { return nil @@ -5112,6 +5113,67 @@ func (ec *executionContext) fieldContext_Subscription_issue896b(ctx context.Cont return fc, nil } +func (ec *executionContext) _Subscription_errorRequired(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { + fc, err := ec.fieldContext_Subscription_errorRequired(ctx, field) + if err != nil { + return nil + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + resTmp := ec._fieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Subscription().ErrorRequired(rctx) + }) + + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return nil + } + return func(ctx context.Context) graphql.Marshaler { + res, ok := <-resTmp.(<-chan *Error) + if !ok { + return nil + } + return graphql.WriterFunc(func(w io.Writer) { + w.Write([]byte{'{'}) + graphql.MarshalString(field.Alias).MarshalGQL(w) + w.Write([]byte{':'}) + ec.marshalNError2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚋfollowschemaᚐError(ctx, field.Selections, res).MarshalGQL(w) + w.Write([]byte{'}'}) + }) + } +} + +func (ec *executionContext) fieldContext_Subscription_errorRequired(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Subscription", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "id": + return ec.fieldContext_Error_id(ctx, field) + case "errorOnNonRequiredField": + return ec.fieldContext_Error_errorOnNonRequiredField(ctx, field) + case "errorOnRequiredField": + return ec.fieldContext_Error_errorOnRequiredField(ctx, field) + case "nilOnRequiredField": + return ec.fieldContext_Error_nilOnRequiredField(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type Error", field.Name) + }, + } + return fc, nil +} + func (ec *executionContext) _User_id(ctx context.Context, field graphql.CollectedField, obj *User) (ret graphql.Marshaler) { fc, err := ec.fieldContext_User_id(ctx, field) if err != nil { @@ -7223,7 +7285,7 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr var subscriptionImplementors = []string{"Subscription"} -func (ec *executionContext) _Subscription(ctx context.Context, sel ast.SelectionSet) func() graphql.Marshaler { +func (ec *executionContext) _Subscription(ctx context.Context, sel ast.SelectionSet) func(ctx context.Context) graphql.Marshaler { fields := graphql.CollectFields(ec.OperationContext, sel, subscriptionImplementors) ctx = graphql.WithFieldContext(ctx, &graphql.FieldContext{ Object: "Subscription", @@ -7248,6 +7310,8 @@ func (ec *executionContext) _Subscription(ctx context.Context, sel ast.Selection return ec._Subscription_directiveUnimplemented(ctx, fields[0]) case "issue896b": return ec._Subscription_issue896b(ctx, fields[0]) + case "errorRequired": + return ec._Subscription_errorRequired(ctx, fields[0]) default: panic("unknown field " + strconv.Quote(fields[0].Name)) } diff --git a/codegen/testserver/followschema/stub.go b/codegen/testserver/followschema/stub.go index 592cf3a6ef9..a1bb8859ff4 100644 --- a/codegen/testserver/followschema/stub.go +++ b/codegen/testserver/followschema/stub.go @@ -124,6 +124,7 @@ type Stub struct { DirectiveDouble func(ctx context.Context) (<-chan *string, error) DirectiveUnimplemented func(ctx context.Context) (<-chan *string, error) Issue896b func(ctx context.Context) (<-chan []*CheckIssue896, error) + ErrorRequired func(ctx context.Context) (<-chan *Error, error) } UserResolver struct { Friends func(ctx context.Context, obj *User) ([]*User, error) @@ -488,6 +489,9 @@ func (r *stubSubscription) DirectiveUnimplemented(ctx context.Context) (<-chan * func (r *stubSubscription) Issue896b(ctx context.Context) (<-chan []*CheckIssue896, error) { return r.SubscriptionResolver.Issue896b(ctx) } +func (r *stubSubscription) ErrorRequired(ctx context.Context) (<-chan *Error, error) { + return r.SubscriptionResolver.ErrorRequired(ctx) +} type stubUser struct{ *Stub } diff --git a/codegen/testserver/followschema/subscription_test.go b/codegen/testserver/followschema/subscription_test.go index 221cbe27288..6126db99ce8 100644 --- a/codegen/testserver/followschema/subscription_test.go +++ b/codegen/testserver/followschema/subscription_test.go @@ -51,6 +51,24 @@ func TestSubscriptions(t *testing.T) { return channel, nil } + errorTick := make(chan *Error, 1) + resolvers.SubscriptionResolver.ErrorRequired = func(ctx context.Context) (<-chan *Error, error) { + res := make(chan *Error, 1) + + go func() { + for { + select { + case t := <-errorTick: + res <- t + case <-ctx.Done(): + close(res) + return + } + } + }() + return res, nil + } + resolvers.SubscriptionResolver.Updated = func(ctx context.Context) (<-chan string, error) { res := make(chan string, 1) @@ -138,4 +156,39 @@ func TestSubscriptions(t *testing.T) { require.Equal(t, "strings = []interface {}{\"hello\", \"world\"}", msg.resp.InitPayload) sub.Close() }) + + t.Run("websocket gets errors", func(t *testing.T) { + runtime.GC() // ensure no go-routines left from preceding tests + initialGoroutineCount := runtime.NumGoroutine() + + sub := c.Websocket(`subscription { errorRequired { id } }`) + + errorTick <- &Error{ID: "ID1234"} + + var msg struct { + resp struct { + ErrorRequired *struct { + Id string + } + } + } + + err := sub.Next(&msg.resp) + require.NoError(t, err) + require.Equal(t, "ID1234", msg.resp.ErrorRequired.Id) + + errorTick <- nil + err = sub.Next(&msg.resp) + require.Error(t, err) + + sub.Close() + + // need a little bit of time for goroutines to settle + start := time.Now() + for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() { + time.Sleep(5 * time.Millisecond) + } + + require.Equal(t, initialGoroutineCount, runtime.NumGoroutine()) + }) } diff --git a/codegen/testserver/singlefile/generated.go b/codegen/testserver/singlefile/generated.go index 16d6f6be363..0c1a487052c 100644 --- a/codegen/testserver/singlefile/generated.go +++ b/codegen/testserver/singlefile/generated.go @@ -378,6 +378,7 @@ type ComplexityRoot struct { DirectiveDouble func(childComplexity int) int DirectiveNullableArg func(childComplexity int, arg *int, arg2 *int, arg3 *string) int DirectiveUnimplemented func(childComplexity int) int + ErrorRequired func(childComplexity int) int InitPayload func(childComplexity int) int Issue896b func(childComplexity int) int Updated func(childComplexity int) int @@ -555,6 +556,7 @@ type SubscriptionResolver interface { DirectiveDouble(ctx context.Context) (<-chan *string, error) DirectiveUnimplemented(ctx context.Context) (<-chan *string, error) Issue896b(ctx context.Context) (<-chan []*CheckIssue896, error) + ErrorRequired(ctx context.Context) (<-chan *Error, error) } type UserResolver interface { Friends(ctx context.Context, obj *User) ([]*User, error) @@ -1852,6 +1854,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Subscription.DirectiveUnimplemented(childComplexity), true + case "Subscription.errorRequired": + if e.complexity.Subscription.ErrorRequired == nil { + break + } + + return e.complexity.Subscription.ErrorRequired(childComplexity), true + case "Subscription.initPayload": if e.complexity.Subscription.InitPayload == nil { break @@ -2087,7 +2096,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { var buf bytes.Buffer return func(ctx context.Context) *graphql.Response { buf.Reset() - data := next() + data := next(ctx) if data == nil { return nil @@ -2391,6 +2400,10 @@ input NestedInput { valid: String! } +extend type Subscription { + errorRequired: Error! +} + type Errors { a: Error! b: Error! @@ -11255,7 +11268,7 @@ func (ec *executionContext) fieldContext_Slices_test4(ctx context.Context, field return fc, nil } -func (ec *executionContext) _Subscription_updated(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_updated(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_updated(ctx, field) if err != nil { return nil @@ -11278,7 +11291,7 @@ func (ec *executionContext) _Subscription_updated(ctx context.Context, field gra } return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan string) if !ok { return nil @@ -11306,7 +11319,7 @@ func (ec *executionContext) fieldContext_Subscription_updated(ctx context.Contex return fc, nil } -func (ec *executionContext) _Subscription_initPayload(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_initPayload(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_initPayload(ctx, field) if err != nil { return nil @@ -11329,7 +11342,7 @@ func (ec *executionContext) _Subscription_initPayload(ctx context.Context, field } return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan string) if !ok { return nil @@ -11357,7 +11370,7 @@ func (ec *executionContext) fieldContext_Subscription_initPayload(ctx context.Co return fc, nil } -func (ec *executionContext) _Subscription_directiveArg(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_directiveArg(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_directiveArg(ctx, field) if err != nil { return nil @@ -11377,7 +11390,7 @@ func (ec *executionContext) _Subscription_directiveArg(ctx context.Context, fiel if resTmp == nil { return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan *string) if !ok { return nil @@ -11416,7 +11429,7 @@ func (ec *executionContext) fieldContext_Subscription_directiveArg(ctx context.C return fc, nil } -func (ec *executionContext) _Subscription_directiveNullableArg(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_directiveNullableArg(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_directiveNullableArg(ctx, field) if err != nil { return nil @@ -11436,7 +11449,7 @@ func (ec *executionContext) _Subscription_directiveNullableArg(ctx context.Conte if resTmp == nil { return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan *string) if !ok { return nil @@ -11475,7 +11488,7 @@ func (ec *executionContext) fieldContext_Subscription_directiveNullableArg(ctx c return fc, nil } -func (ec *executionContext) _Subscription_directiveDouble(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_directiveDouble(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_directiveDouble(ctx, field) if err != nil { return nil @@ -11521,7 +11534,7 @@ func (ec *executionContext) _Subscription_directiveDouble(ctx context.Context, f if resTmp == nil { return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan *string) if !ok { return nil @@ -11549,7 +11562,7 @@ func (ec *executionContext) fieldContext_Subscription_directiveDouble(ctx contex return fc, nil } -func (ec *executionContext) _Subscription_directiveUnimplemented(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_directiveUnimplemented(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_directiveUnimplemented(ctx, field) if err != nil { return nil @@ -11589,7 +11602,7 @@ func (ec *executionContext) _Subscription_directiveUnimplemented(ctx context.Con if resTmp == nil { return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan *string) if !ok { return nil @@ -11617,7 +11630,7 @@ func (ec *executionContext) fieldContext_Subscription_directiveUnimplemented(ctx return fc, nil } -func (ec *executionContext) _Subscription_issue896b(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { +func (ec *executionContext) _Subscription_issue896b(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { fc, err := ec.fieldContext_Subscription_issue896b(ctx, field) if err != nil { return nil @@ -11637,7 +11650,7 @@ func (ec *executionContext) _Subscription_issue896b(ctx context.Context, field g if resTmp == nil { return nil } - return func() graphql.Marshaler { + return func(ctx context.Context) graphql.Marshaler { res, ok := <-resTmp.(<-chan []*CheckIssue896) if !ok { return nil @@ -11669,6 +11682,67 @@ func (ec *executionContext) fieldContext_Subscription_issue896b(ctx context.Cont return fc, nil } +func (ec *executionContext) _Subscription_errorRequired(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { + fc, err := ec.fieldContext_Subscription_errorRequired(ctx, field) + if err != nil { + return nil + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + resTmp := ec._fieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Subscription().ErrorRequired(rctx) + }) + + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return nil + } + return func(ctx context.Context) graphql.Marshaler { + res, ok := <-resTmp.(<-chan *Error) + if !ok { + return nil + } + return graphql.WriterFunc(func(w io.Writer) { + w.Write([]byte{'{'}) + graphql.MarshalString(field.Alias).MarshalGQL(w) + w.Write([]byte{':'}) + ec.marshalNError2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚋsinglefileᚐError(ctx, field.Selections, res).MarshalGQL(w) + w.Write([]byte{'}'}) + }) + } +} + +func (ec *executionContext) fieldContext_Subscription_errorRequired(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Subscription", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "id": + return ec.fieldContext_Error_id(ctx, field) + case "errorOnNonRequiredField": + return ec.fieldContext_Error_errorOnNonRequiredField(ctx, field) + case "errorOnRequiredField": + return ec.fieldContext_Error_errorOnRequiredField(ctx, field) + case "nilOnRequiredField": + return ec.fieldContext_Error_nilOnRequiredField(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type Error", field.Name) + }, + } + return fc, nil +} + func (ec *executionContext) _User_id(ctx context.Context, field graphql.CollectedField, obj *User) (ret graphql.Marshaler) { fc, err := ec.fieldContext_User_id(ctx, field) if err != nil { @@ -18508,7 +18582,7 @@ func (ec *executionContext) _Slices(ctx context.Context, sel ast.SelectionSet, o var subscriptionImplementors = []string{"Subscription"} -func (ec *executionContext) _Subscription(ctx context.Context, sel ast.SelectionSet) func() graphql.Marshaler { +func (ec *executionContext) _Subscription(ctx context.Context, sel ast.SelectionSet) func(ctx context.Context) graphql.Marshaler { fields := graphql.CollectFields(ec.OperationContext, sel, subscriptionImplementors) ctx = graphql.WithFieldContext(ctx, &graphql.FieldContext{ Object: "Subscription", @@ -18533,6 +18607,8 @@ func (ec *executionContext) _Subscription(ctx context.Context, sel ast.Selection return ec._Subscription_directiveUnimplemented(ctx, fields[0]) case "issue896b": return ec._Subscription_issue896b(ctx, fields[0]) + case "errorRequired": + return ec._Subscription_errorRequired(ctx, fields[0]) default: panic("unknown field " + strconv.Quote(fields[0].Name)) } diff --git a/codegen/testserver/singlefile/nulls.graphql b/codegen/testserver/singlefile/nulls.graphql index a1fea680ce2..b213b17b0fb 100644 --- a/codegen/testserver/singlefile/nulls.graphql +++ b/codegen/testserver/singlefile/nulls.graphql @@ -6,6 +6,10 @@ extend type Query { valid: String! } +extend type Subscription { + errorRequired: Error! +} + type Errors { a: Error! b: Error! diff --git a/codegen/testserver/singlefile/resolver.go b/codegen/testserver/singlefile/resolver.go index 0fd649789b7..87ca51c010b 100644 --- a/codegen/testserver/singlefile/resolver.go +++ b/codegen/testserver/singlefile/resolver.go @@ -372,6 +372,10 @@ func (r *subscriptionResolver) Issue896b(ctx context.Context) (<-chan []*CheckIs panic("not implemented") } +func (r *subscriptionResolver) ErrorRequired(ctx context.Context) (<-chan *Error, error) { + panic("not implemented") +} + func (r *userResolver) Friends(ctx context.Context, obj *User) ([]*User, error) { panic("not implemented") } diff --git a/codegen/testserver/singlefile/stub.go b/codegen/testserver/singlefile/stub.go index 1ee3c2bc283..40305096641 100644 --- a/codegen/testserver/singlefile/stub.go +++ b/codegen/testserver/singlefile/stub.go @@ -124,6 +124,7 @@ type Stub struct { DirectiveDouble func(ctx context.Context) (<-chan *string, error) DirectiveUnimplemented func(ctx context.Context) (<-chan *string, error) Issue896b func(ctx context.Context) (<-chan []*CheckIssue896, error) + ErrorRequired func(ctx context.Context) (<-chan *Error, error) } UserResolver struct { Friends func(ctx context.Context, obj *User) ([]*User, error) @@ -488,6 +489,9 @@ func (r *stubSubscription) DirectiveUnimplemented(ctx context.Context) (<-chan * func (r *stubSubscription) Issue896b(ctx context.Context) (<-chan []*CheckIssue896, error) { return r.SubscriptionResolver.Issue896b(ctx) } +func (r *stubSubscription) ErrorRequired(ctx context.Context) (<-chan *Error, error) { + return r.SubscriptionResolver.ErrorRequired(ctx) +} type stubUser struct{ *Stub } diff --git a/codegen/testserver/singlefile/subscription_test.go b/codegen/testserver/singlefile/subscription_test.go index dd53ba15afa..24186e5ebee 100644 --- a/codegen/testserver/singlefile/subscription_test.go +++ b/codegen/testserver/singlefile/subscription_test.go @@ -51,6 +51,24 @@ func TestSubscriptions(t *testing.T) { return channel, nil } + errorTick := make(chan *Error, 1) + resolvers.SubscriptionResolver.ErrorRequired = func(ctx context.Context) (<-chan *Error, error) { + res := make(chan *Error, 1) + + go func() { + for { + select { + case t := <-errorTick: + res <- t + case <-ctx.Done(): + close(res) + return + } + } + }() + return res, nil + } + resolvers.SubscriptionResolver.Updated = func(ctx context.Context) (<-chan string, error) { res := make(chan string, 1) @@ -138,4 +156,39 @@ func TestSubscriptions(t *testing.T) { require.Equal(t, "strings = []interface {}{\"hello\", \"world\"}", msg.resp.InitPayload) sub.Close() }) + + t.Run("websocket gets errors", func(t *testing.T) { + runtime.GC() // ensure no go-routines left from preceding tests + initialGoroutineCount := runtime.NumGoroutine() + + sub := c.Websocket(`subscription { errorRequired { id } }`) + + errorTick <- &Error{ID: "ID1234"} + + var msg struct { + resp struct { + ErrorRequired *struct { + Id string + } + } + } + + err := sub.Next(&msg.resp) + require.NoError(t, err) + require.Equal(t, "ID1234", msg.resp.ErrorRequired.Id) + + errorTick <- nil + err = sub.Next(&msg.resp) + require.Error(t, err) + + sub.Close() + + // need a little bit of time for goroutines to settle + start := time.Now() + for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() { + time.Sleep(5 * time.Millisecond) + } + + require.Equal(t, initialGoroutineCount, runtime.NumGoroutine()) + }) }