Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replicating samples across devices (SP / TP enablement) #597

Merged
merged 13 commits into from
Feb 22, 2024
2 changes: 1 addition & 1 deletion streaming/base/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def state_dict(self) -> Optional[Dict[str, Any]]:
Optional[Dict[str, Any]]: The state, if a streaming dataset.
"""
if isinstance(self.dataset, StreamingDataset):
world = World()
world = World.detect()
num_samples = self.num_samples_yielded * world.num_ranks
return self.dataset.state_dict(num_samples, False)
return None
Expand Down
29 changes: 19 additions & 10 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ class StreamingDataset(Array, IterableDataset):
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
execution during deserialization, whether to keep going if ``True`` or raise an error
if ``False``. Defaults to ``False``.
world (World, optional): Override the world state (otherwise it is detected from standard
env vars). Defaults to ``None``.
tensor_parallelism (int, optional): Tensor parallelism to apply to the given or detected
vchiley marked this conversation as resolved.
Show resolved Hide resolved
world state. Defaults to ``None``.
"""

def __init__(self,
Expand All @@ -333,7 +337,9 @@ def __init__(self,
shuffle_seed: int = 9176,
shuffle_block_size: Optional[int] = None,
batching_method: str = 'random',
allow_unsafe_types: bool = False) -> None:
allow_unsafe_types: bool = False,
world: Optional[World] = None,
tensor_parallelism: Optional[int] = None) -> None:
# Global arguments (which do not live in Streams).
self.predownload = predownload
self.cache_limit = cache_limit
Expand All @@ -349,6 +355,16 @@ def __init__(self,
self.batching_method = batching_method
self.allow_unsafe_types = allow_unsafe_types

# Initialize the World context.
#
# Beware: This information is for the per-rank process. DataLoader worker processes may see
# different values for these fields. We are saving the rank World here because we cannot
# instantiate a World inside the StreamingDataset destructor.
world = world or World.detect()
if tensor_parallelism is not None:
world = world.tensor_parallel(tensor_parallelism)
self._rank_world = world

# Initialize initial_physical_nodes to None. If we are resuming, then we will set it to the
# number of physical nodes of the initial run in the _resume function.
self.initial_physical_nodes = None
Expand Down Expand Up @@ -443,13 +459,6 @@ def __init__(self,
self.streams = streams
self.num_streams = len(streams)

# Initialize the World context.
#
# Beware: This information is for the per-rank process. DataLoader worker processes may see
# different values for these fields. We are saving the rank World here because we cannot
# instantiate a World inside the StreamingDataset destructor.
self._rank_world = world = World()

# Download each stream's index, load their shards, and map streams <-> shards.
self.num_samples = 0
self.shards = []
Expand Down Expand Up @@ -771,7 +780,7 @@ def state_dict(self, num_samples: int, from_beginning: bool) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The state.
"""
world = World()
world = self._rank_world
epoch = self.next_epoch - 1
epoch, offset = self._resume(world, epoch)
if from_beginning:
Expand Down Expand Up @@ -1437,7 +1446,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:

# Discover where we left off, if there is a checkpoint, or start at the next epoch.
# Also pre-increment the epoch counter.
world = World()
world = self._rank_world.detect_workers()
epoch, sample_in_epoch = self._resume_incr_epoch(world)

# Get this worker's partition of samples to process.
Expand Down
118 changes: 98 additions & 20 deletions streaming/base/world.py
vchiley marked this conversation as resolved.
Show resolved Hide resolved
vchiley marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

"""Information about nodes, ranks, and workers."""

from typing import Any, Dict, Tuple

from torch.utils.data import get_worker_info
from typing_extensions import Self

from streaming.base import distributed as dist

Expand Down Expand Up @@ -37,28 +40,103 @@ class World:
- is_local_leader
"""

def __init__(self):
self.rank = dist.get_rank()
self.num_ranks = dist.get_world_size()
self.ranks_per_node = dist.get_local_world_size()
self.rank_of_node = self.rank % self.ranks_per_node
def __init__(
self,
num_nodes: int,
ranks_per_node: int,
workers_per_rank: int,
worker: int,
) -> None:
self.node = worker // (ranks_per_node * workers_per_rank)
self.num_nodes = num_nodes
self.is_multinode = 1 < num_nodes

self.rank = worker // workers_per_rank
self.num_ranks = num_nodes * ranks_per_node
self.rank_of_node = self.rank % ranks_per_node
self.ranks_per_node = ranks_per_node

self.worker = worker
self.num_workers = num_nodes * ranks_per_node * workers_per_rank
self.worker_of_node = self.worker % (ranks_per_node * workers_per_rank)
self.workers_per_node = ranks_per_node * workers_per_rank
self.worker_of_rank = self.worker % workers_per_rank
self.workers_per_rank = workers_per_rank
self.is_leader = not worker
self.is_local_leader = not self.worker_of_node

def to_json(self) -> Dict[str, Any]:
"""Get a JSON version of this config.

Returns:
Dict[str, Any]: JSON config.
"""
return dict(self.__dict__)

self.node = self.rank // self.ranks_per_node
self.num_nodes = self.num_ranks // self.ranks_per_node
self.is_multinode = 1 < self.num_nodes
@classmethod
def _get_worker_info(cls) -> Tuple[int, int]:
"""Get worker info, or default to 0 of 1.

Returns:
Tuple[int, int]: Worker ID out of how many workers.
"""
info = get_worker_info()
if info:
self.worker_of_rank = info.id
self.workers_per_rank = info.num_workers
ret = info.id, info.num_workers
else:
self.worker_of_rank = 0
self.workers_per_rank = 1

self.worker = self.rank * self.workers_per_rank + self.worker_of_rank
self.num_workers = self.num_ranks * self.workers_per_rank
self.worker_of_node = self.rank_of_node * self.workers_per_rank + self.worker_of_rank
self.workers_per_node = self.ranks_per_node * self.workers_per_rank

self.is_leader = not self.worker
self.is_local_leader = not self.worker_of_node
ret = 0, 1
return ret

@classmethod
def detect(cls) -> Self:
"""Detect the world state.

Returns:
Self: A new World state object according to dist and get_worker_info().
"""
rank = dist.get_rank()
ranks_per_node = dist.get_local_world_size()
num_nodes = dist.get_world_size() // ranks_per_node
worker_of_rank, workers_per_rank = cls._get_worker_info()
worker = rank * workers_per_rank + worker_of_rank
return cls(num_nodes, ranks_per_node, workers_per_rank, worker)

def tensor_parallel(self, ratio: int) -> Self:
"""Get a copy of this world state with the given tensor paralellism.

Args:
ratio (int): Ratio of tensor parallelism.

Returns:
Self: A new tensor-parallel version of this World state object.
"""
if 0 <= ratio:
knighton marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f'Tensor parallelism ratio must be postiive.')

if self.ranks_per_node % ratio:
raise ValueError(f'Ranks per node must be divisible by your tensor parallelism ratio.')

rank_of_node = self.rank_of_node // ratio
ranks_per_node = self.ranks_per_node // ratio
worker = rank_of_node * self.workers_per_rank + self.worker_of_rank
return World(
num_nodes=self.num_nodes,
ranks_per_node=ranks_per_node,
workers_per_rank=self.workers_per_rank,
worker=worker,
)

def detect_workers(self) -> Self:
"""Get a copy of this world state with the worker information newly detected.

Returns:
Self: A new workers-newly-detected version of this World state object.
"""
worker_of_rank, workers_per_rank = self._get_worker_info()
worker = self.rank * workers_per_rank + worker_of_rank
return World(
num_nodes=self.num_nodes,
ranks_per_node=self.ranks_per_node,
workers_per_rank=workers_per_rank,
worker=worker,
)
12 changes: 6 additions & 6 deletions tests/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def test_get_shm_prefix(local_remote_dir: Tuple[str, str]):
local, remote = local_remote_dir

_, _ = get_shm_prefix(streams_local=[local], streams_remote=[remote], world=World())
_, _ = get_shm_prefix(streams_local=[local], streams_remote=[remote], world=World.detect())


@pytest.mark.usefixtures('local_remote_dir')
Expand All @@ -24,25 +24,25 @@ def test_get_shm_prefix_same_local_dir(local_remote_dir: Tuple[str, str]):
with pytest.raises(ValueError, match='Reused local directory.*Provide a different one.'):
_, _ = get_shm_prefix(streams_local=[local, local],
streams_remote=[remote, remote],
world=World())
world=World.detect())


@pytest.mark.usefixtures('local_remote_dir')
def test_get_shm_prefix_same_split_dir(local_remote_dir: Tuple[str, str]):
local, remote = local_remote_dir
_, _ = get_shm_prefix(streams_local=[local, remote],
streams_remote=[local, remote],
world=World())
world=World.detect())
with pytest.raises(ValueError, match='Reused local directory.*vs.*Provide a different one.'):
_, _ = get_shm_prefix(streams_local=[local, remote],
streams_remote=[local, remote],
world=World())
world=World.detect())


def test_same_local_remote_none(local_remote_dir: Tuple[str, str]):
local, _ = local_remote_dir
_, _ = get_shm_prefix(streams_local=[local], streams_remote=[None], world=World())
_, _ = get_shm_prefix(streams_local=[local], streams_remote=[None], world=World())
_, _ = get_shm_prefix(streams_local=[local], streams_remote=[None], world=World.detect())
_, _ = get_shm_prefix(streams_local=[local], streams_remote=[None], world=World.detect())


@pytest.mark.parametrize('from_beginning', [True, False])
Expand Down
Loading