Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discuss Batch Standards in TFIO with Keras #382

Open
BryanCutler opened this issue Jul 26, 2019 · 10 comments
Open

Discuss Batch Standards in TFIO with Keras #382

BryanCutler opened this issue Jul 26, 2019 · 10 comments

Comments

@BryanCutler
Copy link
Member

Following the discussion on #366 batching can serve different purposes and optimizing for each is not always done the same way.

Previously, we thought we will reuse the same batch concept for serving two purposes:
Batch is a way to improve performance for record based files such as parquet, feature, etc, as we really don't want to "read one integer or one float32 at a time". This is a caching issue.

We also want to directly feed the dataset to tf.keras so batch concept in tf.keras is reused (number of samples populated in neural network). This is a model parameter.
But those two batch concepts are different. @BryanCutler I am wondering if it makes sense to split those things out. In other words, for 1) we should "read as much record as possible, as long as it fits the memory", and for 2) we should do a rebatch() to adjust the batch_size that has been feeded to tf.keras?

@yongtang
Copy link
Member

To adds additional information, tf.data actually has a tf.data.Dataset.concatenate API which effectively solves the issue of multiple files:

a = Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = Dataset.range(4, 8)  # ==> [ 4, 5, 6, 7 ]
a.concatenate(b)  # ==> [ 1, 2, 3, 4, 5, 6, 7 ]

I think we only need to focus on single filename case (as multiple files could be processed at higher level with concatenate). After Dataset has been concatenated batch mode could reapply:

a.concatenate(b).batch(batch_size, drop_reminder)

# or

a.batch(batch_size, drop_reminder).concatenate(b.batch(batch_size, drop_reminder)

I do think rebatch could be a useful feature which is not available in tf.data.Dataset. rebatch could be emulated with:

d.unbatch().batch(batch_size) # <= d.rebatch(batch_size)

but I suspect d.unbatch().batch(batch_size) is not going to have a good performance.

yongtang added a commit to yongtang/io that referenced this issue Jul 28, 2019
This PR is part of the effort in enhancing performance and
ease of use for tf.data pipeline, as was discussed in tensorflow#382 and tensorflow#366.

Previously, HDF5Dataset is relatively manual and user
has to find out the dataset (columns) in hdf5 file.

In this PR, the idea is to allow user to use list_hdf5_datasets
to probe the shape, dtype, and name of the datasets within HDF5.
A subsequent call to read_hdf5 will bring content to a shaped Tensor
so that it could be used later in TensorFlow.

The read_hdf5 has the option to specify a slice (or a subblock) of the
dataset. This should open up possibility in the future to allow binding
a class with a hdf5 file by implement `__len__` and `__getitem__`.

With list_hdf5_datasets and read_hdf5 ops, it is also possible to
ease the HDF5Dataset in eager mode. In eager mode, HDF5Dataset
could juse call list_hdf5_datasets to find out all the necessary
information, then calling read_hdf5 in pieces to maintain the `batch_size`
to be fed in tf.keras.

The limitation would be in graph mode as in graph mode user still has to
specify almost everything dtype, shape, name for HDF5Dataset to work.

This PR has not changed HDF5Dataset implementation to use
list_hdf5_datasets and read_hdf5 ops. But this could be easily done and
see 384 for similar changes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
yongtang added a commit to yongtang/io that referenced this issue Jul 28, 2019
This PR is part of the effort in enhancing performance and
ease of use for tf.data pipeline, as was discussed in tensorflow#382 and tensorflow#366.

Previously, HDF5Dataset is relatively manual and user
has to find out the dataset (columns) in hdf5 file.

In this PR, the idea is to allow user to use list_hdf5_datasets
to probe the shape, dtype, and name of the datasets within HDF5.
A subsequent call to read_hdf5 will bring content to a shaped Tensor
so that it could be used later in TensorFlow.

The read_hdf5 has the option to specify a slice (or a subblock) of the
dataset. This should open up possibility in the future to allow binding
a class with a hdf5 file by implement `__len__` and `__getitem__`.

With list_hdf5_datasets and read_hdf5 ops, it is also possible to
ease the HDF5Dataset in eager mode. In eager mode, HDF5Dataset
could juse call list_hdf5_datasets to find out all the necessary
information, then calling read_hdf5 in pieces to maintain the `batch_size`
to be fed in tf.keras.

The limitation would be in graph mode as in graph mode user still has to
specify almost everything dtype, shape, name for HDF5Dataset to work.

This PR has not changed HDF5Dataset implementation to use
list_hdf5_datasets and read_hdf5 ops. But this could be easily done and
see 384 for similar changes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@yongtang
Copy link
Member

Added a PR #393 to introduce rebatch(batch_size) ~= unbatch().batch(batch_size).

@terrytangyuan @BryanCutler Here is my thinking on tf.keras batch vs cache batch issue:

  1. We want to max out the cache batch (unless not fit into memory)
  2. the tf.keras batch_size could always be placed as the last step before feeding too tf.keras (so that caching in the middle are not split into small chunks (of several float32 numbers).

So overall I think we could:

  1. All Dataset API will not expose a batch_size, but internally Dataset always batched (up to the cache size).
  2. We provide a rebatch(batch_size) so that user could call it before pass to tf.keras.
# Note:
#     - each individual record is [h, w, c]
#     - total records could fill into [n, h, w, c]
#     - out Dataset shape will be [c, h, w, c] where c <= n
class Dataset:
    def __init__(self):
        # use a large cache of c size
...
...

dateset = Dataset().map()....filter() # shape is [None, h, w, c]

mode.fit(dataset.rebatch(batch_size))

yongtang added a commit to yongtang/io that referenced this issue Jul 31, 2019
This PR is part of the effort in enhancing performance and
ease of use for tf.data pipeline, as was discussed in tensorflow#382 and tensorflow#366.

Previously, HDF5Dataset is relatively manual and user
has to find out the dataset (columns) in hdf5 file.

In this PR, the idea is to allow user to use list_hdf5_datasets
to probe the shape, dtype, and name of the datasets within HDF5.
A subsequent call to read_hdf5 will bring content to a shaped Tensor
so that it could be used later in TensorFlow.

The read_hdf5 has the option to specify a slice (or a subblock) of the
dataset. This should open up possibility in the future to allow binding
a class with a hdf5 file by implement `__len__` and `__getitem__`.

With list_hdf5_datasets and read_hdf5 ops, it is also possible to
ease the HDF5Dataset in eager mode. In eager mode, HDF5Dataset
could juse call list_hdf5_datasets to find out all the necessary
information, then calling read_hdf5 in pieces to maintain the `batch_size`
to be fed in tf.keras.

The limitation would be in graph mode as in graph mode user still has to
specify almost everything dtype, shape, name for HDF5Dataset to work.

This PR has not changed HDF5Dataset implementation to use
list_hdf5_datasets and read_hdf5 ops. But this could be easily done and
see 384 for similar changes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@BryanCutler
Copy link
Member Author

This sounds like a good paradigm to me. Do you plan some way for the user to set the cache batch size or is it determined by the Dataset?

For the Arrow datasets, I'll have to think a bit more on the best way to handle this. Currently, records are read as batch, which is effectively the "cache batch" just not in the form of tensors. If this outputs the entire batch as tensors then it would be a second copy of all the batched data, which might not be good.

Still, it might be useful to cache the batch as tensors and release the original record batch. Then the user could call a rebatch(), like above, in the call to keras model.fit()

@yongtang
Copy link
Member

@BryanCutler for internal cache size I am thinking about "try to fine tune automatically" unless user override (through **kwargs). In one PR for Parquet https://github.com/tensorflow/io/pull/384/files I am using a capacity to fine tune it. Currently hardcoded as 65536 but I plan to add more fine tuning later.

The capacity could be overridden in **kwargs.

@yongtang
Copy link
Member

@BryanCutler Haven't spend much time to look into Arrow yet, though it looks like there are two different types: the feather file format and the arrow streaming format.

The file format probably fits the Tensor case we discussed here as in theory, we could easily distribute a list of cached tensors to different nodes to improve the performance by concurrently running ops on list of chunked (cached) tensors.

The streaming format, depending on if it is replayable or not, may not fit the file format handling.

@yongtang
Copy link
Member

@BryanCutler @terrytangyuan Thinking again I think the problem probably goes to either tf.data should be closer to source (read file format into tensorflow's graph) or closer to destination (as tf.keras input).

The overall tf.data pipeline works well with tf.keras (closer to destination). But it actually has many limitations when it is only closer to source as it lacks many fundamental operations that are required for feature engineering.

This is very obviously when I play with pandas API in PR #356. Could not achieve many operations with tf.data, as tf.data is just an iterable. But I could easily implement anything with plain Tensor and ops. (still actively working on #356 but I do need tools beyond tf.data)

That is actually why I want to rework on quite a few file formats (hdf5, parquet, Avro, text in #399, #392, #384) as I would like to be able to read file into a Tensor so that I can just have additional ops for feature engineering. (The PR has been done in a way such that it allows user to read data into both tf.data and tensor with the same code base).

On the other hand, there are cases where Dataset is closer to destination, meaning data are read from data file and feed directly into tf.keras. In that situations I think batch that matches tf.keras would make much sense.

@BryanCutler For the batch in Arrow I tend to think we could default one way but allows user to override to optimize in another way (depending on if the usage is closer to source or destination).

@yongtang
Copy link
Member

yongtang commented Aug 1, 2019

@BryanCutler Overall my experience when reworking on some of the format (hdf5, parquet, Avro) is that, if a file is splittable by nature, then we could just write primitive ops to read the file. The primitive ops could be used in a normal graph to read data into a Tensor (which could be easily accessible with many powerful operations), or it could piece together into a tf.data pipeline where memory could be limited (TB's of data vs memory).

if a file is not splittable by nature (e.g., PCAP file which is just a concatenation of variable length of packets) then it probably fits a C++ implementation of Dataset.

But still even in case a file is not splittable, we should still support reading a whole file into Tensor in addition to a C++ dataset implementation. The benefit is that manipulating Tensor is too easy while manipulating tf.data is too limited.

@yongtang
Copy link
Member

yongtang commented Aug 2, 2019

@BryanCutler @terrytangyuan Some additional discussion about the ways to process input format. It seems we are mostly dealing with the following

  1. Strict streaming input where input are just concatenations of variable number of records (no way to search or locate):
    • Apache Arrow Stream (@BryanCutler to confirm?)
    • Kafka/PubSub streams
    • Compressions without meta data (gzip)
    • JOSN (not including line separated json/ndjson)
    • Network capture PCAP file
    • Steaming Audio and Video input (e.g. RTP)
    • Some Audio and Video file format (non-seekable)
  2. Splittable file where input could be read in chunks (not necessarily random chunks, just chunks).
    • Text file (splittable through searching separators within a small range)
    • Fixed length record (locations of the record are decided by file offset), e.g., NumPy npy and npz (non pickle)
    • Parquet file (splittable through RowGroups)
    • Avro file (splittable through sync positions)
    • HDF5 file (sort of, due to memory layout)
    • ndjson (similar to text)
    • csv (similar to text)
    • archives with metadata index (e.g., zip)
    • Some Audio and Video files (e.g., WAV could be seekable).
    • Apache Arrow Batch (@BryanCutler to confirm?)
  3. Whole file in memory due to normal use, or by design
    • Most "normal" image files JPEG/PNG/WEBP
    • Special image files such as TIFF (consists of multiple frames but we rarely see big TIFF files that could not fit into memory)
    • Feather file (fit memory by design as there is only one batch record).

I come up with the list, as during the feather PR #404 I noticed that feather file was designed to support zero-copy and expect everything is in memory (at least for each column). I still expect future Arrow Feather format versions may support splittable chunks (multiple batch records).

We discussed about the limitations of tf.data.Datast as it is an iterable (no __len__ or __getitem__). Fitting every formats into tf.data.Dataset actually forced user to lose information when reading those files that has a natural indexing (__len__ or __getitem__).

Given the above scenarios, I am thinking we could do the following:

  1. Implement C++ tf.data.Dataset for strict streaming input.
  2. Implement C++ custom ops with Tensor as output, for splittable and in-memory inputs so that user could use rich Tensor manipulation functions.
  3. Implement Python tf.data.Dataset for for splittable and in-memory inputs.
    • Splittable files could be wired up to tf.data.Dataset in ranges (see several pending PRs)
    • In-memory files will always need to read everything into memory (unless zero-copy is in place with memory mapped files).

Any comments or suggestions?

@BryanCutler
Copy link
Member Author

Thanks @yongtang for the very detailed list! For Arrow, there are 2 memory formats: Arrow Stream and Arrow File. Both support chunking as record batches. Arrow Stream is strict stream. Arrow File is random access, but not necessarily to be read all in memory. Feather files use the Arrow File format on disk, and the file could be chunked as well. I think it's just currently not implemented to read the file a chunk at a time, but could be in the future.

@BryanCutler
Copy link
Member Author

The above proposal sounds pretty good to me. I would just want to be careful that for splittable (or chunked) files we keep a code path available that will keep memory usage at a minimum, but also support reading everything into memory if the user wants.

yongtang added a commit to yongtang/io that referenced this issue Aug 4, 2019
This PR is part of the effort in enhancing performance and
ease of use for tf.data pipeline, as was discussed in tensorflow#382 and tensorflow#366.

Previously, HDF5Dataset is relatively manual and user
has to find out the dataset (columns) in hdf5 file.

In this PR, the idea is to allow user to use list_hdf5_datasets
to probe the shape, dtype, and name of the datasets within HDF5.
A subsequent call to read_hdf5 will bring content to a shaped Tensor
so that it could be used later in TensorFlow.

The read_hdf5 has the option to specify a slice (or a subblock) of the
dataset. This should open up possibility in the future to allow binding
a class with a hdf5 file by implement `__len__` and `__getitem__`.

With list_hdf5_datasets and read_hdf5 ops, it is also possible to
ease the HDF5Dataset in eager mode. In eager mode, HDF5Dataset
could juse call list_hdf5_datasets to find out all the necessary
information, then calling read_hdf5 in pieces to maintain the `batch_size`
to be fed in tf.keras.

The limitation would be in graph mode as in graph mode user still has to
specify almost everything dtype, shape, name for HDF5Dataset to work.

This PR has not changed HDF5Dataset implementation to use
list_hdf5_datasets and read_hdf5 ops. But this could be easily done and
see 384 for similar changes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
yongtang added a commit that referenced this issue Aug 4, 2019
* Rework on HDF5: add list_hdf5_datasets and read_hdf5 ops

This PR is part of the effort in enhancing performance and
ease of use for tf.data pipeline, as was discussed in #382 and #366.

Previously, HDF5Dataset is relatively manual and user
has to find out the dataset (columns) in hdf5 file.

In this PR, the idea is to allow user to use list_hdf5_datasets
to probe the shape, dtype, and name of the datasets within HDF5.
A subsequent call to read_hdf5 will bring content to a shaped Tensor
so that it could be used later in TensorFlow.

The read_hdf5 has the option to specify a slice (or a subblock) of the
dataset. This should open up possibility in the future to allow binding
a class with a hdf5 file by implement `__len__` and `__getitem__`.

With list_hdf5_datasets and read_hdf5 ops, it is also possible to
ease the HDF5Dataset in eager mode. In eager mode, HDF5Dataset
could juse call list_hdf5_datasets to find out all the necessary
information, then calling read_hdf5 in pieces to maintain the `batch_size`
to be fed in tf.keras.

The limitation would be in graph mode as in graph mode user still has to
specify almost everything dtype, shape, name for HDF5Dataset to work.

This PR has not changed HDF5Dataset implementation to use
list_hdf5_datasets and read_hdf5 ops. But this could be easily done and
see 384 for similar changes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Process default value of count and start

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Support HDF5Datast in graph mode

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
i-ony pushed a commit to i-ony/io that referenced this issue Feb 8, 2021
)

* Rework on HDF5: add list_hdf5_datasets and read_hdf5 ops

This PR is part of the effort in enhancing performance and
ease of use for tf.data pipeline, as was discussed in tensorflow#382 and tensorflow#366.

Previously, HDF5Dataset is relatively manual and user
has to find out the dataset (columns) in hdf5 file.

In this PR, the idea is to allow user to use list_hdf5_datasets
to probe the shape, dtype, and name of the datasets within HDF5.
A subsequent call to read_hdf5 will bring content to a shaped Tensor
so that it could be used later in TensorFlow.

The read_hdf5 has the option to specify a slice (or a subblock) of the
dataset. This should open up possibility in the future to allow binding
a class with a hdf5 file by implement `__len__` and `__getitem__`.

With list_hdf5_datasets and read_hdf5 ops, it is also possible to
ease the HDF5Dataset in eager mode. In eager mode, HDF5Dataset
could juse call list_hdf5_datasets to find out all the necessary
information, then calling read_hdf5 in pieces to maintain the `batch_size`
to be fed in tf.keras.

The limitation would be in graph mode as in graph mode user still has to
specify almost everything dtype, shape, name for HDF5Dataset to work.

This PR has not changed HDF5Dataset implementation to use
list_hdf5_datasets and read_hdf5 ops. But this could be easily done and
see 384 for similar changes.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Process default value of count and start

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Support HDF5Datast in graph mode

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants