Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better default values for StreamingDataset args #479

Merged
merged 12 commits into from
Oct 27, 2023
18 changes: 10 additions & 8 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ jobs:
id: tests
run: |
set -ex
pytest --splits 8 --group 1 --cov-fail-under=10
pytest --splits 8 --group 2 --cov-fail-under=10
pytest --splits 8 --group 3 --cov-fail-under=10
pytest --splits 8 --group 4 --cov-fail-under=10
pytest --splits 8 --group 5 --cov-fail-under=10
pytest --splits 8 --group 6 --cov-fail-under=10
pytest --splits 8 --group 7 --cov-fail-under=10
pytest --splits 8 --group 8 --cov-fail-under=10
pytest --splits 10 --group 1 --cov-fail-under=10
pytest --splits 10 --group 2 --cov-fail-under=10
pytest --splits 10 --group 3 --cov-fail-under=10
pytest --splits 10 --group 4 --cov-fail-under=10
pytest --splits 10 --group 5 --cov-fail-under=10
pytest --splits 10 --group 6 --cov-fail-under=10
pytest --splits 10 --group 7 --cov-fail-under=10
pytest --splits 10 --group 8 --cov-fail-under=10
pytest --splits 10 --group 9 --cov-fail-under=10
pytest --splits 10 --group 10 --cov-fail-under=10
3 changes: 3 additions & 0 deletions streaming/base/batching/per_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, e
# same as the ratio of the stream's samples to overall samples.
# This ensures that the overall training shuffle block size is still approximately
# equal to what is set by the user, and allows for reasoning about cache_limit as well.
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {type(dataset.shuffle_block_size)} instead.')
shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion)
stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units,
dataset.num_canonical_nodes, dataset.shuffle_seed, epoch,
Expand Down
3 changes: 3 additions & 0 deletions streaming/base/batching/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def generate_work_random_batching(dataset: StreamingDataset, world: World, epoch

# If we need to shuffle, shuffle in a node-aware and *underlying* shard-aware way.
if dataset.shuffle:
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {type(dataset.shuffle_block_size)} instead.')
shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes,
dataset.shuffle_seed, epoch, dataset.shuffle_block_size)
big_ids = np.where(big_ids != -1, shuffle[big_ids], -1)
Expand Down
3 changes: 3 additions & 0 deletions streaming/base/batching/stratified.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def generate_work_stratified_batching(dataset: StreamingDataset, world: World, e
# same as the ratio of the stream's samples to overall samples.
# This ensures that the overall training shuffle block size is still approximately
# equal to what is set by the user, and allows for reasoning about cache_limit as well.
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {type(dataset.shuffle_block_size)} instead.')
shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion)
stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units,
dataset.num_canonical_nodes, dataset.shuffle_seed, epoch,
Expand Down
71 changes: 39 additions & 32 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,7 @@ class StreamingDataset(Array, IterableDataset):
of current sample. Workers will attempt to download ahead by this many samples during,
but not before, training. Recommendation is to provide a value greater than per device
batch size to ensure at-least per device batch size number of samples cached locally.
If ``None``, its value gets derived using per device batch size and number of
canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``.
Defaults to ``None``.
If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``.
cache_limit (Union[int, str], optional): Maximum size in bytes of this StreamingDataset's
shard cache. Before downloading a shard, the least recently used resident shard(s)
may be evicted (deleted from the local cache) in order to stay under the limit.
Expand All @@ -280,13 +278,14 @@ class StreamingDataset(Array, IterableDataset):
how many samples to pick from the same shard at a time (``1`` for evenly balanced
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
Defaults to ``1``.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``relaxed``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. The sample space is divided evenly according to the number of canonical
nodes. The higher the value, the more independent non-overlapping paths the
StreamingDataset replicas take through the shards per model replica (increasing data
source diversity). Defaults to ``None``, which is interpreted as 64 times the number
of nodes of the initial run.
source diversity). If ``None``, this is interpreted as 64 times the number of physical
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the
number of physical nodes of the initial run otherwise. Defaults to ``None``.

.. note::

Expand All @@ -296,10 +295,12 @@ class StreamingDataset(Array, IterableDataset):
partitioned over the workers. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
shuffle_seed (int): Seed for deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. A canonical node's samples are split into blocks
of this size, and samples within each block are shuffled. Defaults to ``1 << 18``.
shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split
into blocks of this size, and samples within each block are shuffled. If ``None``, its
value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason for 4_000_000? Just wondering

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was based on the shuffle quality experiments, where a shuffle strength of 4_000_000 (i.e., each batch is drawn from 4_000_000 samples) gave good shuffle quality without making the number of downloads required too high.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about max(1 << 18, 1 << 22 // NCN) lol?

Copy link
Collaborator Author

@snarayan21 snarayan21 Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works too, it's essentially the same right? Gonna keep as-is for now

``None``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
"""
Expand All @@ -319,13 +320,13 @@ def __init__(self,
cache_limit: Optional[Union[int, str]] = None,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
partition_algo: str = 'orig',
partition_algo: str = 'relaxed',
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1s',
shuffle_algo: str = 'py1e',
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
shuffle_block_size: Optional[int] = None,
batching_method: str = 'random') -> None:
# Global arguments (which do not live in Streams).
self.predownload = predownload
Expand Down Expand Up @@ -381,12 +382,15 @@ def __init__(self,
raise ValueError(f'`shuffle_seed` must be a non-negative integer, but got: ' +
f'{self.shuffle_seed}.')

# Check that predownload is at least per device batch size.
# Check that predownload is at least per device batch size, and set it if currently `None`.
if self.predownload is not None and self.batch_size is not None and \
self.predownload < self.batch_size:
warnings.warn(f'predownload < batch_size ({self.predownload} < {self.batch_size}).' +
f'This may result in slower batch time. Recommendation is to set ' +
f'predownload to at-least batch_size.')
elif self.predownload is None:
self.predownload = 8 * self.batch_size if self.batch_size is not None else 64

# Convert epoch size from string to int, if needed. Cannot be negative.
epoch_size_value = None
if epoch_size:
Expand Down Expand Up @@ -636,12 +640,11 @@ def __len__(self) -> int:
"""
return self.length

def _set_predownload(self) -> None:
"""Set the predownload value which is per number of workers."""
if self.predownload is None:
self.predownload = max(
self.batch_size, 256 * self.batch_size // self.num_canonical_nodes
) if self.batch_size is not None and self.num_canonical_nodes is not None else 512
def _set_shuffle_block_size(self):
"""Set the shuffle block size value."""
if self.shuffle_block_size is None:
self.shuffle_block_size = max(4_000_000 // self.num_canonical_nodes, 1 << 18) \
if self.num_canonical_nodes is not None else 1 << 18

def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
"""Either resume from checkpoint or start at the beginning.
Expand All @@ -660,8 +663,11 @@ def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
except FileNotFoundError:
# There is nothing to resume.
if not self.num_canonical_nodes:
self.num_canonical_nodes = world.num_nodes * 64
self._set_predownload()
if self.shuffle_algo in ['py1s', 'py2s']:
self.num_canonical_nodes = 64 * world.num_nodes
else:
self.num_canonical_nodes = world.num_nodes
self._set_shuffle_block_size()
return epoch, 0

# SharedMemory buffers may contain additional null bytes at the end.
Expand All @@ -673,8 +679,11 @@ def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
# Check if the resume state is stale.
if obj['epoch'] < epoch:
if not self.num_canonical_nodes:
self.num_canonical_nodes = world.num_nodes * 64
self._set_predownload()
if self.shuffle_algo in ['py1s', 'py2s']:
self.num_canonical_nodes = 64 * world.num_nodes
else:
self.num_canonical_nodes = world.num_nodes
self._set_shuffle_block_size()
return epoch, 0

# Load the correct resumption meta data.
Expand All @@ -685,7 +694,7 @@ def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
# Ensure that we are backwards compatible with old checkpoint dataset state, since the
# 'initial_physical_nodes' key may not be present.
self.initial_physical_nodes = obj.get('initial_physical_nodes', None)
self._set_predownload()
self._set_shuffle_block_size()

return epoch, sample_in_epoch

Expand Down Expand Up @@ -740,14 +749,12 @@ def state_dict(self, num_samples: int, from_beginning: bool) -> Dict[str, Any]:
else:
sample_in_epoch = offset + num_samples

if self.initial_physical_nodes is None:
# In this case, we are running for the first time, so we set initial_physical_nodes
# to the current number of physical nodes.
initial_physical_nodes = world.num_nodes
else:
# In this case, initial_physical_nodes has already been set from an initial run. We
# keep this value persisted in the state across the total run duration.
initial_physical_nodes = self.initial_physical_nodes
# If `self.initial_physical_nodes` is None, we are running for the first time, so we set
# initial_physical_nodes to the current number of physical nodes. Otherwise, we persist
# initial_physical_nodes as the value loaded and set from the resumption state.
initial_physical_nodes = world.num_nodes if self.initial_physical_nodes is None \
else self.initial_physical_nodes

return {
'epoch': epoch,
'sample_in_epoch': sample_in_epoch,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.utils.data import DataLoader

from streaming.base import Stream, StreamingDataLoader, StreamingDataset
from streaming.base.util import clean_stale_shared_memory
from tests.common.utils import convert_to_mds


Expand Down Expand Up @@ -762,6 +763,8 @@ def test_streamingdataloader_mid_epoch_resumption(local_remote_dir: Any, batch_s
del dataloader
del dataset

clean_stale_shared_memory()

dataset = StreamingDataset(local=local_dir,
remote=remote_dir,
shuffle=shuffle,
Expand Down