Skip to content

Commit

Permalink
feat: Logic to detect hardware GPU count and aggregate GPU memory siz…
Browse files Browse the repository at this point in the history
…e in MiB (#4389)

* Add logic to detect hardware GPU count and aggregate GPU memory size in MiB

* Fix all formatting

* Addressed PR review comments

* Addressed PR Review messages

* Addressed PR Review Messages

* Addressed PR Review comments

* Addressed PR Review Comments

* Add integration tests

* Add config

* Fix integration tests

* Include Instance Types GPU infor Config files

* Addressed PR review comments

* Fix unit tests

* Fix unit test: 'Mock' object is not subscriptable

---------

Co-authored-by: Jonathan Makunga <makung@amazon.com>
  • Loading branch information
makungaj1 and Jonathan Makunga authored Jan 30, 2024
1 parent fc11ace commit 427dec6
Show file tree
Hide file tree
Showing 6 changed files with 1,107 additions and 0 deletions.
782 changes: 782 additions & 0 deletions src/sagemaker/image_uri_config/instance_gpu_info.json

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions src/sagemaker/instance_types_gpu_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Accessors to retrieve instance types GPU info."""
from __future__ import absolute_import

import json
import os
from typing import Dict


def retrieve(region: str) -> Dict[str, Dict[str, int]]:
"""Retrieves instance types GPU info of the given region.
Args:
region (str): The AWS region.
Returns:
dict[str, dict[str, int]]: A dictionary that contains instance types as keys
and GPU info as values or empty dictionary if the
config for the given region is not found.
Raises:
ValueError: If no config found.
"""
config_path = os.path.join(
os.path.dirname(__file__), "image_uri_config", "instance_gpu_info.json"
)
try:
with open(config_path) as f:
instance_types_gpu_info_config = json.load(f)
return instance_types_gpu_info_config.get(region, {})
except FileNotFoundError:
raise ValueError("Could not find instance types gpu info.")
110 changes: 110 additions & 0 deletions src/sagemaker/serve/utils/hardware_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Utilities for detecting available GPUs and Aggregate GPU Memory size of an instance"""
from __future__ import absolute_import

import logging
from typing import Tuple

from botocore.exceptions import ClientError

from sagemaker import Session
from sagemaker import instance_types_gpu_info

logger = logging.getLogger(__name__)


def _get_gpu_info(instance_type: str, session: Session) -> Tuple[int, int]:
"""Get GPU info for the provided instance
Args:
instance_type (str)
session: The session to use.
Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0,
and aggregate memory size in MiB at index 1.
Raises:
ValueError: If The given instance type does not exist or GPU is not enabled.
"""
ec2_client = session.boto_session.client("ec2")
ec2_instance = _format_instance_type(instance_type)

try:
instance_info = ec2_client.describe_instance_types(InstanceTypes=[ec2_instance]).get(
"InstanceTypes"
)[0]
except ClientError:
raise ValueError(f"Provided instance_type is not GPU enabled: [#{ec2_instance}]")

if instance_info is not None:
gpus_info = instance_info.get("GpuInfo")
if gpus_info is not None:
gpus = gpus_info.get("Gpus")
if gpus is not None and len(gpus) > 0:
count = gpus[0].get("Count")
total_gpu_memory_in_mib = gpus_info.get("TotalGpuMemoryInMiB")
if count and total_gpu_memory_in_mib:
instance_gpu_info = (
count,
total_gpu_memory_in_mib,
)
logger.info("GPU Info [%s]: %s", ec2_instance, instance_gpu_info)
return instance_gpu_info

raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]")


def _get_gpu_info_fallback(instance_type: str, region: str) -> Tuple[int, int]:
"""Get GPU info for the provided from the config
Args:
instance_type (str):
region: The AWS region.
Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0,
and aggregate memory size in MiB at index 1.
Raises:
ValueError: If The given instance type does not exist.
"""
instance_types_gpu_info_config = instance_types_gpu_info.retrieve(region)
fallback_instance_gpu_info = instance_types_gpu_info_config.get(instance_type)

ec2_instance = _format_instance_type(instance_type)
if fallback_instance_gpu_info is None:
raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]")

fallback_instance_gpu_info = (
fallback_instance_gpu_info.get("Count"),
fallback_instance_gpu_info.get("TotalGpuMemoryInMiB"),
)
logger.info("GPU Info [%s]: %s", ec2_instance, fallback_instance_gpu_info)
return fallback_instance_gpu_info


def _format_instance_type(instance_type: str) -> str:
"""Formats provided instance type name
Args:
instance_type (str):
Returns: formatted instance type.
"""
split_instance = instance_type.split(".")

if len(split_instance) > 2:
split_instance.pop(0)

ec2_instance = ".".join(split_instance)
return ec2_instance
44 changes: 44 additions & 0 deletions tests/integ/sagemaker/serve/utils/test_hardware_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest

from sagemaker.serve.utils import hardware_detector

REGION = "us-west-2"
VALID_INSTANCE_TYPE = "ml.g5.48xlarge"
INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge"
EXPECTED_INSTANCE_GPU_INFO = (8, 196608)


def test_get_gpu_info_success(sagemaker_session):
gpu_info = hardware_detector._get_gpu_info(VALID_INSTANCE_TYPE, sagemaker_session)

assert gpu_info == EXPECTED_INSTANCE_GPU_INFO


def test_get_gpu_info_throws(sagemaker_session):
with pytest.raises(ValueError):
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session)


def test_get_gpu_info_fallback_success():
gpu_info = hardware_detector._get_gpu_info_fallback(VALID_INSTANCE_TYPE, REGION)

assert gpu_info == EXPECTED_INSTANCE_GPU_INFO


def test_get_gpu_info_fallback_throws():
with pytest.raises(ValueError):
hardware_detector._get_gpu_info_fallback(INVALID_INSTANCE_TYPE, REGION)
98 changes: 98 additions & 0 deletions tests/unit/sagemaker/serve/utils/test_hardware_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from botocore.exceptions import ClientError
import pytest

from sagemaker.serve.utils import hardware_detector

REGION = "us-west-2"
VALID_INSTANCE_TYPE = "ml.g5.48xlarge"
INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge"
EXPECTED_INSTANCE_GPU_INFO = (8, 196608)


def test_get_gpu_info_success(sagemaker_session, boto_session):
boto_session.client("ec2").describe_instance_types.return_value = {
"InstanceTypes": [
{
"GpuInfo": {
"Gpus": [
{
"Name": "A10G",
"Manufacturer": "NVIDIA",
"Count": 8,
"MemoryInfo": {"SizeInMiB": 24576},
}
],
"TotalGpuMemoryInMiB": 196608,
},
}
]
}

instance_gpu_info = hardware_detector._get_gpu_info(VALID_INSTANCE_TYPE, sagemaker_session)

boto_session.client("ec2").describe_instance_types.assert_called_once_with(
InstanceTypes=["g5.48xlarge"]
)
assert instance_gpu_info == EXPECTED_INSTANCE_GPU_INFO


def test_get_gpu_info_throws(sagemaker_session, boto_session):
boto_session.client("ec2").describe_instance_types.return_value = {"InstanceTypes": [{}]}

with pytest.raises(ValueError):
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session)


def test_get_gpu_info_describe_instance_types_throws(sagemaker_session, boto_session):
boto_session.client("ec2").describe_instance_types.side_effect = ClientError(
{
"Error": {
"Code": "InvalidInstanceType",
"Message": f"An error occurred (InvalidInstanceType) when calling the DescribeInstanceTypes "
f"operation: The following supplied instance types do not exist: [{INVALID_INSTANCE_TYPE}]",
}
},
"DescribeInstanceTypes",
)

with pytest.raises(ValueError):
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session)


def test_get_gpu_info_fallback_success():
fallback_instance_gpu_info = hardware_detector._get_gpu_info_fallback(
VALID_INSTANCE_TYPE, REGION
)

assert fallback_instance_gpu_info == EXPECTED_INSTANCE_GPU_INFO


def test_get_gpu_info_fallback_throws():
with pytest.raises(ValueError):
hardware_detector._get_gpu_info_fallback(INVALID_INSTANCE_TYPE, REGION)


def test_format_instance_type_success():
formatted_instance_type = hardware_detector._format_instance_type(VALID_INSTANCE_TYPE)

assert formatted_instance_type == "g5.48xlarge"


def test_format_instance_type_without_ml_success():
formatted_instance_type = hardware_detector._format_instance_type("g5.48xlarge")

assert formatted_instance_type == "g5.48xlarge"
30 changes: 30 additions & 0 deletions tests/unit/sagemaker/test_instance_types_gpu_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from sagemaker import instance_types_gpu_info

REGION = "us-west-2"
INVALID_REGION = "invalid-region"


def test_retrieve_success():
data = instance_types_gpu_info.retrieve(REGION)

assert len(data) > 0


def test_retrieve_throws():
data = instance_types_gpu_info.retrieve(INVALID_REGION)

assert len(data) == 0

0 comments on commit 427dec6

Please sign in to comment.