Skip to content

Commit

Permalink
Matmul narrow no simd (llvm#1123)
Browse files Browse the repository at this point in the history
* disable simd for matmul when simdization dim is smaller than simd vector in MatMul operator

Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
Co-authored-by: Tung D. Le <tungld@gmail.com>
  • Loading branch information
AlexandreEichenberger and tungld authored Jan 26, 2022
1 parent 363f4b1 commit f2171ab
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
3 changes: 2 additions & 1 deletion src/Conversion/ONNXToKrnl/Math/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/Support/Debug.h"

#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/KrnlHelper.hpp"
#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp"
#include "llvm/Support/Debug.h"

// Used to trace which op are used, good for profiling apps.
#define DEBUG_TYPE "gemm"
Expand Down
42 changes: 29 additions & 13 deletions src/Conversion/ONNXToKrnl/Math/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/Support/Debug.h"

#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/KrnlHelper.hpp"
#include "src/Dialect/ONNX/IndexExpr.hpp"
Expand All @@ -20,7 +22,7 @@

using namespace mlir;

#define DEBUG_TRACE 0
#define DEBUG_TYPE "matmul"

struct ONNXMatMulOpLowering : public ConversionPattern {
ONNXMatMulOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
Expand Down Expand Up @@ -128,6 +130,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern {

// Initialize alloc/C to zero.
create.krnl.memset(alloc, zeroVal);
bool simdize = true;

// Compute.
// Define blocking, with simdization along the j axis.
Expand All @@ -138,27 +141,39 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
int64_t constI = dimI.getLiteral();
if (constI < iRegTile) {
iRegTile = constI;
if (DEBUG_TRACE)
printf("MatMul: Tiling I is reduced to %d\n", (int)iRegTile);
LLVM_DEBUG({
llvm::dbgs() << "MatMul: Tiling I is reduced to " << iRegTile << "\n";
});
}
}
if (dimJ.isLiteral()) {
int64_t constJ = dimJ.getLiteral();
// When jRegTile does not divide J, but 4 would, use 4, unless J is very
// large, in which case it is better to simdize well the steady state and
// ignore the last partial block.
// large, in which case it is better to simdize well the steady state
// and ignore the last partial block.
if (constJ % jRegTile != 0 && constJ % 4 == 0 && constJ <= 32) {
jRegTile = 4;
if (DEBUG_TRACE)
printf("MatMul: Tiling J is reduced to %d\n", (int)jRegTile);
LLVM_DEBUG({
llvm::dbgs() << "MatMul: Tiling J is reduced to " << jRegTile << "\n";
});
}
// Simdization occurs along j and jRegTile. If dimJ is smaller than
// jRegTile, disable simdization.
if (constJ < jRegTile) {
simdize = false;
LLVM_DEBUG({
llvm::dbgs() << "MatMul: Disable simdization because trip " << constJ
<< " is smaller than reg tile " << jRegTile << "\n";
});
}
}
if (dimK.isLiteral()) {
int64_t constK = dimK.getLiteral();
if (constK < kRegTile) {
kRegTile = constK;
if (DEBUG_TRACE)
printf("MatMul: Tiling K is reduced to %d\n", (int)kRegTile);
LLVM_DEBUG({
llvm::dbgs() << "MatMul: Tiling K is reduced to " << kRegTile << "\n";
});
}
}

Expand All @@ -178,13 +193,14 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
Value i1(indices[0]), j1(indices[1]), k1(indices[2]);
createKrnl.matmul(A, {zero, zero}, B, {zero, zero}, C, {zero, zero},
{ii2, jj2, kk2}, {i1, j1, k1}, {I, J, K},
{iRegTile, jRegTile, kRegTile}, {}, {}, {},
/*simd*/ true, /*unroll*/ true, /*overcompute*/ false);
{iRegTile, jRegTile, kRegTile}, {}, {}, {}, simdize,
/*unroll*/ true, /*overcompute*/ false);
});
}

// Handle the cases with 2x2 matrices both for A, B, and C without broadcast.
// Implementation here uses the efficient 2d tiling plus kernel substitution.
// Handle the cases with 2x2 matrices both for A, B, and C without
// broadcast. Implementation here uses the efficient 2d tiling plus kernel
// substitution.

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Expand Down

0 comments on commit f2171ab

Please sign in to comment.