diff --git a/.gitignore b/.gitignore index d966fa1ccc..fc07847fba 100644 --- a/.gitignore +++ b/.gitignore @@ -34,7 +34,7 @@ env/ **/_repack_script_launcher.sh src/sagemaker/modules/train/container_drivers/sm_train.sh src/sagemaker/modules/train/container_drivers/sourcecode.json -src/sagemaker/modules/train/container_drivers/distributed_runner.json +src/sagemaker/modules/train/container_drivers/distributed.json tests/data/**/_repack_model.py tests/data/experiment/sagemaker-dev-1.0.tar.gz src/sagemaker/serve/tmp_workspace \ No newline at end of file diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index a95e062519..34a98c0b8e 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -664,58 +664,19 @@ def _simple_path(*args: str): "minLength": 20, "maxLength": 2048, }, - "baseJobName": { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, - "sourceCode": { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, - "distributed_runner": { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, - "compute": { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, - "networking": { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, - "stoppingCondition": { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, - "trainingImage": { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, - "trainingImageConfig": { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, - "algorithmName": { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, - "outputDataConfig": { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, - "trainingInputMode": { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, - "environment": { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, - "hyperparameters": { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, + "baseJobName": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "sourceCode": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "distributed": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "compute": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "networking": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "stoppingCondition": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "trainingImage": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "trainingImageConfig": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "algorithmName": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "outputDataConfig": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "trainingInputMode": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "environment": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "hyperparameters": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, }, PROPERTIES: { SCHEMA_VERSION: { @@ -769,10 +730,7 @@ def _simple_path(*args: str): }, }, }, - MODEL_TRAINER: { - TYPE: OBJECT, - ADDITIONAL_PROPERTIES: True - }, + MODEL_TRAINER: {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, ESTIMATOR: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index bb84f2d780..0e45711023 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -39,14 +39,7 @@ TrainingImageConfig, TrainingRepositoryAuthConfig, Tag, - MetricDefinition, - DebugHookConfig, - CollectionConfiguration, - DebugRuleConfiguration, - ExperimentConfig, InfraCheckConfig, - ProfilerConfig, - ProfilerRuleConfiguration, RemoteDebugConfig, SessionChainingConfig, InstanceGroup, @@ -69,14 +62,7 @@ "TrainingImageConfig", "TrainingRepositoryAuthConfig", "Tag", - "MetricDefinition", - "DebugHookConfig", - "CollectionConfiguration", - "DebugRuleConfiguration", - "ExperimentConfig", "InfraCheckConfig", - "ProfilerConfig", - "ProfilerRuleConfiguration", "RemoteDebugConfig", "SessionChainingConfig", "InstanceGroup", diff --git a/src/sagemaker/modules/constants.py b/src/sagemaker/modules/constants.py index 271c74313f..9103c03b21 100644 --- a/src/sagemaker/modules/constants.py +++ b/src/sagemaker/modules/constants.py @@ -26,7 +26,7 @@ ) SOURCE_CODE_JSON = "sourcecode.json" -DISTRIBUTED_RUNNER_JSON = "distributed_runner.json" +DISTRIBUTED_JSON = "distributed.json" TRAIN_SCRIPT = "sm_train.sh" DEFAULT_CONTAINER_ENTRYPOINT = ["/bin/bash"] diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py index 4e8334178e..c66cd0bb2d 100644 --- a/src/sagemaker/modules/distributed.py +++ b/src/sagemaker/modules/distributed.py @@ -72,8 +72,8 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]: return hyperparameters -class DistributedRunner(BaseModel): - """Base class for DistributedRunner Class""" +class DistributedConfig(BaseModel): + """Base class for distributed training configurations.""" _type: str = PrivateAttr() @@ -84,11 +84,11 @@ def model_dump(self, *args, **kwargs): return result -class Torchrun(DistributedRunner): - """TorchDistributed. +class Torchrun(DistributedConfig): + """Torchrun. - The Torchrun runner uses `torchrun` or `torch.distributed.launch` in the backend to - launch distributed training. + The Torchrun class configures a job that uses `torchrun` or + `torch.distributed.launch` in the backend to launch distributed training. Attributes: process_count_per_node (int): @@ -104,10 +104,11 @@ class Torchrun(DistributedRunner): smp: Optional["SMP"] = None -class MPI(DistributedRunner): +class MPI(DistributedConfig): """MPI. - The MPI runner uses `mpirun` in the backend to launch distributed training. + The MPI class configures a job that uses `mpirun` in the backend to launch + distributed training. Attributes: process_count_per_node (int): diff --git a/src/sagemaker/modules/templates.py b/src/sagemaker/modules/templates.py index d52ac6eb46..fba60dda47 100644 --- a/src/sagemaker/modules/templates.py +++ b/src/sagemaker/modules/templates.py @@ -76,14 +76,6 @@ cat /opt/ml/input/config/inputdataconfig.json echo -echo "/opt/ml/input/data/sm_drivers/sourcecode.json" -cat /opt/ml/input/data/sm_drivers/sourcecode.json -echo - -echo "/opt/ml/input/data/sm_drivers/distributed_runner.json" -cat /opt/ml/input/data/sm_drivers/distributed_runner.json -echo - echo "Setting up environment variables" $SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/scripts/environment.py source /opt/ml/input/sm_training.env diff --git a/src/sagemaker/modules/testing_notebooks/base_model_trainer.ipynb b/src/sagemaker/modules/testing_notebooks/base_model_trainer.ipynb index 35a1198c73..0c26877ebc 100644 --- a/src/sagemaker/modules/testing_notebooks/base_model_trainer.ipynb +++ b/src/sagemaker/modules/testing_notebooks/base_model_trainer.ipynb @@ -402,7 +402,7 @@ " compute=compute,\n", " hyperparameters=hyperparameters,\n", " source_code=source_code,\n", - " distributed_runner=torchrun,\n", + " distributed=torchrun,\n", " base_job_name=f\"{alias}-distributed-case-2\",\n", ")" ] @@ -498,7 +498,7 @@ " hyperparameters=hyperparameters,\n", " environment=env,\n", " source_code=source_code,\n", - " distributed_runner=mpi,\n", + " distributed=mpi,\n", " base_job_name=f\"{alias}-distributed-case-3\",\n", ")" ] diff --git a/src/sagemaker/modules/train/container_drivers/mpi_driver.py b/src/sagemaker/modules/train/container_drivers/mpi_driver.py index b0e501fd9e..dceb748cc0 100644 --- a/src/sagemaker/modules/train/container_drivers/mpi_driver.py +++ b/src/sagemaker/modules/train/container_drivers/mpi_driver.py @@ -20,7 +20,7 @@ from utils import ( logger, read_source_code_json, - read_distributed_runner_json, + read_distributed_json, read_hyperparameters_json, hyperparameters_to_cli_args, get_process_count, @@ -59,7 +59,7 @@ def main(): """ source_code = read_source_code_json() - distribution = read_distributed_runner_json() + distribution = read_distributed_json() hyperparameters = read_hyperparameters_json() sm_current_host = os.environ["SM_CURRENT_HOST"] diff --git a/src/sagemaker/modules/train/container_drivers/torchrun_driver.py b/src/sagemaker/modules/train/container_drivers/torchrun_driver.py index 95548d35be..666479ec84 100644 --- a/src/sagemaker/modules/train/container_drivers/torchrun_driver.py +++ b/src/sagemaker/modules/train/container_drivers/torchrun_driver.py @@ -21,7 +21,7 @@ from utils import ( logger, read_source_code_json, - read_distributed_runner_json, + read_distributed_json, read_hyperparameters_json, hyperparameters_to_cli_args, get_process_count, @@ -66,7 +66,7 @@ def setup_env(): def create_commands(): """Create the Torch Distributed command to execute""" source_code = read_source_code_json() - distribution = read_distributed_runner_json() + distribution = read_distributed_json() hyperparameters = read_hyperparameters_json() process_count = get_process_count(distribution) diff --git a/src/sagemaker/modules/train/container_drivers/utils.py b/src/sagemaker/modules/train/container_drivers/utils.py index 93e3e4dc03..222dadd688 100644 --- a/src/sagemaker/modules/train/container_drivers/utils.py +++ b/src/sagemaker/modules/train/container_drivers/utils.py @@ -38,7 +38,7 @@ USER_CODE_PATH = "/opt/ml/input/data/sm_code" SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json" -DISTRIBUTED_RUNNER_JSON = "/opt/ml/input/data/sm_drivers/distributed_runner.json" +DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json" HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json" @@ -79,14 +79,14 @@ def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON): return source_code_dict -def read_distributed_runner_json(distributed_json: Dict[str, Any] = DISTRIBUTED_RUNNER_JSON): +def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON): """Read the distribution config json file.""" try: with open(distributed_json, "r") as f: - distributed_runner_dict = json.load(f) or {} + distributed_dict = json.load(f) or {} except FileNotFoundError: - distributed_runner_dict = {} - return distributed_runner_dict + distributed_dict = {} + return distributed_dict def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON): @@ -99,10 +99,10 @@ def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAME return hyperparameters_dict -def get_process_count(distributed_runner_dict: Dict[str, Any]) -> int: +def get_process_count(distributed_dict: Dict[str, Any]) -> int: """Get the number of processes to run on each node in the training job.""" return ( - int(distributed_runner_dict.get("process_count_per_node", 0)) + int(distributed_dict.get("process_count_per_node", 0)) or int(os.environ.get("SM_NUM_GPUS", 0)) or int(os.environ.get("SM_NUM_NEURONS", 0)) or 1 diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index bd099c5036..05d75497a7 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -29,19 +29,22 @@ from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call -from sagemaker.config.config_schema import (_simple_path, SAGEMAKER, - MODEL_TRAINER, MODULES, - PYTHON_SDK, - TRAINING_JOB_ENVIRONMENT_PATH, - TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, - TRAINING_JOB_VPC_CONFIG_PATH, - TRAINING_JOB_SUBNETS_PATH, - TRAINING_JOB_SECURITY_GROUP_IDS_PATH, - TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, - TRAINING_JOB_PROFILE_CONFIG_PATH, - TRAINING_JOB_RESOURCE_CONFIG_PATH, - TRAINING_JOB_ROLE_ARN_PATH, - TRAINING_JOB_TAGS_PATH) +from sagemaker.config.config_schema import ( + _simple_path, + SAGEMAKER, + MODEL_TRAINER, + MODULES, + PYTHON_SDK, + TRAINING_JOB_ENVIRONMENT_PATH, + TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + TRAINING_JOB_VPC_CONFIG_PATH, + TRAINING_JOB_SUBNETS_PATH, + TRAINING_JOB_SECURITY_GROUP_IDS_PATH, + TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, + TRAINING_JOB_RESOURCE_CONFIG_PATH, + TRAINING_JOB_ROLE_ARN_PATH, + TRAINING_JOB_TAGS_PATH, +) from sagemaker.utils import resolve_value_from_config from sagemaker.modules import Session, get_execution_role @@ -58,13 +61,7 @@ FileSystemDataSource, Networking, Tag, - MetricDefinition, - DebugHookConfig, - DebugRuleConfiguration, - ExperimentConfig, InfraCheckConfig, - ProfilerConfig, - ProfilerRuleConfiguration, RemoteDebugConfig, SessionChainingConfig, TensorBoardOutputConfig, @@ -73,10 +70,7 @@ ) from sagemaker.modules.local_core.local_container import _LocalContainer -from sagemaker.modules.distributed import ( - DistributedRunner, - Torchrun, -) +from sagemaker.modules.distributed import Torchrun, MPI, DistributedConfig from sagemaker.modules.utils import ( _get_repo_name_from_image, _get_unique_name, @@ -95,7 +89,7 @@ DEFAULT_CONTAINER_ENTRYPOINT, DEFAULT_CONTAINER_ARGUMENTS, SOURCE_CODE_JSON, - DISTRIBUTED_RUNNER_JSON, + DISTRIBUTED_JSON, ) from sagemaker.modules.templates import ( TRAIN_SCRIPT_TEMPLATE, @@ -151,7 +145,7 @@ class ModelTrainer(BaseModel): source_code (Optional[SourceCode]): The source code configuration. This is used to configure the source code for running the training job. - distributed_runner (Optional[DistributedRunner]): + distributed (Optional[Union[MPI, Torchrun]]): The distributed runner for the training job. This is used to configure a distributed training job. If specifed, `source_code` must also be provided. @@ -206,7 +200,7 @@ class ModelTrainer(BaseModel): role: Optional[str] = None base_job_name: Optional[str] = None source_code: Optional[SourceCode] = None - distributed_runner: Optional[DistributedRunner] = None + distributed: Optional[Union[MPI, Torchrun]] = None compute: Optional[Compute] = None networking: Optional[Networking] = None stopping_condition: Optional[StoppingCondition] = None @@ -222,31 +216,45 @@ class ModelTrainer(BaseModel): tags: Optional[List[Tag]] = None local_container_root: Optional[str] = os.getcwd() - CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = ["role", - "base_job_name", - "source_code", - "distributed_runner", - "compute", - "networking", - "stopping_condition", - "training_image", - "training_image_config", - "algorithm_name", - "output_data_config", - "checkpoint_config", - "training_input_mode", - "environment", - "hyperparameters"] + # Created Artifacts + _latest_training_job: Optional[resources.TrainingJob] = PrivateAttr(default=None) + + # Private TrainingJob Parameters + _tensorboard_output_config: Optional[TensorBoardOutputConfig] = PrivateAttr(default=None) + _retry_strategy: Optional[RetryStrategy] = PrivateAttr(default=None) + _infra_check_config: Optional[InfraCheckConfig] = PrivateAttr(default=None) + _session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None) + _remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None) + + _temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) + + CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = [ + "role", + "base_job_name", + "source_code", + "distributed", + "compute", + "networking", + "stopping_condition", + "training_image", + "training_image_config", + "algorithm_name", + "output_data_config", + "checkpoint_config", + "training_input_mode", + "environment", + "hyperparameters", + ] SERIALIZABLE_CONFIG_ATTRIBUTES: ClassVar[Any] = { "source_code": SourceCode, - "distributed_runner": type(DistributedRunner), - "compute": type(Compute), - "networking": type(Networking), - "stopping_condition": type(StoppingCondition), - "training_image_config": type(TrainingImageConfig), - "output_data_config": type(OutputDataConfig), - "checkpoint_config": type(CheckpointConfig) + "distributed": DistributedConfig, + "compute": Compute, + "networking": Networking, + "stopping_condition": StoppingCondition, + "training_image_config": TrainingImageConfig, + "output_data_config": OutputDataConfig, + "checkpoint_config": CheckpointConfig, } def _populate_intelligent_defaults(self): @@ -261,52 +269,51 @@ def _populate_intelligent_defaults_from_training_job_space(self): """Function to populate all the possible default configs from Training Job Space""" if not self.environment: self.environment = resolve_value_from_config( - config_path=TRAINING_JOB_ENVIRONMENT_PATH, - sagemaker_session=self.sagemaker_session) + config_path=TRAINING_JOB_ENVIRONMENT_PATH, sagemaker_session=self.sagemaker_session + ) default_enable_network_isolation = resolve_value_from_config( config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, - sagemaker_session=self.sagemaker_session) + sagemaker_session=self.sagemaker_session, + ) default_vpc_config = resolve_value_from_config( - config_path=TRAINING_JOB_VPC_CONFIG_PATH, - sagemaker_session=self.sagemaker_session) + config_path=TRAINING_JOB_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session + ) if not self.networking: - if (default_enable_network_isolation is not None - or default_vpc_config is not None): + if default_enable_network_isolation is not None or default_vpc_config is not None: self.networking = Networking( default_enable_network_isolation=default_enable_network_isolation, subnets=resolve_value_from_config(config_path=TRAINING_JOB_SUBNETS_PATH), security_group_ids=resolve_value_from_config( - config_path=TRAINING_JOB_SECURITY_GROUP_IDS_PATH), + config_path=TRAINING_JOB_SECURITY_GROUP_IDS_PATH + ), ) else: if self.networking.enable_network_isolation is None: self.networking.enable_network_isolation = default_enable_network_isolation if self.networking.subnets is None: - self.networking.subnets = ( - resolve_value_from_config(config_path=TRAINING_JOB_SUBNETS_PATH)) + self.networking.subnets = resolve_value_from_config( + config_path=TRAINING_JOB_SUBNETS_PATH + ) if self.networking.security_group_ids is None: - self.networking.subnets = ( - resolve_value_from_config(config_path=TRAINING_JOB_SUBNETS_PATH)) + self.networking.subnets = resolve_value_from_config( + config_path=TRAINING_JOB_SUBNETS_PATH + ) if not self.output_data_config: default_output_data_config = resolve_value_from_config( - config_path=TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH) + config_path=TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH + ) if default_output_data_config: self.output_data_config = OutputDataConfig( - **self._convert_keys_to_snake(default_output_data_config)) - - if not self._profiler_config: - default_profiler_config = resolve_value_from_config( - config_path=TRAINING_JOB_PROFILE_CONFIG_PATH) - if default_profiler_config: - self._profiler_config = ProfilerConfig( - **self._convert_keys_to_snake(default_profiler_config)) + **self._convert_keys_to_snake(default_output_data_config) + ) if not self.compute: default_resource_config = resolve_value_from_config( - config_path=TRAINING_JOB_RESOURCE_CONFIG_PATH) + config_path=TRAINING_JOB_RESOURCE_CONFIG_PATH + ) if default_resource_config: self.compute = Compute(**self._convert_keys_to_snake(default_resource_config)) @@ -318,10 +325,7 @@ def _populate_intelligent_defaults_from_training_job_space(self): def _convert_keys_to_snake(self, config: dict) -> dict: """Utility helper function that converts the keys of a dictionary into snake case""" - return { - to_snake_case(key): value - for key, value in config.items() - } + return {to_snake_case(key): value for key, value in config.items()} def _populate_intelligent_defaults_from_model_trainer_space(self): """Function to populate all the possible default configs from Model Trainer Space""" @@ -329,43 +333,24 @@ def _populate_intelligent_defaults_from_model_trainer_space(self): for configurable_attribute in self.CONFIGURABLE_ATTRIBUTES: if getattr(self, configurable_attribute) is None: default_config = resolve_value_from_config( - config_path=_simple_path(SAGEMAKER, - PYTHON_SDK, - MODULES, - MODEL_TRAINER, - to_camel_case(configurable_attribute)), - sagemaker_session=self.sagemaker_session) + config_path=_simple_path( + SAGEMAKER, + PYTHON_SDK, + MODULES, + MODEL_TRAINER, + to_camel_case(configurable_attribute), + ), + sagemaker_session=self.sagemaker_session, + ) if default_config is not None: if configurable_attribute in self.SERIALIZABLE_CONFIG_ATTRIBUTES: - default_config = (self.SERIALIZABLE_CONFIG_ATTRIBUTES - .get(configurable_attribute)(**default_config)) # noqa + default_config = self.SERIALIZABLE_CONFIG_ATTRIBUTES.get( + configurable_attribute + )( + **default_config # pylint: disable=E1134 + ) # noqa setattr(self, configurable_attribute, default_config) - # Created Artifacts - _latest_training_job: Optional[resources.TrainingJob] = PrivateAttr(default=None) - - # Metrics settings - _enable_sage_maker_metrics_time_series: Optional[bool] = PrivateAttr(default=False) - _metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None) - - # Debugger settings - _debug_hook_config: Optional[DebugHookConfig] = PrivateAttr(default=None) - _debug_rule_configurations: Optional[List[DebugRuleConfiguration]] = PrivateAttr(default=None) - _profiler_config: Optional[ProfilerConfig] = PrivateAttr(default=None) - _profiler_rule_configurations: Optional[List[ProfilerRuleConfiguration]] = PrivateAttr( - default=None - ) - _tensor_board_output_config: Optional[TensorBoardOutputConfig] = PrivateAttr(default=None) - - # Additional settings - _retry_strategy: Optional[RetryStrategy] = PrivateAttr(default=None) - _experiment_config: Optional[ExperimentConfig] = PrivateAttr(default=None) - _infra_check_config: Optional[InfraCheckConfig] = PrivateAttr(default=None) - _session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None) - _remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None) - - _temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) - def __del__(self): """Destructor method to clean up the temporary directory.""" # Clean up the temporary directory if it exists @@ -385,13 +370,13 @@ def _validate_training_image_and_algorithm_name( "Only one of 'training_image' or 'algorithm_name' must be provided.", ) - def _validate_distributed_runner( + def _validate_distributed_config( self, source_code: Optional[SourceCode], - distributed_runner: Optional[DistributedRunner], + distributed: Optional[DistributedConfig], ): """Validate the distribution configuration.""" - if distributed_runner and not source_code.entry_script: + if distributed and not source_code.entry_script: raise ValueError( "Must provide 'entry_script' if 'distribution' " + "is provided in 'source_code'.", ) @@ -436,7 +421,7 @@ def model_post_init(self, __context: Any): """Post init method to perform custom validation and set default values.""" self._validate_training_image_and_algorithm_name(self.training_image, self.algorithm_name) self._validate_source_code(self.source_code) - self._validate_distributed_runner(self.source_code, self.distributed_runner) + self._validate_distributed_config(self.source_code, self.distributed) if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: if self.sagemaker_session is None: @@ -546,17 +531,15 @@ def train( self._prepare_train_script( tmp_dir=drivers_dir, source_code=self.source_code, - distributed_runner=self.distributed_runner, + distributed=self.distributed, ) - if isinstance(self.distributed_runner, Torchrun) and self.distributed_runner.smp: - mp_parameters = self.distributed_runner.smp._to_mp_hyperparameters() + if isinstance(self.distributed, Torchrun) and self.distributed.smp: + mp_parameters = self.distributed.smp._to_mp_hyperparameters() string_hyper_parameters.update(mp_parameters) self._write_source_code_json(tmp_dir=drivers_dir, source_code=self.source_code) - self._write_distributed_runner_json( - tmp_dir=drivers_dir, distributed_runner=self.distributed_runner - ) + self._write_distributed_json(tmp_dir=drivers_dir, distributed=self.distributed) # Create an input channel for drivers packaged by the sdk sm_drivers_channel = self.create_input_data_channel(SM_DRIVERS, drivers_dir.name) @@ -577,8 +560,6 @@ def train( training_image_config=self.training_image_config, container_entrypoint=container_entrypoint, container_arguments=container_arguments, - metric_definitions=self._metric_definitions, - enable_sage_maker_metrics_time_series=self._enable_sage_maker_metrics_time_series, ) resource_config = self.compute._to_resource_config() @@ -610,14 +591,9 @@ def train( self.networking.enable_network_isolation if self.networking else None ), # Private Instance Attributes - debug_hook_config=self._debug_hook_config, - debug_rule_configurations=self._debug_rule_configurations, remote_debug_config=self._remote_debug_config, - profiler_config=self._profiler_config, - profiler_rule_configurations=self._profiler_rule_configurations, - tensor_board_output_config=self._tensor_board_output_config, + tensor_board_output_config=self._tensorboard_output_config, retry_strategy=self._retry_strategy, - experiment_config=self._experiment_config, infra_check_config=self._infra_check_config, session_chaining_config=self._session_chaining_config, ) @@ -747,22 +723,22 @@ def _write_source_code_json(self, tmp_dir: TemporaryDirectory, source_code: Sour dump = source_code.model_dump(exclude_none=True) if source_code else {} f.write(json.dumps(dump)) - def _write_distributed_runner_json( + def _write_distributed_json( self, tmp_dir: TemporaryDirectory, - distributed_runner: Optional[DistributedRunner] = None, + distributed: Optional[DistributedConfig] = None, ): """Write the distributed runner configuration to a JSON file.""" - file_path = os.path.join(tmp_dir.name, DISTRIBUTED_RUNNER_JSON) + file_path = os.path.join(tmp_dir.name, DISTRIBUTED_JSON) with open(file_path, "w") as f: - dump = distributed_runner.model_dump(exclude_none=True) if distributed_runner else {} + dump = distributed.model_dump(exclude_none=True) if distributed else {} f.write(json.dumps(dump)) def _prepare_train_script( self, tmp_dir: TemporaryDirectory, source_code: SourceCode, - distributed_runner: Optional[DistributedRunner] = None, + distributed: Optional[DistributedConfig] = None, ): """Prepare the training script to be executed in the training job container. @@ -791,15 +767,15 @@ def _prepare_train_script( if base_command: execute_driver = EXECUTE_BASE_COMMANDS.format(base_command=base_command) - elif distributed_runner: - distribution_type = distributed_runner._type + elif distributed: + distribution_type = distributed._type if distribution_type == "mpi": execute_driver = EXECUTE_MPI_DRIVER elif distribution_type == "torchrun": execute_driver = EXEUCTE_TORCHRUN_DRIVER else: raise ValueError(f"Unsupported distribution type: {distribution_type}.") - elif source_code.entry_script and not source_code.command and not distributed_runner: + elif source_code.entry_script and not source_code.command and not distributed: if not source_code.entry_script.endswith((".py", ".sh")): raise ValueError( f"Unsupported entry script: {source_code.entry_script}." @@ -816,139 +792,6 @@ def _prepare_train_script( with open(os.path.join(tmp_dir.name, TRAIN_SCRIPT), "w") as f: f.write(train_script) - def with_metric_settings( - self, - enable_sage_maker_metrics_time_series: bool = True, - metric_definitions: Optional[List[MetricDefinition]] = None, - ) -> "ModelTrainer": - """Set the metrics configuration for the training job. - - Example: - ```python - model_trainer = ModelTrainer(...).with_metric_settings( - enable_sage_maker_metrics_time_series=True, - metric_definitions=[ - MetricDefinition( - name="loss", - regex="Loss: (.*?),", - ), - MetricDefinition( - name="accuracy", - regex="Accuracy: (.*?),", - ), - ] - ) - ``` - - Args: - enable_sage_maker_metrics_time_series (Optional[bool]): - Whether to enable SageMaker metrics time series. Defaults to True. - metric_definitions (Optional[List[MetricDefinition]]): - A list of metric definition objects. Each object specifies - the metric name and regular expressions used to parse algorithm logs. - SageMaker publishes each metric to Amazon CloudWatch. - """ - self._enable_sage_maker_metrics_time_series = enable_sage_maker_metrics_time_series - self._metric_definitions = metric_definitions - return self - - def with_debugger_settings( - self, - debug_hook_config: Optional[DebugHookConfig] = None, - debug_rule_configurations: Optional[List[DebugRuleConfiguration]] = None, - profiler_config: Optional[ProfilerConfig] = None, - profiler_rule_configurations: Optional[List[ProfilerRuleConfiguration]] = None, - tensor_board_output_config: Optional[TensorBoardOutputConfig] = None, - ) -> "ModelTrainer": - """Set the configuration for settings related to Amazon SageMaker Debugger. - - Example: - ```python - model_trainer = ModelTrainer(...).with_debugger_settings( - debug_hook_config=DebugHookConfig( - s3_output_path="s3://bucket/path", - collection_configurations=[ - CollectionConfiguration( - collection_name="some_collection", - collection_parameters={ - "include_regex": ".*", - } - ) - ] - ) - ) - ``` - - Args: - debug_hook_config (Optional[DebugHookConfig]): - Configuration information for the Amazon SageMaker Debugger hook parameters, - metric and tensor collections, and storage paths. - To learn more see: - https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-createtrainingjob-api.html - debug_rule_configurations (Optional[List[DebugRuleConfiguration]]): - Configuration information for Amazon SageMaker Debugger rules for debugging - output ensors. - profiler_config (ProfilerConfig): - Configuration information for Amazon SageMaker Debugger system monitoring, - framework profiling, and storage paths. - profiler_rule_configurations (List[ProfilerRuleConfiguration]): - Configuration information for Amazon SageMaker Debugger rules for profiling - system and framework metrics. - tensor_board_output_config (Optional[TensorBoardOutputConfig]): - Configuration of storage locations for the Amazon SageMaker Debugger TensorBoard - output data. - """ - self._debug_hook_config = debug_hook_config - self._debug_rule_configurations = debug_rule_configurations - self._profiler_config = profiler_config - self._profiler_rule_configurations = profiler_rule_configurations - self._tensor_board_output_config = tensor_board_output_config - return self - - def with_additional_settings( - self, - retry_strategy: Optional[RetryStrategy] = None, - experiment_config: Optional[ExperimentConfig] = None, - infra_check_config: Optional[InfraCheckConfig] = None, - session_chaining_config: Optional[SessionChainingConfig] = None, - remote_debug_config: Optional[RemoteDebugConfig] = None, - ) -> "ModelTrainer": - """Set any additional settings for the training job. - - Example: - ```python - model_trainer = ModelTrainer(...).with_additional_settings( - experiment_config=ExperimentConfig( - experiment_name="my-experiment", - trial_name="my-trial", - ) - ) - ``` - - Args: - retry_strategy (Optional[RetryStrategy]): - The number of times to retry the job when the job fails due to an - `InternalServerError`. - experiment_config (Optional[ExperimentConfig]): - Configuration information for the Amazon SageMaker Experiment. - Associates a SageMaker job as a trial component with an experiment and trial - infra_check_config (Optional[InfraCheckConfig]): - Contains information about the infrastructure health check configuration for the - training job. - session_chaining_config (Optional[SessionChainingConfig]): - Contains information about attribute-based access control (ABAC) for the training - job. - remote_debug_config (Optional[RemoteDebugConfig]): - Configuration for remote debugging through AWS Systems Manager. To learn more see: - https://docs.aws.amazon.com/sagemaker/latest/dg/train-remote-debugging.html - """ - self._retry_strategy = retry_strategy - self._experiment_config = experiment_config - self._infra_check_config = infra_check_config - self._session_chaining_config = session_chaining_config - self._remote_debug_config = remote_debug_config - return self - @classmethod def from_recipe( cls, @@ -1038,7 +881,7 @@ def from_recipe( # The training recipe is used to prepare the following args: # - source_code # - training_image - # - distributed_runner + # - distributed # - compute # - hyperparameters model_trainer_args, recipe_train_dir = _get_args_from_recipe( @@ -1067,3 +910,57 @@ def from_recipe( model_trainer._temp_recipe_train_dir = recipe_train_dir return model_trainer + + def with_tensorboard_output_config( + self, tensorboard_output_config: TensorBoardOutputConfig + ) -> "ModelTrainer": + """Set the TensorBoard output configuration. + + Args: + tensorboard_output_config (TensorBoardOutputConfig): + The TensorBoard output configuration. + """ + self._tensorboard_output_config = tensorboard_output_config + return self + + def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer": + """Set the retry strategy for the training job. + + Args: + retry_strategy (RetryStrategy): + The retry strategy for the training job. + """ + self._retry_strategy = retry_strategy + return self + + def with_infra_check_config(self, infra_check_config: InfraCheckConfig) -> "ModelTrainer": + """Set the infra check configuration for the training job. + + Args: + infra_check_config (InfraCheckConfig): + The infra check configuration for the training job. + """ + self._infra_check_config = infra_check_config + return self + + def with_session_chaining_config( + self, session_chaining_config: SessionChainingConfig + ) -> "ModelTrainer": + """Set the session chaining configuration for the training job. + + Args: + session_chaining_config (SessionChainingConfig): + The session chaining configuration for the training job. + """ + self._session_chaining_config = session_chaining_config + return self + + def with_remote_debug_config(self, remote_debug_config: RemoteDebugConfig) -> "ModelTrainer": + """Set the remote debug configuration for the training job. + + Args: + remote_debug_config (RemoteDebugConfig): + The remote debug configuration for the training job. + """ + self._remote_debug_config = remote_debug_config + return self diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index 450856f74f..ddd93b2432 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -177,7 +177,7 @@ def _configure_gpu_args( { "source_code": source_code, "training_image": training_image, - "distributed_runner": torch_distributed, + "distributed": torch_distributed, } ) return args @@ -212,7 +212,7 @@ def _configure_trainium_args( { "source_code": source_code, "training_image": training_image, - "distributed_runner": Torchrun(), + "distributed": Torchrun(), } ) return args @@ -232,7 +232,7 @@ def _get_args_from_recipe( { "source_code": SourceCode, "training_image": str, - "distributed_runner": DistributedRunner, + "distributed": DistributedConfig, "compute": Compute, "hyperparameters": Dict[str, Any], } @@ -275,7 +275,7 @@ def _get_args_from_recipe( if requirements and not os.path.isfile(requirements): raise ValueError(f"Recipe requirements file {requirements} not found.") - # Get Training Image, SourceCode, and DistributedRunner args + # Get Training Image, SourceCode, and distributed args device_type = _determine_device_type(compute.instance_type) recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_") if device_type == "gpu": diff --git a/tests/integ/sagemaker/modules/train/test_model_trainer.py b/tests/integ/sagemaker/modules/train/test_model_trainer.py index cb1b8bb765..cd298402b2 100644 --- a/tests/integ/sagemaker/modules/train/test_model_trainer.py +++ b/tests/integ/sagemaker/modules/train/test_model_trainer.py @@ -82,7 +82,7 @@ def test_hp_contract_mpi_script(modules_sagemaker_session): compute=compute, hyperparameters=EXPECTED_HYPERPARAMETERS, source_code=source_code, - distributed_runner=MPI(), + distributed=MPI(), base_job_name="hp-contract-mpi-script", ) @@ -101,7 +101,7 @@ def test_hp_contract_torchrun_script(modules_sagemaker_session): compute=compute, hyperparameters=EXPECTED_HYPERPARAMETERS, source_code=source_code, - distributed_runner=Torchrun(), + distributed=Torchrun(), base_job_name="hp-contract-torchrun-script", ) diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py index 9c9fc8feb7..a1a84da1ab 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py @@ -44,7 +44,7 @@ "source_code": "source_code", "entry_script": "script.py", } -DUMMY_DISTRIBUTED_RUNNER = { +DUMMY_DISTRIBUTED = { "_type": "mpi", "process_count_per_node": 2, "mpi_additional_options": [ @@ -64,7 +64,7 @@ "SM_HOST_COUNT": "2", }, ) -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_runner_json") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_json") @patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json") @patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file") @patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon") @@ -82,11 +82,11 @@ def test_mpi_driver_worker( mock_start_sshd_daemon, mock_write_env_vars_to_file, mock_read_source_code_json, - mock_read_distributed_runner_json, + mock_read_distributed_json, ): mock_hyperparameters_to_cli_args.return_value = [] mock_read_source_code_json.return_value = DUMMY_SOURCE_CODE - mock_read_distributed_runner_json.return_value = DUMMY_DISTRIBUTED_RUNNER + mock_read_distributed_json.return_value = DUMMY_DISTRIBUTED mpi_driver.main() @@ -108,7 +108,7 @@ def test_mpi_driver_worker( "SM_HOST_COUNT": "2", }, ) -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_runner_json") +@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_json") @patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json") @patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file") @patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon") @@ -130,11 +130,11 @@ def test_mpi_driver_master( mock_start_sshd_daemon, mock_write_env_vars_to_file, mock_read_source_code_config_json, - mock_read_distributed_runner_json, + mock_read_distributed_json, ): mock_hyperparameters_to_cli_args.return_value = [] mock_read_source_code_config_json.return_value = DUMMY_SOURCE_CODE - mock_read_distributed_runner_json.return_value = DUMMY_DISTRIBUTED_RUNNER + mock_read_distributed_json.return_value = DUMMY_DISTRIBUTED mock_get_mpirun_command.return_value = DUMMY_MPI_COMMAND mock_get_process_count.return_value = 2 mock_execute_commands.return_value = (0, "") diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py index f32b440e39..4cff07a0c0 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py @@ -27,7 +27,7 @@ "entry_script": "script.py", } -DUMMY_DISTRIBUTED_RUNNER = {"_type": "torchrun", "process_count_per_node": 2} +DUMMY_distributed = {"_type": "torchrun", "process_count_per_node": 2} @patch( @@ -83,8 +83,8 @@ def test_get_base_pytorch_command_torch_distributed_launch( return_value=DUMMY_SOURCE_CODE, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_runner_json", - return_value=DUMMY_DISTRIBUTED_RUNNER, + "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_json", + return_value=DUMMY_distributed, ) @patch( "sagemaker.modules.train.container_drivers.torchrun_driver.hyperparameters_to_cli_args", @@ -92,7 +92,7 @@ def test_get_base_pytorch_command_torch_distributed_launch( ) def test_create_commands_single_node( mock_hyperparameters_to_cli_args, - mock_read_distributed_runner_json, + mock_read_distributed_json, mock_read_source_code_json, mock_get_base_pytorch_command, mock_pytorch_version, @@ -139,8 +139,8 @@ def test_create_commands_single_node( return_value=DUMMY_SOURCE_CODE, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_runner_json", - return_value=DUMMY_DISTRIBUTED_RUNNER, + "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_json", + return_value=DUMMY_distributed, ) @patch( "sagemaker.modules.train.container_drivers.torchrun_driver.hyperparameters_to_cli_args", @@ -148,7 +148,7 @@ def test_create_commands_single_node( ) def test_create_commands_multi_node( mock_hyperparameters_to_cli_args, - mock_read_distributed_runner_json, + mock_read_distributed_json, mock_read_source_code_json, mock_get_base_pytorch_command, mock_pytorch_version, diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index c876871485..962e2e0852 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -24,14 +24,16 @@ from sagemaker_core.main.shapes import ResourceConfig from sagemaker.config import SAGEMAKER, PYTHON_SDK, MODULES -from sagemaker.config.config_schema import (MODEL_TRAINER, - _simple_path, - TRAINING_JOB_RESOURCE_CONFIG_PATH) +from sagemaker.config.config_schema import ( + MODEL_TRAINER, + _simple_path, + TRAINING_JOB_RESOURCE_CONFIG_PATH, +) from sagemaker.modules import Session from sagemaker.modules.train.model_trainer import ModelTrainer from sagemaker.modules.constants import ( DEFAULT_INSTANCE_TYPE, - DISTRIBUTED_RUNNER_JSON, + DISTRIBUTED_JSON, SOURCE_CODE_JSON, TRAIN_SCRIPT, ) @@ -43,14 +45,8 @@ SourceCode, S3DataSource, FileSystemDataSource, - MetricDefinition, - DebugHookConfig, - DebugRuleConfiguration, RemoteDebugConfig, - ProfilerConfig, - ProfilerRuleConfiguration, TensorBoardOutputConfig, - ExperimentConfig, InfraCheckConfig, SessionChainingConfig, InputData, @@ -174,17 +170,16 @@ def test_train_with_default_params(mock_training_job, model_trainer): @patch("sagemaker.modules.train.model_trainer.TrainingJob") @patch("sagemaker.modules.train.model_trainer.resolve_value_from_config") -def test_train_with_intelligent_defaults(mock_resolve_value_from_config, - mock_training_job, - model_trainer): - source_code_path = _simple_path(SAGEMAKER, - PYTHON_SDK, - MODULES, - MODEL_TRAINER, - "sourceCode") - - mock_resolve_value_from_config.side_effect = lambda **kwargs: {"command": "echo 'Hello World' && env"} \ - if kwargs['config_path'] == source_code_path else None +def test_train_with_intelligent_defaults( + mock_resolve_value_from_config, mock_training_job, model_trainer +): + source_code_path = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, MODEL_TRAINER, "sourceCode") + + mock_resolve_value_from_config.side_effect = lambda **kwargs: ( + {"command": "echo 'Hello World' && env"} + if kwargs["config_path"] == source_code_path + else None + ) model_trainer.train() @@ -196,58 +191,59 @@ def test_train_with_intelligent_defaults(mock_resolve_value_from_config, @patch("sagemaker.modules.train.model_trainer.TrainingJob") @patch("sagemaker.modules.train.model_trainer.resolve_value_from_config") -def test_train_with_intelligent_defaults_training_job_space(mock_resolve_value_from_config, - mock_training_job, - model_trainer): - mock_resolve_value_from_config.side_effect = lambda **kwargs: { - "instanceType": DEFAULT_INSTANCE_TYPE, - "instanceCount": 1, - "volumeSizeInGB": 30, - } if kwargs['config_path'] == TRAINING_JOB_RESOURCE_CONFIG_PATH else None +def test_train_with_intelligent_defaults_training_job_space( + mock_resolve_value_from_config, mock_training_job, model_trainer +): + mock_resolve_value_from_config.side_effect = lambda **kwargs: ( + { + "instanceType": DEFAULT_INSTANCE_TYPE, + "instanceCount": 1, + "volumeSizeInGB": 30, + } + if kwargs["config_path"] == TRAINING_JOB_RESOURCE_CONFIG_PATH + else None + ) model_trainer.train() - mock_training_job.create.assert_called_once_with(training_job_name=ANY, - algorithm_specification=ANY, - hyper_parameters={}, - input_data_config=[], - resource_config=ResourceConfig( - volume_size_in_gb=30, - instance_type='ml.m5.xlarge', - instance_count=1, - volume_kms_key_id=None, - keep_alive_period_in_seconds=None, - instance_groups=None), - vpc_config=None, - session=ANY, - role_arn='arn:aws:iam::000000000000:' - 'role/test-role', - tags=None, - stopping_condition=StoppingCondition( - max_runtime_in_seconds=3600, - max_wait_time_in_seconds=None, - max_pending_time_in_seconds=None), - output_data_config=OutputDataConfig( - s3_output_path='s3://' - 'sagemaker-us-west-2' - '-000000000000/d' - 'ummy-image-job', - kms_key_id=None, compression_type='GZIP'), - checkpoint_config=None, - environment=None, - enable_managed_spot_training=None, - enable_inter_container_traffic_encryption=None, - enable_network_isolation=None, - debug_hook_config=None, - debug_rule_configurations=None, - remote_debug_config=None, - profiler_config=None, - profiler_rule_configurations=None, - tensor_board_output_config=None, - retry_strategy=None, - experiment_config=None, - infra_check_config=None, - session_chaining_config=None) + mock_training_job.create.assert_called_once_with( + training_job_name=ANY, + algorithm_specification=ANY, + hyper_parameters={}, + input_data_config=[], + resource_config=ResourceConfig( + volume_size_in_gb=30, + instance_type="ml.m5.xlarge", + instance_count=1, + volume_kms_key_id=None, + keep_alive_period_in_seconds=None, + instance_groups=None, + ), + vpc_config=None, + session=ANY, + role_arn="arn:aws:iam::000000000000:" "role/test-role", + tags=None, + stopping_condition=StoppingCondition( + max_runtime_in_seconds=3600, + max_wait_time_in_seconds=None, + max_pending_time_in_seconds=None, + ), + output_data_config=OutputDataConfig( + s3_output_path="s3://" "sagemaker-us-west-2" "-000000000000/d" "ummy-image-job", + kms_key_id=None, + compression_type="GZIP", + ), + checkpoint_config=None, + environment=None, + enable_managed_spot_training=None, + enable_inter_container_traffic_encryption=None, + enable_network_isolation=None, + remote_debug_config=None, + tensor_board_output_config=None, + retry_strategy=None, + infra_check_config=None, + session_chaining_config=None, + ) training_job_instance = mock_training_job.create.return_value training_job_instance.wait.assert_called_once_with(logs=True) @@ -336,181 +332,18 @@ def test_create_input_data_channel(mock_upload_data, model_trainer, test_case): assert channel.data_source.s3_data_source.s3_uri == expected_s3_uri -@patch("sagemaker.modules.train.model_trainer.TrainingJob") -@patch("sagemaker.modules.train.model_trainer.resolve_value_from_config") -def test_metric_settings(resolve_value_from_config, - mock_training_job, - modules_session): - image_uri = DEFAULT_IMAGE - role = DEFAULT_ROLE - metric_definition = MetricDefinition( - name="test-metric", - regex="test-regex", - ) - resolve_value_from_config.return_value = None - - model_trainer = ModelTrainer( - training_image=image_uri, - sagemaker_session=modules_session, - role=role, - ).with_metric_settings( - enable_sage_maker_metrics_time_series=True, metric_definitions=[metric_definition] - ) - - assert model_trainer._enable_sage_maker_metrics_time_series - assert model_trainer._metric_definitions == [metric_definition] - - with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: - mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" - model_trainer.train() - - mock_training_job.create.assert_called_once() - assert mock_training_job.create.call_args.kwargs[ - "algorithm_specification" - ].metric_definitions == [metric_definition] - - assert mock_training_job.create.call_args.kwargs[ - "algorithm_specification" - ].enable_sage_maker_metrics_time_series - - -@patch("sagemaker.modules.train.model_trainer.TrainingJob") -@patch("sagemaker.modules.train.model_trainer.resolve_value_from_config") -def test_debugger_settings(mock_resolve_value_from_config, - mock_training_job, - modules_session): - image_uri = DEFAULT_IMAGE - role = DEFAULT_ROLE - mock_resolve_value_from_config.return_value = None - - debug_hook_config = DebugHookConfig(s3_output_path="s3://dummy-bucket/dummy-prefix") - debug_rule_config = DebugRuleConfiguration( - rule_configuration_name="rule-name", - rule_evaluator_image=image_uri, - rule_parameters={"parameter": "value"}, - ) - profiler_config = ProfilerConfig(s3_output_path="s3://dummy-bucket/dummy-prefix") - profiler_rule_config = ProfilerRuleConfiguration( - rule_configuration_name="rule-name", - rule_evaluator_image=image_uri, - ) - tensor_board_output_config = TensorBoardOutputConfig( - s3_output_path="s3://dummy-bucket/dummy-prefix" - ) - - model_trainer = ModelTrainer( - training_image=image_uri, - sagemaker_session=modules_session, - role=role, - ).with_debugger_settings( - debug_hook_config=debug_hook_config, - debug_rule_configurations=debug_rule_config, - profiler_config=profiler_config, - profiler_rule_configurations=profiler_rule_config, - tensor_board_output_config=tensor_board_output_config, - ) - - assert model_trainer._debug_hook_config == debug_hook_config - assert model_trainer._debug_rule_configurations == debug_rule_config - - assert model_trainer._profiler_config == profiler_config - assert model_trainer._profiler_rule_configurations == profiler_rule_config - assert model_trainer._tensor_board_output_config == tensor_board_output_config - - with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: - mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" - model_trainer.train() - - mock_training_job.create.assert_called_once() - assert mock_training_job.create.call_args.kwargs["debug_hook_config"] == debug_hook_config - assert ( - mock_training_job.create.call_args.kwargs["debug_rule_configurations"] - == debug_rule_config - ) - assert mock_training_job.create.call_args.kwargs["profiler_config"] == profiler_config - assert ( - mock_training_job.create.call_args.kwargs["profiler_rule_configurations"] - == profiler_rule_config - ) - assert ( - mock_training_job.create.call_args.kwargs["tensor_board_output_config"] - == tensor_board_output_config - ) - - -@patch("sagemaker.modules.train.model_trainer.TrainingJob") -@patch("sagemaker.modules.train.model_trainer.resolve_value_from_config") -def test_additional_settings(mock_resolve_value_from_config, - mock_training_job, - modules_session): - image_uri = DEFAULT_IMAGE - role = DEFAULT_ROLE - mock_resolve_value_from_config.return_value = None - - retry_strategy = RetryStrategy( - maximum_retry_attempts=3, - ) - remote_debug_config = RemoteDebugConfig( - enable_remote_debug=True, - ) - experiment_config = ExperimentConfig( - experiment_name="experiment-name", - trial_name="trial-name", - ) - infra_check_config = InfraCheckConfig( - enable_infra_check=True, - ) - session_chaining_config = SessionChainingConfig( - enable_session_tag_chaining=True, - ) - model_trainer = ModelTrainer( - training_image=image_uri, - sagemaker_session=modules_session, - role=role, - ).with_additional_settings( - retry_strategy=retry_strategy, - experiment_config=experiment_config, - remote_debug_config=remote_debug_config, - infra_check_config=infra_check_config, - session_chaining_config=session_chaining_config, - ) - - assert model_trainer._retry_strategy == retry_strategy - assert model_trainer._experiment_config == experiment_config - assert model_trainer._infra_check_config == infra_check_config - assert model_trainer._session_chaining_config == session_chaining_config - assert model_trainer._remote_debug_config == remote_debug_config - - with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: - mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" - model_trainer.train() - - mock_training_job.create.assert_called_once() - - assert mock_training_job.create.call_args.kwargs["retry_strategy"] == retry_strategy - assert mock_training_job.create.call_args.kwargs["experiment_config"] == experiment_config - assert mock_training_job.create.call_args.kwargs["infra_check_config"] == infra_check_config - assert ( - mock_training_job.create.call_args.kwargs["session_chaining_config"] - == session_chaining_config - ) - assert ( - mock_training_job.create.call_args.kwargs["remote_debug_config"] == remote_debug_config - ) - - @pytest.mark.parametrize( "test_case", [ { "source_code": DEFAULT_SOURCE_CODE, - "distributed_runner": Torchrun(), + "distributed": Torchrun(), "expected_template": EXEUCTE_TORCHRUN_DRIVER, "expected_hyperparameters": {}, }, { "source_code": DEFAULT_SOURCE_CODE, - "distributed_runner": Torchrun( + "distributed": Torchrun( smp=SMP( hybrid_shard_degree=3, sm_activation_offloading=True, @@ -532,7 +365,7 @@ def test_additional_settings(mock_resolve_value_from_config, }, { "source_code": DEFAULT_SOURCE_CODE, - "distributed_runner": MPI( + "distributed": MPI( custom_mpi_options=["-x", "VAR1", "-x", "VAR2"], ), "expected_template": EXECUTE_MPI_DRIVER, @@ -548,13 +381,13 @@ def test_additional_settings(mock_resolve_value_from_config, @patch("sagemaker.modules.train.model_trainer.TrainingJob") @patch("sagemaker.modules.train.model_trainer.TemporaryDirectory") @patch("sagemaker.modules.train.model_trainer.resolve_value_from_config") -def test_train_with_distributed_runner( - mock_resolve_value_from_config, - mock_tmp_dir, - mock_training_job, - test_case, - request, - modules_session +def test_train_with_distributed_config( + mock_resolve_value_from_config, + mock_tmp_dir, + mock_training_job, + test_case, + request, + modules_session, ): mock_resolve_value_from_config.return_value = None modules_session.upload_data.return_value = ( @@ -567,7 +400,7 @@ def test_train_with_distributed_runner( mock_tmp_dir.return_value = tmp_dir expected_train_script_path = os.path.join(tmp_dir.name, TRAIN_SCRIPT) - expected_runner_json_path = os.path.join(tmp_dir.name, DISTRIBUTED_RUNNER_JSON) + expected_runner_json_path = os.path.join(tmp_dir.name, DISTRIBUTED_JSON) expected_source_code_json_path = os.path.join(tmp_dir.name, SOURCE_CODE_JSON) try: @@ -575,7 +408,7 @@ def test_train_with_distributed_runner( sagemaker_session=modules_session, training_image=DEFAULT_IMAGE, source_code=test_case["source_code"], - distributed_runner=test_case["distributed_runner"], + distributed=test_case["distributed"], ) model_trainer.train() @@ -592,7 +425,7 @@ def test_train_with_distributed_runner( assert os.path.exists(expected_runner_json_path) with open(expected_runner_json_path, "r") as f: runner_json_content = f.read() - assert test_case["distributed_runner"].model_dump(exclude_none=True) == ( + assert test_case["distributed"].model_dump(exclude_none=True) == ( json.loads(runner_json_content) ) assert os.path.exists(expected_source_code_json_path) @@ -618,3 +451,132 @@ def test_train_stores_created_training_job(mock_training_job, model_trainer): model_trainer.train(wait=False) assert model_trainer._latest_training_job is not None assert model_trainer._latest_training_job == TrainingJob(training_job_name="Created-job") + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_tensorboard_output_config(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + tensorboard_output_config = TensorBoardOutputConfig( + s3_output_path=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}", + local_path="/opt/ml/output/tensorboard", + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_tensorboard_output_config(tensorboard_output_config) + + assert model_trainer._tensorboard_output_config == tensorboard_output_config + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert ( + mock_training_job.create.call_args.kwargs["tensor_board_output_config"] + == tensorboard_output_config + ) + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_retry_strategy(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + retry_strategy = RetryStrategy( + maximum_retry_attempts=3, + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_retry_strategy(retry_strategy) + + assert model_trainer._retry_strategy == retry_strategy + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert mock_training_job.create.call_args.kwargs["retry_strategy"] == retry_strategy + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_infra_check_config(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + infra_check_config = InfraCheckConfig( + enable_infra_check=True, + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_infra_check_config(infra_check_config) + + assert model_trainer._infra_check_config == infra_check_config + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert mock_training_job.create.call_args.kwargs["infra_check_config"] == infra_check_config + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_session_chaining_config(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + session_chaining_config = SessionChainingConfig( + enable_session_tag_chaining=True, + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_session_chaining_config(session_chaining_config) + + assert model_trainer._session_chaining_config == session_chaining_config + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert ( + mock_training_job.create.call_args.kwargs["session_chaining_config"] + == session_chaining_config + ) + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_remote_debug_config(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + remote_debug_config = RemoteDebugConfig( + enable_remote_debug=True, + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_remote_debug_config(remote_debug_config) + + assert model_trainer._remote_debug_config == remote_debug_config + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert ( + mock_training_job.create.call_args.kwargs["remote_debug_config"] == remote_debug_config + )