Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vulkan] Add cooperative matrix support #14817

Merged
merged 1 commit into from
May 20, 2023
Merged

[Vulkan] Add cooperative matrix support #14817

merged 1 commit into from
May 20, 2023

Conversation

mei-ye
Copy link
Contributor

@mei-ye mei-ye commented May 10, 2023

[Vulkan] Add SPIR-V code generation for "SPV_NV_cooperative_matrix" extension
Add im2col implementation for direct Conv2D. Currently only 16x16x16 FP16 wmma fragments with FP32 intermediates are supported. Add "min_design_space" as a parameter to give minimum design space for meta scheduler tuning. Add "use_int32_const" as a paramter to use int32 type for constants. Allow target query to be called from the schedules so that samplings are constrained to produce legal schedules. Do not allow the reuse of buffers with different dtypes. Add a unit test test_wmma.py.

@tvm-bot
Copy link
Collaborator

tvm-bot commented May 10, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: vulkan See #10317 for details

Generated by tvm-bot

@junrushao
Copy link
Member

CC @masahi seems similar to your PR: #14770

@masahi
Copy link
Member

masahi commented May 10, 2023

Yeah, hi Mei, it's interesting to know that we worked on this extension around the same time! I'll look through this PR and think about how to integrate our work.

BTW, I'm aware that AMD's Windows VK driver supports this extension. Is Linux driver going to get the support for this extension as well?

@mei-ye
Copy link
Contributor Author

mei-ye commented May 10, 2023

Yeah, hi Mei, it's interesting to know that we worked on this extension around the same time! I'll look through this PR and think about how to integrate our work.

BTW, I'm aware that AMD's Windows VK driver supports this extension. Is Linux driver going to get the support for this extension as well?

Yes, Linux driver supports it. Though let me ask whether it is made into release.

@mei-ye
Copy link
Contributor Author

mei-ye commented May 10, 2023

Yeah, hi Mei, it's interesting to know that we worked on this extension around the same time! I'll look through this PR and think about how to integrate our work.
BTW, I'm aware that AMD's Windows VK driver supports this extension. Is Linux driver going to get the support for this extension as well?

Yes, Linux driver supports it. Though let me ask whether it is made into release.

Linux driver has been released in April with this extension support.

@masahi
Copy link
Member

masahi commented May 10, 2023

Linux driver has been released in April with this extension support

Do you mean AMDVLK or AMDGPU Pro? I looked into AMDVLK source code but didn't find this extension supported.

@mei-ye
Copy link
Contributor Author

mei-ye commented May 11, 2023

Linux driver has been released in April with this extension support

Do you mean AMDVLK or AMDGPU Pro? I looked into AMDVLK source code but didn't find this extension supported.

AMDGPU pro.

@masahi
Copy link
Member

masahi commented May 14, 2023

@mei-ye Can you split topi / relay / ms change from this PR and send them later? And please add a minimum test case for cooperative matrix, such as the matmul test in my PR.

@mei-ye
Copy link
Contributor Author

mei-ye commented May 15, 2023

@mei-ye Can you split topi / relay / ms change from this PR and send them later? And please add a minimum test case for cooperative matrix, such as the matmul test in my PR.

@masahi : are you going to check in your SPIR-V code gen? In this patch, I added an unit test: test_wmma.py.

@masahi
Copy link
Member

masahi commented May 15, 2023

@mei-ye Can you split topi / relay / ms change from this PR and send them later? And please add a minimum test case for cooperative matrix, such as the matmul test in my PR.

@masahi : are you going to check in your SPIR-V code gen? In this patch, I added an unit test: test_wmma.py.

That's still open for discussion. If people prefer your approach, I'll close my PR. In which case I want your PR to be easier to review.

src/target/spirv/codegen_spirv.cc Outdated Show resolved Hide resolved
src/target/spirv/codegen_spirv.cc Outdated Show resolved Hide resolved
src/target/spirv/ir_builder.cc Show resolved Hide resolved
Add SPIR-V code generation for "SPV_NV_cooperative_matrix" extension. Add a matrix multiplicaiton unit test.
@mei-ye mei-ye reopened this May 19, 2023
@mei-ye mei-ye changed the title [SPIR-V] Add cooperative matrix support. [Vulkan] Add cooperative matrix support May 19, 2023
@masahi
Copy link
Member

masahi commented May 19, 2023

@tvm-bot rerun

ICHECK(ele_dtype.is_float()) << "Only floating point fragment accumulator is supported";
spirv::SType ele_stype = builder_->GetSType(ele_dtype);
spirv::SType& fragment_type = fragment_info_[buffer_node].stype;
double init = static_cast<uint64_t>(Downcast<FloatImm>(op->args[5])->value);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cast to uint64?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't recall a good reason. Removing this cast works fine. Should I reset and re-patch?

*
* If support is present, can perform cooperative matrix operations. If
* support is not present, codegen will throw exception on
* attempting to perform cooperative matrix.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perform cooperative matrix operations

@masahi masahi merged commit b91c2f2 into apache:main May 20, 2023
@masahi
Copy link
Member

masahi commented May 20, 2023

I'll address the nit issues in my upcoming PR.

@tqchen
Copy link
Member

tqchen commented May 20, 2023

awesome, thank you @mei-ye @masahi !

@masahi
Copy link
Member

masahi commented May 27, 2023

I've updated my vk 4K matmul test https://github.com/masahi/tensorir-experiment/blob/master/vk_cooperative_matrix_nv/test_4k.py to use this extension support. It gets ~90 TFLOPs on RTX4080.

The cool part is that by changing the target from vk to cuda the exact same schedule / script works with the same performance. https://github.com/masahi/tensorir-experiment/blob/master/vk_cooperative_matrix_nv/test_4k.py#L171-L172

@masahi
Copy link
Member

masahi commented May 28, 2023

Auto tensorization on vk seems to work as well, but the result is not correct after tuning. For CUDA the result is correct after auto-tensorization tuning, so this is a VK-specific issue.

@mei-ye You can try auto tensorization experiment using my branch https://github.com/masahi/tvm/tree/vk-auto-tensorize and this script https://github.com/masahi/tensorir-experiment/blob/vk-auto-tensorize/vk_cooperative_matrix_nv/test_4k.py. I'm curious if the accuracy issue is specific to NV or applies to AMD as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants