[Dlight] Fix GeMV shared memory estimation #16731
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Prior to this PR, there is one part missing in the shared memory estimation of the GeMV rule. The GeMV rule optimizes by using cross-thread reduction. When the target does not support warp reduction primitives, the cross-thread reduction will be further lowered to shared memory implementation, which consumes another part of shared memory.
If we do not consider this part in the GeMV rule, it is possible for the total shared memory usage to exceed the target shared memory limit. For example, mlc-ai/mlc-llm#1841 reports an issue on the Vulkan shared memory limit exceed.
This PR fixes the issue by introducing a flag
SUPPORT_WARP_SHUFFLE
to the GeMV rule. We only enable warp shuffle for CUDA and Metal backend, and turn it off for all other backends. This is basically aligned with the lowering rule of thread allreduce intrinsic.P.S.. ROCm also supports warp shuffle but has some limitation, where not every set of parameters in the GeMV rule can meet. Therefore, we regard ROCm as "not supported". This just mean we will be conservative in the shared memory usage for ROCm, and does not mean we do not use the warp shuffle when the workload is eligible when lowering.