-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MFMA] Support 64x4 and 4x64 tile size #469
Conversation
ab65f73
to
ae2eff0
Compare
@@ -0,0 +1,800 @@ | |||
// RUN: (! triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul=arch-generation-name=gfx908 --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null) | FileCheck --check-prefixes=CHECK %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file (and other 2) is generated,
I will add scripts in next PR, just don't want to add too much things in one PR
ae2eff0
to
02fd385
Compare
02fd385
to
edada86
Compare
This PR enables two new MxN tile sizes: 64 x 4 and 4 x 64. Both of them uses mfma 4x4 instructions.
edada86
to
ceece88
Compare
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, | ||
// mfma_f32_4x4x1f32 | ||
{{4, MfmaTypeId::Fp32TyId, 1}, | ||
{{4, 4, MfmaTypeId::Fp32TyId, 1}, | ||
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I made a mistake here.
It should be 4 , 4, 1, 1 instead of 4, 4, 16, 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I realized this is for (4x16) x (16x4) --> 4x4. nvm
lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp
Show resolved
Hide resolved
@@ -1233,10 +1239,15 @@ class ConvertTritonGPUOpToLLVMPatternBase { | |||
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); | |||
|
|||
SmallVector<unsigned> numWarpsPerDim(2); | |||
unsigned mDim = mfmaLayout.getMDim(); | |||
unsigned nDim = mfmaLayout.getNDim(); | |||
assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this kind of assert everywhere? Can we only check once at the earliest from the codegen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, this code is redundant most of the time. It just helps add new layouts and do not forget anything crucial.
You can set new (unsupported) m/n combination in accelerate matmul pass, run tests and see where these asserts fire.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this PR is good to go.
@scxiao @vgokhale
However, in the case of FA decode kernel, the tile shape is 16x128. And mfma16 will be picked according to the heuristic. @alefimov-amd In the next PR, can you change |
Can you specify a unique value of |
sure
This is useful idea, thank you! |
This PR refactors the logic of mfma instruction selection. It brings everything from ROCm#441 and parts of ROCm#469 so that we should have full support of mfma32 and mfma16 with all types. But support for mfma4 is not complete yet. We leave it to future PRs. Also in a future PR, we'll add tests for AMD f8 inputs.
This PR updates SharedToDotOperandMFMA.cpp and MFMA.cpp. - SharedToDotOperandMFMA.cpp is up to date with triton-mlir as of today, which includes changes until ROCm#482 - Fixed issue with opaque pointers - Fixed API for `getMFMAElemsPerInstrForOperands` and `getMFMARepForOperands` - MFMA.cpp is synced with triton-mlir@6bb04d, which includes changes until ROCm#469 Note to @binarman: changes in other files from ROCm#469 are not included in this PR. We can bring up the support for mfma 64x4 and 4x64 later.
This PR refactors the logic of mfma instruction selection. It brings everything from ROCm#441 and parts of ROCm#469 so that we should have full support of mfma32 and mfma16 with all types. But support for mfma4 is not complete yet. We leave it to future PRs. Also in a future PR, we'll add tests for AMD f8 inputs.
This PR updates SharedToDotOperandMFMA.cpp and MFMA.cpp. - SharedToDotOperandMFMA.cpp is up to date with triton-mlir as of today, which includes changes until ROCm#482 - Fixed issue with opaque pointers - Fixed API for `getMFMAElemsPerInstrForOperands` and `getMFMARepForOperands` - MFMA.cpp is synced with triton-mlir@6bb04d, which includes changes until ROCm#469 Note to @binarman: changes in other files from ROCm#469 are not included in this PR. We can bring up the support for mfma 64x4 and 4x64 later.
This PR enables two new MxN tile sizes: 64 x 4 and 4 x 64. Both of them uses mfma 4x4 instructions.