diff --git a/src/gretel_client/projects/artifact_handlers.py b/src/gretel_client/projects/artifact_handlers.py index 37921f62..b7d578af 100644 --- a/src/gretel_client/projects/artifact_handlers.py +++ b/src/gretel_client/projects/artifact_handlers.py @@ -8,7 +8,7 @@ from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, List, Optional, Protocol, Tuple, Union +from typing import Any, BinaryIO, Dict, List, Optional, Protocol, Tuple, Union from urllib.parse import urlparse import requests @@ -28,14 +28,37 @@ try: from azure.identity import DefaultAzureCredential - from azure.storage.blob import BlobServiceClient except ImportError: # pragma: no cover DefaultAzureCredential = None +try: + from azure.storage.blob import BlobServiceClient +except ImportError: # pragma: no cover BlobServiceClient = None HYBRID_ARTIFACT_ENDPOINT_PREFIXES = ["azure://", "gs://", "s3://"] +def _get_azure_blob_srv_client() -> Optional[BlobServiceClient]: + if (storage_account := os.getenv("OAUTH_STORAGE_ACCOUNT_NAME")) is not None: + oauth_url = "https://{}.blob.core.windows.net".format(storage_account) + return BlobServiceClient( + account_url=oauth_url, credential=DefaultAzureCredential() + ) + + if (connect_str := os.getenv("AZURE_STORAGE_CONNECTION_STRING")) is not None: + return BlobServiceClient.from_connection_string(connect_str) + + +def _get_transport_params(endpoint: str) -> dict: + """Returns a set of transport params that are suitable for passing + into calls to ``smart_open.open``. + """ + client = None + if endpoint and endpoint.startswith("azure"): + client = _get_azure_blob_srv_client() + return {"client": client} if client else {} + + class ArtifactsException(Exception): pass @@ -48,7 +71,8 @@ class ManifestNotFoundException(Exception): class ManifestPendingException(Exception): def __init__(self, msg: Optional[str] = None, manifest: dict = {}): - # Piggyback the pending manifest onto the exception. If we give up, we still want to pass it back as a normal return value. + # Piggyback the pending manifest onto the exception. + # If we give up, we still want to pass it back as a normal return value. self.manifest = manifest super().__init__(msg) @@ -106,6 +130,9 @@ def list_project_artifacts(self) -> List[dict]: def get_project_artifact_link(self, key: str) -> str: ... + def get_project_artifact_handle(self, key: str) -> BinaryIO: + ... + def get_project_artifact_manifest( self, key: str, @@ -187,6 +214,13 @@ def get_project_artifact_link(self, key: str) -> str: ) return resp[f.DATA][f.DATA][f.URL] + @contextmanager + def get_project_artifact_handle(self, key: str) -> BinaryIO: + link = self.get_project_artifact_link(key) + transport_params = _get_transport_params(link) + with smart_open.open(link, "rb", transport_params=transport_params) as handle: + yield handle + # The server side API will return manifests with PENDING status if artifact processing has not completed # or it will return a 404 (not found) if you immediately request the artifact before processing has even started. # This is correct but not convenient. To keep every end user from writing their own retry logic, we add some here. @@ -201,7 +235,8 @@ def get_project_artifact_link(self, key: str) -> str: wait=wait_fixed(3), stop=stop_after_attempt(5), retry=retry_if_exception_type(ManifestNotFoundException), - # Instead of throwing an exception, return None. Given that we waited for a short grace period to let the artifact become PENDING, + # Instead of throwing an exception, return None. + # Given that we waited for a short grace period to let the artifact become PENDING, # if we are still getting 404's the key probably does not actually exist. retry_error_callback=lambda retry_state: None, ) @@ -273,25 +308,6 @@ def validate_data_source( return common.validate_data_source(artifact_path) - def _get_azure_blob_srv_client(self) -> Optional[BlobServiceClient]: - if (storage_account := os.getenv("OAUTH_STORAGE_ACCOUNT_NAME")) is not None: - oauth_url = "https://{}.blob.core.windows.net".format(storage_account) - return BlobServiceClient( - account_url=oauth_url, credential=DefaultAzureCredential() - ) - - if (connect_str := os.getenv("AZURE_STORAGE_CONNECTION_STRING")) is not None: - return BlobServiceClient.from_connection_string(connect_str) - - def _get_transport_params(self) -> dict: - """Returns a set of transport params that are suitable for passing - into calls to ``smart_open.open``. - """ - client = None - if self.endpoint.startswith("azure"): - client = self._get_azure_blob_srv_client() - return {"client": client} if client else {} - def upload_project_artifact( self, artifact_path: Union[Path, str, _DataFrameT], @@ -311,9 +327,9 @@ def upload_project_artifact( artifact_path, "rb", ignore_ext=True, - transport_params=self._get_transport_params(), + transport_params=_get_transport_params(self.endpoint), ) as in_stream, smart_open.open( - target_out, "wb", transport_params=self._get_transport_params() + target_out, "wb", transport_params=_get_transport_params(self.endpoint) ) as out_stream: shutil.copyfileobj(in_stream, out_stream) @@ -343,6 +359,13 @@ def get_project_artifact_link(self, key: str) -> str: return key return f"{self.data_sources_dir}/{key}" + @contextmanager + def get_project_artifact_handle(self, key: str) -> BinaryIO: + link = self.get_project_artifact_link(key) + transport_params = _get_transport_params(link) + with smart_open.open(link, "rb", transport_params=transport_params) as handle: + yield handle + def get_project_artifact_manifest( self, key: str, @@ -378,7 +401,11 @@ def download( log: logging.Logger, ) -> None: _download( - download_link, output_path, artifact_type, log, self._get_transport_params() + download_link, + output_path, + artifact_type, + log, + _get_transport_params(self.endpoint), ) diff --git a/src/gretel_client/projects/jobs.py b/src/gretel_client/projects/jobs.py index 178b82d2..6688bc38 100644 --- a/src/gretel_client/projects/jobs.py +++ b/src/gretel_client/projects/jobs.py @@ -5,27 +5,32 @@ import time from abc import ABC, abstractmethod, abstractproperty +from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Callable, Iterator, List, Optional, Tuple, Type, TYPE_CHECKING, Union -from urllib.parse import urlparse +from typing import ( + BinaryIO, + Callable, + Iterator, + List, + Optional, + Tuple, + Type, + TYPE_CHECKING, + Union, +) -import requests import smart_open import gretel_client.rest.exceptions from gretel_client.cli.utils.parser_utils import RefData -from gretel_client.config import ( - DEFAULT_RUNNER, - get_logger, - get_session_config, - RunnerMode, -) +from gretel_client.config import get_logger, get_session_config, RunnerMode from gretel_client.dataframe import _DataFrameT from gretel_client.models.config import get_model_type_config from gretel_client.projects.artifact_handlers import ( + _get_transport_params, ArtifactsHandler, CloudArtifactsHandler, HybridArtifactsHandler, @@ -347,7 +352,7 @@ def get_artifact_link(self, artifact_key: str) -> str: artifact type. Args: - artifact_type: Artifact type to download. + artifact_key: Artifact type to download. """ if artifact_key not in self.artifact_types: raise Exception( @@ -355,6 +360,25 @@ def get_artifact_link(self, artifact_key: str) -> str: ) return self._do_get_artifact(artifact_key) + @contextmanager + def get_artifact_handle(self, artifact_key: str) -> BinaryIO: + """Returns a reference to a remote artifact that can be used to + read binary data within a context manager + + >>> with job.get_artifact_handle("report_json") as file: + ... print(file.read()) + + Args: + artifact_key: Artifact type to download. + + Returns: + a file like object + """ + link = self.get_artifact_link(artifact_key) + transport_params = _get_transport_params(link) + with smart_open.open(link, "rb", transport_params=transport_params) as handle: + yield handle + def download_artifacts(self, target_dir: Union[str, Path]): """Given a target directory, either as a string or a Path object, attempt to enumerate and download all artifacts associated with this Job @@ -390,7 +414,10 @@ def _get_report_contents( report_contents = None if report_path: try: - with smart_open.open(report_path, "rb") as rh: # type:ignore + transport_params = _get_transport_params(report_path) + with smart_open.open( + report_path, "rb", transport_params=transport_params + ) as rh: # type:ignore report_contents = rh.read() except Exception: pass diff --git a/src/gretel_client/projects/projects.py b/src/gretel_client/projects/projects.py index e9e9ee6b..8059d236 100644 --- a/src/gretel_client/projects/projects.py +++ b/src/gretel_client/projects/projects.py @@ -1,13 +1,11 @@ """ High level API for interacting with a Gretel Project """ -import uuid from contextlib import contextmanager from functools import wraps -from io import BytesIO from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional, Type, TypeVar, Union +from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Type, TypeVar, Union from backports.cached_property import cached_property @@ -269,6 +267,22 @@ def get_artifact_link(self, key: str) -> str: """ return self.default_artifacts_handler.get_project_artifact_link(key) + @contextmanager + def get_artifact_handle(self, key: str) -> BinaryIO: + """Returns a reference to a remote artifact that can be used to + read binary data within a context manager + + >>> with job.get_artifact_handle("report_json") as file: + ... print(file.read()) + + Args: + key: Artifact key to download. + + Returns: + a file like object + """ + return self.default_artifacts_handler.get_project_artifact_handle(key) + def get_artifact_manifest( self, key: str, retry_on_not_found: bool = True, retry_on_pending: bool = True ) -> dict: diff --git a/tests/gretel_client/integration/test_cli_models_local.py b/tests/gretel_client/integration/test_cli_models_local.py index a88e9985..499054ad 100644 --- a/tests/gretel_client/integration/test_cli_models_local.py +++ b/tests/gretel_client/integration/test_cli_models_local.py @@ -358,6 +358,8 @@ def test_create_records_from_model_obj( str(tmpdir), "--model-id", str(model_obj), + "--project", + project.project_id, "--model-path", str(tmpdir / "model.tar.gz"), "--in-data", diff --git a/tests/gretel_client/test_artifact_handlers.py b/tests/gretel_client/test_artifact_handlers.py index 8687e42b..7c9b886d 100644 --- a/tests/gretel_client/test_artifact_handlers.py +++ b/tests/gretel_client/test_artifact_handlers.py @@ -7,12 +7,12 @@ import pandas as pd import pytest -from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient from gretel_client.config import DEFAULT_GRETEL_ARTIFACT_ENDPOINT from gretel_client.projects.artifact_handlers import ( _get_artifact_path_and_file_name, + _get_transport_params, ArtifactsException, hybrid_handler, HybridArtifactsHandler, @@ -78,7 +78,8 @@ def test_hybrid_created_with_custom_artifact_endpoint(): ) def test_hybrid_created_with_azure_artifact_endpoint(key: str, value: str): with patch.dict(os.environ, {key: value}): - config = Mock(artifact_endpoint="azure://my-bucket") + artifact_endpoint = "azure://my-bucket" + config = Mock(artifact_endpoint=artifact_endpoint) project = Mock( project_id="123", name="proj", @@ -87,7 +88,7 @@ def test_hybrid_created_with_azure_artifact_endpoint(key: str, value: str): ) assert isinstance(hybrid_handler(project), HybridArtifactsHandler) - transport_params = hybrid_handler(project)._get_transport_params() + transport_params = _get_transport_params(artifact_endpoint) assert transport_params is not None assert isinstance(transport_params.get("client"), BlobServiceClient) diff --git a/tests/gretel_client/test_projects.py b/tests/gretel_client/test_projects.py index afa6c288..22be1295 100644 --- a/tests/gretel_client/test_projects.py +++ b/tests/gretel_client/test_projects.py @@ -1,3 +1,5 @@ +import os + from unittest.mock import MagicMock, patch import pytest @@ -8,6 +10,7 @@ CloudArtifactsHandler, HybridArtifactsHandler, ) +from gretel_client.projects.models import Model from gretel_client.projects.projects import GretelProjectError, Project @@ -64,3 +67,55 @@ def test_default_aritfacts_handler_raises_under_unsupported_runner_modes( with pytest.raises(GretelProjectError): project.default_artifacts_handler + + +@patch("gretel_client.projects.projects.get_session_config") +@patch("smart_open.open") +@patch("gretel_client.projects.artifact_handlers.BlobServiceClient") +@patch.dict( + os.environ, + { + "AZURE_STORAGE_CONNECTION_STRING": "BlobEndpoint=https://test.blob.core.windows.net/" + }, +) +def test_get_artifact_handle_azure( + blob_client_mock: MagicMock, + smart_open_mock: MagicMock, + get_session_config: MagicMock, +): + config = MagicMock( + artifact_endpoint="azure://my-bucket", + default_runner=RunnerMode.HYBRID, + ) + get_session_config.return_value = config + blob_client_mock_from_conn = MagicMock() + blob_client_mock.from_connection_string.return_value = blob_client_mock_from_conn + + run = Model(Project(name="proj", project_id="123"), model_id="my_model_id") + with run.get_artifact_handle("report_json"): + smart_open_mock.assert_called_once_with( + "azure://my-bucket/123/model/my_model_id/report_json.json.gz", + "rb", + transport_params={"client": blob_client_mock_from_conn}, + ) + + +@patch("gretel_client.projects.projects.get_session_config") +@patch("smart_open.open") +def test_get_artifact_handle_gs( + smart_open_mock: MagicMock, + get_session_config: MagicMock, +): + config = MagicMock( + artifact_endpoint="gs://my-bucket", + default_runner=RunnerMode.HYBRID, + ) + get_session_config.return_value = config + + run = Model(Project(name="proj", project_id="123"), model_id="my_model_id") + with run.get_artifact_handle("report_json"): + smart_open_mock.assert_called_once_with( + "gs://my-bucket/123/model/my_model_id/report_json.json.gz", + "rb", + transport_params={}, + )