diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 8d0f1832bf..417bae77c7 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -34,7 +34,6 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, MODEL_TYPE_TO_MANIFEST_MAP, MODEL_TYPE_TO_SPECS_MAP, - DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.exceptions import ( get_wildcard_model_version_msg, @@ -443,7 +442,7 @@ def _retrieval_function( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) - + if data_type in { JumpStartS3FileType.OPEN_WEIGHT_SPECS, JumpStartS3FileType.PROPRIETARY_SPECS, diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py index 0393b4234a..e5ea072d86 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py @@ -22,6 +22,7 @@ S3ObjectLocation, ) from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor +from sagemaker.jumpstart.curated_hub.utils import is_gated_bucket from sagemaker.jumpstart.types import JumpStartModelSpecs @@ -65,6 +66,10 @@ def generate_file_infos_from_model_specs( files = [] for dependency in HubContentDependencyType: location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency) + # Training dependencies will return as None if training is unsupported + if not location or is_gated_bucket(location.bucket): + continue + location_type = "prefix" if location.key.endswith("/") else "object" if location_type == "prefix": diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py index 89e3a2f108..a4e339591b 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. """This module accessors for the SageMaker JumpStart Public Hub.""" from __future__ import absolute_import -from typing import Dict, Any +from typing import Dict, Any, Optional from sagemaker import model_uris, script_uris from sagemaker.jumpstart.curated_hub.types import ( HubContentDependencyType, @@ -21,7 +21,10 @@ from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import JumpStartModelSpecs -from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.jumpstart.utils import ( + get_jumpstart_content_bucket, + get_jumpstart_gated_content_bucket, +) class PublicModelDataAccessor: @@ -34,7 +37,11 @@ def __init__( studio_specs: Dict[str, Dict[str, Any]], ): self._region = region - self._bucket = get_jumpstart_content_bucket(region) + self._bucket = ( + get_jumpstart_gated_content_bucket(region) + if model_specs.gated_bucket + else get_jumpstart_content_bucket(region) + ) self.model_specs = model_specs self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift @@ -43,47 +50,53 @@ def get_s3_reference(self, dependency_type: HubContentDependencyType): return getattr(self, dependency_type.value) @property - def inference_artifact_s3_reference(self): + def inference_artifact_s3_reference(self) -> Optional[S3ObjectLocation]: """Retrieves s3 reference for model inference artifact""" return create_s3_object_reference_from_uri( self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE) ) @property - def training_artifact_s3_reference(self): + def training_artifact_s3_reference(self) -> Optional[S3ObjectLocation]: """Retrieves s3 reference for model training artifact""" + if not self.model_specs.training_supported: + return None return create_s3_object_reference_from_uri( self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING) ) @property - def inference_script_s3_reference(self): + def inference_script_s3_reference(self) -> Optional[S3ObjectLocation]: """Retrieves s3 reference for model inference script""" return create_s3_object_reference_from_uri( self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE) ) @property - def training_script_s3_reference(self): + def training_script_s3_reference(self) -> Optional[S3ObjectLocation]: """Retrieves s3 reference for model training script""" + if not self.model_specs.training_supported: + return None return create_s3_object_reference_from_uri( self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING) ) @property - def default_training_dataset_s3_reference(self): + def default_training_dataset_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for s3 directory containing model training datasets""" - return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix()) + if not self.model_specs.training_supported: + return None + return S3ObjectLocation(self._get_bucket_name(), self._get_training_dataset_prefix()) @property - def demo_notebook_s3_reference(self): + def demo_notebook_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for model demo jupyter notebook""" framework = self.model_specs.get_framework() key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb" return S3ObjectLocation(self._get_bucket_name(), key) @property - def markdown_s3_reference(self): + def markdown_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for model markdown""" framework = self.model_specs.get_framework() key = f"{framework}-metadata/{self.model_specs.model_id}.md" @@ -93,24 +106,30 @@ def _get_bucket_name(self) -> str: """Retrieves s3 bucket""" return self._bucket - def __get_training_dataset_prefix(self) -> str: + def _get_training_dataset_prefix(self) -> Optional[str]: """Retrieves training dataset location""" - return self.studio_specs["defaultDataKey"] + return self.studio_specs.get("defaultDataKey") - def _jumpstart_script_s3_uri(self, model_scope: str) -> str: + def _jumpstart_script_s3_uri(self, model_scope: str) -> Optional[str]: """Retrieves JumpStart script s3 location""" - return script_uris.retrieve( - region=self._region, - model_id=self.model_specs.model_id, - model_version=self.model_specs.version, - script_scope=model_scope, - ) + try: + return script_uris.retrieve( + region=self._region, + model_id=self.model_specs.model_id, + model_version=self.model_specs.version, + script_scope=model_scope, + ) + except ValueError: + return None - def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str: + def _jumpstart_artifact_s3_uri(self, model_scope: str) -> Optional[str]: """Retrieves JumpStart artifact s3 location""" - return model_uris.retrieve( - region=self._region, - model_id=self.model_specs.model_id, - model_version=self.model_specs.version, - model_scope=model_scope, - ) + try: + return model_uris.retrieve( + region=self._region, + model_id=self.model_specs.model_id, + model_version=self.model_specs.version, + model_scope=model_scope, + ) + except ValueError: + return None diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index b116411801..71008ab5b4 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -133,8 +133,11 @@ def generate_default_hub_bucket_name( return f"sagemaker-hubs-{region}-{account_id}" -def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation: +def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]: """Utiity to help generate an S3 object reference""" + if not s3_uri: + return None + bucket, key = parse_s3_url(s3_uri) return S3ObjectLocation( @@ -164,3 +167,8 @@ def create_hub_bucket_if_it_does_not_exist( ) return bucket_name + + +def is_gated_bucket(bucket_name: str) -> bool: + """Returns true if the bucket name is the JumpStart gated bucket.""" + return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 4fc8752625..210548511d 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -868,12 +868,13 @@ def generate_studio_spec_file_prefix(model_id: str, model_version: str) -> str: """Returns the Studio Spec file prefix given a model ID and version.""" return f"studio_models/{model_id}/studio_specs_v{model_version}.json" + def extract_info_from_hub_content_arn( arn: str, ) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: """Extracts hub_name, content_name, and content_version from a HubContentArn""" - match = re.match(constants.HUB_MODEL_ARN_REGEX, arn) + match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn) if match: hub_name = match.group(4) hub_region = match.group(2) diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index 5332f7bdd0..8fd83bfcfe 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -65,9 +65,6 @@ def main(sys_args=None): conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) - RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( - client_sagemaker_pysdk_version - ) user = getpass.getuser() if user != "root": diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py index accd2a5c8d..8fcb8dd740 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py @@ -127,3 +127,46 @@ def test_s3_path_file_generator_with_no_objects(s3_client): s3_client.list_objects_v2.assert_called_once() assert response == [] + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_specs_file_generator_training_unsupported(patched_get_model_specs, s3_client): + specs = Mock() + specs.model_id = "mock_model_123" + specs.training_supported = False + specs.gated_bucket = False + specs.hosting_prepacked_artifact_key = "/my/inference/tarball.tgz" + specs.hosting_script_key = "/my/inference/script.py" + + response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client) + + assert response == [ + FileInfo( + "jumpstart-cache-prod-us-west-2", + "/my/inference/tarball.tgz", + 123456789, + "08-14-1997 00:00:00", + ), + FileInfo( + "jumpstart-cache-prod-us-west-2", + "/my/inference/script.py", + 123456789, + "08-14-1997 00:00:00", + ), + ] + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_specs_file_generator_gated_model(patched_get_model_specs, s3_client): + specs = Mock() + specs.model_id = "mock_model_123" + specs.gated_bucket = True + specs.training_supported = True + specs.hosting_prepacked_artifact_key = "/my/inference/tarball.tgz" + specs.hosting_script_key = "/my/inference/script.py" + specs.training_prepacked_artifact_key = "/my/training/tarball.tgz" + specs.training_script_key = "/my/training/script.py" + + response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client) + + assert response == [] diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index b4b2eaabb2..ac5fdaba3e 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -177,29 +177,11 @@ def test_create_hub_bucket_if_it_does_not_exist(): assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn -def test_generate_default_hub_bucket_name(): - mock_sagemaker_session = Mock() - mock_sagemaker_session.account_id.return_value = "123456789123" - mock_sagemaker_session.boto_region_name = "us-east-1" +def test_is_gated_bucket(): + assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True - assert ( - utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session) - == "sagemaker-hubs-us-east-1-123456789123" - ) + assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-east-1") is True + assert utils.is_gated_bucket("jumpstart-cache-prod-us-west-2") is False -def test_create_hub_bucket_if_it_does_not_exist(): - mock_sagemaker_session = Mock() - mock_sagemaker_session.account_id.return_value = "123456789123" - mock_sagemaker_session.client("sts").get_caller_identity.return_value = { - "Account": "123456789123" - } - mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None - mock_sagemaker_session.boto_region_name = "us-east-1" - bucket_name = "sagemaker-hubs-us-east-1-123456789123" - created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( - sagemaker_session=mock_sagemaker_session - ) - - mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() - assert created_hub_bucket_name == bucket_name + assert utils.is_gated_bucket("") is False diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 5d527dd5a1..79eeb4b7f0 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -83,7 +83,9 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache): accessors.JumpStartModelsAccessor.get_model_specs( region=region, model_id=model_id, version=version ) - mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version) + mock_cache.get_specs.assert_called_once_with( + model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS + ) mock_cache.get_hub_model.assert_not_called() accessors.JumpStartModelsAccessor.get_model_specs( @@ -98,6 +100,7 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache): ) ) + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") def test_jumpstart_proprietary_models_cache_get(mock_cache): @@ -138,37 +141,6 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache): ) -@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") -def test_jumpstart_models_cache_get_model_specs(mock_cache): - mock_cache.get_specs = Mock() - mock_cache.get_hub_model = Mock() - model_id, version = "pytorch-ic-mobilenet-v2", "*" - region = "us-west-2" - - accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=version - ) - mock_cache.get_specs.assert_called_once_with( - model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS - ) - mock_cache.get_hub_model.assert_not_called() - - accessors.JumpStartModelsAccessor.get_model_specs( - region=region, - model_id=model_id, - version=version, - hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub", - ) - mock_cache.get_hub_model.assert_called_once_with( - hub_model_arn=( - f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}" - ) - ) - - # necessary because accessors is a static module - reload(accessors) - - @patch("sagemaker.jumpstart.cache.JumpStartModelsCache") def test_jumpstart_models_cache_set_reset(mock_model_cache: Mock): diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index ad093640b7..410aba4d03 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -254,16 +254,6 @@ def patched_retrieval_function( ) ) - if datatype == HubContentType.MODEL: - _, _, _, model_name, model_version = id_info.split("/") - return JumpStartCachedContentValue( - formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version) - ) - - # TODO: Implement - if datatype == HubType.HUB: - return None - raise ValueError(f"Bad value for datatype: {datatype}")