Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Just do the partitioning/shuffling in the local leader worker. #96

Merged
merged 9 commits into from
Dec 21, 2022
86 changes: 75 additions & 11 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from enum import IntEnum
from multiprocessing.shared_memory import SharedMemory
from threading import Thread
from time import sleep
from time import sleep, time
from typing import Any, Dict, Iterator, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -348,13 +348,18 @@ def _get_progress(self, world: World) -> Tuple[int, int]:

return epoch, sample_in_epoch

def _get_partition(self, world: World, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]:
def _get_partition(self,
world: World,
epoch: int,
sample_in_epoch: int,
timeout: Optional[float] = 60) -> NDArray[np.int64]:
"""Get this worker's partition of this epoch's sample space.

Args:
world (World): World state.
epoch (int): Which epoch it is.
sample_in_epoch (int): Where we are in the epoch.
timeout (float): Max seconds to wait for the partitioning/shuffle to be generated.

Returns:
Optional[NDArray[np.int64]]: Our partition of the epoch.
Expand All @@ -366,14 +371,64 @@ def _get_partition(self, world: World, epoch: int, sample_in_epoch: int) -> NDAr
if self.shuffle_seed is None:
raise RuntimeError('Shuffle seed can never be None')

sample_ids = get_partitions(self.index.total_samples, self.num_canonical_nodes,
world.num_nodes, world.ranks_per_node, world.workers_per_rank,
self.batch_size, sample_in_epoch)
if self.shuffle:
mapping = get_shuffle(self.shard_sizes, self.num_canonical_nodes, self.shuffle_seed,
epoch)
sample_ids = np.where(sample_ids == -1, -1, mapping[sample_ids])
return sample_ids[world.node, world.rank_of_node, world.worker_of_rank]
# Decide where to save shuffle data.
tmp_filename = os.path.join(os.path.sep, 'tmp', 'streaming', self._prefix,
'shuffle.npy.tmp')
filename = os.path.join(os.path.sep, 'tmp', 'streaming', self._prefix, 'shuffle.npy')

# In the local leader, generate this epoch's global sample ordering, then save to file.
# Tensor shape: (num nodes, ranks per node, workers per rank, samples per worker).
# This operation is expensive.
if world.is_local_leader:
sample_ids = get_partitions(self.index.total_samples, self.num_canonical_nodes,
world.num_nodes, world.ranks_per_node,
world.workers_per_rank, self.batch_size, sample_in_epoch)
if self.shuffle:
mapping = get_shuffle(self.shard_sizes, self.num_canonical_nodes,
self.shuffle_seed, epoch)
sample_ids = np.where(sample_ids == -1, -1, mapping[sample_ids])
sample_ids.tofile(tmp_filename)
os.rename(tmp_filename, filename)

# Everyone waits for the file to become populated.
knighton marked this conversation as resolved.
Show resolved Hide resolved
t0 = time()
while True:
sleep(TICK)
if os.path.exists(filename):
sleep(TICK)
break
if timeout is not None:
dt = time() - t0
if timeout < dt:
raise RuntimeError('Partitioning and shuffling took too long, bailing out: ' +
f'{timeout:.3f} < {dt:.3f} sec.')

# Each worker loads its slice of the sample ID tensor to iterate through.
# Tensor shape: (num nodes, ranks per node, workers per rank, samples per worker).
sample_id_nbytes = np.int64().nbytes
num_bytes = os.path.getsize(filename)
if num_bytes % sample_id_nbytes:
raise ValueError(f'Generated shuffle is invalid: {filename} ({num_bytes} bytes).')
num_samples = num_bytes // sample_id_nbytes
num_workers = world.num_nodes * world.ranks_per_node * world.workers_per_rank
if num_samples % num_workers:
raise ValueError(f'Generated shuffle is invalid: {filename} ({num_bytes} bytes).')
samples_per_worker = num_samples // num_workers
offset_in_bytes = world.worker * samples_per_worker * sample_id_nbytes
bytes_to_read = samples_per_worker * sample_id_nbytes
with open(filename, 'rb', 0) as fp:
fp.seek(offset_in_bytes)
data = fp.read(bytes_to_read)
sample_ids = np.frombuffer(data, np.int64)

# Wait for everyone to read their part.
self._worker_barrier(world.workers_per_node)

# Now clean up after ourselves.
if world.is_local_leader:
os.remove(filename)

return sample_ids

def _download_file(self, basename: str) -> str:
"""Safely download a file from remote to local cache.
Expand All @@ -384,17 +439,26 @@ def _download_file(self, basename: str) -> str:
Returns:
str: Local cache filename.
"""
# Calculate paths.
if self.remote is None:
remote = None
else:
remote = os.path.join(self.remote, self.split, basename)
local = os.path.join(self.local, self.split, basename)

# Attempt to download, possibly repeating on faiure.
for _ in range(1 + self.download_retry):
try:
download(remote, local, self.download_timeout)
except:
except Exception as e:
print('Download exception:', e)
knighton marked this conversation as resolved.
Show resolved Hide resolved
continue
break

# Verify the local file exists.
if not os.path.exists(local):
raise RuntimeError(f'Download has failed: {remote} to {local}.')

return local

def _decompress_shard_part(self, zip_info: FileInfo, zip_filename: str, raw_filename: str,
Expand Down