-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix pytorch axis #14930
Fix pytorch axis #14930
Conversation
# Conflicts: # python/tvm/relay/frontend/tensorflow.py
…nto mikeseven/fix_pytorch_axis
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a unit test for your changes.
python/tvm/relay/frontend/pytorch.py
Outdated
@@ -399,7 +399,10 @@ def squeeze(self, inputs, input_types): | |||
axis = None | |||
else: | |||
# TODO (t-vi): why is the cast to int needed? similarly elsewhere | |||
axis = [int(inputs[1])] | |||
# [MBS] squeeze axis can be a list | |||
# axis = [int(inputs[1])] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably the old code can be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test added with other tests for this frontend.
* fix conv transpose import from TF * fix String::fromwe() to String::from() * torch squeeze can use a list of axis * added test for squeeze with multiple axis (pytorch 2) * clean old code * code without reformating --------- Co-authored-by: Mikael Sevenier <mikael.sevenier@sima.ai>
relay.frontend.from_pytorch maps torch.squeeze() using 1 axis while pytorch allows multiple axis.
This fixes that.
Note that this pattern seems to be used for other functions. This frontend may need to be updated elsewhere. I only tested for squeeze for my current needs.
cc @areusch