Skip to content

Commit

Permalink
clean up workaround after upstream fix (#10188)
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasRaoux authored Aug 24, 2022
1 parent c9e9482 commit 1d55c6c
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,24 +240,22 @@ struct LLVMGPUVectorToGPUPass
}
RewritePatternSet patterns(funcOp.getContext());
populatePrepareVectorToMMAPatterns(patterns, llvmgpuUseMMASync);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}

if (llvmgpuUseMMASync) {
(void)convertVectorToNVVMCompatibleMMASync(funcOp);

// TODO: Remove once populateMmaSyncF32ToTF32Patterns is fixed to not add
// attribute tf32 attributes to none f32 ops.
bool hasFP32mma = false;
funcOp.walk([&hasFP32mma](nvgpu::MmaSyncOp op) {
if (op.getType().cast<VectorType>().getElementType().isF32())
hasFP32mma = true;
});
if (hasFP32mma) {
// Use TF32 for float32 case for now.
RewritePatternSet patterns(funcOp.getContext());
nvgpu::populateMmaSyncF32ToTF32Patterns(
patterns, nvgpu::MmaSyncF32Lowering::TF32);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
if (failed(convertVectorToNVVMCompatibleMMASync(funcOp))) {
return signalPassFailure();
}
// Use TF32 for float32 case for now.
RewritePatternSet f32ToTF32patterns(funcOp.getContext());
nvgpu::populateMmaSyncF32ToTF32Patterns(f32ToTF32patterns,
nvgpu::MmaSyncF32Lowering::TF32);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(f32ToTF32patterns)))) {
return signalPassFailure();
}
} else {
convertVectorToMMAOps(funcOp);
Expand Down

0 comments on commit 1d55c6c

Please sign in to comment.