diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index e43e212e8..292405528 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -922,14 +922,18 @@ def resample_streams( sample_ids = np.concatenate(sample_ids).astype(np.int64) return shuffle_units, sample_ids - def _share_work(self, sample_ids: NDArray[np.int64]) -> Tuple[SharedMemory, SharedMemory]: + def _share_work( + self, + sample_ids: NDArray[np.int64], + ) -> Tuple[SharedMemory, Optional[SharedMemory]]: """Put an epoch's sample ordering into shared memory. Args: sample_ids (NDArray[np.int64]): Sample IDs. Returns: - Tuple[SharedMemory, SharedMemory]: Shared memory arrays containing shape and data. + Tuple[SharedMemory, Optional[SharedMemory]]: Shared memory arrays containing shape and + data, if present. """ ndim = 5 @@ -945,19 +949,26 @@ def _share_work(self, sample_ids: NDArray[np.int64]) -> Tuple[SharedMemory, Shar shape_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False) shape_shm.buf[:size] = np.array(sample_ids.shape, np.int64).tobytes() - # Save the generated epoch data to shared memory. - name = _get_path(self._shm_prefix_int, EPOCH_DATA) - size = sample_ids.size * np.int64().nbytes - data_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False) - data_shm.buf[:size] = sample_ids.tobytes() + if sample_ids.size > 0: + # Save the generated epoch data to shared memory, but only if the sample partition is + # non-empty. Otherwise, the end of the epoch has been reached. + name = _get_path(self._shm_prefix_int, EPOCH_DATA) + size = sample_ids.size * np.int64().nbytes + data_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False) + data_shm.buf[:size] = sample_ids.tobytes() - return shape_shm, data_shm + return shape_shm, data_shm - def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, SharedMemory]: + else: + + return shape_shm, None + + def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, Optional[SharedMemory]]: """Get an epoch's sample ordering from shared memory. Returns: - NDArray[np.int64]: Sample IDs. + Tuple[NDArray[np.int64], SharedMemory, Optional[SharedMemory]]: Sample IDs, shared + memory array for shape, and shared memory array for data, if present. """ ndim = 5 @@ -967,13 +978,22 @@ def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, SharedMemory]: shape_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False) shape = tuple(np.ndarray(5, buffer=shape_shm.buf, dtype=np.int64)) - # Attach to the generated epoch data in shared memory. - name = _get_path(self._shm_prefix_int, EPOCH_DATA) - size = int(np.prod(shape)) * np.int64().nbytes - data_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False) - sample_ids = np.ndarray(shape, buffer=data_shm.buf, dtype=np.int64) + num_elements = int(np.prod(shape)) + + if num_elements > 0: + # Attach to the generated epoch data in shared memory, but only if the sample partition + # is non-empty. Otherwise, the end of the epoch has been reached. + name = _get_path(self._shm_prefix_int, EPOCH_DATA) + size = num_elements * np.int64().nbytes + data_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False) + sample_ids = np.ndarray(shape, buffer=data_shm.buf, dtype=np.int64) + + return sample_ids, shape_shm, data_shm + + else: - return sample_ids, shape_shm, data_shm + sample_ids = np.empty(shape=shape, dtype=np.int64) + return sample_ids, shape_shm, None def _get_work(self, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]: """Get this worker's partition of this epoch's sample space. @@ -1025,7 +1045,9 @@ def _get_work(self, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]: # Now clean up after ourselves. shape_shm.cleanup() - data_shm.cleanup() + # Can be None if the sample partition was empty. + if data_shm is not None: + data_shm.cleanup() return worker_sample_ids diff --git a/streaming/base/partition/orig.py b/streaming/base/partition/orig.py index dff6d7878..ce8832cf5 100644 --- a/streaming/base/partition/orig.py +++ b/streaming/base/partition/orig.py @@ -46,7 +46,7 @@ def get_partitions_orig(num_samples: int, NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank, batches per worker, batch size). """ - if num_samples <= drop_first: + if num_samples < drop_first: raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' + f'({num_samples})') diff --git a/streaming/base/partition/relaxed.py b/streaming/base/partition/relaxed.py index e84bb7efc..1812b977a 100644 --- a/streaming/base/partition/relaxed.py +++ b/streaming/base/partition/relaxed.py @@ -49,7 +49,7 @@ def get_partitions_relaxed(num_samples: int, NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank, batches per worker, batch size). """ - if num_samples <= drop_first: + if num_samples < drop_first: raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' + f'({num_samples})') diff --git a/tests/test_partition.py b/tests/test_partition.py index 68d4ba8e1..aa26a63d1 100644 --- a/tests/test_partition.py +++ b/tests/test_partition.py @@ -38,6 +38,51 @@ def test_partition_walk(partition_algo: str): assert x.shape == (22, 8, 8, 1, 10) +@pytest.mark.parametrize('num_samples', [400, 1000]) +@pytest.mark.parametrize('num_canonical_nodes', [1, 4]) +@pytest.mark.parametrize('num_physical_nodes', [1, 4]) +@pytest.mark.parametrize('ranks_per_node', [1, 8]) +@pytest.mark.parametrize('workers_per_rank', [1, 8]) +@pytest.mark.parametrize('batch_size', [4]) +@pytest.mark.parametrize('partition_algo', ['orig', 'relaxed']) +def test_partition_drop_all(num_samples: int, num_canonical_nodes: int, num_physical_nodes: int, + ranks_per_node: int, workers_per_rank: int, batch_size: int, + partition_algo: str): + initial_physical_nodes = None + if partition_algo == 'relaxed' and num_canonical_nodes == 4 and ranks_per_node == 8: + num_canonical_nodes = 3 + initial_physical_nodes = 3 + batch_size = batch_size * 3 + num_samples = 3 * num_samples + + drop_first = num_samples + + x = get_partitions(partition_algo, num_samples, num_canonical_nodes, num_physical_nodes, + ranks_per_node, workers_per_rank, batch_size, drop_first, + initial_physical_nodes) + # Partition should still have the appropriate shape, but without any samples in it. + assert x.shape == (num_physical_nodes, ranks_per_node, workers_per_rank, 0, batch_size) + assert x.size == 0 + + +@pytest.mark.parametrize('num_samples', [400, 1000]) +@pytest.mark.parametrize('drop_additional', [1, 400]) +@pytest.mark.parametrize('num_canonical_nodes', [4]) +@pytest.mark.parametrize('num_physical_nodes', [4]) +@pytest.mark.parametrize('ranks_per_node', [8]) +@pytest.mark.parametrize('workers_per_rank', [8]) +@pytest.mark.parametrize('batch_size', [4]) +@pytest.mark.parametrize('partition_algo', ['orig', 'relaxed']) +def test_partition_invalid_drop_first(num_samples: int, drop_additional: int, + num_canonical_nodes: int, num_physical_nodes: int, + ranks_per_node: int, workers_per_rank: int, batch_size: int, + partition_algo: str): + drop_first = num_samples + drop_additional + with pytest.raises(ValueError, match=f'Resuming further into the dataset*'): + _ = get_partitions(partition_algo, num_samples, num_canonical_nodes, num_physical_nodes, + ranks_per_node, workers_per_rank, batch_size, drop_first) + + @pytest.mark.parametrize('num_samples', [1, 4]) @pytest.mark.parametrize('num_canonical_nodes', [1, 4]) @pytest.mark.parametrize('num_physical_nodes', [1, 4])