-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relax][PyTorch] Fix output shape of torch.nn.functional.scaled_dot_product_attention
#17379
Conversation
torch.nn.functional.scaled_dot_product_attention
we can transpose to get the expected result. Thanks for the effort! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall looks good to me
456c72e
to
185d28c
Compare
MSC E2E test is failing. Seems like we also need to change something other than relax frontend. Link to the ci log: https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-unity/detail/PR-17379/6/pipeline/ tests/python/contrib/test_msc/test_translate_torch.py::test_attention FAILED
|
185d28c
to
43268e1
Compare
43268e1
to
a783823
Compare
a783823
to
a2b29c0
Compare
torch.nn.functional.scaled_dot_product_attention outputs in the shape of
(N, ..., L, E_v)
butrelax.op.nn.attention
does(N, L, ..., E_v)
so the output should also be transposed.Maybe we should add E2E tests in
tests/python/nightly/
to check the relax torch frontend.cc: @yongwww