From cf479f1bffb068356b9a0a5cbbc27695ce70a5a3 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 28 Jul 2019 05:49:06 +0000 Subject: [PATCH] Fix build failures Signed-off-by: Yong Tang --- tensorflow_io/parquet/__init__.py | 6 +- .../parquet/kernels/parquet_input.cc | 315 ------------------ .../parquet/kernels/parquet_kernels.cc | 114 +++---- tensorflow_io/parquet/ops/parquet_ops.cc | 3 +- .../parquet/python/ops/parquet_ops.py | 28 +- tests/test_parquet_eager.py | 2 +- 6 files changed, 79 insertions(+), 389 deletions(-) delete mode 100644 tensorflow_io/parquet/kernels/parquet_input.cc diff --git a/tensorflow_io/parquet/__init__.py b/tensorflow_io/parquet/__init__.py index 6e66d7cf4c..e64ed8b83a 100644 --- a/tensorflow_io/parquet/__init__.py +++ b/tensorflow_io/parquet/__init__.py @@ -16,7 +16,7 @@ @@ParquetDataset @@read_parquet -@@read_parquet_specs +@@read_parquet_columns """ from __future__ import absolute_import @@ -25,14 +25,14 @@ from tensorflow_io.parquet.python.ops.parquet_ops import ParquetDataset from tensorflow_io.parquet.python.ops.parquet_ops import read_parquet -from tensorflow_io.parquet.python.ops.parquet_ops import read_parquet_specs +from tensorflow_io.parquet.python.ops.parquet_ops import read_parquet_columns from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "ParquetDataset", "read_parquet", - "read_parquet_specs", + "read_parquet_columns", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow_io/parquet/kernels/parquet_input.cc b/tensorflow_io/parquet/kernels/parquet_input.cc deleted file mode 100644 index e1808d559b..0000000000 --- a/tensorflow_io/parquet/kernels/parquet_input.cc +++ /dev/null @@ -1,315 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "kernels/dataset_ops.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" -#include "parquet/api/reader.h" - -namespace tensorflow { -namespace data { - -class ParquetRandomAccessFile : public ::arrow::io::RandomAccessFile { -public: - explicit ParquetRandomAccessFile(io::InputStreamInterface* s) - : input_stream_(nullptr) - , buffered_stream_(nullptr) { - input_stream_ = dynamic_cast(s); - if (input_stream_ == nullptr) { - buffered_stream_.reset(new SizedRandomAccessBufferedStream(s)); - input_stream_ = buffered_stream_.get(); - } - } - ~ParquetRandomAccessFile() {} - arrow::Status Close() override { - return arrow::Status::OK(); - } - arrow::Status Tell(int64_t* position) const override { - return arrow::Status::NotImplemented("Tell"); - } - arrow::Status Seek(int64_t position) override { - return arrow::Status::NotImplemented("Seek"); - } - arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override { - return arrow::Status::NotImplemented("Read (void*)"); - } - arrow::Status Read(int64_t nbytes, std::shared_ptr* out) override { - return arrow::Status::NotImplemented("Read (Buffer*)"); - } - arrow::Status GetSize(int64_t* size) override { - uint64 size_value = 0; - Status status = input_stream_->GetFileSize(&size_value); - if (!status.ok()) { - return arrow::Status::IOError(status.error_message()); - } - *size = size_value; - return arrow::Status::OK(); - } - bool supports_zero_copy() const override { - return false; - } - arrow::Status ReadAt(int64_t position, int64_t nbytes, int64_t* bytes_read, void* out) override { - StringPiece result; - Status status = input_stream_->Read(position, nbytes, &result, (char *)out); - if (!(status.ok() || errors::IsOutOfRange(status))) { - return arrow::Status::IOError(status.error_message()); - } - *bytes_read = result.size(); - return arrow::Status::OK(); - } - arrow::Status ReadAt(int64_t position, int64_t nbytes, std::shared_ptr* out) override { - string buffer; - buffer.resize(nbytes); - StringPiece result; - Status status = input_stream_->Read(position, nbytes, &result, &buffer[0]); - if (!(status.ok() || errors::IsOutOfRange(status))) { - return arrow::Status::IOError(status.error_message()); - } - buffer.resize(result.size()); - return arrow::Buffer::FromString(buffer, out); - } -private: - SizedRandomAccessInputStreamInterface* input_stream_; - std::unique_ptr buffered_stream_; -}; - -class ParquetInputStream{ -public: - explicit ParquetInputStream(io::InputStreamInterface* s, const std::vector& columns) - : input_stream_(new ParquetRandomAccessFile(s)) - , column_names_(columns) { - } - Status ReadHeader() { - parquet_reader_ = parquet::ParquetFileReader::Open(input_stream_); - file_metadata_ = parquet_reader_->metadata(); - columns_ = std::vector(column_names_.size(), -1); - dtypes_ = std::vector(column_names_.size()); - for (size_t i = 0; i < column_names_.size(); i++) { - for (int j = 0; j < file_metadata_->schema()->num_columns(); j++) { - if (column_names_[i] == file_metadata_->schema()->Column(j)->path().get()->ToDotString()) { - columns_[i] = j; - switch(file_metadata_->schema()->Column(j)->physical_type()) { - case parquet::Type::BOOLEAN: - dtypes_[i] = DT_BOOL; - break; - case parquet::Type::INT32: - dtypes_[i] = DT_INT32; - break; - case parquet::Type::INT64: - dtypes_[i] = DT_INT64; - break; - case parquet::Type::FLOAT: - dtypes_[i] = DT_FLOAT; - break; - case parquet::Type::DOUBLE: - dtypes_[i] = DT_DOUBLE; - break; - default: - return errors::InvalidArgument("data type is not supported for column ", column_names_[i]); - } - break; - } - } - if (columns_[i] < 0) { - return errors::InvalidArgument("unable to find column ", column_names_[i]); - } - } - current_row_group_ = 0; - TF_RETURN_IF_ERROR(ReadRowGroup()); - return Status::OK(); - } - DataType DType(int64 i) { - return dtypes_[i]; - } - int64 Columns() { - return (int64)columns_.size(); - } - Status ReadRowGroup() { - if (current_row_group_ < file_metadata_->num_row_groups()) { - row_group_reader_ = parquet_reader_->RowGroup(current_row_group_); - column_readers_.clear(); - for (size_t i = 0; i < columns_.size(); i++) { - int64 column = columns_[i]; - std::shared_ptr column_reader = - row_group_reader_->Column(column); - column_readers_.emplace_back(column_reader); - } - } - current_row_ = 0; - return Status::OK(); - } - ~ParquetInputStream() { - current_row_ = 0; - column_readers_.clear(); - row_group_reader_.reset(); - current_row_group_ = 0; - file_metadata_.reset(); - parquet_reader_.reset(); - } - Status ReadRecord(int64 index, int64 record_to_read, std::vector* out_tensors, int64* record_read) { - while (current_row_group_ < file_metadata_->num_row_groups()) { - if (current_row_ < row_group_reader_->metadata()->num_rows()) { - // Read columns to outputs. - // TODO: Read more than one value at a time. - for (size_t i = 0; i < columns_.size(); i++) { - DataType dtype = dtypes_[i]; - std::shared_ptr column_reader = column_readers_[i]; - TF_RETURN_IF_ERROR(GetTensorValue(current_row_, dtype, column_reader.get(), &(*out_tensors)[i], index)); - } - ++current_row_; - *record_read = 1; - return Status::OK(); - } - // We have reached the end of the current row group, so maybe - // move on to next row group. - current_row_ = 0; - row_group_reader_.reset(); - ++current_row_group_; - TF_RETURN_IF_ERROR(ReadRowGroup()); - } - return Status::OK(); - } -private: - template - Status FillTensorValue(parquet::ColumnReader* column_reader, - typename DType::c_type* value) { - parquet::TypedColumnReader* reader = - static_cast*>(column_reader); - // Read one value at a time. The number of rows read is returned. - // values_read contains the number of non-null rows - int64_t values_read = 0; - int64_t rows_read = reader->ReadBatch(1, nullptr, nullptr, value, &values_read); - // Ensure only one value is read and there are no NULL values in the - // rows read - if (rows_read != 1) { - return errors::Internal("rows_read (", rows_read, ") != 1 or values_read (", values_read, ") != 1"); - } - return Status::OK(); - } - Status GetTensorValue(int64 row, const DataType& data_type, parquet::ColumnReader* column_reader, Tensor* tensor, int64 index) { - switch (data_type) { - case DT_INT32: { - parquet::TypedColumnReader* reader = - static_cast*>( - column_reader); - int32_t value; - TF_RETURN_IF_ERROR( - FillTensorValue(reader, &value)); - tensor->flat()(index) = value; - } break; - case DT_INT64: { - parquet::TypedColumnReader* reader = - static_cast*>( - column_reader); - int64_t value; - TF_RETURN_IF_ERROR( - FillTensorValue(reader, &value)); - tensor->flat()(index) = value; - } break; - case DT_FLOAT: { - parquet::TypedColumnReader* reader = - static_cast*>( - column_reader); - float value; - TF_RETURN_IF_ERROR( - FillTensorValue(reader, &value)); - tensor->flat()(index) = value; - } break; - case DT_DOUBLE: { - parquet::TypedColumnReader* reader = - static_cast*>( - column_reader); - double value; - TF_RETURN_IF_ERROR( - FillTensorValue(reader, &value)); - tensor->flat()(index) = value; - } break; - case DT_BOOL: { - parquet::TypedColumnReader* reader = - static_cast*>( - column_reader); - bool value; - TF_RETURN_IF_ERROR( - FillTensorValue(reader, &value)); - tensor->flat()(index) = value; - } break; - default: - return errors::Unimplemented( - DataTypeString(data_type), - " is currently not supported in ParquetDataset"); - } - return Status::OK(); - } - std::shared_ptr<::arrow::io::RandomAccessFile> input_stream_; - std::vector column_names_; - std::vector columns_; - std::vector dtypes_; - std::unique_ptr parquet_reader_; - std::shared_ptr file_metadata_; - int64 current_row_group_ = 0; - std::shared_ptr row_group_reader_; - std::vector> column_readers_; - int64 current_row_ = 0; -}; - -class ParquetInput: public FileInput { - public: - Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { - if (state.get() == nullptr) { - state.reset(new ParquetInputStream(s, columns())); - TF_RETURN_IF_ERROR(state.get()->ReadHeader()); - } - // Let's allocate enough space for Tensor, if more than read, replace. - for (int64 i = 0; i < state.get()->Columns(); i++) { - Tensor tensor(ctx->allocator({}), state.get()->DType(i), {record_to_read}); - out_tensors->emplace_back(std::move(tensor)); - } - while ((*record_read) < record_to_read) { - int64 count = 0; - TF_RETURN_IF_ERROR(state.get()->ReadRecord((*record_read), record_to_read - (*record_read), out_tensors, &count)); - (*record_read) += count; - if (count == 0) { - break; - } - } - if (*record_read < record_to_read) { - if (*record_read == 0) { - out_tensors->clear(); - } - for (size_t i = 0; i < out_tensors->size(); i++) { - Tensor tensor = (*out_tensors)[i].Slice(0, *record_read); - (*out_tensors)[i] = std::move(tensor); - } - } - return Status::OK(); - } - Status FromStream(io::InputStreamInterface* s) override { - return Status::OK(); - } - void EncodeAttributes(VariantTensorData* data) const override { - } - bool DecodeAttributes(const VariantTensorData& data) override { - return true; - } - protected: -}; - -REGISTER_UNARY_VARIANT_DECODE_FUNCTION(ParquetInput, "tensorflow::data::ParquetInput"); - -REGISTER_KERNEL_BUILDER(Name("ParquetInput").Device(DEVICE_CPU), - FileInputOp); -REGISTER_KERNEL_BUILDER(Name("ParquetDataset").Device(DEVICE_CPU), - FileInputDatasetOp); -} // namespace data -} // namespace tensorflow diff --git a/tensorflow_io/parquet/kernels/parquet_kernels.cc b/tensorflow_io/parquet/kernels/parquet_kernels.cc index 00fc4624d0..13c85ac869 100644 --- a/tensorflow_io/parquet/kernels/parquet_kernels.cc +++ b/tensorflow_io/parquet/kernels/parquet_kernels.cc @@ -23,6 +23,54 @@ namespace tensorflow { namespace data { namespace { +// Note: This SizedRandomAccessFile should only lives within Compute() +// of the kernel as buffer could be released by outside. +class SizedRandomAccessFile : public tensorflow::RandomAccessFile { + public: + SizedRandomAccessFile(Env* env, const string& filename, const string& optional_memory) + : file_(nullptr) + , size_status_(Status::OK()) + , size_(optional_memory.size()) + , buffer_(optional_memory) { + if (size_ == 0) { + size_status_ = env->GetFileSize(filename, &size_); + if (size_status_.ok()) { + size_status_ = env->NewRandomAccessFile(filename, &file_); + } + } + } + + virtual ~SizedRandomAccessFile() {} + Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { + if (file_.get() != nullptr) { + return file_.get()->Read(offset, n, result, scratch); + } + size_t bytes_to_read = 0; + if (offset < size_) { + bytes_to_read = (offset + n < size_) ? n : (size_ - offset); + } + if (bytes_to_read > 0) { + memcpy(scratch, buffer_.data(), bytes_to_read); + } + *result = StringPiece(scratch, bytes_to_read); + if (bytes_to_read < n) { + return errors::OutOfRange("EOF reached"); + } + return Status::OK(); + } + Status GetFileSize(uint64* size) { + if (size_status_.ok()) { + *size = size_; + } + return size_status_; + } + private: + std::unique_ptr file_; + Status size_status_; + uint64 size_; + const string& buffer_; +}; + class ParquetRandomAccessFile : public ::arrow::io::RandomAccessFile { public: explicit ParquetRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size) @@ -77,9 +125,9 @@ class ParquetRandomAccessFile : public ::arrow::io::RandomAccessFile { int64 size_; }; -class ReadParquetSpecsOp : public OpKernel { +class ReadParquetColumnsOp : public OpKernel { public: - explicit ReadParquetSpecsOp(OpKernelConstruction* context) : OpKernel(context) { + explicit ReadParquetColumnsOp(OpKernelConstruction* context) : OpKernel(context) { env_ = context->env(); } @@ -87,10 +135,12 @@ class ReadParquetSpecsOp : public OpKernel { const Tensor& filename_tensor = context->input(0); const string filename = filename_tensor.scalar()(); - std::unique_ptr file; - OP_REQUIRES_OK(context, env_->NewRandomAccessFile(filename, &file)); - uint64 size = 0; - OP_REQUIRES_OK(context, env_->GetFileSize(filename, &size)); + const Tensor& memory_tensor = context->input(1); + const string& memory = memory_tensor.scalar()(); + + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory)); + uint64 size; + OP_REQUIRES_OK(context, file->GetFileSize(&size)); std::shared_ptr parquet_file(new ParquetRandomAccessFile(file.get(), size)); std::shared_ptr<::parquet::FileMetaData> metadata = ::parquet::ReadMetaData(parquet_file); @@ -155,54 +205,6 @@ class ReadParquetSpecsOp : public OpKernel { Env* env_ GUARDED_BY(mu_); }; -// Note: This SizedRandomAccessFile should only lives within Compute() -// of the kernel as buffer could be released by outside. -class SizedRandomAccessFile : public tensorflow::RandomAccessFile { - public: - SizedRandomAccessFile(Env* env, const string& filename, const string& optional_memory) - : file_(nullptr) - , size_status_(Status::OK()) - , size_(optional_memory.size()) - , buffer_(optional_memory) { - if (size_ == 0) { - size_status_ = env->GetFileSize(filename, &size_); - if (size_status_.ok()) { - size_status_ = env->NewRandomAccessFile(filename, &file_); - } - } - } - - virtual ~SizedRandomAccessFile() {} - Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { - if (file_.get() != nullptr) { - return file_.get()->Read(offset, n, result, scratch); - } - size_t bytes_to_read = 0; - if (offset < size_) { - bytes_to_read = (offset + n < size_) ? n : (size_ - offset); - } - if (bytes_to_read > 0) { - memcpy(scratch, buffer_.data(), bytes_to_read); - } - *result = StringPiece(scratch, bytes_to_read); - if (bytes_to_read < n) { - return errors::OutOfRange("EOF reached"); - } - return Status::OK(); - } - Status GetFileSize(uint64* size) { - if (size_status_.ok()) { - *size = size_; - } - return size_status_; - } - private: - std::unique_ptr file_; - Status size_status_; - uint64 size_; - const string& buffer_; -}; - class ReadParquetOp : public OpKernel { public: explicit ReadParquetOp(OpKernelConstruction* context) : OpKernel(context) { @@ -307,8 +309,8 @@ class ReadParquetOp : public OpKernel { Env* env_ GUARDED_BY(mu_); }; -REGISTER_KERNEL_BUILDER(Name("ReadParquetSpecs").Device(DEVICE_CPU), - ReadParquetSpecsOp); +REGISTER_KERNEL_BUILDER(Name("ReadParquetColumns").Device(DEVICE_CPU), + ReadParquetColumnsOp); REGISTER_KERNEL_BUILDER(Name("ReadParquet").Device(DEVICE_CPU), ReadParquetOp); diff --git a/tensorflow_io/parquet/ops/parquet_ops.cc b/tensorflow_io/parquet/ops/parquet_ops.cc index 8d7a81cafe..dd5e58c894 100644 --- a/tensorflow_io/parquet/ops/parquet_ops.cc +++ b/tensorflow_io/parquet/ops/parquet_ops.cc @@ -19,8 +19,9 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("ReadParquetSpecs") +REGISTER_OP("ReadParquetColumns") .Input("filename: string") + .Input("memory: string") .Output("columns: string") .Output("dtypes: string") .Output("shapes: int64") diff --git a/tensorflow_io/parquet/python/ops/parquet_ops.py b/tensorflow_io/parquet/python/ops/parquet_ops.py index 76acd8b695..7d4b07781e 100644 --- a/tensorflow_io/parquet/python/ops/parquet_ops.py +++ b/tensorflow_io/parquet/python/ops/parquet_ops.py @@ -21,22 +21,24 @@ from tensorflow_io.core.python.ops import core_ops as parquet_ops from tensorflow_io.core.python.ops import data_ops -def read_parquet_specs(filename): - """read_parquet_specs""" +def read_parquet_columns(filename, **kwargs): + """read_parquet_columns""" if not tf.executing_eagerly(): raise NotImplementedError("read_parquet_spect only support eager mode") - columns, dtypes, shapes = parquet_ops.read_parquet_specs(filename) + memory = kwargs.get("memory", "") + columns, dtypes, shapes = parquet_ops.read_parquet_columns( + filename, memory=memory) entries = zip(tf.unstack(columns), tf.unstack(dtypes), tf.unstack(shapes)) - return dict([(column.numpy(), tf.TensorSpec( - shape.numpy(), dtype.numpy(), column.numpy())) for ( + return dict([(column.numpy().decode(), tf.TensorSpec( + shape.numpy(), dtype.numpy().decode(), column.numpy().decode())) for ( column, dtype, shape) in entries]) -def read_parquet(filename, spec, start=0, **kwargs): +def read_parquet(filename, column, start=0, **kwargs): """read_parquet""" memory = kwargs.get("memory", "") return parquet_ops.read_parquet( - filename, spec.name, - start=start, count=spec.shape[0] - start, dtype=spec.dtype, + filename, column.name, + start=start, count=column.shape[0] - start, dtype=column.dtype, memory=memory) class ParquetDataset(data_ops.BaseDataset): @@ -56,9 +58,9 @@ def __init__(self, filename, column, batch=None, **kwargs): count = kwargs.get("count") dtype = kwargs.get("dtype") else: - specs = read_parquet_specs(filename) - count = specs[column].shape[0] - dtype = specs[column].dtype + columns = read_parquet_columns(filename) + count = columns[column].shape[0] + dtype = columns[column].dtype batch = 0 if batch is None else batch shape = tf.TensorShape([]) if ( @@ -76,10 +78,10 @@ def __init__(self, filename, column, batch=None, **kwargs): ).map(lambda start, count: parquet_ops.read_parquet( filename, column, start, count, dtype=dtype, memory="")) if batch is None or batch == 0: - self._dataset = dataset.unbatch() + self._dataset = dataset.apply(tf.data.experimental.unbatch()) else: # TODO: convert to rebatch for performance - self._dataset = dataset.unbatch().batch(batch) + self._dataset = dataset.apply(tf.data.experimental.unbatch()).batch(batch) super(ParquetDataset, self).__init__( self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access diff --git a/tests/test_parquet_eager.py b/tests/test_parquet_eager.py index 4dd67ed7fe..d49c5c2368 100644 --- a/tests/test_parquet_eager.py +++ b/tests/test_parquet_eager.py @@ -47,7 +47,7 @@ def test_parquet(): "parquet_cpp_example.parquet") filename = "file://" + filename - specs = parquet_io.read_parquet_specs(filename) + specs = parquet_io.read_parquet_columns(filename) columns = [ 'boolean_field', 'int32_field',