-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] [MFMA] Support dot3d in MFMA layout #3600
Conversation
binarman
commented
Apr 8, 2024
- Support 3d tensor when emitting offsets for mfma layouts
- Support 3d tensors in Shared to dot operand conversion
- Support dot3d in Dialect.cpp
- Replace amd::DecomposeConversion with common::ReduceDataDuplication
352331c
to
8ca5404
Compare
8f989af
to
0466f46
Compare
Notes: This PR is a continuation of #3298 this PR fixes dot3d for MFMA layout only, I am going to prepare additional patch for WMMA(Navi) layout |
third_party/amd/backend/compiler.py
Outdated
@@ -23,7 +23,7 @@ class HIPOptions: | |||
arch: str = None | |||
allow_fp8e4nv: bool = False | |||
default_dot_input_precision: str = "ieee" | |||
allowed_dot_input_precisions: Tuple[str] = ("ieee", ) | |||
allowed_dot_input_precisions: Tuple[str] = ("tf32", "ieee") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhanglx13
I have a question about this part.
I see two possibilities here:
- assume that TF32 is optional low-precision mode of float32, so we can use ordinary float32 even if TF32 is set (this is what happens in this PR)
- TF32 is not supported by AMD backend and related tests should be simply skipped
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I saw in test_dot
, tf32 is skipped So I assume it's ok to just skip them on AMD backend.
@@ -3181,8 +3181,8 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid | |||
@pytest.mark.parametrize("in_dtype_str, out_dtype_str", [('int8', 'int8'), ('float16', 'float16'), | |||
('float16', 'float32'), ('float32', 'float32')]) | |||
def test_dot3d(B, num_warps, M, N, K, in_dtype_str, out_dtype_str, device): | |||
if is_hip(): | |||
pytest.skip('TODO test_dot3d not supported on HIP.') | |||
if in_dtype_str == 'int8' and is_interpreter(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this also true for nv path? Then why was it not caught before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a new change: #3566
This part leaked here during rebase, will remove it
const int uniqueValuesPerWarp = 4; | ||
effectiveWarpSize = i32_val(uniqueValuesPerWarp); | ||
} | ||
Value laneId = urem(threadId, effectiveWarpSize); | ||
|
||
// Note: here we assume warpId goes along the M dim first |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not the case anymore. We should remove it.
- Support 3d tensor when emitting offsets for mfma layouts - Support 3d tensors in Shared to dot operand conversion - Support dot3d in Dialect.cpp - Replace amd::DecomposeConversion with common::ReduceDataDuplication
955c17d
to
fae2f6e
Compare
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
@@ -1424,7 +1428,12 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const { | |||
SmallVector<unsigned> | |||
AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const { | |||
auto nonKDim = getMDim(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be using mDim and nDim here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea!
int64_t mfmaInstrK; | ||
// TODO(Lixun): make it simpler | ||
// getMFMAInstrShapeForOperands always returns a 2D vector | ||
if (rank == 3) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we were going to do this: #3298 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's leave this for later,
I am going to enable dot3d for WMMA layout as well, after this I'll try to refactor them uniformly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you enabling dot3d for wmma layout in this PR ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this will be separate PR.
I my opinion it is easier to review and fix stuff step by step.