diff --git a/codegen/data.go b/codegen/data.go index 71206885be1..ed93bcbb15a 100644 --- a/codegen/data.go +++ b/codegen/data.go @@ -5,8 +5,10 @@ import ( "sort" "github.com/99designs/gqlgen/codegen/config" + "github.com/99designs/gqlgen/internal/code" "github.com/pkg/errors" "github.com/vektah/gqlparser/ast" + "golang.org/x/tools/go/packages" ) // Data is a unified model of the code to be generated. Plugins may modify this structure to do things like implement @@ -25,6 +27,9 @@ type Data struct { QueryRoot *Object MutationRoot *Object SubscriptionRoot *Object + + // This is important for looking up packages during code generation + NameForPackage code.NameForPackage } type builder struct { @@ -75,12 +80,18 @@ func BuildData(cfg *config.Config) (*Data, error) { } } + pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, cfg.Models.ReferencedPackages()...) + if err != nil { + return nil, errors.Wrap(err, "loading failed") + } + s := Data{ - Config: cfg, - Directives: dataDirectives, - Schema: b.Schema, - SchemaStr: b.SchemaStr, - Interfaces: map[string]*Interface{}, + Config: cfg, + Directives: dataDirectives, + Schema: b.Schema, + SchemaStr: b.SchemaStr, + Interfaces: map[string]*Interface{}, + NameForPackage: code.NewNameForPackage(pkgs), } for _, schemaType := range b.Schema.Types { diff --git a/codegen/generate.go b/codegen/generate.go index eafa3f87434..820b5d56db8 100644 --- a/codegen/generate.go +++ b/codegen/generate.go @@ -11,5 +11,6 @@ func GenerateCode(data *Data) error { Data: data, RegionTags: true, GeneratedHeader: true, + NameForPackage: data.NameForPackage, }) } diff --git a/codegen/templates/import.go b/codegen/templates/import.go index d5bd16a6a1e..5ec7304b335 100644 --- a/codegen/templates/import.go +++ b/codegen/templates/import.go @@ -10,14 +10,16 @@ import ( ) type Import struct { - Name string - Path string - Alias string + NameForPackage code.NameForPackage + Name string + Path string + Alias string } type Imports struct { - imports []*Import - destDir string + nameForPackage code.NameForPackage + imports []*Import + destDir string } func (i *Import) String() string { @@ -49,7 +51,7 @@ func (s *Imports) Reserve(path string, aliases ...string) (string, error) { return "", nil } - name := code.NameForPackage(path) + name := s.nameForPackage.Get(path) var alias string if len(aliases) != 1 { alias = name @@ -69,9 +71,10 @@ func (s *Imports) Reserve(path string, aliases ...string) (string, error) { } s.imports = append(s.imports, &Import{ - Name: name, - Path: path, - Alias: alias, + NameForPackage: s.nameForPackage, + Name: name, + Path: path, + Alias: alias, }) return "", nil @@ -94,8 +97,9 @@ func (s *Imports) Lookup(path string) string { } imp := &Import{ - Name: code.NameForPackage(path), - Path: path, + NameForPackage: s.nameForPackage, + Name: s.nameForPackage.Get(path), + Path: path, } s.imports = append(s.imports, imp) diff --git a/codegen/templates/import_test.go b/codegen/templates/import_test.go index d225457576f..f0f01ff6d44 100644 --- a/codegen/templates/import_test.go +++ b/codegen/templates/import_test.go @@ -5,7 +5,9 @@ import ( "os" "testing" + "github.com/99designs/gqlgen/internal/code" "github.com/stretchr/testify/require" + "golang.org/x/tools/go/packages" ) func TestImports(t *testing.T) { @@ -16,15 +18,20 @@ func TestImports(t *testing.T) { bBar := "github.com/99designs/gqlgen/codegen/templates/testdata/b/bar" mismatch := "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch" + ps, err := packages.Load(nil, aBar, bBar, mismatch) + require.NoError(t, err) + + nameForPackage := code.NewNameForPackage(ps) + t.Run("multiple lookups is ok", func(t *testing.T) { - a := Imports{destDir: wd} + a := Imports{nameForPackage: nameForPackage, destDir: wd} require.Equal(t, "bar", a.Lookup(aBar)) require.Equal(t, "bar", a.Lookup(aBar)) }) t.Run("lookup by type", func(t *testing.T) { - a := Imports{destDir: wd} + a := Imports{nameForPackage: nameForPackage, destDir: wd} pkg := types.NewPackage("github.com/99designs/gqlgen/codegen/templates/testdata/b/bar", "bar") typ := types.NewNamed(types.NewTypeName(0, pkg, "Boolean", types.Typ[types.Bool]), types.Typ[types.Bool], nil) @@ -33,7 +40,7 @@ func TestImports(t *testing.T) { }) t.Run("duplicates are decollisioned", func(t *testing.T) { - a := Imports{destDir: wd} + a := Imports{nameForPackage: nameForPackage, destDir: wd} require.Equal(t, "bar", a.Lookup(aBar)) require.Equal(t, "bar1", a.Lookup(bBar)) @@ -44,13 +51,13 @@ func TestImports(t *testing.T) { }) t.Run("package name defined in code will be used", func(t *testing.T) { - a := Imports{destDir: wd} + a := Imports{nameForPackage: nameForPackage, destDir: wd} require.Equal(t, "turtles", a.Lookup(mismatch)) }) t.Run("string printing for import block", func(t *testing.T) { - a := Imports{destDir: wd} + a := Imports{nameForPackage: nameForPackage, destDir: wd} a.Lookup(aBar) a.Lookup(bBar) a.Lookup(mismatch) @@ -65,7 +72,7 @@ turtles "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch"`, }) t.Run("aliased imports will not collide", func(t *testing.T) { - a := Imports{destDir: wd} + a := Imports{nameForPackage: nameForPackage, destDir: wd} _, _ = a.Reserve(aBar, "abar") _, _ = a.Reserve(bBar, "bbar") diff --git a/codegen/templates/templates.go b/codegen/templates/templates.go index 5d5f69bf88d..fd7ec088576 100644 --- a/codegen/templates/templates.go +++ b/codegen/templates/templates.go @@ -15,6 +15,7 @@ import ( "text/template" "unicode" + "github.com/99designs/gqlgen/internal/code" "github.com/99designs/gqlgen/internal/imports" "github.com/pkg/errors" ) @@ -43,6 +44,9 @@ type Options struct { // Data will be passed to the template execution. Data interface{} Funcs template.FuncMap + + // Lookups for pre-cached package names + NameForPackage code.NameForPackage } // Render renders a gql plugin template from the given Options. Render is an @@ -53,7 +57,7 @@ func Render(cfg Options) error { if CurrentImports != nil { panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected")) } - CurrentImports = &Imports{destDir: filepath.Dir(cfg.Filename)} + CurrentImports = &Imports{nameForPackage: cfg.NameForPackage, destDir: filepath.Dir(cfg.Filename)} // load path relative to calling source file _, callerFile, _, _ := runtime.Caller(1) @@ -143,7 +147,7 @@ func Render(cfg Options) error { } CurrentImports = nil - return write(cfg.Filename, result.Bytes()) + return write(cfg.Filename, result.Bytes(), cfg.NameForPackage) } func center(width int, pad string, s string) string { @@ -551,13 +555,13 @@ func render(filename string, tpldata interface{}) (*bytes.Buffer, error) { return buf, t.Execute(buf, tpldata) } -func write(filename string, b []byte) error { +func write(filename string, b []byte, nameForPackage code.NameForPackage) error { err := os.MkdirAll(filepath.Dir(filename), 0755) if err != nil { return errors.Wrap(err, "failed to create directory") } - formatted, err := imports.Prune(filename, b) + formatted, err := imports.Prune(filename, b, nameForPackage) if err != nil { fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error()) formatted = b diff --git a/internal/code/imports.go b/internal/code/imports.go index ad62f7c56b1..e9b18c6999d 100644 --- a/internal/code/imports.go +++ b/internal/code/imports.go @@ -14,8 +14,6 @@ import ( "golang.org/x/tools/go/packages" ) -var nameForPackageCache = sync.Map{} - var gopaths []string func init() { @@ -93,23 +91,41 @@ func ImportPathForDir(dir string) (res string) { var modregex = regexp.MustCompile("module (.*)\n") // NameForPackage returns the package name for a given import path. This can be really slow. -func NameForPackage(importPath string) string { +type NameForPackage struct { + cache *sync.Map + packages []*packages.Package +} + +// NewNameForPackage creates a NameForPackage +func NewNameForPackage(packages []*packages.Package) NameForPackage { + return NameForPackage{ + cache: &sync.Map{}, + packages: packages, + } +} + +// Get returns the package name for a given import path. This can be really slow. +func (n NameForPackage) Get(importPath string) string { if importPath == "" { panic(errors.New("import path can not be empty")) } - if v, ok := nameForPackageCache.Load(importPath); ok { + + if v, ok := n.cache.Load(importPath); ok { return v.(string) } importPath = QualifyPackagePath(importPath) - p, _ := packages.Load(&packages.Config{ - Mode: packages.NeedName, - }, importPath) + var p *packages.Package + for _, pkg := range n.packages { + if pkg.PkgPath == importPath { + p = pkg + } + } - if len(p) != 1 || p[0].Name == "" { + if p == nil || p.Name == "" { return SanitizePackageName(filepath.Base(importPath)) } - nameForPackageCache.Store(importPath, p[0].Name) + n.cache.Store(importPath, p.Name) - return p[0].Name + return p.Name } diff --git a/internal/code/imports_test.go b/internal/code/imports_test.go index e3bc9474f99..f2aed0bbe7e 100644 --- a/internal/code/imports_test.go +++ b/internal/code/imports_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/tools/go/packages" ) func TestImportPathForDir(t *testing.T) { @@ -31,11 +32,15 @@ func TestImportPathForDir(t *testing.T) { } func TestNameForPackage(t *testing.T) { - assert.Equal(t, "api", NameForPackage("github.com/99designs/gqlgen/api")) + ps, _ := packages.Load(&packages.Config{Mode: packages.NeedName}, + "github.com/99designs/gqlgen/api", "github.com/99designs/gqlgen/docs", "github.com") + nfp := NewNameForPackage(ps) + + assert.Equal(t, "api", nfp.Get("github.com/99designs/gqlgen/api")) // does not contain go code, should still give a valid name - assert.Equal(t, "docs", NameForPackage("github.com/99designs/gqlgen/docs")) - assert.Equal(t, "github_com", NameForPackage("github.com")) + assert.Equal(t, "docs", nfp.Get("github.com/99designs/gqlgen/docs")) + assert.Equal(t, "github_com", nfp.Get("github.com")) } func TestNameForDir(t *testing.T) { diff --git a/internal/imports/prune.go b/internal/imports/prune.go index 27ac94ac0f0..62b1a4f2f50 100644 --- a/internal/imports/prune.go +++ b/internal/imports/prune.go @@ -24,7 +24,7 @@ func (fn visitFn) Visit(node ast.Node) ast.Visitor { } // Prune removes any unused imports -func Prune(filename string, src []byte) ([]byte, error) { +func Prune(filename string, src []byte, nameForPackage code.NameForPackage) ([]byte, error) { fset := token.NewFileSet() file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.AllErrors) @@ -32,7 +32,7 @@ func Prune(filename string, src []byte) ([]byte, error) { return nil, err } - unused := getUnusedImports(file) + unused := getUnusedImports(file, nameForPackage) for ipath, name := range unused { astutil.DeleteNamedImport(fset, file, name, ipath) } @@ -46,7 +46,7 @@ func Prune(filename string, src []byte) ([]byte, error) { return imports.Process(filename, buf.Bytes(), &imports.Options{FormatOnly: true, Comments: true, TabIndent: true, TabWidth: 8}) } -func getUnusedImports(file ast.Node) map[string]string { +func getUnusedImports(file ast.Node, nameForPackage code.NameForPackage) map[string]string { imported := map[string]*ast.ImportSpec{} used := map[string]bool{} @@ -65,7 +65,7 @@ func getUnusedImports(file ast.Node) map[string]string { break } - local := code.NameForPackage(ipath) + local := nameForPackage.Get(ipath) imported[local] = v case *ast.SelectorExpr: diff --git a/internal/imports/prune_test.go b/internal/imports/prune_test.go index d0691bf242e..533443ca1e7 100644 --- a/internal/imports/prune_test.go +++ b/internal/imports/prune_test.go @@ -4,11 +4,12 @@ import ( "io/ioutil" "testing" + "github.com/99designs/gqlgen/internal/code" "github.com/stretchr/testify/require" ) func TestPrune(t *testing.T) { - b, err := Prune("testdata/unused.go", mustReadFile("testdata/unused.go")) + b, err := Prune("testdata/unused.go", mustReadFile("testdata/unused.go"), code.NewNameForPackage(nil)) require.NoError(t, err) require.Equal(t, string(mustReadFile("testdata/unused.expected.go")), string(b)) } diff --git a/plugin/modelgen/models.go b/plugin/modelgen/models.go index b7b224d7cbc..e83d69844bd 100644 --- a/plugin/modelgen/models.go +++ b/plugin/modelgen/models.go @@ -7,8 +7,11 @@ import ( "github.com/99designs/gqlgen/codegen/config" "github.com/99designs/gqlgen/codegen/templates" + "github.com/99designs/gqlgen/internal/code" "github.com/99designs/gqlgen/plugin" + "github.com/pkg/errors" "github.com/vektah/gqlparser/ast" + "golang.org/x/tools/go/packages" ) type BuildMutateHook = func(b *ModelBuild) *ModelBuild @@ -235,11 +238,17 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error { b = m.MutateHook(b) } + pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, cfg.Models.ReferencedPackages()...) + if err != nil { + return errors.Wrap(err, "loading failed") + } + return templates.Render(templates.Options{ PackageName: cfg.Model.Package, Filename: cfg.Model.Filename, Data: b, GeneratedHeader: true, + NameForPackage: code.NewNameForPackage(pkgs), }) } diff --git a/plugin/resolvergen/resolver.go b/plugin/resolvergen/resolver.go index 6785c77c45f..d3b7d2af62b 100644 --- a/plugin/resolvergen/resolver.go +++ b/plugin/resolvergen/resolver.go @@ -36,9 +36,10 @@ func (m *Plugin) GenerateCode(data *codegen.Data) error { if _, err := os.Stat(filename); os.IsNotExist(errors.Cause(err)) { return templates.Render(templates.Options{ - PackageName: data.Config.Resolver.Package, - Filename: data.Config.Resolver.Filename, - Data: resolverBuild, + PackageName: data.Config.Resolver.Package, + Filename: data.Config.Resolver.Filename, + Data: resolverBuild, + NameForPackage: data.NameForPackage, }) } diff --git a/plugin/servergen/server.go b/plugin/servergen/server.go index 22289c0254d..883c5aeeabc 100644 --- a/plugin/servergen/server.go +++ b/plugin/servergen/server.go @@ -31,9 +31,10 @@ func (m *Plugin) GenerateCode(data *codegen.Data) error { if _, err := os.Stat(m.filename); os.IsNotExist(errors.Cause(err)) { return templates.Render(templates.Options{ - PackageName: "main", - Filename: m.filename, - Data: serverBuild, + PackageName: "main", + Filename: m.filename, + Data: serverBuild, + NameForPackage: data.NameForPackage, }) } diff --git a/plugin/stubgen/stubs.go b/plugin/stubgen/stubs.go index af5171b4cfa..9d4af8904e7 100644 --- a/plugin/stubgen/stubs.go +++ b/plugin/stubgen/stubs.go @@ -48,6 +48,7 @@ func (m *Plugin) GenerateCode(data *codegen.Data) error { TypeName: m.typeName, }, GeneratedHeader: true, + NameForPackage: data.NameForPackage, }) }