diff --git a/tensorflow_io/arrow/BUILD b/tensorflow_io/arrow/BUILD index 83836b613..bd5ed8ea0 100644 --- a/tensorflow_io/arrow/BUILD +++ b/tensorflow_io/arrow/BUILD @@ -7,10 +7,12 @@ load( "tf_io_copts", ) -cc_binary( - name = "python/ops/_arrow_ops.so", +cc_library( + name = "arrow_ops", srcs = [ "kernels/arrow_dataset_ops.cc", + "kernels/arrow_kernels.cc", + "kernels/arrow_kernels.h", "kernels/arrow_stream_client.h", "kernels/arrow_stream_client_unix.cc", "kernels/arrow_util.cc", @@ -18,10 +20,9 @@ cc_binary( "ops/dataset_ops.cc", ], copts = tf_io_copts(), - linkshared = 1, + linkstatic = True, deps = [ + "//tensorflow_io/core:dataset_ops", "@arrow", - "@local_config_tf//:libtensorflow_framework", - "@local_config_tf//:tf_header_lib", ], ) diff --git a/tensorflow_io/arrow/__init__.py b/tensorflow_io/arrow/__init__.py index df6aa6b6d..0e36beae6 100644 --- a/tensorflow_io/arrow/__init__.py +++ b/tensorflow_io/arrow/__init__.py @@ -17,6 +17,7 @@ @@ArrowDataset @@ArrowFeatherDataset @@ArrowStreamDataset +@@list_feather_columns """ from __future__ import absolute_import @@ -26,6 +27,7 @@ from tensorflow_io.arrow.python.ops.arrow_dataset_ops import ArrowDataset from tensorflow_io.arrow.python.ops.arrow_dataset_ops import ArrowFeatherDataset from tensorflow_io.arrow.python.ops.arrow_dataset_ops import ArrowStreamDataset +from tensorflow_io.arrow.python.ops.arrow_dataset_ops import list_feather_columns from tensorflow.python.util.all_util import remove_undocumented @@ -33,6 +35,7 @@ "ArrowDataset", "ArrowFeatherDataset", "ArrowStreamDataset", + "list_feather_columns", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow_io/arrow/kernels/arrow_kernels.cc b/tensorflow_io/arrow/kernels/arrow_kernels.cc new file mode 100644 index 000000000..46b3bbcc5 --- /dev/null +++ b/tensorflow_io/arrow/kernels/arrow_kernels.cc @@ -0,0 +1,159 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/arrow/kernels/arrow_kernels.h" +#include "arrow/io/api.h" +#include "arrow/ipc/feather.h" +#include "arrow/ipc/feather_generated.h" +#include "arrow/buffer.h" + +namespace tensorflow { +namespace data { +namespace { + +class ListFeatherColumnsOp : public OpKernel { + public: + explicit ListFeatherColumnsOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string filename = filename_tensor.scalar()(); + + const Tensor& memory_tensor = context->input(1); + const string& memory = memory_tensor.scalar()(); + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); + uint64 size; + OP_REQUIRES_OK(context, file->GetFileSize(&size)); + + // FEA1.....[metadata][uint32 metadata_length]FEA1 + static constexpr const char* kFeatherMagicBytes = "FEA1"; + + size_t header_length = strlen(kFeatherMagicBytes); + size_t footer_length = sizeof(uint32) + strlen(kFeatherMagicBytes); + + string buffer; + buffer.resize(header_length > footer_length ? header_length : footer_length); + + StringPiece result; + + OP_REQUIRES_OK(context, file->Read(0, header_length, &result, &buffer[0])); + OP_REQUIRES(context, !memcmp(buffer.data(), kFeatherMagicBytes, header_length), errors::InvalidArgument("not a feather file")); + + OP_REQUIRES_OK(context, file->Read(size - footer_length, footer_length, &result, &buffer[0])); + OP_REQUIRES(context, !memcmp(buffer.data() + sizeof(uint32), kFeatherMagicBytes, footer_length - sizeof(uint32)), errors::InvalidArgument("incomplete feather file")); + + uint32 metadata_length = *reinterpret_cast(buffer.data()); + + buffer.resize(metadata_length); + + OP_REQUIRES_OK(context, file->Read(size - footer_length - metadata_length, metadata_length, &result, &buffer[0])); + + const ::arrow::ipc::feather::fbs::CTable* table = ::arrow::ipc::feather::fbs::GetCTable(buffer.data()); + + OP_REQUIRES(context, (table->version() >= ::arrow::ipc::feather::kFeatherVersion), errors::InvalidArgument("feather file is old: ", table->version(), " vs. ", ::arrow::ipc::feather::kFeatherVersion)); + + std::vector columns; + std::vector dtypes; + std::vector counts; + columns.reserve(table->columns()->size()); + dtypes.reserve(table->columns()->size()); + counts.reserve(table->columns()->size()); + + for (int64 i = 0; i < table->columns()->size(); i++) { + DataType dtype = ::tensorflow::DataType::DT_INVALID; + switch (table->columns()->Get(i)->values()->type()) { + case ::arrow::ipc::feather::fbs::Type_BOOL: + dtype = ::tensorflow::DataType::DT_BOOL; + break; + case ::arrow::ipc::feather::fbs::Type_INT8: + dtype = ::tensorflow::DataType::DT_INT8; + break; + case ::arrow::ipc::feather::fbs::Type_INT16: + dtype = ::tensorflow::DataType::DT_INT16; + break; + case ::arrow::ipc::feather::fbs::Type_INT32: + dtype = ::tensorflow::DataType::DT_INT32; + break; + case ::arrow::ipc::feather::fbs::Type_INT64: + dtype = ::tensorflow::DataType::DT_INT64; + break; + case ::arrow::ipc::feather::fbs::Type_UINT8: + dtype = ::tensorflow::DataType::DT_UINT8; + break; + case ::arrow::ipc::feather::fbs::Type_UINT16: + dtype = ::tensorflow::DataType::DT_UINT16; + break; + case ::arrow::ipc::feather::fbs::Type_UINT32: + dtype = ::tensorflow::DataType::DT_UINT32; + break; + case ::arrow::ipc::feather::fbs::Type_UINT64: + dtype = ::tensorflow::DataType::DT_UINT64; + break; + case ::arrow::ipc::feather::fbs::Type_FLOAT: + dtype = ::tensorflow::DataType::DT_FLOAT; + break; + case ::arrow::ipc::feather::fbs::Type_DOUBLE: + dtype = ::tensorflow::DataType::DT_DOUBLE; + break; + case ::arrow::ipc::feather::fbs::Type_UTF8: + case ::arrow::ipc::feather::fbs::Type_BINARY: + case ::arrow::ipc::feather::fbs::Type_CATEGORY: + case ::arrow::ipc::feather::fbs::Type_TIMESTAMP: + case ::arrow::ipc::feather::fbs::Type_DATE: + case ::arrow::ipc::feather::fbs::Type_TIME: + // case ::arrow::ipc::feather::fbs::Type_LARGE_UTF8: + // case ::arrow::ipc::feather::fbs::Type_LARGE_BINARY: + default: + break; + } + columns.push_back(table->columns()->Get(i)->name()->str()); + dtypes.push_back(::tensorflow::DataTypeString(dtype)); + counts.push_back(table->num_rows()); + } + + TensorShape output_shape = filename_tensor.shape(); + output_shape.AddDim(columns.size()); + + Tensor* columns_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &columns_tensor)); + Tensor* dtypes_tensor; + OP_REQUIRES_OK(context, context->allocate_output(1, output_shape, &dtypes_tensor)); + + output_shape.AddDim(1); + + Tensor* shapes_tensor; + OP_REQUIRES_OK(context, context->allocate_output(2, output_shape, &shapes_tensor)); + + for (size_t i = 0; i < columns.size(); i++) { + columns_tensor->flat()(i) = columns[i]; + dtypes_tensor->flat()(i) = dtypes[i]; + shapes_tensor->flat()(i) = counts[i]; + } + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("ListFeatherColumns").Device(DEVICE_CPU), + ListFeatherColumnsOp); + + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/arrow/kernels/arrow_kernels.h b/tensorflow_io/arrow/kernels/arrow_kernels.h new file mode 100644 index 000000000..6c729b29d --- /dev/null +++ b/tensorflow_io/arrow/kernels/arrow_kernels.h @@ -0,0 +1,80 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "kernels/stream.h" +#include "arrow/io/api.h" +#include "arrow/buffer.h" + +namespace tensorflow { +namespace data { + +// NOTE: Both SizedRandomAccessFile and ArrowRandomAccessFile overlap +// with another PR. Will remove duplicate once PR merged + +class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile { +public: + explicit ArrowRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size) + : file_(file) + , size_(size) { } + + ~ArrowRandomAccessFile() {} + arrow::Status Close() override { + return arrow::Status::OK(); + } + arrow::Status Tell(int64_t* position) const override { + return arrow::Status::NotImplemented("Tell"); + } + arrow::Status Seek(int64_t position) override { + return arrow::Status::NotImplemented("Seek"); + } + arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override { + return arrow::Status::NotImplemented("Read (void*)"); + } + arrow::Status Read(int64_t nbytes, std::shared_ptr* out) override { + return arrow::Status::NotImplemented("Read (Buffer*)"); + } + arrow::Status GetSize(int64_t* size) override { + *size = size_; + return arrow::Status::OK(); + } + bool supports_zero_copy() const override { + return false; + } + arrow::Status ReadAt(int64_t position, int64_t nbytes, int64_t* bytes_read, void* out) override { + StringPiece result; + Status status = file_->Read(position, nbytes, &result, (char*)out); + if (!(status.ok() || errors::IsOutOfRange(status))) { + return arrow::Status::IOError(status.error_message()); + } + *bytes_read = result.size(); + return arrow::Status::OK(); + } + arrow::Status ReadAt(int64_t position, int64_t nbytes, std::shared_ptr* out) override { + string buffer; + buffer.resize(nbytes); + StringPiece result; + Status status = file_->Read(position, nbytes, &result, (char*)(&buffer[0])); + if (!(status.ok() || errors::IsOutOfRange(status))) { + return arrow::Status::IOError(status.error_message()); + } + buffer.resize(result.size()); + return arrow::Buffer::FromString(buffer, out); + } +private: + tensorflow::RandomAccessFile* file_; + int64 size_; +}; +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/arrow/ops/dataset_ops.cc b/tensorflow_io/arrow/ops/dataset_ops.cc index 1139861b9..548e1288b 100644 --- a/tensorflow_io/arrow/ops/dataset_ops.cc +++ b/tensorflow_io/arrow/ops/dataset_ops.cc @@ -67,4 +67,18 @@ Creates a dataset that connects to a host serving Arrow RecordBatches in stream endpoints: One or more host addresses that are serving an Arrow stream. )doc"); + +REGISTER_OP("ListFeatherColumns") + .Input("filename: string") + .Input("memory: string") + .Output("columns: string") + .Output("dtypes: string") + .Output("shapes: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); + return Status::OK(); + }); + } // namespace tensorflow diff --git a/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py b/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py index f57bf48f7..f67cf73f4 100644 --- a/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py +++ b/tensorflow_io/arrow/python/ops/arrow_dataset_ops.py @@ -30,8 +30,7 @@ from tensorflow.compat.v2 import data from tensorflow.python.data.ops.dataset_ops import flat_structure from tensorflow.python.data.util import structure as structure_lib -from tensorflow_io import _load_library -arrow_ops = _load_library('_arrow_ops.so') +from tensorflow_io.core.python.ops import core_ops if hasattr(tf, "nest"): from tensorflow import nest # pylint: disable=ungrouped-imports @@ -183,7 +182,7 @@ def __init__(self, "auto" (size to number of records in Arrow record batch) """ super(ArrowDataset, self).__init__( - partial(arrow_ops.arrow_dataset, serialized_batches), + partial(core_ops.arrow_dataset, serialized_batches), columns, output_types, output_shapes, @@ -316,7 +315,7 @@ def __init__(self, dtype=dtypes.string, name="filenames") super(ArrowFeatherDataset, self).__init__( - partial(arrow_ops.arrow_feather_dataset, filenames), + partial(core_ops.arrow_feather_dataset, filenames), columns, output_types, output_shapes, @@ -401,7 +400,7 @@ def __init__(self, dtype=dtypes.string, name="endpoints") super(ArrowStreamDataset, self).__init__( - partial(arrow_ops.arrow_stream_dataset, endpoints), + partial(core_ops.arrow_stream_dataset, endpoints), columns, output_types, output_shapes, @@ -594,3 +593,15 @@ def gen_record_batches(): batch_size=batch_size, batch_mode='keep_remainder', record_batch_iter_factory=gen_record_batches) + +def list_feather_columns(filename, **kwargs): + """list_feather_columns""" + if not tf.executing_eagerly(): + raise NotImplementedError("list_feather_columns only support eager mode") + memory = kwargs.get("memory", "") + columns, dtypes_, shapes = core_ops.list_feather_columns( + filename, memory=memory) + entries = zip(tf.unstack(columns), tf.unstack(dtypes_), tf.unstack(shapes)) + return dict([(column.numpy().decode(), tf.TensorSpec( + shape.numpy(), dtype.numpy().decode(), column.numpy().decode())) for ( + column, dtype, shape) in entries]) diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 06a8205e2..1ff22fa58 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -125,6 +125,7 @@ cc_binary( linkshared = 1, deps = [ ":core_ops", + "//tensorflow_io/arrow:arrow_ops", "//tensorflow_io/audio:audio_ops", "//tensorflow_io/avro:avro_ops", "//tensorflow_io/azure:azfs_ops", diff --git a/tensorflow_io/parquet/BUILD b/tensorflow_io/parquet/BUILD index f6768043f..0ff2e4ac9 100644 --- a/tensorflow_io/parquet/BUILD +++ b/tensorflow_io/parquet/BUILD @@ -16,7 +16,7 @@ cc_library( copts = tf_io_copts(), linkstatic = True, deps = [ + "//tensorflow_io/arrow:arrow_ops", "//tensorflow_io/core:dataset_ops", - "@arrow", ], ) diff --git a/tensorflow_io/parquet/kernels/parquet_kernels.cc b/tensorflow_io/parquet/kernels/parquet_kernels.cc index e6caf99a8..5e5da6560 100644 --- a/tensorflow_io/parquet/kernels/parquet_kernels.cc +++ b/tensorflow_io/parquet/kernels/parquet_kernels.cc @@ -14,67 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/arrow/kernels/arrow_kernels.h" #include "parquet/api/reader.h" namespace tensorflow { namespace data { namespace { -class ParquetRandomAccessFile : public ::arrow::io::RandomAccessFile { -public: - explicit ParquetRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size) - : file_(file) - , size_(size) { } - - ~ParquetRandomAccessFile() {} - arrow::Status Close() override { - return arrow::Status::OK(); - } - arrow::Status Tell(int64_t* position) const override { - return arrow::Status::NotImplemented("Tell"); - } - arrow::Status Seek(int64_t position) override { - return arrow::Status::NotImplemented("Seek"); - } - arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override { - return arrow::Status::NotImplemented("Read (void*)"); - } - arrow::Status Read(int64_t nbytes, std::shared_ptr* out) override { - return arrow::Status::NotImplemented("Read (Buffer*)"); - } - arrow::Status GetSize(int64_t* size) override { - *size = size_; - return arrow::Status::OK(); - } - bool supports_zero_copy() const override { - return false; - } - arrow::Status ReadAt(int64_t position, int64_t nbytes, int64_t* bytes_read, void* out) override { - StringPiece result; - Status status = file_->Read(position, nbytes, &result, (char*)out); - if (!(status.ok() || errors::IsOutOfRange(status))) { - return arrow::Status::IOError(status.error_message()); - } - *bytes_read = result.size(); - return arrow::Status::OK(); - } - arrow::Status ReadAt(int64_t position, int64_t nbytes, std::shared_ptr* out) override { - string buffer; - buffer.resize(nbytes); - StringPiece result; - Status status = file_->Read(position, nbytes, &result, (char*)(&buffer[0])); - if (!(status.ok() || errors::IsOutOfRange(status))) { - return arrow::Status::IOError(status.error_message()); - } - buffer.resize(result.size()); - return arrow::Buffer::FromString(buffer, out); - } -private: - tensorflow::RandomAccessFile* file_; - int64 size_; -}; - class ListParquetColumnsOp : public OpKernel { public: explicit ListParquetColumnsOp(OpKernelConstruction* context) : OpKernel(context) { @@ -92,7 +38,7 @@ class ListParquetColumnsOp : public OpKernel { uint64 size; OP_REQUIRES_OK(context, file->GetFileSize(&size)); - std::shared_ptr parquet_file(new ParquetRandomAccessFile(file.get(), size)); + std::shared_ptr parquet_file(new ArrowRandomAccessFile(file.get(), size)); std::shared_ptr<::parquet::FileMetaData> metadata = ::parquet::ReadMetaData(parquet_file); std::vector columns; @@ -181,7 +127,7 @@ class ReadParquetOp : public OpKernel { uint64 size; OP_REQUIRES_OK(context, file->GetFileSize(&size)); - std::shared_ptr parquet_file(new ParquetRandomAccessFile(file.get(), size)); + std::shared_ptr parquet_file(new ArrowRandomAccessFile(file.get(), size)); std::unique_ptr<::parquet::ParquetFileReader> parquet_reader = parquet::ParquetFileReader::Open(parquet_file); std::shared_ptr<::parquet::FileMetaData> file_metadata = parquet_reader->metadata(); int column_index = 0; diff --git a/tests/test_arrow_eager.py b/tests/test_arrow_eager.py index 0879399e1..3c12ac76f 100644 --- a/tests/test_arrow_eager.py +++ b/tests/test_arrow_eager.py @@ -894,6 +894,40 @@ def test_unsupported_batch_mode(self): truth_data.output_shapes, batch_mode='doh') + def test_arrow_list_feather_columns(self): + """test_arrow_list_feather_columns""" + # Feather files currently do not support columns of list types + truth_data = TruthData(self.scalar_data, self.scalar_dtypes, + self.scalar_shapes) + + batch = self.make_record_batch(truth_data) + df = batch.to_pandas() + + # Create a tempfile that is deleted after tests run + with tempfile.NamedTemporaryFile(delete=False) as f: + write_feather(df, f) + + # test single file + # prefix "file://" to test scheme file system (e.g., s3, gcs, azfs, ignite) + columns = arrow_io.list_feather_columns("file://" + f.name) + for name, dtype in list(zip(batch.schema.names, batch.schema.types)): + assert columns[name].name == name + assert columns[name].dtype == dtype + assert columns[name].shape == [4] + + # test memory + with open(f.name, 'rb') as ff: + memory = ff.read() + # when memory is provided filename doesn't matter: + columns = arrow_io.list_feather_columns("file:///non_exist", memory=memory) + for name, dtype in list(zip(batch.schema.names, batch.schema.types)): + assert columns[name].name == name + assert columns[name].dtype == dtype + assert columns[name].shape == [4] + + os.unlink(f.name) + + if __name__ == "__main__": test.main()