-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] Enable test_dot3d on AMD backend #3298
Conversation
480a50c
to
25da738
Compare
23df613
to
47c95ed
Compare
47c95ed
to
6fe12c8
Compare
@ThomasRaoux @zahimoud Gental ping for review :) |
@@ -524,12 +524,13 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc, | |||
return multiDimOffset; | |||
} | |||
if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>()) { | |||
// TODO: extend to support dot3d |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need this ?
auto multiDimBase = | ||
emitBaseIndexForLayout(loc, rewriter, layout, type, false); | ||
SmallVector<SmallVector<unsigned>> offsets; | ||
assert(rank == 2); | ||
SmallVector<Value> multiDimOffset(rank); | ||
emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0], | ||
emitMfmaOffsetForCTA(mfmaLayout, offsets, 0, multiDimCTAInRepId[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a way to not hardcode this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part of code is used by lowerDistributedToDistributed
, which I feel is not used at all.
- One use case is in the epilogue when we need to convert mfma to blocked before tt.store. The optimize_epilogue pass actually removes the conversion and uses mfma layout to do tt.store
- For other distributed layouts, like blocked->dotOp or mfma->dotOp, they should either be decomposed or a shortcut.
Do you have a use case for this conversion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, at least for mma
, some ops do not support mma
so if we have a transpose
or scan
on the result of a dot
, we would have to convert mma to blocked. Might be a similar case for mfma
. I would keep support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is interesting.
If this conversion happens inside a loop, the traffic to and from shared memory will harm perf a lot. So I guess in this case, you'd rather pay the price for perf than supporting mma for those ops.
I can support it anyway, but do you have a test so that I can verify the results?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we have a test, maybe @ThomasRaoux knows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we should be tested in test_convert2d
as we are testing different combination of convert including mma layouts. If this is not tested we should definitely add a case there.
If this conversion happens inside a loop, the traffic to and from shared memory will harm perf a lot. So I guess in this case, you'd rather pay the price for perf than supporting mma for those ops.
well there is always different ways to propagate and reduce the cost but we want functionality first so we should support this no matter what. When we run into performance problems we will address those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough. I get the point that we should support conversions between distributed layout. Therefore, we should keep this file.
I believe we should be tested in test_convert2d as we are testing different combination of convert including mma layouts. If this is not tested we should definitely add a case there.
I'd suggest we do it in future PRs, since this PR is about dot3d. WDYT?
return true; | ||
else | ||
return false; | ||
if ((rank == 3) && (order[0] + opIdx == 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way I got over special-casing for rank==2 and rank==3 in SharedToDotOperandMMAv2 is to create a dummy 3D tensor out of the input tensor, do the codegen, an then throw away that 3D tensor, so we only have to think about 3D codegen rather than having if/else everywhere. Can we do the same here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I checked your PR I did not pay attention what you did in this file since I'm not familiar with mma layout. I think this is indeed a good way to get rid of if(rank == ) stuff as much as possible. However, given the size of this PR, can we do it in a future PR as a refactor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should just do it now. We may forget about it if we delay refactoring to a future PR :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough. I'll find some time next week to refactor this part.
and address review comments
25951fd
to
14e5b56
Compare
if (rank == 3) | ||
multiDimBase[0] = urem(warpId, i32_val(shape[0])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this function is generalized already, or do I miss something?
If it is not, could you add an assert for this here?
something like
if (rank == 3) | |
multiDimBase[0] = urem(warpId, i32_val(shape[0])); | |
if (rank == 3) { | |
assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1); | |
multiDimBase[0] = urem(warpId, i32_val(shape[0])); | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you mean that this particular check is not general, I think we can safely declare that dim 0 is slowest one, i.e. warp order is [2, 1, 0]
or [2, 0, 1]
so you need something like
if (rank == 3) | |
multiDimBase[0] = urem(warpId, i32_val(shape[0])); | |
if (rank == 3) { | |
auto singleDotWarps = _warpsPerCTA[rank - 1] * _warpsPerCTA[rank - 2]; | |
multiDimBase[0] = urem(udiv(warpId, i32_val(singleDotWarps)), i32_val(shape[0])); | |
} |
Waiting for #3171 to land and I'll rebase to add the refactors |
Are you still working on this ? |
Closing until @zhanglx13 can pick up again. |
@zahimoud |
Port #3056 onto AMD backend