diff --git a/graphql.go b/graphql.go index 76a6434d..b2bcf85f 100644 --- a/graphql.go +++ b/graphql.go @@ -82,6 +82,7 @@ type Schema struct { useStringDescriptions bool disableIntrospection bool subscribeResolverTimeout time.Duration + visitors map[string]types.DirectiveVisitor } func (s *Schema) ASTSchema() *types.Schema { @@ -168,6 +169,14 @@ func SubscribeResolverTimeout(timeout time.Duration) SchemaOpt { } } +// DirectiveVisitors allows to pass custom directive visitors that will be able to handle +// your GraphQL schema directives. +func DirectiveVisitors(visitors map[string]types.DirectiveVisitor) SchemaOpt { + return func(s *Schema) { + s.visitors = visitors + } +} + // Response represents a typical response of a GraphQL server. It may be encoded to JSON directly or // it may be further processed to a custom response type, for example to include custom error data. // Errors are intentionally serialized first based on the advice in https://github.com/facebook/graphql/commit/7b40390d48680b15cb93e02d46ac5eb249689876#diff-757cea6edf0288677a9eea4cfc801d87R107 @@ -257,6 +266,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str Tracer: s.tracer, Logger: s.logger, PanicHandler: s.panicHandler, + Visitors: s.visitors, } varTypes := make(map[string]*introspection.Type) for _, v := range op.Vars { diff --git a/graphql_test.go b/graphql_test.go index 497a74f3..e37ff88a 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -14,6 +14,7 @@ import ( "github.com/graph-gophers/graphql-go/gqltesting" "github.com/graph-gophers/graphql-go/introspection" "github.com/graph-gophers/graphql-go/trace" + "github.com/graph-gophers/graphql-go/types" ) type helloWorldResolver1 struct{} @@ -48,6 +49,27 @@ func (r *helloSnakeResolver2) SayHello(ctx context.Context, args struct{ FullNam return "Hello " + args.FullName + "!", nil } +type customDirectiveVisitor struct { + beforeWasCalled bool +} + +func (v *customDirectiveVisitor) Before(ctx context.Context, directive *types.Directive, input interface{}) error { + v.beforeWasCalled = true + return nil +} + +func (v *customDirectiveVisitor) After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error) { + if v.beforeWasCalled == false { + return nil, errors.New("Before directive visitor method wasn't called.") + } + + if value, ok := directive.Arguments.Get("customAttribute"); ok { + return fmt.Sprintf("Directive '%s' (with arg '%s') modified result: %s", directive.Name.Name, value.String(), output.(string)), nil + } else { + return fmt.Sprintf("Directive '%s' modified result: %s", directive.Name.Name, output.(string)), nil + } +} + type theNumberResolver struct { number int32 } @@ -191,7 +213,6 @@ func TestHelloWorld(t *testing.T) { } `, }, - { Schema: graphql.MustParseSchema(` schema { @@ -216,6 +237,67 @@ func TestHelloWorld(t *testing.T) { }) } +func TestCustomDirective(t *testing.T) { + t.Parallel() + + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: graphql.MustParseSchema(` + directive @customDirective on FIELD_DEFINITION + + schema { + query: Query + } + + type Query { + hello_html: String! @customDirective + } + `, &helloSnakeResolver1{}, + graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{ + "customDirective": &customDirectiveVisitor{}, + })), + Query: ` + { + hello_html + } + `, + ExpectedResult: ` + { + "hello_html": "Directive 'customDirective' modified result: Hello snake!" + } + `, + }, + { + Schema: graphql.MustParseSchema(` + directive @customDirective( + customAttribute: String! + ) on FIELD_DEFINITION + + schema { + query: Query + } + + type Query { + say_hello(full_name: String!): String! @customDirective(customAttribute: hi) + } + `, &helloSnakeResolver1{}, + graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{ + "customDirective": &customDirectiveVisitor{}, + })), + Query: ` + { + say_hello(full_name: "Johnny") + } + `, + ExpectedResult: ` + { + "say_hello": "Directive 'customDirective' (with arg 'hi') modified result: Hello Johnny!" + } + `, + }, + }) +} + func TestHelloSnake(t *testing.T) { t.Parallel() diff --git a/internal/exec/exec.go b/internal/exec/exec.go index 6b478487..b69cc7b8 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -25,6 +25,7 @@ type Request struct { Logger log.Logger PanicHandler errors.PanicHandler SubscribeResolverTimeout time.Duration + Visitors map[string]types.DirectiveVisitor } func (r *Request) handlePanic(ctx context.Context) { @@ -208,8 +209,47 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f if f.field.ArgsPacker != nil { in = append(in, f.field.PackedArgs) } + + // Before hook directive visitor + if len(f.field.Directives) > 0 { + for _, directive := range f.field.Directives { + if visitor, ok := r.Visitors[directive.Name.Name]; ok { + var values = make([]interface{}, 0) + for _, inValue := range in { + values = append(values, inValue.Interface()) + } + + if visitorErr := visitor.Before(ctx, directive, values); err != nil { + err := errors.Errorf("%s", visitorErr) + err.Path = path.toSlice() + err.ResolverError = visitorErr + return err + } + } + } + } + + // Call method callOut := res.Method(f.field.MethodIndex).Call(in) result = callOut[0] + + // After hook directive visitor (when no error is returned from resolver) + if !f.field.HasError && len(f.field.Directives) > 0 { + for _, directive := range f.field.Directives { + if visitor, ok := r.Visitors[directive.Name.Name]; ok { + returned, visitorErr := visitor.After(ctx, directive, result.Interface()) + if err != nil { + err := errors.Errorf("%s", visitorErr) + err.Path = path.toSlice() + err.ResolverError = visitorErr + return err + } else { + result = reflect.ValueOf(returned) + } + } + } + } + if f.field.HasError && !callOut[1].IsNil() { resolverErr := callOut[1].Interface().(error) err := errors.Errorf("%s", resolverErr) diff --git a/types/directive.go b/types/directive.go index 0f8a4b99..2d0d9168 100644 --- a/types/directive.go +++ b/types/directive.go @@ -1,6 +1,10 @@ package types -import "github.com/graph-gophers/graphql-go/errors" +import ( + "context" + + "github.com/graph-gophers/graphql-go/errors" +) // Directive is a representation of the GraphQL Directive. // @@ -23,6 +27,11 @@ type DirectiveDefinition struct { type DirectiveList []*Directive +type DirectiveVisitor interface { + Before(ctx context.Context, directive *Directive, input interface{}) error + After(ctx context.Context, directive *Directive, output interface{}) (interface{}, error) +} + // Returns the Directive in the DirectiveList by name or nil if not found. func (l DirectiveList) Get(name string) *Directive { for _, d := range l {