Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backend] Refactor mfma selection #441

Merged
merged 5 commits into from
Jan 17, 2024
Merged

Conversation

zhanglx13
Copy link

@zhanglx13 zhanglx13 commented Dec 31, 2023

This PR refactors the logic of MFMA selection

  • Put the mapping from {mfmaVersion, elemTy, nonKDim} to mfmaInfo into a static table
  • Add versionMajor, versionMinor, MDim, and NDim to MfmaEncodingAttr and remove nonKDim
    • versionMajor is used to pick arch-specific mfma instructions.
    • versionMinor is not used for now.
    • MDim and NDim indicate the two dimensions of the output of mfma instruction. This helps the mfma4 case where the output tensor has shape 4 x 64.
  • Search the table in the same way in both AccelerateAMDMatmul pass and DotOpToLLVM lowering pass
    • The mfmaVersion comes from the frontend in AccelerateAMDMatmul, and it is kept in the mfmaLayout.
    • The mfmaVersion comes from the mfmaLayout in DotOpToLLVM

TODO:

  • Some logic in TritonGPU/IR/Dialect.cpp about mfma4 needs revisiting
  • Now versionMajor=2 limits the codegen to MI2xx GPUs. We might want to extend/replace it to/with versionMajor=3 for MI300 GPUs.

@zhanglx13 zhanglx13 force-pushed the refactor_mfma_selection branch 5 times, most recently from 277515d to 7573bd2 Compare January 2, 2024 17:14
@zhanglx13 zhanglx13 changed the title [Backend] Refactor mfma dim selection [Backend] Refactor mfma selection Jan 2, 2024
@zhanglx13 zhanglx13 force-pushed the refactor_mfma_selection branch 4 times, most recently from 597edd4 to 337e878 Compare January 6, 2024 02:54
@zhanglx13 zhanglx13 force-pushed the refactor_mfma_selection branch from 337e878 to 83bcc25 Compare January 9, 2024 21:55
@zhanglx13 zhanglx13 requested a review from jayfurmanek January 10, 2024 04:11
@zhanglx13 zhanglx13 force-pushed the refactor_mfma_selection branch from 83bcc25 to d0c32d5 Compare January 11, 2024 14:50
Copy link

@alefimov-amd alefimov-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks promising, but it conflicts with #432 in many places.

@zhanglx13
Copy link
Author

@alefimov-amd
I'd suggest we merge this one first since it lays a solid foundation for mfma instructions. You can rebase your mfma4x64 PR and make changes accordingly. I believe this can benefit us in the long run.

@zhanglx13 zhanglx13 force-pushed the refactor_mfma_selection branch 2 times, most recently from 3993100 to ec07fe8 Compare January 15, 2024 21:02
Copy link

@alefimov-amd alefimov-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed my concerns wth @zhanglx13

LGTM if this PR is needed to merge ASAP

@zhanglx13 zhanglx13 changed the title [Backend] Refactor mfma selection [DO NOT MERGE][Backend] Refactor mfma selection Jan 15, 2024
@zhanglx13 zhanglx13 force-pushed the refactor_mfma_selection branch from ec07fe8 to 309dfd6 Compare January 16, 2024 02:32
@zhanglx13 zhanglx13 changed the title [DO NOT MERGE][Backend] Refactor mfma selection [Backend] Refactor mfma selection Jan 16, 2024
Comment on lines -352 to +365
if (mfmaLayout.getNonKDim() == 32) {
threads = {32 * mfmaLayout.getWarpsPerCTA()[0],
2 * mfmaLayout.getWarpsPerCTA()[1]};
int mfmaMDim = mfmaLayout.getMDim();
SmallVector<unsigned> threadsPerWarp;
if (32 == mfmaMDim) {
threadsPerWarp = {2, 32};
} else {
threads = {16 * mfmaLayout.getWarpsPerCTA()[0],
4 * mfmaLayout.getWarpsPerCTA()[1]};
threadsPerWarp = {4, 16};
}
if (mfmaLayout.getIsTransposed())
threads = {threadsPerWarp[1] * mfmaLayout.getWarpsPerCTA()[0],
threadsPerWarp[0] * mfmaLayout.getWarpsPerCTA()[1]};
else
threads = {threadsPerWarp[0] * mfmaLayout.getWarpsPerCTA()[0],
threadsPerWarp[1] * mfmaLayout.getWarpsPerCTA()[1]};

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part changes behavior, is it intentional?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alefimov-amd Yes.
Previously the code did not take isTransposed into consideration.

Copy link

@alefimov-amd alefimov-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhanglx13 zhanglx13 force-pushed the refactor_mfma_selection branch from 309dfd6 to e7943c7 Compare January 17, 2024 02:27
@zhanglx13 zhanglx13 merged commit 02a2f24 into triton-mlir Jan 17, 2024
2 checks passed
oplavsic pushed a commit that referenced this pull request Jan 17, 2024
oplavsic pushed a commit that referenced this pull request Jan 17, 2024
oplavsic pushed a commit that referenced this pull request Jan 17, 2024
oplavsic pushed a commit that referenced this pull request Jan 18, 2024
oplavsic pushed a commit that referenced this pull request Jan 22, 2024
jtang10 pushed a commit that referenced this pull request Jan 30, 2024
* Select mfma dimensions and instruction from static table

* Extend mfmaLayout to include version and instrShape

* Simplify generateMFMAOp by searching the mfma instruction in the table

* Fix getNonKDim() and non_k_dim

* Break instrShape into MDim and NDim
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Feb 29, 2024
- Add `versionMajor`, `versionMinor`, `MDim`, and `NDim` to MfmaEncodingAttr and remove `nonKDim`
  - `versionMajor` is used to pick arch-specific mfma instructions.
  - `versionMinor` is not used for now.
  - `MDim` and `NDim` indicate the two dimensions of the output of
  - mfma instruction. This helps the mfma4 case where the output tensor has shape 4 x 64.
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Feb 29, 2024
- Add `versionMajor`, `versionMinor`, `MDim`, and `NDim` to MfmaEncodingAttr and remove `nonKDim`
  - `versionMajor` is used to pick arch-specific mfma instructions.
  - `versionMinor` is not used for now.
  - `MDim` and `NDim` indicate the two dimensions of the output of
  - mfma instruction. This helps the mfma4 case where the output tensor has shape 4 x 64.
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Feb 29, 2024
- Add mfma instruction table. The proper mfma instruction can be
selected by
  - mDim and nDim
  - elemType: this includes both typeA and typeB, since some mfma
  instructions support mixed input types, such as bf8 and fp8
  - mfmaVersion: arch version
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Feb 29, 2024
- Add mfma instruction table. The proper mfma instruction can be
selected by
  - mDim and nDim
  - elemType: this includes both typeA and typeB, since some mfma
  instructions support mixed input types, such as bf8 and fp8
  - mfmaVersion: arch version
- Clean up and rename Utility.h to MfmaGroup.h
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Feb 29, 2024
- Add mfma instruction table. The proper mfma instruction can be
selected by
  - mDim and nDim
  - elemType: this includes both typeA and typeB, since some mfma
  instructions support mixed input types, such as bf8 and fp8
  - mfmaVersion: arch version
- Clean up and rename Utility.h to MfmaGroup.h
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Feb 29, 2024
- Refactor mfma selection log in shape selection phase
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Feb 29, 2024
- Refactor mfma selection log in codegen phase
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Feb 29, 2024
- Add `versionMajor`, `versionMinor`, `MDim`, and `NDim` to MfmaEncodingAttr and remove `nonKDim`
  - `versionMajor` is used to pick arch-specific mfma instructions.
  - `versionMinor` is not used for now.
  - `MDim` and `NDim` indicate the two dimensions of the output of
  - mfma instruction. This helps the mfma4 case where the output tensor has shape 4 x 64.
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Feb 29, 2024
- Add mfma instruction table. The proper mfma instruction can be
selected by
  - mDim and nDim
  - elemType: this includes both typeA and typeB, since some mfma
  instructions support mixed input types, such as bf8 and fp8
  - mfmaVersion: arch version
- Clean up and rename Utility.h to MfmaGroup.h
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Feb 29, 2024
- Refactor mfma selection log in shape selection phase
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Feb 29, 2024
- Refactor mfma selection log in codegen phase
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Mar 1, 2024
- Add `versionMajor`, `versionMinor`, `MDim`, and `NDim` to MfmaEncodingAttr and remove `nonKDim`
  - `versionMajor` is used to pick arch-specific mfma instructions.
  - `versionMinor` is not used for now.
  - `MDim` and `NDim` indicate the two dimensions of the output of
  - mfma instruction. This helps the mfma4 case where the output tensor has shape 4 x 64.
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Mar 1, 2024
- Add mfma instruction table. The proper mfma instruction can be
selected by
  - mDim and nDim
  - elemType: this includes both typeA and typeB, since some mfma
  instructions support mixed input types, such as bf8 and fp8
  - mfmaVersion: arch version
- Clean up and rename Utility.h to MfmaGroup.h
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Mar 1, 2024
- Refactor mfma selection log in shape selection phase
zhanglx13 added a commit to triton-lang/triton that referenced this pull request Mar 1, 2024
- Refactor mfma selection log in codegen phase
ThomasRaoux pushed a commit to triton-lang/triton that referenced this pull request Mar 1, 2024
This PR refactors the logic of mfma instruction selection. It brings
everything from ROCm#441 and parts of
ROCm#469 so that we should have full
support of mfma32 and mfma16 with all types. But support for mfma4 is
not complete yet. We leave it to future PRs.
Also in a future PR, we'll add tests for AMD f8 inputs.
binarman pushed a commit to binarman/triton that referenced this pull request Apr 2, 2024
This PR refactors the logic of mfma instruction selection. It brings
everything from ROCm#441 and parts of
ROCm#469 so that we should have full
support of mfma32 and mfma16 with all types. But support for mfma4 is
not complete yet. We leave it to future PRs.
Also in a future PR, we'll add tests for AMD f8 inputs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants