-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Vulkan] Add cooperative matrix support #14817
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
Yeah, hi Mei, it's interesting to know that we worked on this extension around the same time! I'll look through this PR and think about how to integrate our work. BTW, I'm aware that AMD's Windows VK driver supports this extension. Is Linux driver going to get the support for this extension as well? |
Yes, Linux driver supports it. Though let me ask whether it is made into release. |
Linux driver has been released in April with this extension support. |
Do you mean AMDVLK or AMDGPU Pro? I looked into AMDVLK source code but didn't find this extension supported. |
AMDGPU pro. |
@mei-ye Can you split topi / relay / ms change from this PR and send them later? And please add a minimum test case for cooperative matrix, such as the matmul test in my PR. |
That's still open for discussion. If people prefer your approach, I'll close my PR. In which case I want your PR to be easier to review. |
Add SPIR-V code generation for "SPV_NV_cooperative_matrix" extension. Add a matrix multiplicaiton unit test.
@tvm-bot rerun |
ICHECK(ele_dtype.is_float()) << "Only floating point fragment accumulator is supported"; | ||
spirv::SType ele_stype = builder_->GetSType(ele_dtype); | ||
spirv::SType& fragment_type = fragment_info_[buffer_node].stype; | ||
double init = static_cast<uint64_t>(Downcast<FloatImm>(op->args[5])->value); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why cast to uint64?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't recall a good reason. Removing this cast works fine. Should I reset and re-patch?
* | ||
* If support is present, can perform cooperative matrix operations. If | ||
* support is not present, codegen will throw exception on | ||
* attempting to perform cooperative matrix. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perform cooperative matrix operations
I'll address the nit issues in my upcoming PR. |
I've updated my vk 4K matmul test https://github.com/masahi/tensorir-experiment/blob/master/vk_cooperative_matrix_nv/test_4k.py to use this extension support. It gets ~90 TFLOPs on RTX4080. The cool part is that by changing the target from vk to cuda the exact same schedule / script works with the same performance. https://github.com/masahi/tensorir-experiment/blob/master/vk_cooperative_matrix_nv/test_4k.py#L171-L172 |
Auto tensorization on vk seems to work as well, but the result is not correct after tuning. For CUDA the result is correct after auto-tensorization tuning, so this is a VK-specific issue. @mei-ye You can try auto tensorization experiment using my branch https://github.com/masahi/tvm/tree/vk-auto-tensorize and this script https://github.com/masahi/tensorir-experiment/blob/vk-auto-tensorize/vk_cooperative_matrix_nv/test_4k.py. I'm curious if the accuracy issue is specific to NV or applies to AMD as well. |
[Vulkan] Add SPIR-V code generation for "SPV_NV_cooperative_matrix" extension
Add im2col implementation for direct Conv2D. Currently only 16x16x16 FP16 wmma fragments with FP32 intermediates are supported. Add "min_design_space" as a parameter to give minimum design space for meta scheduler tuning. Add "use_int32_const" as a paramter to use int32 type for constants. Allow target query to be called from the schedules so that samplings are constrained to produce legal schedules. Do not allow the reuse of buffers with different dtypes. Add a unit test test_wmma.py.