Skip to content

Commit

Permalink
feat: hub service sm session (aws#4422)
Browse files Browse the repository at this point in the history
  • Loading branch information
evakravi authored and bencrabtree committed Feb 27, 2024
1 parent c799d1a commit b774aad
Show file tree
Hide file tree
Showing 2 changed files with 428 additions and 1 deletion.
265 changes: 264 additions & 1 deletion src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6213,7 +6213,7 @@ def _intercept_create_request(
self,
request: typing.Dict,
create,
func_name: str = None
func_name: str = None,
# pylint: disable=unused-argument
):
"""This function intercepts the create job request.
Expand Down Expand Up @@ -6476,6 +6476,269 @@ def wait_for_inference_recommendations_job(
_check_job_status(job_name, desc, "Status")
return desc

def create_hub(
self,
hub_name: str,
hub_description: str,
hub_display_name: str = None,
hub_search_keywords: List[str] = None,
s3_storage_config: Dict[str, Any] = None,
tags: List[Dict[str, Any]] = None,
) -> Dict[str, str]:
"""Creates a SageMaker Hub
Args:
hub_name (str): The name of the Hub to create.
hub_description (str): A description of the Hub.
hub_display_name (str): The display name of the Hub.
hub_search_keywords (list): The searchable keywords for the Hub.
s3_storage_config (S3StorageConfig): The Amazon S3 storage configuration for the Hub.
tags (list): Any tags to associate with the Hub.
Returns:
(dict): Return value from the ``CreateHub`` API.
"""
request = {"HubName": hub_name, "HubDescription": hub_description}
if hub_display_name:
request["HubDisplayName"] = hub_display_name
if hub_search_keywords:
request["HubSearchKeywords"] = hub_search_keywords
if s3_storage_config:
request["S3StorageConfig"] = s3_storage_config
if tags:
request["Tags"] = tags

return self.sagemaker_client.create_hub(**request)

def describe_hub(self, hub_name: str) -> Dict[str, Any]:
"""Describes a SageMaker Hub
Args:
hub_name (str): The name of the hub to describe.
Returns:
(dict): Return value for ``DescribeHub`` API
"""
request = {"HubName": hub_name}

return self.sagemaker_client.describe_hub(**request)

def list_hubs(
self,
creation_time_after: str = None,
creation_time_before: str = None,
max_results: int = None,
max_schema_version: str = None,
name_contains: str = None,
next_token: str = None,
sort_by: str = None,
sort_order: str = None,
) -> Dict[str, Any]:
"""Lists all existing SageMaker Hubs
Args:
creation_time_after (str): Only list HubContent that was created after
the time specified.
creation_time_before (str): Only list HubContent that was created
before the time specified.
max_results (int): The maximum amount of HubContent to list.
max_schema_version (str): The upper bound of the HubContentSchemaVersion.
name_contains (str): Only list HubContent if the name contains the specified string.
next_token (str): If the response to a previous ``ListHubContents`` request was
truncated, the response includes a ``NextToken``. To retrieve the next set of
hub content, use the token in the next request.
sort_by (str): Sort HubContent versions by either name or creation time.
sort_order (str): Sort Hubs by ascending or descending order.
Returns:
(dict): Return value for ``ListHubs`` API
"""
request = {}
if creation_time_after:
request["CreationTimeAfter"] = creation_time_after
if creation_time_before:
request["CreationTimeBefore"] = creation_time_before
if max_results:
request["MaxResults"] = max_results
if max_schema_version:
request["MaxSchemaVersion"] = max_schema_version
if name_contains:
request["NameContains"] = name_contains
if next_token:
request["NextToken"] = next_token
if sort_by:
request["SortBy"] = sort_by
if sort_order:
request["SortOrder"] = sort_order

return self.sagemaker_client.list_hubs(**request)

def list_hub_contents(
self,
hub_name: str,
hub_content_type: str,
creation_time_after: str = None,
creation_time_before: str = None,
max_results: int = None,
max_schema_version: str = None,
name_contains: str = None,
next_token: str = None,
sort_by: str = None,
sort_order: str = None,
) -> Dict[str, Any]:
"""Lists the HubContents in a SageMaker Hub
Args:
hub_name (str): The name of the Hub to list the contents of.
hub_content_type (str): The type of the HubContent to list.
creation_time_after (str): Only list HubContent that was created after the
time specified.
creation_time_before (str): Only list HubContent that was created before the
time specified.
max_results (int): The maximum amount of HubContent to list.
max_schema_version (str): The upper bound of the HubContentSchemaVersion.
name_contains (str): Only list HubContent if the name contains the specified string.
next_token (str): If the response to a previous ``ListHubContents`` request was
truncated, the response includes a ``NextToken``. To retrieve the next set of
hub content, use the token in the next request.
sort_by (str): Sort HubContent versions by either name or creation time.
sort_order (str): Sort Hubs by ascending or descending order.
Returns:
(dict): Return value for ``ListHubContents`` API
"""
request = {"HubName": hub_name, "HubContentType": hub_content_type}
if creation_time_after:
request["CreationTimeAfter"] = creation_time_after
if creation_time_before:
request["CreationTimeBefore"] = creation_time_before
if max_results:
request["MaxResults"] = max_results
if max_schema_version:
request["MaxSchemaVersion"] = max_schema_version
if name_contains:
request["NameContains"] = name_contains
if next_token:
request["NextToken"] = next_token
if sort_by:
request["SortBy"] = sort_by
if sort_order:
request["SortOrder"] = sort_order

return self.sagemaker_client.list_hub_contents(**request)

def delete_hub(self, hub_name: str) -> None:
"""Deletes a SageMaker Hub
Args:
hub_name (str): The name of the hub to delete.
"""
request = {"HubName": hub_name}

return self.sagemaker_client.delete_hub(**request)

def import_hub_content(
self,
document_schema_version: str,
hub_content_name: str,
hub_content_type: str,
hub_name: str,
hub_content_document: str,
hub_content_display_name: str = None,
hub_content_description: str = None,
hub_content_version: str = None,
hub_content_markdown: str = None,
hub_content_search_keywords: List[str] = None,
tags: List[Dict[str, Any]] = None,
) -> Dict[str, str]:
"""Imports a new HubContent into a SageMaker Hub
Args:
document_schema_version (str): The version of the HubContent schema to import.
hub_content_name (str): The name of the HubContent to import.
hub_content_version (str): The version of the HubContent to import.
hub_content_type (str): The type of HubContent to import.
hub_name (str): The name of the Hub to import content to.
hub_content_document (str): The hub content document that describes information
about the hub content such as type, associated containers, scripts, and more.
hub_content_display_name (str): The display name of the HubContent to import.
hub_content_description (str): The description of the HubContent to import.
hub_content_markdown (str): A string that provides a description of the HubContent.
This string can include links, tables, and standard markdown formatting.
hub_content_search_keywords (list): The searchable keywords of the HubContent.
tags (list): Any tags associated with the HubContent.
Returns:
(dict): Return value for ``ImportHubContent`` API
"""
request = {
"DocumentSchemaVersion": document_schema_version,
"HubContentName": hub_content_name,
"HubContentType": hub_content_type,
"HubName": hub_name,
"HubContentDocument": hub_content_document,
}
if hub_content_display_name:
request["HubContentDisplayName"] = hub_content_display_name
if hub_content_description:
request["HubContentDescription"] = hub_content_description
if hub_content_version:
request["HubContentVersion"] = hub_content_version
if hub_content_markdown:
request["HubContentMarkdown"] = hub_content_markdown
if hub_content_search_keywords:
request["HubContentSearchKeywords"] = hub_content_search_keywords
if tags:
request["Tags"] = tags

return self.sagemaker_client.import_hub_content(**request)

def describe_hub_content(
self,
hub_content_name: str,
hub_content_type: str,
hub_name: str,
hub_content_version: str = None,
) -> Dict[str, Any]:
"""Describes a HubContent in a SageMaker Hub
Args:
hub_content_name (str): The name of the HubContent to describe.
hub_content_type (str): The type of HubContent in the Hub.
hub_name (str): The name of the Hub that contains the HubContent to describe.
hub_content_version (str): The version of the HubContent to describe
Returns:
(dict): Return value for ``DescribeHubContent`` API
"""
request = {
"HubContentName": hub_content_name,
"HubContentType": hub_content_type,
"HubName": hub_name,
}
if hub_content_version:
request["HubContentVersion"] = hub_content_version

return self.sagemaker_client.describe_hub_content(**request)

def delete_hub_content(
self, hub_content_name: str, hub_content_version: str, hub_content_type: str, hub_name: str
) -> None:
"""Deletes a given HubContent in a SageMaker Hub
Args:
hub_content_name (str): The name of the content thatyou want to delete from a Hub.
hub_content_version (str): The version of the content that you want to delete from
a Hub.
hub_content_type (str): The type of the content that you want to delete from a Hub.
hub_name (str): The name of the Hub that you want to delete content in.
"""
request = {
"HubContentName": hub_content_name,
"HubContentType": hub_content_type,
"HubName": hub_name,
"HubContentVersion": hub_content_version,
}

return self.sagemaker_client.delete_hub_content(**request)


def get_model_package_args(
content_types=None,
Expand Down
Loading

0 comments on commit b774aad

Please sign in to comment.