-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rebatch method for Dataset #393
Conversation
This PR adds rebatch method for Dataset where ``` dataset.apply(rebatch(n)) = dataset.unbatch().batch(n) ``` The motivation for rebatch is that there are situations we read the data in big batches but then we want to adjust the batch size to fit differnet scenarios. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
/cc @feihugis to take a look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yongtang This PR looks great and will improve the performance of rebatching! Just left a few minor comments here.
There is another RebatchDatasetOp in TensorFlow, which utilizes the grappler to update the batch size but it only allows the rebatching by batch_size/num_workers
for the distributed scenario. I think these two dataset ops are different, so we had better rename the op name to avoid the potential confusion.
const auto& input_shapes = input_->output_shapes(); | ||
output_shapes_.reserve(input_shapes.size()); | ||
// Always set the first dim as None unless batch_mode is specified. | ||
for (const auto& input_shape : input_shapes) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to consider the case with unknown rank like here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@feihugis Done. PR updated.
int64 chunk_to_read = (current_batch_size_ - current_index_) < (dataset()->batch_size_ - chunk_read) ? (current_batch_size_ - current_index_) : (dataset()->batch_size_ - chunk_read); | ||
for (int i = 0; i < tensors_.size(); ++i) { | ||
// TODO: concurrent copy? | ||
for (int64 r = 0; r < chunk_to_read; r++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
r++
-> ++r
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tests/test_text_eager.py
Outdated
@@ -53,6 +54,13 @@ def test_text_input(): | |||
i += 1 | |||
assert i == len(lines) | |||
|
|||
rebatch_dataset = text_dataset.apply(core_io.rebatch(5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More cases can be tested: new_batch_size > cur_batch_size, new_batch_size == cur_batch_size, new_batch_size < cur_batch_size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Additional tests added.
tensorflow_io/core/ops/core_ops.cc
Outdated
|
||
namespace tensorflow { | ||
|
||
REGISTER_OP("RebatchDataset") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need drop_remainder
Input, which will be aligned with BatchDataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The batch_mode
input could take a string to specify the batch mode:
- keep: leave the reminder as is.
- drop: drop the reminder
- pad: pad the reminder.
} | ||
// Finally, resize if needed | ||
if (chunk_read > 0) { | ||
if (chunk_read < dataset()->batch_size_) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, here we assume the remainder needs to be kept. Maybe we can add a comment about the assumption here. Also, If we add drop_remainder
input, users can decide whether to keep the remainder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. keep
, drop
, and pad
modes have been added.
} | ||
} | ||
if (out_tensors->size() != tensors_.size()) { | ||
return errors::InvalidArgument("number tensors should match previous one, ", tensors_.size(), " vs. ", out_tensors->size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have the sanity check for C++ style? This line length exceeds the limitation of 80 chars.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In TensorFlow, at one point the C++ style was enforced with clang-format
. The issue with clang-format
was that different versions of clang-format
have different styles so it is really not easy to figure out which one is the right one. TensorFlow dropped the C++ style check later.
I think we could leave the C++ style check alone until we find a clang-format version that stabilize.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Thanks!
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@feihugis Thanks for the review. The Also the name of the C++ class has been changed to For python function name I really think |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @yongtang! LGTM. Left one minor comment.
errors::InvalidArgument("Batch size must be greater than zero.")); | ||
|
||
string batch_mode = ""; | ||
OP_REQUIRES_OK(ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: do we need to check if the input batch_mode is valid?
*end_of_sequence = true; | ||
return Status::OK(); | ||
} | ||
// otherwise "pad" means keep the size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just remind that pad
is not implemented yet.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add rebatch method for Dataset This PR adds rebatch method for Dataset where ``` dataset.apply(rebatch(n)) = dataset.unbatch().batch(n) ``` The motivation for rebatch is that there are situations we read the data in big batches but then we want to adjust the batch size to fit differnet scenarios. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add additional tests, also add batch_mode = "keep", "drop", "pad" mode Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Rename RebatchDataset to AdjustBatchDataset Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add additional processing in case shape is unknown Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Address review comments Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix failed tests Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This PR adds rebatch method for Dataset where
The motivation for rebatch is that there are situations we read the data in
big batches but then we want to adjust the batch size to fit differnet
scenarios.
This is part of #382.
Signed-off-by: Yong Tang yong.tang.github@outlook.com