Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better default values for StreamingDataset args #479

Merged
merged 12 commits into from
Oct 27, 2023
4 changes: 3 additions & 1 deletion streaming/base/batching/per_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, e
# same as the ratio of the stream's samples to overall samples.
# This ensures that the overall training shuffle block size is still approximately
# equal to what is set by the user, and allows for reasoning about cache_limit as well.
assert isinstance(dataset.shuffle_block_size, int)
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {dataset.shuffle_block_size} instead.')
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion)
stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units,
dataset.num_canonical_nodes, dataset.shuffle_seed, epoch,
Expand Down
4 changes: 3 additions & 1 deletion streaming/base/batching/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def generate_work_random_batching(dataset: StreamingDataset, world: World, epoch

# If we need to shuffle, shuffle in a node-aware and *underlying* shard-aware way.
if dataset.shuffle:
assert isinstance(dataset.shuffle_block_size, int)
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {dataset.shuffle_block_size} instead.')
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes,
dataset.shuffle_seed, epoch, dataset.shuffle_block_size)
big_ids = np.where(big_ids != -1, shuffle[big_ids], -1)
Expand Down
4 changes: 3 additions & 1 deletion streaming/base/batching/stratified.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def generate_work_stratified_batching(dataset: StreamingDataset, world: World, e
# same as the ratio of the stream's samples to overall samples.
# This ensures that the overall training shuffle block size is still approximately
# equal to what is set by the user, and allows for reasoning about cache_limit as well.
assert isinstance(dataset.shuffle_block_size, int)
if not isinstance(dataset.shuffle_block_size, int):
raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' +
f'Got {dataset.shuffle_block_size} instead.')
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion)
stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units,
dataset.num_canonical_nodes, dataset.shuffle_seed, epoch,
Expand Down
Loading