Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD][Navi31] Support WMMA transformations in AccelerateAMDMatmul pass #3309

Merged
merged 2 commits into from
Mar 15, 2024

Conversation

joviliast
Copy link
Contributor

-Transform dot
-Transform dot operands to required layouts
-Support datatypes convertion
-add lit test for AccelerateAAMDMatmul pass for WMMA case

@joviliast
Copy link
Contributor Author

Can be merged after
#3170
#3171
#3308

This PR enables whole pipeline for wmma lowering
tested on the wip branch: https://github.com/joviliast/triton/tree/wmma-upstream-wip

@joviliast joviliast force-pushed the wmma-accelerate-amd-matmul branch 5 times, most recently from 51b1bf2 to 287d227 Compare March 11, 2024 11:07

MatrixCoreVersion getMatrixCoreVersion(StringRef archGen) {
if (archGen.contains("gfx11"))
return MatrixCoreVersion::RDNA_WMMA;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason why we want to pass archGen all the way here to decide on matrix core version ? I thought it's enough to use matrix_core_version from frontend. (cc @zhanglx13 )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RDNA3 arch introduces wmma instructions that are not technically a new version of matrix core (it is only applicable for CDNA). So it couldn't be described by matrix_core_version parameter.
Reserving some value for wmma could be dangerous for future versions of CDNA, thats why I'm passing archGen.
If you have any suggestions please share..

Copy link
Collaborator

@zhanglx13 zhanglx13 Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way to unify wmma and mfma in terms of version is to use the gfx number directly, which corresponds to how AMD tracks hw version today and in the future.
And since on AMD path, we don't use capability anyway, we can just use the gfx number as the version here:

  • 908 <-- matrix_core_version 1
  • 90a <-- matrix_core_version 2
  • 940, 941, 942 <-- matrix_core_version 3
  • 1100 <-- wmma

@zahimoud what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want int, we can use 910 for gfx90a.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I'm fine with passing arch to backend, I would just pass it to TargetInfo object, and in the constructor we can map to an enum class of the arch (probably MatrixCoreVersion) and we would have an api like taretInfo.getMfmaVersion() to get the mfma version directly.

Copy link
Contributor Author

@joviliast joviliast Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, now it is a part of TritonAMDGPUToLLVM CMake target
I think it would be better to create new target AMDTargetInfo instead of adding dependency between TritonAMDGPUTransforms and TritonAMDGPUToLLVM ... WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea let's move it to amd/include, would that work ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see, we also need to move Utility to the separate target. May be there are other nvidia specific pitfalls.
Would it be Ok to do it in separate PR? I can create an issue for that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which utility are you referring to and why do you think we should move it ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/openai/triton/blob/main/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp this one should be moved to separate target to avoid cyclic dependency.

I think it is really need to be done, but this PR is not about it.

@joviliast joviliast force-pushed the wmma-accelerate-amd-matmul branch from 5964c3a to 5603f47 Compare March 14, 2024 12:05
@joviliast joviliast marked this pull request as ready for review March 14, 2024 12:05
@joviliast joviliast requested a review from zahimoud March 14, 2024 14:21
@joviliast joviliast requested a review from zhanglx13 March 15, 2024 11:04
if(self.matrix_core_version == -1):
object.__setattr__(self, 'matrix_core_version', self.get_matrix_core_version(self.arch))
if(self.mfma_version == -1):
object.__setattr__(self, 'mfma_version', self.get_mfma_version(self.arch))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So on Navi arch, this is set to mfma_version = 0. Then the stream-pipeline pass is disabled.
Is this the expected behavior on Navi?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Thanks.
Just checked again. Stream pipeline works fine.
Please, check updated https://github.com/openai/triton/pull/3309/files#diff-33c9a103282c05c9d9d213b94450ae7481b6db8c3c6d810f54f175b4735a3c72R39-R44

@joviliast joviliast requested review from zhanglx13 and zahimoud March 15, 2024 17:14
-Change input option for AccelerateAMDMatmul pass  from matrix-core-version to arch-generation-name
-Transform dot
-Transform dot operands to required layouts
-Support datatypes convertion
-Enable stream pipeline pass for WMMA
-Add lit test for AccelerateAAMDMatmul pass for WMMA case

Signed-off-by: joviliast <iveselov.nn@gmail.com>
@joviliast joviliast force-pushed the wmma-accelerate-amd-matmul branch from 1ce9a46 to 2646616 Compare March 15, 2024 17:25
@zhanglx13 zhanglx13 merged commit ce74d42 into triton-lang:main Mar 15, 2024
5 checks passed
karupayun pushed a commit to openxla/triton that referenced this pull request Apr 3, 2024
triton-lang#3309)

-Transform dot
-Transform dot operands to required layouts
-Support datatypes convertion
-add lit test for AccelerateAAMDMatmul pass for WMMA case

Signed-off-by: joviliast <iveselov.nn@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants