Skip to content

Commit

Permalink
fixed nits, test cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Jul 28, 2023
1 parent 9f6970d commit 11672fe
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 26 deletions.
23 changes: 7 additions & 16 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,31 +313,22 @@ def __init__(self,
}
for stream in streams:
stream.apply_default(default)
elif epoch_size_value:
# if there are no streams provided but the epoch_size is speficied
# we create a single stream that chooses the speficied number of samples
# per epoch. The epoch consists of data from this single stream.
samples_specified_stream = Stream(remote=remote,
local=local,
split=split,
choose=epoch_size_value,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip)
# reset epoch_size_value back to default of None since we have already accounted
# for it inside the samples_specified_stream
epoch_size_value = None
streams = [samples_specified_stream]
else:
default = Stream(remote=remote,
local=local,
split=split,
choose=epoch_size_value,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip)
streams = [default]
# reset epoch_size_value back to default of None since we have already accounted
# for it inside the samples_specified_stream. Needed to get correct `apply_weights`
# function behavior because if epoch_size is specified with a single stream, we are
# weighting absolutely.
if epoch_size_value:
epoch_size_value = None

# Validate the stream weighting scheme (relative or absolute) to catch errors before we go
# to the trouble of loading them.
Expand Down
5 changes: 4 additions & 1 deletion tests/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def local_remote_dir() -> Any:
mock_remote_dir = os.path.join(mock_dir.name, 'remote')
yield mock_local_dir, mock_remote_dir
finally:
mock_dir.cleanup() # pyright: ignore
try:
mock_dir.cleanup() # pyright: ignore
except OSError:
pass


@pytest.fixture(scope='function')
Expand Down
4 changes: 2 additions & 2 deletions tests/test_laziness.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def one(remote: str, local: str):
"""
Verify __getitem__ accesses.
"""
dataset = StreamingDataset(local=remote)
dataset = StreamingDataset(local=local, remote=remote)
for i in range(dataset.num_samples):
sample = dataset[i]
assert sample['value'] == i
Expand All @@ -23,7 +23,7 @@ def two(remote: str, local: str):
"""
Verify __iter__ -> __getitem__ accesses.
"""
dataset = StreamingDataset(local=remote, num_canonical_nodes=1)
dataset = StreamingDataset(local=local, remote=remote, num_canonical_nodes=1)
for i, sample in zip(range(dataset.num_samples), dataset):
assert sample['value'] == i

Expand Down
18 changes: 11 additions & 7 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,13 @@ def test_multiple_dataset_instantiation(local_remote_dir: Any, shuffle_seed: tup
assert len(set(train_sample_order)) == len(set(val_sample_order)), 'Duplicate samples'


@pytest.mark.parametrize('batch_size', [1, 4])
@pytest.mark.parametrize('batch_size', [4])
@pytest.mark.parametrize('seed', [2222])
@pytest.mark.parametrize('shuffle', [False])
@pytest.mark.parametrize('drop_last', [False, True])
@pytest.mark.parametrize('num_workers', [0, 8])
@pytest.mark.parametrize('num_canonical_nodes', [1])
@pytest.mark.parametrize('epoch_size', [10])
@pytest.mark.parametrize('num_workers', [3, 6])
@pytest.mark.parametrize('num_canonical_nodes', [1, 4, 8])
@pytest.mark.parametrize('epoch_size', [10, 200])
@pytest.mark.usefixtures('local_remote_dir')
def test_dataloader_epoch_size_no_streams(local_remote_dir: Any, batch_size: int, seed: int,
shuffle: bool, drop_last: bool, num_workers: int,
Expand Down Expand Up @@ -320,9 +320,13 @@ def test_dataloader_epoch_size_no_streams(local_remote_dir: Any, batch_size: int

samples_seen = 0
for batch in dataloader:
print(batch['sample'])
samples_seen += batch['sample'].size(dim=0)

if drop_last:
assert samples_seen == epoch_size - (epoch_size % batch_size)
if epoch_size % num_canonical_nodes != 0:
assert samples_seen == math.ceil(epoch_size / num_canonical_nodes) * num_canonical_nodes
else:
assert samples_seen == epoch_size
if drop_last:
assert samples_seen == epoch_size - (epoch_size % batch_size)
else:
assert samples_seen == epoch_size

0 comments on commit 11672fe

Please sign in to comment.