Skip to content

Commit

Permalink
Merge pull request #49383 from ashahab/abin-load-segfault-r2.5
Browse files Browse the repository at this point in the history
Resolves coredump caused by `tf.data.experimental.save` with prefetch
  • Loading branch information
mihaimaruseac authored Aug 6, 2021
2 parents 0539b34 + ac1fcf2 commit 6f39597
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tensorflow/core/kernels/data/experimental/io_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,11 @@ class LoadDatasetOp::Dataset : public DatasetBase {
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}

~Iterator() override { input_->Unref(); }
~Iterator() override {
if (input_) {
input_->Unref();
}
}

Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
Expand Down Expand Up @@ -331,7 +335,7 @@ class LoadDatasetOp::Dataset : public DatasetBase {
}

mutex mu_;
DatasetBase* input_ TF_GUARDED_BY(mu_);
DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
};
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/python/data/experimental/kernel_tests/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function

import numpy as np
import os
import shutil

Expand Down Expand Up @@ -111,6 +112,20 @@ def testOptionalElementSpec(self):
dataset_loaded = io.load(self._test_dir)
self.assertDatasetsEqual(dataset, dataset_loaded)

@combinations.generate(test_base.eager_only_combinations())
def testRepeatAndPrefetch(self):
"""This test reproduces github.com/tensorflow/tensorflow/issues/49165"""
dataset1 = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
io.save(dataset1, self._test_dir)
dataset = io.load(self._test_dir)
dataset = dataset.shuffle(buffer_size=16)
dataset = dataset.batch(16)
dataset = dataset.repeat()
dataset = dataset.prefetch(1)
next_element = self.getNext(dataset)
for _ in range(30):
self.evaluate(next_element())


if __name__ == "__main__":
test.main()

0 comments on commit 6f39597

Please sign in to comment.