Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle training.model.tokenizer.model CLI arg #30

Merged
merged 5 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@
from .slurm_install_strategy import NeMoLauncherSlurmInstallStrategy
from .template import NeMoLauncher

REQUIRE_ENV_VARS = [
"NCCL_SOCKET_IFNAME",
"NCCL_IB_GID_INDEX",
"NCCL_IB_TC",
"NCCL_IB_QPS_PER_CONNECTION",
"UCX_IB_GID_INDEX",
"NCCL_IB_ADAPTIVE_ROUTING",
"NCCL_IB_SPLIT_DATA_ON_QPS",
"NCCL_IBEXT_DISABLE",
]


@StrategyRegistry.strategy(CommandGenStrategy, [SlurmSystem], [NeMoLauncher])
class NeMoLauncherSlurmCommandGenStrategy(SlurmCommandGenStrategy):
Expand Down Expand Up @@ -52,18 +63,7 @@ def gen_exec_command(
nodes: List[str],
) -> str:
# Ensure required environment variables are included
required_env_vars = [
"NCCL_SOCKET_IFNAME",
"NCCL_IB_GID_INDEX",
"NCCL_IB_TC",
"NCCL_IB_QPS_PER_CONNECTION",
"UCX_IB_GID_INDEX",
"NCCL_IB_ADAPTIVE_ROUTING",
"NCCL_IB_SPLIT_DATA_ON_QPS",
"NCCL_IBEXT_DISABLE",
]

for key in required_env_vars:
for key in REQUIRE_ENV_VARS:
if key not in extra_env_vars:
extra_env_vars[key] = self.slurm_system.global_env_vars[key]

Expand Down Expand Up @@ -98,6 +98,9 @@ def gen_exec_command(

if extra_cmd_args:
full_cmd += " " + extra_cmd_args
if "training.model.tokenizer.model" in extra_cmd_args:
tokenizer_path = extra_cmd_args.split("training.model.tokenizer.model=")[1].split(" ")[0]
full_cmd += " " + f"container_mounts=[{tokenizer_path}:{tokenizer_path}]"

env_vars_str = " ".join(f"{key}={value}" for key, value in extra_env_vars.items())
full_cmd = f"{env_vars_str} {full_cmd}" if env_vars_str else full_cmd
Expand Down
53 changes: 52 additions & 1 deletion tests/test_slurm_command_gen_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from cloudai.schema.system.slurm import SlurmNode, SlurmNodeState
from cloudai.schema.system.slurm.strategy import SlurmCommandGenStrategy
from cloudai.schema.test_template.nccl_test.slurm_command_gen_strategy import NcclTestSlurmCommandGenStrategy
from cloudai.schema.test_template.nemo_launcher.slurm_command_gen_strategy import NeMoLauncherSlurmCommandGenStrategy
from cloudai.schema.test_template.nemo_launcher.slurm_command_gen_strategy import (
REQUIRE_ENV_VARS,
NeMoLauncherSlurmCommandGenStrategy,
)


@pytest.fixture
Expand Down Expand Up @@ -94,3 +97,51 @@ def test_docker_image_url_is_file(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGe
Path(nemo_cmd_gen.final_cmd_args["docker_image_url"]).touch()
nemo_cmd_gen.set_container_arg()
assert nemo_cmd_gen.final_cmd_args["container"] == nemo_cmd_gen.final_cmd_args["docker_image_url"]


class TestNeMoLauncherSlurmCommandGenStrategy__GenExecCommand:
@pytest.fixture
def nemo_cmd_gen(self, slurm_system: SlurmSystem) -> NeMoLauncherSlurmCommandGenStrategy:
env_vars = {"TEST_VAR": "VALUE"}
cmd_args = {"test_arg": "test_value"}
strategy = NeMoLauncherSlurmCommandGenStrategy(slurm_system, env_vars, cmd_args)
return strategy

def test_raises_if_required_env_var_missed(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy):
with pytest.raises(KeyError) as exc_info:
nemo_cmd_gen.gen_exec_command(
env_vars={}, cmd_args={}, extra_env_vars={}, extra_cmd_args="", output_path="", nodes=[]
)
assert REQUIRE_ENV_VARS[0] in str(exc_info.value)

def test_extra_env_vars_added(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy):
extra_env_vars = {v: "fake" for v in REQUIRE_ENV_VARS}
cmd_args = {
"docker_image_url": "fake",
"repository_url": "fake",
"repository_commit_hash": "fake",
}
cmd = nemo_cmd_gen.gen_exec_command(
env_vars={}, cmd_args=cmd_args, extra_env_vars=extra_env_vars, extra_cmd_args="", output_path="", nodes=[]
)

for k, v in extra_env_vars.items():
assert f"{k}={v}" in cmd

def test_tokenizer_handled(self, nemo_cmd_gen: NeMoLauncherSlurmCommandGenStrategy):
extra_env_vars = {v: "fake" for v in REQUIRE_ENV_VARS}
cmd_args = {
"docker_image_url": "fake",
"repository_url": "fake",
"repository_commit_hash": "fake",
}
cmd = nemo_cmd_gen.gen_exec_command(
env_vars={},
cmd_args=cmd_args,
extra_env_vars=extra_env_vars,
extra_cmd_args="training.model.tokenizer.model=value",
output_path="",
nodes=[],
)

assert "container_mounts=[value:value]" in cmd