-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Backend] Refactor mfma selection #441
Conversation
277515d
to
7573bd2
Compare
597edd4
to
337e878
Compare
337e878
to
83bcc25
Compare
83bcc25
to
d0c32d5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks promising, but it conflicts with #432 in many places.
@alefimov-amd |
3993100
to
ec07fe8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We discussed my concerns wth @zhanglx13
LGTM if this PR is needed to merge ASAP
ec07fe8
to
309dfd6
Compare
if (mfmaLayout.getNonKDim() == 32) { | ||
threads = {32 * mfmaLayout.getWarpsPerCTA()[0], | ||
2 * mfmaLayout.getWarpsPerCTA()[1]}; | ||
int mfmaMDim = mfmaLayout.getMDim(); | ||
SmallVector<unsigned> threadsPerWarp; | ||
if (32 == mfmaMDim) { | ||
threadsPerWarp = {2, 32}; | ||
} else { | ||
threads = {16 * mfmaLayout.getWarpsPerCTA()[0], | ||
4 * mfmaLayout.getWarpsPerCTA()[1]}; | ||
threadsPerWarp = {4, 16}; | ||
} | ||
if (mfmaLayout.getIsTransposed()) | ||
threads = {threadsPerWarp[1] * mfmaLayout.getWarpsPerCTA()[0], | ||
threadsPerWarp[0] * mfmaLayout.getWarpsPerCTA()[1]}; | ||
else | ||
threads = {threadsPerWarp[0] * mfmaLayout.getWarpsPerCTA()[0], | ||
threadsPerWarp[1] * mfmaLayout.getWarpsPerCTA()[1]}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part changes behavior, is it intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alefimov-amd Yes.
Previously the code did not take isTransposed
into consideration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
309dfd6
to
e7943c7
Compare
This reverts commit 02a2f24.
This reverts commit 02a2f24.
This reverts commit 02a2f24.
This reverts commit 02a2f24.
This reverts commit 02a2f24.
* Select mfma dimensions and instruction from static table * Extend mfmaLayout to include version and instrShape * Simplify generateMFMAOp by searching the mfma instruction in the table * Fix getNonKDim() and non_k_dim * Break instrShape into MDim and NDim
- Add `versionMajor`, `versionMinor`, `MDim`, and `NDim` to MfmaEncodingAttr and remove `nonKDim` - `versionMajor` is used to pick arch-specific mfma instructions. - `versionMinor` is not used for now. - `MDim` and `NDim` indicate the two dimensions of the output of - mfma instruction. This helps the mfma4 case where the output tensor has shape 4 x 64.
- Add `versionMajor`, `versionMinor`, `MDim`, and `NDim` to MfmaEncodingAttr and remove `nonKDim` - `versionMajor` is used to pick arch-specific mfma instructions. - `versionMinor` is not used for now. - `MDim` and `NDim` indicate the two dimensions of the output of - mfma instruction. This helps the mfma4 case where the output tensor has shape 4 x 64.
- Add mfma instruction table. The proper mfma instruction can be selected by - mDim and nDim - elemType: this includes both typeA and typeB, since some mfma instructions support mixed input types, such as bf8 and fp8 - mfmaVersion: arch version - Clean up and rename Utility.h to MfmaGroup.h
- Add mfma instruction table. The proper mfma instruction can be selected by - mDim and nDim - elemType: this includes both typeA and typeB, since some mfma instructions support mixed input types, such as bf8 and fp8 - mfmaVersion: arch version - Clean up and rename Utility.h to MfmaGroup.h
- Add `versionMajor`, `versionMinor`, `MDim`, and `NDim` to MfmaEncodingAttr and remove `nonKDim` - `versionMajor` is used to pick arch-specific mfma instructions. - `versionMinor` is not used for now. - `MDim` and `NDim` indicate the two dimensions of the output of - mfma instruction. This helps the mfma4 case where the output tensor has shape 4 x 64.
- Add mfma instruction table. The proper mfma instruction can be selected by - mDim and nDim - elemType: this includes both typeA and typeB, since some mfma instructions support mixed input types, such as bf8 and fp8 - mfmaVersion: arch version - Clean up and rename Utility.h to MfmaGroup.h
- Add `versionMajor`, `versionMinor`, `MDim`, and `NDim` to MfmaEncodingAttr and remove `nonKDim` - `versionMajor` is used to pick arch-specific mfma instructions. - `versionMinor` is not used for now. - `MDim` and `NDim` indicate the two dimensions of the output of - mfma instruction. This helps the mfma4 case where the output tensor has shape 4 x 64.
- Add mfma instruction table. The proper mfma instruction can be selected by - mDim and nDim - elemType: this includes both typeA and typeB, since some mfma instructions support mixed input types, such as bf8 and fp8 - mfmaVersion: arch version - Clean up and rename Utility.h to MfmaGroup.h
This PR refactors the logic of mfma instruction selection. It brings everything from ROCm#441 and parts of ROCm#469 so that we should have full support of mfma32 and mfma16 with all types. But support for mfma4 is not complete yet. We leave it to future PRs. Also in a future PR, we'll add tests for AMD f8 inputs.
This PR refactors the logic of mfma instruction selection. It brings everything from ROCm#441 and parts of ROCm#469 so that we should have full support of mfma32 and mfma16 with all types. But support for mfma4 is not complete yet. We leave it to future PRs. Also in a future PR, we'll add tests for AMD f8 inputs.
This PR refactors the logic of MFMA selection
versionMajor
,versionMinor
,MDim
, andNDim
to MfmaEncodingAttr and removenonKDim
versionMajor
is used to pick arch-specific mfma instructions.versionMinor
is not used for now.MDim
andNDim
indicate the two dimensions of the output of mfma instruction. This helps the mfma4 case where the output tensor has shape 4 x 64.AccelerateAMDMatmul
pass andDotOpToLLVM
lowering passmfmaVersion
comes from the frontend inAccelerateAMDMatmul
, and it is kept in the mfmaLayout.mfmaVersion
comes from the mfmaLayout inDotOpToLLVM
TODO:
versionMajor=2
limits the codegen to MI2xx GPUs. We might want to extend/replace it to/withversionMajor=3
for MI300 GPUs.