From 4ad35aa8c5d0dd62d40c457248edede0d9260bc4 Mon Sep 17 00:00:00 2001 From: Etienne Audet-Cobello Date: Wed, 25 Sep 2024 14:25:42 -0400 Subject: [PATCH] refactor --- src/k8s/cmd/k8s/k8s_bootstrap.go | 36 +++++++++++++-------------- src/k8s/cmd/k8s/k8s_bootstrap_test.go | 26 +++++++++++++++++++ src/k8s/pkg/utils/cidr.go | 5 +++- src/k8s/pkg/utils/cidr_test.go | 18 +++++++++----- 4 files changed, 59 insertions(+), 26 deletions(-) diff --git a/src/k8s/cmd/k8s/k8s_bootstrap.go b/src/k8s/cmd/k8s/k8s_bootstrap.go index 1cfbda487e..55bfe6a189 100644 --- a/src/k8s/cmd/k8s/k8s_bootstrap.go +++ b/src/k8s/cmd/k8s/k8s_bootstrap.go @@ -128,10 +128,12 @@ func newBootstrapCmd(env cmdutil.ExecutionEnvironment) *cobra.Command { } } - if err = validateCIDROverlapAndSize(*bootstrapConfig.PodCIDR, *bootstrapConfig.ServiceCIDR); err != nil { - cmd.PrintErrf("Error: Failed to validate the CIDR configuration.\n\nThe error was: %v\n", err) - env.Exit(1) - return + if bootstrapConfig.PodCIDR != nil && bootstrapConfig.ServiceCIDR != nil { + if err = validateCIDROverlapAndSize(*bootstrapConfig.PodCIDR, *bootstrapConfig.ServiceCIDR); err != nil { + cmd.PrintErrf("Error: Failed to validate the CIDR configuration.\n\nThe error was: %v\n", err) + env.Exit(1) + return + } } cmd.PrintErrln("Bootstrapping the cluster. This may take a few seconds, please wait.") @@ -277,27 +279,22 @@ func askQuestion(stdin io.Reader, stdout io.Writer, stderr io.Writer, question s } } +// validateCIDROverlapAndSize checks for overlap and size constraints between pod and service CIDRs. +// It parses the provided podCIDR and serviceCIDR strings, checks for IPv4 and IPv6 overlaps, and ensures +// that the service IPv6 CIDR does not have a prefix length of 64 or more. func validateCIDROverlapAndSize(podCIDR string, serviceCIDR string) error { // Parse the CIDRs - var err error - - var podIPv4CIDR, podIPv6CIDR string - if podCIDR != "" { - podIPv4CIDR, podIPv6CIDR, err = utils.ParseCIDRs(podCIDR) - if err != nil { - return err - } + podIPv4CIDR, podIPv6CIDR, err := utils.ParseCIDRs(podCIDR) + if err != nil { + return err } - var svcIPv4CIDR, svcIPv6CIDR string - if serviceCIDR != "" { - svcIPv4CIDR, svcIPv6CIDR, err = utils.ParseCIDRs(serviceCIDR) - if err != nil { - return err - } + svcIPv4CIDR, svcIPv6CIDR, err := utils.ParseCIDRs(serviceCIDR) + if err != nil { + return err } - // Check for overlap + // Check for IPv4 overlap if podIPv4CIDR != "" && svcIPv4CIDR != "" { if overlap, err := utils.CIDRsOverlap(podIPv4CIDR, svcIPv4CIDR); err != nil { return err @@ -306,6 +303,7 @@ func validateCIDROverlapAndSize(podCIDR string, serviceCIDR string) error { } } + // Check for IPv6 overlap if podIPv6CIDR != "" && svcIPv6CIDR != "" { if overlap, err := utils.CIDRsOverlap(podIPv6CIDR, svcIPv6CIDR); err != nil { return err diff --git a/src/k8s/cmd/k8s/k8s_bootstrap_test.go b/src/k8s/cmd/k8s/k8s_bootstrap_test.go index b486beb185..21d563b20a 100644 --- a/src/k8s/cmd/k8s/k8s_bootstrap_test.go +++ b/src/k8s/cmd/k8s/k8s_bootstrap_test.go @@ -155,3 +155,29 @@ func TestGetConfigFromYaml_Stdin(t *testing.T) { expectedConfig := apiv1.BootstrapConfig{SecurePort: utils.Pointer(5000)} g.Expect(config).To(Equal(expectedConfig)) } + +func TestValidateCIDROverlapAndSize(t *testing.T) { + tests := []struct { + name string + podCIDR string + serviceCIDR string + expectErr bool + }{ + {"Empty", "", "", true}, + {"SameIPv4CIDRs", "192.168.100.0/24", "192.168.100.0/24", true}, + {"DifferentIPv4CIDRs", "192.0.2.0/24", "192.0.15.0/24", false}, + {"OverlappingIPv4CIDRs", "10.2.0.13/24", "10.2.0.0/24", true}, + {"SameIPv6CIDRs", "fe80::1/64", "fe80::1/64", true}, + {"DifferentIPv6CIDRs", "fe80::/64", "2001:db8::/32", false}, + {"OverlappingIPv6CIDRs", "fe80::/64", "fe80::dead/64", true}, + {"IPv6CIDRPrefixAtLimit", "192.168.1.0/24", "192.168.2.0/24,fe80::/64", true}, + {"IPv6CIDRPrefixGreaterThanLimit", "192.168.1.0/24", "fe80::/68", true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if err := validateCIDROverlapAndSize(tc.podCIDR, tc.serviceCIDR); (err != nil) != tc.expectErr { + t.Errorf("validateCIDROverlapAndSize() error = %v, expectErr %v", err, tc.expectErr) + } + }) + } +} diff --git a/src/k8s/pkg/utils/cidr.go b/src/k8s/pkg/utils/cidr.go index 6bb852e3fb..4d969d7cf9 100644 --- a/src/k8s/pkg/utils/cidr.go +++ b/src/k8s/pkg/utils/cidr.go @@ -143,12 +143,15 @@ func ToIPString(ip net.IP) string { return "[" + ip.String() + "]" } +// CIDRsOverlap checks if two given CIDR blocks overlap. +// It takes two strings representing the CIDR blocks as input and returns a boolean indicating +// whether they overlap and an error if any of the CIDR blocks are invalid. func CIDRsOverlap(cidr1, cidr2 string) (bool, error) { _, ipNet1, err1 := net.ParseCIDR(cidr1) _, ipNet2, err2 := net.ParseCIDR(cidr2) if err1 != nil || err2 != nil { - return false, fmt.Errorf("invalid CIDR blocks") + return false, fmt.Errorf("invalid CIDR blocks, %v, %v", err1, err2) } if ipNet1.Contains(ipNet2.IP) || ipNet2.Contains(ipNet1.IP) { diff --git a/src/k8s/pkg/utils/cidr_test.go b/src/k8s/pkg/utils/cidr_test.go index ebaef7372b..4df6dc76e3 100644 --- a/src/k8s/pkg/utils/cidr_test.go +++ b/src/k8s/pkg/utils/cidr_test.go @@ -210,18 +210,24 @@ func TestCIDRsOverlap(t *testing.T) { g := NewWithT(t) tests := []struct { + name string cidr1 string cidr2 string expected bool }{ - {"192.168.100.0/24", "192.168.100.0/24", true}, - {"192.168.1.0/32", "192.168.5.0/32", false}, - {"fe80::1/64", "fe80::1/64", true}, - {"fe80::/64", "2001:db8::/32", false}, - {"fe80::/64", "fe80::dead/64", true}, + {"SameIPv4CIDRs", "192.168.100.0/24", "192.168.100.0/24", true}, + {"DifferentIPv4CIDRs", "192.0.2.0/24", "192.0.15.0/24", false}, + {"OverlappingIPv4CIDRs", "10.2.0.13/24", "10.2.0.0/24", true}, + {"SameIPv6CIDRs", "fe80::1/64", "fe80::1/64", true}, + {"DifferentIPv6CIDRs", "fe80::/64", "2001:db8::/32", false}, + {"OverlappingIPv6CIDRs", "fe80::/64", "fe80::dead/64", true}, } for _, tc := range tests { - g.Expect(utils.CIDRsOverlap(tc.cidr1, tc.cidr2)).To(Equal(tc.expected)) + t.Run(tc.name, func(t *testing.T) { + overlap, err := utils.CIDRsOverlap(tc.cidr1, tc.cidr2) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(overlap).To(Equal(tc.expected)) + }) } }