Skip to content

Commit

Permalink
new dataset pipeline, work-in-progress
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jun 2, 2020
1 parent 15f5671 commit 9fb0dc7
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 32 deletions.
243 changes: 242 additions & 1 deletion TFDataPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class DataProviderBase(object):
Base class which wraps up the logic in this class. See derived classes.
"""

def __init__(self, extern_data, data_keys):
def __init__(self, extern_data, data_keys=None):
"""
:param ExternData extern_data:
:param set(str)|None data_keys:
Expand Down Expand Up @@ -454,3 +454,244 @@ def get_complete_frac(self):
"""
return self.batches.completed_frac()


class InputContext(object):
"""
This object will be passed to the dataset pipeline function
(``dataset_pipeline`` in the config)
and provides all relevant information, functions, dataset transformations.
The initial design of this class was discussed here:
https://github.com/rwth-i6/returnn/issues/292
"""

def __init__(self, config, dataset_name, returnn_dataset):
"""
:param Config.Config config:
:param str dataset_name: e.g. "train" or "dev"
:param Dataset returnn_dataset:
"""
self.config = config # TODO do we need that? maybe pass horovod and other context explicitly?
self.dataset_name = dataset_name
self.returnn_dataset = returnn_dataset # TODO this might be unset later (if the dataset lives in a separate proc)

self.num_dataset_producers = 1
self.num_dataset_consumers = 1

self.horovod_enabled = False
self.horovod_rank = None
self.horovod_size = None
if config.is_true("use_horovod"):
self.horovod_enabled = True
import horovod.tensorflow as hvd
self.horovod_rank = hvd.rank() # rank 0 is the chief
self.horovod_size = hvd.size()
self.num_dataset_consumers = self.horovod_size
raise NotImplementedError # TODO...

if config.is_true("distributed_tf"):
raise NotImplementedError # TODO...

# These will be set after init.
self.final_dataset = None # type: typing.Optional[tf.data.Dataset]
self.final_dataset_init_iterator = None # type: typing.Optional[tf.Operation]

def get_returnn_dataset(self):
"""
:return: the RETURNN :class:`Dataset` instances wrapped in a :class:`tf.data.Dataset`
:rtype: tensorflow.data.Dataset
"""
# TODO...
raise NotImplementedError

def padded_batch_dataset(self, dataset):
"""
:param tensorflow.data.Dataset dataset:
:rtype: tensorflow.data.Dataset
"""
raise NotImplementedError # TODO

def map_producer_to_consumer(self, dataset):
"""
:param tensorflow.data.Dataset dataset:
:rtype: tensorflow.data.Dataset
"""
raise NotImplementedError # TODO

def prefetch_to_consumer_device(self, dataset):
"""
:param tensorflow.data.Dataset dataset:
:rtype: tensorflow.data.Dataset
"""
raise NotImplementedError # TODO

def get_dataset_name(self):
"""
:return: e.g. "train" or "dev"
:rtype: str
"""
return self.dataset_name

def make_iterator_initializer(self, iterator):
"""
:param tensorflow.data.Iterator iterator:
:rtype: tf.Operation
"""
assert self.final_dataset
return iterator.make_initializer(self.final_dataset)


class DatasetDataProvider(DataProviderBase):
"""
Use a :class:`tf.data.Dataset` as input.
This will be used if ``dataset_pipeline`` is set in the config.
See the discussion about the new dataset pipeline (https://github.com/rwth-i6/returnn/issues/292).
Note that this has also a state: the current active dataset.
"""

def __init__(self, extern_data, config, datasets=None):
"""
:param ExternData extern_data:
:param list[str]|dict[str,Dataset|None]|None datasets: e.g. ["train", "dev"]
:param Config.Config config:
"""
super(DatasetDataProvider, self).__init__(extern_data=extern_data)
output_types = {} # type: typing.Dict[str,tf.DType]
output_shapes = {} # type: typing.Dict[str,tf.TensorShape]
for key, data in extern_data.data.items():
output_types[key] = tf.as_dtype(data.dtype)
output_shapes[key] = tf.TensorShape(data.batch_shape)
for axis_wo_b, dim in enumerate(data.shape):
if dim is None: # dynamic length -- need size info for it
size_key = "size:%s:%i" % (key, axis_wo_b)
output_types[size_key] = tf.as_dtype(data.size_dtype)
output_shapes[size_key] = tf.TensorShape([None]) # [Batch]
self.iterator = tf.data.Iterator.from_structure(output_types=output_types, output_shapes=output_shapes)
self.iterator_next_element = self.iterator.get_next()
for key, data in extern_data.data.items():
assert data.placeholder is None
assert not data.size_placeholder
data.placeholder = self.iterator_next_element[key]
assert isinstance(data.placeholder, tf.Tensor), "next: %r" % (self.iterator_next_element,)
data.size_placeholder = {}
for axis_wo_b, dim in enumerate(data.shape):
if dim is None: # dynamic length
size_key = "size:%s:%i" % (key, axis_wo_b)
data.size_placeholder[axis_wo_b] = self.iterator_next_element[size_key]
assert isinstance(data.size_placeholder[axis_wo_b], tf.Tensor), "next: %r" % (self.iterator_next_element,)

dataset_pipeline_func = config.typed_value("dataset_pipeline")
if dataset_pipeline_func in [True, 1]:
dataset_pipeline_func = self._dataset_pipeline_default
assert callable(dataset_pipeline_func)

if datasets is None or not datasets: # e.g. in distributed TF
# We don't use them here. These will be used by the dataset loader producer workers.
datasets = []
if config.is_true("train"):
datasets.append("train")
if config.is_true("dev"):
datasets.append("dev")
if config.is_true("eval"):
datasets.append("eval")
if config.has("eval_datasets"):
datasets.append(sorted(config.typed_value("eval_datasets", {}).keys()))
if isinstance(datasets, (list, tuple)):
datasets = {name: None for name in datasets}
self.datasets = datasets # type: typing.Dict[str,typing.Optional[Dataset]]
for dataset_name in datasets:
context = InputContext(dataset_name=dataset_name, config=config) # TODO
dataset = dataset_pipeline_func(context)
assert isinstance(dataset, tf.data.Dataset)
context.final_dataset = dataset
context.final_dataset_init_iterator = context.make_iterator_initializer(self.iterator)

self.current_dataset_name = None # type: typing.Optional[str]

def set_current_dataset(self, dataset_name):
"""
:param str dataset_name:
"""
self.current_dataset_name = dataset_name

def start_threads(self):
"""
Start background threads.
Currently a no-op. All/any background threads of tf.data are started automatically when needed.
"""
# TODO actually it might be nice to start them explicitly in advance...
# I think this is currently not possible though.
# With a custom final prefetcher (see comment in have_reached_end), this would be possible.

def stop_threads(self):
"""
Stop background threads (e.g. prefetching).
(Currently a no-op.)
"""
# I don't think this is currently possible. See e.g.:
# https://stackoverflow.com/questions/62148052/how-to-stop-background-thread-of-prefetchdataset
# Anyway, maybe not relevant.
# We just should make sure that any access on the RETURNN dataset is save.

def have_more_data(self, session):
"""
:param tf.Session session:
:return: whether the next session.run() can run in the current epoch & dataset
:rtype: bool
"""
# See have_reached_end.
assert self.current_dataset_name
return True

def get_feed_dict(self, single_threaded=False):
"""
:param bool single_threaded: whether to not use the queue (arg name is slightly misleading)
:returns: batch,meta
:rtype: dict[tf.Tensor,tf.Tensor],dict[str]
"""
assert self.current_dataset_name
assert not single_threaded
return {}, {}

def have_reached_end(self):
"""
:rtype: bool
"""
# we will just raise tf.errors.OutOfRangeError otherwise
# TODO: horovod sync on this is likely broken then...
# TODO we could also have sth like an own custom PrefetchDataset in between,
# which runs a background thread which always prefetches elements,
# and an extra function to check whether we reached the end
# (would block if not the case, and not prefetched yet).
assert self.current_dataset_name
return False

def get_dataset_name(self):
"""
:return: current dataset name, e.g. "train" or "dev"
:rtype: str
"""
assert self.current_dataset_name
return self.current_dataset_name

def get_complete_frac(self):
"""
:return: by how much we are through the current dataset, number between 0 and 1, for visual feedback
:rtype: float
"""
# TODO ... this is somewhat tricky...
# we would need some IPC to the original RETURNN dataset...
return 0.

def _dataset_pipeline_default(self, context):
"""
:param InputContext context:
:rtype: tensorflow.data.Dataset
"""
dataset = context.get_returnn_dataset()
dataset = context.padded_batch_dataset(dataset)
dataset = context.map_producer_to_consumer(dataset)
dataset = context.prefetch_to_consumer_device(dataset)
return dataset
Loading

0 comments on commit 9fb0dc7

Please sign in to comment.