diff --git a/errors/errors.go b/errors/errors.go index 8ffe818e..08e2010a 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -40,3 +40,22 @@ func (err *QueryError) Error() string { } var _ error = &QueryError{} + +// SubscriptionError can be implemented by top-level resolver object to communicate to +// the library a terminal subscription error happened while the stream is still active. +// +// After a subscription has started, this is the mechanism to inform subscriber about stream +// failure in a graceful manner. +// +// **Note** This works only on the top-level object of the resolver, when implemented +// by fields selector, this has no effect. +type SubscriptionError interface { + // SubscriptionError is called to determined if a terminal error occurred. If the returned + // value is nil, subscription continues normally. If the error is non-nil, the subscription is + // assumed to have reached a terminal error, the subscription's channel is closed and the error + // is returned to the user. + // + // If the non-nil error returned is a *QueryError type, it is returned as-is to the user, otherwise, + // the non-nill error is wrapped using `Errorf("%s", err)` above. + SubscriptionError() error +} diff --git a/gqltesting/subscriptions.go b/gqltesting/subscriptions.go index 7a1cd0d1..188891c9 100644 --- a/gqltesting/subscriptions.go +++ b/gqltesting/subscriptions.go @@ -61,6 +61,10 @@ func RunSubscribe(t *testing.T, test *TestSubscription) { } for i, expected := range test.ExpectedResults { + if i+1 > len(results) { + t.Fatalf("missing result: wanted %d results, got only %d, next expected result is %+v", len(test.ExpectedResults), len(results), expected) + } + res := results[i] checkErrorStrings(t, expected.Errors, res.Errors) @@ -89,6 +93,10 @@ func RunSubscribe(t *testing.T, test *TestSubscription) { t.Fail() } } + + if len(results) > len(test.ExpectedResults) { + t.Fatalf("unexpected result: wanted %d results, got %d, first extra result was %+v", len(test.ExpectedResults), len(results), results[len(test.ExpectedResults)]) + } } func checkErrorStrings(t *testing.T, expected, actual []*errors.QueryError) { diff --git a/internal/exec/subscribe.go b/internal/exec/subscribe.go index 16cae69c..ff4a5338 100644 --- a/internal/exec/subscribe.go +++ b/internal/exec/subscribe.go @@ -103,6 +103,18 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query return } + if subErr, ok := resp.Interface().(errors.SubscriptionError); ok { + if err := subErr.SubscriptionError(); err != nil { + if gqlError, ok := err.(*errors.QueryError); ok { + c <- &Response{Errors: []*errors.QueryError{gqlError}} + } else { + c <- &Response{Errors: []*errors.QueryError{errors.Errorf("%s", err)}} + } + close(c) + return + } + } + subR := &Request{ Request: selected.Request{ Doc: r.Request.Doc, diff --git a/subscription_test.go b/subscription_test.go index 96649d94..aba75eb0 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -473,3 +473,53 @@ const schema = ` hello: String! } ` + +type subscriptionsErrorPropagation struct{} + +type subscriptionsErrorPropagationResolver struct { + msg string + err error +} + +func (r subscriptionsErrorPropagationResolver) Msg() string { return r.msg } +func (r subscriptionsErrorPropagationResolver) SubscriptionError() error { return r.err } + +func (r *subscriptionsErrorPropagation) OnMessage() <-chan *subscriptionsErrorPropagationResolver { + c := make(chan *subscriptionsErrorPropagationResolver) + go func() { + c <- &subscriptionsErrorPropagationResolver{msg: "first"} + c <- &subscriptionsErrorPropagationResolver{err: errors.New("error")} + close(c) + }() + + return c +} + +func TestSchemaSubscribe_ErrorPropagation(t *testing.T) { + r := &struct { + *subscriptionsErrorPropagation + }{ + subscriptionsErrorPropagation: &subscriptionsErrorPropagation{}, + } + gqltesting.RunSubscribe(t, &gqltesting.TestSubscription{ + Schema: graphql.MustParseSchema(` + type Query {} + type Subscription { + onMessage : Message! + } + + type Message { + msg: String! + } + `, r), + Query: ` + subscription { + onMessage { msg } + } + `, + ExpectedResults: []gqltesting.TestResponse{ + {Data: json.RawMessage(`{"onMessage":{"msg":"first"}}`)}, + {Errors: []*qerrors.QueryError{{Message: "error"}}}, + }, + }) +}