Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable correct resumption from the end of an epoch #700

Merged
merged 11 commits into from
Jun 18, 2024
54 changes: 37 additions & 17 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,14 +922,16 @@ def resample_streams(
sample_ids = np.concatenate(sample_ids).astype(np.int64)
return shuffle_units, sample_ids

def _share_work(self, sample_ids: NDArray[np.int64]) -> Tuple[SharedMemory, SharedMemory]:
def _share_work(self,
sample_ids: NDArray[np.int64]) -> Tuple[SharedMemory, Optional[SharedMemory]]:
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
"""Put an epoch's sample ordering into shared memory.

Args:
sample_ids (NDArray[np.int64]): Sample IDs.

Returns:
Tuple[SharedMemory, SharedMemory]: Shared memory arrays containing shape and data.
Tuple[SharedMemory, Optional[SharedMemory]]: Shared memory arrays containing shape and
data, if present.
"""
ndim = 5

Expand All @@ -945,19 +947,26 @@ def _share_work(self, sample_ids: NDArray[np.int64]) -> Tuple[SharedMemory, Shar
shape_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False)
shape_shm.buf[:size] = np.array(sample_ids.shape, np.int64).tobytes()

# Save the generated epoch data to shared memory.
name = _get_path(self._shm_prefix_int, EPOCH_DATA)
size = sample_ids.size * np.int64().nbytes
data_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False)
data_shm.buf[:size] = sample_ids.tobytes()
if sample_ids.size > 0:
# Save the generated epoch data to shared memory, but only if the sample partition is
# non-empty. Otherwise, the end of the epoch has been reached.
name = _get_path(self._shm_prefix_int, EPOCH_DATA)
size = sample_ids.size * np.int64().nbytes
data_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False)
data_shm.buf[:size] = sample_ids.tobytes()

return shape_shm, data_shm
return shape_shm, data_shm

def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, SharedMemory]:
else:

return shape_shm, None

def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, Optional[SharedMemory]]:
"""Get an epoch's sample ordering from shared memory.

Returns:
NDArray[np.int64]: Sample IDs.
Tuple[NDArray[np.int64], SharedMemory, Optional[SharedMemory]]: Sample IDs, shared
memory array for shape, and shared memory array for data, if present.
"""
ndim = 5

Expand All @@ -967,13 +976,22 @@ def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, SharedMemory]:
shape_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False)
shape = tuple(np.ndarray(5, buffer=shape_shm.buf, dtype=np.int64))

# Attach to the generated epoch data in shared memory.
name = _get_path(self._shm_prefix_int, EPOCH_DATA)
size = int(np.prod(shape)) * np.int64().nbytes
data_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False)
sample_ids = np.ndarray(shape, buffer=data_shm.buf, dtype=np.int64)
num_elements = int(np.prod(shape))

if num_elements > 0:
# Attach to the generated epoch data in shared memory, but only if the sample partition
# is non-empty. Otherwise, the end of the epoch has been reached.
name = _get_path(self._shm_prefix_int, EPOCH_DATA)
size = num_elements * np.int64().nbytes
data_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False)
sample_ids = np.ndarray(shape, buffer=data_shm.buf, dtype=np.int64)

return sample_ids, shape_shm, data_shm

else:

return sample_ids, shape_shm, data_shm
sample_ids = np.empty(shape=shape, dtype=np.int64)
return sample_ids, shape_shm, None

def _get_work(self, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]:
"""Get this worker's partition of this epoch's sample space.
Expand Down Expand Up @@ -1025,7 +1043,9 @@ def _get_work(self, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]:

# Now clean up after ourselves.
shape_shm.cleanup()
data_shm.cleanup()
# Can be None if the sample partition was empty.
if data_shm is not None:
data_shm.cleanup()

return worker_sample_ids

Expand Down
2 changes: 1 addition & 1 deletion streaming/base/partition/orig.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_partitions_orig(num_samples: int,
NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank,
batches per worker, batch size).
"""
if num_samples <= drop_first:
if num_samples < drop_first:
raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' +
f'({num_samples})')

Expand Down
4 changes: 3 additions & 1 deletion streaming/base/partition/relaxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_partitions_relaxed(num_samples: int,
NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank,
batches per worker, batch size).
"""
if num_samples <= drop_first:
if num_samples < drop_first:
raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' +
f'({num_samples})')

Expand All @@ -65,6 +65,7 @@ def get_partitions_relaxed(num_samples: int,
return get_partitions_orig(num_samples, num_canonical_nodes, num_physical_nodes,
ranks_per_node, workers_per_rank, batch_size, drop_first)
else:
print('WE HERE')
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
# First, make a partition over the initial number of physical nodes and device batch size.
# We assume that ranks_per_node and workers_per_rank stay constant during resumptions.
global_batch_size = num_physical_nodes * ranks_per_node * batch_size
Expand All @@ -82,6 +83,7 @@ def get_partitions_relaxed(num_samples: int,
partition = get_partitions_orig(num_samples, num_canonical_nodes, initial_physical_nodes,
ranks_per_node, workers_per_rank, initial_batch_size,
drop_first)
print('ORIG PARTITION SHAPE: ', partition.shape)
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved

# Flatten the initial partition in order of traversal.
# partition was originally (nodes, ranks, workers, batches per worker, batch size)
Expand Down
45 changes: 45 additions & 0 deletions tests/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,51 @@ def test_partition_walk(partition_algo: str):
assert x.shape == (22, 8, 8, 1, 10)


@pytest.mark.parametrize('num_samples', [400, 1000])
@pytest.mark.parametrize('num_canonical_nodes', [1, 4])
@pytest.mark.parametrize('num_physical_nodes', [1, 4])
@pytest.mark.parametrize('ranks_per_node', [1, 8])
@pytest.mark.parametrize('workers_per_rank', [1, 8])
@pytest.mark.parametrize('batch_size', [4])
@pytest.mark.parametrize('partition_algo', ['orig', 'relaxed'])
def test_partition_drop_all(num_samples: int, num_canonical_nodes: int, num_physical_nodes: int,
ranks_per_node: int, workers_per_rank: int, batch_size: int,
partition_algo: str):
initial_physical_nodes = None
if partition_algo == 'relaxed' and num_canonical_nodes == 4 and ranks_per_node == 8:
num_canonical_nodes = 3
initial_physical_nodes = 3
batch_size = batch_size * 3
num_samples = 3 * num_samples

drop_first = num_samples

x = get_partitions(partition_algo, num_samples, num_canonical_nodes, num_physical_nodes,
ranks_per_node, workers_per_rank, batch_size, drop_first,
initial_physical_nodes)
# Partition should still have the appropriate shape, but without any samples in it.
assert x.shape == (num_physical_nodes, ranks_per_node, workers_per_rank, 0, batch_size)
assert x.size == 0


@pytest.mark.parametrize('num_samples', [400, 1000])
@pytest.mark.parametrize('drop_additional', [1, 400])
@pytest.mark.parametrize('num_canonical_nodes', [4])
@pytest.mark.parametrize('num_physical_nodes', [4])
@pytest.mark.parametrize('ranks_per_node', [8])
@pytest.mark.parametrize('workers_per_rank', [8])
@pytest.mark.parametrize('batch_size', [4])
@pytest.mark.parametrize('partition_algo', ['orig', 'relaxed'])
def test_partition_invalid_drop_first(num_samples: int, drop_additional: int,
num_canonical_nodes: int, num_physical_nodes: int,
ranks_per_node: int, workers_per_rank: int, batch_size: int,
partition_algo: str):
drop_first = num_samples + drop_additional
with pytest.raises(ValueError, match=f'Resuming further into the dataset*'):
_ = get_partitions(partition_algo, num_samples, num_canonical_nodes, num_physical_nodes,
ranks_per_node, workers_per_rank, batch_size, drop_first)


@pytest.mark.parametrize('num_samples', [1, 4])
@pytest.mark.parametrize('num_canonical_nodes', [1, 4])
@pytest.mark.parametrize('num_physical_nodes', [1, 4])
Expand Down
Loading