diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 7a11cf81eb..06a8205e21 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -28,6 +28,7 @@ cc_library( name = "dataset_ops", srcs = [ "kernels/dataset_ops.h", + "kernels/stream.h", ], copts = tf_io_copts(), includes = [ diff --git a/tensorflow_io/core/kernels/stream.h b/tensorflow_io/core/kernels/stream.h new file mode 100644 index 0000000000..e812babf47 --- /dev/null +++ b/tensorflow_io/core/kernels/stream.h @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/lib/io/random_inputstream.h" + +namespace tensorflow { +namespace data { + +// Note: This SizedRandomAccessFile should only lives within Compute() +// of the kernel as buffer could be released by outside. +class SizedRandomAccessFile : public tensorflow::RandomAccessFile { + public: + SizedRandomAccessFile(Env* env, const string& filename, const void* optional_memory_buff, const size_t optional_memory_size) + : file_(nullptr) + , size_(optional_memory_size) + , buff_((const char *)(optional_memory_buff)) + , size_status_(Status::OK()) { + if (size_ == 0) { + size_status_ = env->GetFileSize(filename, &size_); + if (size_status_.ok()) { + size_status_ = env->NewRandomAccessFile(filename, &file_); + } + } + } + + virtual ~SizedRandomAccessFile() {} + Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { + if (file_.get() != nullptr) { + return file_.get()->Read(offset, n, result, scratch); + } + size_t bytes_to_read = 0; + if (offset < size_) { + bytes_to_read = (offset + n < size_) ? n : (size_ - offset); + } + if (bytes_to_read > 0) { + memcpy(scratch, &buff_[offset], bytes_to_read); + } + *result = StringPiece(scratch, bytes_to_read); + if (bytes_to_read < n) { + return errors::OutOfRange("EOF reached"); + } + return Status::OK(); + } + Status GetFileSize(uint64* size) { + if (size_status_.ok()) { + *size = size_; + } + return size_status_; + } + private: + std::unique_ptr file_; + uint64 size_; + const char *buff_; + Status size_status_; +}; + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/parquet/kernels/parquet_kernels.cc b/tensorflow_io/parquet/kernels/parquet_kernels.cc index df40937fb1..e6caf99a83 100644 --- a/tensorflow_io/parquet/kernels/parquet_kernels.cc +++ b/tensorflow_io/parquet/kernels/parquet_kernels.cc @@ -13,64 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#define EIGEN_USE_THREADS - -#include "kernels/dataset_ops.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/stream.h" #include "parquet/api/reader.h" namespace tensorflow { namespace data { namespace { -// Note: This SizedRandomAccessFile should only lives within Compute() -// of the kernel as buffer could be released by outside. -class SizedRandomAccessFile : public tensorflow::RandomAccessFile { - public: - SizedRandomAccessFile(Env* env, const string& filename, const string& optional_memory) - : file_(nullptr) - , size_status_(Status::OK()) - , size_(optional_memory.size()) - , buffer_(optional_memory) { - if (size_ == 0) { - size_status_ = env->GetFileSize(filename, &size_); - if (size_status_.ok()) { - size_status_ = env->NewRandomAccessFile(filename, &file_); - } - } - } - - virtual ~SizedRandomAccessFile() {} - Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { - if (file_.get() != nullptr) { - return file_.get()->Read(offset, n, result, scratch); - } - size_t bytes_to_read = 0; - if (offset < size_) { - bytes_to_read = (offset + n < size_) ? n : (size_ - offset); - } - if (bytes_to_read > 0) { - memcpy(scratch, buffer_.data(), bytes_to_read); - } - *result = StringPiece(scratch, bytes_to_read); - if (bytes_to_read < n) { - return errors::OutOfRange("EOF reached"); - } - return Status::OK(); - } - Status GetFileSize(uint64* size) { - if (size_status_.ok()) { - *size = size_; - } - return size_status_; - } - private: - std::unique_ptr file_; - Status size_status_; - uint64 size_; - const string& buffer_; -}; - class ParquetRandomAccessFile : public ::arrow::io::RandomAccessFile { public: explicit ParquetRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size) @@ -138,7 +88,7 @@ class ListParquetColumnsOp : public OpKernel { const Tensor& memory_tensor = context->input(1); const string& memory = memory_tensor.scalar()(); - std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory)); + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); uint64 size; OP_REQUIRES_OK(context, file->GetFileSize(&size)); @@ -218,16 +168,16 @@ class ReadParquetOp : public OpKernel { const Tensor& column_tensor = context->input(1); const string& column = column_tensor.scalar()(); - const Tensor& start_tensor = context->input(2); - const int64 start = start_tensor.scalar()(); + const Tensor& memory_tensor = context->input(2); + const string& memory = memory_tensor.scalar()(); - const Tensor& count_tensor = context->input(3); - int64 count = count_tensor.scalar()(); + const Tensor& start_tensor = context->input(3); + int64 start = start_tensor.scalar()(); - const Tensor& memory_tensor = context->input(4); - const string& memory = memory_tensor.scalar()(); + const Tensor& stop_tensor = context->input(4); + int64 stop = stop_tensor.scalar()(); - std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory)); + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); uint64 size; OP_REQUIRES_OK(context, file->GetFileSize(&size)); @@ -243,11 +193,17 @@ class ReadParquetOp : public OpKernel { } OP_REQUIRES(context, (column_index < file_metadata->num_columns()), errors::InvalidArgument("unable to find column: ", column)); - if (start + count > file_metadata->num_rows()) { - count = file_metadata->num_rows() - start; + if (start > file_metadata->num_rows()) { + start = file_metadata->num_rows(); + } + if (stop < 0) { + stop = file_metadata->num_rows(); + } + if (stop > file_metadata->num_rows()) { + stop = file_metadata->num_rows(); } - TensorShape output_shape({count}); + TensorShape output_shape({stop - start}); Tensor* output_tensor; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); @@ -255,14 +211,14 @@ class ReadParquetOp : public OpKernel { int64 row_group_offset = 0; for (int row_group = 0; row_group < file_metadata->num_row_groups(); row_group++) { std::shared_ptr row_group_reader = parquet_reader->RowGroup(row_group); - // Skip if row group is not within [start..start+count] - if ((row_group_offset + row_group_reader->metadata()->num_rows() < start) || (start + count <= row_group_offset)) { + // Skip if row group is not within [start..stop] + if ((row_group_offset + row_group_reader->metadata()->num_rows() < start) || (stop <= row_group_offset)) { row_group_offset += row_group_reader->metadata()->num_rows(); continue; } // Find row_to_read range int64 row_to_read_start = row_group_offset > start ? row_group_offset : start; - int64 row_to_read_final = (row_group_offset + row_group_reader->metadata()->num_rows()) < (start + count) ? (row_group_offset + row_group_reader->metadata()->num_rows()) : (start + count); + int64 row_to_read_final = (row_group_offset + row_group_reader->metadata()->num_rows()) < (stop) ? (row_group_offset + row_group_reader->metadata()->num_rows()) : (stop); int64 row_to_read_count = row_to_read_final - row_to_read_start; std::shared_ptr column_reader = row_group_reader->Column(column_index); diff --git a/tensorflow_io/parquet/ops/parquet_ops.cc b/tensorflow_io/parquet/ops/parquet_ops.cc index d12b9b7b25..32c318f8ca 100644 --- a/tensorflow_io/parquet/ops/parquet_ops.cc +++ b/tensorflow_io/parquet/ops/parquet_ops.cc @@ -27,15 +27,17 @@ REGISTER_OP("ListParquetColumns") .Output("shapes: int64") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); return Status::OK(); }); REGISTER_OP("ReadParquet") .Input("filename: string") .Input("column: string") - .Input("start: int64") - .Input("count: int64") .Input("memory: string") + .Input("start: int64") + .Input("stop: int64") .Attr("dtype: type") .Output("output: dtype") .SetShapeFn([](shape_inference::InferenceContext* c) { diff --git a/tensorflow_io/parquet/python/ops/parquet_ops.py b/tensorflow_io/parquet/python/ops/parquet_ops.py index 1d97f3a210..a0cb9407af 100644 --- a/tensorflow_io/parquet/python/ops/parquet_ops.py +++ b/tensorflow_io/parquet/python/ops/parquet_ops.py @@ -24,7 +24,7 @@ def list_parquet_columns(filename, **kwargs): """list_parquet_columns""" if not tf.executing_eagerly(): - raise NotImplementedError("read_parquet_spect only support eager mode") + raise NotImplementedError("list_parquet_columns only support eager mode") memory = kwargs.get("memory", "") columns, dtypes, shapes = parquet_ops.list_parquet_columns( filename, memory=memory) @@ -33,18 +33,23 @@ def list_parquet_columns(filename, **kwargs): shape.numpy(), dtype.numpy().decode(), column.numpy().decode())) for ( column, dtype, shape) in entries]) -def read_parquet(filename, column, start=0, **kwargs): +def read_parquet(filename, column, **kwargs): """read_parquet""" memory = kwargs.get("memory", "") + start = kwargs.get("start", 0) + stop = kwargs.get("stop", None) + if stop is None and column.shape[0] is not None: + stop = column.shape[0] - start + if stop is None: + stop = -1 return parquet_ops.read_parquet( - filename, column.name, - start=start, count=column.shape[0] - start, dtype=column.dtype, - memory=memory) + filename, column.name, memory=memory, + start=start, stop=-1, dtype=column.dtype) class ParquetDataset(data_ops.BaseDataset): """A Parquet Dataset that reads the parquet file.""" - def __init__(self, filename, column, batch=None, **kwargs): + def __init__(self, filename, column, **kwargs): """Create a `ParquetDataset`. `ParquetDataset` allows a user to read data from a parquet file. @@ -53,35 +58,28 @@ def __init__(self, filename, column, batch=None, **kwargs): filename: filename of the parquet file to read. column: column name to read. """ - # Note: count and dtype could be in kwargs if in graph mode. + # Note: start, stop and dtype could be in kwargs if in graph mode. if not tf.executing_eagerly(): - count = kwargs.get("count") + start = kwargs.get("start") + stop = kwargs.get("stop") dtype = kwargs.get("dtype") else: columns = list_parquet_columns(filename) - count = columns[column].shape[0] + start = 0 + stop = columns[column].shape[0] dtype = columns[column].dtype - batch = 0 if batch is None else batch - shape = tf.TensorShape([]) if ( - batch is None or batch == 0) else tf.TensorShape([None]) + shape = tf.TensorShape([None]) # capacity is the rough count for each chunk in dataset - # not directly related to batch, will be padded to batch though capacity = kwargs.get("capacity", 65536) - if batch is not None and batch != 0 and capacity > batch: - capacity = (capacity // batch) * batch - entry_start = range(0, count, capacity) - entry_count = [min(capacity, count - start) for start in entry_start] + entry_start = list(range(start, stop, capacity)) + entry_stop = entry_start[1:] + [stop] dataset = data_ops.BaseDataset.from_tensor_slices( - (tf.constant(entry_start, tf.int64), tf.constant(entry_count, tf.int64)) - ).map(lambda start, count: parquet_ops.read_parquet( - filename, column, start, count, dtype=dtype, memory="")) - if batch is None or batch == 0: - self._dataset = dataset.apply(tf.data.experimental.unbatch()) - else: - # TODO: convert to rebatch for performance - self._dataset = dataset.apply(tf.data.experimental.unbatch()).batch(batch) + (tf.constant(entry_start, tf.int64), tf.constant(entry_stop, tf.int64)) + ).map(lambda start, stop: parquet_ops.read_parquet( + filename, column, memory="", start=start, stop=stop, dtype=dtype)) + self._dataset = dataset super(ParquetDataset, self).__init__( self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access diff --git a/tests/test_parquet.py b/tests/test_parquet.py new file mode 100644 index 0000000000..e66d674611 --- /dev/null +++ b/tests/test_parquet.py @@ -0,0 +1,89 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for ParquetDataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pytest +import numpy as np + +import tensorflow as tf +tf.compat.v1.disable_eager_execution() +import tensorflow_io.parquet as parquet_io # pylint: disable=wrong-import-position + +# Note: The sample file is generated from: +# `parquet-cpp/examples/low-level-api/reader_writer` +# This test extracts columns of [0, 1, 2, 4, 5] +# with column data types of [bool, int32, int64, float, double]. +# Please check `parquet-cpp/examples/low-level-api/reader-writer.cc` +# to find details of how records are generated: +# Column 0 (bool): True for even rows and False otherwise. +# Column 1 (int32): Equal to row_index. +# Column 2 (int64): Equal to row_index * 1000 * 1000 * 1000 * 1000. +# Column 4 (float): Equal to row_index * 1.1. +# Column 5 (double): Equal to row_index * 1.1111111. +def test_parquet(): + """Test case for ParquetDataset.""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_parquet", + "parquet_cpp_example.parquet") + filename = "file://" + filename + + columns = [ + 'boolean_field', + 'int32_field', + 'int64_field', + 'float_field', + 'double_field'] + dtypes = [ + tf.bool, + tf.int32, + tf.int64, + tf.float32, + tf.double] + + dataset = tf.compat.v2.data.Dataset.zip( + tuple([parquet_io.ParquetDataset( + filename, column, dtype=dtype, + start=0, stop=500) for ( + column, dtype) in zip(columns, dtypes)])).apply( + tf.data.experimental.unbatch()) + + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) + init_op = iterator.initializer + get_next = iterator.get_next() + with tf.compat.v1.Session() as sess: + sess.run(init_op) + for i in range(500): + v0 = ((i % 2) == 0) + v1 = i + v2 = i * 1000 * 1000 * 1000 * 1000 + v4 = 1.1 * i + v5 = 1.1111111 * i + p0, p1, p2, p4, p5 = sess.run(get_next) + assert v0 == p0 + assert v1 == p1 + assert v2 == p2 + assert np.isclose(v4, p4) + assert np.isclose(v5, p5) + with pytest.raises(tf.errors.OutOfRangeError): + sess.run(get_next) + +if __name__ == "__main__": + test.main() diff --git a/tests/test_parquet_eager.py b/tests/test_parquet_eager.py index ff3e6363ca..d80440a203 100644 --- a/tests/test_parquet_eager.py +++ b/tests/test_parquet_eager.py @@ -74,7 +74,8 @@ def test_parquet(): dataset = tf.compat.v2.data.Dataset.zip( tuple( - [parquet_io.ParquetDataset(filename, column) for column in columns])) + [parquet_io.ParquetDataset(filename, column) for column in columns]) + ).apply(tf.data.experimental.unbatch()) i = 0 for p in dataset: v0 = ((i % 2) == 0)