Skip to content

Commit

Permalink
Add a ShardDatasetBuilder that creates shards directly.
Browse files Browse the repository at this point in the history
In certain cases, users have data available in different shards and they want to keep the same number of shards and in each shard the same order of examples (or they don't care about the ordering). In that case, our current dataset builder classes are much slower than necessary.

The `ShardBasedBuilder` allows users to create dataset builders that process source data shard by shard. It can be run with or without Beam. In case of Beam, the resulting Beam pipeline is significantly simpler and therefore faster.

PiperOrigin-RevId: 676820294
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Sep 20, 2024
1 parent 3b0dab2 commit f3e94fa
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 15 deletions.
94 changes: 81 additions & 13 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

import abc
import collections
from collections.abc import Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence
import dataclasses
import functools
import inspect
import json
import os
import sys
from typing import Any, ClassVar, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union

from absl import logging
from etils import epy
Expand Down Expand Up @@ -1445,6 +1445,17 @@ def builder_configs(cls) -> dict[str, BuilderConfig]:
)
return config_dict

def _get_filename_template(
self, split_name: str
) -> naming.ShardedFileTemplate:
"""Returns a filename template for the given split."""
return naming.ShardedFileTemplate(
split=split_name,
dataset_name=self.name,
data_dir=self.data_path,
filetype_suffix=self.info.file_format.file_suffix, # pytype: disable=attribute-error
)


class FileReaderBuilder(DatasetBuilder):
"""Base class for datasets reading files.
Expand Down Expand Up @@ -1675,17 +1686,6 @@ def _example_writer(self) -> writer_lib.ExampleWriter:
"""
return writer_lib.ExampleWriter(file_format=self.info.file_format)

def _get_filename_template(
self, split_name: str
) -> naming.ShardedFileTemplate:
"""Returns a filename template for the given split."""
return naming.ShardedFileTemplate(
split=split_name,
dataset_name=self.name,
data_dir=self.data_path,
filetype_suffix=self.info.file_format.file_suffix, # pytype: disable=attribute-error
)

def _generate_splits(
self,
dl_manager: download.DownloadManager,
Expand Down Expand Up @@ -1852,6 +1852,74 @@ def read_tfrecord_beam(
)


class ShardBasedBuilder(FileReaderBuilder):
"""Base class for datasets with data generated shard by shard."""

def _download_and_prepare(
self,
dl_manager: download.DownloadManager,
download_config: download.DownloadConfig | None = None,
) -> None:
download_config = download_config or download.DownloadConfig()

split_builder = split_builder_lib.SplitBuilder(
split_dict=self.info.splits,
features=self.info.features,
dataset_size=self.info.dataset_size,
beam_options=download_config.beam_options,
beam_runner=download_config.beam_runner,
example_writer=self._example_writer(),
# The following options are ignored by `ShardBasedBuilder`.
ignore_duplicates=None,
max_examples_per_split=None,
shard_config=None,
)

shard_iterators_per_split = self._shard_iterators_per_split(dl_manager)
split_info_futures = []
for split_name, example_gen_per_shard in shard_iterators_per_split.items():
logging.info("Generating split %s", split_name)
split_info_future = split_builder.submit_shard_based_generation(
split_name=split_name,
example_gen_per_shard=example_gen_per_shard,
filename_template=self._get_filename_template(split_name=split_name),
)
split_info_futures.append(split_info_future)

# Update the info object with the splits.
split_infos: list[splits_lib.SplitInfo] = [
future.result() for future in split_info_futures
]
split_dict = splits_lib.SplitDict(split_infos)
self.info.set_splits(split_dict)

@abc.abstractmethod
@utils.docs.do_not_doc_in_subclasses
@utils.docs.doc_private
def _shard_iterators_per_split(
self, dl_manager: download.DownloadManager
) -> Mapping[str, Sequence[split_builder_lib.ExampleGeneratorFn]]:
"""Returns a mapping from split name to example generators per shard.
The example generators are functions that take no parameters and return
an iterator of tuples of key + example. The order of the example generators
is the order in which the shards will be written.
Args:
dl_manager: `tfds.download.DownloadManager` used to download/extract the
data.
"""
raise NotImplementedError()

def _example_writer(self) -> writer_lib.ExampleWriter:
"""Returns an example writer.
If datasets should be written to a custom storage, e.g., a database, then
implement a custom `ExampleWriter` and inject it here.
"""
return writer_lib.ExampleWriter(file_format=self.info.file_format)


@utils.docs.deprecated
class BeamBasedBuilder(GeneratorBasedBuilder):
"""Beam based Builder.
Expand Down
47 changes: 47 additions & 0 deletions tensorflow_datasets/core/dataset_builder_beam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""Tests for tensorflow_datasets.core.dataset_builder."""

from collections.abc import Iterator, Mapping, Sequence
import functools
import pathlib
from typing import Callable
from unittest import mock
Expand Down Expand Up @@ -102,6 +104,31 @@ def _generate_examples(self, examples, num_examples):
return examples


class ShardBuilder(dataset_builder.ShardBasedBuilder):
VERSION = utils.Version('0.0.1')

def _info(self):
return dataset_info.DatasetInfo(
builder=self,
features=features.FeaturesDict({'x': np.int64}),
)

def _shard_iterators_per_split(self, dl_manager):
del dl_manager

def gen_examples(start: int, end: int):
for i in range(start, end):
yield i, {'x': i}

return {
'train': [
functools.partial(gen_examples, start=0, end=10),
functools.partial(gen_examples, start=10, end=20),
],
'test': [functools.partial(gen_examples, start=100, end=110)],
}


def _gen_example(x):
return (
x,
Expand Down Expand Up @@ -198,6 +225,26 @@ def _assert_values_equal(nested_lhs, nested_rhs):
np.testing.assert_array_equal(lhs, rhs)


@pytest.mark.parametrize(
'make_dl_config',
[
make_default_config,
],
)
def test_beam_shard_builder_dataset(
tmp_path: pathlib.Path,
make_dl_config: Callable[[], download.DownloadConfig],
):
builder = ShardBuilder(data_dir=tmp_path, version='0.0.1')
builder.download_and_prepare(
file_format='array_record', download_config=make_dl_config()
)
actual_train_data = list(builder.as_data_source(split='train'))
assert actual_train_data == [{'x': i} for i in range(20)]
actual_test_data = list(builder.as_data_source(split='test'))
assert actual_test_data == [{'x': i} for i in range(100, 110)]


def test_read_tfrecord_beam():
builder = DummyBeamDataset()
with mock.patch.object(
Expand Down
48 changes: 48 additions & 0 deletions tensorflow_datasets/core/dataset_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

"""Tests for tensorflow_datasets.core.dataset_builder."""

from collections.abc import Iterator, Mapping, Sequence
import dataclasses
import functools
import os
import tempfile
from unittest import mock
Expand All @@ -37,9 +39,11 @@
from tensorflow_datasets.core import load
from tensorflow_datasets.core import naming
from tensorflow_datasets.core import read_only_builder
from tensorflow_datasets.core import split_builder
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.data_sources import array_record
from tensorflow_datasets.core.download import download_manager
from tensorflow_datasets.core.utils import file_utils
from tensorflow_datasets.core.utils import read_config as read_config_lib
from tensorflow_datasets.testing.dummy_config_based_datasets.dummy_ds_1 import dummy_ds_1_dataset_builder
Expand Down Expand Up @@ -147,6 +151,50 @@ def _split_generators(self, _):
return {"all": self._generate_examples(range(5))}


class ShardBuilder(dataset_builder.ShardBasedBuilder):
VERSION = utils.Version("0.0.1")
BUILDER_CONFIGS = [DummyBuilderConfig(name="cfg1")]

def _info(self):
return dataset_info.DatasetInfo(
builder=self,
features=features.FeaturesDict({"x": np.int64}),
)

def _shard_iterators_per_split(
self, dl_manager: download_manager.DownloadManager
) -> Mapping[str, Sequence[Iterator[split_builder.KeyExample]]]:
del dl_manager

def gen_examples(
start: int, end: int
) -> Iterator[split_builder.KeyExample]:
for i in range(start, end):
yield i, {"x": i}

return {
# train split has 2 shards
"train": [
functools.partial(gen_examples, start=0, end=10),
functools.partial(gen_examples, start=10, end=20),
],
"test": [functools.partial(gen_examples, start=100, end=110)],
}


class ShardBuilderTest(testing.TestCase):

def test_download_and_prepare(self):
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
builder = ShardBuilder(data_dir=tmp_dir, config="cfg1", version="0.0.1")
builder.download_and_prepare(file_format="array_record")
actual_data = list(builder.as_data_source(split="train"))
self.assertEqual(
actual_data,
[{"x": i} for i in range(20)],
)


class GetBuilderDatadirPathTest(testing.TestCase):

def test_builder_data_dir_path_is_correct(self):
Expand Down
94 changes: 93 additions & 1 deletion tensorflow_datasets/core/split_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@

"""Dataset generator code."""

from collections.abc import Iterable, Iterator
from collections.abc import Iterable, Iterator, Sequence
import contextlib
import dataclasses
import functools
import itertools
import json
import sys
from typing import Any, Callable, Optional, Union

from absl import logging
from etils import epath
from tensorflow_datasets.core import example_serializer
from tensorflow_datasets.core import features as features_lib
from tensorflow_datasets.core import naming
Expand All @@ -49,6 +51,7 @@
'beam.PTransform',
'beam.PCollection[KeyExample]',
]
ExampleGeneratorFn = Callable[[], Iterator[KeyExample]]


@utils.docs.deprecated
Expand Down Expand Up @@ -147,6 +150,95 @@ def __init__(
self._ignore_duplicates = ignore_duplicates
self._example_writer = example_writer

def submit_shard_based_generation(
self,
split_name: str,
filename_template: naming.ShardedFileTemplate,
example_gen_per_shard: Sequence[ExampleGeneratorFn],
) -> _SplitInfoFuture:
"""Creates the shards for the split with the given example generators.
If a Beam runner was added when initializing the `SplitBuilder`, then
the `example_gen_per_shard` will be run in parallel using Beam. Otherwise,
they will be run sequentially in the current process.
Args:
split_name: Name of the split to generate
filename_template: Template to format the filename for a shard.
example_gen_per_shard: List of example generators, one per shard. Must be
in the same order as the shards.
Returns:
a future with the split info.
"""
num_shards = len(example_gen_per_shard)
filename_template = filename_template.replace(split=split_name)
serialized_info = self._features.get_serialized_info()
serializer = example_serializer.ExampleSerializer(serialized_info)

shard_writer = writer_lib.ShardWriter(
serializer=serializer,
example_writer=self._example_writer,
)

shard_paths = []
shard_lengths = []
if self._beam_runner is None:
for shard_index, example_gen in enumerate(example_gen_per_shard):
shard_path = filename_template.sharded_filepath(
shard_index=shard_index, num_shards=num_shards
)
shard_paths.append(shard_path)
num_examples = shard_writer.write(
path=shard_path, examples=example_gen()
)
shard_lengths.append(num_examples)
else:
shard_infos_path = filename_template.data_dir / 'shard_infos.json'
with self.maybe_beam_pipeline():
shard_infos = []
for shard_index, example_gen in enumerate(example_gen_per_shard):
shard_path = filename_template.sharded_filepath(
shard_index=shard_index, num_shards=num_shards
)
shard_paths.append(shard_path)
shard_info = shard_writer.write_with_beam(
path=shard_path,
example_gen=example_gen,
shard_index=shard_index,
pipeline=self.beam_pipeline,
)
shard_infos.append(shard_info)

def write_shard_infos(
shard_infos: list[tuple[int, int]], path: epath.Path
) -> None:
shard_infos_dict = {index: length for index, length in shard_infos}
path.write_text(json.dumps(shard_infos_dict))

_ = (
shard_infos
| f'FlattenShardInfos_{split_name}' >> beam.Flatten()
| f'CombineShardInfos_{split_name}'
>> beam.CombineGlobally(beam.combiners.ToListCombineFn())
| f'WriteShardInfos_{split_name}'
>> beam.Map(write_shard_infos, path=shard_infos_path)
)

shard_infos_dict = json.loads(shard_infos_path.read_text())
shard_lengths = [
num_examples for _, num_examples in sorted(shard_infos_dict.items())
]

total_size = sum([shard_path.stat().length for shard_path in shard_paths])
split_info = splits_lib.SplitInfo(
name=split_name,
shard_lengths=shard_lengths,
num_bytes=total_size,
filename_template=filename_template,
)
return _SplitInfoFuture(lambda: split_info)

@contextlib.contextmanager
def maybe_beam_pipeline(self) -> Iterator[PipelineProxy]:
"""Context manager wrapping the beam pipeline.
Expand Down
Loading

0 comments on commit f3e94fa

Please sign in to comment.