Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Azure service principal crawler and fix bug where tenant_id inside secret scope is not detected #942

Merged
merged 17 commits into from
Feb 15, 2024
2 changes: 1 addition & 1 deletion docs/assessment.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ These are Global Init Scripts that are incompatible with Unity Catalog compute.
# Assessment Finding Index
This section will help explain UCX Assessment findings and provide a recommended action.
The assessment finding index is grouped by:
- The 100 serieds findings are Databricks Runtime and compute configuration findings
- The 100 series findings are Databricks Runtime and compute configuration findings.
- The 200 series findings are centered around data related observations.

### AF101 - not supported DBR: ##.#.x-scala2.12
Expand Down
322 changes: 153 additions & 169 deletions src/databricks/labs/ucx/assessment/azure.py

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions src/databricks/labs/ucx/assessment/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
)

from databricks.labs.ucx.assessment.crawlers import (
_AZURE_SP_CONF_FAILURE_MSG,
_INIT_SCRIPT_DBFS_PATH,
AZURE_SP_CONF_FAILURE_MSG,
INCOMPATIBLE_SPARK_CONFIG_KEYS,
_azure_sp_conf_present_check,
INIT_SCRIPT_DBFS_PATH,
azure_sp_conf_present_check,
spark_version_compatibility,
)
from databricks.labs.ucx.assessment.init_scripts import CheckInitScriptMixin
Expand Down Expand Up @@ -51,16 +51,16 @@ def _check_cluster_policy(self, policy_id: str, source: str) -> list[str]:
policy = self._safe_get_cluster_policy(policy_id)
if policy:
if policy.definition:
if _azure_sp_conf_present_check(json.loads(policy.definition)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
if azure_sp_conf_present_check(json.loads(policy.definition)):
failures.append(f"{AZURE_SP_CONF_FAILURE_MSG} {source}.")
if policy.policy_family_definition_overrides:
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
if azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
failures.append(f"{AZURE_SP_CONF_FAILURE_MSG} {source}.")
return failures

def _get_init_script_data(self, init_script_info: InitScriptInfo) -> str | None:
if init_script_info.dbfs is not None and init_script_info.dbfs.destination is not None:
if len(init_script_info.dbfs.destination.split(":")) == _INIT_SCRIPT_DBFS_PATH:
if len(init_script_info.dbfs.destination.split(":")) == INIT_SCRIPT_DBFS_PATH:
file_api_format_destination = init_script_info.dbfs.destination.split(":")[1]
if file_api_format_destination:
try:
Expand Down Expand Up @@ -95,8 +95,8 @@ def check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]:
if "dbfs:/mnt" in value or "/dbfs/mnt" in value:
failures.append(f"using DBFS mount in configuration: {value}")
# Checking if Azure cluster config is present in spark config
if _azure_sp_conf_present_check(conf):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
if azure_sp_conf_present_check(conf):
failures.append(f"{AZURE_SP_CONF_FAILURE_MSG} {source}.")
return failures

def check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[str]:
Expand Down
22 changes: 11 additions & 11 deletions src/databricks/labs/ucx/assessment/crawlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,31 @@
"spark.databricks.hive.metastore.glueCatalog.enabled",
]

_AZURE_SP_CONF = [
AZURE_SP_CONF = [
"fs.azure.account.auth.type",
"fs.azure.account.oauth.provider.type",
"fs.azure.account.oauth2.client.id",
"fs.azure.account.oauth2.client.secret",
"fs.azure.account.oauth2.client.endpoint",
]
_SECRET_PATTERN = r"{{(secrets.*?)}}"
_STORAGE_ACCOUNT_EXTRACT_PATTERN = r"(?:id|endpoint)(.*?)dfs"
_AZURE_SP_CONF_FAILURE_MSG = "Uses azure service principal credentials config in"
_SECRET_LIST_LENGTH = 3
_CLIENT_ENDPOINT_LENGTH = 6
_INIT_SCRIPT_DBFS_PATH = 2
SECRET_PATTERN = r"{{(secrets.*?)}}"
STORAGE_ACCOUNT_EXTRACT_PATTERN = r"(?:id|endpoint)(.*?)dfs"
AZURE_SP_CONF_FAILURE_MSG = "Uses azure service principal credentials config in"
SECRET_LIST_LENGTH = 3
CLIENT_ENDPOINT_LENGTH = 6
INIT_SCRIPT_DBFS_PATH = 2


def _azure_sp_conf_in_init_scripts(init_script_data: str) -> bool:
for conf in _AZURE_SP_CONF:
def azure_sp_conf_in_init_scripts(init_script_data: str) -> bool:
for conf in AZURE_SP_CONF:
if re.search(conf, init_script_data):
return True
return False


def _azure_sp_conf_present_check(config: dict) -> bool:
def azure_sp_conf_present_check(config: dict) -> bool:
for key in config.keys():
for conf in _AZURE_SP_CONF:
for conf in AZURE_SP_CONF:
if re.search(conf, key):
return True
return False
Expand Down
8 changes: 4 additions & 4 deletions src/databricks/labs/ucx/assessment/init_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from databricks.sdk.errors import ResourceDoesNotExist

from databricks.labs.ucx.assessment.crawlers import (
_AZURE_SP_CONF_FAILURE_MSG,
_azure_sp_conf_in_init_scripts,
AZURE_SP_CONF_FAILURE_MSG,
azure_sp_conf_in_init_scripts,
)
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend

Expand All @@ -33,8 +33,8 @@ def check_init_script(self, init_script_data: str | None, source: str) -> list[s
failures: list[str] = []
if not init_script_data:
return failures
if _azure_sp_conf_in_init_scripts(init_script_data):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
if azure_sp_conf_in_init_scripts(init_script_data):
failures.append(f"{AZURE_SP_CONF_FAILURE_MSG} {source}.")
return failures


Expand Down
81 changes: 41 additions & 40 deletions tests/unit/assessment/test_azure.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from unittest.mock import Mock

from databricks.labs.ucx.assessment.azure import AzureServicePrincipalCrawler
from databricks.labs.ucx.assessment.azure import (
AzureServicePrincipalCrawler,
generate_service_principals,
)

from ..framework.mocks import MockBackend
from . import workspace_client_mock
Expand All @@ -9,8 +12,8 @@
def test_azure_spn_info_without_secret():
ws = workspace_client_mock(clusters="single-cluster-spn.json")
sample_spns = [{"application_id": "test123456789", "secret_scope": "", "secret_key": ""}]
AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_cluster_with_spn_in_spark_conf()
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._assess_service_principals(sample_spns)
AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
crawler = generate_service_principals(sample_spns)
result_set = list(crawler)

assert len(result_set) == 1
Expand All @@ -25,7 +28,7 @@ def test_azure_service_principal_info_crawl():
warehouse_config="spn-config.json",
secret_exists=True,
)
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(spn_crawler) == 5

Expand All @@ -38,7 +41,7 @@ def test_azure_service_principal_info_spark_conf_crawl():
warehouse_config="spn-config.json",
)

spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(spn_crawler) == 3

Expand All @@ -51,14 +54,14 @@ def test_azure_service_principal_info_no_spark_conf_crawl():
warehouse_config="single-config.json",
)

spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(spn_crawler) == 0


def test_azure_service_principal_info_policy_family_conf_crawl(mocker):
ws = workspace_client_mock()
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(spn_crawler) == 0

Expand All @@ -67,14 +70,13 @@ def test_azure_service_principal_info_null_applid_crawl():
ws = workspace_client_mock(
clusters="single-cluster-spn-with-policy.json", pipelines="single-pipeline.json", jobs="single-job.json"
)
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
assert len(spn_crawler) == 0


def test_azure_spn_info_with_secret():
ws = workspace_client_mock(clusters="single-cluster-spn.json", secret_exists=True)
sample_spns = [{"application_id": "test123456780", "secret_scope": "abcff", "secret_key": "sp_app_client_id"}]
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._assess_service_principals(sample_spns)
crawler = generate_service_principals(sample_spns)
result_set = list(crawler)

assert len(result_set) == 1
Expand Down Expand Up @@ -120,37 +122,36 @@ def test_spn_with_spark_config_snapshot():

def test_list_all_cluster_with_spn_in_spark_conf_with_secret():
ws = workspace_client_mock(clusters="single-cluster-spn.json")
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_cluster_with_spn_in_spark_conf()
result_set = list(crawler)
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(result_set) == 1


def test_list_all_wh_config_with_spn_no_secret():
ws = workspace_client_mock(warehouse_config="spn-config.json")
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_spn_in_sql_warehouses_spark_conf()
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(result_set) == 2
assert result_set[0].get("application_id") == "dummy_application_id"
assert result_set[0].get("tenant_id") == "dummy_tenant_id"
assert result_set[0].get("storage_account") == "storage_acct2"
assert any(_ for _ in result_set if _.application_id == "dummy_application_id")
assert any(_ for _ in result_set if _.tenant_id == "dummy_tenant_id")
assert any(_ for _ in result_set if _.storage_account == "storage_acct2")


def test_list_all_wh_config_with_spn_and_secret():
ws = workspace_client_mock(warehouse_config="spn-secret-config.json", secret_exists=True)
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_spn_in_sql_warehouses_spark_conf()
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(result_set) == 2
assert result_set[0].get("tenant_id") == "dummy_tenant_id"
assert result_set[0].get("storage_account") == "abcde"
assert any(_ for _ in result_set if _.tenant_id == "dummy_tenant_id")
assert any(_ for _ in result_set if _.storage_account == "abcde")


def test_list_all_clusters_spn_in_spark_conf_with_tenant():
ws = workspace_client_mock(clusters="single-cluster-spn.json", secret_exists=True)
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_cluster_with_spn_in_spark_conf()
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(result_set) == 1
assert result_set[0].get("tenant_id") == "dummy_tenant_id"
assert result_set[0].tenant_id == "dummy_tenant_id"


def test_azure_service_principal_info_policy_conf():
Expand All @@ -161,7 +162,7 @@ def test_azure_service_principal_info_policy_conf():
warehouse_config="spn-config.json",
secret_exists=True,
)
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(spn_crawler) == 4

Expand All @@ -174,48 +175,48 @@ def test_azure_service_principal_info_dedupe():
warehouse_config="dupe-spn-config.json",
secret_exists=True,
)
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(spn_crawler) == 2


def test_list_all_pipeline_with_conf_spn_in_spark_conf():
ws = workspace_client_mock(pipelines="single-pipeline-with-spn.json")
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_pipeline_with_spn_in_spark_conf()
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(result_set) == 1
assert result_set[0].get("storage_account") == "newstorageacct"
assert result_set[0].get("tenant_id") == "directory_12345"
assert result_set[0].get("application_id") == "pipeline_dummy_application_id"
assert result_set[0].storage_account == "newstorageacct"
assert result_set[0].tenant_id == "directory_12345"
assert result_set[0].application_id == "pipeline_dummy_application_id"


def test_list_all_pipeline_wo_conf_spn_in_spark_conf():
ws = workspace_client_mock(pipelines="single-pipeline.json")
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_pipeline_with_spn_in_spark_conf()
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(result_set) == 0


def test_list_all_pipeline_with_conf_spn_tenant():
ws = workspace_client_mock(pipelines="single-pipeline-with-spn.json")
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_pipeline_with_spn_in_spark_conf()
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(result_set) == 1
assert result_set[0].get("storage_account") == "newstorageacct"
assert result_set[0].get("application_id") == "pipeline_dummy_application_id"
assert result_set[0].storage_account == "newstorageacct"
assert result_set[0].application_id == "pipeline_dummy_application_id"


def test_list_all_pipeline_with_conf_spn_secret():
ws = workspace_client_mock(pipelines="single-pipeline-with-spn.json", secret_exists=True)
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_pipeline_with_spn_in_spark_conf()
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(result_set) == 1
assert result_set[0].get("storage_account") == "newstorageacct"
assert result_set[0].storage_account == "newstorageacct"


def test_azure_service_principal_info_policy_family():
ws = workspace_client_mock(clusters="single-cluster-spn-with-policy-overrides.json")
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(spn_crawler) == 1
assert spn_crawler[0].application_id == "dummy_appl_id"
Expand All @@ -225,19 +226,19 @@ def test_azure_service_principal_info_policy_family():
def test_list_all_pipeline_with_conf_spn_secret_unavlbl():
ws = workspace_client_mock(pipelines="single-pipeline.json", secret_exists=False)
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")
result_set = crawler._list_all_pipeline_with_spn_in_spark_conf()
result_set = crawler.snapshot()

assert len(result_set) == 0


def test_list_all_pipeline_with_conf_spn_secret_avlb():
ws = workspace_client_mock(pipelines="single-pipeline-with-spn.json", secret_exists=True)
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_pipeline_with_spn_in_spark_conf()
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()

assert len(result_set) > 0
assert result_set[0].get("application_id") == "pipeline_dummy_application_id"
assert result_set[0].get("tenant_id") == "directory_12345"
assert result_set[0].get("storage_account") == "newstorageacct"
assert result_set[0].application_id == "pipeline_dummy_application_id"
assert result_set[0].tenant_id == "directory_12345"
assert result_set[0].storage_account == "newstorageacct"


def test_azure_spn_info_with_secret_unavailable():
Expand All @@ -251,6 +252,6 @@ def test_azure_spn_info_with_secret_unavailable():
"spark.hadoop.fs.azure.account."
"oauth2.client.secret.abcde.dfs.core.windows.net": "{{secrets/abcff/sp_secret}}",
}
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._get_azure_spn_list(spark_conf)
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._get_azure_spn_from_config(spark_conf)

assert crawler == []
Loading
Loading