From 8af0a01fb42dc521099120c6348cc3002b4af122 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 4 Aug 2019 07:53:07 -0700 Subject: [PATCH] Add read_wav and list_wav_info support (#406) * Add read_wav and list_wav_info support This PR adds read_wav and list_wav_info support so that it is possible to read wav file into a tensor. Since WAV file is splittable, read_wav could take a start and count parameter so that only a slice of the wav file is read. This is also part of the effort to rework on Dataset to move to primitive ops. See See 382 and 366 for related discussions. Signed-off-by: Yong Tang * Rename start, count to start, stop Signed-off-by: Yong Tang --- tensorflow_io/audio/BUILD | 2 +- tensorflow_io/audio/__init__.py | 6 + tensorflow_io/audio/kernels/audio_input.cc | 166 ------------ tensorflow_io/audio/kernels/audio_kernels.cc | 261 +++++++++++++++++++ tensorflow_io/audio/ops/audio_ops.cc | 33 +-- tensorflow_io/audio/python/ops/audio_ops.py | 65 ++++- tensorflow_io/core/BUILD | 1 + tensorflow_io/core/kernels/stream.h | 71 +++++ tensorflow_io/core/python/ops/data_ops.py | 5 +- tests/test_audio_eager.py | 41 ++- 10 files changed, 438 insertions(+), 213 deletions(-) delete mode 100644 tensorflow_io/audio/kernels/audio_input.cc create mode 100644 tensorflow_io/audio/kernels/audio_kernels.cc create mode 100644 tensorflow_io/core/kernels/stream.h diff --git a/tensorflow_io/audio/BUILD b/tensorflow_io/audio/BUILD index 867bddd10..dc56be3fe 100644 --- a/tensorflow_io/audio/BUILD +++ b/tensorflow_io/audio/BUILD @@ -10,7 +10,7 @@ load( cc_library( name = "audio_ops", srcs = [ - "kernels/audio_input.cc", + "kernels/audio_kernels.cc", "ops/audio_ops.cc", ], copts = tf_io_copts(), diff --git a/tensorflow_io/audio/__init__.py b/tensorflow_io/audio/__init__.py index 5792023c0..b95edeedd 100644 --- a/tensorflow_io/audio/__init__.py +++ b/tensorflow_io/audio/__init__.py @@ -15,6 +15,8 @@ """Audio Dataset. @@WAVDataset +@@list_wav_info +@@read_wav """ from __future__ import absolute_import @@ -22,11 +24,15 @@ from __future__ import print_function from tensorflow_io.audio.python.ops.audio_ops import WAVDataset +from tensorflow_io.audio.python.ops.audio_ops import list_wav_info +from tensorflow_io.audio.python.ops.audio_ops import read_wav from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "WAVDataset", + "list_wav_info", + "read_wav", ] remove_undocumented(__name__) diff --git a/tensorflow_io/audio/kernels/audio_input.cc b/tensorflow_io/audio/kernels/audio_input.cc deleted file mode 100644 index 31f1553c0..000000000 --- a/tensorflow_io/audio/kernels/audio_input.cc +++ /dev/null @@ -1,166 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "kernels/dataset_ops.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" - -namespace tensorflow { -namespace data { - -class WAVInputStream : public io::BufferedInputStream { -public: - explicit WAVInputStream(InputStreamInterface* input_stream) - : io::BufferedInputStream(input_stream, 256 * 1024) { - } - Status ReadRecord(int64 samples_to_read, int64* samples_read, string* value) { - while (data_start_ + data_size_ == data_offset_) { - if (Tell() == file_size_ + 8) { - *samples_read = 0; - return Status::OK(); - } - TF_RETURN_IF_ERROR(ReadNBytes(sizeof(struct DataHeader), &buffer_)); - struct DataHeader *p = (struct DataHeader *)buffer_.data(); - if (memcmp(p->mark, "data", 4) == 0 && p->size != 0) { - data_size_ = p->size; - data_start_ = Tell(); - data_offset_ = Tell(); - TF_RETURN_IF_ERROR(ReadNBytes(data_size_, &buffer_)); - continue; - } - TF_RETURN_IF_ERROR(SkipNBytes(p->size)); - } - if (samples_to_read <= 0 || (data_offset_ + samples_to_read * num_channels_ * sizeof(int16) >= data_start_ + data_size_)) { - samples_to_read = (data_start_ + data_size_ - data_offset_) / num_channels_ / sizeof(int16); - } - *value = buffer_.substr(data_offset_ - data_start_, samples_to_read * num_channels_ * sizeof(int16)); - data_offset_ += samples_to_read * num_channels_ * sizeof(int16); - *samples_read = samples_to_read; - return Status::OK(); - } - Status ReadHeader() { - string buffer; - TF_RETURN_IF_ERROR(ReadNBytes(sizeof(struct WAVHeader), &buffer)); - struct WAVHeader *header = (struct WAVHeader *)buffer.data(); - if (memcmp(header->riff, "RIFF", 4) != 0) { - return errors::InvalidArgument("WAV file must starts with `RIFF`"); - } - file_size_ = header->size; - if (memcmp(header->wave, "WAVE", 4) != 0) { - return errors::InvalidArgument("WAV file must contains riff type `WAVE`"); - } - if (memcmp(header->fmt, "fmt ", 4) != 0) { - return errors::InvalidArgument("WAV file must contains `fmt ` mark"); - } - int32 fmt_size_ = header->fmt_size; - if (fmt_size_ != 16 && fmt_size_ != 18) { - return errors::InvalidArgument("WAV file must have `fmt_size ` 16 or 18, received", fmt_size_); - } - int16 fmt_type_ = header->fmt_type; - if (fmt_type_ != 1) { - return errors::InvalidArgument("WAV file must have `fmt_type ` 1, received", fmt_type_); - } - num_channels_ = header->num_channels; - if (num_channels_ <= 0) { - return errors::InvalidArgument("WAV file have invalide channels: ", num_channels_); - } - int32 sample_rate_ = header->sample_rate; - int32 byte_rate_ = header->byte_rate; - int16 sample_alignment_ = header->sample_alignment; - int16 bit_depth_ = header->bit_depth; - if (bit_depth_ != 16) { - return errors::InvalidArgument("WAV file must contains 16 bits data"); - } - if (fmt_size_ == 18) { - TF_RETURN_IF_ERROR(SkipNBytes(2)); - } - do { - TF_RETURN_IF_ERROR(ReadNBytes(sizeof(struct DataHeader), &buffer)); - struct DataHeader *p = (struct DataHeader *)buffer.data(); - if (memcmp(p->mark, "data", 4) == 0) { - data_size_ = p->size; - data_start_ = Tell(); - data_offset_ = Tell(); - TF_RETURN_IF_ERROR(ReadNBytes(data_size_, &buffer_)); - return Status::OK(); - } - TF_RETURN_IF_ERROR(SkipNBytes(p->size)); - } while (Tell() < file_size_ + 8); - - return Status::OK(); - } - int64 Channel() { - return num_channels_; - } -private: - struct WAVHeader { - char riff[4]; // "RIFF" - int32 size; // Size after (file size - 8) - char wave[4]; // "WAVE" - char fmt[4]; // "fmt " - int32 fmt_size; // 16 for PCM - int16 fmt_type; // 1 for PCM. 3 for IEEE Float - int16 num_channels; - int32 sample_rate; - int32 byte_rate; // Number of bytes per second. - int16 sample_alignment; // num_channels * Bytes Per Sample - int16 bit_depth; // Number of bits per sample - }; - struct DataHeader { - char mark[4]; - int32 size; - }; - int64 num_channels_; - int64 file_size_; - int64 data_size_; - int64 data_start_; - int64 data_offset_; - string buffer_; -}; - -class WAVInput: public FileInput { - public: - Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { - if (state.get() == nullptr) { - state.reset(new WAVInputStream(s)); - TF_RETURN_IF_ERROR(state.get()->ReadHeader()); - } - string buffer; - TF_RETURN_IF_ERROR(state.get()->ReadRecord(record_to_read, record_read, &buffer)); - if (*record_read > 0) { - Tensor value_tensor(ctx->allocator({}), DT_INT16, {*record_read, state.get()->Channel()}); - memcpy(value_tensor.flat().data(), buffer.data(), (*record_read) * state.get()->Channel() * sizeof(int16)); - out_tensors->emplace_back(std::move(value_tensor)); - } - return Status::OK(); - } - Status FromStream(io::InputStreamInterface* s) override { - return Status::OK(); - } - void EncodeAttributes(VariantTensorData* data) const override { - } - bool DecodeAttributes(const VariantTensorData& data) override { - return true; - } - protected: -}; - -REGISTER_UNARY_VARIANT_DECODE_FUNCTION(WAVInput, "tensorflow::data::WAVInput"); - -REGISTER_KERNEL_BUILDER(Name("WAVInput").Device(DEVICE_CPU), - FileInputOp); -REGISTER_KERNEL_BUILDER(Name("WAVDataset").Device(DEVICE_CPU), - FileInputDatasetOp); -} // namespace data -} // namespace tensorflow diff --git a/tensorflow_io/audio/kernels/audio_kernels.cc b/tensorflow_io/audio/kernels/audio_kernels.cc new file mode 100644 index 000000000..ed5e750a1 --- /dev/null +++ b/tensorflow_io/audio/kernels/audio_kernels.cc @@ -0,0 +1,261 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/stream.h" + +namespace tensorflow { +namespace data { +namespace { + +struct WAVHeader { + char riff[4]; // "RIFF" + int32 size; // Size after (file size - 8) + char wave[4]; // "WAVE" + char fmt[4]; // "fmt " + int32 fmt_size; // 16 for PCM + int16 fmt_type; // 1 for PCM. 3 for IEEE Float + int16 num_channels; + int32 sample_rate; + int32 byte_rate; // Number of bytes per second. + int16 sample_alignment; // num_channels * Bytes Per Sample + int16 bit_depth; // Number of bits per sample +}; +struct DataHeader { + char mark[4]; + int32 size; +}; +Status ValidateWAVHeader(struct WAVHeader *header) { + if (memcmp(header->riff, "RIFF", 4) != 0) { + return errors::InvalidArgument("WAV file must starts with `RIFF`"); + } + if (memcmp(header->wave, "WAVE", 4) != 0) { + return errors::InvalidArgument("WAV file must contains riff type `WAVE`"); + } + if (memcmp(header->fmt, "fmt ", 4) != 0) { + return errors::InvalidArgument("WAV file must contains `fmt ` mark"); + } + if (header->fmt_size != 16 && header->fmt_size != 18) { + return errors::InvalidArgument("WAV file must have `fmt_size ` 16 or 18, received", header->fmt_size); + } + if (header->fmt_type != 1) { + return errors::InvalidArgument("WAV file must have `fmt_type ` 1, received", header->fmt_type); + } + if (header->num_channels <= 0) { + return errors::InvalidArgument("WAV file have invalide channels: ", header->num_channels); + } + return Status::OK(); +} + +class ListWAVInfoOp : public OpKernel { + public: + explicit ListWAVInfoOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string filename = filename_tensor.scalar()(); + + const Tensor& memory_tensor = context->input(1); + const string& memory = memory_tensor.scalar()(); + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); + uint64 size; + OP_REQUIRES_OK(context, file->GetFileSize(&size)); + + StringPiece result; + struct WAVHeader header; + OP_REQUIRES_OK(context, file->Read(0, sizeof(header), &result, (char *)(&header))); + + OP_REQUIRES_OK(context, ValidateWAVHeader(&header)); + if (header.size + 8 != size) { + // corrupted file? + } + int64 filesize = header.size + 8; + + int64 position = result.size(); + + if (header.fmt_size == 18) { + position += 2; + } + + int64 bytes = 0; + + do { + struct DataHeader head; + OP_REQUIRES_OK(context, file->Read(position, sizeof(head), &result, (char *)(&head))); + position += result.size(); + if (memcmp(head.mark, "data", 4) == 0) { + bytes += head.size; + } + position += head.size; + } while (position < filesize); + + string dtype; + switch (header.bit_depth) { + case 8: + dtype = "int8"; + break; + case 16: + dtype = "int16"; + break; + case 24: + dtype = "int32"; + break; + default: + OP_REQUIRES(context, false, errors::InvalidArgument("unsupported bit_depth: ", header.bit_depth)); + } + // bytes = NumSamples * NumChannels * BitsPerSample/8 + int64 num_samples = bytes / header.num_channels / (header.bit_depth / 8); + + Tensor* dtype_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &dtype_tensor)); + Tensor* shape_tensor; + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({2}), &shape_tensor)); + Tensor* rate_tensor; + OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({}), &rate_tensor)); + + dtype_tensor->scalar()() = std::move(dtype); + shape_tensor->flat()(0) = num_samples; + shape_tensor->flat()(1) = header.num_channels; + rate_tensor->scalar()() = header.sample_rate; + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +class ReadWAVOp : public OpKernel { + public: + explicit ReadWAVOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string& filename = filename_tensor.scalar()(); + + const Tensor& memory_tensor = context->input(1); + const string& memory = memory_tensor.scalar()(); + + const Tensor& start_tensor = context->input(2); + const int64 start = start_tensor.scalar()(); + + const Tensor& stop_tensor = context->input(3); + const int64 stop = stop_tensor.scalar()(); + + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); + uint64 size; + OP_REQUIRES_OK(context, file->GetFileSize(&size)); + + StringPiece result; + struct WAVHeader header; + OP_REQUIRES_OK(context, file->Read(0, sizeof(header), &result, (char *)(&header))); + + OP_REQUIRES_OK(context, ValidateWAVHeader(&header)); + if (header.size + 8 != size) { + // corrupted file? + } + int64 filesize = header.size + 8; + + int64 position = result.size(); + + if (header.fmt_size == 18) { + position += 2; + } + + int64 bytes = 0; + + do { + struct DataHeader head; + OP_REQUIRES_OK(context, file->Read(position, sizeof(head), &result, (char *)(&head))); + position += result.size(); + if (memcmp(head.mark, "data", 4) == 0) { + bytes += head.size; + } + position += head.size; + } while (position < filesize); + + // bytes = NumSamples * NumChannels * BitsPerSample/8 + int64 num_samples = bytes / header.num_channels / (header.bit_depth / 8); + + int64 sample_start = start; + int64 sample_stop = stop; + if (sample_start > num_samples) { + sample_start = num_samples; + } + if (sample_stop < 0) { + sample_stop = num_samples; + } + if (sample_stop < sample_start) { + sample_stop = sample_start; + } + + + Tensor* output_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({sample_stop - sample_start, header.num_channels}), &output_tensor)); + + int64 read_offset = 0; + int64 start_bytes = sample_start * header.num_channels * (header.bit_depth / 8); + int64 final_bytes = sample_stop * header.num_channels * (header.bit_depth / 8); + + bytes = 0; + + position = sizeof(header) + ((header.fmt_size == 18) ? 2 : 0); + do { + struct DataHeader head; + OP_REQUIRES_OK(context, file->Read(position, sizeof(head), &result, (char *)(&head))); + position += result.size(); + if (memcmp(head.mark, "data", 4) == 0) { + // only read if start_bytes and final_bytes within range: + if (start_bytes < bytes + head.size && final_bytes > bytes) { + int64 read_start_bytes = (start_bytes < bytes) ? bytes : start_bytes; + int64 read_final_bytes = (final_bytes < bytes + head.size) ? final_bytes : (bytes + head.size); + string buffer; + buffer.resize(read_final_bytes - read_start_bytes); + OP_REQUIRES_OK(context, file->Read(position + read_start_bytes - bytes, (read_final_bytes - read_start_bytes), &result, &buffer[0])); + + switch (header.bit_depth) { + case 8: + memcpy((char *)(output_tensor->flat().data()) + read_offset, &buffer[0], (read_final_bytes - read_start_bytes)); + read_offset += (read_final_bytes - read_start_bytes); + break; + case 16: + memcpy((char *)(output_tensor->flat().data()) + read_offset, &buffer[0], (read_final_bytes - read_start_bytes)); + read_offset += (read_final_bytes - read_start_bytes); + break; + default: + OP_REQUIRES(context, false, errors::InvalidArgument("unsupported bit_depth: ", header.bit_depth)); + } + } + bytes += head.size; + } + position += head.size; + } while (position < filesize); + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("ListWAVInfo").Device(DEVICE_CPU), + ListWAVInfoOp); +REGISTER_KERNEL_BUILDER(Name("ReadWAV").Device(DEVICE_CPU), + ReadWAVOp); + + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/audio/ops/audio_ops.cc b/tensorflow_io/audio/ops/audio_ops.cc index 81f1bcff6..2ca5c66a4 100644 --- a/tensorflow_io/audio/ops/audio_ops.cc +++ b/tensorflow_io/audio/ops/audio_ops.cc @@ -19,27 +19,28 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("WAVInput") - .Input("source: string") - .Output("handle: variant") - .Attr("filters: list(string) = []") - .Attr("columns: list(string) = []") - .Attr("schema: string = ''") +REGISTER_OP("ListWAVInfo") + .Input("filename: string") + .Input("memory: string") + .Output("dtype: string") + .Output("shape: int64") + .Output("rate: int32") .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(0, c->MakeShape({})); + c->set_output(1, c->MakeShape({2})); + c->set_output(2, c->MakeShape({})); return Status::OK(); }); -REGISTER_OP("WAVDataset") - .Input("input: T") - .Input("batch: int64") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .Attr("T: {string, variant} = DT_VARIANT") - .SetIsStateful() +REGISTER_OP("ReadWAV") + .Input("filename: string") + .Input("memory: string") + .Input("start: int64") + .Input("stop: int64") + .Attr("dtype: type") + .Output("output: dtype") .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({})); + c->set_output(0, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); return Status::OK(); }); diff --git a/tensorflow_io/audio/python/ops/audio_ops.py b/tensorflow_io/audio/python/ops/audio_ops.py index fff7bc525..374ef0400 100644 --- a/tensorflow_io/audio/python/ops/audio_ops.py +++ b/tensorflow_io/audio/python/ops/audio_ops.py @@ -18,27 +18,66 @@ from __future__ import print_function import tensorflow as tf -from tensorflow_io.core.python.ops import data_ops as data_ops -from tensorflow_io.core.python.ops import core_ops as audio_ops +from tensorflow_io.core.python.ops import data_ops +from tensorflow_io.core.python.ops import core_ops -class WAVDataset(data_ops.Dataset): +def list_wav_info(filename, **kwargs): + """list_wav_info""" + if not tf.executing_eagerly(): + raise NotImplementedError("list_wav_info only support eager mode") + memory = kwargs.get("memory", "") + dtype, shape, rate = core_ops.list_wav_info( + filename, memory=memory) + return tf.TensorSpec(shape.numpy(), dtype.numpy().decode()), rate + +def read_wav(filename, spec, **kwargs): + """read_wav""" + memory = kwargs.get("memory", "") + start = kwargs.get("start", 0) + stop = kwargs.get("stop", None) + if stop is None and spec.shape[0] is not None: + stop = spec.shape[0] - start + if stop is None: + stop = -1 + return core_ops.read_wav( + filename, memory=memory, + start=start, stop=stop, dtype=spec.dtype) + +class WAVDataset(data_ops.BaseDataset): """A WAV Dataset""" - def __init__(self, filename, batch=None): + def __init__(self, filename, **kwargs): """Create a WAVDataset. Args: - filename: A `tf.string` tensor containing one or more filenames. + filename: A string containing filename. """ - batch = 0 if batch is None else batch - dtypes = [tf.int16] - shapes = [ - tf.TensorShape([None])] if batch == 0 else [ - tf.TensorShape([None, None])] + if not tf.executing_eagerly(): + start = kwargs.get("start") + stop = kwargs.get("stop") + dtype = kwargs.get("dtype") + shape = kwargs.get("shape") + else: + spec, _ = list_wav_info(filename) + start = 0 + stop = spec.shape[0] + dtype = spec.dtype + shape = tf.TensorShape( + [dim if i != 0 else None for i, dim in enumerate( + spec.shape.as_list())]) + + # capacity is the rough count for each chunk in dataset + capacity = kwargs.get("capacity", 65536) + entry_start = list(range(start, stop, capacity)) + entry_stop = entry_start[1:] + [stop] + dataset = data_ops.BaseDataset.from_tensor_slices( + (tf.constant(entry_start, tf.int64), tf.constant(entry_stop, tf.int64)) + ).map(lambda start, stop: core_ops.read_wav( + filename, memory="", start=start, stop=stop, dtype=dtype)) + self._dataset = dataset + super(WAVDataset, self).__init__( - audio_ops.wav_dataset, - audio_ops.wav_input(filename), - batch, dtypes, shapes) + self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access class AudioDataset(data_ops.Dataset): """A Audio File Dataset that reads the audio file.""" diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 25c62ec34..07bdc0b36 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -28,6 +28,7 @@ cc_library( name = "dataset_ops", srcs = [ "kernels/dataset_ops.h", + "kernels/stream.h", ], copts = tf_io_copts(), includes = [ diff --git a/tensorflow_io/core/kernels/stream.h b/tensorflow_io/core/kernels/stream.h new file mode 100644 index 000000000..e812babf4 --- /dev/null +++ b/tensorflow_io/core/kernels/stream.h @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/lib/io/random_inputstream.h" + +namespace tensorflow { +namespace data { + +// Note: This SizedRandomAccessFile should only lives within Compute() +// of the kernel as buffer could be released by outside. +class SizedRandomAccessFile : public tensorflow::RandomAccessFile { + public: + SizedRandomAccessFile(Env* env, const string& filename, const void* optional_memory_buff, const size_t optional_memory_size) + : file_(nullptr) + , size_(optional_memory_size) + , buff_((const char *)(optional_memory_buff)) + , size_status_(Status::OK()) { + if (size_ == 0) { + size_status_ = env->GetFileSize(filename, &size_); + if (size_status_.ok()) { + size_status_ = env->NewRandomAccessFile(filename, &file_); + } + } + } + + virtual ~SizedRandomAccessFile() {} + Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { + if (file_.get() != nullptr) { + return file_.get()->Read(offset, n, result, scratch); + } + size_t bytes_to_read = 0; + if (offset < size_) { + bytes_to_read = (offset + n < size_) ? n : (size_ - offset); + } + if (bytes_to_read > 0) { + memcpy(scratch, &buff_[offset], bytes_to_read); + } + *result = StringPiece(scratch, bytes_to_read); + if (bytes_to_read < n) { + return errors::OutOfRange("EOF reached"); + } + return Status::OK(); + } + Status GetFileSize(uint64* size) { + if (size_status_.ok()) { + *size = size_; + } + return size_status_; + } + private: + std::unique_ptr file_; + uint64 size_; + const char *buff_; + Status size_status_; +}; + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/core/python/ops/data_ops.py b/tensorflow_io/core/python/ops/data_ops.py index bbd3b5d8b..40107037d 100644 --- a/tensorflow_io/core/python/ops/data_ops.py +++ b/tensorflow_io/core/python/ops/data_ops.py @@ -58,9 +58,8 @@ def _apply_fn(dataset): class BaseDataset(tf.compat.v2.data.Dataset): """A Base Dataset""" - def __init__(self, variant, batch, dtypes, shapes): + def __init__(self, variant, dtypes, shapes): """Create a Base Dataset.""" - self._batch = 0 if batch is None else batch self._dtypes = dtypes self._shapes = shapes super(BaseDataset, self).__init__(variant) @@ -93,4 +92,4 @@ def __init__(self, fn, data_input, batch, dtypes, shapes): self._data_input, self._batch, output_types=self._dtypes, - output_shapes=self._shapes), self._batch, self._dtypes, self._shapes) + output_shapes=self._shapes), self._dtypes, self._shapes) diff --git a/tests/test_audio_eager.py b/tests/test_audio_eager.py index a5f2a326a..aaf465df1 100644 --- a/tests/test_audio_eager.py +++ b/tests/test_audio_eager.py @@ -36,17 +36,30 @@ def test_audio_dataset(): f = lambda x: float(x) / (1 << 15) - audio_dataset = audio_io.WAVDataset([audio_path]) - i = 0 - for v in audio_dataset: - assert audio_v.audio[i].numpy() == f(v.numpy()) - i += 1 - assert i == 5760 - - audio_dataset = audio_io.WAVDataset([audio_path], batch=2) - i = 0 - for v in audio_dataset: - assert audio_v.audio[i].numpy() == f(v[0].numpy()) - assert audio_v.audio[i + 1].numpy() == f(v[1].numpy()) - i += 2 - assert i == 5760 + for capacity in [10, 100, 500]: + audio_dataset = audio_io.WAVDataset(audio_path, capacity=capacity).apply( + tf.data.experimental.unbatch()).map(tf.squeeze) + i = 0 + for v in audio_dataset: + assert audio_v.audio[i].numpy() == f(v.numpy()) + i += 1 + assert i == 5760 + + for capacity in [10, 100, 500]: + audio_dataset = audio_io.WAVDataset(audio_path, capacity=capacity).apply( + tf.data.experimental.unbatch()).batch(2).map(tf.squeeze) + i = 0 + for v in audio_dataset: + assert audio_v.audio[i].numpy() == f(v[0].numpy()) + assert audio_v.audio[i + 1].numpy() == f(v[1].numpy()) + i += 2 + assert i == 5760 + + spec, rate = audio_io.list_wav_info(audio_path) + assert spec.dtype == tf.int16 + assert spec.shape == [5760, 1] + assert rate.numpy() == audio_v.sample_rate.numpy() + + samples = audio_io.read_wav(audio_path, spec) + assert samples.dtype == tf.int16 + assert samples.shape == [5760, 1]