diff --git a/src/k8s/cmd/k8s/k8s_set.go b/src/k8s/cmd/k8s/k8s_set.go index a6aba831e..46750f2a6 100644 --- a/src/k8s/cmd/k8s/k8s_set.go +++ b/src/k8s/cmd/k8s/k8s_set.go @@ -2,13 +2,12 @@ package k8s import ( "fmt" - "github.com/canonical/k8s/pkg/utils" - "strconv" "strings" - "unicode" apiv1 "github.com/canonical/k8s/api/v1" cmdutil "github.com/canonical/k8s/cmd/util" + "github.com/canonical/k8s/pkg/utils" + "github.com/mitchellh/mapstructure" "github.com/spf13/cobra" ) @@ -34,7 +33,7 @@ func newSetCmd(env cmdutil.ExecutionEnvironment) *cobra.Command { config := apiv1.UserFacingClusterConfig{} for _, arg := range args { - if err := updateConfig(&config, arg); err != nil { + if err := updateConfigMapstructure(&config, arg); err != nil { cmd.PrintErrf("Error: Invalid option %q.\n\nThe error was: %v\n", arg, err) env.Exit(1) } @@ -66,7 +65,48 @@ func newSetCmd(env cmdutil.ExecutionEnvironment) *cobra.Command { return cmd } -func updateConfig(config *apiv1.UserFacingClusterConfig, arg string) error { +var knownSetKeys = map[string]struct{}{ + "cloud-provider": struct{}{}, + "dns.cluster-domain": struct{}{}, + "dns.enabled": struct{}{}, + "dns.service-ip": struct{}{}, + "dns.upstream-nameservers": struct{}{}, + "gateway.enabled": struct{}{}, + "ingress.default-tls-secret": struct{}{}, + "ingress.enable-proxy-protocol": struct{}{}, + "ingress.enabled": struct{}{}, + "load-balancer.bgp-local-asn": struct{}{}, + "load-balancer.bgp-mode": struct{}{}, + "load-balancer.bgp-peer-address": struct{}{}, + "load-balancer.bgp-peer-asn": struct{}{}, + "load-balancer.bgp-peer-port": struct{}{}, + "load-balancer.cidrs": struct{}{}, + "load-balancer.enabled": struct{}{}, + "load-balancer.l2-interfaces": struct{}{}, + "load-balancer.l2-mode": struct{}{}, + "local-storage.default": struct{}{}, + "local-storage.enabled": struct{}{}, + "local-storage.local-path": struct{}{}, + "local-storage.reclaim-policy": struct{}{}, + "metrics-server.enabled": struct{}{}, + "network.enabled": struct{}{}, +} + +func updateConfigMapstructure(config *apiv1.UserFacingClusterConfig, arg string) error { + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + ErrorUnused: true, + Result: config, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + utils.YAMLToStringSliceHookFunc, + utils.StringToFieldsSliceHookFunc(','), + ), + }) + if err != nil { + panic(fmt.Sprintf("failed to define decoder with error %v", err.Error())) + } + parts := strings.SplitN(arg, "=", 2) if len(parts) != 2 { return fmt.Errorf("option not in = format") @@ -74,111 +114,20 @@ func updateConfig(config *apiv1.UserFacingClusterConfig, arg string) error { key := parts[0] value := parts[1] - switch key { - case "network.enabled": - v, err := strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid boolean value for network.enabled: %w", err) - } - config.Network.Enabled = &v - case "dns.enabled": - v, err := strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid boolean value for dns.enabled: %w", err) - } - config.DNS.Enabled = &v - case "dns.upstream-nameservers": - config.DNS.UpstreamNameservers = utils.Pointer(strings.FieldsFunc(value, func(r rune) bool { return unicode.IsSpace(r) || r == ',' })) - case "dns.cluster-domain": - config.DNS.ClusterDomain = utils.Pointer(value) - case "dns.service-ip": - config.DNS.ServiceIP = utils.Pointer(value) - case "gateway.enabled": - v, err := strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid boolean value for gateway.enabled: %w", err) - } - config.Gateway.Enabled = &v - case "ingress.enabled": - v, err := strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid boolean value for ingress.enabled: %w", err) - } - config.Ingress.Enabled = &v - case "ingress.default-tls-secret": - config.Ingress.DefaultTLSSecret = utils.Pointer(value) - case "ingress.enable-proxy-protocol": - v, err := strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid boolean value for ingress.enable-proxy-protocol: %w", err) - } - config.Ingress.EnableProxyProtocol = &v - case "local-storage.enabled": - v, err := strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid boolean value for local-storage.enabled: %w", err) - } - config.LocalStorage.Enabled = &v - case "local-storage.local-path": - config.LocalStorage.LocalPath = utils.Pointer(value) - case "local-storage.reclaim-policy": - config.LocalStorage.ReclaimPolicy = utils.Pointer(value) - case "local-storage.default": - v, err := strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid boolean value for local-storage.default: %w", err) - } - config.LocalStorage.Default = &v - case "load-balancer.enabled": - v, err := strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid boolean value for load-balancer.enabled: %w", err) - } - config.LoadBalancer.Enabled = &v - case "load-balancer.cidrs": - config.LoadBalancer.CIDRs = utils.Pointer(strings.FieldsFunc(value, func(r rune) bool { return unicode.IsSpace(r) || r == ',' })) - case "load-balancer.l2-mode": - v, err := strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid boolean value for load-balancer.l2-mode: %w", err) - } - config.LoadBalancer.L2Mode = &v - case "load-balancer.l2-interfaces": - config.LoadBalancer.L2Interfaces = utils.Pointer(strings.FieldsFunc(value, func(r rune) bool { return unicode.IsSpace(r) || r == ',' })) - case "load-balancer.bgp-mode": - v, err := strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid boolean value for load-balancer.bgp-mode: %w", err) - } - config.LoadBalancer.BGPMode = &v - case "load-balancer.bgp-local-asn": - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("invalid integer value for load-balancer.bgp-local-asn: %w", err) - } - config.LoadBalancer.BGPLocalASN = &v - case "load-balancer.bgp-peer-address": - config.LoadBalancer.BGPPeerAddress = utils.Pointer(value) - case "load-balancer.bgp-peer-port": - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("invalid integer value for load-balancer.bgp-peer-port: %w", err) - } - config.LoadBalancer.BGPPeerPort = &v - case "load-balancer.bgp-peer-asn": - v, err := strconv.Atoi(value) - if err != nil { - return fmt.Errorf("invalid integer value for load-balancer.bgp-peer-asn: %w", err) - } - config.LoadBalancer.BGPPeerASN = &v - case "metrics-server.enabled": - v, err := strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid boolean value for metrics-server.enabled: %w", err) - } - config.MetricsServer.Enabled = &v - default: - return fmt.Errorf("unknown config key %q", key) + if _, ok := knownSetKeys[key]; !ok { + return fmt.Errorf("unknown option key %q", key) + } + + if err := decoder.Decode(toRecursiveMap(key, value)); err != nil { + return fmt.Errorf("invalid option %q: %w", arg, err) } return nil } + +func toRecursiveMap(key, value string) map[string]any { + parts := strings.SplitN(key, ".", 2) + if len(parts) == 2 { + return map[string]any{parts[0]: toRecursiveMap(parts[1], value)} + } + return map[string]any{key: value} +} diff --git a/src/k8s/cmd/k8s/k8s_set_test.go b/src/k8s/cmd/k8s/k8s_set_test.go new file mode 100644 index 000000000..52c624840 --- /dev/null +++ b/src/k8s/cmd/k8s/k8s_set_test.go @@ -0,0 +1,158 @@ +package k8s + +import ( + "fmt" + "testing" + + apiv1 "github.com/canonical/k8s/api/v1" + "github.com/canonical/k8s/pkg/utils" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/types" +) + +type mapstructureTestCase struct { + name string + val string + expectErr bool + assertions []types.GomegaMatcher +} + +func generateMapstructureTestCasesBool(keyName string, fieldName string) []mapstructureTestCase { + return []mapstructureTestCase{ + { + val: fmt.Sprintf("%s=true", keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer(true))}, + }, + { + val: fmt.Sprintf("%s=false", keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer(false))}, + }, + { + val: fmt.Sprintf("%s=", keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer(false))}, + }, + { + val: fmt.Sprintf("%s=yes", keyName), + expectErr: true, + }, + } +} + +func generateMapstructureTestCasesStringSlice(keyName string, fieldName string) []mapstructureTestCase { + return []mapstructureTestCase{ + { + val: fmt.Sprintf("%s=", keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer([]string{}))}, + }, + { + val: fmt.Sprintf("%s=[]", keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer([]string{}))}, + }, + { + val: fmt.Sprintf("%s=100", keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer([]string{"100"}))}, + }, + { + val: fmt.Sprintf("%s=t1", keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer([]string{"t1"}))}, + }, + { + val: fmt.Sprintf(`%s=["t1"]`, keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer([]string{"t1"}))}, + }, + { + val: fmt.Sprintf("%s=[t1]", keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer([]string{"t1"}))}, + }, + { + val: fmt.Sprintf("%s=t1, t2", keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer([]string{"t1", "t2"}))}, + }, + { + val: fmt.Sprintf(`%s=["t1", "t2"]`, keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer([]string{"t1", "t2"}))}, + }, + { + val: fmt.Sprintf(`%s=[t1, t2]`, keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer([]string{"t1", "t2"}))}, + }, + } +} + +func generateMapstructureTestCasesString(keyName string, fieldName string) []mapstructureTestCase { + return []mapstructureTestCase{ + { + val: fmt.Sprintf("%s=", keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer(""))}, + }, + { + val: fmt.Sprintf("%s=t1", keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer("t1"))}, + }, + } +} + +func generateMapstructureTestCasesInt(keyName string, fieldName string) []mapstructureTestCase { + return []mapstructureTestCase{ + { + val: fmt.Sprintf("%s=", keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer(0))}, + }, + { + val: fmt.Sprintf("%s=100", keyName), + assertions: []types.GomegaMatcher{HaveField(fieldName, utils.Pointer(100))}, + }, + { + val: fmt.Sprintf("%s=notanumber", keyName), + expectErr: true, + }, + } +} + +func Test_updateConfigMapstructure(t *testing.T) { + for _, tcs := range [][]mapstructureTestCase{ + generateMapstructureTestCasesBool("dns.enabled", "DNS.Enabled"), + generateMapstructureTestCasesBool("gateway.enabled", "Gateway.Enabled"), + generateMapstructureTestCasesBool("ingress.enable-proxy-protocol", "Ingress.EnableProxyProtocol"), + generateMapstructureTestCasesBool("ingress.enabled", "Ingress.Enabled"), + generateMapstructureTestCasesBool("load-balancer.bgp-mode", "LoadBalancer.BGPMode"), + generateMapstructureTestCasesBool("load-balancer.l2-mode", "LoadBalancer.L2Mode"), + generateMapstructureTestCasesBool("load-balancer.enabled", "LoadBalancer.Enabled"), + generateMapstructureTestCasesBool("load-balancer.enabled", "LoadBalancer.Enabled"), + generateMapstructureTestCasesBool("local-storage.default", "LocalStorage.Default"), + generateMapstructureTestCasesBool("local-storage.enabled", "LocalStorage.Enabled"), + generateMapstructureTestCasesBool("metrics-server.enabled", "MetricsServer.Enabled"), + generateMapstructureTestCasesBool("network.enabled", "Network.Enabled"), + + generateMapstructureTestCasesString("cloud-provider", "CloudProvider"), + generateMapstructureTestCasesString("dns.cluster-domain", "DNS.ClusterDomain"), + generateMapstructureTestCasesString("dns.service-ip", "DNS.ServiceIP"), + generateMapstructureTestCasesString("ingress.default-tls-secret", "Ingress.DefaultTLSSecret"), + generateMapstructureTestCasesString("load-balancer.bgp-peer-address", "LoadBalancer.BGPPeerAddress"), + generateMapstructureTestCasesString("local-storage.local-path", "LocalStorage.LocalPath"), + generateMapstructureTestCasesString("local-storage.reclaim-policy", "LocalStorage.ReclaimPolicy"), + + generateMapstructureTestCasesStringSlice("dns.upstream-nameservers", "DNS.UpstreamNameservers"), + generateMapstructureTestCasesStringSlice("load-balancer.cidrs", "LoadBalancer.CIDRs"), + generateMapstructureTestCasesStringSlice("load-balancer.l2-interfaces", "LoadBalancer.L2Interfaces"), + + generateMapstructureTestCasesInt("load-balancer.bgp-local-asn", "LoadBalancer.BGPLocalASN"), + generateMapstructureTestCasesInt("load-balancer.bgp-peer-asn", "LoadBalancer.BGPPeerASN"), + generateMapstructureTestCasesInt("load-balancer.bgp-peer-port", "LoadBalancer.BGPPeerPort"), + } { + for _, tc := range tcs { + t.Run(tc.val, func(t *testing.T) { + g := NewWithT(t) + + var cfg apiv1.UserFacingClusterConfig + err := updateConfigMapstructure(&cfg, tc.val) + if tc.expectErr { + g.Expect(err).To(HaveOccurred()) + } else { + g.Expect(err).To(BeNil()) + g.Expect(cfg).To(SatisfyAll(tc.assertions...)) + } + }) + } + } +} diff --git a/src/k8s/go.mod b/src/k8s/go.mod index fe870f777..c13c06ee4 100644 --- a/src/k8s/go.mod +++ b/src/k8s/go.mod @@ -6,6 +6,7 @@ require ( github.com/canonical/go-dqlite v1.21.0 github.com/canonical/lxd v0.0.0-20240403135607-df45915ce961 github.com/canonical/microcluster v0.0.0-20240418162032-e0f837527e02 + github.com/mitchellh/mapstructure v1.5.0 github.com/moby/sys/mountinfo v0.7.1 github.com/onsi/gomega v1.30.0 github.com/pelletier/go-toml v1.9.5 diff --git a/src/k8s/go.sum b/src/k8s/go.sum index 0c719333f..336883afe 100644 --- a/src/k8s/go.sum +++ b/src/k8s/go.sum @@ -443,6 +443,8 @@ github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0Qu github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= diff --git a/src/k8s/pkg/utils/mapstructure.go b/src/k8s/pkg/utils/mapstructure.go new file mode 100644 index 000000000..c6ecf1fbc --- /dev/null +++ b/src/k8s/pkg/utils/mapstructure.go @@ -0,0 +1,44 @@ +package utils + +import ( + "reflect" + "strings" + "unicode" + + "github.com/mitchellh/mapstructure" + "gopkg.in/yaml.v2" +) + +// YAMLToStringSliceHookFunc returns a mapstructure.DecodeHookFunc that converts string to []string by parsing YAML. +func YAMLToStringSliceHookFunc(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) { + if f != reflect.String || t != reflect.Slice { + return data, nil + } + + if data.(string) == "" { + return data, nil + } + + var result []string + if err := yaml.Unmarshal([]byte(data.(string)), &result); err != nil { + return data, nil + } + + return result, nil +} + +// StringToFieldsSliceHookFunc is like mapstructure.StringToSliceHookFunc() but uses strings.Fields() and filters whitespace. +func StringToFieldsSliceHookFunc(r rune) mapstructure.DecodeHookFunc { + return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) { + if f != reflect.String || t != reflect.Slice { + return data, nil + } + + raw := data.(string) + if raw == "" { + return []string{}, nil + } + + return strings.FieldsFunc(raw, func(this rune) bool { return this == r || unicode.IsSpace(this) }), nil + } +}