Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MFMA] Support 64x4 and 4x64 tile size #469

Merged
merged 3 commits into from
Jan 22, 2024

Conversation

binarman
Copy link

This PR enables two new MxN tile sizes: 64 x 4 and 4 x 64. Both of them uses mfma 4x4 instructions.

@@ -0,0 +1,800 @@
// RUN: (! triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul=arch-generation-name=gfx908 --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null) | FileCheck --check-prefixes=CHECK %s
Copy link

@alefimov-amd alefimov-amd Jan 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file (and other 2) is generated,
I will add scripts in next PR, just don't want to add too much things in one PR

This PR enables two new MxN tile sizes: 64 x 4 and 4 x 64.
Both of them uses mfma 4x4 instructions.
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}},
// mfma_f32_4x4x1f32
{{4, MfmaTypeId::Fp32TyId, 1},
{{4, 4, MfmaTypeId::Fp32TyId, 1},
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}},
Copy link

@zhanglx13 zhanglx13 Jan 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I made a mistake here.
It should be 4 , 4, 1, 1 instead of 4, 4, 16, 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I realized this is for (4x16) x (16x4) --> 4x4. nvm

@@ -1233,10 +1239,15 @@ class ConvertTritonGPUOpToLLVMPatternBase {
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();

SmallVector<unsigned> numWarpsPerDim(2);
unsigned mDim = mfmaLayout.getMDim();
unsigned nDim = mfmaLayout.getNDim();
assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) ||

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this kind of assert everywhere? Can we only check once at the earliest from the codegen?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, this code is redundant most of the time. It just helps add new layouts and do not forget anything crucial.

You can set new (unsupported) m/n combination in accelerate matmul pass, run tests and see where these asserts fire.

Copy link

@zhanglx13 zhanglx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR is good to go.

@zhanglx13
Copy link

@scxiao @vgokhale
The heuristic for picking mfma instruction size is as follows

  1. If the result tile shape is larger than 32x32, pick mfma32
  2. If the tile shape is smaller than 32x32 but larger than 16x16, pick mfma16
  3. if the tile shape is smaller than 4x64 or 64x4, pick mfma4x4
  4. Otherwise, pick mfma4x64 or mfma64x4

However, in the case of FA decode kernel, the tile shape is 16x128. And mfma16 will be picked according to the heuristic.
The tile shape refers to the result tensor shape of tt.dot. This heuristic does not take num_warps into consideration. But we do not have warp layout information when choosing mfma dimensions. Therefore, the only solution here is enable some user input to enforce the choice of mfma4x64 here.

@alefimov-amd In the next PR, can you change chooseMfmaDimensions to pick 4x64 or 64x4 based on the tile shape when matrix_instr_nonkdim is 4?

@scxiao
Copy link

scxiao commented Jan 22, 2024

chooseMfmaDimensions

Can you specify a unique value of chooseMfmaDimensions to choose mfma4x64 and mfma64x4, like 464 and 644?

@binarman
Copy link
Author

In the next PR, can you change chooseMfmaDimensions to pick 4x64 or 64x4 based on the tile shape when matrix_instr_nonkdim is 4?

sure

Can you specify a unique value of chooseMfmaDimensions to choose mfma4x64 and mfma64x4, like 464 and 644?

This is useful idea, thank you!

@alefimov-amd alefimov-amd merged commit 6bb04d1 into ROCm:triton-mlir Jan 22, 2024
2 checks passed
ThomasRaoux pushed a commit to triton-lang/triton that referenced this pull request Mar 1, 2024
This PR refactors the logic of mfma instruction selection. It brings
everything from ROCm#441 and parts of
ROCm#469 so that we should have full
support of mfma32 and mfma16 with all types. But support for mfma4 is
not complete yet. We leave it to future PRs.
Also in a future PR, we'll add tests for AMD f8 inputs.
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Mar 4, 2024
This PR updates SharedToDotOperandMFMA.cpp and MFMA.cpp.
- SharedToDotOperandMFMA.cpp is up to date with triton-mlir as of today,
which includes changes until ROCm#482
  - Fixed issue with opaque pointers
- Fixed API for `getMFMAElemsPerInstrForOperands` and
`getMFMARepForOperands`
- MFMA.cpp is synced with triton-mlir@6bb04d, which includes changes
until ROCm#469

Note to @binarman: changes in other files from
ROCm#469 are not included in this PR. We
can bring up the support for mfma 64x4 and 4x64 later.
binarman pushed a commit to binarman/triton that referenced this pull request Apr 2, 2024
This PR refactors the logic of mfma instruction selection. It brings
everything from ROCm#441 and parts of
ROCm#469 so that we should have full
support of mfma32 and mfma16 with all types. But support for mfma4 is
not complete yet. We leave it to future PRs.
Also in a future PR, we'll add tests for AMD f8 inputs.
binarman pushed a commit to binarman/triton that referenced this pull request Apr 2, 2024
This PR updates SharedToDotOperandMFMA.cpp and MFMA.cpp.
- SharedToDotOperandMFMA.cpp is up to date with triton-mlir as of today,
which includes changes until ROCm#482
  - Fixed issue with opaque pointers
- Fixed API for `getMFMAElemsPerInstrForOperands` and
`getMFMARepForOperands`
- MFMA.cpp is synced with triton-mlir@6bb04d, which includes changes
until ROCm#469

Note to @binarman: changes in other files from
ROCm#469 are not included in this PR. We
can bring up the support for mfma 64x4 and 4x64 later.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants