Skip to content

Commit

Permalink
Add get_artifact_handle for Jobs
Browse files Browse the repository at this point in the history
* Adds a helper method to centralize any
extra smart_open options we might need when
accessing artifacts via a job

GitOrigin-RevId: cc2938dd68d210289f2d6a9bf181989ee86feb90
  • Loading branch information
mckornfield committed Aug 8, 2023
1 parent 6e2bd5c commit 0ecf14f
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 43 deletions.
79 changes: 53 additions & 26 deletions src/gretel_client/projects/artifact_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
from typing import Any, BinaryIO, Dict, List, Optional, Protocol, Tuple, Union
from urllib.parse import urlparse

import requests
Expand All @@ -28,14 +28,37 @@

try:
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
except ImportError: # pragma: no cover
DefaultAzureCredential = None
try:
from azure.storage.blob import BlobServiceClient
except ImportError: # pragma: no cover
BlobServiceClient = None

HYBRID_ARTIFACT_ENDPOINT_PREFIXES = ["azure://", "gs://", "s3://"]


def _get_azure_blob_srv_client() -> Optional[BlobServiceClient]:
if (storage_account := os.getenv("OAUTH_STORAGE_ACCOUNT_NAME")) is not None:
oauth_url = "https://{}.blob.core.windows.net".format(storage_account)
return BlobServiceClient(
account_url=oauth_url, credential=DefaultAzureCredential()
)

if (connect_str := os.getenv("AZURE_STORAGE_CONNECTION_STRING")) is not None:
return BlobServiceClient.from_connection_string(connect_str)


def _get_transport_params(endpoint: str) -> dict:
"""Returns a set of transport params that are suitable for passing
into calls to ``smart_open.open``.
"""
client = None
if endpoint and endpoint.startswith("azure"):
client = _get_azure_blob_srv_client()
return {"client": client} if client else {}


class ArtifactsException(Exception):
pass

Expand All @@ -48,7 +71,8 @@ class ManifestNotFoundException(Exception):

class ManifestPendingException(Exception):
def __init__(self, msg: Optional[str] = None, manifest: dict = {}):
# Piggyback the pending manifest onto the exception. If we give up, we still want to pass it back as a normal return value.
# Piggyback the pending manifest onto the exception.
# If we give up, we still want to pass it back as a normal return value.
self.manifest = manifest
super().__init__(msg)

Expand Down Expand Up @@ -106,6 +130,9 @@ def list_project_artifacts(self) -> List[dict]:
def get_project_artifact_link(self, key: str) -> str:
...

def get_project_artifact_handle(self, key: str) -> BinaryIO:
...

def get_project_artifact_manifest(
self,
key: str,
Expand Down Expand Up @@ -187,6 +214,13 @@ def get_project_artifact_link(self, key: str) -> str:
)
return resp[f.DATA][f.DATA][f.URL]

@contextmanager
def get_project_artifact_handle(self, key: str) -> BinaryIO:
link = self.get_project_artifact_link(key)
transport_params = _get_transport_params(link)
with smart_open.open(link, "rb", transport_params=transport_params) as handle:
yield handle

# The server side API will return manifests with PENDING status if artifact processing has not completed
# or it will return a 404 (not found) if you immediately request the artifact before processing has even started.
# This is correct but not convenient. To keep every end user from writing their own retry logic, we add some here.
Expand All @@ -201,7 +235,8 @@ def get_project_artifact_link(self, key: str) -> str:
wait=wait_fixed(3),
stop=stop_after_attempt(5),
retry=retry_if_exception_type(ManifestNotFoundException),
# Instead of throwing an exception, return None. Given that we waited for a short grace period to let the artifact become PENDING,
# Instead of throwing an exception, return None.
# Given that we waited for a short grace period to let the artifact become PENDING,
# if we are still getting 404's the key probably does not actually exist.
retry_error_callback=lambda retry_state: None,
)
Expand Down Expand Up @@ -273,25 +308,6 @@ def validate_data_source(

return common.validate_data_source(artifact_path)

def _get_azure_blob_srv_client(self) -> Optional[BlobServiceClient]:
if (storage_account := os.getenv("OAUTH_STORAGE_ACCOUNT_NAME")) is not None:
oauth_url = "https://{}.blob.core.windows.net".format(storage_account)
return BlobServiceClient(
account_url=oauth_url, credential=DefaultAzureCredential()
)

if (connect_str := os.getenv("AZURE_STORAGE_CONNECTION_STRING")) is not None:
return BlobServiceClient.from_connection_string(connect_str)

def _get_transport_params(self) -> dict:
"""Returns a set of transport params that are suitable for passing
into calls to ``smart_open.open``.
"""
client = None
if self.endpoint.startswith("azure"):
client = self._get_azure_blob_srv_client()
return {"client": client} if client else {}

def upload_project_artifact(
self,
artifact_path: Union[Path, str, _DataFrameT],
Expand All @@ -311,9 +327,9 @@ def upload_project_artifact(
artifact_path,
"rb",
ignore_ext=True,
transport_params=self._get_transport_params(),
transport_params=_get_transport_params(self.endpoint),
) as in_stream, smart_open.open(
target_out, "wb", transport_params=self._get_transport_params()
target_out, "wb", transport_params=_get_transport_params(self.endpoint)
) as out_stream:
shutil.copyfileobj(in_stream, out_stream)

Expand Down Expand Up @@ -343,6 +359,13 @@ def get_project_artifact_link(self, key: str) -> str:
return key
return f"{self.data_sources_dir}/{key}"

@contextmanager
def get_project_artifact_handle(self, key: str) -> BinaryIO:
link = self.get_project_artifact_link(key)
transport_params = _get_transport_params(link)
with smart_open.open(link, "rb", transport_params=transport_params) as handle:
yield handle

def get_project_artifact_manifest(
self,
key: str,
Expand Down Expand Up @@ -378,7 +401,11 @@ def download(
log: logging.Logger,
) -> None:
_download(
download_link, output_path, artifact_type, log, self._get_transport_params()
download_link,
output_path,
artifact_type,
log,
_get_transport_params(self.endpoint),
)


Expand Down
49 changes: 38 additions & 11 deletions src/gretel_client/projects/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,32 @@
import time

from abc import ABC, abstractmethod, abstractproperty
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from urllib.parse import urlparse
from typing import (
BinaryIO,
Callable,
Iterator,
List,
Optional,
Tuple,
Type,
TYPE_CHECKING,
Union,
)

import requests
import smart_open

import gretel_client.rest.exceptions

from gretel_client.cli.utils.parser_utils import RefData
from gretel_client.config import (
DEFAULT_RUNNER,
get_logger,
get_session_config,
RunnerMode,
)
from gretel_client.config import get_logger, get_session_config, RunnerMode
from gretel_client.dataframe import _DataFrameT
from gretel_client.models.config import get_model_type_config
from gretel_client.projects.artifact_handlers import (
_get_transport_params,
ArtifactsHandler,
CloudArtifactsHandler,
HybridArtifactsHandler,
Expand Down Expand Up @@ -347,14 +352,33 @@ def get_artifact_link(self, artifact_key: str) -> str:
artifact type.
Args:
artifact_type: Artifact type to download.
artifact_key: Artifact type to download.
"""
if artifact_key not in self.artifact_types:
raise Exception(
f"artifact_key {artifact_key} not a valid key. Valid keys are {self.artifact_types}"
)
return self._do_get_artifact(artifact_key)

@contextmanager
def get_artifact_handle(self, artifact_key: str) -> BinaryIO:
"""Returns a reference to a remote artifact that can be used to
read binary data within a context manager
>>> with job.get_artifact_handle("report_json") as file:
... print(file.read())
Args:
artifact_key: Artifact type to download.
Returns:
a file like object
"""
link = self.get_artifact_link(artifact_key)
transport_params = _get_transport_params(link)
with smart_open.open(link, "rb", transport_params=transport_params) as handle:
yield handle

def download_artifacts(self, target_dir: Union[str, Path]):
"""Given a target directory, either as a string or a Path object, attempt to enumerate
and download all artifacts associated with this Job
Expand Down Expand Up @@ -390,7 +414,10 @@ def _get_report_contents(
report_contents = None
if report_path:
try:
with smart_open.open(report_path, "rb") as rh: # type:ignore
transport_params = _get_transport_params(report_path)
with smart_open.open(
report_path, "rb", transport_params=transport_params
) as rh: # type:ignore
report_contents = rh.read()
except Exception:
pass
Expand Down
20 changes: 17 additions & 3 deletions src/gretel_client/projects/projects.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""
High level API for interacting with a Gretel Project
"""
import uuid

from contextlib import contextmanager
from functools import wraps
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Type, TypeVar, Union
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Type, TypeVar, Union

from backports.cached_property import cached_property

Expand Down Expand Up @@ -269,6 +267,22 @@ def get_artifact_link(self, key: str) -> str:
"""
return self.default_artifacts_handler.get_project_artifact_link(key)

@contextmanager
def get_artifact_handle(self, key: str) -> BinaryIO:
"""Returns a reference to a remote artifact that can be used to
read binary data within a context manager
>>> with job.get_artifact_handle("report_json") as file:
... print(file.read())
Args:
key: Artifact key to download.
Returns:
a file like object
"""
return self.default_artifacts_handler.get_project_artifact_handle(key)

def get_artifact_manifest(
self, key: str, retry_on_not_found: bool = True, retry_on_pending: bool = True
) -> dict:
Expand Down
2 changes: 2 additions & 0 deletions tests/gretel_client/integration/test_cli_models_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ def test_create_records_from_model_obj(
str(tmpdir),
"--model-id",
str(model_obj),
"--project",
project.project_id,
"--model-path",
str(tmpdir / "model.tar.gz"),
"--in-data",
Expand Down
7 changes: 4 additions & 3 deletions tests/gretel_client/test_artifact_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import pandas as pd
import pytest

from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient

from gretel_client.config import DEFAULT_GRETEL_ARTIFACT_ENDPOINT
from gretel_client.projects.artifact_handlers import (
_get_artifact_path_and_file_name,
_get_transport_params,
ArtifactsException,
hybrid_handler,
HybridArtifactsHandler,
Expand Down Expand Up @@ -78,7 +78,8 @@ def test_hybrid_created_with_custom_artifact_endpoint():
)
def test_hybrid_created_with_azure_artifact_endpoint(key: str, value: str):
with patch.dict(os.environ, {key: value}):
config = Mock(artifact_endpoint="azure://my-bucket")
artifact_endpoint = "azure://my-bucket"
config = Mock(artifact_endpoint=artifact_endpoint)
project = Mock(
project_id="123",
name="proj",
Expand All @@ -87,7 +88,7 @@ def test_hybrid_created_with_azure_artifact_endpoint(key: str, value: str):
)

assert isinstance(hybrid_handler(project), HybridArtifactsHandler)
transport_params = hybrid_handler(project)._get_transport_params()
transport_params = _get_transport_params(artifact_endpoint)

assert transport_params is not None
assert isinstance(transport_params.get("client"), BlobServiceClient)
Expand Down
55 changes: 55 additions & 0 deletions tests/gretel_client/test_projects.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -8,6 +10,7 @@
CloudArtifactsHandler,
HybridArtifactsHandler,
)
from gretel_client.projects.models import Model
from gretel_client.projects.projects import GretelProjectError, Project


Expand Down Expand Up @@ -64,3 +67,55 @@ def test_default_aritfacts_handler_raises_under_unsupported_runner_modes(

with pytest.raises(GretelProjectError):
project.default_artifacts_handler


@patch("gretel_client.projects.projects.get_session_config")
@patch("smart_open.open")
@patch("gretel_client.projects.artifact_handlers.BlobServiceClient")
@patch.dict(
os.environ,
{
"AZURE_STORAGE_CONNECTION_STRING": "BlobEndpoint=https://test.blob.core.windows.net/"
},
)
def test_get_artifact_handle_azure(
blob_client_mock: MagicMock,
smart_open_mock: MagicMock,
get_session_config: MagicMock,
):
config = MagicMock(
artifact_endpoint="azure://my-bucket",
default_runner=RunnerMode.HYBRID,
)
get_session_config.return_value = config
blob_client_mock_from_conn = MagicMock()
blob_client_mock.from_connection_string.return_value = blob_client_mock_from_conn

run = Model(Project(name="proj", project_id="123"), model_id="my_model_id")
with run.get_artifact_handle("report_json"):
smart_open_mock.assert_called_once_with(
"azure://my-bucket/123/model/my_model_id/report_json.json.gz",
"rb",
transport_params={"client": blob_client_mock_from_conn},
)


@patch("gretel_client.projects.projects.get_session_config")
@patch("smart_open.open")
def test_get_artifact_handle_gs(
smart_open_mock: MagicMock,
get_session_config: MagicMock,
):
config = MagicMock(
artifact_endpoint="gs://my-bucket",
default_runner=RunnerMode.HYBRID,
)
get_session_config.return_value = config

run = Model(Project(name="proj", project_id="123"), model_id="my_model_id")
with run.get_artifact_handle("report_json"):
smart_open_mock.assert_called_once_with(
"gs://my-bucket/123/model/my_model_id/report_json.json.gz",
"rb",
transport_params={},
)

0 comments on commit 0ecf14f

Please sign in to comment.