Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New CLI command for workspace mapping #678

Merged
merged 8 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions labs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ commands:
{{range .}}{{.step}}\t{{.state}}\t{{.started}}
{{end}}


- name: installations
description: Show installations by different users on the same workspace
table_template: |-
User\tDatabase\tWarehouse
{{range .}}{{.user_name}}\t{{.database}}\t{{.warehouse_id}}
{{end}}

- name: sync-workspace-info
is_account_level: true
description: upload workspace config to all workspaces in the account where ucx is installed
161 changes: 52 additions & 109 deletions src/databricks/labs/ucx/account/workspaces.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import base64
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import ClassVar

import requests
from databricks.sdk import WorkspaceClient
from databricks.sdk.core import AzureCliTokenSource, Config, DatabricksError
from databricks.sdk.service.provisioning import PricingTier, Workspace
from databricks.sdk.core import DatabricksError
from databricks.sdk.errors import NotFound
from databricks.sdk.service.provisioning import Workspace
from databricks.sdk.service.workspace import ImportFormat
from requests.exceptions import ConnectionError

from databricks.labs.ucx.__about__ import __version__
Expand All @@ -17,86 +16,6 @@
logger = logging.getLogger(__name__)


@dataclass
class AzureSubscription:
name: str
subscription_id: str
tenant_id: str


class AzureWorkspaceLister:
def __init__(self, cfg: Config):
endpoint = cfg.arm_environment.resource_manager_endpoint
self._token_source = AzureCliTokenSource(endpoint)
self._endpoint = endpoint

def _get(self, path: str, *, api_version=None) -> dict:
token = self._token_source.token()
headers = {"Authorization": f"{token.token_type} {token.access_token}"}
return requests.get(
self._endpoint + path, headers=headers, params={"api-version": api_version}, timeout=10
).json()

def _all_subscriptions(self):
for sub in self._get("/subscriptions", api_version="2022-12-01").get("value", []):
yield AzureSubscription(
name=sub["displayName"], subscription_id=sub["subscriptionId"], tenant_id=sub["tenantId"]
)

def _tenant_id(self):
token = self._token_source.token()
_, payload, _ = token.access_token.split(".")
b64_decoded = base64.standard_b64decode(payload + "==").decode("utf8")
claims = json.loads(b64_decoded)
return claims["tid"]

def current_tenant_subscriptions(self):
tenant_id = self._tenant_id()
for sub in self._all_subscriptions():
if sub.tenant_id != tenant_id:
continue
yield sub

def subscriptions_name_to_id(self):
return {sub.name: sub.subscription_id for sub in self.current_tenant_subscriptions()}

def list_workspaces(self, subscription_id):
endpoint = f"/subscriptions/{subscription_id}/providers/Microsoft.Databricks/workspaces"
sku_tiers = {
"premium": PricingTier.PREMIUM,
"enterprise": PricingTier.ENTERPRISE,
"standard": PricingTier.STANDARD,
"unknown": PricingTier.UNKNOWN,
}
items = self._get(endpoint, api_version="2023-02-01").get("value", [])
for item in sorted(items, key=lambda _: _["name"].lower()):
properties = item["properties"]
if properties["provisioningState"] != "Succeeded":
continue
if "workspaceUrl" not in properties:
continue
parameters = properties.get("parameters", {})
workspace_url = properties["workspaceUrl"]
tags = item.get("tags", {})
if "AzureSubscriptionID" not in tags:
tags["AzureSubscriptionID"] = subscription_id
if "AzureResourceGroup" not in tags:
tags["AzureResourceGroup"] = item["id"].split("resourceGroups/")[1].split("/")[0]
yield Workspace(
cloud="azure",
location=item["location"],
workspace_name=item["name"],
workspace_id=int(properties["workspaceId"]),
workspace_status_message=properties["provisioningState"],
deployment_name=workspace_url.replace(".azuredatabricks.net", ""),
pricing_tier=sku_tiers.get(item.get("sku", {"name": None})["name"], None),
# These fields are just approximation for the fields with same meaning in AWS and GCP
storage_configuration_id=parameters.get("storageAccountName", {"value": None})["value"],
network_id=parameters.get("customVirtualNetworkId", {"value": None})["value"],
custom_tags=tags,
)


class Workspaces:
_tlds: ClassVar[dict[str, str]] = {
"aws": "cloud.databricks.com",
Expand All @@ -119,36 +38,60 @@ def configured_workspaces(self):
continue
yield workspace

def _get_cloud(self) -> str:
if self._ac.config.is_azure:
return "azure"
elif self._ac.config.is_gcp:
return "gcp"
return "aws"

def client_for(self, workspace: Workspace) -> WorkspaceClient:
config = self._ac.config.as_dict()
if "databricks_cli_path" in config:
del config["databricks_cli_path"]
# copy current config and swap with a host relevant to a workspace
config["host"] = f"https://{workspace.deployment_name}.{self._tlds[workspace.cloud]}"
config["host"] = f"https://{workspace.deployment_name}.{self._tlds[self._get_cloud()]}"
return WorkspaceClient(**config, product="ucx", product_version=__version__)

def _all_workspaces(self):
if self._ac.config.is_azure:
yield from self._azure_workspaces()
else:
yield from self._native_workspaces()

def _native_workspaces(self):
yield from self._ac.workspaces.list()

def _azure_workspaces(self):
azure_lister = AzureWorkspaceLister(self._ac.config)
for sub in azure_lister.current_tenant_subscriptions():
if self._cfg.include_azure_subscription_ids:
if sub.subscription_id not in self._cfg.include_azure_subscription_ids:
logger.debug(f"skipping {sub.name} ({sub.subscription_id} because its not explicitly included")
continue
if self._cfg.include_azure_subscription_names:
if sub.name not in self._cfg.include_azure_subscription_names:
logger.debug(f"skipping {sub.name} ({sub.subscription_id} because its not explicitly included")
def workspace_clients(self) -> list[WorkspaceClient]:
"""
Return a list of WorkspaceClient for each configured workspace in the account
:return: list[WorkspaceClient]
"""
clients = []
for workspace in self.configured_workspaces():
ws = self.client_for(workspace)
clients.append(ws)
return clients

def sync_workspace_info(self):
"""
Create a json dump for each Workspace in account
For each user that has ucx installed in their workspace,
upload the json dump of workspace info in the .ucx folder
:return:
"""
workspaces = []
for workspace in self._ac.workspaces.list():
workspaces.append(workspace.as_dict())
workspaces_json = json.dumps(workspaces, indent=2).encode("utf8")

workspaces_in_account = Workspaces(self._cfg)
for ws in workspaces_in_account.workspace_clients():
for user in ws.users.list(attributes="userName"):
try:
potential_install = f"/Users/{user.user_name}/.ucx"
ws.workspace.upload(
f"{potential_install}/workspaces.json",
workspaces_json,
overwrite=True,
format=ImportFormat.AUTO,
)
except NotFound:
continue
for workspace in azure_lister.list_workspaces(sub.subscription_id):
if "AzureSubscription" not in workspace.custom_tags:
workspace.custom_tags["AzureSubscription"] = sub.name
yield workspace

def _all_workspaces(self):
return self._ac.workspaces.list()


if __name__ == "__main__":
Expand Down
12 changes: 12 additions & 0 deletions src/databricks/labs/ucx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from databricks.sdk import WorkspaceClient

from databricks.labs.ucx.account.workspaces import Workspaces
from databricks.labs.ucx.config import AccountConfig
from databricks.labs.ucx.install import WorkspaceInstaller
from databricks.labs.ucx.installer import InstallationManager

Expand Down Expand Up @@ -34,10 +36,20 @@ def list_installations():
print(json.dumps(all_users))


def sync_workspace_info():
"""
Cli function to upload a mapping file to each ucx installation folder
:return:
"""
workspaces = Workspaces(AccountConfig())
workspaces.sync_workspace_info()


MAPPING = {
"open-remote-config": open_remote_config,
"installations": list_installations,
"workflows": workflows,
"sync-workspace-info": sync_workspace_info,
}


Expand Down
124 changes: 42 additions & 82 deletions tests/unit/account/test_workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import json

import pytest
from databricks.sdk.core import Config
from databricks.sdk.oauth import Token
from databricks.sdk.service.provisioning import PricingTier, Workspace
from databricks.sdk.service.iam import User
from databricks.sdk.service.provisioning import Workspace

from databricks.labs.ucx.account.workspaces import AzureWorkspaceLister, Workspaces
from databricks.labs.ucx.account.workspaces import Workspaces
from databricks.labs.ucx.config import AccountConfig, ConnectConfig


Expand Down Expand Up @@ -34,95 +34,35 @@ def mock_response(endpoint, **kwargs):
return inner


def test_subscriptions_name_to_id(arm_requests):
arm_requests(
{
"/subscriptions": {
"value": [
{"displayName": "first", "subscriptionId": "001", "tenantId": "xxx"},
{"displayName": "second", "subscriptionId": "002", "tenantId": "def_from_token"},
{"displayName": "third", "subscriptionId": "003", "tenantId": "def_from_token"},
]
}
}
)
cfg = Config(host="https://accounts.azuredatabricks.net", auth_type="azure-cli")

awl = AzureWorkspaceLister(cfg)
subs = awl.subscriptions_name_to_id()

assert {"second": "002", "third": "003"} == subs


def test_list_azure_workspaces(arm_requests):
arm_requests(
{
"/subscriptions": {
"value": [
{"displayName": "first", "subscriptionId": "001", "tenantId": "xxx"},
{"displayName": "second", "subscriptionId": "002", "tenantId": "def_from_token"},
{"displayName": "third", "subscriptionId": "003", "tenantId": "def_from_token"},
]
},
"/subscriptions/002/providers/Microsoft.Databricks/workspaces": {
"value": [
{
"id": ".../resourceGroups/first-rg/...",
"name": "first-workspace",
"location": "eastus",
"sku": {"name": "premium"},
"properties": {
"provisioningState": "Succeeded",
"workspaceUrl": "adb-123.10.azuredatabricks.net",
"workspaceId": "123",
},
},
{
"id": ".../resourceGroups/first-rg/...",
"name": "second-workspace",
"location": "eastus",
"sku": {"name": "premium"},
"properties": {
"provisioningState": "Succeeded",
"workspaceUrl": "adb-123.10.azuredatabricks.net",
"workspaceId": "123",
},
},
]
},
}
)
@pytest.fixture()
def account_workspaces_mock(mocker, arm_requests):
acc_cfg = AccountConfig()
acc_client = mocker.patch("databricks.sdk.AccountClient.__init__")
acc_cfg.to_databricks_config = lambda: acc_client
acc_client.config = mocker.Mock()
acc_client.config.as_dict = lambda: {}

wrksp = Workspaces(
AccountConfig(
connect=ConnectConfig(host="https://accounts.azuredatabricks.net"),
include_workspace_names=["first-workspace"],
include_azure_subscription_names=["second"],
)
acc_cfg.to_account_client = lambda: acc_client
acc_cfg.include_workspace_names = ["foo", "bar"]
mock_workspace1 = Workspace(
workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc"
)
mock_workspace2 = Workspace(
workspace_name="bar", workspace_id=456, workspace_status_message="Running", deployment_name="def"
)

all_workspaces = list(wrksp.configured_workspaces())

assert [
Workspace(
cloud="azure",
location="eastus",
workspace_id=123,
pricing_tier=PricingTier.PREMIUM,
workspace_name="first-workspace",
deployment_name="adb-123.10",
workspace_status_message="Succeeded",
custom_tags={"AzureResourceGroup": "first-rg", "AzureSubscription": "second", "AzureSubscriptionID": "002"},
)
] == all_workspaces
mock_user1 = User(user_name="jack")
acc_client.workspaces.users.list.return_value = [mock_user1]
acc_client.workspaces.list.return_value = [mock_workspace1, mock_workspace2]
return Workspaces(acc_cfg)


def test_client_for_workspace():
wrksp = Workspaces(
AccountConfig(
connect=ConnectConfig(
host="https://accounts.azuredatabricks.net",
azure_tenant_id="abc",
azure_tenant_id="abc.com",
azure_client_id="bcd",
azure_client_secret="def",
)
Expand All @@ -131,3 +71,23 @@ def test_client_for_workspace():
specified_workspace_client = wrksp.client_for(Workspace(cloud="azure", deployment_name="adb-123.10"))
assert "azure-client-secret" == specified_workspace_client.config.auth_type
assert "https://adb-123.10.azuredatabricks.net" == specified_workspace_client.config.host


def test_workspace_clients(account_workspaces_mock):
ws_clients = account_workspaces_mock.workspace_clients()
assert len(ws_clients) == 2
assert ws_clients[0].config.auth_type == "azure-cli"
assert ws_clients[0].config.host == "https://abc.azuredatabricks.net"


def test_configured_workspaces(account_workspaces_mock):
ws_clients = []
for ws in account_workspaces_mock.configured_workspaces():
ws_clients.append(account_workspaces_mock.client_for(ws))

# test for number of workspaces returned
assert len(ws_clients) == 2

# test for cloud and deployment name
assert ws_clients[1].config.auth_type == "azure-cli"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this auth type is failing tests, don't test for it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the mock back seems to allow the unit test to pass.

assert ws_clients[1].config.host == "https://def.azuredatabricks.net"
Loading