Skip to content

Commit

Permalink
Remove batch args, and add test in graph mode
Browse files Browse the repository at this point in the history
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
  • Loading branch information
yongtang committed Aug 4, 2019
1 parent bd500b7 commit 2a2f113
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 93 deletions.
1 change: 1 addition & 0 deletions tensorflow_io/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ cc_library(
name = "dataset_ops",
srcs = [
"kernels/dataset_ops.h",
"kernels/stream.h",
],
copts = tf_io_copts(),
includes = [
Expand Down
71 changes: 71 additions & 0 deletions tensorflow_io/core/kernels/stream.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/lib/io/inputstream_interface.h"
#include "tensorflow/core/lib/io/random_inputstream.h"

namespace tensorflow {
namespace data {

// Note: This SizedRandomAccessFile should only lives within Compute()
// of the kernel as buffer could be released by outside.
class SizedRandomAccessFile : public tensorflow::RandomAccessFile {
public:
SizedRandomAccessFile(Env* env, const string& filename, const void* optional_memory_buff, const size_t optional_memory_size)
: file_(nullptr)
, size_(optional_memory_size)
, buff_((const char *)(optional_memory_buff))
, size_status_(Status::OK()) {
if (size_ == 0) {
size_status_ = env->GetFileSize(filename, &size_);
if (size_status_.ok()) {
size_status_ = env->NewRandomAccessFile(filename, &file_);
}
}
}

virtual ~SizedRandomAccessFile() {}
Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override {
if (file_.get() != nullptr) {
return file_.get()->Read(offset, n, result, scratch);
}
size_t bytes_to_read = 0;
if (offset < size_) {
bytes_to_read = (offset + n < size_) ? n : (size_ - offset);
}
if (bytes_to_read > 0) {
memcpy(scratch, &buff_[offset], bytes_to_read);
}
*result = StringPiece(scratch, bytes_to_read);
if (bytes_to_read < n) {
return errors::OutOfRange("EOF reached");
}
return Status::OK();
}
Status GetFileSize(uint64* size) {
if (size_status_.ok()) {
*size = size_;
}
return size_status_;
}
private:
std::unique_ptr<tensorflow::RandomAccessFile> file_;
uint64 size_;
const char *buff_;
Status size_status_;
};

} // namespace data
} // namespace tensorflow
86 changes: 21 additions & 65 deletions tensorflow_io/parquet/kernels/parquet_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,64 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#define EIGEN_USE_THREADS

#include "kernels/dataset_ops.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow_io/core/kernels/stream.h"
#include "parquet/api/reader.h"

namespace tensorflow {
namespace data {
namespace {

// Note: This SizedRandomAccessFile should only lives within Compute()
// of the kernel as buffer could be released by outside.
class SizedRandomAccessFile : public tensorflow::RandomAccessFile {
public:
SizedRandomAccessFile(Env* env, const string& filename, const string& optional_memory)
: file_(nullptr)
, size_status_(Status::OK())
, size_(optional_memory.size())
, buffer_(optional_memory) {
if (size_ == 0) {
size_status_ = env->GetFileSize(filename, &size_);
if (size_status_.ok()) {
size_status_ = env->NewRandomAccessFile(filename, &file_);
}
}
}

virtual ~SizedRandomAccessFile() {}
Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override {
if (file_.get() != nullptr) {
return file_.get()->Read(offset, n, result, scratch);
}
size_t bytes_to_read = 0;
if (offset < size_) {
bytes_to_read = (offset + n < size_) ? n : (size_ - offset);
}
if (bytes_to_read > 0) {
memcpy(scratch, buffer_.data(), bytes_to_read);
}
*result = StringPiece(scratch, bytes_to_read);
if (bytes_to_read < n) {
return errors::OutOfRange("EOF reached");
}
return Status::OK();
}
Status GetFileSize(uint64* size) {
if (size_status_.ok()) {
*size = size_;
}
return size_status_;
}
private:
std::unique_ptr<tensorflow::RandomAccessFile> file_;
Status size_status_;
uint64 size_;
const string& buffer_;
};

class ParquetRandomAccessFile : public ::arrow::io::RandomAccessFile {
public:
explicit ParquetRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size)
Expand Down Expand Up @@ -138,7 +88,7 @@ class ListParquetColumnsOp : public OpKernel {
const Tensor& memory_tensor = context->input(1);
const string& memory = memory_tensor.scalar<string>()();

std::unique_ptr<SizedRandomAccessFile> file(new SizedRandomAccessFile(env_, filename, memory));
std::unique_ptr<SizedRandomAccessFile> file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size()));
uint64 size;
OP_REQUIRES_OK(context, file->GetFileSize(&size));

Expand Down Expand Up @@ -218,16 +168,16 @@ class ReadParquetOp : public OpKernel {
const Tensor& column_tensor = context->input(1);
const string& column = column_tensor.scalar<string>()();

const Tensor& start_tensor = context->input(2);
const int64 start = start_tensor.scalar<int64>()();
const Tensor& memory_tensor = context->input(2);
const string& memory = memory_tensor.scalar<string>()();

const Tensor& count_tensor = context->input(3);
int64 count = count_tensor.scalar<int64>()();
const Tensor& start_tensor = context->input(3);
int64 start = start_tensor.scalar<int64>()();

const Tensor& memory_tensor = context->input(4);
const string& memory = memory_tensor.scalar<string>()();
const Tensor& stop_tensor = context->input(4);
int64 stop = stop_tensor.scalar<int64>()();

std::unique_ptr<SizedRandomAccessFile> file(new SizedRandomAccessFile(env_, filename, memory));
std::unique_ptr<SizedRandomAccessFile> file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size()));
uint64 size;
OP_REQUIRES_OK(context, file->GetFileSize(&size));

Expand All @@ -243,26 +193,32 @@ class ReadParquetOp : public OpKernel {
}
OP_REQUIRES(context, (column_index < file_metadata->num_columns()), errors::InvalidArgument("unable to find column: ", column));

if (start + count > file_metadata->num_rows()) {
count = file_metadata->num_rows() - start;
if (start > file_metadata->num_rows()) {
start = file_metadata->num_rows();
}
if (stop < 0) {
stop = file_metadata->num_rows();
}
if (stop > file_metadata->num_rows()) {
stop = file_metadata->num_rows();
}

TensorShape output_shape({count});
TensorShape output_shape({stop - start});

Tensor* output_tensor;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor));

int64 row_group_offset = 0;
for (int row_group = 0; row_group < file_metadata->num_row_groups(); row_group++) {
std::shared_ptr<parquet::RowGroupReader> row_group_reader = parquet_reader->RowGroup(row_group);
// Skip if row group is not within [start..start+count]
if ((row_group_offset + row_group_reader->metadata()->num_rows() < start) || (start + count <= row_group_offset)) {
// Skip if row group is not within [start..stop]
if ((row_group_offset + row_group_reader->metadata()->num_rows() < start) || (stop <= row_group_offset)) {
row_group_offset += row_group_reader->metadata()->num_rows();
continue;
}
// Find row_to_read range
int64 row_to_read_start = row_group_offset > start ? row_group_offset : start;
int64 row_to_read_final = (row_group_offset + row_group_reader->metadata()->num_rows()) < (start + count) ? (row_group_offset + row_group_reader->metadata()->num_rows()) : (start + count);
int64 row_to_read_final = (row_group_offset + row_group_reader->metadata()->num_rows()) < (stop) ? (row_group_offset + row_group_reader->metadata()->num_rows()) : (stop);
int64 row_to_read_count = row_to_read_final - row_to_read_start;

std::shared_ptr<parquet::ColumnReader> column_reader = row_group_reader->Column(column_index);
Expand Down
6 changes: 4 additions & 2 deletions tensorflow_io/parquet/ops/parquet_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@ REGISTER_OP("ListParquetColumns")
.Output("shapes: int64")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->MakeShape({c->UnknownDim()}));
c->set_output(1, c->MakeShape({c->UnknownDim()}));
c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()}));
return Status::OK();
});

REGISTER_OP("ReadParquet")
.Input("filename: string")
.Input("column: string")
.Input("start: int64")
.Input("count: int64")
.Input("memory: string")
.Input("start: int64")
.Input("stop: int64")
.Attr("dtype: type")
.Output("output: dtype")
.SetShapeFn([](shape_inference::InferenceContext* c) {
Expand Down
48 changes: 23 additions & 25 deletions tensorflow_io/parquet/python/ops/parquet_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
def list_parquet_columns(filename, **kwargs):
"""list_parquet_columns"""
if not tf.executing_eagerly():
raise NotImplementedError("read_parquet_spect only support eager mode")
raise NotImplementedError("list_parquet_columns only support eager mode")
memory = kwargs.get("memory", "")
columns, dtypes, shapes = parquet_ops.list_parquet_columns(
filename, memory=memory)
Expand All @@ -33,18 +33,23 @@ def list_parquet_columns(filename, **kwargs):
shape.numpy(), dtype.numpy().decode(), column.numpy().decode())) for (
column, dtype, shape) in entries])

def read_parquet(filename, column, start=0, **kwargs):
def read_parquet(filename, column, **kwargs):
"""read_parquet"""
memory = kwargs.get("memory", "")
start = kwargs.get("start", 0)
stop = kwargs.get("stop", None)
if stop is None and column.shape[0] is not None:
stop = column.shape[0] - start
if stop is None:
stop = -1
return parquet_ops.read_parquet(
filename, column.name,
start=start, count=column.shape[0] - start, dtype=column.dtype,
memory=memory)
filename, column.name, memory=memory,
start=start, stop=-1, dtype=column.dtype)

class ParquetDataset(data_ops.BaseDataset):
"""A Parquet Dataset that reads the parquet file."""

def __init__(self, filename, column, batch=None, **kwargs):
def __init__(self, filename, column, **kwargs):
"""Create a `ParquetDataset`.
`ParquetDataset` allows a user to read data from a parquet file.
Expand All @@ -53,35 +58,28 @@ def __init__(self, filename, column, batch=None, **kwargs):
filename: filename of the parquet file to read.
column: column name to read.
"""
# Note: count and dtype could be in kwargs if in graph mode.
# Note: start, stop and dtype could be in kwargs if in graph mode.
if not tf.executing_eagerly():
count = kwargs.get("count")
start = kwargs.get("start")
stop = kwargs.get("stop")
dtype = kwargs.get("dtype")
else:
columns = list_parquet_columns(filename)
count = columns[column].shape[0]
start = 0
stop = columns[column].shape[0]
dtype = columns[column].dtype

batch = 0 if batch is None else batch
shape = tf.TensorShape([]) if (
batch is None or batch == 0) else tf.TensorShape([None])
shape = tf.TensorShape([None])

# capacity is the rough count for each chunk in dataset
# not directly related to batch, will be padded to batch though
capacity = kwargs.get("capacity", 65536)
if batch is not None and batch != 0 and capacity > batch:
capacity = (capacity // batch) * batch
entry_start = range(0, count, capacity)
entry_count = [min(capacity, count - start) for start in entry_start]
entry_start = list(range(start, stop, capacity))
entry_stop = entry_start[1:] + [stop]
dataset = data_ops.BaseDataset.from_tensor_slices(
(tf.constant(entry_start, tf.int64), tf.constant(entry_count, tf.int64))
).map(lambda start, count: parquet_ops.read_parquet(
filename, column, start, count, dtype=dtype, memory=""))
if batch is None or batch == 0:
self._dataset = dataset.apply(tf.data.experimental.unbatch())
else:
# TODO: convert to rebatch for performance
self._dataset = dataset.apply(tf.data.experimental.unbatch()).batch(batch)
(tf.constant(entry_start, tf.int64), tf.constant(entry_stop, tf.int64))
).map(lambda start, stop: parquet_ops.read_parquet(
filename, column, memory="", start=start, stop=stop, dtype=dtype))
self._dataset = dataset

super(ParquetDataset, self).__init__(
self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access
Loading

0 comments on commit 2a2f113

Please sign in to comment.