Skip to content

Commit

Permalink
Added support of custom directives
Browse files Browse the repository at this point in the history
based on graph-gophers#446 and work by @eko
  • Loading branch information
eko authored and Sean Sorrell committed Dec 22, 2022
1 parent 4423f25 commit 856e336
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 5 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.13

require (
github.com/opentracing/opentracing-go v1.2.0
github.com/stretchr/testify v1.8.1 // indirect
go.opentelemetry.io/otel v1.6.3
go.opentelemetry.io/otel/trace v1.6.3
)
12 changes: 9 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
Expand All @@ -12,9 +13,13 @@ github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYr
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
go.opentelemetry.io/otel v1.6.3 h1:FLOfo8f9JzFVFVyU+MSRJc2HdEAXQgm7pIv2uFKRSZE=
go.opentelemetry.io/otel v1.6.3/go.mod h1:7BgNga5fNlF/iZjG06hM3yofffp0ofKCDwSXx1GC4dI=
go.opentelemetry.io/otel/trace v1.6.3 h1:IqN4L+5b0mPNjdXIiZ90Ni4Bl5BRkDQywePLWemd9bc=
Expand All @@ -23,5 +28,6 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IV
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
10 changes: 10 additions & 0 deletions graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ type Schema struct {
useStringDescriptions bool
disableIntrospection bool
subscribeResolverTimeout time.Duration
visitors map[string]types.DirectiveVisitor
}

func (s *Schema) ASTSchema() *types.Schema {
Expand Down Expand Up @@ -169,6 +170,14 @@ func SubscribeResolverTimeout(timeout time.Duration) SchemaOpt {
}
}

// DirectiveVisitors allows to pass custom directive visitors that will be able to handle
// your GraphQL schema directives.
func DirectiveVisitors(visitors map[string]types.DirectiveVisitor) SchemaOpt {
return func(s *Schema) {
s.visitors = visitors
}
}

// Response represents a typical response of a GraphQL server. It may be encoded to JSON directly or
// it may be further processed to a custom response type, for example to include custom error data.
// Errors are intentionally serialized first based on the advice in https://github.com/facebook/graphql/commit/7b40390d48680b15cb93e02d46ac5eb249689876#diff-757cea6edf0288677a9eea4cfc801d87R107
Expand Down Expand Up @@ -258,6 +267,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str
Tracer: s.tracer,
Logger: s.logger,
PanicHandler: s.panicHandler,
Visitors: s.visitors,
}
varTypes := make(map[string]*introspection.Type)
for _, v := range op.Vars {
Expand Down
150 changes: 149 additions & 1 deletion graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/graph-gophers/graphql-go/gqltesting"
"github.com/graph-gophers/graphql-go/introspection"
"github.com/graph-gophers/graphql-go/trace/tracer"
"github.com/graph-gophers/graphql-go/types"
)

type helloWorldResolver1 struct{}
Expand Down Expand Up @@ -48,6 +49,27 @@ func (r *helloSnakeResolver2) SayHello(ctx context.Context, args struct{ FullNam
return "Hello " + args.FullName + "!", nil
}

type customDirectiveVisitor struct {
beforeWasCalled bool
}

func (v *customDirectiveVisitor) Before(ctx context.Context, directive *types.Directive, input interface{}) error {
v.beforeWasCalled = true
return nil
}

func (v *customDirectiveVisitor) After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error) {
if v.beforeWasCalled == false {
return nil, errors.New("Before directive visitor method wasn't called.")
}

if value, ok := directive.Arguments.Get("customAttribute"); ok {
return fmt.Sprintf("Directive '%s' (with arg '%s') modified result: %s", directive.Name.Name, value.String(), output.(string)), nil
} else {
return fmt.Sprintf("Directive '%s' modified result: %s", directive.Name.Name, output.(string)), nil
}
}

type theNumberResolver struct {
number int32
}
Expand Down Expand Up @@ -191,7 +213,6 @@ func TestHelloWorld(t *testing.T) {
}
`,
},

{
Schema: graphql.MustParseSchema(`
schema {
Expand All @@ -216,6 +237,67 @@ func TestHelloWorld(t *testing.T) {
})
}

func TestCustomDirective(t *testing.T) {
t.Parallel()

gqltesting.RunTests(t, []*gqltesting.Test{
{
Schema: graphql.MustParseSchema(`
directive @customDirective on FIELD_DEFINITION
schema {
query: Query
}
type Query {
hello_html: String! @customDirective
}
`, &helloSnakeResolver1{},
graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{
"customDirective": &customDirectiveVisitor{},
})),
Query: `
{
hello_html
}
`,
ExpectedResult: `
{
"hello_html": "Directive 'customDirective' modified result: Hello snake!"
}
`,
},
{
Schema: graphql.MustParseSchema(`
directive @customDirective(
customAttribute: String!
) on FIELD_DEFINITION
schema {
query: Query
}
type Query {
say_hello(full_name: String!): String! @customDirective(customAttribute: hi)
}
`, &helloSnakeResolver1{},
graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{
"customDirective": &customDirectiveVisitor{},
})),
Query: `
{
say_hello(full_name: "Johnny")
}
`,
ExpectedResult: `
{
"say_hello": "Directive 'customDirective' (with arg 'hi') modified result: Hello Johnny!"
}
`,
},
})
}

func TestHelloSnake(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -4550,3 +4632,69 @@ func TestQueryService(t *testing.T) {
},
})
}

type StructFieldResolver struct {
Hello string
}

func TestStructFieldResolver(t *testing.T) {
gqltesting.RunTests(t, []*gqltesting.Test{
{
Schema: graphql.MustParseSchema(`
schema {
query: Query
}
type Query {
hello: String!
}
`, &StructFieldResolver{Hello: "Hello world!"}, graphql.UseFieldResolvers()),
Query: `
{
hello
}
`,
ExpectedResult: `
{
"hello": "Hello world!"
}
`,
},
})
}

func TestDirectiveStructFieldResolver(t *testing.T) {
schemaOpt := []graphql.SchemaOpt{
graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{
"customDirective": &customDirectiveVisitor{},
}),
graphql.UseFieldResolvers(),
}

gqltesting.RunTests(t, []*gqltesting.Test{

{
Schema: graphql.MustParseSchema(`
directive @customDirective on FIELD_DEFINITION
schema {
query: Query
}
type Query {
hello: String! @customDirective
}
`, &StructFieldResolver{Hello: "Hello world!"}, schemaOpt...),
Query: `
{
hello
}
`,
ExpectedResult: `
{
"hello": "Directive 'customDirective' modified result: Hello world!"
}
`,
}})

}
73 changes: 73 additions & 0 deletions internal/exec/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Request struct {
Logger log.Logger
PanicHandler errors.PanicHandler
SubscribeResolverTimeout time.Duration
Visitors map[string]types.DirectiveVisitor
}

func (r *Request) handlePanic(ctx context.Context) {
Expand Down Expand Up @@ -208,8 +209,48 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
if f.field.ArgsPacker != nil {
in = append(in, f.field.PackedArgs)
}

// Before hook directive visitor
if len(f.field.Directives) > 0 {
for _, directive := range f.field.Directives {
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
values := make([]interface{}, 0, len(in))
for _, inValue := range in {
values = append(values, inValue.Interface())
}

visitorErr := visitor.Before(ctx, directive, values)
if visitorErr != nil {
err := errors.Errorf("%s", visitorErr)
err.Path = path.toSlice()
err.ResolverError = visitorErr
return err
}
}
}
}

// Call method
callOut := res.Method(f.field.MethodIndex).Call(in)
result = callOut[0]

// After hook directive visitor (when no error is returned from resolver)
if !f.field.HasError && len(f.field.Directives) > 0 {
for _, directive := range f.field.Directives {
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
returned, visitorErr := visitor.After(ctx, directive, result.Interface())
if visitorErr != nil {
err := errors.Errorf("%s", visitorErr)
err.Path = path.toSlice()
err.ResolverError = visitorErr
return err
} else {
result = reflect.ValueOf(returned)
}
}
}
}

if f.field.HasError && !callOut[1].IsNil() {
resolverErr := callOut[1].Interface().(error)
err := errors.Errorf("%s", resolverErr)
Expand All @@ -225,8 +266,40 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
if res.Kind() == reflect.Ptr {
res = res.Elem()
}
// Before hook directive visitor struct field
if len(f.field.Directives) > 0 {
for _, directive := range f.field.Directives {
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
// TODO check that directive arity == 0-that should be an error at schema init time
visitorErr := visitor.Before(ctx, directive, nil)
if visitorErr != nil {
err := errors.Errorf("%s", visitorErr)
err.Path = path.toSlice()
err.ResolverError = visitorErr
return err
}
}
}
}
result = res.FieldByIndex(f.field.FieldIndex)
// After hook directive visitor (when no error is returned from resolver)
if !f.field.HasError && len(f.field.Directives) > 0 {
for _, directive := range f.field.Directives {
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
returned, visitorErr := visitor.After(ctx, directive, result.Interface())
if visitorErr != nil {
err := errors.Errorf("%s", visitorErr)
err.Path = path.toSlice()
err.ResolverError = visitorErr
return err
} else {
result = reflect.ValueOf(returned)
}
}
}
}
}

return nil
}()

Expand Down
11 changes: 10 additions & 1 deletion types/directive.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package types

import "github.com/graph-gophers/graphql-go/errors"
import (
"context"

"github.com/graph-gophers/graphql-go/errors"
)

// Directive is a representation of the GraphQL Directive.
//
Expand All @@ -24,6 +28,11 @@ type DirectiveDefinition struct {

type DirectiveList []*Directive

type DirectiveVisitor interface {
Before(ctx context.Context, directive *Directive, input interface{}) error
After(ctx context.Context, directive *Directive, output interface{}) (interface{}, error)
}

// Returns the Directive in the DirectiveList by name or nil if not found.
func (l DirectiveList) Get(name string) *Directive {
for _, d := range l {
Expand Down

0 comments on commit 856e336

Please sign in to comment.