Skip to content

Commit

Permalink
fix: ensure NillableDuration is round-trippable
Browse files Browse the repository at this point in the history
  • Loading branch information
jmdeal committed Aug 9, 2024
1 parent 5758fa2 commit 14e0a2b
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 12 deletions.
27 changes: 21 additions & 6 deletions pkg/apis/v1/duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package v1

import (
"encoding/json"
"fmt"
"slices"
"time"
)

Expand All @@ -28,10 +30,15 @@ const Never = "Never"
// that the duration is disabled and sets the inner duration as nil
type NillableDuration struct {
*time.Duration

// Raw is used to ensure we remarshal the NillableDuration in the same format it was specified.
// This ensures tools like Flux and ArgoCD don't mistakenly detect drift due to our conversion webhooks.
Raw []byte
}

// UnmarshalJSON implements the json.Unmarshaller interface.
func (d *NillableDuration) UnmarshalJSON(b []byte) error {
fmt.Printf("raw string (0): %s\n", string(b))
var str string
err := json.Unmarshal(b, &str)
if err != nil {
Expand All @@ -40,26 +47,34 @@ func (d *NillableDuration) UnmarshalJSON(b []byte) error {
if str == Never {
return nil
}
fmt.Printf("raw string (1): %s\n", str)
pd, err := time.ParseDuration(str)
if err != nil {
return err
}
d.Raw = slices.Clone(b)
d.Duration = &pd
return nil
}

// MarshalJSON implements the json.Marshaler interface.
func (d NillableDuration) MarshalJSON() ([]byte, error) {
if d.Duration == nil {
return json.Marshal(Never)
if d.Raw != nil {
return d.Raw, nil
}
if d.Duration != nil {
return json.Marshal(d.Duration.String())
}
return json.Marshal(d.Duration.String())
return json.Marshal(Never)
}

// ToUnstructured implements the value.UnstructuredConverter interface.
func (d NillableDuration) ToUnstructured() interface{} {
if d.Duration == nil {
return Never
if d.Raw != nil {
return d.Raw
}
if d.Duration != nil {
return d.Duration.String()
}
return d.Duration.String()
return Never
}
10 changes: 10 additions & 0 deletions pkg/apis/v1/nodepool_conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,13 @@ func (in *NodeClaimTemplate) convertFrom(ctx context.Context, v1beta1np *v1beta1

return "", nil
}

// func (in *NillableDuration) convertFrom(src *v1beta1.NillableDuration) {
// in.Raw = slices.Clone(src.Raw)
// in.Duration = lo.ToPtr(*src.Duration)
// }
//
// func (in *NillableDuration) convertTo(dst *v1beta1.NillableDuration) {
// dst.Raw = slices.Clone(in.Raw)
// dst.Duration = lo.ToPtr(*in.Duration)
// }
42 changes: 42 additions & 0 deletions pkg/apis/v1/nodepool_conversion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package v1_test

import (
"encoding/json"
"fmt"
"time"

"sigs.k8s.io/karpenter/pkg/test/v1alpha1"
Expand Down Expand Up @@ -269,6 +270,24 @@ var _ = Describe("Convert V1 to V1beta1 NodePool API", func() {
Expect(v1beta1nodepool.Status.Resources[resource]).To(Equal(v1nodepool.Status.Resources[resource]))
}
})
Context("Round Trip", func() {
DescribeTable(
"NillableDuration",
func(value string, field func() *NillableDuration) {
str := fmt.Sprintf("%q", value)
duration := NillableDuration{}
Expect(json.Unmarshal([]byte(str), &duration)).Should(Succeed())
*field() = duration
Expect(v1nodepool.ConvertTo(ctx, v1beta1nodepool)).To(Succeed())
Expect(v1nodepool.ConvertFrom(ctx, v1beta1nodepool)).To(Succeed())
result, err := json.Marshal(*field())
Expect(err).To(BeNil())
Expect(string(result)).To(Equal(str))
},
Entry("spec.template.spec.expireAfter", "720h", func() *NillableDuration { return &v1nodepool.Spec.Template.Spec.ExpireAfter }),
Entry("spec.disruption.consolidateAfter", "15m", func() *NillableDuration { return &v1nodepool.Spec.Disruption.ConsolidateAfter }),
)
})
})

var _ = Describe("Convert V1beta1 to V1 NodePool API", func() {
Expand Down Expand Up @@ -537,4 +556,27 @@ var _ = Describe("Convert V1beta1 to V1 NodePool API", func() {
Expect(v1beta1nodepool.Status.Resources[resource]).To(Equal(v1nodepool.Status.Resources[resource]))
}
})
Context("Round Trip", func() {
DescribeTable(
"NillableDuration",
func(value string, field func() *v1beta1.NillableDuration) {
str := fmt.Sprintf("%q", value)
duration := v1beta1.NillableDuration{}
Expect(json.Unmarshal([]byte(str), &duration)).Should(Succeed())
*field() = duration
Expect(v1nodepool.ConvertFrom(ctx, v1beta1nodepool)).To(Succeed())
Expect(v1nodepool.ConvertTo(ctx, v1beta1nodepool)).To(Succeed())
result, err := json.Marshal(*field())
Expect(err).To(BeNil())
Expect(string(result)).To(Equal(str))
},
Entry("spec.template.spec.expireAfter", "720h", func() *v1beta1.NillableDuration { return &v1beta1nodepool.Spec.Disruption.ExpireAfter }),
Entry("spec.disruption.consolidateAfter", "15m", func() *v1beta1.NillableDuration {
if v1beta1nodepool.Spec.Disruption.ConsolidateAfter == nil {
v1beta1nodepool.Spec.Disruption.ConsolidateAfter = &v1beta1.NillableDuration{}
}
return v1beta1nodepool.Spec.Disruption.ConsolidateAfter
}),
)
})
})
5 changes: 5 additions & 0 deletions pkg/apis/v1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 21 additions & 6 deletions pkg/apis/v1beta1/duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package v1beta1

import (
"encoding/json"
"fmt"
"slices"
"time"
)

Expand All @@ -28,10 +30,15 @@ const Never = "Never"
// that the duration is disabled and sets the inner duration as nil
type NillableDuration struct {
*time.Duration

// Raw is used to ensure we remarshal the NillableDuration in the same format it was specified.
// This ensures tools like Flux and ArgoCD don't mistakenly detect drift due to our conversion webhooks.
Raw []byte
}

// UnmarshalJSON implements the json.Unmarshaller interface.
func (d *NillableDuration) UnmarshalJSON(b []byte) error {
fmt.Printf("raw string (0): %s\n", string(b))
var str string
err := json.Unmarshal(b, &str)
if err != nil {
Expand All @@ -40,26 +47,34 @@ func (d *NillableDuration) UnmarshalJSON(b []byte) error {
if str == Never {
return nil
}
fmt.Printf("raw string (1): %s\n", str)
pd, err := time.ParseDuration(str)
if err != nil {
return err
}
d.Raw = slices.Clone(b)
d.Duration = &pd
return nil
}

// MarshalJSON implements the json.Marshaler interface.
func (d NillableDuration) MarshalJSON() ([]byte, error) {
if d.Duration == nil {
return json.Marshal(Never)
if d.Raw != nil {
return d.Raw, nil
}
if d.Duration != nil {
return json.Marshal(d.Duration.String())
}
return json.Marshal(d.Duration.String())
return json.Marshal(Never)
}

// ToUnstructured implements the value.UnstructuredConverter interface.
func (d NillableDuration) ToUnstructured() interface{} {
if d.Duration == nil {
return Never
if d.Raw != nil {
return d.Raw
}
if d.Duration != nil {
return d.Duration.String()
}
return d.Duration.String()
return Never
}
5 changes: 5 additions & 0 deletions pkg/apis/v1beta1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 14e0a2b

Please sign in to comment.