diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index db8d572a75..551a42ad55 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -178,6 +178,7 @@ def __init__( container_entry_point: Optional[List[str]] = None, container_arguments: Optional[List[str]] = None, disable_output_compression: bool = False, + enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -540,6 +541,8 @@ def __init__( to Amazon S3 without compression after training finishes. enable_infra_check (bool or PipelineVariable): Optional. Specifies whether it is running Sagemaker built-in infra check jobs. + enable_remote_debug (bool or PipelineVariable): Optional. + Specifies whether RemoteDebug is enabled for the training job """ instance_count = renamed_kwargs( "train_instance_count", "instance_count", instance_count, kwargs @@ -777,6 +780,8 @@ def __init__( self.tensorboard_app = TensorBoardApp(region=self.sagemaker_session.boto_region_name) + self._enable_remote_debug = enable_remote_debug + @abstractmethod def training_image_uri(self): """Return the Docker image to use for training. @@ -1958,6 +1963,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds") if max_wait: init_params["max_wait"] = max_wait + + if "RemoteDebugConfig" in job_details: + init_params["enable_remote_debug"] = job_details["RemoteDebugConfig"].get( + "EnableRemoteDebug" + ) return init_params def _get_instance_type(self): @@ -2292,6 +2302,32 @@ def update_profiler( _TrainingJob.update(self, profiler_rule_configs, profiler_config_request_dict) + def get_remote_debug_config(self): + """dict: Return the configuration of RemoteDebug""" + return ( + None + if self._enable_remote_debug is None + else {"EnableRemoteDebug": self._enable_remote_debug} + ) + + def enable_remote_debug(self): + """Enable remote debug for a training job.""" + self._update_remote_debug(True) + + def disable_remote_debug(self): + """Disable remote debug for a training job.""" + self._update_remote_debug(False) + + def _update_remote_debug(self, enable_remote_debug: bool): + """Update to enable or disable remote debug for a training job. + + This method updates the ``_enable_remote_debug`` parameter + and enables or disables remote debug for a training job + """ + self._ensure_latest_training_job() + _TrainingJob.update(self, remote_debug_config={"EnableRemoteDebug": enable_remote_debug}) + self._enable_remote_debug = enable_remote_debug + def get_app_url( self, app_type, @@ -2520,6 +2556,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config): if estimator.profiler_config: train_args["profiler_config"] = estimator.profiler_config._to_request_dict() + if estimator.get_remote_debug_config() is not None: + train_args["remote_debug_config"] = estimator.get_remote_debug_config() + return train_args @classmethod @@ -2549,7 +2588,12 @@ def _is_local_channel(cls, input_uri): @classmethod def update( - cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None + cls, + estimator, + profiler_rule_configs=None, + profiler_config=None, + resource_config=None, + remote_debug_config=None, ): """Update a running Amazon SageMaker training job. @@ -2562,20 +2606,31 @@ def update( resource_config (dict): Configuration of the resources for the training job. You can update the keep-alive period if the warm pool status is `Available`. No other fields can be updated. (default: None). + remote_debug_config (dict): Configuration for RemoteDebug. (default: ``None``) + The dict can contain 'EnableRemoteDebug'(bool). + For example, + + .. code:: python + + remote_debug_config = { + "EnableRemoteDebug": True, + } (default: None). Returns: sagemaker.estimator._TrainingJob: Constructed object that captures all information about the updated training job. """ update_args = cls._get_update_args( - estimator, profiler_rule_configs, profiler_config, resource_config + estimator, profiler_rule_configs, profiler_config, resource_config, remote_debug_config ) estimator.sagemaker_session.update_training_job(**update_args) return estimator.latest_training_job @classmethod - def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, resource_config): + def _get_update_args( + cls, estimator, profiler_rule_configs, profiler_config, resource_config, remote_debug_config + ): """Constructs a dict of arguments for updating an Amazon SageMaker training job. Args: @@ -2596,6 +2651,7 @@ def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, res update_args.update(build_dict("profiler_rule_configs", profiler_rule_configs)) update_args.update(build_dict("profiler_config", profiler_config)) update_args.update(build_dict("resource_config", resource_config)) + update_args.update(build_dict("remote_debug_config", remote_debug_config)) return update_args @@ -2694,6 +2750,7 @@ def __init__( container_arguments: Optional[List[str]] = None, disable_output_compression: bool = False, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, + enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -3055,6 +3112,8 @@ def __init__( to Amazon S3 without compression after training finishes. enable_infra_check (bool or PipelineVariable): Optional. Specifies whether it is running Sagemaker built-in infra check jobs. + enable_remote_debug (bool or PipelineVariable): Optional. + Specifies whether RemoteDebug is enabled for the training job """ self.image_uri = image_uri self._hyperparameters = hyperparameters.copy() if hyperparameters else {} @@ -3106,6 +3165,7 @@ def __init__( container_entry_point=container_entry_point, container_arguments=container_arguments, disable_output_compression=disable_output_compression, + enable_remote_debug=enable_remote_debug, **kwargs, ) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 4f7a041df0..e6047e9009 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -106,6 +106,7 @@ def __init__( container_entry_point: Optional[List[str]] = None, container_arguments: Optional[List[str]] = None, disable_output_compression: Optional[bool] = None, + enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, ): """Initializes a ``JumpStartEstimator``. @@ -495,6 +496,8 @@ def __init__( a training job. disable_output_compression (Optional[bool]): When set to true, Model is uploaded to Amazon S3 without compression after training finishes. + enable_remote_debug (bool or PipelineVariable): Optional. + Specifies whether RemoteDebug is enabled for the training job Raises: ValueError: If the model ID is not recognized by JumpStart. @@ -569,6 +572,7 @@ def _is_valid_model_id_hook(): container_arguments=container_arguments, disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, + enable_remote_debug=enable_remote_debug, ) self.model_id = estimator_init_kwargs.model_id diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index baa9d55085..7479c23832 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -127,6 +127,7 @@ def get_init_kwargs( container_arguments: Optional[List[str]] = None, disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, + enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" @@ -183,6 +184,7 @@ def get_init_kwargs( container_arguments=container_arguments, disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, + enable_remote_debug=enable_remote_debug, ) estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index de9e2c10a3..7c06282894 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1280,6 +1280,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "container_arguments", "disable_output_compression", "enable_infra_check", + "enable_remote_debug", ] SERIALIZATION_EXCLUSION_SET = { @@ -1344,6 +1345,7 @@ def __init__( container_arguments: Optional[List[str]] = None, disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, + enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -1401,6 +1403,7 @@ def __init__( self.container_arguments = container_arguments self.disable_output_compression = disable_output_compression self.enable_infra_check = enable_infra_check + self.enable_remote_debug = enable_remote_debug class JumpStartEstimatorFitKwargs(JumpStartKwargs): diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 3b2de0239e..5b5df7a792 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -748,6 +748,7 @@ def train( # noqa: C901 profiler_config=None, environment: Optional[Dict[str, str]] = None, retry_strategy=None, + remote_debug_config=None, ): """Create an Amazon SageMaker training job. @@ -858,6 +859,15 @@ def train( # noqa: C901 configurations.src/sagemaker/lineage/artifact.py:285 profiler_config (dict): Configuration for how profiling information is emitted with SageMaker Profiler. (default: ``None``). + remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``) + The dict can contain 'EnableRemoteDebug'(bool). + For example, + + .. code:: python + + remote_debug_config = { + "EnableRemoteDebug": True, + } environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``) retry_strategy(dict): Defines RetryStrategy for InternalServerFailures. @@ -950,6 +960,7 @@ def train( # noqa: C901 enable_sagemaker_metrics=enable_sagemaker_metrics, profiler_rule_configs=profiler_rule_configs, profiler_config=inferred_profiler_config, + remote_debug_config=remote_debug_config, environment=environment, retry_strategy=retry_strategy, ) @@ -992,6 +1003,7 @@ def _get_train_request( # noqa: C901 enable_sagemaker_metrics=None, profiler_rule_configs=None, profiler_config=None, + remote_debug_config=None, environment=None, retry_strategy=None, ): @@ -1103,6 +1115,15 @@ def _get_train_request( # noqa: C901 profiler_rule_configs (list[dict]): A list of profiler rule configurations. profiler_config(dict): Configuration for how profiling information is emitted with SageMaker Profiler. (default: ``None``). + remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``) + The dict can contain 'EnableRemoteDebug'(bool). + For example, + + .. code:: python + + remote_debug_config = { + "EnableRemoteDebug": True, + } environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``) retry_strategy(dict): Defines RetryStrategy for InternalServerFailures. @@ -1206,6 +1227,9 @@ def _get_train_request( # noqa: C901 if profiler_config is not None: train_request["ProfilerConfig"] = profiler_config + if remote_debug_config is not None: + train_request["RemoteDebugConfig"] = remote_debug_config + if retry_strategy is not None: train_request["RetryStrategy"] = retry_strategy @@ -1217,6 +1241,7 @@ def update_training_job( profiler_rule_configs=None, profiler_config=None, resource_config=None, + remote_debug_config=None, ): """Calls the UpdateTrainingJob API for the given job name and returns the response. @@ -1228,6 +1253,15 @@ def update_training_job( resource_config (dict): Configuration of the resources for the training job. You can update the keep-alive period if the warm pool status is `Available`. No other fields can be updated. (default: ``None``). + remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``) + The dict can contain 'EnableRemoteDebug'(bool). + For example, + + .. code:: python + + remote_debug_config = { + "EnableRemoteDebug": True, + } """ # No injections from sagemaker_config because the UpdateTrainingJob API's resource_config # object accepts fewer parameters than the CreateTrainingJob API, and none that the @@ -1240,6 +1274,7 @@ def update_training_job( profiler_rule_configs=profiler_rule_configs, profiler_config=inferred_profiler_config, resource_config=resource_config, + remote_debug_config=remote_debug_config, ) LOGGER.info("Updating training job with name %s", job_name) LOGGER.debug("Update request: %s", json.dumps(update_training_job_request, indent=4)) @@ -1251,6 +1286,7 @@ def _get_update_training_job_request( profiler_rule_configs=None, profiler_config=None, resource_config=None, + remote_debug_config=None, ): """Constructs a request compatible for updating an Amazon SageMaker training job. @@ -1262,6 +1298,15 @@ def _get_update_training_job_request( resource_config (dict): Configuration of the resources for the training job. You can update the keep-alive period if the warm pool status is `Available`. No other fields can be updated. (default: ``None``). + remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``) + The dict can contain 'EnableRemoteDebug'(bool). + For example, + + .. code:: python + + remote_debug_config = { + "EnableRemoteDebug": True, + } Returns: Dict: an update training request dict @@ -1279,6 +1324,9 @@ def _get_update_training_job_request( if resource_config is not None: update_training_job_request["ResourceConfig"] = resource_config + if remote_debug_config is not None: + update_training_job_request["RemoteDebugConfig"] = remote_debug_config + return update_training_job_request def process( diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 437c150c8b..3d8b0c454d 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -2012,6 +2012,82 @@ def test_sagemaker_model_custom_channel_name(sagemaker_session): ] +def test_framework_with_remote_debug_config(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + enable_remote_debug=True, + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args["remote_debug_config"]["EnableRemoteDebug"] + assert f.get_remote_debug_config()["EnableRemoteDebug"] + + +def test_framework_without_remote_debug_config(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args.get("remote_debug_config") is None + assert f.get_remote_debug_config() is None + + +def test_framework_enable_remote_debug(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + ) + f.fit("s3://mydata") + f.enable_remote_debug() + + sagemaker_session.update_training_job.assert_called_once() + _, args = sagemaker_session.update_training_job.call_args + assert args["remote_debug_config"] == { + "EnableRemoteDebug": True, + } + assert f.get_remote_debug_config()["EnableRemoteDebug"] + assert len(args) == 2 + + +def test_framework_disable_remote_debug(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + enable_remote_debug=True, + ) + f.fit("s3://mydata") + f.disable_remote_debug() + + sagemaker_session.update_training_job.assert_called_once() + _, args = sagemaker_session.update_training_job.call_args + assert args["remote_debug_config"] == { + "EnableRemoteDebug": False, + } + assert not f.get_remote_debug_config()["EnableRemoteDebug"] + assert len(args) == 2 + + @patch("time.strftime", return_value=TIMESTAMP) def test_custom_code_bucket(time, sagemaker_session): code_bucket = "codebucket" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 57ba8daad5..d3bba53504 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1876,6 +1876,15 @@ def test_update_training_job_with_sagemaker_config_injection(sagemaker_session): ) +def test_update_training_job_with_remote_debug_config(sagemaker_session): + sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRAINING_JOB + sagemaker_session.update_training_job( + job_name="MyTestJob", remote_debug_config={"EnableRemoteDebug": False} + ) + _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] + assert not actual_train_args["RemoteDebugConfig"]["EnableRemoteDebug"] + + def test_train_with_sagemaker_config_injection(sagemaker_session): sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_TRAINING_JOB @@ -2128,6 +2137,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): } CONTAINER_ENTRY_POINT = ["bin/bash", "test.sh"] CONTAINER_ARGUMENTS = ["--arg1", "value1", "--arg2", "value2"] + remote_debug_config = {"EnableRemoteDebug": True} sagemaker_session.train( image_uri=IMAGE, @@ -2152,6 +2162,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): training_image_config=TRAINING_IMAGE_CONFIG, container_entry_point=CONTAINER_ENTRY_POINT, container_arguments=CONTAINER_ARGUMENTS, + remote_debug_config=remote_debug_config, ) _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] @@ -2174,6 +2185,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): actual_train_args["AlgorithmSpecification"]["ContainerEntrypoint"] == CONTAINER_ENTRY_POINT ) assert actual_train_args["AlgorithmSpecification"]["ContainerArguments"] == CONTAINER_ARGUMENTS + assert actual_train_args["RemoteDebugConfig"]["EnableRemoteDebug"] def test_create_transform_job_with_sagemaker_config_injection(sagemaker_session):