diff --git a/libs/dyn/convert/end_to_end_test.go b/libs/dyn/convert/end_to_end_test.go index 7c048136ee..33902bea85 100644 --- a/libs/dyn/convert/end_to_end_test.go +++ b/libs/dyn/convert/end_to_end_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" "github.com/stretchr/testify/require" ) diff --git a/libs/dyn/convert/from_typed.go b/libs/dyn/convert/from_typed.go index 4778edb960..c344d12dff 100644 --- a/libs/dyn/convert/from_typed.go +++ b/libs/dyn/convert/from_typed.go @@ -71,17 +71,28 @@ func fromTypedStruct(src reflect.Value, ref dyn.Value) (dyn.Value, error) { return dyn.InvalidValue, fmt.Errorf("unhandled type: %s", ref.Kind()) } - out := make(map[string]dyn.Value) + refm, _ := ref.AsMap() + out := dyn.NewMapping() info := getStructInfo(src.Type()) for k, v := range info.FieldValues(src) { + pair, ok := refm.GetPairByString(k) + refk := pair.Key + refv := pair.Value + + // Use nil reference if there is no reference for this key + if !ok { + refk = dyn.V(k) + refv = dyn.NilValue + } + // Convert the field taking into account the reference value (may be equal to config.NilValue). - nv, err := fromTyped(v.Interface(), ref.Get(k)) + nv, err := fromTyped(v.Interface(), refv) if err != nil { return dyn.InvalidValue, err } if nv != dyn.NilValue { - out[k] = nv + out.Set(refk, nv) } } @@ -101,21 +112,31 @@ func fromTypedMap(src reflect.Value, ref dyn.Value) (dyn.Value, error) { return dyn.NilValue, nil } - out := make(map[string]dyn.Value) + refm, _ := ref.AsMap() + out := dyn.NewMapping() iter := src.MapRange() for iter.Next() { k := iter.Key().String() v := iter.Value() + pair, ok := refm.GetPairByString(k) + refk := pair.Key + refv := pair.Value + + // Use nil reference if there is no reference for this key + if !ok { + refk = dyn.V(k) + refv = dyn.NilValue + } // Convert entry taking into account the reference value (may be equal to dyn.NilValue). - nv, err := fromTyped(v.Interface(), ref.Get(k), includeZeroValues) + nv, err := fromTyped(v.Interface(), refv, includeZeroValues) if err != nil { return dyn.InvalidValue, err } // Every entry is represented, even if it is a nil. // Otherwise, a map with zero-valued structs would yield a nil as well. - out[k] = nv + out.Set(refk, nv) } return dyn.NewValue(out, ref.Location()), nil diff --git a/libs/dyn/convert/from_typed_test.go b/libs/dyn/convert/from_typed_test.go index f7e97fc7e7..f75470f420 100644 --- a/libs/dyn/convert/from_typed_test.go +++ b/libs/dyn/convert/from_typed_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" "github.com/stretchr/testify/require" ) diff --git a/libs/dyn/convert/normalize.go b/libs/dyn/convert/normalize.go index d6539be952..f18b27fd24 100644 --- a/libs/dyn/convert/normalize.go +++ b/libs/dyn/convert/normalize.go @@ -74,30 +74,32 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value, seen switch src.Kind() { case dyn.KindMap: - out := make(map[string]dyn.Value) + out := dyn.NewMapping() info := getStructInfo(typ) - for k, v := range src.MustMap() { - index, ok := info.Fields[k] + for _, pair := range src.MustMap().Pairs() { + pk := pair.Key + pv := pair.Value + index, ok := info.Fields[pk.MustString()] if !ok { diags = diags.Append(diag.Diagnostic{ Severity: diag.Warning, - Summary: fmt.Sprintf("unknown field: %s", k), - Location: src.Location(), + Summary: fmt.Sprintf("unknown field: %s", pk.MustString()), + Location: pk.Location(), }) continue } // Normalize the value according to the field type. - v, err := n.normalizeType(typ.FieldByIndex(index).Type, v, seen) + nv, err := n.normalizeType(typ.FieldByIndex(index).Type, pv, seen) if err != nil { diags = diags.Extend(err) // Skip the element if it cannot be normalized. - if !v.IsValid() { + if !nv.IsValid() { continue } } - out[k] = v + out.Set(pk, nv) } // Return the normalized value if missing fields are not included. @@ -107,7 +109,7 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value, seen // Populate missing fields with their zero values. for k, index := range info.Fields { - if _, ok := out[k]; ok { + if _, ok := out.GetByString(k); ok { continue } @@ -143,7 +145,7 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value, seen continue } if v.IsValid() { - out[k] = v + out.Set(dyn.V(k), v) } } @@ -160,19 +162,22 @@ func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value, seen []r switch src.Kind() { case dyn.KindMap: - out := make(map[string]dyn.Value) - for k, v := range src.MustMap() { + out := dyn.NewMapping() + for _, pair := range src.MustMap().Pairs() { + pk := pair.Key + pv := pair.Value + // Normalize the value according to the map element type. - v, err := n.normalizeType(typ.Elem(), v, seen) + nv, err := n.normalizeType(typ.Elem(), pv, seen) if err != nil { diags = diags.Extend(err) // Skip the element if it cannot be normalized. - if !v.IsValid() { + if !nv.IsValid() { continue } } - out[k] = v + out.Set(pk, nv) } return dyn.NewValue(out, src.Location()), diags diff --git a/libs/dyn/convert/normalize_test.go b/libs/dyn/convert/normalize_test.go index a2a6038e43..78c487d3f3 100644 --- a/libs/dyn/convert/normalize_test.go +++ b/libs/dyn/convert/normalize_test.go @@ -5,7 +5,7 @@ import ( "github.com/databricks/cli/libs/diag" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestNormalizeStruct(t *testing.T) { diff --git a/libs/dyn/convert/struct_info_test.go b/libs/dyn/convert/struct_info_test.go index 08be3c47ef..20348ff601 100644 --- a/libs/dyn/convert/struct_info_test.go +++ b/libs/dyn/convert/struct_info_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestStructInfoPlain(t *testing.T) { diff --git a/libs/dyn/convert/to_typed.go b/libs/dyn/convert/to_typed.go index 8b3cf3bb8a..f10853a2e4 100644 --- a/libs/dyn/convert/to_typed.go +++ b/libs/dyn/convert/to_typed.go @@ -59,8 +59,11 @@ func toTypedStruct(dst reflect.Value, src dyn.Value) error { dst.SetZero() info := getStructInfo(dst.Type()) - for k, v := range src.MustMap() { - index, ok := info.Fields[k] + for _, pair := range src.MustMap().Pairs() { + pk := pair.Key + pv := pair.Value + + index, ok := info.Fields[pk.MustString()] if !ok { // Ignore unknown fields. // A warning will be printed later. See PR #904. @@ -82,7 +85,7 @@ func toTypedStruct(dst reflect.Value, src dyn.Value) error { f = f.Field(x) } - err := ToTyped(f.Addr().Interface(), v) + err := ToTyped(f.Addr().Interface(), pv) if err != nil { return err } @@ -112,12 +115,14 @@ func toTypedMap(dst reflect.Value, src dyn.Value) error { m := src.MustMap() // Always overwrite. - dst.Set(reflect.MakeMapWithSize(dst.Type(), len(m))) - for k, v := range m { - kv := reflect.ValueOf(k) + dst.Set(reflect.MakeMapWithSize(dst.Type(), m.Len())) + for _, pair := range m.Pairs() { + pk := pair.Key + pv := pair.Value + kv := reflect.ValueOf(pk.MustString()) kt := dst.Type().Key() vv := reflect.New(dst.Type().Elem()) - err := ToTyped(vv.Interface(), v) + err := ToTyped(vv.Interface(), pv) if err != nil { return err } diff --git a/libs/dyn/convert/to_typed_test.go b/libs/dyn/convert/to_typed_test.go index a3c340e81c..56d98a3cf4 100644 --- a/libs/dyn/convert/to_typed_test.go +++ b/libs/dyn/convert/to_typed_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" "github.com/stretchr/testify/require" ) diff --git a/libs/dyn/dynassert/assert.go b/libs/dyn/dynassert/assert.go new file mode 100644 index 0000000000..dc6676ca27 --- /dev/null +++ b/libs/dyn/dynassert/assert.go @@ -0,0 +1,113 @@ +package dynassert + +import ( + "github.com/databricks/cli/libs/dyn" + "github.com/stretchr/testify/assert" +) + +func Equal(t assert.TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + ev, eok := expected.(dyn.Value) + av, aok := actual.(dyn.Value) + if eok && aok && ev.IsValid() && av.IsValid() { + if !assert.Equal(t, ev.AsAny(), av.AsAny(), msgAndArgs...) { + return false + } + + // The values are equal on contents. Now compare the locations. + if !assert.Equal(t, ev.Location(), av.Location(), msgAndArgs...) { + return false + } + + // Walk ev and av and compare the locations of each element. + _, err := dyn.Walk(ev, func(p dyn.Path, evv dyn.Value) (dyn.Value, error) { + avv, err := dyn.GetByPath(av, p) + if assert.NoError(t, err, "unable to get value from actual value at path %v", p.String()) { + assert.Equal(t, evv.Location(), avv.Location()) + } + return evv, nil + }) + return assert.NoError(t, err) + } + + return assert.Equal(t, expected, actual, msgAndArgs...) +} + +func EqualValues(t assert.TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + return assert.EqualValues(t, expected, actual, msgAndArgs...) +} + +func NotEqual(t assert.TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + return assert.NotEqual(t, expected, actual, msgAndArgs...) +} + +func Len(t assert.TestingT, object interface{}, length int, msgAndArgs ...interface{}) bool { + return assert.Len(t, object, length, msgAndArgs...) +} + +func Empty(t assert.TestingT, object interface{}, msgAndArgs ...interface{}) bool { + return assert.Empty(t, object, msgAndArgs...) +} + +func Nil(t assert.TestingT, object interface{}, msgAndArgs ...interface{}) bool { + return assert.Nil(t, object, msgAndArgs...) +} + +func NotNil(t assert.TestingT, object interface{}, msgAndArgs ...interface{}) bool { + return assert.NotNil(t, object, msgAndArgs...) +} + +func NoError(t assert.TestingT, err error, msgAndArgs ...interface{}) bool { + return assert.NoError(t, err, msgAndArgs...) +} + +func Error(t assert.TestingT, err error, msgAndArgs ...interface{}) bool { + return assert.Error(t, err, msgAndArgs...) +} + +func EqualError(t assert.TestingT, theError error, errString string, msgAndArgs ...interface{}) bool { + return assert.EqualError(t, theError, errString, msgAndArgs...) +} + +func ErrorContains(t assert.TestingT, theError error, contains string, msgAndArgs ...interface{}) bool { + return assert.ErrorContains(t, theError, contains, msgAndArgs...) +} + +func ErrorIs(t assert.TestingT, theError, target error, msgAndArgs ...interface{}) bool { + return assert.ErrorIs(t, theError, target, msgAndArgs...) +} + +func True(t assert.TestingT, value bool, msgAndArgs ...interface{}) bool { + return assert.True(t, value, msgAndArgs...) +} + +func False(t assert.TestingT, value bool, msgAndArgs ...interface{}) bool { + return assert.False(t, value, msgAndArgs...) +} + +func Contains(t assert.TestingT, list interface{}, element interface{}, msgAndArgs ...interface{}) bool { + return assert.Contains(t, list, element, msgAndArgs...) +} + +func NotContains(t assert.TestingT, list interface{}, element interface{}, msgAndArgs ...interface{}) bool { + return assert.NotContains(t, list, element, msgAndArgs...) +} + +func ElementsMatch(t assert.TestingT, listA, listB interface{}, msgAndArgs ...interface{}) bool { + return assert.ElementsMatch(t, listA, listB, msgAndArgs...) +} + +func Panics(t assert.TestingT, f func(), msgAndArgs ...interface{}) bool { + return assert.Panics(t, f, msgAndArgs...) +} + +func PanicsWithValue(t assert.TestingT, expected interface{}, f func(), msgAndArgs ...interface{}) bool { + return assert.PanicsWithValue(t, expected, f, msgAndArgs...) +} + +func PanicsWithError(t assert.TestingT, errString string, f func(), msgAndArgs ...interface{}) bool { + return assert.PanicsWithError(t, errString, f, msgAndArgs...) +} + +func NotPanics(t assert.TestingT, f func(), msgAndArgs ...interface{}) bool { + return assert.NotPanics(t, f, msgAndArgs...) +} diff --git a/libs/dyn/dynassert/assert_test.go b/libs/dyn/dynassert/assert_test.go new file mode 100644 index 0000000000..43258bd205 --- /dev/null +++ b/libs/dyn/dynassert/assert_test.go @@ -0,0 +1,45 @@ +package dynassert + +import ( + "go/parser" + "go/token" + "io/fs" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestThatThisTestPackageIsUsed(t *testing.T) { + var base = ".." + var files []string + err := fs.WalkDir(os.DirFS(base), ".", func(path string, d fs.DirEntry, err error) error { + if d.IsDir() { + // Filter this directory. + if filepath.Base(path) == "dynassert" { + return fs.SkipDir + } + } + if ok, _ := filepath.Match("*_test.go", d.Name()); ok { + files = append(files, filepath.Join(base, path)) + } + return nil + }) + require.NoError(t, err) + + // Confirm that none of the test files under `libs/dyn` import the + // `testify/assert` package and instead import this package for asserts. + fset := token.NewFileSet() + for _, file := range files { + f, err := parser.ParseFile(fset, file, nil, parser.ParseComments) + require.NoError(t, err) + + for _, imp := range f.Imports { + if strings.Contains(imp.Path.Value, `github.com/stretchr/testify/assert`) { + t.Errorf("File %s should not import github.com/stretchr/testify/assert", file) + } + } + } +} diff --git a/libs/dyn/dynvar/lookup_test.go b/libs/dyn/dynvar/lookup_test.go index 2341d72084..b78115ee8f 100644 --- a/libs/dyn/dynvar/lookup_test.go +++ b/libs/dyn/dynvar/lookup_test.go @@ -4,8 +4,8 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" + assert "github.com/databricks/cli/libs/dyn/dynassert" "github.com/databricks/cli/libs/dyn/dynvar" - "github.com/stretchr/testify/assert" ) func TestDefaultLookup(t *testing.T) { diff --git a/libs/dyn/dynvar/ref_test.go b/libs/dyn/dynvar/ref_test.go index 0922373687..aff3643e02 100644 --- a/libs/dyn/dynvar/ref_test.go +++ b/libs/dyn/dynvar/ref_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" "github.com/stretchr/testify/require" ) diff --git a/libs/dyn/dynvar/resolve_test.go b/libs/dyn/dynvar/resolve_test.go index 304ed9391c..bbecbb7760 100644 --- a/libs/dyn/dynvar/resolve_test.go +++ b/libs/dyn/dynvar/resolve_test.go @@ -4,8 +4,8 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" + assert "github.com/databricks/cli/libs/dyn/dynassert" "github.com/databricks/cli/libs/dyn/dynvar" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/libs/dyn/kind.go b/libs/dyn/kind.go index 8f51c25c66..9d507fbc52 100644 --- a/libs/dyn/kind.go +++ b/libs/dyn/kind.go @@ -22,7 +22,7 @@ const ( func kindOf(v any) Kind { switch v.(type) { - case map[string]Value: + case Mapping: return KindMap case []Value: return KindSequence diff --git a/libs/dyn/kind_test.go b/libs/dyn/kind_test.go index 84c90713fb..9889d31e11 100644 --- a/libs/dyn/kind_test.go +++ b/libs/dyn/kind_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestKindZeroValue(t *testing.T) { diff --git a/libs/dyn/location_test.go b/libs/dyn/location_test.go index 6d856410b6..e11f7cb56b 100644 --- a/libs/dyn/location_test.go +++ b/libs/dyn/location_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestLocation(t *testing.T) { diff --git a/libs/dyn/mapping.go b/libs/dyn/mapping.go new file mode 100644 index 0000000000..668f57ecc4 --- /dev/null +++ b/libs/dyn/mapping.go @@ -0,0 +1,148 @@ +package dyn + +import ( + "fmt" + "maps" + "slices" +) + +// Pair represents a single key-value pair in a Mapping. +type Pair struct { + Key Value + Value Value +} + +// Mapping represents a key-value map of dynamic values. +// It exists because plain Go maps cannot use dynamic values for keys. +// We need to use dynamic values for keys because it lets us associate metadata +// with keys (i.e. their definition location). Keys must be strings. +type Mapping struct { + pairs []Pair + index map[string]int +} + +// NewMapping creates a new empty Mapping. +func NewMapping() Mapping { + return Mapping{ + pairs: make([]Pair, 0), + index: make(map[string]int), + } +} + +// newMappingWithSize creates a new Mapping preallocated to the specified size. +func newMappingWithSize(size int) Mapping { + return Mapping{ + pairs: make([]Pair, 0, size), + index: make(map[string]int, size), + } +} + +// newMappingFromGoMap creates a new Mapping from a Go map of string keys and dynamic values. +func newMappingFromGoMap(vin map[string]Value) Mapping { + m := newMappingWithSize(len(vin)) + for k, v := range vin { + m.Set(V(k), v) + } + return m +} + +// Pairs returns all the key-value pairs in the Mapping. +func (m Mapping) Pairs() []Pair { + return m.pairs +} + +// Len returns the number of key-value pairs in the Mapping. +func (m Mapping) Len() int { + return len(m.pairs) +} + +// GetPair returns the key-value pair with the specified key. +// It also returns a boolean indicating whether the pair was found. +func (m Mapping) GetPair(key Value) (Pair, bool) { + skey, ok := key.AsString() + if !ok { + return Pair{}, false + } + return m.GetPairByString(skey) +} + +// GetPairByString returns the key-value pair with the specified string key. +// It also returns a boolean indicating whether the pair was found. +func (m Mapping) GetPairByString(skey string) (Pair, bool) { + if i, ok := m.index[skey]; ok { + return m.pairs[i], true + } + return Pair{}, false +} + +// Get returns the value associated with the specified key. +// It also returns a boolean indicating whether the value was found. +func (m Mapping) Get(key Value) (Value, bool) { + p, ok := m.GetPair(key) + return p.Value, ok +} + +// GetByString returns the value associated with the specified string key. +// It also returns a boolean indicating whether the value was found. +func (m *Mapping) GetByString(skey string) (Value, bool) { + p, ok := m.GetPairByString(skey) + return p.Value, ok +} + +// Set sets the value for the given key in the mapping. +// If the key already exists, the value is updated. +// If the key does not exist, a new key-value pair is added. +// The key must be a string, otherwise an error is returned. +func (m *Mapping) Set(key Value, value Value) error { + skey, ok := key.AsString() + if !ok { + return fmt.Errorf("key must be a string, got %s", key.Kind()) + } + + // If the key already exists, update the value. + if i, ok := m.index[skey]; ok { + m.pairs[i].Value = value + return nil + } + + // Otherwise, add a new pair. + m.pairs = append(m.pairs, Pair{key, value}) + if m.index == nil { + m.index = make(map[string]int) + } + m.index[skey] = len(m.pairs) - 1 + return nil +} + +// Keys returns all the keys in the Mapping. +func (m Mapping) Keys() []Value { + keys := make([]Value, 0, len(m.pairs)) + for _, p := range m.pairs { + keys = append(keys, p.Key) + } + return keys +} + +// Values returns all the values in the Mapping. +func (m Mapping) Values() []Value { + values := make([]Value, 0, len(m.pairs)) + for _, p := range m.pairs { + values = append(values, p.Value) + } + return values +} + +// Clone creates a shallow copy of the Mapping. +func (m Mapping) Clone() Mapping { + return Mapping{ + pairs: slices.Clone(m.pairs), + index: maps.Clone(m.index), + } +} + +// Merge merges the key-value pairs from another Mapping into the current Mapping. +func (m *Mapping) Merge(n Mapping) { + for _, p := range n.pairs { + m.Set(p.Key, p.Value) + } +} diff --git a/libs/dyn/mapping_test.go b/libs/dyn/mapping_test.go new file mode 100644 index 0000000000..43b24b0c5a --- /dev/null +++ b/libs/dyn/mapping_test.go @@ -0,0 +1,204 @@ +package dyn_test + +import ( + "fmt" + "testing" + + "github.com/databricks/cli/libs/dyn" + assert "github.com/databricks/cli/libs/dyn/dynassert" + "github.com/stretchr/testify/require" +) + +func TestNewMapping(t *testing.T) { + m := dyn.NewMapping() + assert.Equal(t, 0, m.Len()) +} + +func TestMappingZeroValue(t *testing.T) { + var m dyn.Mapping + assert.Equal(t, 0, m.Len()) + + value, ok := m.Get(dyn.V("key")) + assert.Equal(t, dyn.InvalidValue, value) + assert.False(t, ok) + assert.Len(t, m.Keys(), 0) + assert.Len(t, m.Values(), 0) +} + +func TestMappingGet(t *testing.T) { + var m dyn.Mapping + err := m.Set(dyn.V("key"), dyn.V("value")) + assert.NoError(t, err) + assert.Equal(t, 1, m.Len()) + + // Call GetPair + p, ok := m.GetPair(dyn.V("key")) + assert.True(t, ok) + assert.Equal(t, dyn.V("key"), p.Key) + assert.Equal(t, dyn.V("value"), p.Value) + + // Modify the value to make sure we're not getting a reference + p.Value = dyn.V("newvalue") + + // Call GetPair with invalid key + p, ok = m.GetPair(dyn.V(1234)) + assert.False(t, ok) + assert.Equal(t, dyn.InvalidValue, p.Key) + assert.Equal(t, dyn.InvalidValue, p.Value) + + // Call GetPair with non-existent key + p, ok = m.GetPair(dyn.V("enoexist")) + assert.False(t, ok) + assert.Equal(t, dyn.InvalidValue, p.Key) + assert.Equal(t, dyn.InvalidValue, p.Value) + + // Call GetPairByString + p, ok = m.GetPairByString("key") + assert.True(t, ok) + assert.Equal(t, dyn.V("key"), p.Key) + assert.Equal(t, dyn.V("value"), p.Value) + + // Modify the value to make sure we're not getting a reference + p.Value = dyn.V("newvalue") + + // Call GetPairByString with with non-existent key + p, ok = m.GetPairByString("enoexist") + assert.False(t, ok) + assert.Equal(t, dyn.InvalidValue, p.Key) + assert.Equal(t, dyn.InvalidValue, p.Value) + + // Call Get + value, ok := m.Get(dyn.V("key")) + assert.True(t, ok) + assert.Equal(t, dyn.V("value"), value) + + // Call Get with invalid key + value, ok = m.Get(dyn.V(1234)) + assert.False(t, ok) + assert.Equal(t, dyn.InvalidValue, value) + + // Call Get with non-existent key + value, ok = m.Get(dyn.V("enoexist")) + assert.False(t, ok) + assert.Equal(t, dyn.InvalidValue, value) + + // Call GetByString + value, ok = m.GetByString("key") + assert.True(t, ok) + assert.Equal(t, dyn.V("value"), value) + + // Call GetByString with non-existent key + value, ok = m.GetByString("enoexist") + assert.False(t, ok) + assert.Equal(t, dyn.InvalidValue, value) +} + +func TestMappingSet(t *testing.T) { + var err error + var m dyn.Mapping + + // Set a value + err = m.Set(dyn.V("key1"), dyn.V("foo")) + assert.NoError(t, err) + assert.Equal(t, 1, m.Len()) + + // Confirm the value + value, ok := m.GetByString("key1") + assert.True(t, ok) + assert.Equal(t, dyn.V("foo"), value) + + // Set another value + err = m.Set(dyn.V("key2"), dyn.V("bar")) + assert.NoError(t, err) + assert.Equal(t, 2, m.Len()) + + // Confirm the value + value, ok = m.Get(dyn.V("key2")) + assert.True(t, ok) + assert.Equal(t, dyn.V("bar"), value) + + // Overwrite first value + err = m.Set(dyn.V("key1"), dyn.V("qux")) + assert.NoError(t, err) + assert.Equal(t, 2, m.Len()) + + // Confirm the value + value, ok = m.Get(dyn.V("key1")) + assert.True(t, ok) + assert.Equal(t, dyn.V("qux"), value) + + // Try to set non-string key + err = m.Set(dyn.V(1), dyn.V("qux")) + assert.Error(t, err) + assert.Equal(t, 2, m.Len()) +} + +func TestMappingKeysValues(t *testing.T) { + var err error + + // Configure mapping + var m dyn.Mapping + err = m.Set(dyn.V("key1"), dyn.V("foo")) + assert.NoError(t, err) + err = m.Set(dyn.V("key2"), dyn.V("bar")) + assert.NoError(t, err) + + // Confirm keys + keys := m.Keys() + assert.Len(t, keys, 2) + assert.Contains(t, keys, dyn.V("key1")) + assert.Contains(t, keys, dyn.V("key2")) + + // Confirm values + values := m.Values() + assert.Len(t, values, 2) + assert.Contains(t, values, dyn.V("foo")) + assert.Contains(t, values, dyn.V("bar")) +} + +func TestMappingClone(t *testing.T) { + var err error + + // Configure mapping + var m1 dyn.Mapping + err = m1.Set(dyn.V("key1"), dyn.V("foo")) + assert.NoError(t, err) + err = m1.Set(dyn.V("key2"), dyn.V("bar")) + assert.NoError(t, err) + + // Clone mapping + m2 := m1.Clone() + assert.Equal(t, m1.Len(), m2.Len()) + + // Modify original mapping + err = m1.Set(dyn.V("key1"), dyn.V("qux")) + assert.NoError(t, err) + + // Confirm values + value, ok := m1.Get(dyn.V("key1")) + assert.True(t, ok) + assert.Equal(t, dyn.V("qux"), value) + value, ok = m2.Get(dyn.V("key1")) + assert.True(t, ok) + assert.Equal(t, dyn.V("foo"), value) +} + +func TestMappingMerge(t *testing.T) { + var m1 dyn.Mapping + for i := 0; i < 10; i++ { + err := m1.Set(dyn.V(fmt.Sprintf("%d", i)), dyn.V(i)) + require.NoError(t, err) + } + + var m2 dyn.Mapping + for i := 5; i < 15; i++ { + err := m2.Set(dyn.V(fmt.Sprintf("%d", i)), dyn.V(i)) + require.NoError(t, err) + } + + var out dyn.Mapping + out.Merge(m1) + assert.Equal(t, 10, out.Len()) + out.Merge(m2) + assert.Equal(t, 15, out.Len()) +} diff --git a/libs/dyn/merge/elements_by_key_test.go b/libs/dyn/merge/elements_by_key_test.go index c61f834e5f..ef316cc666 100644 --- a/libs/dyn/merge/elements_by_key_test.go +++ b/libs/dyn/merge/elements_by_key_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" "github.com/stretchr/testify/require" ) diff --git a/libs/dyn/merge/merge.go b/libs/dyn/merge/merge.go index 1cadbea608..69ccf516ae 100644 --- a/libs/dyn/merge/merge.go +++ b/libs/dyn/merge/merge.go @@ -51,27 +51,27 @@ func merge(a, b dyn.Value) (dyn.Value, error) { } func mergeMap(a, b dyn.Value) (dyn.Value, error) { - out := make(map[string]dyn.Value) + out := dyn.NewMapping() am := a.MustMap() bm := b.MustMap() // Add the values from a into the output map. - for k, v := range am { - out[k] = v - } + out.Merge(am) // Merge the values from b into the output map. - for k, v := range bm { - if _, ok := out[k]; ok { + for _, pair := range bm.Pairs() { + pk := pair.Key + pv := pair.Value + if ov, ok := out.Get(pk); ok { // If the key already exists, merge the values. - merged, err := merge(out[k], v) + merged, err := merge(ov, pv) if err != nil { return dyn.NilValue, err } - out[k] = merged + out.Set(pk, merged) } else { // Otherwise, just set the value. - out[k] = v + out.Set(pk, pv) } } diff --git a/libs/dyn/merge/merge_test.go b/libs/dyn/merge/merge_test.go index c4928e3536..eaaaab16f4 100644 --- a/libs/dyn/merge/merge_test.go +++ b/libs/dyn/merge/merge_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestMergeMaps(t *testing.T) { diff --git a/libs/dyn/path_string_test.go b/libs/dyn/path_string_test.go index 9af394c6f1..0d64bf1107 100644 --- a/libs/dyn/path_string_test.go +++ b/libs/dyn/path_string_test.go @@ -5,7 +5,7 @@ import ( "testing" . "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestNewPathFromString(t *testing.T) { diff --git a/libs/dyn/path_test.go b/libs/dyn/path_test.go index 1152a060ad..44df2050b0 100644 --- a/libs/dyn/path_test.go +++ b/libs/dyn/path_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestPathAppend(t *testing.T) { diff --git a/libs/dyn/pattern.go b/libs/dyn/pattern.go index 960a50d5b8..a265dad081 100644 --- a/libs/dyn/pattern.go +++ b/libs/dyn/pattern.go @@ -2,7 +2,6 @@ package dyn import ( "fmt" - "maps" "slices" ) @@ -55,10 +54,13 @@ func (c anyKeyComponent) visit(v Value, prefix Path, suffix Pattern, opts visitO return InvalidValue, fmt.Errorf("expected a map at %q, found %s", prefix, v.Kind()) } - m = maps.Clone(m) - for key, value := range m { + m = m.Clone() + for _, pair := range m.Pairs() { + pk := pair.Key + pv := pair.Value + var err error - nv, err := visit(value, append(prefix, Key(key)), suffix, opts) + nv, err := visit(pv, append(prefix, Key(pk.MustString())), suffix, opts) if err != nil { // Leave the value intact if the suffix pattern didn't match any value. if IsNoSuchKeyError(err) || IsIndexOutOfBoundsError(err) { @@ -66,7 +68,8 @@ func (c anyKeyComponent) visit(v Value, prefix Path, suffix Pattern, opts visitO } return InvalidValue, err } - m[key] = nv + + m.Set(pk, nv) } return NewValue(m, v.Location()), nil diff --git a/libs/dyn/pattern_test.go b/libs/dyn/pattern_test.go index 372fe74678..1b54953efe 100644 --- a/libs/dyn/pattern_test.go +++ b/libs/dyn/pattern_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestNewPattern(t *testing.T) { diff --git a/libs/dyn/value.go b/libs/dyn/value.go index ecf21abbe8..2e8f1b9aff 100644 --- a/libs/dyn/value.go +++ b/libs/dyn/value.go @@ -27,14 +27,16 @@ var NilValue = Value{ // V constructs a new Value with the given value. func V(v any) Value { - return Value{ - v: v, - k: kindOf(v), - } + return NewValue(v, Location{}) } // NewValue constructs a new Value with the given value and location. func NewValue(v any, loc Location) Value { + switch vin := v.(type) { + case map[string]Value: + v = newMappingFromGoMap(vin) + } + return Value{ v: v, k: kindOf(v), @@ -72,12 +74,14 @@ func (v Value) AsAny() any { case KindInvalid: panic("invoked AsAny on invalid value") case KindMap: - vv := v.v.(map[string]Value) - m := make(map[string]any, len(vv)) - for k, v := range vv { - m[k] = v.AsAny() + m := v.v.(Mapping) + out := make(map[string]any, m.Len()) + for _, pair := range m.pairs { + pk := pair.Key + pv := pair.Value + out[pk.MustString()] = pv.AsAny() } - return m + return out case KindSequence: vv := v.v.([]Value) a := make([]any, len(vv)) @@ -109,7 +113,7 @@ func (v Value) Get(key string) Value { return NilValue } - vv, ok := m[key] + vv, ok := m.GetByString(key) if !ok { return NilValue } diff --git a/libs/dyn/value_test.go b/libs/dyn/value_test.go index 7c9a9d990e..bbdc2c96ba 100644 --- a/libs/dyn/value_test.go +++ b/libs/dyn/value_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestInvalidValue(t *testing.T) { @@ -22,14 +22,12 @@ func TestValueIsAnchor(t *testing.T) { func TestValueAsMap(t *testing.T) { var zeroValue dyn.Value - m, ok := zeroValue.AsMap() + _, ok := zeroValue.AsMap() assert.False(t, ok) - assert.Nil(t, m) var intValue = dyn.NewValue(1, dyn.Location{}) - m, ok = intValue.AsMap() + _, ok = intValue.AsMap() assert.False(t, ok) - assert.Nil(t, m) var mapValue = dyn.NewValue( map[string]dyn.Value{ @@ -37,9 +35,9 @@ func TestValueAsMap(t *testing.T) { }, dyn.Location{File: "file", Line: 1, Column: 2}, ) - m, ok = mapValue.AsMap() + m, ok := mapValue.AsMap() assert.True(t, ok) - assert.Len(t, m, 1) + assert.Equal(t, 1, m.Len()) } func TestValueIsValid(t *testing.T) { diff --git a/libs/dyn/value_underlying.go b/libs/dyn/value_underlying.go index c8c5037900..2f0f26a1f6 100644 --- a/libs/dyn/value_underlying.go +++ b/libs/dyn/value_underlying.go @@ -5,16 +5,16 @@ import ( "time" ) -// AsMap returns the underlying map if this value is a map, +// AsMap returns the underlying mapping if this value is a map, // the zero value and false otherwise. -func (v Value) AsMap() (map[string]Value, bool) { - vv, ok := v.v.(map[string]Value) +func (v Value) AsMap() (Mapping, bool) { + vv, ok := v.v.(Mapping) return vv, ok } -// MustMap returns the underlying map if this value is a map, +// MustMap returns the underlying mapping if this value is a map, // panics otherwise. -func (v Value) MustMap() map[string]Value { +func (v Value) MustMap() Mapping { vv, ok := v.AsMap() if !ok || v.k != KindMap { panic(fmt.Sprintf("expected kind %s, got %s", KindMap, v.k)) diff --git a/libs/dyn/value_underlying_test.go b/libs/dyn/value_underlying_test.go index 17cb959418..9878cfaf9d 100644 --- a/libs/dyn/value_underlying_test.go +++ b/libs/dyn/value_underlying_test.go @@ -5,7 +5,7 @@ import ( "time" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestValueUnderlyingMap(t *testing.T) { diff --git a/libs/dyn/visit.go b/libs/dyn/visit.go index 376dcc22db..3fe3561943 100644 --- a/libs/dyn/visit.go +++ b/libs/dyn/visit.go @@ -3,7 +3,6 @@ package dyn import ( "errors" "fmt" - "maps" "slices" ) @@ -77,7 +76,7 @@ func (component pathComponent) visit(v Value, prefix Path, suffix Pattern, opts } // Lookup current value in the map. - ev, ok := m[component.key] + ev, ok := m.GetByString(component.key) if !ok { return InvalidValue, noSuchKeyError{path} } @@ -94,8 +93,8 @@ func (component pathComponent) visit(v Value, prefix Path, suffix Pattern, opts } // Return an updated map value. - m = maps.Clone(m) - m[component.key] = nv + m = m.Clone() + m.Set(V(component.key), nv) return Value{ v: m, k: KindMap, diff --git a/libs/dyn/visit_get_test.go b/libs/dyn/visit_get_test.go index 22dce0858b..adc307794c 100644 --- a/libs/dyn/visit_get_test.go +++ b/libs/dyn/visit_get_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestGetWithEmptyPath(t *testing.T) { diff --git a/libs/dyn/visit_map.go b/libs/dyn/visit_map.go index 18fc668ede..f5cfea3114 100644 --- a/libs/dyn/visit_map.go +++ b/libs/dyn/visit_map.go @@ -2,7 +2,6 @@ package dyn import ( "fmt" - "maps" "slices" ) @@ -15,13 +14,15 @@ func Foreach(fn MapFunc) MapFunc { return func(p Path, v Value) (Value, error) { switch v.Kind() { case KindMap: - m := maps.Clone(v.MustMap()) - for key, value := range m { - var err error - m[key], err = fn(append(p, Key(key)), value) + m := v.MustMap().Clone() + for _, pair := range m.Pairs() { + pk := pair.Key + pv := pair.Value + nv, err := fn(append(p, Key(pk.MustString())), pv) if err != nil { return InvalidValue, err } + m.Set(pk, nv) } return NewValue(m, v.Location()), nil case KindSequence: diff --git a/libs/dyn/visit_map_test.go b/libs/dyn/visit_map_test.go index f87f0a40d4..df6bad4965 100644 --- a/libs/dyn/visit_map_test.go +++ b/libs/dyn/visit_map_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" "github.com/stretchr/testify/require" ) diff --git a/libs/dyn/visit_set.go b/libs/dyn/visit_set.go index edcd9bb735..b086fb8a91 100644 --- a/libs/dyn/visit_set.go +++ b/libs/dyn/visit_set.go @@ -2,7 +2,6 @@ package dyn import ( "fmt" - "maps" "slices" ) @@ -41,8 +40,8 @@ func SetByPath(v Value, p Path, nv Value) (Value, error) { } // Return an updated map value. - m = maps.Clone(m) - m[component.key] = nv + m = m.Clone() + m.Set(V(component.key), nv) return Value{ v: m, k: KindMap, diff --git a/libs/dyn/visit_set_test.go b/libs/dyn/visit_set_test.go index b384715875..df58941e17 100644 --- a/libs/dyn/visit_set_test.go +++ b/libs/dyn/visit_set_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestSetWithEmptyPath(t *testing.T) { diff --git a/libs/dyn/walk.go b/libs/dyn/walk.go index 26ddfc11d9..97b99b061e 100644 --- a/libs/dyn/walk.go +++ b/libs/dyn/walk.go @@ -34,16 +34,18 @@ func walk(v Value, p Path, fn func(p Path, v Value) (Value, error)) (Value, erro switch v.Kind() { case KindMap: m := v.MustMap() - out := make(map[string]Value, len(m)) - for k := range m { - nv, err := walk(m[k], append(p, Key(k)), fn) + out := newMappingWithSize(m.Len()) + for _, pair := range m.Pairs() { + pk := pair.Key + pv := pair.Value + nv, err := walk(pv, append(p, Key(pk.MustString())), fn) if err == ErrDrop { continue } if err != nil { return NilValue, err } - out[k] = nv + out.Set(pk, nv) } v.v = out case KindSequence: diff --git a/libs/dyn/walk_test.go b/libs/dyn/walk_test.go index 1b94ad9027..d62b9a4db8 100644 --- a/libs/dyn/walk_test.go +++ b/libs/dyn/walk_test.go @@ -5,7 +5,7 @@ import ( "testing" . "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" "github.com/stretchr/testify/require" ) diff --git a/libs/dyn/yamlloader/loader.go b/libs/dyn/yamlloader/loader.go index 899e1d7b8a..908793d584 100644 --- a/libs/dyn/yamlloader/loader.go +++ b/libs/dyn/yamlloader/loader.go @@ -92,7 +92,7 @@ func (d *loader) loadSequence(node *yaml.Node, loc dyn.Location) (dyn.Value, err func (d *loader) loadMapping(node *yaml.Node, loc dyn.Location) (dyn.Value, error) { var merge *yaml.Node - acc := make(map[string]dyn.Value) + acc := dyn.NewMapping() for i := 0; i < len(node.Content); i += 2 { key := node.Content[i] val := node.Content[i+1] @@ -116,12 +116,17 @@ func (d *loader) loadMapping(node *yaml.Node, loc dyn.Location) (dyn.Value, erro return dyn.NilValue, errorf(loc, "invalid key tag: %v", st) } + k, err := d.load(key) + if err != nil { + return dyn.NilValue, err + } + v, err := d.load(val) if err != nil { return dyn.NilValue, err } - acc[key.Value] = v + acc.Set(k, v) } if merge == nil { @@ -146,7 +151,7 @@ func (d *loader) loadMapping(node *yaml.Node, loc dyn.Location) (dyn.Value, erro // Build a sequence of values to merge. // The entries that we already accumulated have precedence. - var seq []map[string]dyn.Value + var seq []dyn.Mapping for _, n := range mnodes { v, err := d.load(n) if err != nil { @@ -161,11 +166,9 @@ func (d *loader) loadMapping(node *yaml.Node, loc dyn.Location) (dyn.Value, erro // Append the accumulated entries to the sequence. seq = append(seq, acc) - out := make(map[string]dyn.Value) + out := dyn.NewMapping() for _, m := range seq { - for k, v := range m { - out[k] = v - } + out.Merge(m) } return dyn.NewValue(out, loc), nil diff --git a/libs/dyn/yamlloader/yaml_anchor_test.go b/libs/dyn/yamlloader/yaml_anchor_test.go index 05beb5401d..29ce69f0ac 100644 --- a/libs/dyn/yamlloader/yaml_anchor_test.go +++ b/libs/dyn/yamlloader/yaml_anchor_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestYAMLAnchor01(t *testing.T) { diff --git a/libs/dyn/yamlloader/yaml_error_test.go b/libs/dyn/yamlloader/yaml_error_test.go index 11c444ad36..0ae424341e 100644 --- a/libs/dyn/yamlloader/yaml_error_test.go +++ b/libs/dyn/yamlloader/yaml_error_test.go @@ -5,8 +5,8 @@ import ( "os" "testing" + assert "github.com/databricks/cli/libs/dyn/dynassert" "github.com/databricks/cli/libs/dyn/yamlloader" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" ) diff --git a/libs/dyn/yamlloader/yaml_mix_test.go b/libs/dyn/yamlloader/yaml_mix_test.go index 307b93dbf3..55ded6bafc 100644 --- a/libs/dyn/yamlloader/yaml_mix_test.go +++ b/libs/dyn/yamlloader/yaml_mix_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestYAMLMix01(t *testing.T) { diff --git a/libs/dyn/yamlloader/yaml_test.go b/libs/dyn/yamlloader/yaml_test.go index 14269feeef..9bb0377dd7 100644 --- a/libs/dyn/yamlloader/yaml_test.go +++ b/libs/dyn/yamlloader/yaml_test.go @@ -6,8 +6,8 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" + assert "github.com/databricks/cli/libs/dyn/dynassert" "github.com/databricks/cli/libs/dyn/yamlloader" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" ) diff --git a/libs/dyn/yamlsaver/order_test.go b/libs/dyn/yamlsaver/order_test.go index ed2877f6c8..ee9dc4752f 100644 --- a/libs/dyn/yamlsaver/order_test.go +++ b/libs/dyn/yamlsaver/order_test.go @@ -3,7 +3,7 @@ package yamlsaver import ( "testing" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestOrderReturnsIncreasingIndex(t *testing.T) { diff --git a/libs/dyn/yamlsaver/saver.go b/libs/dyn/yamlsaver/saver.go index 84483a12f0..fe4cfb8549 100644 --- a/libs/dyn/yamlsaver/saver.go +++ b/libs/dyn/yamlsaver/saver.go @@ -9,7 +9,6 @@ import ( "strconv" "github.com/databricks/cli/libs/dyn" - "golang.org/x/exp/maps" "gopkg.in/yaml.v3" ) @@ -75,25 +74,27 @@ func (s *saver) toYamlNodeWithStyle(v dyn.Value, style yaml.Style) (*yaml.Node, switch v.Kind() { case dyn.KindMap: m, _ := v.AsMap() - keys := maps.Keys(m) + // We're using location lines to define the order of keys in YAML. // The location is set when we convert API response struct to config.Value representation // See convert.convertMap for details - sort.SliceStable(keys, func(i, j int) bool { - return m[keys[i]].Location().Line < m[keys[j]].Location().Line + pairs := m.Pairs() + sort.SliceStable(pairs, func(i, j int) bool { + return pairs[i].Value.Location().Line < pairs[j].Value.Location().Line }) content := make([]*yaml.Node, 0) - for _, k := range keys { - item := m[k] - node := yaml.Node{Kind: yaml.ScalarNode, Value: k, Style: style} + for _, pair := range pairs { + pk := pair.Key + pv := pair.Value + node := yaml.Node{Kind: yaml.ScalarNode, Value: pk.MustString(), Style: style} var nestedNodeStyle yaml.Style - if customStyle, ok := s.hasStyle(k); ok { + if customStyle, ok := s.hasStyle(pk.MustString()); ok { nestedNodeStyle = customStyle } else { nestedNodeStyle = style } - c, err := s.toYamlNodeWithStyle(item, nestedNodeStyle) + c, err := s.toYamlNodeWithStyle(pv, nestedNodeStyle) if err != nil { return nil, err } diff --git a/libs/dyn/yamlsaver/saver_test.go b/libs/dyn/yamlsaver/saver_test.go index ec44a42987..bdf1891cdd 100644 --- a/libs/dyn/yamlsaver/saver_test.go +++ b/libs/dyn/yamlsaver/saver_test.go @@ -5,7 +5,7 @@ import ( "time" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" "gopkg.in/yaml.v3" ) diff --git a/libs/dyn/yamlsaver/utils.go b/libs/dyn/yamlsaver/utils.go index 0fb4064b54..6149491d60 100644 --- a/libs/dyn/yamlsaver/utils.go +++ b/libs/dyn/yamlsaver/utils.go @@ -26,7 +26,9 @@ func ConvertToMapValue(strct any, order *Order, skipFields []string, dst map[str } func skipAndOrder(mv dyn.Value, order *Order, skipFields []string, dst map[string]dyn.Value) (dyn.Value, error) { - for k, v := range mv.MustMap() { + for _, pair := range mv.MustMap().Pairs() { + k := pair.Key.MustString() + v := pair.Value if v.Kind() == dyn.KindNil { continue } diff --git a/libs/dyn/yamlsaver/utils_test.go b/libs/dyn/yamlsaver/utils_test.go index 32c9143bea..04b4c404fb 100644 --- a/libs/dyn/yamlsaver/utils_test.go +++ b/libs/dyn/yamlsaver/utils_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/databricks/cli/libs/dyn" - "github.com/stretchr/testify/assert" + assert "github.com/databricks/cli/libs/dyn/dynassert" ) func TestConvertToMapValueWithOrder(t *testing.T) { @@ -32,7 +32,7 @@ func TestConvertToMapValueWithOrder(t *testing.T) { result, err := ConvertToMapValue(v, NewOrder([]string{"list", "name", "map"}), []string{"format"}, map[string]dyn.Value{}) assert.NoError(t, err) - assert.Equal(t, map[string]dyn.Value{ + assert.Equal(t, dyn.V(map[string]dyn.Value{ "list": dyn.NewValue([]dyn.Value{ dyn.V("a"), dyn.V("b"), @@ -44,5 +44,5 @@ func TestConvertToMapValueWithOrder(t *testing.T) { "key2": dyn.V("value2"), }, dyn.Location{Line: -1}), "long_name_field": dyn.NewValue("long name goes here", dyn.Location{Line: 1}), - }, result.MustMap()) + }), result) }