diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 94fed00ca..4ecacb2d7 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -57,6 +57,22 @@ cc_library( ], ) +cc_library( + name = "core_ops", + srcs = [ + "kernels/rebatch_dataset_op.cc", + "ops/core_ops.cc", + ], + copts = tf_io_copts(), + includes = [ + ".", + ], + linkstatic = True, + deps = [ + ":dataset_ops", + ], +) + cc_library( name = "ffmpeg_3.4", srcs = [ @@ -107,6 +123,7 @@ cc_binary( copts = tf_io_copts(), linkshared = 1, deps = [ + ":core_ops", "//tensorflow_io/audio:audio_ops", "//tensorflow_io/avro:avro_ops", "//tensorflow_io/azure:azfs_ops", diff --git a/tensorflow_io/core/kernels/rebatch_dataset_op.cc b/tensorflow_io/core/kernels/rebatch_dataset_op.cc new file mode 100644 index 000000000..44a4b9882 --- /dev/null +++ b/tensorflow_io/core/kernels/rebatch_dataset_op.cc @@ -0,0 +1,274 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/batch_util.h" + +namespace tensorflow { +namespace data { +namespace { + +class AdjustBatchDatasetOp : public UnaryDatasetOpKernel { + public: + explicit AdjustBatchDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + int64 batch_size = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "batch_size", &batch_size)); + OP_REQUIRES( + ctx, batch_size > 0, + errors::InvalidArgument("Batch size must be greater than zero.")); + + string batch_mode = ""; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "batch_mode", &batch_mode)); + OP_REQUIRES( + ctx, (batch_mode == "" || + batch_mode == "keep" || + batch_mode == "drop" || + batch_mode == "pad"), errors::InvalidArgument("invalid batch_mode: ", batch_mode)); + + + *output = + new Dataset(ctx, batch_size, batch_mode, input); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, int64 batch_size, string batch_mode, + const DatasetBase* input) + : DatasetBase(DatasetContext(ctx)), + batch_size_(batch_size), + batch_mode_(batch_mode), + input_(input) { + input_->Ref(); + + const auto& input_shapes = input_->output_shapes(); + output_shapes_.reserve(input_shapes.size()); + // Always set the first dim as None unless batch_mode is specified. + for (const auto& input_shape : input_shapes) { + if (!input_shape.unknown_rank()) { + output_shapes_.emplace_back( + PartialTensorShape({-1}).Concatenate(input_shape)); + output_shapes_.back().RemoveDim(1); + } else { + output_shapes_.emplace_back(); + } + } + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return absl::make_unique( + Iterator::Params{this, strings::StrCat(prefix, "::Rebatch")}); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return strings::StrCat("AdjustBatchDatasetOp(", batch_size_, ")::Dataset"); + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* batch_size = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size)); + Node* batch_mode = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(batch_mode_, &batch_mode)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_graph_node, batch_size, batch_mode}, + output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + current_index_(0), + current_batch_size_(0) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + *end_of_sequence = false; + + int64 chunk_read = 0; + + out_tensors->clear(); + std::vector elements; + while (!*end_of_sequence) { + if (current_index_ < current_batch_size_) { + if (out_tensors->size() == 0) { + out_tensors->reserve(tensors_.size()); + elements.reserve(tensors_.size()); + for (size_t i = 0; i < tensors_.size(); ++i) { + TensorShape shape = tensors_[i].shape(); + shape.RemoveDim(0); + elements.emplace_back(ctx->allocator({}), tensors_[i].dtype(), shape); + shape.InsertDim(0, dataset()->batch_size_); + out_tensors->emplace_back(ctx->allocator({}), tensors_[i].dtype(), shape); + } + } + if (out_tensors->size() != tensors_.size()) { + return errors::InvalidArgument("number tensors should match previous one, ", tensors_.size(), " vs. ", out_tensors->size()); + } + int64 chunk_to_read = (current_batch_size_ - current_index_) < (dataset()->batch_size_ - chunk_read) ? (current_batch_size_ - current_index_) : (dataset()->batch_size_ - chunk_read); + for (int i = 0; i < tensors_.size(); ++i) { + // TODO: concurrent copy? + for (int64 r = 0; r < chunk_to_read; ++r) { + TF_RETURN_IF_ERROR(batch_util::MaybeMoveSliceToElement( + &tensors_[i], &elements[i], current_index_ + r)); + TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice( + elements[i], &(*out_tensors)[i], chunk_read + r)); + } + } + chunk_read += chunk_to_read; + current_index_ += chunk_to_read; + if (chunk_read == dataset()->batch_size_) { + *end_of_sequence = false; + return Status::OK(); + } + } + current_index_ = 0; + current_batch_size_ = 0; + tensors_.clear(); + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &tensors_, end_of_sequence)); + if (!*end_of_sequence) { + for (size_t i = 0; i < tensors_.size(); ++i) { + if (tensors_[i].dims() == 0) { + return errors::InvalidArgument( + "Input element must have a non-scalar value in each " + "component."); + } + if (tensors_[i].dim_size(0) != tensors_[0].dim_size(0)) { + return errors::InvalidArgument( + "Input element must have the same batch size in each " + "component. Component 0 had size ", + tensors_[0].dim_size(0), " but component ", i, + " had size, ", tensors_[i].dim_size(0), "."); + } + } + current_batch_size_ = tensors_[0].dim_size(0); + } + } + // Finally, resize if needed + if (chunk_read > 0) { + if (chunk_read < dataset()->batch_size_) { + // "keep" reminder will need to resize + if (dataset()->batch_mode_ == "" || dataset()->batch_mode_ == "keep") { + for (int i = 0; i < out_tensors->size(); ++i) { + TensorShape shape = (*out_tensors)[i].shape(); + shape.set_dim(0, chunk_read); + Tensor value_tensor(ctx->allocator({}), (*out_tensors)[i].dtype(), shape); + for (int64 r = 0; r < chunk_read; r++) { + TF_RETURN_IF_ERROR(batch_util::MaybeMoveSliceToElement( + &(*out_tensors)[i], &elements[i], r)); + TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice( + elements[i], &value_tensor, r)); + } + (*out_tensors)[i] = std::move(value_tensor); + } + // "drop" the reminder + } else if (dataset()->batch_mode_ == "drop") { + out_tensors->clear(); + input_impl_.reset(); + *end_of_sequence = true; + return Status::OK(); + } + // otherwise "pad" means keep the size + // TODO: at the moment the remining of the Tensor will + // be filled with default values, so there is nothing + // needs to be done. If non-default values are needed + // then it will need to be filled. + } + *end_of_sequence = false; + return Status::OK(); + } + out_tensors->clear(); + input_impl_.reset(); + return Status::OK(); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + dataset()->batch_size_); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented("SaveInternal is currently not supported"); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented("RestoreInternal is currently not supported"); + } + + private: + mutex mu_; + int64 current_index_ GUARDED_BY(mu_); + int64 current_batch_size_ GUARDED_BY(mu_); + std::vector tensors_ GUARDED_BY(mu_); + std::unique_ptr input_impl_ GUARDED_BY(mu_); + }; + + const int64 batch_size_; + const string batch_mode_; + const DatasetBase* const input_; + std::vector output_shapes_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("AdjustBatchDataset").Device(DEVICE_CPU), + AdjustBatchDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/core/ops/core_ops.cc b/tensorflow_io/core/ops/core_ops.cc new file mode 100644 index 000000000..051b18a33 --- /dev/null +++ b/tensorflow_io/core/ops/core_ops.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("AdjustBatchDataset") + .Input("input_dataset: variant") + .Input("batch_size: int64") + .Input("batch_mode: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // batch_size should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + // batch_mode should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return shape_inference::ScalarShape(c); + }); +} // namespace tensorflow diff --git a/tensorflow_io/core/python/ops/data_ops.py b/tensorflow_io/core/python/ops/data_ops.py index d9d20541b..bbd3b5d8b 100644 --- a/tensorflow_io/core/python/ops/data_ops.py +++ b/tensorflow_io/core/python/ops/data_ops.py @@ -18,6 +18,40 @@ from __future__ import print_function import tensorflow as tf +from tensorflow_io.core.python.ops import core_ops + +class _AdjustBatchDataset(tf.compat.v2.data.Dataset): + """AdjustBatchDataset""" + + def __init__(self, input_dataset, batch_size, batch_mode=""): + """Create a AdjustBatchDataset.""" + self._input_dataset = input_dataset + self._batch_size = batch_size + self._batch_mode = batch_mode + + self._structure = input_dataset._element_structure._unbatch()._batch(None) # pylint: disable=protected-access + + variant_tensor = core_ops.adjust_batch_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access + batch_size=self._batch_size, + batch_mode=self._batch_mode, + output_types=self._structure._flat_types, # pylint: disable=protected-access + output_shapes=self._structure._flat_shapes) # pylint: disable=protected-access + + super(_AdjustBatchDataset, self).__init__(variant_tensor) + + def _inputs(self): + return [self._input_dataset] + + @property + def _element_structure(self): + return self._structure + +def rebatch(batch_size, batch_mode=""): + def _apply_fn(dataset): + return _AdjustBatchDataset(dataset, batch_size, batch_mode) + + return _apply_fn # Note: BaseDataset could be used by Dataset implementations # that does not utilize DataInput implementation. diff --git a/tests/test_text_eager.py b/tests/test_text_eager.py index 647a7c831..838ce24e7 100644 --- a/tests/test_text_eager.py +++ b/tests/test_text_eager.py @@ -27,6 +27,7 @@ if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() import tensorflow_io.text as text_io # pylint: disable=wrong-import-position +import tensorflow_io.core.python.ops.data_ops as core_io # pylint: disable=wrong-import-position def test_text_input(): """test_text_input @@ -53,6 +54,33 @@ def test_text_input(): i += 1 assert i == len(lines) + for batch in [1, 2, 3, 4, 5]: + rebatch_dataset = text_dataset.apply(core_io.rebatch(batch)) + i = 0 + for v in rebatch_dataset: + for vv in v.numpy(): + assert lines[i] == vv + i += 1 + assert i == len(lines) + + rebatch_dataset = text_dataset.apply(core_io.rebatch(5, "drop")) + i = 0 + for v in rebatch_dataset: + for vv in v.numpy(): + assert lines[i] == vv + i += 1 + assert i == 145 + + rebatch_dataset = text_dataset.apply(core_io.rebatch(5, "pad")) + i = 0 + for v in rebatch_dataset: + for vv in v.numpy(): + if i < len(lines): + assert lines[i] == vv + else: + assert vv.decode() == "" + i += 1 + assert i == 150 def test_text_output_sequence(): """Test case based on fashion mnist tutorial"""