diff --git a/libs/config/convert/end_to_end_test.go b/libs/config/convert/end_to_end_test.go new file mode 100644 index 0000000000..c06830e83e --- /dev/null +++ b/libs/config/convert/end_to_end_test.go @@ -0,0 +1,61 @@ +package convert + +import ( + "testing" + + "github.com/databricks/cli/libs/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func assertFromTypedToTypedEqual[T any](t *testing.T, src T) { + nv, err := FromTyped(src, config.NilValue) + require.NoError(t, err) + + var dst T + err = ToTyped(&dst, nv) + require.NoError(t, err) + assert.Equal(t, src, dst) +} + +func TestAdditional(t *testing.T) { + type StructType struct { + Str string `json:"str"` + } + + type Tmp struct { + MapToPointer map[string]*string `json:"map_to_pointer"` + SliceOfPointer []*string `json:"slice_of_pointer"` + NestedStruct StructType `json:"nested_struct"` + } + + t.Run("nil", func(t *testing.T) { + assertFromTypedToTypedEqual(t, Tmp{}) + }) + + t.Run("empty map", func(t *testing.T) { + assertFromTypedToTypedEqual(t, Tmp{ + MapToPointer: map[string]*string{}, + }) + }) + + t.Run("map with nil value", func(t *testing.T) { + assertFromTypedToTypedEqual(t, Tmp{ + MapToPointer: map[string]*string{ + "key": nil, + }, + }) + }) + + t.Run("empty slice", func(t *testing.T) { + assertFromTypedToTypedEqual(t, Tmp{ + SliceOfPointer: []*string{}, + }) + }) + + t.Run("slice with nil value", func(t *testing.T) { + assertFromTypedToTypedEqual(t, Tmp{ + SliceOfPointer: []*string{nil}, + }) + }) +} diff --git a/libs/config/convert/from_typed.go b/libs/config/convert/from_typed.go new file mode 100644 index 0000000000..e3911a9e5a --- /dev/null +++ b/libs/config/convert/from_typed.go @@ -0,0 +1,214 @@ +package convert + +import ( + "fmt" + "reflect" + + "github.com/databricks/cli/libs/config" +) + +// FromTyped converts changes made in the typed structure w.r.t. the configuration value +// back to the configuration value, retaining existing location information where possible. +func FromTyped(src any, ref config.Value) (config.Value, error) { + srcv := reflect.ValueOf(src) + + // Dereference pointer if necessary + for srcv.Kind() == reflect.Pointer { + if srcv.IsNil() { + return config.NilValue, nil + } + srcv = srcv.Elem() + } + + switch srcv.Kind() { + case reflect.Struct: + return fromTypedStruct(srcv, ref) + case reflect.Map: + return fromTypedMap(srcv, ref) + case reflect.Slice: + return fromTypedSlice(srcv, ref) + case reflect.String: + return fromTypedString(srcv, ref) + case reflect.Bool: + return fromTypedBool(srcv, ref) + case reflect.Int, reflect.Int32, reflect.Int64: + return fromTypedInt(srcv, ref) + case reflect.Float32, reflect.Float64: + return fromTypedFloat(srcv, ref) + } + + return config.NilValue, fmt.Errorf("unsupported type: %s", srcv.Kind()) +} + +func fromTypedStruct(src reflect.Value, ref config.Value) (config.Value, error) { + // Check that the reference value is compatible or nil. + switch ref.Kind() { + case config.KindMap, config.KindNil: + default: + return config.Value{}, fmt.Errorf("unhandled type: %s", ref.Kind()) + } + + out := make(map[string]config.Value) + info := getStructInfo(src.Type()) + for k, v := range info.FieldValues(src) { + // Convert the field taking into account the reference value (may be equal to config.NilValue). + nv, err := FromTyped(v.Interface(), ref.Get(k)) + if err != nil { + return config.Value{}, err + } + + if nv != config.NilValue { + out[k] = nv + } + } + + // If the struct was equal to its zero value, emit a nil. + if len(out) == 0 { + return config.NilValue, nil + } + + return config.NewValue(out, ref.Location()), nil +} + +func fromTypedMap(src reflect.Value, ref config.Value) (config.Value, error) { + // Check that the reference value is compatible or nil. + switch ref.Kind() { + case config.KindMap, config.KindNil: + default: + return config.Value{}, fmt.Errorf("unhandled type: %s", ref.Kind()) + } + + // Return nil if the map is nil. + if src.IsNil() { + return config.NilValue, nil + } + + out := make(map[string]config.Value) + iter := src.MapRange() + for iter.Next() { + k := iter.Key().String() + v := iter.Value() + + // Convert entry taking into account the reference value (may be equal to config.NilValue). + nv, err := FromTyped(v.Interface(), ref.Get(k)) + if err != nil { + return config.Value{}, err + } + + // Every entry is represented, even if it is a nil. + // Otherwise, a map with zero-valued structs would yield a nil as well. + out[k] = nv + } + + return config.NewValue(out, ref.Location()), nil +} + +func fromTypedSlice(src reflect.Value, ref config.Value) (config.Value, error) { + // Check that the reference value is compatible or nil. + switch ref.Kind() { + case config.KindSequence, config.KindNil: + default: + return config.Value{}, fmt.Errorf("unhandled type: %s", ref.Kind()) + } + + // Return nil if the slice is nil. + if src.IsNil() { + return config.NilValue, nil + } + + out := make([]config.Value, src.Len()) + for i := 0; i < src.Len(); i++ { + v := src.Index(i) + + // Convert entry taking into account the reference value (may be equal to config.NilValue). + nv, err := FromTyped(v.Interface(), ref.Index(i)) + if err != nil { + return config.Value{}, err + } + + out[i] = nv + } + + return config.NewValue(out, ref.Location()), nil +} + +func fromTypedString(src reflect.Value, ref config.Value) (config.Value, error) { + switch ref.Kind() { + case config.KindString: + value := src.String() + if value == ref.MustString() { + return ref, nil + } + + return config.V(value), nil + case config.KindNil: + // This field is not set in the reference, so we only include it if it has a non-zero value. + // Otherwise, we would always include all zero valued fields. + if src.IsZero() { + return config.NilValue, nil + } + return config.V(src.String()), nil + } + + return config.Value{}, fmt.Errorf("unhandled type: %s", ref.Kind()) +} + +func fromTypedBool(src reflect.Value, ref config.Value) (config.Value, error) { + switch ref.Kind() { + case config.KindBool: + value := src.Bool() + if value == ref.MustBool() { + return ref, nil + } + return config.V(value), nil + case config.KindNil: + // This field is not set in the reference, so we only include it if it has a non-zero value. + // Otherwise, we would always include all zero valued fields. + if src.IsZero() { + return config.NilValue, nil + } + return config.V(src.Bool()), nil + } + + return config.Value{}, fmt.Errorf("unhandled type: %s", ref.Kind()) +} + +func fromTypedInt(src reflect.Value, ref config.Value) (config.Value, error) { + switch ref.Kind() { + case config.KindInt: + value := src.Int() + if value == ref.MustInt() { + return ref, nil + } + return config.V(value), nil + case config.KindNil: + // This field is not set in the reference, so we only include it if it has a non-zero value. + // Otherwise, we would always include all zero valued fields. + if src.IsZero() { + return config.NilValue, nil + } + return config.V(src.Int()), nil + } + + return config.Value{}, fmt.Errorf("unhandled type: %s", ref.Kind()) +} + +func fromTypedFloat(src reflect.Value, ref config.Value) (config.Value, error) { + switch ref.Kind() { + case config.KindFloat: + value := src.Float() + if value == ref.MustFloat() { + return ref, nil + } + return config.V(value), nil + case config.KindNil: + // This field is not set in the reference, so we only include it if it has a non-zero value. + // Otherwise, we would always include all zero valued fields. + if src.IsZero() { + return config.NilValue, nil + } + return config.V(src.Float()), nil + } + + return config.Value{}, fmt.Errorf("unhandled type: %s", ref.Kind()) +} diff --git a/libs/config/convert/from_typed_test.go b/libs/config/convert/from_typed_test.go new file mode 100644 index 0000000000..2b28f549cd --- /dev/null +++ b/libs/config/convert/from_typed_test.go @@ -0,0 +1,394 @@ +package convert + +import ( + "testing" + + "github.com/databricks/cli/libs/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFromTypedStructZeroFields(t *testing.T) { + type Tmp struct { + Foo string `json:"foo"` + Bar string `json:"bar"` + } + + src := Tmp{} + ref := config.NilValue + + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.NilValue, nv) +} + +func TestFromTypedStructSetFields(t *testing.T) { + type Tmp struct { + Foo string `json:"foo"` + Bar string `json:"bar"` + } + + src := Tmp{ + Foo: "foo", + Bar: "bar", + } + + ref := config.NilValue + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(map[string]config.Value{ + "foo": config.V("foo"), + "bar": config.V("bar"), + }), nv) +} + +func TestFromTypedStructSetFieldsRetainLocationIfUnchanged(t *testing.T) { + type Tmp struct { + Foo string `json:"foo"` + Bar string `json:"bar"` + } + + src := Tmp{ + Foo: "bar", + Bar: "qux", + } + + ref := config.V(map[string]config.Value{ + "foo": config.NewValue("bar", config.Location{File: "foo"}), + "bar": config.NewValue("baz", config.Location{File: "bar"}), + }) + + nv, err := FromTyped(src, ref) + require.NoError(t, err) + + // Assert foo has retained its location. + assert.Equal(t, config.NewValue("bar", config.Location{File: "foo"}), nv.Get("foo")) + + // Assert bar lost its location (because it was overwritten). + assert.Equal(t, config.NewValue("qux", config.Location{}), nv.Get("bar")) +} + +func TestFromTypedMapNil(t *testing.T) { + var src map[string]string = nil + + ref := config.V(map[string]config.Value{ + "foo": config.V("bar"), + "bar": config.V("baz"), + }) + + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.NilValue, nv) +} + +func TestFromTypedMapEmpty(t *testing.T) { + var src = map[string]string{} + + ref := config.V(map[string]config.Value{ + "foo": config.V("bar"), + "bar": config.V("baz"), + }) + + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(map[string]config.Value{}), nv) +} + +func TestFromTypedMapNonEmpty(t *testing.T) { + var src = map[string]string{ + "foo": "foo", + "bar": "bar", + } + + ref := config.NilValue + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(map[string]config.Value{ + "foo": config.V("foo"), + "bar": config.V("bar"), + }), nv) +} + +func TestFromTypedMapNonEmptyRetainLocationIfUnchanged(t *testing.T) { + var src = map[string]string{ + "foo": "bar", + "bar": "qux", + } + + ref := config.V(map[string]config.Value{ + "foo": config.NewValue("bar", config.Location{File: "foo"}), + "bar": config.NewValue("baz", config.Location{File: "bar"}), + }) + + nv, err := FromTyped(src, ref) + require.NoError(t, err) + + // Assert foo has retained its location. + assert.Equal(t, config.NewValue("bar", config.Location{File: "foo"}), nv.Get("foo")) + + // Assert bar lost its location (because it was overwritten). + assert.Equal(t, config.NewValue("qux", config.Location{}), nv.Get("bar")) +} + +func TestFromTypedMapFieldWithZeroValue(t *testing.T) { + var src = map[string]string{ + "foo": "", + } + + ref := config.NilValue + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(map[string]config.Value{ + "foo": config.NilValue, + }), nv) +} + +func TestFromTypedSliceNil(t *testing.T) { + var src []string = nil + + ref := config.V([]config.Value{ + config.V("bar"), + config.V("baz"), + }) + + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.NilValue, nv) +} + +func TestFromTypedSliceEmpty(t *testing.T) { + var src = []string{} + + ref := config.V([]config.Value{ + config.V("bar"), + config.V("baz"), + }) + + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V([]config.Value{}), nv) +} + +func TestFromTypedSliceNonEmpty(t *testing.T) { + var src = []string{ + "foo", + "bar", + } + + ref := config.NilValue + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V([]config.Value{ + config.V("foo"), + config.V("bar"), + }), nv) +} + +func TestFromTypedSliceNonEmptyRetainLocationIfUnchanged(t *testing.T) { + var src = []string{ + "foo", + "bar", + } + + ref := config.V([]config.Value{ + config.NewValue("foo", config.Location{File: "foo"}), + config.NewValue("baz", config.Location{File: "baz"}), + }) + + nv, err := FromTyped(src, ref) + require.NoError(t, err) + + // Assert foo has retained its location. + assert.Equal(t, config.NewValue("foo", config.Location{File: "foo"}), nv.Index(0)) + + // Assert bar lost its location (because it was overwritten). + assert.Equal(t, config.NewValue("bar", config.Location{}), nv.Index(1)) +} + +func TestFromTypedStringEmpty(t *testing.T) { + var src string + var ref = config.NilValue + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.NilValue, nv) +} + +func TestFromTypedStringEmptyOverwrite(t *testing.T) { + var src string + var ref = config.V("old") + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(""), nv) +} + +func TestFromTypedStringNonEmpty(t *testing.T) { + var src string = "new" + var ref = config.NilValue + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V("new"), nv) +} + +func TestFromTypedStringNonEmptyOverwrite(t *testing.T) { + var src string = "new" + var ref = config.V("old") + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V("new"), nv) +} + +func TestFromTypedStringRetainsLocationsIfUnchanged(t *testing.T) { + var src string = "foo" + var ref = config.NewValue("foo", config.Location{File: "foo"}) + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.NewValue("foo", config.Location{File: "foo"}), nv) +} + +func TestFromTypedStringTypeError(t *testing.T) { + var src string = "foo" + var ref = config.V(1234) + _, err := FromTyped(src, ref) + require.Error(t, err) +} + +func TestFromTypedBoolEmpty(t *testing.T) { + var src bool + var ref = config.NilValue + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.NilValue, nv) +} + +func TestFromTypedBoolEmptyOverwrite(t *testing.T) { + var src bool + var ref = config.V(true) + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(false), nv) +} + +func TestFromTypedBoolNonEmpty(t *testing.T) { + var src bool = true + var ref = config.NilValue + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(true), nv) +} + +func TestFromTypedBoolNonEmptyOverwrite(t *testing.T) { + var src bool = true + var ref = config.V(false) + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(true), nv) +} + +func TestFromTypedBoolRetainsLocationsIfUnchanged(t *testing.T) { + var src bool = true + var ref = config.NewValue(true, config.Location{File: "foo"}) + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.NewValue(true, config.Location{File: "foo"}), nv) +} + +func TestFromTypedBoolTypeError(t *testing.T) { + var src bool = true + var ref = config.V("string") + _, err := FromTyped(src, ref) + require.Error(t, err) +} + +func TestFromTypedIntEmpty(t *testing.T) { + var src int + var ref = config.NilValue + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.NilValue, nv) +} + +func TestFromTypedIntEmptyOverwrite(t *testing.T) { + var src int + var ref = config.V(1234) + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(int64(0)), nv) +} + +func TestFromTypedIntNonEmpty(t *testing.T) { + var src int = 1234 + var ref = config.NilValue + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(int64(1234)), nv) +} + +func TestFromTypedIntNonEmptyOverwrite(t *testing.T) { + var src int = 1234 + var ref = config.V(1233) + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(int64(1234)), nv) +} + +func TestFromTypedIntRetainsLocationsIfUnchanged(t *testing.T) { + var src int = 1234 + var ref = config.NewValue(1234, config.Location{File: "foo"}) + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.NewValue(1234, config.Location{File: "foo"}), nv) +} + +func TestFromTypedIntTypeError(t *testing.T) { + var src int = 1234 + var ref = config.V("string") + _, err := FromTyped(src, ref) + require.Error(t, err) +} + +func TestFromTypedFloatEmpty(t *testing.T) { + var src float64 + var ref = config.NilValue + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.NilValue, nv) +} + +func TestFromTypedFloatEmptyOverwrite(t *testing.T) { + var src float64 + var ref = config.V(1.23) + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(0.0), nv) +} + +func TestFromTypedFloatNonEmpty(t *testing.T) { + var src float64 = 1.23 + var ref = config.NilValue + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(1.23), nv) +} + +func TestFromTypedFloatNonEmptyOverwrite(t *testing.T) { + var src float64 = 1.23 + var ref = config.V(1.24) + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.V(1.23), nv) +} + +func TestFromTypedFloatRetainsLocationsIfUnchanged(t *testing.T) { + var src float64 = 1.23 + var ref = config.NewValue(1.23, config.Location{File: "foo"}) + nv, err := FromTyped(src, ref) + require.NoError(t, err) + assert.Equal(t, config.NewValue(1.23, config.Location{File: "foo"}), nv) +} + +func TestFromTypedFloatTypeError(t *testing.T) { + var src float64 = 1.23 + var ref = config.V("string") + _, err := FromTyped(src, ref) + require.Error(t, err) +} diff --git a/libs/config/convert/struct_info.go b/libs/config/convert/struct_info.go index 367b9ecdc4..2457b3c297 100644 --- a/libs/config/convert/struct_info.go +++ b/libs/config/convert/struct_info.go @@ -85,3 +85,31 @@ func buildStructInfo(typ reflect.Type) structInfo { return out } + +func (s *structInfo) FieldValues(v reflect.Value) map[string]reflect.Value { + var out = make(map[string]reflect.Value) + + for k, index := range s.Fields { + fv := v + + // Locate value in struct (it could be an embedded type). + for i, x := range index { + if i > 0 { + if fv.Kind() == reflect.Pointer && fv.Type().Elem().Kind() == reflect.Struct { + if fv.IsNil() { + fv = reflect.Value{} + break + } + fv = fv.Elem() + } + } + fv = fv.Field(x) + } + + if fv.IsValid() { + out[k] = fv + } + } + + return out +} diff --git a/libs/config/convert/struct_info_test.go b/libs/config/convert/struct_info_test.go index 3079958b2b..2e31adac16 100644 --- a/libs/config/convert/struct_info_test.go +++ b/libs/config/convert/struct_info_test.go @@ -87,3 +87,110 @@ func TestStructInfoAnonymousByPointer(t *testing.T) { assert.Equal(t, []int{0, 0}, si.Fields["foo"]) assert.Equal(t, []int{0, 1, 0}, si.Fields["bar"]) } + +func TestStructInfoFieldValues(t *testing.T) { + type Tmp struct { + Foo string `json:"foo"` + Bar string `json:"bar"` + } + + var src = Tmp{ + Foo: "foo", + Bar: "bar", + } + + si := getStructInfo(reflect.TypeOf(Tmp{})) + fv := si.FieldValues(reflect.ValueOf(src)) + assert.Len(t, fv, 2) + assert.Equal(t, "foo", fv["foo"].String()) + assert.Equal(t, "bar", fv["bar"].String()) +} + +func TestStructInfoFieldValuesAnonymousByValue(t *testing.T) { + type Bar struct { + Bar string `json:"bar"` + } + + type Foo struct { + Foo string `json:"foo"` + Bar + } + + type Tmp struct { + Foo + } + + var src = Tmp{ + Foo: Foo{ + Foo: "foo", + Bar: Bar{ + Bar: "bar", + }, + }, + } + + si := getStructInfo(reflect.TypeOf(Tmp{})) + fv := si.FieldValues(reflect.ValueOf(src)) + assert.Len(t, fv, 2) + assert.Equal(t, "foo", fv["foo"].String()) + assert.Equal(t, "bar", fv["bar"].String()) +} + +func TestStructInfoFieldValuesAnonymousByPointer(t *testing.T) { + type Bar struct { + Bar string `json:"bar"` + } + + type Foo struct { + Foo string `json:"foo"` + *Bar + } + + type Tmp struct { + *Foo + } + + // Test that the embedded fields are dereferenced properly. + t.Run("all are set", func(t *testing.T) { + src := Tmp{ + Foo: &Foo{ + Foo: "foo", + Bar: &Bar{ + Bar: "bar", + }, + }, + } + + si := getStructInfo(reflect.TypeOf(Tmp{})) + fv := si.FieldValues(reflect.ValueOf(src)) + assert.Len(t, fv, 2) + assert.Equal(t, "foo", fv["foo"].String()) + assert.Equal(t, "bar", fv["bar"].String()) + }) + + // Test that fields of embedded types are skipped if the embedded type is nil. + t.Run("top level is set", func(t *testing.T) { + src := Tmp{ + Foo: &Foo{ + Foo: "foo", + Bar: nil, + }, + } + + si := getStructInfo(reflect.TypeOf(Tmp{})) + fv := si.FieldValues(reflect.ValueOf(src)) + assert.Len(t, fv, 1) + assert.Equal(t, "foo", fv["foo"].String()) + }) + + // Test that fields of embedded types are skipped if the embedded type is nil. + t.Run("none are set", func(t *testing.T) { + src := Tmp{ + Foo: nil, + } + + si := getStructInfo(reflect.TypeOf(Tmp{})) + fv := si.FieldValues(reflect.ValueOf(src)) + assert.Empty(t, fv) + }) +} diff --git a/libs/config/convert/to_typed.go b/libs/config/convert/to_typed.go index 9915d30a6e..ca09fce42b 100644 --- a/libs/config/convert/to_typed.go +++ b/libs/config/convert/to_typed.go @@ -13,6 +13,12 @@ func ToTyped(dst any, src config.Value) error { // Dereference pointer if necessary for dstv.Kind() == reflect.Pointer { + // If the source value is nil and the destination is a settable pointer, + // set the destination to nil. Also see `end_to_end_test.go`. + if dstv.CanSet() && src == config.NilValue { + dstv.SetZero() + return nil + } if dstv.IsNil() { dstv.Set(reflect.New(dstv.Type().Elem())) }