Skip to content

Commit

Permalink
feature: support remote debug for sagemaker training job (aws#4315)
Browse files Browse the repository at this point in the history
* feature: support remote debug for sagemaker training job

* change: Replace update_remote_config with 2 helper methods for enable and disable respectively

* change: add new argument enable_remote_debug to skip set of test_jumpstart_estimator_kwargs_match_parent_class

* chore: add jumpstart support for remote debug

---------

Co-authored-by: Xinyu Xie <xixinyu@amazon.com>
Co-authored-by: Evan Kravitz <evakravi@amazon.com>
  • Loading branch information
3 people authored Dec 20, 2023
1 parent 4968a91 commit 0541b26
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 3 deletions.
66 changes: 63 additions & 3 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(
container_entry_point: Optional[List[str]] = None,
container_arguments: Optional[List[str]] = None,
disable_output_compression: bool = False,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``EstimatorBase`` instance.
Expand Down Expand Up @@ -540,6 +541,8 @@ def __init__(
to Amazon S3 without compression after training finishes.
enable_infra_check (bool or PipelineVariable): Optional.
Specifies whether it is running Sagemaker built-in infra check jobs.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
"""
instance_count = renamed_kwargs(
"train_instance_count", "instance_count", instance_count, kwargs
Expand Down Expand Up @@ -777,6 +780,8 @@ def __init__(

self.tensorboard_app = TensorBoardApp(region=self.sagemaker_session.boto_region_name)

self._enable_remote_debug = enable_remote_debug

@abstractmethod
def training_image_uri(self):
"""Return the Docker image to use for training.
Expand Down Expand Up @@ -1958,6 +1963,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds")
if max_wait:
init_params["max_wait"] = max_wait

if "RemoteDebugConfig" in job_details:
init_params["enable_remote_debug"] = job_details["RemoteDebugConfig"].get(
"EnableRemoteDebug"
)
return init_params

def _get_instance_type(self):
Expand Down Expand Up @@ -2292,6 +2302,32 @@ def update_profiler(

_TrainingJob.update(self, profiler_rule_configs, profiler_config_request_dict)

def get_remote_debug_config(self):
"""dict: Return the configuration of RemoteDebug"""
return (
None
if self._enable_remote_debug is None
else {"EnableRemoteDebug": self._enable_remote_debug}
)

def enable_remote_debug(self):
"""Enable remote debug for a training job."""
self._update_remote_debug(True)

def disable_remote_debug(self):
"""Disable remote debug for a training job."""
self._update_remote_debug(False)

def _update_remote_debug(self, enable_remote_debug: bool):
"""Update to enable or disable remote debug for a training job.
This method updates the ``_enable_remote_debug`` parameter
and enables or disables remote debug for a training job
"""
self._ensure_latest_training_job()
_TrainingJob.update(self, remote_debug_config={"EnableRemoteDebug": enable_remote_debug})
self._enable_remote_debug = enable_remote_debug

def get_app_url(
self,
app_type,
Expand Down Expand Up @@ -2520,6 +2556,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
if estimator.profiler_config:
train_args["profiler_config"] = estimator.profiler_config._to_request_dict()

if estimator.get_remote_debug_config() is not None:
train_args["remote_debug_config"] = estimator.get_remote_debug_config()

return train_args

@classmethod
Expand Down Expand Up @@ -2549,7 +2588,12 @@ def _is_local_channel(cls, input_uri):

@classmethod
def update(
cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None
cls,
estimator,
profiler_rule_configs=None,
profiler_config=None,
resource_config=None,
remote_debug_config=None,
):
"""Update a running Amazon SageMaker training job.
Expand All @@ -2562,20 +2606,31 @@ def update(
resource_config (dict): Configuration of the resources for the training job. You can
update the keep-alive period if the warm pool status is `Available`. No other fields
can be updated. (default: None).
remote_debug_config (dict): Configuration for RemoteDebug. (default: ``None``)
The dict can contain 'EnableRemoteDebug'(bool).
For example,
.. code:: python
remote_debug_config = {
"EnableRemoteDebug": True,
} (default: None).
Returns:
sagemaker.estimator._TrainingJob: Constructed object that captures
all information about the updated training job.
"""
update_args = cls._get_update_args(
estimator, profiler_rule_configs, profiler_config, resource_config
estimator, profiler_rule_configs, profiler_config, resource_config, remote_debug_config
)
estimator.sagemaker_session.update_training_job(**update_args)

return estimator.latest_training_job

@classmethod
def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, resource_config):
def _get_update_args(
cls, estimator, profiler_rule_configs, profiler_config, resource_config, remote_debug_config
):
"""Constructs a dict of arguments for updating an Amazon SageMaker training job.
Args:
Expand All @@ -2596,6 +2651,7 @@ def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, res
update_args.update(build_dict("profiler_rule_configs", profiler_rule_configs))
update_args.update(build_dict("profiler_config", profiler_config))
update_args.update(build_dict("resource_config", resource_config))
update_args.update(build_dict("remote_debug_config", remote_debug_config))

return update_args

Expand Down Expand Up @@ -2694,6 +2750,7 @@ def __init__(
container_arguments: Optional[List[str]] = None,
disable_output_compression: bool = False,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -3055,6 +3112,8 @@ def __init__(
to Amazon S3 without compression after training finishes.
enable_infra_check (bool or PipelineVariable): Optional.
Specifies whether it is running Sagemaker built-in infra check jobs.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
"""
self.image_uri = image_uri
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
Expand Down Expand Up @@ -3106,6 +3165,7 @@ def __init__(
container_entry_point=container_entry_point,
container_arguments=container_arguments,
disable_output_compression=disable_output_compression,
enable_remote_debug=enable_remote_debug,
**kwargs,
)

Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
container_entry_point: Optional[List[str]] = None,
container_arguments: Optional[List[str]] = None,
disable_output_compression: Optional[bool] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
):
"""Initializes a ``JumpStartEstimator``.
Expand Down Expand Up @@ -495,6 +496,8 @@ def __init__(
a training job.
disable_output_compression (Optional[bool]): When set to true, Model is uploaded
to Amazon S3 without compression after training finishes.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
Raises:
ValueError: If the model ID is not recognized by JumpStart.
Expand Down Expand Up @@ -569,6 +572,7 @@ def _is_valid_model_id_hook():
container_arguments=container_arguments,
disable_output_compression=disable_output_compression,
enable_infra_check=enable_infra_check,
enable_remote_debug=enable_remote_debug,
)

self.model_id = estimator_init_kwargs.model_id
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def get_init_kwargs(
container_arguments: Optional[List[str]] = None,
disable_output_compression: Optional[bool] = None,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
) -> JumpStartEstimatorInitKwargs:
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""

Expand Down Expand Up @@ -183,6 +184,7 @@ def get_init_kwargs(
container_arguments=container_arguments,
disable_output_compression=disable_output_compression,
enable_infra_check=enable_infra_check,
enable_remote_debug=enable_remote_debug,
)

estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"container_arguments",
"disable_output_compression",
"enable_infra_check",
"enable_remote_debug",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -1344,6 +1345,7 @@ def __init__(
container_arguments: Optional[List[str]] = None,
disable_output_compression: Optional[bool] = None,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
) -> None:
"""Instantiates JumpStartEstimatorInitKwargs object."""

Expand Down Expand Up @@ -1401,6 +1403,7 @@ def __init__(
self.container_arguments = container_arguments
self.disable_output_compression = disable_output_compression
self.enable_infra_check = enable_infra_check
self.enable_remote_debug = enable_remote_debug


class JumpStartEstimatorFitKwargs(JumpStartKwargs):
Expand Down
48 changes: 48 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ def train( # noqa: C901
profiler_config=None,
environment: Optional[Dict[str, str]] = None,
retry_strategy=None,
remote_debug_config=None,
):
"""Create an Amazon SageMaker training job.
Expand Down Expand Up @@ -858,6 +859,15 @@ def train( # noqa: C901
configurations.src/sagemaker/lineage/artifact.py:285
profiler_config (dict): Configuration for how profiling information is emitted
with SageMaker Profiler. (default: ``None``).
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
The dict can contain 'EnableRemoteDebug'(bool).
For example,
.. code:: python
remote_debug_config = {
"EnableRemoteDebug": True,
}
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
Expand Down Expand Up @@ -950,6 +960,7 @@ def train( # noqa: C901
enable_sagemaker_metrics=enable_sagemaker_metrics,
profiler_rule_configs=profiler_rule_configs,
profiler_config=inferred_profiler_config,
remote_debug_config=remote_debug_config,
environment=environment,
retry_strategy=retry_strategy,
)
Expand Down Expand Up @@ -992,6 +1003,7 @@ def _get_train_request( # noqa: C901
enable_sagemaker_metrics=None,
profiler_rule_configs=None,
profiler_config=None,
remote_debug_config=None,
environment=None,
retry_strategy=None,
):
Expand Down Expand Up @@ -1103,6 +1115,15 @@ def _get_train_request( # noqa: C901
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
profiler_config(dict): Configuration for how profiling information is emitted with
SageMaker Profiler. (default: ``None``).
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
The dict can contain 'EnableRemoteDebug'(bool).
For example,
.. code:: python
remote_debug_config = {
"EnableRemoteDebug": True,
}
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
Expand Down Expand Up @@ -1206,6 +1227,9 @@ def _get_train_request( # noqa: C901
if profiler_config is not None:
train_request["ProfilerConfig"] = profiler_config

if remote_debug_config is not None:
train_request["RemoteDebugConfig"] = remote_debug_config

if retry_strategy is not None:
train_request["RetryStrategy"] = retry_strategy

Expand All @@ -1217,6 +1241,7 @@ def update_training_job(
profiler_rule_configs=None,
profiler_config=None,
resource_config=None,
remote_debug_config=None,
):
"""Calls the UpdateTrainingJob API for the given job name and returns the response.
Expand All @@ -1228,6 +1253,15 @@ def update_training_job(
resource_config (dict): Configuration of the resources for the training job. You can
update the keep-alive period if the warm pool status is `Available`. No other fields
can be updated. (default: ``None``).
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
The dict can contain 'EnableRemoteDebug'(bool).
For example,
.. code:: python
remote_debug_config = {
"EnableRemoteDebug": True,
}
"""
# No injections from sagemaker_config because the UpdateTrainingJob API's resource_config
# object accepts fewer parameters than the CreateTrainingJob API, and none that the
Expand All @@ -1240,6 +1274,7 @@ def update_training_job(
profiler_rule_configs=profiler_rule_configs,
profiler_config=inferred_profiler_config,
resource_config=resource_config,
remote_debug_config=remote_debug_config,
)
LOGGER.info("Updating training job with name %s", job_name)
LOGGER.debug("Update request: %s", json.dumps(update_training_job_request, indent=4))
Expand All @@ -1251,6 +1286,7 @@ def _get_update_training_job_request(
profiler_rule_configs=None,
profiler_config=None,
resource_config=None,
remote_debug_config=None,
):
"""Constructs a request compatible for updating an Amazon SageMaker training job.
Expand All @@ -1262,6 +1298,15 @@ def _get_update_training_job_request(
resource_config (dict): Configuration of the resources for the training job. You can
update the keep-alive period if the warm pool status is `Available`. No other fields
can be updated. (default: ``None``).
remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``)
The dict can contain 'EnableRemoteDebug'(bool).
For example,
.. code:: python
remote_debug_config = {
"EnableRemoteDebug": True,
}
Returns:
Dict: an update training request dict
Expand All @@ -1279,6 +1324,9 @@ def _get_update_training_job_request(
if resource_config is not None:
update_training_job_request["ResourceConfig"] = resource_config

if remote_debug_config is not None:
update_training_job_request["RemoteDebugConfig"] = remote_debug_config

return update_training_job_request

def process(
Expand Down
Loading

0 comments on commit 0541b26

Please sign in to comment.