-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OPTIMIZER] Take numWarps into account for Hopper mma op #2956
Conversation
5491dbb
to
768f822
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense to me. Could you add a simple lit test?
Thanks for the review! I will work on adding a test, yes - but it seems one of the tests is broken - https://github.com/openai/triton/actions/runs/7556918956/job/20575017710#step:10:28148 I understand the assertion that checks if the ptx has the correct wgmma instruction needs to be updated, but there's some that are failing with incorrect results here https://github.com/openai/triton/blob/main/python/test/unit/hopper/test_gemm.py#L475:
Any ideas how this could be? |
Our Hopper CI is currently broken due to environment problems in the CI bot, this should be fixed soon |
Ci should be fixed. Can you restore the changes? |
Done, still seems to be failing with
|
those are not related to your changes? |
I think they are, I was just curious if you had a hunch as to why this change would affect that :/ . I will take another look tomorrow if not. |
ah, not sure, I can't tell from the just the log. |
2b24b11
to
c70573a
Compare
I was finally able to reproduce it locally, and it actually only happens when ENABLE_TMA=1, so I am looking at what this does and seeing if there's something that needs to be updated there. |
2abd29f
to
47564cf
Compare
This changes the wgmma instruction based on the total number of warps. Using m's shape, it calculates how many warps will be used in the m dimension, then see how many are left for the n dimension. Then, it chooses the largest N such that it is still evenly distributed. This resolves issue triton-lang#2662.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
@ThomasRaoux could you take a look if this is good to be merged?
Regarding the failing test that we used to discuss above:
- It was TMA-specific, so the test was removed with TMA support.
- We were concerned that this PR is somehow still triggering the reduction bug, but Tori figured out that it is actually completely orthogonal to this PR in Reduction Op on MMA Layout produces incorrect results #3467 (comment), and happens just the same on main as well. We'll work on that one separately, as we discussed on the issue.
Sorry for the delay on this PR. The change looks fine to me however I'm wondering if this will cause performance regressions if we have a chain of dot ops like in attention and one of them forces the N dimension to be distributed across multiple warps. Is this something you have looked at? I think one solution is to land that for now and revert it later if it turns out there are such cases in real life workloads. |
I'll merge this but as mentioned above if we end up needing a more complex heuristic we may have to revert it. |
I don't think this should cause performance regressions. The only thing this does is uses all the warps we have available in a block, instead of potentially keeping some of them idle, which is what happened before this landed. If there is a kernel that becomes slower after this, this is an indication that it was already using too many warps, and the right fix would be to just make it use fewer warps - we end up doing the same work per warp, but without having idle warps needlessly consume resources. |
This changes the
wgmma
instruction shape based on the total number of warps.Instead of always using the largest version of
wgmma
, it honorsthe user's
numWarps
hint, and uses a smallerwgmma
shape todistribute the work to all warps, rather than having some of them
idle.
Using m's shape, it calculates how many warps will be used in the m
dimension, then see how many are left for the n dimension. Then, it
chooses the largest N such that it is still evenly distributed.
This resolves issue #2662.