Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/gated model support #4510

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
MODEL_TYPE_TO_MANIFEST_MAP,
MODEL_TYPE_TO_SPECS_MAP,
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.exceptions import (
get_wildcard_model_version_msg,
Expand Down Expand Up @@ -443,7 +442,7 @@ def _retrieval_function(
formatted_content=utils.get_formatted_manifest(formatted_body),
md5_hash=etag,
)

if data_type in {
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
JumpStartS3FileType.PROPRIETARY_SPECS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
S3ObjectLocation,
)
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor
from sagemaker.jumpstart.curated_hub.utils import is_gated_bucket
from sagemaker.jumpstart.types import JumpStartModelSpecs


Expand Down Expand Up @@ -65,6 +66,10 @@ def generate_file_infos_from_model_specs(
files = []
for dependency in HubContentDependencyType:
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency)
# Training dependencies will return as None if training is unsupported
if not location or is_gated_bucket(location.bucket):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we checking for gated?

Copy link
Collaborator Author

@bencrabtree bencrabtree Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to / can't copy any files over from the gated bucket

continue

location_type = "prefix" if location.key.endswith("/") else "object"

if location_type == "prefix":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# language governing permissions and limitations under the License.
"""This module accessors for the SageMaker JumpStart Public Hub."""
from __future__ import absolute_import
from typing import Dict, Any
from typing import Dict, Any, Optional
from sagemaker import model_uris, script_uris
from sagemaker.jumpstart.curated_hub.types import (
HubContentDependencyType,
Expand All @@ -21,7 +21,10 @@
from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.types import JumpStartModelSpecs
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
from sagemaker.jumpstart.utils import (
get_jumpstart_content_bucket,
get_jumpstart_gated_content_bucket,
)


class PublicModelDataAccessor:
Expand All @@ -34,7 +37,11 @@ def __init__(
studio_specs: Dict[str, Dict[str, Any]],
):
self._region = region
self._bucket = get_jumpstart_content_bucket(region)
self._bucket = (
get_jumpstart_gated_content_bucket(region)
if model_specs.gated_bucket
else get_jumpstart_content_bucket(region)
)
self.model_specs = model_specs
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift

Expand All @@ -43,47 +50,53 @@ def get_s3_reference(self, dependency_type: HubContentDependencyType):
return getattr(self, dependency_type.value)

@property
def inference_artifact_s3_reference(self):
def inference_artifact_s3_reference(self) -> Optional[S3ObjectLocation]:
"""Retrieves s3 reference for model inference artifact"""
return create_s3_object_reference_from_uri(
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE)
)

@property
def training_artifact_s3_reference(self):
def training_artifact_s3_reference(self) -> Optional[S3ObjectLocation]:
"""Retrieves s3 reference for model training artifact"""
if not self.model_specs.training_supported:
return None
return create_s3_object_reference_from_uri(
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING)
)

@property
def inference_script_s3_reference(self):
def inference_script_s3_reference(self) -> Optional[S3ObjectLocation]:
"""Retrieves s3 reference for model inference script"""
return create_s3_object_reference_from_uri(
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE)
)

@property
def training_script_s3_reference(self):
def training_script_s3_reference(self) -> Optional[S3ObjectLocation]:
"""Retrieves s3 reference for model training script"""
if not self.model_specs.training_supported:
return None
return create_s3_object_reference_from_uri(
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING)
)

@property
def default_training_dataset_s3_reference(self):
def default_training_dataset_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for s3 directory containing model training datasets"""
if not self.model_specs.training_supported:
return None
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix())

@property
def demo_notebook_s3_reference(self):
def demo_notebook_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model demo jupyter notebook"""
framework = self.model_specs.get_framework()
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb"
return S3ObjectLocation(self._get_bucket_name(), key)

@property
def markdown_s3_reference(self):
def markdown_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model markdown"""
framework = self.model_specs.get_framework()
key = f"{framework}-metadata/{self.model_specs.model_id}.md"
Expand All @@ -93,24 +106,30 @@ def _get_bucket_name(self) -> str:
"""Retrieves s3 bucket"""
return self._bucket

def __get_training_dataset_prefix(self) -> str:
def _get_training_dataset_prefix(self) -> Optional[str]:
"""Retrieves training dataset location"""
return self.studio_specs["defaultDataKey"]
return self.studio_specs.get("defaultDataKey")

def _jumpstart_script_s3_uri(self, model_scope: str) -> str:
def _jumpstart_script_s3_uri(self, model_scope: str) -> Optional[str]:
"""Retrieves JumpStart script s3 location"""
return script_uris.retrieve(
region=self._region,
model_id=self.model_specs.model_id,
model_version=self.model_specs.version,
script_scope=model_scope,
)
try:
return script_uris.retrieve(
region=self._region,
model_id=self.model_specs.model_id,
model_version=self.model_specs.version,
script_scope=model_scope,
)
except ValueError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log something perhaps?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can log something here, though I don't think we'll ever reach this since we're only calling this function if training is supported

return None

def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str:
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> Optional[str]:
"""Retrieves JumpStart artifact s3 location"""
return model_uris.retrieve(
region=self._region,
model_id=self.model_specs.model_id,
model_version=self.model_specs.version,
model_scope=model_scope,
)
try:
return model_uris.retrieve(
region=self._region,
model_id=self.model_specs.model_id,
model_version=self.model_specs.version,
model_scope=model_scope,
)
except ValueError:
return None
10 changes: 9 additions & 1 deletion src/sagemaker/jumpstart/curated_hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,11 @@ def generate_default_hub_bucket_name(
return f"sagemaker-hubs-{region}-{account_id}"


def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation:
def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]:
"""Utiity to help generate an S3 object reference"""
if not s3_uri:
return None

bucket, key = parse_s3_url(s3_uri)

return S3ObjectLocation(
Expand Down Expand Up @@ -164,3 +167,8 @@ def create_hub_bucket_if_it_does_not_exist(
)

return bucket_name


def is_gated_bucket(bucket_name: str) -> bool:
"""Returns true if the bucket name is the JumpStart gated bucket."""
return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET
3 changes: 2 additions & 1 deletion src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,12 +868,13 @@ def generate_studio_spec_file_prefix(model_id: str, model_version: str) -> str:
"""Returns the Studio Spec file prefix given a model ID and version."""
return f"studio_models/{model_id}/studio_specs_v{model_version}.json"


def extract_info_from_hub_content_arn(
arn: str,
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
"""Extracts hub_name, content_name, and content_version from a HubContentArn"""

match = re.match(constants.HUB_MODEL_ARN_REGEX, arn)
match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn)
if match:
hub_name = match.group(4)
hub_region = match.group(2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,46 @@ def test_s3_path_file_generator_with_no_objects(s3_client):

s3_client.list_objects_v2.assert_called_once()
assert response == []


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_specs_file_generator_training_unsupported(patched_get_model_specs, s3_client):
specs = Mock()
specs.model_id = "mock_model_123"
specs.training_supported = False
specs.gated_bucket = False
specs.hosting_prepacked_artifact_key = "/my/inference/tarball.tgz"
specs.hosting_script_key = "/my/inference/script.py"

response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client)

assert response == [
FileInfo(
"jumpstart-cache-prod-us-west-2",
"/my/inference/tarball.tgz",
123456789,
"08-14-1997 00:00:00",
),
FileInfo(
"jumpstart-cache-prod-us-west-2",
"/my/inference/script.py",
123456789,
"08-14-1997 00:00:00",
),
]


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_specs_file_generator_gated_model(patched_get_model_specs, s3_client):
specs = Mock()
specs.model_id = "mock_model_123"
specs.gated_bucket = True
specs.training_supported = True
specs.hosting_prepacked_artifact_key = "/my/inference/tarball.tgz"
specs.hosting_script_key = "/my/inference/script.py"
specs.training_prepacked_artifact_key = "/my/training/tarball.tgz"
specs.training_script_key = "/my/training/script.py"

response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client)

assert response == []
28 changes: 5 additions & 23 deletions tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,29 +177,11 @@ def test_create_hub_bucket_if_it_does_not_exist():
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn


def test_generate_default_hub_bucket_name():
mock_sagemaker_session = Mock()
mock_sagemaker_session.account_id.return_value = "123456789123"
mock_sagemaker_session.boto_region_name = "us-east-1"
def test_is_gated_bucket():
assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True

assert (
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
== "sagemaker-hubs-us-east-1-123456789123"
)
assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-east-1") is True

assert utils.is_gated_bucket("jumpstart-cache-prod-us-west-2") is False

def test_create_hub_bucket_if_it_does_not_exist():
mock_sagemaker_session = Mock()
mock_sagemaker_session.account_id.return_value = "123456789123"
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
"Account": "123456789123"
}
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
mock_sagemaker_session.boto_region_name = "us-east-1"
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
sagemaker_session=mock_sagemaker_session
)

mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
assert created_hub_bucket_name == bucket_name
assert utils.is_gated_bucket("") is False
36 changes: 4 additions & 32 deletions tests/unit/sagemaker/jumpstart/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache):
accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=version
)
mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version)
mock_cache.get_specs.assert_called_once_with(
model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS
)
mock_cache.get_hub_model.assert_not_called()

accessors.JumpStartModelsAccessor.get_model_specs(
Expand All @@ -98,6 +100,7 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache):
)
)


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
def test_jumpstart_proprietary_models_cache_get(mock_cache):

Expand Down Expand Up @@ -138,37 +141,6 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache):
)


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
def test_jumpstart_models_cache_get_model_specs(mock_cache):
mock_cache.get_specs = Mock()
mock_cache.get_hub_model = Mock()
model_id, version = "pytorch-ic-mobilenet-v2", "*"
region = "us-west-2"

accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=version
)
mock_cache.get_specs.assert_called_once_with(
model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS
)
mock_cache.get_hub_model.assert_not_called()

accessors.JumpStartModelsAccessor.get_model_specs(
region=region,
model_id=model_id,
version=version,
hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub",
)
mock_cache.get_hub_model.assert_called_once_with(
hub_model_arn=(
f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}"
)
)

# necessary because accessors is a static module
reload(accessors)


@patch("sagemaker.jumpstart.cache.JumpStartModelsCache")
def test_jumpstart_models_cache_set_reset(mock_model_cache: Mock):

Expand Down
10 changes: 0 additions & 10 deletions tests/unit/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,6 @@ def patched_retrieval_function(
)
)

if datatype == HubContentType.MODEL:
_, _, _, model_name, model_version = id_info.split("/")
return JumpStartCachedContentValue(
formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version)
)

# TODO: Implement
if datatype == HubType.HUB:
return None

raise ValueError(f"Bad value for datatype: {datatype}")


Expand Down
Loading