Skip to content

Commit

Permalink
Better default values for StreamingDataset args (#479)
Browse files Browse the repository at this point in the history
* streaming dataset better defaults

* relaxed partition default

* modified dosctring

* removed assert statements

* corrected type in TypeError

* linting fix

* test cleanup

* test cleanup

* test modification

* test modification
  • Loading branch information
snarayan21 authored Oct 27, 2023
1 parent 217e66e commit 93bf054
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 40 deletions.
18 changes: 10 additions & 8 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ jobs:
id: tests
run: |
set -ex
pytest --splits 8 --group 1 --cov-fail-under=10
pytest --splits 8 --group 2 --cov-fail-under=10
pytest --splits 8 --group 3 --cov-fail-under=10
pytest --splits 8 --group 4 --cov-fail-under=10
pytest --splits 8 --group 5 --cov-fail-under=10
pytest --splits 8 --group 6 --cov-fail-under=10
pytest --splits 8 --group 7 --cov-fail-under=10
pytest --splits 8 --group 8 --cov-fail-under=10
pytest --splits 10 --group 1 --cov-fail-under=10
pytest --splits 10 --group 2 --cov-fail-under=10
pytest --splits 10 --group 3 --cov-fail-under=10
pytest --splits 10 --group 4 --cov-fail-under=10
pytest --splits 10 --group 5 --cov-fail-under=10
pytest --splits 10 --group 6 --cov-fail-under=10
pytest --splits 10 --group 7 --cov-fail-under=10
pytest --splits 10 --group 8 --cov-fail-under=10
pytest --splits 10 --group 9 --cov-fail-under=10
pytest --splits 10 --group 10 --cov-fail-under=10
3 changes: 3 additions & 0 deletions streaming/base/batching/per_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, e
# same as the ratio of the stream's samples to overall samples.
# This ensures that the overall training shuffle block size is still approximately
# equal to what is set by the user, and allows for reasoning about cache_limit as well.
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {type(dataset.shuffle_block_size)} instead.')
shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion)
stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units,
dataset.num_canonical_nodes, dataset.shuffle_seed, epoch,
Expand Down
3 changes: 3 additions & 0 deletions streaming/base/batching/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def generate_work_random_batching(dataset: StreamingDataset, world: World, epoch

# If we need to shuffle, shuffle in a node-aware and *underlying* shard-aware way.
if dataset.shuffle:
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {type(dataset.shuffle_block_size)} instead.')
shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes,
dataset.shuffle_seed, epoch, dataset.shuffle_block_size)
big_ids = np.where(big_ids != -1, shuffle[big_ids], -1)
Expand Down
3 changes: 3 additions & 0 deletions streaming/base/batching/stratified.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def generate_work_stratified_batching(dataset: StreamingDataset, world: World, e
# same as the ratio of the stream's samples to overall samples.
# This ensures that the overall training shuffle block size is still approximately
# equal to what is set by the user, and allows for reasoning about cache_limit as well.
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {type(dataset.shuffle_block_size)} instead.')
shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion)
stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units,
dataset.num_canonical_nodes, dataset.shuffle_seed, epoch,
Expand Down
71 changes: 39 additions & 32 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,7 @@ class StreamingDataset(Array, IterableDataset):
of current sample. Workers will attempt to download ahead by this many samples during,
but not before, training. Recommendation is to provide a value greater than per device
batch size to ensure at-least per device batch size number of samples cached locally.
If ``None``, its value gets derived using per device batch size and number of
canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``.
Defaults to ``None``.
If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``.
cache_limit (Union[int, str], optional): Maximum size in bytes of this StreamingDataset's
shard cache. Before downloading a shard, the least recently used resident shard(s)
may be evicted (deleted from the local cache) in order to stay under the limit.
Expand All @@ -280,13 +278,14 @@ class StreamingDataset(Array, IterableDataset):
how many samples to pick from the same shard at a time (``1`` for evenly balanced
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
Defaults to ``1``.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``relaxed``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. The sample space is divided evenly according to the number of canonical
nodes. The higher the value, the more independent non-overlapping paths the
StreamingDataset replicas take through the shards per model replica (increasing data
source diversity). Defaults to ``None``, which is interpreted as 64 times the number
of nodes of the initial run.
source diversity). If ``None``, this is interpreted as 64 times the number of physical
nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the
number of physical nodes of the initial run otherwise. Defaults to ``None``.
.. note::
Expand All @@ -296,10 +295,12 @@ class StreamingDataset(Array, IterableDataset):
partitioned over the workers. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
shuffle_seed (int): Seed for deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. A canonical node's samples are split into blocks
of this size, and samples within each block are shuffled. Defaults to ``1 << 18``.
shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split
into blocks of this size, and samples within each block are shuffled. If ``None``, its
value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to
``None``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
"""
Expand All @@ -319,13 +320,13 @@ def __init__(self,
cache_limit: Optional[Union[int, str]] = None,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
partition_algo: str = 'orig',
partition_algo: str = 'relaxed',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1s',
shuffle_algo: str = 'py1e',
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
shuffle_block_size: Optional[int] = None,
batching_method: str = 'random') -> None:
# Global arguments (which do not live in Streams).
self.predownload = predownload
Expand Down Expand Up @@ -381,12 +382,15 @@ def __init__(self,
raise ValueError(f'`shuffle_seed` must be a non-negative integer, but got: ' +
f'{self.shuffle_seed}.')

# Check that predownload is at least per device batch size.
# Check that predownload is at least per device batch size, and set it if currently `None`.
if self.predownload is not None and self.batch_size is not None and \
self.predownload < self.batch_size:
warnings.warn(f'predownload < batch_size ({self.predownload} < {self.batch_size}).' +
f'This may result in slower batch time. Recommendation is to set ' +
f'predownload to at-least batch_size.')
elif self.predownload is None:
self.predownload = 8 * self.batch_size if self.batch_size is not None else 64

# Convert epoch size from string to int, if needed. Cannot be negative.
epoch_size_value = None
if epoch_size:
Expand Down Expand Up @@ -636,12 +640,11 @@ def __len__(self) -> int:
"""
return self.length

def _set_predownload(self) -> None:
"""Set the predownload value which is per number of workers."""
if self.predownload is None:
self.predownload = max(
self.batch_size, 256 * self.batch_size // self.num_canonical_nodes
) if self.batch_size is not None and self.num_canonical_nodes is not None else 512
def _set_shuffle_block_size(self):
"""Set the shuffle block size value."""
if self.shuffle_block_size is None:
self.shuffle_block_size = max(4_000_000 // self.num_canonical_nodes, 1 << 18) \
if self.num_canonical_nodes is not None else 1 << 18

def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
"""Either resume from checkpoint or start at the beginning.
Expand All @@ -660,8 +663,11 @@ def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
except FileNotFoundError:
# There is nothing to resume.
if not self.num_canonical_nodes:
self.num_canonical_nodes = world.num_nodes * 64
self._set_predownload()
if self.shuffle_algo in ['py1s', 'py2s']:
self.num_canonical_nodes = 64 * world.num_nodes
else:
self.num_canonical_nodes = world.num_nodes
self._set_shuffle_block_size()
return epoch, 0

# SharedMemory buffers may contain additional null bytes at the end.
Expand All @@ -673,8 +679,11 @@ def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
# Check if the resume state is stale.
if obj['epoch'] < epoch:
if not self.num_canonical_nodes:
self.num_canonical_nodes = world.num_nodes * 64
self._set_predownload()
if self.shuffle_algo in ['py1s', 'py2s']:
self.num_canonical_nodes = 64 * world.num_nodes
else:
self.num_canonical_nodes = world.num_nodes
self._set_shuffle_block_size()
return epoch, 0

# Load the correct resumption meta data.
Expand All @@ -685,7 +694,7 @@ def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
# Ensure that we are backwards compatible with old checkpoint dataset state, since the
# 'initial_physical_nodes' key may not be present.
self.initial_physical_nodes = obj.get('initial_physical_nodes', None)
self._set_predownload()
self._set_shuffle_block_size()

return epoch, sample_in_epoch

Expand Down Expand Up @@ -740,14 +749,12 @@ def state_dict(self, num_samples: int, from_beginning: bool) -> Dict[str, Any]:
else:
sample_in_epoch = offset + num_samples

if self.initial_physical_nodes is None:
# In this case, we are running for the first time, so we set initial_physical_nodes
# to the current number of physical nodes.
initial_physical_nodes = world.num_nodes
else:
# In this case, initial_physical_nodes has already been set from an initial run. We
# keep this value persisted in the state across the total run duration.
initial_physical_nodes = self.initial_physical_nodes
# If `self.initial_physical_nodes` is None, we are running for the first time, so we set
# initial_physical_nodes to the current number of physical nodes. Otherwise, we persist
# initial_physical_nodes as the value loaded and set from the resumption state.
initial_physical_nodes = world.num_nodes if self.initial_physical_nodes is None \
else self.initial_physical_nodes

return {
'epoch': epoch,
'sample_in_epoch': sample_in_epoch,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.utils.data import DataLoader

from streaming.base import Stream, StreamingDataLoader, StreamingDataset
from streaming.base.util import clean_stale_shared_memory
from tests.common.utils import convert_to_mds


Expand Down Expand Up @@ -762,6 +763,8 @@ def test_streamingdataloader_mid_epoch_resumption(local_remote_dir: Any, batch_s
del dataloader
del dataset

clean_stale_shared_memory()

dataset = StreamingDataset(local=local_dir,
remote=remote_dir,
shuffle=shuffle,
Expand Down

0 comments on commit 93bf054

Please sign in to comment.