-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
first pass at sync function with util classes
- Loading branch information
1 parent
352a5c1
commit 344d26b
Showing
17 changed files
with
881 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
112 changes: 112 additions & 0 deletions
112
src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains important utilities related to HubContent data files.""" | ||
from __future__ import absolute_import | ||
from functools import singledispatchmethod | ||
from typing import Any, Dict, List, Optional | ||
|
||
from botocore.client import BaseClient | ||
|
||
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType | ||
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation | ||
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor | ||
from sagemaker.jumpstart.types import JumpStartModelSpecs | ||
|
||
|
||
class FileGenerator: | ||
"""Utility class to help format HubContent data files.""" | ||
|
||
def __init__( | ||
self, region: str, s3_client: BaseClient, studio_specs: Optional[Dict[str, Any]] = None | ||
): | ||
self.region = region | ||
self.s3_client = s3_client | ||
self.studio_specs = studio_specs | ||
|
||
@singledispatchmethod | ||
def format(self, file_input) -> List[FileInfo]: | ||
"""Implement.""" | ||
# pylint: disable=W0107 | ||
pass | ||
|
||
@format.register | ||
def _(self, file_input: S3ObjectLocation) -> List[FileInfo]: | ||
"""Something.""" | ||
files = self.s3_format(file_input) | ||
return files | ||
|
||
@format.register | ||
def _(self, file_input: JumpStartModelSpecs) -> List[FileInfo]: | ||
"""Something.""" | ||
files = self.specs_format(file_input, self.studio_specs) | ||
return files | ||
|
||
def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]: | ||
"""Retrieves data from a bucket and formats into FileInfo""" | ||
parameters = {"Bucket": file_input.bucket, "Prefix": file_input.key} | ||
response = self.s3_client.list_objects_v2(**parameters) | ||
contents = response.get("Contents", None) | ||
|
||
if not contents: | ||
print("Nothing to download") | ||
return [] | ||
|
||
files = [] | ||
for s3_obj in contents: | ||
key: str = s3_obj.get("Key") | ||
size: bytes = s3_obj.get("Size", None) | ||
last_modified: str = s3_obj.get("LastModified", None) | ||
files.append(FileInfo(key, size, last_modified)) | ||
return files | ||
|
||
def specs_format( | ||
self, file_input: JumpStartModelSpecs, studio_specs: Dict[str, Any] | ||
) -> List[FileInfo]: | ||
"""Collects data locations from JumpStart public model specs and | ||
converts into FileInfo. | ||
""" | ||
public_model_data_accessor = PublicModelDataAccessor( | ||
region=self.region, model_specs=file_input, studio_specs=studio_specs | ||
) | ||
function_table = { | ||
HubContentDependencyType.INFERENCE_ARTIFACT: ( | ||
public_model_data_accessor.get_inference_artifact_s3_reference | ||
), | ||
HubContentDependencyType.TRAINING_ARTIFACT: ( | ||
public_model_data_accessor.get_training_artifact_s3_reference | ||
), | ||
HubContentDependencyType.INFERNECE_SCRIPT: ( | ||
public_model_data_accessor.get_inference_script_s3_reference | ||
), | ||
HubContentDependencyType.TRAINING_SCRIPT: ( | ||
public_model_data_accessor.get_training_script_s3_reference | ||
), | ||
HubContentDependencyType.DEFAULT_TRAINING_DATASET: ( | ||
public_model_data_accessor.get_default_training_dataset_s3_reference | ||
), | ||
HubContentDependencyType.DEMO_NOTEBOOK: ( | ||
public_model_data_accessor.get_demo_notebook_s3_reference | ||
), | ||
HubContentDependencyType.MARKDOWN: public_model_data_accessor.get_markdown_s3_reference, | ||
} | ||
files = [] | ||
for dependency in HubContentDependencyType: | ||
location = function_table[dependency]() | ||
parameters = {"Bucket": location.bucket, "Prefix": location.key} | ||
response = self.s3_client.head_object(**parameters) | ||
key: str = location.key | ||
size: bytes = response.get("ContentLength", None) | ||
last_updated: str = response.get("LastModified", None) | ||
dependency_type: HubContentDependencyType = dependency | ||
files.append(FileInfo(key, size, last_updated, dependency_type)) | ||
return files |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains important details related to HubContent data files.""" | ||
from __future__ import absolute_import | ||
|
||
from enum import Enum | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
|
||
class HubContentDependencyType(str, Enum): | ||
"""Enum class for HubContent dependency names""" | ||
|
||
INFERENCE_ARTIFACT = "INFERENCE_ARTIFACT" | ||
TRAINING_ARTIFACT = "TRAINING_ARTIFACT" | ||
INFERNECE_SCRIPT = "INFERENCE_SCRIPT" | ||
TRAINING_SCRIPT = "TRAINING_SCRIPT" | ||
DEFAULT_TRAINING_DATASET = "DEFAULT_TRAINING_DATASET" | ||
DEMO_NOTEBOOK = "DEMO_NOTEBOOK" | ||
MARKDOWN = "MARKDOWN" | ||
|
||
|
||
@dataclass | ||
class FileInfo: | ||
"""Data class for additional S3 file info.""" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
size: Optional[bytes], | ||
last_updated: Optional[str], | ||
dependecy_type: Optional[HubContentDependencyType] = None, | ||
): | ||
self.name = name | ||
self.size = size | ||
self.last_updated = last_updated | ||
self.dependecy_type = dependecy_type |
46 changes: 46 additions & 0 deletions
46
src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module utilites to assist S3 client calls for the Curated Hub.""" | ||
from __future__ import absolute_import | ||
from dataclasses import dataclass | ||
from typing import Dict | ||
|
||
|
||
@dataclass | ||
class S3ObjectLocation: | ||
"""Helper class for S3 object references""" | ||
|
||
bucket: str | ||
key: str | ||
|
||
def format_for_s3_copy(self) -> Dict[str, str]: | ||
"""Returns a dict formatted for S3 copy calls""" | ||
return { | ||
"Bucket": self.bucket, | ||
"Key": self.key, | ||
} | ||
|
||
def get_uri(self) -> str: | ||
"""Returns the s3 URI""" | ||
return f"s3://{self.bucket}/{self.key}" | ||
|
||
|
||
def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation: | ||
"""Utiity to help generate an S3 object reference""" | ||
uri_with_s3_prefix_removed = s3_uri.replace("s3://", "", 1) | ||
uri_split = uri_with_s3_prefix_removed.split("/") | ||
|
||
return S3ObjectLocation( | ||
bucket=uri_split[0], | ||
key="/".join(uri_split[1:]) if len(uri_split) > 1 else "", | ||
) |
111 changes: 111 additions & 0 deletions
111
src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module accessors for the SageMaker JumpStart Public Hub.""" | ||
from __future__ import absolute_import | ||
from typing import Dict, Any | ||
from sagemaker import model_uris, script_uris | ||
from sagemaker.jumpstart.curated_hub.utils import ( | ||
get_model_framework, | ||
) | ||
from sagemaker.jumpstart.enums import JumpStartScriptScope | ||
from sagemaker.jumpstart.types import JumpStartModelSpecs | ||
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket | ||
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import ( | ||
S3ObjectLocation, | ||
create_s3_object_reference_from_uri, | ||
) | ||
|
||
|
||
class PublicModelDataAccessor: | ||
"""Accessor class for JumpStart model data s3 locations.""" | ||
|
||
def __init__( | ||
self, | ||
region: str, | ||
model_specs: JumpStartModelSpecs, | ||
studio_specs: Dict[str, Dict[str, Any]], | ||
): | ||
self._region = region | ||
self._bucket = get_jumpstart_content_bucket(region) | ||
self.model_specs = model_specs | ||
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift | ||
|
||
def get_bucket_name(self) -> str: | ||
"""Retrieves s3 bucket""" | ||
return self._bucket | ||
|
||
def get_inference_artifact_s3_reference(self) -> S3ObjectLocation: | ||
"""Retrieves s3 reference for model inference artifact""" | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE) | ||
) | ||
|
||
def get_training_artifact_s3_reference(self) -> S3ObjectLocation: | ||
"""Retrieves s3 reference for model training artifact""" | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING) | ||
) | ||
|
||
def get_inference_script_s3_reference(self) -> S3ObjectLocation: | ||
"""Retrieves s3 reference for model inference script""" | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE) | ||
) | ||
|
||
def get_training_script_s3_reference(self) -> S3ObjectLocation: | ||
"""Retrieves s3 reference for model training script""" | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING) | ||
) | ||
|
||
def get_default_training_dataset_s3_reference(self) -> S3ObjectLocation: | ||
"""Retrieves s3 reference for s3 directory containing model training datasets""" | ||
return S3ObjectLocation(self.get_bucket_name(), self._get_training_dataset_prefix()) | ||
|
||
def _get_training_dataset_prefix(self) -> str: | ||
"""Retrieves training dataset location""" | ||
return self.studio_specs["defaultDataKey"] | ||
|
||
def get_demo_notebook_s3_reference(self) -> S3ObjectLocation: | ||
"""Retrieves s3 reference for model demo jupyter notebook""" | ||
framework = get_model_framework(self.model_specs) | ||
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb" | ||
return S3ObjectLocation(self.get_bucket_name(), key) | ||
|
||
def get_markdown_s3_reference(self) -> S3ObjectLocation: | ||
"""Retrieves s3 reference for model markdown""" | ||
framework = get_model_framework(self.model_specs) | ||
key = f"{framework}-metadata/{self.model_specs.model_id}.md" | ||
return S3ObjectLocation(self.get_bucket_name(), key) | ||
|
||
def _jumpstart_script_s3_uri(self, model_scope: str) -> str: | ||
"""Retrieves JumpStart script s3 location""" | ||
return script_uris.retrieve( | ||
region=self._region, | ||
model_id=self.model_specs.model_id, | ||
model_version=self.model_specs.version, | ||
script_scope=model_scope, | ||
tolerate_vulnerable_model=True, | ||
tolerate_deprecated_model=True, | ||
) | ||
|
||
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str: | ||
"""Retrieves JumpStart artifact s3 location""" | ||
return model_uris.retrieve( | ||
region=self._region, | ||
model_id=self.model_specs.model_id, | ||
model_version=self.model_specs.version, | ||
model_scope=model_scope, | ||
tolerate_vulnerable_model=True, | ||
tolerate_deprecated_model=True, | ||
) |
Oops, something went wrong.