diff --git a/graphql.go b/graphql.go index 663c36f80db..5a4c801d928 100644 --- a/graphql.go +++ b/graphql.go @@ -87,39 +87,65 @@ func execSelectionSet(s *Schema, r *request, t *schema.Object, selSet *query.Sel for _, sel := range selSet.Selections { switch sel := sel.(type) { case *query.Field: - sf := t.Fields[sel.Name] - m := resolver.Method(findMethod(resolver.Type(), sel.Name)) - var in []reflect.Value - if len(sf.Parameters) != 0 { - args := reflect.New(m.Type().In(0)) - for name, param := range sf.Parameters { - value, ok := sel.Arguments[name] - if !ok { - value = &query.Literal{Value: param.Default} - } - rf := args.Elem().FieldByNameFunc(func(n string) bool { return strings.EqualFold(n, name) }) - switch v := value.(type) { - case *query.Variable: - rf.Set(reflect.ValueOf(r.Variables[v.Name])) - case *query.Literal: - rf.Set(reflect.ValueOf(v.Value)) - default: - panic("invalid value") - } - } - in = []reflect.Value{args.Elem()} + if skipByDirective(r, sel.Directives) { + continue } - result[sel.Alias] = exec(s, r, sf.Type, sel.SelSet, m.Call(in)[0]) - + execField(s, r, t, sel, resolver, result) case *query.FragmentSpread: + if skipByDirective(r, sel.Directives) { + continue + } execSelectionSet(s, r, t, r.Fragments[sel.Name].SelSet, resolver, result) - default: panic("invalid type") } } } +func execField(s *Schema, r *request, t *schema.Object, f *query.Field, resolver reflect.Value, result map[string]interface{}) { + sf := t.Fields[f.Name] + m := resolver.Method(findMethod(resolver.Type(), f.Name)) + var in []reflect.Value + if len(sf.Parameters) != 0 { + args := reflect.New(m.Type().In(0)) + for name, param := range sf.Parameters { + value, ok := f.Arguments[name] + if !ok { + value = &query.Literal{Value: param.Default} + } + rf := args.Elem().FieldByNameFunc(func(n string) bool { return strings.EqualFold(n, name) }) + rf.Set(reflect.ValueOf(execValue(r, value))) + } + in = []reflect.Value{args.Elem()} + } + result[f.Alias] = exec(s, r, sf.Type, f.SelSet, m.Call(in)[0]) +} + +func skipByDirective(r *request, d map[string]*query.Directive) bool { + if skip, ok := d["skip"]; ok { + if execValue(r, skip.Arguments["if"]).(bool) { + return true + } + } + if include, ok := d["include"]; ok { + if !execValue(r, include.Arguments["if"]).(bool) { + return true + } + } + return false +} + +func execValue(r *request, v query.Value) interface{} { + switch v := v.(type) { + case *query.Variable: + return r.Variables[v.Name] + case *query.Literal: + return v.Value + default: + panic("invalid value") + } +} + func findMethod(t reflect.Type, name string) int { for i := 0; i < t.NumMethod(); i++ { if strings.EqualFold(name, t.Method(i).Name) { diff --git a/graphql_test.go b/graphql_test.go index 405cc819507..f057842ba65 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -617,6 +617,144 @@ var tests = []struct { } `, }, + + { + name: "StarWarsInclude1", + schema: starWarsSchema, + resolver: &starWarsResolver{}, + query: ` + query Hero($episode: Episode, $withoutFriends: Boolean!) { + hero(episode: $episode) { + name + friends @skip(if: $withoutFriends) { + name + } + } + } + `, + variables: map[string]interface{}{ + "episode": "JEDI", + "withoutFriends": true, + }, + result: ` + { + "hero": { + "name": "R2-D2" + } + } + `, + }, + + { + name: "StarWarsInclude2", + schema: starWarsSchema, + resolver: &starWarsResolver{}, + query: ` + query Hero($episode: Episode, $withoutFriends: Boolean!) { + hero(episode: $episode) { + name + friends @skip(if: $withoutFriends) { + name + } + } + } + `, + variables: map[string]interface{}{ + "episode": "JEDI", + "withoutFriends": false, + }, + result: ` + { + "hero": { + "name": "R2-D2", + "friends": [ + { + "name": "Luke Skywalker" + }, + { + "name": "Han Solo" + }, + { + "name": "Leia Organa" + } + ] + } + } + `, + }, + + { + name: "StarWarsSkip1", + schema: starWarsSchema, + resolver: &starWarsResolver{}, + query: ` + query Hero($episode: Episode, $withFriends: Boolean!) { + hero(episode: $episode) { + name + ...friendsFragment @include(if: $withFriends) + } + } + + fragment friendsFragment on Character { + friends { + name + } + } + `, + variables: map[string]interface{}{ + "episode": "JEDI", + "withFriends": false, + }, + result: ` + { + "hero": { + "name": "R2-D2" + } + } + `, + }, + + { + name: "StarWarsSkip2", + schema: starWarsSchema, + resolver: &starWarsResolver{}, + query: ` + query Hero($episode: Episode, $withFriends: Boolean!) { + hero(episode: $episode) { + name + ...friendsFragment @include(if: $withFriends) + } + } + + fragment friendsFragment on Character { + friends { + name + } + } + `, + variables: map[string]interface{}{ + "episode": "JEDI", + "withFriends": true, + }, + result: ` + { + "hero": { + "name": "R2-D2", + "friends": [ + { + "name": "Luke Skywalker" + }, + { + "name": "Han Solo" + }, + { + "name": "Leia Organa" + } + ] + } + } + `, + }, } func TestAll(t *testing.T) { diff --git a/internal/query/query.go b/internal/query/query.go index 709c0f8956c..27992aab905 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -48,14 +48,21 @@ type Selection interface { } type Field struct { - Alias string + Alias string + Name string + Arguments map[string]Value + Directives map[string]*Directive + SelSet *SelectionSet +} + +type Directive struct { Name string Arguments map[string]Value - SelSet *SelectionSet } type FragmentSpread struct { - Name string + Name string + Directives map[string]*Directive } func (Field) isSelection() {} @@ -184,7 +191,7 @@ func parseSelection(l *lexer.Lexer) Selection { func parseField(l *lexer.Lexer) *Field { f := &Field{ - Arguments: make(map[string]Value), + Directives: make(map[string]*Directive), } f.Alias = l.ConsumeIdent() f.Name = f.Alias @@ -193,17 +200,11 @@ func parseField(l *lexer.Lexer) *Field { f.Name = l.ConsumeIdent() } if l.Peek() == '(' { - l.ConsumeToken('(') - if l.Peek() != ')' { - name, value := parseArgument(l) - f.Arguments[name] = value - for l.Peek() != ')' { - l.ConsumeToken(',') - name, value := parseArgument(l) - f.Arguments[name] = value - } - } - l.ConsumeToken(')') + f.Arguments = parseArguments(l) + } + for l.Peek() == '@' { + d := parseDirective(l) + f.Directives[d.Name] = d } if l.Peek() == '{' { f.SelSet = parseSelectionSet(l) @@ -211,11 +212,45 @@ func parseField(l *lexer.Lexer) *Field { return f } +func parseArguments(l *lexer.Lexer) map[string]Value { + args := make(map[string]Value) + l.ConsumeToken('(') + if l.Peek() != ')' { + name, value := parseArgument(l) + args[name] = value + for l.Peek() != ')' { + l.ConsumeToken(',') + name, value := parseArgument(l) + args[name] = value + } + } + l.ConsumeToken(')') + return args +} + +func parseDirective(l *lexer.Lexer) *Directive { + d := &Directive{} + l.ConsumeToken('@') + d.Name = l.ConsumeIdent() + if l.Peek() == '(' { + d.Arguments = parseArguments(l) + } + return d +} + func parseFragmentSpread(l *lexer.Lexer) *FragmentSpread { + fs := &FragmentSpread{ + Directives: make(map[string]*Directive), + } l.ConsumeToken('.') l.ConsumeToken('.') l.ConsumeToken('.') - return &FragmentSpread{Name: l.ConsumeIdent()} + fs.Name = l.ConsumeIdent() + for l.Peek() == '@' { + d := parseDirective(l) + fs.Directives[d.Name] = d + } + return fs } func parseArgument(l *lexer.Lexer) (string, Value) {