From 8076cd96e9bd0a83e98a3f5a69995c1531786b86 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 13 Jul 2019 21:49:32 +0000 Subject: [PATCH 1/6] Add rebatch method for Dataset This PR adds rebatch method for Dataset where ``` dataset.apply(rebatch(n)) = dataset.unbatch().batch(n) ``` The motivation for rebatch is that there are situations we read the data in big batches but then we want to adjust the batch size to fit differnet scenarios. Signed-off-by: Yong Tang --- tensorflow_io/core/BUILD | 17 ++ .../core/kernels/rebatch_dataset_op.cc | 250 ++++++++++++++++++ tensorflow_io/core/ops/core_ops.cc | 37 +++ tensorflow_io/core/python/ops/data_ops.py | 34 +++ tests/test_text_eager.py | 8 + 5 files changed, 346 insertions(+) create mode 100644 tensorflow_io/core/kernels/rebatch_dataset_op.cc create mode 100644 tensorflow_io/core/ops/core_ops.cc diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 94fed00ca..4ecacb2d7 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -57,6 +57,22 @@ cc_library( ], ) +cc_library( + name = "core_ops", + srcs = [ + "kernels/rebatch_dataset_op.cc", + "ops/core_ops.cc", + ], + copts = tf_io_copts(), + includes = [ + ".", + ], + linkstatic = True, + deps = [ + ":dataset_ops", + ], +) + cc_library( name = "ffmpeg_3.4", srcs = [ @@ -107,6 +123,7 @@ cc_binary( copts = tf_io_copts(), linkshared = 1, deps = [ + ":core_ops", "//tensorflow_io/audio:audio_ops", "//tensorflow_io/avro:avro_ops", "//tensorflow_io/azure:azfs_ops", diff --git a/tensorflow_io/core/kernels/rebatch_dataset_op.cc b/tensorflow_io/core/kernels/rebatch_dataset_op.cc new file mode 100644 index 000000000..b1d908505 --- /dev/null +++ b/tensorflow_io/core/kernels/rebatch_dataset_op.cc @@ -0,0 +1,250 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/batch_util.h" + +namespace tensorflow { +namespace data { +namespace { + +class RebatchDatasetOp : public UnaryDatasetOpKernel { + public: + explicit RebatchDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + int64 batch_size = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "batch_size", &batch_size)); + OP_REQUIRES( + ctx, batch_size > 0, + errors::InvalidArgument("Batch size must be greater than zero.")); + + string batch_mode = ""; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "batch_mode", &batch_mode)); + + *output = + new Dataset(ctx, batch_size, batch_mode, input); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, int64 batch_size, string batch_mode, + const DatasetBase* input) + : DatasetBase(DatasetContext(ctx)), + batch_size_(batch_size), + batch_mode_(batch_mode), + input_(input) { + input_->Ref(); + + const auto& input_shapes = input_->output_shapes(); + output_shapes_.reserve(input_shapes.size()); + // Always set the first dim as None unless batch_mode is specified. + for (const auto& input_shape : input_shapes) { + output_shapes_.emplace_back( + PartialTensorShape({-1}).Concatenate(input_shape)); + output_shapes_.back().RemoveDim(1); + } + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return absl::make_unique( + Iterator::Params{this, strings::StrCat(prefix, "::Rebatch")}); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return strings::StrCat("RebatchDatasetOp(", batch_size_, ")::Dataset"); + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* batch_size = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size)); + Node* batch_mode = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(batch_mode_, &batch_mode)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_graph_node, batch_size, batch_mode}, + output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + current_index_(0), + current_batch_size_(0) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + *end_of_sequence = false; + + int64 chunk_read = 0; + + out_tensors->clear(); + std::vector elements; + while (!*end_of_sequence) { + if (current_index_ < current_batch_size_) { + if (out_tensors->size() == 0) { + out_tensors->reserve(tensors_.size()); + elements.reserve(tensors_.size()); + for (size_t i = 0; i < tensors_.size(); ++i) { + TensorShape shape = tensors_[i].shape(); + shape.RemoveDim(0); + elements.emplace_back(ctx->allocator({}), tensors_[i].dtype(), shape); + shape.InsertDim(0, dataset()->batch_size_); + out_tensors->emplace_back(ctx->allocator({}), tensors_[i].dtype(), shape); + } + } + if (out_tensors->size() != tensors_.size()) { + return errors::InvalidArgument("number tensors should match previous one, ", tensors_.size(), " vs. ", out_tensors->size()); + } + int64 chunk_to_read = (current_batch_size_ - current_index_) < (dataset()->batch_size_ - chunk_read) ? (current_batch_size_ - current_index_) : (dataset()->batch_size_ - chunk_read); + for (int i = 0; i < tensors_.size(); ++i) { + // TODO: concurrent copy? + for (int64 r = 0; r < chunk_to_read; r++) { + TF_RETURN_IF_ERROR(batch_util::MaybeMoveSliceToElement( + &tensors_[i], &elements[i], current_index_ + r)); + TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice( + elements[i], &(*out_tensors)[i], chunk_read + r)); + } + } + chunk_read += chunk_to_read; + current_index_ += chunk_to_read; + if (chunk_read == dataset()->batch_size_) { + *end_of_sequence = false; + return Status::OK(); + } + } + current_index_ = 0; + current_batch_size_ = 0; + tensors_.clear(); + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &tensors_, end_of_sequence)); + if (!*end_of_sequence) { + for (size_t i = 0; i < tensors_.size(); ++i) { + if (tensors_[i].dims() == 0) { + return errors::InvalidArgument( + "Input element must have a non-scalar value in each " + "component."); + } + if (tensors_[i].dim_size(0) != tensors_[0].dim_size(0)) { + return errors::InvalidArgument( + "Input element must have the same batch size in each " + "component. Component 0 had size ", + tensors_[0].dim_size(0), " but component ", i, + " had size, ", tensors_[i].dim_size(0), "."); + } + } + current_batch_size_ = tensors_[0].dim_size(0); + } + } + // Finally, resize if needed + if (chunk_read > 0) { + if (chunk_read < dataset()->batch_size_) { + for (int i = 0; i < out_tensors->size(); ++i) { + TensorShape shape = (*out_tensors)[i].shape(); + shape.set_dim(0, chunk_read); + Tensor value_tensor(ctx->allocator({}), (*out_tensors)[i].dtype(), shape); + for (int64 r = 0; r < chunk_read; r++) { + TF_RETURN_IF_ERROR(batch_util::MaybeMoveSliceToElement( + &(*out_tensors)[i], &elements[i], r)); + TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice( + elements[i], &value_tensor, r)); + } + (*out_tensors)[i] = std::move(value_tensor); + } + } + *end_of_sequence = false; + return Status::OK(); + } + out_tensors->clear(); + input_impl_.reset(); + return Status::OK(); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + dataset()->batch_size_); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented("SaveInternal is currently not supported"); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented("RestoreInternal is currently not supported"); + } + + private: + mutex mu_; + int64 current_index_ GUARDED_BY(mu_); + int64 current_batch_size_ GUARDED_BY(mu_); + std::vector tensors_ GUARDED_BY(mu_); + std::unique_ptr input_impl_ GUARDED_BY(mu_); + }; + + const int64 batch_size_; + const string batch_mode_; + const DatasetBase* const input_; + std::vector output_shapes_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("RebatchDataset").Device(DEVICE_CPU), + RebatchDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/core/ops/core_ops.cc b/tensorflow_io/core/ops/core_ops.cc new file mode 100644 index 000000000..c7ae8c925 --- /dev/null +++ b/tensorflow_io/core/ops/core_ops.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("RebatchDataset") + .Input("input_dataset: variant") + .Input("batch_size: int64") + .Input("batch_mode: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // batch_size should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + // batch_mode should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return shape_inference::ScalarShape(c); + }); +} // namespace tensorflow diff --git a/tensorflow_io/core/python/ops/data_ops.py b/tensorflow_io/core/python/ops/data_ops.py index d9d20541b..392d91542 100644 --- a/tensorflow_io/core/python/ops/data_ops.py +++ b/tensorflow_io/core/python/ops/data_ops.py @@ -18,6 +18,40 @@ from __future__ import print_function import tensorflow as tf +from tensorflow_io.core.python.ops import core_ops + +class _RebatchDataset(tf.compat.v2.data.Dataset): + """RebatchDataset""" + + def __init__(self, input_dataset, batch_size, batch_mode=""): + """Create a RebatchDataset.""" + self._input_dataset = input_dataset + self._batch_size = batch_size + self._batch_mode = batch_mode + + self._structure = input_dataset._element_structure._unbatch()._batch(None) # pylint: disable=protected-access + + variant_tensor = core_ops.rebatch_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access + batch_size=self._batch_size, + batch_mode=self._batch_mode, + output_types=self._structure._flat_types, # pylint: disable=protected-access + output_shapes=self._structure._flat_shapes) # pylint: disable=protected-access + + super(_RebatchDataset, self).__init__(variant_tensor) + + def _inputs(self): + return [self._input_dataset] + + @property + def _element_structure(self): + return self._structure + +def rebatch(batch_size, batch_mode=""): + def _apply_fn(dataset): + return _RebatchDataset(dataset, batch_size, batch_mode) + + return _apply_fn # Note: BaseDataset could be used by Dataset implementations # that does not utilize DataInput implementation. diff --git a/tests/test_text_eager.py b/tests/test_text_eager.py index 647a7c831..af8d67bb4 100644 --- a/tests/test_text_eager.py +++ b/tests/test_text_eager.py @@ -27,6 +27,7 @@ if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() import tensorflow_io.text as text_io # pylint: disable=wrong-import-position +import tensorflow_io.core.python.ops.data_ops as core_io # pylint: disable=wrong-import-position def test_text_input(): """test_text_input @@ -53,6 +54,13 @@ def test_text_input(): i += 1 assert i == len(lines) + rebatch_dataset = text_dataset.apply(core_io.rebatch(5)) + i = 0 + for v in rebatch_dataset: + for vv in v.numpy(): + assert lines[i] == vv + i += 1 + assert i == len(lines) def test_text_output_sequence(): """Test case based on fashion mnist tutorial""" From 12946d35a7e3a264a332d1f9ca3e44ba05c48e61 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 30 Jul 2019 02:54:54 +0000 Subject: [PATCH 2/6] Add additional tests, also add batch_mode = "keep", "drop", "pad" mode Signed-off-by: Yong Tang --- .../core/kernels/rebatch_dataset_op.cc | 32 ++++++++++++------- tests/test_text_eager.py | 24 ++++++++++++-- 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/tensorflow_io/core/kernels/rebatch_dataset_op.cc b/tensorflow_io/core/kernels/rebatch_dataset_op.cc index b1d908505..25eec1aab 100644 --- a/tensorflow_io/core/kernels/rebatch_dataset_op.cc +++ b/tensorflow_io/core/kernels/rebatch_dataset_op.cc @@ -150,7 +150,7 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { int64 chunk_to_read = (current_batch_size_ - current_index_) < (dataset()->batch_size_ - chunk_read) ? (current_batch_size_ - current_index_) : (dataset()->batch_size_ - chunk_read); for (int i = 0; i < tensors_.size(); ++i) { // TODO: concurrent copy? - for (int64 r = 0; r < chunk_to_read; r++) { + for (int64 r = 0; r < chunk_to_read; ++r) { TF_RETURN_IF_ERROR(batch_util::MaybeMoveSliceToElement( &tensors_[i], &elements[i], current_index_ + r)); TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice( @@ -190,18 +190,28 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { // Finally, resize if needed if (chunk_read > 0) { if (chunk_read < dataset()->batch_size_) { - for (int i = 0; i < out_tensors->size(); ++i) { - TensorShape shape = (*out_tensors)[i].shape(); - shape.set_dim(0, chunk_read); - Tensor value_tensor(ctx->allocator({}), (*out_tensors)[i].dtype(), shape); - for (int64 r = 0; r < chunk_read; r++) { - TF_RETURN_IF_ERROR(batch_util::MaybeMoveSliceToElement( - &(*out_tensors)[i], &elements[i], r)); - TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice( - elements[i], &value_tensor, r)); + // "keep" reminder will need to resize + if (dataset()->batch_mode_ == "" || dataset()->batch_mode_ == "keep") { + for (int i = 0; i < out_tensors->size(); ++i) { + TensorShape shape = (*out_tensors)[i].shape(); + shape.set_dim(0, chunk_read); + Tensor value_tensor(ctx->allocator({}), (*out_tensors)[i].dtype(), shape); + for (int64 r = 0; r < chunk_read; r++) { + TF_RETURN_IF_ERROR(batch_util::MaybeMoveSliceToElement( + &(*out_tensors)[i], &elements[i], r)); + TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice( + elements[i], &value_tensor, r)); + } + (*out_tensors)[i] = std::move(value_tensor); } - (*out_tensors)[i] = std::move(value_tensor); + // "drop" the reminder + } else if (dataset()->batch_mode_ == "drop") { + out_tensors->clear(); + input_impl_.reset(); + *end_of_sequence = true; + return Status::OK(); } + // otherwise "pad" means keep the size } *end_of_sequence = false; return Status::OK(); diff --git a/tests/test_text_eager.py b/tests/test_text_eager.py index af8d67bb4..838ce24e7 100644 --- a/tests/test_text_eager.py +++ b/tests/test_text_eager.py @@ -54,13 +54,33 @@ def test_text_input(): i += 1 assert i == len(lines) - rebatch_dataset = text_dataset.apply(core_io.rebatch(5)) + for batch in [1, 2, 3, 4, 5]: + rebatch_dataset = text_dataset.apply(core_io.rebatch(batch)) + i = 0 + for v in rebatch_dataset: + for vv in v.numpy(): + assert lines[i] == vv + i += 1 + assert i == len(lines) + + rebatch_dataset = text_dataset.apply(core_io.rebatch(5, "drop")) i = 0 for v in rebatch_dataset: for vv in v.numpy(): assert lines[i] == vv i += 1 - assert i == len(lines) + assert i == 145 + + rebatch_dataset = text_dataset.apply(core_io.rebatch(5, "pad")) + i = 0 + for v in rebatch_dataset: + for vv in v.numpy(): + if i < len(lines): + assert lines[i] == vv + else: + assert vv.decode() == "" + i += 1 + assert i == 150 def test_text_output_sequence(): """Test case based on fashion mnist tutorial""" From 159ad590f2cc4543dda2bf00373bc71a201ce164 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 30 Jul 2019 02:59:28 +0000 Subject: [PATCH 3/6] Rename RebatchDataset to AdjustBatchDataset Signed-off-by: Yong Tang --- tensorflow_io/core/kernels/rebatch_dataset_op.cc | 10 +++++----- tensorflow_io/core/ops/core_ops.cc | 2 +- tensorflow_io/core/python/ops/data_ops.py | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tensorflow_io/core/kernels/rebatch_dataset_op.cc b/tensorflow_io/core/kernels/rebatch_dataset_op.cc index 25eec1aab..496f72112 100644 --- a/tensorflow_io/core/kernels/rebatch_dataset_op.cc +++ b/tensorflow_io/core/kernels/rebatch_dataset_op.cc @@ -25,9 +25,9 @@ namespace tensorflow { namespace data { namespace { -class RebatchDatasetOp : public UnaryDatasetOpKernel { +class AdjustBatchDatasetOp : public UnaryDatasetOpKernel { public: - explicit RebatchDatasetOp(OpKernelConstruction* ctx) + explicit AdjustBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { } @@ -86,7 +86,7 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { } string DebugString() const override { - return strings::StrCat("RebatchDatasetOp(", batch_size_, ")::Dataset"); + return strings::StrCat("AdjustBatchDatasetOp(", batch_size_, ")::Dataset"); } protected: @@ -252,8 +252,8 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { }; }; -REGISTER_KERNEL_BUILDER(Name("RebatchDataset").Device(DEVICE_CPU), - RebatchDatasetOp); +REGISTER_KERNEL_BUILDER(Name("AdjustBatchDataset").Device(DEVICE_CPU), + AdjustBatchDatasetOp); } // namespace } // namespace data diff --git a/tensorflow_io/core/ops/core_ops.cc b/tensorflow_io/core/ops/core_ops.cc index c7ae8c925..051b18a33 100644 --- a/tensorflow_io/core/ops/core_ops.cc +++ b/tensorflow_io/core/ops/core_ops.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("RebatchDataset") +REGISTER_OP("AdjustBatchDataset") .Input("input_dataset: variant") .Input("batch_size: int64") .Input("batch_mode: string") diff --git a/tensorflow_io/core/python/ops/data_ops.py b/tensorflow_io/core/python/ops/data_ops.py index 392d91542..bbd3b5d8b 100644 --- a/tensorflow_io/core/python/ops/data_ops.py +++ b/tensorflow_io/core/python/ops/data_ops.py @@ -20,25 +20,25 @@ import tensorflow as tf from tensorflow_io.core.python.ops import core_ops -class _RebatchDataset(tf.compat.v2.data.Dataset): - """RebatchDataset""" +class _AdjustBatchDataset(tf.compat.v2.data.Dataset): + """AdjustBatchDataset""" def __init__(self, input_dataset, batch_size, batch_mode=""): - """Create a RebatchDataset.""" + """Create a AdjustBatchDataset.""" self._input_dataset = input_dataset self._batch_size = batch_size self._batch_mode = batch_mode self._structure = input_dataset._element_structure._unbatch()._batch(None) # pylint: disable=protected-access - variant_tensor = core_ops.rebatch_dataset( + variant_tensor = core_ops.adjust_batch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access batch_size=self._batch_size, batch_mode=self._batch_mode, output_types=self._structure._flat_types, # pylint: disable=protected-access output_shapes=self._structure._flat_shapes) # pylint: disable=protected-access - super(_RebatchDataset, self).__init__(variant_tensor) + super(_AdjustBatchDataset, self).__init__(variant_tensor) def _inputs(self): return [self._input_dataset] @@ -49,7 +49,7 @@ def _element_structure(self): def rebatch(batch_size, batch_mode=""): def _apply_fn(dataset): - return _RebatchDataset(dataset, batch_size, batch_mode) + return _AdjustBatchDataset(dataset, batch_size, batch_mode) return _apply_fn From 5f862346fd66858f192fe7c57882e0fe1ecae349 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 30 Jul 2019 03:05:08 +0000 Subject: [PATCH 4/6] Add additional processing in case shape is unknown Signed-off-by: Yong Tang --- tensorflow_io/core/kernels/rebatch_dataset_op.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow_io/core/kernels/rebatch_dataset_op.cc b/tensorflow_io/core/kernels/rebatch_dataset_op.cc index 496f72112..15bc3ef79 100644 --- a/tensorflow_io/core/kernels/rebatch_dataset_op.cc +++ b/tensorflow_io/core/kernels/rebatch_dataset_op.cc @@ -63,9 +63,13 @@ class AdjustBatchDatasetOp : public UnaryDatasetOpKernel { output_shapes_.reserve(input_shapes.size()); // Always set the first dim as None unless batch_mode is specified. for (const auto& input_shape : input_shapes) { - output_shapes_.emplace_back( - PartialTensorShape({-1}).Concatenate(input_shape)); - output_shapes_.back().RemoveDim(1); + if (!input_shape.unknown_rank()) { + output_shapes_.emplace_back( + PartialTensorShape({-1}).Concatenate(input_shape)); + output_shapes_.back().RemoveDim(1); + } else { + output_shapes_.emplace_back(); + } } } From 8f12b8862951fc957a094b81aaeb4321790e7b51 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 30 Jul 2019 18:15:28 +0000 Subject: [PATCH 5/6] Address review comments Signed-off-by: Yong Tang --- tensorflow_io/core/kernels/rebatch_dataset_op.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tensorflow_io/core/kernels/rebatch_dataset_op.cc b/tensorflow_io/core/kernels/rebatch_dataset_op.cc index 15bc3ef79..382a4527a 100644 --- a/tensorflow_io/core/kernels/rebatch_dataset_op.cc +++ b/tensorflow_io/core/kernels/rebatch_dataset_op.cc @@ -43,6 +43,12 @@ class AdjustBatchDatasetOp : public UnaryDatasetOpKernel { string batch_mode = ""; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_mode", &batch_mode)); + OP_REQUIRES( + ctx, !(batch_mode == "" || + batch_mode == "keep" || + batch_mode == "drop" || + batch_mode == "pad"), errors::InvalidArgument("invalid batch_mode: ", batch_mode)); + *output = new Dataset(ctx, batch_size, batch_mode, input); @@ -216,6 +222,10 @@ class AdjustBatchDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } // otherwise "pad" means keep the size + // TODO: at the moment the remining of the Tensor will + // be filled with default values, so there is nothing + // needs to be done. If non-default values are needed + // then it will need to be filled. } *end_of_sequence = false; return Status::OK(); From 95166c54d61d29661eb3a9cb4fb5c2635740c074 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 30 Jul 2019 20:43:54 +0000 Subject: [PATCH 6/6] Fix failed tests Signed-off-by: Yong Tang --- tensorflow_io/core/kernels/rebatch_dataset_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_io/core/kernels/rebatch_dataset_op.cc b/tensorflow_io/core/kernels/rebatch_dataset_op.cc index 382a4527a..44a4b9882 100644 --- a/tensorflow_io/core/kernels/rebatch_dataset_op.cc +++ b/tensorflow_io/core/kernels/rebatch_dataset_op.cc @@ -44,7 +44,7 @@ class AdjustBatchDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_mode", &batch_mode)); OP_REQUIRES( - ctx, !(batch_mode == "" || + ctx, (batch_mode == "" || batch_mode == "keep" || batch_mode == "drop" || batch_mode == "pad"), errors::InvalidArgument("invalid batch_mode: ", batch_mode));