diff --git a/src/databricks/labs/ucx/inventory/inventorizer.py b/src/databricks/labs/ucx/inventory/inventorizer.py index 376354d0a2..034d8889b0 100644 --- a/src/databricks/labs/ucx/inventory/inventorizer.py +++ b/src/databricks/labs/ucx/inventory/inventorizer.py @@ -15,6 +15,7 @@ ObjectType, SecretScope, ) +from ratelimit import limits, sleep_and_retry from databricks.labs.ucx.inventory.listing import WorkspaceListing from databricks.labs.ucx.inventory.types import ( @@ -23,7 +24,6 @@ PermissionsInventoryItem, RequestObjectType, ) -from databricks.labs.ucx.providers.client import ImprovedWorkspaceClient from databricks.labs.ucx.providers.groups_info import GroupMigrationState from databricks.labs.ucx.utils import ProgressReporter, ThreadedExecution @@ -57,7 +57,7 @@ def logical_object_types(self) -> list[LogicalObjectType]: def __init__( self, - ws: ImprovedWorkspaceClient, + ws: WorkspaceClient, logical_object_type: LogicalObjectType, request_object_type: RequestObjectType, listing_function: Callable[..., Iterator[InventoryObject]], @@ -72,9 +72,14 @@ def __init__( self._permissions_function = permissions_function if permissions_function else self._safe_get_permissions self._objects: list[InventoryObject] = [] + @sleep_and_retry + @limits(calls=100, period=1) + def _get_permissions(self, request_object_type: RequestObjectType, request_object_id: str): + return self._ws.permissions.get(request_object_type=request_object_type, request_object_id=request_object_id) + def _safe_get_permissions(self, request_object_type: RequestObjectType, object_id: str) -> ObjectPermissions | None: try: - permissions = self._ws.get_permissions(request_object_type, object_id) + permissions = self._get_permissions(request_object_type, object_id) return permissions except DatabricksError as e: if e.error_code in ["RESOURCE_DOES_NOT_EXIST", "RESOURCE_NOT_FOUND", "PERMISSION_DENIED"]: @@ -119,7 +124,7 @@ class TokensAndPasswordsInventorizer(BaseInventorizer[InventoryObject]): def logical_object_types(self) -> list[LogicalObjectType]: return [LogicalObjectType.TOKEN, LogicalObjectType.PASSWORD] - def __init__(self, ws: ImprovedWorkspaceClient): + def __init__(self, ws: WorkspaceClient): self._ws = ws self._tokens_acl = [] self._passwords_acl = [] @@ -188,7 +193,7 @@ class SecretScopeInventorizer(BaseInventorizer[InventoryObject]): def logical_object_types(self) -> list[LogicalObjectType]: return [LogicalObjectType.SECRET_SCOPE] - def __init__(self, ws: ImprovedWorkspaceClient): + def __init__(self, ws: WorkspaceClient): self._ws = ws self._scopes = ws.secrets.list_scopes() @@ -221,7 +226,7 @@ class WorkspaceInventorizer(BaseInventorizer[InventoryObject]): def logical_object_types(self) -> list[LogicalObjectType]: return [LogicalObjectType.NOTEBOOK, LogicalObjectType.DIRECTORY, LogicalObjectType.REPO, LogicalObjectType.FILE] - def __init__(self, ws: ImprovedWorkspaceClient, num_threads=20, start_path: str | None = "/"): + def __init__(self, ws: WorkspaceClient, num_threads=20, start_path: str | None = "/"): self._ws = ws self.listing = WorkspaceListing( ws, @@ -262,13 +267,18 @@ def __convert_request_object_type_to_logical_type(request_object_type: RequestOb case RequestObjectType.FILES: return LogicalObjectType.FILE + @sleep_and_retry + @limits(calls=100, period=1) + def _get_permissions(self, request_object_type: RequestObjectType, request_object_id: str): + return self._ws.permissions.get(request_object_type=request_object_type, request_object_id=request_object_id) + def _convert_result_to_permission_item(self, _object: ObjectInfo) -> PermissionsInventoryItem | None: request_object_type = self.__convert_object_type_to_request_type(_object) if not request_object_type: return else: try: - permissions = self._ws.get_permissions( + permissions = self._get_permissions( request_object_type=request_object_type, request_object_id=_object.object_id ) except DatabricksError as e: @@ -306,7 +316,7 @@ class RolesAndEntitlementsInventorizer(BaseInventorizer[InventoryObject]): def logical_object_types(self) -> list[LogicalObjectType]: return [LogicalObjectType.ROLES, LogicalObjectType.ENTITLEMENTS] - def __init__(self, ws: ImprovedWorkspaceClient, migration_state: GroupMigrationState): + def __init__(self, ws: WorkspaceClient, migration_state: GroupMigrationState): self._ws = ws self._migration_state = migration_state self._group_info: list[Group] = [] @@ -363,7 +373,7 @@ def inner() -> Iterator[ModelDatabricks]: class Inventorizers: @staticmethod - def provide(ws: ImprovedWorkspaceClient, migration_state: GroupMigrationState, num_threads: int): + def provide(ws: WorkspaceClient, migration_state: GroupMigrationState, num_threads: int): return [ RolesAndEntitlementsInventorizer(ws, migration_state), TokensAndPasswordsInventorizer(ws), diff --git a/src/databricks/labs/ucx/inventory/listing.py b/src/databricks/labs/ucx/inventory/listing.py index 2830fd98ef..b8a4411426 100644 --- a/src/databricks/labs/ucx/inventory/listing.py +++ b/src/databricks/labs/ucx/inventory/listing.py @@ -1,11 +1,12 @@ import datetime as dt import logging +from collections.abc import Iterator from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait from itertools import groupby +from databricks.sdk import WorkspaceClient from databricks.sdk.service.workspace import ObjectInfo, ObjectType - -from databricks.labs.ucx.providers.client import ImprovedWorkspaceClient +from ratelimit import limits, sleep_and_retry logger = logging.getLogger(__name__) @@ -13,7 +14,7 @@ class WorkspaceListing: def __init__( self, - ws: ImprovedWorkspaceClient, + ws: WorkspaceClient, num_threads: int, *, with_directories: bool = True, @@ -39,12 +40,16 @@ def _progress_report(self, _): f" rps: {rps:.3f}/sec" ) + @sleep_and_retry + @limits(calls=45, period=1) # safety value, can be 50 actually + def _list_workspace(self, path: str) -> Iterator[ObjectType]: + # TODO: remove, use SDK + return self._ws.workspace.list(path=path, recursive=False) + def _list_and_analyze(self, obj: ObjectInfo) -> (list[ObjectInfo], list[ObjectInfo]): directories = [] others = [] - grouped_iterator = groupby( - self._ws.list_workspace(obj.path), key=lambda x: x.object_type == ObjectType.DIRECTORY - ) + grouped_iterator = groupby(self._list_workspace(obj.path), key=lambda x: x.object_type == ObjectType.DIRECTORY) for is_directory, objects in grouped_iterator: if is_directory: directories.extend(list(objects)) diff --git a/src/databricks/labs/ucx/inventory/permissions.py b/src/databricks/labs/ucx/inventory/permissions.py index 85578bb095..b5cf9bf610 100644 --- a/src/databricks/labs/ucx/inventory/permissions.py +++ b/src/databricks/labs/ucx/inventory/permissions.py @@ -1,3 +1,4 @@ +import json import logging import random import time @@ -6,8 +7,10 @@ from functools import partial from typing import Literal +from databricks.sdk import WorkspaceClient from databricks.sdk.service.iam import AccessControlRequest, Group, ObjectPermissions from databricks.sdk.service.workspace import AclItem as SdkAclItem +from ratelimit import limits, sleep_and_retry from tenacity import retry, stop_after_attempt, wait_fixed, wait_random from databricks.labs.ucx.inventory.inventorizer import BaseInventorizer @@ -19,7 +22,6 @@ RequestObjectType, RolesAndEntitlements, ) -from databricks.labs.ucx.providers.client import ImprovedWorkspaceClient from databricks.labs.ucx.providers.groups_info import GroupMigrationState from databricks.labs.ucx.utils import ThreadedExecution, safe_get_acls @@ -51,7 +53,7 @@ class RolesAndEntitlementsRequestPayload: # TODO: this class has too many @staticmethod and they must not be such. write a unit test for this logic. class PermissionManager: - def __init__(self, ws: ImprovedWorkspaceClient, inventory_table_manager: InventoryTableManager): + def __init__(self, ws: WorkspaceClient, inventory_table_manager: InventoryTableManager): self._ws = ws self.inventory_table_manager = inventory_table_manager self._inventorizers = [] @@ -195,8 +197,22 @@ def _scope_permissions_applicator(self, request_payload: SecretsPermissionReques f"Expected: {_acl_item.permission}. Actual: {applied_acls.permission}" ) + @sleep_and_retry + @limits(calls=30, period=1) + def _update_permissions( + self, + request_object_type: RequestObjectType, + request_object_id: str, + access_control_list: list[AccessControlRequest], + ): + return self._ws.permissions.update( + request_object_type=request_object_type, + request_object_id=request_object_id, + access_control_list=access_control_list, + ) + def _standard_permissions_applicator(self, request_payload: PermissionRequestPayload): - self._ws.update_permissions( + self._update_permissions( request_object_type=request_payload.request_object_type, request_object_id=request_payload.object_id, access_control_list=request_payload.access_control_list, @@ -204,7 +220,7 @@ def _standard_permissions_applicator(self, request_payload: PermissionRequestPay def applicator(self, request_payload: AnyRequestPayload): if isinstance(request_payload, RolesAndEntitlementsRequestPayload): - self._ws.apply_roles_and_entitlements( + self._apply_roles_and_entitlements( group_id=request_payload.group_id, roles=request_payload.payload.roles, entitlements=request_payload.payload.entitlements, @@ -216,6 +232,49 @@ def applicator(self, request_payload: AnyRequestPayload): else: logger.warning(f"Unsupported payload type {type(request_payload)}") + @sleep_and_retry + @limits(calls=10, period=1) # assumption + def _apply_roles_and_entitlements(self, group_id: str, roles: list, entitlements: list): + # TODO: move to other places, this won't be in SDK + op_schema = "urn:ietf:params:scim:api:messages:2.0:PatchOp" + schemas = [] + operations = [] + + if entitlements: + schemas.append(op_schema) + entitlements_payload = { + "op": "add", + "path": "entitlements", + "value": entitlements, + } + operations.append(entitlements_payload) + + if roles: + schemas.append(op_schema) + roles_payload = { + "op": "add", + "path": "roles", + "value": roles, + } + operations.append(roles_payload) + + if operations: + request = { + "schemas": schemas, + "Operations": operations, + } + self._patch_workspace_group(group_id, request) + + def _patch_workspace_group(self, group_id: str, payload: dict): + # TODO: replace usages + # self.groups.patch(group_id, + # schemas=[PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], + # operations=[ + # Patch(op=PatchOp.ADD, path='..', value='...') + # ]) + path = f"/api/2.0/preview/scim/v2/Groups/{group_id}" + self._ws.api_client.do("PATCH", path, data=json.dumps(payload)) + def _apply_permissions_in_parallel( self, requests: list[AnyRequestPayload], diff --git a/src/databricks/labs/ucx/managers/group.py b/src/databricks/labs/ucx/managers/group.py index 3e7e9e330f..47e1aa3b9f 100644 --- a/src/databricks/labs/ucx/managers/group.py +++ b/src/databricks/labs/ucx/managers/group.py @@ -1,12 +1,14 @@ +import json import logging import typing from functools import partial +from databricks.sdk import WorkspaceClient from databricks.sdk.service.iam import Group +from ratelimit import limits, sleep_and_retry from databricks.labs.ucx.config import GroupsConfig from databricks.labs.ucx.generic import StrEnum -from databricks.labs.ucx.providers.client import ImprovedWorkspaceClient from databricks.labs.ucx.providers.groups_info import ( GroupMigrationState, MigrationGroupInfo, @@ -24,7 +26,7 @@ class GroupLevel(StrEnum): class GroupManager: SYSTEM_GROUPS: typing.ClassVar[list[str]] = ["users", "admins", "account users"] - def __init__(self, ws: ImprovedWorkspaceClient, groups: GroupsConfig): + def __init__(self, ws: WorkspaceClient, groups: GroupsConfig): self._ws = ws self.config = groups self._migration_state: GroupMigrationState = GroupMigrationState() @@ -39,9 +41,18 @@ def _find_eligible_groups(self) -> list[str]: logger.info(f"Found {len(eligible_groups)} eligible groups") return [g.display_name for g in eligible_groups] + @sleep_and_retry + @limits(calls=100, period=1) # assumption + def _list_account_level_groups( + self, filter: str, attributes: str | None = None, excluded_attributes: str | None = None # noqa: A002 + ) -> list[Group]: + query = {"filter": filter, "attributes": attributes, "excludedAttributes": excluded_attributes} + response = self._ws.api_client.do("GET", "/api/2.0/account/scim/v2/Groups", query=query) + return [Group.from_dict(v) for v in response.get("Resources", [])] + def _get_group(self, group_name, level: GroupLevel) -> Group | None: # TODO: calling this can cause issues for SCIM backend, cache groups instead - method = self._ws.groups.list if level == GroupLevel.WORKSPACE else self._ws.list_account_level_groups + method = self._ws.groups.list if level == GroupLevel.WORKSPACE else self._list_account_level_groups query_filter = f"displayName eq '{group_name}'" attributes = ",".join(["id", "displayName", "meta", "entitlements", "roles", "members"]) @@ -99,7 +110,20 @@ def _replace_group(self, migration_info: MigrationGroupInfo): else: logger.warning(f"Workspace-level group {ws_group.display_name} does not exist, skipping") - self._ws.reflect_account_group_to_workspace(acc_group) + self._reflect_account_group_to_workspace(acc_group) + + @sleep_and_retry + @limits(calls=5, period=1) # assumption + def _reflect_account_group_to_workspace(self, acc_group: Group) -> None: + logger.info(f"Reflecting group {acc_group.display_name} to workspace") + + # TODO: add OpenAPI spec for it + principal_id = acc_group.id + permissions = ["USER"] + path = f"/api/2.0/preview/permissionassignments/principals/{principal_id}" + self._ws.api_client.do("PUT", path, data=json.dumps({"permissions": permissions})) + + logger.info(f"Group {acc_group.display_name} successfully reflected to workspace") # please keep the public methods below this line diff --git a/src/databricks/labs/ucx/providers/client.py b/src/databricks/labs/ucx/providers/client.py deleted file mode 100644 index 76c7da81d7..0000000000 --- a/src/databricks/labs/ucx/providers/client.py +++ /dev/null @@ -1,110 +0,0 @@ -import json -import logging -from collections.abc import Iterator - -from databricks.sdk import WorkspaceClient -from databricks.sdk.service.iam import AccessControlRequest, Group -from databricks.sdk.service.workspace import ObjectType -from ratelimit import limits, sleep_and_retry - -from databricks.labs.ucx.inventory.types import RequestObjectType - -logger = logging.getLogger(__name__) - - -class ImprovedWorkspaceClient(WorkspaceClient): - # *** - # *** CAUTION: DO NOT ADD ANY METHODS THAT WON'T END UP IN THE SDK *** - # *** - # to this class we add rate-limited methods to make calls to various APIs - # source info - https://docs.databricks.com/resources/limits.html - - @sleep_and_retry - @limits(calls=5, period=1) # assumption - def assign_permissions(self, principal_id: str, permissions: list[str]): - # TODO: add OpenAPI spec for it - request_string = f"/api/2.0/preview/permissionassignments/principals/{principal_id}" - self.api_client.do("put", request_string, data=json.dumps({"permissions": permissions})) - - @sleep_and_retry - @limits(calls=10, period=1) # assumption - def patch_workspace_group(self, group_id: str, payload: dict): - # TODO: replace usages - # self.groups.patch(group_id, - # schemas=[PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], - # operations=[ - # Patch(op=PatchOp.ADD, path='..', value='...') - # ]) - path = f"/api/2.0/preview/scim/v2/Groups/{group_id}" - self.api_client.do("PATCH", path, data=json.dumps(payload)) - - @sleep_and_retry - @limits(calls=100, period=1) # assumption - def list_account_level_groups( - self, filter: str, attributes: str | None = None, excluded_attributes: str | None = None # noqa: A002 - ) -> list[Group]: - # TODO: move to other places, this won't be in SDK - query = {"filter": filter, "attributes": attributes, "excludedAttributes": excluded_attributes} - response = self.api_client.do("get", "/api/2.0/account/scim/v2/Groups", query=query) - return [Group.from_dict(v) for v in response.get("Resources", [])] - - def reflect_account_group_to_workspace(self, acc_group: Group) -> None: - logger.info(f"Reflecting group {acc_group.display_name} to workspace") - self.assign_permissions(principal_id=acc_group.id, permissions=["USER"]) - logger.info(f"Group {acc_group.display_name} successfully reflected to workspace") - - @sleep_and_retry - @limits(calls=45, period=1) # safety value, can be 50 actually - def list_workspace(self, path: str) -> Iterator[ObjectType]: - # TODO: remove, use SDK - return self.workspace.list(path=path, recursive=False) - - @sleep_and_retry - @limits(calls=100, period=1) - def get_permissions(self, request_object_type: RequestObjectType, request_object_id: str): - return self.permissions.get(request_object_type=request_object_type, request_object_id=request_object_id) - - @sleep_and_retry - @limits(calls=30, period=1) - def update_permissions( - self, - request_object_type: RequestObjectType, - request_object_id: str, - access_control_list: list[AccessControlRequest], - ): - return self.permissions.update( - request_object_type=request_object_type, - request_object_id=request_object_id, - access_control_list=access_control_list, - ) - - def apply_roles_and_entitlements(self, group_id: str, roles: list, entitlements: list): - # TODO: move to other places, this won't be in SDK - op_schema = "urn:ietf:params:scim:api:messages:2.0:PatchOp" - schemas = [] - operations = [] - - if entitlements: - schemas.append(op_schema) - entitlements_payload = { - "op": "add", - "path": "entitlements", - "value": entitlements, - } - operations.append(entitlements_payload) - - if roles: - schemas.append(op_schema) - roles_payload = { - "op": "add", - "path": "roles", - "value": roles, - } - operations.append(roles_payload) - - if operations: - request = { - "schemas": schemas, - "Operations": operations, - } - self.patch_workspace_group(group_id, request) diff --git a/src/databricks/labs/ucx/toolkits/group_migration.py b/src/databricks/labs/ucx/toolkits/group_migration.py index 299ff40a6a..299b911024 100644 --- a/src/databricks/labs/ucx/toolkits/group_migration.py +++ b/src/databricks/labs/ucx/toolkits/group_migration.py @@ -1,11 +1,12 @@ import logging +from databricks.sdk import WorkspaceClient + from databricks.labs.ucx.config import MigrationConfig from databricks.labs.ucx.inventory.inventorizer import Inventorizers from databricks.labs.ucx.inventory.permissions import PermissionManager from databricks.labs.ucx.inventory.table import InventoryTableManager from databricks.labs.ucx.managers.group import GroupManager -from databricks.labs.ucx.providers.client import ImprovedWorkspaceClient class GroupMigrationToolkit: @@ -17,7 +18,7 @@ def __init__(self, config: MigrationConfig): # integrate with connection pool settings properly # https://github.com/databricks/databricks-sdk-py/pull/276 - self._ws = ImprovedWorkspaceClient(config=databricks_config) + self._ws = WorkspaceClient(config=databricks_config) self._ws.api_client._session.adapters["https://"].max_retries.total = 20 self._verify_ws_client(self._ws) @@ -26,7 +27,7 @@ def __init__(self, config: MigrationConfig): self.permissions_manager = PermissionManager(self._ws, self.table_manager) @staticmethod - def _verify_ws_client(w: ImprovedWorkspaceClient): + def _verify_ws_client(w: WorkspaceClient): _me = w.current_user.me() is_workspace_admin = any(g.display == "admins" for g in _me.groups) if not is_workspace_admin: diff --git a/src/databricks/labs/ucx/utils.py b/src/databricks/labs/ucx/utils.py index de74b9ac9c..76e841362c 100644 --- a/src/databricks/labs/ucx/utils.py +++ b/src/databricks/labs/ucx/utils.py @@ -5,10 +5,10 @@ from concurrent.futures import ALL_COMPLETED, ThreadPoolExecutor from typing import Generic, TypeVar +from databricks.sdk import WorkspaceClient from databricks.sdk.service.workspace import AclItem from databricks.labs.ucx.generic import StrEnum -from databricks.labs.ucx.providers.client import ImprovedWorkspaceClient ExecutableResult = TypeVar("ExecutableResult") ExecutableFunction = Callable[..., ExecutableResult] @@ -85,7 +85,7 @@ class WorkspaceLevelEntitlement(StrEnum): ALLOW_INSTANCE_POOL_CREATE = "allow-instance-pool-create" -def safe_get_acls(ws: ImprovedWorkspaceClient, scope_name: str, group_name: str) -> AclItem | None: +def safe_get_acls(ws: WorkspaceClient, scope_name: str, group_name: str) -> AclItem | None: all_acls = ws.secrets.list_acls(scope=scope_name) for acl in all_acls: if acl.principal == group_name: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 4cb5a156ac..3da1451401 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -9,7 +9,7 @@ import databricks.sdk.core import pytest from _pytest.fixtures import SubRequest -from databricks.sdk import AccountClient +from databricks.sdk import AccountClient, WorkspaceClient from databricks.sdk.core import Config, DatabricksError from databricks.sdk.service.compute import ( ClusterDetails, @@ -38,7 +38,6 @@ from databricks.labs.ucx.config import InventoryTable from databricks.labs.ucx.inventory.types import RequestObjectType -from databricks.labs.ucx.providers.client import ImprovedWorkspaceClient from databricks.labs.ucx.providers.mixins.fixtures import * # noqa: F403 from databricks.labs.ucx.providers.mixins.sql import StatementExecutionExt from databricks.labs.ucx.utils import ThreadedExecution @@ -87,13 +86,6 @@ def account_host(self: databricks.sdk.core.Config) -> str: return "https://accounts.cloud.databricks.com" -@pytest.fixture(scope="session") -def ws() -> ImprovedWorkspaceClient: - # Use variables from Unified Auth - # See https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html - return ImprovedWorkspaceClient() - - @pytest.fixture(scope="session") def acc(ws) -> AccountClient: # TODO: move to SDK @@ -111,14 +103,14 @@ def account_host(cfg: Config) -> str: @pytest.fixture -def sql_exec(ws: ImprovedWorkspaceClient): +def sql_exec(ws: WorkspaceClient): warehouse_id = os.environ["TEST_DEFAULT_WAREHOUSE_ID"] statement_execution = StatementExecutionExt(ws.api_client) return partial(statement_execution.execute, warehouse_id) @pytest.fixture -def sql_fetch_all(ws: ImprovedWorkspaceClient): +def sql_fetch_all(ws: WorkspaceClient): warehouse_id = os.environ["TEST_DEFAULT_WAREHOUSE_ID"] statement_execution = StatementExecutionExt(ws.api_client) return partial(statement_execution.execute_fetch_all, warehouse_id) @@ -229,7 +221,7 @@ def test_table_fixture(make_table): @pytest.fixture(scope="session") -def env(ws: ImprovedWorkspaceClient, acc: AccountClient, request: SubRequest) -> EnvironmentInfo: +def env(ws: WorkspaceClient, acc: AccountClient, request: SubRequest) -> EnvironmentInfo: # prepare environment test_uid = f"{UCX_TESTING_PREFIX}_{str(uuid.uuid4())[:8]}" logger.debug(f"Creating environment with uid {test_uid}") @@ -270,7 +262,7 @@ def _wrapped(*args, **kwargs): @pytest.fixture(scope="session") -def instance_profiles(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[InstanceProfile]: +def instance_profiles(env: EnvironmentInfo, ws: WorkspaceClient) -> list[InstanceProfile]: logger.debug("Adding test instance profiles") profiles: list[InstanceProfile] = [] @@ -304,7 +296,7 @@ def instance_profiles(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list @pytest.fixture(scope="session") -def instance_pools(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[CreateInstancePoolResponse]: +def instance_pools(env: EnvironmentInfo, ws: WorkspaceClient) -> list[CreateInstancePoolResponse]: logger.debug("Creating test instance pools") test_instance_pools: list[CreateInstancePoolResponse] = [ @@ -329,7 +321,7 @@ def instance_pools(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[Cr @pytest.fixture(scope="session") -def pipelines(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[CreatePipelineResponse]: +def pipelines(env: EnvironmentInfo, ws: WorkspaceClient) -> list[CreatePipelineResponse]: logger.debug("Creating test DLT pipelines") test_pipelines: list[CreatePipelineResponse] = [ @@ -359,7 +351,7 @@ def pipelines(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[CreateP @pytest.fixture(scope="session") -def jobs(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[CreateResponse]: +def jobs(env: EnvironmentInfo, ws: WorkspaceClient) -> list[CreateResponse]: logger.debug("Creating test jobs") test_jobs: list[CreateResponse] = [ @@ -386,7 +378,7 @@ def jobs(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[CreateRespon @pytest.fixture(scope="session") -def cluster_policies(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[CreatePolicyResponse]: +def cluster_policies(env: EnvironmentInfo, ws: WorkspaceClient) -> list[CreatePolicyResponse]: logger.debug("Creating test cluster policies") test_cluster_policies: list[CreatePolicyResponse] = [ @@ -421,7 +413,7 @@ def cluster_policies(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[ @pytest.fixture(scope="session") -def clusters(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[ClusterDetails]: +def clusters(env: EnvironmentInfo, ws: WorkspaceClient) -> list[ClusterDetails]: logger.debug("Creating test clusters") creators = [ @@ -456,7 +448,7 @@ def clusters(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[ClusterD @pytest.fixture(scope="session") -def experiments(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[CreateExperimentResponse]: +def experiments(ws: WorkspaceClient, env: EnvironmentInfo) -> list[CreateExperimentResponse]: logger.debug("Creating test experiments") try: @@ -489,7 +481,7 @@ def experiments(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[Creat @pytest.fixture(scope="session") -def models(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[ModelDatabricks]: +def models(ws: WorkspaceClient, env: EnvironmentInfo) -> list[ModelDatabricks]: logger.debug("Creating models") test_models: list[ModelDatabricks] = [ @@ -522,7 +514,7 @@ def models(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[ModelDatab @pytest.fixture(scope="session") -def warehouses(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[GetWarehouseResponse]: +def warehouses(ws: WorkspaceClient, env: EnvironmentInfo) -> list[GetWarehouseResponse]: logger.debug("Creating warehouses") creators = [ @@ -557,7 +549,7 @@ def warehouses(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[GetWar @pytest.fixture(scope="session") -def tokens(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[AccessControlRequest]: +def tokens(ws: WorkspaceClient, env: EnvironmentInfo) -> list[AccessControlRequest]: logger.debug("Adding token-level permissions to groups") token_permissions = [ @@ -575,7 +567,7 @@ def tokens(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[AccessCont @pytest.fixture(scope="session") -def secret_scopes(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[SecretScope]: +def secret_scopes(ws: WorkspaceClient, env: EnvironmentInfo) -> list[SecretScope]: logger.debug("Creating test secret scopes") for i in range(NUM_TEST_SECRET_SCOPES): @@ -596,7 +588,7 @@ def secret_scopes(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[Sec @pytest.fixture(scope="session") -def workspace_objects(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> WorkspaceObjects: +def workspace_objects(ws: WorkspaceClient, env: EnvironmentInfo) -> WorkspaceObjects: logger.info(f"Creating test workspace objects under /{env.test_uid}") ws.workspace.mkdirs(f"/{env.test_uid}") @@ -681,7 +673,7 @@ def verifiable_objects( @pytest.fixture() -def inventory_table(env: EnvironmentInfo, ws: ImprovedWorkspaceClient, make_catalog, make_schema) -> InventoryTable: +def inventory_table(env: EnvironmentInfo, ws: WorkspaceClient, make_catalog, make_schema) -> InventoryTable: catalog, schema = make_schema(make_catalog()).split(".") table = InventoryTable( catalog=catalog, diff --git a/tests/integration/test_e2e.py b/tests/integration/test_e2e.py index 923fa688db..29534bce15 100644 --- a/tests/integration/test_e2e.py +++ b/tests/integration/test_e2e.py @@ -2,6 +2,7 @@ from typing import Literal import pytest +from databricks.sdk import WorkspaceClient from databricks.sdk.service.iam import ( AccessControlRequest, AccessControlResponse, @@ -19,7 +20,6 @@ MigrationConfig, ) from databricks.labs.ucx.inventory.types import RequestObjectType -from databricks.labs.ucx.providers.client import ImprovedWorkspaceClient from databricks.labs.ucx.providers.groups_info import GroupMigrationState from databricks.labs.ucx.toolkits.group_migration import GroupMigrationToolkit from databricks.labs.ucx.utils import safe_get_acls @@ -33,7 +33,7 @@ def _verify_group_permissions( objects: list | WorkspaceObjects | None, id_attribute: str, request_object_type: RequestObjectType | None, - ws: ImprovedWorkspaceClient, + ws: WorkspaceClient, toolkit: GroupMigrationToolkit, target: Literal["backup", "account"], ): @@ -149,7 +149,7 @@ def _verify_group_permissions( def _verify_roles_and_entitlements( migration_state: GroupMigrationState, - ws: ImprovedWorkspaceClient, + ws: WorkspaceClient, target: Literal["backup", "account"], ): for el in migration_state.groups: @@ -166,7 +166,7 @@ def _verify_roles_and_entitlements( def test_e2e( env: EnvironmentInfo, inventory_table: InventoryTable, - ws: ImprovedWorkspaceClient, + ws: WorkspaceClient, verifiable_objects: list[tuple[list, str, RequestObjectType | None]], ): logger.debug(f"Test environment: {env.test_uid}") @@ -192,7 +192,7 @@ def test_e2e( toolkit.group_manager.migration_groups_provider.groups ) - assert len(ws.list_account_level_groups(filter=f"displayName sw '{env.test_uid}'")) == len( + assert len(toolkit.group_manager._list_account_level_groups(filter=f"displayName sw '{env.test_uid}'")) == len( toolkit.group_manager.migration_groups_provider.groups ) diff --git a/tests/integration/test_jobs.py b/tests/integration/test_jobs.py index c67b497c0e..e5cdfc1b07 100644 --- a/tests/integration/test_jobs.py +++ b/tests/integration/test_jobs.py @@ -1,6 +1,7 @@ import logging import pytest +from databricks.sdk import WorkspaceClient from pyspark.errors import AnalysisException from databricks.labs.ucx.config import ( @@ -11,7 +12,6 @@ MigrationConfig, ) from databricks.labs.ucx.inventory.types import RequestObjectType -from databricks.labs.ucx.providers.client import ImprovedWorkspaceClient from databricks.labs.ucx.toolkits.group_migration import GroupMigrationToolkit from .test_e2e import _verify_group_permissions, _verify_roles_and_entitlements @@ -23,7 +23,7 @@ def test_jobs( env: EnvironmentInfo, inventory_table: InventoryTable, - ws: ImprovedWorkspaceClient, + ws: WorkspaceClient, jobs, ): logger.debug(f"Test environment: {env.test_uid}") @@ -49,7 +49,7 @@ def test_jobs( toolkit.group_manager.migration_groups_provider.groups ) - assert len(ws.list_account_level_groups(filter=f"displayName sw '{env.test_uid}'")) == len( + assert len(toolkit.group_manager._list_account_level_groups(filter=f"displayName sw '{env.test_uid}'")) == len( toolkit.group_manager.migration_groups_provider.groups ) diff --git a/tests/integration/test_tacls.py b/tests/integration/test_tacls.py index 4f1c90e241..fb12fa0910 100644 --- a/tests/integration/test_tacls.py +++ b/tests/integration/test_tacls.py @@ -1,13 +1,14 @@ import logging import os -from databricks.labs.ucx.providers.client import ImprovedWorkspaceClient +from databricks.sdk import WorkspaceClient + from databricks.labs.ucx.toolkits.table_acls import TaclToolkit logger = logging.getLogger(__name__) -def test_describe_all_tables(ws: ImprovedWorkspaceClient, make_catalog, make_schema, make_table): +def test_describe_all_tables(ws: WorkspaceClient, make_catalog, make_schema, make_table): warehouse_id = os.environ["TEST_DEFAULT_WAREHOUSE_ID"] logger.info("setting up fixtures") @@ -42,9 +43,7 @@ def test_describe_all_tables(ws: ImprovedWorkspaceClient, make_catalog, make_sch assert all_tables[view].view_text == "SELECT 2+2 AS four" -def test_all_grants_in_database( - ws: ImprovedWorkspaceClient, sql_exec, make_catalog, make_schema, make_table, make_group -): +def test_all_grants_in_database(ws: WorkspaceClient, sql_exec, make_catalog, make_schema, make_table, make_group): warehouse_id = os.environ["TEST_DEFAULT_WAREHOUSE_ID"] group_a = make_group() diff --git a/tests/integration/utils.py b/tests/integration/utils.py index e9cd656786..7d886a6aa1 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -17,7 +17,6 @@ from databricks.sdk.service.workspace import ObjectInfo from databricks.labs.ucx.inventory.types import RequestObjectType -from databricks.labs.ucx.providers.client import ImprovedWorkspaceClient from databricks.labs.ucx.utils import WorkspaceLevelEntitlement logger = logging.getLogger(__name__) @@ -55,7 +54,7 @@ def get_random_entitlements(): def _create_groups( - _ws: ImprovedWorkspaceClient, _acc: AccountClient, prefix: str, num_test_groups: int, threader: callable + _ws: WorkspaceClient, _acc: AccountClient, prefix: str, num_test_groups: int, threader: callable ) -> list[tuple[Group, Group]]: logger.debug("Listing users to create sample groups") test_users = list(_ws.users.list(filter="displayName sw 'test-user-'", attributes="id, userName, displayName")) @@ -93,7 +92,7 @@ def _set_random_permissions( id_attribute: str, request_object_type: RequestObjectType, env: EnvironmentInfo, - ws: ImprovedWorkspaceClient, + ws: WorkspaceClient, permission_levels: list[PermissionLevel], num_acls: int | None = 3, ): diff --git a/tests/unit/test_listing.py b/tests/unit/test_listing.py index a7b37c1898..36aded06b3 100644 --- a/tests/unit/test_listing.py +++ b/tests/unit/test_listing.py @@ -24,7 +24,7 @@ def test_list_and_analyze_should_separate_folders_and_other_objects(): notebook = ObjectInfo(path="/rootPath/notebook", object_type=ObjectType.NOTEBOOK) client = Mock() - client.list_workspace.return_value = [file, directory, notebook] + client.workspace.list.return_value = [file, directory, notebook] listing = WorkspaceListing(client, 1) directories, others = listing._list_and_analyze(rootobj) @@ -37,7 +37,7 @@ def test_walk_with_an_empty_folder_should_return_it(): rootobj = ObjectInfo(path="/rootPath") client = Mock() - client.list_workspace.return_value = [] + client.workspace.list.return_value = [] client.workspace.get_status.return_value = rootobj listing = WorkspaceListing(client, 1) @@ -53,7 +53,7 @@ def test_walk_with_two_files_should_return_rootpath_and_two_files(): notebook = ObjectInfo(path="/rootPath/notebook", object_type=ObjectType.NOTEBOOK) client = Mock() - client.list_workspace.return_value = [file, notebook] + client.workspace.list.return_value = [file, notebook] client.workspace.get_status.return_value = rootobj listing = WorkspaceListing(client, 1) @@ -69,14 +69,14 @@ def test_walk_with_nested_folders_should_return_nested_objects(): nested_folder = ObjectInfo(path="/rootPath/nested_folder", object_type=ObjectType.DIRECTORY) nested_notebook = ObjectInfo(path="/rootPath/nested_folder/notebook", object_type=ObjectType.NOTEBOOK) - def my_side_effect(*args): - if args[0] == "/rootPath": + def my_side_effect(path, **kwargs): # noqa: ARG001 + if path == "/rootPath": return [file, nested_folder] - elif args[0] == "/rootPath/nested_folder": + elif path == "/rootPath/nested_folder": return [nested_notebook] client = Mock() - client.list_workspace.side_effect = my_side_effect + client.workspace.list.side_effect = my_side_effect client.workspace.get_status.return_value = rootobj listing = WorkspaceListing(client, 1) @@ -98,16 +98,16 @@ def test_walk_with_three_level_nested_folders_returns_three_levels(): path="/rootPath/nested_folder/second_nested_folder/notebook2", object_type=ObjectType.NOTEBOOK ) - def my_side_effect(*args): - if args[0] == "/rootPath": + def my_side_effect(path, **kwargs): # noqa: ARG001 + if path == "/rootPath": return [file, nested_folder] - elif args[0] == "/rootPath/nested_folder": + elif path == "/rootPath/nested_folder": return [nested_notebook, second_nested_folder] - elif args[0] == "/rootPath/nested_folder/second_nested_folder": + elif path == "/rootPath/nested_folder/second_nested_folder": return [second_nested_notebook] client = Mock() - client.list_workspace.side_effect = my_side_effect + client.workspace.list.side_effect = my_side_effect client.workspace.get_status.return_value = rootobj listing = WorkspaceListing(client, 2) listing.walk("/rootPath")