-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD backend] Fix unit test test_dot_without_load
#3338
Conversation
@@ -87,7 +87,7 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( | |||
return elemTy; | |||
if (auto mfmaParent = | |||
dotOpLayout.getParent().dyn_cast<AMDMfmaEncodingAttr>()) { | |||
return vec_ty(elemTy, dotOpLayout.getKWidth()); | |||
return elemTy; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is not a special case anymore, we can just remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
@@ -43,7 +43,15 @@ namespace gpu { | |||
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape, | |||
Type eltTy) { | |||
if (auto tritonGPUAttr = layout.dyn_cast<TritonGPU_AttrTrait>()) { | |||
return tritonGPUAttr.getTotalElemsPerThread(shape, eltTy); | |||
unsigned elemNum = tritonGPUAttr.getTotalElemsPerThread(shape, eltTy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better not to touch this "interface" function. The changes are for dotOp with mfma as parent anyway. Can we move the changes in DotOperandEncodingAttr::getTotalElemsPerThread
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Changed accordingly.
694cd67
to
eda0398
Compare
Hi @ThomasRaoux, could you please take a look at this PR to see if you have any comments? Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
looks like this fails CI: |
OK, let me check that. Sorry about that. |
This reverts commit e0e5a36.
This PR is actually fix the regression in the reverted PR: #3338, which caused a regression for the test `test_masked_load_shared_memory`. The reason is for type used in packing dot_op for bfloat16. We should use the type `i16` for `bf16` when packing dot_op for mfma. This time I ran all the tests in `test_core.py` locally, and all work fine.
This PR is to fix the unit test `test_dot_without_load` by by changing dot op from `vector<vector<type>>` to `vector<type>`, so the constantOp can be converted to mfma `dot_op` with existing lowering code.
This PR is actually fix the regression in the reverted PR: triton-lang#3338, which caused a regression for the test `test_masked_load_shared_memory`. The reason is for type used in packing dot_op for bfloat16. We should use the type `i16` for `bf16` when packing dot_op for mfma. This time I ran all the tests in `test_core.py` locally, and all work fine.
This PR is to fix the unit test `test_dot_without_load` by by changing dot op from `vector<vector<type>>` to `vector<type>`, so the constantOp can be converted to mfma `dot_op` with existing lowering code.
This PR is actually fix the regression in the reverted PR: triton-lang#3338, which caused a regression for the test `test_masked_load_shared_memory`. The reason is for type used in packing dot_op for bfloat16. We should use the type `i16` for `bf16` when packing dot_op for mfma. This time I ran all the tests in `test_core.py` locally, and all work fine.
This PR is to fix the unit test
test_dot_without_load
by by changing dot op fromvector<vector<type>>
tovector<type>
, so the constantOp can be converted to mfmadot_op
with existing lowering code.