-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD][Navi31] Support WMMA transformations in AccelerateAMDMatmul pass #3309
[AMD][Navi31] Support WMMA transformations in AccelerateAMDMatmul pass #3309
Conversation
Can be merged after This PR enables whole pipeline for wmma lowering |
51b1bf2
to
287d227
Compare
|
||
MatrixCoreVersion getMatrixCoreVersion(StringRef archGen) { | ||
if (archGen.contains("gfx11")) | ||
return MatrixCoreVersion::RDNA_WMMA; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason why we want to pass archGen
all the way here to decide on matrix core version ? I thought it's enough to use matrix_core_version
from frontend. (cc @zhanglx13 )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RDNA3 arch introduces wmma
instructions that are not technically a new version of matrix core (it is only applicable for CDNA). So it couldn't be described by matrix_core_version
parameter.
Reserving some value for wmma
could be dangerous for future versions of CDNA, thats why I'm passing archGen
.
If you have any suggestions please share..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One way to unify wmma and mfma in terms of version is to use the gfx number directly, which corresponds to how AMD tracks hw version today and in the future.
And since on AMD path, we don't use capability anyway, we can just use the gfx number as the version
here:
- 908 <-- matrix_core_version 1
- 90a <-- matrix_core_version 2
- 940, 941, 942 <-- matrix_core_version 3
- 1100 <-- wmma
@zahimoud what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want int
, we can use 910 for gfx90a.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I'm fine with passing arch to backend, I would just pass it to TargetInfo
object, and in the constructor we can map to an enum class of the arch (probably MatrixCoreVersion
) and we would have an api like taretInfo.getMfmaVersion()
to get the mfma version directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, now it is a part of TritonAMDGPUToLLVM CMake target
I think it would be better to create new target AMDTargetInfo instead of adding dependency between TritonAMDGPUTransforms and TritonAMDGPUToLLVM ... WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea let's move it to amd/include, would that work ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see, we also need to move Utility to the separate target. May be there are other nvidia specific pitfalls.
Would it be Ok to do it in separate PR? I can create an issue for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which utility are you referring to and why do you think we should move it ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/openai/triton/blob/main/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp this one should be moved to separate target to avoid cyclic dependency.
I think it is really need to be done, but this PR is not about it.
5964c3a
to
5603f47
Compare
third_party/amd/backend/compiler.py
Outdated
if(self.matrix_core_version == -1): | ||
object.__setattr__(self, 'matrix_core_version', self.get_matrix_core_version(self.arch)) | ||
if(self.mfma_version == -1): | ||
object.__setattr__(self, 'mfma_version', self.get_mfma_version(self.arch)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So on Navi arch, this is set to mfma_version = 0. Then the stream-pipeline pass is disabled.
Is this the expected behavior on Navi?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Thanks.
Just checked again. Stream pipeline works fine.
Please, check updated https://github.com/openai/triton/pull/3309/files#diff-33c9a103282c05c9d9d213b94450ae7481b6db8c3c6d810f54f175b4735a3c72R39-R44
-Change input option for AccelerateAMDMatmul pass from matrix-core-version to arch-generation-name -Transform dot -Transform dot operands to required layouts -Support datatypes convertion -Enable stream pipeline pass for WMMA -Add lit test for AccelerateAAMDMatmul pass for WMMA case Signed-off-by: joviliast <iveselov.nn@gmail.com>
1ce9a46
to
2646616
Compare
triton-lang#3309) -Transform dot -Transform dot operands to required layouts -Support datatypes convertion -add lit test for AccelerateAAMDMatmul pass for WMMA case Signed-off-by: joviliast <iveselov.nn@gmail.com>
-Transform dot
-Transform dot operands to required layouts
-Support datatypes convertion
-add lit test for AccelerateAAMDMatmul pass for WMMA case