Skip to content

Commit

Permalink
fix: ensure NillableDuration is round-trippable (#1545)
Browse files Browse the repository at this point in the history
Co-authored-by: Nick Tran <10810510+njtran@users.noreply.github.com>
  • Loading branch information
jmdeal and njtran authored Aug 12, 2024
1 parent ba15aa4 commit feb6857
Show file tree
Hide file tree
Showing 18 changed files with 191 additions and 74 deletions.
34 changes: 28 additions & 6 deletions pkg/apis/v1/duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ package v1

import (
"encoding/json"
"fmt"
"slices"
"time"

"github.com/samber/lo"
)

const Never = "Never"
Expand All @@ -28,6 +32,17 @@ const Never = "Never"
// that the duration is disabled and sets the inner duration as nil
type NillableDuration struct {
*time.Duration

// Raw is used to ensure we remarshal the NillableDuration in the same format it was specified.
// This ensures tools like Flux and ArgoCD don't mistakenly detect drift due to our conversion webhooks.
Raw []byte `hash:"ignore"`
}

func MustParseNillableDuration(val string) NillableDuration {
nd := NillableDuration{}
// Use %q instead of %s to ensure that we unmarshal the value as a string and not an int
lo.Must0(json.Unmarshal([]byte(fmt.Sprintf("%q", val)), &nd))
return nd
}

// UnmarshalJSON implements the json.Unmarshaller interface.
Expand All @@ -44,22 +59,29 @@ func (d *NillableDuration) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
d.Raw = slices.Clone(b)
d.Duration = &pd
return nil
}

// MarshalJSON implements the json.Marshaler interface.
func (d NillableDuration) MarshalJSON() ([]byte, error) {
if d.Duration == nil {
return json.Marshal(Never)
if d.Raw != nil {
return d.Raw, nil
}
if d.Duration != nil {
return json.Marshal(d.Duration.String())
}
return json.Marshal(d.Duration.String())
return json.Marshal(Never)
}

// ToUnstructured implements the value.UnstructuredConverter interface.
func (d NillableDuration) ToUnstructured() interface{} {
if d.Duration == nil {
return Never
if d.Raw != nil {
return d.Raw
}
if d.Duration != nil {
return d.Duration.String()
}
return d.Duration.String()
return Never
}
6 changes: 3 additions & 3 deletions pkg/apis/v1/nodeclaim_conversion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ var _ = Describe("Convert V1beta1 to V1 NodeClaim API", func() {

BeforeEach(func() {
v1nodePool = test.NodePool()
v1nodePool.Spec.Template.Spec.ExpireAfter = NillableDuration{Duration: lo.ToPtr(30 * time.Minute)}
v1nodePool.Spec.Template.Spec.ExpireAfter = MustParseNillableDuration("30m")
v1nodeclaim = &NodeClaim{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
Expand Down Expand Up @@ -306,13 +306,13 @@ var _ = Describe("Convert V1beta1 to V1 NodeClaim API", func() {
})
It("should default the v1beta1 expireAfter to v1 when the nodepool doesn't exist", func() {
Expect(env.Client.Delete(ctx, v1nodePool)).To(Succeed())
v1nodePool.Spec.Template.Spec.ExpireAfter = NillableDuration{Duration: lo.ToPtr(30 * time.Minute)}
v1nodePool.Spec.Template.Spec.ExpireAfter = MustParseNillableDuration("30m")
Expect(v1nodeclaim.ConvertFrom(ctx, v1beta1nodeclaim)).To(Succeed())
Expect(v1nodeclaim.Spec.ExpireAfter.Duration).To(BeNil())
})
It("should default the v1beta1 expireAfter to v1 when the nodepool label doesn't exist", func() {
delete(v1beta1nodeclaim.Labels, v1beta1.NodePoolLabelKey)
v1nodePool.Spec.Template.Spec.ExpireAfter = NillableDuration{Duration: lo.ToPtr(30 * time.Minute)}
v1nodePool.Spec.Template.Spec.ExpireAfter = MustParseNillableDuration("30m")
Expect(env.Client.Update(ctx, v1nodePool)).To(Succeed())
Expect(v1nodeclaim.ConvertFrom(ctx, v1beta1nodeclaim)).To(Succeed())
Expect(v1nodeclaim.Spec.ExpireAfter.Duration).To(BeNil())
Expand Down
74 changes: 71 additions & 3 deletions pkg/apis/v1/nodepool_conversion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,13 @@ var _ = Describe("Convert V1 to V1beta1 NodePool API", func() {
Context("Disruption", func() {
It("should convert v1 nodepool consolidateAfter to nil with WhenEmptyOrUnderutilized", func() {
v1nodepool.Spec.Disruption.ConsolidationPolicy = ConsolidationPolicyWhenEmptyOrUnderutilized
v1nodepool.Spec.Disruption.ConsolidateAfter = NillableDuration{Duration: lo.ToPtr(time.Second * 2121)}
v1nodepool.Spec.Disruption.ConsolidateAfter = MustParseNillableDuration("2121s")
Expect(v1nodepool.ConvertTo(ctx, v1beta1nodepool)).To(Succeed())
Expect(v1beta1nodepool.Spec.Disruption.ConsolidateAfter).To(BeNil())
})
It("should convert v1 nodepool consolidateAfter with WhenEmpty", func() {
v1nodepool.Spec.Disruption.ConsolidationPolicy = ConsolidationPolicyWhenEmpty
v1nodepool.Spec.Disruption.ConsolidateAfter = NillableDuration{Duration: lo.ToPtr(time.Second * 2121)}
v1nodepool.Spec.Disruption.ConsolidateAfter = MustParseNillableDuration("2121s")
Expect(v1nodepool.ConvertTo(ctx, v1beta1nodepool)).To(Succeed())
Expect(lo.FromPtr(v1beta1nodepool.Spec.Disruption.ConsolidateAfter.Duration)).To(Equal(lo.FromPtr(v1nodepool.Spec.Disruption.ConsolidateAfter.Duration)))
})
Expand All @@ -234,7 +234,7 @@ var _ = Describe("Convert V1 to V1beta1 NodePool API", func() {
Expect(string(v1beta1nodepool.Spec.Disruption.ConsolidationPolicy)).To(Equal(string(v1nodepool.Spec.Disruption.ConsolidationPolicy)))
})
It("should convert v1 nodepool ExpireAfter", func() {
v1nodepool.Spec.Template.Spec.ExpireAfter = NillableDuration{Duration: lo.ToPtr(time.Second * 2121)}
v1nodepool.Spec.Template.Spec.ExpireAfter = MustParseNillableDuration("2121s")
Expect(v1nodepool.ConvertTo(ctx, v1beta1nodepool)).To(Succeed())
Expect(v1beta1nodepool.Spec.Disruption.ExpireAfter.Duration).To(Equal(v1nodepool.Spec.Template.Spec.ExpireAfter.Duration))
})
Expand Down Expand Up @@ -279,6 +279,40 @@ var _ = Describe("Convert V1 to V1beta1 NodePool API", func() {
Expect(v1beta1nodepool.Status.Resources[resource]).To(Equal(v1nodepool.Status.Resources[resource]))
}
})
Context("Round Trip", func() {
It("spec.template.spec.expireAfter", func() {
v1nodepool.Spec.Template.Spec.ExpireAfter = MustParseNillableDuration("10h")
Expect(v1nodepool.ConvertTo(ctx, v1beta1nodepool)).To(Succeed())
Expect(v1nodepool.ConvertFrom(ctx, v1beta1nodepool)).To(Succeed())
result, err := json.Marshal(v1nodepool.Spec.Template.Spec.ExpireAfter)
Expect(err).To(BeNil())
Expect(string(result)).To(Equal(`"10h"`))
})
It("spec.template.spec.expireAfter (Never)", func() {
v1nodepool.Spec.Template.Spec.ExpireAfter = MustParseNillableDuration("Never")
Expect(v1nodepool.ConvertTo(ctx, v1beta1nodepool)).To(Succeed())
Expect(v1nodepool.ConvertFrom(ctx, v1beta1nodepool)).To(Succeed())
result, err := json.Marshal(v1nodepool.Spec.Template.Spec.ExpireAfter)
Expect(err).To(BeNil())
Expect(string(result)).To(Equal(`"Never"`))
})
It("spec.disruption.consolidateAfter", func() {
v1nodepool.Spec.Disruption.ConsolidateAfter = MustParseNillableDuration("10h")
Expect(v1nodepool.ConvertTo(ctx, v1beta1nodepool)).To(Succeed())
Expect(v1nodepool.ConvertFrom(ctx, v1beta1nodepool)).To(Succeed())
result, err := json.Marshal(v1nodepool.Spec.Disruption.ConsolidateAfter)
Expect(err).To(BeNil())
Expect(string(result)).To(Equal(`"10h"`))
})
It("spec.disruption.consolidateAfter (Never)", func() {
v1nodepool.Spec.Disruption.ConsolidateAfter = MustParseNillableDuration("Never")
Expect(v1nodepool.ConvertTo(ctx, v1beta1nodepool)).To(Succeed())
Expect(v1nodepool.ConvertFrom(ctx, v1beta1nodepool)).To(Succeed())
result, err := json.Marshal(v1nodepool.Spec.Disruption.ConsolidateAfter)
Expect(err).To(BeNil())
Expect(string(result)).To(Equal(`"Never"`))
})
})
})

var _ = Describe("Convert V1beta1 to V1 NodePool API", func() {
Expand Down Expand Up @@ -555,4 +589,38 @@ var _ = Describe("Convert V1beta1 to V1 NodePool API", func() {
Expect(v1beta1nodepool.Status.Resources[resource]).To(Equal(v1nodepool.Status.Resources[resource]))
}
})
Context("Round Trip", func() {
It("spec.disruption.expireAfter", func() {
v1beta1nodepool.Spec.Disruption.ExpireAfter = v1beta1.MustParseNillableDuration("10h")
Expect(v1nodepool.ConvertFrom(ctx, v1beta1nodepool)).To(Succeed())
Expect(v1nodepool.ConvertTo(ctx, v1beta1nodepool)).To(Succeed())
result, err := json.Marshal(v1beta1nodepool.Spec.Disruption.ExpireAfter)
Expect(err).To(BeNil())
Expect(string(result)).To(Equal(`"10h"`))
})
It("spec.disruption.expireAfter (Never)", func() {
v1beta1nodepool.Spec.Disruption.ExpireAfter = v1beta1.MustParseNillableDuration("Never")
Expect(v1nodepool.ConvertFrom(ctx, v1beta1nodepool)).To(Succeed())
Expect(v1nodepool.ConvertTo(ctx, v1beta1nodepool)).To(Succeed())
result, err := json.Marshal(v1beta1nodepool.Spec.Disruption.ExpireAfter)
Expect(err).To(BeNil())
Expect(string(result)).To(Equal(`"Never"`))
})
It("spec.disruption.consolidateAfter", func() {
v1beta1nodepool.Spec.Disruption.ConsolidateAfter = lo.ToPtr(v1beta1.MustParseNillableDuration("10h"))
Expect(v1nodepool.ConvertFrom(ctx, v1beta1nodepool)).To(Succeed())
Expect(v1nodepool.ConvertTo(ctx, v1beta1nodepool)).To(Succeed())
result, err := json.Marshal(lo.FromPtr(v1beta1nodepool.Spec.Disruption.ConsolidateAfter))
Expect(err).To(BeNil())
Expect(string(result)).To(Equal(`"10h"`))
})
It("spec.disruption.consolidateAfter (Never)", func() {
v1beta1nodepool.Spec.Disruption.ConsolidateAfter = lo.ToPtr(v1beta1.MustParseNillableDuration("Never"))
Expect(v1nodepool.ConvertFrom(ctx, v1beta1nodepool)).To(Succeed())
Expect(v1nodepool.ConvertTo(ctx, v1beta1nodepool)).To(Succeed())
result, err := json.Marshal(lo.FromPtr(v1beta1nodepool.Spec.Disruption.ConsolidateAfter))
Expect(err).To(BeNil())
Expect(string(result)).To(Equal(`"Never"`))
})
})
})
20 changes: 10 additions & 10 deletions pkg/apis/v1/nodepool_validation_cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,47 +64,47 @@ var _ = Describe("CEL/Validation", func() {
})
Context("Disruption", func() {
It("should fail on negative expireAfter", func() {
nodePool.Spec.Template.Spec.ExpireAfter.Duration = lo.ToPtr(lo.Must(time.ParseDuration("-1s")))
nodePool.Spec.Template.Spec.ExpireAfter = MustParseNillableDuration("-1s")
Expect(env.Client.Create(ctx, nodePool)).ToNot(Succeed())
})
It("should succeed on a disabled expireAfter", func() {
nodePool.Spec.Template.Spec.ExpireAfter.Duration = nil
nodePool.Spec.Template.Spec.ExpireAfter = MustParseNillableDuration("Never")
Expect(env.Client.Create(ctx, nodePool)).To(Succeed())
})
It("should succeed on a valid expireAfter", func() {
nodePool.Spec.Template.Spec.ExpireAfter.Duration = lo.ToPtr(lo.Must(time.ParseDuration("30s")))
nodePool.Spec.Template.Spec.ExpireAfter = MustParseNillableDuration("30s")
Expect(env.Client.Create(ctx, nodePool)).To(Succeed())
})
It("should fail on negative consolidateAfter", func() {
nodePool.Spec.Disruption.ConsolidateAfter = NillableDuration{Duration: lo.ToPtr(lo.Must(time.ParseDuration("-1s")))}
nodePool.Spec.Disruption.ConsolidateAfter = MustParseNillableDuration("-1s")
Expect(env.Client.Create(ctx, nodePool)).ToNot(Succeed())
})
It("should succeed on a disabled consolidateAfter", func() {
nodePool.Spec.Disruption.ConsolidateAfter = NillableDuration{Duration: nil}
nodePool.Spec.Disruption.ConsolidateAfter = MustParseNillableDuration("Never")
Expect(env.Client.Create(ctx, nodePool)).To(Succeed())
})
It("should succeed on a valid consolidateAfter", func() {
nodePool.Spec.Disruption.ConsolidateAfter = NillableDuration{Duration: lo.ToPtr(lo.Must(time.ParseDuration("30s")))}
nodePool.Spec.Disruption.ConsolidateAfter = MustParseNillableDuration("30s")
nodePool.Spec.Disruption.ConsolidationPolicy = ConsolidationPolicyWhenEmpty
Expect(env.Client.Create(ctx, nodePool)).To(Succeed())
})
It("should succeed when setting consolidateAfter with consolidationPolicy=WhenEmpty", func() {
nodePool.Spec.Disruption.ConsolidateAfter = NillableDuration{Duration: lo.ToPtr(lo.Must(time.ParseDuration("30s")))}
nodePool.Spec.Disruption.ConsolidateAfter = MustParseNillableDuration("30s")
nodePool.Spec.Disruption.ConsolidationPolicy = ConsolidationPolicyWhenEmpty
Expect(env.Client.Create(ctx, nodePool)).To(Succeed())
})
It("should succeed when setting consolidateAfter with consolidationPolicy=WhenUnderutilized", func() {
nodePool.Spec.Disruption.ConsolidateAfter = NillableDuration{Duration: lo.ToPtr(lo.Must(time.ParseDuration("30s")))}
nodePool.Spec.Disruption.ConsolidateAfter = MustParseNillableDuration("30s")
nodePool.Spec.Disruption.ConsolidationPolicy = ConsolidationPolicyWhenEmptyOrUnderutilized
Expect(env.Client.Create(ctx, nodePool)).To(Succeed())
})
It("should succeed when setting consolidateAfter to 'Never' with consolidationPolicy=WhenUnderutilized", func() {
nodePool.Spec.Disruption.ConsolidateAfter = NillableDuration{Duration: nil}
nodePool.Spec.Disruption.ConsolidateAfter = MustParseNillableDuration("Never")
nodePool.Spec.Disruption.ConsolidationPolicy = ConsolidationPolicyWhenEmptyOrUnderutilized
Expect(env.Client.Create(ctx, nodePool)).To(Succeed())
})
It("should succeed when setting consolidateAfter to 'Never' with consolidationPolicy=WhenEmpty", func() {
nodePool.Spec.Disruption.ConsolidateAfter = NillableDuration{Duration: nil}
nodePool.Spec.Disruption.ConsolidateAfter = MustParseNillableDuration("Never")
nodePool.Spec.Disruption.ConsolidationPolicy = ConsolidationPolicyWhenEmpty
Expect(env.Client.Create(ctx, nodePool)).To(Succeed())
})
Expand Down
5 changes: 5 additions & 0 deletions pkg/apis/v1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 28 additions & 6 deletions pkg/apis/v1beta1/duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ package v1beta1

import (
"encoding/json"
"fmt"
"slices"
"time"

"github.com/samber/lo"
)

const Never = "Never"
Expand All @@ -28,6 +32,17 @@ const Never = "Never"
// that the duration is disabled and sets the inner duration as nil
type NillableDuration struct {
*time.Duration

// Raw is used to ensure we remarshal the NillableDuration in the same format it was specified.
// This ensures tools like Flux and ArgoCD don't mistakenly detect drift due to our conversion webhooks.
Raw []byte `hash:"ignore"`
}

func MustParseNillableDuration(val string) NillableDuration {
nd := NillableDuration{}
// Use %q instead of %s to ensure that we unmarshal the value as a string and not an int
lo.Must0(json.Unmarshal([]byte(fmt.Sprintf("%q", val)), &nd))
return nd
}

// UnmarshalJSON implements the json.Unmarshaller interface.
Expand All @@ -44,22 +59,29 @@ func (d *NillableDuration) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
d.Raw = slices.Clone(b)
d.Duration = &pd
return nil
}

// MarshalJSON implements the json.Marshaler interface.
func (d NillableDuration) MarshalJSON() ([]byte, error) {
if d.Duration == nil {
return json.Marshal(Never)
if d.Raw != nil {
return d.Raw, nil
}
if d.Duration != nil {
return json.Marshal(d.Duration.String())
}
return json.Marshal(d.Duration.String())
return json.Marshal(Never)
}

// ToUnstructured implements the value.UnstructuredConverter interface.
func (d NillableDuration) ToUnstructured() interface{} {
if d.Duration == nil {
return Never
if d.Raw != nil {
return d.Raw
}
if d.Duration != nil {
return d.Duration.String()
}
return d.Duration.String()
return Never
}
5 changes: 5 additions & 0 deletions pkg/apis/v1beta1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit feb6857

Please sign in to comment.