Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better default values for StreamingDataset args #479

Merged
merged 12 commits into from
Oct 27, 2023
3 changes: 3 additions & 0 deletions streaming/base/batching/per_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, e
# same as the ratio of the stream's samples to overall samples.
# This ensures that the overall training shuffle block size is still approximately
# equal to what is set by the user, and allows for reasoning about cache_limit as well.
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {dataset.shuffle_block_size} instead.')
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion)
stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units,
dataset.num_canonical_nodes, dataset.shuffle_seed, epoch,
Expand Down
3 changes: 3 additions & 0 deletions streaming/base/batching/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def generate_work_random_batching(dataset: StreamingDataset, world: World, epoch

# If we need to shuffle, shuffle in a node-aware and *underlying* shard-aware way.
if dataset.shuffle:
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {dataset.shuffle_block_size} instead.')
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes,
dataset.shuffle_seed, epoch, dataset.shuffle_block_size)
big_ids = np.where(big_ids != -1, shuffle[big_ids], -1)
Expand Down
3 changes: 3 additions & 0 deletions streaming/base/batching/stratified.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def generate_work_stratified_batching(dataset: StreamingDataset, world: World, e
# same as the ratio of the stream's samples to overall samples.
# This ensures that the overall training shuffle block size is still approximately
# equal to what is set by the user, and allows for reasoning about cache_limit as well.
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {dataset.shuffle_block_size} instead.')
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion)
stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units,
dataset.num_canonical_nodes, dataset.shuffle_seed, epoch,
Expand Down
57 changes: 33 additions & 24 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,7 @@ class StreamingDataset(Array, IterableDataset):
of current sample. Workers will attempt to download ahead by this many samples during,
but not before, training. Recommendation is to provide a value greater than per device
batch size to ensure at-least per device batch size number of samples cached locally.
If ``None``, its value gets derived using per device batch size and number of
canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``.
Defaults to ``None``.
If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``.
cache_limit (Union[int, str], optional): Maximum size in bytes of this StreamingDataset's
shard cache. Before downloading a shard, the least recently used resident shard(s)
may be evicted (deleted from the local cache) in order to stay under the limit.
Expand All @@ -280,13 +278,14 @@ class StreamingDataset(Array, IterableDataset):
how many samples to pick from the same shard at a time (``1`` for evenly balanced
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
Defaults to ``1``.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``relaxed``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. The sample space is divided evenly according to the number of canonical
nodes. The higher the value, the more independent non-overlapping paths the
StreamingDataset replicas take through the shards per model replica (increasing data
source diversity). Defaults to ``None``, which is interpreted as 64 times the number
of nodes of the initial run.
source diversity). If ``None``, this is interpreted as 64 times the number of physical
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the
number of physical nodes of the initial run otherwise. Defaults to ``None``.

.. note::

Expand All @@ -296,10 +295,12 @@ class StreamingDataset(Array, IterableDataset):
partitioned over the workers. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
shuffle_seed (int): Seed for deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. A canonical node's samples are split into blocks
of this size, and samples within each block are shuffled. Defaults to ``1 << 18``.
shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split
into blocks of this size, and samples within each block are shuffled. If ``None``, its
value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason for 4_000_000? Just wondering

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was based on the shuffle quality experiments, where a shuffle strength of 4_000_000 (i.e., each batch is drawn from 4_000_000 samples) gave good shuffle quality without making the number of downloads required too high.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about max(1 << 18, 1 << 22 // NCN) lol?

Copy link
Collaborator Author

@snarayan21 snarayan21 Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works too, it's essentially the same right? Gonna keep as-is for now

``None``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
"""
Expand All @@ -319,13 +320,13 @@ def __init__(self,
cache_limit: Optional[Union[int, str]] = None,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
partition_algo: str = 'orig',
partition_algo: str = 'relaxed',
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1s',
shuffle_algo: str = 'py1e',
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
shuffle_block_size: Optional[int] = None,
batching_method: str = 'random') -> None:
# Global arguments (which do not live in Streams).
self.predownload = predownload
Expand Down Expand Up @@ -377,12 +378,15 @@ def __init__(self,
raise ValueError(f'`shuffle_seed` must be a non-negative integer, but got: ' +
f'{self.shuffle_seed}.')

# Check that predownload is at least per device batch size.
# Check that predownload is at least per device batch size, and set it if currently None.
if self.predownload is not None and self.batch_size is not None and \
self.predownload < self.batch_size:
warnings.warn(f'predownload < batch_size ({self.predownload} < {self.batch_size}).' +
f'This may result in slower batch time. Recommendation is to set ' +
f'predownload to at-least batch_size.')
elif self.predownload is None:
self.predownload = 8 * self.batch_size if self.batch_size is not None else 64

# Convert epoch size from string to int, if needed. Cannot be negative.
epoch_size_value = None
if epoch_size:
Expand Down Expand Up @@ -632,12 +636,11 @@ def __len__(self) -> int:
"""
return self.length

def _set_predownload(self) -> None:
"""Set the predownload value which is per number of workers."""
if self.predownload is None:
self.predownload = max(
self.batch_size, 256 * self.batch_size // self.num_canonical_nodes
) if self.batch_size is not None and self.num_canonical_nodes is not None else 512
def _set_shuffle_block_size(self):
"""Set the shuffle block size value."""
if self.shuffle_block_size is None:
self.shuffle_block_size = max(4_000_000 // self.num_canonical_nodes, 1 << 18) \
if self.num_canonical_nodes is not None else 1 << 18

def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
"""Either resume from checkpoint or start at the beginning.
Expand All @@ -656,8 +659,11 @@ def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
except FileNotFoundError:
# There is nothing to resume.
if not self.num_canonical_nodes:
self.num_canonical_nodes = world.num_nodes * 64
self._set_predownload()
if self.shuffle_algo in ['py1s', 'py2s']:
self.num_canonical_nodes = 64 * world.num_nodes
else:
self.num_canonical_nodes = world.num_nodes
self._set_shuffle_block_size()
return epoch, 0

# SharedMemory buffers may contain additional null bytes at the end.
Expand All @@ -669,16 +675,19 @@ def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
# Check if the resume state is stale.
if obj['epoch'] < epoch:
if not self.num_canonical_nodes:
self.num_canonical_nodes = world.num_nodes * 64
self._set_predownload()
if self.shuffle_algo in ['py1s', 'py2s']:
self.num_canonical_nodes = 64 * world.num_nodes
else:
self.num_canonical_nodes = world.num_nodes
self._set_shuffle_block_size()
return epoch, 0

# Load the correct resumption meta data.
epoch = obj['epoch']
sample_in_epoch = obj['sample_in_epoch']
self.num_canonical_nodes = obj['num_canonical_nodes']
self.shuffle_seed = obj['shuffle_seed']
self._set_predownload()
self._set_shuffle_block_size()

return epoch, sample_in_epoch

Expand Down
Loading