diff --git a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp index 3a68c7079c38..959c5e288cea 100644 --- a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp @@ -12,10 +12,11 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Support/Debug.h" + #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Krnl/KrnlHelper.hpp" #include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp" -#include "llvm/Support/Debug.h" // Used to trace which op are used, good for profiling apps. #define DEBUG_TYPE "gemm" diff --git a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp index edb1e3843d34..d6a2affe4e72 100644 --- a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp +++ b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp @@ -12,6 +12,8 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Support/Debug.h" + #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Krnl/KrnlHelper.hpp" #include "src/Dialect/ONNX/IndexExpr.hpp" @@ -20,7 +22,7 @@ using namespace mlir; -#define DEBUG_TRACE 0 +#define DEBUG_TYPE "matmul" struct ONNXMatMulOpLowering : public ConversionPattern { ONNXMatMulOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) @@ -128,6 +130,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern { // Initialize alloc/C to zero. create.krnl.memset(alloc, zeroVal); + bool simdize = true; // Compute. // Define blocking, with simdization along the j axis. @@ -138,27 +141,39 @@ struct ONNXMatMulOpLowering : public ConversionPattern { int64_t constI = dimI.getLiteral(); if (constI < iRegTile) { iRegTile = constI; - if (DEBUG_TRACE) - printf("MatMul: Tiling I is reduced to %d\n", (int)iRegTile); + LLVM_DEBUG({ + llvm::dbgs() << "MatMul: Tiling I is reduced to " << iRegTile << "\n"; + }); } } if (dimJ.isLiteral()) { int64_t constJ = dimJ.getLiteral(); // When jRegTile does not divide J, but 4 would, use 4, unless J is very - // large, in which case it is better to simdize well the steady state and - // ignore the last partial block. + // large, in which case it is better to simdize well the steady state + // and ignore the last partial block. if (constJ % jRegTile != 0 && constJ % 4 == 0 && constJ <= 32) { jRegTile = 4; - if (DEBUG_TRACE) - printf("MatMul: Tiling J is reduced to %d\n", (int)jRegTile); + LLVM_DEBUG({ + llvm::dbgs() << "MatMul: Tiling J is reduced to " << jRegTile << "\n"; + }); + } + // Simdization occurs along j and jRegTile. If dimJ is smaller than + // jRegTile, disable simdization. + if (constJ < jRegTile) { + simdize = false; + LLVM_DEBUG({ + llvm::dbgs() << "MatMul: Disable simdization because trip " << constJ + << " is smaller than reg tile " << jRegTile << "\n"; + }); } } if (dimK.isLiteral()) { int64_t constK = dimK.getLiteral(); if (constK < kRegTile) { kRegTile = constK; - if (DEBUG_TRACE) - printf("MatMul: Tiling K is reduced to %d\n", (int)kRegTile); + LLVM_DEBUG({ + llvm::dbgs() << "MatMul: Tiling K is reduced to " << kRegTile << "\n"; + }); } } @@ -178,13 +193,14 @@ struct ONNXMatMulOpLowering : public ConversionPattern { Value i1(indices[0]), j1(indices[1]), k1(indices[2]); createKrnl.matmul(A, {zero, zero}, B, {zero, zero}, C, {zero, zero}, {ii2, jj2, kk2}, {i1, j1, k1}, {I, J, K}, - {iRegTile, jRegTile, kRegTile}, {}, {}, {}, - /*simd*/ true, /*unroll*/ true, /*overcompute*/ false); + {iRegTile, jRegTile, kRegTile}, {}, {}, {}, simdize, + /*unroll*/ true, /*overcompute*/ false); }); } - // Handle the cases with 2x2 matrices both for A, B, and C without broadcast. - // Implementation here uses the efficient 2d tiling plus kernel substitution. + // Handle the cases with 2x2 matrices both for A, B, and C without + // broadcast. Implementation here uses the efficient 2d tiling plus kernel + // substitution. LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final {