Skip to content

Commit

Permalink
feature: JumpStart CuratedHub class creation and function definitions (
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyoung-lim authored and bencrabtree committed Mar 21, 2024
1 parent fbed0fc commit f4c72ca
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 22 deletions.
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
MODEL_TYPE_TO_MANIFEST_MAP,
MODEL_TYPE_TO_SPECS_MAP,
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.exceptions import (
get_wildcard_model_version_msg,
Expand Down
19 changes: 0 additions & 19 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,25 +1594,6 @@ def from_describe_hub_content_response(self, response: DescribeHubContentRespons
else None
)

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartModelSpecs object."""
json_obj = {}
for att in self.__slots__:
if hasattr(self, att):
cur_val = getattr(self, att)
if issubclass(type(cur_val), JumpStartDataHolderType):
json_obj[att] = cur_val.to_json()
elif isinstance(cur_val, list):
json_obj[att] = []
for obj in cur_val:
if issubclass(type(obj), JumpStartDataHolderType):
json_obj[att].append(obj.to_json())
else:
json_obj[att].append(obj)
else:
json_obj[att] = cur_val
return json_obj

def supports_prepacked_inference(self) -> bool:
"""Returns True if the model has a prepacked inference artifact."""
return getattr(self, "hosting_prepacked_artifact_key", None) is not None
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,38 @@ def test_generate_hub_arn_for_init_kwargs():
utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn
)

assert (
utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session)
== hub_arn
)


def test_generate_default_hub_bucket_name():
mock_sagemaker_session = Mock()
mock_sagemaker_session.account_id.return_value = "123456789123"
mock_sagemaker_session.boto_region_name = "us-east-1"

assert (
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
== "sagemaker-hubs-us-east-1-123456789123"
)


def test_create_hub_bucket_if_it_does_not_exist():
mock_sagemaker_session = Mock()
mock_sagemaker_session.account_id.return_value = "123456789123"
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
"Account": "123456789123"
}
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
mock_sagemaker_session.boto_region_name = "us-east-1"
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
sagemaker_session=mock_sagemaker_session
)

mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
assert created_hub_bucket_name == bucket_name
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn


Expand Down
3 changes: 2 additions & 1 deletion tests/unit/sagemaker/jumpstart/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
JumpStartModelsCache,
)
from sagemaker.session_settings import SessionSettings
from sagemaker.jumpstart.constants import (
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
Expand Down Expand Up @@ -1133,7 +1134,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(

mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
assert mocked_is_dir.call_count == 2
mocked_open.assert_not_called()
assert mocked_open.call_count == 2
mocked_get_json_file_and_etag_from_s3.assert_has_calls(
calls=[
call("models_manifest.json"),
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
JUMPSTART_REGION_NAME_SET,
)
from sagemaker.jumpstart.types import (
HubContentType,
JumpStartCachedContentKey,
JumpStartCachedContentValue,
JumpStartModelSpecs,
Expand All @@ -32,6 +31,7 @@
HubContentType,
)
from sagemaker.jumpstart.enums import JumpStartModelType

from sagemaker.jumpstart.utils import get_formatted_manifest
from tests.unit.sagemaker.jumpstart.constants import (
PROTOTYPICAL_MODEL_SPECS_DICT,
Expand Down Expand Up @@ -254,7 +254,7 @@ def patched_retrieval_function(
)
)
# TODO: Implement
if datatype == HubContentType.HUB:
if datatype == HubType.HUB:
return None

raise ValueError(f"Bad value for datatype: {datatype}")
Expand Down

0 comments on commit f4c72ca

Please sign in to comment.