Skip to content

Commit

Permalink
Update ModelTrainer Interface Parameters (#1617)
Browse files Browse the repository at this point in the history
  • Loading branch information
benieric authored and pintaoz-aws committed Dec 4, 2024
1 parent c015e3f commit 775a627
Show file tree
Hide file tree
Showing 16 changed files with 430 additions and 634 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ env/
**/_repack_script_launcher.sh
src/sagemaker/modules/train/container_drivers/sm_train.sh
src/sagemaker/modules/train/container_drivers/sourcecode.json
src/sagemaker/modules/train/container_drivers/distributed_runner.json
src/sagemaker/modules/train/container_drivers/distributed.json
tests/data/**/_repack_model.py
tests/data/experiment/sagemaker-dev-1.0.tar.gz
src/sagemaker/serve/tmp_workspace
70 changes: 14 additions & 56 deletions src/sagemaker/config/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,58 +664,19 @@ def _simple_path(*args: str):
"minLength": 20,
"maxLength": 2048,
},
"baseJobName": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
"sourceCode": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
"distributed_runner": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
"compute": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
"networking": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
"stoppingCondition": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
"trainingImage": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
"trainingImageConfig": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
"algorithmName": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
"outputDataConfig": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
"trainingInputMode": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
"environment": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
"hyperparameters": {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
"baseJobName": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
"sourceCode": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
"distributed": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
"compute": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
"networking": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
"stoppingCondition": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
"trainingImage": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
"trainingImageConfig": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
"algorithmName": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
"outputDataConfig": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
"trainingInputMode": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
"environment": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
"hyperparameters": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
},
PROPERTIES: {
SCHEMA_VERSION: {
Expand Down Expand Up @@ -769,10 +730,7 @@ def _simple_path(*args: str):
},
},
},
MODEL_TRAINER: {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: True
},
MODEL_TRAINER: {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True},
ESTIMATOR: {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: False,
Expand Down
14 changes: 0 additions & 14 deletions src/sagemaker/modules/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,7 @@
TrainingImageConfig,
TrainingRepositoryAuthConfig,
Tag,
MetricDefinition,
DebugHookConfig,
CollectionConfiguration,
DebugRuleConfiguration,
ExperimentConfig,
InfraCheckConfig,
ProfilerConfig,
ProfilerRuleConfiguration,
RemoteDebugConfig,
SessionChainingConfig,
InstanceGroup,
Expand All @@ -69,14 +62,7 @@
"TrainingImageConfig",
"TrainingRepositoryAuthConfig",
"Tag",
"MetricDefinition",
"DebugHookConfig",
"CollectionConfiguration",
"DebugRuleConfiguration",
"ExperimentConfig",
"InfraCheckConfig",
"ProfilerConfig",
"ProfilerRuleConfiguration",
"RemoteDebugConfig",
"SessionChainingConfig",
"InstanceGroup",
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/modules/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)

SOURCE_CODE_JSON = "sourcecode.json"
DISTRIBUTED_RUNNER_JSON = "distributed_runner.json"
DISTRIBUTED_JSON = "distributed.json"
TRAIN_SCRIPT = "sm_train.sh"

DEFAULT_CONTAINER_ENTRYPOINT = ["/bin/bash"]
Expand Down
17 changes: 9 additions & 8 deletions src/sagemaker/modules/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
return hyperparameters


class DistributedRunner(BaseModel):
"""Base class for DistributedRunner Class"""
class DistributedConfig(BaseModel):
"""Base class for distributed training configurations."""

_type: str = PrivateAttr()

Expand All @@ -84,11 +84,11 @@ def model_dump(self, *args, **kwargs):
return result


class Torchrun(DistributedRunner):
"""TorchDistributed.
class Torchrun(DistributedConfig):
"""Torchrun.
The Torchrun runner uses `torchrun` or `torch.distributed.launch` in the backend to
launch distributed training.
The Torchrun class configures a job that uses `torchrun` or
`torch.distributed.launch` in the backend to launch distributed training.
Attributes:
process_count_per_node (int):
Expand All @@ -104,10 +104,11 @@ class Torchrun(DistributedRunner):
smp: Optional["SMP"] = None


class MPI(DistributedRunner):
class MPI(DistributedConfig):
"""MPI.
The MPI runner uses `mpirun` in the backend to launch distributed training.
The MPI class configures a job that uses `mpirun` in the backend to launch
distributed training.
Attributes:
process_count_per_node (int):
Expand Down
8 changes: 0 additions & 8 deletions src/sagemaker/modules/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,6 @@
cat /opt/ml/input/config/inputdataconfig.json
echo
echo "/opt/ml/input/data/sm_drivers/sourcecode.json"
cat /opt/ml/input/data/sm_drivers/sourcecode.json
echo
echo "/opt/ml/input/data/sm_drivers/distributed_runner.json"
cat /opt/ml/input/data/sm_drivers/distributed_runner.json
echo
echo "Setting up environment variables"
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/scripts/environment.py
source /opt/ml/input/sm_training.env
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@
" compute=compute,\n",
" hyperparameters=hyperparameters,\n",
" source_code=source_code,\n",
" distributed_runner=torchrun,\n",
" distributed=torchrun,\n",
" base_job_name=f\"{alias}-distributed-case-2\",\n",
")"
]
Expand Down Expand Up @@ -498,7 +498,7 @@
" hyperparameters=hyperparameters,\n",
" environment=env,\n",
" source_code=source_code,\n",
" distributed_runner=mpi,\n",
" distributed=mpi,\n",
" base_job_name=f\"{alias}-distributed-case-3\",\n",
")"
]
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/modules/train/container_drivers/mpi_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from utils import (
logger,
read_source_code_json,
read_distributed_runner_json,
read_distributed_json,
read_hyperparameters_json,
hyperparameters_to_cli_args,
get_process_count,
Expand Down Expand Up @@ -59,7 +59,7 @@ def main():
"""
source_code = read_source_code_json()
distribution = read_distributed_runner_json()
distribution = read_distributed_json()
hyperparameters = read_hyperparameters_json()

sm_current_host = os.environ["SM_CURRENT_HOST"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from utils import (
logger,
read_source_code_json,
read_distributed_runner_json,
read_distributed_json,
read_hyperparameters_json,
hyperparameters_to_cli_args,
get_process_count,
Expand Down Expand Up @@ -66,7 +66,7 @@ def setup_env():
def create_commands():
"""Create the Torch Distributed command to execute"""
source_code = read_source_code_json()
distribution = read_distributed_runner_json()
distribution = read_distributed_json()
hyperparameters = read_hyperparameters_json()

process_count = get_process_count(distribution)
Expand Down
14 changes: 7 additions & 7 deletions src/sagemaker/modules/train/container_drivers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

USER_CODE_PATH = "/opt/ml/input/data/sm_code"
SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json"
DISTRIBUTED_RUNNER_JSON = "/opt/ml/input/data/sm_drivers/distributed_runner.json"
DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json"

HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json"

Expand Down Expand Up @@ -79,14 +79,14 @@ def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON):
return source_code_dict


def read_distributed_runner_json(distributed_json: Dict[str, Any] = DISTRIBUTED_RUNNER_JSON):
def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON):
"""Read the distribution config json file."""
try:
with open(distributed_json, "r") as f:
distributed_runner_dict = json.load(f) or {}
distributed_dict = json.load(f) or {}
except FileNotFoundError:
distributed_runner_dict = {}
return distributed_runner_dict
distributed_dict = {}
return distributed_dict


def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON):
Expand All @@ -99,10 +99,10 @@ def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAME
return hyperparameters_dict


def get_process_count(distributed_runner_dict: Dict[str, Any]) -> int:
def get_process_count(distributed_dict: Dict[str, Any]) -> int:
"""Get the number of processes to run on each node in the training job."""
return (
int(distributed_runner_dict.get("process_count_per_node", 0))
int(distributed_dict.get("process_count_per_node", 0))
or int(os.environ.get("SM_NUM_GPUS", 0))
or int(os.environ.get("SM_NUM_NEURONS", 0))
or 1
Expand Down
Loading

0 comments on commit 775a627

Please sign in to comment.