diff --git a/bundle/python/warning.go b/bundle/python/warning.go index 9a718ae758..443b8fd27c 100644 --- a/bundle/python/warning.go +++ b/bundle/python/warning.go @@ -2,11 +2,13 @@ package python import ( "context" + "fmt" "strings" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/libraries" - "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go" "golang.org/x/mod/semver" ) @@ -19,7 +21,7 @@ func WrapperWarning() bundle.Mutator { func (m *wrapperWarning) Apply(ctx context.Context, b *bundle.Bundle) error { if hasIncompatibleWheelTasks(ctx, b) { - cmdio.LogString(ctx, "Python wheel tasks with local libraries require compute with DBR 13.1+. Please change your cluster configuration or set experimental 'python_wheel_wrapper' setting to 'true'") + return fmt.Errorf("python wheel tasks with local libraries require compute with DBR 13.1+. Please change your cluster configuration or set experimental 'python_wheel_wrapper' setting to 'true'") } return nil } @@ -44,6 +46,20 @@ func hasIncompatibleWheelTasks(ctx context.Context, b *bundle.Bundle) bool { } } } + + if task.ExistingClusterId != "" { + version, err := getSparkVersionForCluster(ctx, b.WorkspaceClient(), task.ExistingClusterId) + + // If there's error getting spark version for cluster, do not mark it as incompatible + if err != nil { + log.Warnf(ctx, "unable to get spark version for cluster %s, err: %s", task.ExistingClusterId, err.Error()) + return false + } + + if lowerThanExpectedVersion(ctx, version) { + return true + } + } } return false @@ -63,3 +79,12 @@ func lowerThanExpectedVersion(ctx context.Context, sparkVersion string) bool { func (m *wrapperWarning) Name() string { return "PythonWrapperWarning" } + +func getSparkVersionForCluster(ctx context.Context, w *databricks.WorkspaceClient, clusterId string) (string, error) { + details, err := w.Clusters.GetByClusterId(ctx, clusterId) + if err != nil { + return "", err + } + + return details.SparkVersion, nil +} diff --git a/bundle/python/warning_test.go b/bundle/python/warning_test.go index 46bbd6562d..83bc142f1f 100644 --- a/bundle/python/warning_test.go +++ b/bundle/python/warning_test.go @@ -12,6 +12,117 @@ import ( "github.com/stretchr/testify/require" ) +type MockClusterService struct{} + +// ChangeOwner implements compute.ClustersService. +func (MockClusterService) ChangeOwner(ctx context.Context, request compute.ChangeClusterOwner) error { + panic("unimplemented") +} + +// Create implements compute.ClustersService. +func (MockClusterService) Create(ctx context.Context, request compute.CreateCluster) (*compute.CreateClusterResponse, error) { + panic("unimplemented") +} + +// Delete implements compute.ClustersService. +func (MockClusterService) Delete(ctx context.Context, request compute.DeleteCluster) error { + panic("unimplemented") +} + +// Edit implements compute.ClustersService. +func (MockClusterService) Edit(ctx context.Context, request compute.EditCluster) error { + panic("unimplemented") +} + +// Events implements compute.ClustersService. +func (MockClusterService) Events(ctx context.Context, request compute.GetEvents) (*compute.GetEventsResponse, error) { + panic("unimplemented") +} + +// Get implements compute.ClustersService. +func (MockClusterService) Get(ctx context.Context, request compute.GetClusterRequest) (*compute.ClusterDetails, error) { + clusterDetails := map[string]*compute.ClusterDetails{ + "test-key-1": { + SparkVersion: "12.2.x-scala2.12", + }, + "test-key-2": { + SparkVersion: "13.2.x-scala2.12", + }, + } + + return clusterDetails[request.ClusterId], nil +} + +// GetPermissionLevels implements compute.ClustersService. +func (MockClusterService) GetPermissionLevels(ctx context.Context, request compute.GetClusterPermissionLevelsRequest) (*compute.GetClusterPermissionLevelsResponse, error) { + panic("unimplemented") +} + +// GetPermissions implements compute.ClustersService. +func (MockClusterService) GetPermissions(ctx context.Context, request compute.GetClusterPermissionsRequest) (*compute.ClusterPermissions, error) { + panic("unimplemented") +} + +// List implements compute.ClustersService. +func (MockClusterService) List(ctx context.Context, request compute.ListClustersRequest) (*compute.ListClustersResponse, error) { + panic("unimplemented") +} + +// ListNodeTypes implements compute.ClustersService. +func (MockClusterService) ListNodeTypes(ctx context.Context) (*compute.ListNodeTypesResponse, error) { + panic("unimplemented") +} + +// ListZones implements compute.ClustersService. +func (MockClusterService) ListZones(ctx context.Context) (*compute.ListAvailableZonesResponse, error) { + panic("unimplemented") +} + +// PermanentDelete implements compute.ClustersService. +func (MockClusterService) PermanentDelete(ctx context.Context, request compute.PermanentDeleteCluster) error { + panic("unimplemented") +} + +// Pin implements compute.ClustersService. +func (MockClusterService) Pin(ctx context.Context, request compute.PinCluster) error { + panic("unimplemented") +} + +// Resize implements compute.ClustersService. +func (MockClusterService) Resize(ctx context.Context, request compute.ResizeCluster) error { + panic("unimplemented") +} + +// Restart implements compute.ClustersService. +func (MockClusterService) Restart(ctx context.Context, request compute.RestartCluster) error { + panic("unimplemented") +} + +// SetPermissions implements compute.ClustersService. +func (MockClusterService) SetPermissions(ctx context.Context, request compute.ClusterPermissionsRequest) (*compute.ClusterPermissions, error) { + panic("unimplemented") +} + +// SparkVersions implements compute.ClustersService. +func (MockClusterService) SparkVersions(ctx context.Context) (*compute.GetSparkVersionsResponse, error) { + panic("unimplemented") +} + +// Start implements compute.ClustersService. +func (MockClusterService) Start(ctx context.Context, request compute.StartCluster) error { + panic("unimplemented") +} + +// Unpin implements compute.ClustersService. +func (MockClusterService) Unpin(ctx context.Context, request compute.UnpinCluster) error { + panic("unimplemented") +} + +// UpdatePermissions implements compute.ClustersService. +func (MockClusterService) UpdatePermissions(ctx context.Context, request compute.ClusterPermissionsRequest) (*compute.ClusterPermissions, error) { + panic("unimplemented") +} + func TestIncompatibleWheelTasksWithNewCluster(t *testing.T) { b := &bundle.Bundle{ Config: config.Root{ @@ -100,6 +211,43 @@ func TestIncompatibleWheelTasksWithJobClusterKey(t *testing.T) { require.True(t, hasIncompatibleWheelTasks(context.Background(), b)) } +func TestIncompatibleWheelTasksWithExistingClusterId(t *testing.T) { + b := &bundle.Bundle{ + Config: config.Root{ + Resources: config.Resources{ + Jobs: map[string]*resources.Job{ + "job1": { + JobSettings: &jobs.JobSettings{ + Tasks: []jobs.Task{ + { + TaskKey: "key1", + PythonWheelTask: &jobs.PythonWheelTask{}, + ExistingClusterId: "test-key-1", + Libraries: []compute.Library{ + {Whl: "./dist/test.whl"}, + }, + }, + { + TaskKey: "key2", + PythonWheelTask: &jobs.PythonWheelTask{}, + ExistingClusterId: "test-key-2", + Libraries: []compute.Library{ + {Whl: "./dist/test.whl"}, + }, + }, + }, + }, + }, + }, + }, + }, + } + + b.WorkspaceClient().Clusters.WithImpl(MockClusterService{}) + + require.True(t, hasIncompatibleWheelTasks(context.Background(), b)) +} + func TestNoIncompatibleWheelTasks(t *testing.T) { b := &bundle.Bundle{ Config: config.Root{ @@ -168,6 +316,14 @@ func TestNoIncompatibleWheelTasks(t *testing.T) { {Whl: "./dist/test.whl"}, }, }, + { + TaskKey: "key6", + PythonWheelTask: &jobs.PythonWheelTask{}, + ExistingClusterId: "test-key-2", + Libraries: []compute.Library{ + {Whl: "./dist/test.whl"}, + }, + }, }, }, }, @@ -176,6 +332,8 @@ func TestNoIncompatibleWheelTasks(t *testing.T) { }, } + b.WorkspaceClient().Clusters.WithImpl(MockClusterService{}) + require.False(t, hasIncompatibleWheelTasks(context.Background(), b)) }