Skip to content

Commit

Permalink
Compare .Kind() instead of direct equality checks on a dyn.Value (#…
Browse files Browse the repository at this point in the history
…1520)

## Changes

This PR makes two changes:

1. In #1510 we'll be adding
multiple associated location metadata with a dyn.Value. The Go compiler
does not allow comparing structs if they contain slice values
(presumably due to multiple possible definitions for equality). In
anticipation for adding a `[]dyn.Location` type field to `dyn.Value`
this PR removes all direct comparisons of `dyn.Value` and instead relies
on the kind.

2. Retain location metadata for values in convert.FromTyped. The change
diff is exactly the same as #1523.
It's been combined with this PR because they both depend on each other
to prevent test failures (forming a test failure deadlock).

Go patch used:
```
@@
var x expression
@@
-x == dyn.InvalidValue
+x.Kind() == dyn.KindInvalid

@@
var x expression
@@
-x != dyn.InvalidValue
+x.Kind() != dyn.KindInvalid

@@
var x expression
@@
-x == dyn.NilValue
+x.Kind() == dyn.KindNil

@@
var x expression
@@
-x != dyn.NilValue
+x.Kind() != dyn.KindNil
```
 

## Tests
Unit tests and integration tests pass.
  • Loading branch information
shreyas-goenka authored Jun 27, 2024
1 parent dba6164 commit 4d8eba0
Show file tree
Hide file tree
Showing 16 changed files with 195 additions and 81 deletions.
4 changes: 2 additions & 2 deletions bundle/config/mutator/environments_compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ func (m *environmentsToTargets) Apply(ctx context.Context, b *bundle.Bundle) dia
targets := v.Get("targets")

// Return an error if both "environments" and "targets" are set.
if environments != dyn.InvalidValue && targets != dyn.InvalidValue {
if environments.Kind() != dyn.KindInvalid && targets.Kind() != dyn.KindInvalid {
return dyn.InvalidValue, fmt.Errorf(
"both 'environments' and 'targets' are specified; only 'targets' should be used: %s",
environments.Location().String(),
)
}

// Rewrite "environments" to "targets".
if environments != dyn.InvalidValue && targets == dyn.InvalidValue {
if environments.Kind() != dyn.KindInvalid && targets.Kind() == dyn.KindInvalid {
nv, err := dyn.Set(v, "targets", environments)
if err != nil {
return dyn.InvalidValue, err
Expand Down
2 changes: 1 addition & 1 deletion bundle/config/mutator/merge_job_clusters.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (m *mergeJobClusters) jobClusterKey(v dyn.Value) string {

func (m *mergeJobClusters) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
err := b.Config.Mutate(func(v dyn.Value) (dyn.Value, error) {
if v == dyn.NilValue {
if v.Kind() == dyn.KindNil {
return v, nil
}

Expand Down
2 changes: 1 addition & 1 deletion bundle/config/mutator/merge_job_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (m *mergeJobTasks) taskKeyString(v dyn.Value) string {

func (m *mergeJobTasks) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
err := b.Config.Mutate(func(v dyn.Value) (dyn.Value, error) {
if v == dyn.NilValue {
if v.Kind() == dyn.KindNil {
return v, nil
}

Expand Down
2 changes: 1 addition & 1 deletion bundle/config/mutator/merge_pipeline_clusters.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (m *mergePipelineClusters) clusterLabel(v dyn.Value) string {

func (m *mergePipelineClusters) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
err := b.Config.Mutate(func(v dyn.Value) (dyn.Value, error) {
if v == dyn.NilValue {
if v.Kind() == dyn.KindNil {
return v, nil
}

Expand Down
19 changes: 12 additions & 7 deletions bundle/config/mutator/run_as.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,20 @@ func (e errBothSpAndUserSpecified) Error() string {
}

func validateRunAs(b *bundle.Bundle) error {
runAs := b.Config.RunAs

// Error if neither service_principal_name nor user_name are specified
if runAs.ServicePrincipalName == "" && runAs.UserName == "" {
return fmt.Errorf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s", b.Config.GetLocation("run_as"))
neitherSpecifiedErr := fmt.Errorf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s", b.Config.GetLocation("run_as"))
// Error if neither service_principal_name nor user_name are specified, but the
// run_as section is present.
if b.Config.Value().Get("run_as").Kind() == dyn.KindNil {
return neitherSpecifiedErr
}
// Error if one or both of service_principal_name and user_name are specified,
// but with empty values.
if b.Config.RunAs.ServicePrincipalName == "" && b.Config.RunAs.UserName == "" {
return neitherSpecifiedErr
}

// Error if both service_principal_name and user_name are specified
runAs := b.Config.RunAs
if runAs.UserName != "" && runAs.ServicePrincipalName != "" {
return errBothSpAndUserSpecified{
spName: runAs.ServicePrincipalName,
Expand Down Expand Up @@ -163,8 +169,7 @@ func setPipelineOwnersToRunAsIdentity(b *bundle.Bundle) {

func (m *setRunAs) Apply(_ context.Context, b *bundle.Bundle) diag.Diagnostics {
// Mutator is a no-op if run_as is not specified in the bundle
runAs := b.Config.RunAs
if runAs == nil {
if b.Config.Value().Get("run_as").Kind() == dyn.KindInvalid {
return nil
}

Expand Down
18 changes: 9 additions & 9 deletions bundle/config/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,39 +346,39 @@ func (r *Root) MergeTargetOverrides(name string) error {
}

// Merge `run_as`. This field must be overwritten if set, not merged.
if v := target.Get("run_as"); v != dyn.InvalidValue {
if v := target.Get("run_as"); v.Kind() != dyn.KindInvalid {
root, err = dyn.Set(root, "run_as", v)
if err != nil {
return err
}
}

// Below, we're setting fields on the bundle key, so make sure it exists.
if root.Get("bundle") == dyn.InvalidValue {
if root.Get("bundle").Kind() == dyn.KindInvalid {
root, err = dyn.Set(root, "bundle", dyn.NewValue(map[string]dyn.Value{}, dyn.Location{}))
if err != nil {
return err
}
}

// Merge `mode`. This field must be overwritten if set, not merged.
if v := target.Get("mode"); v != dyn.InvalidValue {
if v := target.Get("mode"); v.Kind() != dyn.KindInvalid {
root, err = dyn.SetByPath(root, dyn.NewPath(dyn.Key("bundle"), dyn.Key("mode")), v)
if err != nil {
return err
}
}

// Merge `compute_id`. This field must be overwritten if set, not merged.
if v := target.Get("compute_id"); v != dyn.InvalidValue {
if v := target.Get("compute_id"); v.Kind() != dyn.KindInvalid {
root, err = dyn.SetByPath(root, dyn.NewPath(dyn.Key("bundle"), dyn.Key("compute_id")), v)
if err != nil {
return err
}
}

// Merge `git`.
if v := target.Get("git"); v != dyn.InvalidValue {
if v := target.Get("git"); v.Kind() != dyn.KindInvalid {
ref, err := dyn.GetByPath(root, dyn.NewPath(dyn.Key("bundle"), dyn.Key("git")))
if err != nil {
ref = dyn.NewValue(map[string]dyn.Value{}, dyn.Location{})
Expand All @@ -391,7 +391,7 @@ func (r *Root) MergeTargetOverrides(name string) error {
}

// If the branch was overridden, we need to clear the inferred flag.
if branch := v.Get("branch"); branch != dyn.InvalidValue {
if branch := v.Get("branch"); branch.Kind() != dyn.KindInvalid {
out, err = dyn.SetByPath(out, dyn.NewPath(dyn.Key("inferred")), dyn.NewValue(false, dyn.Location{}))
if err != nil {
return err
Expand Down Expand Up @@ -419,7 +419,7 @@ func rewriteShorthands(v dyn.Value) (dyn.Value, error) {
// For each target, rewrite the variables block.
return dyn.Map(v, "targets", dyn.Foreach(func(_ dyn.Path, target dyn.Value) (dyn.Value, error) {
// Confirm it has a variables block.
if target.Get("variables") == dyn.InvalidValue {
if target.Get("variables").Kind() == dyn.KindInvalid {
return target, nil
}

Expand Down Expand Up @@ -464,15 +464,15 @@ func validateVariableOverrides(root, target dyn.Value) (err error) {
var tv map[string]variable.Variable

// Collect variables from the root.
if v := root.Get("variables"); v != dyn.InvalidValue {
if v := root.Get("variables"); v.Kind() != dyn.KindInvalid {
err = convert.ToTyped(&rv, v)
if err != nil {
return fmt.Errorf("unable to collect variables from root: %w", err)
}
}

// Collect variables from the target.
if v := target.Get("variables"); v != dyn.InvalidValue {
if v := target.Get("variables"); v.Kind() != dyn.KindInvalid {
err = convert.ToTyped(&tv, v)
if err != nil {
return fmt.Errorf("unable to collect variables from target: %w", err)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
bundle:
name: "abc"

run_as:
service_principal_name: ""
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
bundle:
name: "abc"

run_as:
user_name: ""
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
bundle:
name: "abc"

run_as:
service_principal_name: ""
user_name: ""
64 changes: 45 additions & 19 deletions bundle/tests/run_as_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,27 +196,53 @@ func TestRunAsErrorWhenBothUserAndSpSpecified(t *testing.T) {
}

func TestRunAsErrorNeitherUserOrSpSpecified(t *testing.T) {
b := load(t, "./run_as/not_allowed/neither_sp_nor_user")

ctx := context.Background()
bundle.ApplyFunc(ctx, b, func(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
b.Config.Workspace.CurrentUser = &config.User{
User: &iam.User{
UserName: "my_service_principal",
},
}
return nil
})

diags := bundle.Apply(ctx, b, mutator.SetRunAs())
err := diags.Error()

configPath := filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user/databricks.yml")
assert.EqualError(t, err, fmt.Sprintf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s:4:8", configPath))
tcases := []struct {
name string
err string
}{
{
name: "empty_run_as",
err: fmt.Sprintf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s:4:8", filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user/empty_run_as/databricks.yml")),
},
{
name: "empty_sp",
err: fmt.Sprintf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s:5:3", filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user/empty_sp/databricks.yml")),
},
{
name: "empty_user",
err: fmt.Sprintf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s:5:3", filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user/empty_user/databricks.yml")),
},
{
name: "empty_user_and_sp",
err: fmt.Sprintf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s:5:3", filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user/empty_user_and_sp/databricks.yml")),
},
}

for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {

bundlePath := fmt.Sprintf("./run_as/not_allowed/neither_sp_nor_user/%s", tc.name)
b := load(t, bundlePath)

ctx := context.Background()
bundle.ApplyFunc(ctx, b, func(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
b.Config.Workspace.CurrentUser = &config.User{
User: &iam.User{
UserName: "my_service_principal",
},
}
return nil
})

diags := bundle.Apply(ctx, b, mutator.SetRunAs())
err := diags.Error()
assert.EqualError(t, err, tc.err)
})
}
}

func TestRunAsErrorNeitherUserOrSpSpecifiedAtTargetOverride(t *testing.T) {
b := loadTarget(t, "./run_as/not_allowed/neither_sp_nor_user_override", "development")
b := loadTarget(t, "./run_as/not_allowed/neither_sp_nor_user/override", "development")

ctx := context.Background()
bundle.ApplyFunc(ctx, b, func(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
Expand All @@ -231,7 +257,7 @@ func TestRunAsErrorNeitherUserOrSpSpecifiedAtTargetOverride(t *testing.T) {
diags := bundle.Apply(ctx, b, mutator.SetRunAs())
err := diags.Error()

configPath := filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user_override/override.yml")
configPath := filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user/override/override.yml")
assert.EqualError(t, err, fmt.Sprintf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s:4:12", configPath))
}

Expand Down
38 changes: 23 additions & 15 deletions libs/dyn/convert/from_typed.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func fromTyped(src any, ref dyn.Value, options ...fromTypedOptions) (dyn.Value,
// Dereference pointer if necessary
for srcv.Kind() == reflect.Pointer {
if srcv.IsNil() {
return dyn.NilValue, nil
return dyn.NilValue.WithLocation(ref.Location()), nil
}
srcv = srcv.Elem()

Expand All @@ -55,27 +55,35 @@ func fromTyped(src any, ref dyn.Value, options ...fromTypedOptions) (dyn.Value,
}
}

var v dyn.Value
var err error
switch srcv.Kind() {
case reflect.Struct:
return fromTypedStruct(srcv, ref, options...)
v, err = fromTypedStruct(srcv, ref, options...)
case reflect.Map:
return fromTypedMap(srcv, ref)
v, err = fromTypedMap(srcv, ref)
case reflect.Slice:
return fromTypedSlice(srcv, ref)
v, err = fromTypedSlice(srcv, ref)
case reflect.String:
return fromTypedString(srcv, ref, options...)
v, err = fromTypedString(srcv, ref, options...)
case reflect.Bool:
return fromTypedBool(srcv, ref, options...)
v, err = fromTypedBool(srcv, ref, options...)
case reflect.Int, reflect.Int32, reflect.Int64:
return fromTypedInt(srcv, ref, options...)
v, err = fromTypedInt(srcv, ref, options...)
case reflect.Float32, reflect.Float64:
return fromTypedFloat(srcv, ref, options...)
v, err = fromTypedFloat(srcv, ref, options...)
case reflect.Invalid:
// If the value is untyped and not set (e.g. any type with nil value), we return nil.
return dyn.NilValue, nil
v, err = dyn.NilValue, nil
default:
return dyn.InvalidValue, fmt.Errorf("unsupported type: %s", srcv.Kind())
}

return dyn.InvalidValue, fmt.Errorf("unsupported type: %s", srcv.Kind())
// Ensure the location metadata is retained.
if err != nil {
return dyn.InvalidValue, err
}
return v.WithLocation(ref.Location()), err
}

func fromTypedStruct(src reflect.Value, ref dyn.Value, options ...fromTypedOptions) (dyn.Value, error) {
Expand Down Expand Up @@ -117,7 +125,7 @@ func fromTypedStruct(src reflect.Value, ref dyn.Value, options ...fromTypedOptio
}

// Either if the key was set in the reference or the field is not zero-valued, we include it.
if ok || nv != dyn.NilValue {
if ok || nv.Kind() != dyn.KindNil {
out.Set(refk, nv)
}
}
Expand All @@ -127,7 +135,7 @@ func fromTypedStruct(src reflect.Value, ref dyn.Value, options ...fromTypedOptio
// 2. The reference is a map (i.e. the struct was and still is empty).
// 3. The "includeZeroValues" option is set (i.e. the struct is a non-nil pointer).
if out.Len() > 0 || ref.Kind() == dyn.KindMap || slices.Contains(options, includeZeroValues) {
return dyn.NewValue(out, ref.Location()), nil
return dyn.V(out), nil
}

// Otherwise, return nil.
Expand Down Expand Up @@ -179,7 +187,7 @@ func fromTypedMap(src reflect.Value, ref dyn.Value) (dyn.Value, error) {
out.Set(refk, nv)
}

return dyn.NewValue(out, ref.Location()), nil
return dyn.V(out), nil
}

func fromTypedSlice(src reflect.Value, ref dyn.Value) (dyn.Value, error) {
Expand All @@ -206,7 +214,7 @@ func fromTypedSlice(src reflect.Value, ref dyn.Value) (dyn.Value, error) {
refv := ref.Index(i)

// Use nil reference if there is no reference for this index.
if refv == dyn.InvalidValue {
if refv.Kind() == dyn.KindInvalid {
refv = dyn.NilValue
}

Expand All @@ -219,7 +227,7 @@ func fromTypedSlice(src reflect.Value, ref dyn.Value) (dyn.Value, error) {
out[i] = nv
}

return dyn.NewValue(out, ref.Location()), nil
return dyn.V(out), nil
}

func fromTypedString(src reflect.Value, ref dyn.Value, options ...fromTypedOptions) (dyn.Value, error) {
Expand Down
Loading

0 comments on commit 4d8eba0

Please sign in to comment.