[BugFix][TIR] Fix multi-grouped multi-warp allreduce #15399
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
PR #15327 and #15373 introduced multi-warp allreduce implementation. At the time of the introduction, I tested the correctness numerically via the workload of "taking a matrix of ones as input, computing the summation over each row". Both PR passed this numerical tess, while I didn't realize that this test is not complete and cannot guarantee the correctness.
The previous implementation has bug which can be tested by turning the input matrix from ones to random floating-point numbers. This will expose the issues of the previous implementation.
Therefore, this PR fixes the issues, and add the numerical tests for multi-warp allreduce into
test_allreduce_cuda.py
. By reducing some of the redundant tests in that file, we hope this can reduce the testing time a bit while still guarantee the correctness.Sorry for not testing the implementation completely before.