From 3de2e68a7fb3070f6f85e4c9a5a633034b113378 Mon Sep 17 00:00:00 2001 From: Joe Early Date: Wed, 14 Feb 2024 18:15:32 +0000 Subject: [PATCH] Allow writers to overwrite existing data (#594) * Allow writers to overwrite existing data * Add exist_ok arg to writer docs * Fix linting * Fix linting * Move removal code into writer from cloud uploader * Remove old function --------- Co-authored-by: Xiaohan Zhang --- streaming/base/format/base/writer.py | 13 +++- streaming/base/format/json/writer.py | 3 + streaming/base/format/mds/writer.py | 3 + streaming/base/format/xsv/writer.py | 3 + tests/test_writer.py | 103 +++++++++++++++++++++++++++ 5 files changed, 124 insertions(+), 1 deletion(-) diff --git a/streaming/base/format/base/writer.py b/streaming/base/format/base/writer.py index f9128d8a7..c699e32c0 100644 --- a/streaming/base/format/base/writer.py +++ b/streaming/base/format/base/writer.py @@ -64,6 +64,9 @@ class Writer(ABC): file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. retry (int): Number of times to retry uploading a file to a remote location. Default to ``2``. + exist_ok (bool): If the local directory exists and is not empty, whether to overwrite + the content or raise an error. `False` raises an error. `True` deletes the + content and starts fresh. Defaults to `False`. """ format: str = '' # Name of the format (like "mds", "csv", "json", etc). @@ -100,7 +103,8 @@ def __init__(self, # Validate keyword arguments invalid_kwargs = [ - arg for arg in kwargs.keys() if arg not in ('progress_bar', 'max_workers', 'retry') + arg for arg in kwargs.keys() + if arg not in ('progress_bar', 'max_workers', 'retry', 'exist_ok') ] if invalid_kwargs: raise ValueError(f'Invalid Writer argument(s): {invalid_kwargs} ') @@ -116,6 +120,13 @@ def __init__(self, self.shards = [] + # Remove local directory if requested prior to creating writer + local = out if isinstance(out, str) else out[0] + if os.path.exists(local) and kwargs.get('exist_ok', False): + logger.warning( + f'Directory {local} exists and is not empty; exist_ok is set to True so will remove contents.' + ) + shutil.rmtree(local) self.cloud_writer = CloudUploader.get(out, keep_local, kwargs.get('progress_bar', False), kwargs.get('retry', 2)) self.local = self.cloud_writer.local diff --git a/streaming/base/format/json/writer.py b/streaming/base/format/json/writer.py index c8a0f8c10..3fc4bd42a 100644 --- a/streaming/base/format/json/writer.py +++ b/streaming/base/format/json/writer.py @@ -45,6 +45,9 @@ class JSONWriter(SplitWriter): max_workers (int): Maximum number of threads used to upload output dataset files in parallel to a remote location. One thread is responsible for uploading one shard file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. + exist_ok (bool): If the local directory exists and is not empty, whether to overwrite + the content or raise an error. `False` raises an error. `True` deletes the + content and starts fresh. Defaults to `False`. """ format = 'json' diff --git a/streaming/base/format/mds/writer.py b/streaming/base/format/mds/writer.py index 4b20a6d76..494c7e8ad 100644 --- a/streaming/base/format/mds/writer.py +++ b/streaming/base/format/mds/writer.py @@ -45,6 +45,9 @@ class MDSWriter(JointWriter): max_workers (int): Maximum number of threads used to upload output dataset files in parallel to a remote location. One thread is responsible for uploading one shard file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. + exist_ok (bool): If the local directory exists and is not empty, whether to overwrite + the content or raise an error. `False` raises an error. `True` deletes the + content and starts fresh. Defaults to `False`. """ format = 'mds' diff --git a/streaming/base/format/xsv/writer.py b/streaming/base/format/xsv/writer.py index ea0e7b6e9..fcd60641a 100644 --- a/streaming/base/format/xsv/writer.py +++ b/streaming/base/format/xsv/writer.py @@ -46,6 +46,9 @@ class XSVWriter(SplitWriter): max_workers (int): Maximum number of threads used to upload output dataset files in parallel to a remote location. One thread is responsible for uploading one shard file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. + exist_ok (bool): If the local directory exists and is not empty, whether to overwrite + the content or raise an error. `False` raises an error. `True` deletes the + content and starts fresh. Defaults to `False`. """ format = 'xsv' diff --git a/tests/test_writer.py b/tests/test_writer.py index 5aa2be00b..a0caab36d 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -121,6 +121,34 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s for before, after in zip(dataset, mds_dataset): assert before == after + def test_exist_ok(self, local_remote_dir: Tuple[str, str]) -> None: + num_samples = 1000 + size_limit = 4096 + local, _ = local_remote_dir + dataset = SequenceDataset(num_samples) + columns = dict(zip(dataset.column_names, dataset.column_encodings)) + + # Write entire dataset initially + with MDSWriter(out=local, columns=columns, size_limit=size_limit) as out: + for sample in dataset: + out.write(sample) + num_orig_files = len(os.listdir(local)) + + # Write single sample with exist_ok set to True + with MDSWriter(out=local, columns=columns, size_limit=size_limit, exist_ok=True) as out: + out.write(dataset[0]) + num_files = len(os.listdir(local)) + + # Two files for single sample (index.json and one shard) + assert num_files == 2 + # Should be more files generated for the entire dataset, which are then deleted as exist_ok is True + assert num_orig_files > num_files + + # Check exception is raised when exist_ok is False and local already exists + with pytest.raises(FileExistsError, match='Directory is not empty'): + with MDSWriter(out=local, columns=columns, size_limit=size_limit) as out: + out.write(dataset[0]) + class TestJSONWriter: @@ -177,6 +205,34 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s for before, after in zip(dataset, mds_dataset): assert before == after + def test_exist_ok(self, local_remote_dir: Tuple[str, str]) -> None: + num_samples = 1000 + size_limit = 4096 + local, _ = local_remote_dir + dataset = SequenceDataset(num_samples) + columns = dict(zip(dataset.column_names, dataset.column_encodings)) + + # Write entire dataset initially + with JSONWriter(out=local, columns=columns, size_limit=size_limit) as out: + for sample in dataset: + out.write(sample) + num_orig_files = len(os.listdir(local)) + + # Write single sample with exist_ok set to True + with JSONWriter(out=local, columns=columns, size_limit=size_limit, exist_ok=True) as out: + out.write(dataset[0]) + num_files = len(os.listdir(local)) + + # Three files for single sample (index.json, one shard, and one shard metadata) + assert num_files == 3 + # Should be more files generated for the entire dataset, which are then deleted as exist_ok is True + assert num_orig_files > num_files + + # Check exception is raised when exist_ok is False and local already exists + with pytest.raises(FileExistsError, match='Directory is not empty'): + with JSONWriter(out=local, columns=columns, size_limit=size_limit) as out: + out.write(dataset[0]) + class TestXSVWriter: @@ -256,3 +312,50 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s # Ensure sample iterator is deterministic for before, after in zip(dataset, mds_dataset): assert before == after + + @pytest.mark.parametrize('writer', [XSVWriter, TSVWriter, CSVWriter]) + def test_exist_ok(self, local_remote_dir: Tuple[str, str], writer: Any) -> None: + num_samples = 1000 + size_limit = 4096 + local, _ = local_remote_dir + dataset = SequenceDataset(num_samples) + columns = dict(zip(dataset.column_names, dataset.column_encodings)) + + # Write entire dataset initially + if writer.__name__ == XSVWriter.__name__: + with writer(out=local, columns=columns, size_limit=size_limit, separator=',') as out: + for sample in dataset: + out.write(sample) + else: + with writer(out=local, columns=columns, size_limit=size_limit) as out: + for sample in dataset: + out.write(sample) + num_orig_files = len(os.listdir(local)) + + # Write single sample with exist_ok set to True + if writer.__name__ == XSVWriter.__name__: + with writer(out=local, + columns=columns, + size_limit=size_limit, + separator=',', + exist_ok=True) as out: + out.write(dataset[0]) + else: + with writer(out=local, columns=columns, size_limit=size_limit, exist_ok=True) as out: + out.write(dataset[0]) + num_files = len(os.listdir(local)) + + # Three files for single sample (index.json, one shard, and one shard metadata) + assert num_files == 3 + # Should be more files generated for the entire dataset, which are then deleted as exist_ok is True + assert num_orig_files > num_files + + # Check exception is raised when exist_ok is False and local already exists + with pytest.raises(FileExistsError, match='Directory is not empty'): + if writer.__name__ == XSVWriter.__name__: + with writer(out=local, columns=columns, size_limit=size_limit, + separator=',') as out: + out.write(dataset[0]) + else: + with writer(out=local, columns=columns, size_limit=size_limit) as out: + out.write(dataset[0])