Skip to content

Commit

Permalink
Add list_feather_columns function in eager mode (#404)
Browse files Browse the repository at this point in the history
* Add list_feather_columns function in eager mode

This PR adds list_feather_columns function in eager mode,
so that it is possible to get the column name and spec
information for feather format.

This PR implements an `::arrow::io::RandomAccessFile` interface
so it is possible to read files through scheme file system,
e.g., s3, gcs, azfs, etc.

The `::arrow::io::RandomAccessFile` is the same as in Parquet PR 384
so they could be combined.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Use flatbuffer to read feather metadata, to avoid reading whole file through feather api.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Keep unsupported datatype so that it is possible to process in python, based on review comment

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Combine .so files into one place to reduce whl package size

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Combine ArrowRandomAccessFile and ParquetRandomAccessFile as they are the same

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
  • Loading branch information
yongtang authored Aug 5, 2019
1 parent 1642da1 commit d0fe60c
Show file tree
Hide file tree
Showing 10 changed files with 317 additions and 68 deletions.
11 changes: 6 additions & 5 deletions tensorflow_io/arrow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,22 @@ load(
"tf_io_copts",
)

cc_binary(
name = "python/ops/_arrow_ops.so",
cc_library(
name = "arrow_ops",
srcs = [
"kernels/arrow_dataset_ops.cc",
"kernels/arrow_kernels.cc",
"kernels/arrow_kernels.h",
"kernels/arrow_stream_client.h",
"kernels/arrow_stream_client_unix.cc",
"kernels/arrow_util.cc",
"kernels/arrow_util.h",
"ops/dataset_ops.cc",
],
copts = tf_io_copts(),
linkshared = 1,
linkstatic = True,
deps = [
"//tensorflow_io/core:dataset_ops",
"@arrow",
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
],
)
3 changes: 3 additions & 0 deletions tensorflow_io/arrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
@@ArrowDataset
@@ArrowFeatherDataset
@@ArrowStreamDataset
@@list_feather_columns
"""

from __future__ import absolute_import
Expand All @@ -26,13 +27,15 @@
from tensorflow_io.arrow.python.ops.arrow_dataset_ops import ArrowDataset
from tensorflow_io.arrow.python.ops.arrow_dataset_ops import ArrowFeatherDataset
from tensorflow_io.arrow.python.ops.arrow_dataset_ops import ArrowStreamDataset
from tensorflow_io.arrow.python.ops.arrow_dataset_ops import list_feather_columns

from tensorflow.python.util.all_util import remove_undocumented

_allowed_symbols = [
"ArrowDataset",
"ArrowFeatherDataset",
"ArrowStreamDataset",
"list_feather_columns",
]

remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
159 changes: 159 additions & 0 deletions tensorflow_io/arrow/kernels/arrow_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow_io/arrow/kernels/arrow_kernels.h"
#include "arrow/io/api.h"
#include "arrow/ipc/feather.h"
#include "arrow/ipc/feather_generated.h"
#include "arrow/buffer.h"

namespace tensorflow {
namespace data {
namespace {

class ListFeatherColumnsOp : public OpKernel {
public:
explicit ListFeatherColumnsOp(OpKernelConstruction* context) : OpKernel(context) {
env_ = context->env();
}

void Compute(OpKernelContext* context) override {
const Tensor& filename_tensor = context->input(0);
const string filename = filename_tensor.scalar<string>()();

const Tensor& memory_tensor = context->input(1);
const string& memory = memory_tensor.scalar<string>()();
std::unique_ptr<SizedRandomAccessFile> file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size()));
uint64 size;
OP_REQUIRES_OK(context, file->GetFileSize(&size));

// FEA1.....[metadata][uint32 metadata_length]FEA1
static constexpr const char* kFeatherMagicBytes = "FEA1";

size_t header_length = strlen(kFeatherMagicBytes);
size_t footer_length = sizeof(uint32) + strlen(kFeatherMagicBytes);

string buffer;
buffer.resize(header_length > footer_length ? header_length : footer_length);

StringPiece result;

OP_REQUIRES_OK(context, file->Read(0, header_length, &result, &buffer[0]));
OP_REQUIRES(context, !memcmp(buffer.data(), kFeatherMagicBytes, header_length), errors::InvalidArgument("not a feather file"));

OP_REQUIRES_OK(context, file->Read(size - footer_length, footer_length, &result, &buffer[0]));
OP_REQUIRES(context, !memcmp(buffer.data() + sizeof(uint32), kFeatherMagicBytes, footer_length - sizeof(uint32)), errors::InvalidArgument("incomplete feather file"));

uint32 metadata_length = *reinterpret_cast<const uint32*>(buffer.data());

buffer.resize(metadata_length);

OP_REQUIRES_OK(context, file->Read(size - footer_length - metadata_length, metadata_length, &result, &buffer[0]));

const ::arrow::ipc::feather::fbs::CTable* table = ::arrow::ipc::feather::fbs::GetCTable(buffer.data());

OP_REQUIRES(context, (table->version() >= ::arrow::ipc::feather::kFeatherVersion), errors::InvalidArgument("feather file is old: ", table->version(), " vs. ", ::arrow::ipc::feather::kFeatherVersion));

std::vector<string> columns;
std::vector<string> dtypes;
std::vector<int64> counts;
columns.reserve(table->columns()->size());
dtypes.reserve(table->columns()->size());
counts.reserve(table->columns()->size());

for (int64 i = 0; i < table->columns()->size(); i++) {
DataType dtype = ::tensorflow::DataType::DT_INVALID;
switch (table->columns()->Get(i)->values()->type()) {
case ::arrow::ipc::feather::fbs::Type_BOOL:
dtype = ::tensorflow::DataType::DT_BOOL;
break;
case ::arrow::ipc::feather::fbs::Type_INT8:
dtype = ::tensorflow::DataType::DT_INT8;
break;
case ::arrow::ipc::feather::fbs::Type_INT16:
dtype = ::tensorflow::DataType::DT_INT16;
break;
case ::arrow::ipc::feather::fbs::Type_INT32:
dtype = ::tensorflow::DataType::DT_INT32;
break;
case ::arrow::ipc::feather::fbs::Type_INT64:
dtype = ::tensorflow::DataType::DT_INT64;
break;
case ::arrow::ipc::feather::fbs::Type_UINT8:
dtype = ::tensorflow::DataType::DT_UINT8;
break;
case ::arrow::ipc::feather::fbs::Type_UINT16:
dtype = ::tensorflow::DataType::DT_UINT16;
break;
case ::arrow::ipc::feather::fbs::Type_UINT32:
dtype = ::tensorflow::DataType::DT_UINT32;
break;
case ::arrow::ipc::feather::fbs::Type_UINT64:
dtype = ::tensorflow::DataType::DT_UINT64;
break;
case ::arrow::ipc::feather::fbs::Type_FLOAT:
dtype = ::tensorflow::DataType::DT_FLOAT;
break;
case ::arrow::ipc::feather::fbs::Type_DOUBLE:
dtype = ::tensorflow::DataType::DT_DOUBLE;
break;
case ::arrow::ipc::feather::fbs::Type_UTF8:
case ::arrow::ipc::feather::fbs::Type_BINARY:
case ::arrow::ipc::feather::fbs::Type_CATEGORY:
case ::arrow::ipc::feather::fbs::Type_TIMESTAMP:
case ::arrow::ipc::feather::fbs::Type_DATE:
case ::arrow::ipc::feather::fbs::Type_TIME:
// case ::arrow::ipc::feather::fbs::Type_LARGE_UTF8:
// case ::arrow::ipc::feather::fbs::Type_LARGE_BINARY:
default:
break;
}
columns.push_back(table->columns()->Get(i)->name()->str());
dtypes.push_back(::tensorflow::DataTypeString(dtype));
counts.push_back(table->num_rows());
}

TensorShape output_shape = filename_tensor.shape();
output_shape.AddDim(columns.size());

Tensor* columns_tensor;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &columns_tensor));
Tensor* dtypes_tensor;
OP_REQUIRES_OK(context, context->allocate_output(1, output_shape, &dtypes_tensor));

output_shape.AddDim(1);

Tensor* shapes_tensor;
OP_REQUIRES_OK(context, context->allocate_output(2, output_shape, &shapes_tensor));

for (size_t i = 0; i < columns.size(); i++) {
columns_tensor->flat<string>()(i) = columns[i];
dtypes_tensor->flat<string>()(i) = dtypes[i];
shapes_tensor->flat<int64>()(i) = counts[i];
}
}
private:
mutex mu_;
Env* env_ GUARDED_BY(mu_);
};

REGISTER_KERNEL_BUILDER(Name("ListFeatherColumns").Device(DEVICE_CPU),
ListFeatherColumnsOp);


} // namespace
} // namespace data
} // namespace tensorflow
80 changes: 80 additions & 0 deletions tensorflow_io/arrow/kernels/arrow_kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "kernels/stream.h"
#include "arrow/io/api.h"
#include "arrow/buffer.h"

namespace tensorflow {
namespace data {

// NOTE: Both SizedRandomAccessFile and ArrowRandomAccessFile overlap
// with another PR. Will remove duplicate once PR merged

class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile {
public:
explicit ArrowRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size)
: file_(file)
, size_(size) { }

~ArrowRandomAccessFile() {}
arrow::Status Close() override {
return arrow::Status::OK();
}
arrow::Status Tell(int64_t* position) const override {
return arrow::Status::NotImplemented("Tell");
}
arrow::Status Seek(int64_t position) override {
return arrow::Status::NotImplemented("Seek");
}
arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override {
return arrow::Status::NotImplemented("Read (void*)");
}
arrow::Status Read(int64_t nbytes, std::shared_ptr<arrow::Buffer>* out) override {
return arrow::Status::NotImplemented("Read (Buffer*)");
}
arrow::Status GetSize(int64_t* size) override {
*size = size_;
return arrow::Status::OK();
}
bool supports_zero_copy() const override {
return false;
}
arrow::Status ReadAt(int64_t position, int64_t nbytes, int64_t* bytes_read, void* out) override {
StringPiece result;
Status status = file_->Read(position, nbytes, &result, (char*)out);
if (!(status.ok() || errors::IsOutOfRange(status))) {
return arrow::Status::IOError(status.error_message());
}
*bytes_read = result.size();
return arrow::Status::OK();
}
arrow::Status ReadAt(int64_t position, int64_t nbytes, std::shared_ptr<arrow::Buffer>* out) override {
string buffer;
buffer.resize(nbytes);
StringPiece result;
Status status = file_->Read(position, nbytes, &result, (char*)(&buffer[0]));
if (!(status.ok() || errors::IsOutOfRange(status))) {
return arrow::Status::IOError(status.error_message());
}
buffer.resize(result.size());
return arrow::Buffer::FromString(buffer, out);
}
private:
tensorflow::RandomAccessFile* file_;
int64 size_;
};
} // namespace data
} // namespace tensorflow
14 changes: 14 additions & 0 deletions tensorflow_io/arrow/ops/dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,18 @@ Creates a dataset that connects to a host serving Arrow RecordBatches in stream
endpoints: One or more host addresses that are serving an Arrow stream.
)doc");


REGISTER_OP("ListFeatherColumns")
.Input("filename: string")
.Input("memory: string")
.Output("columns: string")
.Output("dtypes: string")
.Output("shapes: int64")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->MakeShape({c->UnknownDim()}));
c->set_output(1, c->MakeShape({c->UnknownDim()}));
c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()}));
return Status::OK();
});

} // namespace tensorflow
21 changes: 16 additions & 5 deletions tensorflow_io/arrow/python/ops/arrow_dataset_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
from tensorflow.compat.v2 import data
from tensorflow.python.data.ops.dataset_ops import flat_structure
from tensorflow.python.data.util import structure as structure_lib
from tensorflow_io import _load_library
arrow_ops = _load_library('_arrow_ops.so')
from tensorflow_io.core.python.ops import core_ops

if hasattr(tf, "nest"):
from tensorflow import nest # pylint: disable=ungrouped-imports
Expand Down Expand Up @@ -183,7 +182,7 @@ def __init__(self,
"auto" (size to number of records in Arrow record batch)
"""
super(ArrowDataset, self).__init__(
partial(arrow_ops.arrow_dataset, serialized_batches),
partial(core_ops.arrow_dataset, serialized_batches),
columns,
output_types,
output_shapes,
Expand Down Expand Up @@ -316,7 +315,7 @@ def __init__(self,
dtype=dtypes.string,
name="filenames")
super(ArrowFeatherDataset, self).__init__(
partial(arrow_ops.arrow_feather_dataset, filenames),
partial(core_ops.arrow_feather_dataset, filenames),
columns,
output_types,
output_shapes,
Expand Down Expand Up @@ -401,7 +400,7 @@ def __init__(self,
dtype=dtypes.string,
name="endpoints")
super(ArrowStreamDataset, self).__init__(
partial(arrow_ops.arrow_stream_dataset, endpoints),
partial(core_ops.arrow_stream_dataset, endpoints),
columns,
output_types,
output_shapes,
Expand Down Expand Up @@ -594,3 +593,15 @@ def gen_record_batches():
batch_size=batch_size,
batch_mode='keep_remainder',
record_batch_iter_factory=gen_record_batches)

def list_feather_columns(filename, **kwargs):
"""list_feather_columns"""
if not tf.executing_eagerly():
raise NotImplementedError("list_feather_columns only support eager mode")
memory = kwargs.get("memory", "")
columns, dtypes_, shapes = core_ops.list_feather_columns(
filename, memory=memory)
entries = zip(tf.unstack(columns), tf.unstack(dtypes_), tf.unstack(shapes))
return dict([(column.numpy().decode(), tf.TensorSpec(
shape.numpy(), dtype.numpy().decode(), column.numpy().decode())) for (
column, dtype, shape) in entries])
1 change: 1 addition & 0 deletions tensorflow_io/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ cc_binary(
linkshared = 1,
deps = [
":core_ops",
"//tensorflow_io/arrow:arrow_ops",
"//tensorflow_io/audio:audio_ops",
"//tensorflow_io/avro:avro_ops",
"//tensorflow_io/azure:azfs_ops",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_io/parquet/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ cc_library(
copts = tf_io_copts(),
linkstatic = True,
deps = [
"//tensorflow_io/arrow:arrow_ops",
"//tensorflow_io/core:dataset_ops",
"@arrow",
],
)
Loading

0 comments on commit d0fe60c

Please sign in to comment.