Skip to content

Commit

Permalink
[AMD] ElementwiseOpToLLVM: Do not convert types if they are equal (tr…
Browse files Browse the repository at this point in the history
…iton-lang#3091)

This commit fixes failure in
python/tutorials/03-matrix-multiplication.py for FMA cases,
also fixes mixed dot for FMA cases.
Tested on Navi31

---------

Signed-off-by: joviliast <iveselov.nn@gmail.com>
  • Loading branch information
joviliast authored and htyu committed Mar 20, 2024
1 parent b4e8896 commit d2d2e8f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
11 changes: 11 additions & 0 deletions test/Conversion/amd/fp_to_fp.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm | FileCheck %s

// CHECK-LABEL: f16_to_f32
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @f16_to_f32(%arg0: tensor<8x8xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) {
// CHECK-COUNT-8: llvm.inline_asm asm_dialect {{.*}}v_cvt_f32_f16 {{.*}}: (f16) -> f32
%0 = tt.fp_to_fp %arg0 : tensor<8x8xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<8x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
tt.return
}
}
10 changes: 8 additions & 2 deletions third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,6 @@ struct FpToFpOpConversion
bool isDstFP32 = dstElementType.isF32();
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
Type dstType = isDstFP32 ? f16_ty : dstElementType;
auto cvtFunc = getConversionFunc(srcType, dstType);
SmallVector<Value> inVals;
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
inVals.push_back(operands[i][0]);
Expand All @@ -1239,7 +1238,14 @@ struct FpToFpOpConversion
for (Value &v : inVals)
v = convertFp32ToFp16NZ(loc, rewriter, v);
inVals.resize(numElements, undef(typeConverter->convertType(srcType)));
SmallVector<Value> outVals = cvtFunc(loc, rewriter, inVals);
SmallVector<Value> outVals;
if (srcType != dstType) {
auto cvtFunc = getConversionFunc(srcType, dstType);
outVals = cvtFunc(loc, rewriter, inVals);
} else {
outVals = inVals;
}

assert(outVals.size() == inVals.size());
outVals.resize(std::min(numElements, operands.size()));
if (isDstFP32)
Expand Down

0 comments on commit d2d2e8f

Please sign in to comment.