diff --git a/src/databricks/labs/ucx/install.py b/src/databricks/labs/ucx/install.py index 4f5ecf4ed9..dfc4c90fbb 100644 --- a/src/databricks/labs/ucx/install.py +++ b/src/databricks/labs/ucx/install.py @@ -243,13 +243,7 @@ def warehouse_type(_): cluster_policy = json.loads(self._prompts.choice_from_dict("Choose a cluster policy", cluster_policies)) instance_profile, spark_conf_dict = self._get_ext_hms_conf_from_policy(cluster_policy) - logger.info("Creating UCX cluster policy.") - policy_id = self._ws.cluster_policies.create( - name=f"Unity Catalog Migration ({inventory_database})", - definition=self._cluster_policy_definition(conf=spark_conf_dict, instance_profile=instance_profile), - description="Custom cluster policy for Unity Catalog Migration (UCX)", - ).policy_id - + policy_id = self._create_cluster_policy(inventory_database, spark_conf_dict, instance_profile) config = WorkspaceConfig( inventory_database=inventory_database, workspace_group_regex=configure_groups.workspace_group_regex, @@ -275,6 +269,26 @@ def warehouse_type(_): def _policy_config(value: str): return {"type": "fixed", "value": value} + def _create_cluster_policy( + self, inventory_database: str, spark_conf: dict, instance_profile: str | None + ) -> str | None: + policy_name = f"Unity Catalog Migration ({inventory_database}) ({self._ws.current_user.me().user_name})" + policies = self._ws.cluster_policies.list() + policy_id = None + for policy in policies: + if policy.name == policy_name: + policy_id = policy.policy_id + logger.info(f"Cluster policy {policy_name} already present, reusing the same.") + break + if not policy_id: + logger.info("Creating UCX cluster policy.") + policy_id = self._ws.cluster_policies.create( + name=policy_name, + definition=self._cluster_policy_definition(conf=spark_conf, instance_profile=instance_profile), + description="Custom cluster policy for Unity Catalog Migration (UCX)", + ).policy_id + return policy_id + def _cluster_policy_definition(self, conf: dict, instance_profile: str | None) -> str: policy_definition = { "spark_version": self._policy_config(self._ws.clusters.select_spark_version(latest=True)), @@ -543,22 +557,28 @@ def _upload_wheel(self): self._installation.save(self._config) return self._wheels.upload_to_wsfs() - def create_jobs(self): - logger.debug(f"Creating jobs from tasks in {main.__name__}") - remote_wheel = self._upload_wheel() + def _upload_cluster_policy(self, remote_wheel: str): try: - policy_definition = self._ws.cluster_policies.get(policy_id=self.config.policy_id).definition + if self.config.policy_id is None: + msg = "Cluster policy not present, please uninstall and reinstall ucx completely." + raise InvalidParameterValue(msg) + policy = self._ws.cluster_policies.get(policy_id=self.config.policy_id) except NotFound as err: msg = f"UCX Policy {self.config.policy_id} not found, please reinstall UCX" logger.error(msg) raise NotFound(msg) from err + if policy.name is not None: + self._ws.cluster_policies.edit( + policy_id=self.config.policy_id, + name=policy.name, + definition=policy.definition, + libraries=[compute.Library(whl=f"dbfs:{remote_wheel}")], + ) - self._ws.cluster_policies.edit( - policy_id=self.config.policy_id, - name=f"Unity Catalog Migration ({self.config.inventory_database})", - definition=policy_definition, - libraries=[compute.Library(whl=f"dbfs:{remote_wheel}")], - ) + def create_jobs(self): + logger.debug(f"Creating jobs from tasks in {main.__name__}") + remote_wheel = self._upload_wheel() + self._upload_cluster_policy(remote_wheel) desired_steps = {t.workflow for t in _TASKS.values() if t.cloud_compatible(self._ws.config)} wheel_runner = None diff --git a/tests/integration/test_installation.py b/tests/integration/test_installation.py index 0d3f6facef..236de89ba0 100644 --- a/tests/integration/test_installation.py +++ b/tests/integration/test_installation.py @@ -116,10 +116,11 @@ def test_job_failure_propagates_correct_error_message_and_logs(ws, sql_backend, @retried(on=[NotFound, Unknown, InvalidParameterValue], timeout=timedelta(minutes=18)) def test_job_cluster_policy(ws, new_installation): install = new_installation(lambda wc: replace(wc, override_clusters=None)) + user_name = ws.current_user.me().user_name cluster_policy = ws.cluster_policies.get(policy_id=install.config.policy_id) policy_definition = json.loads(cluster_policy.definition) - assert cluster_policy.name == f"Unity Catalog Migration ({install.config.inventory_database})" + assert cluster_policy.name == f"Unity Catalog Migration ({install.config.inventory_database}) ({user_name})" assert policy_definition["spark_version"]["value"] == ws.clusters.select_spark_version(latest=True) assert policy_definition["node_type_id"]["value"] == ws.clusters.select_node_type(local_disk=True) diff --git a/tests/unit/test_install.py b/tests/unit/test_install.py index 6d3d45b227..98ba51fb71 100644 --- a/tests/unit/test_install.py +++ b/tests/unit/test_install.py @@ -158,7 +158,7 @@ def test_install_cluster_override_jobs(ws, mock_installation, any_prompt): sql_backend = MockBackend() wheels = create_autospec(WheelsV2) workspace_installation = WorkspaceInstallation( - WorkspaceConfig(inventory_database='ucx', override_clusters={"main": 'one', "tacl": 'two'}), + WorkspaceConfig(inventory_database='ucx', override_clusters={"main": 'one', "tacl": 'two'}, policy_id='123'), mock_installation, sql_backend, wheels, @@ -190,7 +190,7 @@ def test_write_protected_dbfs(ws, tmp_path, mock_installation): ) workspace_installation = WorkspaceInstallation( - WorkspaceConfig(inventory_database='ucx'), + WorkspaceConfig(inventory_database='ucx', policy_id='123'), mock_installation, sql_backend, wheels, @@ -214,6 +214,7 @@ def test_write_protected_dbfs(ws, tmp_path, mock_installation): 'log_level': 'INFO', 'num_threads': 10, 'override_clusters': {'main': '2222-999999-nosecuri', 'tacl': '3333-999999-legacytc'}, + 'policy_id': '123', 'renamed_group_prefix': 'ucx-renamed-', 'workspace_start_path': '/', }, @@ -225,7 +226,7 @@ def test_writeable_dbfs(ws, tmp_path, mock_installation, any_prompt): sql_backend = MockBackend() wheels = create_autospec(WheelsV2) workspace_installation = WorkspaceInstallation( - WorkspaceConfig(inventory_database='ucx'), + WorkspaceConfig(inventory_database='ucx', policy_id='123'), mock_installation, sql_backend, wheels, @@ -452,6 +453,47 @@ def test_save_config_strip_group_names(ws, mock_installation): ) +def test_cluster_policy_definition_present_reuse(ws, mock_installation): + ws.config.is_aws = False + ws.config.is_azure = True + ws.config.is_gcp = False + ws.cluster_policies.list.return_value = [ + Policy( + policy_id="foo1", + name="Unity Catalog Migration (ucx) (me@example.com)", + definition=json.dumps({}), + description="Custom cluster policy for Unity Catalog Migration (UCX)", + ) + ] + prompts = MockPrompts( + { + r".*PRO or SERVERLESS SQL warehouse.*": "1", + r"Choose how to map the workspace groups.*": "2", # specify names + r".*workspace group names.*": "g1, g2, g99", + r".*We have identified one or more cluster.*": "No", + r".*Choose a cluster policy.*": "0", + r".*": "", + } + ) + install = WorkspaceInstaller(prompts, mock_installation, ws) + install.configure() + mock_installation.assert_file_written( + 'config.yml', + { + 'version': 2, + 'default_catalog': 'ucx_default', + 'include_group_names': ['g1', 'g2', 'g99'], + 'inventory_database': 'ucx', + 'log_level': 'INFO', + 'num_threads': 8, + 'policy_id': 'foo1', + 'renamed_group_prefix': 'db-temp-', + 'warehouse_id': 'abc', + 'workspace_start_path': '/', + }, + ) + + def test_cluster_policy_definition_azure_hms(ws, mock_installation): ws.config.is_aws = False ws.config.is_azure = True @@ -498,7 +540,7 @@ def test_cluster_policy_definition_azure_hms(ws, mock_installation): "azure_attributes.availability": {"type": "fixed", "value": "ON_DEMAND_AZURE"}, } ws.cluster_policies.create.assert_called_with( - name="Unity Catalog Migration (ucx)", + name="Unity Catalog Migration (ucx) (me@example.com)", definition=json.dumps(policy_definition_actual), description="Custom cluster policy for Unity Catalog Migration (UCX)", ) @@ -541,7 +583,7 @@ def test_cluster_policy_definition_aws_glue(ws, mock_installation): "aws_attributes.instance_profile_arn": {"type": "fixed", "value": "role_arn_1"}, } ws.cluster_policies.create.assert_called_with( - name="Unity Catalog Migration (ucx)", + name="Unity Catalog Migration (ucx) (me@example.com)", definition=json.dumps(policy_definition_actual), description="Custom cluster policy for Unity Catalog Migration (UCX)", ) @@ -592,7 +634,7 @@ def test_cluster_policy_definition_gcp(ws, mock_installation): "gcp_attributes.availability": {"type": "fixed", "value": "ON_DEMAND_GCP"}, } ws.cluster_policies.create.assert_called_with( - name="Unity Catalog Migration (ucx)", + name="Unity Catalog Migration (ucx) (me@example.com)", definition=json.dumps(policy_definition_actual), description="Custom cluster policy for Unity Catalog Migration (UCX)", ) @@ -611,17 +653,19 @@ def test_install_edit_policy_with_library(ws, mock_installation, any_prompt): timedelta(seconds=1), ) wheels.upload_to_wsfs.return_value = "path1" - ws.cluster_policies.get.return_value = Policy(policy_id="foo") + ws.cluster_policies.get.return_value = Policy( + policy_id="foo", name="Unity Catalog Migration (ucx) (me@example.com)" + ) workspace_installation.create_jobs() ws.cluster_policies.edit.assert_called_with( - name="Unity Catalog Migration (ucx)", + name="Unity Catalog Migration (ucx) (me@example.com)", policy_id="foo", definition=None, libraries=[compute.Library(whl="dbfs:path1")], ) -def test_install_edit_policy_not_present(ws, mock_installation, any_prompt): +def test_install_edit_policy_not_found(ws, mock_installation, any_prompt): sql_backend = MockBackend() wheels = create_autospec(WheelsV2) workspace_installation = WorkspaceInstallation( @@ -638,6 +682,22 @@ def test_install_edit_policy_not_present(ws, mock_installation, any_prompt): workspace_installation.create_jobs() +def test_install_edit_policy_not_present(ws, mock_installation, any_prompt): + sql_backend = MockBackend() + wheels = create_autospec(WheelsV2) + workspace_installation = WorkspaceInstallation( + WorkspaceConfig(inventory_database='ucx', override_clusters={"main": 'one', "tacl": 'two'}), + mock_installation, + sql_backend, + wheels, + ws, + any_prompt, + timedelta(seconds=1), + ) + with pytest.raises(InvalidParameterValue): + workspace_installation.create_jobs() + + def test_save_config_with_custom_policy(ws, mock_installation): policy_def = b"""{ "aws_attributes.instance_profile_arn": { @@ -750,7 +810,7 @@ def test_main_with_existing_conf_does_not_recreate_config(ws, mocker, mock_insta } ) workspace_installation = WorkspaceInstallation( - WorkspaceConfig(inventory_database="..."), + WorkspaceConfig(inventory_database="...", policy_id='123'), mock_installation, sql_backend, create_autospec(WheelsV2),