diff --git a/tests/test_laziness.py b/tests/test_laziness.py index 2b0570bcd..60f1010a3 100644 --- a/tests/test_laziness.py +++ b/tests/test_laziness.py @@ -13,7 +13,7 @@ def one(remote: str, local: str): """ Verify __getitem__ accesses. """ - dataset = StreamingDataset(local=remote) + dataset = StreamingDataset(local=local, remote=remote) for i in range(dataset.num_samples): sample = dataset[i] assert sample['value'] == i @@ -23,7 +23,7 @@ def two(remote: str, local: str): """ Verify __iter__ -> __getitem__ accesses. """ - dataset = StreamingDataset(local=remote, num_canonical_nodes=1) + dataset = StreamingDataset(local=local, remote=remote, num_canonical_nodes=1) for i, sample in zip(range(dataset.num_samples), dataset): assert sample['value'] == i diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 068cf9339..31d3afd76 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -12,180 +12,50 @@ from tests.common.utils import convert_to_mds -# @pytest.mark.parametrize('batch_size', [4]) -# @pytest.mark.parametrize('seed', [2222]) -# @pytest.mark.parametrize('shuffle', [False]) -# @pytest.mark.parametrize('drop_last', [False, True]) -# @pytest.mark.parametrize('num_workers', [3, 6]) -# @pytest.mark.parametrize('num_canonical_nodes', [4, 8]) -# @pytest.mark.parametrize('epoch_size', [10, 200]) -# @pytest.mark.usefixtures('local_remote_dir') -# def test_dataloader_epoch_size_no_streams(local_remote_dir: Tuple[str, str], batch_size: int, seed: int, -# shuffle: bool, drop_last: bool, num_workers: int, -# num_canonical_nodes: int, epoch_size: int): -# local, remote = local_remote_dir -# convert_to_mds(out_root=remote, -# dataset_name='sequencedataset', -# num_samples=117, -# size_limit=1 << 8) - -# # Build StreamingDataset -# dataset = StreamingDataset(local=local, -# remote=remote, -# shuffle=shuffle, -# batch_size=batch_size, -# shuffle_seed=seed, -# num_canonical_nodes=num_canonical_nodes, -# epoch_size=epoch_size) - -# # Build DataLoader -# dataloader = StreamingDataLoader(dataset=dataset, -# batch_size=batch_size, -# num_workers=num_workers, -# drop_last=drop_last) - -# samples_seen = 0 -# for batch in dataloader: -# print(batch['sample']) -# samples_seen += batch['sample'].size(dim=0) - -# if epoch_size % num_canonical_nodes != 0: -# assert samples_seen == math.ceil(epoch_size / num_canonical_nodes) * num_canonical_nodes -# else: -# if drop_last: -# assert samples_seen == epoch_size - (epoch_size % batch_size) -# else: -# assert samples_seen == epoch_size - -# shutil.rmtree(local) -# shutil.rmtree(remote) - -@pytest.mark.parametrize('batch_size', [128]) -@pytest.mark.parametrize('drop_last', [False, True]) -@pytest.mark.parametrize('shuffle', [False, True]) -@pytest.mark.parametrize('num_workers', [0, 4]) -@pytest.mark.parametrize('num_samples', [9867, 30_000]) -@pytest.mark.parametrize('size_limit', [8_192]) -@pytest.mark.parametrize('seed', [1234]) -@pytest.mark.usefixtures('local_remote_dir') -def test_dataloader_single_device2(local_remote_dir: Tuple[str, str], batch_size: int, - drop_last: bool, shuffle: bool, num_workers: int, - num_samples: int, size_limit: int, seed: int): - local, remote = local_remote_dir - convert_to_mds(out_root=remote, - dataset_name='sequencedataset', - num_samples=num_samples, - size_limit=size_limit) - - # Build a StreamingDataset - dataset = StreamingDataset(local=local, - remote=remote, - shuffle=shuffle, - batch_size=batch_size, - shuffle_seed=seed) - - # Build DataLoader - dataloader = DataLoader(dataset=dataset, - batch_size=batch_size, - num_workers=num_workers, - drop_last=drop_last) - - # Expected number of batches based on batch_size and drop_last - expected_num_batches = (num_samples // batch_size) if drop_last else math.ceil(num_samples / - batch_size) - expected_num_samples = expected_num_batches * batch_size if drop_last else num_samples - - # Iterate over DataLoader - rcvd_batches = 0 - sample_order = [] - - for batch_ix, batch in enumerate(dataloader): - rcvd_batches += 1 - - # Every batch should be complete except (maybe) final one - if batch_ix + 1 < expected_num_batches: - assert len(batch['id']) == batch_size - else: - if drop_last: - assert len(batch['id']) == batch_size - else: - assert len(batch['id']) <= batch_size - - sample_order.extend(batch['id'][:]) - - # Test dataloader length - assert len(dataloader) == expected_num_batches - assert rcvd_batches == expected_num_batches - - # Test that all samples arrived with no duplicates - assert len(set(sample_order)) == expected_num_samples - if not drop_last: - assert len(set(sample_order)) == num_samples - - shutil.rmtree(local) - shutil.rmtree(remote) - -@pytest.mark.parametrize('batch_size', [128]) +@pytest.mark.parametrize('batch_size', [4]) +@pytest.mark.parametrize('seed', [2222]) +@pytest.mark.parametrize('shuffle', [False]) @pytest.mark.parametrize('drop_last', [False, True]) -@pytest.mark.parametrize('shuffle', [False, True]) -@pytest.mark.parametrize('num_workers', [0, 4]) -@pytest.mark.parametrize('num_samples', [9867, 30_000]) -@pytest.mark.parametrize('size_limit', [8_192]) -@pytest.mark.parametrize('seed', [1234]) +@pytest.mark.parametrize('num_workers', [3, 6]) +@pytest.mark.parametrize('num_canonical_nodes', [4, 8]) +@pytest.mark.parametrize('epoch_size', [10, 200]) @pytest.mark.usefixtures('local_remote_dir') -def test_dataloader_single_device3(local_remote_dir: Tuple[str, str], batch_size: int, - drop_last: bool, shuffle: bool, num_workers: int, - num_samples: int, size_limit: int, seed: int): +def test_dataloader_epoch_size_no_streams(local_remote_dir: Tuple[str, str], batch_size: int, seed: int, + shuffle: bool, drop_last: bool, num_workers: int, + num_canonical_nodes: int, epoch_size: int): local, remote = local_remote_dir convert_to_mds(out_root=remote, dataset_name='sequencedataset', - num_samples=num_samples, - size_limit=size_limit) + num_samples=117, + size_limit=1 << 8) - # Build a StreamingDataset + # Build StreamingDataset dataset = StreamingDataset(local=local, remote=remote, shuffle=shuffle, batch_size=batch_size, - shuffle_seed=seed) + shuffle_seed=seed, + num_canonical_nodes=num_canonical_nodes, + epoch_size=epoch_size) # Build DataLoader - dataloader = DataLoader(dataset=dataset, - batch_size=batch_size, - num_workers=num_workers, - drop_last=drop_last) - - # Expected number of batches based on batch_size and drop_last - expected_num_batches = (num_samples // batch_size) if drop_last else math.ceil(num_samples / - batch_size) - expected_num_samples = expected_num_batches * batch_size if drop_last else num_samples - - # Iterate over DataLoader - rcvd_batches = 0 - sample_order = [] + dataloader = StreamingDataLoader(dataset=dataset, + batch_size=batch_size, + num_workers=num_workers, + drop_last=drop_last) - for batch_ix, batch in enumerate(dataloader): - rcvd_batches += 1 + samples_seen = 0 + for batch in dataloader: + print(batch['sample']) + samples_seen += batch['sample'].size(dim=0) - # Every batch should be complete except (maybe) final one - if batch_ix + 1 < expected_num_batches: - assert len(batch['id']) == batch_size + if epoch_size % num_canonical_nodes != 0: + assert samples_seen == math.ceil(epoch_size / num_canonical_nodes) * num_canonical_nodes + else: + if drop_last: + assert samples_seen == epoch_size - (epoch_size % batch_size) else: - if drop_last: - assert len(batch['id']) == batch_size - else: - assert len(batch['id']) <= batch_size - - sample_order.extend(batch['id'][:]) - - # Test dataloader length - assert len(dataloader) == expected_num_batches - assert rcvd_batches == expected_num_batches - - # Test that all samples arrived with no duplicates - assert len(set(sample_order)) == expected_num_samples - if not drop_last: - assert len(set(sample_order)) == num_samples + assert samples_seen == epoch_size shutil.rmtree(local) shutil.rmtree(remote) @@ -305,6 +175,9 @@ def test_dataloader_determinism(local_remote_dir: Any, batch_size: int, seed: in assert len(sample_order) == len(second_sample_order) assert sample_order == second_sample_order + shutil.rmtree(local_dir) + shutil.rmtree(remote_dir) + @pytest.mark.parametrize('batch_size', [1, 4]) @pytest.mark.parametrize('seed', [2222]) @@ -348,6 +221,9 @@ def test_dataloader_sample_order(local_remote_dir: Any, batch_size: int, seed: i assert expected_sample_order == sample_order + shutil.rmtree(local_dir) + shutil.rmtree(remote_dir) + @pytest.mark.parametrize('batch_size', [1, 4]) @pytest.mark.parametrize('seed', [3456]) @@ -413,6 +289,9 @@ def test_streamingdataloader_mid_epoch_resumption(local_remote_dir: Any, batch_s assert len(set(sample_order)) == len(set(expected_sample_order)), 'Duplicate samples' assert sample_order == expected_sample_order, 'Incorrect sample order' + shutil.rmtree(local_dir) + shutil.rmtree(remote_dir) + @pytest.mark.parametrize('shuffle_seed', [(9876, 9876), (12345, 1567)]) @pytest.mark.usefixtures('local_remote_dir') @@ -464,4 +343,7 @@ def test_multiple_dataset_instantiation(local_remote_dir: Any, shuffle_seed: tup val_sample_order.extend(batch['id'][:]) assert len(train_sample_order) == len(val_sample_order), 'Missing samples' - assert len(set(train_sample_order)) == len(set(val_sample_order)), 'Duplicate samples' \ No newline at end of file + assert len(set(train_sample_order)) == len(set(val_sample_order)), 'Duplicate samples' + + shutil.rmtree(local_dir) + shutil.rmtree(remote_dir) \ No newline at end of file