diff --git a/tensorflow_io/parquet/kernels/parquet_kernels.cc b/tensorflow_io/parquet/kernels/parquet_kernels.cc index df40937fb1..5504cda682 100644 --- a/tensorflow_io/parquet/kernels/parquet_kernels.cc +++ b/tensorflow_io/parquet/kernels/parquet_kernels.cc @@ -13,9 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#define EIGEN_USE_THREADS - -#include "kernels/dataset_ops.h" #include "tensorflow/core/framework/op_kernel.h" #include "parquet/api/reader.h" @@ -243,6 +240,9 @@ class ReadParquetOp : public OpKernel { } OP_REQUIRES(context, (column_index < file_metadata->num_columns()), errors::InvalidArgument("unable to find column: ", column)); + if (count < 0) { + count = file_metadata->num_rows(); + } if (start + count > file_metadata->num_rows()) { count = file_metadata->num_rows() - start; } diff --git a/tensorflow_io/parquet/python/ops/parquet_ops.py b/tensorflow_io/parquet/python/ops/parquet_ops.py index 1d97f3a210..0800ab12b8 100644 --- a/tensorflow_io/parquet/python/ops/parquet_ops.py +++ b/tensorflow_io/parquet/python/ops/parquet_ops.py @@ -24,7 +24,7 @@ def list_parquet_columns(filename, **kwargs): """list_parquet_columns""" if not tf.executing_eagerly(): - raise NotImplementedError("read_parquet_spect only support eager mode") + raise NotImplementedError("list_parquet_columns only support eager mode") memory = kwargs.get("memory", "") columns, dtypes, shapes = parquet_ops.list_parquet_columns( filename, memory=memory) @@ -33,18 +33,24 @@ def list_parquet_columns(filename, **kwargs): shape.numpy(), dtype.numpy().decode(), column.numpy().decode())) for ( column, dtype, shape) in entries]) -def read_parquet(filename, column, start=0, **kwargs): +def read_parquet(filename, column, **kwargs): """read_parquet""" + start = kwargs.get("start", 0) + count = kwargs.get("start", None) memory = kwargs.get("memory", "") + if count is None and column.shape[0] is not None: + count = column.shape[0] - start + if count is None: + count = -1 return parquet_ops.read_parquet( filename, column.name, - start=start, count=column.shape[0] - start, dtype=column.dtype, + start=start, count=-1, dtype=column.dtype, memory=memory) class ParquetDataset(data_ops.BaseDataset): """A Parquet Dataset that reads the parquet file.""" - def __init__(self, filename, column, batch=None, **kwargs): + def __init__(self, filename, column, **kwargs): """Create a `ParquetDataset`. `ParquetDataset` allows a user to read data from a parquet file. @@ -62,26 +68,17 @@ def __init__(self, filename, column, batch=None, **kwargs): count = columns[column].shape[0] dtype = columns[column].dtype - batch = 0 if batch is None else batch - shape = tf.TensorShape([]) if ( - batch is None or batch == 0) else tf.TensorShape([None]) + shape = tf.TensorShape([None]) # capacity is the rough count for each chunk in dataset - # not directly related to batch, will be padded to batch though capacity = kwargs.get("capacity", 65536) - if batch is not None and batch != 0 and capacity > batch: - capacity = (capacity // batch) * batch - entry_start = range(0, count, capacity) + entry_start = list(range(0, count, capacity)) entry_count = [min(capacity, count - start) for start in entry_start] dataset = data_ops.BaseDataset.from_tensor_slices( (tf.constant(entry_start, tf.int64), tf.constant(entry_count, tf.int64)) ).map(lambda start, count: parquet_ops.read_parquet( filename, column, start, count, dtype=dtype, memory="")) - if batch is None or batch == 0: - self._dataset = dataset.apply(tf.data.experimental.unbatch()) - else: - # TODO: convert to rebatch for performance - self._dataset = dataset.apply(tf.data.experimental.unbatch()).batch(batch) + self._dataset = dataset super(ParquetDataset, self).__init__( self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access diff --git a/tests/test_parquet.py b/tests/test_parquet.py new file mode 100644 index 0000000000..37090f2497 --- /dev/null +++ b/tests/test_parquet.py @@ -0,0 +1,89 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for ParquetDataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pytest +import numpy as np + +import tensorflow as tf +tf.compat.v1.disable_eager_execution() +import tensorflow_io.parquet as parquet_io # pylint: disable=wrong-import-position + +# Note: The sample file is generated from: +# `parquet-cpp/examples/low-level-api/reader_writer` +# This test extracts columns of [0, 1, 2, 4, 5] +# with column data types of [bool, int32, int64, float, double]. +# Please check `parquet-cpp/examples/low-level-api/reader-writer.cc` +# to find details of how records are generated: +# Column 0 (bool): True for even rows and False otherwise. +# Column 1 (int32): Equal to row_index. +# Column 2 (int64): Equal to row_index * 1000 * 1000 * 1000 * 1000. +# Column 4 (float): Equal to row_index * 1.1. +# Column 5 (double): Equal to row_index * 1.1111111. +def test_parquet(): + """Test case for ParquetDataset.""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_parquet", + "parquet_cpp_example.parquet") + filename = "file://" + filename + + columns = [ + 'boolean_field', + 'int32_field', + 'int64_field', + 'float_field', + 'double_field'] + dtypes = [ + tf.bool, + tf.int32, + tf.int64, + tf.float32, + tf.double] + + dataset = tf.compat.v2.data.Dataset.zip( + tuple([parquet_io.ParquetDataset( + filename, column, dtype=dtype, + start=0, count=500) for ( + column, dtype) in zip(columns, dtypes)])).apply( + tf.data.experimental.unbatch()) + + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) + init_op = iterator.initializer + get_next = iterator.get_next() + with tf.compat.v1.Session() as sess: + sess.run(init_op) + for i in range(500): + v0 = ((i % 2) == 0) + v1 = i + v2 = i * 1000 * 1000 * 1000 * 1000 + v4 = 1.1 * i + v5 = 1.1111111 * i + p0, p1, p2, p4, p5 = sess.run(get_next) + assert v0 == p0 + assert v1 == p1 + assert v2 == p2 + assert np.isclose(v4, p4) + assert np.isclose(v5, p5) + with pytest.raises(tf.errors.OutOfRangeError): + sess.run(get_next) + +if __name__ == "__main__": + test.main() diff --git a/tests/test_parquet_eager.py b/tests/test_parquet_eager.py index ff3e6363ca..d80440a203 100644 --- a/tests/test_parquet_eager.py +++ b/tests/test_parquet_eager.py @@ -74,7 +74,8 @@ def test_parquet(): dataset = tf.compat.v2.data.Dataset.zip( tuple( - [parquet_io.ParquetDataset(filename, column) for column in columns])) + [parquet_io.ParquetDataset(filename, column) for column in columns]) + ).apply(tf.data.experimental.unbatch()) i = 0 for p in dataset: v0 = ((i % 2) == 0)