diff --git a/go.mod b/go.mod index 423043e84..94a3a773e 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.13 require ( github.com/opentracing/opentracing-go v1.2.0 + github.com/stretchr/testify v1.8.1 // indirect go.opentelemetry.io/otel v1.6.3 go.opentelemetry.io/otel/trace v1.6.3 ) diff --git a/go.sum b/go.sum index b987a5d21..75d0bd4c8 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -12,9 +13,13 @@ github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYr github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= go.opentelemetry.io/otel v1.6.3 h1:FLOfo8f9JzFVFVyU+MSRJc2HdEAXQgm7pIv2uFKRSZE= go.opentelemetry.io/otel v1.6.3/go.mod h1:7BgNga5fNlF/iZjG06hM3yofffp0ofKCDwSXx1GC4dI= go.opentelemetry.io/otel/trace v1.6.3 h1:IqN4L+5b0mPNjdXIiZ90Ni4Bl5BRkDQywePLWemd9bc= @@ -23,5 +28,6 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IV golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/graphql.go b/graphql.go index 891c03799..5768c177a 100644 --- a/graphql.go +++ b/graphql.go @@ -83,6 +83,7 @@ type Schema struct { useStringDescriptions bool disableIntrospection bool subscribeResolverTimeout time.Duration + visitors map[string]types.DirectiveVisitor } func (s *Schema) ASTSchema() *types.Schema { @@ -169,6 +170,14 @@ func SubscribeResolverTimeout(timeout time.Duration) SchemaOpt { } } +// DirectiveVisitors allows to pass custom directive visitors that will be able to handle +// your GraphQL schema directives. +func DirectiveVisitors(visitors map[string]types.DirectiveVisitor) SchemaOpt { + return func(s *Schema) { + s.visitors = visitors + } +} + // Response represents a typical response of a GraphQL server. It may be encoded to JSON directly or // it may be further processed to a custom response type, for example to include custom error data. // Errors are intentionally serialized first based on the advice in https://github.com/facebook/graphql/commit/7b40390d48680b15cb93e02d46ac5eb249689876#diff-757cea6edf0288677a9eea4cfc801d87R107 @@ -258,6 +267,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str Tracer: s.tracer, Logger: s.logger, PanicHandler: s.panicHandler, + Visitors: s.visitors, } varTypes := make(map[string]*introspection.Type) for _, v := range op.Vars { diff --git a/graphql_test.go b/graphql_test.go index c12334c85..6133aff31 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -14,6 +14,7 @@ import ( "github.com/graph-gophers/graphql-go/gqltesting" "github.com/graph-gophers/graphql-go/introspection" "github.com/graph-gophers/graphql-go/trace/tracer" + "github.com/graph-gophers/graphql-go/types" ) type helloWorldResolver1 struct{} @@ -48,6 +49,27 @@ func (r *helloSnakeResolver2) SayHello(ctx context.Context, args struct{ FullNam return "Hello " + args.FullName + "!", nil } +type customDirectiveVisitor struct { + beforeWasCalled bool +} + +func (v *customDirectiveVisitor) Before(ctx context.Context, directive *types.Directive, input interface{}) error { + v.beforeWasCalled = true + return nil +} + +func (v *customDirectiveVisitor) After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error) { + if v.beforeWasCalled == false { + return nil, errors.New("Before directive visitor method wasn't called.") + } + + if value, ok := directive.Arguments.Get("customAttribute"); ok { + return fmt.Sprintf("Directive '%s' (with arg '%s') modified result: %s", directive.Name.Name, value.String(), output.(string)), nil + } else { + return fmt.Sprintf("Directive '%s' modified result: %s", directive.Name.Name, output.(string)), nil + } +} + type theNumberResolver struct { number int32 } @@ -191,7 +213,6 @@ func TestHelloWorld(t *testing.T) { } `, }, - { Schema: graphql.MustParseSchema(` schema { @@ -216,6 +237,67 @@ func TestHelloWorld(t *testing.T) { }) } +func TestCustomDirective(t *testing.T) { + t.Parallel() + + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: graphql.MustParseSchema(` + directive @customDirective on FIELD_DEFINITION + + schema { + query: Query + } + + type Query { + hello_html: String! @customDirective + } + `, &helloSnakeResolver1{}, + graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{ + "customDirective": &customDirectiveVisitor{}, + })), + Query: ` + { + hello_html + } + `, + ExpectedResult: ` + { + "hello_html": "Directive 'customDirective' modified result: Hello snake!" + } + `, + }, + { + Schema: graphql.MustParseSchema(` + directive @customDirective( + customAttribute: String! + ) on FIELD_DEFINITION + + schema { + query: Query + } + + type Query { + say_hello(full_name: String!): String! @customDirective(customAttribute: hi) + } + `, &helloSnakeResolver1{}, + graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{ + "customDirective": &customDirectiveVisitor{}, + })), + Query: ` + { + say_hello(full_name: "Johnny") + } + `, + ExpectedResult: ` + { + "say_hello": "Directive 'customDirective' (with arg 'hi') modified result: Hello Johnny!" + } + `, + }, + }) +} + func TestHelloSnake(t *testing.T) { t.Parallel() @@ -4550,3 +4632,69 @@ func TestQueryService(t *testing.T) { }, }) } + +type StructFieldResolver struct { + Hello string +} + +func TestStructFieldResolver(t *testing.T) { + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: graphql.MustParseSchema(` + schema { + query: Query + } + + type Query { + hello: String! + } + `, &StructFieldResolver{Hello: "Hello world!"}, graphql.UseFieldResolvers()), + Query: ` + { + hello + } + `, + ExpectedResult: ` + { + "hello": "Hello world!" + } + `, + }, + }) +} + +func TestDirectiveStructFieldResolver(t *testing.T) { + schemaOpt := []graphql.SchemaOpt{ + graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{ + "customDirective": &customDirectiveVisitor{}, + }), + graphql.UseFieldResolvers(), + } + + gqltesting.RunTests(t, []*gqltesting.Test{ + + { + Schema: graphql.MustParseSchema(` + directive @customDirective on FIELD_DEFINITION + + schema { + query: Query + } + + type Query { + hello: String! @customDirective + } + `, &StructFieldResolver{Hello: "Hello world!"}, schemaOpt...), + Query: ` + { + hello + } + `, + ExpectedResult: ` + { + "hello": "Directive 'customDirective' modified result: Hello world!" + } + `, + }}) + +} diff --git a/internal/exec/exec.go b/internal/exec/exec.go index e9056c53e..fd259acfa 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -25,6 +25,7 @@ type Request struct { Logger log.Logger PanicHandler errors.PanicHandler SubscribeResolverTimeout time.Duration + Visitors map[string]types.DirectiveVisitor } func (r *Request) handlePanic(ctx context.Context) { @@ -208,8 +209,48 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f if f.field.ArgsPacker != nil { in = append(in, f.field.PackedArgs) } + + // Before hook directive visitor + if len(f.field.Directives) > 0 { + for _, directive := range f.field.Directives { + if visitor, ok := r.Visitors[directive.Name.Name]; ok { + values := make([]interface{}, 0, len(in)) + for _, inValue := range in { + values = append(values, inValue.Interface()) + } + + visitorErr := visitor.Before(ctx, directive, values) + if visitorErr != nil { + err := errors.Errorf("%s", visitorErr) + err.Path = path.toSlice() + err.ResolverError = visitorErr + return err + } + } + } + } + + // Call method callOut := res.Method(f.field.MethodIndex).Call(in) result = callOut[0] + + // After hook directive visitor (when no error is returned from resolver) + if !f.field.HasError && len(f.field.Directives) > 0 { + for _, directive := range f.field.Directives { + if visitor, ok := r.Visitors[directive.Name.Name]; ok { + returned, visitorErr := visitor.After(ctx, directive, result.Interface()) + if visitorErr != nil { + err := errors.Errorf("%s", visitorErr) + err.Path = path.toSlice() + err.ResolverError = visitorErr + return err + } else { + result = reflect.ValueOf(returned) + } + } + } + } + if f.field.HasError && !callOut[1].IsNil() { resolverErr := callOut[1].Interface().(error) err := errors.Errorf("%s", resolverErr) @@ -225,8 +266,40 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f if res.Kind() == reflect.Ptr { res = res.Elem() } + // Before hook directive visitor struct field + if len(f.field.Directives) > 0 { + for _, directive := range f.field.Directives { + if visitor, ok := r.Visitors[directive.Name.Name]; ok { + // TODO check that directive arity == 0-that should be an error at schema init time + visitorErr := visitor.Before(ctx, directive, nil) + if visitorErr != nil { + err := errors.Errorf("%s", visitorErr) + err.Path = path.toSlice() + err.ResolverError = visitorErr + return err + } + } + } + } result = res.FieldByIndex(f.field.FieldIndex) + // After hook directive visitor (when no error is returned from resolver) + if !f.field.HasError && len(f.field.Directives) > 0 { + for _, directive := range f.field.Directives { + if visitor, ok := r.Visitors[directive.Name.Name]; ok { + returned, visitorErr := visitor.After(ctx, directive, result.Interface()) + if visitorErr != nil { + err := errors.Errorf("%s", visitorErr) + err.Path = path.toSlice() + err.ResolverError = visitorErr + return err + } else { + result = reflect.ValueOf(returned) + } + } + } + } } + return nil }() diff --git a/types/directive.go b/types/directive.go index 7b62d51e3..562338663 100644 --- a/types/directive.go +++ b/types/directive.go @@ -1,6 +1,10 @@ package types -import "github.com/graph-gophers/graphql-go/errors" +import ( + "context" + + "github.com/graph-gophers/graphql-go/errors" +) // Directive is a representation of the GraphQL Directive. // @@ -24,6 +28,11 @@ type DirectiveDefinition struct { type DirectiveList []*Directive +type DirectiveVisitor interface { + Before(ctx context.Context, directive *Directive, input interface{}) error + After(ctx context.Context, directive *Directive, output interface{}) (interface{}, error) +} + // Returns the Directive in the DirectiveList by name or nil if not found. func (l DirectiveList) Get(name string) *Directive { for _, d := range l {