Skip to content

Commit

Permalink
removing temporary directories in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Jul 29, 2023
1 parent bb16a83 commit 9bd8896
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 162 deletions.
4 changes: 2 additions & 2 deletions tests/test_laziness.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def one(remote: str, local: str):
"""
Verify __getitem__ accesses.
"""
dataset = StreamingDataset(local=remote)
dataset = StreamingDataset(local=local, remote=remote)
for i in range(dataset.num_samples):
sample = dataset[i]
assert sample['value'] == i
Expand All @@ -23,7 +23,7 @@ def two(remote: str, local: str):
"""
Verify __iter__ -> __getitem__ accesses.
"""
dataset = StreamingDataset(local=remote, num_canonical_nodes=1)
dataset = StreamingDataset(local=local, remote=remote, num_canonical_nodes=1)
for i, sample in zip(range(dataset.num_samples), dataset):
assert sample['value'] == i

Expand Down
202 changes: 42 additions & 160 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,180 +12,50 @@
from tests.common.utils import convert_to_mds


# @pytest.mark.parametrize('batch_size', [4])
# @pytest.mark.parametrize('seed', [2222])
# @pytest.mark.parametrize('shuffle', [False])
# @pytest.mark.parametrize('drop_last', [False, True])
# @pytest.mark.parametrize('num_workers', [3, 6])
# @pytest.mark.parametrize('num_canonical_nodes', [4, 8])
# @pytest.mark.parametrize('epoch_size', [10, 200])
# @pytest.mark.usefixtures('local_remote_dir')
# def test_dataloader_epoch_size_no_streams(local_remote_dir: Tuple[str, str], batch_size: int, seed: int,
# shuffle: bool, drop_last: bool, num_workers: int,
# num_canonical_nodes: int, epoch_size: int):
# local, remote = local_remote_dir
# convert_to_mds(out_root=remote,
# dataset_name='sequencedataset',
# num_samples=117,
# size_limit=1 << 8)

# # Build StreamingDataset
# dataset = StreamingDataset(local=local,
# remote=remote,
# shuffle=shuffle,
# batch_size=batch_size,
# shuffle_seed=seed,
# num_canonical_nodes=num_canonical_nodes,
# epoch_size=epoch_size)

# # Build DataLoader
# dataloader = StreamingDataLoader(dataset=dataset,
# batch_size=batch_size,
# num_workers=num_workers,
# drop_last=drop_last)

# samples_seen = 0
# for batch in dataloader:
# print(batch['sample'])
# samples_seen += batch['sample'].size(dim=0)

# if epoch_size % num_canonical_nodes != 0:
# assert samples_seen == math.ceil(epoch_size / num_canonical_nodes) * num_canonical_nodes
# else:
# if drop_last:
# assert samples_seen == epoch_size - (epoch_size % batch_size)
# else:
# assert samples_seen == epoch_size

# shutil.rmtree(local)
# shutil.rmtree(remote)

@pytest.mark.parametrize('batch_size', [128])
@pytest.mark.parametrize('drop_last', [False, True])
@pytest.mark.parametrize('shuffle', [False, True])
@pytest.mark.parametrize('num_workers', [0, 4])
@pytest.mark.parametrize('num_samples', [9867, 30_000])
@pytest.mark.parametrize('size_limit', [8_192])
@pytest.mark.parametrize('seed', [1234])
@pytest.mark.usefixtures('local_remote_dir')
def test_dataloader_single_device2(local_remote_dir: Tuple[str, str], batch_size: int,
drop_last: bool, shuffle: bool, num_workers: int,
num_samples: int, size_limit: int, seed: int):
local, remote = local_remote_dir
convert_to_mds(out_root=remote,
dataset_name='sequencedataset',
num_samples=num_samples,
size_limit=size_limit)

# Build a StreamingDataset
dataset = StreamingDataset(local=local,
remote=remote,
shuffle=shuffle,
batch_size=batch_size,
shuffle_seed=seed)

# Build DataLoader
dataloader = DataLoader(dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=drop_last)

# Expected number of batches based on batch_size and drop_last
expected_num_batches = (num_samples // batch_size) if drop_last else math.ceil(num_samples /
batch_size)
expected_num_samples = expected_num_batches * batch_size if drop_last else num_samples

# Iterate over DataLoader
rcvd_batches = 0
sample_order = []

for batch_ix, batch in enumerate(dataloader):
rcvd_batches += 1

# Every batch should be complete except (maybe) final one
if batch_ix + 1 < expected_num_batches:
assert len(batch['id']) == batch_size
else:
if drop_last:
assert len(batch['id']) == batch_size
else:
assert len(batch['id']) <= batch_size

sample_order.extend(batch['id'][:])

# Test dataloader length
assert len(dataloader) == expected_num_batches
assert rcvd_batches == expected_num_batches

# Test that all samples arrived with no duplicates
assert len(set(sample_order)) == expected_num_samples
if not drop_last:
assert len(set(sample_order)) == num_samples

shutil.rmtree(local)
shutil.rmtree(remote)

@pytest.mark.parametrize('batch_size', [128])
@pytest.mark.parametrize('batch_size', [4])
@pytest.mark.parametrize('seed', [2222])
@pytest.mark.parametrize('shuffle', [False])
@pytest.mark.parametrize('drop_last', [False, True])
@pytest.mark.parametrize('shuffle', [False, True])
@pytest.mark.parametrize('num_workers', [0, 4])
@pytest.mark.parametrize('num_samples', [9867, 30_000])
@pytest.mark.parametrize('size_limit', [8_192])
@pytest.mark.parametrize('seed', [1234])
@pytest.mark.parametrize('num_workers', [3, 6])
@pytest.mark.parametrize('num_canonical_nodes', [4, 8])
@pytest.mark.parametrize('epoch_size', [10, 200])
@pytest.mark.usefixtures('local_remote_dir')
def test_dataloader_single_device3(local_remote_dir: Tuple[str, str], batch_size: int,
drop_last: bool, shuffle: bool, num_workers: int,
num_samples: int, size_limit: int, seed: int):
def test_dataloader_epoch_size_no_streams(local_remote_dir: Tuple[str, str], batch_size: int, seed: int,
shuffle: bool, drop_last: bool, num_workers: int,
num_canonical_nodes: int, epoch_size: int):
local, remote = local_remote_dir
convert_to_mds(out_root=remote,
dataset_name='sequencedataset',
num_samples=num_samples,
size_limit=size_limit)
num_samples=117,
size_limit=1 << 8)

# Build a StreamingDataset
# Build StreamingDataset
dataset = StreamingDataset(local=local,
remote=remote,
shuffle=shuffle,
batch_size=batch_size,
shuffle_seed=seed)
shuffle_seed=seed,
num_canonical_nodes=num_canonical_nodes,
epoch_size=epoch_size)

# Build DataLoader
dataloader = DataLoader(dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=drop_last)

# Expected number of batches based on batch_size and drop_last
expected_num_batches = (num_samples // batch_size) if drop_last else math.ceil(num_samples /
batch_size)
expected_num_samples = expected_num_batches * batch_size if drop_last else num_samples

# Iterate over DataLoader
rcvd_batches = 0
sample_order = []
dataloader = StreamingDataLoader(dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=drop_last)

for batch_ix, batch in enumerate(dataloader):
rcvd_batches += 1
samples_seen = 0
for batch in dataloader:
print(batch['sample'])
samples_seen += batch['sample'].size(dim=0)

# Every batch should be complete except (maybe) final one
if batch_ix + 1 < expected_num_batches:
assert len(batch['id']) == batch_size
if epoch_size % num_canonical_nodes != 0:
assert samples_seen == math.ceil(epoch_size / num_canonical_nodes) * num_canonical_nodes
else:
if drop_last:
assert samples_seen == epoch_size - (epoch_size % batch_size)
else:
if drop_last:
assert len(batch['id']) == batch_size
else:
assert len(batch['id']) <= batch_size

sample_order.extend(batch['id'][:])

# Test dataloader length
assert len(dataloader) == expected_num_batches
assert rcvd_batches == expected_num_batches

# Test that all samples arrived with no duplicates
assert len(set(sample_order)) == expected_num_samples
if not drop_last:
assert len(set(sample_order)) == num_samples
assert samples_seen == epoch_size

shutil.rmtree(local)
shutil.rmtree(remote)
Expand Down Expand Up @@ -305,6 +175,9 @@ def test_dataloader_determinism(local_remote_dir: Any, batch_size: int, seed: in
assert len(sample_order) == len(second_sample_order)
assert sample_order == second_sample_order

shutil.rmtree(local_dir)
shutil.rmtree(remote_dir)


@pytest.mark.parametrize('batch_size', [1, 4])
@pytest.mark.parametrize('seed', [2222])
Expand Down Expand Up @@ -348,6 +221,9 @@ def test_dataloader_sample_order(local_remote_dir: Any, batch_size: int, seed: i

assert expected_sample_order == sample_order

shutil.rmtree(local_dir)
shutil.rmtree(remote_dir)


@pytest.mark.parametrize('batch_size', [1, 4])
@pytest.mark.parametrize('seed', [3456])
Expand Down Expand Up @@ -413,6 +289,9 @@ def test_streamingdataloader_mid_epoch_resumption(local_remote_dir: Any, batch_s
assert len(set(sample_order)) == len(set(expected_sample_order)), 'Duplicate samples'
assert sample_order == expected_sample_order, 'Incorrect sample order'

shutil.rmtree(local_dir)
shutil.rmtree(remote_dir)


@pytest.mark.parametrize('shuffle_seed', [(9876, 9876), (12345, 1567)])
@pytest.mark.usefixtures('local_remote_dir')
Expand Down Expand Up @@ -464,4 +343,7 @@ def test_multiple_dataset_instantiation(local_remote_dir: Any, shuffle_seed: tup
val_sample_order.extend(batch['id'][:])

assert len(train_sample_order) == len(val_sample_order), 'Missing samples'
assert len(set(train_sample_order)) == len(set(val_sample_order)), 'Duplicate samples'
assert len(set(train_sample_order)) == len(set(val_sample_order)), 'Duplicate samples'

shutil.rmtree(local_dir)
shutil.rmtree(remote_dir)

0 comments on commit 9bd8896

Please sign in to comment.