Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rebatch method for Dataset #393

Merged
merged 6 commits into from
Jul 31, 2019
Merged

Conversation

yongtang
Copy link
Member

@yongtang yongtang commented Jul 29, 2019

This PR adds rebatch method for Dataset where

dataset.apply(rebatch(n)) = dataset.unbatch().batch(n)

The motivation for rebatch is that there are situations we read the data in
big batches but then we want to adjust the batch size to fit differnet
scenarios.

This is part of #382.

Signed-off-by: Yong Tang yong.tang.github@outlook.com

This PR adds rebatch method for Dataset where
```
dataset.apply(rebatch(n)) = dataset.unbatch().batch(n)
```

The motivation for rebatch is that there are situations we read the data in
big batches but then we want to adjust the batch size to fit differnet
scenarios.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@yongtang
Copy link
Member Author

/cc @feihugis to take a look.

Copy link
Member

@feihugis feihugis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yongtang This PR looks great and will improve the performance of rebatching! Just left a few minor comments here.

There is another RebatchDatasetOp in TensorFlow, which utilizes the grappler to update the batch size but it only allows the rebatching by batch_size/num_workers for the distributed scenario. I think these two dataset ops are different, so we had better rename the op name to avoid the potential confusion.

const auto& input_shapes = input_->output_shapes();
output_shapes_.reserve(input_shapes.size());
// Always set the first dim as None unless batch_mode is specified.
for (const auto& input_shape : input_shapes) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to consider the case with unknown rank like here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@feihugis Done. PR updated.

int64 chunk_to_read = (current_batch_size_ - current_index_) < (dataset()->batch_size_ - chunk_read) ? (current_batch_size_ - current_index_) : (dataset()->batch_size_ - chunk_read);
for (int i = 0; i < tensors_.size(); ++i) {
// TODO: concurrent copy?
for (int64 r = 0; r < chunk_to_read; r++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r++ -> ++r

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -53,6 +54,13 @@ def test_text_input():
i += 1
assert i == len(lines)

rebatch_dataset = text_dataset.apply(core_io.rebatch(5))
Copy link
Member

@feihugis feihugis Jul 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More cases can be tested: new_batch_size > cur_batch_size, new_batch_size == cur_batch_size, new_batch_size < cur_batch_size.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Additional tests added.


namespace tensorflow {

REGISTER_OP("RebatchDataset")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need drop_remainder Input, which will be aligned with BatchDataset?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batch_mode input could take a string to specify the batch mode:

  • keep: leave the reminder as is.
  • drop: drop the reminder
  • pad: pad the reminder.

}
// Finally, resize if needed
if (chunk_read > 0) {
if (chunk_read < dataset()->batch_size_) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, here we assume the remainder needs to be kept. Maybe we can add a comment about the assumption here. Also, If we add drop_remainder input, users can decide whether to keep the remainder.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. keep, drop, and pad modes have been added.

}
}
if (out_tensors->size() != tensors_.size()) {
return errors::InvalidArgument("number tensors should match previous one, ", tensors_.size(), " vs. ", out_tensors->size());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have the sanity check for C++ style? This line length exceeds the limitation of 80 chars.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In TensorFlow, at one point the C++ style was enforced with clang-format. The issue with clang-format was that different versions of clang-format have different styles so it is really not easy to figure out which one is the right one. TensorFlow dropped the C++ style check later.

I think we could leave the C++ style check alone until we find a clang-format version that stabilize.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks!

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@yongtang
Copy link
Member Author

@feihugis Thanks for the review. The batch_mode takes keep, drop, pad modes to decide what to do when reminder surface.

Also the name of the C++ class has been changed to AdjustBatchDataset.

For python function name I really think "rebatch" makes plenty of sense. I will just leave as is. In the future if this Dataset ops is to be added to TensorFlow core repo then we could rethink the name I think.

Copy link
Member

@feihugis feihugis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yongtang! LGTM. Left one minor comment.

errors::InvalidArgument("Batch size must be greater than zero."));

string batch_mode = "";
OP_REQUIRES_OK(ctx,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: do we need to check if the input batch_mode is valid?

*end_of_sequence = true;
return Status::OK();
}
// otherwise "pad" means keep the size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just remind that pad is not implemented yet.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@yongtang yongtang merged commit 06b38f8 into tensorflow:master Jul 31, 2019
@yongtang yongtang deleted the rebatch branch July 31, 2019 00:18
i-ony pushed a commit to i-ony/io that referenced this pull request Feb 8, 2021
* Add rebatch method for Dataset

This PR adds rebatch method for Dataset where
```
dataset.apply(rebatch(n)) = dataset.unbatch().batch(n)
```

The motivation for rebatch is that there are situations we read the data in
big batches but then we want to adjust the batch size to fit differnet
scenarios.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add additional tests, also add batch_mode = "keep", "drop", "pad" mode

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Rename RebatchDataset to AdjustBatchDataset

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add additional processing in case shape is unknown

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Address review comments

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Fix failed tests

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants