diff --git a/cmd/polaris/root.go b/cmd/polaris/root.go index e1ee10da2..946e8d413 100644 --- a/cmd/polaris/root.go +++ b/cmd/polaris/root.go @@ -24,6 +24,7 @@ import ( ) var ( + mergeConfig bool configPath string disallowExemptions bool disallowConfigExemptions bool @@ -42,6 +43,7 @@ var ( func init() { // Flags + rootCmd.PersistentFlags().BoolVarP(&mergeConfig, "merge-config", "m", false, "If true, custom configuration will be merged with default configuration instead of replacing it.") rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "Location of Polaris configuration file.") rootCmd.PersistentFlags().StringVarP(&kubeContext, "context", "x", "", "Set the kube context.") rootCmd.PersistentFlags().BoolVarP(&disallowExemptions, "disallow-exemptions", "", false, "Disallow any configured exemption.") @@ -65,7 +67,7 @@ var rootCmd = &cobra.Command{ logrus.SetLevel(parsedLevel) } - config, err = conf.ParseFile(configPath) + config, err = conf.MergeConfigAndParseFile(configPath, mergeConfig) if err != nil { logrus.Errorf("Error parsing config at %s: %v", configPath, err) os.Exit(1) diff --git a/pkg/config/config.go b/pkg/config/config.go index 051473d5b..168872262 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -52,27 +52,50 @@ type Exemption struct { //go:embed default.yaml var defaultConfig []byte -// ParseFile parses config from a file. -func ParseFile(path string) (Configuration, error) { - var rawBytes []byte +// MergeConfigAndParseFile parses config from a file. +func MergeConfigAndParseFile(customConfigPath string, mergeConfig bool) (Configuration, error) { + rawBytes, err := mergeConfigFile(customConfigPath, mergeConfig) + if err != nil { + return Configuration{}, err + } + + return Parse(rawBytes) +} + +func mergeConfigFile(customConfigPath string, mergeConfig bool) ([]byte, error) { + if customConfigPath == "" { + return defaultConfig, nil + } + + var customConfigContent []byte var err error - if path == "" { - rawBytes = defaultConfig - } else if strings.HasPrefix(path, "https://") || strings.HasPrefix(path, "http://") { + if strings.HasPrefix(customConfigPath, "https://") || strings.HasPrefix(customConfigPath, "http://") { // path is a url - response, err2 := http.Get(path) - if err2 != nil { - return Configuration{}, err2 + response, err := http.Get(customConfigPath) + if err != nil { + return nil, err + } + customConfigContent, err = io.ReadAll(response.Body) + if err != nil { + return nil, err } - rawBytes, err = io.ReadAll(response.Body) } else { // path is local - rawBytes, err = os.ReadFile(path) + customConfigContent, err = os.ReadFile(customConfigPath) + if err != nil { + return nil, err + } } - if err != nil { - return Configuration{}, err + + if mergeConfig { + mergedConfig, err := mergeYaml(defaultConfig, customConfigContent) + if err != nil { + return nil, err + } + return mergedConfig, nil } - return Parse(rawBytes) + + return customConfigContent, nil } // Parse parses config from a byte array. diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index cd7417492..5ad91f32a 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -125,7 +125,7 @@ func TestConfigFromURL(t *testing.T) { }() time.Sleep(time.Second) - parsedConf, err = ParseFile("http://localhost:8081/exampleURL") + parsedConf, err = MergeConfigAndParseFile("http://localhost:8081/exampleURL", false) assert.NoError(t, err, "Expected no error when parsing YAML from URL") if err := srv.Shutdown(context.TODO()); err != nil { panic(err) @@ -136,7 +136,7 @@ func TestConfigFromURL(t *testing.T) { func TestConfigNoServerError(t *testing.T) { var err error - _, err = ParseFile("http://localhost:8081/exampleURL") + _, err = MergeConfigAndParseFile("http://localhost:8081/exampleURL", false) assert.Error(t, err) assert.Regexp(t, regexp.MustCompile("connection refused"), err.Error()) } diff --git a/pkg/config/merger.go b/pkg/config/merger.go new file mode 100644 index 000000000..36003e255 --- /dev/null +++ b/pkg/config/merger.go @@ -0,0 +1,45 @@ +package config + +import ( + "gopkg.in/yaml.v3" // do not change the yaml import +) + +func mergeYaml(defaultConfig, overridesConfig []byte) ([]byte, error) { + var defaultData, overrideConfig map[string]any + + err := yaml.Unmarshal([]byte(defaultConfig), &defaultData) + if err != nil { + return nil, err + } + + err = yaml.Unmarshal([]byte(overridesConfig), &overrideConfig) + if err != nil { + return nil, err + } + + mergedData := mergeYAMLMaps(defaultData, overrideConfig) + + mergedConfig, err := yaml.Marshal(mergedData) + if err != nil { + return nil, err + } + + return mergedConfig, nil +} + +func mergeYAMLMaps(defaults, overrides map[string]any) map[string]any { + for k, v := range overrides { + if vMap, ok := v.(map[string]any); ok { + // if the key exists in defaults and is a map, recursively merge + if mv1, ok := defaults[k].(map[string]any); ok { + defaults[k] = mergeYAMLMaps(mv1, vMap) + } else { + defaults[k] = vMap + } + } else { + // add or overwrite the value in defaults + defaults[k] = v + } + } + return defaults +} diff --git a/pkg/config/merger_test.go b/pkg/config/merger_test.go new file mode 100644 index 000000000..4f1df6711 --- /dev/null +++ b/pkg/config/merger_test.go @@ -0,0 +1,50 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +var defaults = ` +checks: + deploymentMissingReplicas: warning + priorityClassNotSet: warning + tagNotSpecified: danger +existing: + sub: + key: value +` + +var overrides = ` +checks: + pullPolicyNotAlways: ignore + tagNotSpecified: overrides +existing: + sub: + key1: value1 + new: value +new: + key: value +` + +func TestMergeYaml(t *testing.T) { + mergedContent, err := mergeYaml([]byte(defaults), []byte(overrides)) + assert.NoError(t, err) + + expectedYAML := `checks: + deploymentMissingReplicas: warning + priorityClassNotSet: warning + pullPolicyNotAlways: ignore + tagNotSpecified: overrides +existing: + new: value + sub: + key: value + key1: value1 +new: + key: value +` + + assert.Equal(t, expectedYAML, string(mergedContent)) +}