Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bind directly to AST types, instead of copying out random bits #490

Merged
merged 6 commits into from
Jan 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,23 @@ type ServerBuild struct {

// Create a list of models that need to be generated
func (g *Generator) models() (*ModelBuild, error) {
namedTypes := g.buildNamedTypes()

progLoader := g.newLoaderWithoutErrors()

prog, err := progLoader.Load()
if err != nil {
return nil, errors.Wrap(err, "loading failed")
}

g.bindTypes(namedTypes, g.Model.Dir(), prog)
namedTypes, err := g.buildNamedTypes(prog)
if err != nil {
return nil, errors.Wrap(err, "binding types failed")
}

directives, err := g.buildDirectives(namedTypes)
if err != nil {
return nil, err
}
g.Directives = directives

models, err := g.buildModels(namedTypes, prog)
if err != nil {
Expand All @@ -77,11 +84,16 @@ func (g *Generator) resolver() (*ResolverBuild, error) {
return nil, err
}

destDir := g.Resolver.Dir()

namedTypes := g.buildNamedTypes()
namedTypes, err := g.buildNamedTypes(prog)
if err != nil {
return nil, errors.Wrap(err, "binding types failed")
}

g.bindTypes(namedTypes, destDir, prog)
directives, err := g.buildDirectives(namedTypes)
if err != nil {
return nil, err
}
g.Directives = directives

objects, err := g.buildObjects(namedTypes, prog)
if err != nil {
Expand Down Expand Up @@ -109,26 +121,29 @@ func (g *Generator) server(destDir string) *ServerBuild {

// bind a schema together with some code to generate a Build
func (g *Generator) bind() (*Build, error) {
namedTypes := g.buildNamedTypes()

progLoader := g.newLoaderWithoutErrors()
prog, err := progLoader.Load()
if err != nil {
return nil, errors.Wrap(err, "loading failed")
}

g.bindTypes(namedTypes, g.Exec.Dir(), prog)
namedTypes, err := g.buildNamedTypes(prog)
if err != nil {
return nil, errors.Wrap(err, "binding types failed")
}

objects, err := g.buildObjects(namedTypes, prog)
directives, err := g.buildDirectives(namedTypes)
if err != nil {
return nil, err
}
g.Directives = directives

inputs, err := g.buildInputs(namedTypes, prog)
objects, err := g.buildObjects(namedTypes, prog)
if err != nil {
return nil, err
}
directives, err := g.buildDirectives(namedTypes)

inputs, err := g.buildInputs(namedTypes, prog)
if err != nil {
return nil, err
}
Expand Down
39 changes: 26 additions & 13 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"sort"
"strings"

"go/types"

"github.com/99designs/gqlgen/internal/gopath"
"github.com/pkg/errors"
"github.com/vektah/gqlparser"
Expand Down Expand Up @@ -168,6 +170,10 @@ func (c *PackageConfig) Check() error {
return c.normalize()
}

func (c *PackageConfig) Pkg() *types.Package {
return types.NewPackage(c.ImportPath(), c.Dir())
}

func (c *PackageConfig) IsDefined() bool {
return c.Filename != ""
}
Expand Down Expand Up @@ -198,6 +204,11 @@ func (tm TypeMap) Exists(typeName string) bool {
return ok
}

func (tm TypeMap) UserDefined(typeName string) bool {
m, ok := tm[typeName]
return ok && m.Model != ""
}

func (tm TypeMap) Check() error {
for typeName, entry := range tm {
if strings.LastIndex(entry.Model, ".") < strings.LastIndex(entry.Model, "/") {
Expand Down Expand Up @@ -285,19 +296,21 @@ func (cfg *Config) normalize() error {
}

builtins := TypeMap{
"__Directive": {Model: "github.com/99designs/gqlgen/graphql/introspection.Directive"},
"__Type": {Model: "github.com/99designs/gqlgen/graphql/introspection.Type"},
"__Field": {Model: "github.com/99designs/gqlgen/graphql/introspection.Field"},
"__EnumValue": {Model: "github.com/99designs/gqlgen/graphql/introspection.EnumValue"},
"__InputValue": {Model: "github.com/99designs/gqlgen/graphql/introspection.InputValue"},
"__Schema": {Model: "github.com/99designs/gqlgen/graphql/introspection.Schema"},
"Int": {Model: "github.com/99designs/gqlgen/graphql.Int"},
"Float": {Model: "github.com/99designs/gqlgen/graphql.Float"},
"String": {Model: "github.com/99designs/gqlgen/graphql.String"},
"Boolean": {Model: "github.com/99designs/gqlgen/graphql.Boolean"},
"ID": {Model: "github.com/99designs/gqlgen/graphql.ID"},
"Time": {Model: "github.com/99designs/gqlgen/graphql.Time"},
"Map": {Model: "github.com/99designs/gqlgen/graphql.Map"},
"__Directive": {Model: "github.com/99designs/gqlgen/graphql/introspection.Directive"},
"__DirectiveLocation": {Model: "github.com/99designs/gqlgen/graphql.String"},
"__Type": {Model: "github.com/99designs/gqlgen/graphql/introspection.Type"},
"__TypeKind": {Model: "github.com/99designs/gqlgen/graphql.String"},
"__Field": {Model: "github.com/99designs/gqlgen/graphql/introspection.Field"},
"__EnumValue": {Model: "github.com/99designs/gqlgen/graphql/introspection.EnumValue"},
"__InputValue": {Model: "github.com/99designs/gqlgen/graphql/introspection.InputValue"},
"__Schema": {Model: "github.com/99designs/gqlgen/graphql/introspection.Schema"},
"Int": {Model: "github.com/99designs/gqlgen/graphql.Int"},
"Float": {Model: "github.com/99designs/gqlgen/graphql.Float"},
"String": {Model: "github.com/99designs/gqlgen/graphql.String"},
"Boolean": {Model: "github.com/99designs/gqlgen/graphql.Boolean"},
"ID": {Model: "github.com/99designs/gqlgen/graphql.ID"},
"Time": {Model: "github.com/99designs/gqlgen/graphql.Time"},
"Map": {Model: "github.com/99designs/gqlgen/graphql.Map"},
}

if cfg.Models == nil {
Expand Down
4 changes: 2 additions & 2 deletions codegen/directive.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (d *Directive) CallArgs() string {
args := []string{"ctx", "obj", "n"}

for _, arg := range d.Args {
args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+templates.CurrentImports.LookupType(arg.GoType)+")")
}

return strings.Join(args, ", ")
Expand Down Expand Up @@ -56,7 +56,7 @@ func (d *Directive) Declaration() string {
res := ucFirst(d.Name) + " func(ctx context.Context, obj interface{}, next graphql.Resolver"

for _, arg := range d.Args {
res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
res += fmt.Sprintf(", %s %s", arg.GoVarName, templates.CurrentImports.LookupType(arg.GoType))
}

res += ") (res interface{}, err error)"
Expand Down
2 changes: 1 addition & 1 deletion codegen/directive_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (g *Generator) buildDirectives(types NamedTypes) (map[string]*Directive, er
GoVarName: sanitizeArgName(arg.Name),
}

if !newArg.TypeReference.IsInput && !newArg.TypeReference.IsScalar {
if !newArg.TypeReference.Definition.GQLDefinition.IsInputType() {
return nil, errors.Errorf("%s cannot be used as argument of directive %s(%s) only input and scalar types are allowed", arg.Type, dir.Name, arg.Name)
}

Expand Down
2 changes: 1 addition & 1 deletion codegen/enum.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package codegen

type Enum struct {
*TypeDefinition
Definition *TypeDefinition
Description string
Values []EnumValue
}
Expand Down
19 changes: 11 additions & 8 deletions codegen/enum_build.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
package codegen

import (
"go/types"
"sort"
"strings"

"github.com/99designs/gqlgen/codegen/templates"
"github.com/vektah/gqlparser/ast"
)

func (g *Generator) buildEnums(types NamedTypes) []Enum {
func (g *Generator) buildEnums(ts NamedTypes) []Enum {
var enums []Enum

for _, typ := range g.schema.Types {
namedType := types[typ.Name]
if typ.Kind != ast.Enum || strings.HasPrefix(typ.Name, "__") || namedType.IsUserDefined {
namedType := ts[typ.Name]
if typ.Kind != ast.Enum || strings.HasPrefix(typ.Name, "__") || g.Models.UserDefined(typ.Name) {
continue
}

Expand All @@ -23,16 +24,18 @@ func (g *Generator) buildEnums(types NamedTypes) []Enum {
}

enum := Enum{
TypeDefinition: namedType,
Values: values,
Description: typ.Description,
Definition: namedType,
Values: values,
Description: typ.Description,
}
enum.GoType = templates.ToCamel(enum.GQLType)

enum.Definition.GoType = types.NewNamed(types.NewTypeName(0, g.Config.Model.Pkg(), templates.ToCamel(enum.Definition.GQLDefinition.Name), nil), nil, nil)

enums = append(enums, enum)
}

sort.Slice(enums, func(i, j int) bool {
return enums[i].GQLType < enums[j].GQLType
return enums[i].Definition.GQLDefinition.Name < enums[j].Definition.GQLDefinition.Name
})

return enums
Expand Down
21 changes: 7 additions & 14 deletions codegen/generator.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package codegen

import (
"go/types"
"log"
"os"
"path/filepath"
Expand Down Expand Up @@ -40,14 +41,6 @@ func (g *Generator) Generate() error {
_ = syscall.Unlink(g.Exec.Filename)
_ = syscall.Unlink(g.Model.Filename)

namedTypes := g.buildNamedTypes()

directives, err := g.buildDirectives(namedTypes)
if err != nil {
return err
}
g.Directives = directives

modelsBuild, err := g.models()
if err != nil {
return errors.Wrap(err, "model plan failed")
Expand All @@ -58,15 +51,15 @@ func (g *Generator) Generate() error {
}

for _, model := range modelsBuild.Models {
modelCfg := g.Models[model.GQLType]
modelCfg.Model = g.Model.ImportPath() + "." + model.GoType
g.Models[model.GQLType] = modelCfg
modelCfg := g.Models[model.Definition.GQLDefinition.Name]
modelCfg.Model = types.TypeString(model.Definition.GoType, nil)
g.Models[model.Definition.GQLDefinition.Name] = modelCfg
}

for _, enum := range modelsBuild.Enums {
modelCfg := g.Models[enum.GQLType]
modelCfg.Model = g.Model.ImportPath() + "." + enum.GoType
g.Models[enum.GQLType] = modelCfg
modelCfg := g.Models[enum.Definition.GQLDefinition.Name]
modelCfg.Model = types.TypeString(enum.Definition.GoType, nil)
g.Models[enum.Definition.GQLDefinition.Name] = modelCfg
}
}

Expand Down
41 changes: 12 additions & 29 deletions codegen/input_build.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package codegen

import (
"go/types"
"sort"

"go/types"

"github.com/pkg/errors"
"github.com/vektah/gqlparser/ast"
"golang.org/x/tools/go/loader"
Expand All @@ -20,13 +21,8 @@ func (g *Generator) buildInputs(namedTypes NamedTypes, prog *loader.Program) (Ob
return nil, err
}

def, err := findGoType(prog, input.Package, input.GoType)
if err != nil {
return nil, errors.Wrap(err, "cannot find type")
}
if def != nil {
input.Marshaler = buildInputMarshaler(typ, def)
bindErrs := bindObject(def.Type(), input, g.StructTag)
if _, isMap := input.Definition.GoType.(*types.Map); !isMap {
bindErrs := bindObject(input, g.StructTag)
if len(bindErrs) > 0 {
return nil, bindErrs
}
Expand All @@ -37,14 +33,14 @@ func (g *Generator) buildInputs(namedTypes NamedTypes, prog *loader.Program) (Ob
}

sort.Slice(inputs, func(i, j int) bool {
return inputs[i].GQLType < inputs[j].GQLType
return inputs[i].Definition.GQLDefinition.Name < inputs[j].Definition.GQLDefinition.Name
})

return inputs, nil
}

func (g *Generator) buildInput(types NamedTypes, typ *ast.Definition) (*Object, error) {
obj := &Object{TypeDefinition: types[typ.Name]}
obj := &Object{Definition: types[typ.Name]}
typeEntry, entryExists := g.Models[typ.Name]

for _, field := range typ.Fields {
Expand Down Expand Up @@ -73,8 +69,12 @@ func (g *Generator) buildInput(types NamedTypes, typ *ast.Definition) (*Object,
}
}

if !newField.TypeReference.IsInput && !newField.TypeReference.IsScalar {
return nil, errors.Errorf("%s cannot be used as a field of %s. only input and scalar types are allowed", newField.GQLType, obj.GQLType)
if !newField.TypeReference.Definition.GQLDefinition.IsInputType() {
return nil, errors.Errorf(
"%s cannot be used as a field of %s. only input and scalar types are allowed",
newField.Definition.GQLDefinition.Name,
obj.Definition.GQLDefinition.Name,
)
}

obj.Fields = append(obj.Fields, newField)
Expand All @@ -88,20 +88,3 @@ func (g *Generator) buildInput(types NamedTypes, typ *ast.Definition) (*Object,

return obj, nil
}

// if user has implemented an UnmarshalGQL method on the input type manually, use it
// otherwise we will generate one.
func buildInputMarshaler(typ *ast.Definition, def types.Object) *TypeImplementation {
switch def := def.(type) {
case *types.TypeName:
namedType := def.Type().(*types.Named)
for i := 0; i < namedType.NumMethods(); i++ {
method := namedType.Method(i)
if method.Name() == "UnmarshalGQL" {
return nil
}
}
}

return &TypeImplementation{GoType: typ.Name}
}
6 changes: 2 additions & 4 deletions codegen/input_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@ package codegen
import (
"testing"

"github.com/vektah/gqlparser/gqlerror"

"github.com/vektah/gqlparser/ast"

"github.com/99designs/gqlgen/codegen/config"
"github.com/stretchr/testify/require"
"github.com/vektah/gqlparser"
"github.com/vektah/gqlparser/ast"
"github.com/vektah/gqlparser/gqlerror"
"golang.org/x/tools/go/loader"
)

Expand Down
6 changes: 2 additions & 4 deletions codegen/interface.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package codegen

type Interface struct {
*TypeDefinition

Definition *TypeDefinition
Implementors []InterfaceImplementor
}

type InterfaceImplementor struct {
ValueReceiver bool

*TypeDefinition
Definition *TypeDefinition
}
Loading