Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][PyTorch] Fix output shape of torch.nn.functional.scaled_dot_product_attention #17379

Merged
merged 3 commits into from
Sep 20, 2024

Conversation

mshr-h
Copy link
Contributor

@mshr-h mshr-h commented Sep 16, 2024

torch.nn.functional.scaled_dot_product_attention outputs in the shape of (N, ..., L, E_v) but relax.op.nn.attention does (N, L, ..., E_v) so the output should also be transposed.

Maybe we should add E2E tests in tests/python/nightly/ to check the relax torch frontend.

cc: @yongwww

@mshr-h mshr-h marked this pull request as ready for review September 16, 2024 14:25
@mshr-h mshr-h changed the title Fix torch sdpa converter [Relax][PyTorch] Fix output shape of torch.nn.functional.scaled_dot_product_attention Sep 16, 2024
@mshr-h mshr-h marked this pull request as draft September 17, 2024 02:38
@yongwww
Copy link
Member

yongwww commented Sep 17, 2024

we can transpose to get the expected result. Thanks for the effort!

Copy link
Member

@yongwww yongwww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall looks good to me

tests/python/nightly/relax/test_frontend_from_fx.py Outdated Show resolved Hide resolved
@mshr-h mshr-h marked this pull request as ready for review September 17, 2024 04:37
@mshr-h mshr-h force-pushed the fix-torch-sdpa-converter branch 2 times, most recently from 456c72e to 185d28c Compare September 17, 2024 07:33
@mshr-h
Copy link
Contributor Author

mshr-h commented Sep 17, 2024

MSC E2E test is failing. Seems like we also need to change something other than relax frontend.
@Archermmt Do you have any ideas on how to fix the error?

Link to the ci log: https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-unity/detail/PR-17379/6/pipeline/

tests/python/contrib/test_msc/test_translate_torch.py::test_attention FAILED

[2024-09-17T10:06:53.212Z] tests/python/contrib/test_msc/test_translate_torch.py::test_attention FAILED
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] =================================== FAILURES ===================================
[2024-09-17T10:06:53.212Z] ________________________________ test_attention ________________________________
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] def test_attention():
[2024-09-17T10:06:53.212Z] """test torch translator for attention"""
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] # pylint: disable=import-outside-toplevel
[2024-09-17T10:06:53.212Z] import torch.nn.functional as F
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] class Attention1(Module):
[2024-09-17T10:06:53.212Z] def forward(self, q_data, k_data, v_data):
[2024-09-17T10:06:53.212Z] return F.scaled_dot_product_attention(q_data, k_data, v_data)
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] class Attention2(Module):
[2024-09-17T10:06:53.212Z] def forward(self, q_data, k_data, v_data):
[2024-09-17T10:06:53.212Z] return F.scaled_dot_product_attention(q_data, k_data, v_data, is_causal=True)
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] input_info = [
[2024-09-17T10:06:53.212Z] ([32, 8, 128, 64], "float32"),
[2024-09-17T10:06:53.212Z] ([32, 8, 128, 64], "float32"),
[2024-09-17T10:06:53.212Z] ([32, 8, 128, 64], "float32"),
[2024-09-17T10:06:53.212Z] ]
[2024-09-17T10:06:53.212Z] > verify_model(Attention1(), input_info)
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] tests/python/contrib/test_msc/test_translate_torch.py:1127:
[2024-09-17T10:06:53.212Z] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
[2024-09-17T10:06:53.212Z] tests/python/contrib/test_msc/test_translate_torch.py:52: in verify_model
[2024-09-17T10:06:53.212Z] tvm.testing.assert_allclose(
[2024-09-17T10:06:53.212Z] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] actual = array([[[[0.5253085 , 0.48211107, 0.52921075, ..., 0.518867 ,
[2024-09-17T10:06:53.212Z] 0.49926636, 0.48493868],
[2024-09-17T10:06:53.212Z] [0.5294311 ...5],
[2024-09-17T10:06:53.212Z] [0.47335747, 0.48579183, 0.5360674 , ..., 0.543607 ,
[2024-09-17T10:06:53.212Z] 0.5020893 , 0.47848547]]]], dtype=float32)
[2024-09-17T10:06:53.212Z] desired = array([[[[0.5253085 , 0.48211107, 0.52921075, ..., 0.518867 ,
[2024-09-17T10:06:53.212Z] 0.49926636, 0.48493868],
[2024-09-17T10:06:53.212Z] [0.49697113... ],
[2024-09-17T10:06:53.212Z] [0.47335747, 0.48579183, 0.5360674 , ..., 0.543607 ,
[2024-09-17T10:06:53.212Z] 0.5020893 , 0.47848547]]]], dtype=float32)
[2024-09-17T10:06:53.212Z] rtol = 1e-05, atol = 1e-05
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
[2024-09-17T10:06:53.212Z] """Version of np.testing.assert_allclose with atol and rtol fields set
[2024-09-17T10:06:53.212Z] in reasonable defaults.
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] Arguments actual and desired are not interchangeable, since the function
[2024-09-17T10:06:53.212Z] compares the abs(actual-desired) with atol+rtol*abs(desired). Since we
[2024-09-17T10:06:53.212Z] often allow desired to be close to zero, we generally want non-zero atol.
[2024-09-17T10:06:53.212Z] """
[2024-09-17T10:06:53.212Z] actual = np.asanyarray(actual)
[2024-09-17T10:06:53.212Z] desired = np.asanyarray(desired)
[2024-09-17T10:06:53.212Z] > np.testing.assert_allclose(actual.shape, desired.shape)
[2024-09-17T10:06:53.212Z] E AssertionError:
[2024-09-17T10:06:53.212Z] E Not equal to tolerance rtol=1e-07, atol=0
[2024-09-17T10:06:53.212Z] E
[2024-09-17T10:06:53.212Z] E Mismatched elements: 2 / 4 (50%)
[2024-09-17T10:06:53.212Z] E Max absolute difference: 120
[2024-09-17T10:06:53.212Z] E Max relative difference: 15.
[2024-09-17T10:06:53.212Z] E x: array([ 32, 8, 128, 64])
[2024-09-17T10:06:53.212Z] E y: array([ 32, 128, 8, 64])
[2024-09-17T10:06:53.212Z]
[2024-09-17T10:06:53.212Z] python/tvm/testing/utils.py:119: AssertionError

@mshr-h mshr-h force-pushed the fix-torch-sdpa-converter branch from 185d28c to 43268e1 Compare September 17, 2024 14:50
@mshr-h mshr-h force-pushed the fix-torch-sdpa-converter branch from 43268e1 to a783823 Compare September 19, 2024 05:25
@mshr-h mshr-h force-pushed the fix-torch-sdpa-converter branch from a783823 to a2b29c0 Compare September 19, 2024 09:06
@yongwww yongwww merged commit 85f2cc3 into apache:main Sep 20, 2024
17 of 18 checks passed
@mshr-h mshr-h deleted the fix-torch-sdpa-converter branch September 20, 2024 04:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants