-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Logic to detect hardware GPU count and aggregate GPU memory siz…
…e in MiB (#4389) * Add logic to detect hardware GPU count and aggregate GPU memory size in MiB * Fix all formatting * Addressed PR review comments * Addressed PR Review messages * Addressed PR Review Messages * Addressed PR Review comments * Addressed PR Review Comments * Add integration tests * Add config * Fix integration tests * Include Instance Types GPU infor Config files * Addressed PR review comments * Fix unit tests * Fix unit test: 'Mock' object is not subscriptable --------- Co-authored-by: Jonathan Makunga <makung@amazon.com>
- Loading branch information
Showing
6 changed files
with
1,107 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Accessors to retrieve instance types GPU info.""" | ||
from __future__ import absolute_import | ||
|
||
import json | ||
import os | ||
from typing import Dict | ||
|
||
|
||
def retrieve(region: str) -> Dict[str, Dict[str, int]]: | ||
"""Retrieves instance types GPU info of the given region. | ||
Args: | ||
region (str): The AWS region. | ||
Returns: | ||
dict[str, dict[str, int]]: A dictionary that contains instance types as keys | ||
and GPU info as values or empty dictionary if the | ||
config for the given region is not found. | ||
Raises: | ||
ValueError: If no config found. | ||
""" | ||
config_path = os.path.join( | ||
os.path.dirname(__file__), "image_uri_config", "instance_gpu_info.json" | ||
) | ||
try: | ||
with open(config_path) as f: | ||
instance_types_gpu_info_config = json.load(f) | ||
return instance_types_gpu_info_config.get(region, {}) | ||
except FileNotFoundError: | ||
raise ValueError("Could not find instance types gpu info.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Utilities for detecting available GPUs and Aggregate GPU Memory size of an instance""" | ||
from __future__ import absolute_import | ||
|
||
import logging | ||
from typing import Tuple | ||
|
||
from botocore.exceptions import ClientError | ||
|
||
from sagemaker import Session | ||
from sagemaker import instance_types_gpu_info | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _get_gpu_info(instance_type: str, session: Session) -> Tuple[int, int]: | ||
"""Get GPU info for the provided instance | ||
Args: | ||
instance_type (str) | ||
session: The session to use. | ||
Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0, | ||
and aggregate memory size in MiB at index 1. | ||
Raises: | ||
ValueError: If The given instance type does not exist or GPU is not enabled. | ||
""" | ||
ec2_client = session.boto_session.client("ec2") | ||
ec2_instance = _format_instance_type(instance_type) | ||
|
||
try: | ||
instance_info = ec2_client.describe_instance_types(InstanceTypes=[ec2_instance]).get( | ||
"InstanceTypes" | ||
)[0] | ||
except ClientError: | ||
raise ValueError(f"Provided instance_type is not GPU enabled: [#{ec2_instance}]") | ||
|
||
if instance_info is not None: | ||
gpus_info = instance_info.get("GpuInfo") | ||
if gpus_info is not None: | ||
gpus = gpus_info.get("Gpus") | ||
if gpus is not None and len(gpus) > 0: | ||
count = gpus[0].get("Count") | ||
total_gpu_memory_in_mib = gpus_info.get("TotalGpuMemoryInMiB") | ||
if count and total_gpu_memory_in_mib: | ||
instance_gpu_info = ( | ||
count, | ||
total_gpu_memory_in_mib, | ||
) | ||
logger.info("GPU Info [%s]: %s", ec2_instance, instance_gpu_info) | ||
return instance_gpu_info | ||
|
||
raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]") | ||
|
||
|
||
def _get_gpu_info_fallback(instance_type: str, region: str) -> Tuple[int, int]: | ||
"""Get GPU info for the provided from the config | ||
Args: | ||
instance_type (str): | ||
region: The AWS region. | ||
Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0, | ||
and aggregate memory size in MiB at index 1. | ||
Raises: | ||
ValueError: If The given instance type does not exist. | ||
""" | ||
instance_types_gpu_info_config = instance_types_gpu_info.retrieve(region) | ||
fallback_instance_gpu_info = instance_types_gpu_info_config.get(instance_type) | ||
|
||
ec2_instance = _format_instance_type(instance_type) | ||
if fallback_instance_gpu_info is None: | ||
raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]") | ||
|
||
fallback_instance_gpu_info = ( | ||
fallback_instance_gpu_info.get("Count"), | ||
fallback_instance_gpu_info.get("TotalGpuMemoryInMiB"), | ||
) | ||
logger.info("GPU Info [%s]: %s", ec2_instance, fallback_instance_gpu_info) | ||
return fallback_instance_gpu_info | ||
|
||
|
||
def _format_instance_type(instance_type: str) -> str: | ||
"""Formats provided instance type name | ||
Args: | ||
instance_type (str): | ||
Returns: formatted instance type. | ||
""" | ||
split_instance = instance_type.split(".") | ||
|
||
if len(split_instance) > 2: | ||
split_instance.pop(0) | ||
|
||
ec2_instance = ".".join(split_instance) | ||
return ec2_instance |
44 changes: 44 additions & 0 deletions
44
tests/integ/sagemaker/serve/utils/test_hardware_detector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import pytest | ||
|
||
from sagemaker.serve.utils import hardware_detector | ||
|
||
REGION = "us-west-2" | ||
VALID_INSTANCE_TYPE = "ml.g5.48xlarge" | ||
INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge" | ||
EXPECTED_INSTANCE_GPU_INFO = (8, 196608) | ||
|
||
|
||
def test_get_gpu_info_success(sagemaker_session): | ||
gpu_info = hardware_detector._get_gpu_info(VALID_INSTANCE_TYPE, sagemaker_session) | ||
|
||
assert gpu_info == EXPECTED_INSTANCE_GPU_INFO | ||
|
||
|
||
def test_get_gpu_info_throws(sagemaker_session): | ||
with pytest.raises(ValueError): | ||
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session) | ||
|
||
|
||
def test_get_gpu_info_fallback_success(): | ||
gpu_info = hardware_detector._get_gpu_info_fallback(VALID_INSTANCE_TYPE, REGION) | ||
|
||
assert gpu_info == EXPECTED_INSTANCE_GPU_INFO | ||
|
||
|
||
def test_get_gpu_info_fallback_throws(): | ||
with pytest.raises(ValueError): | ||
hardware_detector._get_gpu_info_fallback(INVALID_INSTANCE_TYPE, REGION) |
98 changes: 98 additions & 0 deletions
98
tests/unit/sagemaker/serve/utils/test_hardware_detector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
from botocore.exceptions import ClientError | ||
import pytest | ||
|
||
from sagemaker.serve.utils import hardware_detector | ||
|
||
REGION = "us-west-2" | ||
VALID_INSTANCE_TYPE = "ml.g5.48xlarge" | ||
INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge" | ||
EXPECTED_INSTANCE_GPU_INFO = (8, 196608) | ||
|
||
|
||
def test_get_gpu_info_success(sagemaker_session, boto_session): | ||
boto_session.client("ec2").describe_instance_types.return_value = { | ||
"InstanceTypes": [ | ||
{ | ||
"GpuInfo": { | ||
"Gpus": [ | ||
{ | ||
"Name": "A10G", | ||
"Manufacturer": "NVIDIA", | ||
"Count": 8, | ||
"MemoryInfo": {"SizeInMiB": 24576}, | ||
} | ||
], | ||
"TotalGpuMemoryInMiB": 196608, | ||
}, | ||
} | ||
] | ||
} | ||
|
||
instance_gpu_info = hardware_detector._get_gpu_info(VALID_INSTANCE_TYPE, sagemaker_session) | ||
|
||
boto_session.client("ec2").describe_instance_types.assert_called_once_with( | ||
InstanceTypes=["g5.48xlarge"] | ||
) | ||
assert instance_gpu_info == EXPECTED_INSTANCE_GPU_INFO | ||
|
||
|
||
def test_get_gpu_info_throws(sagemaker_session, boto_session): | ||
boto_session.client("ec2").describe_instance_types.return_value = {"InstanceTypes": [{}]} | ||
|
||
with pytest.raises(ValueError): | ||
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session) | ||
|
||
|
||
def test_get_gpu_info_describe_instance_types_throws(sagemaker_session, boto_session): | ||
boto_session.client("ec2").describe_instance_types.side_effect = ClientError( | ||
{ | ||
"Error": { | ||
"Code": "InvalidInstanceType", | ||
"Message": f"An error occurred (InvalidInstanceType) when calling the DescribeInstanceTypes " | ||
f"operation: The following supplied instance types do not exist: [{INVALID_INSTANCE_TYPE}]", | ||
} | ||
}, | ||
"DescribeInstanceTypes", | ||
) | ||
|
||
with pytest.raises(ValueError): | ||
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session) | ||
|
||
|
||
def test_get_gpu_info_fallback_success(): | ||
fallback_instance_gpu_info = hardware_detector._get_gpu_info_fallback( | ||
VALID_INSTANCE_TYPE, REGION | ||
) | ||
|
||
assert fallback_instance_gpu_info == EXPECTED_INSTANCE_GPU_INFO | ||
|
||
|
||
def test_get_gpu_info_fallback_throws(): | ||
with pytest.raises(ValueError): | ||
hardware_detector._get_gpu_info_fallback(INVALID_INSTANCE_TYPE, REGION) | ||
|
||
|
||
def test_format_instance_type_success(): | ||
formatted_instance_type = hardware_detector._format_instance_type(VALID_INSTANCE_TYPE) | ||
|
||
assert formatted_instance_type == "g5.48xlarge" | ||
|
||
|
||
def test_format_instance_type_without_ml_success(): | ||
formatted_instance_type = hardware_detector._format_instance_type("g5.48xlarge") | ||
|
||
assert formatted_instance_type == "g5.48xlarge" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
from sagemaker import instance_types_gpu_info | ||
|
||
REGION = "us-west-2" | ||
INVALID_REGION = "invalid-region" | ||
|
||
|
||
def test_retrieve_success(): | ||
data = instance_types_gpu_info.retrieve(REGION) | ||
|
||
assert len(data) > 0 | ||
|
||
|
||
def test_retrieve_throws(): | ||
data = instance_types_gpu_info.retrieve(INVALID_REGION) | ||
|
||
assert len(data) == 0 |