Skip to content

Commit

Permalink
less strict dist formatting (mosaicml#1535)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanlint authored and Bandish Shah committed Sep 19, 2022
1 parent 8dd05c2 commit e8f16a2
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 18 deletions.
8 changes: 6 additions & 2 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
log = logging.getLogger(__name__)


class MissingEnvironmentError(Exception):
pass


def _get_distributed_config_var(
env_var: str,
human_name: str,
Expand All @@ -91,8 +95,8 @@ def _get_distributed_config_var(
return int(os.environ[env_var])

if dist.is_initialized():
raise RuntimeError('Torch distributed is initialized but environment variable '
f'{env_var} is not set.')
raise MissingEnvironmentError('Torch distributed is initialized but environment variable '
f'{env_var} is not set.')

return default

Expand Down
48 changes: 32 additions & 16 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import re
import tempfile
import uuid
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import requests
import tqdm
Expand All @@ -34,6 +34,34 @@
]


def _get_dist_config(strict: bool = True) -> Dict[str, Any]:
"""Returns a dict of distributed settings (rank, world_size, etc.).
If ``strict=True``, will error if a setting is not available (e.g. the
environment variable is not set). Otherwise, will only return settings
that are availalbe.
"""
settings = {
'rank': dist.get_global_rank,
'local_rank': dist.get_local_rank,
'world_size': dist.get_world_size,
'local_world_size': dist.get_local_world_size,
'node_rank': dist.get_node_rank,
}

dist_config = {}
for name, func in settings.items():
try:
value = func()
except dist.MissingEnvironmentError as e:
if strict:
raise e
else:
dist_config[name] = value

return dist_config


def is_tar(name: Union[str, pathlib.Path]) -> bool:
"""Returns whether ``name`` has a tar-like extension.
Expand Down Expand Up @@ -89,11 +117,7 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]
pattern = pattern.replace(f'{{{unit}}}', f'(?P<{unit}>\\d+)')

# Format rank information
pattern = pattern.format(rank=dist.get_global_rank(),
local_rank=dist.get_local_rank(),
world_size=dist.get_world_size(),
local_world_size=dist.get_local_world_size(),
node_rank=dist.get_node_rank())
pattern = pattern.format(**_get_dist_config(strict=False))

template = re.compile(pattern)

Expand Down Expand Up @@ -143,11 +167,7 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]
def format_name_with_dist(format_str: str, run_name: str, **extra_format_kwargs: object): # noqa: D103
formatted_str = format_str.format(
run_name=run_name,
rank=dist.get_global_rank(),
local_rank=dist.get_local_rank(),
world_size=dist.get_world_size(),
local_world_size=dist.get_local_world_size(),
node_rank=dist.get_node_rank(),
**_get_dist_config(strict=False),
**extra_format_kwargs,
)
return formatted_str
Expand Down Expand Up @@ -240,11 +260,6 @@ def format_name_with_dist_and_time(
): # noqa: D103
formatted_str = format_str.format(
run_name=run_name,
rank=dist.get_global_rank(),
local_rank=dist.get_local_rank(),
world_size=dist.get_world_size(),
local_world_size=dist.get_local_world_size(),
node_rank=dist.get_node_rank(),
epoch=int(timestamp.epoch),
batch=int(timestamp.batch),
batch_in_epoch=int(timestamp.batch_in_epoch),
Expand All @@ -255,6 +270,7 @@ def format_name_with_dist_and_time(
total_wct=timestamp.total_wct.total_seconds(),
epoch_wct=timestamp.epoch_wct.total_seconds(),
batch_wct=timestamp.batch_wct.total_seconds(),
**_get_dist_config(strict=False),
**extra_format_kwargs,
)
return formatted_str
Expand Down
23 changes: 23 additions & 0 deletions tests/utils/test_file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from composer.utils.file_helpers import (ensure_folder_has_no_conflicting_files, ensure_folder_is_empty,
format_name_with_dist, format_name_with_dist_and_time, get_file, is_tar)
from composer.utils.object_store.libcloud_object_store import LibcloudObjectStore
from tests.common.markers import world_size


@pytest.mark.xfail(reason='Occassionally hits the timeout. Should refactor to use a local webserver.')
Expand Down Expand Up @@ -153,6 +154,28 @@ def test_format_name_with_dist():
assert format_name_with_dist(format_str, 'awesome_run', extra=42) == expected_str


@world_size(2)
def test_safe_format_name_with_dist(monkeypatch: pytest.MonkeyPatch, world_size):
"""node rank deleted, but not in format string, so format should complete."""
vars = ['run_name', 'world_size']
format_str = ','.join(f'{x}={{{x}}}' for x in vars)
expected_str = 'run_name=awesome_run,world_size=2'

monkeypatch.delenv('NODE_RANK')
assert format_name_with_dist(format_str, 'awesome_run') == expected_str


@world_size(2)
def test_unsafe_format_name_with_dist(monkeypatch: pytest.MonkeyPatch, world_size):
"""Node rank is deleted, but also in the format string, so expect error."""
vars = ['run_name', 'node_rank']
format_str = ','.join(f'{x}={{{x}}}' for x in vars)

monkeypatch.delenv('NODE_RANK')
with pytest.raises(KeyError):
assert format_name_with_dist(format_str, 'awesome_run') == 'run_name=awesome_run,node_rank=3'


def test_format_name_with_dist_and_time():
vars = [
'run_name',
Expand Down

0 comments on commit e8f16a2

Please sign in to comment.