Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --merge-config flag to support merging with default configuration #1075

Merged
merged 3 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cmd/polaris/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
)

var (
mergeConfig bool
configPath string
disallowExemptions bool
disallowConfigExemptions bool
Expand All @@ -42,6 +43,7 @@ var (

func init() {
// Flags
rootCmd.PersistentFlags().BoolVarP(&mergeConfig, "merge-config", "m", false, "If true, custom configuration will be merged with default configuration instead of replacing it.")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "Location of Polaris configuration file.")
rootCmd.PersistentFlags().StringVarP(&kubeContext, "context", "x", "", "Set the kube context.")
rootCmd.PersistentFlags().BoolVarP(&disallowExemptions, "disallow-exemptions", "", false, "Disallow any configured exemption.")
Expand All @@ -65,7 +67,7 @@ var rootCmd = &cobra.Command{
logrus.SetLevel(parsedLevel)
}

config, err = conf.ParseFile(configPath)
config, err = conf.MergeConfigAndParseFile(configPath, mergeConfig)
if err != nil {
logrus.Errorf("Error parsing config at %s: %v", configPath, err)
os.Exit(1)
Expand Down
51 changes: 37 additions & 14 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,50 @@ type Exemption struct {
//go:embed default.yaml
var defaultConfig []byte

// ParseFile parses config from a file.
func ParseFile(path string) (Configuration, error) {
var rawBytes []byte
// MergeConfigAndParseFile parses config from a file.
func MergeConfigAndParseFile(customConfigPath string, mergeConfig bool) (Configuration, error) {
rawBytes, err := mergeConfigFile(customConfigPath, mergeConfig)
if err != nil {
return Configuration{}, err
}

return Parse(rawBytes)
}

func mergeConfigFile(customConfigPath string, mergeConfig bool) ([]byte, error) {
if customConfigPath == "" {
return defaultConfig, nil
}

var customConfigContent []byte
var err error
if path == "" {
rawBytes = defaultConfig
} else if strings.HasPrefix(path, "https://") || strings.HasPrefix(path, "http://") {
if strings.HasPrefix(customConfigPath, "https://") || strings.HasPrefix(customConfigPath, "http://") {
// path is a url
response, err2 := http.Get(path)
if err2 != nil {
return Configuration{}, err2
response, err := http.Get(customConfigPath)
if err != nil {
return nil, err
}
customConfigContent, err = io.ReadAll(response.Body)
if err != nil {
return nil, err
}
rawBytes, err = io.ReadAll(response.Body)
} else {
// path is local
rawBytes, err = os.ReadFile(path)
customConfigContent, err = os.ReadFile(customConfigPath)
if err != nil {
return nil, err
}
}
if err != nil {
return Configuration{}, err

if mergeConfig {
mergedConfig, err := mergeYaml(defaultConfig, customConfigContent)
if err != nil {
return nil, err
}
return mergedConfig, nil
}
return Parse(rawBytes)

return customConfigContent, nil
}

// Parse parses config from a byte array.
Expand Down
4 changes: 2 additions & 2 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func TestConfigFromURL(t *testing.T) {
}()
time.Sleep(time.Second)

parsedConf, err = ParseFile("http://localhost:8081/exampleURL")
parsedConf, err = MergeConfigAndParseFile("http://localhost:8081/exampleURL", false)
assert.NoError(t, err, "Expected no error when parsing YAML from URL")
if err := srv.Shutdown(context.TODO()); err != nil {
panic(err)
Expand All @@ -136,7 +136,7 @@ func TestConfigFromURL(t *testing.T) {

func TestConfigNoServerError(t *testing.T) {
var err error
_, err = ParseFile("http://localhost:8081/exampleURL")
_, err = MergeConfigAndParseFile("http://localhost:8081/exampleURL", false)
assert.Error(t, err)
assert.Regexp(t, regexp.MustCompile("connection refused"), err.Error())
}
Expand Down
45 changes: 45 additions & 0 deletions pkg/config/merger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package config

import (
"gopkg.in/yaml.v3" // do not change the yaml import
)

func mergeYaml(defaultConfig, overridesConfig []byte) ([]byte, error) {
var defaultData, overrideConfig map[string]any

err := yaml.Unmarshal([]byte(defaultConfig), &defaultData)
if err != nil {
return nil, err
}

err = yaml.Unmarshal([]byte(overridesConfig), &overrideConfig)
if err != nil {
return nil, err
}

mergedData := mergeYAMLMaps(defaultData, overrideConfig)

mergedConfig, err := yaml.Marshal(mergedData)
if err != nil {
return nil, err
}

return mergedConfig, nil
}

func mergeYAMLMaps(defaults, overrides map[string]any) map[string]any {
for k, v := range overrides {
if vMap, ok := v.(map[string]any); ok {
// if the key exists in defaults and is a map, recursively merge
if mv1, ok := defaults[k].(map[string]any); ok {
defaults[k] = mergeYAMLMaps(mv1, vMap)
} else {
defaults[k] = vMap
}
} else {
// add or overwrite the value in defaults
defaults[k] = v
}
}
return defaults
}
50 changes: 50 additions & 0 deletions pkg/config/merger_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package config

import (
"testing"

"github.com/stretchr/testify/assert"
)

var defaults = `
checks:
deploymentMissingReplicas: warning
priorityClassNotSet: warning
tagNotSpecified: danger
existing:
sub:
key: value
`

var overrides = `
checks:
pullPolicyNotAlways: ignore
tagNotSpecified: overrides
existing:
sub:
key1: value1
new: value
new:
key: value
`

func TestMergeYaml(t *testing.T) {
mergedContent, err := mergeYaml([]byte(defaults), []byte(overrides))
assert.NoError(t, err)

expectedYAML := `checks:
deploymentMissingReplicas: warning
priorityClassNotSet: warning
pullPolicyNotAlways: ignore
tagNotSpecified: overrides
existing:
new: value
sub:
key: value
key1: value1
new:
key: value
`

assert.Equal(t, expectedYAML, string(mergedContent))
}
Loading