Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround for matmul kernel crash with i8xf32 operands. #12

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

3gx
Copy link

@3gx 3gx commented Aug 29, 2024

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

     The BlockedToMMA pass creates a layout with kWidth=4 when one operand is
     i8. However, the TritonGPU to LLVM lowering pass does not support
     lowering f32 with kWidth=4, which is the other operand, causing a
     segmentation fault.
    
     To work around this, if the operands' minBitWidth is 8 and maxBitWidth
     is 32, we use a minBitWidth of 16 instead of 8, creating a layout with
     kWidth=2 for both i8 and f32 operands.
    
    
  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /python/test for end-to-end tests
  • Select one of the following.

    • I have not added any lit tests.

Copy link

google-cla bot commented Aug 29, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@3gx
Copy link
Author

3gx commented Aug 30, 2024

I signed CLA.

@gflegar gflegar requested review from Moerafaat and chsigg September 3, 2024 17:46
@Moerafaat
Copy link
Member

This change, as mentioned in the title, would only work-around the issue but not fix it. Effectively what this is doing is it removes mixed-precision behavior for any matmuls with s8. Also the current change would regress s8 x . The ideal way we would hope to handle the issue is to fix the limitations of Triton during its lowering to LLVM, and still allow proper mixed-precision mma to happen.

@3gx
Copy link
Author

3gx commented Sep 5, 2024

Could you elaborate on what you mean by "removes mixed-precision behavior for any matmuls with s8"? I ask because the lowered code contains a cast from i8 to f32 before feeding data to the tf32 mma op, which is necessary since the other operand is already f32. Could you also clarify what you mean by "the current change would regress s8 x "? Perhaps you could provide an example to illustrate this point? Thank you.

@Moerafaat
Copy link
Member

Moerafaat commented Sep 9, 2024

My apologies for replying late.

Regarding s8 x : We can consider the example of s8 x f16:
This change will cause the "kwidth" attribute (can be inspected in MLIR in the AccelerateMatmul pass if you dump the MLIR) to be different before and after the change. The value before will be equal to 4, while after it will be equal to 2. This will affect how the data is loaded.
I have attached the LLVM IR before and after the change for you to inspect given the HLO below.

HloModule m

ENTRY e {
  p0 = s8[16,32] parameter(0)
  p0c = bf16[16,32] convert(p0)
  p1 = bf16[32,8] parameter(1)
  ROOT _ = bf16[16,8] dot(p0c, p1),
    lhs_contracting_dims={1}, rhs_contracting_dims={0}
})

I haven't looked deeply into the performance impact, but it is clear that the change is not local.
llvm-after-change.txt
llvm-before-change.txt

As you can see the change will impact other use-cases. I'm not sure whether what the performance impact is (would be nice if you profile it). The constraints could be tighter to only match on s8 x f32 combinations, but that would still be working around the issue.
I hope this explains it a bit more.

The BlockedToMMA pass creates a layout with kWidth=4 when one operand is
i8. However, the TritonGPU to LLVM lowering pass does not support
lowering f32 with kWidth=4, which is the other operand, causing a
segmentation fault.

To work around this, if the operands' minBitWidth is 8 and maxBitWidth
is 32, we use a minBitWidth of 16 instead of 8, creating a layout with
kWidth=2 for both i8 and f32 operands.
@3gx 3gx force-pushed the xla/egx/bug-2853-v1 branch from 3ff8088 to b313a8b Compare September 10, 2024 13:25
@3gx 3gx changed the title Workaround for matmul kernel crash with i8 operand Workaround for matmul kernel crash with i8xf32 operands. Sep 10, 2024
@3gx
Copy link
Author

3gx commented Sep 10, 2024

Thank you for the details. I think I understand the issue with the proposed workaround. I have updated this MR with changes that should not affect other mixed-precision matrix multiplications. I verified that the i8xf16 kWidth remains 4 with this workaround.

The issue stems from the LLVM lowering pass not supporting f32 with kWidth=4 when lowering for Ampere tensor cores. I am not familiar with Ampere tensor cores and cannot estimate the effort required to fix the issue in the lowering pass.

@Moerafaat
Copy link
Member

Thank you for the modifications. Currently there are discussions whether we would proceed with a work-around or not. I will get back to you once there is a decision.

@gflegar
Copy link
Member

gflegar commented Oct 17, 2024

Unfortunately, a workaround is not something we can accept for this issue, and would need a proper fix here.

We already have a different workaround internally, and the performance benefits we would gain from this do not outweigh the cost of maintaining a patch on top of upstream.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants