Skip to content

Commit

Permalink
added default behavior if no streams and epoch_size specified
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Jul 27, 2023
1 parent f36318c commit 9f6970d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 7 deletions.
30 changes: 23 additions & 7 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,13 @@ def __init__(self,
raise ValueError(
'You must provide either `streams` or `remote`/`local`, but not both.')

# Convert epoch size from string to int, if needed. Cannot be negative.
epoch_size_value = None
if epoch_size:
epoch_size_value = number_abbrev_to_int(epoch_size)
if epoch_size_value < 0:
raise ValueError(f'Epoch size cannot be negative. Received {epoch_size_value}.')

# Initialize torch dist ourselves, if necessary.
destroy_dist = maybe_init_dist()

Expand All @@ -306,6 +313,22 @@ def __init__(self,
}
for stream in streams:
stream.apply_default(default)
elif epoch_size_value:
# if there are no streams provided but the epoch_size is speficied
# we create a single stream that chooses the speficied number of samples
# per epoch. The epoch consists of data from this single stream.
samples_specified_stream = Stream(remote=remote,
local=local,
split=split,
choose=epoch_size_value,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip)
# reset epoch_size_value back to default of None since we have already accounted
# for it inside the samples_specified_stream
epoch_size_value = None
streams = [samples_specified_stream]
else:
default = Stream(remote=remote,
local=local,
Expand Down Expand Up @@ -370,13 +393,6 @@ def __init__(self,
self.sample_offset_per_shard = self.samples_per_shard.cumsum() - self.samples_per_shard
self.spanner = Spanner(self.samples_per_shard)

# Convert epoch size from string to int, if needed. Cannot be negative.
epoch_size_value = None
if epoch_size:
epoch_size_value = number_abbrev_to_int(epoch_size)
if epoch_size_value < 0:
raise ValueError(f'Epoch size cannot be negative. Received {epoch_size_value}.')

# Now that we know the number of underlying samples of each stream, derive each stream's
# true proportion/repeat/choose, as well as the total epoch size.
self.epoch_size = Stream.apply_weights(self.streams, self.samples_per_stream,
Expand Down
42 changes: 42 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,45 @@ def test_multiple_dataset_instantiation(local_remote_dir: Any, shuffle_seed: tup

assert len(train_sample_order) == len(val_sample_order), 'Missing samples'
assert len(set(train_sample_order)) == len(set(val_sample_order)), 'Duplicate samples'


@pytest.mark.parametrize('batch_size', [1, 4])
@pytest.mark.parametrize('seed', [2222])
@pytest.mark.parametrize('shuffle', [False])
@pytest.mark.parametrize('drop_last', [False, True])
@pytest.mark.parametrize('num_workers', [0, 8])
@pytest.mark.parametrize('num_canonical_nodes', [1])
@pytest.mark.parametrize('epoch_size', [10])
@pytest.mark.usefixtures('local_remote_dir')
def test_dataloader_epoch_size_no_streams(local_remote_dir: Any, batch_size: int, seed: int,
shuffle: bool, drop_last: bool, num_workers: int,
num_canonical_nodes: int, epoch_size: int):
remote_dir, local_dir = local_remote_dir
convert_to_mds(out_root=remote_dir,
dataset_name='sequencedataset',
num_samples=117,
size_limit=1 << 8)

# Build StreamingDataset
dataset = StreamingDataset(local=local_dir,
remote=remote_dir,
shuffle=shuffle,
batch_size=batch_size,
shuffle_seed=seed,
num_canonical_nodes=num_canonical_nodes,
epoch_size=epoch_size)

# Build DataLoader
dataloader = StreamingDataLoader(dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=drop_last)

samples_seen = 0
for batch in dataloader:
samples_seen += batch['sample'].size(dim=0)

if drop_last:
assert samples_seen == epoch_size - (epoch_size % batch_size)
else:
assert samples_seen == epoch_size

0 comments on commit 9f6970d

Please sign in to comment.