Skip to content

Commit

Permalink
first pass at sync function with util classes
Browse files Browse the repository at this point in the history
  • Loading branch information
bencrabtree committed Mar 4, 2024
1 parent 352a5c1 commit 344d26b
Show file tree
Hide file tree
Showing 17 changed files with 881 additions and 17 deletions.
11 changes: 8 additions & 3 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def __init__(
Default: None (no config).
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object,
used for SageMaker interactions. Default: Session in region associated with boto3 session.
used for SageMaker interactions. Default: Session in region associated with boto3
session.
"""

self._region = region
Expand Down Expand Up @@ -358,7 +359,9 @@ def _retrieval_function(
hub_content_type=data_type
)

model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True)
model_specs = JumpStartModelSpecs(
DescribeHubContentsResponse(hub_model_description), is_hub_content=True
)

utils.emit_logs_based_on_model_specs(
model_specs,
Expand All @@ -372,7 +375,9 @@ def _retrieval_function(
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)
hub_description = DescribeHubResponse(response)
return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description))
return JumpStartCachedContentValue(
formatted_content=DescribeHubResponse(hub_description)
)
raise ValueError(
f"Bad value for key '{key}': must be in ",
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}"
Expand Down
Empty file.
112 changes: 112 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains important utilities related to HubContent data files."""
from __future__ import absolute_import
from functools import singledispatchmethod
from typing import Any, Dict, List, Optional

from botocore.client import BaseClient

from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor
from sagemaker.jumpstart.types import JumpStartModelSpecs


class FileGenerator:
"""Utility class to help format HubContent data files."""

def __init__(
self, region: str, s3_client: BaseClient, studio_specs: Optional[Dict[str, Any]] = None
):
self.region = region
self.s3_client = s3_client
self.studio_specs = studio_specs

@singledispatchmethod
def format(self, file_input) -> List[FileInfo]:
"""Implement."""
# pylint: disable=W0107
pass

@format.register
def _(self, file_input: S3ObjectLocation) -> List[FileInfo]:
"""Something."""
files = self.s3_format(file_input)
return files

@format.register
def _(self, file_input: JumpStartModelSpecs) -> List[FileInfo]:
"""Something."""
files = self.specs_format(file_input, self.studio_specs)
return files

def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]:
"""Retrieves data from a bucket and formats into FileInfo"""
parameters = {"Bucket": file_input.bucket, "Prefix": file_input.key}
response = self.s3_client.list_objects_v2(**parameters)
contents = response.get("Contents", None)

if not contents:
print("Nothing to download")
return []

files = []
for s3_obj in contents:
key: str = s3_obj.get("Key")
size: bytes = s3_obj.get("Size", None)
last_modified: str = s3_obj.get("LastModified", None)
files.append(FileInfo(key, size, last_modified))
return files

def specs_format(
self, file_input: JumpStartModelSpecs, studio_specs: Dict[str, Any]
) -> List[FileInfo]:
"""Collects data locations from JumpStart public model specs and
converts into FileInfo.
"""
public_model_data_accessor = PublicModelDataAccessor(
region=self.region, model_specs=file_input, studio_specs=studio_specs
)
function_table = {
HubContentDependencyType.INFERENCE_ARTIFACT: (
public_model_data_accessor.get_inference_artifact_s3_reference
),
HubContentDependencyType.TRAINING_ARTIFACT: (
public_model_data_accessor.get_training_artifact_s3_reference
),
HubContentDependencyType.INFERNECE_SCRIPT: (
public_model_data_accessor.get_inference_script_s3_reference
),
HubContentDependencyType.TRAINING_SCRIPT: (
public_model_data_accessor.get_training_script_s3_reference
),
HubContentDependencyType.DEFAULT_TRAINING_DATASET: (
public_model_data_accessor.get_default_training_dataset_s3_reference
),
HubContentDependencyType.DEMO_NOTEBOOK: (
public_model_data_accessor.get_demo_notebook_s3_reference
),
HubContentDependencyType.MARKDOWN: public_model_data_accessor.get_markdown_s3_reference,
}
files = []
for dependency in HubContentDependencyType:
location = function_table[dependency]()
parameters = {"Bucket": location.bucket, "Prefix": location.key}
response = self.s3_client.head_object(**parameters)
key: str = location.key
size: bytes = response.get("ContentLength", None)
last_updated: str = response.get("LastModified", None)
dependency_type: HubContentDependencyType = dependency
files.append(FileInfo(key, size, last_updated, dependency_type))
return files
47 changes: 47 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains important details related to HubContent data files."""
from __future__ import absolute_import

from enum import Enum
from dataclasses import dataclass
from typing import Optional


class HubContentDependencyType(str, Enum):
"""Enum class for HubContent dependency names"""

INFERENCE_ARTIFACT = "INFERENCE_ARTIFACT"
TRAINING_ARTIFACT = "TRAINING_ARTIFACT"
INFERNECE_SCRIPT = "INFERENCE_SCRIPT"
TRAINING_SCRIPT = "TRAINING_SCRIPT"
DEFAULT_TRAINING_DATASET = "DEFAULT_TRAINING_DATASET"
DEMO_NOTEBOOK = "DEMO_NOTEBOOK"
MARKDOWN = "MARKDOWN"


@dataclass
class FileInfo:
"""Data class for additional S3 file info."""

def __init__(
self,
name: str,
size: Optional[bytes],
last_updated: Optional[str],
dependecy_type: Optional[HubContentDependencyType] = None,
):
self.name = name
self.size = size
self.last_updated = last_updated
self.dependecy_type = dependecy_type
46 changes: 46 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module utilites to assist S3 client calls for the Curated Hub."""
from __future__ import absolute_import
from dataclasses import dataclass
from typing import Dict


@dataclass
class S3ObjectLocation:
"""Helper class for S3 object references"""

bucket: str
key: str

def format_for_s3_copy(self) -> Dict[str, str]:
"""Returns a dict formatted for S3 copy calls"""
return {
"Bucket": self.bucket,
"Key": self.key,
}

def get_uri(self) -> str:
"""Returns the s3 URI"""
return f"s3://{self.bucket}/{self.key}"


def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation:
"""Utiity to help generate an S3 object reference"""
uri_with_s3_prefix_removed = s3_uri.replace("s3://", "", 1)
uri_split = uri_with_s3_prefix_removed.split("/")

return S3ObjectLocation(
bucket=uri_split[0],
key="/".join(uri_split[1:]) if len(uri_split) > 1 else "",
)
111 changes: 111 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module accessors for the SageMaker JumpStart Public Hub."""
from __future__ import absolute_import
from typing import Dict, Any
from sagemaker import model_uris, script_uris
from sagemaker.jumpstart.curated_hub.utils import (
get_model_framework,
)
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.types import JumpStartModelSpecs
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import (
S3ObjectLocation,
create_s3_object_reference_from_uri,
)


class PublicModelDataAccessor:
"""Accessor class for JumpStart model data s3 locations."""

def __init__(
self,
region: str,
model_specs: JumpStartModelSpecs,
studio_specs: Dict[str, Dict[str, Any]],
):
self._region = region
self._bucket = get_jumpstart_content_bucket(region)
self.model_specs = model_specs
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift

def get_bucket_name(self) -> str:
"""Retrieves s3 bucket"""
return self._bucket

def get_inference_artifact_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model inference artifact"""
return create_s3_object_reference_from_uri(
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE)
)

def get_training_artifact_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model training artifact"""
return create_s3_object_reference_from_uri(
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING)
)

def get_inference_script_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model inference script"""
return create_s3_object_reference_from_uri(
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE)
)

def get_training_script_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model training script"""
return create_s3_object_reference_from_uri(
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING)
)

def get_default_training_dataset_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for s3 directory containing model training datasets"""
return S3ObjectLocation(self.get_bucket_name(), self._get_training_dataset_prefix())

def _get_training_dataset_prefix(self) -> str:
"""Retrieves training dataset location"""
return self.studio_specs["defaultDataKey"]

def get_demo_notebook_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model demo jupyter notebook"""
framework = get_model_framework(self.model_specs)
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb"
return S3ObjectLocation(self.get_bucket_name(), key)

def get_markdown_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model markdown"""
framework = get_model_framework(self.model_specs)
key = f"{framework}-metadata/{self.model_specs.model_id}.md"
return S3ObjectLocation(self.get_bucket_name(), key)

def _jumpstart_script_s3_uri(self, model_scope: str) -> str:
"""Retrieves JumpStart script s3 location"""
return script_uris.retrieve(
region=self._region,
model_id=self.model_specs.model_id,
model_version=self.model_specs.version,
script_scope=model_scope,
tolerate_vulnerable_model=True,
tolerate_deprecated_model=True,
)

def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str:
"""Retrieves JumpStart artifact s3 location"""
return model_uris.retrieve(
region=self._region,
model_id=self.model_specs.model_id,
model_version=self.model_specs.version,
model_scope=model_scope,
tolerate_vulnerable_model=True,
tolerate_deprecated_model=True,
)
Loading

0 comments on commit 344d26b

Please sign in to comment.