Skip to content

Commit

Permalink
feat: Added update for model package (aws#4309)
Browse files Browse the repository at this point in the history
Co-authored-by: Keshav Chandak <chakesh@amazon.com>
  • Loading branch information
keshav-chandak and Keshav Chandak authored Dec 15, 2023
1 parent 53b9471 commit 12d5040
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 13 deletions.
102 changes: 95 additions & 7 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
)
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker.enums import EndpointType
from sagemaker.session import get_add_model_package_inference_args

LOGGER = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -485,12 +486,6 @@ def register(
if response_types is not None:
self.response_types = response_types

if self.content_types is None:
raise ValueError("The supported MIME types for the input data is not set")

if self.response_types is None:
raise ValueError("The supported MIME types for the output data is not set")

if image_uri is not None:
self.image_uri = image_uri

Expand Down Expand Up @@ -2181,7 +2176,7 @@ def update_approval_status(self, approval_status, approval_description=None):
"""Update the approval status for the model package
Args:
approval_status (str or PipelineVariable): Model Approval Status, values can be
approval_status (str): Model Approval Status, values can be
"Approved", "Rejected", or "PendingManualApproval".
approval_description (str): Optional. Description for the approval status of the model
(default: None).
Expand All @@ -2202,3 +2197,96 @@ def update_approval_status(self, approval_status, approval_description=None):
update_approval_args["ApprovalDescription"] = approval_description

sagemaker_session.sagemaker_client.update_model_package(**update_approval_args)

def update_customer_metadata(self, customer_metadata_properties: Dict[str, str]):
"""Updating customer metadata properties for the model package
Args:
customer_metadata_properties (dict[str, str]):
A dictionary of key-value paired metadata properties (default: None).
"""

update_metadata_args = {
"ModelPackageArn": self.model_package_arn,
"CustomerMetadataProperties": customer_metadata_properties,
}

sagemaker_session = self.sagemaker_session or sagemaker.Session()
sagemaker_session.sagemaker_client.update_model_package(**update_metadata_args)

def remove_customer_metadata_properties(
self, customer_metadata_properties_to_remove: List[str]
):
"""Removes the specified keys from customer metadata properties
Args:
customer_metadata_properties (list[str, str]):
list of keys of customer metadata properties to remove.
"""

delete_metadata_args = {
"ModelPackageArn": self.model_package_arn,
"CustomerMetadataPropertiesToRemove": customer_metadata_properties_to_remove,
}

sagemaker_session = self.sagemaker_session or sagemaker.Session()
sagemaker_session.sagemaker_client.update_model_package(**delete_metadata_args)

def add_inference_specification(
self,
name: str,
containers: Dict = None,
image_uris: List[str] = None,
description: str = None,
content_types: List[str] = None,
response_types: List[str] = None,
inference_instances: List[str] = None,
transform_instances: List[str] = None,
):
"""Additional inference specification to be added for the model package
Args:
name (str): Name to identify the additional inference specification
containers (dict): The Amazon ECR registry path of the Docker image
that contains the inference code.
image_uris (List[str]): The ECR path where inference code is stored.
description (str): Description for the additional inference specification
content_types (list[str]): The supported MIME types
for the input data.
response_types (list[str]): The supported MIME types
for the output data.
inference_instances (list[str]): A list of the instance
types that are used to generate inferences in real-time (default: None).
transform_instances (list[str]): A list of the instance
types on which a transformation job can be run or on which an endpoint can be
deployed (default: None).
"""
sagemaker_session = self.sagemaker_session or sagemaker.Session()
if containers is not None and image_uris is not None:
raise ValueError("Cannot have both containers and image_uris.")
if containers is None and image_uris is None:
raise ValueError("Should have either containers or image_uris for inference.")
container_def = []
if image_uris:
for uri in image_uris:
container_def.append(
{
"Image": uri,
}
)
else:
container_def = containers

model_package_update_args = get_add_model_package_inference_args(
model_package_arn=self.model_package_arn,
name=name,
containers=container_def,
content_types=content_types,
description=description,
response_types=response_types,
inference_instances=inference_instances,
transform_instances=transform_instances,
)

sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args)
88 changes: 82 additions & 6 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6557,15 +6557,21 @@ def get_create_model_package_request(
if task is not None:
request_dict["Task"] = task
if containers is not None:
if not all([content_types, response_types]):
raise ValueError(
"content_types and response_types " "must be provided if containers is present."
)
inference_specification = {
"Containers": containers,
"SupportedContentTypes": content_types,
"SupportedResponseMIMETypes": response_types,
}
if content_types is not None:
inference_specification.update(
{
"SupportedContentTypes": content_types,
}
)
if response_types is not None:
inference_specification.update(
{
"SupportedResponseMIMETypes": response_types,
}
)
if model_package_group_name is not None:
if inference_instances is not None:
inference_specification.update(
Expand Down Expand Up @@ -6598,6 +6604,76 @@ def get_create_model_package_request(
return request_dict


def get_add_model_package_inference_args(
model_package_arn,
name,
containers=None,
content_types=None,
response_types=None,
inference_instances=None,
transform_instances=None,
description=None,
):
"""Get request dictionary for UpdateModelPackage API for additional inference.
Args:
model_package_arn (str): Arn for the model package.
name (str): Name to identify the additional inference specification
containers (dict): The Amazon ECR registry path of the Docker image
that contains the inference code.
image_uris (List[str]): The ECR path where inference code is stored.
description (str): Description for the additional inference specification
content_types (list[str]): The supported MIME types
for the input data.
response_types (list[str]): The supported MIME types
for the output data.
inference_instances (list[str]): A list of the instance
types that are used to generate inferences in real-time (default: None).
transform_instances (list[str]): A list of the instance
types on which a transformation job can be run or on which an endpoint can be
deployed (default: None).
"""

request_dict = {}
if containers is not None:
inference_specification = {
"Containers": containers,
}

if name is not None:
inference_specification.update({"Name": name})

if description is not None:
inference_specification.update({"Description": description})
if content_types is not None:
inference_specification.update(
{
"SupportedContentTypes": content_types,
}
)
if response_types is not None:
inference_specification.update(
{
"SupportedResponseMIMETypes": response_types,
}
)
if inference_instances is not None:
inference_specification.update(
{
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
}
)
if transform_instances is not None:
inference_specification.update(
{
"SupportedTransformInstanceTypes": transform_instances,
}
)
request_dict["AdditionalInferenceSpecificationsToAdd"] = [inference_specification]
request_dict.update({"ModelPackageArn": model_package_arn})
return request_dict


def update_args(args: Dict[str, Any], **kwargs):
"""Updates the request arguments dict with the value if populated.
Expand Down
43 changes: 43 additions & 0 deletions tests/integ/test_model_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sagemaker.utils import unique_name_from_base
from tests.integ import DATA_DIR
from sagemaker.xgboost import XGBoostModel
from sagemaker import image_uris

_XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone")

Expand Down Expand Up @@ -61,3 +62,45 @@ def test_update_approval_model_package(sagemaker_session):
sagemaker_session.sagemaker_client.delete_model_package_group(
ModelPackageGroupName=model_group_name
)


def test_inference_specification_addition(sagemaker_session):

model_group_name = unique_name_from_base("test-model-group")

sagemaker_session.sagemaker_client.create_model_package_group(
ModelPackageGroupName=model_group_name
)

xgb_model_data_s3 = sagemaker_session.upload_data(
path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"),
key_prefix="integ-test-data/xgboost/model",
)
model = XGBoostModel(
model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session
)

model_package = model.register(
content_types=["text/csv"],
response_types=["text/csv"],
inference_instances=["ml.m5.large"],
transform_instances=["ml.m5.large"],
model_package_group_name=model_group_name,
)

xgb_image = image_uris.retrieve(
"xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference"
)
model_package.add_inference_specification(image_uris=[xgb_image], name="Inference")
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=model_package.model_package_arn
)
assert len(desc_model_package["AdditionalInferenceSpecifications"]) == 1
assert desc_model_package["AdditionalInferenceSpecifications"][0]["Name"] == "Inference"

sagemaker_session.sagemaker_client.delete_model_package(
ModelPackageName=model_package.model_package_arn
)
sagemaker_session.sagemaker_client.delete_model_package_group(
ModelPackageGroupName=model_group_name
)
73 changes: 73 additions & 0 deletions tests/unit/sagemaker/model/test_model_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,76 @@ def test_model_package_auto_approve_on_deploy(update_approval_status, sagemaker_
update_approval_status.call_args_list[0][1]["approval_status"]
== ModelApprovalStatusEnum.APPROVED
)


def test_update_customer_metadata(sagemaker_session):
model_package = ModelPackage(
role="role",
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
sagemaker_session=sagemaker_session,
)

customer_metadata_to_update = {
"Key": "Value",
}
model_package.update_customer_metadata(customer_metadata_properties=customer_metadata_to_update)

sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN,
CustomerMetadataProperties=customer_metadata_to_update,
)


def test_remove_customer_metadata(sagemaker_session):
model_package = ModelPackage(
role="role",
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
sagemaker_session=sagemaker_session,
)

customer_metadata_to_remove = ["Key"]

model_package.remove_customer_metadata_properties(
customer_metadata_properties_to_remove=customer_metadata_to_remove
)

sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN,
CustomerMetadataPropertiesToRemove=customer_metadata_to_remove,
)


def test_add_inference_specification(sagemaker_session):
model_package = ModelPackage(
role="role",
model_package_arn=MODEL_PACKAGE_VERSIONED_ARN,
sagemaker_session=sagemaker_session,
)

image_uris = ["image_uri"]

containers = [{"Image": "image_uri"}]

try:
model_package.add_inference_specification(
image_uris=image_uris, name="Inference", containers=containers
)
except ValueError as ve:
assert "Cannot have both containers and image_uris." in str(ve)

try:
model_package.add_inference_specification(name="Inference")
except ValueError as ve:
assert "Should have either containers or image_uris for inference." in str(ve)

model_package.add_inference_specification(image_uris=image_uris, name="Inference")

sagemaker_session.sagemaker_client.update_model_package.assert_called_with(
ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN,
AdditionalInferenceSpecificationsToAdd=[
{
"Containers": [{"Image": "image_uri"}],
"Name": "Inference",
}
],
)

0 comments on commit 12d5040

Please sign in to comment.