Skip to content

Commit

Permalink
Merge branch 'master' into 2PR
Browse files Browse the repository at this point in the history
  • Loading branch information
samruds authored Mar 5, 2024
2 parents 8abf8a9 + 790bd87 commit de2a0e7
Show file tree
Hide file tree
Showing 15 changed files with 240 additions and 35 deletions.
18 changes: 16 additions & 2 deletions doc/api/prep_data/feature_store.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Feature Definition
:members:
:show-inheritance:


Inputs
******

Expand Down Expand Up @@ -181,9 +182,13 @@ Feature Processor Data Source
:members:
:show-inheritance:

.. autoclass:: sagemaker.feature_store.feature_processor.PySparkDataSource
:members:
:show-inheritance:

Feature Processor Scheduler
***************************

Feature Processor Scheduler and Triggers
****************************************

.. automethod:: sagemaker.feature_store.feature_processor.to_pipeline

Expand All @@ -196,3 +201,12 @@ Feature Processor Scheduler
.. automethod:: sagemaker.feature_store.feature_processor.describe

.. automethod:: sagemaker.feature_store.feature_processor.list_pipelines

.. automethod:: sagemaker.feature_store.feature_processor.put_trigger

.. automethod:: sagemaker.feature_store.feature_processor.enable_trigger

.. automethod:: sagemaker.feature_store.feature_processor.disable_trigger

.. automethod:: sagemaker.feature_store.feature_processor.delete_trigger

2 changes: 1 addition & 1 deletion requirements/extras/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ awslogs==0.14.0
black==22.3.0
stopit==1.1.2
# Update tox.ini to have correct version of airflow constraints file
apache-airflow==2.8.1
apache-airflow==2.8.2
apache-airflow-providers-amazon==7.2.1
attrs>=23.1.0,<24
fabric==2.6.0
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ def _load_config_from_file(file_path: str) -> dict:
f"Provide a valid file path"
)
logger.debug("Fetching defaults config from location: %s", file_path)
return yaml.safe_load(open(inferred_file_path, "r"))
with open(inferred_file_path, "r") as f:
content = yaml.safe_load(f)
return content


def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict:
Expand Down
66 changes: 52 additions & 14 deletions src/sagemaker/jumpstart/artifacts/resource_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""This module contains functions for obtaining JumpStart resoure requirements."""
from __future__ import absolute_import

from typing import Optional
from typing import Dict, Optional, Tuple

from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -28,6 +28,20 @@
from sagemaker.session import Session
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements

REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP: Dict[
str, Dict[str, Tuple[str, str]]
] = {
"requests": {
"num_accelerators": ("num_accelerators", "num_accelerators"),
"num_cpus": ("num_cpus", "num_cpus"),
"copies": ("copies", "copy_count"),
"min_memory_mb": ("memory", "min_memory"),
},
"limits": {
"max_memory_mb": ("memory", "max_memory"),
},
}


def _retrieve_default_resources(
model_id: str,
Expand All @@ -37,6 +51,7 @@ def _retrieve_default_resources(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
) -> ResourceRequirements:
"""Retrieves the default resource requirements for the model.
Expand All @@ -60,6 +75,8 @@ def _retrieve_default_resources(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
host requirements specific for the instance type.
Returns:
str: The default resource requirements to use for the model or None.
Expand Down Expand Up @@ -87,23 +104,44 @@ def _retrieve_default_resources(
is_dynamic_container_deployment_supported = (
model_specs.dynamic_container_deployment_supported
)
default_resource_requirements = model_specs.hosting_resource_requirements
default_resource_requirements: Dict[str, int] = (
model_specs.hosting_resource_requirements or {}
)
else:
raise NotImplementedError(
f"Unsupported script scope for retrieving default resource requirements: '{scope}'"
)

instance_specific_resource_requirements: Dict[str, int] = (
model_specs.hosting_instance_type_variants.get_instance_specific_resource_requirements(
instance_type
)
if instance_type
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
else {}
)

default_resource_requirements = {
**default_resource_requirements,
**instance_specific_resource_requirements,
}

if is_dynamic_container_deployment_supported:
requests = {}
if "num_accelerators" in default_resource_requirements:
requests["num_accelerators"] = default_resource_requirements["num_accelerators"]
if "min_memory_mb" in default_resource_requirements:
requests["memory"] = default_resource_requirements["min_memory_mb"]
if "num_cpus" in default_resource_requirements:
requests["num_cpus"] = default_resource_requirements["num_cpus"]

limits = {}
if "max_memory_mb" in default_resource_requirements:
limits["memory"] = default_resource_requirements["max_memory_mb"]
return ResourceRequirements(requests=requests, limits=limits)

all_resource_requirement_kwargs = {}

for (
requirement_type,
spec_field_to_resource_requirement_map,
) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.items():
requirement_kwargs = {}
for spec_field, resource_requirement in spec_field_to_resource_requirement_map.items():
if spec_field in default_resource_requirements:
requirement_kwargs[resource_requirement[0]] = default_resource_requirements[
spec_field
]

all_resource_requirement_kwargs[requirement_type] = requirement_kwargs

return ResourceRequirements(**all_resource_requirement_kwargs)
return None
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
instance_type=kwargs.instance_type,
)

return kwargs
Expand Down
23 changes: 23 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,29 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str
instance_type=instance_type, property_name="artifact_key"
)

def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]:
"""Returns instance specific resource requirements.
If a value exists for both the instance family and instance type, the instance type value
is chosen.
"""

instance_specific_resource_requirements: dict = (
self.variants.get(instance_type, {})
.get("properties", {})
.get("resource_requirements", {})
)

instance_type_family = get_instance_type_family(instance_type)

instance_family_resource_requirements: dict = (
self.variants.get(instance_type_family, {})
.get("properties", {})
.get("resource_requirements", {})
)

return {**instance_family_resource_requirements, **instance_specific_resource_requirements}

def _get_instance_specific_property(
self, instance_type: str, property_name: str
) -> Optional[str]:
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,9 @@ def _create_docker_host(
# to setting --runtime=nvidia in the docker commandline.
if self.instance_type == "local_gpu":
host_config["deploy"] = {
"resources": {"reservations": {"devices": [{"capabilities": ["gpu"]}]}}
"resources": {
"reservations": {"devices": [{"count": "all", "capabilities": ["gpu"]}]}
}
}

if not self.is_studio and command == "serve":
Expand Down
5 changes: 0 additions & 5 deletions src/sagemaker/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,11 +694,6 @@ def __init__(
encrypt_inter_container_traffic (bool): A flag that specifies whether traffic between
training containers is encrypted for the training job. Defaults to ``False``.
enable_network_isolation (bool): A flag that specifies whether container will run in
network isolation mode. Defaults to ``False``. Network isolation mode restricts the
container access to outside networks (such as the Internet). The container does not
make any inbound or outbound network calls. Also known as Internet-free mode.
spark_config (SparkConfig): Configurations to the Spark application that runs on
Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri
will be used for training. Note that ``image_uri`` can not be specified at the
Expand Down
7 changes: 6 additions & 1 deletion src/sagemaker/resource_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
from typing import Optional
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements

from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts
Expand All @@ -33,7 +34,8 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
instance_type: Optional[str] = None,
) -> ResourceRequirements:
"""Retrieves the default resource requirements for the model matching the given arguments.
Args:
Expand All @@ -56,6 +58,8 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
host requirements specific for the instance type.
Returns:
str: The default resource requirements to use for the model.
Expand All @@ -79,4 +83,5 @@ def retrieve_default(
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
instance_type=instance_type,
)
23 changes: 18 additions & 5 deletions tests/unit/sagemaker/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,23 @@
@pytest.fixture()
def config_file_as_yaml(get_data_dir):
config_file_path = os.path.join(get_data_dir, "config.yaml")
return open(config_file_path, "r").read()
with open(config_file_path, "r") as f:
content = f.read()
return content


@pytest.fixture()
def expected_merged_config(get_data_dir):
expected_merged_config_file_path = os.path.join(
get_data_dir, "expected_output_config_after_merge.yaml"
)
return yaml.safe_load(open(expected_merged_config_file_path, "r").read())
with open(expected_merged_config_file_path, "r") as f:
content = yaml.safe_load(f.read())
return content


def _raise_valueerror(*args):
raise ValueError(args)


def test_config_when_default_config_file_and_user_config_file_is_not_found():
Expand All @@ -60,7 +68,8 @@ def test_config_when_overriden_default_config_file_is_not_found(get_data_dir):
def test_invalid_config_file_which_has_python_code(get_data_dir):
invalid_config_file_path = os.path.join(get_data_dir, "config_file_with_code.yaml")
# no exceptions will be thrown with yaml.unsafe_load
yaml.unsafe_load(open(invalid_config_file_path, "r"))
with open(invalid_config_file_path, "r") as f:
yaml.unsafe_load(f)
# PyYAML will throw exceptions for yaml.safe_load. SageMaker Config is using
# yaml.safe_load internally
with pytest.raises(ConstructorError) as exception_info:
Expand Down Expand Up @@ -228,7 +237,8 @@ def test_merge_of_s3_default_config_file_and_regular_config_file(
get_data_dir, expected_merged_config, s3_resource_mock
):
config_file_content_path = os.path.join(get_data_dir, "sample_config_for_merge.yaml")
config_file_as_yaml = open(config_file_content_path, "r").read()
with open(config_file_content_path, "r") as f:
config_file_as_yaml = f.read()
config_file_bucket = "config-file-bucket"
config_file_s3_prefix = "config/config.yaml"
config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix)
Expand Down Expand Up @@ -440,8 +450,11 @@ def test_load_local_mode_config(mock_load_config):
mock_load_config.assert_called_with(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH)


def test_load_local_mode_config_when_config_file_is_not_found():
@patch("sagemaker.config.config._load_config_from_file", side_effect=_raise_valueerror)
def test_load_local_mode_config_when_config_file_is_not_found(mock_load_config):
# Patch is needed because one might actually have a local config file
assert load_local_mode_config() is None
mock_load_config.assert_called_with(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH)


@pytest.mark.parametrize(
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,22 @@
"model_package_arn": "$gpu_model_package_arn",
}
},
"g5": {
"properties": {
"resource_requirements": {
"num_accelerators": 888810,
"randon-field-2": 2222,
}
}
},
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
"ml.g5.xlarge": {
"properties": {
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"},
"resource_requirements": {"num_accelerators": 10},
}
},
"ml.g5.48xlarge": {
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}}
},
Expand All @@ -857,6 +871,12 @@
"framework_version": "1.5.0",
"py_version": "py3",
},
"dynamic_container_deployment_supported": True,
"hosting_resource_requirements": {
"min_memory_mb": 81999,
"num_accelerators": 1,
"random_field_1": 1,
},
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"variants": {
"ml.p2.12xlarge": {
"properties": {
"resource_requirements": {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9},
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"},
"supported_inference_instance_types": ["ml.p5.xlarge"],
"default_inference_instance_type": "ml.p5.xlarge",
Expand All @@ -60,6 +61,11 @@
"p2": {
"regional_properties": {"image_uri": "$gpu_image_uri"},
"properties": {
"resource_requirements": {
"req2": {"2": 5, "9": 999},
"req3": 999,
"req4": "blah",
},
"supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"],
"default_inference_instance_type": "ml.p2.xlarge",
"metrics": [
Expand Down Expand Up @@ -879,3 +885,20 @@ def test_jumpstart_training_artifact_key_instance_variants():
)
is None
)


def test_jumpstart_resource_requirements_instance_variants():
assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
instance_type="ml.p2.xlarge"
) == {"req2": {"2": 5, "9": 999}, "req3": 999, "req4": "blah"}

assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
instance_type="ml.p2.12xlarge"
) == {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9, "req4": "blah"}

assert (
INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
instance_type="ml.p99.12xlarge"
)
== {}
)
Loading

0 comments on commit de2a0e7

Please sign in to comment.