diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index bf2ed2698..4ed36140a 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -8,7 +8,7 @@ from enum import IntEnum from multiprocessing.shared_memory import SharedMemory from threading import Thread -from time import sleep +from time import sleep, time from typing import Any, Dict, Iterator, Optional, Tuple import numpy as np @@ -153,7 +153,7 @@ def __init__(self, self.validate_hash = validate_hash or None if tdist.is_available() and not tdist.is_initialized() and torch.cuda.is_available() and \ - hasattr(os.environ, 'RANK'): + 'RANK' in os.environ: tdist.init_process_group('nccl') # Seed is set below. @@ -348,13 +348,18 @@ def _get_progress(self, world: World) -> Tuple[int, int]: return epoch, sample_in_epoch - def _get_partition(self, world: World, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]: + def _get_partition(self, + world: World, + epoch: int, + sample_in_epoch: int, + timeout: Optional[float] = 60) -> NDArray[np.int64]: """Get this worker's partition of this epoch's sample space. Args: world (World): World state. epoch (int): Which epoch it is. sample_in_epoch (int): Where we are in the epoch. + timeout (float): Max seconds to wait for the partitioning/shuffle to be generated. Returns: Optional[NDArray[np.int64]]: Our partition of the epoch. @@ -366,14 +371,64 @@ def _get_partition(self, world: World, epoch: int, sample_in_epoch: int) -> NDAr if self.shuffle_seed is None: raise RuntimeError('Shuffle seed can never be None') - sample_ids = get_partitions(self.index.total_samples, self.num_canonical_nodes, - world.num_nodes, world.ranks_per_node, world.workers_per_rank, - self.batch_size, sample_in_epoch) - if self.shuffle: - mapping = get_shuffle(self.shard_sizes, self.num_canonical_nodes, self.shuffle_seed, - epoch) - sample_ids = np.where(sample_ids == -1, -1, mapping[sample_ids]) - return sample_ids[world.node, world.rank_of_node, world.worker_of_rank] + # Decide where to save shuffle data. + tmp_filename = os.path.join(os.path.sep, 'tmp', 'streaming', self._prefix, + 'shuffle.npy.tmp') + filename = os.path.join(os.path.sep, 'tmp', 'streaming', self._prefix, 'shuffle.npy') + + # In the local leader, generate this epoch's global sample ordering, then save to file. + # Tensor shape: (num nodes, ranks per node, workers per rank, samples per worker). + # This operation is expensive. + if world.is_local_leader: + sample_ids = get_partitions(self.index.total_samples, self.num_canonical_nodes, + world.num_nodes, world.ranks_per_node, + world.workers_per_rank, self.batch_size, sample_in_epoch) + if self.shuffle: + mapping = get_shuffle(self.shard_sizes, self.num_canonical_nodes, + self.shuffle_seed, epoch) + sample_ids = np.where(sample_ids == -1, -1, mapping[sample_ids]) + sample_ids.tofile(tmp_filename) + os.rename(tmp_filename, filename) + + # Everyone waits for the file to become populated. + t0 = time() + while True: + sleep(TICK) + if os.path.exists(filename): + sleep(TICK) + break + if timeout is not None: + dt = time() - t0 + if timeout < dt: + raise RuntimeError('Partitioning and shuffling took too long, bailing out: ' + + f'{timeout:.3f} < {dt:.3f} sec.') + + # Each worker loads its slice of the sample ID tensor to iterate through. + # Tensor shape: (num nodes, ranks per node, workers per rank, samples per worker). + sample_id_nbytes = np.int64().nbytes + num_bytes = os.path.getsize(filename) + if num_bytes % sample_id_nbytes: + raise ValueError(f'Generated shuffle is invalid: {filename} ({num_bytes} bytes).') + num_samples = num_bytes // sample_id_nbytes + num_workers = world.num_nodes * world.ranks_per_node * world.workers_per_rank + if num_samples % num_workers: + raise ValueError(f'Generated shuffle is invalid: {filename} ({num_bytes} bytes).') + samples_per_worker = num_samples // num_workers + offset_in_bytes = world.worker * samples_per_worker * sample_id_nbytes + bytes_to_read = samples_per_worker * sample_id_nbytes + with open(filename, 'rb', 0) as fp: + fp.seek(offset_in_bytes) + data = fp.read(bytes_to_read) + sample_ids = np.frombuffer(data, np.int64) + + # Wait for everyone to read their part. + self._worker_barrier(world.workers_per_node) + + # Now clean up after ourselves. + if world.is_local_leader: + os.remove(filename) + + return sample_ids def _download_file(self, basename: str) -> str: """Safely download a file from remote to local cache. @@ -384,17 +439,21 @@ def _download_file(self, basename: str) -> str: Returns: str: Local cache filename. """ + # Calculate paths. if self.remote is None: remote = None else: remote = os.path.join(self.remote, self.split, basename) local = os.path.join(self.local, self.split, basename) + + # Attempt to download, possibly repeating on faiure. for _ in range(1 + self.download_retry): try: download(remote, local, self.download_timeout) except: continue break + return local def _decompress_shard_part(self, zip_info: FileInfo, zip_filename: str, raw_filename: str, diff --git a/tests/test_reader.py b/tests/test_reader.py index 5cb262e3b..a03879740 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -138,15 +138,14 @@ def test_reader_download_fail(mds_dataset_dir: Any, missing_file: str): os.remove(os.path.join(remote_dir, 'shard.00000.mds')) # Build and iterate over a StreamingDataset - try: + with pytest.raises(FileNotFoundError) as exc_info: dataset = StreamingDataset(local=local_dir, remote=remote_dir, shuffle=False, download_timeout=1) for _ in dataset: pass - except FileNotFoundError as e: - logger.debug(f'Successfully raised error: {e}') + assert exc_info.match(r'.*No such file or directory*') @pytest.mark.parametrize('created_ago', [0.5, 1.0])