From be8a96f95a821a7c14f5ec2bb28d55c7f1b80310 Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Thu, 26 Apr 2018 16:11:30 +1000 Subject: [PATCH 1/3] Refactor main so tests can execute the generator --- codegen/build.go | 62 +++++++------ codegen/codegen.go | 177 +++++++++++++++++++++++++++++++++++++ codegen/enum_build.go | 4 +- codegen/input_build.go | 11 ++- codegen/interface_build.go | 23 ++--- codegen/models_build.go | 16 ++-- codegen/object_build.go | 16 ++-- codegen/type_build.go | 8 +- codegen/util.go | 50 +++++------ main.go | 135 +++------------------------- 10 files changed, 288 insertions(+), 214 deletions(-) create mode 100644 codegen/codegen.go diff --git a/codegen/build.go b/codegen/build.go index 8f7dc6f11cc..68720004645 100644 --- a/codegen/build.go +++ b/codegen/build.go @@ -5,9 +5,8 @@ import ( "go/build" "go/types" "os" - "path/filepath" - "github.com/vektah/gqlgen/neelance/schema" + "github.com/pkg/errors" "golang.org/x/tools/go/loader" ) @@ -31,58 +30,65 @@ type ModelBuild struct { } // Create a list of models that need to be generated -func Models(schema *schema.Schema, userTypes map[string]string, destDir string) *ModelBuild { - namedTypes := buildNamedTypes(schema, userTypes) +func (cfg *Config) models() (*ModelBuild, error) { + namedTypes := cfg.buildNamedTypes() - imports := buildImports(namedTypes, destDir) - prog, err := loadProgram(imports, true) + imports := buildImports(namedTypes, cfg.modelDir) + prog, err := cfg.loadProgram(imports, true) if err != nil { - panic(err) + return nil, errors.Wrap(err, "loading failed") } - bindTypes(imports, namedTypes, destDir, prog) + cfg.bindTypes(imports, namedTypes, cfg.modelDir, prog) - models := buildModels(namedTypes, schema, prog) + models := cfg.buildModels(namedTypes, prog) return &ModelBuild{ - PackageName: filepath.Base(destDir), + PackageName: cfg.ModelPackageName, Models: models, - Enums: buildEnums(namedTypes, schema), - Imports: buildImports(namedTypes, destDir), - } + Enums: cfg.buildEnums(namedTypes), + Imports: buildImports(namedTypes, cfg.modelDir), + }, nil } -// Bind a schema together with some code to generate a Build -func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*Build, error) { - namedTypes := buildNamedTypes(schema, userTypes) +// bind a schema together with some code to generate a Build +func (cfg *Config) bind() (*Build, error) { + namedTypes := cfg.buildNamedTypes() - imports := buildImports(namedTypes, destDir) - prog, err := loadProgram(imports, false) + imports := buildImports(namedTypes, cfg.execDir) + prog, err := cfg.loadProgram(imports, false) if err != nil { - return nil, err + return nil, errors.Wrap(err, "loading failed") } - imports = bindTypes(imports, namedTypes, destDir, prog) + imports = cfg.bindTypes(imports, namedTypes, cfg.execDir, prog) - objects := buildObjects(namedTypes, schema, prog, imports) - inputs := buildInputs(namedTypes, schema, prog, imports) + objects, err := cfg.buildObjects(namedTypes, prog, imports) + if err != nil { + return nil, err + } + + inputs, err := cfg.buildInputs(namedTypes, prog, imports) + if err != nil { + return nil, err + } b := &Build{ - PackageName: filepath.Base(destDir), + PackageName: cfg.ExecPackageName, Objects: objects, - Interfaces: buildInterfaces(namedTypes, schema, prog), + Interfaces: cfg.buildInterfaces(namedTypes, prog), Inputs: inputs, Imports: imports, } - if qr, ok := schema.EntryPoints["query"]; ok { + if qr, ok := cfg.schema.EntryPoints["query"]; ok { b.QueryRoot = b.Objects.ByName(qr.TypeName()) } - if mr, ok := schema.EntryPoints["mutation"]; ok { + if mr, ok := cfg.schema.EntryPoints["mutation"]; ok { b.MutationRoot = b.Objects.ByName(mr.TypeName()) } - if sr, ok := schema.EntryPoints["subscription"]; ok { + if sr, ok := cfg.schema.EntryPoints["subscription"]; ok { b.SubscriptionRoot = b.Objects.ByName(sr.TypeName()) } @@ -113,7 +119,7 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (* return b, nil } -func loadProgram(imports Imports, allowErrors bool) (*loader.Program, error) { +func (cfg *Config) loadProgram(imports Imports, allowErrors bool) (*loader.Program, error) { conf := loader.Config{} if allowErrors { conf = loader.Config{ diff --git a/codegen/codegen.go b/codegen/codegen.go new file mode 100644 index 00000000000..283b110f8c8 --- /dev/null +++ b/codegen/codegen.go @@ -0,0 +1,177 @@ +package codegen + +import ( + "fmt" + "go/build" + "io/ioutil" + "os" + "path/filepath" + "strings" + + "github.com/pkg/errors" + "github.com/vektah/gqlgen/codegen/templates" + "github.com/vektah/gqlgen/neelance/schema" + "golang.org/x/tools/imports" +) + +type Config struct { + SchemaStr string + Typemap map[string]string + + schema *schema.Schema + + ExecFilename string + ExecPackageName string + execDir string + fullExecPackageName string + + ModelFilename string + ModelPackageName string + modelDir string + fullModelPackageName string +} + +func Generate(cfg Config) error { + if err := cfg.normalize(); err != nil { + return err + } + + modelsBuild, err := cfg.models() + if err != nil { + return errors.Wrap(err, "model plan failed") + } + if len(modelsBuild.Models) > 0 { + modelsBuild.PackageName = cfg.ModelPackageName + + buf, err := templates.Run("models.gotpl", modelsBuild) + if err != nil { + return errors.Wrap(err, "model generation failed") + } + + if err = write(cfg.ModelFilename, buf.Bytes()); err != nil { + return err + } + for _, model := range modelsBuild.Models { + cfg.Typemap[model.GQLType] = cfg.fullModelPackageName + "." + model.GoType + } + + for _, enum := range modelsBuild.Enums { + cfg.Typemap[enum.GQLType] = cfg.fullModelPackageName + "." + enum.GoType + } + } + + build, err := cfg.bind() + if err != nil { + return errors.Wrap(err, "exec plan failed") + } + build.SchemaRaw = cfg.SchemaStr + build.PackageName = cfg.ExecPackageName + + buf, err := templates.Run("generated.gotpl", build) + if err != nil { + return errors.Wrap(err, "exec codegen failed") + } + + if err = write(cfg.ExecFilename, buf.Bytes()); err != nil { + return err + } + return nil +} + +func (cfg *Config) normalize() error { + if cfg.ModelFilename == "" { + return errors.New("ModelFilename is required") + } + cfg.ModelFilename = abs(cfg.ModelFilename) + cfg.modelDir = filepath.Dir(cfg.ModelFilename) + if cfg.ModelPackageName == "" { + cfg.ModelPackageName = filepath.Base(cfg.modelDir) + } + cfg.fullModelPackageName = fullPackageName(cfg.modelDir, cfg.ModelPackageName) + + if cfg.ExecFilename == "" { + return errors.New("ModelFilename is required") + } + cfg.ExecFilename = abs(cfg.ExecFilename) + cfg.execDir = filepath.Dir(cfg.ExecFilename) + if cfg.ExecPackageName == "" { + cfg.ExecPackageName = filepath.Base(cfg.execDir) + } + cfg.fullExecPackageName = fullPackageName(cfg.execDir, cfg.ExecPackageName) + + builtins := map[string]string{ + "__Directive": "github.com/vektah/gqlgen/neelance/introspection.Directive", + "__Type": "github.com/vektah/gqlgen/neelance/introspection.Type", + "__Field": "github.com/vektah/gqlgen/neelance/introspection.Field", + "__EnumValue": "github.com/vektah/gqlgen/neelance/introspection.EnumValue", + "__InputValue": "github.com/vektah/gqlgen/neelance/introspection.InputValue", + "__Schema": "github.com/vektah/gqlgen/neelance/introspection.Schema", + "Int": "github.com/vektah/gqlgen/graphql.Int", + "Float": "github.com/vektah/gqlgen/graphql.Float", + "String": "github.com/vektah/gqlgen/graphql.String", + "Boolean": "github.com/vektah/gqlgen/graphql.Boolean", + "ID": "github.com/vektah/gqlgen/graphql.ID", + "Time": "github.com/vektah/gqlgen/graphql.Time", + "Map": "github.com/vektah/gqlgen/graphql.Map", + } + + if cfg.Typemap == nil { + cfg.Typemap = map[string]string{} + } + for k, v := range builtins { + if _, ok := cfg.Typemap[k]; !ok { + cfg.Typemap[k] = v + } + } + + cfg.schema = schema.New() + return cfg.schema.Parse(cfg.SchemaStr) +} + +func abs(path string) string { + absPath, err := filepath.Abs(path) + if err != nil { + panic(err) + } + return absPath +} + +func fullPackageName(dir string, pkgName string) string { + fullPkgName := filepath.Join(filepath.Dir(dir), pkgName) + + for _, gopath := range filepath.SplitList(build.Default.GOPATH) { + gopath = filepath.Join(gopath, "src") + string(os.PathSeparator) + if strings.HasPrefix(fullPkgName, gopath) { + fullPkgName = fullPkgName[len(gopath):] + } + } + return filepath.ToSlash(fullPkgName) +} + +func gofmt(filename string, b []byte) ([]byte, error) { + out, err := imports.Process(filename, b, nil) + if err != nil { + return b, errors.Wrap(err, "unable to gofmt") + } + return out, nil +} + +func write(filename string, b []byte) error { + err := os.MkdirAll(filepath.Dir(filename), 0755) + if err != nil { + return errors.Wrap(err, "failed to create directory") + } + + formatted, err := gofmt(filename, b) + if err != nil { + fmt.Fprintf(os.Stderr, "gofmt failed: %s", err.Error()) + formatted = b + } + + err = ioutil.WriteFile(filename, formatted, 0644) + if err != nil { + return errors.Wrapf(err, "failed to write %s", filename) + } + + return nil +} diff --git a/codegen/enum_build.go b/codegen/enum_build.go index 59a342f2ed4..e16a757520b 100644 --- a/codegen/enum_build.go +++ b/codegen/enum_build.go @@ -7,10 +7,10 @@ import ( "github.com/vektah/gqlgen/neelance/schema" ) -func buildEnums(types NamedTypes, s *schema.Schema) []Enum { +func (cfg *Config) buildEnums(types NamedTypes) []Enum { var enums []Enum - for _, typ := range s.Types { + for _, typ := range cfg.schema.Types { namedType := types[typ.TypeName()] e, isEnum := typ.(*schema.Enum) if !isEnum || strings.HasPrefix(typ.TypeName(), "__") || namedType.IsUserDefined { diff --git a/codegen/input_build.go b/codegen/input_build.go index 685c5df8558..13241495a82 100644 --- a/codegen/input_build.go +++ b/codegen/input_build.go @@ -1,27 +1,26 @@ package codegen import ( - "fmt" "go/types" - "os" "sort" "strings" + "github.com/pkg/errors" "github.com/vektah/gqlgen/neelance/schema" "golang.org/x/tools/go/loader" ) -func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program, imports Imports) Objects { +func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, imports Imports) (Objects, error) { var inputs Objects - for _, typ := range s.Types { + for _, typ := range cfg.schema.Types { switch typ := typ.(type) { case *schema.InputObject: input := buildInput(namedTypes, typ) def, err := findGoType(prog, input.Package, input.GoType) if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) + return nil, errors.Wrap(err, "cannot find type") } if def != nil { input.Marshaler = buildInputMarshaler(typ, def) @@ -36,7 +35,7 @@ func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program, return strings.Compare(inputs[i].GQLType, inputs[j].GQLType) == -1 }) - return inputs + return inputs, nil } func buildInput(types NamedTypes, typ *schema.InputObject) *Object { diff --git a/codegen/interface_build.go b/codegen/interface_build.go index de8541cb2d9..cdf0f59724b 100644 --- a/codegen/interface_build.go +++ b/codegen/interface_build.go @@ -11,12 +11,12 @@ import ( "golang.org/x/tools/go/loader" ) -func buildInterfaces(types NamedTypes, s *schema.Schema, prog *loader.Program) []*Interface { +func (cfg *Config) buildInterfaces(types NamedTypes, prog *loader.Program) []*Interface { var interfaces []*Interface - for _, typ := range s.Types { + for _, typ := range cfg.schema.Types { switch typ := typ.(type) { case *schema.Union, *schema.Interface: - interfaces = append(interfaces, buildInterface(types, typ, prog)) + interfaces = append(interfaces, cfg.buildInterface(types, typ, prog)) default: continue } @@ -29,7 +29,7 @@ func buildInterfaces(types NamedTypes, s *schema.Schema, prog *loader.Program) [ return interfaces } -func buildInterface(types NamedTypes, typ schema.NamedType, prog *loader.Program) *Interface { +func (cfg *Config) buildInterface(types NamedTypes, typ schema.NamedType, prog *loader.Program) *Interface { switch typ := typ.(type) { case *schema.Union: @@ -40,7 +40,7 @@ func buildInterface(types NamedTypes, typ schema.NamedType, prog *loader.Program i.Implementors = append(i.Implementors, InterfaceImplementor{ NamedType: t, - ValueReceiver: isValueReceiver(types[typ.Name], t, prog), + ValueReceiver: cfg.isValueReceiver(types[typ.Name], t, prog), }) } @@ -54,7 +54,7 @@ func buildInterface(types NamedTypes, typ schema.NamedType, prog *loader.Program i.Implementors = append(i.Implementors, InterfaceImplementor{ NamedType: t, - ValueReceiver: isValueReceiver(types[typ.Name], t, prog), + ValueReceiver: cfg.isValueReceiver(types[typ.Name], t, prog), }) } @@ -64,11 +64,14 @@ func buildInterface(types NamedTypes, typ schema.NamedType, prog *loader.Program } } -func isValueReceiver(intf *NamedType, implementor *NamedType, prog *loader.Program) bool { - interfaceType := findGoInterface(prog, intf.Package, intf.GoType) - implementorType := findGoNamedType(prog, implementor.Package, implementor.GoType) +func (cfg *Config) isValueReceiver(intf *NamedType, implementor *NamedType, prog *loader.Program) bool { + interfaceType, err := findGoInterface(prog, intf.Package, intf.GoType) + if interfaceType == nil || err != nil { + return true + } - if interfaceType == nil || implementorType == nil { + implementorType, err := findGoNamedType(prog, implementor.Package, implementor.GoType) + if implementorType == nil || err != nil { return true } diff --git a/codegen/models_build.go b/codegen/models_build.go index d75deebde78..937af104f7e 100644 --- a/codegen/models_build.go +++ b/codegen/models_build.go @@ -8,26 +8,26 @@ import ( "golang.org/x/tools/go/loader" ) -func buildModels(types NamedTypes, s *schema.Schema, prog *loader.Program) []Model { +func (cfg *Config) buildModels(types NamedTypes, prog *loader.Program) []Model { var models []Model - for _, typ := range s.Types { + for _, typ := range cfg.schema.Types { var model Model switch typ := typ.(type) { case *schema.Object: - obj := buildObject(types, typ, s) + obj := cfg.buildObject(types, typ) if obj.Root || obj.IsUserDefined { continue } - model = obj2Model(s, obj) + model = cfg.obj2Model(obj) case *schema.InputObject: obj := buildInput(types, typ) if obj.IsUserDefined { continue } - model = obj2Model(s, obj) + model = cfg.obj2Model(obj) case *schema.Interface, *schema.Union: - intf := buildInterface(types, typ, prog) + intf := cfg.buildInterface(types, typ, prog) if intf.IsUserDefined { continue } @@ -46,7 +46,7 @@ func buildModels(types NamedTypes, s *schema.Schema, prog *loader.Program) []Mod return models } -func obj2Model(s *schema.Schema, obj *Object) Model { +func (cfg *Config) obj2Model(obj *Object) Model { model := Model{ NamedType: obj.NamedType, Fields: []ModelField{}, @@ -72,7 +72,7 @@ func obj2Model(s *schema.Schema, obj *Object) Model { mf.GoFKName = ucFirst(field.GQLName) + "ID" mf.GoFKType = "string" - if obj, ok := s.Types[field.GQLType].(*schema.Object); ok { + if obj, ok := cfg.schema.Types[field.GQLType].(*schema.Object); ok { for _, f := range obj.Fields { if strings.EqualFold(f.Name, "id") { if strings.Contains(f.Type.String(), "Int") { diff --git a/codegen/object_build.go b/codegen/object_build.go index d4fbd5a96b0..1f30113e836 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -1,8 +1,6 @@ package codegen import ( - "fmt" - "os" "sort" "strings" @@ -10,17 +8,17 @@ import ( "golang.org/x/tools/go/loader" ) -func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program, imports Imports) Objects { +func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports Imports) (Objects, error) { var objects Objects - for _, typ := range s.Types { + for _, typ := range cfg.schema.Types { switch typ := typ.(type) { case *schema.Object: - obj := buildObject(types, typ, s) + obj := cfg.buildObject(types, typ) def, err := findGoType(prog, obj.Package, obj.GoType) if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) + return nil, err } if def != nil { bindObject(def.Type(), obj, imports) @@ -34,10 +32,10 @@ func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program, impo return strings.Compare(objects[i].GQLType, objects[j].GQLType) == -1 }) - return objects + return objects, nil } -func buildObject(types NamedTypes, typ *schema.Object, s *schema.Schema) *Object { +func (cfg *Config) buildObject(types NamedTypes, typ *schema.Object) *Object { obj := &Object{NamedType: types[typ.TypeName()]} for _, i := range typ.Interfaces { @@ -68,7 +66,7 @@ func buildObject(types NamedTypes, typ *schema.Object, s *schema.Schema) *Object }) } - for name, typ := range s.EntryPoints { + for name, typ := range cfg.schema.EntryPoints { schemaObj := typ.(*schema.Object) if schemaObj.TypeName() != obj.GQLType { continue diff --git a/codegen/type_build.go b/codegen/type_build.go index bbc9a64b6bf..aea5a425eb3 100644 --- a/codegen/type_build.go +++ b/codegen/type_build.go @@ -11,12 +11,12 @@ import ( ) // namedTypeFromSchema objects for every graphql type, including scalars. There should only be one instance of Type for each thing -func buildNamedTypes(s *schema.Schema, userTypes map[string]string) NamedTypes { +func (cfg *Config) buildNamedTypes() NamedTypes { types := map[string]*NamedType{} - for _, schemaType := range s.Types { + for _, schemaType := range cfg.schema.Types { t := namedTypeFromSchema(schemaType) - userType := userTypes[t.GQLType] + userType := cfg.Typemap[t.GQLType] t.IsUserDefined = userType != "" if userType == "" && t.IsScalar { userType = "github.com/vektah/gqlgen/graphql.String" @@ -31,7 +31,7 @@ func buildNamedTypes(s *schema.Schema, userTypes map[string]string) NamedTypes { return types } -func bindTypes(imports Imports, namedTypes NamedTypes, destDir string, prog *loader.Program) Imports { +func (cfg *Config) bindTypes(imports Imports, namedTypes NamedTypes, destDir string, prog *loader.Program) Imports { for _, t := range namedTypes { if t.Package == "" { continue diff --git a/codegen/util.go b/codegen/util.go index a0ef5414d96..c69059a40d1 100644 --- a/codegen/util.go +++ b/codegen/util.go @@ -1,11 +1,10 @@ package codegen import ( - "fmt" "go/types" - "os" "strings" + "github.com/pkg/errors" "golang.org/x/tools/go/loader" ) @@ -20,12 +19,12 @@ func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Ob pkgName, err := resolvePkg(pkgName) if err != nil { - return nil, fmt.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error()) + return nil, errors.Errorf("unable to resolve package for %s: %s\n", fullName, err.Error()) } pkg := prog.Imported[pkgName] if pkg == nil { - return nil, fmt.Errorf("required package was not loaded: %s", fullName) + return nil, errors.Errorf("required package was not loaded: %s", fullName) } for astNode, def := range pkg.Defs { @@ -35,40 +34,42 @@ func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Ob return def, nil } - return nil, fmt.Errorf("unable to find type %s\n", fullName) + + return nil, errors.Errorf("unable to find type %s\n", fullName) } -func findGoNamedType(prog *loader.Program, pkgName string, typeName string) *types.Named { +func findGoNamedType(prog *loader.Program, pkgName string, typeName string) (*types.Named, error) { def, err := findGoType(prog, pkgName, typeName) if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) + return nil, err } if def == nil { - return nil + return nil, nil } namedType, ok := def.Type().(*types.Named) if !ok { - fmt.Fprintf(os.Stderr, "expected %s to be a named type, instead found %T\n", typeName, def.Type()) - return nil + return nil, errors.Errorf("expected %s to be a named type, instead found %T\n", typeName, def.Type()) } - return namedType + return namedType, nil } -func findGoInterface(prog *loader.Program, pkgName string, typeName string) *types.Interface { - namedType := findGoNamedType(prog, pkgName, typeName) +func findGoInterface(prog *loader.Program, pkgName string, typeName string) (*types.Interface, error) { + namedType, err := findGoNamedType(prog, pkgName, typeName) + if err != nil { + return nil, err + } if namedType == nil { - return nil + return nil, nil } underlying, ok := namedType.Underlying().(*types.Interface) if !ok { - fmt.Fprintf(os.Stderr, "expected %s to be a named interface, instead found %s", typeName, namedType.String()) - return nil + return nil, errors.Errorf("expected %s to be a named interface, instead found %s", typeName, namedType.String()) } - return underlying + return underlying, nil } func findMethod(typ *types.Named, name string) *types.Func { @@ -129,17 +130,15 @@ func findField(typ *types.Struct, name string) *types.Var { return nil } -func bindObject(t types.Type, object *Object, imports Imports) { +func bindObject(t types.Type, object *Object, imports Imports) error { namedType, ok := t.(*types.Named) if !ok { - fmt.Fprintf(os.Stderr, "expected %s to be a named struct, instead found %s", object.FullName(), t.String()) - return + return errors.Errorf("expected %s to be a named struct, instead found %s", object.FullName(), t.String()) } underlying, ok := t.Underlying().(*types.Struct) if !ok { - fmt.Fprintf(os.Stderr, "expected %s to be a named struct, instead found %s", object.FullName(), t.String()) - return + return errors.Errorf("expected %s to be a named struct, instead found %s", object.FullName(), t.String()) } for i := range object.Fields { @@ -161,14 +160,14 @@ func bindObject(t types.Type, object *Object, imports Imports) { continue l2 } } - fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String()) + return errors.Errorf("cannot match argument " + param.Name() + " to any argument in " + t.String()) } field.Args = newArgs if sig.Results().Len() == 1 { field.NoErr = true } else if sig.Results().Len() != 2 { - fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name()) + return errors.Errorf("weird number of results on %s. expected either (result), or (result, error)\n", method.Name()) } continue } @@ -190,11 +189,12 @@ func bindObject(t types.Type, object *Object, imports Imports) { } default: - fmt.Fprintf(os.Stderr, "type mismatch on %s.%s, expected %s got %s\n", object.GQLType, field.GQLName, field.Type.FullSignature(), structField.Type()) + return errors.Errorf("type mismatch on %s.%s, expected %s got %s\n", object.GQLType, field.GQLName, field.Type.FullSignature(), structField.Type()) } continue } } + return nil } func modifiersFromGoType(t types.Type) []string { diff --git a/main.go b/main.go index fe59ac0d0c7..9cf0278455d 100644 --- a/main.go +++ b/main.go @@ -4,18 +4,11 @@ import ( "encoding/json" "flag" "fmt" - "go/build" "io/ioutil" "os" - "path/filepath" "syscall" - "strings" - "github.com/vektah/gqlgen/codegen" - "github.com/vektah/gqlgen/codegen/templates" - "github.com/vektah/gqlgen/neelance/schema" - "golang.org/x/tools/imports" ) var output = flag.String("out", "generated.go", "the file to write to") @@ -39,143 +32,41 @@ func main() { os.Exit(1) } - schema := schema.New() schemaRaw, err := ioutil.ReadFile(*schemaFilename) if err != nil { fmt.Fprintln(os.Stderr, "unable to open schema: "+err.Error()) os.Exit(1) } - if err = schema.Parse(string(schemaRaw)); err != nil { - fmt.Fprintln(os.Stderr, "unable to parse schema: "+err.Error()) - os.Exit(1) - } - _ = syscall.Unlink(*output) _ = syscall.Unlink(*models) types := loadTypeMap() - modelsBuild := codegen.Models(schema, types, dirName(*models)) - if len(modelsBuild.Models) > 0 { - if *packageName != "" { - modelsBuild.PackageName = *packageName - } - - buf, err := templates.Run("models.gotpl", modelsBuild) - if err != nil { - fmt.Fprintf(os.Stderr, "unable to generate code: "+err.Error()) - os.Exit(1) - } - - write(*models, buf.Bytes()) - pkgName := fullPackageName(*models, *modelPackageName) - - for _, model := range modelsBuild.Models { - types[model.GQLType] = pkgName + "." + model.GoType - } - - for _, enum := range modelsBuild.Enums { - types[enum.GQLType] = pkgName + "." + enum.GoType - } - } - - build, err := codegen.Bind(schema, types, dirName(*output)) - if err != nil { - fmt.Fprintln(os.Stderr, "failed to generate code: "+err.Error()) - os.Exit(1) - } - build.SchemaRaw = string(schemaRaw) - - if *packageName != "" { - build.PackageName = *packageName - } - - buf, err := templates.Run("generated.gotpl", build) - if err != nil { - fmt.Fprintf(os.Stderr, "unable to generate code: "+err.Error()) - os.Exit(1) - } - - write(*output, buf.Bytes()) -} - -func gofmt(filename string, b []byte) []byte { - out, err := imports.Process(filename, b, nil) - if err != nil { - fmt.Fprintln(os.Stderr, "unable to gofmt: "+err.Error()) - return b - } - return out -} - -func write(filename string, b []byte) { - err := os.MkdirAll(filepath.Dir(filename), 0755) - if err != nil { - fmt.Fprintln(os.Stderr, "failed to create directory: ", err.Error()) - os.Exit(1) - } - - err = ioutil.WriteFile(filename, gofmt(filename, b), 0644) - if err != nil { - fmt.Fprintf(os.Stderr, "failed to write %s: %s", filename, err.Error()) - os.Exit(1) - } -} + err = codegen.Generate(codegen.Config{ + ModelFilename: *models, + ExecFilename: *output, + ExecPackageName: *packageName, + ModelPackageName: *modelPackageName, + SchemaStr: string(schemaRaw), + Typemap: types, + }) -func abs(path string) string { - absPath, err := filepath.Abs(path) if err != nil { - panic(err) + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(2) } - return absPath -} - -func fullPackageName(file string, override string) string { - absPath, err := filepath.Abs(file) - if err != nil { - panic(err) - } - pkgName := filepath.Dir(absPath) - if override != "" { - pkgName = filepath.Join(filepath.Dir(pkgName), override) - } - - for _, gopath := range filepath.SplitList(build.Default.GOPATH) { - gopath = filepath.Join(gopath, "src") + string(os.PathSeparator) - if strings.HasPrefix(pkgName, gopath) { - pkgName = pkgName[len(gopath):] - } - } - return filepath.ToSlash(pkgName) -} - -func dirName(path string) string { - return filepath.Dir(abs(path)) } func loadTypeMap() map[string]string { - goTypes := map[string]string{ - "__Directive": "github.com/vektah/gqlgen/neelance/introspection.Directive", - "__Type": "github.com/vektah/gqlgen/neelance/introspection.Type", - "__Field": "github.com/vektah/gqlgen/neelance/introspection.Field", - "__EnumValue": "github.com/vektah/gqlgen/neelance/introspection.EnumValue", - "__InputValue": "github.com/vektah/gqlgen/neelance/introspection.InputValue", - "__Schema": "github.com/vektah/gqlgen/neelance/introspection.Schema", - "Int": "github.com/vektah/gqlgen/graphql.Int", - "Float": "github.com/vektah/gqlgen/graphql.Float", - "String": "github.com/vektah/gqlgen/graphql.String", - "Boolean": "github.com/vektah/gqlgen/graphql.Boolean", - "ID": "github.com/vektah/gqlgen/graphql.ID", - "Time": "github.com/vektah/gqlgen/graphql.Time", - "Map": "github.com/vektah/gqlgen/graphql.Map", - } + var goTypes map[string]string if *typemap != "" { b, err := ioutil.ReadFile(*typemap) if err != nil { fmt.Fprintln(os.Stderr, "unable to open typemap: "+err.Error()) - return goTypes + return nil } + if err = json.Unmarshal(b, &goTypes); err != nil { fmt.Fprintln(os.Stderr, "unable to parse typemap: "+err.Error()) os.Exit(1) From 35a959b816661fd469a1177dcdd49f1533b49698 Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Thu, 26 Apr 2018 16:33:40 +1000 Subject: [PATCH 2/3] Add a better error message when passing a type into an input --- .gitignore | 1 + codegen/build.go | 5 ++++- codegen/codegen.go | 4 ++++ codegen/models_build.go | 9 ++++++--- codegen/object_build.go | 14 +++++++++++--- codegen/tests/input_union_test.go | 24 ++++++++++++++++++++++++ main.go | 4 ---- 7 files changed, 50 insertions(+), 11 deletions(-) create mode 100644 codegen/tests/input_union_test.go diff --git a/.gitignore b/.gitignore index e726518406a..2c5f0421a05 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ /docs/public /example/chat/node_modules /example/chat/package-lock.json +/codegen/tests/gen diff --git a/codegen/build.go b/codegen/build.go index 68720004645..e1d9f2581a7 100644 --- a/codegen/build.go +++ b/codegen/build.go @@ -41,7 +41,10 @@ func (cfg *Config) models() (*ModelBuild, error) { cfg.bindTypes(imports, namedTypes, cfg.modelDir, prog) - models := cfg.buildModels(namedTypes, prog) + models, err := cfg.buildModels(namedTypes, prog) + if err != nil { + return nil, err + } return &ModelBuild{ PackageName: cfg.ModelPackageName, Models: models, diff --git a/codegen/codegen.go b/codegen/codegen.go index 283b110f8c8..284f63e7b16 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "syscall" "github.com/pkg/errors" "github.com/vektah/gqlgen/codegen/templates" @@ -36,6 +37,9 @@ func Generate(cfg Config) error { return err } + _ = syscall.Unlink(cfg.ExecFilename) + _ = syscall.Unlink(cfg.ModelFilename) + modelsBuild, err := cfg.models() if err != nil { return errors.Wrap(err, "model plan failed") diff --git a/codegen/models_build.go b/codegen/models_build.go index 937af104f7e..0c03ae7fcbf 100644 --- a/codegen/models_build.go +++ b/codegen/models_build.go @@ -8,14 +8,17 @@ import ( "golang.org/x/tools/go/loader" ) -func (cfg *Config) buildModels(types NamedTypes, prog *loader.Program) []Model { +func (cfg *Config) buildModels(types NamedTypes, prog *loader.Program) ([]Model, error) { var models []Model for _, typ := range cfg.schema.Types { var model Model switch typ := typ.(type) { case *schema.Object: - obj := cfg.buildObject(types, typ) + obj, err := cfg.buildObject(types, typ) + if err != nil { + return nil, err + } if obj.Root || obj.IsUserDefined { continue } @@ -43,7 +46,7 @@ func (cfg *Config) buildModels(types NamedTypes, prog *loader.Program) []Model { return strings.Compare(models[i].GQLType, models[j].GQLType) == -1 }) - return models + return models, nil } func (cfg *Config) obj2Model(obj *Object) Model { diff --git a/codegen/object_build.go b/codegen/object_build.go index 1f30113e836..50c2c65c94b 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -4,6 +4,7 @@ import ( "sort" "strings" + "github.com/pkg/errors" "github.com/vektah/gqlgen/neelance/schema" "golang.org/x/tools/go/loader" ) @@ -14,7 +15,10 @@ func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports for _, typ := range cfg.schema.Types { switch typ := typ.(type) { case *schema.Object: - obj := cfg.buildObject(types, typ) + obj, err := cfg.buildObject(types, typ) + if err != nil { + return nil, err + } def, err := findGoType(prog, obj.Package, obj.GoType) if err != nil { @@ -35,7 +39,7 @@ func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports return objects, nil } -func (cfg *Config) buildObject(types NamedTypes, typ *schema.Object) *Object { +func (cfg *Config) buildObject(types NamedTypes, typ *schema.Object) (*Object, error) { obj := &Object{NamedType: types[typ.TypeName()]} for _, i := range typ.Interfaces { @@ -51,6 +55,10 @@ func (cfg *Config) buildObject(types NamedTypes, typ *schema.Object) *Object { Object: obj, } + if !newArg.Type.IsInput && !newArg.Type.IsScalar { + return nil, errors.Errorf("%s cannot be used as argument of %s.%s. only input and scalar types are allowed", arg.Type, obj.GQLType, field.Name) + } + if arg.Default != nil { newArg.Default = arg.Default.Value(nil) newArg.StripPtr() @@ -80,5 +88,5 @@ func (cfg *Config) buildObject(types NamedTypes, typ *schema.Object) *Object { obj.Stream = true } } - return obj + return obj, nil } diff --git a/codegen/tests/input_union_test.go b/codegen/tests/input_union_test.go new file mode 100644 index 00000000000..a9ee295037e --- /dev/null +++ b/codegen/tests/input_union_test.go @@ -0,0 +1,24 @@ +package tests + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/vektah/gqlgen/codegen" +) + +func TestInputUnion(t *testing.T) { + err := codegen.Generate(codegen.Config{ + SchemaStr: ` + type Query { + addBookmark(b: Bookmarkable!): Boolean! + } + type Item {} + union Bookmarkable = Item + `, + ExecFilename: "gen/inputunion/exec.go", + ModelFilename: "gen/inputunion/model.go", + }) + + require.EqualError(t, err, "model plan failed: Bookmarkable! cannot be used as argument of Query.addBookmark. only input and scalar types are allowed") +} diff --git a/main.go b/main.go index 9cf0278455d..91ab30ede45 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,6 @@ import ( "fmt" "io/ioutil" "os" - "syscall" "github.com/vektah/gqlgen/codegen" ) @@ -38,9 +37,6 @@ func main() { os.Exit(1) } - _ = syscall.Unlink(*output) - _ = syscall.Unlink(*models) - types := loadTypeMap() err = codegen.Generate(codegen.Config{ From e314b151f709835d99985b057465468fe0f98196 Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Thu, 26 Apr 2018 16:42:05 +1000 Subject: [PATCH 3/3] Add an error message when using types inside inputs --- codegen/codegen.go | 36 +++++++++++++++---------------- codegen/input_build.go | 25 +++++++++++++++------ codegen/models_build.go | 5 ++++- codegen/object_build.go | 5 ++++- codegen/tests/input_union_test.go | 20 ++++++++++++++++- 5 files changed, 64 insertions(+), 27 deletions(-) diff --git a/codegen/codegen.go b/codegen/codegen.go index 284f63e7b16..7fd9ecc6a52 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -1,6 +1,7 @@ package codegen import ( + "bytes" "fmt" "go/build" "io/ioutil" @@ -21,15 +22,15 @@ type Config struct { schema *schema.Schema - ExecFilename string - ExecPackageName string - execDir string - fullExecPackageName string + ExecFilename string + ExecPackageName string + execPackagePath string + execDir string - ModelFilename string - ModelPackageName string - modelDir string - fullModelPackageName string + ModelFilename string + ModelPackageName string + modelPackagePath string + modelDir string } func Generate(cfg Config) error { @@ -46,8 +47,8 @@ func Generate(cfg Config) error { } if len(modelsBuild.Models) > 0 { modelsBuild.PackageName = cfg.ModelPackageName - - buf, err := templates.Run("models.gotpl", modelsBuild) + var buf *bytes.Buffer + buf, err = templates.Run("models.gotpl", modelsBuild) if err != nil { return errors.Wrap(err, "model generation failed") } @@ -56,11 +57,11 @@ func Generate(cfg Config) error { return err } for _, model := range modelsBuild.Models { - cfg.Typemap[model.GQLType] = cfg.fullModelPackageName + "." + model.GoType + cfg.Typemap[model.GQLType] = cfg.modelPackagePath + "." + model.GoType } for _, enum := range modelsBuild.Enums { - cfg.Typemap[enum.GQLType] = cfg.fullModelPackageName + "." + enum.GoType + cfg.Typemap[enum.GQLType] = cfg.modelPackagePath + "." + enum.GoType } } @@ -71,7 +72,8 @@ func Generate(cfg Config) error { build.SchemaRaw = cfg.SchemaStr build.PackageName = cfg.ExecPackageName - buf, err := templates.Run("generated.gotpl", build) + var buf *bytes.Buffer + buf, err = templates.Run("generated.gotpl", build) if err != nil { return errors.Wrap(err, "exec codegen failed") } @@ -91,7 +93,7 @@ func (cfg *Config) normalize() error { if cfg.ModelPackageName == "" { cfg.ModelPackageName = filepath.Base(cfg.modelDir) } - cfg.fullModelPackageName = fullPackageName(cfg.modelDir, cfg.ModelPackageName) + cfg.modelPackagePath = fullPackageName(cfg.modelDir, cfg.ModelPackageName) if cfg.ExecFilename == "" { return errors.New("ModelFilename is required") @@ -101,7 +103,7 @@ func (cfg *Config) normalize() error { if cfg.ExecPackageName == "" { cfg.ExecPackageName = filepath.Base(cfg.execDir) } - cfg.fullExecPackageName = fullPackageName(cfg.execDir, cfg.ExecPackageName) + cfg.execPackagePath = fullPackageName(cfg.execDir, cfg.ExecPackageName) builtins := map[string]string{ "__Directive": "github.com/vektah/gqlgen/neelance/introspection.Directive", @@ -145,9 +147,7 @@ func fullPackageName(dir string, pkgName string) string { for _, gopath := range filepath.SplitList(build.Default.GOPATH) { gopath = filepath.Join(gopath, "src") + string(os.PathSeparator) - if strings.HasPrefix(fullPkgName, gopath) { - fullPkgName = fullPkgName[len(gopath):] - } + fullPkgName = strings.TrimPrefix(fullPkgName, gopath) } return filepath.ToSlash(fullPkgName) } diff --git a/codegen/input_build.go b/codegen/input_build.go index 13241495a82..72817d77c2e 100644 --- a/codegen/input_build.go +++ b/codegen/input_build.go @@ -16,7 +16,10 @@ func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, impo for _, typ := range cfg.schema.Types { switch typ := typ.(type) { case *schema.InputObject: - input := buildInput(namedTypes, typ) + input, err := buildInput(namedTypes, typ) + if err != nil { + return nil, err + } def, err := findGoType(prog, input.Package, input.GoType) if err != nil { @@ -24,7 +27,10 @@ func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, impo } if def != nil { input.Marshaler = buildInputMarshaler(typ, def) - bindObject(def.Type(), input, imports) + err = bindObject(def.Type(), input, imports) + if err != nil { + return nil, err + } } inputs = append(inputs, input) @@ -38,17 +44,24 @@ func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, impo return inputs, nil } -func buildInput(types NamedTypes, typ *schema.InputObject) *Object { +func buildInput(types NamedTypes, typ *schema.InputObject) (*Object, error) { obj := &Object{NamedType: types[typ.TypeName()]} for _, field := range typ.Values { - obj.Fields = append(obj.Fields, Field{ + newField := Field{ GQLName: field.Name.Name, Type: types.getType(field.Type), Object: obj, - }) + } + + if !newField.Type.IsInput && !newField.Type.IsScalar { + return nil, errors.Errorf("%s cannot be used as a field of %s. only input and scalar types are allowed", newField.GQLType, obj.GQLType) + } + + obj.Fields = append(obj.Fields, newField) + } - return obj + return obj, nil } // if user has implemented an UnmarshalGQL method on the input type manually, use it diff --git a/codegen/models_build.go b/codegen/models_build.go index 0c03ae7fcbf..d694ce08e57 100644 --- a/codegen/models_build.go +++ b/codegen/models_build.go @@ -24,7 +24,10 @@ func (cfg *Config) buildModels(types NamedTypes, prog *loader.Program) ([]Model, } model = cfg.obj2Model(obj) case *schema.InputObject: - obj := buildInput(types, typ) + obj, err := buildInput(types, typ) + if err != nil { + return nil, err + } if obj.IsUserDefined { continue } diff --git a/codegen/object_build.go b/codegen/object_build.go index 50c2c65c94b..3b9e092dc71 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -25,7 +25,10 @@ func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports return nil, err } if def != nil { - bindObject(def.Type(), obj, imports) + err = bindObject(def.Type(), obj, imports) + if err != nil { + return nil, err + } } objects = append(objects, obj) diff --git a/codegen/tests/input_union_test.go b/codegen/tests/input_union_test.go index a9ee295037e..f20d702b391 100644 --- a/codegen/tests/input_union_test.go +++ b/codegen/tests/input_union_test.go @@ -7,7 +7,7 @@ import ( "github.com/vektah/gqlgen/codegen" ) -func TestInputUnion(t *testing.T) { +func TestTypeUnionAsInput(t *testing.T) { err := codegen.Generate(codegen.Config{ SchemaStr: ` type Query { @@ -22,3 +22,21 @@ func TestInputUnion(t *testing.T) { require.EqualError(t, err, "model plan failed: Bookmarkable! cannot be used as argument of Query.addBookmark. only input and scalar types are allowed") } + +func TestTypeInInput(t *testing.T) { + err := codegen.Generate(codegen.Config{ + SchemaStr: ` + type Query { + addBookmark(b: BookmarkableInput!): Boolean! + } + type Item {} + input BookmarkableInput { + item: Item + } + `, + ExecFilename: "gen/typeinput/exec.go", + ModelFilename: "gen/typeinput/model.go", + }) + + require.EqualError(t, err, "model plan failed: Item cannot be used as a field of BookmarkableInput. only input and scalar types are allowed") +}