Skip to content

Commit

Permalink
Remove batch args, and add test in graph mode
Browse files Browse the repository at this point in the history
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
  • Loading branch information
yongtang committed Jul 31, 2019
1 parent c493791 commit 378b9c2
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 20 deletions.
6 changes: 3 additions & 3 deletions tensorflow_io/parquet/kernels/parquet_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#define EIGEN_USE_THREADS

#include "kernels/dataset_ops.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "parquet/api/reader.h"

Expand Down Expand Up @@ -243,6 +240,9 @@ class ReadParquetOp : public OpKernel {
}
OP_REQUIRES(context, (column_index < file_metadata->num_columns()), errors::InvalidArgument("unable to find column: ", column));

if (count < 0) {
count = file_metadata->num_rows();
}
if (start + count > file_metadata->num_rows()) {
count = file_metadata->num_rows() - start;
}
Expand Down
29 changes: 13 additions & 16 deletions tensorflow_io/parquet/python/ops/parquet_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
def list_parquet_columns(filename, **kwargs):
"""list_parquet_columns"""
if not tf.executing_eagerly():
raise NotImplementedError("read_parquet_spect only support eager mode")
raise NotImplementedError("list_parquet_columns only support eager mode")
memory = kwargs.get("memory", "")
columns, dtypes, shapes = parquet_ops.list_parquet_columns(
filename, memory=memory)
Expand All @@ -33,18 +33,24 @@ def list_parquet_columns(filename, **kwargs):
shape.numpy(), dtype.numpy().decode(), column.numpy().decode())) for (
column, dtype, shape) in entries])

def read_parquet(filename, column, start=0, **kwargs):
def read_parquet(filename, column, **kwargs):
"""read_parquet"""
start = kwargs.get("start", 0)
count = kwargs.get("start", None)
memory = kwargs.get("memory", "")
if count is None and column.shape[0] is not None:
count = column.shape[0] - start
if count is None:
count = -1
return parquet_ops.read_parquet(
filename, column.name,
start=start, count=column.shape[0] - start, dtype=column.dtype,
start=start, count=-1, dtype=column.dtype,
memory=memory)

class ParquetDataset(data_ops.BaseDataset):
"""A Parquet Dataset that reads the parquet file."""

def __init__(self, filename, column, batch=None, **kwargs):
def __init__(self, filename, column, **kwargs):
"""Create a `ParquetDataset`.
`ParquetDataset` allows a user to read data from a parquet file.
Expand All @@ -62,26 +68,17 @@ def __init__(self, filename, column, batch=None, **kwargs):
count = columns[column].shape[0]
dtype = columns[column].dtype

batch = 0 if batch is None else batch
shape = tf.TensorShape([]) if (
batch is None or batch == 0) else tf.TensorShape([None])
shape = tf.TensorShape([None])

# capacity is the rough count for each chunk in dataset
# not directly related to batch, will be padded to batch though
capacity = kwargs.get("capacity", 65536)
if batch is not None and batch != 0 and capacity > batch:
capacity = (capacity // batch) * batch
entry_start = range(0, count, capacity)
entry_start = list(range(0, count, capacity))
entry_count = [min(capacity, count - start) for start in entry_start]
dataset = data_ops.BaseDataset.from_tensor_slices(
(tf.constant(entry_start, tf.int64), tf.constant(entry_count, tf.int64))
).map(lambda start, count: parquet_ops.read_parquet(
filename, column, start, count, dtype=dtype, memory=""))
if batch is None or batch == 0:
self._dataset = dataset.apply(tf.data.experimental.unbatch())
else:
# TODO: convert to rebatch for performance
self._dataset = dataset.apply(tf.data.experimental.unbatch()).batch(batch)
self._dataset = dataset

super(ParquetDataset, self).__init__(
self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access
89 changes: 89 additions & 0 deletions tests/test_parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
# ==============================================================================
"""Tests for ParquetDataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import pytest
import numpy as np

import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import tensorflow_io.parquet as parquet_io # pylint: disable=wrong-import-position

# Note: The sample file is generated from:
# `parquet-cpp/examples/low-level-api/reader_writer`
# This test extracts columns of [0, 1, 2, 4, 5]
# with column data types of [bool, int32, int64, float, double].
# Please check `parquet-cpp/examples/low-level-api/reader-writer.cc`
# to find details of how records are generated:
# Column 0 (bool): True for even rows and False otherwise.
# Column 1 (int32): Equal to row_index.
# Column 2 (int64): Equal to row_index * 1000 * 1000 * 1000 * 1000.
# Column 4 (float): Equal to row_index * 1.1.
# Column 5 (double): Equal to row_index * 1.1111111.
def test_parquet():
"""Test case for ParquetDataset."""
filename = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"test_parquet",
"parquet_cpp_example.parquet")
filename = "file://" + filename

columns = [
'boolean_field',
'int32_field',
'int64_field',
'float_field',
'double_field']
dtypes = [
tf.bool,
tf.int32,
tf.int64,
tf.float32,
tf.double]

dataset = tf.compat.v2.data.Dataset.zip(
tuple([parquet_io.ParquetDataset(
filename, column, dtype=dtype,
start=0, count=500) for (
column, dtype) in zip(columns, dtypes)])).apply(
tf.data.experimental.unbatch())

iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
init_op = iterator.initializer
get_next = iterator.get_next()
with tf.compat.v1.Session() as sess:
sess.run(init_op)
for i in range(500):
v0 = ((i % 2) == 0)
v1 = i
v2 = i * 1000 * 1000 * 1000 * 1000
v4 = 1.1 * i
v5 = 1.1111111 * i
p0, p1, p2, p4, p5 = sess.run(get_next)
assert v0 == p0
assert v1 == p1
assert v2 == p2
assert np.isclose(v4, p4)
assert np.isclose(v5, p5)
with pytest.raises(tf.errors.OutOfRangeError):
sess.run(get_next)

if __name__ == "__main__":
test.main()
3 changes: 2 additions & 1 deletion tests/test_parquet_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def test_parquet():

dataset = tf.compat.v2.data.Dataset.zip(
tuple(
[parquet_io.ParquetDataset(filename, column) for column in columns]))
[parquet_io.ParquetDataset(filename, column) for column in columns])
).apply(tf.data.experimental.unbatch())
i = 0
for p in dataset:
v0 = ((i % 2) == 0)
Expand Down

0 comments on commit 378b9c2

Please sign in to comment.