diff --git a/cmd/gen.go b/cmd/gen.go index 3842f02b37d..97bd6342f1e 100644 --- a/cmd/gen.go +++ b/cmd/gen.go @@ -46,11 +46,6 @@ var genCmd = cli.Command{ config.SchemaStr[filename] = string(schemaRaw) } - if err = config.Check(); err != nil { - fmt.Fprintln(os.Stderr, "invalid config format: "+err.Error()) - os.Exit(1) - } - err = codegen.Generate(*config) if err != nil { fmt.Fprintln(os.Stderr, err.Error()) diff --git a/cmd/init.go b/cmd/init.go index 1e7c18b9327..6ad7e58c811 100644 --- a/cmd/init.go +++ b/cmd/init.go @@ -78,11 +78,6 @@ func GenerateGraphServer(config *codegen.Config, serverFilename string) { config.SchemaStr[filename] = string(schemaRaw) } - if err := config.Check(); err != nil { - fmt.Fprintln(os.Stderr, "invalid config format: "+err.Error()) - os.Exit(1) - } - if err := codegen.Generate(*config); err != nil { fmt.Fprintln(os.Stderr, err.Error()) os.Exit(1) diff --git a/codegen/codegen.go b/codegen/codegen.go index 773e3db7cc4..8aadd48cd2d 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -1,6 +1,7 @@ package codegen import ( + "fmt" "log" "os" "path/filepath" @@ -19,6 +20,9 @@ func Generate(cfg Config) error { return err } + if err := cfg.check(); err != nil { + return fmt.Errorf("invalid config format: " + err.Error()) + } _ = syscall.Unlink(cfg.Exec.Filename) _ = syscall.Unlink(cfg.Model.Filename) diff --git a/codegen/config.go b/codegen/config.go index 4319953ba18..9826ca5dfe5 100644 --- a/codegen/config.go +++ b/codegen/config.go @@ -178,7 +178,7 @@ func (c *PackageConfig) IsDefined() bool { return c.Filename != "" } -func (cfg *Config) Check() error { +func (cfg *Config) check() error { if err := cfg.Models.Check(); err != nil { return errors.Wrap(err, "config.models") } @@ -200,36 +200,18 @@ func (cfg *Config) Check() error { cfg.Resolver, } filesMap := make(map[string]bool) - pkgConfigsByDir := make(map[string][]PackageConfig) - for i, current := range packageConfigList { - if i == 0 { - filesMap[current.Filename] = true - pkgConfigsByDir[current.Dir()] = []PackageConfig{current} - continue - } + pkgConfigsByDir := make(map[string]PackageConfig) + for _, current := range packageConfigList { _, fileFound := filesMap[current.Filename] if fileFound { return fmt.Errorf("filename %s defined more than once", current.Filename) } filesMap[current.Filename] = true - prevPkgList, inSameDir := pkgConfigsByDir[current.Dir()] - if inSameDir { - for _, previous := range prevPkgList { - if current.Package != previous.Package { - eitherPackageEmpty := previous.Package != "" || current.Package != "" - if eitherPackageEmpty { - if current.Package == filepath.Base(current.Dir()) && previous.Package == "" { - break - } - if previous.Package == filepath.Base(previous.Dir()) && current.Package == "" { - break - } - return fmt.Errorf("filenames %s and %s are in the same directory but have different package definitions", current.Filename, previous.Filename) - } - } - } + previous, inSameDir := pkgConfigsByDir[current.Dir()] + if inSameDir && current.Package != previous.Package { + return fmt.Errorf("filenames %s and %s are in the same directory but have different package definitions", stripPath(current.Filename), stripPath(previous.Filename)) } - pkgConfigsByDir[current.Dir()] = append(pkgConfigsByDir[current.Dir()], current) + pkgConfigsByDir[current.Dir()] = current } return nil @@ -312,3 +294,7 @@ func findCfgInDir(dir string) string { } return "" } + +func stripPath(path string) string { + return filepath.Base(path) +}