From dde7b9ee02db23d78b219998e1b9be9d92aadec5 Mon Sep 17 00:00:00 2001 From: 1046102779 Date: Fri, 13 Apr 2018 16:48:52 +0800 Subject: [PATCH 1/2] optimise code style --- language/ast/arguments.go | 8 +- language/ast/definitions.go | 20 +- language/ast/name.go | 7 +- language/ast/selections.go | 18 +- language/ast/types.go | 7 +- language/ast/values.go | 7 +- language/parser/parser.go | 115 +++++----- language/printer/printer.go | 3 +- language/visitor/visitor.go | 347 +++++++++++++------------------ language/visitor/visitor_test.go | 1 + rules.go | 195 ++++++++--------- schema.go | 6 +- validator.go | 1 - values.go | 52 ++--- 14 files changed, 337 insertions(+), 450 deletions(-) diff --git a/language/ast/arguments.go b/language/ast/arguments.go index 5f7ef0d2..2ebd0fa7 100644 --- a/language/ast/arguments.go +++ b/language/ast/arguments.go @@ -16,12 +16,8 @@ func NewArgument(arg *Argument) *Argument { if arg == nil { arg = &Argument{} } - return &Argument{ - Kind: kinds.Argument, - Loc: arg.Loc, - Name: arg.Name, - Value: arg.Value, - } + arg.Kind = kinds.Argument + return arg } func (arg *Argument) GetKind() string { diff --git a/language/ast/definitions.go b/language/ast/definitions.go index cd527f0a..e16cf18d 100644 --- a/language/ast/definitions.go +++ b/language/ast/definitions.go @@ -39,15 +39,8 @@ func NewOperationDefinition(op *OperationDefinition) *OperationDefinition { if op == nil { op = &OperationDefinition{} } - return &OperationDefinition{ - Kind: kinds.OperationDefinition, - Loc: op.Loc, - Operation: op.Operation, - Name: op.Name, - VariableDefinitions: op.VariableDefinitions, - Directives: op.Directives, - SelectionSet: op.SelectionSet, - } + op.Kind = kinds.OperationDefinition + return op } func (op *OperationDefinition) GetKind() string { @@ -143,13 +136,8 @@ func NewVariableDefinition(vd *VariableDefinition) *VariableDefinition { if vd == nil { vd = &VariableDefinition{} } - return &VariableDefinition{ - Kind: kinds.VariableDefinition, - Loc: vd.Loc, - Variable: vd.Variable, - Type: vd.Type, - DefaultValue: vd.DefaultValue, - } + vd.Kind = kinds.VariableDefinition + return vd } func (vd *VariableDefinition) GetKind() string { diff --git a/language/ast/name.go b/language/ast/name.go index 00fddbcd..ce0e9ebd 100644 --- a/language/ast/name.go +++ b/language/ast/name.go @@ -15,11 +15,8 @@ func NewName(node *Name) *Name { if node == nil { node = &Name{} } - return &Name{ - Kind: kinds.Name, - Value: node.Value, - Loc: node.Loc, - } + node.Kind = kinds.Name + return node } func (node *Name) GetKind() string { diff --git a/language/ast/selections.go b/language/ast/selections.go index 0dc0ea12..55df71a3 100644 --- a/language/ast/selections.go +++ b/language/ast/selections.go @@ -28,15 +28,8 @@ func NewField(f *Field) *Field { if f == nil { f = &Field{} } - return &Field{ - Kind: kinds.Field, - Loc: f.Loc, - Alias: f.Alias, - Name: f.Name, - Arguments: f.Arguments, - Directives: f.Directives, - SelectionSet: f.SelectionSet, - } + f.Kind = kinds.Field + return f } func (f *Field) GetKind() string { @@ -128,11 +121,8 @@ func NewSelectionSet(ss *SelectionSet) *SelectionSet { if ss == nil { ss = &SelectionSet{} } - return &SelectionSet{ - Kind: kinds.SelectionSet, - Loc: ss.Loc, - Selections: ss.Selections, - } + ss.Kind = kinds.SelectionSet + return ss } func (ss *SelectionSet) GetKind() string { diff --git a/language/ast/types.go b/language/ast/types.go index 27f00997..0308a609 100644 --- a/language/ast/types.go +++ b/language/ast/types.go @@ -26,11 +26,8 @@ func NewNamed(t *Named) *Named { if t == nil { t = &Named{} } - return &Named{ - Kind: kinds.Named, - Loc: t.Loc, - Name: t.Name, - } + t.Kind = kinds.Named + return t } func (t *Named) GetKind() string { diff --git a/language/ast/values.go b/language/ast/values.go index 67912bdc..6c3c8864 100644 --- a/language/ast/values.go +++ b/language/ast/values.go @@ -31,11 +31,8 @@ func NewVariable(v *Variable) *Variable { if v == nil { v = &Variable{} } - return &Variable{ - Kind: kinds.Variable, - Loc: v.Loc, - Name: v.Name, - } + v.Kind = kinds.Variable + return v } func (v *Variable) GetKind() string { diff --git a/language/parser/parser.go b/language/parser/parser.go index c3fdc0c8..764a170e 100644 --- a/language/parser/parser.go +++ b/language/parser/parser.go @@ -249,20 +249,18 @@ func parseOperationType(parser *Parser) (string, error) { */ func parseVariableDefinitions(parser *Parser) ([]*ast.VariableDefinition, error) { variableDefinitions := []*ast.VariableDefinition{} - if peek(parser, lexer.TokenKind[lexer.PAREN_L]) { - vdefs, err := reverse(parser, - lexer.TokenKind[lexer.PAREN_L], parseVariableDefinition, lexer.TokenKind[lexer.PAREN_R], - true, - ) + if !peek(parser, lexer.TokenKind[lexer.PAREN_L]) { + return nil, nil + } + if vdefs, err := reverse(parser, + lexer.TokenKind[lexer.PAREN_L], parseVariableDefinition, lexer.TokenKind[lexer.PAREN_R], + true, + ); err != nil { + return variableDefinitions, nil + } else { for _, vdef := range vdefs { - if vdef != nil { - variableDefinitions = append(variableDefinitions, vdef.(*ast.VariableDefinition)) - } - } - if err != nil { - return variableDefinitions, err + variableDefinitions = append(variableDefinitions, vdef.(*ast.VariableDefinition)) } - return variableDefinitions, nil } return variableDefinitions, nil } @@ -271,28 +269,28 @@ func parseVariableDefinitions(parser *Parser) ([]*ast.VariableDefinition, error) * VariableDefinition : Variable : Type DefaultValue? */ func parseVariableDefinition(parser *Parser) (interface{}, error) { + var ( + variable *ast.Variable + ttype ast.Type + err error + ) start := parser.Token.Start - variable, err := parseVariable(parser) - if err != nil { + if variable, err = parseVariable(parser); err != nil { return nil, err } - _, err = expect(parser, lexer.TokenKind[lexer.COLON]) - if err != nil { + if _, err = expect(parser, lexer.TokenKind[lexer.COLON]); err != nil { return nil, err } - ttype, err := parseType(parser) - if err != nil { + if ttype, err = parseType(parser); err != nil { return nil, err } var defaultValue ast.Value if skp, err := skip(parser, lexer.TokenKind[lexer.EQUALS]); err != nil { return nil, err } else if skp { - dv, err := parseValueLiteral(parser, true) - if err != nil { + if defaultValue, err = parseValueLiteral(parser, true); err != nil { return nil, err } - defaultValue = dv } return ast.NewVariableDefinition(&ast.VariableDefinition{ Variable: variable, @@ -306,13 +304,15 @@ func parseVariableDefinition(parser *Parser) (interface{}, error) { * Variable : $ Name */ func parseVariable(parser *Parser) (*ast.Variable, error) { + var ( + err error + name *ast.Name + ) start := parser.Token.Start - _, err := expect(parser, lexer.TokenKind[lexer.DOLLAR]) - if err != nil { + if _, err = expect(parser, lexer.TokenKind[lexer.DOLLAR]); err != nil { return nil, err } - name, err := parseName(parser) - if err != nil { + if name, err = parseName(parser); err != nil { return nil, err } return ast.NewVariable(&ast.Variable{ @@ -326,17 +326,16 @@ func parseVariable(parser *Parser) (*ast.Variable, error) { */ func parseSelectionSet(parser *Parser) (*ast.SelectionSet, error) { start := parser.Token.Start - iSelections, err := reverse(parser, + selections := []ast.Selection{} + if iSelections, err := reverse(parser, lexer.TokenKind[lexer.BRACE_L], parseSelection, lexer.TokenKind[lexer.BRACE_R], true, - ) - if err != nil { + ); err != nil { return nil, err - } - selections := []ast.Selection{} - for _, iSelection := range iSelections { - // type assert interface{} into Selection interface - selections = append(selections, iSelection.(ast.Selection)) + } else { + for _, iSelection := range iSelections { + selections = append(selections, iSelection.(ast.Selection)) + } } return ast.NewSelectionSet(&ast.SelectionSet{ @@ -353,8 +352,7 @@ func parseSelectionSet(parser *Parser) (*ast.SelectionSet, error) { */ func parseSelection(parser *Parser) (interface{}, error) { if peek(parser, lexer.TokenKind[lexer.SPREAD]) { - r, err := parseFragment(parser) - return r, err + return parseFragment(parser) } return parseField(parser) } @@ -412,15 +410,15 @@ func parseField(parser *Parser) (*ast.Field, error) { func parseArguments(parser *Parser) ([]*ast.Argument, error) { arguments := []*ast.Argument{} if peek(parser, lexer.TokenKind[lexer.PAREN_L]) { - iArguments, err := reverse(parser, + if iArguments, err := reverse(parser, lexer.TokenKind[lexer.PAREN_L], parseArgument, lexer.TokenKind[lexer.PAREN_R], true, - ) - if err != nil { + ); err != nil { return arguments, err - } - for _, iArgument := range iArguments { - arguments = append(arguments, iArgument.(*ast.Argument)) + } else { + for _, iArgument := range iArguments { + arguments = append(arguments, iArgument.(*ast.Argument)) + } } } return arguments, nil @@ -462,9 +460,11 @@ func parseArgument(parser *Parser) (interface{}, error) { * InlineFragment : ... TypeCondition? Directives? SelectionSet */ func parseFragment(parser *Parser) (interface{}, error) { + var ( + err error + ) start := parser.Token.Start - _, err := expect(parser, lexer.TokenKind[lexer.SPREAD]) - if err != nil { + if _, err = expect(parser, lexer.TokenKind[lexer.SPREAD]); err != nil { return nil, err } if peek(parser, lexer.TokenKind[lexer.NAME]) && parser.Token.Value != "on" { @@ -518,7 +518,7 @@ func parseFragment(parser *Parser) (interface{}, error) { */ func parseFragmentDefinition(parser *Parser) (ast.Node, error) { start := parser.Token.Start - _, err := expectKeyWord(parser, "fragment") + _, err := expectKeyWord(parser, lexer.FRAGMENT) if err != nil { return nil, err } @@ -738,15 +738,12 @@ func parseObjectField(parser *Parser, isConst bool) (*ast.ObjectField, error) { */ func parseDirectives(parser *Parser) ([]*ast.Directive, error) { directives := []*ast.Directive{} - for { - if !peek(parser, lexer.TokenKind[lexer.AT]) { - break - } - directive, err := parseDirective(parser) - if err != nil { + for peek(parser, lexer.TokenKind[lexer.AT]) { + if directive, err := parseDirective(parser); err != nil { return directives, err + } else { + directives = append(directives, directive) } - directives = append(directives, directive) } return directives, nil } @@ -944,7 +941,7 @@ func parseScalarTypeDefinition(parser *Parser) (ast.Node, error) { if err != nil { return nil, err } - _, err = expectKeyWord(parser, "scalar") + _, err = expectKeyWord(parser, lexer.SCALAR) if err != nil { return nil, err } @@ -976,7 +973,7 @@ func parseObjectTypeDefinition(parser *Parser) (ast.Node, error) { if err != nil { return nil, err } - _, err = expectKeyWord(parser, "type") + _, err = expectKeyWord(parser, lexer.TYPE) if err != nil { return nil, err } @@ -1161,7 +1158,7 @@ func parseInterfaceTypeDefinition(parser *Parser) (ast.Node, error) { if err != nil { return nil, err } - _, err = expectKeyWord(parser, "interface") + _, err = expectKeyWord(parser, lexer.INTERFACE) if err != nil { return nil, err } @@ -1204,7 +1201,7 @@ func parseUnionTypeDefinition(parser *Parser) (ast.Node, error) { if err != nil { return nil, err } - _, err = expectKeyWord(parser, "union") + _, err = expectKeyWord(parser, lexer.UNION) if err != nil { return nil, err } @@ -1264,7 +1261,7 @@ func parseEnumTypeDefinition(parser *Parser) (ast.Node, error) { if err != nil { return nil, err } - _, err = expectKeyWord(parser, "enum") + _, err = expectKeyWord(parser, lexer.ENUM) if err != nil { return nil, err } @@ -1335,7 +1332,7 @@ func parseInputObjectTypeDefinition(parser *Parser) (ast.Node, error) { if err != nil { return nil, err } - _, err = expectKeyWord(parser, "input") + _, err = expectKeyWord(parser, lexer.INPUT) if err != nil { return nil, err } @@ -1374,7 +1371,7 @@ func parseInputObjectTypeDefinition(parser *Parser) (ast.Node, error) { */ func parseTypeExtensionDefinition(parser *Parser) (ast.Node, error) { start := parser.Token.Start - _, err := expectKeyWord(parser, "extend") + _, err := expectKeyWord(parser, lexer.EXTEND) if err != nil { return nil, err } @@ -1405,7 +1402,7 @@ func parseDirectiveDefinition(parser *Parser) (ast.Node, error) { if description, err = parseDescription(parser); err != nil { return nil, err } - if _, err = expectKeyWord(parser, "directive"); err != nil { + if _, err = expectKeyWord(parser, lexer.DIRECTIVE); err != nil { return nil, err } if _, err = expect(parser, lexer.TokenKind[lexer.AT]); err != nil { diff --git a/language/printer/printer.go b/language/printer/printer.go index 3add3852..173604a0 100644 --- a/language/printer/printer.go +++ b/language/printer/printer.go @@ -4,9 +4,10 @@ import ( "fmt" "strings" + "reflect" + "github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/visitor" - "reflect" ) func getMapValue(m map[string]interface{}, key string) interface{} { diff --git a/language/visitor/visitor.go b/language/visitor/visitor.go index 9a1c2ac2..fc42b748 100644 --- a/language/visitor/visitor.go +++ b/language/visitor/visitor.go @@ -3,9 +3,10 @@ package visitor import ( "encoding/json" "fmt" + "reflect" + "github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/typeInfo" - "reflect" ) const ( @@ -184,27 +185,33 @@ func Visit(root ast.Node, visitorOpts *VisitorOptions, keyMap KeyMap) interface{ visitorKeys = QueryDocumentKeys } - var result interface{} - var newRoot = root - var sstack *stack - var parent interface{} - var parentSlice []interface{} - inSlice := false - prevInSlice := false - keys := []interface{}{newRoot} - index := -1 - edits := []*edit{} - path := []interface{}{} - ancestors := []interface{}{} - ancestorsSlice := [][]interface{}{} + var ( + result interface{} + newRoot ast.Node = root + sstack *stack + parent interface{} + parentSlice []interface{} + inSlice = false + prevInSlice = false + keys = []interface{}{root} + index = -1 + edits = []*edit{} // key-value + path = []interface{}{} + ancestors = []interface{}{} + ancestorsSlice = [][]interface{}{} + ) + // these algorithm must be simple!!! + // abstract algorithm Loop: for { index = index + 1 isLeaving := (len(keys) == index) - var key interface{} // string for structs or int for slices - var node interface{} // ast.Node or can be anything - var nodeSlice []interface{} + var ( + key interface{} // string for structs or int for slices + node interface{} // ast.Node or can be anything + nodeSlice []interface{} + ) isEdited := (isLeaving && len(edits) != 0) if isLeaving { @@ -238,7 +245,7 @@ Loop: arrayEditKey = edit.Key.(int) } if inSlice && isNilNode(edit.Value) { - nodeSlice = spliceNode(nodeSlice, arrayEditKey) + nodeSlice = removeNodeByIndex(nodeSlice, arrayEditKey) editOffset = editOffset + 1 } else { if inSlice { @@ -440,8 +447,6 @@ Loop: continue } } - } else { - resultIsUndefined = true } } @@ -520,218 +525,167 @@ Loop: return result } -func pop(a []interface{}) (x interface{}, aa []interface{}) { +func pop(a []interface{}) (interface{}, []interface{}) { if len(a) == 0 { - return x, aa + return nil, nil } - x, aa = a[len(a)-1], a[:len(a)-1] - return x, aa + return a[len(a)-1], a[:len(a)-1] } -func popNodeSlice(a [][]interface{}) (x []interface{}, aa [][]interface{}) { + +func popNodeSlice(a [][]interface{}) ([]interface{}, [][]interface{}) { if len(a) == 0 { - return x, aa + return nil, nil } - x, aa = a[len(a)-1], a[:len(a)-1] - return x, aa + return a[len(a)-1], a[:len(a)-1] } -func spliceNode(a interface{}, i int) (result []interface{}) { - if i < 0 { - return result - } - typeOf := reflect.TypeOf(a) - if typeOf == nil { - return result - } - switch typeOf.Kind() { - case reflect.Slice: - s := reflect.ValueOf(a) - for i := 0; i < s.Len(); i++ { - elem := s.Index(i) - elemInterface := elem.Interface() - result = append(result, elemInterface) - } - if i >= s.Len() { - return result - } - return append(result[:i], result[i+1:]...) - default: - return result + +func removeNodeByIndex(a []interface{}, pos int) []interface{} { + if pos < 0 || pos >= len(a) { + return a } + return append(a[:pos], a[pos+1:]...) } +// get value by key from struct | slice | map | wrap(prev) +// when obj type is struct, the key's type must be string +// ... slice, ... int +// ... map, ... any type. But the type satisfies map's key definition(feature: compare...) func getFieldValue(obj interface{}, key interface{}) interface{} { + var value reflect.Value val := reflect.ValueOf(obj) - if val.Type().Kind() == reflect.Ptr { + if val.Kind() == reflect.Ptr { val = val.Elem() } - if val.Type().Kind() == reflect.Struct { - key, ok := key.(string) - if !ok { - return nil - } - valField := val.FieldByName(key) - if valField.IsValid() { - return valField.Interface() - } - return nil - } - if val.Type().Kind() == reflect.Slice { - key, ok := key.(int) - if !ok { - return nil - } - if key >= val.Len() { + switch val.Kind() { + case reflect.Struct: + value = val.FieldByName(key.(string)) + case reflect.Map: + value = val.MapIndex(reflect.ValueOf(key)) + case reflect.Slice: + if index, ok := key.(int); !ok { return nil + } else if index >= 0 || val.Len() > index { + value = val.Index(index) } - valField := val.Index(key) - if valField.IsValid() { - return valField.Interface() - } - return nil } - if val.Type().Kind() == reflect.Map { - keyVal := reflect.ValueOf(key) - valField := val.MapIndex(keyVal) - if valField.IsValid() { - return valField.Interface() - } + if !value.IsValid() { return nil } - return nil + return value.Interface() } -func updateNodeField(value interface{}, fieldName string, fieldValue interface{}) (retVal interface{}) { - retVal = value - val := reflect.ValueOf(value) - - isPtr := false - if val.IsValid() && val.Type().Kind() == reflect.Ptr { - val = val.Elem() +// currenty only supports update struct field value +func updateNodeField(src interface{}, targetName string, target interface{}) interface{} { + var isPtr bool + srcVal := reflect.ValueOf(src) + // verify condition + if srcVal.Kind() == reflect.Ptr { isPtr = true + srcVal = srcVal.Elem() + } + targetVal := reflect.ValueOf(target) + if srcVal.Kind() != reflect.Struct { + return src + } + srcFieldValue := srcVal.FieldByName(targetName) + if !srcFieldValue.IsValid() || srcFieldValue.Kind() != targetVal.Kind() { + return src } - if !val.IsValid() { - return retVal - } - if val.Type().Kind() == reflect.Struct { - for i := 0; i < val.NumField(); i++ { - valueField := val.Field(i) - typeField := val.Type().Field(i) - - // try matching the field name - if typeField.Name == fieldName { - fieldValueVal := reflect.ValueOf(fieldValue) - if valueField.CanSet() { - - if fieldValueVal.IsValid() { - if valueField.Type().Kind() == fieldValueVal.Type().Kind() { - if fieldValueVal.Type().Kind() == reflect.Slice { - newSliceValue := reflect.MakeSlice(reflect.TypeOf(valueField.Interface()), fieldValueVal.Len(), fieldValueVal.Len()) - for i := 0; i < newSliceValue.Len(); i++ { - dst := newSliceValue.Index(i) - src := fieldValueVal.Index(i) - srcValue := reflect.ValueOf(src.Interface()) - if dst.CanSet() { - dst.Set(srcValue) - } - } - valueField.Set(newSliceValue) - - } else { - valueField.Set(fieldValueVal) - } - } - } else { - valueField.Set(reflect.New(valueField.Type()).Elem()) - } - if isPtr == true { - retVal = val.Addr().Interface() - return retVal - } - retVal = val.Interface() - return retVal - } + if srcFieldValue.CanSet() { + if srcFieldValue.Kind() == reflect.Slice { + items := reflect.MakeSlice(srcFieldValue.Type(), targetVal.Len(), targetVal.Len()) + for index := 0; index < items.Len(); index++ { + tmp := targetVal.Index(index).Interface() + items.Index(index).Set(reflect.ValueOf(tmp)) } + srcFieldValue.Set(items) + } else { + srcFieldValue.Set(targetVal) } } - return retVal + if isPtr { + return srcVal.Addr().Interface() + } + return srcVal.Interface() } -func toSliceInterfaces(slice interface{}) (result []interface{}) { - switch reflect.TypeOf(slice).Kind() { - case reflect.Slice: - s := reflect.ValueOf(slice) - for i := 0; i < s.Len(); i++ { - elem := s.Index(i) - elemInterface := elem.Interface() - if elem, ok := elemInterface.(ast.Node); ok { - result = append(result, elem) - } - } - return result - default: - return result + +func toSliceInterfaces(src interface{}) []interface{} { + var list []interface{} + value := reflect.ValueOf(src) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + if value.Kind() != reflect.Slice { + return nil + } + for index := 0; index < value.Len(); index++ { + list = append(list, value.Index(index).Interface()) } + return list } func isSlice(value interface{}) bool { - val := reflect.ValueOf(value) - if val.IsValid() && val.Type().Kind() == reflect.Slice { + if value == nil { + return false + } + typ := reflect.TypeOf(value) + if typ.Kind() == reflect.Slice { return true } return false } -func isNode(node interface{}) bool { - val := reflect.ValueOf(node) - if val.IsValid() && val.Type().Kind() == reflect.Ptr { - val = val.Elem() - } - if !val.IsValid() { + +func isStructNode(node interface{}) bool { + if node == nil { return false } - if val.Type().Kind() == reflect.Map { - keyVal := reflect.ValueOf("Kind") - valField := val.MapIndex(keyVal) - return valField.IsValid() + value := reflect.ValueOf(node) + if value.Kind() == reflect.Ptr { + value = value.Elem() } - if val.Type().Kind() == reflect.Struct { - valField := val.FieldByName("Kind") - return valField.IsValid() + if value.Kind() == reflect.Struct { + _, ok := node.(ast.Node) + return ok } return false } -func isStructNode(node interface{}) bool { - val := reflect.ValueOf(node) - if val.IsValid() && val.Type().Kind() == reflect.Ptr { - val = val.Elem() + +// notice: type: Named, List or NonNull maybe map type +// and it can't be asserted to ast.Node +func isNode(node interface{}) bool { + if node == nil { + return false } + val := reflect.ValueOf(node) if !val.IsValid() { return false } - if val.Type().Kind() == reflect.Struct { - valField := val.FieldByName("Kind") - return valField.IsValid() + switch val.Kind() { + case reflect.Map: + return true + case reflect.Ptr: + val = val.Elem() } - return false + _, ok := node.(ast.Node) + return ok } func isNilNode(node interface{}) bool { + if node == nil { + return true + } val := reflect.ValueOf(node) if !val.IsValid() { return true } - if val.Type().Kind() == reflect.Ptr { + switch val.Kind() { + case reflect.Ptr, reflect.Map, reflect.Slice: return val.IsNil() + case reflect.Bool: + return node.(bool) } - if val.Type().Kind() == reflect.Slice { - return val.Len() == 0 - } - if val.Type().Kind() == reflect.Map { - return val.Len() == 0 - } - if val.Type().Kind() == reflect.Bool { - return val.Interface().(bool) - } - return val.Interface() == nil + return false } // VisitInParallel Creates a new visitor instance which delegates to many visitors to run in @@ -830,44 +784,43 @@ func VisitWithTypeInfo(ttypeInfo typeInfo.TypeInfoI, visitorOpts *VisitorOptions // GetVisitFn Given a visitor instance, if it is leaving or not, and a node kind, return // the function the visitor runtime should call. +// priority [high->low] in VisitorOptions: +// KindFuncMap{Kind> {Leave, Enter}} > {Leave, Enter} > {EnterKindMap, LeaveKindMap} func GetVisitFn(visitorOpts *VisitorOptions, kind string, isLeaving bool) VisitFunc { if visitorOpts == nil { return nil } - kindVisitor, ok := visitorOpts.KindFuncMap[kind] - if ok { + if kindVisitor, ok := visitorOpts.KindFuncMap[kind]; ok { if !isLeaving && kindVisitor.Kind != nil { // { Kind() {} } return kindVisitor.Kind - } - if isLeaving { + } else if isLeaving { // { Kind: { leave() {} } } return kindVisitor.Leave + } else { + // { Kind: { enter() {} } } + return kindVisitor.Enter } - // { Kind: { enter() {} } } - return kindVisitor.Enter - } if isLeaving { - // { enter() {} } - specificVisitor := visitorOpts.Leave - if specificVisitor != nil { - return specificVisitor + // { leave() {} } + if genericVisitor := visitorOpts.Leave; genericVisitor != nil { + return genericVisitor } if specificKindVisitor, ok := visitorOpts.LeaveKindMap[kind]; ok { // { leave: { Kind() {} } } return specificKindVisitor } - } - // { leave() {} } - specificVisitor := visitorOpts.Enter - if specificVisitor != nil { - return specificVisitor - } - if specificKindVisitor, ok := visitorOpts.EnterKindMap[kind]; ok { - // { enter: { Kind() {} } } - return specificKindVisitor + } else { + // { enter() {} } + if genericVisitor := visitorOpts.Enter; genericVisitor != nil { + return genericVisitor + } + if specificKindVisitor, ok := visitorOpts.EnterKindMap[kind]; ok { + // { enter: { Kind() {} } } + return specificKindVisitor + } } return nil } diff --git a/language/visitor/visitor_test.go b/language/visitor/visitor_test.go index 400dc2bf..33e6fee7 100644 --- a/language/visitor/visitor_test.go +++ b/language/visitor/visitor_test.go @@ -6,6 +6,7 @@ import ( "testing" "fmt" + "github.com/graphql-go/graphql" "github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/kinds" diff --git a/rules.go b/rules.go index 95d80aad..41b94494 100644 --- a/rules.go +++ b/rules.go @@ -3,6 +3,7 @@ package graphql import ( "fmt" "math" + "reflect" "sort" "strings" @@ -72,25 +73,21 @@ func ArgumentsOfCorrectTypeRule(context *ValidationContext) *ValidationRuleInsta kinds.Argument: { Kind: func(p visitor.VisitFuncParams) (string, interface{}) { if argAST, ok := p.Node.(*ast.Argument); ok { - value := argAST.Value - argDef := context.Argument() - if argDef != nil { - isValid, messages := isValidLiteralValue(argDef.Type, value) - if !isValid { - argNameValue := "" + if argDef := context.Argument(); argDef != nil { + if isValid, messages := isValidLiteralValue(argDef.Type, argAST.Value); !isValid { + var messagesStr, argNameValue string if argAST.Name != nil { argNameValue = argAST.Name.Value } - messagesStr := "" if len(messages) > 0 { messagesStr = "\n" + strings.Join(messages, "\n") } reportError( context, fmt.Sprintf(`Argument "%v" has invalid value %v.%v`, - argNameValue, printer.Print(value), messagesStr), - []ast.Node{value}, + argNameValue, printer.Print(argAST.Value), messagesStr), + []ast.Node{argAST.Value}, ) } @@ -116,13 +113,17 @@ func DefaultValuesOfCorrectTypeRule(context *ValidationContext) *ValidationRuleI kinds.VariableDefinition: { Kind: func(p visitor.VisitFuncParams) (string, interface{}) { if varDefAST, ok := p.Node.(*ast.VariableDefinition); ok { - name := "" + var ( + name string + defaultValue = varDefAST.DefaultValue + messagesStr string + ) if varDefAST.Variable != nil && varDefAST.Variable.Name != nil { name = varDefAST.Variable.Name.Value } - defaultValue := varDefAST.DefaultValue ttype := context.InputType() + // when input variable value must be nonNull, and set default value is unnecessary if ttype, ok := ttype.(*NonNull); ok && defaultValue != nil { reportError( context, @@ -131,9 +132,7 @@ func DefaultValuesOfCorrectTypeRule(context *ValidationContext) *ValidationRuleI []ast.Node{defaultValue}, ) } - isValid, messages := isValidLiteralValue(ttype, defaultValue) - if ttype != nil && defaultValue != nil && !isValid { - messagesStr := "" + if isValid, messages := isValidLiteralValue(ttype, defaultValue); !isValid && defaultValue != nil { if len(messages) > 0 { messagesStr = "\n" + strings.Join(messages, "\n") } @@ -211,25 +210,21 @@ func FieldsOnCorrectTypeRule(context *ValidationContext) *ValidationRuleInstance kinds.Field: { Kind: func(p visitor.VisitFuncParams) (string, interface{}) { var action = visitor.ActionNoChange - var result interface{} if node, ok := p.Node.(*ast.Field); ok { - ttype := context.ParentType() - if ttype == nil { - return action, result - } - if t, ok := ttype.(*Object); ok && t == nil { - return action, result + var ttype Composite + if ttype = context.ParentType(); ttype == nil { + return action, nil } - if t, ok := ttype.(*Interface); ok && t == nil { - return action, result - } - if t, ok := ttype.(*Union); ok && t == nil { - return action, result + switch ttype.(type) { + case *Object, *Interface, *Union: + if reflect.ValueOf(ttype).IsNil() { + return action, nil + } } fieldDef := context.FieldDef() if fieldDef == nil { // This field doesn't exist, lets look for suggestions. - nodeName := "" + var nodeName string if node.Name != nil { nodeName = node.Name.Value } @@ -248,7 +243,7 @@ func FieldsOnCorrectTypeRule(context *ValidationContext) *ValidationRuleInstance ) } } - return action, result + return action, nil }, }, }, @@ -263,16 +258,16 @@ func FieldsOnCorrectTypeRule(context *ValidationContext) *ValidationRuleInstance // suggest them, sorted by how often the type is referenced, starting // with Interfaces. func getSuggestedTypeNames(schema *Schema, ttype Output, fieldName string) []string { - + var ( + suggestedObjectTypes = []string{} + suggestedInterfaces = []*suggestedInterface{} + // stores a map of interface name => index in suggestedInterfaces + suggestedInterfaceMap = map[string]int{} + // stores a maps of object name => true to remove duplicates from results + suggestedObjectMap = map[string]bool{} + ) possibleTypes := schema.PossibleTypes(ttype) - suggestedObjectTypes := []string{} - suggestedInterfaces := []*suggestedInterface{} - // stores a map of interface name => index in suggestedInterfaces - suggestedInterfaceMap := map[string]int{} - // stores a maps of object name => true to remove duplicates from results - suggestedObjectMap := map[string]bool{} - for _, possibleType := range possibleTypes { if field, ok := possibleType.Fields()[fieldName]; !ok || field == nil { continue @@ -446,72 +441,75 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance kinds.Argument: { Kind: func(p visitor.VisitFuncParams) (string, interface{}) { var action = visitor.ActionNoChange - var result interface{} if node, ok := p.Node.(*ast.Argument); ok { var argumentOf ast.Node if len(p.Ancestors) > 0 { argumentOf = p.Ancestors[len(p.Ancestors)-1] } if argumentOf == nil { - return action, result + return action, nil } - var fieldArgDef *Argument - if argumentOf.GetKind() == kinds.Field { - fieldDef := context.FieldDef() + // verify node, if the node's name exists in Arguments{Field, Directive} + var ( + fieldArgDef *Argument + fieldDef = context.FieldDef() + directive = context.Directive() + argNames []string + parentTypeName string + ) + switch argumentOf.GetKind() { + case kinds.Field: + // get field definition if fieldDef == nil { - return action, result - } - nodeName := "" - if node.Name != nil { - nodeName = node.Name.Value + return action, nil } - argNames := []string{} for _, arg := range fieldDef.Args { - argNames = append(argNames, arg.Name()) - if arg.Name() == nodeName { + if arg.Name() == node.Name.Value { fieldArgDef = arg + break } + argNames = append(argNames, arg.Name()) } if fieldArgDef == nil { parentType := context.ParentType() - parentTypeName := "" if parentType != nil { parentTypeName = parentType.Name() } reportError( context, - unknownArgMessage(nodeName, fieldDef.Name, parentTypeName, suggestionList(nodeName, argNames)), + unknownArgMessage( + node.Name.Value, + fieldDef.Name, + parentTypeName, suggestionList(node.Name.Value, argNames), + ), []ast.Node{node}, ) } - } else if argumentOf.GetKind() == kinds.Directive { - directive := context.Directive() - if directive == nil { - return action, result - } - nodeName := "" - if node.Name != nil { - nodeName = node.Name.Value + case kinds.Directive: + if directive = context.Directive(); directive == nil { + return action, nil } - argNames := []string{} - var directiveArgDef *Argument for _, arg := range directive.Args { - argNames = append(argNames, arg.Name()) - if arg.Name() == nodeName { - directiveArgDef = arg + if arg.Name() == node.Name.Value { + fieldArgDef = arg + break } + argNames = append(argNames, arg.Name()) } - if directiveArgDef == nil { + if fieldArgDef == nil { reportError( context, - unknownDirectiveArgMessage(nodeName, directive.Name, suggestionList(nodeName, argNames)), + unknownDirectiveArgMessage( + node.Name.Value, + directive.Name, + suggestionList(node.Name.Value, argNames), + ), []ast.Node{node}, ) } } - } - return action, result + return action, nil }, }, }, @@ -1727,8 +1725,20 @@ func VariablesInAllowedPositionRule(context *ValidationContext) *ValidationRuleI // Note that this only validates literal values, variables are assumed to // provide values of the correct type. func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) { - // A value must be provided if the type is non-null. - if ttype, ok := ttype.(*NonNull); ok { + if _, ok := ttype.(*NonNull); !ok { + if valueAST == nil { + return true, nil + } + + // This function only tests literals, and assumes variables will provide + // values of the correct type. + if valueAST.GetKind() == kinds.Variable { + return true, nil + } + } + switch ttype := ttype.(type) { + case *NonNull: + // A value must be provided if the type is non-null. if e := ttype.Error(); e != nil { return false, []string{e.Error()} } @@ -1740,20 +1750,8 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) { } ofType, _ := ttype.OfType.(Input) return isValidLiteralValue(ofType, valueAST) - } - - if valueAST == nil { - return true, nil - } - - // This function only tests literals, and assumes variables will provide - // values of the correct type. - if valueAST.GetKind() == kinds.Variable { - return true, nil - } - - // Lists accept a non-list value as a list of one. - if ttype, ok := ttype.(*List); ok { + case *List: + // Lists accept a non-list value as a list of one. itemType, _ := ttype.OfType.(Input) if valueAST, ok := valueAST.(*ast.ListValue); ok { messagesReduce := []string{} @@ -1766,11 +1764,8 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) { return (len(messagesReduce) == 0), messagesReduce } return isValidLiteralValue(itemType, valueAST) - - } - - // Input objects check each defined field and look for undefined fields. - if ttype, ok := ttype.(*InputObject); ok { + case *InputObject: + // Input objects check each defined field and look for undefined fields. valueAST, ok := valueAST.(*ast.ObjectValue) if !ok { return false, []string{fmt.Sprintf(`Expected "%v", found not an object.`, ttype.Name())} @@ -1782,23 +1777,16 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) { fieldASTs := valueAST.Fields fieldASTMap := map[string]*ast.ObjectField{} for _, fieldAST := range fieldASTs { - fieldASTName := "" - if fieldAST.Name != nil { - fieldASTName = fieldAST.Name.Value - } - - fieldASTMap[fieldASTName] = fieldAST - - field, ok := fields[fieldASTName] + fieldASTMap[fieldAST.Name.Value] = fieldAST + field, ok := fields[fieldAST.Name.Value] if !ok || field == nil { - messagesReduce = append(messagesReduce, fmt.Sprintf(`In field "%v": Unknown field.`, fieldASTName)) + messagesReduce = append(messagesReduce, fmt.Sprintf(`In field "%v": Unknown field.`, fieldAST.Name.Value)) } } // Ensure every defined field is valid. for fieldName, field := range fields { - fieldAST, _ := fieldASTMap[fieldName] var fieldASTValue ast.Value - if fieldAST != nil { + if fieldAST := fieldASTMap[fieldName]; fieldAST != nil { fieldASTValue = fieldAST.Value } if isValid, messages := isValidLiteralValue(field.Type, fieldASTValue); !isValid { @@ -1808,14 +1796,11 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) { } } return (len(messagesReduce) == 0), messagesReduce - } - - if ttype, ok := ttype.(*Scalar); ok { + case *Scalar: if isNullish(ttype.ParseLiteral(valueAST)) { return false, []string{fmt.Sprintf(`Expected type "%v", found %v.`, ttype.Name(), printer.Print(valueAST))} } - } - if ttype, ok := ttype.(*Enum); ok { + case *Enum: if isNullish(ttype.ParseLiteral(valueAST)) { return false, []string{fmt.Sprintf(`Expected type "%v", found %v.`, ttype.Name(), printer.Print(valueAST))} } diff --git a/schema.go b/schema.go index 3d34da95..f4be8160 100644 --- a/schema.go +++ b/schema.go @@ -227,10 +227,10 @@ func (gq *Schema) Type(name string) Type { } func (gq *Schema) PossibleTypes(abstractType Abstract) []*Object { - if abstractType, ok := abstractType.(*Union); ok { + switch abstractType := abstractType.(type) { + case *Union: return abstractType.Types() - } - if abstractType, ok := abstractType.(*Interface); ok { + case *Interface: if impls, ok := gq.implementations[abstractType.Name()]; ok { return impls } diff --git a/validator.go b/validator.go index 73c213eb..33379b85 100644 --- a/validator.go +++ b/validator.go @@ -31,7 +31,6 @@ func ValidateDocument(schema *Schema, astDoc *ast.Document, rules []ValidationRu rules = SpecifiedRules } - vr.IsValid = false if schema == nil { vr.Errors = append(vr.Errors, gqlerrors.NewFormattedError("Must provide schema")) return vr diff --git a/values.go b/values.go index 36cba6ec..7453e394 100644 --- a/values.go +++ b/values.go @@ -7,11 +7,12 @@ import ( "reflect" "strings" + "sort" + "github.com/graphql-go/graphql/gqlerrors" "github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/kinds" "github.com/graphql-go/graphql/language/printer" - "sort" ) // Prepares an object map of variableValues of the correct type based on the @@ -316,41 +317,26 @@ func isValidInputValue(value interface{}, ttype Input) (bool, []string) { } // Returns true if a value is null, undefined, or NaN. -func isNullish(value interface{}) bool { - if value, ok := value.(*string); ok { - if value == nil { - return true - } - return *value == "" - } - if value, ok := value.(int); ok { - return math.IsNaN(float64(value)) - } - if value, ok := value.(*int); ok { - if value == nil { - return true - } - return math.IsNaN(float64(*value)) - } - if value, ok := value.(float32); ok { - return math.IsNaN(float64(value)) - } - if value, ok := value.(*float32); ok { - if value == nil { - return true - } - return math.IsNaN(float64(*value)) - } - if value, ok := value.(float64); ok { - return math.IsNaN(value) - } - if value, ok := value.(*float64); ok { - if value == nil { +func isNullish(src interface{}) bool { + if src == nil { + return true + } + value := reflect.ValueOf(src) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + switch value.Kind() { + case reflect.String: + // if src is ptr type and len(string)=0, it returns false + if !value.IsValid() { return true } - return math.IsNaN(*value) + case reflect.Int: + return math.IsNaN(float64(value.Int())) + case reflect.Float32, reflect.Float64: + return math.IsNaN(float64(value.Float())) } - return value == nil + return false } /** From 047e926f3db962db1c0f451bd1cedc2adb95015b Mon Sep 17 00:00:00 2001 From: 1046102779 Date: Fri, 13 Apr 2018 17:34:27 +0800 Subject: [PATCH 2/2] fix --- language/parser/parser.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/language/parser/parser.go b/language/parser/parser.go index 764a170e..3fd83149 100644 --- a/language/parser/parser.go +++ b/language/parser/parser.go @@ -250,13 +250,13 @@ func parseOperationType(parser *Parser) (string, error) { func parseVariableDefinitions(parser *Parser) ([]*ast.VariableDefinition, error) { variableDefinitions := []*ast.VariableDefinition{} if !peek(parser, lexer.TokenKind[lexer.PAREN_L]) { - return nil, nil + return variableDefinitions, nil } if vdefs, err := reverse(parser, lexer.TokenKind[lexer.PAREN_L], parseVariableDefinition, lexer.TokenKind[lexer.PAREN_R], true, ); err != nil { - return variableDefinitions, nil + return variableDefinitions, err } else { for _, vdef := range vdefs { variableDefinitions = append(variableDefinitions, vdef.(*ast.VariableDefinition))