diff --git a/codegen/config/config.go b/codegen/config/config.go index ba939fcf59a..24ea9e7ef04 100644 --- a/codegen/config/config.go +++ b/codegen/config/config.go @@ -70,8 +70,8 @@ func LoadDefaultConfig() (*Config, error) { // LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories // walking up the tree. The closest config file will be returned. -func LoadConfigFromDefaultLocations() (*Config, error) { - cfgFile, err := findCfg() +func LoadConfigFromDefaultLocations() (cfg *Config, err error) { + cfgFile, cwd, err := findCfg() if err != nil { return nil, err } @@ -80,6 +80,12 @@ func LoadConfigFromDefaultLocations() (*Config, error) { if err != nil { return nil, errors.Wrap(err, "unable to enter config dir") } + defer func() { + if cerr := os.Chdir(cwd); cerr != nil { + cfg = nil + err = errors.Wrap(cerr, "unable to restore working directory") + } + }() return LoadConfig(cfgFile) } @@ -467,24 +473,24 @@ func inStrSlice(haystack []string, needle string) bool { // findCfg searches for the config file in this directory and all parents up the tree // looking for the closest match -func findCfg() (string, error) { - dir, err := os.Getwd() +func findCfg() (string, string, error) { + cwd, err := os.Getwd() if err != nil { - return "", errors.Wrap(err, "unable to get working dir to findCfg") + return "", "", errors.Wrap(err, "unable to get working dir to findCfg") } - cfg := findCfgInDir(dir) + cfg := findCfgInDir(cwd) - for cfg == "" && dir != filepath.Dir(dir) { + for dir := cwd; cfg == "" && dir != filepath.Dir(dir); { dir = filepath.Dir(dir) cfg = findCfgInDir(dir) } if cfg == "" { - return "", os.ErrNotExist + return "", "", os.ErrNotExist } - return cfg, nil + return cfg, cwd, nil } func findCfgInDir(dir string) string { diff --git a/codegen/config/config_test.go b/codegen/config/config_test.go index b16e90c11a5..de63b15d3f0 100644 --- a/codegen/config/config_test.go +++ b/codegen/config/config_test.go @@ -72,9 +72,14 @@ func TestLoadConfigFromDefaultLocation(t *testing.T) { err = os.Chdir(filepath.Join(testDir, "testdata", "cfg", "otherdir")) require.NoError(t, err) + before, err := os.Getwd() + require.NoError(t, err) cfg, err = LoadConfigFromDefaultLocations() require.NoError(t, err) require.Equal(t, StringList{"outer"}, cfg.SchemaFilename) + after, err := os.Getwd() + require.NoError(t, err) + require.Equal(t, before, after) }) t.Run("will return error if config doesn't exist", func(t *testing.T) {