-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Pytorch] Add support for aten::unflatten
#16131
Conversation
cfa10bf
to
bfbef63
Compare
bfbef63
to
6a42bff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @mshr-h! LGTM, but it is needed one check
unflattened_size = tuple(inputs[2]) | ||
dshape = get_const_tuple(self.infer_shape_with_prelude(data)) | ||
assert len(dshape) > dim | ||
new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add check that dshape[dim] == multiplication of dimensions in unflattened_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vvchernov Thanks!
I don't think we have to check it because torch.jit.trace does it.
They provide something like the below RuntimeError when the shape is wrong.
RuntimeError: unflatten: Provided sizes [3, 5, 3, -1] don't multiply up to the size of dim 0 (60) in the input tensor
Should we add the check in TVM's PyTorch frontend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @mshr-h! Ok torch.jit.trace do it, but in this case we do not need assert len(dshape) > dim
.
It looks like TVM usually rechecks all corner cases.
About -1 in unflattened_size. In this case we can multiply together other dimension and check dshape[dim] % mult == 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vvchernov Thanks. Added the assertion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @mshr-h! There are two cases: unflattened_size has -1 and does not have. You check only the first one. Example: dshape[dim] = 8, unflattened_size = [2, 2, 1, 1] pass your assert, but it is failure case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've checked that _op.reshape does not check size correctness by it-self. Nevertheless I found that dim can be not only -1, but from list {0, -1, -2, -3, -4} See
tvm/python/tvm/relay/op/transform.py
Line 243 in 748882a
def reshape(data, newshape, allowzero=False): |
I suggest to check -1 case only. It looks like torch does not have other options
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vvchernov
Thanks! Updated the shape check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, dim can be negative but I didn't check that. I'll add the check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, dim can be negative but I didn't check that. I'll add the check.
Done.
6a42bff
to
252a58b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! @mshr-h Thanks for your PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @mshr-h! Thank you for careful work! Sorry for many small boring corrections. LGTM
Fix #15663
Support torch.unflatten.
cc @jikechao @vvchernov @Hzfengsy @junrushao