Skip to content

Commit

Permalink
Merge branch 'master' into feat/instance-specific-jumpstart-host-requ…
Browse files Browse the repository at this point in the history
…irements
  • Loading branch information
evakravi authored Jan 31, 2024
2 parents 39d3fa6 + 8b206ba commit a3deb08
Show file tree
Hide file tree
Showing 9 changed files with 1,221 additions and 24 deletions.
782 changes: 782 additions & 0 deletions src/sagemaker/image_uri_config/instance_gpu_info.json

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions src/sagemaker/instance_types_gpu_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Accessors to retrieve instance types GPU info."""
from __future__ import absolute_import

import json
import os
from typing import Dict


def retrieve(region: str) -> Dict[str, Dict[str, int]]:
"""Retrieves instance types GPU info of the given region.
Args:
region (str): The AWS region.
Returns:
dict[str, dict[str, int]]: A dictionary that contains instance types as keys
and GPU info as values or empty dictionary if the
config for the given region is not found.
Raises:
ValueError: If no config found.
"""
config_path = os.path.join(
os.path.dirname(__file__), "image_uri_config", "instance_gpu_info.json"
)
try:
with open(config_path) as f:
instance_types_gpu_info_config = json.load(f)
return instance_types_gpu_info_config.get(region, {})
except FileNotFoundError:
raise ValueError("Could not find instance types gpu info.")
2 changes: 2 additions & 0 deletions src/sagemaker/model_monitor/clarify_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ def create_monitoring_schedule(
self.monitoring_schedule_name = monitor_schedule_name
except Exception:
logger.exception("Failed to create monitoring schedule.")
self.monitoring_schedule_name = None
# noinspection PyBroadException
try:
self.sagemaker_session.sagemaker_client.delete_model_bias_job_definition(
Expand Down Expand Up @@ -1109,6 +1110,7 @@ def create_monitoring_schedule(
self.monitoring_schedule_name = monitor_schedule_name
except Exception:
logger.exception("Failed to create monitoring schedule.")
self.monitoring_schedule_name = None
# noinspection PyBroadException
try:
self.sagemaker_session.sagemaker_client.delete_model_explainability_job_definition(
Expand Down
54 changes: 30 additions & 24 deletions src/sagemaker/model_monitor/model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,30 +415,34 @@ def create_monitoring_schedule(
if arguments is not None:
self.arguments = arguments

self.sagemaker_session.create_monitoring_schedule(
monitoring_schedule_name=self.monitoring_schedule_name,
schedule_expression=schedule_cron_expression,
statistics_s3_uri=statistics_s3_uri,
constraints_s3_uri=constraints_s3_uri,
monitoring_inputs=[normalized_monitoring_input],
monitoring_output_config=monitoring_output_config,
instance_count=self.instance_count,
instance_type=self.instance_type,
volume_size_in_gb=self.volume_size_in_gb,
volume_kms_key=self.volume_kms_key,
image_uri=self.image_uri,
entrypoint=self.entrypoint,
arguments=self.arguments,
record_preprocessor_source_uri=None,
post_analytics_processor_source_uri=None,
max_runtime_in_seconds=self.max_runtime_in_seconds,
environment=self.env,
network_config=network_config_dict,
role_arn=self.sagemaker_session.expand_role(self.role),
tags=self.tags,
data_analysis_start_time=data_analysis_start_time,
data_analysis_end_time=data_analysis_end_time,
)
try:
self.sagemaker_session.create_monitoring_schedule(
monitoring_schedule_name=self.monitoring_schedule_name,
schedule_expression=schedule_cron_expression,
statistics_s3_uri=statistics_s3_uri,
constraints_s3_uri=constraints_s3_uri,
monitoring_inputs=[normalized_monitoring_input],
monitoring_output_config=monitoring_output_config,
instance_count=self.instance_count,
instance_type=self.instance_type,
volume_size_in_gb=self.volume_size_in_gb,
volume_kms_key=self.volume_kms_key,
image_uri=self.image_uri,
entrypoint=self.entrypoint,
arguments=self.arguments,
record_preprocessor_source_uri=None,
post_analytics_processor_source_uri=None,
max_runtime_in_seconds=self.max_runtime_in_seconds,
environment=self.env,
network_config=network_config_dict,
role_arn=self.sagemaker_session.expand_role(self.role),
tags=self.tags,
data_analysis_start_time=data_analysis_start_time,
data_analysis_end_time=data_analysis_end_time,
)
except Exception:
self.monitoring_schedule_name = None
raise

def update_monitoring_schedule(
self,
Expand Down Expand Up @@ -2054,6 +2058,7 @@ def create_monitoring_schedule(
self.monitoring_schedule_name = monitor_schedule_name
except Exception:
logger.exception("Failed to create monitoring schedule.")
self.monitoring_schedule_name = None
# noinspection PyBroadException
try:
self.sagemaker_session.sagemaker_client.delete_data_quality_job_definition(
Expand Down Expand Up @@ -3173,6 +3178,7 @@ def create_monitoring_schedule(
self.monitoring_schedule_name = monitor_schedule_name
except Exception:
logger.exception("Failed to create monitoring schedule.")
self.monitoring_schedule_name = None
# noinspection PyBroadException
try:
self.sagemaker_session.sagemaker_client.delete_model_quality_job_definition(
Expand Down
110 changes: 110 additions & 0 deletions src/sagemaker/serve/utils/hardware_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Utilities for detecting available GPUs and Aggregate GPU Memory size of an instance"""
from __future__ import absolute_import

import logging
from typing import Tuple

from botocore.exceptions import ClientError

from sagemaker import Session
from sagemaker import instance_types_gpu_info

logger = logging.getLogger(__name__)


def _get_gpu_info(instance_type: str, session: Session) -> Tuple[int, int]:
"""Get GPU info for the provided instance
Args:
instance_type (str)
session: The session to use.
Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0,
and aggregate memory size in MiB at index 1.
Raises:
ValueError: If The given instance type does not exist or GPU is not enabled.
"""
ec2_client = session.boto_session.client("ec2")
ec2_instance = _format_instance_type(instance_type)

try:
instance_info = ec2_client.describe_instance_types(InstanceTypes=[ec2_instance]).get(
"InstanceTypes"
)[0]
except ClientError:
raise ValueError(f"Provided instance_type is not GPU enabled: [#{ec2_instance}]")

if instance_info is not None:
gpus_info = instance_info.get("GpuInfo")
if gpus_info is not None:
gpus = gpus_info.get("Gpus")
if gpus is not None and len(gpus) > 0:
count = gpus[0].get("Count")
total_gpu_memory_in_mib = gpus_info.get("TotalGpuMemoryInMiB")
if count and total_gpu_memory_in_mib:
instance_gpu_info = (
count,
total_gpu_memory_in_mib,
)
logger.info("GPU Info [%s]: %s", ec2_instance, instance_gpu_info)
return instance_gpu_info

raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]")


def _get_gpu_info_fallback(instance_type: str, region: str) -> Tuple[int, int]:
"""Get GPU info for the provided from the config
Args:
instance_type (str):
region: The AWS region.
Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0,
and aggregate memory size in MiB at index 1.
Raises:
ValueError: If The given instance type does not exist.
"""
instance_types_gpu_info_config = instance_types_gpu_info.retrieve(region)
fallback_instance_gpu_info = instance_types_gpu_info_config.get(instance_type)

ec2_instance = _format_instance_type(instance_type)
if fallback_instance_gpu_info is None:
raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]")

fallback_instance_gpu_info = (
fallback_instance_gpu_info.get("Count"),
fallback_instance_gpu_info.get("TotalGpuMemoryInMiB"),
)
logger.info("GPU Info [%s]: %s", ec2_instance, fallback_instance_gpu_info)
return fallback_instance_gpu_info


def _format_instance_type(instance_type: str) -> str:
"""Formats provided instance type name
Args:
instance_type (str):
Returns: formatted instance type.
"""
split_instance = instance_type.split(".")

if len(split_instance) > 2:
split_instance.pop(0)

ec2_instance = ".".join(split_instance)
return ec2_instance
44 changes: 44 additions & 0 deletions tests/integ/sagemaker/serve/utils/test_hardware_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest

from sagemaker.serve.utils import hardware_detector

REGION = "us-west-2"
VALID_INSTANCE_TYPE = "ml.g5.48xlarge"
INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge"
EXPECTED_INSTANCE_GPU_INFO = (8, 196608)


def test_get_gpu_info_success(sagemaker_session):
gpu_info = hardware_detector._get_gpu_info(VALID_INSTANCE_TYPE, sagemaker_session)

assert gpu_info == EXPECTED_INSTANCE_GPU_INFO


def test_get_gpu_info_throws(sagemaker_session):
with pytest.raises(ValueError):
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session)


def test_get_gpu_info_fallback_success():
gpu_info = hardware_detector._get_gpu_info_fallback(VALID_INSTANCE_TYPE, REGION)

assert gpu_info == EXPECTED_INSTANCE_GPU_INFO


def test_get_gpu_info_fallback_throws():
with pytest.raises(ValueError):
hardware_detector._get_gpu_info_fallback(INVALID_INSTANCE_TYPE, REGION)
82 changes: 82 additions & 0 deletions tests/integ/test_model_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2488,3 +2488,85 @@ def test_one_time_monitoring_schedule(sagemaker_session):
my_default_monitor.stop_monitoring_schedule()
my_default_monitor.delete_monitoring_schedule()
raise e


def test_create_monitoring_schedule_with_validation_error(sagemaker_session):
my_default_monitor = DefaultModelMonitor(
role=ROLE,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
volume_size_in_gb=VOLUME_SIZE_IN_GB,
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
sagemaker_session=sagemaker_session,
env=ENVIRONMENT,
tags=TAGS,
network_config=NETWORK_CONFIG,
)

output_s3_uri = os.path.join(
"s3://",
sagemaker_session.default_bucket(),
"integ-test-monitoring-output-bucket",
str(uuid.uuid4()),
)

data_captured_destination_s3_uri = os.path.join(
"s3://",
sagemaker_session.default_bucket(),
"sagemaker-serving-batch-transform",
str(uuid.uuid4()),
)

batch_transform_input = BatchTransformInput(
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
destination="/opt/ml/processing/output",
dataset_format=MonitoringDatasetFormat.csv(header=False),
)

statistics = Statistics.from_file_path(
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
sagemaker_session=sagemaker_session,
)

constraints = Constraints.from_file_path(
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
sagemaker_session=sagemaker_session,
)

try:
my_default_monitor.create_monitoring_schedule(
monitor_schedule_name="schedule-name-more-than-63-characters-to-get-a-validation-exception",
batch_transform_input=batch_transform_input,
output_s3_uri=output_s3_uri,
statistics=statistics,
constraints=constraints,
schedule_cron_expression=CronExpressionGenerator.now(),
data_analysis_start_time="-PT1H",
data_analysis_end_time="-PT0H",
enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS,
)
except Exception as e:
assert "ValidationException" in str(e)

my_default_monitor.create_monitoring_schedule(
monitor_schedule_name=unique_name_from_base("valid-schedule-name"),
batch_transform_input=batch_transform_input,
output_s3_uri=output_s3_uri,
statistics=statistics,
constraints=constraints,
schedule_cron_expression=CronExpressionGenerator.now(),
data_analysis_start_time="-PT1H",
data_analysis_end_time="-PT0H",
enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS,
)
try:

_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)

my_default_monitor.stop_monitoring_schedule()
my_default_monitor.delete_monitoring_schedule()

except Exception as e:
my_default_monitor.stop_monitoring_schedule()
my_default_monitor.delete_monitoring_schedule()
raise e
Loading

0 comments on commit a3deb08

Please sign in to comment.