Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add FileDataset to read whole content of file into tf.data pipeline #366

Closed
wants to merge 5 commits into from

Conversation

yongtang
Copy link
Member

When we started the project, we thought data read through
tf.data pipeline will always be record like. That is, each file
will have multiple records. This scenario is the case for many situations
like Text, Video, etc. where the natural boundary within the file
are new line feed, or inidvidual frames.

But for many images formats such as webp, jpeg, etc, each image
is already a boundary and there is no need to further partition.
Further, people still perfer decode_xxx call in many situations.

This PR adds a FileDataset which could take files (and potentially compressed files)
and feed whole file content as string into tf.data.

This FileDataset is not a fit where files are really large (e.g., 100GB of text record).
However, it is good enough for image file usage.

As an example, this PR also converts WebPDataset to use FileDataset + decode_webp.

Note: TIFF file could have multiple images with different shape so it is not a fit for FileDataset as well.

Signed-off-by: Yong Tang yong.tang.github@outlook.com

Copy link
Member

@feihugis feihugis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yongtang This PR looks great! Just add some minor comments.

@@ -0,0 +1,92 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2018 -> 2019

// First try to find out the size of the file, depending on if size is available, we will set the chunk to read.
uint64 file_size = 0;
if (sized_stream->GetFileSize(&file_size) == Status::OK()) {
chunk_size = file_size;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be helpful if we add a warning log here when file_size is big?

} else {
string buffer;
int64 total_size = 0;
for (size_t i = 0; i < entries.size(); i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If file_size is known, will total_size be equal to file_size? Could we calculate total_size during reading the file:

int64 total_size = 0;
while (status.ok()) {
      string buffer;
      status = s->ReadNBytes(chunk_size, &buffer);
      if (status.ok() || errors::IsOutOfRange(status)) {
        entries.emplace_back(std::move(buffer));
        total_size += buffer.size();
      }
}

total_size += entries[i].size();
}
buffer.reserve(total_size);
for (size_t i = 0; i < entries.size(); i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++i may be better.

bool DecodeAttributes(const VariantTensorData& data) override {
return true;
}
protected:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protected: is empty. We can delete it here.

@@ -0,0 +1,46 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2018 -> 2019

@@ -60,3 +61,18 @@ def __init__(self, fn, data_input, batch, dtypes, shapes):
self._batch,
output_types=self._dtypes,
output_shapes=self._shapes), self._batch, self._dtypes, self._shapes)

class FileDataset(BaseDataset):
"""A FileDataset that read file content as string"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: read -> reads

@@ -21,39 +21,9 @@
from tensorflow import dtypes
from tensorflow.compat.v1 import data
from tensorflow_io import _load_library
from tensorflow_io.core.python.ops import data_ops as data_ops
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we just use from tensorflow_io.core.python.ops import data_ops without as data_ops?

When we started the project, we thought data read through
tf.data pipeline will always be record like. That is, each file
will have multiple records. This scenario is the case for many situations
like Text, Video, etc. where the natural boundary within the file
are new line feed, or inidvidual frames.

But for many images formats such as webp, jpeg, etc, each image
is already a boundary and there is no need to further partition.
Further, people still perfer `decode_xxx` call in many situations.

This PR adds a FileDataset which could take files (and potentially compressed files)
and feed whole file content as string into tf.data.

This FileDataset is not a fit where files are really large (e.g., 100GB of text record).
However, it is good enough for image file usage.

As an example, this PR also converts WebPDataset to use FileDataset + decode_webp.

Note: TIFF file could have multiple images with different shape so it is not a fit for FileDataset as well.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
…+decode_webp

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@damienpontifex
Copy link
Contributor

How is this different from using

ds = tf.data.Dataset.list_files('file/pattern/*')
ds = ds.map(tf.io.read_file)

or

ds = tf.data.Dataset.from_tensor_slices('file/path').map(tf.io.read_file)

@yongtang
Copy link
Member Author

yongtang commented Jul 24, 2019

@damienpontifex That is a good question. When the project was started, it was a collection of several C++ implementation of tf.data.Dataset (Kafka, Parquet, Ignite, Kinesis) and IGFS file system that used to be in tensorflow's core repo (and spin off due to the deprecation of tf.contrib).

Some of them are file formats, some of them are not even file format, but really just streaming IO processing (e.g., Kafka/Kinesis).

Because of the legacy, we used to favor C++ implementations on tf.data.Dataset, than high level python implementations. We thought that could be advantageous as it could be part of the TensorFlow's graph (save and load), and should performs well without the limitation of py_func in TF.

As we move forward, there are several limitations:

  1. C++ is not an exotic language to attract lots of contributors.
  2. There are lots of repeated code in C++ as we essentially using C++ to implement a TF subgraph for every format. And combined with 1) it makes contribution even prohibitive (to an extent).

We made some efforts to streamline the C++ code so that contribution could be easier. For example, if you look into some recent additions of Dataset such as Pcap (PR #303), you probably noticed that you don't need to even understand TensorFlow's Graph concept in C++. You only need to implement a method called ReadRecord to extract individual records.

That helps a little. But then we realized that some of the early additions of tf.data.Dataset are not really that great due to the repetitive C++ code implementation. For example, as you might see from this PR, we used to implement both WebPDataset and decode_webp just to accommodate two different pipelines. One for tf.data.Dataset pipeline, another for situations where user already have the data in memory.

So for this reason, I was hoping this PR could at least remove the need to have both WebPDataset and decode_webp in C++.

@yongtang
Copy link
Member Author

yongtang commented Jul 24, 2019

Now back to the FileDataset discussion. I think

ds = tf.data.Dataset.from_tensor_slices('file/path').map(tf.io.read_file)

is a good point.

I am in favor of reducing the repetitive C++ code (think @terrytangyuan also agrees even before that), and replace it with:

  1. Very small set of C++ code to have the basic function as building blocks
  2. Wire up with high level python code to add more functionalities.

For example, image files could be very straightforward, so let's do all image files this way:

  1. Read the whole file into memory
  2. decode_image to process the data.

That is also why I am in favor of dicom's 'decode_dicom_image/decode_dicom_data' approach than rework as a C++ DICOMDataset.

That was the initial motivation for FileDataset. (apparently I forgot we could use .map(tf.io.read_file) as there are so many things to keep track of)

@yongtang
Copy link
Member Author

@damienpontifex I am thinking close this PR, or just replace WebPDataset with .map(tf.io.read_file).map(decode_webp) in this PR.

But there are also some additional items we may have to be concerned with.

@yongtang
Copy link
Member Author

@damienpontifex For one, even though most of the image files are straightforward as they most likely will fit in memory, there are special files that may have multiple image frames.

For example, GIF may have multiple frames of the same size, and TIFF may have multiple images of even different size. How to come up with a decode_image that could capture those?

(Note TIFF issues is also one of the reasons PR #186 is in Working-in-Progress state for so long without any progress).

@yongtang
Copy link
Member Author

yongtang commented Jul 24, 2019

@damienpontifex Another issue is the compression (note compression is different from archive. compression consists of one component such as gzip. archive consists of multiple components such as zip.).

In TensorFlow's core repo, you probably noticed that file formats such as TFRecord and CSV actually manually repeated the C++ implantation of compression in both cases. At one point I want to avoid the repeated C++ code of compression so I added the C++ template to handle that. For example, in pcap you can add another compression with just adding a filter such as gz, tar etc. (without adding additional C++ code). The current C++ template also allows iterate through archive files such as zip (will detail later).

Now with .map(tf.io.read_file).map(decode_webp) compression is not supported anymore.

As we discuss, I think we could add a generic decode_compression op to extract one layer of compression at a time, so that we could have

.map(tf.io.read_file).map(decode_compression).map(decode_compression).map(decode_webp)

to extract any layers of compression filters.

But, archive is still a problem as we could not use decode_zip because we don't know how many output components are. It just could not be a straightforward op like decode_compression.

We could have a ZipDataset to allow rendering multiple components. But how are we going to chain them together with map? Because tf.data.Dataset seems to not allow recursive levels of tf.data.Dataset itself to iterate.

@yongtang
Copy link
Member Author

@damienpontifex Finally, there are also files that just not fit into memory. For example, a video file naturally should be considered as a collection of individual image frames. We could not just read GB's of video in memory and decode it in one batch. There are also some other formats like hdf5 that could be big and user want to extract one record at a time (instead of one batch to read GB's of data into memory). At the moment we just use C++ implementation manually to process them one format after another format. I haven't come up with a better way to handle them yet.

@yongtang
Copy link
Member Author

@damienpontifex Thanks for bring this issue up. Let me move this PR to Working-in-Progress for now.

@yongtang yongtang changed the title Add FileDataset to read whole content of file into tf.data pipeline [WIP] Add FileDataset to read whole content of file into tf.data pipeline Jul 24, 2019
@damienpontifex
Copy link
Contributor

Thank you @yongtang! That was a very comprehensive insight. I love the thinking process and I'd love to be able to work on this type of thing more/full time.

Another spanner in all of this work is the different language bindings (well outside this discussion). I know c++ helps in this domain, so having things down lower opens opportunities to use them across these as I'm seeing also in the swift for TF repo now.

Agreed on there being a lot to keep track of and I I'd used the tf.io.read_file so was interested to know how this was going to be different. If it's a lot for us who dig into the core code for these repos, it's also a lot for an individual just using the libraries. It's always confusing when there's many ways to achieve the same thing and someone new to the domain always gets stuck evaluating which one is right for them.

@yongtang
Copy link
Member Author

Thanks @damienpontifex.

Yes one advantages with C++ primitive ops is that it could be part of the TF graph. And as long as it is part of the TF Graph then it is always possible to bind with another language such as Golang or R or Swift. (I have lots interests to swift too, haven't had a chance to play with it yet).

That is one of the reasons we still prefer C++ for low level (but we should restrict C++ to only low level primitive ops, not composite ops, for the ease of developing with more contributors).

On another note about language binding & mixture. It is not even necessarily one way around (other language => C++). In fact, it is even possible to build code from other C-family language into C++ as well.

For example, in tensorflow-io there is already a prometheus dataset (PR #265) which actually builds from prometheus' Golang SDK and embedded into part of the C++ tf.data.Dataset as shared library. I was hoping it could be set as an example to get more interests from golang developers and docker/kubernetes community, for tensorflow-io project. (Documentation is missing but feature has been there. I probably should spend more time on documentation).

Swift could be similar, either through cdecl or callback I believe.

Anyway, I think there are plenty of rooms to make this an interesting projects.

@feihugis
Copy link
Member

feihugis commented Jul 24, 2019

@yongtang @damienpontifex These discussions are great and inspiring!

Finally, there are also files that just not fit into memory. For example, a video file naturally should be considered as a collection of individual image frames. We could not just read GB's of the video in memory and decode it in one batch. There are also some other formats like hdf5 that could be big and user want to extract one record at a time (instead of one batch to read GB's of data into memory). At the moment we just use C++ implementation manually to process them one format after another format. I haven't come up with a better way to handle them yet.

If I remember correctly, HDF5 and TIFF could support chunk/tile structure. The dataset op could read and decompress the chunk one by one instead of reading the whole big file.

To process different data formats, I'm wondering if something like below would work:

  1. The dataset op for a specific data format reads the header info/metadata (e.g. byte offset, length, compression method for each chunk/file/image/video frame) instead of the data content. This metadata info can be organized in a shared metadata class;

  2. Based on these metadata info, we could use the basic filesystem API to read byte stream;

  3. Then the in-memory byte stream can be decompressed as tensors based on the metadata info. Here, the decompression method can be shared for different formats;

struct Metadata {
  string path;
  long byte_offset;
  long byte_length;
  CompressionType compression_type;
};

class Reader {
  public:
    Reader(Metadata metadata) {
      TF_ASSERT_OK(env->NewRandomAccessFile(metadata.path, &file_));
    } 

    Status Read(const Metadata& metadata, Tensor* tensor) {
      string read;
      RandomAccessInputStream in(file.get());
      TF_RETURN_IF_ERROR(in.Seek(metadata.byte_offset));
      TF_RETURN_IF_ERROR(in.ReadNBytes(metadata.byte_length, &read));   
      switch (metadata.compression_type) {
        case CompressionType::ZLIB : 
             TF_RETURN_IF_ERROR(Decompression::Decode_compression_zlib(read, tensor));
        case CompressionType::GZIP : 
            TF_RETURN_IF_ERROR(Decompression::Decode_compression_gzip(read, tensor));
        ...
      }
     return Status::OK();
  }
  
  private:
    std::unique_ptr<RandomAccessFile> file_;
};

class HDF5DatasetOp {
  
  ...
  
  Status ReadMetadata(const string& path) {
       // populate metadata_vec_    
  }
  
  ...
  
  // This method can be shared between different data formats.
  Status GetNextInternal(IteratorContext* ctx,
                         std::vector<Tensor>* out_tensors,
                         bool* end_of_sequence) override {
     if (index_ == metadata_vec_.size()) {
       *end_of_sequence = true;
       return Status::OK();
     }
     TF_RETURN_IF_ERROR(reader.read(metadata_vec_[index_], output));
      out_tensors->emplace_back(std::move(output));
      ++index;
      return Status::OK();
    }  
  
  private:
    std::vector<MetaData> metadata_vec_; 
    long index_;
    Reader reader;
};

This approach supports different data formats stored in different storage systems and can also share the code between different data formats (e.g. compression and decompression, reading and writing).

@yongtang
Copy link
Member Author

@feihugis Yes meta data definitely works in archives (e.g., zip). In fact, to go one step further, we could even store the buffer in metadata itself, so that the next round of map() could even go directly to metadata.

The metadata could be wrapped into variant so that it can be any content. @suphoff suggested variant to serialize the meta information at one point. But the last time I tried, I think I was too restricted to single input tensor, didn't find a good way to to work around multiple input types for one tensor (filename string vs variant, disk file vs memory buffer).

If we allow to have an optional input of meta data, we could do it in this way:

struct Metadata {
  string path;
  long byte_offset;
  long byte_length;
  CompressionType compression_type;
  void *buffer; // optional
};
dataset = tf.data.Dataset.from_tensor_slices('sample.zip')

dataset = dataset.map(unarchive) # unarchive: filename => (filename, metadata) NOTE: might have multiple (filename, metadata) pairs:

(sample.zip, sample1.data buffer)
(sample.zip, sample2.data buffer)
(sample.zip, sample3.data buffer)

dataset = dataset.map(read_object) # read_object: (filename, metadata) => buffer NOTE: if metadata already have buffer in memory, just pass the pass to output

(sample1.data buffer)
(sample2.data buffer)
(sample3.data buffer)

In case there is no middle step, then read_object will just read the whole file:

dataset = tf.data.Dataset.from_tensor_slices('sample.zip')

dataset = dataset.map(read_object) # read_object: (filename, null) => buffer # no meta data so let's just read the whole file

(sample.zip buffer)

@yongtang
Copy link
Member Author

@feihugis @damienpontifex Overall I think we have several things to fix:

  1. We want to have a way to iterate through a file with multiple generated outputs. (e.g., zip file)
  2. We also want to read everything into the memory, unless it does not fit into CPU or GPU memory.
  3. We want to reuse C++ code for primitive ops so that not to repeat everything in C++ again and again.

Also add @BryanCutler here about the batch. Previously, we thought we will reuse the same batch concept for serving two purposes:

  1. Batch is a way to improve performance for record based files such as parquet, feature, etc, as we really don't want to "read one integer or one float32 at a time". This is a caching issue.
  2. We also want to directly feed the dataset to tf.keras so batch concept in tf.keras is reused (number of samples populated in neural network). This is a model parameter.

But those two batch concepts are different. @BryanCutler I am wondering if it makes sense to split those things out. In other words, for 1) we should "read as much record as possible, as long as it fits the memory", and for 2) we should do a rebatch() to adjust the batch_size that has been feeded to tf.keras?

@feihugis
Copy link
Member

@yongtang Yeah, Zip file is a good example. Do you think we can apply the metadata-based approach to other data formats? It may be challenging for some data formats as the libraries may not expose this kind of API to extract the metadata info from the files.

@feihugis
Copy link
Member

  1. We also want to directly feed the dataset to tf.keras so batch concept in tf.keras is reused (number of samples populated in neural network). This is a model parameter.

If we want to directly feed datasets to tf.keras for model training, do we also need to consider the data shuffle?

@yongtang
Copy link
Member Author

@BryanCutler There are already quite a lot of discussions on this thread, though I want to add one more.

Previously, when we create a dataset, we create a dataset in one shot for every column. For example, for CsvDataset, we have to specify EVERY column before hand before it even runs.

The requirement was actually because we used to implement against 1.13/1.14 where TF Graph is static. So we need to know EVERYTHING before hand, in order to run the graph (or pass to tf.keras).

Now as we move to TF 2.0, knowing everything before hand is not necessarily anymore. We could just parse the file and find the meta data in eager mode, then build the dataset to pass to tf.keras.

In this situation, I am wondering if it makes sense to focus on "building dataset with one column at a time"? Something like:

# read parquet file and find all columns
# then build multiple datasets and build one dataset at a time for each column
dataset = zip([ParquetDataset(filename, column) for column in columns])

The reason, is that, when we try to build a dataset from ALL columns, we assume all columns should have the same number of records. But this is not the case for many files such as HDFS or Feather (if I understand correctly).

I noticed this issue when I tries to play with pandas. Just realized in our current implementation, it is hard to apply NA or null field.

But with TF 2.0 and eager execution, we actually have more freedom to handle those situations. For example, we could do additional bit masking before merge different columns.

From that standpoint, maybe it makes more sense to focus on building a dataset with only one column at a time?

@yongtang
Copy link
Member Author

@feihugis The good thing is that, in Tensorflow's core repo there are not a lot of format supported anyway (only TFRecord, FixedLengthRecord, and CsvDataset). Actually tensorflow-io has a lot more formats supported than tensorflow's core repo so I think that may not be a big issue for us.

@yongtang
Copy link
Member Author

@feihugis shuffle could be kind of tricky, as naturally tf.data.Dataset's pipeline is not exactly a good fit for shuffle.

tf.data.Dataset is essentially an iterable so shuffle is very expensive, unless the dataset is small.

The provided in-batch shuffle will avoid expensive operations but it is not a true shuffle.

I think true shuffle may have to be outside of tf.data.Dataset pipeline though I haven't been able to find a good way of achieving it, other than read everything into one tensor then shuffle.

@yongtang
Copy link
Member Author

To reiterate the problems here. We actually have two issues:

  1. We want to have a good way to process archive like zipped file, without additional C++ implementation from individual formats, and ideally just read into memory like tf.io.read_file.
  2. We also need to think about how to process large column-type data such as Parquet, HDF5, Feather, etc where each column's data may not fit into memory. We might be forced to split data into chunks but the chunk size should be as large as possible.

Overall, tf.data.Dataset has been designed in record-oriented way from beginning. This works well with image objects where each object is relatively large. It is not very performing when each record is small (e.g., one integer or one float) which is very common in column-data (.e.g, HDF5, Parquet, etc).

With the upcoming TF 2.0 in mind, we really should read everything into memory (unless forced to chunk). /cc @BryanCutler for column dataset, also @CaptainDuke as it may impact HDF5.

I think I have some ideas, will come up with PoC soon.

@BryanCutler
Copy link
Member

@yongtang there is a lot of good discussion here and I agree with both your points on batching and handling NULL values differently for eager mode. I'll make separate issues where we can continue to discuss.

yongtang added a commit to yongtang/io that referenced this pull request Jul 28, 2019
This PR is part of the effort in enhancing performance and
ease of use for tf.data pipeline, as was discussed in tensorflow#382 and tensorflow#366.

Previously, HDF5Dataset is relatively manual and user
has to find out the dataset (columns) in hdf5 file.

In this PR, the idea is to allow user to use list_hdf5_datasets
to probe the shape, dtype, and name of the datasets within HDF5.
A subsequent call to read_hdf5 will bring content to a shaped Tensor
so that it could be used later in TensorFlow.

The read_hdf5 has the option to specify a slice (or a subblock) of the
dataset. This should open up possibility in the future to allow binding
a class with a hdf5 file by implement `__len__` and `__getitem__`.

With list_hdf5_datasets and read_hdf5 ops, it is also possible to
ease the HDF5Dataset in eager mode. In eager mode, HDF5Dataset
could juse call list_hdf5_datasets to find out all the necessary
information, then calling read_hdf5 in pieces to maintain the `batch_size`
to be fed in tf.keras.

The limitation would be in graph mode as in graph mode user still has to
specify almost everything dtype, shape, name for HDF5Dataset to work.

This PR has not changed HDF5Dataset implementation to use
list_hdf5_datasets and read_hdf5 ops. But this could be easily done and
see 384 for similar changes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
yongtang added a commit to yongtang/io that referenced this pull request Jul 28, 2019
This PR is part of the effort in enhancing performance and
ease of use for tf.data pipeline, as was discussed in tensorflow#382 and tensorflow#366.

Previously, HDF5Dataset is relatively manual and user
has to find out the dataset (columns) in hdf5 file.

In this PR, the idea is to allow user to use list_hdf5_datasets
to probe the shape, dtype, and name of the datasets within HDF5.
A subsequent call to read_hdf5 will bring content to a shaped Tensor
so that it could be used later in TensorFlow.

The read_hdf5 has the option to specify a slice (or a subblock) of the
dataset. This should open up possibility in the future to allow binding
a class with a hdf5 file by implement `__len__` and `__getitem__`.

With list_hdf5_datasets and read_hdf5 ops, it is also possible to
ease the HDF5Dataset in eager mode. In eager mode, HDF5Dataset
could juse call list_hdf5_datasets to find out all the necessary
information, then calling read_hdf5 in pieces to maintain the `batch_size`
to be fed in tf.keras.

The limitation would be in graph mode as in graph mode user still has to
specify almost everything dtype, shape, name for HDF5Dataset to work.

This PR has not changed HDF5Dataset implementation to use
list_hdf5_datasets and read_hdf5 ops. But this could be easily done and
see 384 for similar changes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
yongtang added a commit to yongtang/io that referenced this pull request Jul 31, 2019
This PR is part of the effort in enhancing performance and
ease of use for tf.data pipeline, as was discussed in tensorflow#382 and tensorflow#366.

Previously, HDF5Dataset is relatively manual and user
has to find out the dataset (columns) in hdf5 file.

In this PR, the idea is to allow user to use list_hdf5_datasets
to probe the shape, dtype, and name of the datasets within HDF5.
A subsequent call to read_hdf5 will bring content to a shaped Tensor
so that it could be used later in TensorFlow.

The read_hdf5 has the option to specify a slice (or a subblock) of the
dataset. This should open up possibility in the future to allow binding
a class with a hdf5 file by implement `__len__` and `__getitem__`.

With list_hdf5_datasets and read_hdf5 ops, it is also possible to
ease the HDF5Dataset in eager mode. In eager mode, HDF5Dataset
could juse call list_hdf5_datasets to find out all the necessary
information, then calling read_hdf5 in pieces to maintain the `batch_size`
to be fed in tf.keras.

The limitation would be in graph mode as in graph mode user still has to
specify almost everything dtype, shape, name for HDF5Dataset to work.

This PR has not changed HDF5Dataset implementation to use
list_hdf5_datasets and read_hdf5 ops. But this could be easily done and
see 384 for similar changes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
yongtang added a commit to yongtang/io that referenced this pull request Aug 4, 2019
This PR is part of the effort in enhancing performance and
ease of use for tf.data pipeline, as was discussed in tensorflow#382 and tensorflow#366.

Previously, HDF5Dataset is relatively manual and user
has to find out the dataset (columns) in hdf5 file.

In this PR, the idea is to allow user to use list_hdf5_datasets
to probe the shape, dtype, and name of the datasets within HDF5.
A subsequent call to read_hdf5 will bring content to a shaped Tensor
so that it could be used later in TensorFlow.

The read_hdf5 has the option to specify a slice (or a subblock) of the
dataset. This should open up possibility in the future to allow binding
a class with a hdf5 file by implement `__len__` and `__getitem__`.

With list_hdf5_datasets and read_hdf5 ops, it is also possible to
ease the HDF5Dataset in eager mode. In eager mode, HDF5Dataset
could juse call list_hdf5_datasets to find out all the necessary
information, then calling read_hdf5 in pieces to maintain the `batch_size`
to be fed in tf.keras.

The limitation would be in graph mode as in graph mode user still has to
specify almost everything dtype, shape, name for HDF5Dataset to work.

This PR has not changed HDF5Dataset implementation to use
list_hdf5_datasets and read_hdf5 ops. But this could be easily done and
see 384 for similar changes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
yongtang added a commit that referenced this pull request Aug 4, 2019
* Rework on HDF5: add list_hdf5_datasets and read_hdf5 ops

This PR is part of the effort in enhancing performance and
ease of use for tf.data pipeline, as was discussed in #382 and #366.

Previously, HDF5Dataset is relatively manual and user
has to find out the dataset (columns) in hdf5 file.

In this PR, the idea is to allow user to use list_hdf5_datasets
to probe the shape, dtype, and name of the datasets within HDF5.
A subsequent call to read_hdf5 will bring content to a shaped Tensor
so that it could be used later in TensorFlow.

The read_hdf5 has the option to specify a slice (or a subblock) of the
dataset. This should open up possibility in the future to allow binding
a class with a hdf5 file by implement `__len__` and `__getitem__`.

With list_hdf5_datasets and read_hdf5 ops, it is also possible to
ease the HDF5Dataset in eager mode. In eager mode, HDF5Dataset
could juse call list_hdf5_datasets to find out all the necessary
information, then calling read_hdf5 in pieces to maintain the `batch_size`
to be fed in tf.keras.

The limitation would be in graph mode as in graph mode user still has to
specify almost everything dtype, shape, name for HDF5Dataset to work.

This PR has not changed HDF5Dataset implementation to use
list_hdf5_datasets and read_hdf5 ops. But this could be easily done and
see 384 for similar changes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Process default value of count and start

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Support HDF5Datast in graph mode

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@yongtang
Copy link
Member Author

With the upcoming PR #437, I think this PR is not needed anymore.

@yongtang yongtang closed this Aug 23, 2019
i-ony pushed a commit to i-ony/io that referenced this pull request Feb 8, 2021
)

* Rework on HDF5: add list_hdf5_datasets and read_hdf5 ops

This PR is part of the effort in enhancing performance and
ease of use for tf.data pipeline, as was discussed in tensorflow#382 and tensorflow#366.

Previously, HDF5Dataset is relatively manual and user
has to find out the dataset (columns) in hdf5 file.

In this PR, the idea is to allow user to use list_hdf5_datasets
to probe the shape, dtype, and name of the datasets within HDF5.
A subsequent call to read_hdf5 will bring content to a shaped Tensor
so that it could be used later in TensorFlow.

The read_hdf5 has the option to specify a slice (or a subblock) of the
dataset. This should open up possibility in the future to allow binding
a class with a hdf5 file by implement `__len__` and `__getitem__`.

With list_hdf5_datasets and read_hdf5 ops, it is also possible to
ease the HDF5Dataset in eager mode. In eager mode, HDF5Dataset
could juse call list_hdf5_datasets to find out all the necessary
information, then calling read_hdf5 in pieces to maintain the `batch_size`
to be fed in tf.keras.

The limitation would be in graph mode as in graph mode user still has to
specify almost everything dtype, shape, name for HDF5Dataset to work.

This PR has not changed HDF5Dataset implementation to use
list_hdf5_datasets and read_hdf5 ops. But this could be easily done and
see 384 for similar changes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Process default value of count and start

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Support HDF5Datast in graph mode

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants