Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] [MFMA] Support dot3d in MFMA layout #3600

Merged
merged 8 commits into from
Apr 9, 2024

Conversation

binarman
Copy link
Contributor

@binarman binarman commented Apr 8, 2024

  • Support 3d tensor when emitting offsets for mfma layouts
  • Support 3d tensors in Shared to dot operand conversion
  • Support dot3d in Dialect.cpp
  • Replace amd::DecomposeConversion with common::ReduceDataDuplication

@binarman binarman force-pushed the enable_dot3d_mfma branch from 352331c to 8ca5404 Compare April 8, 2024 15:18
@binarman binarman marked this pull request as ready for review April 8, 2024 21:02
@binarman binarman requested review from Jokeren and ptillet as code owners April 8, 2024 21:02
@binarman binarman force-pushed the enable_dot3d_mfma branch from 8f989af to 0466f46 Compare April 8, 2024 21:05
@binarman
Copy link
Contributor Author

binarman commented Apr 8, 2024

Notes:

This PR is a continuation of #3298

this PR fixes dot3d for MFMA layout only, I am going to prepare additional patch for WMMA(Navi) layout

@@ -23,7 +23,7 @@ class HIPOptions:
arch: str = None
allow_fp8e4nv: bool = False
default_dot_input_precision: str = "ieee"
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
allowed_dot_input_precisions: Tuple[str] = ("tf32", "ieee")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhanglx13
I have a question about this part.
I see two possibilities here:

  1. assume that TF32 is optional low-precision mode of float32, so we can use ordinary float32 even if TF32 is set (this is what happens in this PR)
  2. TF32 is not supported by AMD backend and related tests should be simply skipped

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I saw in test_dot, tf32 is skipped So I assume it's ok to just skip them on AMD backend.

@@ -3181,8 +3181,8 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
@pytest.mark.parametrize("in_dtype_str, out_dtype_str", [('int8', 'int8'), ('float16', 'float16'),
('float16', 'float32'), ('float32', 'float32')])
def test_dot3d(B, num_warps, M, N, K, in_dtype_str, out_dtype_str, device):
if is_hip():
pytest.skip('TODO test_dot3d not supported on HIP.')
if in_dtype_str == 'int8' and is_interpreter():
Copy link
Collaborator

@zhanglx13 zhanglx13 Apr 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this also true for nv path? Then why was it not caught before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new change: #3566

This part leaked here during rebase, will remove it

const int uniqueValuesPerWarp = 4;
effectiveWarpSize = i32_val(uniqueValuesPerWarp);
}
Value laneId = urem(threadId, effectiveWarpSize);

// Note: here we assume warpId goes along the M dim first
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the case anymore. We should remove it.

zhanglx13 and others added 7 commits April 9, 2024 18:51
- Support 3d tensor when emitting offsets for mfma layouts
- Support 3d tensors in Shared to dot operand conversion
- Support dot3d in Dialect.cpp
- Replace amd::DecomposeConversion with common::ReduceDataDuplication
@binarman binarman force-pushed the enable_dot3d_mfma branch from 955c17d to fae2f6e Compare April 9, 2024 18:51
@@ -1424,7 +1428,12 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
SmallVector<unsigned>
AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
auto nonKDim = getMDim();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be using mDim and nDim here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

int64_t mfmaInstrK;
// TODO(Lixun): make it simpler
// getMFMAInstrShapeForOperands always returns a 2D vector
if (rank == 3) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we were going to do this: #3298 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave this for later,
I am going to enable dot3d for WMMA layout as well, after this I'll try to refactor them uniformly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you enabling dot3d for wmma layout in this PR ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this will be separate PR.

I my opinion it is easier to review and fix stuff step by step.

@zahimoud zahimoud merged commit 3c2f88b into triton-lang:main Apr 9, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants