Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dlight] Fix GeMV shared memory estimation #16731

Merged
merged 1 commit into from
Mar 16, 2024

Conversation

MasterJH5574
Copy link
Contributor

Prior to this PR, there is one part missing in the shared memory estimation of the GeMV rule. The GeMV rule optimizes by using cross-thread reduction. When the target does not support warp reduction primitives, the cross-thread reduction will be further lowered to shared memory implementation, which consumes another part of shared memory.

If we do not consider this part in the GeMV rule, it is possible for the total shared memory usage to exceed the target shared memory limit. For example, mlc-ai/mlc-llm#1841 reports an issue on the Vulkan shared memory limit exceed.

This PR fixes the issue by introducing a flag SUPPORT_WARP_SHUFFLE to the GeMV rule. We only enable warp shuffle for CUDA and Metal backend, and turn it off for all other backends. This is basically aligned with the lowering rule of thread allreduce intrinsic.

P.S.. ROCm also supports warp shuffle but has some limitation, where not every set of parameters in the GeMV rule can meet. Therefore, we regard ROCm as "not supported". This just mean we will be conservative in the shared memory usage for ROCm, and does not mean we do not use the warp shuffle when the workload is eligible when lowering.

Prior to this PR, there is one part missing in the shared memory
estimation of the GeMV rule. The GeMV rule optimizes by using
cross-thread reduction. When the target does not support warp
reduction primitives, the cross-thread reduction will be further
lowered to shared memory implementation, which consumes another
part of shared memory.

If we do not consider this part in the GeMV rule, it is possible
for the total shared memory usage to exceed the target shared
memory limit. For example, mlc-ai/mlc-llm#1841 reports an issue
on the Vulkan shared memory limit exceed.

This PR fixes the issue by introducing a flag `SUPPORT_WARP_SHUFFLE`
to the GeMV rule. We only enable warp shuffle for CUDA and Metal
backend, and turn it off for all other backends. This is basically
aligned with the lowering rule of thread allreduce intrinsic.

P.S.. ROCm also supports warp shuffle but has some limitation, where
not every set of parameters in the GeMV rule can meet. Therefore,
we regard ROCm as "not supported". This just mean we will be
conservative in the shared memory usage for ROCm, and does not mean
we do not use the warp shuffle when the workload is eligible
when lowering.
@tqchen tqchen merged commit 1c73491 into apache:main Mar 16, 2024
18 checks passed
thaisacs pushed a commit to thaisacs/tvm that referenced this pull request Apr 3, 2024
Prior to this PR, there is one part missing in the shared memory
estimation of the GeMV rule. The GeMV rule optimizes by using
cross-thread reduction. When the target does not support warp
reduction primitives, the cross-thread reduction will be further
lowered to shared memory implementation, which consumes another
part of shared memory.

If we do not consider this part in the GeMV rule, it is possible
for the total shared memory usage to exceed the target shared
memory limit. For example, mlc-ai/mlc-llm#1841 reports an issue
on the Vulkan shared memory limit exceed.

This PR fixes the issue by introducing a flag `SUPPORT_WARP_SHUFFLE`
to the GeMV rule. We only enable warp shuffle for CUDA and Metal
backend, and turn it off for all other backends. This is basically
aligned with the lowering rule of thread allreduce intrinsic.

P.S.. ROCm also supports warp shuffle but has some limitation, where
not every set of parameters in the GeMV rule can meet. Therefore,
we regard ROCm as "not supported". This just mean we will be
conservative in the shared memory usage for ROCm, and does not mean
we do not use the warp shuffle when the workload is eligible
when lowering.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants