From 1a3e28f9b980f513b79938efc7c05a14b58a7690 Mon Sep 17 00:00:00 2001 From: Tim Green Date: Thu, 6 Oct 2022 13:14:55 -0500 Subject: [PATCH 1/4] files that are larger than the 1MB limit allowed in the dbfs put API v2.0 will use the streaming api calls to upload the content and avoid the http 400 MAX_BLOCK_SIZE_EXCEEDED error --- dbx/sync/clients.py | 53 +++++++++-- tests/unit/sync/clients/conftest.py | 9 ++ tests/unit/sync/clients/test_dbfs_client.py | 96 ++++++++++++++------ tests/unit/sync/clients/test_repos_client.py | 18 ++-- 4 files changed, 131 insertions(+), 45 deletions(-) diff --git a/dbx/sync/clients.py b/dbx/sync/clients.py index f02e6852..765f5adc 100644 --- a/dbx/sync/clients.py +++ b/dbx/sync/clients.py @@ -97,7 +97,7 @@ async def _api( more_opts = {"ssl": ssl} if ssl is not None else {} async with session.post(url=url, json=json_data, headers=headers, **more_opts) as resp: if resp.status in ok_status: - break + return await resp.json() if resp.status == 429: dbx_echo("Rate limited") await _rate_limit_sleep(resp) @@ -196,14 +196,49 @@ async def put( path = f"{self.base_path}/{sub_path}" with open(full_source_path, "rb") as f: contents = base64.b64encode(f.read()).decode("ascii") - await self._api_put( - api_base_path=self.api_base_path, - path=path, - session=session, - api_token=self.api_token, - contents=contents, - ssl=self.ssl, - ) + + if len(contents) <= 1024 * 1024: + await self._api_put( + api_base_path=self.api_base_path, + path=path, + session=session, + api_token=self.api_token, + contents=contents, + ssl=self.ssl, + ) + else: + dbx_echo(f"Streaming {path}") + + resp = await self._api( + url=f"{self.api_base_path}/create", + path=path, + session=session, + api_token=self.api_token, + ssl=self.ssl, + overwrite=True, + ) + handle = resp.get("handle") + import textwrap + + chunks = textwrap.wrap(contents, 1024 * 1024) + for chunk in chunks: + await self._api( + url=f"{self.api_base_path}/add-block", + path=path, + session=session, + api_token=self.api_token, + ssl=self.ssl, + handle=handle, + data=chunk, + ) + await self._api( + url=f"{self.api_base_path}/close", + path=path, + session=session, + api_token=self.api_token, + ssl=self.ssl, + handle=handle, + ) class ReposClient(BaseClient): diff --git a/tests/unit/sync/clients/conftest.py b/tests/unit/sync/clients/conftest.py index 03d55563..28f00ab4 100644 --- a/tests/unit/sync/clients/conftest.py +++ b/tests/unit/sync/clients/conftest.py @@ -18,3 +18,12 @@ def dummy_file_path() -> str: with open(file_path, "w") as f: f.write("yo") yield file_path + + +@pytest.fixture +def dummy_file_path_2mb() -> str: + with temporary_directory() as tempdir: + file_path = os.path.join(tempdir, "file") + with open(file_path, "w") as f: + f.write("y" * 1024 * 2048) + yield file_path diff --git a/tests/unit/sync/clients/test_dbfs_client.py b/tests/unit/sync/clients/test_dbfs_client.py index 088f3f86..c155ff74 100644 --- a/tests/unit/sync/clients/test_dbfs_client.py +++ b/tests/unit/sync/clients/test_dbfs_client.py @@ -1,6 +1,8 @@ import asyncio import base64 -from unittest.mock import AsyncMock, MagicMock, PropertyMock +import textwrap +from tests.unit.sync.utils import create_async_with_result +from unittest.mock import AsyncMock, MagicMock, PropertyMock, call import pytest @@ -21,7 +23,7 @@ def test_init(client): def test_delete(client: DBFSClient): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session)) @@ -38,7 +40,7 @@ def test_delete_secure(client: DBFSClient): mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/base/", insecure=False) client = DBFSClient(base_path="/tmp/foo", config=mock_config) session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session)) @@ -49,21 +51,6 @@ def test_delete_secure(client: DBFSClient): assert session.post.call_args[1]["ssl"] is True -def test_delete_secure(client: DBFSClient): - mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/base/", insecure=True) - client = DBFSClient(base_path="/tmp/foo", config=mock_config) - session = MagicMock() - resp = MagicMock() - setattr(type(resp), "status", PropertyMock(return_value=200)) - session.post.return_value = create_async_with_result(resp) - asyncio.run(client.delete(sub_path="foo/bar", session=session)) - - assert session.post.call_count == 1 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/delete" - assert session.post.call_args[1]["json"] == {"path": "dbfs:/tmp/foo/foo/bar"} - assert session.post.call_args[1]["ssl"] is False - - def test_delete_backslash(client: DBFSClient): session = MagicMock() resp = MagicMock() @@ -81,7 +68,7 @@ def test_delete_no_path(client: DBFSClient): def test_delete_recursive(client: DBFSClient): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session, recursive=True)) @@ -97,7 +84,7 @@ def test_delete_rate_limited(client: DBFSClient): rate_limit_resp = MagicMock() setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None})) @@ -117,7 +104,7 @@ def test_delete_rate_limited_retry_after(client: DBFSClient): setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1})) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)] @@ -145,7 +132,7 @@ def test_delete_unauthorized(client: DBFSClient): def test_mkdirs(client: DBFSClient): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.mkdirs(sub_path="foo/bar", session=session)) @@ -178,7 +165,7 @@ def test_mkdirs_rate_limited(client: DBFSClient): rate_limit_resp = MagicMock() setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None})) @@ -198,7 +185,7 @@ def test_mkdirs_rate_limited_retry_after(client: DBFSClient): setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1})) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)] @@ -226,7 +213,7 @@ def test_mkdirs_unauthorized(client: DBFSClient): def test_put(client: DBFSClient, dummy_file_path: str): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) @@ -243,6 +230,61 @@ def test_put(client: DBFSClient, dummy_file_path: str): assert is_dbfs_user_agent(session.post.call_args[1]["headers"]["user-agent"]) +def test_put_max_block_size_exceeded(client: DBFSClient, dummy_file_path_2mb: str): + expected_handle = 1234 + + async def mock_json(*args, **kwargs): + return {"handle": expected_handle} + + def mock_post(url, *args, **kwargs): + resp = AsyncMock() + setattr(type(resp), "status", PropertyMock(return_value=200)) + if "/base/api/2.0/dbfs/put" in url: + contents = kwargs.get("json").get("contents") + if len(contents) > 1024 * 1024: # replicate the api error thrown when contents exceeds max allowed + setattr(type(resp), "status", PropertyMock(return_value=400)) + elif "/base/api/2.0/dbfs/create" in url: + # return a mock response json + resp.json = MagicMock(side_effect=mock_json) + + return create_async_with_result(resp) + + session = AsyncMock() + post = MagicMock(side_effect=mock_post) + session.post = post + + asyncio.run(client.put(sub_path="foo/bar", full_source_path=dummy_file_path_2mb, session=session)) + + with open(dummy_file_path_2mb, "r") as f: + expected_contents = f.read() + + chunks = textwrap.wrap(base64.b64encode(bytes(expected_contents, encoding="utf8")).decode("ascii"), 1024 * 1024) + + assert session.post.call_count == len(chunks) + 2 + assert session.post.call_args_list[0][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/create" + assert session.post.call_args_list[1][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/add-block" + assert session.post.call_args_list[2][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/add-block" + assert session.post.call_args_list[3][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/add-block" + assert session.post.call_args_list[4][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/close" + + assert session.post.call_args_list[0][1]["json"] == { + "path": "dbfs:/tmp/foo/foo/bar", + "overwrite": True, + } + + for i, chunk in enumerate(chunks): + assert session.post.call_args_list[i + 1][1]["json"] == { + "data": chunk, + "path": "dbfs:/tmp/foo/foo/bar", + "handle": expected_handle, + }, f"invalid json for chunk {i}" + + assert session.post.call_args_list[4][1]["json"] == { + "path": "dbfs:/tmp/foo/foo/bar", + "handle": expected_handle, + } + + def test_put_backslash(client: DBFSClient, dummy_file_path: str): session = MagicMock() resp = MagicMock() @@ -266,7 +308,7 @@ def test_put_rate_limited(client: DBFSClient, dummy_file_path: str): rate_limit_resp = MagicMock() setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None})) @@ -290,7 +332,7 @@ def test_put_rate_limited_retry_after(client: DBFSClient, dummy_file_path: str): setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1})) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)] diff --git a/tests/unit/sync/clients/test_repos_client.py b/tests/unit/sync/clients/test_repos_client.py index 5accf60d..456985c9 100644 --- a/tests/unit/sync/clients/test_repos_client.py +++ b/tests/unit/sync/clients/test_repos_client.py @@ -31,7 +31,7 @@ def test_init(mock_config): def test_delete(client: ReposClient): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session)) @@ -48,7 +48,7 @@ def test_delete_secure(client: ReposClient): mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/base/", insecure=False) client = ReposClient(user="foo@somewhere.com", repo_name="my-repo", config=mock_config) session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session)) @@ -63,7 +63,7 @@ def test_delete_insecure(client: ReposClient): mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/base/", insecure=True) client = ReposClient(user="foo@somewhere.com", repo_name="my-repo", config=mock_config) session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session)) @@ -91,7 +91,7 @@ def test_delete_no_path(client: ReposClient): def test_delete_recursive(client: ReposClient): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session, recursive=True)) @@ -107,7 +107,7 @@ def test_delete_rate_limited(client: ReposClient): rate_limit_resp = MagicMock() setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None})) @@ -127,7 +127,7 @@ def test_delete_rate_limited_retry_after(client: ReposClient): setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1})) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)] @@ -155,7 +155,7 @@ def test_delete_unauthorized(client: ReposClient): def test_mkdirs(client: ReposClient): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.mkdirs(sub_path="foo/bar", session=session)) @@ -188,7 +188,7 @@ def test_mkdirs_rate_limited(client: ReposClient): rate_limit_resp = MagicMock() setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None})) @@ -208,7 +208,7 @@ def test_mkdirs_rate_limited_retry_after(client: ReposClient): setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1})) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)] From 4a17b4fe812d2d1390e206b66c8c42095cf86df8 Mon Sep 17 00:00:00 2001 From: Tim Green Date: Tue, 13 Dec 2022 10:15:02 -0600 Subject: [PATCH 2/4] merge with upstream changes --- .github/workflows/onpush.yml | 4 +- .github/workflows/onrelease.yml | 4 +- .gitignore | 1 + CHANGELOG.md | 256 +++++++++++---- Makefile | 4 +- dbx/__init__.py | 2 +- .../parameters => api/adjuster}/__init__.py | 0 dbx/api/adjuster/adjuster.py | 189 +++++++++++ dbx/api/adjuster/mixins/__init__.py | 0 dbx/api/adjuster/mixins/base.py | 32 ++ dbx/api/adjuster/mixins/existing_cluster.py | 43 +++ dbx/api/adjuster/mixins/file_reference.py | 13 + dbx/api/adjuster/mixins/instance_pool.py | 46 +++ dbx/api/adjuster/mixins/instance_profile.py | 46 +++ dbx/api/adjuster/mixins/pipeline.py | 10 + dbx/api/adjuster/mixins/service_principal.py | 38 +++ dbx/api/adjuster/mixins/sql_properties.py | 120 +++++++ dbx/api/adjuster/policy.py | 156 +++++++++ dbx/api/build.py | 31 -- dbx/api/cluster.py | 9 +- dbx/api/config_reader.py | 35 ++- dbx/api/configure.py | 17 +- dbx/api/context.py | 11 +- dbx/api/dependency/__init__.py | 0 dbx/api/dependency/core_package.py | 42 +++ dbx/api/dependency/requirements.py | 36 +++ dbx/api/deployment.py | 41 +++ dbx/api/destroyer.py | 37 ++- dbx/api/execute.py | 56 ++-- dbx/api/launch/functions.py | 31 +- dbx/api/launch/pipeline_models.py | 53 ++++ dbx/api/launch/processors.py | 29 +- dbx/api/launch/runners.py | 191 ----------- dbx/api/launch/runners/__init__.py | 0 dbx/api/launch/runners/asset_based.py | 83 +++++ dbx/api/launch/runners/base.py | 12 + dbx/api/launch/runners/pipeline.py | 57 ++++ dbx/api/launch/runners/standard.py | 65 ++++ dbx/api/launch/tracer.py | 34 +- dbx/api/output_provider.py | 2 +- dbx/api/services/__init__.py | 0 dbx/api/services/_base.py | 23 ++ dbx/api/services/jobs.py | 86 +++++ dbx/api/services/permissions.py | 23 ++ dbx/api/services/pipelines.py | 75 +++++ dbx/api/storage/io.py | 29 ++ dbx/api/storage/mlflow_based.py | 15 +- dbx/callbacks.py | 8 +- dbx/cli.py | 6 +- dbx/commands/configure.py | 20 +- dbx/commands/deploy.py | 219 ++++--------- dbx/commands/destroy.py | 35 +-- dbx/commands/execute.py | 137 +++----- dbx/commands/launch.py | 166 ++++++---- dbx/commands/sync/options.py | 2 +- dbx/commands/sync/sync.py | 16 + dbx/constants.py | 3 + dbx/models/build.py | 54 ++++ dbx/models/cli/__init__.py | 0 dbx/models/cli/destroyer.py | 20 ++ dbx/models/cli/execute.py | 22 ++ dbx/models/{ => cli}/options.py | 0 dbx/models/deployment.py | 156 ++++++--- dbx/models/destroyer.py | 32 -- dbx/models/files/__init__.py | 0 dbx/models/{ => files}/context.py | 0 dbx/models/{ => files}/project.py | 1 + dbx/models/job_clusters.py | 34 -- dbx/models/parameters/common.py | 21 -- dbx/models/parameters/execute.py | 29 -- dbx/models/parameters/run_now.py | 27 -- dbx/models/parameters/run_submit.py | 132 -------- dbx/models/task.py | 71 ----- dbx/models/validators.py | 73 +++++ dbx/models/workflow/__init__.py | 0 dbx/models/workflow/common/__init__.py | 0 dbx/models/workflow/common/access_control.py | 40 +++ .../workflow/common/deployment_config.py | 7 + dbx/models/workflow/common/flexible.py | 25 ++ .../common/job_email_notifications.py | 10 + dbx/models/workflow/common/libraries.py | 37 +++ dbx/models/workflow/common/new_cluster.py | 61 ++++ dbx/models/workflow/common/parameters.py | 32 ++ dbx/models/workflow/common/pipeline.py | 45 +++ dbx/models/workflow/common/task.py | 100 ++++++ dbx/models/workflow/common/task_type.py | 18 ++ dbx/models/workflow/common/workflow.py | 29 ++ dbx/models/workflow/common/workflow_types.py | 7 + dbx/models/workflow/v2dot0/__init__.py | 0 dbx/models/workflow/v2dot0/parameters.py | 16 + dbx/models/workflow/v2dot0/task.py | 26 ++ dbx/models/workflow/v2dot0/workflow.py | 54 ++++ dbx/models/workflow/v2dot1/__init__.py | 0 dbx/models/workflow/v2dot1/_parameters.py | 19 ++ dbx/models/workflow/v2dot1/job_cluster.py | 39 +++ .../workflow/v2dot1/job_task_settings.py | 28 ++ dbx/models/workflow/v2dot1/parameters.py | 32 ++ dbx/models/workflow/v2dot1/task.py | 86 +++++ dbx/models/workflow/v2dot1/workflow.py | 71 +++++ dbx/options.py | 46 +-- dbx/sync/__init__.py | 4 +- dbx/sync/clients.py | 32 +- .../render/hooks/post_gen_project.py | 9 +- .../{{cookiecutter.project_name}}/README.md | 2 +- .../{{cookiecutter.project_name}}/setup.py | 2 +- .../tests/entrypoint.py | 6 +- dbx/types.py | 6 + dbx/utils/__init__.py | 3 + dbx/utils/adjuster.py | 96 ------ dbx/utils/common.py | 14 - dbx/utils/dependency_manager.py | 88 ------ dbx/utils/file_uploader.py | 55 ++-- dbx/utils/job_listing.py | 26 -- dbx/utils/named_properties.py | 162 ---------- dbx/utils/policy_parser.py | 69 ---- dbx/utils/url.py | 12 + docs/concepts/artifact_storage.md | 74 +++++ docs/concepts/cluster_types.md | 2 +- docs/custom/custom.css | 3 - docs/extras/styles.css | 181 +++++++++++ docs/faq.md | 23 +- docs/features/file_references.md | 32 +- docs/features/named_properties.md | 163 +++++++++- docs/features/permissions_management.md | 48 ++- docs/guides/general/custom_templates.md | 81 +++++ docs/guides/general/delta_live_tables.md | 134 ++++++++ docs/guides/general/passing_parameters.md | 157 ++++++++++ docs/guides/python/python_quickstart.md | 24 +- docs/index.md | 129 +++----- docs/intro.md | 90 ++++++ docs/migration.md | 26 +- docs/overrides/404.html | 18 ++ docs/overrides/partials/source-file.html | 27 ++ docs/reference/deployment.md | 223 ++++++++++--- mkdocs.yml | 26 +- prospector.yaml | 6 +- setup.py | 34 +- tests/unit/api/adjuster/__init__.py | 0 tests/unit/api/adjuster/test_complex.py | 145 +++++++++ .../api/adjuster/test_existing_cluster.py | 82 +++++ tests/unit/api/adjuster/test_instance_pool.py | 83 +++++ .../api/adjuster/test_instance_profile.py | 89 ++++++ tests/unit/api/adjuster/test_pipeline.py | 62 ++++ tests/unit/api/adjuster/test_policy.py | 296 ++++++++++++++++++ .../api/adjuster/test_service_principals.py | 67 ++++ tests/unit/api/launch/test_functions.py | 10 +- tests/unit/api/launch/test_pipeline_runner.py | 31 ++ tests/unit/api/launch/test_processors.py | 25 +- tests/unit/api/launch/test_runners.py | 101 ++---- tests/unit/api/launch/test_tracer.py | 34 +- tests/unit/api/storage/test_io.py | 11 + tests/unit/api/storage/test_mlflow_storage.py | 5 +- tests/unit/api/test_build.py | 25 +- tests/unit/api/test_context.py | 2 +- tests/unit/api/test_deployment.py | 42 +++ tests/unit/api/test_destroyer.py | 25 +- tests/unit/api/test_jinja.py | 2 + tests/unit/api/test_jobs_service.py | 25 ++ tests/unit/commands/test_deploy.py | 157 ++++------ .../test_deploy_jinja_variables_file.py | 12 +- tests/unit/commands/test_destroy.py | 5 +- tests/unit/commands/test_execute.py | 142 +++++---- tests/unit/commands/test_launch.py | 92 ++++-- tests/unit/conftest.py | 43 ++- tests/unit/models/test_acls.py | 37 +++ tests/unit/models/test_deployment.py | 48 ++- tests/unit/models/test_destroyer.py | 19 +- tests/unit/models/test_git_source.py | 13 + tests/unit/models/test_job_clusters.py | 24 +- tests/unit/models/test_new_cluster.py | 29 ++ tests/unit/models/test_parameters.py | 85 +++-- tests/unit/models/test_pipeline.py | 8 + tests/unit/models/test_task.py | 107 ++++--- tests/unit/models/test_v2dot0_workflow.py | 39 +++ tests/unit/models/test_v2dot1_workflow.py | 35 +++ tests/unit/sync/clients/conftest.py | 2 +- tests/unit/sync/clients/test_dbfs_client.py | 27 +- tests/unit/sync/clients/test_get_user.py | 17 +- tests/unit/sync/clients/test_repos_client.py | 90 +++++- tests/unit/sync/test_commands.py | 216 +++++++------ tests/unit/utils/test_common.py | 81 ++--- tests/unit/utils/test_dependency_manager.py | 100 +----- tests/unit/utils/test_file_uploader.py | 46 +-- tests/unit/utils/test_named_properties.py | 167 ---------- tests/unit/utils/test_policy_parser.py | 25 -- 185 files changed, 6315 insertions(+), 2799 deletions(-) rename dbx/{models/parameters => api/adjuster}/__init__.py (100%) create mode 100644 dbx/api/adjuster/adjuster.py create mode 100644 dbx/api/adjuster/mixins/__init__.py create mode 100644 dbx/api/adjuster/mixins/base.py create mode 100644 dbx/api/adjuster/mixins/existing_cluster.py create mode 100644 dbx/api/adjuster/mixins/file_reference.py create mode 100644 dbx/api/adjuster/mixins/instance_pool.py create mode 100644 dbx/api/adjuster/mixins/instance_profile.py create mode 100644 dbx/api/adjuster/mixins/pipeline.py create mode 100644 dbx/api/adjuster/mixins/service_principal.py create mode 100644 dbx/api/adjuster/mixins/sql_properties.py create mode 100644 dbx/api/adjuster/policy.py create mode 100644 dbx/api/dependency/__init__.py create mode 100644 dbx/api/dependency/core_package.py create mode 100644 dbx/api/dependency/requirements.py create mode 100644 dbx/api/deployment.py create mode 100644 dbx/api/launch/pipeline_models.py delete mode 100644 dbx/api/launch/runners.py create mode 100644 dbx/api/launch/runners/__init__.py create mode 100644 dbx/api/launch/runners/asset_based.py create mode 100644 dbx/api/launch/runners/base.py create mode 100644 dbx/api/launch/runners/pipeline.py create mode 100644 dbx/api/launch/runners/standard.py create mode 100644 dbx/api/services/__init__.py create mode 100644 dbx/api/services/_base.py create mode 100644 dbx/api/services/jobs.py create mode 100644 dbx/api/services/permissions.py create mode 100644 dbx/api/services/pipelines.py create mode 100644 dbx/api/storage/io.py create mode 100644 dbx/models/build.py create mode 100644 dbx/models/cli/__init__.py create mode 100644 dbx/models/cli/destroyer.py create mode 100644 dbx/models/cli/execute.py rename dbx/models/{ => cli}/options.py (100%) delete mode 100644 dbx/models/destroyer.py create mode 100644 dbx/models/files/__init__.py rename dbx/models/{ => files}/context.py (100%) rename dbx/models/{ => files}/project.py (96%) delete mode 100644 dbx/models/job_clusters.py delete mode 100644 dbx/models/parameters/common.py delete mode 100644 dbx/models/parameters/execute.py delete mode 100644 dbx/models/parameters/run_now.py delete mode 100644 dbx/models/parameters/run_submit.py delete mode 100644 dbx/models/task.py create mode 100644 dbx/models/validators.py create mode 100644 dbx/models/workflow/__init__.py create mode 100644 dbx/models/workflow/common/__init__.py create mode 100644 dbx/models/workflow/common/access_control.py create mode 100644 dbx/models/workflow/common/deployment_config.py create mode 100644 dbx/models/workflow/common/flexible.py create mode 100644 dbx/models/workflow/common/job_email_notifications.py create mode 100644 dbx/models/workflow/common/libraries.py create mode 100644 dbx/models/workflow/common/new_cluster.py create mode 100644 dbx/models/workflow/common/parameters.py create mode 100644 dbx/models/workflow/common/pipeline.py create mode 100644 dbx/models/workflow/common/task.py create mode 100644 dbx/models/workflow/common/task_type.py create mode 100644 dbx/models/workflow/common/workflow.py create mode 100644 dbx/models/workflow/common/workflow_types.py create mode 100644 dbx/models/workflow/v2dot0/__init__.py create mode 100644 dbx/models/workflow/v2dot0/parameters.py create mode 100644 dbx/models/workflow/v2dot0/task.py create mode 100644 dbx/models/workflow/v2dot0/workflow.py create mode 100644 dbx/models/workflow/v2dot1/__init__.py create mode 100644 dbx/models/workflow/v2dot1/_parameters.py create mode 100644 dbx/models/workflow/v2dot1/job_cluster.py create mode 100644 dbx/models/workflow/v2dot1/job_task_settings.py create mode 100644 dbx/models/workflow/v2dot1/parameters.py create mode 100644 dbx/models/workflow/v2dot1/task.py create mode 100644 dbx/models/workflow/v2dot1/workflow.py create mode 100644 dbx/types.py delete mode 100644 dbx/utils/adjuster.py delete mode 100644 dbx/utils/dependency_manager.py delete mode 100644 dbx/utils/job_listing.py delete mode 100644 dbx/utils/named_properties.py delete mode 100644 dbx/utils/policy_parser.py create mode 100644 dbx/utils/url.py create mode 100644 docs/concepts/artifact_storage.md delete mode 100644 docs/custom/custom.css create mode 100644 docs/extras/styles.css create mode 100644 docs/guides/general/custom_templates.md create mode 100644 docs/guides/general/delta_live_tables.md create mode 100644 docs/guides/general/passing_parameters.md create mode 100644 docs/intro.md create mode 100644 docs/overrides/404.html create mode 100644 docs/overrides/partials/source-file.html create mode 100644 tests/unit/api/adjuster/__init__.py create mode 100644 tests/unit/api/adjuster/test_complex.py create mode 100644 tests/unit/api/adjuster/test_existing_cluster.py create mode 100644 tests/unit/api/adjuster/test_instance_pool.py create mode 100644 tests/unit/api/adjuster/test_instance_profile.py create mode 100644 tests/unit/api/adjuster/test_pipeline.py create mode 100644 tests/unit/api/adjuster/test_policy.py create mode 100644 tests/unit/api/adjuster/test_service_principals.py create mode 100644 tests/unit/api/launch/test_pipeline_runner.py create mode 100644 tests/unit/api/storage/test_io.py create mode 100644 tests/unit/api/test_deployment.py create mode 100644 tests/unit/api/test_jobs_service.py create mode 100644 tests/unit/models/test_acls.py create mode 100644 tests/unit/models/test_git_source.py create mode 100644 tests/unit/models/test_new_cluster.py create mode 100644 tests/unit/models/test_pipeline.py create mode 100644 tests/unit/models/test_v2dot0_workflow.py create mode 100644 tests/unit/models/test_v2dot1_workflow.py delete mode 100644 tests/unit/utils/test_named_properties.py delete mode 100644 tests/unit/utils/test_policy_parser.py diff --git a/.github/workflows/onpush.yml b/.github/workflows/onpush.yml index df91682c..c493e277 100644 --- a/.github/workflows/onpush.yml +++ b/.github/workflows/onpush.yml @@ -16,7 +16,9 @@ jobs: os: [ ubuntu-latest, windows-latest ] steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v3 + with: + fetch-depth: 0 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 diff --git a/.github/workflows/onrelease.yml b/.github/workflows/onrelease.yml index e54bdfbd..0ba1fdd4 100644 --- a/.github/workflows/onrelease.yml +++ b/.github/workflows/onrelease.yml @@ -15,7 +15,9 @@ jobs: os: [ ubuntu-latest ] steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v3 + with: + fetch-depth: 0 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 diff --git a/.gitignore b/.gitignore index c41c22e7..9fab74fb 100644 --- a/.gitignore +++ b/.gitignore @@ -92,3 +92,4 @@ temp/ out/ site/ +.dbx/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 26d83aa3..ccef6273 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ # Changelog + All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), @@ -10,19 +11,124 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] - YYYY-MM-DD ----- -> Unreleased changes must be tracked above this line. -> When releasing, Copy the changelog to below this line, with proper version and date. -> And empty the **[Unreleased]** section above. ----- +# Fixed -## [0.7.5] - 2022-09-15 +- 🩹 Reload config after build in case if there are any dynamic components dependent on it +- 🩹 Check if target repo exists before syncing and produce more clear error message if it does not. +- 🩹 Type recognition of `named_parameters` in `python_wheel_task` +- 🔨 Add support for extras for cloud file operations + +## [0.8.7] - 2022-11-14 ## Added + +- 📖 Documentation on how to use custom templates +- 🦺 Add explicit Python file extension validation for `spark_python_task` + +## Fixed + +- 🩹 Build logic in case when `no_build` is specified + + +## [0.8.6] - 2022-11-09 + +### Changed + +- ♻️ Allow `init_scripts` in DLT pipelines +- 🔇 Hide the rst version overlay from read the docs + + +## [0.8.5] - 2022-11-09 + +### Changed + +- ⬆️ Bump typer to 0.7.0 +- 👔 improve docs and add landing page + + +## [0.8.4] - 2022-11-07 + +### Fixed + +- 🩹 Argument parsing logic in `dbx execute` without any arguments + + +## [0.8.3] - 2022-11-06 + +### Fixed + +- 🩹 Wheel dependency for setup has been removed +- 🩹 Add host cleanup logic to `dbx sync` commands +- 🩹 Return auto-add functionality from `dist` folder +- 🩹 Make `pause_status` property optional +- 🚨 Make traversal process fail-safe for dictionaries + +### Changed + +- ⚡️ Use improved method for job search + +## [0.8.2] - 2022-11-02 + +### Fixed + +- 🩹 Deletion logic in the workflow eraser + + +## [0.8.1] - 2022-11-02 + +### Changed + +- 📖 Reference documentation for deployment file +- ♻️ Add extensive caching for job listing + + +## [0.8.0] - 2022-11-02 + +### Changed + +- ♻️ Introduce model matching for workflow object +- ♻️ Heavily refactor parameter passing logic +- ♻️ Heavily refactor the models used by `dbx` internal APIs +- ♻️ Make empty `workflows` list a noop instead of error +- ♻️ Handle `pytest` exit code in cookiecutter project integration test entrypoint + +### Added + +- 🔥 Delta Live Tables support +- 📖 Documentation on the differences between `dbx execute` and `dbx launch` +- 📖 Documentation on how to use parameter passing in various cases +- 📖 Documentation on how to enable Photon +- 📖 Documentation on artifact storage +- 🪄 Functionality to automatically enable context-based upload +- 🪄 Automatic conversion from `wasbs://` to `abfss://` references when using ADLS as artifact storage. +- ♻️ New init scripts append logic in case when `cluster-policy://` resolution is used. + +### Fixed + +- 🐛 Message with rich markup [] is properly displayed now +- 📖 Broken link in the generated README.md in Python template + + +## [0.7.6] - 2022-10-05 + +### Changed + +- ✨ Empty list of workflows is now a noop instead of throwing an error +- 🩹 Disable the soft-wrap for printed out text + +### Fixed + +- 🐛 Rollback to the failsafe behaviour for assets-based property preprocessing + +## [0.7.5] - 2022-09-15 + +### Added + - 📖 documentation on the dependency management - ✨ failsafe switch for assets-based shared job clusters -## Fixed +### Fixed + - 🎨 404 page in docs is now rendered correctly - ✏️ Small typos in the docs - ✏️ Reference structures for `libraries` section @@ -30,52 +136,60 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.7.4] - 2022-08-31 -## Added +### Added + - 📖 documentation on the integration tests -## Changed +### Changed + - ♻️ refactored poetry build logic -## Fixed +### Fixed + - 📖 indents in quickstart doc - 📝 add integration tests to the quickstart structure ## [0.7.3] - 2022-08-29 -## Added +### Added + - ✨ add pip install extras option - 🎨 Nice spinners for long-running processes (e.g. cluster start and run tracing) - 🧪 Add convenient integration tests interface example -## Changed +### Changed + - 📖 Small typos in Jinja docs - 📖 Formatting issues in cluster types doc ## [0.7.2] - 2022-08-28 -## Fixed +### Fixed + - 🐛 bug with context provisioning for `dbx execute` ## [0.7.1] - 2022-08-28 -## Added +### Added + - ⚡️`dbx destroy` command - ☁️ failsafe behaviour for shared clusters when assets-based launch is used - 📖 Documentation with cluster types guidance - 📖 Documentation with scheduling and orchestration links - 📖 Documentation for mixed-mode projects DevOps -## Changed +### Changed + - ✨Add `.dbx/sync` folder to template gitignore - ✨Changed the dependencies from the `mlflow` to a more lightweight `mlflow-skinny` option - ✨Added suppression for too verbose `click` stacktraces - ⚡️added `execute_shell_command` fixture, improving tests performance x2 - ⚡️added failsafe check for `get_experiment_by_name` call - ## [0.7.0] - 2022-08-24 -## Added +### Added + - 🎨Switch all the CLI interfaces to `typer` - ✨Add `workflow-name` argument to `dbx deploy`, `dbx launch` and `dbx execute` - ✨Add `--workflows` argument to `dbx deploy` @@ -89,7 +203,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ✨Add build logic customization with `build.commands` section - ✨Add support for custom Python functions in Jinja templates -## Changed +### Changed + - ✨Arguments `--allow-delete-unmatched`/`--disallow-delete-unmatched` were **replaced** with `--unmatched-behaviour` option. - 🏷️Deprecate `jobs` section and rename it to `workflows` - 🏷️Deprecate `job` and `jobs` options and rename it to `workflow` argument @@ -104,68 +219,71 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - 💎Documentation framework changed from `sphinx` to `mkdocs` - 💎Documentation has been heavily re-worked and improved -## Fixed +### Fixed + - 🐛`dbx sync` now takes into account `HTTP(S)_PROXY` env variables - 🐛empty task parameters are now supported - 🐛ACLs are now properly updated for Jobs API 2.1 ## [0.6.12] - 2022-08-15 -## Added +### Added - `--jinja-variables-file` for `dbx execute` -## Fixed +### Fixed + - Support `jobs_api_version` values provided by config in `ApiClient` construction - References and wording in the Python template ## [0.6.11] - 2022-08-09 -## Fixed -- Callback issue in `--jinja-variables-file` for `dbx deploy` +### Fixed +- Callback issue in `--jinja-variables-file` for `dbx deploy` ## [0.6.10] - 2022-08-04 -## Added +### Added + - Added support for `python_wheel_task` in `dbx execute` -## Fixed +### Fixed + - Error in case when `.dbx/project.json` is non-existent - Error in case when `environment` is not provided in the project file - Path usage when `--upload-via-context` on win platform ## [0.6.9] - 2022-08-03 -## Added +### Added - Additional `sync` command options (`--no-use-gitignore`, `--force-include`, etc.) for more control over what is synced. - Additional `init` command option `--template` was added to allow using dbx templates distributed as part of python packages. - Refactored the `--deployment-file` option for better modularity of the code - Add upload via context for `dbx execute` - ## [0.6.8] - 2022-07-21 -## Fixed +### Fixed - Tasks naming in tests imports for Python template ## [0.6.7] - 2022-07-21 -## Fixed +### Fixed - Task naming and references in the Python template - Small typo in Python template ## [0.6.6] - 2022-07-21 -## Changed +### Changed - Rename `workloads` to `tasks` in the Python package template - Documentation structure has been refactored -## Added +### Added - Option (`--include-output`) to include run stderr and stdout output to the console output - Docs describing how-to for Python packaging @@ -174,26 +292,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.6.5] - 2022-07-19 -## Fixed +### Fixed - Local build command now produces only one file in the `dist` folder -## Added +### Added - Add `dist` directory cleanup before core package build - Add `--job-run-log-level` option to `dbx launch` to retrieve log after trace run -## Changed +### Changed - Separate `unit-requirements.txt` file has been deleted from the template ## [0.6.4] - 2022-07-01 -## Fixed +### Fixed - `RunSubmit` based launch when cloud storage is used as an artifact location - ## [0.6.3] - 2022-06-28 ### Added @@ -204,12 +321,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - All invocations in Azure Pipelines template are now module-based (`python -m ...`) - ## [0.6.2] - 2022-06-24 - Fix auth ordering (now env-variables based auth has priority across any other auth methods) - ## [0.6.1] - 2022-06-22 - Fix import issues in `dbx.api.storage` package @@ -234,7 +349,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Documentation improvements for Jinja-based templates - Now package builds are performed with `pip` by default - ### Fixed - Parsing of `requirements.txt` has been improved to properly handle comments in requirements files @@ -259,7 +373,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.4.1] - 2022-03-01 -## Fixed +### Fixed - Jinja2-based file recognition behaviour @@ -277,7 +391,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Issues with `--no-package` argument for multi-task jobs - Issues with named properties propagation for Jobs API 2.1 - ## [0.3.3] - 2022-02-08 ### Fixed @@ -298,6 +411,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.3.1] - 2022-01-30 ### Added + - Recognition of `conf/deployment.yml` file from conf directory as a default parameter - Remove unnecessary references of `conf/deployment.yml` in CI pipelines @@ -308,6 +422,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Upgraded minimal requirements for Azure Data Factory dependent libraries ### Fixed + - Provided bugfix for emoji-based messages in certain shell environments - Provided bugfix for cases when not all jobs are listed due to usage of Jobs API 2.1 - Provided bugfix for cases when file names are reused multiple times @@ -315,63 +430,74 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Provided bugfix for ADF integration that deleted pipeline-level properties ## [0.3.0] - 2022-01-04 + ### Added + - Add support for named property of the driver instance pool name - Add support for built-in templates and project initialization via :code:`dbx init` ### Fixed -- Provided bugfix for named property resolution in multitask-based jobs - +- Provided bugfix for named property resolution in multitask-based jobs ## [0.2.2] - 2021-12-03 + ### Changed + - Update the contribution docs with CLA - Update documentation about environment variables ### Added + - Add support for named job properties - Add support for `spark_jar_task` in Azure Data Factory reflector ### Fixed + - Provide bugfix for strict path resolving in the execute command - Provide bugfix for Azure Datafactory when using `existing_cluster_id` ## [0.2.1] - 2021-11-04 + ### Changed + - Update `databricks-cli` dependency to 0.16.2 - Improved code coverage ### Added + - Added support for environment variables in deployment files ### Fixed + - Fixed minor bug in exception text - Provide a bugfix for execute issue ## [0.2.0] - 2021-09-12 + ### Changed + - Removed pydash from package dependencies, as it is not used. Still need it as a dev-requirement. ### Added + - Added support for [multitask jobs](https://docs.databricks.com/data-engineering/jobs/index.html). - Added more explanations around DATABRICKS_HOST exception during API client initialization - Add strict path adjustment policy and FUSE-based path adjustment - - - - ## [0.1.6] - 2021-08-26 + ### Fixed + - Fix issue which stripped non-pyspark libraries from a requirements file during deploys. - Fix issue which didn't update local package during remote execution. - ## [0.1.5] - 2021-08-12 ### Added + - Support for [yaml-based deployment files](https://github.com/databrickslabs/dbx/issues/39). ### Changed + - Now dbx finds the git branch name from any subdirectory in the repository. - Minor alterations in the documentation. - Altered the Changelog based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) @@ -380,70 +506,80 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `make clean install` will set you up with all that is needed. - `make help` to see all available commands. - ## [0.1.4] + ### Fixed + - Fix issue with execute parameters passing - Fix issue with multi-version package upload - ## [0.1.3] ### Added + - Add explicit exception for artifact location change - Add experimental support for fixed properties' propagation from cluster policies - ## [0.1.2] + ### Added -- Added Run Submit API support. +- Added Run Submit API support. ## [0.1.1] + ### Fixed -- Fixed the issue with pywin32 installation for Azure imports on win platforms. +- Fixed the issue with pywin32 installation for Azure imports on win platforms. ## [0.1.0] + ### Added + - Integration with Azure Data Factory. ### Fixed + - Some small internal behaviour fixes. ### Changed -- Changed the behaviour of `dbx deploy --write-specs-to-file`, to make the structure of specs file compatible with environment structure. +- Changed the behaviour of `dbx deploy --write-specs-to-file`, to make the structure of specs file compatible with environment structure. ## [0.0.14] + ### Added -- Added integrated permission management, please refer to documentation for details. +- Added integrated permission management, please refer to documentation for details. ## [0.0.13] ### Added -- Added `--write-specs-to-file` option for `dbx deploy` command. +- Added `--write-specs-to-file` option for `dbx deploy` command. ## [0.0.12] ### Fixed -- HotFix for execute command. +- HotFix for execute command. ## [0.0.11] ### Changed -- Made Internal refactorings after code coverage analysis. +- Made Internal refactorings after code coverage analysis. ## [0.0.10] + ### Fixed -- Fixed issue with job spec adjustment. +- Fixed issue with job spec adjustment. ## [0.0.9] + ### Changed + - Finalized the CI setup for the project. - No code changes were done. - Release is required to start correct numeration in pypi. - ## [0.0.8] + ### Added + - Initial public release version. diff --git a/Makefile b/Makefile index e3b6b007..1ab464c4 100644 --- a/Makefile +++ b/Makefile @@ -198,7 +198,9 @@ test: ## Run the tests. (option): file=tests/path/to/file.py @echo "" @echo "${YELLOW}Running tests:${NORMAL}" @make helper-line - $(PYTHON) -m pytest -vv --cov dbx $(file) -n auto + $(PYTHON) -m pytest -vv --cov dbx $(file) -n auto \ + --cov-report=xml \ + --cov-report=term-missing:skip-covered test-with-html-report: ## Run all tests with html reporter. @echo "" diff --git a/dbx/__init__.py b/dbx/__init__.py index ab55bb1a..aa00ec3d 100644 --- a/dbx/__init__.py +++ b/dbx/__init__.py @@ -1 +1 @@ -__version__ = "0.7.5" +__version__ = "0.8.8" diff --git a/dbx/models/parameters/__init__.py b/dbx/api/adjuster/__init__.py similarity index 100% rename from dbx/models/parameters/__init__.py rename to dbx/api/adjuster/__init__.py diff --git a/dbx/api/adjuster/adjuster.py b/dbx/api/adjuster/adjuster.py new file mode 100644 index 00000000..cfcf38c3 --- /dev/null +++ b/dbx/api/adjuster/adjuster.py @@ -0,0 +1,189 @@ +from typing import Any, Optional, Union, List + +from databricks_cli.sdk import ApiClient +from pydantic import BaseModel + +from dbx.api.adjuster.mixins.existing_cluster import ExistingClusterAdjuster +from dbx.api.adjuster.mixins.file_reference import FileReferenceAdjuster +from dbx.api.adjuster.mixins.instance_pool import InstancePoolAdjuster +from dbx.api.adjuster.mixins.instance_profile import InstanceProfileAdjuster +from dbx.api.adjuster.mixins.pipeline import PipelineAdjuster +from dbx.api.adjuster.mixins.service_principal import ServicePrincipalAdjuster +from dbx.api.adjuster.mixins.sql_properties import SqlPropertiesAdjuster +from dbx.api.adjuster.policy import PolicyAdjuster +from dbx.models.deployment import WorkflowList +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.common.libraries import Library +from dbx.models.workflow.common.new_cluster import NewCluster +from dbx.models.workflow.v2dot0.workflow import Workflow as V2dot0Workflow +from dbx.models.workflow.v2dot1.job_cluster import JobCluster +from dbx.models.workflow.v2dot1.job_task_settings import JobTaskSettings +from dbx.utils import dbx_echo +from dbx.utils.file_uploader import AbstractFileUploader + + +class AdditionalLibrariesProvider(FlexibleModel): + no_package: Optional[bool] = False + core_package: Optional[Library] + libraries_from_requirements: Optional[List[Library]] = [] + + +class PropertyAdjuster( + InstancePoolAdjuster, + ExistingClusterAdjuster, + InstanceProfileAdjuster, + PipelineAdjuster, + ServicePrincipalAdjuster, + SqlPropertiesAdjuster, + PolicyAdjuster, +): + def traverse(self, _object: Any, parent: Optional[Any] = None, index_in_parent: Optional[Any] = None): + + # if element is a dictionary, simply continue traversing + if isinstance(_object, dict): + for key in list(_object.keys()): + item = _object[key] + yield item, _object, key + for _out in self.traverse(item, _object, index_in_parent): + yield _out + + # if element is a list, simply continue traversing + elif isinstance(_object, list): + for idx, sub_item in enumerate(_object): + yield sub_item, _object, idx + for _out in self.traverse(sub_item, _object, idx): + yield _out + + # process any other kind of nested references + elif isinstance(_object, (BaseModel, FlexibleModel)): + for key, sub_element in _object.__dict__.items(): + if sub_element is not None: + yield sub_element, _object, key + for _out in self.traverse(sub_element, _object, key): + yield _out + else: + # yield the low-level objects + yield _object, parent, index_in_parent + + @staticmethod + def _preprocess_libraries( + element: Union[JobTaskSettings, V2dot0Workflow], additional_libraries: AdditionalLibrariesProvider + ): + _element_string = ( + f"workflow {element.name}" if isinstance(element, V2dot0Workflow) else f"task {element.task_key}" + ) + dbx_echo(f"Processing libraries for {_element_string}") + element.libraries += additional_libraries.libraries_from_requirements + if additional_libraries.no_package or (element.deployment_config and element.deployment_config.no_package): + pass + else: + element.libraries += [additional_libraries.core_package] if additional_libraries.core_package else [] + dbx_echo(f"✅ Processing libraries for {_element_string} - done") + + def library_traverse(self, workflows: WorkflowList, additional_libraries: AdditionalLibrariesProvider): + + for element, _, __ in self.traverse(workflows): + + if isinstance(element, (V2dot0Workflow, JobTaskSettings)): + self._preprocess_libraries(element, additional_libraries) + + def _new_cluster_handler(self, element: NewCluster): + # driver_instance_pool_name -> driver_instance_pool_id + if element.driver_instance_pool_name is not None: + self._adjust_legacy_driver_instance_pool_ref(element) + # instance_pool_name -> instance_pool_id + if element.instance_pool_name is not None: + self._adjust_legacy_instance_pool_ref(element) + # instance_profile_name -> instance_profile_arn + if element.aws_attributes is not None and element.aws_attributes.instance_profile_name is not None: + self._adjust_legacy_instance_profile_ref(element) + + def property_traverse(self, workflows: WorkflowList): + """ + This traverse applies all the transformations to the workflows + :param workflows: + :return: None + """ + for element, parent, index in self.traverse(workflows): + + if isinstance(element, V2dot0Workflow): + # legacy named conversion + # existing_cluster_name -> existing_cluster_id + if element.existing_cluster_name is not None: + self._adjust_legacy_existing_cluster(element) + + if isinstance(element, NewCluster): + self._new_cluster_handler(element) + + if isinstance(element, str): + + if element.startswith("cluster://"): + self._adjust_existing_cluster_ref(element, parent, index) + + elif element.startswith("instance-profile://"): + self._adjust_instance_profile_ref(element, parent, index) + + elif element.startswith("instance-pool://"): + self._adjust_instance_pool_ref(element, parent, index) + + elif element.startswith("pipeline://"): + self._adjust_pipeline_ref(element, parent, index) + + elif element.startswith("service-principal://"): + self._adjust_service_principal_ref(element, parent, index) + + elif element.startswith("warehouse://"): + self._adjust_warehouse_ref(element, parent, index) + + elif element.startswith("query://"): + self._adjust_query_ref(element, parent, index) + + elif element.startswith("dashboard://"): + self._adjust_dashboard_ref(element, parent, index) + + elif element.startswith("alert://"): + self._adjust_alert_ref(element, parent, index) + + def cluster_policy_traverse(self, workflows: WorkflowList): + """ + This traverse applies only the policy_name OR policy_id traverse. + Please note that this traverse should go STRICTLY after all other rules, + when ids and other transformations are already resolved. + :param workflows: + :return: None + """ + for element, parent, _ in self.traverse(workflows): + if isinstance(parent, (V2dot0Workflow, JobTaskSettings, JobCluster)) and isinstance(element, NewCluster): + if element.policy_name is not None or ( + isinstance(element, NewCluster) + and element.policy_id is not None + and element.policy_id.startswith("cluster-policy://") + ): + element = self._adjust_policy_ref(element) + parent.new_cluster = element + + def file_traverse(self, workflows, file_adjuster: FileReferenceAdjuster): + for element, parent, index in self.traverse(workflows): + if isinstance(element, str): + if element.startswith("file://") or element.startswith("file:fuse://"): + file_adjuster.adjust_file_ref(element, parent, index) + + +class Adjuster: + def __init__( + self, + additional_libraries: AdditionalLibrariesProvider, + file_uploader: AbstractFileUploader, + api_client: ApiClient, + ): + self.property_adjuster = PropertyAdjuster(api_client=api_client) + self.file_adjuster = FileReferenceAdjuster(file_uploader) + self.additional_libraries = additional_libraries + + def traverse(self, workflows: Union[WorkflowList, List[str]]): + dbx_echo("Starting the traversal process") + self.property_adjuster.library_traverse(workflows, self.additional_libraries) + self.property_adjuster.file_traverse(workflows, self.file_adjuster) + self.property_adjuster.property_traverse(workflows) + self.property_adjuster.cluster_policy_traverse(workflows) + dbx_echo("Traversal process finished, all provided references were resolved") diff --git a/dbx/api/adjuster/mixins/__init__.py b/dbx/api/adjuster/mixins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbx/api/adjuster/mixins/base.py b/dbx/api/adjuster/mixins/base.py new file mode 100644 index 00000000..c02de068 --- /dev/null +++ b/dbx/api/adjuster/mixins/base.py @@ -0,0 +1,32 @@ +from abc import ABC + +from databricks_cli.sdk import ApiClient +from pydantic import BaseModel + +from dbx.models.workflow.common.flexible import FlexibleModel + + +class ApiClientMixin(ABC): + def __init__(self, api_client: ApiClient): + self.api_client = api_client + + +class ElementSetterMixin: + @classmethod + def set_element_at_parent(cls, element, parent, index) -> None: + """ + Sets the element value for various types of parent + :param element: New element value + :param parent: A nested structure where element should be placed + :param index: Position (or pointer) where element should be provided + :return: None + """ + if isinstance(parent, (dict, list)): + parent[index] = element + elif isinstance(parent, (BaseModel, FlexibleModel)): + setattr(parent, index, element) + else: + raise ValueError( + "Cannot apply reference to the parent structure." + f"Please create a GitHub issue providing the following parent object type: {type(parent)}" + ) diff --git a/dbx/api/adjuster/mixins/existing_cluster.py b/dbx/api/adjuster/mixins/existing_cluster.py new file mode 100644 index 00000000..58301603 --- /dev/null +++ b/dbx/api/adjuster/mixins/existing_cluster.py @@ -0,0 +1,43 @@ +import functools +from typing import Optional, List, Any + +from databricks_cli.sdk import ClusterService + +from dbx.api.adjuster.mixins.base import ApiClientMixin, ElementSetterMixin +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.v2dot0.workflow import Workflow as V2dot0Workflow + + +class ClusterInfo(FlexibleModel): + cluster_id: str + cluster_name: str + + +class ListClustersResponse(FlexibleModel): + clusters: Optional[List[ClusterInfo]] = [] + + @property + def cluster_names(self) -> List[str]: + return [p.cluster_name for p in self.clusters] + + def get_cluster(self, name: str) -> ClusterInfo: + _found = list(filter(lambda p: p.cluster_name == name, self.clusters)) + assert _found, NameError( + f"No clusters with name {name} were found. Available clusters are {self.cluster_names}" + ) + assert len(_found) == 1, NameError(f"More than one cluster with name {name} was found: {_found}") + return _found[0] + + +class ExistingClusterAdjuster(ApiClientMixin, ElementSetterMixin): + def _adjust_legacy_existing_cluster(self, element: V2dot0Workflow): + element.existing_cluster_id = self._clusters.get_cluster(element.existing_cluster_name).cluster_id + + def _adjust_existing_cluster_ref(self, element: str, parent: Any, index: Any): + _id = self._clusters.get_cluster(element.replace("cluster://", "")).cluster_id + self.set_element_at_parent(_id, parent, index) + + @functools.cached_property + def _clusters(self) -> ListClustersResponse: + _service = ClusterService(self.api_client) + return ListClustersResponse(**_service.list_clusters()) diff --git a/dbx/api/adjuster/mixins/file_reference.py b/dbx/api/adjuster/mixins/file_reference.py new file mode 100644 index 00000000..51051425 --- /dev/null +++ b/dbx/api/adjuster/mixins/file_reference.py @@ -0,0 +1,13 @@ +from typing import Any + +from dbx.api.adjuster.mixins.base import ElementSetterMixin +from dbx.utils.file_uploader import AbstractFileUploader + + +class FileReferenceAdjuster(ElementSetterMixin): + def __init__(self, file_uploader: AbstractFileUploader): + self._uploader = file_uploader + + def adjust_file_ref(self, element: str, parent: Any, index: Any): + _uploaded = self._uploader.upload_and_provide_path(element) + self.set_element_at_parent(_uploaded, parent, index) diff --git a/dbx/api/adjuster/mixins/instance_pool.py b/dbx/api/adjuster/mixins/instance_pool.py new file mode 100644 index 00000000..fd222318 --- /dev/null +++ b/dbx/api/adjuster/mixins/instance_pool.py @@ -0,0 +1,46 @@ +import functools +from typing import Any, List, Optional + +from databricks_cli.sdk import InstancePoolService + +from dbx.api.adjuster.mixins.base import ApiClientMixin, ElementSetterMixin +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.common.new_cluster import NewCluster + + +class InstancePoolInfo(FlexibleModel): + instance_pool_name: str + instance_pool_id: str + + +class ListInstancePoolsResponse(FlexibleModel): + instance_pools: Optional[List[InstancePoolInfo]] = [] + + @property + def pool_names(self) -> List[str]: + return [p.instance_pool_name for p in self.instance_pools] + + def get_pool(self, name: str) -> InstancePoolInfo: + _found = list(filter(lambda p: p.instance_pool_name == name, self.instance_pools)) + assert _found, NameError(f"No pools with name {name} were found, available pools are {self.pool_names}") + assert len(_found) == 1, NameError(f"More than one pool with name {name} was found: {_found}") + return _found[0] + + +class InstancePoolAdjuster(ApiClientMixin, ElementSetterMixin): + @functools.cached_property + def _instance_pools(self) -> ListInstancePoolsResponse: + _service = InstancePoolService(self.api_client) + return ListInstancePoolsResponse(**_service.list_instance_pools()) + + def _adjust_legacy_driver_instance_pool_ref(self, element: NewCluster): + element.driver_instance_pool_id = self._instance_pools.get_pool( + element.driver_instance_pool_name + ).instance_pool_id + + def _adjust_legacy_instance_pool_ref(self, element: NewCluster): + element.instance_pool_id = self._instance_pools.get_pool(element.instance_pool_name).instance_pool_id + + def _adjust_instance_pool_ref(self, element: str, parent: Any, index: Any): + pool_id = self._instance_pools.get_pool(element.replace("instance-pool://", "")).instance_pool_id + self.set_element_at_parent(pool_id, parent, index) diff --git a/dbx/api/adjuster/mixins/instance_profile.py b/dbx/api/adjuster/mixins/instance_profile.py new file mode 100644 index 00000000..2c683d94 --- /dev/null +++ b/dbx/api/adjuster/mixins/instance_profile.py @@ -0,0 +1,46 @@ +import functools +from typing import Optional, Any, List + +from dbx.api.adjuster.mixins.base import ApiClientMixin, ElementSetterMixin +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.common.new_cluster import NewCluster + + +class InstanceProfileInfo(FlexibleModel): + instance_profile_arn: str + + @property + def instance_profile_name(self): + return self.instance_profile_arn.split("/")[-1] + + +class ListInstanceProfilesResponse(FlexibleModel): + instance_profiles: Optional[List[InstanceProfileInfo]] = [] + + @property + def names(self) -> List[str]: + return [p.instance_profile_name for p in self.instance_profiles] + + def get(self, name: str) -> InstanceProfileInfo: + _found = list(filter(lambda p: p.instance_profile_name == name, self.instance_profiles)) + assert _found, NameError( + f"No instance profiles with name {name} were found, available instance profiles are {self.names}" + ) + assert len(_found) == 1, NameError(f"More than one instance profile with name {name} was found: {_found}") + return _found[0] + + +class InstanceProfileAdjuster(ApiClientMixin, ElementSetterMixin): + def _adjust_legacy_instance_profile_ref(self, element: NewCluster): + _arn = self._instance_profiles.get(element.aws_attributes.instance_profile_name).instance_profile_arn + element.aws_attributes.instance_profile_arn = _arn + + def _adjust_instance_profile_ref(self, element: str, parent: Any, index: Any): + _arn = self._instance_profiles.get(element.replace("instance-profile://", "")).instance_profile_arn + self.set_element_at_parent(_arn, parent, index) + + @functools.cached_property + def _instance_profiles(self) -> ListInstanceProfilesResponse: + return ListInstanceProfilesResponse( + **self.api_client.perform_query(method="GET", path="/instance-profiles/list") + ) diff --git a/dbx/api/adjuster/mixins/pipeline.py b/dbx/api/adjuster/mixins/pipeline.py new file mode 100644 index 00000000..1e0edee7 --- /dev/null +++ b/dbx/api/adjuster/mixins/pipeline.py @@ -0,0 +1,10 @@ +from typing import Any + +from dbx.api.adjuster.mixins.base import ApiClientMixin, ElementSetterMixin +from dbx.api.services.pipelines import NamedPipelinesService + + +class PipelineAdjuster(ApiClientMixin, ElementSetterMixin): + def _adjust_pipeline_ref(self, element: str, parent: Any, index: Any): + _pipeline_id = NamedPipelinesService(self.api_client).find_by_name_strict(element.replace("pipeline://", "")) + self.set_element_at_parent(_pipeline_id, parent, index) diff --git a/dbx/api/adjuster/mixins/service_principal.py b/dbx/api/adjuster/mixins/service_principal.py new file mode 100644 index 00000000..1ff7cda6 --- /dev/null +++ b/dbx/api/adjuster/mixins/service_principal.py @@ -0,0 +1,38 @@ +import functools +from typing import Any, List + +from pydantic import Field + +from dbx.api.adjuster.mixins.base import ApiClientMixin, ElementSetterMixin +from dbx.models.workflow.common.flexible import FlexibleModel + + +class ResourceInfo(FlexibleModel): + display_name: str = Field(str, alias="displayName") # noqa + application_id: str = Field(str, alias="applicationId") # noqa + + +class ListServicePrincipals(FlexibleModel): + Resources: List[ResourceInfo] + + @property + def names(self) -> List[str]: + return [p.display_name for p in self.Resources] + + def get(self, name: str) -> ResourceInfo: + _found = list(filter(lambda p: p.display_name == name, self.Resources)) + assert _found, NameError( + f"No service principals with name {name} were found, available objects are {self.names}" + ) + assert len(_found) == 1, NameError(f"More than one service principal with name {name} was found: {_found}") + return _found[0] + + +class ServicePrincipalAdjuster(ApiClientMixin, ElementSetterMixin): + def _adjust_service_principal_ref(self, element: str, parent: Any, index: Any): + app_id = self._principals.get(element.replace("service-principal://", "")).application_id + self.set_element_at_parent(app_id, parent, index) + + @functools.cached_property + def _principals(self) -> ListServicePrincipals: + return ListServicePrincipals(**self.api_client.perform_query("GET", path="/preview/scim/v2/ServicePrincipals")) diff --git a/dbx/api/adjuster/mixins/sql_properties.py b/dbx/api/adjuster/mixins/sql_properties.py new file mode 100644 index 00000000..8933edc8 --- /dev/null +++ b/dbx/api/adjuster/mixins/sql_properties.py @@ -0,0 +1,120 @@ +import functools +from abc import abstractmethod +from typing import Any, List + +from dbx.api.adjuster.mixins.base import ApiClientMixin, ElementSetterMixin +from dbx.models.workflow.common.flexible import FlexibleModel + + +class NamedModel(FlexibleModel): + id: str + name: str + + +class WarehouseInfo(NamedModel): + """""" + + +class QueryInfo(FlexibleModel): + """""" + + +class DashboardInfo(FlexibleModel): + """""" + + +class AlertInfo(FlexibleModel): + """""" + + +class WarehousesList(FlexibleModel): + warehouses: List[WarehouseInfo] = [] + + @property + def names(self) -> List[str]: + return [p.name for p in self.warehouses] + + def get(self, name: str) -> WarehouseInfo: + _found = list(filter(lambda p: p.name == name, self.warehouses)) + assert _found, NameError(f"No warehouses with name {name} were found, available warehouses are {self.names}") + assert len(_found) == 1, NameError(f"More than one warehouse with name {name} was found:\n{_found}") + return _found[0] + + +class ResultsListGetterMixin: + results: List[Any] + + @property + @abstractmethod + def object_type(self) -> str: + """To be implemented in subclasses""" + + @property + def names(self) -> List[str]: + return [p.name for p in self.results] + + def get(self, name: str) -> Any: + _found = list(filter(lambda p: p.name == name, self.results)) + assert _found, NameError(f"No {self.object_type} with name {name} were found") + assert len(_found) == 1, NameError(f"More than one {self.object_type} with name {name} was found: {_found}") + return _found[0] + + +class QueriesList(FlexibleModel, ResultsListGetterMixin): + @property + def object_type(self) -> str: + return "query" + + results: List[QueryInfo] = [] + + +class DashboardsList(FlexibleModel, ResultsListGetterMixin): + results: List[DashboardInfo] = [] + + @property + def object_type(self) -> str: + return "dashboard" + + +class AlertsList(FlexibleModel, ResultsListGetterMixin): + results: List[AlertInfo] = [] + + @property + def object_type(self) -> str: + return "alert" + + +class SqlPropertiesAdjuster(ApiClientMixin, ElementSetterMixin): + # TODO: design of this class is a terrible copy-paste. It must be rewritten. + + @functools.cached_property + def _warehouses(self) -> WarehousesList: + return WarehousesList(**self.api_client.perform_query("GET", path="/sql/warehouses/")) + + def _adjust_warehouse_ref(self, element: str, parent: Any, index: Any): + _id = self._warehouses.get(element.replace("warehouse://", "")).id + self.set_element_at_parent(_id, parent, index) + + def _adjust_query_ref(self, element: str, parent: Any, index: Any): + query_name = element.replace("query://", "") + _relevant = QueriesList( + **self.api_client.perform_query("GET", path="/preview/sql/queries", data={"q": query_name}) + ) + _id = _relevant.get(query_name).id + self.set_element_at_parent(_id, parent, index) + + def _adjust_dashboard_ref(self, element: str, parent: Any, index: Any): + dashboard_name = element.replace("dashboard://", "") + _relevant = DashboardsList( + **self.api_client.perform_query("GET", path="/preview/sql/dashboards", data={"q": dashboard_name}) + ) + _id = _relevant.get(dashboard_name).id + self.set_element_at_parent(_id, parent, index) + + def _adjust_alert_ref(self, element: str, parent: Any, index: Any): + alert_name = element.replace("alert://", "") + _relevant = DashboardsList( + **self.api_client.perform_query("GET", path="/preview/sql/alerts", data={"q": alert_name}) + ) + _id = _relevant.get(alert_name).id + self.set_element_at_parent(_id, parent, index) diff --git a/dbx/api/adjuster/policy.py b/dbx/api/adjuster/policy.py new file mode 100644 index 00000000..6e970904 --- /dev/null +++ b/dbx/api/adjuster/policy.py @@ -0,0 +1,156 @@ +import json +from collections import defaultdict +from collections.abc import Mapping +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple, Union + +from databricks_cli.cluster_policies.api import PolicyService + +from dbx.api.adjuster.mixins.base import ApiClientMixin +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.common.new_cluster import NewCluster + + +class Policy(FlexibleModel): + policy_id: str + name: str + definition: str + description: Optional[str] + + +class PoliciesResponse(FlexibleModel): + policies: List[Policy] + + +class PolicyAdjuster(ApiClientMixin): + """ + This policy parser is based on: + - API Doc: policy parser is based on API doc https://docs.databricks.com/dev-tools/api/latest/policies.html + - Policy definition docs: + - AWS: https://docs.databricks.com/administration-guide/clusters/policies.html#cluster-policy-attribute-paths + - Azure: https://docs.microsoft.com/en-us/azure/databricks/administration-guide/clusters/policies + - GCP: Cluster policies were not supported at the moment of 0.1.3 release. + Please note that only "fixed" values will be automatically added to the job definition. + """ + + def _adjust_policy_ref(self, cluster: NewCluster): + policy_service = PolicyService(self.api_client) + policy = self._get_policy(policy_service, cluster.policy_name, cluster.policy_id) + traversed_policy = self._traverse_policy(policy_payload=json.loads(policy.definition)) + _updated_object = self._deep_update(cluster.dict(exclude_none=True), traversed_policy) + _updated_object = NewCluster(**_updated_object) + _updated_object.policy_id = policy.policy_id + return _updated_object + + @staticmethod + def _get_policy(policy_service: PolicyService, policy_name: Optional[str], policy_id: Optional[str]) -> Policy: + policy_name = policy_name if policy_name else policy_id.replace("cluster-policy://", "") + all_policies = PoliciesResponse(**policy_service.list_policies()) + relevant_policy = list(filter(lambda p: p.name == policy_name, all_policies.policies)) + + if relevant_policy: + if len(relevant_policy) != 1: + raise ValueError( + f"More than one cluster policy with name {policy_name} found." + f"Available policies are: {all_policies}" + ) + return relevant_policy[0] + + raise ValueError( + f"No cluster policies were fund under name {policy_name}." + f"Available policy names are: {[p.name for p in all_policies.policies]}" + ) + + @staticmethod + def _append_init_scripts(policy_init_scripts: List, existing_init_scripts: List) -> List: + final_init_scripts = deepcopy(policy_init_scripts) + flat_policy_init_scripts = defaultdict(list) + for script in policy_init_scripts: + for k, v in script.items(): + flat_policy_init_scripts[k].append(v["destination"]) + for script in existing_init_scripts: + for k, v in script.items(): + if not v or not v.get("destination"): + raise Exception("init_scripts section format is incorrect in the deployment file") + destination = v["destination"] + if destination not in flat_policy_init_scripts.get(k, []): + # deduplication and ensure init scripts from policy to run firstly + final_init_scripts.append(script) + return final_init_scripts + + @classmethod + def _deep_update(cls, d: Dict, u: Mapping) -> Dict: + for k, v in u.items(): + if isinstance(v, Mapping): + d[k] = cls._deep_update(d.get(k, {}), v) + else: + # if the key is already provided in deployment configuration, we need to verify the value + # if value exists, we verify that it's the same as in the policy + if existing_value := d.get(k): + if k == "init_scripts": + d[k] = PolicyAdjuster._append_init_scripts(v, existing_value) + continue + if existing_value != v: + err_msg = ( + f"For key {k} there is a value in the cluster definition: {existing_value} \n" + f"However this value is fixed in the policy and shall be equal to: {v}." + ) + raise ValueError(err_msg) + d[k] = v + return d + + @staticmethod + def _traverse_policy(policy_payload: Dict[str, Any]) -> Dict[str, Any]: + """ + Idea of this function is the following: + 1. Walk through all items in the source policy + 2. Take only fixed policies + 3. parse the key: + 3.0 if there are no dots, key is a simple string + 3.1 key might be either a composite one, with dots - then we split this key by dots into a tuple + 3.2 a specific case is with spark_conf (such keys might have multiple dots after the spark_conf + 4. definitions will be added into parsed_props variable + 5. Generate Jobs API compatible dictionary with fixed properties + :return: dictionary in a Jobs API compatible format + """ + + parsed_props: List[Tuple[Union[List[str], str], Any]] = [] + for key, definition in policy_payload.items(): + if definition.get("type") == "fixed": + # preprocess key + # for spark_conf keys might contain multiple dots + if key.startswith("spark_conf"): + _key = key.split(".", 1) + elif "." in key: + _key = key.split(".") + else: + _key = key + _value = definition["value"] + parsed_props.append((_key, _value)) + + result = {} + init_scripts = {} + + for key_candidate, value in parsed_props: + if isinstance(key_candidate, str): + result[key_candidate] = value + else: + if key_candidate[0] == "init_scripts": + idx = int(key_candidate[1]) + payload = {key_candidate[2]: {key_candidate[3]: value}} + init_scripts[idx] = payload + else: + d = {key_candidate[-1]: value} + for _k in key_candidate[1:-1]: + d[_k] = d + + updatable = result.get(key_candidate[0], {}) + updatable.update(d) + + result[key_candidate[0]] = updatable + + init_scripts = [init_scripts[k] for k in sorted(init_scripts)] + if init_scripts: + result["init_scripts"] = init_scripts + + return result diff --git a/dbx/api/build.py b/dbx/api/build.py index ee864067..ccdfbef6 100644 --- a/dbx/api/build.py +++ b/dbx/api/build.py @@ -3,9 +3,7 @@ import sys from pathlib import Path from typing import Union, List, Optional -from rich.console import Console -from dbx.models.deployment import BuildConfiguration, PythonBuild from dbx.utils import dbx_echo @@ -29,32 +27,3 @@ def execute_shell_command( except subprocess.CalledProcessError as exc: dbx_echo("\n💥Command execution failed") raise exc - - -def prepare_build(build_config: BuildConfiguration): - if build_config.no_build: - dbx_echo("No build actions will be performed.") - else: - dbx_echo("Following the provided build logic") - - if build_config.commands: - dbx_echo("Running the build commands") - for command in build_config.commands: - with Console().status(f"🔨Running command {command}", spinner="dots"): - execute_shell_command(command) - elif build_config.python: - dbx_echo("🐍 Building a Python-based project") - cleanup_dist() - - if build_config.python == PythonBuild.poetry: - build_kwargs = {"cmd": "poetry build -f wheel"} - elif build_config.python == PythonBuild.flit: - command = "-m flit build --format wheel" - build_kwargs = {"cmd": command, "with_python_executable": True} - else: - command = "-m pip wheel -w dist -e . --prefer-binary --no-deps" - build_kwargs = {"cmd": command, "with_python_executable": True} - - with Console().status("Building the package :hammer:", spinner="dots"): - execute_shell_command(**build_kwargs) - dbx_echo(":white_check_mark: Python-based project build finished") diff --git a/dbx/api/cluster.py b/dbx/api/cluster.py index c5d7404d..5eff250c 100644 --- a/dbx/api/cluster.py +++ b/dbx/api/cluster.py @@ -8,12 +8,13 @@ class ClusterController: - def __init__(self, api_client: ApiClient): + def __init__(self, api_client: ApiClient, cluster_name: Optional[str], cluster_id: Optional[str]): self._cluster_service = ClusterService(api_client) + self.cluster_id = self.preprocess_cluster_args(cluster_name, cluster_id) - def awake_cluster(self, cluster_id): - with Console().status("Preparing the all-purpose cluster", spinner="dots") as status: - self._awake_cluster(cluster_id, status) + def awake_cluster(self): + with Console().status("Preparing the all-purpose cluster to accept commands", spinner="dots") as status: + self._awake_cluster(self.cluster_id, status) def _awake_cluster(self, cluster_id, status: Status): cluster_info = self._cluster_service.get_cluster(cluster_id) diff --git a/dbx/api/config_reader.py b/dbx/api/config_reader.py index 4140fd2f..674dc108 100644 --- a/dbx/api/config_reader.py +++ b/dbx/api/config_reader.py @@ -6,6 +6,7 @@ import jinja2 import yaml +from pydantic import BaseModel import dbx.api.jinja as dbx_jinja from dbx.api._module_loader import load_module_from_source @@ -19,9 +20,9 @@ class _AbstractConfigReader(ABC): def __init__(self, path: Path): self._path = path - self.config = self._get_config() + self.config = self.get_config() - def _get_config(self) -> DeploymentConfig: + def get_config(self) -> DeploymentConfig: return self._read_file() @abstractmethod @@ -103,6 +104,11 @@ def _read_file(self) -> DeploymentConfig: raise Exception(f"Unexpected extension for Jinja reader: {self._ext}") +class BuildProperties(BaseModel): + potential_build: bool = False + no_rebuild: bool = False + + class ConfigReader: """ Entrypoint for reading the raw configurations from files. @@ -114,6 +120,11 @@ def __init__(self, path: Path, jinja_vars_file: Optional[Path] = None): self._jinja_vars_file = jinja_vars_file self._path = path self._reader = self._define_reader() + self._build_properties = BuildProperties() + + def with_build_properties(self, build_properties: BuildProperties): + self._build_properties = build_properties + return self def _define_reader(self) -> _AbstractConfigReader: if len(self._path.suffixes) > 1: @@ -146,7 +157,25 @@ def _define_reader(self) -> _AbstractConfigReader: ) def get_config(self) -> DeploymentConfig: - return self._reader.config + + if self._build_properties.potential_build: + dbx_echo("Reading the build config section first to identify build steps") + build_config = self._reader.config.build + + if self._build_properties.no_rebuild: + dbx_echo( + """[yellow bold] + Legacy [code]--no-rebuild[/code] flag has been used. + Please specify build logic in the build section of the deployment file instead.[/yellow bold]""" + ) + build_config.no_build = True + return self._reader.config + + build_config.trigger_build_process() + dbx_echo("🔄 Build process finished, reloading the config to catch changes if any") + return self._reader.get_config() # reload config after build + else: + return self._reader.config def get_environment(self, environment: str) -> Optional[EnvironmentDeploymentInfo]: """ diff --git a/dbx/api/configure.py b/dbx/api/configure.py index d4160d4d..910246cc 100644 --- a/dbx/api/configure.py +++ b/dbx/api/configure.py @@ -2,7 +2,7 @@ from typing import Optional from dbx.constants import PROJECT_INFO_FILE_PATH -from dbx.models.project import EnvironmentInfo, ProjectInfo +from dbx.models.files.project import EnvironmentInfo, ProjectInfo from dbx.utils.json import JsonUtils @@ -67,6 +67,15 @@ def get_failsafe_cluster_reuse(self): _result = self._read_typed().failsafe_cluster_reuse_with_assets if self._file.exists() else False return _result + def enable_context_based_upload_for_execute(self): + _typed = self._read_typed() + _typed.context_based_upload_for_execute = True + JsonUtils.write(self._file, _typed.dict()) + + def get_context_based_upload_for_execute(self) -> bool: + _result = self._read_typed().context_based_upload_for_execute if self._file.exists() else False + return _result + class ProjectConfigurationManager: def __init__(self): @@ -92,3 +101,9 @@ def enable_failsafe_cluster_reuse(self): def get_failsafe_cluster_reuse(self) -> bool: return self._manager.get_failsafe_cluster_reuse() + + def enable_context_based_upload_for_execute(self): + self._manager.enable_context_based_upload_for_execute() + + def get_context_based_upload_for_execute(self) -> bool: + return self._manager.get_context_based_upload_for_execute() diff --git a/dbx/api/context.py b/dbx/api/context.py index a14c2b59..8a8f813a 100644 --- a/dbx/api/context.py +++ b/dbx/api/context.py @@ -3,11 +3,12 @@ from pathlib import Path from typing import Optional, List, Any +import typer from databricks_cli.sdk import ApiClient from dbx.api.client_provider import ApiV1Client from dbx.constants import LOCK_FILE_PATH -from dbx.models.context import ContextInfo +from dbx.models.files.context import ContextInfo from dbx.utils import dbx_echo from dbx.utils.json import JsonUtils @@ -69,13 +70,14 @@ def execute_command(self, command: str, verbose=True) -> Optional[str]: final_result = execution_result["results"]["resultType"] if final_result == "error": dbx_echo("Execution failed, please follow the given error") - raise RuntimeError( - "Command execution failed. Traceback from cluster: \n" f'{execution_result["results"]["cause"]}' - ) + _traceback = execution_result["results"]["cause"] + print(_traceback) + raise typer.Exit(1) if verbose: dbx_echo("Command successfully executed") if result_data: + dbx_echo("🔊 stdout from the execution is shown below:") print(result_data) return result_data @@ -118,6 +120,7 @@ def context_id(self): class RichExecutionContextClient: def __init__(self, v2_client: ApiClient, cluster_id: str, language: str = "python"): + self.api_client = v2_client self._client = LowLevelExecutionContextClient(v2_client, cluster_id, language) def install_package(self, package_file: str, pip_install_extras: Optional[str]): diff --git a/dbx/api/dependency/__init__.py b/dbx/api/dependency/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbx/api/dependency/core_package.py b/dbx/api/dependency/core_package.py new file mode 100644 index 00000000..abd0da4d --- /dev/null +++ b/dbx/api/dependency/core_package.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import Optional + +from dbx.models.workflow.common.libraries import Library +from dbx.utils import dbx_echo + + +class CorePackageManager: + def __init__(self): + self._core_package: Optional[Library] = self.prepare_core_package() + + @property + def core_package(self) -> Optional[Library]: + return self._core_package + + def prepare_core_package(self) -> Optional[Library]: + package_file = self.get_package_file() + + if package_file: + return Library(whl=f"file://{package_file}") + else: + dbx_echo( + "Package file was not found. Please check the dist folder if you expect to use package-based imports" + ) + + @staticmethod + def get_package_file() -> Optional[Path]: + dbx_echo("Locating package file") + file_locator = list(Path("dist").glob("*.whl")) + sorted_locator = sorted( + file_locator, key=os.path.getmtime + ) # get latest modified file, aka latest package version + if sorted_locator: + file_path = sorted_locator[-1] + dbx_echo(f"Package file located in: {file_path}") + return file_path + else: + dbx_echo("Package file was not found") + return None diff --git a/dbx/api/dependency/requirements.py b/dbx/api/dependency/requirements.py new file mode 100644 index 00000000..a7cf545e --- /dev/null +++ b/dbx/api/dependency/requirements.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from pathlib import Path +from typing import List + +import pkg_resources + +from dbx.models.workflow.common.libraries import Library, PythonPyPiLibrary +from dbx.utils import dbx_echo + + +class RequirementsFileProcessor: + def __init__(self, requirements_file: Path): + self._requirements_file = requirements_file + self._libraries = self.parse_requirements() + + @property + def libraries(self) -> List[Library]: + return self._libraries + + @staticmethod + def _delete_managed_libraries(packages: List[pkg_resources.Requirement]) -> List[pkg_resources.Requirement]: + output_packages = [] + for package in packages: + if package.key == "pyspark": + dbx_echo("pyspark dependency deleted from the list of libraries, because it's a managed library") + else: + output_packages.append(package) + return output_packages + + def parse_requirements(self) -> List[Library]: + with self._requirements_file.open(encoding="utf-8") as requirements_txt: + requirements_content = pkg_resources.parse_requirements(requirements_txt) + filtered_libraries = self._delete_managed_libraries(requirements_content) + libraries = [Library(pypi=PythonPyPiLibrary(package=str(req))) for req in filtered_libraries] + return libraries diff --git a/dbx/api/deployment.py b/dbx/api/deployment.py new file mode 100644 index 00000000..c3689c14 --- /dev/null +++ b/dbx/api/deployment.py @@ -0,0 +1,41 @@ +from databricks_cli.sdk import ApiClient + +from dbx.api.adjuster.mixins.base import ApiClientMixin +from dbx.api.services.jobs import NamedJobsService +from dbx.api.services.permissions import PermissionsService +from dbx.api.services.pipelines import NamedPipelinesService +from dbx.models.deployment import WorkflowList, AnyWorkflow +from dbx.models.workflow.common.workflow_types import WorkflowType +from dbx.utils import dbx_echo + + +class WorkflowDeploymentManager(ApiClientMixin): + def __init__(self, api_client: ApiClient, workflows: WorkflowList): + super().__init__(api_client) + self._wfs = workflows + self._deployment_data = {} + self._pipeline_service = NamedPipelinesService(api_client) + self._jobs_service = NamedJobsService(api_client) + + def _apply_permissions(self, wf: AnyWorkflow): + PermissionsService(self.api_client).apply(wf) + + def _deploy(self, wf: AnyWorkflow): + service_instance = ( + self._jobs_service if not wf.workflow_type == WorkflowType.pipeline else self._pipeline_service + ) + obj_id = service_instance.find_by_name(wf.name) + + if not obj_id: + service_instance.create(wf) + else: + service_instance.update(obj_id, wf) + + def apply(self): + dbx_echo("🤖 Applying workflow definitions via API") + + for wf in self._wfs: + self._deploy(wf) + self._apply_permissions(wf) + + dbx_echo("✅ Applying workflow definitions - done") diff --git a/dbx/api/destroyer.py b/dbx/api/destroyer.py index 609202ad..b2912302 100644 --- a/dbx/api/destroyer.py +++ b/dbx/api/destroyer.py @@ -6,14 +6,18 @@ from typing import List import mlflow -from databricks_cli.jobs.api import JobsApi from databricks_cli.sdk import ApiClient from mlflow.entities import Run +from rich.markup import escape from rich.progress import track from typer.rich_utils import _get_rich_console # noqa -from dbx.models.destroyer import DestroyerConfig, DeletionMode -from dbx.models.project import EnvironmentInfo +from dbx.api.services.jobs import NamedJobsService +from dbx.api.services.pipelines import NamedPipelinesService +from dbx.models.cli.destroyer import DestroyerConfig, DeletionMode +from dbx.models.deployment import AnyWorkflow +from dbx.models.files.project import EnvironmentInfo +from dbx.models.workflow.common.workflow_types import WorkflowType from dbx.utils import dbx_echo @@ -24,27 +28,28 @@ def erase(self): class WorkflowEraser(Eraser): - def __init__(self, api_client: ApiClient, workflows: List[str], dry_run: bool): + def __init__(self, api_client: ApiClient, workflows: List[AnyWorkflow], dry_run: bool): self._client = api_client self._workflows = workflows self._dry_run = dry_run + self._pipeline_service = NamedPipelinesService(self._client) + self._jobs_service = NamedJobsService(self._client) - def _delete_workflow(self, workflow): - dbx_echo(f"Job object {workflow} will be deleted") - api = JobsApi(self._client) - found = api._list_jobs_by_name(workflow) # noqa + def _delete_workflow(self, workflow: AnyWorkflow): + dbx_echo(f"Workflow {escape(workflow.name)} will be deleted") + service_instance = ( + self._jobs_service if not workflow.workflow_type == WorkflowType.pipeline else self._pipeline_service + ) + obj_id = service_instance.find_by_name(workflow.name) - if len(found) > 1: - raise Exception(f"More than one job with name {workflow} was found, please check the duplicates in the UI") - if len(found) == 0: - dbx_echo(f"Job with name {workflow} doesn't exist, no deletion is required") + if not obj_id: + dbx_echo(f"Workflow with name {escape(workflow.name)} doesn't exist, no deletion is required") else: - _job = found[0] if self._dry_run: - dbx_echo(f"Job {workflow} with definition {_job} would be deleted in case of a real run") + dbx_echo(f"Workflow {escape(workflow.name)} with definition would be deleted in case of a real run") else: - api.delete_job(_job["job_id"]) - dbx_echo(f"Job object with name {workflow} was successfully deleted ✅") + service_instance.delete(obj_id) + dbx_echo(f"Workflow object with name {escape(workflow.name)} was successfully deleted ✅") def erase(self): for w in self._workflows: diff --git a/dbx/api/execute.py b/dbx/api/execute.py index a47f5870..76bda1de 100644 --- a/dbx/api/execute.py +++ b/dbx/api/execute.py @@ -1,14 +1,16 @@ from pathlib import Path -from typing import Optional, List, Any +from typing import Optional, List, Union, Dict import mlflow from rich.console import Console +from dbx.api.adjuster.adjuster import Adjuster, AdditionalLibrariesProvider from dbx.api.context import RichExecutionContextClient -from dbx.models.task import Task, TaskType, PythonWheelTask +from dbx.models.workflow.common.libraries import Library +from dbx.models.workflow.common.task_type import TaskType +from dbx.models.workflow.v2dot1.task import PythonWheelTask +from dbx.types import ExecuteTask from dbx.utils import dbx_echo -from dbx.utils.adjuster import adjust_path, walk_content -from dbx.utils.common import get_package_file from dbx.utils.file_uploader import MlflowFileUploader, ContextBasedUploader @@ -17,24 +19,27 @@ def __init__( self, client: RichExecutionContextClient, no_package: bool, + core_package: Optional[Library], upload_via_context: bool, requirements_file: Optional[Path], - task: Task, + task: ExecuteTask, pip_install_extras: Optional[str], ): + self.additional_libraries = AdditionalLibrariesProvider(no_package=no_package, core_package=core_package) self._client = client self._requirements_file = requirements_file - self._no_package = no_package self._task = task self._upload_via_context = upload_via_context self._pip_install_extras = pip_install_extras self._run = None - if not self._upload_via_context: + if self._upload_via_context: + dbx_echo("Context-based file uploader will be used") + self._file_uploader = ContextBasedUploader(self._client) + else: + dbx_echo("Mlflow-based file uploader will be used") self._run = mlflow.start_run() self._file_uploader = MlflowFileUploader(self._run.info.artifact_uri) - else: - self._file_uploader = ContextBasedUploader(self._client) def execute_entrypoint_file(self, _file: Path): dbx_echo("Starting entrypoint file execution") @@ -49,15 +54,16 @@ def execute_entrypoint(self, task: PythonWheelTask): dbx_echo("Entrypoint execution finished") def run(self): - if self._requirements_file.exists(): + if self._requirements_file: self.install_requirements_file() - if not self._no_package: + if not self.additional_libraries.no_package: self.install_package(self._pip_install_extras) if self._task.task_type == TaskType.spark_python_task: self.preprocess_task_parameters(self._task.spark_python_task.parameters) - self.execute_entrypoint_file(self._task.spark_python_task.python_file) + self.execute_entrypoint_file(self._task.spark_python_task.execute_file) + elif self._task.task_type == TaskType.python_wheel_task: if self._task.python_wheel_task.named_parameters: self.preprocess_task_parameters(self._task.python_wheel_task.named_parameters) @@ -69,32 +75,38 @@ def run(self): mlflow.end_run() def install_requirements_file(self): + if not self._requirements_file.exists(): + raise Exception(f"Requirements file provided, but doesn't exist at path {self._requirements_file}") + dbx_echo("Installing provided requirements") - localized_requirements_path = self._file_uploader.upload_and_provide_path(self._requirements_file, as_fuse=True) + localized_requirements_path = self._file_uploader.upload_and_provide_path( + f"file:fuse://{self._requirements_file}" + ) installation_command = f"%pip install -U -r {localized_requirements_path}" self._client.client.execute_command(installation_command, verbose=False) dbx_echo("Provided requirements installed") def install_package(self, pip_install_extras: Optional[str]): - package_file = get_package_file() - if not package_file: + if not self.additional_libraries.core_package: raise FileNotFoundError("Project package was not found. Please check that /dist directory exists.") dbx_echo("Uploading package") - driver_package_path = self._file_uploader.upload_and_provide_path(package_file, as_fuse=True) + stripped_package_path = self.additional_libraries.core_package.whl.replace("file://", "") + localized_package_path = self._file_uploader.upload_and_provide_path(f"file:fuse://{stripped_package_path}") dbx_echo(":white_check_mark: Uploading package - done") with Console().status("Installing package on the cluster 📦", spinner="dots"): - self._client.install_package(driver_package_path, pip_install_extras) + self._client.install_package(localized_package_path, pip_install_extras) dbx_echo(":white_check_mark: Installing package - done") - def preprocess_task_parameters(self, parameters: List[str]): + def preprocess_task_parameters(self, parameters: Union[List[str], Dict[str, str]]): dbx_echo(f":fast_forward: Processing task parameters: {parameters}") - def adjustment_callback(p: Any): - return adjust_path(p, self._file_uploader) - - walk_content(adjustment_callback, parameters) + Adjuster( + api_client=self._client.api_client, + additional_libraries=self.additional_libraries, + file_uploader=self._file_uploader, + ).traverse(parameters) self._client.setup_arguments(parameters) dbx_echo(":white_check_mark: Processing task parameters") diff --git a/dbx/api/launch/functions.py b/dbx/api/launch/functions.py index 57190bbf..b0942520 100644 --- a/dbx/api/launch/functions.py +++ b/dbx/api/launch/functions.py @@ -1,37 +1,26 @@ -import tempfile import time -from pathlib import Path from typing import Dict, Any, List import mlflow from databricks_cli.sdk import ApiClient, JobsService from mlflow.entities import Run -from mlflow.tracking import MlflowClient from rich.console import Console from dbx.api.configure import ProjectConfigurationManager +from dbx.api.launch.runners.base import RunData from dbx.constants import TERMINAL_RUN_LIFECYCLE_STATES from dbx.utils import dbx_echo, format_dbx_message -from dbx.utils.json import JsonUtils -def cancel_run(api_client: ApiClient, run_data: Dict[str, Any]): +def cancel_run(api_client: ApiClient, run_data: RunData): jobs_service = JobsService(api_client) - jobs_service.cancel_run(run_data["run_id"]) + jobs_service.cancel_run(run_data.run_id) wait_run(api_client, run_data) -def load_dbx_file(run_id: str, file_name: str) -> Dict[Any, Any]: - client = MlflowClient() - with tempfile.TemporaryDirectory() as tmp: - dbx_file_path = f".dbx/{file_name}" - client.download_artifacts(run_id, dbx_file_path, tmp) - return JsonUtils.read(Path(tmp) / dbx_file_path) - - -def wait_run(api_client: ApiClient, run_data: Dict[str, Any]) -> Dict[str, Any]: +def wait_run(api_client: ApiClient, run_data: RunData) -> Dict[str, Any]: with Console().status( - format_dbx_message(f"Tracing run with id {run_data['run_id']}"), spinner="dots" + format_dbx_message(f"Tracing run with id {run_data.run_id}"), spinner="dots" ) as console_status: while True: time.sleep(5) # runs API is eventually consistent, it's better to have a short pause for status update @@ -39,15 +28,15 @@ def wait_run(api_client: ApiClient, run_data: Dict[str, Any]) -> Dict[str, Any]: run_state = status["state"] life_cycle_state = run_state.get("life_cycle_state", None) - console_status.update(format_dbx_message(f"[Run Id: {run_data['run_id']}] run state: {run_state}")) + console_status.update(format_dbx_message(f"[Run Id: {run_data.run_id}] run state: {run_state}")) if life_cycle_state in TERMINAL_RUN_LIFECYCLE_STATES: return status -def get_run_status(api_client: ApiClient, run_data: Dict[str, Any]) -> Dict[str, Any]: +def get_run_status(api_client: ApiClient, run_data: RunData) -> Dict[str, Any]: jobs_service = JobsService(api_client) - run_status = jobs_service.get_run(run_data["run_id"]) + run_status = jobs_service.get_run(run_data.run_id) return run_status @@ -99,9 +88,9 @@ def _filter_by_tags(run: Run) -> bool: return last_run_info -def trace_run(api_client: ApiClient, run_data: Dict[str, Any]) -> [str, Dict[str, Any]]: +def trace_run(api_client: ApiClient, run_data: RunData) -> [str, Dict[str, Any]]: final_status = wait_run(api_client, run_data) - dbx_echo(f"Finished tracing run with id {run_data['run_id']}") + dbx_echo(f"Finished tracing run with id {run_data.run_id}") result_state = final_status["state"].get("result_state", None) if result_state == "SUCCESS": dbx_echo("Job run finished successfully") diff --git a/dbx/api/launch/pipeline_models.py b/dbx/api/launch/pipeline_models.py new file mode 100644 index 00000000..3208fe1e --- /dev/null +++ b/dbx/api/launch/pipeline_models.py @@ -0,0 +1,53 @@ +from enum import Enum +from typing import Optional, List + +from pydantic import BaseModel + + +class UpdateStatus(str, Enum): + ACTIVE = "ACTIVE" + TERMINATED = "TERMINATED" + + +class PipelineGlobalState(str, Enum): + IDLE = "IDLE" + RUNNING = "RUNNING" + + +class PipelineUpdateState(str, Enum): + QUEUED = "QUEUED" + CREATED = "CREATED" + WAITING_FOR_RESOURCES = "WAITING_FOR_RESOURCES" + INITIALIZING = "INITIALIZING" + RESETTING = "RESETTING" + SETTING_UP_TABLES = "SETTING_UP_TABLES" + RUNNING = "RUNNING" + STOPPING = "STOPPING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + CANCELED = "CANCELED" + + +class StartCause(str, Enum): + API_CALL = "API_CALL" + RETRY_ON_FAILURE = "RETRY_ON_FAILURE" + SERVICE_UPGRADE = "SERVICE_UPGRADE" + SCHEMA_CHANGE = "SCHEMA_CHANGE" + JOB_TASK = "JOB_TASK" + USER_ACTION = "USER_ACTION" + + +class LatestUpdate(BaseModel): + update_id: str + state: PipelineUpdateState + cause: Optional[StartCause] + + +class PipelineUpdateStatus(BaseModel): + status: Optional[UpdateStatus] + latest_update: LatestUpdate + + +class PipelineDetails(BaseModel): + state: Optional[PipelineGlobalState] + latest_updates: Optional[List[LatestUpdate]] = [] diff --git a/dbx/api/launch/processors.py b/dbx/api/launch/processors.py index a192c5b9..dc9227b7 100644 --- a/dbx/api/launch/processors.py +++ b/dbx/api/launch/processors.py @@ -1,28 +1,17 @@ -from copy import deepcopy -from typing import Dict, Any - from rich.console import Console -from dbx.models.job_clusters import JobClusters +from dbx.models.workflow.v2dot1.workflow import Workflow from dbx.utils import dbx_echo class ClusterReusePreprocessor: - def __init__(self, job_spec: Dict[str, Any]): - self._job_spec = deepcopy(job_spec) - self._job_clusters = JobClusters(**job_spec) - # delete the job clusters section from the spec - self._job_spec.pop("job_clusters") - - def _preprocess_task_definition(self, task: Dict[str, Any]): - task_cluster_key = task.pop("job_cluster_key") - definition = self._job_clusters.get_cluster_definition(task_cluster_key).new_cluster - task.update({"new_cluster": definition}) - - def process(self) -> Dict[str, Any]: + @classmethod + def process(cls, workflow: Workflow): with Console().status("🔍 Iterating over task definitions to find shared job cluster usages", spinner="dots"): - for task in self._job_spec.get("tasks", []): - if "job_cluster_key" in task: - self._preprocess_task_definition(task) + for task in workflow.tasks: + if task.job_cluster_key is not None: + _definition = workflow.get_job_cluster_definition(task.job_cluster_key) + task.job_cluster_key = None + task.new_cluster = _definition.new_cluster + workflow.job_clusters = None dbx_echo("✅ All shared job cluster usages were replaced with their relevant cluster definitions") - return self._job_spec diff --git a/dbx/api/launch/runners.py b/dbx/api/launch/runners.py deleted file mode 100644 index 38b1ac26..00000000 --- a/dbx/api/launch/runners.py +++ /dev/null @@ -1,191 +0,0 @@ -import json -from copy import deepcopy -from typing import Optional, Union, Tuple, Dict, Any - -from databricks_cli.sdk import ApiClient, JobsService - -from dbx.api.configure import ProjectConfigurationManager -from dbx.api.launch.functions import cancel_run, load_dbx_file, wait_run -from dbx.api.launch.processors import ClusterReusePreprocessor -from dbx.models.options import ExistingRunsOption -from dbx.models.parameters.run_now import RunNowV2d0ParamInfo, RunNowV2d1ParamInfo -from dbx.models.parameters.run_submit import RunSubmitV2d0ParamInfo, RunSubmitV2d1ParamInfo -from dbx.utils import dbx_echo -from dbx.utils.job_listing import find_job_by_name - - -class RunSubmitLauncher: - def __init__( - self, - job: str, - api_client: ApiClient, - deployment_run_id: str, - environment: str, - parameters: Optional[str] = None, - ): - self.run_id = deployment_run_id - self.job = job - self.api_client = api_client - self.environment = environment - self.failsafe_cluster_reuse = ProjectConfigurationManager().get_failsafe_cluster_reuse() - self._parameters = None if not parameters else self._process_parameters(parameters) - - def _process_parameters(self, payload: str) -> Union[RunSubmitV2d0ParamInfo, RunSubmitV2d1ParamInfo]: - _payload = json.loads(payload) - - if self.api_client.jobs_api_version == "2.1": - return RunSubmitV2d1ParamInfo(**_payload) - else: - return RunSubmitV2d0ParamInfo(**_payload) - - def launch(self) -> Tuple[Dict[Any, Any], Optional[str]]: - dbx_echo("Launching workflow via run submit API") - - env_spec = load_dbx_file(self.run_id, "deployment-result.json").get(self.environment) - - if not env_spec: - raise Exception(f"No workflow definitions found for environment {self.environment}") - - job_specs = env_spec.get("jobs") - - found_jobs = [j for j in job_specs if j["name"] == self.job] - - if not found_jobs: - raise Exception(f"Workflow definition {self.job} not found in deployment spec") - - job_spec: Dict[str, Any] = found_jobs[0] - job_spec.pop("name") - - service = JobsService(self.api_client) - - if self.failsafe_cluster_reuse: - if "job_clusters" in job_spec: - processor = ClusterReusePreprocessor(job_spec) - job_spec = processor.process() - - if self._parameters: - final_spec = self._add_parameters(job_spec, self._parameters) - else: - final_spec = job_spec - - run_data = service.submit_run(**final_spec) - return run_data, None - - @staticmethod - def override_v2d0_parameters(_spec: Dict[str, Any], parameters: RunSubmitV2d0ParamInfo): - expected_task_key = parameters.get_task_key() - task_section = _spec.get(expected_task_key) - if not task_section: - raise ValueError( - f""" - While overriding launch parameters the task key {expected_task_key} was not found in the - workflow specification {_spec}. - Please check that you override the task parameters correctly and - accordingly to the RunSubmit V2.0 API. - """ - ) - - expected_parameters_key = parameters.get_defined_task().get_parameters_key() - task_section[expected_parameters_key] = parameters.get_defined_task().get_parameters() - - @staticmethod - def override_v2d1_parameters(_spec: Dict[str, Any], parameters: RunSubmitV2d1ParamInfo): - tasks_in_spec = _spec.get("tasks") - if not tasks_in_spec: - raise ValueError( - f""" - While overriding launch parameters the "tasks" section was not found in the - workflow specification {_spec}. - Please check that you override the task parameters correctly and - accordingly to the RunSubmit V2.1 API. - """ - ) - for _task in parameters.tasks: - if _task.task_key not in [t.get("task_key") for t in tasks_in_spec]: - raise ValueError( - f""" - While overriding launch parameters task with key {_task.task_key} was not found in the tasks - specification {tasks_in_spec}. - Please check that you override the task parameters correctly and - accordingly to the RunSubmit V2.1 API. - """ - ) - - _task_container_spec = [t for t in tasks_in_spec if t["task_key"] == _task.task_key][0] - _task_spec = _task_container_spec.get(_task.get_task_key()) - - if not _task_spec: - raise ValueError( - f""" - While overriding launch parameters task with key {_task.task_key} was found in the tasks - specification, but task has a different type then the provided parameters: - - Provided parameters: {_task.dict(exclude_none=True)} - Task payload: {_task_spec} - - Please check that you override the task parameters correctly and - accordingly to the RunSubmit V2.1 API. - """ - ) - expected_parameters_key = _task.get_defined_task().get_parameters_key() - _task_spec[expected_parameters_key] = _task.get_defined_task().get_parameters() - - def _add_parameters( - self, workflow_spec: Dict[str, Any], parameters: Union[RunSubmitV2d0ParamInfo, RunSubmitV2d1ParamInfo] - ) -> Dict[str, Any]: - _spec = deepcopy(workflow_spec) - - if isinstance(parameters, RunSubmitV2d0ParamInfo): - self.override_v2d0_parameters(_spec, parameters) - else: - self.override_v2d1_parameters(_spec, parameters) - return _spec - - -class RunNowLauncher: - def __init__( - self, job: str, api_client: ApiClient, existing_runs: ExistingRunsOption, parameters: Optional[str] = None - ): - self.job = job - self.api_client = api_client - self.existing_runs: ExistingRunsOption = existing_runs - self._parameters = None if not parameters else self._process_parameters(parameters) - - def _process_parameters(self, payload: str) -> Union[RunNowV2d0ParamInfo, RunNowV2d1ParamInfo]: - _payload = json.loads(payload) - if self.api_client.jobs_api_version == "2.1": - return RunNowV2d1ParamInfo(**_payload) - else: - return RunNowV2d0ParamInfo(**_payload) - - def launch(self) -> Tuple[Dict[Any, Any], Optional[str]]: - dbx_echo("Launching job via run now API") - jobs_service = JobsService(self.api_client) - job_data = find_job_by_name(jobs_service, self.job) - - if not job_data: - raise Exception(f"Job with name {self.job} not found") - - job_id = job_data["job_id"] - - active_runs = jobs_service.list_runs(job_id, active_only=True).get("runs", []) - - for run in active_runs: - if self.existing_runs == ExistingRunsOption.pass_: - dbx_echo("Passing the existing runs status check") - elif self.existing_runs == ExistingRunsOption.wait: - dbx_echo(f'Waiting for job run with id {run["run_id"]} to be finished') - wait_run(self.api_client, run) - elif self.existing_runs == ExistingRunsOption.cancel: - dbx_echo(f'Cancelling run with id {run["run_id"]}') - cancel_run(self.api_client, run) - - if self._parameters: - dbx_echo(f"Running the workload with the provided parameters {self._parameters.dict()}") - _additional_parameters = self._parameters.dict() - else: - _additional_parameters = {} - - run_data = jobs_service.run_now(job_id=job_id, **_additional_parameters) - - return run_data, job_id diff --git a/dbx/api/launch/runners/__init__.py b/dbx/api/launch/runners/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbx/api/launch/runners/asset_based.py b/dbx/api/launch/runners/asset_based.py new file mode 100644 index 00000000..a42afe31 --- /dev/null +++ b/dbx/api/launch/runners/asset_based.py @@ -0,0 +1,83 @@ +import inspect +import json +from typing import Optional, Union, Tuple, Dict, Any + +from databricks_cli.sdk import ApiClient, JobsService + +from dbx.api.configure import ProjectConfigurationManager +from dbx.api.launch.processors import ClusterReusePreprocessor +from dbx.api.launch.runners.base import RunData +from dbx.api.storage.io import StorageIO +from dbx.models.deployment import EnvironmentDeploymentInfo +from dbx.models.workflow.v2dot0.parameters import AssetBasedRunPayload as V2dot0AssetBasedParametersPayload +from dbx.models.workflow.v2dot1.parameters import AssetBasedRunPayload as V2dot1AssetBasedParametersPayload +from dbx.models.workflow.v2dot1.workflow import Workflow as V2dot1Workflow +from dbx.utils import dbx_echo + + +class AssetBasedLauncher: + def __init__( + self, + workflow_name: str, + api_client: ApiClient, + deployment_run_id: str, + environment_name: str, + parameters: Optional[str] = None, + ): + self.run_id = deployment_run_id + self.workflow_name = workflow_name + self.api_client = api_client + self.environment_name = environment_name + self.failsafe_cluster_reuse = ProjectConfigurationManager().get_failsafe_cluster_reuse() + self._parameters = None if not parameters else self._process_parameters(parameters) + + def _process_parameters( + self, payload: str + ) -> Union[V2dot0AssetBasedParametersPayload, V2dot1AssetBasedParametersPayload]: + + if self.api_client.jobs_api_version == "2.0": + return V2dot0AssetBasedParametersPayload(**json.loads(payload)) + else: + return V2dot1AssetBasedParametersPayload.from_string(payload) + + def launch(self) -> Tuple[RunData, Optional[int]]: + dbx_echo( + f"Launching workflow in assets-based mode " + f"(via RunSubmit method, Jobs API V{self.api_client.jobs_api_version})" + ) + + service = JobsService(self.api_client) + env_spec = StorageIO.load(self.run_id, "deployment-result.json") + _config = EnvironmentDeploymentInfo.from_spec( + self.environment_name, env_spec.get(self.environment_name), reader_type="remote" + ) + + workflow = _config.payload.get_workflow(self.workflow_name) + + if self.failsafe_cluster_reuse: + if isinstance(workflow, V2dot1Workflow) and workflow.job_clusters: + ClusterReusePreprocessor.process(workflow) + + if self._parameters: + dbx_echo(f"Running the workload with the provided parameters {self._parameters.dict(exclude_none=True)}") + workflow.override_asset_based_launch_parameters(self._parameters) + + final_spec = workflow.dict(exclude_none=True, exclude_unset=True) + cleaned_spec = self._cleanup_unsupported_properties(final_spec) + run_data = service.submit_run(**cleaned_spec) + + return RunData(**run_data), None + + @staticmethod + def _cleanup_unsupported_properties(spec: Dict[str, Any]) -> Dict[str, Any]: + expected_props = inspect.getfullargspec(JobsService.submit_run).args + cleaned_args = {} + for _prop in spec: + if _prop not in expected_props: + dbx_echo( + f"[yellow bold]Property {_prop} is not supported in the assets-only launch mode." + f" It will be ignored during current launch.[/yellow bold]" + ) + else: + cleaned_args[_prop] = spec[_prop] + return cleaned_args diff --git a/dbx/api/launch/runners/base.py b/dbx/api/launch/runners/base.py new file mode 100644 index 00000000..a65f5517 --- /dev/null +++ b/dbx/api/launch/runners/base.py @@ -0,0 +1,12 @@ +from typing import Optional + +from pydantic import BaseModel + + +class RunData(BaseModel): + run_id: Optional[int] + + +class PipelineUpdateResponse(BaseModel): + update_id: str + request_id: str diff --git a/dbx/api/launch/runners/pipeline.py b/dbx/api/launch/runners/pipeline.py new file mode 100644 index 00000000..19a0e743 --- /dev/null +++ b/dbx/api/launch/runners/pipeline.py @@ -0,0 +1,57 @@ +import json +import time +from functools import partial +from typing import Optional, List, Tuple + +from databricks_cli.sdk import ApiClient +from pydantic import BaseModel +from rich.console import Console + +from dbx.api.launch.pipeline_models import PipelineDetails, PipelineGlobalState +from dbx.api.launch.runners.base import PipelineUpdateResponse +from dbx.api.services.pipelines import NamedPipelinesService + + +class PipelinesRunPayload(BaseModel): + full_refresh: Optional[bool] + refresh_selection: Optional[List[str]] = [] + full_refresh_selection: Optional[List[str]] = [] + + +class PipelineLauncher: + def __init__( + self, + workflow_name: str, + api_client: ApiClient, + parameters: Optional[str] = None, + ): + self.api_client = api_client + self.name = workflow_name + self.parameters = self._process_parameters(parameters) + + @staticmethod + def _process_parameters(payload: Optional[str]) -> Optional[PipelinesRunPayload]: + if payload: + _payload = json.loads(payload) + return PipelinesRunPayload(**_payload) + + def _stop_current_updates(self, pipeline_id: str): + msg = f"Checking if there are any running updates for the pipeline {pipeline_id} and stopping them." + with Console().status(msg, spinner="dots") as status: + while True: + raw_response = self.api_client.perform_query("GET", f"/pipelines/{pipeline_id}") + current_pipeline_status = PipelineDetails(**raw_response) + if current_pipeline_status.state == PipelineGlobalState.RUNNING: + status.update("Found a running pipeline update, stopping it") + self.api_client.perform_query("POST", f"/pipelines/{pipeline_id}/stop") + time.sleep(5) + else: + status.update(f"Pipeline {pipeline_id} is stopped, starting a new update") + break + + def launch(self) -> Tuple[PipelineUpdateResponse, str]: + pipeline_id = NamedPipelinesService(self.api_client).find_by_name_strict(self.name) + self._stop_current_updates(pipeline_id) + prepared_query = partial(self.api_client.perform_query, "POST", f"/pipelines/{pipeline_id}/updates") + resp = prepared_query(data=self.parameters.dict(exclude_none=True)) if self.parameters else prepared_query() + return PipelineUpdateResponse(**resp), pipeline_id diff --git a/dbx/api/launch/runners/standard.py b/dbx/api/launch/runners/standard.py new file mode 100644 index 00000000..d60ccb49 --- /dev/null +++ b/dbx/api/launch/runners/standard.py @@ -0,0 +1,65 @@ +import json +from typing import Optional, Union, Tuple + +from databricks_cli.sdk import ApiClient, JobsService + +from dbx.api.launch.functions import wait_run, cancel_run +from dbx.api.launch.runners.base import RunData +from dbx.api.services.jobs import NamedJobsService +from dbx.models.cli.options import ExistingRunsOption +from dbx.models.workflow.v2dot0.parameters import StandardRunPayload as V2dot0StandardRunPayload +from dbx.models.workflow.v2dot1.parameters import StandardRunPayload as V2dot1StandardRunPayload +from dbx.utils import dbx_echo + + +class StandardLauncher: + def __init__( + self, + workflow_name: str, + api_client: ApiClient, + existing_runs: ExistingRunsOption, + parameters: Optional[str] = None, + ): + self.workflow_name = workflow_name + self.api_client = api_client + self.existing_runs: ExistingRunsOption = existing_runs + self._parameters = None if not parameters else self._process_parameters(parameters) + + def _process_parameters(self, payload: str) -> Union[V2dot0StandardRunPayload, V2dot1StandardRunPayload]: + _payload = json.loads(payload) + if self.api_client.jobs_api_version == "2.0": + return V2dot0StandardRunPayload(**_payload) + else: + return V2dot1StandardRunPayload(**_payload) + + def launch(self) -> Tuple[RunData, int]: + dbx_echo("Launching job via run now API") + named_service = NamedJobsService(self.api_client) + standard_service = JobsService(self.api_client) + job_id = named_service.find_by_name(self.workflow_name) + + if not job_id: + raise Exception(f"Workflow with name {self.workflow_name} not found") + + active_runs = standard_service.list_runs(job_id, active_only=True).get("runs", []) + + for run in active_runs: + _run_data = RunData(**run) + if self.existing_runs == ExistingRunsOption.wait: + dbx_echo(f"Waiting for job run with id {_run_data.run_id} to be finished") + wait_run(self.api_client, _run_data) + elif self.existing_runs == ExistingRunsOption.cancel: + dbx_echo(f"Cancelling run with id {_run_data.run_id}") + cancel_run(self.api_client, _run_data) + else: + dbx_echo("Passing the existing runs status check") + + api_request_payload = {"job_id": job_id} + + if self._parameters: + dbx_echo(f"Running the workload with the provided parameters {self._parameters.dict(exclude_none=True)}") + api_request_payload.update(self._parameters.dict(exclude_none=True)) + + run_data = self.api_client.perform_query("POST", "/jobs/run-now", data=api_request_payload) + + return RunData(**run_data), job_id diff --git a/dbx/api/launch/tracer.py b/dbx/api/launch/tracer.py index 13f549b9..dd9a3d82 100644 --- a/dbx/api/launch/tracer.py +++ b/dbx/api/launch/tracer.py @@ -1,14 +1,17 @@ -from typing import Dict, Any +import time from databricks_cli.sdk import ApiClient +from rich.console import Console from dbx.api.launch.functions import trace_run, cancel_run +from dbx.api.launch.pipeline_models import PipelineUpdateState, PipelineUpdateStatus +from dbx.api.launch.runners.base import RunData, PipelineUpdateResponse from dbx.utils import dbx_echo class RunTracer: @staticmethod - def start(kill_on_sigterm: bool, api_client: ApiClient, run_data: Dict[str, Any]): + def start(kill_on_sigterm: bool, api_client: ApiClient, run_data: RunData): if kill_on_sigterm: dbx_echo("Click Ctrl+C to stop the run") try: @@ -23,3 +26,30 @@ def start(kill_on_sigterm: bool, api_client: ApiClient, run_data: Dict[str, Any] dbx_status, final_run_state = trace_run(api_client, run_data) return dbx_status, final_run_state + + +class PipelineTracer: + TERMINAL_STATES = [PipelineUpdateState.COMPLETED, PipelineUpdateState.FAILED, PipelineUpdateState.CANCELED] + + @classmethod + def start( + cls, api_client: ApiClient, process_info: PipelineUpdateResponse, pipeline_id: str + ) -> PipelineUpdateState: + _path = f"/pipelines/{pipeline_id}/requests/{process_info.request_id}" + with Console().status(f"Tracing the DLT pipeline with id {pipeline_id}", spinner="dots") as display: + while True: + time.sleep(1) # to avoid API throttling + raw_response = api_client.perform_query("GET", _path) + status_response = PipelineUpdateStatus(**raw_response) + latest_update = status_response.latest_update + + msg = ( + f"Tracing the DLT pipeline with id {pipeline_id}, " + f"started with cause {latest_update.cause}, " + f"current state: {latest_update.state}" + ) + + display.update(msg) + if latest_update.state in cls.TERMINAL_STATES: + display.update(f"DLT pipeline with id {pipeline_id} finished with state {latest_update.state}") + return latest_update.state diff --git a/dbx/api/output_provider.py b/dbx/api/output_provider.py index 742506fc..c0be7979 100644 --- a/dbx/api/output_provider.py +++ b/dbx/api/output_provider.py @@ -2,7 +2,7 @@ from databricks_cli.sdk import JobsService -from dbx.models.options import IncludeOutputOption +from dbx.models.cli.options import IncludeOutputOption from dbx.utils import dbx_echo diff --git a/dbx/api/services/__init__.py b/dbx/api/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbx/api/services/_base.py b/dbx/api/services/_base.py new file mode 100644 index 00000000..8dcf370a --- /dev/null +++ b/dbx/api/services/_base.py @@ -0,0 +1,23 @@ +from abc import abstractmethod +from typing import Optional, Any + +from dbx.api.adjuster.mixins.base import ApiClientMixin +from dbx.models.deployment import AnyWorkflow + + +class WorkflowBaseService(ApiClientMixin): + @abstractmethod + def find_by_name(self, name: str) -> Optional[int]: + """Searches for the workflow by name and returns its id or None if not found""" + + @abstractmethod + def create(self, wf: AnyWorkflow): + """Creates the workflow from a given payload""" + + @abstractmethod + def update(self, object_id: int, wf: AnyWorkflow): + """Updates the workflow by provided id""" + + @abstractmethod + def delete(self, object_id: Any): + """Deletes the workflow by provided id""" diff --git a/dbx/api/services/jobs.py b/dbx/api/services/jobs.py new file mode 100644 index 00000000..407874de --- /dev/null +++ b/dbx/api/services/jobs.py @@ -0,0 +1,86 @@ +from typing import List, Optional, Union + +from databricks_cli.sdk import ApiClient, JobsService +from requests import HTTPError +from rich.markup import escape + +from dbx.api.adjuster.mixins.base import ApiClientMixin +from dbx.api.services._base import WorkflowBaseService +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.v2dot0.workflow import Workflow as V2dot0Workflow +from dbx.models.workflow.v2dot1.workflow import Workflow as V2dot1Workflow +from dbx.utils import dbx_echo + +AnyJob = Union[V2dot0Workflow, V2dot1Workflow] + + +class JobSettingsResponse(FlexibleModel): + name: str + + +class JobResponse(FlexibleModel): + job_id: int + settings: JobSettingsResponse + + +class ListJobsResponse(FlexibleModel): + has_more: Optional[bool] = False + jobs: Optional[List[JobResponse]] = [] + + +class JobListing(ApiClientMixin): + def by_name(self, name) -> ListJobsResponse: + raw = self.api_client.perform_query(method="get", version="2.1", path="/jobs/list", data={"name": name}) + return ListJobsResponse(**raw) + + +class NamedJobsService(WorkflowBaseService): + DEFAULT_LIST_LIMIT = 25 + JOBS_API_VERSION_FOR_SEARCH = "2.1" + + def __init__(self, api_client: ApiClient): + super().__init__(api_client) + self._service = JobsService(api_client) + + def find_by_name(self, name: str) -> Optional[int]: + response = JobListing(self.api_client).by_name(name) + + if len(response.jobs) > 1: + raise Exception( + f"""There are more than one jobs with name {name}. + Please delete duplicated jobs first.""" + ) + + if not response.jobs: + return None + else: + return response.jobs[0].job_id + + def create(self, wf: AnyJob): + """ + Please note that this method adjusts the provided workflow definition + by setting the job_id field value on it + """ + dbx_echo(f"🪄 Creating new workflow with name {escape(wf.name)}") + payload = wf.dict(exclude_none=True) + try: + _response = self.api_client.perform_query("POST", "/jobs/create", data=payload) + wf.job_id = _response["job_id"] + except HTTPError as e: + dbx_echo(":boom: Failed to create job with definition:") + dbx_echo(payload) + raise e + + def update(self, object_id: int, wf: AnyJob): + dbx_echo(f"🪄 Updating existing workflow with name {escape(wf.name)} and id: {object_id}") + payload = wf.dict(exclude_none=True) + wf.job_id = object_id + try: + self._service.reset_job(object_id, payload) + except HTTPError as e: + dbx_echo(":boom: Failed to update job with definition:") + dbx_echo(payload) + raise e + + def delete(self, object_id: int): + self._service.delete_job(object_id) diff --git a/dbx/api/services/permissions.py b/dbx/api/services/permissions.py new file mode 100644 index 00000000..c4d36662 --- /dev/null +++ b/dbx/api/services/permissions.py @@ -0,0 +1,23 @@ +from rich.markup import escape + +from dbx.api.adjuster.mixins.base import ApiClientMixin +from dbx.models.deployment import AnyWorkflow +from dbx.models.workflow.common.workflow_types import WorkflowType +from dbx.utils import dbx_echo + + +class PermissionsService(ApiClientMixin): + def apply(self, wf: AnyWorkflow): + path = ( + f"/permissions/pipelines/{wf.pipeline_id}" + if wf.workflow_type == WorkflowType.pipeline + else f"/permissions/jobs/{wf.job_id}" + ) + if wf.access_control_list: + dbx_echo(f"🛂 Applying permission settings for workflow {escape(wf.name)}") + self.api_client.perform_query( + "PUT", + path, + data=wf.get_acl_payload(), + ) + dbx_echo(f"✅ Permission settings were successfully set for workflow {escape(wf.name)}") diff --git a/dbx/api/services/pipelines.py b/dbx/api/services/pipelines.py new file mode 100644 index 00000000..a2603ea8 --- /dev/null +++ b/dbx/api/services/pipelines.py @@ -0,0 +1,75 @@ +from typing import Optional, List + +from requests import HTTPError +from rich.markup import escape + +from dbx.api.services._base import WorkflowBaseService +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.common.pipeline import Pipeline +from dbx.utils import dbx_echo + + +class NamedPipelinesService(WorkflowBaseService): + def delete(self, object_id: str): + self.api_client.perform_query("DELETE", path=f"/pipelines/{object_id}") + + def find_by_name_strict(self, name: str): + _response = ListPipelinesResponse( + **self.api_client.perform_query(method="GET", path="/pipelines/", data={"filter": f"name like '{name}'"}) + ) + _found = _response.get(name, strict_exist=True) + return _found.pipeline_id + + def find_by_name(self, name: str) -> Optional[str]: + _response = ListPipelinesResponse( + **self.api_client.perform_query(method="GET", path="/pipelines/", data={"filter": f"name like '{name}'"}) + ) + _found = _response.get(name, strict_exist=False) + return _found.pipeline_id if _found else None + + def create(self, wf: Pipeline): + dbx_echo(f"🪄 Creating new DLT pipeline with name {escape(wf.name)}") + payload = wf.dict(exclude_none=True) + try: + _response = self.api_client.perform_query("POST", path="/pipelines", data=payload) + wf.pipeline_id = _response["pipeline_id"] + except HTTPError as e: + dbx_echo(":boom: Failed to create pipeline with definition:") + dbx_echo(payload) + raise e + + def update(self, object_id: int, wf: Pipeline): + dbx_echo(f"🪄 Updating existing DLT pipeline with name [yellow]{escape(wf.name)}[/yellow] and id: {object_id}") + payload = wf.dict(exclude_none=True) + try: + self.api_client.perform_query("PUT", path=f"/pipelines/{object_id}", data=payload) + wf.pipeline_id = object_id + except HTTPError as e: + dbx_echo(":boom: Failed to edit pipeline with definition:") + dbx_echo(payload) + raise e + + +class PipelineStateInfo(FlexibleModel): + pipeline_id: str + name: str + + +class ListPipelinesResponse(FlexibleModel): + statuses: List[PipelineStateInfo] = [] + + @property + def pipeline_names(self) -> List[str]: + return [p.name for p in self.statuses] + + def get(self, name: str, strict_exist: bool = True) -> Optional[PipelineStateInfo]: + _found = list(filter(lambda p: p.name == name, self.statuses)) + + if strict_exist: + assert _found, NameError(f"No pipelines with name {name} were found!") + + if not _found: + return None + else: + assert len(_found) == 1, NameError(f"More than one pipeline with name {name} was found: {_found}") + return _found[0] diff --git a/dbx/api/storage/io.py b/dbx/api/storage/io.py new file mode 100644 index 00000000..c3760304 --- /dev/null +++ b/dbx/api/storage/io.py @@ -0,0 +1,29 @@ +import json +import shutil +import tempfile +from pathlib import Path +from typing import Dict, Any + +import mlflow +from mlflow.tracking import MlflowClient + +from dbx.utils.json import JsonUtils + + +class StorageIO: + @staticmethod + def save(content: Dict[Any, Any], name: str): + temp_dir = tempfile.mkdtemp() + serialized_data = json.dumps(content, indent=4) + temp_path = Path(temp_dir, name) + temp_path.write_text(serialized_data, encoding="utf-8") + mlflow.log_artifact(str(temp_path), ".dbx") + shutil.rmtree(temp_dir) + + @staticmethod + def load(run_id: str, file_name: str) -> Dict[Any, Any]: + client = MlflowClient() + with tempfile.TemporaryDirectory() as tmp: + dbx_file_path = f".dbx/{file_name}" + client.download_artifacts(run_id, dbx_file_path, tmp) + return JsonUtils.read(Path(tmp) / dbx_file_path) diff --git a/dbx/api/storage/mlflow_based.py b/dbx/api/storage/mlflow_based.py index 4e9ea569..3d055311 100644 --- a/dbx/api/storage/mlflow_based.py +++ b/dbx/api/storage/mlflow_based.py @@ -1,7 +1,6 @@ import os from pathlib import PurePosixPath from typing import Optional -from urllib.parse import urlparse import mlflow from databricks_cli.sdk import WorkspaceService @@ -12,27 +11,17 @@ from dbx.api.client_provider import DatabricksClientProvider from dbx.api.configure import EnvironmentInfo from dbx.constants import DATABRICKS_MLFLOW_URI +from dbx.utils.url import strip_databricks_url class MlflowStorageConfigurationManager: DATABRICKS_HOST_ENV: str = "DATABRICKS_HOST" DATABRICKS_TOKEN_ENV: str = "DATABRICKS_TOKEN" - @staticmethod - def _strip_url(url: str) -> str: - """ - Mlflow API requires url to be stripped, e.g. - {scheme}://{netloc}/some-stuff/ shall be transformed to {scheme}://{netloc} - :param url: url to be stripped - :return: stripped url - """ - parsed = urlparse(url) - return f"{parsed.scheme}://{parsed.netloc}" - @classmethod def _setup_tracking_uri(cls): config = AuthConfigProvider.get_config() - os.environ[cls.DATABRICKS_HOST_ENV] = cls._strip_url(config.host) + os.environ[cls.DATABRICKS_HOST_ENV] = strip_databricks_url(config.host) os.environ[cls.DATABRICKS_TOKEN_ENV] = config.token mlflow.set_tracking_uri(DATABRICKS_MLFLOW_URI) diff --git a/dbx/callbacks.py b/dbx/callbacks.py index 1c0bf73d..d11bc483 100644 --- a/dbx/callbacks.py +++ b/dbx/callbacks.py @@ -6,7 +6,7 @@ import typer from dbx import __version__ -from dbx.models.parameters.execute import ExecuteWorkloadParamInfo +from dbx.models.cli.execute import ExecuteParametersPayload from dbx.utils import dbx_echo @@ -49,7 +49,7 @@ def deployment_file_callback(_, value: Optional[str]) -> Path: def version_callback(value: bool): if value: dbx_echo( - f":brick:[red]Databricks[/red] e[red]X[/red]tensions aka [red]dbx[/red], " + f":brick: [red]Databricks[/red] e[red]X[/red]tensions aka [red]dbx[/red], " f"version ~> [green]{__version__}[/green]" ) raise typer.Exit() @@ -64,12 +64,12 @@ def debug_callback(_, value): def execute_parameters_callback(_, value: str) -> Optional[str]: if value: try: - _parsed = json.loads(value) + json.loads(value) except json.JSONDecodeError as e: dbx_echo(":boom: Provided parameters payload cannot be parsed since it's not in json format") raise e - ExecuteWorkloadParamInfo(**_parsed) + ExecuteParametersPayload.from_json(value) return value diff --git a/dbx/cli.py b/dbx/cli.py index ce75cca8..99c09158 100644 --- a/dbx/cli.py +++ b/dbx/cli.py @@ -16,12 +16,12 @@ typer.rich_utils._get_help_text = _get_custom_help_text -app = typer.Typer(rich_markup_mode="markdown") +app = typer.Typer(rich_markup_mode="markdown", pretty_exceptions_show_locals=False) app.callback( name="dbx", help=""" - 🧱Databricks eXtensions aka dbx. Please find the main docs page [here](https://dbx.readthedocs.io/). + 🧱 Databricks eXtensions aka dbx. Please find the main docs page [here](https://dbx.readthedocs.io/). """, )(version_entrypoint) @@ -182,6 +182,8 @@ )(destroy) +# click app object here is used in the mkdocs. +# Don't delete it! def get_click_app(): return typer.main.get_command(app) diff --git a/dbx/commands/configure.py b/dbx/commands/configure.py index 1d02d535..24428e00 100644 --- a/dbx/commands/configure.py +++ b/dbx/commands/configure.py @@ -1,7 +1,7 @@ import typer from dbx.api.configure import ProjectConfigurationManager, EnvironmentInfo -from dbx.models.project import MlflowStorageProperties, StorageType +from dbx.models.files.project import MlflowStorageProperties, StorageType from dbx.options import ENVIRONMENT_OPTION, PROFILE_OPTION from dbx.utils import dbx_echo, current_folder_name @@ -45,6 +45,19 @@ def configure( This flag ignores any other flags. + Project file should exist, otherwise command will fail.""", + ), + enable_context_based_upload_for_execute: bool = typer.Option( + False, + "--enable-context-based-upload-for-execute", + is_flag=True, + help=""" + Enables failsafe behaviour for assets-based launches with definitions + that are based on shared job clusters feature. + + This flag ignores any other flags. + + Project file should exist, otherwise command will fail.""", ), ): @@ -60,6 +73,11 @@ def configure( manager.enable_failsafe_cluster_reuse() dbx_echo("✅ Enabling failsafe cluster reuse with assets") + elif enable_context_based_upload_for_execute: + dbx_echo("Enabling context-based upload for execute") + manager.enable_context_based_upload_for_execute() + dbx_echo("✅ Enabling context-based upload for execute") + else: dbx_echo(f"Configuring new environment with name {environment}") manager.create_or_update( diff --git a/dbx/commands/deploy.py b/dbx/commands/deploy.py index 287f78ec..d55388bc 100644 --- a/dbx/commands/deploy.py +++ b/dbx/commands/deploy.py @@ -1,18 +1,18 @@ import json -import shutil -import tempfile from pathlib import Path -from typing import Dict, Any, Union, Optional from typing import List +from typing import Optional import mlflow import typer -from databricks_cli.jobs.api import JobsService, JobsApi -from databricks_cli.sdk.api_client import ApiClient -from requests.exceptions import HTTPError -from dbx.api.config_reader import ConfigReader -from dbx.models.deployment import EnvironmentDeploymentInfo +from dbx.api.adjuster.adjuster import Adjuster, AdditionalLibrariesProvider +from dbx.api.config_reader import ConfigReader, BuildProperties +from dbx.api.dependency.core_package import CorePackageManager +from dbx.api.dependency.requirements import RequirementsFileProcessor +from dbx.api.deployment import WorkflowDeploymentManager +from dbx.api.storage.io import StorageIO +from dbx.models.workflow.common.workflow_types import WorkflowType from dbx.options import ( DEPLOYMENT_FILE_OPTION, ENVIRONMENT_OPTION, @@ -26,35 +26,32 @@ WORKFLOW_ARGUMENT, ) from dbx.utils import dbx_echo -from dbx.utils.adjuster import adjust_job_definitions from dbx.utils.common import ( prepare_environment, parse_multiple, get_current_branch_name, ) -from dbx.utils.dependency_manager import DependencyManager from dbx.utils.file_uploader import MlflowFileUploader -from dbx.utils.job_listing import find_job_by_name def deploy( workflow_name: str = WORKFLOW_ARGUMENT, deployment_file: Path = DEPLOYMENT_FILE_OPTION, - job: Optional[str] = typer.Option( + job_name: Optional[str] = typer.Option( None, "--job", help="This option is deprecated, please use workflow name as argument instead.", show_default=False, ), - jobs: Optional[str] = typer.Option( + job_names: Optional[str] = typer.Option( None, "--jobs", help="This option is deprecated, please use `--workflows` instead.", show_default=False ), - workflows: Optional[str] = typer.Option( + workflow_names: Optional[str] = typer.Option( None, "--workflows", help="Comma-separated list of workflow names to be deployed", show_default=False ), requirements_file: Optional[Path] = REQUIREMENTS_FILE_OPTION, tags: Optional[List[str]] = TAGS_OPTION, - environment: str = ENVIRONMENT_OPTION, + environment_name: str = ENVIRONMENT_OPTION, no_rebuild: bool = NO_REBUILD_OPTION, no_package: bool = NO_PACKAGE_OPTION, files_only: bool = typer.Option( @@ -70,6 +67,7 @@ def deploy( help="""When provided, will **only** upload assets (📁 referenced files, 📦 core package and workflow definition) to the artifact storage. + ⚠️ A workflow(s) won't be created or updated in the Jobs UI. @@ -96,84 +94,83 @@ def deploy( jinja_variables_file: Optional[Path] = JINJA_VARIABLES_FILE_OPTION, debug: Optional[bool] = DEBUG_OPTION, # noqa ): - dbx_echo(f"Starting new deployment for environment {environment}") + dbx_echo(f"Starting new deployment for environment {environment_name}") - api_client = prepare_environment(environment) + api_client = prepare_environment(environment_name) additional_tags = parse_multiple(tags) if not branch_name: branch_name = get_current_branch_name() config_reader = ConfigReader(deployment_file, jinja_variables_file) - config = config_reader.get_config() - - deployment = config.get_environment(environment, raise_if_not_found=True) - - if workflow_name: - job = workflow_name - - if workflows: - jobs = workflows + config = config_reader.with_build_properties( + BuildProperties(potential_build=True, no_rebuild=no_rebuild) + ).get_config() - requested_jobs = _define_deployable_jobs(job, jobs) + environment_info = config.get_environment(environment_name, raise_if_not_found=True) - _preprocess_deployment(deployment, requested_jobs) + workflow_name = workflow_name if workflow_name else job_name + workflow_names = workflow_names.split(",") if workflow_names else job_names.split(",") if job_names else [] - if no_rebuild: - dbx_echo( - """[yellow bold] - Legacy [code]--no-rebuild[/code] flag has been used. - Please specify build logic in the build section of the deployment file instead.[/yellow bold]""" - ) - config.build.no_build = True + deployable_workflows = environment_info.payload.select_relevant_or_all_workflows(workflow_name, workflow_names) + environment_info.payload.workflows = deployable_workflows # filter out the chosen set of workflows - dependency_manager = DependencyManager(config.build, no_package, requirements_file) + core_package = CorePackageManager().core_package + libraries_from_requirements = RequirementsFileProcessor(requirements_file).libraries if requirements_file else [] _assets_only = assets_only if assets_only else files_only - with mlflow.start_run() as deployment_run: - - artifact_base_uri = deployment_run.info.artifact_uri - _file_uploader = MlflowFileUploader(artifact_base_uri) + if _assets_only: + any_pipelines = [w for w in deployable_workflows if w.workflow_type == WorkflowType.pipeline] + if any_pipelines: + raise Exception( + f"Assets-only deployment mode is not supported for DLT pipelines: {[p.name for p in any_pipelines]}" + ) - adjust_job_definitions(deployment.payload.workflows, dependency_manager, _file_uploader, api_client) + with mlflow.start_run() as deployment_run: - if not _assets_only: - dbx_echo("Updating job definitions") - deployment_data = _create_jobs(deployment.payload.workflows, api_client) - _log_dbx_file(deployment_data, "deployments.json") + adjuster = Adjuster( + api_client=api_client, + file_uploader=MlflowFileUploader(deployment_run.info.artifact_uri), + additional_libraries=AdditionalLibrariesProvider( + no_package=no_package, + core_package=core_package, + libraries_from_requirements=libraries_from_requirements, + ), + ) - for job_spec in deployment.payload.workflows: - permissions = job_spec.get("permissions") - if permissions: - job_name = job_spec.get("name") - dbx_echo(f"Permission settings are provided for job {job_name}, setting it up") - job_id = deployment_data.get(job_spec.get("name")) - api_client.perform_query("PUT", f"/permissions/jobs/{job_id}", data=permissions) - dbx_echo(f"Permission settings were successfully set for job {job_name}") + pipelines = [p for p in deployable_workflows if p.workflow_type == WorkflowType.pipeline] + workflows = [w for w in deployable_workflows if w.workflow_type != WorkflowType.pipeline] - dbx_echo("Updating job definitions - done") + if pipelines: + dbx_echo("Found DLT pipelines definition, applying them first for proper reference resolution") + for elements in [pipelines, workflows]: + adjuster.traverse(elements) + wf_manager = WorkflowDeploymentManager(api_client, elements) + wf_manager.apply() + else: + adjuster.traverse(deployable_workflows) + if not _assets_only: + wf_manager = WorkflowDeploymentManager(api_client, deployable_workflows) + wf_manager.apply() deployment_tags = { "dbx_action_type": "deploy", - "dbx_environment": environment, + "dbx_environment": environment_name, "dbx_status": "SUCCESS", + "dbx_branch_name": branch_name, } - deployment_spec = deployment.to_spec() - deployment_tags.update(additional_tags) - - if branch_name: - deployment_tags["dbx_branch_name"] = branch_name - if _assets_only: deployment_tags["dbx_deploy_type"] = "files_only" - _log_dbx_file(deployment_spec, "deployment-result.json") + environment_spec = environment_info.to_spec() + + StorageIO.save(environment_spec, "deployment-result.json") mlflow.set_tags(deployment_tags) - dbx_echo(f":sparkles: Deployment for environment {environment} finished successfully") + dbx_echo(f":sparkles: Deployment for environment {environment_name} finished successfully") if write_specs_to_file: dbx_echo("Writing final job specifications into file") @@ -182,98 +179,4 @@ def deploy( if specs_file.exists(): specs_file.unlink() - specs_file.write_text(json.dumps(deployment_spec, indent=4), encoding="utf-8") - - -def _log_dbx_file(content: Dict[Any, Any], name: str): - temp_dir = tempfile.mkdtemp() - serialized_data = json.dumps(content, indent=4) - temp_path = Path(temp_dir, name) - temp_path.write_text(serialized_data, encoding="utf-8") - mlflow.log_artifact(str(temp_path), ".dbx") - shutil.rmtree(temp_dir) - - -def _define_deployable_jobs(job: str, jobs: str) -> Optional[List[str]]: - if jobs and job: - raise Exception("Both --job and --jobs cannot be provided together") - - if job: - requested_jobs = [job] - elif jobs: - requested_jobs = jobs.split(",") - else: - requested_jobs = None - - return requested_jobs - - -def _preprocess_deployment(deployment: EnvironmentDeploymentInfo, requested_jobs: Union[List[str], None]): - if not deployment.payload.workflows: - raise Exception("No jobs provided for deployment") - - deployment.payload.workflows = _preprocess_jobs(deployment.payload.workflows, requested_jobs) - - -def _preprocess_jobs(jobs: List[Dict[str, Any]], requested_jobs: Union[List[str], None]) -> List[Dict[str, Any]]: - job_names = [job["name"] for job in jobs] - if requested_jobs: - dbx_echo(f"Deployment will be performed only for the following jobs: {requested_jobs}") - for requested_job_name in requested_jobs: - if requested_job_name not in job_names: - raise Exception( - f""" - Workflow {requested_job_name} was requested, but not provided in deployment file. - Available workflows are: {job_names} - """ - ) - preprocessed_jobs = [job for job in jobs if job["name"] in requested_jobs] - else: - preprocessed_jobs = jobs - return preprocessed_jobs - - -def _create_jobs(jobs: List[Dict[str, Any]], api_client: ApiClient) -> Dict[str, int]: - deployment_data = {} - for job in jobs: - dbx_echo(f'Processing deployment for job: {job["name"]}') - jobs_service = JobsService(api_client) - matching_job = find_job_by_name(jobs_service, job["name"]) - - if not matching_job: - job_id = _create_job(api_client, job) - else: - job_id = matching_job["job_id"] - _update_job(jobs_service, job_id, job) - - deployment_data[job["name"]] = job_id - return deployment_data - - -def _create_job(api_client: ApiClient, job: Dict[str, Any]) -> str: - dbx_echo(f'Creating a new job with name {job["name"]}') - try: - jobs_api = JobsApi(api_client) - job_id = jobs_api.create_job(job)["job_id"] - except HTTPError as e: - dbx_echo(":boom: Failed to create job with definition:") - dbx_echo(job) - raise e - return job_id - - -def _update_job(jobs_service: JobsService, job_id: str, job: Dict[str, Any]) -> str: - dbx_echo(f'Updating existing job with id: {job_id} and name: {job["name"]}') - try: - jobs_service.reset_job(job_id, job) - except HTTPError as e: - dbx_echo(":boom: Failed to update job with definition:") - dbx_echo(job) - raise e - - _acl = job.get("access_control_list") - if _acl: - _client = jobs_service.client - _client.perform_query("PUT", f"/permissions/jobs/{job_id}", data={"access_control_list": _acl}) - - return job_id + specs_file.write_text(json.dumps(environment_spec, indent=4), encoding="utf-8") diff --git a/dbx/commands/destroy.py b/dbx/commands/destroy.py index 16255b5d..6aef8916 100644 --- a/dbx/commands/destroy.py +++ b/dbx/commands/destroy.py @@ -3,12 +3,13 @@ from typing import Optional import typer +from rich.markup import escape from rich.prompt import Prompt from typer.rich_utils import _get_rich_console # noqa from dbx.api.config_reader import ConfigReader from dbx.api.destroyer import Destroyer -from dbx.models.destroyer import DestroyerConfig, DeletionMode +from dbx.models.cli.destroyer import DestroyerConfig, DeletionMode from dbx.options import ( WORKFLOW_ARGUMENT, DEPLOYMENT_FILE_OPTION, @@ -20,12 +21,12 @@ def destroy( - workflow: Optional[str] = WORKFLOW_ARGUMENT, - workflows: Optional[str] = typer.Option( + workflow_name: Optional[str] = WORKFLOW_ARGUMENT, + workflow_names: Optional[str] = typer.Option( None, "--workflows", help="Comma-separated list of workflow names to be deleted", show_default=False ), deployment_file: Optional[Path] = DEPLOYMENT_FILE_OPTION, - environment: str = ENVIRONMENT_OPTION, + environment_name: str = ENVIRONMENT_OPTION, jinja_variables_file: Optional[Path] = JINJA_VARIABLES_FILE_OPTION, deletion_mode: DeletionMode = typer.Option( DeletionMode.all, @@ -33,10 +34,10 @@ def destroy( help="""Deletion mode. - If `assets-only`, will only delete the stored assets in the artifact storage, but won't affect job objects. + If `assets-only`, will only delete the stored assets in the artifact storage, but won't affect workflow objects. - If `workflows-only`, will only delete the defined job objects, but won't affect job objects. + If `workflows-only`, will only delete the defined workflow objects, but won't affect the artifact storage. If `all`, will delete everything.""", @@ -54,19 +55,18 @@ def destroy( False, "--dracarys", help="🔥 add more fire to the CLI output, making the deletion absolutely **epic**." ), ): - if workflow and workflows: - raise Exception(f"arguments {workflow} and {workflows} cannot be provided together") - _workflows = [workflow] if workflow else workflows.split(",") if workflows else [] + workflow_names = workflow_names.split(",") if workflow_names else [] - config_reader = ConfigReader(deployment_file, jinja_variables_file) - config = config_reader.get_config() + global_config = ConfigReader(deployment_file, jinja_variables_file).get_config() + env_config = global_config.get_environment(environment_name, raise_if_not_found=True) + relevant_workflows = env_config.payload.select_relevant_or_all_workflows(workflow_name, workflow_names) _d_config = DestroyerConfig( - workflows=_workflows, + workflows=relevant_workflows, deletion_mode=deletion_mode, dracarys=dracarys, - deployment=config.get_environment(environment, raise_if_not_found=True), + deployment=env_config, dry_run=dry_run, ) @@ -79,7 +79,7 @@ def destroy( if not confirm: ask_for_confirmation(_d_config) - api_client = prepare_environment(environment) + api_client = prepare_environment(environment_name) destroyer = Destroyer(api_client, _d_config) destroyer.launch() @@ -93,14 +93,13 @@ def ask_for_confirmation(conf: DestroyerConfig): 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 """ + wf_names = [escape(w.name) for w in conf.workflows] if conf.deletion_mode == DeletionMode.assets_only: deletion_message = "All assets will de deleted, but the workflow definitions won't be affected." elif conf.deletion_mode == DeletionMode.workflows_only: - deletion_message = ( - f"The following workflows are marked for deletion: {conf.workflows}, assets won't be affected" - ) + deletion_message = f"The following workflows are marked for deletion: {wf_names}, assets won't be affected" else: - deletion_message = f"""The following workflows are marked for deletion: {conf.workflows}. + deletion_message = f"""The following workflows are marked for deletion: {wf_names}. [bold]All assets are also marked for deletion.[/bold]""" _c = _get_rich_console() diff --git a/dbx/commands/execute.py b/dbx/commands/execute.py index c916b671..24dcc7f8 100644 --- a/dbx/commands/execute.py +++ b/dbx/commands/execute.py @@ -4,12 +4,13 @@ import typer from dbx.api.cluster import ClusterController -from dbx.api.config_reader import ConfigReader +from dbx.api.config_reader import ConfigReader, BuildProperties +from dbx.api.configure import ProjectConfigurationManager from dbx.api.context import RichExecutionContextClient +from dbx.api.dependency.core_package import CorePackageManager from dbx.api.execute import ExecutionController -from dbx.models.deployment import EnvironmentDeploymentInfo -from dbx.models.parameters.execute import ExecuteWorkloadParamInfo -from dbx.models.task import Task, TaskType +from dbx.models.cli.execute import ExecuteParametersPayload +from dbx.models.workflow.common.workflow_types import WorkflowType from dbx.options import ( DEPLOYMENT_FILE_OPTION, ENVIRONMENT_OPTION, @@ -21,13 +22,13 @@ WORKFLOW_ARGUMENT, EXECUTE_PARAMETERS_OPTION, ) +from dbx.types import ExecuteTask from dbx.utils import dbx_echo from dbx.utils.common import prepare_environment -from dbx.api.build import prepare_build def execute( - workflow: str = WORKFLOW_ARGUMENT, + workflow_name: str = WORKFLOW_ARGUMENT, environment: str = ENVIRONMENT_OPTION, cluster_id: Optional[str] = typer.Option( None, "--cluster-id", help="Cluster ID. Cannot be provided together with `--cluster-name`" @@ -35,14 +36,14 @@ def execute( cluster_name: Optional[str] = typer.Option( None, "--cluster-name", help="Cluster name. Cannot be provided together with `--cluster-id`" ), - job: str = typer.Option( + job_name: str = typer.Option( None, "--job", help="This option is deprecated. Please use `workflow-name` as argument instead" ), - task: Optional[str] = typer.Option( + task_name: Optional[str] = typer.Option( None, "--task", help="""Task name (`task_key` field) inside the workflow to be executed. - Required if the workflow is a multitask job""", + Required if the workflow is a multitask job with more than one task""", ), deployment_file: Path = DEPLOYMENT_FILE_OPTION, requirements_file: Optional[Path] = REQUIREMENTS_FILE_OPTION, @@ -68,111 +69,63 @@ def execute( debug: Optional[bool] = DEBUG_OPTION, # noqa ): api_client = prepare_environment(environment) - controller = ClusterController(api_client) - cluster_id = controller.preprocess_cluster_args(cluster_name, cluster_id) + cluster_controller = ClusterController(api_client, cluster_name=cluster_name, cluster_id=cluster_id) - _job = workflow if workflow else job + workflow_name = workflow_name if workflow_name else job_name - if not _job: + if not workflow_name: raise Exception("Please provide workflow name as an argument") - dbx_echo(f"Executing job: {_job} in environment {environment} on cluster {cluster_name} (id: {cluster_id})") + dbx_echo( + f"Executing workflow: {workflow_name} in environment {environment} " + f"on cluster {cluster_name} (id: {cluster_id})" + ) config_reader = ConfigReader(deployment_file, jinja_variables_file) - config = config_reader.get_config() - deployment = config.get_environment(environment) - - if no_rebuild: - dbx_echo( - """[yellow bold] - Legacy [code]--no-rebuild[/code] flag has been used. - Please specify build logic in the build section of the deployment file instead.[/yellow bold]""" - ) - config.build.no_build = True - - prepare_build(config.build) - - _verify_deployment(deployment, deployment_file) - - found_jobs = [j for j in deployment.payload.workflows if j["name"] == _job] + config = config_reader.with_build_properties( + BuildProperties(potential_build=True, no_rebuild=no_rebuild) + ).get_config() - if not found_jobs: - raise RuntimeError(f"Job {_job} was not found in environment jobs, please check the deployment file") + environment_config = config.get_environment(environment, raise_if_not_found=True) - job_payload = found_jobs[0] + workflow = environment_config.payload.get_workflow(workflow_name) - if task: - _tasks = job_payload.get("tasks", []) - found_tasks = [t for t in _tasks if t.get("task_key") == task] + if workflow.workflow_type == WorkflowType.pipeline: + raise Exception("DLT pipelines are not supported in the execute mode.") - if not found_tasks: - raise Exception(f"Task {task} not found in the definition of job {_job}") + if not task_name and workflow.workflow_type == WorkflowType.job_v2d1: + if len(workflow.task_names) == 1: + dbx_echo("Task key wasn't provided, automatically picking it since there is only one task in the workflow") + task_name = workflow.task_names[0] + else: + raise ValueError("Task key is not provided and there is more than one task in the workflow.") - if len(found_tasks) > 1: - raise Exception(f"Task keys are not unique, more then one task found for job {_job} with task name {task}") + task: ExecuteTask = workflow.get_task(task_name) if task_name else workflow - _task = found_tasks[0] - - _payload = _task - else: - if "tasks" in job_payload: - raise Exception( - "You're trying to execute a multitask job without passing the task name. " - "Please provide the task name via --task parameter" - ) - _payload = job_payload - - task = Task(**_payload) + task.check_if_supported_in_execute() + core_package = CorePackageManager().core_package if not no_package else None if parameters: - override_parameters(parameters, task) + task.override_execute_parameters(ExecuteParametersPayload.from_json(parameters)) - dbx_echo("Preparing interactive cluster to accept jobs") - controller.awake_cluster(cluster_id) + cluster_controller.awake_cluster() - context_client = RichExecutionContextClient(api_client, cluster_id) + context_client = RichExecutionContextClient(api_client, cluster_controller.cluster_id) + + upload_via_context = ( + upload_via_context + if upload_via_context + else ProjectConfigurationManager().get_context_based_upload_for_execute() + ) - controller_instance = ExecutionController( + execution_controller = ExecutionController( client=context_client, no_package=no_package, + core_package=core_package, requirements_file=requirements_file, task=task, upload_via_context=upload_via_context, pip_install_extras=pip_install_extras, ) - controller_instance.run() - - -def _verify_deployment(deployment: EnvironmentDeploymentInfo, deployment_file): - if not deployment: - raise NameError( - f"Environment {deployment.name} is not provided in deployment file {deployment_file}" - + " please add this environment first" - ) - env_jobs = deployment.payload.workflows - if not env_jobs: - raise RuntimeError(f"No jobs section found in environment {deployment.name}, please check the deployment file") - - -def override_parameters(raw_params_info: str, task: Task): - param_info = ExecuteWorkloadParamInfo.from_string(raw_params_info) - if param_info.named_parameters is not None and task.task_type != TaskType.python_wheel_task: - raise Exception(f"named parameters are only supported if task type is {TaskType.python_wheel_task.value}") - - if param_info.named_parameters: - dbx_echo(":twisted_rightwards_arrows:Overriding named_parameters section for the task") - task.python_wheel_task.named_parameters = param_info.named_parameters - task.python_wheel_task.parameters = [] - dbx_echo(":white_check_mark:Overriding named_parameters section for the task") - - if param_info.parameters: - dbx_echo(":twisted_rightwards_arrows:Overriding parameters section for the task") - - if task.task_type == TaskType.python_wheel_task: - task.python_wheel_task.parameters = param_info.parameters - task.python_wheel_task.named_parameters = [] - elif task.task_type == TaskType.spark_python_task: - task.spark_python_task.parameters = param_info.parameters - - dbx_echo(":white_check_mark:Overriding parameters section for the task") + execution_controller.run() diff --git a/dbx/commands/launch.py b/dbx/commands/launch.py index 4b7e7189..64746774 100644 --- a/dbx/commands/launch.py +++ b/dbx/commands/launch.py @@ -1,15 +1,20 @@ -from typing import List +from typing import List, Dict, Any from typing import Optional import mlflow import typer from databricks_cli.jobs.api import JobsService +from rich.markup import escape from dbx.api.launch.functions import find_deployment_run -from dbx.api.launch.runners import RunSubmitLauncher, RunNowLauncher -from dbx.api.launch.tracer import RunTracer +from dbx.api.launch.pipeline_models import PipelineUpdateState +from dbx.api.launch.runners.asset_based import AssetBasedLauncher +from dbx.api.launch.runners.base import RunData +from dbx.api.launch.runners.pipeline import PipelineLauncher +from dbx.api.launch.runners.standard import StandardLauncher +from dbx.api.launch.tracer import RunTracer, PipelineTracer from dbx.api.output_provider import OutputProvider -from dbx.models.options import ExistingRunsOption, IncludeOutputOption +from dbx.models.cli.options import ExistingRunsOption, IncludeOutputOption from dbx.options import ( ENVIRONMENT_OPTION, TAGS_OPTION, @@ -28,14 +33,21 @@ def launch( - workflow: str = WORKFLOW_ARGUMENT, - environment: str = ENVIRONMENT_OPTION, - job: str = typer.Option( + workflow_name: str = WORKFLOW_ARGUMENT, + environment_name: str = ENVIRONMENT_OPTION, + job_name: str = typer.Option( None, "--job", help="This option is deprecated, please use workflow name as argument instead.", show_default=False, ), + is_pipeline: bool = typer.Option( + False, + "--pipeline", + "-p", + is_flag=True, + help="Search for the workflow in the DLT pipelines instead of standard job objects.", + ), trace: bool = typer.Option(False, "--trace", help="Trace the workload until it finishes.", is_flag=True), kill_on_sigterm: bool = typer.Option( False, @@ -100,75 +112,113 @@ def launch( parameters: Optional[str] = LAUNCH_PARAMETERS_OPTION, debug: Optional[bool] = DEBUG_OPTION, # noqa ): - _job = workflow if workflow else job + workflow_name = workflow_name if workflow_name else job_name - if not _job: + if not workflow_name: raise Exception("Please provide workflow name as an argument") - dbx_echo(f"Launching job {_job} on environment {environment}") + if is_pipeline and from_assets: + raise Exception("DLT pipelines cannot be launched in the asset-based mode") - api_client = prepare_environment(environment) + dbx_echo(f"Launching workflow {escape(workflow_name)} on environment {environment_name}") + + api_client = prepare_environment(environment_name) additional_tags = parse_multiple(tags) if not branch_name: branch_name = get_current_branch_name() - filter_string = generate_filter_string(environment, branch_name) + filter_string = generate_filter_string(environment_name, branch_name) _from_assets = from_assets if from_assets else as_run_submit - last_deployment_run = find_deployment_run(filter_string, additional_tags, _from_assets, environment) + last_deployment_run = find_deployment_run(filter_string, additional_tags, _from_assets, environment_name) with mlflow.start_run(run_id=last_deployment_run.info.run_id): with mlflow.start_run(nested=True): - if not _from_assets: - run_launcher = RunNowLauncher( - job=_job, api_client=api_client, existing_runs=existing_runs, parameters=parameters - ) + if is_pipeline: + launcher = PipelineLauncher(workflow_name=workflow_name, api_client=api_client, parameters=parameters) else: - run_launcher = RunSubmitLauncher( - job=_job, - api_client=api_client, - deployment_run_id=last_deployment_run.info.run_id, - environment=environment, - parameters=parameters, - ) - - run_data, job_id = run_launcher.launch() + if not _from_assets: + launcher = StandardLauncher( + workflow_name=workflow_name, + api_client=api_client, + existing_runs=existing_runs, + parameters=parameters, + ) + else: + launcher = AssetBasedLauncher( + workflow_name=workflow_name, + api_client=api_client, + deployment_run_id=last_deployment_run.info.run_id, + environment_name=environment_name, + parameters=parameters, + ) - jobs_service = JobsService(api_client) - run_info = jobs_service.get_run(run_data["run_id"]) - run_url = run_info.get("run_page_url") - dbx_echo(f"Run URL: {run_url}") - if trace: - dbx_status, final_run_state = RunTracer.start(kill_on_sigterm, api_client, run_data) - if include_output: - log_provider = OutputProvider(jobs_service, final_run_state) - dbx_echo(f"Run output provisioning requested with level {include_output.value}") - log_provider.provide(include_output) + process_info, object_id = launcher.launch() - if dbx_status == "ERROR": - raise Exception( - "Tracked run failed during execution. " - "Please check the status and logs of the run for details." - ) + if isinstance(process_info, RunData): + jobs_service = JobsService(api_client) + run_info = jobs_service.get_run(process_info.run_id) + run_url = run_info.get("run_page_url") + dbx_echo(f"Run URL: {run_url}") + else: + dbx_echo("DLT pipeline launched successfully") + + if trace: + if isinstance(process_info, RunData): + status = trace_workflow_object(api_client, process_info, include_output, kill_on_sigterm) + additional_tags = { + "job_id": object_id, + "run_id": process_info.run_id, + } else: - dbx_status = "NOT_TRACKED" - dbx_echo( - "Run successfully launched in non-tracking mode :rocket:. " - "Please check Databricks UI for job status :eyes:" + final_state = PipelineTracer.start( + api_client=api_client, process_info=process_info, pipeline_id=object_id ) - - deployment_tags = { - "job_id": job_id, - "run_id": run_data.get("run_id"), - "dbx_action_type": "launch", - "dbx_status": dbx_status, - "dbx_environment": environment, - } - - if branch_name: - deployment_tags["dbx_branch_name"] = branch_name - - mlflow.set_tags(deployment_tags) + if final_state == PipelineUpdateState.FAILED: + raise Exception( + f"Tracked pipeline {object_id} failed during execution, please check the UI for details." + ) + status = final_state + additional_tags = {"pipeline_id": object_id} + else: + status = "NOT_TRACKED" + dbx_echo( + "Workflow successfully launched in the non-tracking mode 🚀. " + "Please check Databricks UI for job status 👀" + ) + log_launch_info(additional_tags, status, environment_name, branch_name) + + +def trace_workflow_object( + api_client, + run_data: RunData, + include_output, + kill_on_sigterm, +): + dbx_status, final_run_state = RunTracer.start(kill_on_sigterm, api_client, run_data) + js = JobsService(api_client) + if include_output: + log_provider = OutputProvider(js, final_run_state) + dbx_echo(f"Run output provisioning requested with level {include_output.value}") + log_provider.provide(include_output) + + if dbx_status == "ERROR": + raise Exception("Tracked run failed during execution. Please check the status and logs of the run for details.") + return dbx_status + + +def log_launch_info(additional_tags: Dict[str, Any], dbx_status, environment_name, branch_name): + deployment_tags = { + "dbx_action_type": "launch", + "dbx_status": dbx_status, + "dbx_environment": environment_name, + } + + if branch_name: + deployment_tags["dbx_branch_name"] = branch_name + + deployment_tags.update(additional_tags) + mlflow.set_tags(deployment_tags) diff --git a/dbx/commands/sync/options.py b/dbx/commands/sync/options.py index a44312fa..8a20b2f8 100644 --- a/dbx/commands/sync/options.py +++ b/dbx/commands/sync/options.py @@ -139,7 +139,7 @@ that are not present locally with the current filters. So for the example above, this would remove `foo` in the destination when syncing with`-i bar`. - * `---unmatched-behaviour=allow-delete-unmatched=disallow-delete-unmatched` will NOT delete files/directories + * `--unmatched-behaviour=allow-delete-unmatched=disallow-delete-unmatched` will NOT delete files/directories in the destination that are not present locally with the current filters. So for the example above, this would leave `foo` in the destination when syncing with`-i bar`.""", ) diff --git a/dbx/commands/sync/sync.py b/dbx/commands/sync/sync.py index 5b600f06..ad154d21 100644 --- a/dbx/commands/sync/sync.py +++ b/dbx/commands/sync/sync.py @@ -1,5 +1,7 @@ +import asyncio from typing import List, Optional +import aiohttp import click import typer from databricks_cli.configure.provider import ProfileConfigProvider @@ -231,6 +233,12 @@ def dbfs( ) +async def repo_exists(client: ReposClient) -> bool: + connector = aiohttp.TCPConnector(limit=1) + async with aiohttp.ClientSession(connector=connector, trust_env=True) as session: + return await client.exists(session=session) + + @sync_app.command( short_help=""" 🔀 Syncs from a source directory to a Databricks Repo @@ -315,6 +323,14 @@ def repo( client = ReposClient(user=user_name, repo_name=dest_repo, config=config) + if not asyncio.run(repo_exists(client)): + raise click.UsageError( + f"Destination repo {dest_repo} does not exist. " + "Please create the repo using the Databricks UI and try again. You can create an empty repo by " + "clicking 'Add Repo', unchecking the 'Create repo by cloning a Git repository' option, and providing " + f"{dest_repo} as the repository name." + ) + main_loop( source=source, matcher=matcher, diff --git a/dbx/constants.py b/dbx/constants.py index e6f8e82b..ec2092cc 100644 --- a/dbx/constants.py +++ b/dbx/constants.py @@ -2,6 +2,8 @@ import pkg_resources +from dbx.models.workflow.common.task_type import TaskType + DBX_PATH = Path(".dbx") PROJECT_INFO_FILE_PATH = DBX_PATH / "project.json" LOCK_FILE_PATH = DBX_PATH / "lock.json" @@ -17,3 +19,4 @@ # would want to sync these, so we don't make this configurable. DBX_SYNC_DEFAULT_IGNORES = [".git/", ".dbx", "*.isorted"] TERMINAL_RUN_LIFECYCLE_STATES = ["TERMINATED", "SKIPPED", "INTERNAL_ERROR"] +TASKS_SUPPORTED_IN_EXECUTE = [TaskType.spark_python_task, TaskType.python_wheel_task] diff --git a/dbx/models/build.py b/dbx/models/build.py new file mode 100644 index 00000000..afb2cb16 --- /dev/null +++ b/dbx/models/build.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from enum import Enum +from typing import Optional, List + +from pydantic import BaseModel +from rich.console import Console + +from dbx.api.build import execute_shell_command, cleanup_dist +from dbx.utils import dbx_echo + + +class PythonBuild(str, Enum): + pip = "pip" + poetry = "poetry" + flit = "flit" + + +class BuildConfiguration(BaseModel): + no_build: Optional[bool] = False + commands: Optional[List[str]] + python: Optional[PythonBuild] = PythonBuild.pip + + def _build_process(self): + if self.commands: + dbx_echo("Running the build commands") + for command in self.commands: + with Console().status(f"🔨Running command {command}", spinner="dots"): + execute_shell_command(command) + elif self.python: + dbx_echo("🐍 Building a Python-based project") + cleanup_dist() + + if self.python == PythonBuild.poetry: + build_kwargs = {"cmd": "poetry build -f wheel"} + elif self.python == PythonBuild.flit: + command = "-m flit build --format wheel" + build_kwargs = {"cmd": command, "with_python_executable": True} + else: + command = "-m pip wheel -w dist -e . --prefer-binary --no-deps" + build_kwargs = {"cmd": command, "with_python_executable": True} + + with Console().status("Building the package :hammer:", spinner="dots"): + execute_shell_command(**build_kwargs) + dbx_echo(":white_check_mark: Python-based project build finished") + else: + dbx_echo("Neither commands nor python building configuration was provided, skipping the build stage") + + def trigger_build_process(self): + if self.no_build: + dbx_echo("No build actions will be performed.") + else: + dbx_echo("Following the provided build logic") + self._build_process() diff --git a/dbx/models/cli/__init__.py b/dbx/models/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbx/models/cli/destroyer.py b/dbx/models/cli/destroyer.py new file mode 100644 index 00000000..fdcfe457 --- /dev/null +++ b/dbx/models/cli/destroyer.py @@ -0,0 +1,20 @@ +from enum import Enum +from typing import Optional, List + +from pydantic import BaseModel + +from dbx.models.deployment import EnvironmentDeploymentInfo, AnyWorkflow + + +class DeletionMode(str, Enum): + all = "all" + assets_only = "assets-only" + workflows_only = "workflows-only" + + +class DestroyerConfig(BaseModel): + workflows: Optional[List[AnyWorkflow]] = [] + deletion_mode: DeletionMode + dry_run: Optional[bool] = False + dracarys: Optional[bool] = False + deployment: EnvironmentDeploymentInfo diff --git a/dbx/models/cli/execute.py b/dbx/models/cli/execute.py new file mode 100644 index 00000000..0b48acd8 --- /dev/null +++ b/dbx/models/cli/execute.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import json + +from pydantic import root_validator + +from dbx.models.validators import mutually_exclusive, at_least_one_of +from dbx.models.workflow.common.parameters import ParametersMixin, NamedParametersMixin + + +class ExecuteParametersPayload(ParametersMixin, NamedParametersMixin): + """Parameters for execute""" + + @root_validator(pre=True) + def _validate(cls, values): # noqa + at_least_one_of(["parameters", "named_parameters"], values) + mutually_exclusive(["parameters", "named_parameters"], values) + return values + + @staticmethod + def from_json(raw: str) -> ExecuteParametersPayload: + return ExecuteParametersPayload(**json.loads(raw)) diff --git a/dbx/models/options.py b/dbx/models/cli/options.py similarity index 100% rename from dbx/models/options.py rename to dbx/models/cli/options.py diff --git a/dbx/models/deployment.py b/dbx/models/deployment.py index b392f5e0..75861e5f 100644 --- a/dbx/models/deployment.py +++ b/dbx/models/deployment.py @@ -1,45 +1,93 @@ from __future__ import annotations -from copy import deepcopy -from enum import Enum -from typing import Optional, Dict, Any, List +import collections +from typing import Optional, Dict, Any, List, Union -from pydantic import BaseModel, root_validator, validator +from pydantic import BaseModel, validator, Field +from rich.markup import escape +from typing_extensions import Annotated from dbx.api.configure import ProjectConfigurationManager -from dbx.models.project import EnvironmentInfo +from dbx.models.build import BuildConfiguration +from dbx.models.files.project import EnvironmentInfo +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.common.pipeline import Pipeline +from dbx.models.workflow.common.workflow_types import WorkflowType +from dbx.models.workflow.v2dot0.workflow import Workflow as V2dot0Workflow +from dbx.models.workflow.v2dot1.workflow import Workflow as V2dot1Workflow from dbx.utils import dbx_echo +AnyWorkflow = Annotated[Union[V2dot0Workflow, V2dot1Workflow, Pipeline], Field(discriminator="workflow_type")] +WorkflowList = List[AnyWorkflow] + + +class WorkflowListMixin(BaseModel): + workflows: Optional[WorkflowList] + + @property + def workflow_names(self) -> List[str]: + return [w.name for w in self.workflows] + + @validator("workflows") + def _validate_unique(cls, workflows: Optional[WorkflowList]): # noqa + if workflows: + _duplicates = [ + name for name, count in collections.Counter([w.name for w in workflows]).items() if count > 1 + ] + if _duplicates: + raise ValueError(f"Duplicated workflow names: {_duplicates}") + return workflows + else: + return [] + + def get_workflow(self, name) -> AnyWorkflow: + _found = list(filter(lambda w: w.name == name, self.workflows)) + if not _found: + raise ValueError(f"Workflow {name} not found. Available workflows are {self.workflow_names}") + return _found[0] -class Deployment(BaseModel): - workflows: Optional[List[Dict[str, Any]]] - @root_validator(pre=True) - def check_inputs(cls, values: Dict[str, Any]): # noqa - if "jobs" in values: +class Deployment(FlexibleModel, WorkflowListMixin): + @staticmethod + def from_spec_local(raw_spec: Dict[str, Any]) -> Deployment: + if "jobs" in raw_spec: dbx_echo( "[yellow bold]Usage of jobs keyword in deployment file is deprecated. " "Please use [bold]workflows[bold] instead (simply rename this section to workflows).[/yellow bold]" ) - _w = values.get("jobs") if "jobs" in values else values.get("workflows") - return {"workflows": _w} - - -class PythonBuild(str, Enum): - pip = "pip" - poetry = "poetry" - flit = "flit" + return Deployment.from_spec_remote(raw_spec) - -class BuildConfiguration(BaseModel): - no_build: Optional[bool] = False - commands: Optional[List[str]] = [] - python: Optional[PythonBuild] - - @root_validator(pre=True) - def init_default(cls, values): # noqa - _v = values if values else {"python": "pip"} - return _v + @staticmethod + def from_spec_remote(raw_spec: Dict[str, Any]) -> Deployment: + _wfs = raw_spec.get("jobs") if "jobs" in raw_spec else raw_spec.get("workflows") + assert isinstance(_wfs, list), ValueError(f"Provided payload is not a list {_wfs}") + + for workflow_def in _wfs: + if not workflow_def.get("workflow_type"): + workflow_def["workflow_type"] = ( + WorkflowType.job_v2d1 if "tasks" in workflow_def else WorkflowType.job_v2d0 + ) + return Deployment(**{"workflows": _wfs}) + + def select_relevant_or_all_workflows( + self, workflow_name: Optional[str] = None, workflow_names: Optional[List[str]] = None + ) -> WorkflowList: + + if workflow_name and workflow_names: + raise Exception("Workflow argument and --workflows (or --job and --jobs) cannot be provided together") + + if workflow_name: + dbx_echo(f"The workflow {escape(workflow_name)} was selected for further operations") + return [self.get_workflow(workflow_name)] + elif workflow_names: + dbx_echo(f"Workflows {[escape(w) for w in workflow_names]} were selected for further operations") + return [self.get_workflow(w) for w in workflow_names] + else: + dbx_echo( + f"All available workflows were selected for further operations: " + f"{[escape(w) for w in self.workflow_names]}" + ) + return self.workflows class EnvironmentDeploymentInfo(BaseModel): @@ -47,9 +95,21 @@ class EnvironmentDeploymentInfo(BaseModel): payload: Deployment def to_spec(self) -> Dict[str, Any]: - _spec = {self.name: {"jobs": self.payload.workflows}} + _spec = {self.name: self.payload.dict(exclude_none=True)} return _spec + @staticmethod + def from_spec( + environment_name: str, raw_spec: Dict[str, Any], reader_type: Optional[str] = "local" + ) -> EnvironmentDeploymentInfo: + deployment_reader = Deployment.from_spec_local if reader_type == "local" else Deployment.from_spec_remote + if not raw_spec: + raise ValueError(f"Deployment result for {environment_name} doesn't contain any workflow definitions") + + _spec = {"name": environment_name, "payload": deployment_reader(raw_spec)} + + return EnvironmentDeploymentInfo(**_spec) + def get_project_info(self) -> EnvironmentInfo: """ Convenience method for cases when the project information about specific environment is required. @@ -59,12 +119,17 @@ def get_project_info(self) -> EnvironmentInfo: class DeploymentConfig(BaseModel): environments: List[EnvironmentDeploymentInfo] - build: Optional[BuildConfiguration] + build: Optional[BuildConfiguration] = BuildConfiguration() - @validator("build", pre=True) - def default_build(cls, value): # noqa - build_spec = value if value else {"python": "pip"} - return build_spec + @staticmethod + def _prepare_build(payload: Dict[str, Any]) -> BuildConfiguration: + _build_payload = payload.get("build", {}) + if not _build_payload: + dbx_echo( + "No build logic defined in the deployment file. " + "Default [code]pip[/code]-based build logic will be used." + ) + return BuildConfiguration(**_build_payload) def get_environment(self, name, raise_if_not_found: Optional[bool] = False) -> Optional[EnvironmentDeploymentInfo]: _found = [env for env in self.environments if env.name == name] @@ -82,20 +147,10 @@ def get_environment(self, name, raise_if_not_found: Optional[bool] = False) -> O return _found[0] - @staticmethod - def prepare_build(payload: Dict[str, Any]) -> BuildConfiguration: - _build_payload = payload.get("build", {}) - if not _build_payload: - dbx_echo( - "No build logic defined in the deployment file. " - "Default [code]pip[/code]-based build logic will be used." - ) - return BuildConfiguration(**_build_payload) - @classmethod def from_legacy_json_payload(cls, payload: Dict[str, Any]) -> DeploymentConfig: - _build = cls.prepare_build(payload) + _build = cls._prepare_build(payload) _envs = [] for name, _env_payload in payload.items(): @@ -105,17 +160,14 @@ def from_legacy_json_payload(cls, payload: Dict[str, Any]) -> DeploymentConfig: This behaviour is not supported since dbx v0.7.0. Please nest all environment configurations under "environments" key in the deployment file.""" ) - _env = EnvironmentDeploymentInfo(name=name, payload=_env_payload) + _env = EnvironmentDeploymentInfo.from_spec(name, _env_payload) _envs.append(_env) return DeploymentConfig(environments=_envs, build=_build) @classmethod def from_payload(cls, payload: Dict[str, Any]) -> DeploymentConfig: - _payload = deepcopy(payload) - _envs = [ - EnvironmentDeploymentInfo(name=name, payload=env_payload) - for name, env_payload in _payload.get("environments", {}).items() - ] - _build = cls.prepare_build(_payload) + _env_payloads = payload.get("environments", {}) + _envs = [EnvironmentDeploymentInfo.from_spec(name, env_payload) for name, env_payload in _env_payloads.items()] + _build = cls._prepare_build(payload) return DeploymentConfig(environments=_envs, build=_build) diff --git a/dbx/models/destroyer.py b/dbx/models/destroyer.py deleted file mode 100644 index f4545d4d..00000000 --- a/dbx/models/destroyer.py +++ /dev/null @@ -1,32 +0,0 @@ -from enum import Enum -from typing import Optional, List - -from pydantic import BaseModel, root_validator - -from dbx.models.deployment import EnvironmentDeploymentInfo - - -class DeletionMode(str, Enum): - all = "all" - assets_only = "assets-only" - workflows_only = "workflows-only" - - -class DestroyerConfig(BaseModel): - workflows: Optional[List[str]] - deletion_mode: DeletionMode - dry_run: Optional[bool] = False - dracarys: Optional[bool] = False - deployment: EnvironmentDeploymentInfo - - @root_validator() - def validate_all(cls, values): # noqa - _dc = values["deployment"] - if not values["workflows"]: - values["workflows"] = [w["name"] for w in _dc.payload.workflows] - else: - _ws_names = [w["name"] for w in _dc.payload.workflows] - for w in values["workflows"]: - if w not in _ws_names: - raise ValueError(f"Workflow name {w} not found in {_ws_names}") - return values diff --git a/dbx/models/files/__init__.py b/dbx/models/files/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbx/models/context.py b/dbx/models/files/context.py similarity index 100% rename from dbx/models/context.py rename to dbx/models/files/context.py diff --git a/dbx/models/project.py b/dbx/models/files/project.py similarity index 96% rename from dbx/models/project.py rename to dbx/models/files/project.py index 03e6bde2..296e85ab 100644 --- a/dbx/models/project.py +++ b/dbx/models/files/project.py @@ -44,6 +44,7 @@ class ProjectInfo(BaseModel): environments: Dict[str, Union[EnvironmentInfo, LegacyEnvironmentInfo]] inplace_jinja_support: Optional[bool] = False failsafe_cluster_reuse_with_assets: Optional[bool] = False + context_based_upload_for_execute: Optional[bool] = False def get_environment(self, name: str) -> EnvironmentInfo: _env = self.environments.get(name) diff --git a/dbx/models/job_clusters.py b/dbx/models/job_clusters.py deleted file mode 100644 index 202215d9..00000000 --- a/dbx/models/job_clusters.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Dict, Any, List - -from pydantic import BaseModel, root_validator - - -class JobCluster(BaseModel): - job_cluster_key: str - new_cluster: Dict[str, Any] - - -class JobClusters(BaseModel): - job_clusters: List[JobCluster] = [] - - @root_validator(pre=True) - def validator(cls, values: Dict[str, Any]): # noqa - if values: - job_clusters = values.get("job_clusters", []) - - # checks that structure is provided in expected format - assert isinstance(job_clusters, list), f"Job clusters payload should be a list, provided: {job_clusters}" - - cluster_keys = [JobCluster(**v).job_cluster_key for v in job_clusters] - - # checks that there are no duplicates - for key in cluster_keys: - if cluster_keys.count(key) > 1: - raise ValueError(f"Duplicated cluster key {key} found in the job_clusters section") - return values - - def get_cluster_definition(self, key: str) -> JobCluster: - _found = list(filter(lambda jc: jc.job_cluster_key == key, self.job_clusters)) - if not _found: - raise ValueError(f"Cluster key {key} is not provided in the job_clusters section: {self.job_clusters}") - return _found[0] diff --git a/dbx/models/parameters/common.py b/dbx/models/parameters/common.py deleted file mode 100644 index 2c569650..00000000 --- a/dbx/models/parameters/common.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - -from typing import Optional, Dict, List, Any - -ParamPair = Optional[Dict[str, str]] -StringArray = Optional[List[str]] - - -def validate_contains(fields: Dict[str, Any], values: Dict[str, Any]): - _matching_fields = [f for f in fields if f in values] - if not _matching_fields: - raise ValueError(f"Provided payload {values} doesn't contain any of the supported fields: {fields}") - return values - - -def validate_unique(fields: Dict[str, Any], values: Dict[str, Any]): - _matching_fields = [f for f in fields if f in values] - if len(_matching_fields) > 1: - raise ValueError(f"Provided payload {values} contains more than one definition") - - return values diff --git a/dbx/models/parameters/execute.py b/dbx/models/parameters/execute.py deleted file mode 100644 index 5424c9bd..00000000 --- a/dbx/models/parameters/execute.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations - -import json -from typing import Optional, List, Dict, Any - -from pydantic import BaseModel, root_validator, validator - -from dbx.models.parameters.common import validate_contains, validate_unique -from dbx.models.task import validate_named_parameters - - -class ExecuteWorkloadParamInfo(BaseModel): - parameters: Optional[List[str]] # for spark_python_task, python_wheel_task - named_parameters: Optional[List[str]] # only for python_wheel_task - - @root_validator(pre=True) - def initialize(cls, values: Dict[str, Any]): # noqa - validate_contains(cls.__fields__, values) - validate_unique(cls.__fields__, values) - return values - - @staticmethod - def from_string(payload: str) -> ExecuteWorkloadParamInfo: - return ExecuteWorkloadParamInfo(**json.loads(payload)) - - @validator("named_parameters", pre=True) - def _validate_named_parameters(cls, values: List[str]): # noqa - validate_named_parameters(values) - return values diff --git a/dbx/models/parameters/run_now.py b/dbx/models/parameters/run_now.py deleted file mode 100644 index 32b2345a..00000000 --- a/dbx/models/parameters/run_now.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from typing import Dict, Any - -from pydantic import BaseModel, root_validator - -from dbx.models.parameters.common import StringArray, ParamPair -from dbx.models.parameters.common import validate_contains - - -class RunNowV2d0ParamInfo(BaseModel): - jar_params: StringArray - python_params: StringArray - spark_submit_params: StringArray - notebook_params: ParamPair - - @root_validator(pre=True) - def initialize(cls, values: Dict[str, Any]): # noqa - return validate_contains(cls.__fields__, values) - - -class RunNowV2d1ParamInfo(RunNowV2d0ParamInfo): - python_named_params: ParamPair - - @root_validator(pre=True) - def initialize(cls, values: Dict[str, Any]): # noqa - return validate_contains(cls.__fields__, values) diff --git a/dbx/models/parameters/run_submit.py b/dbx/models/parameters/run_submit.py deleted file mode 100644 index 078f1dc5..00000000 --- a/dbx/models/parameters/run_submit.py +++ /dev/null @@ -1,132 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import List, Optional, Union - -from pydantic import BaseModel, validator, root_validator - -from dbx.models.parameters.common import ParamPair, StringArray, validate_contains, validate_unique -from dbx.models.task import validate_named_parameters - - -class BaseTaskModel(BaseModel): - @abstractmethod - def get_parameters_key(self) -> str: - """""" - - @abstractmethod - def get_parameters(self) -> Union[ParamPair, StringArray]: - """""" - - -class NotebookTask(BaseTaskModel): - base_parameters: ParamPair - - def get_parameters_key(self) -> str: - return "base_parameters" - - def get_parameters(self) -> Union[ParamPair, StringArray]: - return self.base_parameters - - -class SparkJarTask(BaseTaskModel): - parameters: StringArray - - def get_parameters_key(self) -> str: - return "parameters" - - def get_parameters(self) -> Union[ParamPair, StringArray]: - return self.parameters - - -class SparkPythonTask(BaseModel): - parameters: StringArray - - def get_parameters_key(self) -> str: # noqa - return "parameters" - - def get_parameters(self) -> Union[ParamPair, StringArray]: - return self.parameters - - -class SparkSubmitTask(BaseModel): - parameters: StringArray - - def get_parameters_key(self) -> str: # noqa - return "parameters" - - def get_parameters(self) -> Union[ParamPair, StringArray]: - return self.parameters - - -class PythonWheelTask(BaseTaskModel): - parameters: StringArray - named_parameters: ParamPair - - @root_validator(pre=True) - def initialize(cls, values): # noqa - validate_contains(cls.__fields__, values) - validate_unique(cls.__fields__, values) - return values - - @validator("named_parameters", pre=True) - def _validate_named_parameters(cls, values: List[str]): # noqa - validate_named_parameters(values) - return values - - def get_parameters_key(self) -> str: - _key = "parameters" if self.parameters else "named_parameters" - return _key - - def get_parameters(self) -> Union[ParamPair, StringArray]: - _params = self.parameters if self.parameters else self.named_parameters - return _params - - -class TaskContainerModel(BaseModel): - def get_task_key(self) -> str: - """ - Returns the name of the non-empty task section - """ - _task_key = [ - k - for k in self.dict(exclude_none=True, exclude_unset=True, exclude_defaults=True).keys() - if k.endswith("_task") - ] - return _task_key[0] - - def get_defined_task(self) -> Union[NotebookTask, SparkJarTask, SparkPythonTask, SparkSubmitTask]: - return getattr(self, self.get_task_key()) - - -class RunSubmitV2d0ParamInfo(TaskContainerModel): - notebook_task: Optional[NotebookTask] - spark_jar_task: Optional[SparkJarTask] - spark_python_task: Optional[SparkPythonTask] - spark_submit_task: Optional[SparkSubmitTask] - - @root_validator(pre=True) - def initialize(cls, values): # noqa - validate_contains(cls.__fields__, values) - validate_unique(cls.__fields__, values) - return values - - -class NamedV2d1Task(TaskContainerModel): - task_key: str - notebook_task: Optional[NotebookTask] - spark_jar_task: Optional[SparkJarTask] - spark_python_task: Optional[SparkPythonTask] - spark_submit_task: Optional[SparkSubmitTask] - python_wheel_task: Optional[PythonWheelTask] - - @root_validator(pre=True) - def initialize(cls, values): # noqa - task_fields = {k: v for k, v in cls.__fields__.items() if k != "task_key"} - validate_contains(task_fields, values) - validate_unique(task_fields, values) - return values - - -class RunSubmitV2d1ParamInfo(BaseModel): - tasks: List[NamedV2d1Task] diff --git a/dbx/models/task.py b/dbx/models/task.py deleted file mode 100644 index 777f83f9..00000000 --- a/dbx/models/task.py +++ /dev/null @@ -1,71 +0,0 @@ -from enum import Enum -from pathlib import Path -from typing import List, Optional - -from pydantic import BaseModel, root_validator, validator - - -def validate_named_parameters(values: List[str]): - for v in values: - if not v.startswith("--"): - raise ValueError(f"Named parameter shall start with --, provided value: {v}") - if "=" not in v: - raise ValueError(f"Named parameter shall contain equal sign, provided value: {v}") - - -class TaskType(Enum): - spark_python_task = "spark_python_task" - python_wheel_task = "python_wheel_task" - - -class PythonWheelTask(BaseModel): - package_name: str - entry_point: str - parameters: Optional[List[str]] = [] - named_parameters: Optional[List[str]] = [] - - @root_validator(pre=True) - def validate_parameters(cls, values): # noqa - if all(param in values for param in ["parameters", "named_parameters"]): - raise ValueError("Both named_parameters and parameters cannot be provided at the same time") - return values - - @validator("named_parameters", pre=True) - def _validate_named_parameters(cls, values: List[str]): # noqa - validate_named_parameters(values) - return values - - -class SparkPythonTask(BaseModel): - python_file: Path - parameters: Optional[List[str]] = [] - - @validator("python_file", always=True) - def python_file_validator(cls, v: Path, values) -> Path: # noqa - stripped = v.relative_to("file://") # we need to strip out the file:// prefix - if not stripped.exists(): - raise FileNotFoundError(f"File {stripped} is mentioned in the task or job definition, but is non-existent") - return stripped - - -class Task(BaseModel): - spark_python_task: Optional[SparkPythonTask] - python_wheel_task: Optional[PythonWheelTask] - task_type: Optional[TaskType] - - @root_validator - def validate_all(cls, values): # noqa - if all(values.get(_type.name) is None for _type in TaskType): - raise ValueError( - f"Provided task or job definition doesn't contain one of the supported types: \n" - f"{[t.value for t in TaskType]}" - ) - if sum(1 if values.get(_type.name) else 0 for _type in TaskType) > 1: - raise ValueError("More then one definition has been provided, please review the job or task definition") - return values - - @validator("task_type", always=True) - def task_type_validator(cls, v, values) -> TaskType: # noqa - for _type in TaskType: - if values.get(_type.name): - return TaskType(_type) diff --git a/dbx/models/validators.py b/dbx/models/validators.py new file mode 100644 index 00000000..045ee252 --- /dev/null +++ b/dbx/models/validators.py @@ -0,0 +1,73 @@ +from typing import Dict, Any, List + + +def at_least_one_by_suffix(suffix: str, values: Dict[str, Any]): + _matching_fields = [f for f in values if f.endswith(suffix)] + if not _matching_fields: + raise ValueError( + f""" + At least one field with suffix {suffix} should be provided. + Provided payload: {values} + """, + ) + return values + + +def only_one_by_suffix(suffix: str, values: Dict[str, Any]): + _matching_fields = [f for f in values if f.endswith(suffix)] + + if len(_matching_fields) != 1: + _filtered_values = {k: v for k, v in values.items() if v is not None} + raise ValueError( + f""" + Only one field with suffix {suffix} should be provided. + Provided payload: {_filtered_values} + """, + ) + return values + + +def at_least_one_of(fields_names: List[str], values: Dict[str, Any]): + """ + Verifies that provided payload contains at least one of the fields + :param fields_names: List of the field names to be validated + :param values: Raw payload values + :return: Nothing, raises an error if validation didn't pass. + """ + _matching_fields = [f for f in fields_names if f in values] + if not _matching_fields: + raise ValueError( + f""" + At least one of the following fields should be provided in the payload: {fields_names}. + Provided payload: {values} + """, + ) + return values + + +def only_one_provided(suffix: str, values: Dict[str, Any]): + """Function verifies if value IS provided and it's unique""" + at_least_one_by_suffix(suffix, values) + only_one_by_suffix(suffix, values) + return values + + +def mutually_exclusive(fields_names: List[str], values: Dict[str, Any]): + non_empty_values = [key for key, item in values.items() if item] # will coalesce both checks for None and [] + _matching_fields = [f for f in fields_names if f in non_empty_values] + if len(_matching_fields) > 1: + raise ValueError( + f""" + The following fields {_matching_fields} are mutually exclusive. + Provided payload: {values} + """, + ) + return values + + +def check_dbt_commands(commands): + if commands: + for cmd in commands: + if not cmd.startswith("dbt"): + raise ValueError("All commands in the dbt_task must start with `dbt`, e.g. `dbt command1`") + return commands diff --git a/dbx/models/workflow/__init__.py b/dbx/models/workflow/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbx/models/workflow/common/__init__.py b/dbx/models/workflow/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbx/models/workflow/common/access_control.py b/dbx/models/workflow/common/access_control.py new file mode 100644 index 00000000..c732441e --- /dev/null +++ b/dbx/models/workflow/common/access_control.py @@ -0,0 +1,40 @@ +from enum import Enum +from typing import Optional, List, Dict, Any + +from pydantic import root_validator, validator + +from dbx.models.validators import at_least_one_of +from dbx.models.workflow.common.flexible import FlexibleModel + + +class PermissionLevel(str, Enum): + CAN_MANAGE = "CAN_MANAGE" + CAN_MANAGE_RUN = "CAN_MANAGE_RUN" + CAN_VIEW = "CAN_VIEW" + IS_OWNER = "IS_OWNER" + + +class AccessControlRequest(FlexibleModel): + user_name: Optional[str] + group_name: Optional[str] + permission_level: PermissionLevel + + _one_of_provided = root_validator(pre=True, allow_reuse=True)( + lambda _, values: at_least_one_of(["user_name", "group_name"], values) + ) + + +class AccessControlMixin(FlexibleModel): + access_control_list: Optional[List[AccessControlRequest]] + + @validator("access_control_list") + def owner_is_provided(cls, acls: List[AccessControlRequest]): # noqa + owner_info = [o for o in acls if o.permission_level == PermissionLevel.IS_OWNER] + if not owner_info: + raise ValueError("At least one owner (IS_OWNER) should be provided in the access control list") + if len(owner_info) > 1: + raise ValueError("Only one owner should be provided in the access control list") + return acls + + def get_acl_payload(self) -> Dict[str, Any]: + return self.dict(exclude_none=True) diff --git a/dbx/models/workflow/common/deployment_config.py b/dbx/models/workflow/common/deployment_config.py new file mode 100644 index 00000000..9d18dab3 --- /dev/null +++ b/dbx/models/workflow/common/deployment_config.py @@ -0,0 +1,7 @@ +from typing import Optional + +from pydantic import BaseModel + + +class DbxDeploymentConfig(BaseModel): + no_package: Optional[bool] = False diff --git a/dbx/models/workflow/common/flexible.py b/dbx/models/workflow/common/flexible.py new file mode 100644 index 00000000..e7a91bc9 --- /dev/null +++ b/dbx/models/workflow/common/flexible.py @@ -0,0 +1,25 @@ +from typing import List + +from pydantic import BaseModel, Extra + +from dbx.utils import dbx_echo + + +class FlexibleModel(BaseModel, extra=Extra.allow): + """ + Base class for models used across all domain objects in dbx. + Provides extensible functions for verification. + """ + + @classmethod + def get_field_names(cls) -> List[str]: + return list(cls.__fields__.keys()) + + @classmethod + def field_deprecated(cls, field_id: str, field_name: str, reference: str, value): + dbx_echo( + f"""[yellow]⚠️ Field [bold]{field_name}[/bold] is DEPRECATED.[/yellow] + Please use the in-place reference instead: + [code]{field_id}: "{reference}://{value}"[/code] + """ + ) diff --git a/dbx/models/workflow/common/job_email_notifications.py b/dbx/models/workflow/common/job_email_notifications.py new file mode 100644 index 00000000..0ac32c12 --- /dev/null +++ b/dbx/models/workflow/common/job_email_notifications.py @@ -0,0 +1,10 @@ +from typing import Optional, List + +from dbx.models.workflow.common.flexible import FlexibleModel + + +class JobEmailNotifications(FlexibleModel): + on_start: Optional[List[str]] + on_success: Optional[List[str]] + on_failure: Optional[List[str]] + no_alert_for_skipped_runs: Optional[bool] diff --git a/dbx/models/workflow/common/libraries.py b/dbx/models/workflow/common/libraries.py new file mode 100644 index 00000000..c36a96c0 --- /dev/null +++ b/dbx/models/workflow/common/libraries.py @@ -0,0 +1,37 @@ +from typing import Optional, List + +from pydantic import root_validator + +from dbx.models.validators import at_least_one_of, mutually_exclusive +from dbx.models.workflow.common.flexible import FlexibleModel + + +class PythonPyPiLibrary(FlexibleModel): + package: str + repo: Optional[str] + + +class MavenLibrary(FlexibleModel): + coordinates: str + repo: Optional[str] + exclusions: Optional[List[str]] + + +class RCranLibrary(FlexibleModel): + package: str + repo: Optional[str] + + +class Library(FlexibleModel): + jar: Optional[str] + egg: Optional[str] + whl: Optional[str] + pypi: Optional[PythonPyPiLibrary] + maven: Optional[MavenLibrary] + cran: Optional[RCranLibrary] + + @root_validator(pre=True) + def _validate(cls, values): # noqa + at_least_one_of(cls.get_field_names(), values) + mutually_exclusive(cls.get_field_names(), values) + return values diff --git a/dbx/models/workflow/common/new_cluster.py b/dbx/models/workflow/common/new_cluster.py new file mode 100644 index 00000000..942eeaa1 --- /dev/null +++ b/dbx/models/workflow/common/new_cluster.py @@ -0,0 +1,61 @@ +from typing import Optional + +from pydantic import root_validator, validator + +from dbx.models.workflow.common.flexible import FlexibleModel + + +class AutoScale(FlexibleModel): + min_workers: int + max_workers: int + + @root_validator() + def _validate(cls, values): # noqa + assert values["max_workers"] > values["min_workers"], ValueError( + f""" + max_workers ({values["max_workers"]}) should be bigger than min_workers ({values["min_workers"]}) + """, + ) + return values + + +class AwsAttributes(FlexibleModel): + first_on_demand: Optional[int] + availability: Optional[str] + zone_id: Optional[str] + instance_profile_arn: Optional[str] + instance_profile_name: Optional[str] + + @validator("instance_profile_name") + def _validate(cls, value): # noqa + cls.field_deprecated("instance_profile_arn", "instance_profile_name", "instance-profile", value) + return value + + +class NewCluster(FlexibleModel): + spark_version: str + node_type_id: Optional[str] + num_workers: Optional[int] + autoscale: Optional[AutoScale] + instance_pool_name: Optional[str] + driver_instance_pool_name: Optional[str] + driver_instance_pool_id: Optional[str] + instance_pool_id: Optional[str] + aws_attributes: Optional[AwsAttributes] + policy_name: Optional[str] + policy_id: Optional[str] + + @validator("instance_pool_name") + def instance_pool_name_validate(cls, value): # noqa + cls.field_deprecated("instance_pool_id", "instance_pool_name", "instance-pool", value) + return value + + @validator("driver_instance_pool_name") + def driver_instance_pool_name_validate(cls, value): # noqa + cls.field_deprecated("driver_instance_pool_id", "driver_instance_pool_name", "instance-pool", value) + return value + + @validator("policy_name") + def policy_name_validate(cls, value): # noqa + cls.field_deprecated("policy_id", "policy_name", "cluster-policy", value) + return value diff --git a/dbx/models/workflow/common/parameters.py b/dbx/models/workflow/common/parameters.py new file mode 100644 index 00000000..9a731937 --- /dev/null +++ b/dbx/models/workflow/common/parameters.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import Dict, List, Any +from typing import Optional + +from pydantic import BaseModel + +ParamPair = Optional[Dict[str, str]] +StringArray = Optional[List[str]] + + +class ParametersMixin(BaseModel): + parameters: Optional[StringArray] + + +class NamedParametersMixin(BaseModel): + named_parameters: Optional[Dict[str, Any]] + + +class PipelineTaskParametersPayload(BaseModel): + full_refresh: Optional[bool] + + +class BaseParametersMixin(BaseModel): + base_parameters: Optional[ParamPair] + + +class StandardBasePayload(BaseModel): + jar_params: Optional[StringArray] + notebook_params: Optional[ParamPair] + python_params: Optional[StringArray] + spark_submit_params: Optional[StringArray] diff --git a/dbx/models/workflow/common/pipeline.py b/dbx/models/workflow/common/pipeline.py new file mode 100644 index 00000000..abf6b093 --- /dev/null +++ b/dbx/models/workflow/common/pipeline.py @@ -0,0 +1,45 @@ +from typing import Optional, Dict, List, Any, Literal + +from pydantic import validator + +from dbx.models.workflow.common.access_control import AccessControlMixin +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.common.new_cluster import NewCluster +from dbx.utils import dbx_echo + + +class PipelinesNewCluster(NewCluster): + label: Optional[str] + spark_version: Optional[str] = None + init_scripts: List[Any] = [] + + @staticmethod + def _omit_msg(property_name: str): + dbx_echo( + f"[yellow bold]The `{property_name}` property cannot be applied for DLT pipelines. " + "Provided value will be omitted.[/yellow bold]" + ) + + @validator("spark_version", pre=True) + def _validate_spark_version(cls, value): # noqa + if value: + cls._omit_msg("spark_version") + + +class NotebookLibrary(FlexibleModel): + path: str + + +class PipelineLibrary(FlexibleModel): + notebook: NotebookLibrary + + +class Pipeline(AccessControlMixin): + name: str + pipeline_id: Optional[str] + workflow_type: Literal["pipeline"] + storage: Optional[str] + target: Optional[str] + configuration: Optional[Dict[str, str]] + clusters: Optional[List[PipelinesNewCluster]] = [] + libraries: Optional[List[PipelineLibrary]] = [] diff --git a/dbx/models/workflow/common/task.py b/dbx/models/workflow/common/task.py new file mode 100644 index 00000000..879b9d41 --- /dev/null +++ b/dbx/models/workflow/common/task.py @@ -0,0 +1,100 @@ +from abc import ABC +from pathlib import Path +from typing import Optional + +from pydantic import validator, root_validator, BaseModel + +from dbx.constants import TASKS_SUPPORTED_IN_EXECUTE +from dbx.models.cli.execute import ExecuteParametersPayload +from dbx.models.validators import at_least_one_of, only_one_provided +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.common.parameters import ParamPair, StringArray +from dbx.models.workflow.common.task_type import TaskType +from dbx.utils import dbx_echo + + +class BaseNotebookTask(FlexibleModel, ABC): + notebook_path: str + base_parameters: Optional[ParamPair] + + +class SparkJarTask(FlexibleModel): + main_class_name: str + parameters: Optional[StringArray] + jar_params: Optional[StringArray] + jar_uri: Optional[str] + + @validator("jar_uri") + def _deprecated_msg(cls, value): # noqa + dbx_echo( + "[yellow bold] Field jar_uri is DEPRECATED since 04/2016. " + "Provide a [code]jar[/code] through the [code]libraries[/code] field instead." + ) + return value + + +class SparkPythonTask(BaseModel): + python_file: str + parameters: Optional[StringArray] = [] + + @validator("python_file") + def _not_fuse(cls, v): # noqa + if v.startswith("file:fuse://"): + raise ValueError("The python_file property cannot be FUSE-based") + if not v.endswith(".py"): + raise ValueError(f"Only a .py file can be used in this property, provided: {v}") + return v + + @property + def execute_file(self) -> Path: + if not self.python_file.startswith("file://"): + raise ValueError("File for execute mode should be located locally and referenced via file:// prefix.") + + _path = Path(self.python_file).relative_to("file://") + + if not _path.exists(): + raise ValueError(f"Provided file doesn't exist {_path}") + + return _path + + +class SparkSubmitTask(FlexibleModel): + parameters: Optional[StringArray] + spark_submit_params: Optional[StringArray] + + _validate_provided = root_validator(allow_reuse=True)( + lambda _, values: at_least_one_of(["parameters", "spark_submit_params"], values) + ) + + +class BasePipelineTask(FlexibleModel, ABC): + pipeline_id: str + + +class BaseTaskMixin(FlexibleModel): + _only_one_provided = root_validator(pre=True, allow_reuse=True)( + lambda _, values: only_one_provided("_task", values) + ) + + @property + def task_type(self) -> TaskType: + for _type in TaskType: + if self.dict().get(_type): + return TaskType(_type) + return TaskType.undefined_task + + def check_if_supported_in_execute(self): + if self.task_type not in TASKS_SUPPORTED_IN_EXECUTE: + raise RuntimeError( + f"Provided task type {self.task_type} is not supported in execute mode. " + f"Supported types are: {TASKS_SUPPORTED_IN_EXECUTE}" + ) + + def override_execute_parameters(self, payload: ExecuteParametersPayload): + if payload.named_parameters and self.task_type == TaskType.spark_python_task: + raise ValueError( + "`named_parameters` are not supported by spark_python_task. Please use `parameters` instead." + ) + + pointer = getattr(self, self.task_type) + pointer.__dict__.update(payload.dict(exclude_none=True)) diff --git a/dbx/models/workflow/common/task_type.py b/dbx/models/workflow/common/task_type.py new file mode 100644 index 00000000..e1f479a7 --- /dev/null +++ b/dbx/models/workflow/common/task_type.py @@ -0,0 +1,18 @@ +from enum import Enum + + +class TaskType(str, Enum): + # task types defined both in v2.0 and v2.1 + notebook_task = "notebook_task" + spark_jar_task = "spark_jar_task" + spark_python_task = "spark_python_task" + spark_submit_task = "spark_submit_task" + pipeline_task = "pipeline_task" + + # specific to v2.1 + python_wheel_task = "python_wheel_task" + sql_task = "sql_task" + dbt_task = "dbt_task" + + # undefined handler for cases when a new task type is added + undefined_task = "undefined_task" diff --git a/dbx/models/workflow/common/workflow.py b/dbx/models/workflow/common/workflow.py new file mode 100644 index 00000000..27d01ea1 --- /dev/null +++ b/dbx/models/workflow/common/workflow.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from typing import Optional, Union + +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.common.job_email_notifications import JobEmailNotifications + + +class CronSchedule(FlexibleModel): + quartz_cron_expression: str + timezone_id: str + pause_status: Optional[str] + + +class WorkflowBase(FlexibleModel, ABC): + # common fields between 2.0 and 2.1 + name: str + email_notifications: Optional[JobEmailNotifications] + timeout_seconds: Optional[Union[int, str]] + schedule: Optional[CronSchedule] + max_concurrent_runs: Optional[int] + job_id: Optional[str] + + @abstractmethod + def get_task(self, task_key: str): + """Abstract method to be implemented""" + + @abstractmethod + def override_asset_based_launch_parameters(self, payload): + """Abstract method to be implemented""" diff --git a/dbx/models/workflow/common/workflow_types.py b/dbx/models/workflow/common/workflow_types.py new file mode 100644 index 00000000..a7c7c08f --- /dev/null +++ b/dbx/models/workflow/common/workflow_types.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class WorkflowType(str, Enum): + pipeline = "pipeline" + job_v2d0 = "job-v2.0" + job_v2d1 = "job-v2.1" diff --git a/dbx/models/workflow/v2dot0/__init__.py b/dbx/models/workflow/v2dot0/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbx/models/workflow/v2dot0/parameters.py b/dbx/models/workflow/v2dot0/parameters.py new file mode 100644 index 00000000..5322173c --- /dev/null +++ b/dbx/models/workflow/v2dot0/parameters.py @@ -0,0 +1,16 @@ +from pydantic import root_validator + +from dbx.models.validators import mutually_exclusive +from dbx.models.workflow.common.parameters import BaseParametersMixin, ParametersMixin, StandardBasePayload + + +class AssetBasedRunPayload(BaseParametersMixin, ParametersMixin): + """""" + + _validate_unique = root_validator(pre=True)( + lambda _, values: mutually_exclusive(["base_parameters", "parameters"], values) + ) + + +class StandardRunPayload(StandardBasePayload): + """""" diff --git a/dbx/models/workflow/v2dot0/task.py b/dbx/models/workflow/v2dot0/task.py new file mode 100644 index 00000000..e895d6ab --- /dev/null +++ b/dbx/models/workflow/v2dot0/task.py @@ -0,0 +1,26 @@ +from typing import Optional + +from dbx.models.workflow.common.task import ( + BaseNotebookTask, + BasePipelineTask, + BaseTaskMixin, + SparkJarTask, + SparkPythonTask, + SparkSubmitTask, +) + + +class NotebookTask(BaseNotebookTask): + revision_timestamp: Optional[int] + + +class PipelineTask(BasePipelineTask): + """Simple reference to the base""" + + +class TaskMixin(BaseTaskMixin): + notebook_task: Optional[NotebookTask] + spark_jar_task: Optional[SparkJarTask] + spark_python_task: Optional[SparkPythonTask] + spark_submit_task: Optional[SparkSubmitTask] + pipeline_task: Optional[PipelineTask] diff --git a/dbx/models/workflow/v2dot0/workflow.py b/dbx/models/workflow/v2dot0/workflow.py new file mode 100644 index 00000000..5d7b5358 --- /dev/null +++ b/dbx/models/workflow/v2dot0/workflow.py @@ -0,0 +1,54 @@ +from typing import Optional, List, Union, Literal + +from pydantic import root_validator, validator + +from dbx.models.workflow.common.access_control import AccessControlMixin +from dbx.models.workflow.common.deployment_config import DbxDeploymentConfig +from dbx.models.workflow.common.libraries import Library +from dbx.models.workflow.common.new_cluster import NewCluster +from dbx.models.workflow.common.task import SparkPythonTask, SparkJarTask, SparkSubmitTask +from dbx.models.workflow.common.task_type import TaskType +from dbx.models.workflow.common.workflow import WorkflowBase +from dbx.models.workflow.common.workflow_types import WorkflowType +from dbx.models.workflow.v2dot0.parameters import AssetBasedRunPayload +from dbx.models.workflow.v2dot0.task import TaskMixin, NotebookTask + +ALLOWED_TASK_TYPES = Union[SparkPythonTask, NotebookTask, SparkJarTask, SparkSubmitTask] + + +class Workflow(WorkflowBase, TaskMixin, AccessControlMixin): + # this follows structure of 2.0 API + # https://docs.databricks.com/dev-tools/api/2.0/jobs.html + existing_cluster_id: Optional[str] + existing_cluster_name: Optional[str] # deprecated field + new_cluster: Optional[NewCluster] + libraries: Optional[List[Library]] = [] + max_retries: Optional[int] + min_retry_interval_millis: Optional[int] + retry_on_timeout: Optional[bool] + deployment_config: Optional[DbxDeploymentConfig] + workflow_type: Literal[WorkflowType.job_v2d0] = WorkflowType.job_v2d0 + + @validator("existing_cluster_name") + def _deprecated(cls, value): # noqa + cls.field_deprecated("existing_cluster_id", "existing_cluster_name", "cluster", value) + return value + + @root_validator() + def mutually_exclusive(cls, values): # noqa + if not values.get("pipeline_task"): + if values.get("new_cluster") and (values.get("existing_cluster_id") or values.get("existing_cluster_name")): + raise ValueError( + 'Fields ("existing_cluster_id" or "existing_cluster_name") and "new_cluster" are mutually exclusive' + ) + return values + + def get_task(self, task_key: str): + raise RuntimeError("Provided workflow format is V2.0, and it doesn't support tasks") + + def override_asset_based_launch_parameters(self, payload: AssetBasedRunPayload): + if self.task_type == TaskType.notebook_task: + self.notebook_task.base_parameters = payload.base_parameters + else: + pointer = getattr(self, self.task_type) + pointer.__dict__.update(payload.dict(exclude_none=True)) diff --git a/dbx/models/workflow/v2dot1/__init__.py b/dbx/models/workflow/v2dot1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbx/models/workflow/v2dot1/_parameters.py b/dbx/models/workflow/v2dot1/_parameters.py new file mode 100644 index 00000000..dec6f563 --- /dev/null +++ b/dbx/models/workflow/v2dot1/_parameters.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import Optional, Union + +from dbx.models.workflow.common.parameters import ( + ParametersMixin, + ParamPair, + StringArray, + BaseParametersMixin, + PipelineTaskParametersPayload, +) + + +class FlexibleParametersMixin(ParametersMixin): + parameters: Optional[Union[ParamPair, StringArray]] + + +class PayloadElement(FlexibleParametersMixin, BaseParametersMixin, PipelineTaskParametersPayload): + task_key: str diff --git a/dbx/models/workflow/v2dot1/job_cluster.py b/dbx/models/workflow/v2dot1/job_cluster.py new file mode 100644 index 00000000..4c506034 --- /dev/null +++ b/dbx/models/workflow/v2dot1/job_cluster.py @@ -0,0 +1,39 @@ +import collections +from typing import Dict, Any, List, Optional + +from pydantic import root_validator + +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.common.new_cluster import NewCluster + + +class JobCluster(FlexibleModel): + job_cluster_key: str + new_cluster: NewCluster + + +class JobClustersMixin(FlexibleModel): + job_clusters: Optional[List[JobCluster]] = [] + + @root_validator(pre=True) + def _jc_validator(cls, values: Dict[str, Any]): # noqa + if values: + job_clusters = values.get("job_clusters", []) + + # checks that structure is provided in expected format + assert isinstance(job_clusters, list), f"Job clusters payload should be a list, provided: {job_clusters}" + + _duplicates = [ + name + for name, count in collections.Counter([jc.get("job_cluster_key") for jc in job_clusters]).items() + if count > 1 + ] + if _duplicates: + raise ValueError(f"Duplicated cluster keys {_duplicates} found in the job_clusters section") + return values + + def get_job_cluster_definition(self, key: str) -> JobCluster: + _found = list(filter(lambda jc: jc.job_cluster_key == key, self.job_clusters)) + if not _found: + raise ValueError(f"Cluster key {key} is not provided in the job_clusters section: {self.job_clusters}") + return _found[0] diff --git a/dbx/models/workflow/v2dot1/job_task_settings.py b/dbx/models/workflow/v2dot1/job_task_settings.py new file mode 100644 index 00000000..4084e666 --- /dev/null +++ b/dbx/models/workflow/v2dot1/job_task_settings.py @@ -0,0 +1,28 @@ +from typing import Optional, List + +from dbx.models.workflow.common.deployment_config import DbxDeploymentConfig +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.common.libraries import Library +from dbx.models.workflow.common.new_cluster import NewCluster +from dbx.models.workflow.common.job_email_notifications import JobEmailNotifications +from dbx.models.workflow.v2dot1.task import TaskMixin + + +class TaskDependencies(FlexibleModel): + task_key: str + + +class JobTaskSettings(TaskMixin): + task_key: str + description: Optional[str] + depends_on: Optional[List[TaskDependencies]] + existing_cluster_id: Optional[str] + new_cluster: Optional[NewCluster] + job_cluster_key: Optional[str] + libraries: Optional[List[Library]] = [] + email_notifications: Optional[JobEmailNotifications] + timeout_seconds: Optional[int] + max_retries: Optional[int] + min_retry_interval_millis: Optional[int] + retry_on_timeout: Optional[bool] + deployment_config: Optional[DbxDeploymentConfig] diff --git a/dbx/models/workflow/v2dot1/parameters.py b/dbx/models/workflow/v2dot1/parameters.py new file mode 100644 index 00000000..1bc23e79 --- /dev/null +++ b/dbx/models/workflow/v2dot1/parameters.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import json +from typing import Optional, List + +from pydantic import BaseModel, validator + +from dbx.models.validators import check_dbt_commands +from dbx.models.workflow.common.parameters import ( + ParamPair, + StringArray, + StandardBasePayload, + PipelineTaskParametersPayload, +) +from dbx.models.workflow.v2dot1._parameters import PayloadElement + + +class AssetBasedRunPayload(BaseModel): + elements: Optional[List[PayloadElement]] + + @staticmethod + def from_string(raw: str) -> AssetBasedRunPayload: + return AssetBasedRunPayload(elements=json.loads(raw)) + + +class StandardRunPayload(StandardBasePayload): + python_named_params: Optional[ParamPair] + pipeline_params: Optional[PipelineTaskParametersPayload] + sql_params: Optional[ParamPair] + dbt_commands: Optional[StringArray] + + _verify_dbt_commands = validator("dbt_commands", allow_reuse=True)(check_dbt_commands) diff --git a/dbx/models/workflow/v2dot1/task.py b/dbx/models/workflow/v2dot1/task.py new file mode 100644 index 00000000..7005da14 --- /dev/null +++ b/dbx/models/workflow/v2dot1/task.py @@ -0,0 +1,86 @@ +from enum import Enum +from typing import Optional, List, Dict + +from pydantic import root_validator, validator, BaseModel +from pydantic.fields import Field + +from dbx.models.validators import check_dbt_commands, at_least_one_of, mutually_exclusive +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.common.task import ( + BaseTaskMixin, + BaseNotebookTask, + SparkJarTask, + SparkPythonTask, + SparkSubmitTask, + BasePipelineTask, +) + + +class NotebookSource(str, Enum): + WORKSPACE = "WORKSPACE" + GIT = "GIT" + + +class NotebookTask(BaseNotebookTask): + source: Optional[NotebookSource] + + +class PipelineTask(BasePipelineTask): + full_refresh: Optional[bool] + + +class SqlTaskQuery(FlexibleModel): + query_id: str + + +class SqlTaskDashboard(FlexibleModel): + dashboard_id: str + + +class SqlTaskAlert(FlexibleModel): + alert_id: str + + +class SqlTask(FlexibleModel): + warehouse_id: str + query: Optional[SqlTaskQuery] + dashboard: Optional[SqlTaskDashboard] + alert: Optional[SqlTaskAlert] + + @root_validator(pre=True) + def _validate(cls, values): # noqa + at_least_one_of(["query", "dashboard", "alert"], values) + mutually_exclusive(["query", "dashboard", "alert"], values) + return values + + +class DbtTask(FlexibleModel): + project_directory: Optional[str] + profiles_directory: Optional[str] + commands: List[str] + _schema: str = Field(alias="schema") # noqa + warehouse_id: str + + _verify_dbt_commands = validator("commands", allow_reuse=True)(check_dbt_commands) + + +class PythonWheelTask(BaseModel): + package_name: str + entry_point: str + parameters: Optional[List[str]] = [] + named_parameters: Optional[Dict[str, str]] = {} + + _validate_exclusive = root_validator(pre=True, allow_reuse=True)( + lambda _, values: mutually_exclusive(["parameters", "named_parameters"], values) + ) + + +class TaskMixin(BaseTaskMixin): + notebook_task: Optional[NotebookTask] + spark_jar_task: Optional[SparkJarTask] + spark_python_task: Optional[SparkPythonTask] + spark_submit_task: Optional[SparkSubmitTask] + python_wheel_task: Optional[PythonWheelTask] + pipeline_task: Optional[PipelineTask] + sql_task: Optional[SqlTask] + dbt_task: Optional[DbtTask] diff --git a/dbx/models/workflow/v2dot1/workflow.py b/dbx/models/workflow/v2dot1/workflow.py new file mode 100644 index 00000000..8e1535cf --- /dev/null +++ b/dbx/models/workflow/v2dot1/workflow.py @@ -0,0 +1,71 @@ +import collections +from typing import Optional, List, Dict, Any, Literal + +from pydantic import root_validator, validator + +from dbx.models.validators import at_least_one_of, mutually_exclusive +from dbx.models.workflow.common.access_control import AccessControlMixin +from dbx.models.workflow.common.flexible import FlexibleModel +from dbx.models.workflow.common.workflow import WorkflowBase +from dbx.models.workflow.common.workflow_types import WorkflowType +from dbx.models.workflow.v2dot1.job_cluster import JobClustersMixin +from dbx.models.workflow.v2dot1.job_task_settings import JobTaskSettings +from dbx.models.workflow.v2dot1.parameters import AssetBasedRunPayload +from dbx.utils import dbx_echo + + +class GitSource(FlexibleModel): + git_url: str + git_provider: str + git_branch: Optional[str] + git_tag: Optional[str] + git_commit: Optional[str] + + @root_validator(pre=True) + def _validate(cls, values): # noqa + at_least_one_of(["git_branch", "git_tag", "git_commit"], values) + mutually_exclusive(["git_branch", "git_tag", "git_commit"], values) + return values + + +class Workflow(WorkflowBase, AccessControlMixin, JobClustersMixin): + # this follows the structure of 2.1 Jobs API + # https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate + tags: Optional[Dict[str, Any]] + tasks: Optional[List[JobTaskSettings]] + git_source: Optional[GitSource] + format: Optional[str] + workflow_type: Literal[WorkflowType.job_v2d1] = WorkflowType.job_v2d1 + + @validator("tasks") + def _validate_tasks(cls, tasks: Optional[List[JobTaskSettings]]) -> Optional[List[JobTaskSettings]]: # noqa + if tasks: + _duplicates = [ + name for name, count in collections.Counter([t.task_key for t in tasks]).items() if count > 1 + ] + if _duplicates: + raise ValueError("Duplicated task keys are not allowed.") + else: + dbx_echo( + "[yellow bold]No task definitions were provided for workflow. " + "This might cause errors during deployment[/yellow bold]" + ) + return tasks + + def get_task(self, task_key: str) -> JobTaskSettings: + _found = list(filter(lambda t: t.task_key == task_key, self.tasks)) + assert len(_found) == 1, ValueError( + f"Requested task key {task_key} doesn't exist in the workflow definition." + f"Available tasks are: {self.task_names}" + ) + return _found[0] + + @property + def task_names(self) -> List[str]: + return [t.task_key for t in self.tasks] + + def override_asset_based_launch_parameters(self, payload: AssetBasedRunPayload): + for task_parameters in payload.elements: + _t = self.get_task(task_parameters.task_key) + pointer = getattr(_t, _t.task_type) + pointer.__dict__.update(task_parameters.dict(exclude_none=True)) diff --git a/dbx/options.py b/dbx/options.py index 9154113f..e4893c34 100644 --- a/dbx/options.py +++ b/dbx/options.py @@ -1,5 +1,3 @@ -from pathlib import Path - import typer from databricks_cli.configure.provider import DEFAULT_SECTION @@ -59,7 +57,8 @@ ) REQUIREMENTS_FILE_OPTION = typer.Option( - Path("requirements.txt"), + None, + "--requirements-file", help=""" This option is deprecated. @@ -127,7 +126,7 @@ dbx execute --parameters='{parameters: ["argument1", "argument2"]}' - dbx execute --parameters='{named_parameters: ["--a=1", "--b=2"]}' + dbx execute --parameters='{named_parameters: {"a": 1, "b": 1}}' Please note that various tasks have various parameter structures. @@ -148,44 +147,9 @@ "--parameters", help="""If provided, overrides parameters of the chosen workflow. - Depending on the workflow type and launch type, it might contain various payloads. - - Provided payload shall match the expected payload of a chosen workflow or task. - - Payload should be wrapped as a JSON-compatible string with curly brackets around API-compatible payload. - - Examples: - - - Jobs API v2.0 `spark_python_task`, or `python_wheel_task` workflow without tasks inside - - `dbx launch --parameters='{"parameters": ["argument1", "argument2"]}' - - - Jobs API v2.1 multitask job with one `python_wheel_task` - - `dbx execute --parameters='[{"task_key": "some", "named_parameters": - ["--a=1", "--b=2"]}]'` - - - Jobs API v2.1 multitask job with one notebook_task - `dbx execute --parameters='[{"task_key": "some", "base_parameters": - {"a": 1, "b": 2}}]'` - - - Jobs API v2.1 multitask job with 2 tasks - - `dbx execute --parameters='[ - {"task_key": "first", "base_parameters": {"a": 1, "b": 2}}, - {"task_key": "second", "parameters": ["a", "b"]}]'` - - - Also note that all parameters provided for the workflow or task will be preprocessed with file uploader. - - - It means that if you reference a `file://` or `file:fuse://` string - in the parameter override, it will be resolved and uploaded to DBFS. - - - You can find more on the parameter structures for various Jobs API - versions in the official documentation""", + Please read more details on the parameter passing in the dbx docs. + """, show_default=True, callback=launch_parameters_callback, ) diff --git a/dbx/sync/__init__.py b/dbx/sync/__init__.py index 48527c52..284fddc4 100644 --- a/dbx/sync/__init__.py +++ b/dbx/sync/__init__.py @@ -184,7 +184,7 @@ async def _apply_dirs_deleted( async def _apply_file_puts(self, session: aiohttp.ClientSession, paths: List[str], msg: str) -> None: tasks = [] op_count = 0 - sem = asyncio.Semaphore(self.max_parallel_puts) + sem = asyncio.Semaphore(self.max_parallel_puts) # noqa for path in sorted(paths): op_count += 1 if not self.dry_run: @@ -192,7 +192,7 @@ async def _apply_file_puts(self, session: aiohttp.ClientSession, paths: List[str async def task(p): # Files can be created in parallel, but we limit how many are opened at a time # so we don't use memory excessively. - async with sem: + async with sem: # noqa await self.client.put(get_relative_path(self.source, p), p, session=session) tasks.append(task(path)) diff --git a/dbx/sync/clients.py b/dbx/sync/clients.py index 765f5adc..b7a1b955 100644 --- a/dbx/sync/clients.py +++ b/dbx/sync/clients.py @@ -8,6 +8,7 @@ from databricks_cli.version import version as databricks_cli_version from dbx.utils import dbx_echo +from dbx.utils.url import strip_databricks_url class ClientError(Exception): @@ -38,7 +39,7 @@ def get_user(config: DatabricksConfig) -> dict: or isn't supported """ api_token = config.token - host = config.host.rstrip("/") + host = strip_databricks_url(config.host) headers = get_headers(api_token) url = f"{host}/api/2.0/preview/scim/v2/Me" resp = requests.get(url, headers=headers, timeout=10) @@ -158,7 +159,7 @@ def __init__(self, *, base_path: str, config: DatabricksConfig): check_path(base_path) self.base_path = "dbfs:" + base_path.rstrip("/") self.api_token = config.token - self.host = config.host.rstrip("/") + self.host = strip_databricks_url(config.host) self.api_base_path = f"{self.host}/api/2.0/dbfs" if config.insecure is None: self.ssl = None @@ -251,7 +252,7 @@ def __init__(self, *, user: str, repo_name: str, config: DatabricksConfig): raise ValueError("repo_name is required") self.base_path = f"/Repos/{user}/{repo_name}" self.api_token = config.token - self.host = config.host.rstrip("/") + self.host = strip_databricks_url(config.host) self.workspace_api_base_path = f"{self.host}/api/2.0/workspace" self.workspace_files_api_base_path = f"{self.host}/api/2.0/workspace-files/import-file" if config.insecure is None: @@ -272,6 +273,31 @@ async def delete(self, sub_path: str, *, session: aiohttp.ClientSession, recursi ssl=self.ssl, ) + async def exists(self, *, session: aiohttp.ClientSession) -> bool: + """Checks if the target repo the client will sync to exists. + + Args: + session (aiohttp.ClientSession): client session + + Raises: + ClientError: failed to check repos API + + Returns: + bool: True if the repo exists, otherwise False + """ + headers = get_headers(self.api_token, self.name) + more_opts = {"ssl": self.ssl} if self.ssl is not None else {} + url = f"{self.host}/api/2.0/repos" + params = {"path_prefix": self.base_path} + async with session.get(url=url, headers=headers, params=params, **more_opts) as resp: + if resp.status == 200: + json_resp = await resp.json() + return self.base_path in [repo["path"] for repo in json_resp.get("repos", [])] + else: + txt = await resp.text() + dbx_echo(f"HTTP {resp.status}: {txt}") + raise ClientError(resp.status) + async def mkdirs(self, sub_path: str, *, session: aiohttp.ClientSession): check_path(sub_path) path = f"{self.base_path}/{sub_path}" diff --git a/dbx/templates/projects/python_basic/render/hooks/post_gen_project.py b/dbx/templates/projects/python_basic/render/hooks/post_gen_project.py index d1eaf320..0a7b54ae 100644 --- a/dbx/templates/projects/python_basic/render/hooks/post_gen_project.py +++ b/dbx/templates/projects/python_basic/render/hooks/post_gen_project.py @@ -92,8 +92,13 @@ def process_cloud_component(env: Environment): def process(): configure( - environment="default", workspace_dir=WORKSPACE_DIR, artifact_location=ARTIFACT_LOCATION, profile=PROFILE, - enable_inplace_jinja_support=False, enable_failsafe_cluster_reuse_with_assets=False + environment="default", + workspace_dir=WORKSPACE_DIR, + artifact_location=ARTIFACT_LOCATION, + profile=PROFILE, + enable_inplace_jinja_support=False, + enable_failsafe_cluster_reuse_with_assets=False, + enable_context_based_upload_for_execute=False, ) env = Environment(loader=FileSystemLoader(COMPONENTS_PATH)) diff --git a/dbx/templates/projects/python_basic/render/{{cookiecutter.project_name}}/README.md b/dbx/templates/projects/python_basic/render/{{cookiecutter.project_name}}/README.md index ada24f0a..62167a45 100644 --- a/dbx/templates/projects/python_basic/render/{{cookiecutter.project_name}}/README.md +++ b/dbx/templates/projects/python_basic/render/{{cookiecutter.project_name}}/README.md @@ -83,7 +83,7 @@ To start working with your notebooks from a Repos, do the following steps: databricks repos create --url --provider ``` This command will create your personal repository under `/Repos//{{cookiecutter.project_slug}}`. -3. Use `git_source` in your job definition as described [here](https://dbx.readthedocs.io/en/latest/examples/notebook_remote.html) +3. Use `git_source` in your job definition as described [here](https://dbx.readthedocs.io/en/latest/guides/python/devops/notebook/?h=git_source#using-git_source-to-specify-the-remote-source) ## CI/CD pipeline settings diff --git a/dbx/templates/projects/python_basic/render/{{cookiecutter.project_name}}/setup.py b/dbx/templates/projects/python_basic/render/{{cookiecutter.project_name}}/setup.py index c5eeaa53..670ad187 100644 --- a/dbx/templates/projects/python_basic/render/{{cookiecutter.project_name}}/setup.py +++ b/dbx/templates/projects/python_basic/render/{{cookiecutter.project_name}}/setup.py @@ -25,7 +25,7 @@ "pytest", "coverage[toml]", "pytest-cov", - "dbx>=0.7,<0.8" + "dbx>=0.8" ] setup( diff --git a/dbx/templates/projects/python_basic/render/{{cookiecutter.project_name}}/tests/entrypoint.py b/dbx/templates/projects/python_basic/render/{{cookiecutter.project_name}}/tests/entrypoint.py index 6d3a158d..31ca684e 100644 --- a/dbx/templates/projects/python_basic/render/{{cookiecutter.project_name}}/tests/entrypoint.py +++ b/dbx/templates/projects/python_basic/render/{{cookiecutter.project_name}}/tests/entrypoint.py @@ -2,5 +2,7 @@ import pytest -if __name__ == '__main__': - pytest.main(sys.argv[1:]) +if __name__ == "__main__": + exit_code = pytest.main(sys.argv[1:]) + if exit_code != pytest.ExitCode.OK: + raise RuntimeError(f"pytest returned non-zero exit code: {str(exit_code)}. See logs for details.") diff --git a/dbx/types.py b/dbx/types.py new file mode 100644 index 00000000..cbe391d6 --- /dev/null +++ b/dbx/types.py @@ -0,0 +1,6 @@ +from typing import Union + +from dbx.models.workflow.v2dot0.workflow import Workflow as V2dot0Workflow +from dbx.models.workflow.v2dot1.job_task_settings import JobTaskSettings + +ExecuteTask = Union[V2dot0Workflow, JobTaskSettings] diff --git a/dbx/utils/__init__.py b/dbx/utils/__init__.py index 5231fd51..911d6c9f 100644 --- a/dbx/utils/__init__.py +++ b/dbx/utils/__init__.py @@ -4,6 +4,9 @@ import typer from rich import print as rich_print +from rich import reconfigure + +reconfigure(soft_wrap=True) def format_dbx_message(message: Any) -> str: diff --git a/dbx/utils/adjuster.py b/dbx/utils/adjuster.py deleted file mode 100644 index bb95d551..00000000 --- a/dbx/utils/adjuster.py +++ /dev/null @@ -1,96 +0,0 @@ -import pathlib -from typing import List, Dict, Any - -from databricks_cli.sdk import ApiClient - -from dbx.utils import dbx_echo -from dbx.utils.dependency_manager import DependencyManager -from dbx.utils.file_uploader import MlflowFileUploader, AbstractFileUploader -from dbx.utils.named_properties import WorkloadPropertiesProcessor, NewClusterPropertiesProcessor, PolicyNameProcessor - - -def adjust_job_definitions( - jobs: List[Dict[str, Any]], - dependency_manager: DependencyManager, - file_uploader: MlflowFileUploader, - api_client: ApiClient, -): - def adjustment_callback(p: Any): - return adjust_path(p, file_uploader) - - for job in jobs: - - # please note that all adjustments here have side effects to the main jobs object. - - workload_processor = WorkloadPropertiesProcessor(api_client) - new_cluster_processor = NewClusterPropertiesProcessor(api_client) - policy_name_processor = PolicyNameProcessor(api_client) - - adjustable_references = [] - - if "tasks" in job: - dbx_echo(f"Tasks section found in the job {job['name']}, job will be deployed as a multitask job") - adjustable_references += job["tasks"] - job_clusters = job.get("job_clusters", []) - for jc_reference in job_clusters: - cluster_definition = jc_reference.get("new_cluster", {}) - policy_name_processor.process(cluster_definition) - new_cluster_processor.process(cluster_definition) - else: - dbx_echo(f"Tasks section not found in the job {job['name']}, job will be deployed as a single-task job") - adjustable_references.append(job) - - for workload_reference in adjustable_references: - - dependency_manager.process_dependencies(workload_reference) - workload_processor.process(workload_reference) - - new_cluster_definition = workload_reference.get("new_cluster", {}) - - if new_cluster_definition: - policy_name_processor.process(new_cluster_definition) - new_cluster_processor.process(new_cluster_definition) - - walk_content(adjustment_callback, workload_reference) - - -def walk_content(func, content, parent=None, index=None): - if isinstance(content, dict): - for key, item in content.items(): - walk_content(func, item, content, key) - elif isinstance(content, list): - for idx, sub_item in enumerate(content): - walk_content(func, sub_item, content, idx) - else: - parent[index] = func(content) - - -def path_adjustment(candidate: str, file_uploader: AbstractFileUploader) -> str: - if candidate.startswith("file:"): - fuse_flag = candidate.startswith("file:fuse:") - replace_string = "file:fuse://" if fuse_flag else "file://" - local_path = pathlib.Path(candidate.replace(replace_string, "")) - - if not local_path.exists(): - raise FileNotFoundError( - f"Path {candidate} is referenced in the deployment configuration, but is non-existent." - ) - - adjusted_path = file_uploader.upload_and_provide_path(local_path, as_fuse=fuse_flag) - - return adjusted_path - - else: - return candidate - - -def adjust_path(candidate, file_uploader: AbstractFileUploader): - if isinstance(candidate, str): - # path already adjusted or points to another dbfs object - pass it - if candidate.startswith("dbfs") or candidate.startswith("/dbfs"): - return candidate - else: - adjusted_path = path_adjustment(candidate, file_uploader) - return adjusted_path - else: - return candidate diff --git a/dbx/utils/common.py b/dbx/utils/common.py index 9e5e8149..06fe4061 100644 --- a/dbx/utils/common.py +++ b/dbx/utils/common.py @@ -1,5 +1,4 @@ import os -from pathlib import Path from typing import Dict, List, Optional import git @@ -52,19 +51,6 @@ def get_environment_data(environment: str) -> EnvironmentInfo: return ProjectConfigurationManager().get(environment) -def get_package_file() -> Optional[Path]: - dbx_echo("Locating package file") - file_locator = list(Path("dist").glob("*.whl")) - sorted_locator = sorted(file_locator, key=os.path.getmtime) # get latest modified file, aka latest package version - if sorted_locator: - file_path = sorted_locator[-1] - dbx_echo(f"Package file located in: {file_path}") - return file_path - else: - dbx_echo("Package file was not found") - return None - - def get_current_branch_name() -> Optional[str]: if "GITHUB_REF" in os.environ: ref = os.environ["GITHUB_REF"].split("/") diff --git a/dbx/utils/dependency_manager.py b/dbx/utils/dependency_manager.py deleted file mode 100644 index b4201557..00000000 --- a/dbx/utils/dependency_manager.py +++ /dev/null @@ -1,88 +0,0 @@ -from pathlib import Path -from typing import Optional, Dict, List, Union, Any - -import pkg_resources - -from dbx.models.deployment import BuildConfiguration -from dbx.utils import dbx_echo -from dbx.utils.common import get_package_file -from dbx.api.build import prepare_build - -LibraryReference = Dict[str, Union[str, Dict[str, Any]]] - - -class DependencyManager: - """ - This class manages dependency references in the job or task deployment. - """ - - def __init__(self, build_config: BuildConfiguration, global_no_package: bool, requirements_file: Optional[Path]): - self.build_config = build_config - self._global_no_package = global_no_package - self._core_package_reference: Optional[LibraryReference] = self._get_package_requirement() - self._requirements_references: List[LibraryReference] = self._get_requirements_from_file(requirements_file) - - @staticmethod - def _delete_managed_libraries(packages: List[pkg_resources.Requirement]) -> List[pkg_resources.Requirement]: - output_packages = [] - - for package in packages: - - if package.key == "pyspark": - dbx_echo("pyspark dependency deleted from the list of libraries, because it's a managed library") - else: - output_packages.append(package) - - return output_packages - - def _get_requirements_from_file(self, requirements_file: Optional[Path]) -> List[LibraryReference]: - if not requirements_file: - dbx_echo("No requirements file was provided") - return [] - else: - - if not requirements_file.exists(): - dbx_echo("Requirements file doesn't exist") - return [] - else: - with requirements_file.open(encoding="utf-8") as requirements_txt: - requirements_content = pkg_resources.parse_requirements(requirements_txt) - filtered_libraries = self._delete_managed_libraries(requirements_content) - requirements_payload = [{"pypi": {"package": str(req)}} for req in filtered_libraries] - return requirements_payload - - def _get_package_requirement(self) -> Optional[LibraryReference]: - """ - Prepare package requirement to be added into the definition in case it's required. - """ - prepare_build(self.build_config) - package_file = get_package_file() - - if self._global_no_package: - dbx_echo("No package definition will be added into any jobs in the given deployment") - return None - else: - if package_file: - return {"whl": f"file://{package_file}"} - else: - dbx_echo( - "Package file was not found! " - "Please check your dist folder if you expect to use package-based imports" - ) - return None - - def process_dependencies(self, reference: Dict[str, Any]): - reference_level_deployment_config = reference.get("deployment_config", {}) - no_package_reference = reference_level_deployment_config.get("no_package", False) - - if self._global_no_package and not no_package_reference: - dbx_echo( - ":warning: Global --no-package option is set to true, " - "but task or job level deployment config is set to false. " - "Global-level property will take priority." - ) - - reference["libraries"] = reference.get("libraries", []) + self._requirements_references - - if not no_package_reference: - reference["libraries"] += [self._core_package_reference] diff --git a/dbx/utils/file_uploader.py b/dbx/utils/file_uploader.py index 0fd86a5e..31bdf2e5 100644 --- a/dbx/utils/file_uploader.py +++ b/dbx/utils/file_uploader.py @@ -1,6 +1,7 @@ +import functools from abc import ABC, abstractmethod from pathlib import Path, PurePosixPath -from typing import Optional, Dict +from typing import Optional, Tuple import mlflow from retry import retry @@ -11,36 +12,54 @@ class AbstractFileUploader(ABC): def __init__(self, base_uri: Optional[str] = None): - self._base_uri = base_uri - self._uploaded_files: Dict[Path, str] = {} # contains mapping from local to remote paths for all uploaded files + self.base_uri = base_uri @abstractmethod def _upload_file(self, local_file_path: Path): """""" def _verify_fuse_support(self): - if not self._base_uri.startswith("dbfs:/"): + if not self.base_uri.startswith("dbfs:/"): raise Exception( "Fuse-based paths are not supported for non-dbfs artifact locations." "If fuse-like paths are required, consider using experiment with DBFS as a location." ) - def upload_and_provide_path(self, local_file_path: Path, as_fuse: Optional[bool] = False) -> str: - if as_fuse: - self._verify_fuse_support() + def _postprocess_path(self, local_file_path: Path, as_fuse) -> str: + remote_path = "/".join([self.base_uri, str(local_file_path.as_posix())]) + remote_path = remote_path.replace("dbfs:/", "/dbfs/") if as_fuse else remote_path - if local_file_path in self._uploaded_files: - remote_path = self._uploaded_files[local_file_path] - else: - dbx_echo(f":arrow_up: Uploading local file {local_file_path}") - self._upload_file(local_file_path) - remote_path = "/".join([self._base_uri, str(local_file_path.as_posix())]) - self._uploaded_files[local_file_path] = remote_path - dbx_echo(f":white_check_mark: Uploading local file {local_file_path}") + if self.base_uri.startswith("wasbs://"): + remote_path = remote_path.replace("wasbs://", "abfss://") + remote_path = remote_path.replace(".blob.", ".dfs.") - remote_path = remote_path.replace("dbfs:/", "/dbfs/") if as_fuse else remote_path return remote_path + @staticmethod + def _preprocess_reference(ref: str) -> Tuple[Path, bool]: + _as_fuse = ref.startswith("file:fuse://") + _corrected = ref.replace("file:fuse://", "") if _as_fuse else ref.replace("file://", "") + _path = Path(_corrected) + return _path, _as_fuse + + @staticmethod + def _verify_reference(ref, _path: Path): + if not _path.exists(): + raise FileNotFoundError(f"Provided file reference: {ref} doesn't exist in the local FS") + + @functools.lru_cache(maxsize=3000) + def upload_and_provide_path(self, file_reference: str) -> str: + local_file_path, as_fuse = self._preprocess_reference(file_reference) + self._verify_reference(file_reference, local_file_path) + + if as_fuse: + self._verify_fuse_support() + + dbx_echo(f":arrow_up: Uploading local file {local_file_path}") + self._upload_file(local_file_path) + dbx_echo(f":white_check_mark: Uploading local file {local_file_path}") + return self._postprocess_path(local_file_path, as_fuse) + class MlflowFileUploader(AbstractFileUploader): """ @@ -65,12 +84,12 @@ def _verify_fuse_support(self): dbx_echo("Skipping the FUSE check since context-based uploader is used") def _upload_file(self, local_file_path: Path): - self._client.upload_file(local_file_path, self._base_uri) + self._client.upload_file(local_file_path, self.base_uri) def __del__(self): try: dbx_echo("Cleaning up the temp directory") - self._client.remove_dir(self._base_uri) + self._client.remove_dir(self.base_uri) dbx_echo(":white_check_mark: Cleaning up the temp directory") except Exception as e: dbx_echo(f"Cannot cleanup temp directory due to {e}") diff --git a/dbx/utils/job_listing.py b/dbx/utils/job_listing.py deleted file mode 100644 index 561e387e..00000000 --- a/dbx/utils/job_listing.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import List, Dict, Any, Optional - -from databricks_cli.sdk.service import JobsService - - -def list_all_jobs(js: JobsService) -> List[Dict[str, Any]]: - all_jobs = js.list_jobs( - version="2.0" - ) # version 2.0 is expected to list all jobs without iterations over limit/offset - return all_jobs.get("jobs", []) - - -def find_job_by_name(js: JobsService, job_name: str) -> Optional[Dict[str, Any]]: - all_jobs = list_all_jobs(js) - matching_jobs = [j for j in all_jobs if j["settings"]["name"] == job_name] - - if len(matching_jobs) > 1: - raise Exception( - f"""There are more than one jobs with name {job_name}. - Please delete duplicated jobs first""" - ) - - if not matching_jobs: - return None - else: - return matching_jobs[0] diff --git a/dbx/utils/named_properties.py b/dbx/utils/named_properties.py deleted file mode 100644 index c8f03f3e..00000000 --- a/dbx/utils/named_properties.py +++ /dev/null @@ -1,162 +0,0 @@ -""" -Databricks Jobs API supports various arguments with _id-based API, but this is really quirky for end users. -In particular, this code supports the following resolutions: -1. Job and task-name resolutions: - - existing_cluster_name - - for new_cluster on job name: - - aws_attributes.instance_profile_id - - new_cluster.instance_pool_name - - new_cluster.driver_instance_pool_name - - new_cluster.aws_attributes.instance_profile_name -2. Multitask-jobs job_clusters properties -3. policy_name on the new_cluster structure -""" -import abc -import collections.abc -import json -from typing import Dict, Any - -from databricks_cli.sdk import ApiClient, InstancePoolService, PolicyService - -from dbx.api.cluster import ClusterController -from dbx.utils import dbx_echo -from dbx.utils.policy_parser import PolicyParser - - -class AbstractProcessor(abc.ABC): - def __init__(self, api_client: ApiClient): - self._api_client = api_client - - @abc.abstractmethod - def process(self, object_reference: Dict[str, Any]): - """""" - - -class PolicyNameProcessor(AbstractProcessor): - def process(self, object_reference: Dict[str, Any]): - policy_name = object_reference.get("policy_name") - - if policy_name: - dbx_echo(f"Processing policy name {policy_name}") - policy_spec = self._preprocess_policy_name(policy_name) - policy = json.loads(policy_spec["definition"]) - policy_props = PolicyParser(policy).parse() - self._deep_update(object_reference, policy_props, policy_name) - object_reference["policy_id"] = policy_spec["policy_id"] - - @staticmethod - def _deep_update(d: Dict, u: collections.abc.Mapping, policy_name: str) -> Dict: - for k, v in u.items(): - if isinstance(v, collections.abc.Mapping): - d[k] = PolicyNameProcessor._deep_update(d.get(k, {}), v, policy_name) - else: - # if the key is already provided in deployment configuration, we need to verify the value - # if value exists, we verify that it's the same as in the policy - existing_value = d.get(k) - if existing_value: - if existing_value != v: - raise Exception( - f"For key {k} there is a value in the cluster definition: {existing_value} \n" - f"However this value is fixed in the policy {policy_name} and shall be equal to: {v}" - ) - d[k] = v - return d - - def _preprocess_policy_name(self, policy_name: str): - policies = PolicyService(self._api_client).list_policies().get("policies", []) - found_policies = [p for p in policies if p["name"] == policy_name] - - if not found_policies: - raise Exception(f"Policy {policy_name} not found") - - if len(found_policies) > 1: - raise Exception(f"Policy with name {policy_name} is not unique. Please make unique names for policies.") - - policy_spec = found_policies[0] - return policy_spec - - -class WorkloadPropertiesProcessor(AbstractProcessor): - def process(self, object_reference: Dict[str, Any]): - self._preprocess_existing_cluster_name(object_reference) - - def _preprocess_existing_cluster_name(self, object_reference: Dict[str, Any]): - existing_cluster_name = object_reference.get("existing_cluster_name") - - if existing_cluster_name: - dbx_echo("Named parameter existing_cluster_name is provided, looking for it's id") - existing_cluster_id = ClusterController(self._api_client).preprocess_cluster_args( - existing_cluster_name, None - ) - object_reference["existing_cluster_id"] = existing_cluster_id - - -class NewClusterPropertiesProcessor(AbstractProcessor): - def process(self, object_reference: Dict[str, Any]): - self._preprocess_instance_profile_name(object_reference) - self._preprocess_driver_instance_pool_name(object_reference) - self._preprocess_instance_pool_name(object_reference) - - @staticmethod - def _name_from_profile(profile_def) -> str: - return profile_def.get("instance_profile_arn").split("/")[-1] - - def _preprocess_instance_profile_name(self, object_reference: Dict[str, Any]): - instance_profile_name = object_reference.get("aws_attributes", {}).get("instance_profile_name") - - if instance_profile_name: - dbx_echo("Named parameter instance_profile_name is provided, looking for it's id") - all_instance_profiles = self._api_client.perform_query("get", "/instance-profiles/list").get( - "instance_profiles", [] - ) - instance_profile_names = [self._name_from_profile(p) for p in all_instance_profiles] - matching_profiles = [ - p for p in all_instance_profiles if self._name_from_profile(p) == instance_profile_name - ] - - if not matching_profiles: - raise Exception( - f"No instance profile with name {instance_profile_name} found." - f"Available instance profiles are: {instance_profile_names}" - ) - - if len(matching_profiles) > 1: - raise Exception( - f"Found multiple instance profiles with name {instance_profile_name}" - f"Please provide unique names for the instance profiles." - ) - - object_reference["aws_attributes"]["instance_profile_arn"] = matching_profiles[0]["instance_profile_arn"] - - def _preprocess_driver_instance_pool_name(self, object_reference: Dict[str, Any]): - self._generic_instance_pool_name_preprocessor( - object_reference, "driver_instance_pool_name", "instance_pool_id", "driver_instance_pool_id" - ) - - def _preprocess_instance_pool_name(self, object_reference: Dict[str, Any]): - self._generic_instance_pool_name_preprocessor( - object_reference, "instance_pool_name", "instance_pool_id", "instance_pool_id" - ) - - def _generic_instance_pool_name_preprocessor( - self, object_reference: Dict[str, Any], named_parameter, search_id, property_name - ): - instance_pool_name = object_reference.get(named_parameter) - - if instance_pool_name: - dbx_echo(f"Named parameter {named_parameter} is provided, looking for its id") - all_pools = InstancePoolService(self._api_client).list_instance_pools().get("instance_pools", []) - instance_pool_names = [p.get("instance_pool_name") for p in all_pools] - matching_pools = [p for p in all_pools if p["instance_pool_name"] == instance_pool_name] - - if not matching_pools: - raise Exception( - f"No instance pool with name {instance_pool_name} found, available pools: {instance_pool_names}" - ) - - if len(matching_pools) > 1: - raise Exception( - f"Found multiple pools with name {instance_pool_name}, please provide unique names for the pools" - ) - - object_reference[property_name] = matching_pools[0][search_id] diff --git a/dbx/utils/policy_parser.py b/dbx/utils/policy_parser.py deleted file mode 100644 index d8d89842..00000000 --- a/dbx/utils/policy_parser.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -This policy parser is based on: -- API Doc: policy parser is based on API doc https://docs.databricks.com/dev-tools/api/latest/policies.html -- Policy definition docs: - - AWS: https://docs.databricks.com/administration-guide/clusters/policies.html#cluster-policy-attribute-paths - - Azure: https://docs.microsoft.com/en-us/azure/databricks/administration-guide/clusters/policies - - GCP: Cluster policies were not supported at the moment of 0.1.3 release. -Please note that only "fixed" values will be automatically added to the job definition. -""" -from typing import Dict, Any, Tuple, List, Union - - -class PolicyParser: - def __init__(self, policy: Dict[str, Dict[str, Any]]): - self.source_policy = policy - - def parse(self) -> Dict[str, Any]: - """ - Idea of this function is the following: - 1. Walk through all items in the source policy - 2. Take only fixed policies - 3. parse the key: - 3.0 if there are no dots, key is a simple string - 3.1 key might be either a composite one, with dots - then we split this key by dots into a tuple - 3.2 a specific case is with spark_conf (such keys might have multiple dots after the spark_conf - 4. definitions will be added into parsed_props variable - 5. Generate Jobs API compatible dictionary with fixed properties - :return: dictionary in a Jobs API compatible format - """ - parsed_props: List[Tuple[Union[List[str], str], Any]] = [] - for key, definition in self.source_policy.items(): - if definition.get("type") == "fixed": - # preprocess key - # for spark_conf keys might contain multiple dots - if key.startswith("spark_conf"): - _key = key.split(".", 1) - elif "." in key: - _key = key.split(".") - else: - _key = key - _value = definition["value"] - parsed_props.append((_key, _value)) - - result = {} - init_scripts = {} - - for key_candidate, value in parsed_props: - if isinstance(key_candidate, str): - result[key_candidate] = value - else: - if key_candidate[0] == "init_scripts": - idx = int(key_candidate[1]) - payload = {key_candidate[2]: {key_candidate[3]: value}} - init_scripts[idx] = payload - else: - d = {key_candidate[-1]: value} - for _k in key_candidate[1:-1]: - d[_k] = d - - updatable = result.get(key_candidate[0], {}) - updatable.update(d) - - result[key_candidate[0]] = updatable - - init_scripts = [init_scripts[k] for k in sorted(init_scripts)] - if init_scripts: - result["init_scripts"] = init_scripts - - return result diff --git a/dbx/utils/url.py b/dbx/utils/url.py new file mode 100644 index 00000000..4bdf158f --- /dev/null +++ b/dbx/utils/url.py @@ -0,0 +1,12 @@ +from urllib.parse import urlparse + + +def strip_databricks_url(url: str) -> str: + """ + Mlflow API requires url to be stripped, e.g. + {scheme}://{netloc}/some-stuff/ shall be transformed to {scheme}://{netloc} + :param url: url to be stripped + :return: stripped url + """ + parsed = urlparse(url) + return f"{parsed.scheme}://{parsed.netloc}" diff --git a/docs/concepts/artifact_storage.md b/docs/concepts/artifact_storage.md new file mode 100644 index 00000000..f879d930 --- /dev/null +++ b/docs/concepts/artifact_storage.md @@ -0,0 +1,74 @@ +# :material-diamond-stone: Artifact storage + +To properly resolve the [file references](../features/file_references.md) and store the artifacts (e.g. packages, files and deployment definitions), +`dbx` uses persistent cloud storage which is called **artifact storage**. + +## :material-bookshelf: Storage configuration +To perform upload/download/lookup operations, `dbx` uses MLflow APIs under the hood. +Currently, `dbx` only supports MLflow-based API for file operations. + +MLflow-based artifact storage properties are specified per environment in the [project configuration file](../reference/project.md). + +When project is configured by default, the definition looks like this: + +```json title="project.json" hl_lines="6-8" +{ + "environments": { + "default": { + "profile": "some-profile", + "storage_type": "mlflow", + "properties": { + "workspace_directory": "/Shared/dbx/some-project-name", + "artifact_location": "dbfs:/Shared/dbx/projects/some-project-name" + } + } + } +} +``` + +1. Workspace directory points to an MLflow experiment which will be used as a basis for Mlflow-based operations. + There is no need to create a new experiment before running any `dbx` commands. Please note that for security purposes to protect this experiment use the experiment permissions model. +2. Artifact location is one of the locations supported by MLflow on Databricks, namely: `dbfs://` (both mounts and root container are supported), `s3://`, `wasbs://` and `gs://`. + + +!!! warning "Security of the experiment files and the artifact storage" + + To ensure protected R/W access to the deployed objects we **recommend** using `s3://` or `wasbs://` or `gs://` artifact locations. + + **By default, any file stored in the `dbfs://` non-mounted location could be accessed in R/W mode by any user of the same workspace.** + + Therefore, we recommend storing your deployment artifacts in `s3://`, `wasbs://` or `gs://`-based artifact locations. + By doing so you'll ensure that only the relevant people will be able to work with this location. + +!!! tip "Using a non-`dbfs` based artifact location with `dbx execute`" + + Available since 0.8.0 + + Since `dbx execute` expects from the artifact location FUSE support, this limitation could be overcome by using context-based loading. + + To enable context-based loading, you can either specify a project-wide property by configuring it: + ```bash + dbx configure --enable-context-based-upload-for-execute + ``` + Alternatively, you can specift it per each execution by using `dbx execute` with `--upload-via-context` switch. + + +!!! info "ADLS resolution specifics" + + Since `mlfow` only supports `wasbs://`-based paths, and Databricks API requires job objects to be referenced via `abfss://`, + during the file reference resolution `dbx` will use the `wasbs://` protocol for uploads but references will be still resolved into `abfss://` format. + + +## :material-book-plus: Additional libraries + +Available since 0.8.8 + +In case if you're using a cloud-based storage, you might require additional libraries to be installed. + +Add the following extra to your `pip install dbx[chosen-identifier]`: + +- :material-microsoft-azure: for `wasbs://` use `dbx[azure]` +- :material-aws: for `s3://` use `dbx[aws]` +- :material-google-cloud: for `gs://` use `dbx[gcp]` + + diff --git a/docs/concepts/cluster_types.md b/docs/concepts/cluster_types.md index d870cc33..3fde51b8 100644 --- a/docs/concepts/cluster_types.md +++ b/docs/concepts/cluster_types.md @@ -93,6 +93,6 @@ To sum up the cases and potential choices for :material-lightning-bolt-circle: a * Use **all-purpose cluster** for development loop. Synchronize local files to Repo with Notebooks via [`dbx sync repo`](../reference/cli.md#dbx-sync-repo) as [described here](../guides/python/devloop/mixed.md). * Use **job clusters** and `dbx deploy` together with `dbx launch` for automated workflows as [described here](../guides/python/devops/mixed.md). * Developing a :material-language-java: JVM-based project in IDE? - * Use local tests and **job clusters** with [:fontawesome-solid-microchip: instance pools](https://docs.databricks.com/clusters/instance-pools/index.html) for development loop as [described here](../guides/jvm/jvm_devloop.md) + * Use local tests and **job clusters** with [:fontawesome-solid-microchip: instance pools](https://docs.databricks.com/clusters/instance-pools/index.html) for development loop as [described here](../guides/jvm/jvm_devloop.md). * Use **job clusters** and `dbx deploy` together with `dbx launch` for automated workflows as [described here](../guides/jvm/jvm_devops.md). diff --git a/docs/custom/custom.css b/docs/custom/custom.css deleted file mode 100644 index d7abe4f0..00000000 --- a/docs/custom/custom.css +++ /dev/null @@ -1,3 +0,0 @@ -.nowrap { - white-space: nowrap ; -} diff --git a/docs/extras/styles.css b/docs/extras/styles.css new file mode 100644 index 00000000..76489bad --- /dev/null +++ b/docs/extras/styles.css @@ -0,0 +1,181 @@ +.nowrap { + white-space: nowrap; +} + +.margined { + margin-top: 10vh; +} + +.centered_text { + display: inline-block; + vertical-align: middle; +} + +.date-tooltip { + position: relative; /* making the .tooltip span a container for the tooltip text */ +} + +.date-tooltip:before { + content: attr(data-text); /* here's the magic */ + position:absolute; + + /* vertically center */ + /*top:50%;*/ + transform:translateY(60%); + + /*left:100%;*/ + margin-top:15px; /* and add a small left margin */ + + /* basic styles */ + width:100px; + padding:4px; + border-radius:10px; + text-align:center; + color: var(--md-primary-bg-color); + background: var(--md-primary-fg-color); + display:none; /* hide by default */ +} + +.date-tooltip:hover:before { + display:block; +} + + + + +.index_separator { + border-right: rgb(255, 54, 33) solid 0.3em; + display: inline-block; + animation-name: line_animation; + animation-duration: 1.4s; + animation-timing-function: linear; + animation-fill-mode: forwards; +} + + +@keyframes line_animation { + from { + height: 0px; + } + to { + height: 200px; + } +} + +.title_element { + font-size: 4em; + margin: 0 !important; + font-family: DM Mono; + animation-name: text_appear; + animation-duration: 1.4s; + animation-timing-function: linear; + animation-fill-mode: forwards; +} + +.desc_element { + font-size: 1em !important; + font-family: DM Mono; + animation-name: text_appear; + animation-duration: 1.4s; + animation-timing-function: linear; + animation-fill-mode: forwards; +} + +@keyframes text_appear { + from { + opacity: 0; + } + to { + opacity: 1; + } +} + + +.bounce { + -moz-animation: bounce 3s infinite; + -webkit-animation: bounce 3s infinite; + animation: bounce 3s infinite; +} +@-moz-keyframes bounce { + 0%, 20%, 50%, 80%, 100% { + -moz-transform: translateY(0); + transform: translateY(0); + } + 40% { + -moz-transform: translateY(-30px); + transform: translateY(-30px); + } + 60% { + -moz-transform: translateY(-15px); + transform: translateY(-15px); + } +} +@-webkit-keyframes bounce { + 0%, 20%, 50%, 80%, 100% { + -webkit-transform: translateY(0); + transform: translateY(0); + } + 40% { + -webkit-transform: translateY(-30px); + transform: translateY(-30px); + } + 60% { + -webkit-transform: translateY(-15px); + transform: translateY(-15px); + } +} +@keyframes bounce { + 0%, 20%, 50%, 80%, 100% { + -moz-transform: translateY(0); + -ms-transform: translateY(0); + -webkit-transform: translateY(0); + transform: translateY(0); + } + 40% { + -moz-transform: translateY(-30px); + -ms-transform: translateY(-30px); + -webkit-transform: translateY(-30px); + transform: translateY(-30px); + } + 60% { + -moz-transform: translateY(-15px); + -ms-transform: translateY(-15px); + -webkit-transform: translateY(-15px); + transform: translateY(-15px); + } +} + +.cards-container { + display: flex; + width: 100%; + justify-content: center; + padding-right: 5rem; + padding-left: 5rem; + font-size: 0.5em; +} + +.card-element-title { + padding-top: 0.1em; + font-size: 1em; + margin-bottom: 0.5em; +} + +.card-element { + box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2); + display: flex; + flex: 50%; + flex-direction: column; + padding: 0.5em; + border: hsla(232, 62%, 95%, 0.66) solid 1px; + margin: 1em; + min-height: 10em; + border-radius: 5px; +} + +.card-element:hover { + border: rgb(255, 54, 33) solid 1px; + } + +.rst-current-version { + display: none; +} diff --git a/docs/faq.md b/docs/faq.md index fd1416dc..7ff4b5e8 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -1,14 +1,31 @@ # Frequently asked questions -- Does `dbx` support R? +## Does `dbx` support R? Unfortunately at the moment `dbx` doesn't support R in `dbx execute` mode. At the same time you can use `dbx deploy` and `dbx launch` to work with R-based notebook tasks. -- Does `dbx` support `poetry`? +## Does `dbx` support `poetry`? Yes, setup the build logic for poetry as described [here](./features/build_management.md). -- Does `dbx` support `flit`? +## Does `dbx` support `flit`? Yes, setup the build logic for flit as described [here](./features/build_management.md). + +## What's the difference between `dbx execute` and `dbx launch`? + +The `dbx execute` command runs your code on `all-purpose` [cluster](../concepts/cluster_types#all-purpose-clusters). +It's very handy for interactive development and data exploration. + +!!! danger "Don't use `dbx execute` for production workloads" + + It's not recommended to use `dbx execute` for production workloads. Run your workflows on the dedicated job clusters instead. + Reasoning is described in detail in the [concepts section](../concepts/cluster_types). + +In contrast to the `dbx execute`, `dbx launch` launches your workflow on a [dedicated job cluster](./concepts/cluster_types#job-clusters). This is a recommended way for CI pipelines, automated launches etc. + +When in doubt, follow the [summary section](../concepts/cluster_types#summary) for precise guidance. + +## Can I have multiple `deployment files` ? +Yes, `dbx deploy` accepts `--deployment-file PATH` as described [here](./reference/cli/#dbx-deploy). diff --git a/docs/features/file_references.md b/docs/features/file_references.md index 88e21254..07505196 100644 --- a/docs/features/file_references.md +++ b/docs/features/file_references.md @@ -3,6 +3,8 @@ During the workflow deployment, you frequently would like to upload some specific files to `DBFS` and reference them in the workflow definition. +## :material-file-link: File referencing by example + Any keys referenced in the deployment file starting with `file://` or `file:fuse://` will be uploaded to the artifact storage. @@ -47,16 +49,40 @@ environments: As you can see there are two **different** prefixes for file references. +## :material-file-search: FUSE and standard reference resolution + There are two types of how the file path will be resolved and referenced in the final deployment definition: === "Standard" This definition looks like this `file://some/path/in/project/some.file`.
- It will be resolved into `dbfs:///some/path/in/project/some.file`. + It will be resolved into `/some/path/in/project/some.file`. + + Files that were referenced in a standard way are in most cases used as workflow properties (e.g. init scripts or `spark_python_file` references). + To programmatically read a file that was referenced in standard way, you'll need to use object storage compatible APIs, for example Spark APIs: + ```python + standard_referenced_path = "dbfs://some/path" # or "s3://some/path" or "gs://some/path" or "abfss://some/path" + + def read_text_payload(_path): + return "\n".join(spark.read.format("text").load(standard_referenced_path).select("value").toPandas()["value"]) + raw_text_payload = read_text_payload(standard_referenced_path) + ``` === "FUSE" This definition looks like this `file:fuse://some/path/in/project/some.file`.
- It will be resolved into `/dbfs//some/path/in/project/some.file`
. + It will be resolved into `/dbfs//some/path/in/project/some.file`. + + In most cases FUSE-based paths are used to pass something into a library which only supports reading files from local FS. + + For instance, a use-case might be to read a configuration file using Python `pathlib` library: + ```python + from pathlib import Path + fuse_based_path = "/dbfs/some/path" + payload = Path(fuse_based_path).read_text() + ``` + Although FUSE is a very convenient approach, unfortunately it only works with `dbfs://`-based artifact locations (both mounted and non-mounted). + +!!! tip "Various artifact storage types process references differently" -The latter type of path resolution might come in handy when the using system doesn't know how to work with cloud storage protocols. + Please read the [artifact storage](../concepts/artifact_storage.md) for details on how various references would work with different artifact storage types. diff --git a/docs/features/named_properties.md b/docs/features/named_properties.md index aa83526d..780197ca 100644 --- a/docs/features/named_properties.md +++ b/docs/features/named_properties.md @@ -1,17 +1,164 @@ -# :material-rename-box: Named properties +# :material-rename-box: Name-based properties referencing -With `dbx` you can use name-based properties instead of providing ids in the [:material-file-code: deployment file](../reference/deployment.md). +## :material-cog-stop: Legacy approach + +!!! danger "This approach is considered legacy" + + Please don't use the approach described here. Use the new approach described below on the same page. + +With `dbx` you can use name-based properties instead of providing ids in +the [:material-file-code: deployment file](../reference/deployment.md). The following properties are supported: -- :material-state-machine: `existing_cluster_name` will be automatically replaced with `existing_cluster_id` -- :fontawesome-solid-microchip: `new_cluster.instance_pool_name` will be automatically replaced with `new_cluster.instance_pool_id` -- :fontawesome-solid-microchip: `new_cluster.driver_instance_pool_name` will be automatically replaced with `new_cluster.driver_instance_pool_id` -- :material-aws: `new_cluster.aws_attributes.instance_profile_name` will be automatically replaced with `new_cluster.aws_attributes.instance_profile_arn` -- :material-list-status: `new_cluster.policy_name` will automatically fetch all the missing policy parts and properly resolved them, replacing the `policy_name` with `policy_id` +- :material-state-machine: `existing_cluster_name` will be automatically replaced with `existing_cluster_id` +- :fontawesome-solid-microchip: `new_cluster.instance_pool_name` will be automatically replaced + with `new_cluster.instance_pool_id` +- :fontawesome-solid-microchip: `new_cluster.driver_instance_pool_name` will be automatically replaced + with `new_cluster.driver_instance_pool_id` +- :material-aws: `new_cluster.aws_attributes.instance_profile_name` will be automatically replaced + with `new_cluster.aws_attributes.instance_profile_arn` +- :material-list-status: `new_cluster.policy_name` will automatically fetch all the missing policy parts and properly + resolved them, replacing the `policy_name` with `policy_id` By this simplification, you don't need to look-up for these id-based properties, you can simply provide the names. -!!! warning +!!! warning "Name verification" `dbx` will automatically check if the provided name exists and is unique, and if it's doesn't or it's non-unique you'll get an exception. + +!!! danger "DLT support" + + Please note that `*_name`-based legacy properties **will not work** with DLT. Use the reference-based approach described below. + +## :material-vector-link: Reference-based approach + +Available since 0.8.0 + +With the new approach introduced in 0.8.0, we've made the named parameter passing way easier. + +Simply use the string prefixed by the object type to create a reference which will be automatically replaced during +deployment. + +General format for a reference looks line this: + +``` +object-type://object-name +``` + +The following references are supported: + +| Reference prefix | Referencing target | API Method used for reference resolution | +|-----------------------------|---------------------------|-----------------------------------------------------------------------------------------------------------------------------------| +| `instance-pool://` | Instance Pools | [ListInstancePools](https://docs.databricks.com/dev-tools/api/latest/instance-pools.html#list) | +| `instance-profile://` | Instance Profiles | [ListInstanceProfiles](https://docs.databricks.com/dev-tools/api/latest/instance-profiles.html#list) | +| `pipeline://` | Delta Live Tables | [ListPipelines](https://docs.databricks.com/workflows/delta-live-tables/delta-live-tables-api-guide.html#list-pipelines) | +| `service-principal://` | Service Principals | [GetServicePrincipals (for workspaces)](https://docs.databricks.com/dev-tools/api/latest/scim/scim-sp.html#get-service-principals) | +| `warehouse://` | Databricks SQL Warehouses | [ListSqlWarehouses](https://docs.databricks.com/sql/api/sql-endpoints.html#list) | +| `query://` | Databricks SQL Queries | [GetSqlQueries](https://docs.databricks.com/sql/api/queries-dashboards.html#operation/sql-analytics-get-queries) | +| `dashboard://` | Databricks SQL Dashboards | [GetSqlDashboards](https://docs.databricks.com/sql/api/queries-dashboards.html#operation/get-sql-analytics-dashboards) | +| `alert://` | Databricks SQL Alerts | [GetSqlAlerts](https://docs.databricks.com/sql/api/queries-dashboards.html#operation/databricks-sql-get-alerts) | +| `cluster-policy://` | Cluster Policies | [ListClusterPolicies](https://docs.databricks.com/dev-tools/api/latest/policies.html#operation/list-cluster-policies) | +| `file://` or `file:fuse://` | Files | Please refer to the [file references documentation](./file_references.md) | + +The provided object references are expected to be **unique**. If the name of the object is not unique, an error will be +raised. + +## :material-list-status: Cluster policies resolution + +[Cluster policies](https://docs.databricks.com/administration-guide/clusters/policies.html) is a very convenient +interface that allows generalizing specific rules to a wide set of clusters. + +`dbx` provides capabilities to reference the policy name in the cluster definition, and some of the policy +properties will be automatically added to the cluster definition during deployment step. + +### :material-format-list-checks: Resolution logic for properties + +Please note that cluster policies are only resolved in the following cases: + +- `policy_id` OR `policy_name` (latter is legacy) are provided as a part of `new_cluster` definition +- `policy_id` startswith `cluster-policy://` + +The following logic is then applied to the policy and cluster definition: + +1. Policy definition is traversed and transformed into Jobs API compatible format. Only the `fixed` properties are + selected during traversal. +2. Policy definition deeply updates the cluster definition. If there are any keys provided in the cluster definition + that are fixed in the policy, an error will be thrown. +3. Updated cluster definition goes back to the overall workflow definition + +!!! warning "Other policy elements" + + `dbx` doesn't resolve and verify any other policy elements except the [Fixed ones](https://docs.databricks.com/administration-guide/clusters/policies.html#fixed-policy). + + Therefore, if you have for instance: + + * [Forbidden Policies](https://docs.databricks.com/administration-guide/clusters/policies.html#forbidden-policy) + * [Limiting Policies](https://docs.databricks.com/administration-guide/clusters/policies.html#limiting-policies-common-fields) + * [Allowlist Policies](https://docs.databricks.com/administration-guide/clusters/policies.html#allow-list-policy) + + They will only be resolved during the workflow deployment API call. + +### :material-script: Init scripts resolution logic + +Available since 0.8.0 + +[Init scripts](https://docs.databricks.com/clusters/init-scripts.html) is a powerful tool in Databricks to setup the +workflow environment before the workflow is running. + +A very common use case is +to [setup the Python pip.conf](https://kb.databricks.com/en_US/clusters/install-private-pypi-repo) +if the workflow needs some private packages, then you don't need to declare it in +each [pip install](https://docs.databricks.com/libraries/notebooks-python-libraries.html#install-a-private-package-with-credentials-managed-by-databricks-secrets-with-pip) +. + +To properly resolve init scripts together with the policy settings, dbx will merge in order and with deduplication the +init scripts from the cluster policy and those from the key `new_cluster.init_scripts`. + +For instance, assuming there is a policy `policy-with-pip-install-script` that enforces adding an `init_script`: +```json title="cluster-policy.json" +{ + "init_scripts.0.dbfs.destination": { + "type": "fixed", + "value": "dbfs://some/path/script.sh" + } +} +``` + +With the following [:material-file-code: deployment file](../reference/deployment.md) that references this policy: + +```yaml title="conf/deployment.yml" linenums="1" hl_lines="8 12" +# irrelevant parts are omitted +environments: + default: + workflows: + - name: workflow_name + job_clusters: + - new_cluster: + policy_id: "cluster-policy://policy-with-pip-install-script" + init_scripts: + - dbfs: + destination: dbfs:/some/path/install_sql_driver.sh + tasks: + ... +``` + +`dbx` will correctly resolve the `init_scripts` array and turn it into the following definition: + +```json title="playload-for-api.json" +{ + "policy_id": "AAABBBCCCDDDD", + "init_scripts": [ + { + "dbfs": { + "destination": "dbfs://some/path/script.sh" + } + }, + { + "dbfs": { + "destination": "dbfs:/some/path/install_sql_driver.sh" + } + } + ] +} +``` diff --git a/docs/features/permissions_management.md b/docs/features/permissions_management.md index c1922a63..d7a13069 100644 --- a/docs/features/permissions_management.md +++ b/docs/features/permissions_management.md @@ -1,41 +1,35 @@ # :fontawesome-solid-users-gear: Permissions management -`dbx` supports permissions management both for Jobs API 2.0 and Jobs API 2.1. +`dbx` supports permissions management for Jobs API :material-surround-sound-2-0: and Jobs API :material-surround-sound-2-1: as well as for workflows in the format of :material-table-heart: Delta Live Tables. !!! tip You can find the full specification for Permissions API [here](https://docs.databricks.com/dev-tools/api/latest/permissions.html). -## :material-tag-outline: For Jobs API 2.0 +## :material-file-check: Providing the permissions -To manage permissions for Jobs API 2.0, provide the following payload at the workflow level: +To manage permissions provide the following payload at the workflow level: ```yaml environments: default: workflows: - - name": "job-v2.0" - permissions: - ## here goes payload compliant with Permissions API - access_control_list: - - user_name: "some_user@example.com" - permission_level: "IS_OWNER" - - group_name: "some-user-group" - permission_level: "CAN_VIEW" -``` - - - -## :material-tag-multiple: For Jobs API 2.1 - -To manage permissions for Jobs API 2.1, provide the following payload at the workflow level: + # example for DLT pipeline + - name: "some-dlt-pipeline" + libraries: + - notebook: + path: "/some/repos" + access_control_list: + - user_name: "some_user@example.com" + permission_level: "IS_OWNER" + - group_name: "some-user-group" + permission_level: "CAN_VIEW" -```yaml -environments: - default: - workflows: - - name": "job-v2.0" + # example for multitask workflow + - name: "some-workflow" + tasks: + ... access_control_list: - user_name: "some_user@example.com" permission_level: "IS_OWNER" @@ -43,5 +37,9 @@ environments: permission_level: "CAN_VIEW" ``` -Please note that in both cases (v2.0 and v2.1) the permissions **must be exhaustive**. -It means that per each job at least the job owner shall be specified (even if it's already specified in the UI). +Please note that the permissions **must be exhaustive**. +It means that per each workflow at least the owner (`permission_level: "IS_OWNER"`) shall be specified (even if it's already specified in the UI). + +!!! tip "Managing permissions for service principals" + + Take a look at [this example](../reference/deployment.md#managing-the-workflow-as-a-service-principal) in the deployment file reference. diff --git a/docs/guides/general/custom_templates.md b/docs/guides/general/custom_templates.md new file mode 100644 index 00000000..6237fe15 --- /dev/null +++ b/docs/guides/general/custom_templates.md @@ -0,0 +1,81 @@ +# :octicons-repo-template-24: Using custom templates + +When using [`dbx init`](../../reference/cli.md#dbx-init) you have the option to define custom templates to use to create your project structure. +You can store these on git or decide to ship them as a Python package, and you can also use built in templates that are provided as a part of dbx. + +If you would like to create a custom template, feel free to re-use the code you’ll find in the +[`python_basic` dbx templates folder](https://github.com/databrickslabs/dbx/tree/main/dbx/templates/projects/python_basic) and adjust it according to your needs, for example: + +- generate other payload in the config file +- change the code in the `common.py` +- change the project structure +- choose another packaging tool, etc. + +Pretty much anything you would like to add to your template could be configured using this functionality. + +There are two options on how to ship your templates for further dbx usage: +- Git repo +- Python package + +This page further describes both approaches. + +## :material-git: Git repo + +The following command: +```bash +dbx init --path PATH [--checkout LOC] +``` +will check out a project template based on the [cookiecutter](https://cookiecutter.readthedocs.io/en/latest/index.html) approach from a git repository. + +The `--checkout` flag is optional. If provided, `dbx` will check out a branch, tag or commit after cloning the repo. Default behaviour is to check out the `master` or `main` branch. + +``` +dbx init --path=https://git/repo/with/template.git +``` + +If you need versioning to your package add the `--checkout` flag: +```bash +#specific tag +dbx init --path=https://git/repo/with/template.git --checkout=v0.0.1 + +#specific branch +dbx init --path=https://git/repo/with/template.git --checkout=prod + +#specific git commit +dbx init --path=https://git/repo/with/template.git --checkout=aaa111bbb +``` + +## :material-package-variant-closed: Python package + +When you don't have direct access to the git repo, or you want to ship your template as packaged Python library this flag will be helpful to you. +Some organizations prefer to store their templates in an internal PyPi Repo as Python package. For such use-cases dbx can pick up the template source directly from the package code. + +In your package you will need to have a root folder called `render`, and dbx will pick up the template spec from it. + +Find the example steps below: + + 1. Create template package with the following structure (minimal example): + + ``` + ├── render # define your cookiecutter project inside it + │ └── cookiecutter.json + └── setup.py + ``` + + 2. Host this package in PyPi or Nexus, etc. + 3. Install package *before* the `init` command: + ```bash + pip install "my-template-pkg==0.0.1" # or whatever version + ``` + 4. Initialize template from the package: + ```bash + dbx init --package=my-template-pkg + ``` + +## :material-note-edit-outline: Default templates + +Using `dbx init --template` option provides access to the templates that are pre-shipped with dbx (currently there is only one template which is `python_basic`). + +``` +dbx init --template=python_basic +``` diff --git a/docs/guides/general/delta_live_tables.md b/docs/guides/general/delta_live_tables.md new file mode 100644 index 00000000..299f37ba --- /dev/null +++ b/docs/guides/general/delta_live_tables.md @@ -0,0 +1,134 @@ +# :material-table-heart: Working with Delta Live Tables + +Available since 0.8.0 + +`dbx` provides capabilities to deploy, launch and reference pipelines based on the [Delta Live Tables framework](https://docs.databricks.com/workflows/delta-live-tables/index.html). + +!!! warning "Development support for DLT Pipelines" + + Please note that `dbx` doesn't provide interactive development and execution capabilities for DLT Pipelines. + +## :material-hexagon-multiple-outline: Providing pipeline definition in the deployment file + +Example pipeline definition would look like this: + +```yaml title="conf/deployment.yml" +environments: + workflows: + - name: "sample-dlt-pipeline" + workflow_type: "pipeline" #(1) + storage: "dbfs:/some/location" #(2) + configuration: #(3) + "property1": "value" + "property2": "value2" + clusters: #(4) + - label: "some-label" #(5) + spark_conf: + "spark.property1": "value" + "spark.property2": "value2" + aws_attributes: + ... + instance_pool_id: "instance-pool://some-pool" + driver_instance_pool_id: "instance-pool://some-pool" + policy_id: "cluster-policy://some-policy" + autoscale: + min_workers: 1 + max_workers: 4 + mode: "legacy" #(6) + libraries: #(7) + - notebook: + path: "/Repos/some/path" + - notebook: + path: "/Repos/some/other/path" + target: "some_target_db" + ... #(8) +``` + +1. [REQUIRED] If not provided, `dbx` will try to parse the workflow definition as a workflow in Jobs format. +2. [OPTIONAL] A path to a DBFS directory for storing checkpoints and tables created by the pipeline. The system uses a default location if this field is empty. +3. [OPTIONAL] A list of key-value pairs to add to the Spark configuration of the cluster that will run the pipeline. +4. [OPTIONAL] If this is not specified, the system will select a default cluster configuration for the pipeline. +5. Follow documentation for this section [here](https://docs.databricks.com/workflows/delta-live-tables/delta-live-tables-api-guide.html#pipelinesnewcluster). +6. Also, could be `mode: "enchanced"`, read more on [this feature here](https://docs.databricks.com/workflows/delta-live-tables/delta-live-tables-concepts.html#databricks-enhanced-autoscaling). +7. [REQUIRED] The notebooks containing the pipeline code and any dependencies required to run the pipeline. +8. Follow the [official documentation page](https://docs.databricks.com/workflows/delta-live-tables/delta-live-tables-api-guide.html#pipelinesettings) for other fields and properties + +!!! tip "Payload structure for DLT pipelines" + + In general, `dbx` will use the payload structure specified in the [CreatePipeline](https://docs.databricks.com/workflows/delta-live-tables/delta-live-tables-api-guide.html#create-a-pipeline) API of DLT. + All features of `dbx` such as [named properties](../../features/named_properties.md) are fully supported with pipelines deployment as well. + +Please note that prior to deploying the pipeline, you'll need to update the relevant notebook sources. + +This could be done by using the functionality of the main `databricks-cli`: + +```bash +databricks repos update --path="/Repos/some/path" --branch="specific-branch" +databricks repos update --path="/Repos/some/path" --tag="specific-tag" +``` + +## :material-rocket-launch: Launching DLT pipelines using `dbx` + + +To launch a DLT pipeline, simply use the `dbx launch` command with `-p` or `--pipeline` switch: + +```bash +dbx launch -p # also could be --pipeline instead of -p +``` + +!!! danger "Assets-based launch is not supported in DLT pipelines" + + Please note that [assets-based launch](../../features/assets.md) is **not supported for DLT pipelines**. + + Use the properties of the DLT pipeline, such as `target` and `development` if you're looking for capabilities to launch a specific branch. + +## :material-code-brackets: Passing parameters to DLT pipelines during launch + +Following the API structures and examples provided [here](https://docs.databricks.com/workflows/delta-live-tables/delta-live-tables-api-guide.html#start-a-pipeline-update): + +```bash +dbx launch --parameters='{ "full_refresh": "true" }' # for full refresh +dbx launch --parameters='{ "refresh_selection": ["sales_orders_cleaned", "sales_order_in_chicago"] }' # start an update of selected tables +dbx launch --parameters='{ + "refresh_selection": ["sales_orders_cleaned", "sales_order_in_chicago"], + "full_refresh_selection": ["customers", "sales_orders_raw"] +}' # start a full update of selected tables +``` + +## :material-link-plus: Referencing DLT pipelines inside multitask workflows + +Sometimes you might need to chain various tasks around the DLT pipeline. In this case you can use the `pipeline_task` capabilities of the Databricks Workflows. +For example, your deployment file definition could look like this: + +```yaml title="conf/deployment.yml" hl_lines="21-22" +environments: + default: + workflows: + - name: "some-pipeline" + workflow_type: "pipeline" + libraries: + - notebook: + path: "/Repos/some/project" + - name: "dbx-pipeline-chain" + job_clusters: + - job_cluster_key: "main" + <<: *basic-static-cluster + tasks: + - task_key: "one" + python_wheel_task: + entry_point: "some-ep" + package_name: "some-pkg" + - task_key: "two" + depends_on: + - task_key: "one" + pipeline_task: + pipeline_id: "pipeline://some-pipeline" #(1) +``` + +1. Read more on the reference syntax [here](../../features/named_properties.md). + +!!! tip "Order of the definitions" + + Please note that if you're planning to reference the pipeline definition, you should explicitly put them first in the deployment file `workflows` section, **prior to the reference**. + + Without the proper order you'll run into a situation where your DLT pipeline wasn't yet deployed, therefore the reference won't be resolved which will cause an error. diff --git a/docs/guides/general/passing_parameters.md b/docs/guides/general/passing_parameters.md new file mode 100644 index 00000000..66952ea7 --- /dev/null +++ b/docs/guides/general/passing_parameters.md @@ -0,0 +1,157 @@ +# :material-code-brackets: Passing parameters + +Available since 0.8.0 + + +`dbx` provides various interfaces to pass the parameters to the workflows and tasks. +Unfortunately, the underlying APIs and the fact that `dbx` is a CLI tool are somewhat +limiting the capabilities of parameter passing and require additional payload preparation. + +This documentation section explains various ways to pass parameters both statically and dynamically. + + +## :material-pin-outline: Static parameter passing + + +To pass the parameters statically, you shall use the deployment file. +Please find reference on various tasks and their respective parameters in the [deployment file reference](../../reference/deployment.md). + + +## :material-clipboard-flow-outline: Dynamic parameter passing + +In some cases you would like to override the parameters that are defined in the static file to launch the workflow. + +Therefore, there are 3 fundamentally different options on how parameters could be provided dynamically: + +- Parameters of a **:octicons-package-dependents-24: task** in [`dbx execute`](../../reference/cli#dbx-execute) +- Parameters of a **:octicons-workflow-24: workflow** in the [asset-based launch mode](../../features/assets.md) (`dbx launch --from-assets`) +- Parameters of a **:octicons-workflow-24: workflow** in the normal launch mode (`dbx launch`) + +All parameters should be provided in a form of a JSON-compatible payload (see examples below). + +!!! tip "Multiline strings in the shell" + + To handle a multiline string in the shell, use single quotes (`'`), for instance: + + ```bash + dbx ... --parameters='{ + "nicely": "formatted" + "multiline": [ + "string", + "with", + "many", + "rows" + ] + }' + ``` + +=== ":material-lightning-bolt-circle: `dbx execute`" + + The `execute` command expects the parameters only for one specific task (if it's a multitask job) or only for the job itself (if it's the old 2.0 API format). + The provided payload shall match the expected payload of a chosen workflow or task. + It also should be wrapped as a JSON-compatible string with curly brackets around the API-compatible payload. + Since `dbx execute` only supports `spark_python_task` and `python_wheel_task`, only the compatible parameters would work. + + Examples: + ```bash + dbx execute --parameters='{"parameters": ["argument1", "argument2"]}' # compatible with spark_python_task and python_wheel_task + dbx execute --parameters='{"named_parameters": {"a":1, "b": 2}}' # compatible only with python_wheel_task + ``` + +=== ":material-lightning-bolt-outline: `dbx launch --assets-only`" + + !!! tip "General note about 2.0 and 2.1 Jobs API" + + We strongly recommend using the Jobs API 2.1, since it provides extensible choice of tasks and options. + + Assets-based launch provides 2 different interfaces depending on the Jobs API version you use. + + For Jobs API 2.0, the payload format should be compliant with the [RunSubmit](https://docs.databricks.com/dev-tools/api/2.0/jobs.html#runs-submit) API structures in terms of the task arguments. + + Examples for Jobs API 2.0: + ```bash + dbx launch --assets-only --parameters='{"base_parameters": {"key1": "value1", "key2": "value2"}}' # for notebook_task + dbx launch --assets-only --parameters='{"parameters": ["argument1", "argument2"]}' # for spark_jar_task, spark_python_task and spark_submit_task + ``` + + For Jobs API 2.1, the payload format should be compliant with the [JobsRunSubmit](https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit). + This API supports task-level parameter passing. + + Examples for Jobs API 2.1: + + ```bash + # notebook_task + dbx launch --assets-only --parameters='[ + {"task_key": "some", "base_parameters": {"a": 1, "b": 2}} + ]' + + # spark_python_task, python_wheel_task, spark_jar_task, spark_submit_task, python_wheel_task + dbx launch --assets-only --parameters='[ + {"task_key": "some", "parameters": ["a", "b"]} + ]' + + # python_wheel_task + dbx launch --assets-only --parameters='[ + {"task_key": "some", "named_parameters": {"a": 1, "b": 2}} + ]' + + # pipeline_task + dbx launch --assets-only --parameters='[ + {"task_key": "some", "full_refresh": true} + ]' + + # sql_task + dbx launch --assets-only --parameters='[ + {"task_key": "some", "parameters": {"key1": "value2"}} + ]' + ``` + + +=== ":material-lightning-bolt-outline: `dbx launch`" + + !!! tip "General note about 2.0 and 2.1 Jobs API" + + We strongly recommend using the Jobs API 2.1, since it provides extensible choice of tasks and options. + + For Jobs API 2.0, the payload format should be compliant with the [RunNow](https://docs.databricks.com/dev-tools/api/2.0/jobs.html#run-now) API structures in terms of the task arguments. + + Examples: + + ```bash + dbx launch --parameters='{"jar_params": ["a1", "b1"]}' # spark_jar_task + dbx launch --parameters='{"notebook_params":{"name":"john doe","age":"35"}}' # notebook_task + dbx launch --parameters='{"python_params":["john doe","35"]}' # spark_python_task + dbx launch --parameters='{"spark_submit_params": ["--class", "org.apache.spark.examples.SparkPi"]}' # spark_submit_task + ``` + + For Jobs API 2.1, the payload format should be compliant with the [RunNow](https://docs.databricks.com/dev-tools/api/2.0/jobs.html#run-now) API structures in terms of the task arguments. + + !!! danger "Per-task parameter passing is not supported" + + Unfortunately it's not possible to pass parameters to the each task individually. + + As a workaround, you can name parameters differently in different tasks. + + It is also possible to provide a combined payload, e.g. in you have a `notebook_task` and a `spark_python_task` in your workflow, you can combine as follows: + ```bash + dbx launch --parameters='{ + "notebook_params":{"name":"john doe","age":"35"}, + "python_params":["john doe","35"] + }' + ``` + + Examples of parameter provisioning: + ```bash + dbx launch --parameters='{"jar_params": ["a1", "b1"]}' # spark_jar_task + dbx launch --parameters='{"notebook_params":{"name":"john doe","age":"35"}}' # notebook_task + dbx launch --parameters='{"python_params":["john doe","35"]}' # spark_python_task or python_wheel_task + dbx launch --parameters='{"spark_submit_params": ["--class", "org.apache.spark.examples.SparkPi"]}' # spark_submit_task + dbx launch --parameters='{"python_named_params": {"name": "task", "data": "dbfs:/path/to/data.json"}}' # python_wheel_task + dbx launch --parameters='{"pipeline_params": {"full_refresh": true}}' # pipeline_task as a part of a workflow + dbx launch --parameters='{"sql_params": {"name": "john doe", "age": "35"}}' # sql_task + dbx launch --parameters='{"dbt_commands": ["dbt deps", "dbt seed", "dbt run"]}' # dbt_task + ``` + +!!! tip "Passing parameters for DLT pipelines" + + Please follow the [Delta Live Tables guide](./delta_live_tables.md) to pass parameters that are specific for DLT pipelines. diff --git a/docs/guides/python/python_quickstart.md b/docs/guides/python/python_quickstart.md index 20400043..5398c648 100644 --- a/docs/guides/python/python_quickstart.md +++ b/docs/guides/python/python_quickstart.md @@ -1,6 +1,6 @@ # :material-language-python: :material-airplane-takeoff: Python quickstart -In this guide we're going to walkthough a typical setup for development purposes. +In this guide we're going to walk through a typical setup for development purposes. In the end of this guide you'll have a prepared local environment, as well as capabilities to: @@ -64,14 +64,30 @@ databricks --profile charming-aurora workspace ls / Now the preparation is done, let’s generate a project skeleton using [`dbx init`](../../reference/cli.md): ```bash -dbx init -p \ - "cicd_tool=GitHub Actions" \ +dbx init \ + -p "cicd_tool=GitHub Actions" \ -p "cloud=" \ -p "project_name=charming-aurora" \ -p "profile=charming-aurora" \ --no-input ``` +!!! warning "Choosing the artifact storage wisely" + + Although for the quickstart we're using the standard `dbfs://`-based artifact location, it's **strictly recommended** to use + proper cloud-based storage for artifacts. Please read more in [this section](../../concepts/artifact_storage.md). + + To change the artifact location for this specific guide, add the following parameter to the command above: + ```bash + dbx init \ + -p "cicd_tool=GitHub Actions" \ + -p "cloud=" \ + -p "project_name=charming-aurora" \ + -p "profile=charming-aurora" \ + -p "artifact_location=" + --no-input + ``` + Step into the newly generated folder: ```bash @@ -101,7 +117,7 @@ pip install -e ".[local,test]" ``` Use the [`findspark`](https://github.com/minrk/findspark) package in your local unit tests to correctly identify Apache Spark on the local path. -After installing all of the dependencies, it's time to run local unit tests: +After installing all the dependencies, it's time to run local unit tests: ```bash pytest tests/unit --cov diff --git a/docs/index.md b/docs/index.md index a57aafeb..ac0943f9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,90 +1,39 @@ -# :octicons-feed-rocket-16: dbx by Databricks Labs - -

- - logo - -

- -🧱 Databricks CLI eXtensions - aka `dbx` is a CLI tool for development and advanced Databricks workflows management. - -## :octicons-light-bulb-24: Concept - -`dbx` aims to improve development experience for Data and ML teams that use Databricks, by providing the following capabilities: - -- Ready to use project templates, with built-in support to use custom templates -- Simple configuration for multi-environment setup -- Interactive development loop for Python-based projects -- Flexible deployment configuration -- Built-in versioning for deployments - -Since `dbx` primary interface is CLI, it's easy to use it in various CI/CD pipelines, independent of the CI provider. - -Read more about the place of `dbx` and potential use-cases in the [Ecosystem section](concepts/ecosystem.md). - -## :thinking: Differences from other tools - -| Tool | Comment | -|--------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [databricks-cli](https://github.com/databricks/databricks-cli) | dbx is NOT a replacement for databricks-cli. Quite the opposite - dbx is heavily dependent on databricks-cli and uses most of the APIs exactly from databricks-cli SDK. | -| [mlflow cli](https://www.mlflow.org/docs/latest/cli.html) | dbx is NOT a replacement for mlflow cli. dbx uses some of the MLflow APIs under the hood to store serialized job objects, but doesn't use mlflow CLI directly. | -| [Databricks Terraform Provider](https://github.com/databrickslabs/terraform-provider-databricks) | While dbx is primarily oriented on versioned job management, Databricks Terraform Provider provides much wider set of infrastructure settings. In comparison, dbx doesn't provide infrastructure management capabilities, but brings more flexible deployment and launch options. | -| [Databricks Stack CLI](https://docs.databricks.com/dev-tools/cli/stack-cli.html) | Databricks Stack CLI is a great component for managing a stack of objects. dbx concentrates on the versioning and packaging jobs together, not treating files and notebooks as a separate component. | - -Read more about the place of `dbx` and potential use-cases in the concepts section. - -## :octicons-link-external-24: Next steps - -Depending on your developer journey and overall tasks, you might use `dbx` in various ways: - -=== ":material-language-python: Python" - - | :material-sign-direction: Developer journey | :octicons-link-24: Link | :octicons-tag-24: Tags | - |-------------------------------------------------------------------------------------------------|----------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| - | Develop a new Python project on Databricks solely in IDE without using Notebooks | [Python quickstart](./guides/python/python_quickstart.md) |
:fontawesome-brands-python: Python
:fontawesome-solid-laptop: IDE
| - | Develop a new Python project on Databricks with Databricks Notebooks and partially in the IDE | [Python quickstart](./guides/python/python_quickstart.md) followed by
[Mixed-mode development loop for Python projects](./guides/python/devloop/mixed.md) |
:fontawesome-brands-python: Python
:fontawesome-solid-laptop: IDE
:material-notebook-heart-outline: Notebook
| - | Organize a development loop for an existing Notebooks-based project together with IDE | [Mixed-mode development loop for Python projects](./guides/python/devloop/mixed.md) |
:fontawesome-brands-python: Python
:fontawesome-solid-laptop: IDE
:material-notebook-heart-outline: Notebook
| - | Organize a development loop for an existing Python package-based project | [Development loop for Python package-based projects](./guides/python/devloop/package.md) |
:fontawesome-brands-python: Python
:fontawesome-solid-laptop: IDE
:octicons-package-16: Packaging
| - | Add workflow deployment and automation capabilities to an existing Python package-based project | [DevOps for Python package-based projects](./guides/python/devops/package.md) |
:fontawesome-brands-python: Python
:octicons-package-16: Packaging
:fontawesome-solid-ship: Deployment
| - | Add workflow deployment and automation capabilities to an existing Notebooks-based project | [DevOps for Notebooks-based projects](./guides/python/devops/notebook.md) |
:fontawesome-brands-python: Python
:material-notebook-heart-outline: Notebook
:fontawesome-solid-ship: Deployment
| - -=== ":material-language-java: Java and Scala" - - | :material-sign-direction: Developer journey | :octicons-link-24: Link | :octicons-tag-24: Tags | - |-------------------------------------------------------------------------------------------------|----------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| - | Organize a development loop for an existing JVM-based project (e.g. Java or Scala) in IDE | [Development loop for JVM-based projects](./guides/jvm/jvm_devloop.md) |
:fontawesome-brands-java: JVM
:fontawesome-solid-laptop: IDE
| - | Add workflow deployment and automation capabilities to an existing JVM-based project | [DevOps for JVM-based projects](./guides/jvm/jvm_devops.md) |
:fontawesome-brands-java: JVM
:octicons-package-16: Packaging
:fontawesome-solid-ship: Deployment
| - - -## :octicons-stop-24: Limitations - -- For interactive development `dbx` can only be used for Python and JVM-based projects. - Please note that development workflow for JVM-based projects is different from the Python ones. - For R-based projects, `dbx` can only be used as a deployment management and workflow launch tool. -- `dbx` currently doesn't provide interactive debugging capabilities. - If you want to use interactive debugging, you can use [Databricks - Connect](https://docs.databricks.com/dev-tools/databricks-connect.html), and then use - `dbx` for deployment operations. -- [Delta Live - Tables](https://databricks.com/product/delta-live-tables) currently are not supported. We plan do add support for DLT deployments in near future. - -## :octicons-law-24: Legal Information - -!!! danger "Support notice" - - This software is provided as-is and is not officially supported by - Databricks through customer technical support channels. Support, questions, and feature requests can be communicated through the Issues - page of the [dbx repo](https://github.com/databrickslabs/dbx/issues). Please see the legal agreement and understand that - issues with the use of this code will not be answered or investigated by - Databricks Support. - -## :octicons-comment-24: Feedback - -Issues with `dbx`? Found a :octicons-bug-24: bug? -Have a great idea for an addition? Want to improve the documentation? Please feel -free to file an [issue](https://github.com/databrickslabs/dbx/issues/new/choose). - -## :fontawesome-solid-user-plus: Contributing - -Please find more details about contributing to `dbx` in the contributing -[doc](https://github.com/databrickslabs/dbx/blob/master/contrib/CONTRIBUTING.md). +--- +hide: +- navigation +- toc +--- + + + +
+ + +
+
+
+ + logo + +
+ +
+

+ dbx by Databricks Labs +

+
+
+
+

+ 🧱 Databricks CLI eXtensions - aka dbx is a CLI tool for development and advanced Databricks + workflows + management. +

+
+
diff --git a/docs/intro.md b/docs/intro.md new file mode 100644 index 00000000..687e3cac --- /dev/null +++ b/docs/intro.md @@ -0,0 +1,90 @@ +# :octicons-feed-rocket-16: dbx by Databricks Labs - intro + +

+ + logo + +

+ +🧱 Databricks CLI eXtensions - aka `dbx` is a CLI tool for development and advanced Databricks workflows management. + +## :octicons-light-bulb-24: Concept + +`dbx` aims to improve development experience for Data and ML teams that use Databricks, by providing the following capabilities: + +- Ready to use project templates, with built-in support to use custom templates +- Simple configuration for multi-environment setup +- Interactive development loop for Python-based projects +- Flexible deployment configuration +- Built-in versioning for deployments + +Since `dbx` primary interface is CLI, it's easy to use it in various CI/CD pipelines, independent of the CI provider. + +Read more about the place of `dbx` and potential use-cases in the [ecosystem section](concepts/ecosystem.md). + +## :thinking: Differences from other tools + +| Tool | Comment | +|--------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [databricks-cli](https://github.com/databricks/databricks-cli) | dbx is NOT a replacement for databricks-cli. Quite the opposite - dbx is heavily dependent on databricks-cli and uses most of the APIs exactly from databricks-cli SDK. | +| [mlflow cli](https://www.mlflow.org/docs/latest/cli.html) | dbx is NOT a replacement for mlflow cli. dbx uses some of the MLflow APIs under the hood to store serialized job objects, but doesn't use mlflow CLI directly. | +| [Databricks Terraform Provider](https://github.com/databrickslabs/terraform-provider-databricks) | While dbx is primarily oriented on versioned job management, Databricks Terraform Provider provides much wider set of infrastructure settings. In comparison, dbx doesn't provide infrastructure management capabilities, but brings more flexible deployment and launch options. | +| [Databricks Stack CLI](https://docs.databricks.com/dev-tools/cli/stack-cli.html) | Databricks Stack CLI is a great component for managing a stack of objects. dbx concentrates on the versioning and packaging jobs together, not treating files and notebooks as a separate component. | + +Read more about the differences between `dbx` and other instruments in the [ecosystem section](concepts/ecosystem.md). + +## :octicons-link-external-24: Next steps + +Depending on your developer journey and overall tasks, you might use `dbx` in various ways: + +=== ":material-language-python: Python" + + | :material-sign-direction: Developer journey | :octicons-link-24: Link | :octicons-tag-24: Tags | + |-------------------------------------------------------------------------------------------------|----------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| + | Develop a new Python project on Databricks solely in IDE without using Notebooks | [Python quickstart](./guides/python/python_quickstart.md) |
:fontawesome-brands-python: Python
:fontawesome-solid-laptop: IDE
| + | Develop a new Python project on Databricks with Databricks Notebooks and partially in the IDE | [Python quickstart](./guides/python/python_quickstart.md) followed by
[Mixed-mode development loop for Python projects](./guides/python/devloop/mixed.md) |
:fontawesome-brands-python: Python
:fontawesome-solid-laptop: IDE
:material-notebook-heart-outline: Notebook
| + | Organize a development loop for an existing Notebooks-based project together with IDE | [Mixed-mode development loop for Python projects](./guides/python/devloop/mixed.md) |
:fontawesome-brands-python: Python
:fontawesome-solid-laptop: IDE
:material-notebook-heart-outline: Notebook
| + | Organize a development loop for an existing Python package-based project | [Development loop for Python package-based projects](./guides/python/devloop/package.md) |
:fontawesome-brands-python: Python
:fontawesome-solid-laptop: IDE
:octicons-package-16: Packaging
| + | Add workflow deployment and automation capabilities to an existing Python package-based project | [DevOps for Python package-based projects](./guides/python/devops/package.md) |
:fontawesome-brands-python: Python
:octicons-package-16: Packaging
:fontawesome-solid-ship: Deployment
| + | Add workflow deployment and automation capabilities to an existing Notebooks-based project | [DevOps for Notebooks-based projects](./guides/python/devops/notebook.md) |
:fontawesome-brands-python: Python
:material-notebook-heart-outline: Notebook
:fontawesome-solid-ship: Deployment
| + +=== ":material-language-java: Java and Scala" + + | :material-sign-direction: Developer journey | :octicons-link-24: Link | :octicons-tag-24: Tags | + |-------------------------------------------------------------------------------------------------|----------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| + | Organize a development loop for an existing JVM-based project (e.g. Java or Scala) in IDE | [Development loop for JVM-based projects](./guides/jvm/jvm_devloop.md) |
:fontawesome-brands-java: JVM
:fontawesome-solid-laptop: IDE
| + | Add workflow deployment and automation capabilities to an existing JVM-based project | [DevOps for JVM-based projects](./guides/jvm/jvm_devops.md) |
:fontawesome-brands-java: JVM
:octicons-package-16: Packaging
:fontawesome-solid-ship: Deployment
| + + +## :octicons-stop-24: Limitations + +- For interactive development `dbx` can only be used for Python and JVM-based projects. + Please note that development workflow for JVM-based projects is different from the Python ones. + For R-based projects, `dbx` can only be used as a deployment management and workflow launch tool. +- `dbx` currently doesn't provide interactive debugging capabilities. + If you want to use interactive debugging, you can use [Databricks + Connect](https://docs.databricks.com/dev-tools/databricks-connect.html), and then use + `dbx` for deployment operations. +- [Delta Live + Tables](https://databricks.com/product/delta-live-tables) are supported for deployment and launch. The interactive execution mode is not supported. Please read more on DLT with `dbx` in [this guide](guides/general/delta_live_tables.md). + +## :octicons-law-24: Legal Information + +!!! danger "Support notice" + + This software is provided as-is and is not officially supported by + Databricks through customer technical support channels. Support, questions, and feature requests can be communicated through the Issues + page of the [dbx repo](https://github.com/databrickslabs/dbx/issues). Please see the legal agreement and understand that + issues with the use of this code will not be answered or investigated by + Databricks Support. + +## :octicons-comment-24: Feedback + +Issues with `dbx`? Found a :octicons-bug-24: bug? +Have a great idea for an addition? Want to improve the documentation? Please feel +free to file an [issue](https://github.com/databrickslabs/dbx/issues/new/choose). + +## :fontawesome-solid-user-plus: Contributing + +Please find more details about contributing to `dbx` in the contributing +[doc](https://github.com/databrickslabs/dbx/blob/master/contrib/CONTRIBUTING.md). diff --git a/docs/migration.md b/docs/migration.md index d8a75f47..7339158b 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1,7 +1,17 @@ -# Migration between dbx versions +# :material-arrow-up-bold-hexagon-outline: Migration between dbx versions +## :material-arrow-up-bold-hexagon-outline: From 0.7.0 to 0.8.0 -## From 0.6.0 and earlier to 0.7.0 +- The interface for `--parameters` passing has been changed. Please check + the [doc dedicated to parameter passing](./guides/general/passing_parameters.md). +- `dbx deploy --write-specs-to-file` now generates a JSON payload which is `workflows` based (not `jobs`). +- The `permissions` section is **not processed** anymore in the workflow definitions. Simply delete the `permissions` + section and replace it with the content of `access_control_list` subsection. Read more on this + change [here](features/permissions_management.md). +- Logic of `init_scripts` resolution in case of a policy reference has changed. Read + more [here](features/named_properties.md#init-scripts-resolution-logic). + +## :material-arrow-up-bold-hexagon-outline: From 0.6.0 and earlier to 0.7.0 - Azure Data Factory support has been **dropped**.
Please use Azure Data Factory APIs directly on top of the deployed workflow definitions.
@@ -9,9 +19,13 @@ ```bash dbx deploy ... --write-specs-to-file=.dbx/deployment-result.json ``` -- `--job`, `--jobs` arguments were deprecated. Please pass the workflow name as argument, and for `--jobs` use `--workflows`. -- `dbx sync` arguments `--allow-delete-unmatched`/`--disallow-delete-unmatched` were **replaced** with `--unmatched-behaviour` option. -- `jobs` section in the deployment file has been renamed to `workflows`. Old versions will continue working, but a warning will pop up. +- `--job`, `--jobs` arguments were deprecated. Please pass the workflow name as argument, and for `--jobs` + use `--workflows`. +- `dbx sync` arguments `--allow-delete-unmatched`/`--disallow-delete-unmatched` were **replaced** + with `--unmatched-behaviour` option. +- `jobs` section in the deployment file has been renamed to `workflows`. Old versions will continue working, but a + warning will pop up. - `--files-only` and `--as-run-submit` options are deprecated. Please use `--assets-only` and `--from-assets` instead. -- Project file format has been changed. Old format is supported, but a warning pops up. Please migrate to the new format as described [here](./reference/project.md). +- Project file format has been changed. Old format is supported, but a warning pops up. Please migrate to the new format + as described [here](./reference/project.md). diff --git a/docs/overrides/404.html b/docs/overrides/404.html new file mode 100644 index 00000000..9cd62fc7 --- /dev/null +++ b/docs/overrides/404.html @@ -0,0 +1,18 @@ +{% extends "base.html" %} + +{% block htmltitle %} + dbx - Not Found +{% endblock %} + +{% block content %} +
+

😱 Unfortunately, this page doesn't exist.

+

+ Most probably it was moved to another location.
+ Please use the 🔍 search to find what you're looking for. +

+

+ If you want to report an issue with the dbx docs, please create it on GitHub. +

+
+{% endblock %} diff --git a/docs/overrides/partials/source-file.html b/docs/overrides/partials/source-file.html new file mode 100644 index 00000000..dbbd21ff --- /dev/null +++ b/docs/overrides/partials/source-file.html @@ -0,0 +1,27 @@ +
+
+ +
+ + + + + + + + {{ page.meta.git_revision_date_localized }} + + + + + + + + {{ page.meta.git_creation_date_localized }} + +
+ +
+
diff --git a/docs/reference/deployment.md b/docs/reference/deployment.md index dee58154..0cfc4fb7 100644 --- a/docs/reference/deployment.md +++ b/docs/reference/deployment.md @@ -2,6 +2,7 @@ Deployment file is one of the most important files for `dbx` functionality. +## :material-format-columns: File format options It contains workflows definitions, as well as `build` configurations. The following file extensions are supported: @@ -14,8 +15,9 @@ The following file extensions are supported: By default `dbx` commands will search for a `deployment.*` file in the `conf` directory of the project. Alternatively, all commands that require a deployment file support passing it explicitly via `--deployment-file` option. +## :material-page-layout-header-footer: Layout -Typical layout of this file looks like this: +Typical layout of the deployment file looks like this: ```yaml title="conf/deployment.yml" @@ -45,41 +47,122 @@ environments: #(2) As the project file, deployment file supports multiple environments. You can configure them by naming new environments under the `environments` section. -The `workflows` section of the deployment file fully follows the [Databricks Jobs API structures](https://docs.databricks.com/dev-tools/api/latest/jobs.html). +The `workflows` section of the deployment file fully follows the [Databricks Jobs API structures](https://docs.databricks.com/dev-tools/api/latest/jobs.html) with features described in [this section](../features/assets.md). -## :material-package-up: Advanced package dependency management +## :material-alpha-t-box-outline: Various workflow types +Available since 0.8.0 -By default `dbx` is heavily oriented towards Python package-based projects. However, for pure Notebook or JVM projects this might be not necessary. +`dbx` supports the following workflow types: -Therefore, to disable the default behaviour of `dbx` which tries to add the Python package dependency, use the `deployment_config` section inside the task definition: +- Workflows in [Jobs API :material-surround-sound-2-0: format](https://docs.databricks.com/dev-tools/api/2.0/jobs.html) +- Workflows in [Jobs API :material-surround-sound-2-1: format](https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate) +- Workflows in [:material-table-heart: Delta Live Tables pipeline format](https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate) + +If the `workflow_type` property is not specified on the workflow level, `dbx` will define the type based on the following rules: + +- If `tasks` section is provided, then the workflow type is `jobs-v2.1`. +- If this section is not provided, workflow will be parsed as a `jobs-v2.0`. + +Allowed values for the `workflow_type` field are: + +- `jobs-v2.1` +- `jobs-v2.0` +- `pipeline` + +Examples below demonstrate how to define various workflow types: + +```yaml title="conf/deployment.yml" + +build: + python: "pip" -```yaml title="conf/deployment.yml" hl_lines="12-13" -# some code omitted environments: default: workflows: - - name: "workflow1" + + ################################################ + - name: "workflow-in-v2.1-format" tasks: - task_key: "task1" - python_wheel_task: #(1) + python_wheel_task: package_name: "some-pkg" entry_point: "some-ep" - - task_key: "task2" #(2) - deployment_config: + + ################################################ + - name: "workflow-in-v2.0-format" + spark_python_task: + python_file: "file://some/file.py" + + ################################################ + - name: "workflow-in-pipeline-format" + target: "some-target-db" + workflow_type: "pipeline" # enforces the recognition + libraries: + - notebook: + path: "/Repos/some/path" +``` + +!!! tip "DLT with `dbx` guide" + + Read more on the topic of DLT pipelines with `dbx` [here](../guides/general/delta_live_tables.md). + +## :material-package-up: Advanced package dependency management + +By default `dbx` is heavily oriented towards Python package-based projects. However, for pure Notebook or JVM projects this might be not necessary. + +Therefore, to disable the default behaviour of `dbx` which tries to add the Python package dependency, use the `deployment_config` section inside the task definition: + +=== "Latest - Jobs API :material-surround-sound-2-1:" + + ```yaml title="conf/deployment.yml" hl_lines="12-13" + # some code omitted + environments: + default: + workflows: + - name: "workflow1" + tasks: + - task_key: "task1" + python_wheel_task: #(1) + package_name: "some-pkg" + entry_point: "some-ep" + - task_key: "task2" #(2) + deployment_config: + no_package: true + notebook_task: + notebook_path: "/some/notebook/path" + ``` + + 1. Standard Python package-based payload, the python wheel dependency will be added by default + 2. In the notebook task, the Python package is not required since code is delivered together with the Notebook. + Therefore, we disable this behaviour by providing this property. + +=== "Legacy - Jobs API :material-surround-sound-2-0:" + + ```yaml title="conf/deployment.yml" hl_lines="10-11" + # some code omitted + environments: + default: + workflows: + - name: "wheel-workflow" + python_wheel_task: #(1) + package_name: "some-pkg" + entry_point: "some-ep" + - name: "notebook-workflow" #(2) + deployment_config: no_package: true - notebook_task: + notebook_task: notebook_path: "/some/notebook/path" -``` + ``` -1. Standard Python package-based payload, the python wheel dependency will be added by default -2. In the notebook task, the Python package is not required since code is delivered together with the Notebook. - Therefore, we disable this behaviour by providing this property. + 1. Standard Python package-based payload, the python wheel dependency will be added by default + 2. In the notebook task, the Python package is not required since code is delivered together with the Notebook. + Therefore, we disable this behaviour by providing this property. ## :material-folder-star-multiple: Examples This section contains various examples of the deployment file for various cases. -Most of the examples below use inplace Jinja functionality which is [described here](../features/jinja_support.md#enable-inplace-jinja) +Most of the examples below use inplace Jinja functionality which is [described here](../features/jinja_support.md#enable-inplace-jinja). ### :material-tag-plus: Tagging workflows @@ -119,7 +202,7 @@ environments: timezone_id: "Europe/Berlin" #(2) ``` -1. This sets up the schedule for every day at midnight. Check [chrontab.guru](https://crontab.guru/) for more examples. +1. This sets up the schedule for every day at midnight. Check [this site](http://www.quartz-scheduler.org/documentation/quartz-2.3.0/tutorials/crontrigger.html) for more examples. 2. Timezone is set accordingly to the Java [`TimeZone`](https://docs.oracle.com/javase/7/docs/api/java/util/TimeZone.html) class. !!! tip "Official Databricks docs" @@ -132,6 +215,44 @@ environments: Here are some examples for [Apache Airflow](https://airflow.apache.org/docs/apache-airflow-providers-databricks/stable/connections/databricks.html) and [Prefect](https://docs-v1.prefect.io/api/0.15.13/tasks/databricks.html). +### :octicons-zap-24: Enabling Photon + +To define job clusters with [Photon](https://www.databricks.com/product/photon) support, add the following to the configuration: + +```yaml title="conf/deployment.yaml" hl_lines="9" +custom: + basic-cluster-props: &basic-cluster-props + spark_version: "your-spark-version" + node_type_id: "your-node-type-id" + spark_conf: + spark.databricks.delta.preview.enabled: 'true' + instance_pool_name: + driver_instance_pool_name: + runtime_engine: PHOTON + init_scripts: + - dbfs: + destination: dbfs:/ + +``` + +### :material-dots-hexagon: Managing the workflow as a service principal + +Available since 0.8.0 + +This example uses the [named reference feature](../features/named_properties.md#reference-based-approach): + +```yaml title="conf/deployment.yml" hl_lines="5-7" +environments: + default: + workflows: + - name: "example-workflow" + access_control_list: + - user_name: "service-principal://some-service-principal-name" + permission_level: "IS_OWNER" + - user_name: "some-real-user@email.com" + permission_level: "CAN_MANAGE" +``` + ### :material-code-array: Configuring complex deployments While configuring complex deployments, it's recommended to use YAML anchor mechanics to avoid repeating code blocks. @@ -145,28 +266,28 @@ custom: node_type_id: "your-node-type-id" spark_conf: spark.databricks.delta.preview.enabled: 'true' - instance_pool_name: - driver_instance_pool_name: + instance_pool_id: "instance-pool://some-pool-name" + driver_instance_pool_id: "instance-pool://some-pool-name" runtime_engine: STANDARD init_scripts: - dbfs: - destination: dbfs:/ + destination: dbfs:/ basic-auto-scale-props: &basic-auto-scale-props autoscale: - min_workers: 2 - max_workers: 4 + min_workers: 2 + max_workers: 4 basic-static-cluster: &basic-static-cluster new_cluster: - <<: *basic-cluster-props - num_workers: 2 + <<: *basic-cluster-props + num_workers: 2 basic-autoscale-cluster: &basic-autoscale-cluster new_cluster: - <<: # merge these two maps and place them here. - - *basic-cluster-props - - *basic-auto-scale-props + <<: # merge these two maps and place them here. + - *basic-cluster-props + - *basic-auto-scale-props environments: default: @@ -177,12 +298,13 @@ environments: on_start: [ "user@email.com" ] on_success: [ "user@email.com" ] on_failure: [ "user@email.com" ] - no_alert_for_skipped_runs: false + + no_alert_for_skipped_runs: false schedule: - quartz_cron_expression: "00 25 03 * * ?" #(1) - timezone_id: "UTC" - pause_status: "PAUSED" + quartz_cron_expression: "00 25 03 * * ?" #(1) + timezone_id: "UTC" + pause_status: "PAUSED" tags: your-key: "your-value" @@ -190,22 +312,17 @@ environments: format: MULTI_TASK #(2) - permissions: - access_control_list: - - user_name: "user@email.com" - permission_level: "IS_OWNER" - #- group_name: "your-group-name" - #permission_level: "CAN_VIEW" - #- user_name: "user2@databricks.com" - #permission_level: "CAN_VIEW" - #- user_name: "user3@databricks.com" - #permission_level: "CAN_VIEW" + access_control_list: + - user_name: "service-principal://some-service-principal" + permission_level: "IS_OWNER" + - group_name: "your-group-name" + permission_level: "CAN_VIEW" job_clusters: - job_cluster_key: "basic-cluster" - <<: *basic-static-cluster + <<: *basic-static-cluster - job_cluster_key: "basic-autoscale-cluster" - <<: *basic-autoscale-cluster + <<: *basic-autoscale-cluster tasks: - task_key: "your-task-01" @@ -233,12 +350,28 @@ environments: depends_on: - task_key: "your-task-01" - - task_key: "your-task-02" + - task_key: "your-task-03" job_cluster_key: "basic-cluster" notebook_task: notebook_path: "/Repos/some/project/notebook" depends_on: - task_key: "your-task-01" + + - task_key: "example_sql_task" + job_cluster_key: "basic-cluster" + sql_task: + query: "query://some-query-name" + warehouse_id: "warehouse://some-warehouse-id" + + - task_key: "example_dbt_task" + job_cluster_key: "basic-cluster" + dbt_task: + project_directory: "/some/project/dir" + profiles_directory: "/some/profiles/dir" + warehouse_id: "warehouse://some-warehouse-id" + commands: + - "dbt cmd1" + - "dbt cmd2" ``` 1. Read more about scheduling in [this section](./deployment.md#scheduling-workflows) diff --git a/mkdocs.yml b/mkdocs.yml index c8c2f734..10381657 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -5,6 +5,7 @@ copyright: © Copyright 2022, Databricks Inc. theme: name: material + custom_dir: docs/overrides font: text: DM Sans code: DM Mono @@ -12,14 +13,14 @@ theme: - media: '(prefers-color-scheme: light)' scheme: default primary: black - accent: amber + accent: indigo toggle: icon: material/lightbulb name: Switch to light mode - media: '(prefers-color-scheme: dark)' scheme: slate primary: black - accent: amber + accent: indigo toggle: icon: material/lightbulb-outline name: Switch to dark mode @@ -32,6 +33,7 @@ theme: - toc.follow - navigation.sections - navigation.tabs + - navigation.top icon: repo: fontawesome/brands/github logo: img/logo.svg @@ -40,12 +42,19 @@ theme: repo_name: databrickslabs/dbx repo_url: https://github.com/databrickslabs/dbx +edit_uri: edit/main/docs plugins: - - glightbox - search + - glightbox + - git-revision-date-localized: + type: timeago + enable_creation_date: true + exclude: + - index.md -extra_css: [ custom/custom.css ] +extra_css: + - extras/styles.css markdown_extensions: - attr_list @@ -73,15 +82,20 @@ markdown_extensions: custom_checkbox: true nav: - - Home: index.md + - 'index.md' + - Intro: intro.md - Concepts: - Place of dbx in the ecosystem: concepts/ecosystem.md - DevOps for workflows: concepts/devops.md - Testing workflows: concepts/testing.md - Cluster types for various flows: concepts/cluster_types.md + - Artifact storage: concepts/artifact_storage.md - Guides: - General: - Dependency management: guides/general/dependency_management.md + - Passing parameters: guides/general/passing_parameters.md + - Delta Live Tables: guides/general/delta_live_tables.md + - Custom templates: guides/general/custom_templates.md - Python: - Python quickstart: guides/python/python_quickstart.md - Development loop: @@ -109,5 +123,3 @@ nav: - Project file reference: reference/project.md - Migration: migration.md - Frequently Asked Questions: faq.md - - diff --git a/prospector.yaml b/prospector.yaml index 1a8634f0..3efb372f 100644 --- a/prospector.yaml +++ b/prospector.yaml @@ -11,6 +11,7 @@ max-line-length: 120 pylint: disable: + - too-many-ancestors # in dbx we pretty comfortable having a lot of mixins - too-many-branches - too-many-statements - too-many-instance-attributes @@ -31,14 +32,11 @@ pylint: - broad-except - arguments-differ -pep8: +pycodestyle: # W293: disabled because we have newlines in docstrings # E203: disabled because pep8 and black disagree on whitespace before colon in some cases disable: W293,E203 -pycodestyle: - disable: E203 # conflicts with black formatting - mccabe: disable: - MC0001 diff --git a/setup.py b/setup.py index 0ce95b5f..5c8da95e 100644 --- a/setup.py +++ b/setup.py @@ -11,17 +11,17 @@ # to use Databricks and MLflow APIs "retry>=0.9.2, <1.0.0", "requests>=2.24.0, <3.0.0", - "mlflow-skinny>=1.28.0,<=2.0.0", + "mlflow-skinny>=1.28.0,<3.0.0", "databricks-cli>=0.17,<0.18", # CLI interface "click>=8.1.0,<9.0.0", - "rich==12.5.1", - "typer[all]==0.6.1", + "rich==12.6.0", + "typer[all]==0.7.0", # for templates creation "cookiecutter>=1.7.2, <3.0.0", # file formats and models "pyyaml>=6.0", - "pydantic>=1.9.1", + "pydantic>=1.9.1,<=2.0.0", "Jinja2>=2.11.2", # misc - enforced to avoid issues with dependent libraries "cryptography>=3.3.1,<39.0.0", @@ -38,33 +38,45 @@ # utilities for documentation "mkdocs>=1.1.2,<2.0.0", "mkdocs-click>=0.8.0,<1.0", - "mkdocs-material>=8.5,<9.0.0", + "mkdocs-material>=8.5.9,<9.0.0", "mdx-include>=1.4.1,<2.0.0", "mkdocs-markdownextradata-plugin>=0.1.7,<0.3.0", "mkdocs-glightbox>=0.2.1,<1.0", + "mkdocs-git-revision-date-localized-plugin>=1.1.0,<=2.0", # pre-commit and linting utilities "pre-commit>=2.20.0,<3.0.0", - "prospector==1.7.0", + "pylint==2.15.6", + "pycodestyle==2.8.0", + "pyflakes==2.5.0", + "mccabe==0.6.1", + "prospector==1.7.7", "black>=22.3.0,<23.0.0", "MarkupSafe>=2.1.1,<3.0.0", # testing framework - "pytest>=7.1.2,<8.0.0", - "pytest-mock>=3.8.2,<3.9.0", + "pytest>=7.1.3,<8.0.0", + "pytest-mock>=3.8.2,<3.11.0", "pytest-xdist[psutil]>=2.5.0,<3.0.0", "pytest-asyncio>=0.18.3,<1.0.0", - "pytest-cov>=3.0.0,<4.0.0", + "pytest-cov>=4.0.0,<5.0.0", "pytest-timeout>=2.1.0,<3.0.0", "pytest-clarity>=1.0.1,<2.0.0", "poetry>=1.2.0", ] +AZURE_EXTRAS = ["azure-storage-blob>=12.14.1,<13.0.0", "azure-identity>=1.12.0,<2.0.0"] + +AWS_EXTRAS = [ + "boto3>=1.26.13,<2", +] + +GCP_EXTRAS = ["google-cloud-storage>=2.6.0,<3.0.0"] + setup( name="dbx", python_requires=">=3.8", packages=find_packages(exclude=["tests", "tests.*"]), - setup_requires=["wheel>=0.37.1,<0.38"], install_requires=INSTALL_REQUIRES, - extras_require={"dev": DEV_REQUIREMENTS}, + extras_require={"dev": DEV_REQUIREMENTS, "azure": AZURE_EXTRAS, "aws": AWS_EXTRAS, "gcp": GCP_EXTRAS}, entry_points={"console_scripts": ["dbx=dbx.cli:entrypoint"]}, long_description=long_description, long_description_content_type="text/markdown", diff --git a/tests/unit/api/adjuster/__init__.py b/tests/unit/api/adjuster/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/api/adjuster/test_complex.py b/tests/unit/api/adjuster/test_complex.py new file mode 100644 index 00000000..44ab508c --- /dev/null +++ b/tests/unit/api/adjuster/test_complex.py @@ -0,0 +1,145 @@ +from unittest.mock import MagicMock + +import pytest +import yaml +from databricks_cli.sdk import InstancePoolService +from pytest_mock import MockerFixture + +from dbx.api.adjuster.adjuster import Adjuster, AdditionalLibrariesProvider +from dbx.api.services.pipelines import NamedPipelinesService +from dbx.models.deployment import DeploymentConfig +from dbx.models.workflow.common.libraries import Library +from dbx.models.workflow.common.workflow_types import WorkflowType + +TEST_PAYLOAD = yaml.safe_load( + """ +custom: + basic-cluster-props: &basic-cluster-props + spark_version: "9.1.x-cpu-ml-scala2.12" + + basic-static-cluster: &basic-static-cluster + new_cluster: + <<: *basic-cluster-props + num_workers: 1 + instance_pool_id: "instance-pool://pool-2" + instance_pool_id: "instance-pool://pool-2" + +environments: + default: + workflows: + + - name: "dlt-test" + workflow_type: "pipeline" + deployment_config: + no_package: true + target: "some" + libraries: + - notebook: + path: "/Repos/some" + clusters: + - label: "default" + autoscale: + min_workers: 1 + max_workers: 3 + mode: "ENHANCED" + instance_pool_id: "instance-pool://pool-1" + driver_instance_pool_id: "instance-pool://pool-1" + + + - name: "dbx-pipeline-chain" + access_control_list: + - user_name: "some@email.com" + permission_level: "IS_OWNER" + job_clusters: + - job_cluster_key: "main" + <<: *basic-static-cluster + tasks: + - task_key: "first" + job_cluster_key: "main" + python_wheel_task: + entry_point: "etl" + package_name: "dbx_exec_srv" + - task_key: "second" + deployment_config: + no_package: true + pipeline_task: + pipeline_id: "pipeline://dlt-test" +""" +) + +TEST_CONFIG = DeploymentConfig.from_payload(TEST_PAYLOAD) +ENVIRONMENT_DEFINITION = TEST_CONFIG.get_environment("default") + + +@pytest.fixture +def complex_instance_pool_mock(mocker: MockerFixture): + mocker.patch.object( + InstancePoolService, + "list_instance_pools", + MagicMock( + return_value={ + "instance_pools": [ + {"instance_pool_name": "pool-1", "instance_pool_id": "some-id-1"}, + {"instance_pool_name": "pool-2", "instance_pool_id": "some-id-2"}, + ] + } + ), + ) + + +@pytest.fixture +def named_pipeline_mock(mocker: MockerFixture): + mocker.patch.object(NamedPipelinesService, "find_by_name_strict", MagicMock(return_value="aa-bb")) + + +def test_complex(complex_instance_pool_mock, named_pipeline_mock): + wfs = TEST_CONFIG.get_environment("default").payload.workflows + core_pkg = Library(whl="/some/local/file") + client_mock = MagicMock() + _adj = Adjuster( + additional_libraries=AdditionalLibrariesProvider(core_package=core_pkg), + file_uploader=MagicMock(), + api_client=client_mock, + ) + _adj.traverse(wfs) + + assert wfs[0].workflow_type == WorkflowType.pipeline + assert wfs[0].clusters[0].instance_pool_id == "some-id-1" + assert wfs[0].clusters[0].driver_instance_pool_id == "some-id-1" + assert len(wfs[0].libraries) == 1 + assert wfs[1].workflow_type == WorkflowType.job_v2d1 + assert core_pkg in wfs[1].get_task("first").libraries + assert core_pkg not in wfs[1].get_task("second").libraries + + +TEST_ARBITRARY_TRVS_PAYLOAD = yaml.safe_load( + """ +environments: + default: + workflows: + - name: "sample-wf" + tags: + key: "instance-pool://pool-1" + tasks: + - task_key: "some-task" + some_task: + list_props: + - "instance-pool://pool-1" + nested_props: + nested_key: "instance-pool://pool-1" +""" +) + + +def test_arbitrary_traversals(complex_instance_pool_mock): + default = DeploymentConfig.from_payload(TEST_ARBITRARY_TRVS_PAYLOAD).get_environment("default") + _adj = Adjuster( + additional_libraries=AdditionalLibrariesProvider(core_package=None), + file_uploader=MagicMock(), + api_client=MagicMock(), + ) + _adj.traverse(default.payload.workflows) + _wf = default.payload.get_workflow("sample-wf") + assert _wf.tags["key"] == "some-id-1" + assert getattr(_wf.get_task("some-task"), "some_task")["list_props"][0] == "some-id-1" + assert getattr(_wf.get_task("some-task"), "some_task")["nested_props"]["nested_key"] == "some-id-1" diff --git a/tests/unit/api/adjuster/test_existing_cluster.py b/tests/unit/api/adjuster/test_existing_cluster.py new file mode 100644 index 00000000..a3b2ea88 --- /dev/null +++ b/tests/unit/api/adjuster/test_existing_cluster.py @@ -0,0 +1,82 @@ +from typing import List +from unittest.mock import MagicMock + +import pytest +import yaml +from databricks_cli.sdk import ClusterService +from pytest_mock import MockerFixture + +from dbx.api.adjuster.adjuster import Adjuster, AdditionalLibrariesProvider +from dbx.models.deployment import AnyWorkflow +from dbx.models.workflow.v2dot0.workflow import Workflow as V2dot0Workflow +from dbx.models.workflow.v2dot1.workflow import Workflow as V2dot1Workflow + +TEST_PAYLOADS = { + "legacy": """ + name: "a" + some_task: "a" + existing_cluster_name: "some-cluster" + """, + "property": """ + name: "a" + tasks: + - task_key: "a" + some_task: "a" + existing_cluster_id: "cluster://some-cluster" + """, + "duplicated": """ + name: "a" + tasks: + - task_key: "a" + some_task: "a" + existing_cluster_id: "cluster://some-duplicated-cluster" + """, + "not_found": """ + name: "a" + tasks: + - task_key: "a" + some_task: "a" + existing_cluster_id: "cluster://some-not-found-cluster" + """, +} + + +@pytest.fixture +def existing_cluster_mock(mocker: MockerFixture): + mocker.patch.object( + ClusterService, + "list_clusters", + MagicMock( + return_value={ + "clusters": [ + {"cluster_name": "some-cluster", "cluster_id": "some-id"}, + {"cluster_name": "some-duplicated-cluster", "cluster_id": "some-id-1"}, + {"cluster_name": "some-duplicated-cluster", "cluster_id": "some-id-2"}, + ] + } + ), + ) + + +def convert_to_workflow_with_legacy_support(key: str, payload: str) -> List[AnyWorkflow]: + _base_class = V2dot0Workflow if key == "legacy" else V2dot1Workflow + return [_base_class(**yaml.safe_load(payload))] + + +@pytest.mark.parametrize("key", list(TEST_PAYLOADS.keys())) +def test_instance_pools(key, existing_cluster_mock): + _wf = convert_to_workflow_with_legacy_support(key, TEST_PAYLOADS[key]) + _adj = Adjuster( + api_client=MagicMock(), + additional_libraries=AdditionalLibrariesProvider(core_package=None), + file_uploader=MagicMock(), + ) + if key in ["duplicated", "not_found"]: + with pytest.raises(AssertionError): + _adj.traverse(_wf) + elif key == "legacy": + _adj.traverse(_wf) + assert _wf[0].existing_cluster_id == "some-id" + else: + _adj.traverse(_wf) + assert _wf[0].get_task("a").existing_cluster_id == "some-id" diff --git a/tests/unit/api/adjuster/test_instance_pool.py b/tests/unit/api/adjuster/test_instance_pool.py new file mode 100644 index 00000000..01abb60c --- /dev/null +++ b/tests/unit/api/adjuster/test_instance_pool.py @@ -0,0 +1,83 @@ +from unittest.mock import MagicMock + +import pytest +from databricks_cli.sdk import InstancePoolService +from pytest_mock import MockerFixture + +from dbx.api.adjuster.adjuster import Adjuster, AdditionalLibrariesProvider +from .test_instance_profile import convert_to_workflow + +TEST_PAYLOADS = { + "legacy": """ + job_clusters: + - job_cluster_key: "some-cluster" + new_cluster: + spark_version: "some" + instance_pool_name: "some-pool" + driver_instance_pool_name: "some-pool" + name: "test" + some_task: "a" + """, + "property": """ + job_clusters: + - job_cluster_key: "some-cluster" + new_cluster: + spark_version: "some" + instance_pool_id: "instance-pool://some-pool" + driver_instance_pool_id: "instance-pool://some-pool" + name: "test" + some_task: "a" + """, + "duplicated": """ + job_clusters: + - job_cluster_key: "some-cluster" + new_cluster: + spark_version: "some" + instance_pool_id: "instance-pool://some-duplicated-pool" + name: "instance-profile-test" + some_task: "a" + """, + "not_found": """ + job_clusters: + - job_cluster_key: "some-cluster" + new_cluster: + spark_version: "some" + aws_attributes: + instance_pool_id: "instance-pool://some-non-existent-pool" + name: "test" + some_task: "a" + """, +} + + +@pytest.fixture +def instance_pool_mock(mocker: MockerFixture): + mocker.patch.object( + InstancePoolService, + "list_instance_pools", + MagicMock( + return_value={ + "instance_pools": [ + {"instance_pool_name": "some-pool", "instance_pool_id": "some-id"}, + {"instance_pool_name": "some-duplicated-pool", "instance_pool_id": "some-id-1"}, + {"instance_pool_name": "some-duplicated-pool", "instance_pool_id": "some-id-2"}, + ] + } + ), + ) + + +@pytest.mark.parametrize("key", list(TEST_PAYLOADS.keys())) +def test_instance_pools(key, instance_pool_mock): + _wf = convert_to_workflow(TEST_PAYLOADS[key]) + _adj = Adjuster( + api_client=MagicMock(), + additional_libraries=AdditionalLibrariesProvider(core_package=None), + file_uploader=MagicMock(), + ) + if key in ["duplicated", "not_found"]: + with pytest.raises(AssertionError): + _adj.traverse(_wf) + else: + _adj.traverse(_wf) + assert _wf[0].get_job_cluster_definition("some-cluster").new_cluster.instance_pool_id == "some-id" diff --git a/tests/unit/api/adjuster/test_instance_profile.py b/tests/unit/api/adjuster/test_instance_profile.py new file mode 100644 index 00000000..4ab6f426 --- /dev/null +++ b/tests/unit/api/adjuster/test_instance_profile.py @@ -0,0 +1,89 @@ +from typing import List +from unittest.mock import MagicMock + +import pytest +import yaml + +from dbx.api.adjuster.adjuster import Adjuster, AdditionalLibrariesProvider +from dbx.models.workflow.v2dot1.workflow import Workflow + +TEST_PAYLOADS = { + "legacy": """ + job_clusters: + - job_cluster_key: "some-cluster" + new_cluster: + spark_version: "some" + aws_attributes: + instance_profile_name: "some-instance-profile" + name: "instance-profile-test" + some_task: "a" + """, + "property": """ + job_clusters: + - job_cluster_key: "some-cluster" + new_cluster: + spark_version: "some" + aws_attributes: + instance_profile_arn: "instance-profile://some-instance-profile" + name: "instance-profile-test" + some_task: "a" + """, + "duplicated": """ + job_clusters: + - job_cluster_key: "some-cluster" + new_cluster: + spark_version: "some" + aws_attributes: + instance_profile_arn: "instance-profile://some-duplicated-profile" + name: "instance-profile-test" + some_task: "a" + """, + "not_found": """ + job_clusters: + - job_cluster_key: "some-cluster" + new_cluster: + spark_version: "some" + aws_attributes: + instance_profile_arn: "instance-profile://some-not-found-profile" + name: "instance-profile-test" + some_task: "a" + """, +} + + +def convert_to_workflow(payload: str) -> List[Workflow]: + return [Workflow(**yaml.safe_load(payload))] + + +@pytest.fixture +def instance_profile_mock() -> MagicMock: + client = MagicMock() + client.perform_query = MagicMock( + return_value={ + "instance_profiles": [ + {"instance_profile_arn": "some-arn/some-instance-profile"}, + {"instance_profile_arn": "some-arn/some-duplicated-profile"}, + {"instance_profile_arn": "some-arn/some-duplicated-profile"}, + ] + } + ) + return client + + +@pytest.mark.parametrize("key", list(TEST_PAYLOADS.keys())) +def test_instance_profile(key, instance_profile_mock): + _wf = convert_to_workflow(TEST_PAYLOADS[key]) + _adj = Adjuster( + api_client=instance_profile_mock, + additional_libraries=AdditionalLibrariesProvider(core_package=None), + file_uploader=MagicMock(), + ) + if key in ["duplicated", "not_found"]: + with pytest.raises(AssertionError): + _adj.traverse(_wf) + else: + _adj.traverse(_wf) + assert ( + _wf[0].get_job_cluster_definition("some-cluster").new_cluster.aws_attributes.instance_profile_arn + == "some-arn/some-instance-profile" + ) diff --git a/tests/unit/api/adjuster/test_pipeline.py b/tests/unit/api/adjuster/test_pipeline.py new file mode 100644 index 00000000..3212048a --- /dev/null +++ b/tests/unit/api/adjuster/test_pipeline.py @@ -0,0 +1,62 @@ +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from dbx.api.adjuster.adjuster import Adjuster, AdditionalLibrariesProvider +from .test_instance_profile import convert_to_workflow + +TEST_PAYLOADS = { + "property": """ + name: "test" + tasks: + - task_key: "p1" + pipeline_task: + pipeline_id: "pipeline://some-pipeline" + """, + "duplicated": """ + name: "test" + tasks: + - task_key: "p1" + pipeline_task: + pipeline_id: "pipeline://some-duplicated-pipeline" + """, + "not_found": """ + name: "test" + tasks: + - task_key: "p1" + pipeline_task: + pipeline_id: "pipeline://some-non-existent-pipeline" + """, +} + + +@pytest.fixture +def pipelines_mock(mocker: MockerFixture): + client = MagicMock() + client.perform_query = MagicMock( + return_value={ + "statuses": [ + {"pipeline_id": "some-id", "name": "some-pipeline"}, + {"pipeline_id": "some-id-1", "name": "some-duplicated-pipeline"}, + {"pipeline_id": "some-id-2", "name": "some-duplicated-pipeline"}, + ] + } + ) + return client + + +@pytest.mark.parametrize("key", list(TEST_PAYLOADS.keys())) +def test_pipelines(key, pipelines_mock): + _wf = convert_to_workflow(TEST_PAYLOADS[key]) + _adj = Adjuster( + api_client=pipelines_mock, + additional_libraries=AdditionalLibrariesProvider(core_package=None), + file_uploader=MagicMock(), + ) + if key in ["duplicated", "not_found"]: + with pytest.raises(AssertionError): + _adj.traverse(_wf) + else: + _adj.traverse(_wf) + assert _wf[0].get_task("p1").pipeline_task.pipeline_id == "some-id" diff --git a/tests/unit/api/adjuster/test_policy.py b/tests/unit/api/adjuster/test_policy.py new file mode 100644 index 00000000..4631e026 --- /dev/null +++ b/tests/unit/api/adjuster/test_policy.py @@ -0,0 +1,296 @@ +from unittest.mock import MagicMock + +import pytest +import yaml +from databricks_cli.sdk import PolicyService +from pytest_mock import MockerFixture + +from dbx.api.adjuster.adjuster import Adjuster, AdditionalLibrariesProvider +from dbx.api.adjuster.policy import PolicyAdjuster +from dbx.models.deployment import DeploymentConfig +from dbx.models.workflow.common.libraries import Library +from dbx.models.workflow.common.new_cluster import NewCluster + + +def test_base_aws_policy(): + _policy = { + "aws_attributes.instance_profile_arn": { + "type": "fixed", + "value": "arn:aws:iam::123456789:instance-profile/sample-aws-iam", + }, + "spark_version": {"type": "fixed", "value": "lts"}, + "node_type_id": {"type": "allowlist", "values": ["node_1", "node_2"]}, + "spark_conf.spark.my.conf": {"type": "fixed", "value": "my_value"}, + "spark_conf.spark.my.other.conf": {"type": "fixed", "value": "my_other_value"}, + "init_scripts.0.dbfs.destination": {"type": "fixed", "value": "dbfs:/some/init-scripts/sc1.sh"}, + "init_scripts.1.dbfs.destination": {"type": "fixed", "value": "dbfs:/some/init-scripts/sc2.sh"}, + } + _formatted = { + "aws_attributes": {"instance_profile_arn": "arn:aws:iam::123456789:instance-profile/sample-aws-iam"}, + "spark_conf": {"spark.my.conf": "my_value", "spark.my.other.conf": "my_other_value"}, + "spark_version": "lts", + "init_scripts": [ + {"dbfs": {"destination": "dbfs:/some/init-scripts/sc1.sh"}}, + {"dbfs": {"destination": "dbfs:/some/init-scripts/sc2.sh"}}, + ], + } + api_client = MagicMock() + adj = PolicyAdjuster(api_client) + result = adj._traverse_policy(_policy) + assert result == _formatted + + +@pytest.fixture() +def policy_mock(mocker: MockerFixture): + mocker.patch.object( + PolicyService, + "list_policies", + MagicMock( + return_value={ + "policies": [ + { + "policy_id": 1, + "name": "good-policy", + "definition": """ + {"spark_conf.spark.my.conf": {"type": "fixed", "value": "my_value"}} + """, + }, + { + "policy_id": 2, + "name": "duplicated-name", + "definition": """ + {"spark_conf.spark.my.conf": {"type": "fixed", "value": "my_value"}} + """, + }, + { + "policy_id": 3, + "name": "duplicated-name", + "definition": """ + {"spark_conf.spark.my.conf": {"type": "fixed", "value": "my_value"}} + """, + }, + { + "policy_id": 4, + "name": "conflicting", + "definition": """ + {"spark_version": {"type": "fixed", "value": "some-other"}} + """, + }, + { + "policy_id": 20, + "name": "policy-with-one-script", + "definition": """ + { + "init_scripts.0.dbfs.destination": { + "type": "fixed", + "value": "dbfs://some/path/script.sh" + } + } + """, + }, + { + "policy_id": 10, + "name": "policy-with-multiple-scripts", + "definition": """ + { + "init_scripts.0.dbfs.destination": { + "type": "fixed", + "value": "dbfs://some/path/script.sh" + }, + "init_scripts.1.dbfs.destination": { + "type": "fixed", + "value": "dbfs://some/path/other-script.sh" + } + } + """, + }, + ] + } + ), + ) + + +@pytest.mark.parametrize( + "cluster_def", + [ + NewCluster(spark_version="lts", policy_name="good-policy"), + NewCluster(spark_version="lts", policy_id="cluster-policy://good-policy"), + ], +) +def test_adjusting(cluster_def, policy_mock): + _adj = PolicyAdjuster(api_client=MagicMock()) + _obj = _adj._adjust_policy_ref(cluster_def) + assert getattr(_obj, "spark_conf").get("spark.my.conf") == "my_value" + + +@pytest.mark.parametrize( + "cluster_def", + [ + NewCluster(spark_version="lts", policy_id="cluster-policy://duplicated-name"), + NewCluster(spark_version="lts", policy_id="cluster-policy://not-found"), + NewCluster(spark_version="lts", policy_id="cluster-policy://conflicting"), + ], +) +def test_negative_cases(cluster_def, policy_mock): + _adj = PolicyAdjuster(api_client=MagicMock()) + with pytest.raises(ValueError): + _obj = _adj._adjust_policy_ref(cluster_def) + + +TEST_DEFINITIONS = yaml.safe_load( + """ +environments: + default: + workflows: + - name: "legacy-definition" + some_task: "here" + new_cluster: + spark_version: "some" + policy_id: "cluster-policy://good-policy" + - name: "v2.1-inplace" + job_clusters: + - job_cluster_key: "base" + new_cluster: + spark_version: "some" + policy_id: "cluster-policy://good-policy" + tasks: + - task_key: "inplace" + new_cluster: + spark_version: "some" + policy_id: "cluster-policy://good-policy" + some_task: "here" + - task_key: "from-job-clusters" + job_cluster_key: "base" + some_task: "here" +""" +) + +TEST_CONFIG = DeploymentConfig.from_payload(TEST_DEFINITIONS) +ENVIRONMENT_DEFINITION = TEST_CONFIG.get_environment("default") + + +def test_locations(policy_mock): + wfs = TEST_CONFIG.get_environment("default").payload.workflows + core_pkg = Library(whl="/some/local/file") + client_mock = MagicMock() + _adj = Adjuster( + additional_libraries=AdditionalLibrariesProvider(core_package=core_pkg), + file_uploader=MagicMock(), + api_client=client_mock, + ) + _adj.traverse(wfs) + for element in [ + wfs[0].new_cluster, + wfs[1].get_task("inplace").new_cluster, + wfs[1].get_job_cluster_definition("base").new_cluster, + ]: + assert getattr(element, "spark_conf").get("spark.my.conf") == "my_value" + assert element.policy_id == "1" + + +@pytest.mark.parametrize( + "existing_init_scripts, expected", + [ + ( + [], + [ + {"dbfs": {"destination": "dbfs1"}}, + {"dbfs": {"destination": "dbfs2"}}, + {"s3": {"destination": "s31"}}, + {"s3": {"destination": "s32"}}, + ], + ), + ( + [ + {"dbfs": {"destination": "dbfs1"}}, + {"dbfs": {"destination": "dbfs2"}}, + {"s3": {"destination": "s31"}}, + {"s3": {"destination": "s32"}}, + ], + [ + {"dbfs": {"destination": "dbfs1"}}, + {"dbfs": {"destination": "dbfs2"}}, + {"s3": {"destination": "s31"}}, + {"s3": {"destination": "s32"}}, + ], + ), + ( + [ + {"dbfs": {"destination": "dbfs2"}}, + {"dbfs": {"destination": "dbfs3"}}, + {"s3": {"destination": "s32"}}, + {"s3": {"destination": "s33"}}, + ], + [ + {"dbfs": {"destination": "dbfs1"}}, + {"dbfs": {"destination": "dbfs2"}}, + {"s3": {"destination": "s31"}}, + {"s3": {"destination": "s32"}}, + {"dbfs": {"destination": "dbfs3"}}, + {"s3": {"destination": "s33"}}, + ], + ), + ], +) +def test_append_init_scripts(existing_init_scripts, expected): + policy_init_scripts = [ + {"dbfs": {"destination": "dbfs1"}}, + {"dbfs": {"destination": "dbfs2"}}, + {"s3": {"destination": "s31"}}, + {"s3": {"destination": "s32"}}, + ] + assert expected == PolicyAdjuster._append_init_scripts(policy_init_scripts, existing_init_scripts) + + +TESTS_WITH_SCRIPTS = yaml.safe_load( + """ +environments: + default: + workflows: + - name: "one-script" + some_task: "here" + new_cluster: + spark_version: "some" + policy_id: "cluster-policy://policy-with-one-script" + init_scripts: + - dbfs: + destination: "dbfs:/some/script.sh" + - name: "multiple-scripts" + some_task: "here" + new_cluster: + spark_version: "some" + policy_id: "cluster-policy://policy-with-multiple-scripts" + init_scripts: + - dbfs: + destination: "dbfs:/some/script.sh" + - name: "wrong-format" + some_task: "here" + new_cluster: + spark_version: "some" + policy_id: "cluster-policy://policy-with-one-script" + init_scripts: + - dbfs: + destination: "dbfs:/some/script.sh" +""" +) + + +@pytest.mark.parametrize( + "wf_name, amount_or_behaviour", [("one-script", 2), ("multiple-scripts", 3), ("wrong-format", Exception())] +) +def test_with_scripts(wf_name, amount_or_behaviour, policy_mock): + wf = DeploymentConfig.from_payload(TESTS_WITH_SCRIPTS).get_environment("default").payload.get_workflow(wf_name) + core_pkg = Library(whl="/some/local/file") + client_mock = MagicMock() + _adj = Adjuster( + additional_libraries=AdditionalLibrariesProvider(core_package=core_pkg), + file_uploader=MagicMock(), + api_client=client_mock, + ) + if isinstance(amount_or_behaviour, Exception): + with pytest.raises(Exception): + _adj.traverse([wf]) + else: + _adj.traverse([wf]) + assert not wf.new_cluster.policy_id.startswith("cluster-policy://") + assert len(getattr(wf.new_cluster, "init_scripts")) == amount_or_behaviour diff --git a/tests/unit/api/adjuster/test_service_principals.py b/tests/unit/api/adjuster/test_service_principals.py new file mode 100644 index 00000000..ee8c05eb --- /dev/null +++ b/tests/unit/api/adjuster/test_service_principals.py @@ -0,0 +1,67 @@ +from unittest.mock import MagicMock + +import pytest + +from dbx.api.adjuster.adjuster import Adjuster, AdditionalLibrariesProvider +from .test_instance_profile import convert_to_workflow + +TEST_PAYLOADS = { + "property": """ + name: "test" + tasks: + - task_key: "p1" + some_task: "a" + access_control_list: + - user_name: "service-principal://some-principal" + permission_level: "IS_OWNER" + """, + "duplicated": """ + name: "test" + tasks: + - task_key: "p1" + some_task: "a" + access_control_list: + - user_name: "service-principal://some-duplicated-principal" + permission_level: "IS_OWNER" + """, + "not_found": """ + name: "test" + tasks: + - task_key: "p1" + some_task: "a" + access_control_list: + - user_name: "service-principal://some-non-existent-principal" + permission_level: "IS_OWNER" + """, +} + + +@pytest.fixture +def service_principal_mock() -> MagicMock: + client = MagicMock() + client.perform_query = MagicMock( + return_value={ + "Resources": [ + {"displayName": "some-principal", "applicationId": "some-id"}, + {"displayName": "some-duplicated-principal", "applicationId": "some-id-1"}, + {"displayName": "some-duplicated-principal", "applicationId": "some-id-2"}, + ] + } + ) + return client + + +@pytest.mark.parametrize("key", list(TEST_PAYLOADS.keys())) +def test_service_principals(key, service_principal_mock): + _wf = convert_to_workflow(TEST_PAYLOADS[key]) + _adj = Adjuster( + api_client=service_principal_mock, + additional_libraries=AdditionalLibrariesProvider(core_package=None), + file_uploader=MagicMock(), + ) + if key in ["duplicated", "not_found"]: + with pytest.raises(AssertionError): + _adj.traverse(_wf) + else: + _adj.traverse(_wf) + assert _wf[0].access_control_list[0].user_name == "some-id" diff --git a/tests/unit/api/launch/test_functions.py b/tests/unit/api/launch/test_functions.py index 7d115a9f..ccf75951 100644 --- a/tests/unit/api/launch/test_functions.py +++ b/tests/unit/api/launch/test_functions.py @@ -3,14 +3,16 @@ import mlflow from pytest_mock import MockerFixture -from dbx.api.launch.functions import cancel_run, load_dbx_file, wait_run +from dbx.api.launch.functions import cancel_run, wait_run +from dbx.api.launch.runners.base import RunData +from dbx.api.storage.io import StorageIO from dbx.utils.json import JsonUtils def test_cancel(mocker: MockerFixture): wait_mock = mocker.patch("dbx.api.launch.functions.wait_run", MagicMock()) client = MagicMock() - cancel_run(client, {"run_id": 1}) + cancel_run(client, RunData(run_id=1)) wait_mock.assert_called() @@ -25,7 +27,7 @@ def test_wait_run(mocker: MockerFixture): ), ) client = MagicMock() - wait_run(client, {"run_id": 1}) + wait_run(client, RunData(**{"run_id": 1})) def test_load_file(tmp_path): @@ -34,5 +36,5 @@ def test_load_file(tmp_path): _file = tmp_path / "conf.json" JsonUtils.write(_file, content) mlflow.log_artifact(str(_file.absolute()), ".dbx") - _result = load_dbx_file(test_run.info.run_id, "conf.json") + _result = StorageIO.load(test_run.info.run_id, "conf.json") assert _result == content diff --git a/tests/unit/api/launch/test_pipeline_runner.py b/tests/unit/api/launch/test_pipeline_runner.py new file mode 100644 index 00000000..c82c66e2 --- /dev/null +++ b/tests/unit/api/launch/test_pipeline_runner.py @@ -0,0 +1,31 @@ +import pytest + +from dbx.api.launch.runners.base import PipelineUpdateResponse +from dbx.api.launch.runners.pipeline import PipelineLauncher, PipelinesRunPayload + +TEST_PIPELINE_ID = "aaa-bbb" +TEST_PIPELINE_UPDATE_PAYLOAD = {"update_id": "u1", "request_id": "r1"} + + +def test_basic(pipeline_launch_mock): + launcher = PipelineLauncher("some", api_client=pipeline_launch_mock) + process_info, object_id = launcher.launch() + assert object_id == TEST_PIPELINE_ID + assert process_info == PipelineUpdateResponse(**TEST_PIPELINE_UPDATE_PAYLOAD) + + +@pytest.mark.parametrize( + "payload, expected", + [ + ('{"full_refresh": "true"}', PipelinesRunPayload(full_refresh=True)), + ('{"refresh_selection": ["tab1"]}', PipelinesRunPayload(refresh_selection=["tab1"])), + ( + '{"refresh_selection": ["tab1"], "full_refresh_selection": ["tab2"]}', + PipelinesRunPayload(refresh_selection=["tab1"], full_refresh_selection=["tab2"]), + ), + ], +) +def test_with_parameters(payload, expected, pipeline_launch_mock): + launcher = PipelineLauncher("some", api_client=pipeline_launch_mock, parameters=payload) + assert launcher.parameters is not None + assert launcher.parameters == expected diff --git a/tests/unit/api/launch/test_processors.py b/tests/unit/api/launch/test_processors.py index 1f93bdfc..a0b632f1 100644 --- a/tests/unit/api/launch/test_processors.py +++ b/tests/unit/api/launch/test_processors.py @@ -1,18 +1,25 @@ from dbx.api.launch.processors import ClusterReusePreprocessor +from dbx.models.workflow.v2dot1.workflow import Workflow def test_preprocessor_positive(): - nc_payload = {"some_key": "some_value"} - nc_untouched_payload = {"some_key", "some_other_value"} + nc_payload = {"some_key": "some_value", "spark_version": "some-version"} + nc_untouched_payload = {"some_key": "some-other-value", "spark_version": "some-version"} test_spec = { + "name": "name", "job_clusters": [{"job_cluster_key": "main", "new_cluster": nc_payload}], "tasks": [ - {"task_key": "some-task", "job_cluster_key": "main"}, - {"task_key": "some-other-task", "new_cluster": nc_untouched_payload}, + {"task_key": "some-task", "job_cluster_key": "main", "spark_python_task": {"python_file": "here.py"}}, + { + "task_key": "some-other-task", + "new_cluster": nc_untouched_payload, + "spark_python_task": {"python_file": "here.py"}, + }, ], } - proc = ClusterReusePreprocessor(test_spec) - result_spec = proc.process() - assert "job_clusters" not in result_spec - assert result_spec["tasks"][0]["new_cluster"] == nc_payload - assert result_spec["tasks"][1]["new_cluster"] == nc_untouched_payload + wf = Workflow(**test_spec) + proc = ClusterReusePreprocessor() + proc.process(wf) + assert wf.job_clusters is None + assert wf.get_task("some-task").new_cluster.dict(exclude_none=True) == nc_payload + assert wf.get_task("some-other-task").new_cluster.dict(exclude_none=True) == nc_untouched_payload diff --git a/tests/unit/api/launch/test_runners.py b/tests/unit/api/launch/test_runners.py index 8c97293a..f8620a4e 100644 --- a/tests/unit/api/launch/test_runners.py +++ b/tests/unit/api/launch/test_runners.py @@ -1,76 +1,37 @@ from unittest.mock import MagicMock import pytest -from databricks_cli.sdk import JobsService +from databricks_cli.sdk import JobsService, ApiClient from pytest_mock import MockerFixture -from dbx.api.configure import ProjectConfigurationManager -from dbx.api.launch.runners import RunSubmitLauncher -from dbx.models.parameters.run_submit import RunSubmitV2d0ParamInfo, RunSubmitV2d1ParamInfo - - -def test_v2d0_parameter_override_negative(): - spec = {"spark_python_task": {"parameters": ["a"]}} - parameters = RunSubmitV2d0ParamInfo(notebook_task={"base_parameters": {"a": 1}}) - with pytest.raises(ValueError): - RunSubmitLauncher.override_v2d0_parameters(spec, parameters) - - -def test_v2d0_parameter_override_positive(): - spec = {"spark_python_task": {"parameters": ["a"]}} - parameters = RunSubmitV2d0ParamInfo(spark_python_task={"parameters": ["b"]}) - RunSubmitLauncher.override_v2d0_parameters(spec, parameters) - assert spec["spark_python_task"]["parameters"] == ["b"] - - -def test_vd21_parameter_override_no_tasks(): - spec = {"a": "b"} - parameters = RunSubmitV2d1ParamInfo(tasks=[{"task_key": "first", "spark_python_task": {"parameters": ["a"]}}]) - with pytest.raises(ValueError): - RunSubmitLauncher.override_v2d1_parameters(spec, parameters) - - -def test_vd21_parameter_override_no_task_key(): - spec = {"tasks": [{"task_key": "this", "spark_python_task": {"parameters": ["a"]}}]} - parameters = RunSubmitV2d1ParamInfo(tasks=[{"task_key": "that", "spark_python_task": {"parameters": ["a"]}}]) - with pytest.raises(ValueError): - RunSubmitLauncher.override_v2d1_parameters(spec, parameters) - - -def test_vd21_parameter_override_incorrect_type(): - spec = {"tasks": [{"task_key": "this", "python_wheel_task": {"parameters": ["a"]}}]} - parameters = RunSubmitV2d1ParamInfo(tasks=[{"task_key": "this", "spark_python_task": {"parameters": ["a"]}}]) - with pytest.raises(ValueError): - RunSubmitLauncher.override_v2d1_parameters(spec, parameters) - - -def test_vd21_parameter_override_positive(): - spec = {"tasks": [{"task_key": "this", "python_wheel_task": {"parameters": ["a"]}}]} - parameters = RunSubmitV2d1ParamInfo(tasks=[{"task_key": "this", "python_wheel_task": {"parameters": ["b"]}}]) - RunSubmitLauncher.override_v2d1_parameters(spec, parameters) - assert spec["tasks"][0]["python_wheel_task"]["parameters"] == ["b"] - - -def test_run_submit_reuse(temp_project, mocker: MockerFixture): - ProjectConfigurationManager().enable_failsafe_cluster_reuse() - service_mock = mocker.patch.object(JobsService, "submit_run", MagicMock()) - cluster_def = {"some_key": "some_value"} - mocker.patch( - "dbx.api.launch.runners.load_dbx_file", - MagicMock( - return_value={ - "default": { - "jobs": [ - { - "name": "test", - "job_clusters": [{"job_cluster_key": "some", "new_cluster": cluster_def}], - "tasks": [{"task_key": "one", "job_cluster_key": "some"}], - } - ] - } - } - ), - ) - launcher = RunSubmitLauncher(job="test", api_client=MagicMock(), deployment_run_id="aaa-bbb", environment="default") +from dbx.api.launch.runners.standard import StandardLauncher +from dbx.api.services.jobs import NamedJobsService +from dbx.models.cli.options import ExistingRunsOption + + +def test_not_found(mocker: MockerFixture): + mocker.patch.object(NamedJobsService, "find_by_name", MagicMock(return_value=None)) + launcher = StandardLauncher("non-existent", api_client=MagicMock(), existing_runs=ExistingRunsOption.pass_) + with pytest.raises(Exception): + launcher.launch() + + +@pytest.mark.parametrize( + "behaviour, msg", + [ + (ExistingRunsOption.pass_, "Passing the existing"), + (ExistingRunsOption.wait, "Waiting for job run"), + (ExistingRunsOption.cancel, "Cancelling run"), + ], +) +def test_with_behaviours(behaviour, msg, mocker: MockerFixture, capsys): + mocker.patch("dbx.api.launch.runners.standard.wait_run", MagicMock()) + mocker.patch("dbx.api.launch.runners.standard.cancel_run", MagicMock()) + client = MagicMock() + client.perform_query = MagicMock(return_value={"run_id": 1}) + mocker.patch.object(NamedJobsService, "find_by_name", MagicMock(return_value=1)) + mocker.patch.object(JobsService, "list_runs", MagicMock(return_value={"runs": [{"run_id": 1}]})) + + launcher = StandardLauncher("non-existent", api_client=client, existing_runs=behaviour) launcher.launch() - service_mock.assert_called_once_with(tasks=[{"task_key": "one", "new_cluster": cluster_def}]) + assert msg in capsys.readouterr().out diff --git a/tests/unit/api/launch/test_tracer.py b/tests/unit/api/launch/test_tracer.py index 1769a61e..106b3221 100644 --- a/tests/unit/api/launch/test_tracer.py +++ b/tests/unit/api/launch/test_tracer.py @@ -1,14 +1,44 @@ from unittest.mock import MagicMock +import pytest from pytest_mock import MockerFixture -from dbx.api.launch.tracer import RunTracer +from dbx.api.launch.pipeline_models import PipelineUpdateState, UpdateStatus +from dbx.api.launch.runners.base import RunData, PipelineUpdateResponse +from dbx.api.launch.tracer import RunTracer, PipelineTracer def test_tracer_with_interruption(mocker: MockerFixture): mocker.patch("dbx.api.launch.tracer.trace_run", MagicMock(side_effect=KeyboardInterrupt())) cancel_mock = MagicMock() mocker.patch("dbx.api.launch.tracer.cancel_run", cancel_mock) - _st, _ = RunTracer.start(kill_on_sigterm=True, api_client=MagicMock(), run_data={"run_id": 1}) + _st, _ = RunTracer.start(kill_on_sigterm=True, api_client=MagicMock(), run_data=RunData(run_id=1)) assert _st == "CANCELLED" cancel_mock.assert_called_once() + + +@pytest.fixture +def tracer_mock(): + client = MagicMock() + client.perform_query = MagicMock( + side_effect=[ + {"status": UpdateStatus.ACTIVE, "latest_update": {"update_id": "a", "state": PipelineUpdateState.CREATED}}, + { + "status": UpdateStatus.ACTIVE, + "latest_update": {"update_id": "a", "state": PipelineUpdateState.WAITING_FOR_RESOURCES}, + }, + {"status": UpdateStatus.ACTIVE, "latest_update": {"update_id": "a", "state": PipelineUpdateState.RUNNING}}, + {"status": UpdateStatus.ACTIVE, "latest_update": {"update_id": "a", "state": PipelineUpdateState.RUNNING}}, + { + "status": UpdateStatus.TERMINATED, + "latest_update": {"update_id": "a", "state": PipelineUpdateState.COMPLETED}, + }, + ] + ) + return client + + +def test_pipeline_tracer(tracer_mock): + pipeline_update = PipelineUpdateResponse(update_id="a", request_id="b") + final_state = PipelineTracer.start(api_client=tracer_mock, pipeline_id="aaa-bbb", process_info=pipeline_update) + assert final_state == PipelineUpdateState.COMPLETED diff --git a/tests/unit/api/storage/test_io.py b/tests/unit/api/storage/test_io.py new file mode 100644 index 00000000..362ba28b --- /dev/null +++ b/tests/unit/api/storage/test_io.py @@ -0,0 +1,11 @@ +import mlflow + +from dbx.api.storage.io import StorageIO + + +def test_storage_serde(): + payload = {"a": 1} + with mlflow.start_run() as _run: + StorageIO.save(payload, "content.json") + result = StorageIO.load(_run.info.run_id, "content.json") + assert result == payload diff --git a/tests/unit/api/storage/test_mlflow_storage.py b/tests/unit/api/storage/test_mlflow_storage.py index bdf49ce1..60b463b5 100644 --- a/tests/unit/api/storage/test_mlflow_storage.py +++ b/tests/unit/api/storage/test_mlflow_storage.py @@ -6,11 +6,12 @@ from pytest_mock import MockerFixture from dbx.api.storage.mlflow_based import MlflowStorageConfigurationManager -from dbx.models.project import EnvironmentInfo, MlflowStorageProperties +from dbx.models.files.project import EnvironmentInfo, MlflowStorageProperties +from dbx.utils.url import strip_databricks_url def test_url_strip(): - _stripped = MlflowStorageConfigurationManager._strip_url("https://some-location.com/with/some-postfix") + _stripped = strip_databricks_url("https://some-location.com/with/some-postfix") assert _stripped == "https://some-location.com" diff --git a/tests/unit/api/test_build.py b/tests/unit/api/test_build.py index 4a6818a8..5826efd1 100644 --- a/tests/unit/api/test_build.py +++ b/tests/unit/api/test_build.py @@ -3,21 +3,26 @@ from pytest_mock import MockerFixture -from dbx.api.build import prepare_build -from dbx.models.deployment import BuildConfiguration +from dbx.models.build import BuildConfiguration def test_empty(capsys): - prepare_build(BuildConfiguration(no_build=True)) + BuildConfiguration(no_build=True).trigger_build_process() res = capsys.readouterr() assert "No build actions will be performed" in res.out -def test_commands(mocker: MockerFixture, capsys): +def test_no_action(capsys): + BuildConfiguration(no_build=False, commands=None, python=None).trigger_build_process() + res = capsys.readouterr() + assert "skipping the build stage" in res.out + + +def test_commands(temp_project, mocker: MockerFixture, capsys): exec_mock = MagicMock() - mocker.patch("dbx.api.build.execute_shell_command", exec_mock) + mocker.patch("dbx.models.build.execute_shell_command", exec_mock) conf = BuildConfiguration(commands=["sleep 1", "sleep 2", "sleep 3"]) - prepare_build(conf) + conf.trigger_build_process() res = capsys.readouterr() assert "Running the build commands" in res.out assert exec_mock.call_count == 3 @@ -44,12 +49,12 @@ def test_poetry(temp_project): (temp_project / "pyproject.toml").write_text(inspect.cleandoc(pyproject_content)) conf = BuildConfiguration(python="poetry") - prepare_build(conf) + conf.trigger_build_process() -def test_flit(mocker: MockerFixture): +def test_flit(temp_project, mocker: MockerFixture): exec_mock = MagicMock() - mocker.patch("dbx.api.build.execute_shell_command", exec_mock) + mocker.patch("dbx.models.build.execute_shell_command", exec_mock) conf = BuildConfiguration(python="flit") - prepare_build(conf) + conf.trigger_build_process() exec_mock.assert_called_once_with(cmd="-m flit build --format wheel", with_python_executable=True) diff --git a/tests/unit/api/test_context.py b/tests/unit/api/test_context.py index 9693a210..d666d7f3 100644 --- a/tests/unit/api/test_context.py +++ b/tests/unit/api/test_context.py @@ -1,5 +1,5 @@ from dbx.api.context import LocalContextManager -from dbx.models.context import ContextInfo +from dbx.models.files.context import ContextInfo def test_local_context_serde(temp_project): diff --git a/tests/unit/api/test_deployment.py b/tests/unit/api/test_deployment.py new file mode 100644 index 00000000..aaf3e078 --- /dev/null +++ b/tests/unit/api/test_deployment.py @@ -0,0 +1,42 @@ +from unittest.mock import MagicMock + +import yaml +from pytest_mock import MockerFixture + +from dbx.api.deployment import WorkflowDeploymentManager +from dbx.api.services.jobs import NamedJobsService +from dbx.api.services.pipelines import NamedPipelinesService +from dbx.models.deployment import EnvironmentDeploymentInfo + +TEST_PAYLOAD = """ + workflows: + - name: "some-wf" + access_control_list: + - user_name: "some@email.com" + permission_level: "IS_OWNER" + tasks: + - task_key: "some-task" + python_wheel_task: + entry_point: "some-ep" + package_name: "some-pkg" + - name: "some-pipeline" + workflow_type: "pipeline" + target: "target_db" + libraries: + - notebook: + path: "/Some/path/in/repos" + """ + +WFS = EnvironmentDeploymentInfo.from_spec("default", yaml.safe_load(TEST_PAYLOAD)) + + +def test_basic(): + mgr = WorkflowDeploymentManager(api_client=MagicMock(), workflows=WFS.payload.workflows) + mgr.apply() + + +def test_update_basic(mocker: MockerFixture): + mocker.patch.object(NamedJobsService, "find_by_name", MagicMock(return_value=1)) + mocker.patch.object(NamedPipelinesService, "find_by_name", MagicMock(return_value="aaa-bbb")) + mgr = WorkflowDeploymentManager(api_client=MagicMock(), workflows=WFS.payload.workflows) + mgr.apply() diff --git a/tests/unit/api/test_destroyer.py b/tests/unit/api/test_destroyer.py index b172094c..cbc64f2a 100644 --- a/tests/unit/api/test_destroyer.py +++ b/tests/unit/api/test_destroyer.py @@ -4,15 +4,18 @@ import mlflow import pytest -from databricks_cli.jobs.api import JobsApi from pytest_mock import MockerFixture from dbx.api.destroyer import Destroyer, WorkflowEraser, AssetEraser +from dbx.api.services.jobs import NamedJobsService +from dbx.models.cli.destroyer import DestroyerConfig, DeletionMode from dbx.models.deployment import EnvironmentDeploymentInfo -from dbx.models.destroyer import DestroyerConfig, DeletionMode +from dbx.models.workflow.v2dot1.workflow import Workflow from tests.unit.conftest import invoke_cli_runner -test_env_info = EnvironmentDeploymentInfo(name="default", payload={"workflows": [{"name": "w1"}]}) +test_env_info = EnvironmentDeploymentInfo( + name="default", payload={"workflows": [{"name": "w1", "workflow_type": "job-v2.0", "some_task": "t"}]} +) basic_config = partial(DestroyerConfig, deployment=test_env_info) @@ -44,9 +47,8 @@ def test_destroyer_modes(conf, wf_expected, assets_expected, temp_project, capsy @pytest.mark.parametrize("dry_run", [True, False]) def test_workflow_eraser_dr(dry_run, capsys, mocker: MockerFixture): - api_client = MagicMock() - mocker.patch.object(JobsApi, "_list_jobs_by_name", MagicMock(return_value=[{"name": "w1", "job_id": "1"}])) - eraser = WorkflowEraser(api_client, ["w1"], dry_run=dry_run) + mocker.patch.object(NamedJobsService, "find_by_name", MagicMock(return_value=1)) + eraser = WorkflowEraser(MagicMock(), [Workflow(name="test", some_task="here")], dry_run=dry_run) eraser.erase() out = capsys.readouterr().out _test = "would be deleted" in out @@ -54,17 +56,16 @@ def test_workflow_eraser_dr(dry_run, capsys, mocker: MockerFixture): @pytest.mark.parametrize( - "job_response, expected", + "job_response,expected", [ - ([{"name": "w1", "job_id": "1"}, {"name": "w1", "job_id": "1"}], Exception()), - ([], "doesn't exist"), - ([{"name": "w1", "job_id": "1"}], "w1 was successfully deleted"), + (None, "doesn't exist"), + (1, "was successfully deleted"), ], ) def test_workflow_eraser_list(job_response, expected, capsys, mocker: MockerFixture): api_client = MagicMock() - mocker.patch.object(JobsApi, "_list_jobs_by_name", MagicMock(return_value=job_response)) - eraser = WorkflowEraser(api_client, ["w1"], dry_run=False) + mocker.patch.object(NamedJobsService, "find_by_name", MagicMock(return_value=job_response)) + eraser = WorkflowEraser(api_client, [Workflow(name="test", some_task="here")], dry_run=False) if isinstance(expected, Exception): with pytest.raises(Exception): eraser.erase() diff --git a/tests/unit/api/test_jinja.py b/tests/unit/api/test_jinja.py index af973b8a..0f934542 100644 --- a/tests/unit/api/test_jinja.py +++ b/tests/unit/api/test_jinja.py @@ -1,3 +1,4 @@ +import time from pathlib import Path from textwrap import dedent from typing import Tuple @@ -15,6 +16,7 @@ def temp_with_file(tmp_path) -> Tuple[Path, Path]: _content_path.mkdir(exist_ok=True) (_content_path / "file1.dat").write_bytes(b"a") + time.sleep(3) # for CI stability, the FS in the CI run is pretty slow and might fail the order of writes last = _content_path / "file2.dat" last.write_bytes(b"b") (_content_path / "file2.ndat").write_bytes(b"b") diff --git a/tests/unit/api/test_jobs_service.py b/tests/unit/api/test_jobs_service.py new file mode 100644 index 00000000..a7e2e69a --- /dev/null +++ b/tests/unit/api/test_jobs_service.py @@ -0,0 +1,25 @@ +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from dbx.api.services.jobs import NamedJobsService, JobListing, ListJobsResponse + + +def test_duplicated_jobs(mocker: MockerFixture): + mocker.patch.object( + JobListing, + "by_name", + MagicMock( + return_value=ListJobsResponse( + **{ + "jobs": [ + {"job_id": 1, "settings": {"name": "dup"}}, + {"job_id": 1, "settings": {"name": "dup"}}, + ] + } + ) + ), + ) + with pytest.raises(Exception): + NamedJobsService(api_client=MagicMock()).find_by_name("dup") diff --git a/tests/unit/commands/test_deploy.py b/tests/unit/commands/test_deploy.py index 9c404182..0fe13d2c 100644 --- a/tests/unit/commands/test_deploy.py +++ b/tests/unit/commands/test_deploy.py @@ -1,24 +1,17 @@ import shutil +import textwrap from pathlib import Path -from unittest.mock import Mock, MagicMock +from unittest.mock import MagicMock -import mlflow import pytest import yaml -from databricks_cli.sdk import ApiClient, JobsService -from requests import HTTPError +from pytest_mock import MockerFixture from dbx.api.config_reader import ConfigReader from dbx.api.configure import ProjectConfigurationManager, EnvironmentInfo +from dbx.api.services.jobs import NamedJobsService from dbx.api.storage.mlflow_based import MlflowStorageConfigurationManager -from dbx.commands.deploy import ( # noqa - _create_job, - _log_dbx_file, - _preprocess_jobs, - _update_job, - deploy, -) -from dbx.models.project import MlflowStorageProperties +from dbx.models.files.project import MlflowStorageProperties from dbx.utils.json import JsonUtils from tests.unit.conftest import ( get_path_with_relation_to_current_file, @@ -26,26 +19,40 @@ ) -def test_deploy_smoke_default(temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client): +def test_deploy_smoke_default(temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client): deploy_result = invoke_cli_runner("deploy") assert deploy_result.exit_code == 0 -def test_deploy_files_only_smoke_default( - temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client +@pytest.mark.parametrize("argset", [["--files-only"], ["--assets-only"], ["--no-rebuild"]]) +def test_deploy_assets_only_smoke_default( + argset, temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client ): - deploy_result = invoke_cli_runner(["deploy", "--files-only"]) + deploy_result = invoke_cli_runner(["deploy"] + argset) assert deploy_result.exit_code == 0 -def test_deploy_assets_only_smoke_default( - temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client -): - deploy_result = invoke_cli_runner(["deploy", "--assets-only"]) - assert deploy_result.exit_code == 0 +def test_deploy_assets_pipeline(temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client): + (temp_project / "conf" / "deployment.yml").write_text( + """ + environments: + default: + workflows: + - name: "pipe" + workflow_type: "pipeline" + libraries: + - notebook: + path: "/some/path" + """ + ) + deploy_result = invoke_cli_runner(["deploy", "--assets-only"], expected_error=True) + assert "not supported for DLT pipelines" in str(deploy_result.exception) -def test_deploy_multitask_smoke(mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client, temp_project): +def test_deploy_multitask_smoke( + mlflow_file_uploader, mocker: MockerFixture, mock_storage_io, mock_api_v2_client, temp_project +): + mocker.patch.object(NamedJobsService, "create", MagicMock(return_value=1)) samples_path = get_path_with_relation_to_current_file("../deployment-configs/") for file_name in ["03-multitask-job.json", "03-multitask-job.yaml"]: deployment_file = Path("./conf/") / file_name @@ -66,11 +73,12 @@ def test_deploy_multitask_smoke(mlflow_file_uploader, mock_dbx_file_upload, mock ) assert deploy_result.exit_code == 0 _content = JsonUtils.read(Path(".dbx/deployment-result.json")) - assert "libraries" not in _content["default"]["jobs"][0] - assert "libraries" in _content["default"]["jobs"][0]["tasks"][0] + assert "libraries" not in _content["default"]["workflows"][0] + assert "libraries" in _content["default"]["workflows"][0]["tasks"][0] -def test_deploy_path_adjustment_json(mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client, temp_project): +def test_deploy_path_adjustment_json(mlflow_file_uploader, mocker, mock_storage_io, mock_api_v2_client, temp_project): + mocker.patch.object(NamedJobsService, "create", MagicMock(return_value=1)) samples_path = get_path_with_relation_to_current_file("../deployment-configs/") for file_name in ["04-path-adjustment-policy.json", "04-path-adjustment-policy.yaml"]: deployment_file = Path("./conf/") / file_name @@ -91,10 +99,11 @@ def test_deploy_path_adjustment_json(mlflow_file_uploader, mock_dbx_file_upload, ], ) _content = JsonUtils.read(Path(".dbx/deployment-result.json")) - expected_prefix = mlflow.get_tracking_uri() - assert _content["default"]["jobs"][0]["libraries"][0]["whl"].startswith(expected_prefix) - assert _content["default"]["jobs"][0]["spark_python_task"]["python_file"].startswith(expected_prefix) - assert _content["default"]["jobs"][0]["spark_python_task"]["parameters"][0].startswith(expected_prefix) + expected_prefix = "dbfs:/mocks/testing" + + assert _content["default"]["workflows"][0]["libraries"][0]["whl"].startswith(expected_prefix) + assert _content["default"]["workflows"][0]["spark_python_task"]["python_file"].startswith(expected_prefix) + assert _content["default"]["workflows"][0]["spark_python_task"]["parameters"][0].startswith("/dbfs/") assert deploy_result.exit_code == 0 @@ -133,46 +142,39 @@ def test_non_existent_env(mock_api_v2_client, temp_project): assert "not found in the deployment file" in str(deploy_result.exception) -def test_deploy_only_chosen_workflow(mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client, temp_project): +def test_deploy_only_chosen_workflow(mlflow_file_uploader, mocker, mock_storage_io, mock_api_v2_client, temp_project): + mocker.patch.object(NamedJobsService, "create", MagicMock(return_value=1)) result_file = ".dbx/deployment-result.json" deployment_info = ConfigReader(Path("conf/deployment.yml")).get_environment("default") - _chosen = [j["name"] for j in deployment_info.payload.workflows][0] + _chosen = deployment_info.payload.workflow_names[0] deploy_result = invoke_cli_runner( ["deploy", "--environment=default", f"--write-specs-to-file={result_file}", _chosen], ) assert deploy_result.exit_code == 0 _content = JsonUtils.read(Path(result_file)) - assert _chosen in [j["name"] for j in _content["default"]["jobs"]] - - -def test_deploy_only_chosen_jobs(mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client, temp_project): - result_file = ".dbx/deployment-result.json" - deployment_info = ConfigReader(Path("conf/deployment.yml")).get_environment("default") - _chosen = [j["name"] for j in deployment_info.payload.workflows][:2] - deploy_result = invoke_cli_runner( - ["deploy", "--environment", "default", "--jobs", ",".join(_chosen), "--write-specs-to-file", result_file], - ) - assert deploy_result.exit_code == 0 - _content = JsonUtils.read(Path(result_file)) - assert _chosen == [j["name"] for j in _content["default"]["jobs"]] + assert [w["name"] for w in _content["default"]["workflows"]] == [_chosen] -def test_deploy_only_chosen_workflows(mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client, temp_project): +@pytest.mark.parametrize("workflow_arg", ["--workflows", "--jobs"]) +def test_deploy_only_chosen( + workflow_arg, mlflow_file_uploader, mocker, mock_storage_io, mock_api_v2_client, temp_project +): + mocker.patch.object(NamedJobsService, "create", MagicMock(return_value=1)) result_file = ".dbx/deployment-result.json" deployment_info = ConfigReader(Path("conf/deployment.yml")).get_environment("default") - _chosen = [j["name"] for j in deployment_info.payload.workflows][:2] + _chosen = deployment_info.payload.workflow_names[:2] deploy_result = invoke_cli_runner( - ["deploy", "--environment", "default", "--workflows", ",".join(_chosen), "--write-specs-to-file", result_file], + ["deploy", "--environment", "default", workflow_arg, ",".join(_chosen), "--write-specs-to-file", result_file], ) assert deploy_result.exit_code == 0 _content = JsonUtils.read(Path(result_file)) - assert _chosen == [j["name"] for j in _content["default"]["jobs"]] + assert [w["name"] for w in _content["default"]["workflows"]] == _chosen -def test_negative_both_arguments(mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client, temp_project): +def test_negative_both_arguments(mlflow_file_uploader, mock_storage_io, mock_api_v2_client, temp_project): result_file = ".dbx/deployment-result.json" deployment_info = ConfigReader(Path("conf/deployment.yml")).get_environment("default") - _chosen = [j["name"] for j in deployment_info.payload.workflows][:2] + _chosen = deployment_info.payload.workflow_names[0] deploy_result = invoke_cli_runner( [ "deploy", @@ -190,9 +192,7 @@ def test_negative_both_arguments(mlflow_file_uploader, mock_dbx_file_upload, moc assert "cannot be provided together" in str(deploy_result.exception) -def test_deploy_with_requirements_and_branch( - mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client, temp_project -): +def test_deploy_with_requirements_and_branch(mlflow_file_uploader, mock_storage_io, mock_api_v2_client, temp_project): sample_requirements = "\n".join(["pyspark==3.0.0", "xgboost==0.6.0", "pyspark3d"]) Path("runtime_requirements.txt").write_text(sample_requirements) @@ -214,31 +214,8 @@ def test_deploy_with_requirements_and_branch( assert deploy_result.exit_code == 0 -def test_smoke_update_job_positive(): - js = Mock(JobsService) - _update_job(js, "aa-bbb-ccc-111", {"name": 1}) - - -def test_smoke_update_job_negative(): - js = Mock(JobsService) - js.reset_job.side_effect = Mock(side_effect=HTTPError()) - with pytest.raises(HTTPError): - _update_job(js, "aa-bbb-ccc-111", {"name": 1}) - - -def test_create_job_with_error(): - client = Mock(ApiClient) - client.perform_query.side_effect = Mock(side_effect=HTTPError()) - with pytest.raises(HTTPError): - _create_job(client, {"name": "some-job"}) - - -def test_preprocess_jobs(): - with pytest.raises(Exception): - _preprocess_jobs([], ["some-job-name"]) - - -def test_with_permissions(mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client, temp_project): +def test_with_permissions(mocker, mlflow_file_uploader, mock_storage_io, mock_api_v2_client, temp_project): + mocker.patch.object(NamedJobsService, "create", MagicMock(return_value=1)) deployment_file = Path("conf/deployment.yml") deploy_content = yaml.safe_load(deployment_file.read_text()) @@ -261,7 +238,7 @@ def test_with_permissions(mlflow_file_uploader, mock_dbx_file_upload, mock_api_v assert deploy_result.exit_code == 0 -def test_jinja_custom_path(mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client, temp_project: Path): +def test_jinja_custom_path(mlflow_file_uploader, mock_storage_io, mock_api_v2_client, temp_project: Path): samples_path = get_path_with_relation_to_current_file("../deployment-configs/") nested_config_dir = samples_path / "nested-configs" shutil.copytree(nested_config_dir, temp_project.parent / "configs") @@ -271,14 +248,14 @@ def test_jinja_custom_path(mlflow_file_uploader, mock_dbx_file_upload, mock_api_ assert deploy_result.exit_code == 0 -def test_update_job_v21_with_permissions(): - _client = MagicMock(spec=ApiClient) - _jobs_service = JobsService(_client) - acl_definition = {"access_control_list": [{"user_name": "test@user.com", "permission_level": "IS_OWNER"}]} - job_definition = { - "name": "test", - } - job_definition.update(acl_definition) # noqa - - _update_job(_jobs_service, "1", job_definition) - _client.perform_query.assert_called_with("PUT", "/permissions/jobs/1", data=acl_definition) +def test_deploy_empty_workflows_list(temp_project, mlflow_file_uploader, mock_storage_io, mock_api_v2_client): + payload = textwrap.dedent( + """\ + environments: + default: + workflows: [] + """ + ) + Path("conf/deployment.yml").write_text(payload) + deploy_result = invoke_cli_runner("deploy") + assert deploy_result.exit_code == 0 diff --git a/tests/unit/commands/test_deploy_jinja_variables_file.py b/tests/unit/commands/test_deploy_jinja_variables_file.py index 1837ece3..34331582 100644 --- a/tests/unit/commands/test_deploy_jinja_variables_file.py +++ b/tests/unit/commands/test_deploy_jinja_variables_file.py @@ -9,19 +9,19 @@ ) -def test_incorrect_file_name(temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client): +def test_incorrect_file_name(temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client): deploy_result = invoke_cli_runner(["deploy", "--jinja-variables-file", "some-file.py"], expected_error=True) assert "Jinja variables file shall be provided" in str(deploy_result.exception) -def test_non_existent_file(temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client): +def test_non_existent_file(temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client): deploy_result = invoke_cli_runner( ["deploy", "--jinja-variables-file", "some-non-existent.yml"], expected_error=True ) assert "file is non-existent" in str(deploy_result.exception) -def test_passed_with_unsupported(temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client): +def test_passed_with_unsupported(temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client): file_name = "jinja-template-variables-file.yaml" src_vars_file = get_path_with_relation_to_current_file(f"../deployment-configs/jinja-vars/{file_name}") dst_vars_file = Path("./conf") / file_name @@ -33,7 +33,7 @@ def test_passed_with_unsupported(temp_project: Path, mlflow_file_uploader, mock_ assert "deployment file is not based on Jinja" in str(deploy_result.exception) -def test_passed_with_inplace(temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client): +def test_passed_with_inplace(temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client): invoke_cli_runner(["configure", "--enable-inplace-jinja-support"]) file_name = "jinja-template-variables-file.yaml" src_vars_file = get_path_with_relation_to_current_file(f"../deployment-configs/jinja-vars/{file_name}") @@ -43,7 +43,7 @@ def test_passed_with_inplace(temp_project: Path, mlflow_file_uploader, mock_dbx_ assert isinstance(rdr, Jinja2ConfigReader) -def test_jinja_vars_file_api(mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client, temp_project: Path): +def test_jinja_vars_file_api(mlflow_file_uploader, mock_storage_io, mock_api_v2_client, temp_project: Path): jinja_vars_dir = get_path_with_relation_to_current_file("../deployment-configs/jinja-vars/") project_config_dir = temp_project / "conf" vars_file = Path("./conf/jinja-template-variables-file.yaml") @@ -59,7 +59,7 @@ def test_jinja_vars_file_api(mlflow_file_uploader, mock_dbx_file_upload, mock_ap assert _def == definitions[0] -def test_jinja_vars_file_cli(mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client, temp_project: Path): +def test_jinja_vars_file_cli(mlflow_file_uploader, mock_storage_io, mock_api_v2_client, temp_project: Path): deployment_file_name = "09-jinja-with-custom-vars.yaml.j2" vars_file_name = "jinja-template-variables-file.yaml" project_config_dir = temp_project / "conf" diff --git a/tests/unit/commands/test_destroy.py b/tests/unit/commands/test_destroy.py index 7892f940..4485c374 100644 --- a/tests/unit/commands/test_destroy.py +++ b/tests/unit/commands/test_destroy.py @@ -9,7 +9,7 @@ from dbx.api.config_reader import ConfigReader from dbx.api.destroyer import Destroyer from dbx.commands.destroy import ask_for_confirmation -from dbx.models.destroyer import DestroyerConfig, DeletionMode +from dbx.models.cli.destroyer import DestroyerConfig, DeletionMode from tests.unit.conftest import invoke_cli_runner @@ -18,7 +18,8 @@ def base_config(temp_project): config_reader = ConfigReader(Path("conf/deployment.yml"), None) config = config_reader.get_config() deployment = config.get_environment("default", raise_if_not_found=True) - return partial(DestroyerConfig, dracarys=False, deployment=deployment) + wfs = deployment.payload.select_relevant_or_all_workflows() + return partial(DestroyerConfig, dracarys=False, deployment=deployment, workflows=wfs) def test_ask_for_confirmation_positive(monkeypatch, base_config): diff --git a/tests/unit/commands/test_execute.py b/tests/unit/commands/test_execute.py index d7e7ac6a..6e152aac 100644 --- a/tests/unit/commands/test_execute.py +++ b/tests/unit/commands/test_execute.py @@ -5,7 +5,7 @@ from dbx.api.cluster import ClusterController from dbx.api.context import LocalContextManager -from dbx.models.context import ContextInfo +from dbx.models.files.context import ContextInfo from tests.unit.conftest import invoke_cli_runner @@ -39,7 +39,7 @@ def test_smoke_execute( mock_api_v2_client, mock_local_context_manager, mlflow_file_uploader, - mock_dbx_file_upload, + mock_storage_io, ): # noqa with patch( "dbx.api.client_provider.ApiV1Client.get_command_status", @@ -72,7 +72,7 @@ def test_smoke_execute_workflow( mock_api_v2_client, mock_local_context_manager, mlflow_file_uploader, - mock_dbx_file_upload, + mock_storage_io, ): # noqa with patch( "dbx.api.client_provider.ApiV1Client.get_command_status", @@ -101,7 +101,44 @@ def test_smoke_execute_spark_python_task( mock_api_v2_client, mock_local_context_manager, mlflow_file_uploader, - mock_dbx_file_upload, + mock_storage_io, + mocker, +): # noqa + mocker.patch( + "dbx.api.client_provider.ApiV1Client.get_command_status", + MagicMock( + return_value={ + "status": "Finished", + "results": {"resultType": "Ok", "data": "Ok!"}, + } + ), + ) + execute_result = invoke_cli_runner( + [ + "execute", + "--deployment-file", + "conf/deployment.yml", + "--environment", + "default", + "--cluster-id", + "000-some-cluster-id", + "--job", + f"{temp_project.name}-sample-multitask", + "--task", + "etl", + ], + ) + + assert execute_result.exit_code == 0 + + +def test_smoke_execute_python_wheel_task( + temp_project, + mock_api_v1_client, + mock_api_v2_client, + mock_local_context_manager, + mlflow_file_uploader, + mock_storage_io, ): # noqa with patch( "dbx.api.client_provider.ApiV1Client.get_command_status", @@ -122,27 +159,30 @@ def test_smoke_execute_spark_python_task( "--job", f"{temp_project.name}-sample-multitask", "--task", - "etl", + "ml", ], ) assert execute_result.exit_code == 0 -def test_smoke_execute_python_wheel_task( +@pytest.mark.parametrize("param_set", ['{"parameters": ["a", 1]}', '{"named_parameters": {"a":1}}']) +def test_smoke_execute_python_wheel_task_with_params( + param_set, temp_project, mock_api_v1_client, mock_api_v2_client, mock_local_context_manager, mlflow_file_uploader, - mock_dbx_file_upload, + mock_storage_io, ): # noqa + mock_retval = { + "status": "Finished", + "results": {"resultType": "Ok", "data": "Ok!"}, + } with patch( "dbx.api.client_provider.ApiV1Client.get_command_status", - return_value={ - "status": "Finished", - "results": {"resultType": "Ok", "data": "Ok!"}, - }, + return_value=mock_retval, ): execute_result = invoke_cli_runner( [ @@ -157,58 +197,21 @@ def test_smoke_execute_python_wheel_task( f"{temp_project.name}-sample-multitask", "--task", "ml", + "--parameters", + param_set, ], ) assert execute_result.exit_code == 0 -def test_smoke_execute_python_wheel_task_with_params( - temp_project, - mock_api_v1_client, - mock_api_v2_client, - mock_local_context_manager, - mlflow_file_uploader, - mock_dbx_file_upload, -): # noqa - _params_options = ['{"parameters": ["a", 1]}', '{"named_parameters": ["--a=1", "--b=1"]}'] - mock_retval = { - "status": "Finished", - "results": {"resultType": "Ok", "data": "Ok!"}, - } - for _params in _params_options: - with patch( - "dbx.api.client_provider.ApiV1Client.get_command_status", - return_value=mock_retval, - ): - execute_result = invoke_cli_runner( - [ - "execute", - "--deployment-file", - "conf/deployment.yml", - "--environment", - "default", - "--cluster-id", - "000-some-cluster-id", - "--job", - f"{temp_project.name}-sample-multitask", - "--task", - "ml", - "--parameters", - _params, - ], - ) - - assert execute_result.exit_code == 0 - - def test_smoke_execute_spark_python_task_with_params( temp_project, mock_api_v1_client, mock_api_v2_client, mock_local_context_manager, mlflow_file_uploader, - mock_dbx_file_upload, + mock_storage_io, ): # noqa mock_retval = { "status": "Finished", @@ -256,28 +259,31 @@ def test_smoke_execute_spark_python_task_with_params( ) def test_preprocess_cluster_args(*_): # noqa api_client = Mock(ApiClient) - controller = ClusterController(api_client) with pytest.raises(RuntimeError): - controller.preprocess_cluster_args(None, None) + ClusterController(api_client, cluster_name=None, cluster_id=None) - id_by_name = controller.preprocess_cluster_args("some-cluster-name", None) - assert id_by_name == "aaa-111" + c1 = ClusterController(api_client, "some-cluster-name", None) + assert c1.cluster_id == "aaa-111" - id_by_id = controller.preprocess_cluster_args(None, "aaa-bbb-ccc") - assert id_by_id == "aaa-bbb-ccc" + c2 = ClusterController(api_client, None, "aaa-bbb-ccc") + assert c2.cluster_id == "aaa-bbb-ccc" - negative_funcs = [ - lambda: controller.preprocess_cluster_args("non-existent-cluster-by-name", None), - lambda: controller.preprocess_cluster_args("duplicated-name", None), - lambda: controller.preprocess_cluster_args(None, "non-existent-id"), + negative_controllers = [ + lambda: ClusterController(api_client, "non-existent-cluster-by-name", None), + lambda: ClusterController(api_client, "duplicated-name", None), + lambda: ClusterController(api_client, None, "non-existent-id"), ] - for func in negative_funcs: + for _c in negative_controllers: with pytest.raises(NameError): - func() + _c() -def test_awake_cluster(): +@patch( + "databricks_cli.clusters.api.ClusterService.get_cluster", + side_effect=lambda cid: "something" if cid in ("aaa-bbb-ccc", "aaa-111") else None, +) +def test_awake_cluster(*_): # normal behavior client_mock = MagicMock() side_effect = [ @@ -287,11 +293,11 @@ def test_awake_cluster(): {"state": "RUNNING"}, ] with patch.object(ClusterService, "get_cluster", side_effect=side_effect) as cluster_service_mock: - controller = ClusterController(client_mock) - controller.awake_cluster("aaa-bbb") + controller = ClusterController(client_mock, None, "aaa-bbb-ccc") + controller.awake_cluster() assert cluster_service_mock("aaa-bbb").get("state") == "RUNNING" with patch.object(ClusterService, "get_cluster", return_value={"state": "ERROR"}): - controller = ClusterController(client_mock) + controller = ClusterController(client_mock, None, "aaa-bbb-ccc") with pytest.raises(RuntimeError): - controller.awake_cluster("aaa-bbb") + controller.awake_cluster() diff --git a/tests/unit/commands/test_launch.py b/tests/unit/commands/test_launch.py index 38872639..5373bb1b 100644 --- a/tests/unit/commands/test_launch.py +++ b/tests/unit/commands/test_launch.py @@ -1,13 +1,21 @@ +import textwrap from pathlib import Path from typing import List, Optional from unittest.mock import MagicMock, PropertyMock +import pytest from databricks_cli.sdk import JobsService from pytest_mock import MockFixture from dbx.api.client_provider import DatabricksClientProvider from dbx.api.config_reader import ConfigReader -from dbx.api.launch.tracer import RunTracer +from dbx.api.launch.pipeline_models import PipelineUpdateState +from dbx.api.launch.runners.base import PipelineUpdateResponse +from dbx.api.launch.runners.pipeline import PipelineLauncher +from dbx.api.launch.tracer import RunTracer, PipelineTracer +from dbx.api.services.jobs import JobListing, ListJobsResponse +from dbx.api.services.pipelines import NamedPipelinesService +from dbx.api.storage.io import StorageIO from dbx.utils.json import JsonUtils from tests.unit.conftest import invoke_cli_runner @@ -19,7 +27,7 @@ def deploy_and_get_job_name(deploy_args: Optional[List[str]] = None) -> str: deploy_result = invoke_cli_runner(["deploy"] + deploy_args) assert deploy_result.exit_code == 0 deployment_info = ConfigReader(Path("conf/deployment.yml")).get_environment("default") - _chosen_job = deployment_info.payload.workflows[0]["name"] + _chosen_job = deployment_info.payload.workflows[0].name return _chosen_job @@ -34,7 +42,8 @@ def prepare_job_service_mock(mocker: MockFixture, job_name): } ] } - mocker.patch.object(JobsService, "list_jobs", MagicMock(return_value=jobs_payload)) + response = ListJobsResponse(**jobs_payload) + mocker.patch.object(JobListing, "by_name", MagicMock(return_value=response)) def prepare_tracing_mock(mocker: MockFixture, final_result_state: str): @@ -63,7 +72,7 @@ def prepare_tracing_mock(mocker: MockFixture, final_result_state: str): def test_smoke_launch( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client + mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client ): _chosen_job = deploy_and_get_job_name() prepare_job_service_mock(mocker, _chosen_job) @@ -77,7 +86,7 @@ def test_smoke_launch( def test_smoke_launch_workflow( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client + mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client ): _chosen_job = deploy_and_get_job_name() prepare_job_service_mock(mocker, _chosen_job) @@ -87,7 +96,7 @@ def test_smoke_launch_workflow( def test_launch_no_arguments( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client + mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client ): _chosen_job = deploy_and_get_job_name() prepare_job_service_mock(mocker, _chosen_job) @@ -97,7 +106,7 @@ def test_launch_no_arguments( def test_parametrized_tags( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client + mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client ): tags_definition = ["--tags", "cake=cheesecake", "--branch-name", "test-branch"] _chosen_job = deploy_and_get_job_name(tags_definition) @@ -108,7 +117,7 @@ def test_parametrized_tags( def test_long_tags_list( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client + mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client ): tags_definition = [ "--tags", @@ -128,7 +137,7 @@ def test_long_tags_list( def test_unmatched_deploy_and_launch( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client + mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client ): _chosen_job = deploy_and_get_job_name() prepare_job_service_mock(mocker, _chosen_job) @@ -138,17 +147,17 @@ def test_unmatched_deploy_and_launch( def test_launch_run_submit( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client + mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client ): deployment_result = Path(".dbx/deployment-result.json") _chosen_job = deploy_and_get_job_name(["--files-only", "--write-specs-to-file", deployment_result]) mocked_result = JsonUtils.read(deployment_result) - mocker.patch("dbx.api.launch.runners.load_dbx_file", MagicMock(return_value=mocked_result)) + mocker.patch.object(StorageIO, "load", MagicMock(return_value=mocked_result)) launch_result = invoke_cli_runner(["launch", "--job", _chosen_job] + ["--as-run-submit"]) assert launch_result.exit_code == 0 -def test_launch_not_found(temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client): +def test_launch_not_found(temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client): _chosen_job = deploy_and_get_job_name(["--tags", "soup=beautiful"]) launch_result = invoke_cli_runner( ["launch", "--job", _chosen_job] + ["--tags", "cake=cheesecake"], expected_error=True @@ -156,7 +165,7 @@ def test_launch_not_found(temp_project: Path, mlflow_file_uploader, mock_dbx_fil assert "No deployments provided per given set of filters" in str(launch_result.exception) -def test_launch_empty_runs(temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client): +def test_launch_empty_runs(temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client): _chosen_job = deploy_and_get_job_name(["--files-only", "--tags", "cake=strudel"]) launch_result = invoke_cli_runner( ["launch", "--job", _chosen_job] + ["--as-run-submit", "--tags", "cake=cheesecake"], expected_error=True @@ -165,7 +174,7 @@ def test_launch_empty_runs(temp_project: Path, mlflow_file_uploader, mock_dbx_fi def test_launch_with_output( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client + mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client ): _chosen_job = deploy_and_get_job_name() prepare_job_service_mock(mocker, _chosen_job) @@ -173,9 +182,7 @@ def test_launch_with_output( assert launch_result.exit_code == 0 -def test_launch_with_unparsable_params( - temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client -): +def test_launch_with_unparsable_params(temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client): _chosen_job = deploy_and_get_job_name() launch_result = invoke_cli_runner( ["launch", "--job", _chosen_job, "--parameters", "{very[bad]_json}"], expected_error=True @@ -183,9 +190,7 @@ def test_launch_with_unparsable_params( assert "Provided parameters payload cannot be" in launch_result.stdout -def test_launch_with_run_now_v21_params( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload -): +def test_launch_with_run_now_v21_params(mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io): client_mock = MagicMock() p = PropertyMock(return_value="2.1") type(client_mock).jobs_api_version = p @@ -198,9 +203,7 @@ def test_launch_with_run_now_v21_params( assert launch_result.exit_code == 0 -def test_launch_with_run_now_v20_params( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload -): +def test_launch_with_run_now_v20_params(mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io): client_mock = MagicMock() type(client_mock).jobs_api_version = PropertyMock(return_value="2.0") mocker.patch.object(DatabricksClientProvider, "get_v2_client", lambda: client_mock) @@ -211,7 +214,7 @@ def test_launch_with_run_now_v20_params( def test_launch_with_trace( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client + mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client ): _chosen_job = deploy_and_get_job_name(["--tags", "soup=beautiful"]) prepare_job_service_mock(mocker, _chosen_job) @@ -220,8 +223,43 @@ def test_launch_with_trace( assert launch_result.exit_code == 0 +@pytest.mark.parametrize("state, err", [(PipelineUpdateState.COMPLETED, None), (PipelineUpdateState.FAILED, Exception)]) +def test_launch_pipeline( + state, err, mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client +): + (temp_project / "conf" / "deployment.yml").write_text( + textwrap.dedent( + """ + environments: + default: + workflows: + - name: "some" + workflow_type: "pipeline" + target: "some" + libraries: + - notebook: + path: "/Repos/some/path" + """ + ) + ) + mocker.patch.object(NamedPipelinesService, "find_by_name_strict", MagicMock(return_value=1)) + mocker.patch.object( + PipelineLauncher, "launch", MagicMock(return_value=(PipelineUpdateResponse(update_id="a", request_id="a"), 1)) + ) + invoke_cli_runner(["deploy", "some"]) + + mocker.patch.object(PipelineTracer, "start", MagicMock(return_value=state)) + + if err: + launch_result = invoke_cli_runner(["launch", "some", "-p", "--trace"], expected_error=True) + assert "failed during execution" in str(launch_result.exception) + else: + launch_result = invoke_cli_runner(["launch", "some", "-p", "--trace"]) + assert launch_result.exit_code == 0 + + def test_launch_with_trace_failed( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client + mock_storage_io, mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_api_v2_client ): _chosen_job = deploy_and_get_job_name(["--tags", "soup=beautiful"]) prepare_job_service_mock(mocker, _chosen_job) @@ -233,7 +271,7 @@ def test_launch_with_trace_failed( def test_launch_with_trace_and_kill_on_sigterm( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client + mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client ): _chosen_job = deploy_and_get_job_name(["--tags", "soup=beautiful"]) prepare_job_service_mock(mocker, _chosen_job) @@ -245,7 +283,7 @@ def test_launch_with_trace_and_kill_on_sigterm( def test_launch_with_trace_and_kill_on_sigterm_with_interruption( - mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_dbx_file_upload, mock_api_v2_client + mocker: MockFixture, temp_project: Path, mlflow_file_uploader, mock_storage_io, mock_api_v2_client ): _chosen_job = deploy_and_get_job_name(["--tags", "soup=beautiful"]) prepare_job_service_mock(mocker, _chosen_job) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a595e5da..0337180d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -15,12 +15,13 @@ from typer.testing import CliRunner from dbx.api.client_provider import DatabricksClientProvider +from dbx.api.launch.pipeline_models import PipelineGlobalState +from dbx.api.storage.io import StorageIO from dbx.api.storage.mlflow_based import MlflowStorageConfigurationManager from dbx.cli import app -from dbx.commands.deploy import _log_dbx_file from dbx.commands.init import init -from dbx.utils.adjuster import adjust_path from dbx.utils.file_uploader import MlflowFileUploader +from tests.unit.api.launch.test_pipeline_runner import TEST_PIPELINE_ID, TEST_PIPELINE_UPDATE_PAYLOAD TEST_HOST = "https:/dbx.cloud.databricks.com" TEST_TOKEN = "dapiDBXTEST" @@ -88,7 +89,7 @@ def generate_wheel(*args, **kwargs): if "disable_auto_execute_mock" in request.keywords: logging.info("Disabling the execute_shell_command for specific test") else: - mocker.patch("dbx.api.build.execute_shell_command", generate_wheel) + mocker.patch("dbx.models.build.execute_shell_command", generate_wheel) yield project_path @@ -127,23 +128,20 @@ def mlflow_fixture(session_mocker): @pytest.fixture(scope="function") def mlflow_file_uploader(mocker, mlflow_fixture): - real_adjuster = adjust_path + mocker.patch.object(MlflowFileUploader, "_verify_fuse_support", MagicMock()) + mocker.patch.object(MlflowFileUploader, "_upload_file", MagicMock()) - def fake_adjuster(candidate: str, file_uploader: MlflowFileUploader) -> str: - if str(candidate).startswith(file_uploader._base_uri): - return candidate - else: - adjusted = real_adjuster(candidate, file_uploader) - return adjusted + def _mocked_processor(local_file_path: Path, as_fuse) -> str: + remote_path = "/".join(["dbfs:/mocks/testing", str(local_file_path.as_posix())]) + remote_path = remote_path.replace("dbfs:/", "/dbfs/") if as_fuse else remote_path + return remote_path - mocker.patch.object(MlflowFileUploader, "_verify_fuse_support", MagicMock()) - mocker.patch(extract_function_name(real_adjuster), MagicMock(side_effect=fake_adjuster)) + mocker.patch.object(MlflowFileUploader, "_postprocess_path", MagicMock(side_effect=_mocked_processor)) @pytest.fixture() -def mock_dbx_file_upload(mocker): - func = _log_dbx_file - mocker.patch(extract_function_name(func), MagicMock()) +def mock_storage_io(mocker): + mocker.patch.object(StorageIO, "save", MagicMock()) @pytest.fixture() @@ -154,3 +152,18 @@ def mock_api_v2_client(mocker): @pytest.fixture() def mock_api_v1_client(mocker): mocker.patch.object(DatabricksClientProvider, "get_v2_client", MagicMock()) + + +@pytest.fixture +def pipeline_launch_mock(mocker: MockerFixture): + client = MagicMock() + client.perform_query = MagicMock( + side_effect=[ + {"statuses": [{"pipeline_id": TEST_PIPELINE_ID, "name": "some"}]}, # get pipeline + {"state": PipelineGlobalState.RUNNING}, # get current state + {}, # stop pipeline + {"state": PipelineGlobalState.IDLE}, # second verification get + TEST_PIPELINE_UPDATE_PAYLOAD, # start pipeline + ] + ) + return client diff --git a/tests/unit/models/test_acls.py b/tests/unit/models/test_acls.py new file mode 100644 index 00000000..32481c4f --- /dev/null +++ b/tests/unit/models/test_acls.py @@ -0,0 +1,37 @@ +import pytest + +from dbx.models.workflow.common.access_control import AccessControlMixin + + +def test_acls_positive(): + acls = AccessControlMixin( + **{ + "access_control_list": [ + {"user_name": "test1", "permission_level": "IS_OWNER"}, + {"user_name": "test2", "permission_level": "CAN_VIEW"}, + ] + } + ) + assert acls.access_control_list is not None + + +def test_owner_not_provided(): + with pytest.raises(ValueError): + AccessControlMixin(**{"access_control_list": [{"user_name": "test1", "permission_level": "CAN_MANAGE"}]}) + + +def test_two_owners_provided(): + with pytest.raises(ValueError): + AccessControlMixin( + **{ + "access_control_list": [ + {"user_name": "test1", "permission_level": "IS_OWNER"}, + {"user_name": "test2", "permission_level": "IS_OWNER"}, + ] + } + ) + + +def test_empty_acl(): + _e = AccessControlMixin(**{}) + assert _e.access_control_list is None diff --git a/tests/unit/models/test_deployment.py b/tests/unit/models/test_deployment.py index 9d98c594..fa26798d 100644 --- a/tests/unit/models/test_deployment.py +++ b/tests/unit/models/test_deployment.py @@ -1,7 +1,11 @@ import pytest +import yaml from dbx.api.config_reader import ConfigReader -from dbx.models.deployment import DeploymentConfig, EnvironmentDeploymentInfo +from dbx.models.deployment import DeploymentConfig, EnvironmentDeploymentInfo, WorkflowListMixin, Deployment +from dbx.models.workflow.common.pipeline import Pipeline +from dbx.models.workflow.v2dot0.workflow import Workflow as V2dot0Workflow +from dbx.models.workflow.v2dot1.workflow import Workflow as V2dot1Workflow from tests.unit.conftest import get_path_with_relation_to_current_file @@ -38,14 +42,14 @@ def test_raise_if_not_found(): def test_build_payload(capsys): - _payload = DeploymentConfig.prepare_build({"build": {"commands": ["sleep 5"]}}) + _payload = DeploymentConfig._prepare_build({"build": {"commands": ["sleep 5"]}}) res = capsys.readouterr() assert "No build logic defined in the deployment file" not in res.out assert _payload.commands is not None def test_build_payload_warning(capsys): - _payload = DeploymentConfig.prepare_build({}) + _payload = DeploymentConfig._prepare_build({}) res = capsys.readouterr() assert "No build logic defined in the deployment file" in res.out @@ -54,3 +58,41 @@ def test_legacy_build_conflict(): with pytest.raises(ValueError) as exc_info: DeploymentConfig.from_legacy_json_payload({"build": {"some": "value"}}) assert "Deployment file with a legacy syntax" in str(exc_info) + + +def test_empty_spec(): + with pytest.raises(ValueError): + EnvironmentDeploymentInfo.from_spec("test", {}) + + +def test_workflows_list_duplicates(): + with pytest.raises(ValueError): + WorkflowListMixin( + **{"workflows": [{"name": "a", "workflow_type": "job-v2.1"}, {"name": "a", "workflow_type": "job-v2.1"}]} + ) + + +def test_workflows_list_bad_get(): + _wf = WorkflowListMixin(**{"workflows": [{"name": "a", "workflow_type": "job-v2.1"}]}) + with pytest.raises(ValueError): + _wf.get_workflow("b") + + +def test_various_workflow_definitions(): + test_payload = """ + workflows: + - name: "dlt-pipeline" + workflow_type: "pipeline" + - name: "job-v21" + tasks: + - task_key: "first" + spark_python_task: + python_file: "/some/file.py" + - name: "job-v20" + spark_python_task: + python_file: "/some/file.py" + """ + _dep = Deployment.from_spec_remote(yaml.safe_load(test_payload)) + assert isinstance(_dep.get_workflow("dlt-pipeline"), Pipeline) + assert isinstance(_dep.get_workflow("job-v21"), V2dot1Workflow) + assert isinstance(_dep.get_workflow("job-v20"), V2dot0Workflow) diff --git a/tests/unit/models/test_destroyer.py b/tests/unit/models/test_destroyer.py index fa4986e3..a3487e2b 100644 --- a/tests/unit/models/test_destroyer.py +++ b/tests/unit/models/test_destroyer.py @@ -1,24 +1,15 @@ -from functools import partial from pathlib import Path -import pytest - from dbx.api.config_reader import ConfigReader -from dbx.models.destroyer import DestroyerConfig, DeletionMode +from dbx.models.cli.destroyer import DestroyerConfig, DeletionMode def test_destroy_model(temp_project): config_reader = ConfigReader(Path("conf/deployment.yml"), None) config = config_reader.get_config() deployment = config.get_environment("default", raise_if_not_found=True) - base_config = partial(DestroyerConfig, deletion_mode=DeletionMode.all, dracarys=False, deployment=deployment) - good_config: DestroyerConfig = base_config( - workflows=[f"{temp_project.name}-sample-etl"], + selected_wfs = [deployment.payload.workflows[0]] + base_config = DestroyerConfig( + workflows=selected_wfs, deletion_mode=DeletionMode.all, dracarys=False, deployment=deployment ) - assert good_config.workflows == [f"{temp_project.name}-sample-etl"] - - with pytest.raises(ValueError): - base_config(workflows=["some-non-existent"]) - - config_autofill: DestroyerConfig = base_config(workflows=[]) - assert config_autofill.workflows is not None + assert base_config.workflows == selected_wfs diff --git a/tests/unit/models/test_git_source.py b/tests/unit/models/test_git_source.py new file mode 100644 index 00000000..9f1cd960 --- /dev/null +++ b/tests/unit/models/test_git_source.py @@ -0,0 +1,13 @@ +import pytest + +from dbx.models.workflow.v2dot1.workflow import GitSource + + +def test_git_source_positive(): + gs = GitSource(git_url="http://some", git_provider="some", git_branch="some") + assert gs.git_branch == "some" + + +def test_git_source_negative(): + with pytest.raises(ValueError): + GitSource(git_url="http://some", git_provider="some", git_branch="some", git_tag="some") diff --git a/tests/unit/models/test_job_clusters.py b/tests/unit/models/test_job_clusters.py index 740c87b8..b3591d00 100644 --- a/tests/unit/models/test_job_clusters.py +++ b/tests/unit/models/test_job_clusters.py @@ -1,17 +1,17 @@ import pytest from pydantic import ValidationError -from dbx.models.job_clusters import JobClusters +from dbx.models.workflow.v2dot1.job_cluster import JobClustersMixin def test_empty(): - jc = JobClusters(**{}) + jc = JobClustersMixin(**{}) assert jc.job_clusters == [] def test_duplicates(): with pytest.raises(ValueError): - JobClusters( + JobClustersMixin( **{ "job_clusters": [ {"job_cluster_key": "some", "new_cluster": {}}, @@ -23,7 +23,7 @@ def test_duplicates(): def test_incorrect_format(): with pytest.raises(ValidationError): - JobClusters( + JobClustersMixin( **{ "job_clusters": {"job_cluster_key": "some", "new_cluster": {}}, } @@ -31,26 +31,26 @@ def test_incorrect_format(): def test_not_found(): - jc = JobClusters( + jc = JobClustersMixin( **{ "job_clusters": [ - {"job_cluster_key": "some", "new_cluster": {}}, + {"job_cluster_key": "some", "new_cluster": {"spark_version": "some"}}, ] } ) with pytest.raises(ValueError): - jc.get_cluster_definition("non-existent") + jc.get_job_cluster_definition("non-existent") def test_positive(): - nc_content = {"node_type_id": "some-node-type-id"} - jc = JobClusters( + nc_content = {"node_type_id": "some-node-type-id", "spark_version": "some"} + jc = JobClustersMixin( **{ "job_clusters": [ {"job_cluster_key": "some", "new_cluster": nc_content}, ] } ) - assert jc.get_cluster_definition("some") is not None - assert jc.get_cluster_definition("some").job_cluster_key == "some" - assert jc.get_cluster_definition("some").new_cluster == nc_content + assert jc.get_job_cluster_definition("some") is not None + assert jc.get_job_cluster_definition("some").job_cluster_key == "some" + assert jc.get_job_cluster_definition("some").new_cluster.dict(exclude_none=True) == nc_content diff --git a/tests/unit/models/test_new_cluster.py b/tests/unit/models/test_new_cluster.py new file mode 100644 index 00000000..57e43be5 --- /dev/null +++ b/tests/unit/models/test_new_cluster.py @@ -0,0 +1,29 @@ +import pytest + +from dbx.models.workflow.common.new_cluster import NewCluster, AutoScale + + +def test_legacy_msg(capsys): + NewCluster( + spark_version="some", + instance_pool_name="some", + driver_instance_pool_name="some", + policy_name="some", + aws_attributes={"instance_profile_name": "some"}, + ) + out = capsys.readouterr().out + assert "cluster-policy://" in out + assert "instance-pool://" in out + assert "driver_instance_pool_id" in out + assert "instance-profile://" in out + + +def test_autoscale_negative(): + with pytest.raises(ValueError): + AutoScale(min_workers=10, max_workers=5) + + +def test_autoscale_positive(): + _as = AutoScale(min_workers=1, max_workers=5) + assert _as.min_workers == 1 + assert _as.max_workers == 5 diff --git a/tests/unit/models/test_parameters.py b/tests/unit/models/test_parameters.py index 0b001164..e14cbc61 100644 --- a/tests/unit/models/test_parameters.py +++ b/tests/unit/models/test_parameters.py @@ -1,83 +1,68 @@ import pytest -from dbx.models.parameters.execute import ExecuteWorkloadParamInfo -from dbx.models.parameters.run_now import RunNowV2d0ParamInfo, RunNowV2d1ParamInfo -from dbx.models.parameters.run_submit import RunSubmitV2d0ParamInfo, RunSubmitV2d1ParamInfo +from dbx.models.workflow.v2dot0.parameters import StandardRunPayload as V2dot0StandardRunPayload +from dbx.models.workflow.v2dot0.parameters import AssetBasedRunPayload as V2dot0AssetBasedRunPayload +from dbx.models.workflow.v2dot1.parameters import StandardRunPayload as V2dot1StandardRunPayload +from dbx.models.workflow.v2dot1.parameters import AssetBasedRunPayload as V2dot1AssetBasedRunPayload +from dbx.models.cli.execute import ExecuteParametersPayload def test_empty_execute(): with pytest.raises(ValueError): - ExecuteWorkloadParamInfo(**{}) + ExecuteParametersPayload(**{}) def test_multiple_execute(): with pytest.raises(ValueError): - ExecuteWorkloadParamInfo(**{"parameters": ["a", "b"], "named_parameters": ["--c=1"]}) + ExecuteParametersPayload(**{"parameters": ["a", "b"], "named_parameters": ["--c=1"]}) def test_params_execute(): - _p = ExecuteWorkloadParamInfo(**{"parameters": ["a", "b"]}) + _p = ExecuteParametersPayload(**{"parameters": ["a", "b"]}) assert _p.parameters is not None def test_named_params_execute(): - _p = ExecuteWorkloadParamInfo(**{"named_parameters": ["--a=1", "--b=2"]}) + _p = ExecuteParametersPayload(**{"named_parameters": {"a": 1}}) assert _p.named_parameters is not None -def test_runnow_v20(): - _rn = RunNowV2d0ParamInfo(**{"jar_params": ["a"]}) +def test_standard_v2dot0(): + _rn = V2dot0StandardRunPayload(**{"jar_params": ["a"]}) assert _rn.jar_params is not None -def test_runnow_v20_two(): - _rn = RunNowV2d0ParamInfo(**{"jar_params": ["a"], "notebook_params": {"a": 1}}) +def test_standard_v2dot0_multiple(): + _rn = V2dot0StandardRunPayload(**{"jar_params": ["a"], "notebook_params": {"a": 1}}) assert _rn.jar_params is not None assert _rn.notebook_params is not None -def test_runnow_v20_negative(): - with pytest.raises(ValueError): - RunNowV2d0ParamInfo(**{}) - - -def test_runnow_v21(): - _rn = RunNowV2d1ParamInfo(**{"python_named_params": {"a": 1}}) +def test_standard_v2dot1(): + _rn = V2dot1StandardRunPayload(**{"python_named_params": {"a": 1}}) assert _rn.python_named_params is not None -def test_runsubmit_v20(): - _rs = RunSubmitV2d0ParamInfo(**{"spark_python_task": {"parameters": ["a"]}}) - assert _rs.spark_python_task.parameters is not None - - -def test_runsubmit_v20_non_unique(): - with pytest.raises(ValueError): - RunSubmitV2d0ParamInfo(**{"spark_python_task": {"parameters": ["a"]}, "spark_jar_task": {"parameters": ["a"]}}) - - -def test_runsubmit_v20_empty(): - with pytest.raises(ValueError): - RunSubmitV2d0ParamInfo(**{}) - - -def test_runsubmit_v21_empty_tasks(): - with pytest.raises(ValueError): - RunSubmitV2d0ParamInfo(**{"tasks": []}) +def test_assert_based_v2dot0(): + _rs = V2dot0AssetBasedRunPayload(**{"parameters": ["a"]}) + assert _rs.parameters is not None -def test_runsubmit_v21_empty(): +def test_assert_based_v2dot0_not_unique(): with pytest.raises(ValueError): - RunSubmitV2d1ParamInfo(**{}) - - -def test_runsubmit_v21_good(): - _sp_task = {"task_key": "first", "spark_python_task": {"parameters": ["a"]}} - _sj_task = {"task_key": "second", "spark_jar_task": {"parameters": ["a"]}} - _rs = RunSubmitV2d1ParamInfo(**{"tasks": [_sp_task, _sj_task]}) - - assert _rs.tasks[0].task_key == "first" - assert _rs.tasks[0].spark_python_task.parameters is not None - - assert _rs.tasks[1].task_key == "second" - assert _rs.tasks[1].spark_jar_task.parameters is not None + V2dot0AssetBasedRunPayload(**{"parameters": ["a"], "base_parameters": {"a": "b"}}) + + +@pytest.mark.parametrize( + "param_raw_payload", + [ + '[{"task_key": "some", "base_parameters": {"a": 1, "b": 2}}]', + '[{"task_key": "some", "parameters": ["a", "b"]}]', + '[{"task_key": "some", "named_parameters": {"a": 1}}]', + '[{"task_key": "some", "full_refresh": true}]', + '[{"task_key": "some", "parameters": {"key1": "value2"}}]', + ], +) +def test_assert_based_v2dot1_good(param_raw_payload): + parsed = V2dot1AssetBasedRunPayload.from_string(param_raw_payload) + assert parsed.elements is not None diff --git a/tests/unit/models/test_pipeline.py b/tests/unit/models/test_pipeline.py new file mode 100644 index 00000000..e24c7c8f --- /dev/null +++ b/tests/unit/models/test_pipeline.py @@ -0,0 +1,8 @@ +from dbx.models.workflow.common.pipeline import PipelinesNewCluster + + +def test_omits(capsys): + nc = PipelinesNewCluster(spark_version="some") + _out = capsys.readouterr().out + assert "The `spark_version` property cannot be applied" in _out + assert nc.spark_version is None diff --git a/tests/unit/models/test_task.py b/tests/unit/models/test_task.py index b058552f..9b213f71 100644 --- a/tests/unit/models/test_task.py +++ b/tests/unit/models/test_task.py @@ -1,10 +1,12 @@ -from copy import deepcopy from pathlib import Path import pytest from pydantic import ValidationError -from dbx.models.task import Task, TaskType, SparkPythonTask +from dbx.models.cli.execute import ExecuteParametersPayload +from dbx.models.workflow.common.task import SparkPythonTask, SparkJarTask, SparkSubmitTask, BaseTaskMixin +from dbx.models.workflow.common.task_type import TaskType +from dbx.models.workflow.v2dot1.task import SqlTask def get_spark_python_task_payload(py_file: str): @@ -30,60 +32,83 @@ def test_spark_python_task_positive(temp_project: Path): py_file = f"file://{temp_project.name}/tasks/sample_etl_task.py" _payload = get_spark_python_task_payload(py_file).get("spark_python_task") _t = SparkPythonTask(**_payload) - assert isinstance(_t.python_file, Path) + assert isinstance(_t.execute_file, Path) -def test_task_recognition(temp_project: Path): +def test_spark_python_task_non_py_file(temp_project: Path): + py_file = f"file://{temp_project.name}/tasks/sample_etl_task.ipynb" + _payload = get_spark_python_task_payload(py_file).get("spark_python_task") + with pytest.raises(ValidationError): + SparkPythonTask(**_payload) + + +def test_sql_task_non_unique(): + payload = {"query": {"query_id": "some"}, "dashboard": {"dashboard_id": "some"}, "warehouse_id": "some"} + with pytest.raises(ValueError): + SqlTask(**payload) + + +def test_sql_task_good(): + payload = {"query": {"query_id": "some"}, "warehouse_id": "some"} + _task = SqlTask(**payload) + assert _task.query.query_id is not None + + +def test_spark_jar_deprecated(capsys): + _jt = SparkJarTask(main_class_name="some.Class", jar_uri="file://some/uri") + assert _jt.jar_uri is not None + assert "Field jar_uri is DEPRECATED since" in capsys.readouterr().out + + +def test_spark_python_task_not_fuse(temp_project: Path): py_file = f"file://{temp_project.name}/tasks/sample_etl_task.py" - _payload = get_spark_python_task_payload(py_file) - _result = Task(**_payload) - assert _result.spark_python_task is not None - assert _result.python_wheel_task is None - assert _result.task_type == TaskType.spark_python_task + _payload = get_spark_python_task_payload(py_file).get("spark_python_task") + _payload["python_file"] = "file:fuse://some/file" + with pytest.raises(ValueError): + SparkPythonTask(**_payload) + +def test_spark_python_task_execute_incorrect(temp_project: Path): + py_file = f"file://{temp_project.name}/tasks/sample_etl_task.py" + _payload = get_spark_python_task_payload(py_file).get("spark_python_task") + _payload["python_file"] = "dbfs:/some/path.py" + with pytest.raises(ValueError): + _st = SparkPythonTask(**_payload) + _st.execute_file # noqa -def test_python_wheel_task(): - _result = Task(**python_wheel_task_payload) - assert _result.spark_python_task is None - assert _result.python_wheel_task is not None - assert _result.task_type == TaskType.python_wheel_task +def test_spark_python_task_execute_non_existent(temp_project: Path): + py_file = f"file://{temp_project.name}/tasks/sample_etl_task.py" + _payload = get_spark_python_task_payload(py_file).get("spark_python_task") + _t = SparkPythonTask(**_payload) + Path(py_file.replace("file://", "")).unlink() + with pytest.raises(ValueError): + _st = SparkPythonTask(**_payload) + _st.execute_file # noqa -def test_python_wheel_task_named(): - _c = deepcopy(python_wheel_task_payload) - _c["python_wheel_task"].pop("parameters") - _c["python_wheel_task"]["named_parameters"] = ["--a=1", "--b=2"] - _result = Task(**_c) - assert _result.task_type == TaskType.python_wheel_task - assert _result.python_wheel_task.named_parameters is not None +def test_spark_submit_task(): + st = SparkSubmitTask(**{"parameters": ["some", "other"]}) + assert st.parameters is not None -def test_python_wheel_task_named_invalid_prefix(): - _c = deepcopy(python_wheel_task_payload) - _c["python_wheel_task"].pop("parameters") - _c["python_wheel_task"]["named_parameters"] = ["a", "--b=2"] - with pytest.raises(ValidationError): - Task(**_c) +def test_mixin_undefined_type(): + bt = BaseTaskMixin(**{"unknown_task": {"prop1": "arg1"}}) + assert bt.task_type == TaskType.undefined_task -def test_python_wheel_task_named_invalid_equal(): - _c = deepcopy(python_wheel_task_payload) - _c["python_wheel_task"].pop("parameters") - _c["python_wheel_task"]["named_parameters"] = ["--a"] - with pytest.raises(ValidationError): - Task(**_c) +def test_mixin_execute_unsupported(): + bt = BaseTaskMixin(**{"unknown_task": {"prop1": "arg1"}}) + with pytest.raises(RuntimeError): + bt.check_if_supported_in_execute() -def test_negative(): - _payload = {"spark_jar_task": {"main_class_name": "org.some.Class"}} +def test_mixin_incorrect_override(): + bt = BaseTaskMixin(**{"spark_python_task": {"python_file": "/some/file"}}) with pytest.raises(ValueError): - Task(**_payload) + bt.override_execute_parameters(ExecuteParametersPayload(named_parameters={"p1": 1})) -def test_multiple(temp_project): - py_file = f"file://{temp_project.name}/tasks/sample_etl_task.py" - _sp_payload = get_spark_python_task_payload(py_file) - _payload = {**_sp_payload, **python_wheel_task_payload} +def test_mixin_multiple_provided(): with pytest.raises(ValueError): - Task(**_payload) + BaseTaskMixin(**{"spark_python_task": {"python_file": "/some/file"}, "unknown_task": {"some": "props"}}) diff --git a/tests/unit/models/test_v2dot0_workflow.py b/tests/unit/models/test_v2dot0_workflow.py new file mode 100644 index 00000000..c57dcd21 --- /dev/null +++ b/tests/unit/models/test_v2dot0_workflow.py @@ -0,0 +1,39 @@ +import pytest + +from dbx.models.workflow.v2dot0.parameters import AssetBasedRunPayload +from dbx.models.workflow.v2dot0.workflow import Workflow + + +def test_wf(capsys): + wf = Workflow(existing_cluster_name="something", pipeline_task={"pipeline_id": "something"}, name="test") + assert "cluster://" in capsys.readouterr().out + with pytest.raises(RuntimeError): + wf.get_task("whatever") + + +@pytest.mark.parametrize( + "wf_def", + [ + { + "name": "test1", + "new_cluster": {"spark_version": "lts"}, + "existing_cluster_name": "some-cluster", + "some_task": "here", + } + ], +) +def test_validation(wf_def): + with pytest.raises(ValueError): + Workflow(**wf_def) + + +def test_v2dot0_overrides_parameters(): + wf = Workflow(**{"name": "test", "spark_python_task": {"python_file": "some/file.py", "parameters": ["a"]}}) + wf.override_asset_based_launch_parameters(AssetBasedRunPayload(parameters=["b"])) + assert wf.spark_python_task.parameters == ["b"] + + +def test_v2dot0_overrides_notebook(): + wf = Workflow(**{"name": "test", "notebook_task": {"notebook_path": "/some/path", "base_parameters": {"k1": "v1"}}}) + wf.override_asset_based_launch_parameters(AssetBasedRunPayload(base_parameters={"k1": "v2"})) + assert wf.notebook_task.base_parameters == {"k1": "v2"} diff --git a/tests/unit/models/test_v2dot1_workflow.py b/tests/unit/models/test_v2dot1_workflow.py new file mode 100644 index 00000000..30a60338 --- /dev/null +++ b/tests/unit/models/test_v2dot1_workflow.py @@ -0,0 +1,35 @@ +import pytest +from pydantic import ValidationError + +from dbx.models.workflow.v2dot1.parameters import AssetBasedRunPayload +from dbx.models.workflow.v2dot1.workflow import Workflow + + +def test_empty_tasks(capsys): + Workflow(tasks=[], name="empty") + assert "might cause errors" in capsys.readouterr().out + + +def test_task_not_provided(): + with pytest.raises(ValidationError): + Workflow(name="test", tasks=[{"some": "a"}, {"other": "b"}]) + + +def test_duplicated_tasks(capsys): + with pytest.raises(ValidationError): + Workflow(tasks=[{"task_key": "d", "some_task": "prop"}, {"task_key": "d", "some_task": "prop"}], name="empty") + + +def test_override_positive(): + wf = Workflow( + name="some", + tasks=[{"task_key": "one", "spark_python_task": {"python_file": "/some/file.py", "parameters": ["a"]}}], + ) + override_payload = AssetBasedRunPayload.from_string( + """[ + {"task_key": "one", "parameters": ["a", "b"]} + ]""" + ) + wf.override_asset_based_launch_parameters(override_payload) + assert wf.get_task("one").spark_python_task.parameters == ["a", "b"] + assert wf.task_names == ["one"] diff --git a/tests/unit/sync/clients/conftest.py b/tests/unit/sync/clients/conftest.py index 28f00ab4..1baedd76 100644 --- a/tests/unit/sync/clients/conftest.py +++ b/tests/unit/sync/clients/conftest.py @@ -8,7 +8,7 @@ @pytest.fixture def mock_config(): - return mocked_props(token="fake-token", host="http://fakehost.asdf/base/", insecure=None) + return mocked_props(token="fake-token", host="http://fakehost.asdf/?o=1234", insecure=None) @pytest.fixture diff --git a/tests/unit/sync/clients/test_dbfs_client.py b/tests/unit/sync/clients/test_dbfs_client.py index c155ff74..d2f21902 100644 --- a/tests/unit/sync/clients/test_dbfs_client.py +++ b/tests/unit/sync/clients/test_dbfs_client.py @@ -18,7 +18,8 @@ def client(mock_config) -> DBFSClient: def test_init(client): assert client.api_token == "fake-token" - assert client.host == "http://fakehost.asdf/base" + assert client.host == "http://fakehost.asdf" + assert client.api_base_path == "http://fakehost.asdf/api/2.0/dbfs" def test_delete(client: DBFSClient): @@ -29,7 +30,7 @@ def test_delete(client: DBFSClient): asyncio.run(client.delete(sub_path="foo/bar", session=session)) assert session.post.call_count == 1 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/delete" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/delete" assert session.post.call_args[1]["json"] == {"path": "dbfs:/tmp/foo/foo/bar"} assert "ssl" not in session.post.call_args[1] assert session.post.call_args[1]["headers"]["Authorization"] == "Bearer fake-token" @@ -37,7 +38,7 @@ def test_delete(client: DBFSClient): def test_delete_secure(client: DBFSClient): - mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/base/", insecure=False) + mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=False) client = DBFSClient(base_path="/tmp/foo", config=mock_config) session = MagicMock() resp = AsyncMock() @@ -46,7 +47,7 @@ def test_delete_secure(client: DBFSClient): asyncio.run(client.delete(sub_path="foo/bar", session=session)) assert session.post.call_count == 1 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/delete" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/delete" assert session.post.call_args[1]["json"] == {"path": "dbfs:/tmp/foo/foo/bar"} assert session.post.call_args[1]["ssl"] is True @@ -74,7 +75,7 @@ def test_delete_recursive(client: DBFSClient): asyncio.run(client.delete(sub_path="foo/bar", session=session, recursive=True)) assert session.post.call_count == 1 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/delete" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/delete" assert session.post.call_args[1]["json"] == {"path": "dbfs:/tmp/foo/foo/bar", "recursive": True} @@ -93,7 +94,7 @@ def test_delete_rate_limited(client: DBFSClient): asyncio.run(client.delete(sub_path="foo/bar", session=session)) assert session.post.call_count == 2 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/delete" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/delete" assert session.post.call_args[1]["json"] == {"path": "dbfs:/tmp/foo/foo/bar"} @@ -112,7 +113,7 @@ def test_delete_rate_limited_retry_after(client: DBFSClient): asyncio.run(client.delete(sub_path="foo/bar", session=session)) assert session.post.call_count == 2 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/delete" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/delete" assert session.post.call_args[1]["json"] == {"path": "dbfs:/tmp/foo/foo/bar"} @@ -138,7 +139,7 @@ def test_mkdirs(client: DBFSClient): asyncio.run(client.mkdirs(sub_path="foo/bar", session=session)) assert session.post.call_count == 1 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/mkdirs" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/mkdirs" assert session.post.call_args[1]["json"] == {"path": "dbfs:/tmp/foo/foo/bar"} assert session.post.call_args[1]["headers"]["Authorization"] == "Bearer fake-token" assert is_dbfs_user_agent(session.post.call_args[1]["headers"]["user-agent"]) @@ -174,7 +175,7 @@ def test_mkdirs_rate_limited(client: DBFSClient): asyncio.run(client.mkdirs(sub_path="foo/bar", session=session)) assert session.post.call_count == 2 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/mkdirs" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/mkdirs" assert session.post.call_args[1]["json"] == {"path": "dbfs:/tmp/foo/foo/bar"} @@ -193,7 +194,7 @@ def test_mkdirs_rate_limited_retry_after(client: DBFSClient): asyncio.run(client.mkdirs(sub_path="foo/bar", session=session)) assert session.post.call_count == 2 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/mkdirs" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/mkdirs" assert session.post.call_args[1]["json"] == {"path": "dbfs:/tmp/foo/foo/bar"} @@ -220,7 +221,7 @@ def test_put(client: DBFSClient, dummy_file_path: str): asyncio.run(client.put(sub_path="foo/bar", full_source_path=dummy_file_path, session=session)) assert session.post.call_count == 1 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/put" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/put" assert session.post.call_args[1]["json"] == { "path": "dbfs:/tmp/foo/foo/bar", "contents": base64.b64encode(b"yo").decode("ascii"), @@ -317,7 +318,7 @@ def test_put_rate_limited(client: DBFSClient, dummy_file_path: str): asyncio.run(client.put(sub_path="foo/bar", full_source_path=dummy_file_path, session=session)) assert session.post.call_count == 2 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/put" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/put" assert session.post.call_args[1]["json"] == { "path": "dbfs:/tmp/foo/foo/bar", "contents": base64.b64encode(b"yo").decode("ascii"), @@ -340,7 +341,7 @@ def test_put_rate_limited_retry_after(client: DBFSClient, dummy_file_path: str): asyncio.run(client.put(sub_path="foo/bar", full_source_path=dummy_file_path, session=session)) assert session.post.call_count == 2 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/put" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/put" assert session.post.call_args[1]["json"] == { "path": "dbfs:/tmp/foo/foo/bar", "contents": base64.b64encode(b"yo").decode("ascii"), diff --git a/tests/unit/sync/clients/test_get_user.py b/tests/unit/sync/clients/test_get_user.py index 3711accc..8f2bb448 100644 --- a/tests/unit/sync/clients/test_get_user.py +++ b/tests/unit/sync/clients/test_get_user.py @@ -1,10 +1,24 @@ from unittest.mock import MagicMock, PropertyMock, patch +import pytest + from dbx.sync.clients import get_user +from tests.unit.sync.utils import mocked_props +@pytest.mark.parametrize( + "test_case", + [ + ("http://fakehost.asdf", "http://fakehost.asdf/api/2.0/preview/scim/v2/Me"), + ("http://fakehost.asdf/", "http://fakehost.asdf/api/2.0/preview/scim/v2/Me"), + ("http://fakehost.asdf/?o=1234", "http://fakehost.asdf/api/2.0/preview/scim/v2/Me"), + ("http://fakehost.asdf:8080/?o=1234", "http://fakehost.asdf:8080/api/2.0/preview/scim/v2/Me"), + ], +) @patch("dbx.sync.clients.requests") -def test_get_user(mock_requests, mock_config): +def test_get_user(mock_requests, test_case): + config_host, expected_url = test_case + mock_config = mocked_props(token="fake-token", host=config_host, insecure=None) resp = MagicMock() setattr(type(resp), "status_code", PropertyMock(return_value=200)) user_info = {"userName": "foo"} @@ -12,6 +26,7 @@ def test_get_user(mock_requests, mock_config): mock_requests.get.return_value = resp assert get_user(mock_config) == user_info assert resp.json.call_count == 1 + assert mock_requests.get.call_args[0][0] == expected_url @patch("dbx.sync.clients.requests") diff --git a/tests/unit/sync/clients/test_repos_client.py b/tests/unit/sync/clients/test_repos_client.py index 456985c9..ed2c033e 100644 --- a/tests/unit/sync/clients/test_repos_client.py +++ b/tests/unit/sync/clients/test_repos_client.py @@ -16,7 +16,9 @@ def client(mock_config): def test_init(mock_config): client = ReposClient(user="foo@somewhere.com", repo_name="my-repo", config=mock_config) assert client.api_token == "fake-token" - assert client.host == "http://fakehost.asdf/base" + assert client.host == "http://fakehost.asdf" + assert client.workspace_api_base_path == "http://fakehost.asdf/api/2.0/workspace" + assert client.workspace_files_api_base_path == "http://fakehost.asdf/api/2.0/workspace-files/import-file" assert client.base_path == "/Repos/foo@somewhere.com/my-repo" with pytest.raises(ValueError): @@ -37,7 +39,7 @@ def test_delete(client: ReposClient): asyncio.run(client.delete(sub_path="foo/bar", session=session)) assert session.post.call_count == 1 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/workspace/delete" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/workspace/delete" assert session.post.call_args[1]["json"] == {"path": "/Repos/foo@somewhere.com/my-repo/foo/bar"} assert "ssl" not in session.post.call_args[1] assert session.post.call_args[1]["headers"]["Authorization"] == "Bearer fake-token" @@ -45,7 +47,7 @@ def test_delete(client: ReposClient): def test_delete_secure(client: ReposClient): - mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/base/", insecure=False) + mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=False) client = ReposClient(user="foo@somewhere.com", repo_name="my-repo", config=mock_config) session = MagicMock() resp = AsyncMock() @@ -54,13 +56,13 @@ def test_delete_secure(client: ReposClient): asyncio.run(client.delete(sub_path="foo/bar", session=session)) assert session.post.call_count == 1 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/workspace/delete" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/workspace/delete" assert session.post.call_args[1]["json"] == {"path": "/Repos/foo@somewhere.com/my-repo/foo/bar"} assert session.post.call_args[1]["ssl"] is True def test_delete_insecure(client: ReposClient): - mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/base/", insecure=True) + mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=True) client = ReposClient(user="foo@somewhere.com", repo_name="my-repo", config=mock_config) session = MagicMock() resp = AsyncMock() @@ -69,7 +71,7 @@ def test_delete_insecure(client: ReposClient): asyncio.run(client.delete(sub_path="foo/bar", session=session)) assert session.post.call_count == 1 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/workspace/delete" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/workspace/delete" assert session.post.call_args[1]["json"] == {"path": "/Repos/foo@somewhere.com/my-repo/foo/bar"} assert session.post.call_args[1]["ssl"] is False @@ -97,7 +99,7 @@ def test_delete_recursive(client: ReposClient): asyncio.run(client.delete(sub_path="foo/bar", session=session, recursive=True)) assert session.post.call_count == 1 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/workspace/delete" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/workspace/delete" assert session.post.call_args[1]["json"] == {"path": "/Repos/foo@somewhere.com/my-repo/foo/bar", "recursive": True} @@ -116,7 +118,7 @@ def test_delete_rate_limited(client: ReposClient): asyncio.run(client.delete(sub_path="foo/bar", session=session)) assert session.post.call_count == 2 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/workspace/delete" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/workspace/delete" assert session.post.call_args[1]["json"] == {"path": "/Repos/foo@somewhere.com/my-repo/foo/bar"} @@ -135,7 +137,7 @@ def test_delete_rate_limited_retry_after(client: ReposClient): asyncio.run(client.delete(sub_path="foo/bar", session=session)) assert session.post.call_count == 2 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/workspace/delete" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/workspace/delete" assert session.post.call_args[1]["json"] == {"path": "/Repos/foo@somewhere.com/my-repo/foo/bar"} @@ -161,7 +163,7 @@ def test_mkdirs(client: ReposClient): asyncio.run(client.mkdirs(sub_path="foo/bar", session=session)) assert session.post.call_count == 1 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/workspace/mkdirs" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/workspace/mkdirs" assert session.post.call_args[1]["json"] == {"path": "/Repos/foo@somewhere.com/my-repo/foo/bar"} assert session.post.call_args[1]["headers"]["Authorization"] == "Bearer fake-token" assert is_repos_user_agent(session.post.call_args[1]["headers"]["user-agent"]) @@ -197,7 +199,7 @@ def test_mkdirs_rate_limited(client: ReposClient): asyncio.run(client.mkdirs(sub_path="foo/bar", session=session)) assert session.post.call_count == 2 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/workspace/mkdirs" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/workspace/mkdirs" assert session.post.call_args[1]["json"] == {"path": "/Repos/foo@somewhere.com/my-repo/foo/bar"} @@ -216,7 +218,7 @@ def test_mkdirs_rate_limited_retry_after(client: ReposClient): asyncio.run(client.mkdirs(sub_path="foo/bar", session=session)) assert session.post.call_count == 2 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/base/api/2.0/workspace/mkdirs" + assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/workspace/mkdirs" assert session.post.call_args[1]["json"] == {"path": "/Repos/foo@somewhere.com/my-repo/foo/bar"} @@ -245,7 +247,7 @@ def test_put(client: ReposClient, dummy_file_path: str): assert session.post.call_count == 1 assert ( session.post.call_args[1]["url"] - == "http://fakehost.asdf/base/api/2.0/workspace-files/import-file/Repos/foo@somewhere.com/my-repo/foo/bar" + == "http://fakehost.asdf/api/2.0/workspace-files/import-file/Repos/foo@somewhere.com/my-repo/foo/bar" ) assert session.post.call_args[1]["data"] == b"yo" assert session.post.call_args[1]["headers"]["Authorization"] == "Bearer fake-token" @@ -286,7 +288,7 @@ def test_put_rate_limited(client: ReposClient, dummy_file_path: str): assert session.post.call_count == 2 assert ( session.post.call_args[1]["url"] - == "http://fakehost.asdf/base/api/2.0/workspace-files/import-file/Repos/foo@somewhere.com/my-repo/foo/bar" + == "http://fakehost.asdf/api/2.0/workspace-files/import-file/Repos/foo@somewhere.com/my-repo/foo/bar" ) assert session.post.call_args[1]["data"] == b"yo" @@ -308,7 +310,7 @@ def test_put_rate_limited_retry_after(client: ReposClient, dummy_file_path: str) assert session.post.call_count == 2 assert ( session.post.call_args[1]["url"] - == "http://fakehost.asdf/base/api/2.0/workspace-files/import-file/Repos/foo@somewhere.com/my-repo/foo/bar" + == "http://fakehost.asdf/api/2.0/workspace-files/import-file/Repos/foo@somewhere.com/my-repo/foo/bar" ) assert session.post.call_args[1]["data"] == b"yo" @@ -325,3 +327,61 @@ def test_put_unauthorized(client: ReposClient, dummy_file_path: str): with pytest.raises(ClientError): asyncio.run(client.put(sub_path="foo/bar", full_source_path=dummy_file_path, session=session)) + + +def test_exists(client: ReposClient): + session = MagicMock() + resp = AsyncMock() + setattr(type(resp), "status", PropertyMock(return_value=200)) + resp.json.return_value = {"repos": [{"path": "/Repos/foo@somewhere.com/my-repo"}]} + session.get.return_value = create_async_with_result(resp) + assert asyncio.run(client.exists(session=session)) + + assert session.get.call_count == 1 + assert session.get.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/repos" + assert session.get.call_args[1]["params"] == {"path_prefix": "/Repos/foo@somewhere.com/my-repo"} + assert "ssl" not in session.get.call_args[1] + assert session.get.call_args[1]["headers"]["Authorization"] == "Bearer fake-token" + assert is_repos_user_agent(session.get.call_args[1]["headers"]["user-agent"]) + + +def test_exists_not_found(client: ReposClient): + session = MagicMock() + resp = AsyncMock() + setattr(type(resp), "status", PropertyMock(return_value=200)) + resp.json.return_value = {"repos": [{"path": "/Repos/foo@somewhere.com/other-repo"}]} + session.get.return_value = create_async_with_result(resp) + assert not asyncio.run(client.exists(session=session)) + + assert session.get.call_count == 1 + assert session.get.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/repos" + assert session.get.call_args[1]["params"] == {"path_prefix": "/Repos/foo@somewhere.com/my-repo"} + assert "ssl" not in session.get.call_args[1] + assert session.get.call_args[1]["headers"]["Authorization"] == "Bearer fake-token" + assert is_repos_user_agent(session.get.call_args[1]["headers"]["user-agent"]) + + +def test_exists_empty_response(client: ReposClient): + session = MagicMock() + resp = AsyncMock() + setattr(type(resp), "status", PropertyMock(return_value=200)) + resp.json.return_value = {} + session.get.return_value = create_async_with_result(resp) + assert not asyncio.run(client.exists(session=session)) + + assert session.get.call_count == 1 + assert session.get.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/repos" + assert session.get.call_args[1]["params"] == {"path_prefix": "/Repos/foo@somewhere.com/my-repo"} + assert "ssl" not in session.get.call_args[1] + assert session.get.call_args[1]["headers"]["Authorization"] == "Bearer fake-token" + assert is_repos_user_agent(session.get.call_args[1]["headers"]["user-agent"]) + + +def test_exists_failure(client: ReposClient): + session = MagicMock() + resp = AsyncMock() + setattr(type(resp), "status", PropertyMock(return_value=500)) + resp.json.return_value = {"repos": [{"path": "/Repos/foo@somewhere.com/my-repo"}]} + session.get.return_value = create_async_with_result(resp) + with pytest.raises(ClientError): + asyncio.run(client.exists(session=session)) diff --git a/tests/unit/sync/test_commands.py b/tests/unit/sync/test_commands.py index 5c53c81a..2e02ca3d 100644 --- a/tests/unit/sync/test_commands.py +++ b/tests/unit/sync/test_commands.py @@ -1,18 +1,33 @@ +import asyncio import os -from unittest.mock import patch, call, MagicMock +from unittest.mock import patch, call, MagicMock, AsyncMock import click import pytest from databricks_cli.configure.provider import ProfileConfigProvider +from dbx.commands.sync.sync import repo_exists from dbx.commands.sync.functions import get_user_name, get_source_base_name from dbx.constants import DBX_SYNC_DEFAULT_IGNORES from dbx.sync import DeleteUnmatchedOption from dbx.sync.clients import DBFSClient, ReposClient +from tests.unit.sync.utils import mocked_props from .conftest import invoke_cli_runner from .utils import temporary_directory, pushd +def get_config(): + return mocked_props(token="fake-token", host="http://fakehost.asdf/?o=1234", insecure=None) + + +@pytest.fixture +def mock_get_config(): + with patch("dbx.commands.sync.sync.get_databricks_config") as mock_get_databricks_config: + config = get_config() + mock_get_databricks_config.return_value = config + yield mock_get_databricks_config + + @patch("dbx.commands.sync.functions.get_user") def test_get_user_name(mock_get_user): mock_get_user.return_value = {"userName": "foo"} @@ -30,21 +45,23 @@ def test_get_source_base_name(): get_source_base_name("/") -@patch("dbx.commands.sync.sync.get_databricks_config") @patch("dbx.commands.sync.sync.main_loop") -def test_repo_no_opts(mock_get_config, mock_main_loop): +def test_repo_no_opts(mock_main_loop): # some options are required res = invoke_cli_runner(["repo"], expected_error=True) assert "Missing option" in res.output +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -def test_repo_environment(mock_main_loop, mock_get_user_name, temp_project): +def test_repo_environment(mock_main_loop, mock_get_user_name, mock_repo_exists, temp_project): with temporary_directory() as tempdir: + mock_repo_exists.return_value = True mock_get_user_name.return_value = "me" with patch.object(ProfileConfigProvider, "get_config") as config_mock: + config_mock.return_value = get_config() invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo", "--environment", "default"]) assert mock_main_loop.call_count == 1 @@ -59,19 +76,19 @@ def test_dbfs_environment(mock_main_loop, mock_get_user_name, temp_project): mock_get_user_name.return_value = "me" with patch.object(ProfileConfigProvider, "get_config") as config_mock: + config_mock.return_value = get_config() invoke_cli_runner(["dbfs", "-s", tempdir, "-d", "the-repo", "--environment", "default"]) assert mock_main_loop.call_count == 1 assert config_mock.call_count == 1 +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_basic_opts(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_basic_opts(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: - config = MagicMock() - mock_get_config.return_value = config + mock_repo_exists.return_value = True mock_get_user_name.return_value = "me" invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo"]) @@ -98,13 +115,12 @@ def test_repo_basic_opts(mock_get_config, mock_main_loop, mock_get_user_name): assert client.base_path == "/Repos/me/the-repo" +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_unknown_user(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_unknown_user(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: - config = MagicMock() - mock_get_config.return_value = config + mock_repo_exists.return_value = True mock_get_user_name.return_value = None res = invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo"], expected_error=True) @@ -116,13 +132,30 @@ def test_repo_unknown_user(mock_get_config, mock_main_loop, mock_get_user_name): assert "Destination repo path can't be automatically determined because the user is" in res.output +@patch("dbx.commands.sync.sync.repo_exists") +@patch("dbx.commands.sync.sync.get_user_name") +@patch("dbx.commands.sync.sync.main_loop") +def test_repo_unknown_repo(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): + with temporary_directory() as tempdir: + mock_repo_exists.return_value = False + mock_get_user_name.return_value = "me" + + res = invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo"], expected_error=True) + + assert mock_main_loop.call_count == 0 + assert mock_get_config.call_count == 1 + assert mock_get_user_name.call_count == 1 + + assert "lease create the repo" in res.output + + +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_dry_run(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_dry_run(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: - config = MagicMock() - mock_get_config.return_value = config + mock_repo_exists.return_value = True + invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo", "-u", "me", "--dry-run"]) assert mock_main_loop.call_count == 1 @@ -147,13 +180,12 @@ def test_repo_dry_run(mock_get_config, mock_main_loop, mock_get_user_name): assert client.base_path == "/Repos/me/the-repo" +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_polling(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_polling(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: - config = MagicMock() - mock_get_config.return_value = config + mock_repo_exists.return_value = True mock_get_user_name.return_value = "me" invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo", "--polling-interval", "2"]) @@ -181,15 +213,15 @@ def test_repo_polling(mock_get_config, mock_main_loop, mock_get_user_name): assert client.base_path == "/Repos/me/the-repo" +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_include_dir(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_include_dir(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: + mock_repo_exists.return_value = True + os.mkdir(os.path.join(tempdir, "foo")) - config = MagicMock() - mock_get_config.return_value = config invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo", "-u", "me", "-i", "foo"]) assert mock_main_loop.call_count == 1 @@ -214,15 +246,15 @@ def test_repo_include_dir(mock_get_config, mock_main_loop, mock_get_user_name): assert client.base_path == "/Repos/me/the-repo" +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_force_include_dir(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_force_include_dir(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: + mock_repo_exists.return_value = True + os.mkdir(os.path.join(tempdir, "foo")) - config = MagicMock() - mock_get_config.return_value = config invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo", "-u", "me", "-fi", "foo"]) assert mock_main_loop.call_count == 1 @@ -247,15 +279,15 @@ def test_repo_force_include_dir(mock_get_config, mock_main_loop, mock_get_user_n assert client.base_path == "/Repos/me/the-repo" +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_include_pattern(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_include_pattern(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: + mock_repo_exists.return_value = True + os.mkdir(os.path.join(tempdir, "foo")) - config = MagicMock() - mock_get_config.return_value = config invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo", "-u", "me", "-ip", "foo/*.py"]) assert mock_main_loop.call_count == 1 @@ -280,15 +312,15 @@ def test_repo_include_pattern(mock_get_config, mock_main_loop, mock_get_user_nam assert client.base_path == "/Repos/me/the-repo" +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_force_include_pattern(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_force_include_pattern(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: + mock_repo_exists.return_value = True + os.mkdir(os.path.join(tempdir, "foo")) - config = MagicMock() - mock_get_config.return_value = config invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo", "-u", "me", "-fip", "foo/*.py"]) assert mock_main_loop.call_count == 1 @@ -313,15 +345,15 @@ def test_repo_force_include_pattern(mock_get_config, mock_main_loop, mock_get_us assert client.base_path == "/Repos/me/the-repo" +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_exclude_dir(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_exclude_dir(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: + mock_repo_exists.return_value = True + os.mkdir(os.path.join(tempdir, "foo")) - config = MagicMock() - mock_get_config.return_value = config invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo", "-u", "me", "-e", "foo"]) assert mock_main_loop.call_count == 1 @@ -346,15 +378,15 @@ def test_repo_exclude_dir(mock_get_config, mock_main_loop, mock_get_user_name): assert client.base_path == "/Repos/me/the-repo" +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_exclude_pattern(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_exclude_pattern(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: + mock_repo_exists.return_value = True + os.mkdir(os.path.join(tempdir, "foo")) - config = MagicMock() - mock_get_config.return_value = config invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo", "-u", "me", "-ep", "foo/**/*.py"]) assert mock_main_loop.call_count == 1 @@ -381,15 +413,15 @@ def test_repo_exclude_pattern(mock_get_config, mock_main_loop, mock_get_user_nam assert client.base_path == "/Repos/me/the-repo" +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_include_dir_not_exists(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_include_dir_not_exists(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: + mock_repo_exists.return_value = True + # we don't create the "foo" subdir, so it should produce an error - config = MagicMock() - mock_get_config.return_value = config res = invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo", "-u", "me", "-i", "foo"], expected_error=True) assert mock_main_loop.call_count == 0 @@ -399,15 +431,15 @@ def test_repo_include_dir_not_exists(mock_get_config, mock_main_loop, mock_get_u assert "does not exist" in res.output +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_inferred_source(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_inferred_source(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir, pushd(tempdir): + mock_repo_exists.return_value = True + os.mkdir(os.path.join(tempdir, ".git")) - config = MagicMock() - mock_get_config.return_value = config invoke_cli_runner(["repo", "-d", "the-repo", "-u", "me"]) assert mock_main_loop.call_count == 1 @@ -432,15 +464,15 @@ def test_repo_inferred_source(mock_get_config, mock_main_loop, mock_get_user_nam assert client.base_path == "/Repos/me/the-repo" +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_inferred_source_no_git(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_inferred_source_no_git(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir, pushd(tempdir): + mock_repo_exists.return_value = True + # source can only be inferred when the cwd contains a .git subdir - config = MagicMock() - mock_get_config.return_value = config res = invoke_cli_runner(["repo", "-d", "the-repo", "-u", "me"], expected_error=True) assert mock_main_loop.call_count == 0 @@ -450,13 +482,12 @@ def test_repo_inferred_source_no_git(mock_get_config, mock_main_loop, mock_get_u assert "Must specify source" in res.output +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_allow_delete_unmatched(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_allow_delete_unmatched(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: - config = MagicMock() - mock_get_config.return_value = config + mock_repo_exists.return_value = True invoke_cli_runner( ["repo", "-s", tempdir, "-d", "the-repo", "-u", "me", "--unmatched-behaviour=allow-delete-unmatched"] @@ -482,13 +513,12 @@ def test_repo_allow_delete_unmatched(mock_get_config, mock_main_loop, mock_get_u assert client.base_path == "/Repos/me/the-repo" +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_disallow_delete_unmatched(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_disallow_delete_unmatched(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: - config = MagicMock() - mock_get_config.return_value = config + mock_repo_exists.return_value = True invoke_cli_runner( ["repo", "-s", tempdir, "-d", "the-repo", "-u", "me", "--unmatched-behaviour=disallow-delete-unmatched"] @@ -516,14 +546,11 @@ def test_repo_disallow_delete_unmatched(mock_get_config, mock_main_loop, mock_ge @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_dbfs_no_opts(mock_get_config, mock_main_loop, mock_get_user_name): +def test_dbfs_no_opts(mock_main_loop, mock_get_user_name, mock_get_config): with temporary_directory() as tempdir, pushd(tempdir): # infer source based on cwd having a .git directory os.mkdir(os.path.join(tempdir, ".git")) - config = MagicMock() - mock_get_config.return_value = config mock_get_user_name.return_value = "me" # we can run with no options as long as the source and user can be automatically inferred @@ -553,14 +580,11 @@ def test_dbfs_no_opts(mock_get_config, mock_main_loop, mock_get_user_name): @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_dbfs_polling(mock_get_config, mock_main_loop, mock_get_user_name): +def test_dbfs_polling(mock_main_loop, mock_get_user_name, mock_get_config): with temporary_directory() as tempdir, pushd(tempdir): # infer source based on cwd having a .git directory os.mkdir(os.path.join(tempdir, ".git")) - config = MagicMock() - mock_get_config.return_value = config mock_get_user_name.return_value = "me" # we can run with no options as long as the source and user can be automatically inferred @@ -587,14 +611,11 @@ def test_dbfs_polling(mock_get_config, mock_main_loop, mock_get_user_name): @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_dbfs_dry_run(mock_get_config, mock_main_loop, mock_get_user_name): +def test_dbfs_dry_run(mock_main_loop, mock_get_user_name, mock_get_config): with temporary_directory() as tempdir, pushd(tempdir): # infer source based on cwd having a .git directory os.mkdir(os.path.join(tempdir, ".git")) - config = MagicMock() - mock_get_config.return_value = config mock_get_user_name.return_value = "me" # we can run with no options as long as the source and user can be automatically inferred @@ -620,11 +641,8 @@ def test_dbfs_dry_run(mock_get_config, mock_main_loop, mock_get_user_name): @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_dbfs_source_dest(mock_get_config, mock_main_loop, mock_get_user_name): +def test_dbfs_source_dest(mock_main_loop, mock_get_user_name, mock_get_config): with temporary_directory() as tempdir: - config = MagicMock() - mock_get_config.return_value = config mock_get_user_name.return_value = "me" # we can run with no options as long as the source and user can be automatically inferred @@ -650,15 +668,11 @@ def test_dbfs_source_dest(mock_get_config, mock_main_loop, mock_get_user_name): @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_dbfs_specify_user(mock_get_config, mock_main_loop, mock_get_user_name): +def test_dbfs_specify_user(mock_main_loop, mock_get_user_name, mock_get_config): with temporary_directory() as tempdir, pushd(tempdir): # infer source based on cwd having a .git directory os.mkdir(os.path.join(tempdir, ".git")) - config = MagicMock() - mock_get_config.return_value = config - # we can run with no options as long as the source and user can be automatically inferred invoke_cli_runner(["dbfs", "-u", "someone"]) @@ -683,11 +697,8 @@ def test_dbfs_specify_user(mock_get_config, mock_main_loop, mock_get_user_name): @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_dbfs_unknown_user(mock_get_config, mock_main_loop, mock_get_user_name): +def test_dbfs_unknown_user(mock_main_loop, mock_get_user_name, mock_get_config): with temporary_directory() as tempdir: - config = MagicMock() - mock_get_config.return_value = config mock_get_user_name.return_value = None # we can run with no options as long as the source and user can be automatically inferred @@ -701,11 +712,8 @@ def test_dbfs_unknown_user(mock_get_config, mock_main_loop, mock_get_user_name): @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_dbfs_no_root(mock_get_config, mock_main_loop, mock_get_user_name): +def test_dbfs_no_root(mock_main_loop, mock_get_user_name, mock_get_config): with temporary_directory() as tempdir: - config = MagicMock() - mock_get_config.return_value = config mock_get_user_name.return_value = "me" # we can run with no options as long as the source and user can be automatically inferred @@ -717,11 +725,13 @@ def test_dbfs_no_root(mock_get_config, mock_main_loop, mock_get_user_name): assert "Destination cannot be the root path" in res.output +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_use_gitignore(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_use_gitignore(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: + mock_repo_exists.return_value = True + os.mkdir(os.path.join(tempdir, "foo")) # .gitignore will be used by default for ignore patterns @@ -729,8 +739,6 @@ def test_repo_use_gitignore(mock_get_config, mock_main_loop, mock_get_user_name) gif.write("/bar\n") gif.write("/baz\n") - config = MagicMock() - mock_get_config.return_value = config invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo", "-u", "me", "-i", "foo"]) assert mock_main_loop.call_count == 1 @@ -757,11 +765,13 @@ def test_repo_use_gitignore(mock_get_config, mock_main_loop, mock_get_user_name) assert client.base_path == "/Repos/me/the-repo" +@patch("dbx.commands.sync.sync.repo_exists") @patch("dbx.commands.sync.sync.get_user_name") @patch("dbx.commands.sync.sync.main_loop") -@patch("dbx.commands.sync.sync.get_databricks_config") -def test_repo_no_use_gitignore(mock_get_config, mock_main_loop, mock_get_user_name): +def test_repo_no_use_gitignore(mock_main_loop, mock_get_user_name, mock_repo_exists, mock_get_config): with temporary_directory() as tempdir: + mock_repo_exists.return_value = True + os.mkdir(os.path.join(tempdir, "foo")) # .gitignore will be used by default for ignore patterns @@ -769,8 +779,6 @@ def test_repo_no_use_gitignore(mock_get_config, mock_main_loop, mock_get_user_na gif.write("/bar\n") gif.write("/baz\n") - config = MagicMock() - mock_get_config.return_value = config invoke_cli_runner(["repo", "-s", tempdir, "-d", "the-repo", "-u", "me", "-i", "foo", "--no-use-gitignore"]) assert mock_main_loop.call_count == 1 @@ -793,3 +801,9 @@ def test_repo_no_use_gitignore(mock_get_config, mock_main_loop, mock_get_user_na assert isinstance(client, ReposClient) assert client.base_path == "/Repos/me/the-repo" + + +def test_repo_exists(): + client = AsyncMock() + asyncio.run(repo_exists(client)) + assert client.exists.call_count == 1 diff --git a/tests/unit/utils/test_common.py b/tests/unit/utils/test_common.py index d102e9c1..af62664e 100644 --- a/tests/unit/utils/test_common.py +++ b/tests/unit/utils/test_common.py @@ -3,22 +3,17 @@ from pathlib import Path from subprocess import CalledProcessError from unittest import mock -from unittest.mock import MagicMock import pytest -from databricks_cli.sdk import JobsService from pytest_mock import MockFixture from dbx.api.config_reader import ConfigReader -from dbx.models.deployment import BuildConfiguration -from dbx.utils.adjuster import adjust_path, path_adjustment +from dbx.models.build import BuildConfiguration from dbx.utils.common import ( generate_filter_string, get_current_branch_name, get_environment_data, ) -from dbx.api.build import prepare_build -from dbx.utils.job_listing import find_job_by_name from tests.unit.conftest import get_path_with_relation_to_current_file json_file_01 = get_path_with_relation_to_current_file("../deployment-configs/01-json-test.json") @@ -53,7 +48,7 @@ def test_all_file_formats_contents_match(temp_project): assert yaml_default_env == json_default_env == jinja_json_default_env == jinja_yaml_default_env -@mock.patch.dict(os.environ, {"TIMEOUT": "100"}, clear=True) +@mock.patch.dict(os.environ, {"TIMEOUT": "100", "ALERT_EMAIL": "test@test.com"}, clear=True) def test_jinja_files_with_env_variables_scalar_type(temp_project): """ JINJA2: Simple Scalar (key-value) type for timeout_seconds parameter @@ -62,8 +57,8 @@ def test_jinja_files_with_env_variables_scalar_type(temp_project): json_default_envs = ConfigReader(json_j2_file_04).get_environment("default") yaml_default_envs = ConfigReader(yaml_j2_file_04).get_environment("default") - json_timeout_seconds = json_default_envs.payload.workflows[0].get("timeout_seconds") - yaml_timeout_seconds = yaml_default_envs.payload.workflows[0].get("timeout_seconds") + json_timeout_seconds = json_default_envs.payload.workflows[0].timeout_seconds + yaml_timeout_seconds = yaml_default_envs.payload.workflows[0].timeout_seconds assert int(json_timeout_seconds) == 100 assert int(yaml_timeout_seconds) == 100 @@ -77,13 +72,14 @@ def test_jinja_files_with_env_variables_array_type(temp_project): json_default_envs = ConfigReader(json_j2_file_04).get_environment("default") yaml_default_envs = ConfigReader(yaml_j2_file_04).get_environment("default") - json_emails = json_default_envs.payload.workflows[0].get("email_notifications").get("on_failure") - yaml_emails = yaml_default_envs.payload.workflows[0].get("email_notifications").get("on_failure") + json_emails = json_default_envs.payload.workflows[0].email_notifications.on_failure + yaml_emails = yaml_default_envs.payload.workflows[0].email_notifications.on_failure assert json_emails == yaml_emails assert json_emails[0] == "test@test.com" +@mock.patch.dict(os.environ, {"ALERT_EMAIL": "test@test.com"}, clear=True) def test_jinja_file_with_env_variables_default_values(temp_project): """ JINJA: @@ -96,10 +92,10 @@ def test_jinja_file_with_env_variables_default_values(temp_project): json_default_envs = ConfigReader(json_j2_file_04).get_environment("default") yaml_default_envs = ConfigReader(yaml_j2_file_04).get_environment("default") - json_max_retries = json_default_envs.payload.workflows[0].get("max_retries") - yaml_max_retries = yaml_default_envs.payload.workflows[0].get("max_retries") - json_avail = json_default_envs.payload.workflows[0].get("new_cluster").get("aws_attributes").get("availability") - yaml_avail = yaml_default_envs.payload.workflows[0].get("new_cluster").get("aws_attributes").get("availability") + json_max_retries = json_default_envs.payload.workflows[0].max_retries + yaml_max_retries = yaml_default_envs.payload.workflows[0].max_retries + json_avail = json_default_envs.payload.workflows[0].new_cluster.aws_attributes.availability + yaml_avail = yaml_default_envs.payload.workflows[0].new_cluster.aws_attributes.availability assert int(json_max_retries) == int(yaml_max_retries) assert int(json_max_retries) == 3 @@ -123,10 +119,10 @@ def test_jinja_files_with_env_variables_logic_1(temp_project): json_default_envs = ConfigReader(json_j2_file_06).get_environment("default") yaml_default_envs = ConfigReader(yaml_j2_file_06).get_environment("default") - json_max_retries = json_default_envs.payload.workflows[0].get("max_retries") - yaml_max_retries = yaml_default_envs.payload.workflows[0].get("max_retries") - json_emails = json_default_envs.payload.workflows[0].get("email_notifications").get("on_failure") - yaml_emails = yaml_default_envs.payload.workflows[0].get("email_notifications").get("on_failure") + json_max_retries = json_default_envs.payload.workflows[0].max_retries + yaml_max_retries = yaml_default_envs.payload.workflows[0].max_retries + json_emails = json_default_envs.payload.workflows[0].email_notifications.on_failure + yaml_emails = yaml_default_envs.payload.workflows[0].email_notifications.on_failure assert int(json_max_retries) == -1 assert int(yaml_max_retries) == -1 @@ -148,10 +144,10 @@ def test_jinja_files_with_env_variables_logic_2(temp_project): json_default_envs = ConfigReader(json_j2_file_06).get_environment("default") yaml_default_envs = ConfigReader(yaml_j2_file_06).get_environment("default") - json_max_retries = json_default_envs.payload.workflows[0].get("max_retries") - yaml_max_retries = yaml_default_envs.payload.workflows[0].get("max_retries") - json_emails = json_default_envs.payload.workflows[0].get("email_notifications") - yaml_emails = yaml_default_envs.payload.workflows[0].get("email_notifications") + json_max_retries = json_default_envs.payload.workflows[0].max_retries + yaml_max_retries = yaml_default_envs.payload.workflows[0].max_retries + json_emails = json_default_envs.payload.workflows[0].email_notifications + yaml_emails = yaml_default_envs.payload.workflows[0].email_notifications assert int(json_max_retries) == 3 assert int(yaml_max_retries) == 3 @@ -166,7 +162,7 @@ def test_jinja_with_include(temp_project): cluster. """ json_default_envs = ConfigReader(json_j2_file_09).get_environment("default") - json_node_type = json_default_envs.payload.workflows[0].get("new_cluster").get("node_type_id") + json_node_type = json_default_envs.payload.workflows[0].new_cluster.node_type_id assert json_node_type == "some-node-type" @@ -196,44 +192,9 @@ def test_handle_package_no_setup(temp_project): Path("setup.py").unlink() build = BuildConfiguration() with pytest.raises(CalledProcessError): - prepare_build(build) - - -def test_non_existent_path_adjustment(): - with pytest.raises(FileNotFoundError): - path_adjustment("file://some/non-existent/file", MagicMock()) - - -def test_path_adjustment(): - dbfs_path = "dbfs:/some/path" - _dbfs_result = adjust_path(dbfs_path, MagicMock()) - assert dbfs_path == _dbfs_result + build.trigger_build_process() def test_filter_string(): output = generate_filter_string(env="test", branch_name=None) assert "dbx_branch_name" not in output - - -def test_job_listing_duplicates(): - duplicated_name = "some-name" - jobs_payload = { - "jobs": [ - { - "settings": { - "name": duplicated_name, - }, - "job_id": 1, - }, - { - "settings": { - "name": duplicated_name, - }, - "job_id": 2, - }, - ] - } - js = JobsService(MagicMock()) - js.list_jobs = MagicMock(return_value=jobs_payload) - with pytest.raises(Exception): - find_job_by_name(js, duplicated_name) diff --git a/tests/unit/utils/test_dependency_manager.py b/tests/unit/utils/test_dependency_manager.py index 65ca7993..ed9c32f2 100644 --- a/tests/unit/utils/test_dependency_manager.py +++ b/tests/unit/utils/test_dependency_manager.py @@ -1,8 +1,10 @@ import textwrap from pathlib import Path -from dbx.models.deployment import BuildConfiguration -from dbx.utils.dependency_manager import DependencyManager +import pytest + +from dbx.api.dependency.requirements import RequirementsFileProcessor +from dbx.models.workflow.common.libraries import Library, PythonPyPiLibrary def write_requirements(parent: Path, content: str) -> Path: @@ -11,98 +13,30 @@ def write_requirements(parent: Path, content: str) -> Path: return _file -def test_simple_requirements_file(tmp_path: Path): - requirements_txt = write_requirements( - tmp_path, +@pytest.mark.parametrize( + "req_payload", + [ """\ tqdm rstcheck prospector>=1.3.1,<1.7.0""", - ) - - dm = DependencyManager( - BuildConfiguration(no_package=True), - global_no_package=True, - requirements_file=requirements_txt.resolve(), - ) - assert dm._requirements_references == [ - {"pypi": {"package": "tqdm"}}, - {"pypi": {"package": "rstcheck"}}, - {"pypi": {"package": "prospector<1.7.0,>=1.3.1"}}, - ] - - -def test_requirements_with_comments(tmp_path: Path): - requirements_txt = write_requirements( - tmp_path, """\ # simple comment tqdm rstcheck # use this library prospector>=1.3.1,<1.7.0""", - ) - - dm = DependencyManager( - BuildConfiguration(no_package=True), - global_no_package=True, - requirements_file=requirements_txt.resolve(), - ) - assert dm._requirements_references == [ - {"pypi": {"package": "tqdm"}}, - {"pypi": {"package": "rstcheck"}}, - {"pypi": {"package": "prospector<1.7.0,>=1.3.1"}}, - ] - - -def test_requirements_with_empty_line(tmp_path): - requirements_txt = write_requirements( - tmp_path, """\ tqdm rstcheck prospector>=1.3.1,<1.7.0""", - ) - - dm = DependencyManager( - BuildConfiguration(no_package=True), - global_no_package=True, - requirements_file=requirements_txt.resolve(), - ) - assert dm._requirements_references == [ - {"pypi": {"package": "tqdm"}}, - {"pypi": {"package": "rstcheck"}}, - {"pypi": {"package": "prospector<1.7.0,>=1.3.1"}}, + ], +) +def test_simple_requirements_file(req_payload, tmp_path: Path): + requirements_txt = write_requirements(tmp_path, req_payload) + + parsed = RequirementsFileProcessor(requirements_txt).parse_requirements() + assert parsed == [ + Library(pypi=PythonPyPiLibrary(package="tqdm")), + Library(pypi=PythonPyPiLibrary(package="rstcheck")), + Library(pypi=PythonPyPiLibrary(package="prospector<1.7.0,>=1.3.1")), ] - - -def test_requirements_with_filtered_pyspark(tmp_path): - requirements_txt = write_requirements( - tmp_path, - """\ - tqdm - pyspark==1.2.3 - rstcheck - prospector>=1.3.1,<1.7.0""", - ) - - dm = DependencyManager( - BuildConfiguration(no_package=True), - global_no_package=True, - requirements_file=requirements_txt.resolve(), - ) - assert dm._requirements_references == [ - {"pypi": {"package": "tqdm"}}, - {"pypi": {"package": "rstcheck"}}, - {"pypi": {"package": "prospector<1.7.0,>=1.3.1"}}, - ] - - -def test_not_matching_conditions(tmp_path, capsys): - - dm = DependencyManager(BuildConfiguration(no_package=True), global_no_package=True, requirements_file=None) - - reference = {"deployment_config": {"no_package": False}} - - dm.process_dependencies(reference) - captured = capsys.readouterr() - assert "--no-package option is set to true" in captured.out diff --git a/tests/unit/utils/test_file_uploader.py b/tests/unit/utils/test_file_uploader.py index f742a4e9..b88ae844 100644 --- a/tests/unit/utils/test_file_uploader.py +++ b/tests/unit/utils/test_file_uploader.py @@ -1,48 +1,24 @@ -from pathlib import PurePosixPath, PureWindowsPath, Path -from unittest.mock import patch, MagicMock - import pytest -from dbx.utils.file_uploader import MlflowFileUploader, ContextBasedUploader +from dbx.utils.file_uploader import MlflowFileUploader TEST_ARTIFACT_PATHS = ["s3://some/prefix", "dbfs:/some/prefix", "adls://some/prefix", "gs://some/prefix"] -@patch("mlflow.log_artifact", return_value=None) -def test_mlflow_uploader(_): - local_paths = [PurePosixPath("/some/local/file"), PureWindowsPath("C:\\some\\file")] - - for artifact_uri in TEST_ARTIFACT_PATHS: - for local_path in local_paths: - uploader = MlflowFileUploader(base_uri=artifact_uri) - resulting_path = uploader.upload_and_provide_path(local_path) - expected_path = "/".join([artifact_uri, str(local_path.as_posix())]) - assert expected_path == resulting_path - - -def test_context_uploader(): - local_paths = [PurePosixPath("/some/local/file"), PureWindowsPath("C:\\some\\file")] - client = MagicMock() - base_uri = "/tmp/some/path" - client.get_temp_dir = MagicMock(return_value=base_uri) - - for local_path in local_paths: - uploader = ContextBasedUploader(client) - resulting_path = uploader.upload_and_provide_path(local_path) - expected_path = "/".join([base_uri, str(local_path.as_posix())]) - assert expected_path == resulting_path - - -@patch("mlflow.log_artifact", return_value=None) -def test_fuse_support(_): - local_path = Path("/some/local/file") +@pytest.mark.parametrize("artifact_uri", TEST_ARTIFACT_PATHS) +def test_fuse_support(artifact_uri, mocker): + mocker.patch("mlflow.log_artifact", return_value=None) + mocker.patch("dbx.utils.file_uploader.MlflowFileUploader._verify_reference", return_value=None) + test_reference = "file:fuse://some-path" for artifact_uri in TEST_ARTIFACT_PATHS: uploader = MlflowFileUploader(base_uri=artifact_uri) if not artifact_uri.startswith("dbfs:/"): with pytest.raises(Exception): - uploader.upload_and_provide_path(local_path, as_fuse=True) + uploader.upload_and_provide_path(test_reference) else: - resulting_path = uploader.upload_and_provide_path(local_path, as_fuse=True) - expected_path = "/".join([artifact_uri.replace("dbfs:/", "/dbfs/"), str(local_path.as_posix())]) + resulting_path = uploader.upload_and_provide_path(test_reference) + expected_path = "/".join( + [artifact_uri.replace("dbfs:/", "/dbfs/"), test_reference.replace("file:fuse://", "")] + ) assert expected_path == resulting_path diff --git a/tests/unit/utils/test_named_properties.py b/tests/unit/utils/test_named_properties.py deleted file mode 100644 index 258efd65..00000000 --- a/tests/unit/utils/test_named_properties.py +++ /dev/null @@ -1,167 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from databricks_cli.clusters.api import ClusterService -from databricks_cli.instance_pools.api import InstancePoolService - -from dbx.api.config_reader import ConfigReader -from dbx.models.deployment import EnvironmentDeploymentInfo, BuildConfiguration -from dbx.utils.adjuster import adjust_job_definitions -from dbx.utils.dependency_manager import DependencyManager -from dbx.utils.named_properties import NewClusterPropertiesProcessor, WorkloadPropertiesProcessor -from tests.unit.conftest import get_path_with_relation_to_current_file - -samples_root_path = get_path_with_relation_to_current_file("../deployment-configs/") - -json_conf = samples_root_path / "05-json-with-named-properties.json" -yaml_conf = samples_root_path / "05-yaml-with-named-properties.yaml" - -mtj_json_conf = samples_root_path / "08-json-with-named-properties-mtj.json" -mtj_yaml_conf = samples_root_path / "08-yaml-with-named-properties-mtj.yaml" - -json_deployment_conf = ConfigReader(json_conf).get_environment("default") -yaml_deployment_conf = ConfigReader(yaml_conf).get_environment("default") - -mtj_json_dep_conf = ConfigReader(mtj_json_conf).get_environment("default") -mtj_yaml_dep_conf = ConfigReader(mtj_yaml_conf).get_environment("default") - - -def get_job_by_name(src: EnvironmentDeploymentInfo, name: str): - matched = [j for j in src.payload.workflows if j["name"] == name] - return matched[0] - - -def test_instance_profile_name_positive(): - job_in_json = get_job_by_name(json_deployment_conf, "named-props-instance-profile-name") - job_in_yaml = get_job_by_name(yaml_deployment_conf, "named-props-instance-profile-name") - - api_client = MagicMock() - test_profile_arn = "arn:aws:iam::123456789:instance-profile/some-instance-profile-name" - api_client.perform_query = MagicMock( - return_value={"instance_profiles": [{"instance_profile_arn": test_profile_arn}]} - ) - - processor = NewClusterPropertiesProcessor(api_client) - processor.process(job_in_json["new_cluster"]) - processor.process(job_in_yaml["new_cluster"]) - - assert job_in_json["new_cluster"]["aws_attributes"]["instance_profile_arn"] == test_profile_arn - assert job_in_yaml["new_cluster"]["aws_attributes"]["instance_profile_arn"] == test_profile_arn - - -def test_instance_profile_name_negative(): - job_in_json = get_job_by_name(json_deployment_conf, "named-props-instance-profile-name") - job_in_yaml = get_job_by_name(yaml_deployment_conf, "named-props-instance-profile-name") - - api_client = MagicMock() - api_client.perform_query = MagicMock( - return_value={ - "instance_profiles": [ - {"instance_profile_arn": "arn:aws:iam::123456789:instance-profile/another-instance-profile-name"} - ] - } - ) - - processor = NewClusterPropertiesProcessor(api_client) - - funcs = [ - lambda: processor.process(job_in_json["new_cluster"]), - lambda: processor.process(job_in_yaml["new_cluster"]), - ] - - for func in funcs: - with pytest.raises(Exception): - func() - - -def test_instance_pool_name_positive(): - job_in_json = get_job_by_name(json_deployment_conf, "named-props-instance-pool-name") - job_in_yaml = get_job_by_name(yaml_deployment_conf, "named-props-instance-pool-name") - - api_client = MagicMock() - test_pool_id = "aaa-bbb-000-ccc" - - with patch.object( - InstancePoolService, - "list_instance_pools", - return_value={ - "instance_pools": [{"instance_pool_name": "some-instance-pool-name", "instance_pool_id": test_pool_id}] - }, - ): - processor = NewClusterPropertiesProcessor(api_client) - processor.process(job_in_json["new_cluster"]) - processor.process(job_in_yaml["new_cluster"]) - - assert job_in_json["new_cluster"]["instance_pool_id"] == test_pool_id - assert job_in_yaml["new_cluster"]["instance_pool_id"] == test_pool_id - - -def test_instance_pool_name_negative(): - job_in_json = get_job_by_name(json_deployment_conf, "named-props-instance-pool-name") - job_in_yaml = get_job_by_name(yaml_deployment_conf, "named-props-instance-pool-name") - - api_client = MagicMock() - - processor = NewClusterPropertiesProcessor(api_client) - - funcs = [ - lambda: processor.process(job_in_json["new_cluster"]), - lambda: processor.process(job_in_yaml["new_cluster"]), - ] - - for func in funcs: - with pytest.raises(Exception): - func() - - -def test_existing_cluster_name_positive(): - job_in_json = get_job_by_name(json_deployment_conf, "named-props-existing-cluster-name") - job_in_yaml = get_job_by_name(yaml_deployment_conf, "named-props-existing-cluster-name") - - api_client = MagicMock() - test_existing_cluster_id = "aaa-bbb-000-ccc" - with patch.object( - ClusterService, - "list_clusters", - return_value={"clusters": [{"cluster_name": "some-cluster", "cluster_id": test_existing_cluster_id}]}, - ): - processor = WorkloadPropertiesProcessor(api_client) - processor.process(job_in_json) - processor.process(job_in_yaml) - - assert job_in_yaml["existing_cluster_id"] == test_existing_cluster_id - assert job_in_json["existing_cluster_id"] == test_existing_cluster_id - - -def test_existing_cluster_name_negative(): - job1 = get_job_by_name(json_deployment_conf, "named-props-existing-cluster-name") - api_client = MagicMock() - - processor = WorkloadPropertiesProcessor(api_client) - - with pytest.raises(Exception): - processor.process(job1) - - -def test_mtj_named_positive(): - file_uploader = MagicMock() - api_client = MagicMock() - test_profile_arn = "arn:aws:iam::123456789:instance-profile/some-instance-profile-name" - - dm = DependencyManager(BuildConfiguration(no_package=True), global_no_package=False, requirements_file=None) - - api_client.perform_query = MagicMock( - return_value={"instance_profiles": [{"instance_profile_arn": test_profile_arn}]} - ) - - sample_reference = {"whl": "path/to/some/file"} - dm._core_package_reference = sample_reference - - for deployment_conf in [mtj_json_dep_conf, mtj_yaml_dep_conf]: - jobs = deployment_conf.payload.workflows - - adjust_job_definitions(jobs=jobs, dependency_manager=dm, file_uploader=file_uploader, api_client=api_client) - - assert jobs[0]["job_clusters"][0]["new_cluster"]["aws_attributes"]["instance_profile_arn"] is not None - assert jobs[0]["tasks"][0]["libraries"] == [] - assert jobs[0]["tasks"][1]["libraries"] == [sample_reference] diff --git a/tests/unit/utils/test_policy_parser.py b/tests/unit/utils/test_policy_parser.py deleted file mode 100644 index 652a936c..00000000 --- a/tests/unit/utils/test_policy_parser.py +++ /dev/null @@ -1,25 +0,0 @@ -from dbx.utils.policy_parser import PolicyParser - - -def test_base_aws_policy(): - _policy = { - "aws_attributes.instance_profile_arn": { - "type": "fixed", - "value": "arn:aws:iam::123456789:instance-profile/sample-aws-iam", - }, - "spark_conf.spark.my.conf": {"type": "fixed", "value": "my_value"}, - "spark_conf.spark.my.other.conf": {"type": "fixed", "value": "my_other_value"}, - "init_scripts.0.dbfs.destination": {"type": "fixed", "value": "dbfs:/some/init-scripts/sc1.sh"}, - "init_scripts.1.dbfs.destination": {"type": "fixed", "value": "dbfs:/some/init-scripts/sc2.sh"}, - } - _formatted = { - "aws_attributes": {"instance_profile_arn": "arn:aws:iam::123456789:instance-profile/sample-aws-iam"}, - "spark_conf": {"spark.my.conf": "my_value", "spark.my.other.conf": "my_other_value"}, - "init_scripts": [ - {"dbfs": {"destination": "dbfs:/some/init-scripts/sc1.sh"}}, - {"dbfs": {"destination": "dbfs:/some/init-scripts/sc2.sh"}}, - ], - } - parser = PolicyParser(_policy) - _parsed = parser.parse() - assert _formatted == _parsed From e5ebbe013c7a4e6a1a18b75bd151485f0dad6597 Mon Sep 17 00:00:00 2001 From: Matt Hayes Date: Mon, 20 Mar 2023 09:32:25 -0700 Subject: [PATCH 3/4] fix test --- tests/unit/sync/clients/test_dbfs_client.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/unit/sync/clients/test_dbfs_client.py b/tests/unit/sync/clients/test_dbfs_client.py index d2f21902..e983390f 100644 --- a/tests/unit/sync/clients/test_dbfs_client.py +++ b/tests/unit/sync/clients/test_dbfs_client.py @@ -240,11 +240,11 @@ async def mock_json(*args, **kwargs): def mock_post(url, *args, **kwargs): resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) - if "/base/api/2.0/dbfs/put" in url: + if "/api/2.0/dbfs/put" in url: contents = kwargs.get("json").get("contents") if len(contents) > 1024 * 1024: # replicate the api error thrown when contents exceeds max allowed setattr(type(resp), "status", PropertyMock(return_value=400)) - elif "/base/api/2.0/dbfs/create" in url: + elif "/api/2.0/dbfs/create" in url: # return a mock response json resp.json = MagicMock(side_effect=mock_json) @@ -262,11 +262,11 @@ def mock_post(url, *args, **kwargs): chunks = textwrap.wrap(base64.b64encode(bytes(expected_contents, encoding="utf8")).decode("ascii"), 1024 * 1024) assert session.post.call_count == len(chunks) + 2 - assert session.post.call_args_list[0][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/create" - assert session.post.call_args_list[1][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/add-block" - assert session.post.call_args_list[2][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/add-block" - assert session.post.call_args_list[3][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/add-block" - assert session.post.call_args_list[4][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/close" + assert session.post.call_args_list[0][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/create" + assert session.post.call_args_list[1][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/add-block" + assert session.post.call_args_list[2][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/add-block" + assert session.post.call_args_list[3][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/add-block" + assert session.post.call_args_list[4][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/close" assert session.post.call_args_list[0][1]["json"] == { "path": "dbfs:/tmp/foo/foo/bar", From 843356c965ed3b7785c78c0687be86231a468c69 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Tue, 21 Mar 2023 08:10:57 +0100 Subject: [PATCH 4/4] add changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 52e7fadf..e360b92f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - 📌 switch from using `retry` to `tenacity` +### Added +- ✨ support for files bigger than 1MB in sync + ## [0.8.8] - 2022-02-22 # Fixed