Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Just do the partitioning/shuffling in the local leader worker. #96

Merged
merged 9 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from enum import IntEnum
from multiprocessing.shared_memory import SharedMemory
from threading import Thread
from time import sleep
from time import sleep, time
from typing import Any, Dict, Iterator, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(self,
self.validate_hash = validate_hash or None

if tdist.is_available() and not tdist.is_initialized() and torch.cuda.is_available() and \
hasattr(os.environ, 'RANK'):
'RANK' in os.environ:
tdist.init_process_group('nccl')

# Seed is set below.
Expand Down Expand Up @@ -348,13 +348,18 @@ def _get_progress(self, world: World) -> Tuple[int, int]:

return epoch, sample_in_epoch

def _get_partition(self, world: World, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]:
def _get_partition(self,
world: World,
epoch: int,
sample_in_epoch: int,
timeout: Optional[float] = 60) -> NDArray[np.int64]:
"""Get this worker's partition of this epoch's sample space.

Args:
world (World): World state.
epoch (int): Which epoch it is.
sample_in_epoch (int): Where we are in the epoch.
timeout (float): Max seconds to wait for the partitioning/shuffle to be generated.

Returns:
Optional[NDArray[np.int64]]: Our partition of the epoch.
Expand All @@ -366,14 +371,64 @@ def _get_partition(self, world: World, epoch: int, sample_in_epoch: int) -> NDAr
if self.shuffle_seed is None:
raise RuntimeError('Shuffle seed can never be None')

sample_ids = get_partitions(self.index.total_samples, self.num_canonical_nodes,
world.num_nodes, world.ranks_per_node, world.workers_per_rank,
self.batch_size, sample_in_epoch)
if self.shuffle:
mapping = get_shuffle(self.shard_sizes, self.num_canonical_nodes, self.shuffle_seed,
epoch)
sample_ids = np.where(sample_ids == -1, -1, mapping[sample_ids])
return sample_ids[world.node, world.rank_of_node, world.worker_of_rank]
# Decide where to save shuffle data.
tmp_filename = os.path.join(os.path.sep, 'tmp', 'streaming', self._prefix,
'shuffle.npy.tmp')
filename = os.path.join(os.path.sep, 'tmp', 'streaming', self._prefix, 'shuffle.npy')

# In the local leader, generate this epoch's global sample ordering, then save to file.
# Tensor shape: (num nodes, ranks per node, workers per rank, samples per worker).
# This operation is expensive.
if world.is_local_leader:
sample_ids = get_partitions(self.index.total_samples, self.num_canonical_nodes,
world.num_nodes, world.ranks_per_node,
world.workers_per_rank, self.batch_size, sample_in_epoch)
if self.shuffle:
mapping = get_shuffle(self.shard_sizes, self.num_canonical_nodes,
self.shuffle_seed, epoch)
sample_ids = np.where(sample_ids == -1, -1, mapping[sample_ids])
sample_ids.tofile(tmp_filename)
os.rename(tmp_filename, filename)

# Everyone waits for the file to become populated.
knighton marked this conversation as resolved.
Show resolved Hide resolved
t0 = time()
while True:
sleep(TICK)
if os.path.exists(filename):
sleep(TICK)
break
if timeout is not None:
dt = time() - t0
if timeout < dt:
raise RuntimeError('Partitioning and shuffling took too long, bailing out: ' +
f'{timeout:.3f} < {dt:.3f} sec.')

# Each worker loads its slice of the sample ID tensor to iterate through.
# Tensor shape: (num nodes, ranks per node, workers per rank, samples per worker).
sample_id_nbytes = np.int64().nbytes
num_bytes = os.path.getsize(filename)
if num_bytes % sample_id_nbytes:
raise ValueError(f'Generated shuffle is invalid: {filename} ({num_bytes} bytes).')
num_samples = num_bytes // sample_id_nbytes
num_workers = world.num_nodes * world.ranks_per_node * world.workers_per_rank
if num_samples % num_workers:
raise ValueError(f'Generated shuffle is invalid: {filename} ({num_bytes} bytes).')
samples_per_worker = num_samples // num_workers
offset_in_bytes = world.worker * samples_per_worker * sample_id_nbytes
bytes_to_read = samples_per_worker * sample_id_nbytes
with open(filename, 'rb', 0) as fp:
fp.seek(offset_in_bytes)
data = fp.read(bytes_to_read)
sample_ids = np.frombuffer(data, np.int64)

# Wait for everyone to read their part.
self._worker_barrier(world.workers_per_node)

# Now clean up after ourselves.
if world.is_local_leader:
os.remove(filename)

return sample_ids

def _download_file(self, basename: str) -> str:
"""Safely download a file from remote to local cache.
Expand All @@ -384,17 +439,21 @@ def _download_file(self, basename: str) -> str:
Returns:
str: Local cache filename.
"""
# Calculate paths.
if self.remote is None:
remote = None
else:
remote = os.path.join(self.remote, self.split, basename)
local = os.path.join(self.local, self.split, basename)

# Attempt to download, possibly repeating on faiure.
for _ in range(1 + self.download_retry):
try:
download(remote, local, self.download_timeout)
except:
continue
break

return local

def _decompress_shard_part(self, zip_info: FileInfo, zip_filename: str, raw_filename: str,
Expand Down
5 changes: 2 additions & 3 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,14 @@ def test_reader_download_fail(mds_dataset_dir: Any, missing_file: str):
os.remove(os.path.join(remote_dir, 'shard.00000.mds'))

# Build and iterate over a StreamingDataset
try:
with pytest.raises(FileNotFoundError) as exc_info:
dataset = StreamingDataset(local=local_dir,
remote=remote_dir,
shuffle=False,
download_timeout=1)
for _ in dataset:
pass
except FileNotFoundError as e:
logger.debug(f'Successfully raised error: {e}')
assert exc_info.match(r'.*No such file or directory*')


@pytest.mark.parametrize('created_ago', [0.5, 1.0])
Expand Down