From 9fb0dc7c74d74d0ccf28e9773af5c5d6f136c0a2 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 2 Jun 2020 13:33:49 +0200 Subject: [PATCH] new dataset pipeline, work-in-progress https://github.com/rwth-i6/returnn/issues/292 --- TFDataPipeline.py | 243 ++++++++++++++++++++++++++++++++++- TFEngine.py | 78 +++++++---- TFNetwork.py | 5 +- tests/test_TF_determinism.py | 2 +- 4 files changed, 296 insertions(+), 32 deletions(-) diff --git a/TFDataPipeline.py b/TFDataPipeline.py index 46c7ca4d46..6c79504e21 100644 --- a/TFDataPipeline.py +++ b/TFDataPipeline.py @@ -137,7 +137,7 @@ class DataProviderBase(object): Base class which wraps up the logic in this class. See derived classes. """ - def __init__(self, extern_data, data_keys): + def __init__(self, extern_data, data_keys=None): """ :param ExternData extern_data: :param set(str)|None data_keys: @@ -454,3 +454,244 @@ def get_complete_frac(self): """ return self.batches.completed_frac() + +class InputContext(object): + """ + This object will be passed to the dataset pipeline function + (``dataset_pipeline`` in the config) + and provides all relevant information, functions, dataset transformations. + + The initial design of this class was discussed here: + https://github.com/rwth-i6/returnn/issues/292 + """ + + def __init__(self, config, dataset_name, returnn_dataset): + """ + :param Config.Config config: + :param str dataset_name: e.g. "train" or "dev" + :param Dataset returnn_dataset: + """ + self.config = config # TODO do we need that? maybe pass horovod and other context explicitly? + self.dataset_name = dataset_name + self.returnn_dataset = returnn_dataset # TODO this might be unset later (if the dataset lives in a separate proc) + + self.num_dataset_producers = 1 + self.num_dataset_consumers = 1 + + self.horovod_enabled = False + self.horovod_rank = None + self.horovod_size = None + if config.is_true("use_horovod"): + self.horovod_enabled = True + import horovod.tensorflow as hvd + self.horovod_rank = hvd.rank() # rank 0 is the chief + self.horovod_size = hvd.size() + self.num_dataset_consumers = self.horovod_size + raise NotImplementedError # TODO... + + if config.is_true("distributed_tf"): + raise NotImplementedError # TODO... + + # These will be set after init. + self.final_dataset = None # type: typing.Optional[tf.data.Dataset] + self.final_dataset_init_iterator = None # type: typing.Optional[tf.Operation] + + def get_returnn_dataset(self): + """ + :return: the RETURNN :class:`Dataset` instances wrapped in a :class:`tf.data.Dataset` + :rtype: tensorflow.data.Dataset + """ + # TODO... + raise NotImplementedError + + def padded_batch_dataset(self, dataset): + """ + :param tensorflow.data.Dataset dataset: + :rtype: tensorflow.data.Dataset + """ + raise NotImplementedError # TODO + + def map_producer_to_consumer(self, dataset): + """ + :param tensorflow.data.Dataset dataset: + :rtype: tensorflow.data.Dataset + """ + raise NotImplementedError # TODO + + def prefetch_to_consumer_device(self, dataset): + """ + :param tensorflow.data.Dataset dataset: + :rtype: tensorflow.data.Dataset + """ + raise NotImplementedError # TODO + + def get_dataset_name(self): + """ + :return: e.g. "train" or "dev" + :rtype: str + """ + return self.dataset_name + + def make_iterator_initializer(self, iterator): + """ + :param tensorflow.data.Iterator iterator: + :rtype: tf.Operation + """ + assert self.final_dataset + return iterator.make_initializer(self.final_dataset) + + +class DatasetDataProvider(DataProviderBase): + """ + Use a :class:`tf.data.Dataset` as input. + This will be used if ``dataset_pipeline`` is set in the config. + See the discussion about the new dataset pipeline (https://github.com/rwth-i6/returnn/issues/292). + + Note that this has also a state: the current active dataset. + """ + + def __init__(self, extern_data, config, datasets=None): + """ + :param ExternData extern_data: + :param list[str]|dict[str,Dataset|None]|None datasets: e.g. ["train", "dev"] + :param Config.Config config: + """ + super(DatasetDataProvider, self).__init__(extern_data=extern_data) + output_types = {} # type: typing.Dict[str,tf.DType] + output_shapes = {} # type: typing.Dict[str,tf.TensorShape] + for key, data in extern_data.data.items(): + output_types[key] = tf.as_dtype(data.dtype) + output_shapes[key] = tf.TensorShape(data.batch_shape) + for axis_wo_b, dim in enumerate(data.shape): + if dim is None: # dynamic length -- need size info for it + size_key = "size:%s:%i" % (key, axis_wo_b) + output_types[size_key] = tf.as_dtype(data.size_dtype) + output_shapes[size_key] = tf.TensorShape([None]) # [Batch] + self.iterator = tf.data.Iterator.from_structure(output_types=output_types, output_shapes=output_shapes) + self.iterator_next_element = self.iterator.get_next() + for key, data in extern_data.data.items(): + assert data.placeholder is None + assert not data.size_placeholder + data.placeholder = self.iterator_next_element[key] + assert isinstance(data.placeholder, tf.Tensor), "next: %r" % (self.iterator_next_element,) + data.size_placeholder = {} + for axis_wo_b, dim in enumerate(data.shape): + if dim is None: # dynamic length + size_key = "size:%s:%i" % (key, axis_wo_b) + data.size_placeholder[axis_wo_b] = self.iterator_next_element[size_key] + assert isinstance(data.size_placeholder[axis_wo_b], tf.Tensor), "next: %r" % (self.iterator_next_element,) + + dataset_pipeline_func = config.typed_value("dataset_pipeline") + if dataset_pipeline_func in [True, 1]: + dataset_pipeline_func = self._dataset_pipeline_default + assert callable(dataset_pipeline_func) + + if datasets is None or not datasets: # e.g. in distributed TF + # We don't use them here. These will be used by the dataset loader producer workers. + datasets = [] + if config.is_true("train"): + datasets.append("train") + if config.is_true("dev"): + datasets.append("dev") + if config.is_true("eval"): + datasets.append("eval") + if config.has("eval_datasets"): + datasets.append(sorted(config.typed_value("eval_datasets", {}).keys())) + if isinstance(datasets, (list, tuple)): + datasets = {name: None for name in datasets} + self.datasets = datasets # type: typing.Dict[str,typing.Optional[Dataset]] + for dataset_name in datasets: + context = InputContext(dataset_name=dataset_name, config=config) # TODO + dataset = dataset_pipeline_func(context) + assert isinstance(dataset, tf.data.Dataset) + context.final_dataset = dataset + context.final_dataset_init_iterator = context.make_iterator_initializer(self.iterator) + + self.current_dataset_name = None # type: typing.Optional[str] + + def set_current_dataset(self, dataset_name): + """ + :param str dataset_name: + """ + self.current_dataset_name = dataset_name + + def start_threads(self): + """ + Start background threads. + + Currently a no-op. All/any background threads of tf.data are started automatically when needed. + """ + # TODO actually it might be nice to start them explicitly in advance... + # I think this is currently not possible though. + # With a custom final prefetcher (see comment in have_reached_end), this would be possible. + + def stop_threads(self): + """ + Stop background threads (e.g. prefetching). + (Currently a no-op.) + """ + # I don't think this is currently possible. See e.g.: + # https://stackoverflow.com/questions/62148052/how-to-stop-background-thread-of-prefetchdataset + # Anyway, maybe not relevant. + # We just should make sure that any access on the RETURNN dataset is save. + + def have_more_data(self, session): + """ + :param tf.Session session: + :return: whether the next session.run() can run in the current epoch & dataset + :rtype: bool + """ + # See have_reached_end. + assert self.current_dataset_name + return True + + def get_feed_dict(self, single_threaded=False): + """ + :param bool single_threaded: whether to not use the queue (arg name is slightly misleading) + :returns: batch,meta + :rtype: dict[tf.Tensor,tf.Tensor],dict[str] + """ + assert self.current_dataset_name + assert not single_threaded + return {}, {} + + def have_reached_end(self): + """ + :rtype: bool + """ + # we will just raise tf.errors.OutOfRangeError otherwise + # TODO: horovod sync on this is likely broken then... + # TODO we could also have sth like an own custom PrefetchDataset in between, + # which runs a background thread which always prefetches elements, + # and an extra function to check whether we reached the end + # (would block if not the case, and not prefetched yet). + assert self.current_dataset_name + return False + + def get_dataset_name(self): + """ + :return: current dataset name, e.g. "train" or "dev" + :rtype: str + """ + assert self.current_dataset_name + return self.current_dataset_name + + def get_complete_frac(self): + """ + :return: by how much we are through the current dataset, number between 0 and 1, for visual feedback + :rtype: float + """ + # TODO ... this is somewhat tricky... + # we would need some IPC to the original RETURNN dataset... + return 0. + + def _dataset_pipeline_default(self, context): + """ + :param InputContext context: + :rtype: tensorflow.data.Dataset + """ + dataset = context.get_returnn_dataset() + dataset = context.padded_batch_dataset(dataset) + dataset = context.map_producer_to_consumer(dataset) + dataset = context.prefetch_to_consumer_device(dataset) + return dataset diff --git a/TFEngine.py b/TFEngine.py index e0592054b9..e02dbe5bfb 100644 --- a/TFEngine.py +++ b/TFEngine.py @@ -32,8 +32,9 @@ from LearningRateControl import load_learning_rate_control_from_config, LearningRateControl from Log import log from Pretrain import pretrain_from_config -from TFNetwork import TFNetwork, help_on_tf_exception +from TFNetwork import TFNetwork, ExternData, help_on_tf_exception from TFUpdater import Updater +from TFDataPipeline import FeedDictDataProvider, DatasetDataProvider from Util import hms, NumbersDict, BackendEngine from pprint import pprint @@ -50,12 +51,15 @@ class Runner(object): """ # noinspection PyShadowingBuiltins - def __init__(self, engine, dataset, batches, train, eval=True, train_flag=None, + def __init__(self, engine, + dataset_name=None, dataset=None, batches=None, + train=False, eval=True, train_flag=None, extra_fetches=None, extra_fetches_callback=None): """ :param Engine engine: - :param Dataset.Dataset dataset: - :param BatchSetGenerator batches: + :param str|None dataset_name: "train", "dev" or so + :param Dataset.Dataset|None dataset: + :param BatchSetGenerator|None batches: :param bool train: whether to do updates on the model :param bool|None train_flag: normally just as train. but e.g. maybe you want to have the train_flag but not train :param bool eval: whether to evaluate (i.e. calculate loss/error) @@ -71,7 +75,7 @@ def __init__(self, engine, dataset, batches, train, eval=True, train_flag=None, dataset=dataset, used_data_keys=engine.network.get_used_data_keys()) self.engine = engine # noinspection PyProtectedMember - self.data_provider = self.engine._get_new_data_provider(dataset=dataset, batches=batches) + self.data_provider = self.engine._get_data_provider(dataset_name=dataset_name, dataset=dataset, batches=batches) assert isinstance(self.data_provider, DataProviderBase) if train_flag is None: train_flag = train @@ -688,6 +692,7 @@ def __init__(self, config=None): self._checked_uninitialized_vars = False self._merge_all_summaries = None self.dataset_batches = {} # type: typing.Dict[str,BatchSetGenerator] + self.dataset_provider = None # type: typing.Optional[DatasetDataProvider] self.train_data = None # type: typing.Optional[Dataset] self.eval_datasets = {} # type: typing.Dict[str,Dataset] self.start_epoch = None # type: typing.Optional[int] @@ -1113,12 +1118,19 @@ def _init_network(self, net_desc, epoch=None): train_flag = get_global_train_flag_placeholder() else: train_flag = False - # if False: # TODO ... - # extern_data = ExternData() - # extern_data.init_from_config(self.config) - # TODO... + use_dataset_pipeline = False + if self.config.is_true("dataset_pipeline"): + use_dataset_pipeline = True + extern_data = ExternData() + extern_data.init_from_config(config=self.config, auto_create_placeholders=not use_dataset_pipeline) + if use_dataset_pipeline: + datasets = self.eval_datasets.copy() + if self.train_data: + datasets["train"] = self.train_data + self.dataset_provider = DatasetDataProvider(extern_data=extern_data, datasets=datasets, config=self.config) self.network, self.updater = self.create_network( config=self.config, + extern_data=extern_data, rnd_seed=net_random_seed, train_flag=train_flag, eval_flag=self.use_eval_flag, search_flag=self.use_search_flag, initial_learning_rate=getattr(self, "initial_learning_rate", None), @@ -1135,7 +1147,8 @@ def _init_network(self, net_desc, epoch=None): self.tf_session.run(bcast_op) @classmethod - def create_network(cls, config, rnd_seed, train_flag, eval_flag, search_flag, net_dict, initial_learning_rate=1.0): + def create_network(cls, config, rnd_seed, train_flag, eval_flag, search_flag, net_dict, + extern_data=None, initial_learning_rate=1.0): """ :param Config.Config config: :param int rnd_seed: @@ -1143,6 +1156,7 @@ def create_network(cls, config, rnd_seed, train_flag, eval_flag, search_flag, ne :param float initial_learning_rate: :param bool eval_flag: :param bool search_flag: + :param ExternData|None extern_data: :param dict[str,dict[str]] net_dict: :return: network, updater :rtype: (TFNetwork, Updater|None) @@ -1150,6 +1164,7 @@ def create_network(cls, config, rnd_seed, train_flag, eval_flag, search_flag, ne network = TFNetwork( name="root", config=config, + extern_data=extern_data, rnd_seed=rnd_seed, train_flag=train_flag, eval_flag=eval_flag, @@ -1804,25 +1819,32 @@ def check_uninitialized_vars(self): self.tf_session.run(tf.variables_initializer(uninitialized_vars)) self._checked_uninitialized_vars = True - def _get_new_data_provider(self, dataset, batches): + def _get_data_provider(self, dataset_name=None, dataset=None, batches=None, feed_dict=None): """ - :param Dataset.Dataset dataset: - :param BatchSetGenerator batches: - :rtype: TFDataPipeline.FeedDictDataProvider + :param str|None dataset_name: + :param Dataset.Dataset|None dataset: + :param BatchSetGenerator|None batches: + :param bool|None feed_dict: + :rtype: FeedDictDataProvider|DatasetDataProvider """ - batch_slice = None - if self.config.is_true("use_horovod"): - # noinspection PyPackageRequirements,PyUnresolvedReferences - import horovod.tensorflow as hvd - batch_slice = slice(hvd.rank(), None, hvd.size()) - from TFDataPipeline import FeedDictDataProvider - data_provider = FeedDictDataProvider( - tf_session=self.tf_session, extern_data=self.network.extern_data, - data_keys=self.network.get_used_data_keys(), - dataset=dataset, batches=batches, - batch_slice=batch_slice, - enforce_min_len1=self.config.is_true("enforce_min_len1", False)) - return data_provider + if self.dataset_provider and feed_dict is not True and dataset_name: + self.dataset_provider.set_current_dataset(dataset_name=dataset_name) + return self.dataset_provider + else: + if self.dataset_provider and feed_dict is not False: + print("WARNING: dataset_provider is set (via dataset_pipeline) but not used", file=log.v2) + batch_slice = None + if self.config.is_true("use_horovod"): + # noinspection PyPackageRequirements,PyUnresolvedReferences + import horovod.tensorflow as hvd + batch_slice = slice(hvd.rank(), None, hvd.size()) + data_provider = FeedDictDataProvider( + tf_session=self.tf_session, extern_data=self.network.extern_data, + data_keys=self.network.get_used_data_keys(), + dataset=dataset, batches=batches, + batch_slice=batch_slice, + enforce_min_len1=self.config.is_true("enforce_min_len1", False)) + return data_provider def get_specific_feed_dict(self, dataset, seq_idx): """ @@ -1843,7 +1865,7 @@ def get_specific_feed_dict(self, dataset, seq_idx): batch.init_with_one_full_sequence(seq_idx=seq_idx, dataset=dataset) batch_generator = iter([batch]) batches = BatchSetGenerator(dataset, generator=batch_generator) - data_provider = self._get_new_data_provider(dataset=dataset, batches=batches) + data_provider = self._get_data_provider(dataset=dataset, batches=batches, feed_dict=True) feed_dict, _ = data_provider.get_feed_dict(single_threaded=True) return feed_dict diff --git a/TFNetwork.py b/TFNetwork.py index b6d4b9154c..f9ccb4f91a 100644 --- a/TFNetwork.py +++ b/TFNetwork.py @@ -38,9 +38,10 @@ def __init__(self, data=None, default_input="data", default_target="classes"): def __repr__(self): return "" % self.data - def init_from_config(self, config): + def init_from_config(self, config, auto_create_placeholders=True): """ :param Config.Config config: + :param bool auto_create_placeholders: """ self._config = config from NetworkDescription import LayerNetworkDescription @@ -50,7 +51,7 @@ def init_from_config(self, config): # In TensorFlow, the default is (batch,time,feature). # This is also what we use here, i.e.: # batch_dim_axis=0, time_dim_axis=1. See TFEngine.DataProvider._get_next_batch(). - self.data[key] = Data(name=key, auto_create_placeholders=True, **init_args) + self.data[key] = Data(name=key, auto_create_placeholders=auto_create_placeholders, **init_args) self.default_target = config.value('target', 'classes') @classmethod diff --git a/tests/test_TF_determinism.py b/tests/test_TF_determinism.py index 5d0cb802d3..3597b172b7 100644 --- a/tests/test_TF_determinism.py +++ b/tests/test_TF_determinism.py @@ -68,7 +68,7 @@ def create_engine(): return engine def train_engine_fetch_vars(engine): - data_provider = engine._get_new_data_provider(dataset=engine.train_data, batches=engine.train_batches) + data_provider = engine._get_data_provider(dataset=engine.train_data, batches=engine.train_batches, feed_dict=True) feed_dict, _ = data_provider.get_feed_dict(single_threaded=True) trainer = Runner(engine=engine, dataset=engine.train_data, batches=engine.train_batches, train=True) feed_dict, _ = data_provider.get_feed_dict(single_threaded=True)