Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pytorch] Add support for aten::unflatten #16131

Merged
merged 5 commits into from
Nov 22, 2023

Conversation

mshr-h
Copy link
Contributor

@mshr-h mshr-h commented Nov 15, 2023

Copy link
Contributor

@vvchernov vvchernov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @mshr-h! LGTM, but it is needed one check

unflattened_size = tuple(inputs[2])
dshape = get_const_tuple(self.infer_shape_with_prelude(data))
assert len(dshape) > dim
new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add check that dshape[dim] == multiplication of dimensions in unflattened_size

Copy link
Contributor Author

@mshr-h mshr-h Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vvchernov Thanks!
I don't think we have to check it because torch.jit.trace does it.
They provide something like the below RuntimeError when the shape is wrong.

RuntimeError: unflatten: Provided sizes [3, 5, 3, -1] don't multiply up to the size of dim 0 (60) in the input tensor

Should we add the check in TVM's PyTorch frontend?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @mshr-h! Ok torch.jit.trace do it, but in this case we do not need assert len(dshape) > dim.
It looks like TVM usually rechecks all corner cases.
About -1 in unflattened_size. In this case we can multiply together other dimension and check dshape[dim] % mult == 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vvchernov Thanks. Added the assertion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @mshr-h! There are two cases: unflattened_size has -1 and does not have. You check only the first one. Example: dshape[dim] = 8, unflattened_size = [2, 2, 1, 1] pass your assert, but it is failure case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked that _op.reshape does not check size correctness by it-self. Nevertheless I found that dim can be not only -1, but from list {0, -1, -2, -3, -4} See

def reshape(data, newshape, allowzero=False):

I suggest to check -1 case only. It looks like torch does not have other options

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vvchernov
Thanks! Updated the shape check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, dim can be negative but I didn't check that. I'll add the check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, dim can be negative but I didn't check that. I'll add the check.

Done.

@mshr-h mshr-h marked this pull request as ready for review November 16, 2023 03:13
Copy link
Contributor

@jikechao jikechao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! @mshr-h Thanks for your PR.

@mshr-h mshr-h marked this pull request as draft November 17, 2023 04:33
Copy link
Contributor

@vvchernov vvchernov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @mshr-h! Thank you for careful work! Sorry for many small boring corrections. LGTM

@mshr-h mshr-h marked this pull request as ready for review November 17, 2023 05:53
@github-actions github-actions bot requested a review from junrushao November 20, 2023 01:40
@masahi masahi merged commit bce8243 into apache:main Nov 22, 2023
18 checks passed
@mshr-h mshr-h deleted the pytorch-unflatten branch November 22, 2023 05:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FRONTEND][PYTORCH] aten:unflatten not supported
4 participants