Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compare .Kind() instead of direct equality checks on a dyn.Value #1520

Merged
merged 8 commits into from
Jun 27, 2024
Merged
4 changes: 2 additions & 2 deletions bundle/config/mutator/environments_compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ func (m *environmentsToTargets) Apply(ctx context.Context, b *bundle.Bundle) dia
targets := v.Get("targets")

// Return an error if both "environments" and "targets" are set.
if environments != dyn.InvalidValue && targets != dyn.InvalidValue {
if environments.Kind() != dyn.KindInvalid && targets.Kind() != dyn.KindInvalid {
return dyn.InvalidValue, fmt.Errorf(
"both 'environments' and 'targets' are specified; only 'targets' should be used: %s",
environments.Location().String(),
)
}

// Rewrite "environments" to "targets".
if environments != dyn.InvalidValue && targets == dyn.InvalidValue {
if environments.Kind() != dyn.KindInvalid && targets.Kind() == dyn.KindInvalid {
nv, err := dyn.Set(v, "targets", environments)
if err != nil {
return dyn.InvalidValue, err
Expand Down
2 changes: 1 addition & 1 deletion bundle/config/mutator/merge_job_clusters.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (m *mergeJobClusters) jobClusterKey(v dyn.Value) string {

func (m *mergeJobClusters) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
err := b.Config.Mutate(func(v dyn.Value) (dyn.Value, error) {
if v == dyn.NilValue {
if v.Kind() == dyn.KindNil {
return v, nil
}

Expand Down
2 changes: 1 addition & 1 deletion bundle/config/mutator/merge_job_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (m *mergeJobTasks) taskKeyString(v dyn.Value) string {

func (m *mergeJobTasks) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
err := b.Config.Mutate(func(v dyn.Value) (dyn.Value, error) {
if v == dyn.NilValue {
if v.Kind() == dyn.KindNil {
return v, nil
}

Expand Down
2 changes: 1 addition & 1 deletion bundle/config/mutator/merge_pipeline_clusters.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (m *mergePipelineClusters) clusterLabel(v dyn.Value) string {

func (m *mergePipelineClusters) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
err := b.Config.Mutate(func(v dyn.Value) (dyn.Value, error) {
if v == dyn.NilValue {
if v.Kind() == dyn.KindNil {
return v, nil
}

Expand Down
19 changes: 12 additions & 7 deletions bundle/config/mutator/run_as.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,20 @@ func (e errBothSpAndUserSpecified) Error() string {
}

func validateRunAs(b *bundle.Bundle) error {
runAs := b.Config.RunAs

// Error if neither service_principal_name nor user_name are specified
if runAs.ServicePrincipalName == "" && runAs.UserName == "" {
return fmt.Errorf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s", b.Config.GetLocation("run_as"))
neitherSpecifiedErr := fmt.Errorf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s", b.Config.GetLocation("run_as"))
// Error if neither service_principal_name nor user_name are specified, but the
// run_as section is present.
if b.Config.Value().Get("run_as").Kind() == dyn.KindNil {
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
return neitherSpecifiedErr
}
// Error if one or both of service_principal_name and user_name are specified,
// but with empty values.
if b.Config.RunAs.ServicePrincipalName == "" && b.Config.RunAs.UserName == "" {
return neitherSpecifiedErr
}

// Error if both service_principal_name and user_name are specified
runAs := b.Config.RunAs
if runAs.UserName != "" && runAs.ServicePrincipalName != "" {
return errBothSpAndUserSpecified{
spName: runAs.ServicePrincipalName,
Expand Down Expand Up @@ -163,8 +169,7 @@ func setPipelineOwnersToRunAsIdentity(b *bundle.Bundle) {

func (m *setRunAs) Apply(_ context.Context, b *bundle.Bundle) diag.Diagnostics {
// Mutator is a no-op if run_as is not specified in the bundle
runAs := b.Config.RunAs
if runAs == nil {
if b.Config.Value().Get("run_as").Kind() == dyn.KindInvalid {
return nil
}

Expand Down
18 changes: 9 additions & 9 deletions bundle/config/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,39 +346,39 @@ func (r *Root) MergeTargetOverrides(name string) error {
}

// Merge `run_as`. This field must be overwritten if set, not merged.
if v := target.Get("run_as"); v != dyn.InvalidValue {
if v := target.Get("run_as"); v.Kind() != dyn.KindInvalid {
root, err = dyn.Set(root, "run_as", v)
if err != nil {
return err
}
}

// Below, we're setting fields on the bundle key, so make sure it exists.
if root.Get("bundle") == dyn.InvalidValue {
if root.Get("bundle").Kind() == dyn.KindInvalid {
root, err = dyn.Set(root, "bundle", dyn.NewValue(map[string]dyn.Value{}, dyn.Location{}))
if err != nil {
return err
}
}

// Merge `mode`. This field must be overwritten if set, not merged.
if v := target.Get("mode"); v != dyn.InvalidValue {
if v := target.Get("mode"); v.Kind() != dyn.KindInvalid {
root, err = dyn.SetByPath(root, dyn.NewPath(dyn.Key("bundle"), dyn.Key("mode")), v)
if err != nil {
return err
}
}

// Merge `compute_id`. This field must be overwritten if set, not merged.
if v := target.Get("compute_id"); v != dyn.InvalidValue {
if v := target.Get("compute_id"); v.Kind() != dyn.KindInvalid {
root, err = dyn.SetByPath(root, dyn.NewPath(dyn.Key("bundle"), dyn.Key("compute_id")), v)
if err != nil {
return err
}
}

// Merge `git`.
if v := target.Get("git"); v != dyn.InvalidValue {
if v := target.Get("git"); v.Kind() != dyn.KindInvalid {
ref, err := dyn.GetByPath(root, dyn.NewPath(dyn.Key("bundle"), dyn.Key("git")))
if err != nil {
ref = dyn.NewValue(map[string]dyn.Value{}, dyn.Location{})
Expand All @@ -391,7 +391,7 @@ func (r *Root) MergeTargetOverrides(name string) error {
}

// If the branch was overridden, we need to clear the inferred flag.
if branch := v.Get("branch"); branch != dyn.InvalidValue {
if branch := v.Get("branch"); branch.Kind() != dyn.KindInvalid {
out, err = dyn.SetByPath(out, dyn.NewPath(dyn.Key("inferred")), dyn.NewValue(false, dyn.Location{}))
if err != nil {
return err
Expand Down Expand Up @@ -419,7 +419,7 @@ func rewriteShorthands(v dyn.Value) (dyn.Value, error) {
// For each target, rewrite the variables block.
return dyn.Map(v, "targets", dyn.Foreach(func(_ dyn.Path, target dyn.Value) (dyn.Value, error) {
// Confirm it has a variables block.
if target.Get("variables") == dyn.InvalidValue {
if target.Get("variables").Kind() == dyn.KindInvalid {
return target, nil
}

Expand Down Expand Up @@ -464,15 +464,15 @@ func validateVariableOverrides(root, target dyn.Value) (err error) {
var tv map[string]variable.Variable

// Collect variables from the root.
if v := root.Get("variables"); v != dyn.InvalidValue {
if v := root.Get("variables"); v.Kind() != dyn.KindInvalid {
err = convert.ToTyped(&rv, v)
if err != nil {
return fmt.Errorf("unable to collect variables from root: %w", err)
}
}

// Collect variables from the target.
if v := target.Get("variables"); v != dyn.InvalidValue {
if v := target.Get("variables"); v.Kind() != dyn.KindInvalid {
err = convert.ToTyped(&tv, v)
if err != nil {
return fmt.Errorf("unable to collect variables from target: %w", err)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
bundle:
name: "abc"

run_as:
service_principal_name: ""
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
bundle:
name: "abc"

run_as:
user_name: ""
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
bundle:
name: "abc"

run_as:
service_principal_name: ""
user_name: ""
64 changes: 45 additions & 19 deletions bundle/tests/run_as_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,27 +196,53 @@ func TestRunAsErrorWhenBothUserAndSpSpecified(t *testing.T) {
}

func TestRunAsErrorNeitherUserOrSpSpecified(t *testing.T) {
b := load(t, "./run_as/not_allowed/neither_sp_nor_user")

ctx := context.Background()
bundle.ApplyFunc(ctx, b, func(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
b.Config.Workspace.CurrentUser = &config.User{
User: &iam.User{
UserName: "my_service_principal",
},
}
return nil
})

diags := bundle.Apply(ctx, b, mutator.SetRunAs())
err := diags.Error()

configPath := filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user/databricks.yml")
assert.EqualError(t, err, fmt.Sprintf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s:4:8", configPath))
tcases := []struct {
name string
err string
}{
{
name: "empty_run_as",
err: fmt.Sprintf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s:4:8", filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user/empty_run_as/databricks.yml")),
},
{
name: "empty_sp",
err: fmt.Sprintf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s:5:3", filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user/empty_sp/databricks.yml")),
},
{
name: "empty_user",
err: fmt.Sprintf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s:5:3", filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user/empty_user/databricks.yml")),
},
{
name: "empty_user_and_sp",
err: fmt.Sprintf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s:5:3", filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user/empty_user_and_sp/databricks.yml")),
},
}

for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {

bundlePath := fmt.Sprintf("./run_as/not_allowed/neither_sp_nor_user/%s", tc.name)
b := load(t, bundlePath)

ctx := context.Background()
bundle.ApplyFunc(ctx, b, func(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
b.Config.Workspace.CurrentUser = &config.User{
User: &iam.User{
UserName: "my_service_principal",
},
}
return nil
})

diags := bundle.Apply(ctx, b, mutator.SetRunAs())
err := diags.Error()
assert.EqualError(t, err, tc.err)
})
}
}

func TestRunAsErrorNeitherUserOrSpSpecifiedAtTargetOverride(t *testing.T) {
b := loadTarget(t, "./run_as/not_allowed/neither_sp_nor_user_override", "development")
b := loadTarget(t, "./run_as/not_allowed/neither_sp_nor_user/override", "development")

ctx := context.Background()
bundle.ApplyFunc(ctx, b, func(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
Expand All @@ -231,7 +257,7 @@ func TestRunAsErrorNeitherUserOrSpSpecifiedAtTargetOverride(t *testing.T) {
diags := bundle.Apply(ctx, b, mutator.SetRunAs())
err := diags.Error()

configPath := filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user_override/override.yml")
configPath := filepath.FromSlash("run_as/not_allowed/neither_sp_nor_user/override/override.yml")
assert.EqualError(t, err, fmt.Sprintf("run_as section must specify exactly one identity. Neither service_principal_name nor user_name is specified at %s:4:12", configPath))
}

Expand Down
38 changes: 23 additions & 15 deletions libs/dyn/convert/from_typed.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func fromTyped(src any, ref dyn.Value, options ...fromTypedOptions) (dyn.Value,
// Dereference pointer if necessary
for srcv.Kind() == reflect.Pointer {
if srcv.IsNil() {
return dyn.NilValue, nil
return dyn.NilValue.WithLocation(ref.Location()), nil
}
srcv = srcv.Elem()

Expand All @@ -55,27 +55,35 @@ func fromTyped(src any, ref dyn.Value, options ...fromTypedOptions) (dyn.Value,
}
}

var v dyn.Value
var err error
switch srcv.Kind() {
case reflect.Struct:
return fromTypedStruct(srcv, ref, options...)
v, err = fromTypedStruct(srcv, ref, options...)
case reflect.Map:
return fromTypedMap(srcv, ref)
v, err = fromTypedMap(srcv, ref)
case reflect.Slice:
return fromTypedSlice(srcv, ref)
v, err = fromTypedSlice(srcv, ref)
case reflect.String:
return fromTypedString(srcv, ref, options...)
v, err = fromTypedString(srcv, ref, options...)
case reflect.Bool:
return fromTypedBool(srcv, ref, options...)
v, err = fromTypedBool(srcv, ref, options...)
case reflect.Int, reflect.Int32, reflect.Int64:
return fromTypedInt(srcv, ref, options...)
v, err = fromTypedInt(srcv, ref, options...)
case reflect.Float32, reflect.Float64:
return fromTypedFloat(srcv, ref, options...)
v, err = fromTypedFloat(srcv, ref, options...)
case reflect.Invalid:
// If the value is untyped and not set (e.g. any type with nil value), we return nil.
return dyn.NilValue, nil
v, err = dyn.NilValue, nil
default:
return dyn.InvalidValue, fmt.Errorf("unsupported type: %s", srcv.Kind())
}

return dyn.InvalidValue, fmt.Errorf("unsupported type: %s", srcv.Kind())
// Ensure the location metadata is retained.
if err != nil {
return dyn.InvalidValue, err
}
return v.WithLocation(ref.Location()), err
}

func fromTypedStruct(src reflect.Value, ref dyn.Value, options ...fromTypedOptions) (dyn.Value, error) {
Expand Down Expand Up @@ -117,7 +125,7 @@ func fromTypedStruct(src reflect.Value, ref dyn.Value, options ...fromTypedOptio
}

// Either if the key was set in the reference or the field is not zero-valued, we include it.
if ok || nv != dyn.NilValue {
if ok || nv.Kind() != dyn.KindNil {
out.Set(refk, nv)
}
}
Expand All @@ -127,7 +135,7 @@ func fromTypedStruct(src reflect.Value, ref dyn.Value, options ...fromTypedOptio
// 2. The reference is a map (i.e. the struct was and still is empty).
// 3. The "includeZeroValues" option is set (i.e. the struct is a non-nil pointer).
if out.Len() > 0 || ref.Kind() == dyn.KindMap || slices.Contains(options, includeZeroValues) {
return dyn.NewValue(out, ref.Location()), nil
return dyn.V(out), nil
}

// Otherwise, return nil.
Expand Down Expand Up @@ -179,7 +187,7 @@ func fromTypedMap(src reflect.Value, ref dyn.Value) (dyn.Value, error) {
out.Set(refk, nv)
}

return dyn.NewValue(out, ref.Location()), nil
return dyn.V(out), nil
}

func fromTypedSlice(src reflect.Value, ref dyn.Value) (dyn.Value, error) {
Expand All @@ -206,7 +214,7 @@ func fromTypedSlice(src reflect.Value, ref dyn.Value) (dyn.Value, error) {
refv := ref.Index(i)

// Use nil reference if there is no reference for this index.
if refv == dyn.InvalidValue {
if refv.Kind() == dyn.KindInvalid {
refv = dyn.NilValue
}

Expand All @@ -219,7 +227,7 @@ func fromTypedSlice(src reflect.Value, ref dyn.Value) (dyn.Value, error) {
out[i] = nv
}

return dyn.NewValue(out, ref.Location()), nil
return dyn.V(out), nil
}

func fromTypedString(src reflect.Value, ref dyn.Value, options ...fromTypedOptions) (dyn.Value, error) {
Expand Down
Loading
Loading