-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD][WMMA] Support dot3d #3674
Conversation
+cc @joviliast |
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
@@ -1649,7 +1676,7 @@ AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, | |||
unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands( | |||
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const { | |||
auto rep = getWMMARepForOperands(shape, eltTy, kWidth, opIdx); | |||
return rep[0] * rep[1] * kWidth; | |||
return rep[0] * rep[1] * rep[2] * kWidth; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use something like
return product(rep) * kWidth;
?
unsigned mfmaInstrNonK = elemsPerInstr[opIdx == 0 ? 0 : 1]; | ||
unsigned mfmaInstrK = elemsPerInstr[opIdx == 0 ? 1 : 0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as it became a common logic, could you please rename it?
LGTM Have you run test_dot locally on navi ? |
9bb5b68
to
85a1379
Compare
yes
|
f27d70c
to
ae72bf9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some lit tests? At the moment we don't have CI for RDNA GPUs so test_core.py
is effectively not checked. It may regress at any time. lit tests is checking the compiler transformation and can make sure we have some guarantee. It's also easier to read and fix lit tests than full blown integration runtime tests. So lit tests are typically the first line of defense for quality.
@antiagainst @joviliast is working on same code, so even if I add lit test, it will probably break in near future adding more redundant work to him(or me, it depends who will merge changes first). I can implement some basic test, which will check that there are no crashes, but in my opinion this test does not guarantee much. P.s. We have some basic llir interpreter which can help checking changes from this PR, but at this point it requires some massive work. I prefer to invest time in this task, if correctness on Navi aligns with our team priorities. |
40c142e
to
8c9ba47
Compare
@antiagainst PTAL |
8c9ba47
to
55988f4
Compare
if triton.runtime.driver.active.get_current_target().arch == "gfx1100": | ||
if in_dtype_str == "int8" or in_dtype_str == "float32": | ||
pytest.skip(f"{in_dtype_str} is not supported in WMMA dot") | ||
if out_dtype_str == "float16": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are float16
accumulate wmma ops? Are they not matching the precision w.r.t. reference pytorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to check this. At some point they did not match, but maybe this is not the case anymore, since a lot of time passed since I've implemented this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, precision issue is still there.
I suspect this is a hardware problem, though this requires more investigation of wmma behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay thanks. worth understanding more. I think we can also prmote to f32 and then downcast if necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@binarman Are we currently using V_WMMA_F16_16X16X16_F16
and see accuracy mismatch with pytorch? If so, can we use V_WMMA_F32_16X16X16_F16
and then cast to fp16 as @antiagainst mentioned?
assert(shape[0] % mnkDim[0] == 0); | ||
multiDimWarpId[0] = | ||
urem(multiDimWarpId[0], i32_val(ceil<unsigned>(shape[0], mnkDim[0]))); | ||
if (shape[rank - 2] >= mnkDim[0]) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have quite some duplicated shape[rank - N]
references. What about using some self-documenting local variables for them? Then we have less chance to be inconsistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is intended to support dot3d, I suggest to refactor this code as a separate task(FYI we discussed this some time ago: #3600 (comment)),
A lot of this code is same on MFMA side and it will be better to refactor both MFMA and WMMA at the same time.
We have two ideas how to refactor this code:
- always assume we have batch dimension in dot
- use structure with named fields, i.e. M/N/K/B instead of indexes
Choosing one of this paths is a separate task, which will be next step after test bringup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG to follow up on this later.
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { | ||
offsets.push_back( | ||
{ctaOffsetX * shapePerCta[0] + 2 * elem, ctaOffsetY * shapePerCta[1]}); | ||
elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minus is less mentally straightforward than plus. I'd suggest doing bool hasBatch = rank == 3;
and then use [0 + hasBath]
for M index and [1 + hasBatch]
for N index.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to not change this now: all other places like this use minus style.
My suggestion is to make this refactoring a separate task,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure works for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(sorry clicked the wrong buttion before)
Regarding tests, I treat it as what we want to invest to guard against future breakages. We don't have RDNA CI; so this can easily regress. Compared to the efforts spent on writing some tests now (which is mostly one time), I'm more concerned about the potential time lost on debugging all these complex logic in the future only via integration python tests in a sense. And we don't know how many regressions we will see throughout the journey. Also btw lit tests don't need to be super detailed and cover all the lines; we can just cover important parts so it's not a change detector. I don't think it's a lot of effort to update them, given that the index caculation doesn't change frequently I believe. And whatever we change there it's delibrate--it can help for folks touching the code to verify their changes too. (Keep in mind that there are contributors that only do MFMA parts--they will not run their changes on some RDNA cards to verify things pass. let alone folks only touching nvidia support. But a lit tests runs everywhere and can provide us guarantees.) |
This PR has extensive indexing calculation. So + @zhanglx13 to double check too. |
This PR enables support of 3d dot and fixes tests in test_core.py
55988f4
to
c824575
Compare
if triton.runtime.driver.active.get_current_target().arch == "gfx1100": | ||
if in_dtype_str == "int8" or in_dtype_str == "float32": | ||
pytest.skip(f"{in_dtype_str} is not supported in WMMA dot") | ||
if out_dtype_str == "float16": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay thanks. worth understanding more. I think we can also prmote to f32 and then downcast if necessary.
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { | ||
offsets.push_back( | ||
{ctaOffsetX * shapePerCta[0] + 2 * elem, ctaOffsetY * shapePerCta[1]}); | ||
elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure works for me.
assert(shape[0] % mnkDim[0] == 0); | ||
multiDimWarpId[0] = | ||
urem(multiDimWarpId[0], i32_val(ceil<unsigned>(shape[0], mnkDim[0]))); | ||
if (shape[rank - 2] >= mnkDim[0]) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG to follow up on this later.
This PR enables support of 3d dot for RDNA GPUs. (cherry picked from commit 100e2aa)
Cherry picks for release/3.0.x General: - e8bc45d [BACKEND][AMD] Disable linear layout due to perf regression (#4126) - 9a0a7c2 [AMD] Add basic verification to MFMA encoding (#4117) for RDNA: - 100e2aa [AMD][WMMA] Support dot3d (#3674) - 4a1ea8e [AMD][gfx11] Fix BF16 wmma instr generation (#4135) Proton HIP PRs: - 328b86d [PROTON] Refactor GPU profilers (#4056) - 60613fb [PROTON] Roctracer: convert agent id to gpu id for gpu ops (#4090) - c1776fa [PROTON][AMD] Add Proton HIP GPU Utilization Metrics (#4119) --------- Co-authored-by: Lei Zhang <antiagainst@gmail.com> Co-authored-by: Alexander Efimov <efimov.alexander@gmail.com> Co-authored-by: Ilya V <152324710+joviliast@users.noreply.github.com> Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: mwootton <michael.wootton@amd.com> Co-authored-by: Corbin Robeck <corbin.robeck@amd.com>
This PR enables support of 3d dot for RDNA GPUs.
This PR enables support of 3d dot for RDNA GPUs.