Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add hub_arn support for accept_types, content_types, serializers, deserializers, and predictor #4463

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions src/sagemaker/accept_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def retrieve_options(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -36,6 +37,8 @@ def retrieve_options(
retrieve the supported accept types. (Default: None).
model_version (str): The version of the model for which to retrieve the
supported accept types. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -59,11 +62,12 @@ def retrieve_options(
)

return artifacts._retrieve_supported_accept_types(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)

Expand All @@ -72,6 +76,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -85,6 +90,8 @@ def retrieve_default(
retrieve the default accept type. (Default: None).
model_version (str): The version of the model for which to retrieve the
default accept type. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -108,10 +115,11 @@ def retrieve_default(
)

return artifacts._retrieve_default_accept_type(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)
28 changes: 18 additions & 10 deletions src/sagemaker/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def retrieve_options(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -36,6 +37,8 @@ def retrieve_options(
retrieve the supported content types. (Default: None).
model_version (str): The version of the model for which to retrieve the
supported content types. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -59,11 +62,12 @@ def retrieve_options(
)

return artifacts._retrieve_supported_content_types(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
Comment on lines +65 to +70
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for making kwarg style arguments!

sagemaker_session=sagemaker_session,
)

Expand All @@ -72,6 +76,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -85,6 +90,8 @@ def retrieve_default(
retrieve the default content type. (Default: None).
model_version (str): The version of the model for which to retrieve the
default content type. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -108,11 +115,12 @@ def retrieve_default(
)

return artifacts._retrieve_default_content_type(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)

Expand Down
28 changes: 18 additions & 10 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def retrieve_options(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -55,6 +56,8 @@ def retrieve_options(
retrieve the supported deserializers. (Default: None).
model_version (str): The version of the model for which to retrieve the
supported deserializers. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -79,11 +82,12 @@ def retrieve_options(
)

return artifacts._retrieve_deserializer_options(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)

Expand All @@ -92,6 +96,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -105,6 +110,8 @@ def retrieve_default(
retrieve the default deserializer. (Default: None).
model_version (str): The version of the model for which to retrieve the
default deserializer. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -129,10 +136,11 @@ def retrieve_default(
)

return artifacts._retrieve_default_deserializer(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)
8 changes: 8 additions & 0 deletions src/sagemaker/jumpstart/artifacts/kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
def _retrieve_model_init_kwargs(
model_id: str,
model_version: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -43,6 +44,8 @@ def _retrieve_model_init_kwargs(
retrieve the kwargs.
model_version (str): Version of the JumpStart model for which to retrieve the
kwargs.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
region (Optional[str]): Region for which to retrieve kwargs.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand All @@ -66,6 +69,7 @@ def _retrieve_model_init_kwargs(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=JumpStartScriptScope.INFERENCE,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand All @@ -85,6 +89,7 @@ def _retrieve_model_deploy_kwargs(
model_id: str,
model_version: str,
instance_type: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -99,6 +104,8 @@ def _retrieve_model_deploy_kwargs(
kwargs.
instance_type (str): Instance type of the hosting endpoint, to determine if volume size
is supported.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
region (Optional[str]): Region for which to retrieve kwargs.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand All @@ -123,6 +130,7 @@ def _retrieve_model_deploy_kwargs(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=JumpStartScriptScope.INFERENCE,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/artifacts/model_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _retrieve_model_package_arn(
model_version: str,
instance_type: Optional[str],
region: Optional[str],
hub_arn: Optional[str] = None,
scope: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -46,6 +47,8 @@ def _retrieve_model_package_arn(
instance_type (Optional[str]): An instance type to optionally supply in order to get an arn
specific for the instance type.
region (Optional[str]): Region for which to retrieve the model package arn.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
scope (Optional[str]): Scope for which to retrieve the model package arn.
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
Expand All @@ -69,6 +72,7 @@ def _retrieve_model_package_arn(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down
Loading