Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validate method return types #193

Merged
merged 1 commit into from
Jul 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions codegen/input_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, impo
}
if def != nil {
input.Marshaler = buildInputMarshaler(typ, def)
err = bindObject(def.Type(), input, imports)
if err != nil {
return nil, err
bindErrs := bindObject(def.Type(), input, imports)
if len(bindErrs) > 0 {
return nil, bindErrs
}
}

Expand Down
7 changes: 4 additions & 3 deletions codegen/object_build.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package codegen

import (
"log"
"sort"
"strings"

Expand All @@ -25,9 +26,9 @@ func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports
return nil, err
}
if def != nil {
err = bindObject(def.Type(), obj, imports)
if err != nil {
return nil, err
for _, bindErr := range bindObject(def.Type(), obj, imports) {
log.Println(bindErr.Error())
log.Println(" Adding resolver method")
}
}

Expand Down
2 changes: 1 addition & 1 deletion codegen/testdata/recursive.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package testdata

type RecursiveInputSlice struct {
Self *[]*RecursiveInputSlice
Self []RecursiveInputSlice
}
183 changes: 134 additions & 49 deletions codegen/util.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package codegen

import (
"fmt"
"go/types"
"regexp"
"strings"
Expand Down Expand Up @@ -131,71 +132,155 @@ func findField(typ *types.Struct, name string) *types.Var {
return nil
}

func bindObject(t types.Type, object *Object, imports *Imports) error {
type BindError struct {
object *Object
field *Field
typ types.Type
methodErr error
varErr error
}

func (b BindError) Error() string {
return fmt.Sprintf(
"Unable to bind %s.%s to %s\n %s\n %s",
b.object.GQLType,
b.field.GQLName,
b.typ.String(),
b.methodErr.Error(),
b.varErr.Error(),
)
}

type BindErrors []BindError

func (b BindErrors) Error() string {
var errs []string
for _, err := range b {
errs = append(errs, err.Error())
}
return strings.Join(errs, "\n\n")
}

func bindObject(t types.Type, object *Object, imports *Imports) BindErrors {
var errs BindErrors
for i := range object.Fields {
field := &object.Fields[i]

// first try binding to a method
methodErr := bindMethod(imports, t, field)
if methodErr == nil {
continue
}

// otherwise try binding to a var
varErr := bindVar(imports, t, field)

if varErr != nil {
errs = append(errs, BindError{
object: object,
typ: t,
field: field,
varErr: varErr,
methodErr: methodErr,
})
}
}
return errs
}

func bindMethod(imports *Imports, t types.Type, field *Field) error {
namedType, ok := t.(*types.Named)
if !ok {
return errors.Errorf("expected %s to be a named struct, instead found %s", object.FullName(), t.String())
return fmt.Errorf("not a named type")
}

method := findMethod(namedType, field.GQLName)
if method == nil {
return fmt.Errorf("no method named %s", field.GQLName)
}
sig := method.Type().(*types.Signature)

if sig.Results().Len() == 1 {
field.NoErr = true
} else if sig.Results().Len() != 2 {
return fmt.Errorf("method has wrong number of args")
}
newArgs, err := matchArgs(field, sig.Params())
if err != nil {
return err
}

result := sig.Results().At(0)
if err := validateTypeBinding(imports, field, result.Type()); err != nil {
return errors.Wrap(err, "method has wrong return type")
}

// success, args and return type match. Bind to method
field.GoMethodName = "obj." + method.Name()
field.Args = newArgs
return nil
}

func bindVar(imports *Imports, t types.Type, field *Field) error {
underlying, ok := t.Underlying().(*types.Struct)
if !ok {
return errors.Errorf("expected %s to be a named struct, instead found %s", object.FullName(), t.String())
return fmt.Errorf("not a struct")
}

for i := range object.Fields {
field := &object.Fields[i]
if method := findMethod(namedType, field.GQLName); method != nil {
sig := method.Type().(*types.Signature)
field.GoMethodName = "obj." + method.Name()
field.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type())

// check arg order matches code, not gql
var newArgs []FieldArgument
l2:
for j := 0; j < sig.Params().Len(); j++ {
param := sig.Params().At(j)
for _, oldArg := range field.Args {
if strings.EqualFold(oldArg.GQLName, param.Name()) {
oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
newArgs = append(newArgs, oldArg)
continue l2
}
}
return errors.Errorf("cannot match argument " + param.Name() + " to any argument in " + t.String())
}
field.Args = newArgs
structField := findField(underlying, field.GQLName)
if structField == nil {
return fmt.Errorf("no field named %s", field.GQLName)
}

if err := validateTypeBinding(imports, field, structField.Type()); err != nil {
return errors.Wrap(err, "field has wrong type")
}

if sig.Results().Len() == 1 {
field.NoErr = true
} else if sig.Results().Len() != 2 {
return errors.Errorf("weird number of results on %s. expected either (result), or (result, error)\n", method.Name())
// success, bind to var
field.GoVarName = structField.Name()
return nil
}

func matchArgs(field *Field, params *types.Tuple) ([]FieldArgument, error) {
var newArgs []FieldArgument

nextArg:
for j := 0; j < params.Len(); j++ {
param := params.At(j)
for _, oldArg := range field.Args {
if strings.EqualFold(oldArg.GQLName, param.Name()) {
oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
newArgs = append(newArgs, oldArg)
continue nextArg
}
continue
}

if structField := findField(underlying, field.GQLName); structField != nil {
prevModifiers := field.Type.Modifiers
field.Type.Modifiers = modifiersFromGoType(structField.Type())
field.GoVarName = structField.Name()
// no matching arg found, abort
return nil, fmt.Errorf("arg %s not found on method", param.Name())
}
return newArgs, nil
}

switch normalizeVendor(field.Type.FullSignature()) {
case normalizeVendor(structField.Type().String()):
// everything is fine
func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error {
gqlType := normalizeVendor(field.Type.FullSignature())
goTypeStr := normalizeVendor(goType.String())

case normalizeVendor(structField.Type().Underlying().String()):
pkg, typ := pkgAndType(structField.Type().String())
imp := imports.findByPath(pkg)
field.CastType = &Ref{GoType: typ, Import: imp}
if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType {
field.Type.Modifiers = modifiersFromGoType(goType)
return nil
}

default:
// type mismatch, require custom resolver for field
field.GoVarName = ""
field.Type.Modifiers = prevModifiers
}
continue
}
// deal with type aliases
underlyingStr := normalizeVendor(goType.Underlying().String())
if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType {
field.Type.Modifiers = modifiersFromGoType(goType)
pkg, typ := pkgAndType(goType.String())
imp := imports.findByPath(pkg)
field.CastType = &Ref{GoType: typ, Import: imp}
return nil
}
return nil

return fmt.Errorf("%s is not compatible with %s", gqlType, goTypeStr)
}

func modifiersFromGoType(t types.Type) []string {
Expand Down
60 changes: 12 additions & 48 deletions example/chat/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading