Skip to content

Commit

Permalink
Parallel Grid loop creation for gemms (#885)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Heinecke <alexander.heinecke@intel.com>
Co-authored-by: Renato Golin <rengolin@systemcall.eu>
  • Loading branch information
3 people authored Feb 20, 2024
1 parent edded3d commit 6f4b13b
Show file tree
Hide file tree
Showing 12 changed files with 527 additions and 80 deletions.
16 changes: 8 additions & 8 deletions benchmarks/config/omp/mlir-bf16.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,28 @@
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --float-width=16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16'" ],
"extensions": [ "(avx2|asimd)" ]
},
"bf16_dp2_3x1024_omp_4_mlir": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --float-width=16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8'" ],
"extensions": [ "(avx2|asimd)" ]
},
"bf16_dp2_3x1024_omp_8_mlir": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --float-width=16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8'" ],
"extensions": [ "(avx2|asimd)" ]
},
"bf16_dp2_3x1024_omp_16_mlir": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --float-width=16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8'" ],
"extensions": [ "(avx2|asimd)" ]
}
}},
Expand All @@ -36,28 +36,28 @@
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-width=16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16'" ],
"extensions": [ "(avx2|asimd)" ]
},
"bf16_dp2_3x1024_omp_4_mlir": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-width=16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8'" ],
"extensions": [ "(avx2|asimd)" ]
},
"bf16_dp2_3x1024_omp_8_mlir": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-width=16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8'" ],
"extensions": [ "(avx2|asimd)" ]
},
"bf16_dp2_3x1024_omp_16_mlir": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-width=16 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32 --vnni=2" ],
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8'" ],
"extensions": [ "(avx2|asimd)" ]
}
}}
Expand Down
16 changes: 8 additions & 8 deletions benchmarks/config/omp/mlir-fp32.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,28 @@
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --float-width=32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16'" ],
"extensions": [ "(avx2|asimd)" ]
},
"fp32_3x1024_omp_4_mlir": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --float-width=32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8'" ],
"extensions": [ "(avx2|asimd)" ]
},
"fp32_3x1024_omp_8_mlir": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --float-width=32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8'" ],
"extensions": [ "(avx2|asimd)" ]
},
"fp32_3x1024_omp_16_mlir": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --float-width=32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8'" ],
"extensions": [ "(avx2|asimd)" ]
}
}},
Expand All @@ -36,28 +36,28 @@
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-width=32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16'" ],
"extensions": [ "(avx2|asimd)" ]
},
"fp32_3x1024_omp_4_mlir": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-width=32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8'" ],
"extensions": [ "(avx2|asimd)" ]
},
"fp32_3x1024_omp_8_mlir": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-width=32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8'" ],
"extensions": [ "(avx2|asimd)" ]
},
"fp32_3x1024_omp_16_mlir": {
"type": "IR-GEN",
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-width=32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8'" ],
"extensions": [ "(avx2|asimd)" ]
}
}}
Expand Down
32 changes: 16 additions & 16 deletions benchmarks/config/omp/torch-dynamo.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,28 @@
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-gemm-fp32-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16'" ],
"extensions": [ ]
},
"fp32_3x1024_omp_4_mlir": {
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-gemm-fp32-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8'" ],
"extensions": [ ]
},
"fp32_3x1024_omp_8_mlir": {
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-gemm-fp32-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8'" ],
"extensions": [ ]
},
"fp32_3x1024_omp_16_mlir": {
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-gemm-fp32-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8'" ],
"extensions": [ ]
}
}},
Expand All @@ -36,28 +36,28 @@
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-gemm-bf16-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16'" ],
"extensions": [ ]
},
"bf16_3x1024_omp_4_mlir": {
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-gemm-bf16-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8'" ],
"extensions": [ ]
},
"bf16_3x1024_omp_8_mlir": {
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-gemm-bf16-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8'" ],
"extensions": [ ]
},
"bf16_3x1024_omp_16_mlir": {
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-gemm-bf16-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8'" ],
"extensions": [ ]
}
}},
Expand All @@ -67,28 +67,28 @@
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-mlp-fp32-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16'" ],
"extensions": [ ]
},
"fp32_3x1024_omp_4_mlir": {
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-mlp-fp32-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8'" ],
"extensions": [ ]
},
"fp32_3x1024_omp_8_mlir": {
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-mlp-fp32-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8'" ],
"extensions": [ ]
},
"fp32_3x1024_omp_16_mlir": {
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-mlp-fp32-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8'" ],
"extensions": [ ]
}
}},
Expand All @@ -98,28 +98,28 @@
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-mlp-bf16-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16'" ],
"extensions": [ ]
},
"bf16_3x1024_omp_4_mlir": {
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-mlp-bf16-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8'" ],
"extensions": [ ]
},
"bf16_3x1024_omp_8_mlir": {
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-mlp-bf16-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8'" ],
"extensions": [ ]
},
"bf16_3x1024_omp_16_mlir": {
"type": "MLIR",
"benchmark": "pytorch/torch-dynamo-mlp-bf16-3x1024.mlir",
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
"flags": [ "-n", "100", "-run-args='-def-parallel'" ],
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8'" ],
"extensions": [ ]
}
}}
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/harness/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def run(self):
if self.args.gpu is not None:
runCmd.extend(["--gpu", self.args.gpu])
if self.args.run_args:
runCmd.extend(shlex.split(self.args.run_args))
runCmd.extend(shlex.split(self.args.run_args.replace("'", "")))
runResult = executor.run(runCmd, irContents)
if 0 != runResult.returncode:
self.logger.error(f"Error executing tpp-run: {runResult.stderr}")
Expand Down
13 changes: 13 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -470,4 +470,17 @@ def FoldXsmmFlags : Pass<"fold-xsmm-flags", "func::FuncOp"> {
let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ];
}

def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling-pass"> {
let summary = "Tile parallel loops";
let options = [
ListOption<"tileSizes", "parallel-loop-tile-sizes", "int64_t",
"Factors to tile parallel loops by">,
Option<"noMinMaxBounds", "no-min-max-bounds", "bool",
/*default=*/"false",
"Perform tiling with fixed upper bound with inbound check "
"inside the internal loops">
];
let dependentDialects = ["affine::AffineDialect", "scf::SCFDialect"];
}

#endif // TPP_DIALECT_TPP_PASSES
13 changes: 12 additions & 1 deletion lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ llvm::cl::opt<bool>
llvm::cl::desc("Default pipeline - enable parallel execution"),
llvm::cl::init(false));

// Control grid parallelism sizes.
llvm::cl::list<int64_t>
parallelTaskGrid("parallel-task-grid",
llvm::cl::desc("Grid-sizes for parallel tasks"),
llvm::cl::list_init<int64_t>(SmallVector<int64_t>{2, 8}),
llvm::cl::CommaSeparated);

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_DEFAULTPIPELINE
Expand Down Expand Up @@ -133,8 +140,12 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
pm.addPass(tpp::createConvertPerfToFunc());
pm.addPass(createConvertTensorToLinalgPass());
pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
if (defParallel)
if (defParallel) {
mlir::tpp::SCFParallelLoopTilingOptions tilingOptions;
tilingOptions.tileSizes = parallelTaskGrid;
pm.addPass(createSCFParallelLoopTiling(tilingOptions));
pm.addPass(createConvertSCFToOpenMPPass());
}
pm.addPass(createConvertVectorToSCFPass());
pm.addPass(arith::createArithExpandOpsPass());
pm.addPass(createLowerAffinePass());
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_mlir_library(TPPTransforms
ToBlockLayoutAndBack.cpp
TransformUtils.cpp
CombineXsmmPass.cpp
SCFParallelLoopTiling.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
51 changes: 5 additions & 46 deletions lib/TPP/Transforms/CombineXsmmPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,50 +28,6 @@ using namespace mlir;

namespace {

static FailureOr<DenseI64ArrayAttr>
getSizesAndLeadingDimForBrgemmOp(RewriterBase &rewriter, xsmm::BrgemmOp opTy) {

auto memrefA = opTy.getOperand(1).getType();
auto memrefB = opTy.getOperand(2).getType();
auto memrefC = opTy.getOperand(3).getType();

int64_t m, n, k;
m = memrefC.cast<ShapedType>().getShape()[0];
n = memrefC.cast<ShapedType>().getShape()[1];
k = memrefA.cast<ShapedType>().getShape()[2];

auto ldaDim = xsmm::utils::getLeadingDim(memrefA, /*pos=*/1);
if (failed(ldaDim))
return failure();

int64_t lda = *ldaDim;

auto ldbDim = xsmm::utils::getLeadingDim(memrefB, /*pos=*/1);
if (failed(ldbDim))
return failure();

int64_t ldb =
(vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::BRGEMM_INS,
memrefB.cast<MemRefType>()))
? *ldbDim / *vnni::utils::getVnniBlockingFactor(memrefB)
: *ldbDim;

auto ldcDim = xsmm::utils::getLeadingDim(memrefC);
if (failed(ldcDim))
return failure();

int64_t ldc = *ldcDim;

// If we are dealing with a BRGEMM we need to pass two extra dimensions:
// - strideA and strideB that represent the stride between different GEMM
// in BRGEMM.
int64_t strideA = lda * m;
int64_t strideB = ldb * k;
return DenseI64ArrayAttr::get(
rewriter.getContext(),
ArrayRef<int64_t>{m, n, k, lda, ldb, ldc, strideA, strideB});
}

static ArrayAttr getBrgemmFlags(RewriterBase &rewriter, xsmm::BrgemmOp opTy) {
auto memrefB = opTy.getOperand(2).getType().cast<MemRefType>();
SmallVector<Attribute, 2> attributes;
Expand Down Expand Up @@ -146,14 +102,17 @@ struct CombineXsmmOp : public OpRewritePattern<xsmm::BrgemmOp> {
IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64);

Location loc = brgemmOp.getLoc();
auto dims = getSizesAndLeadingDimForBrgemmOp(rewriter, brgemmOp);
auto dims = DenseI64ArrayAttr::get(
rewriter.getContext(), dyn_cast<mlir::xsmm::BrgemmDispatchOp>(
brgemmOp.getOperand(0).getDefiningOp())
.getInputs());
auto memrefB = brgemmOp.getOperand(2);
int64_t batchSize = memrefB.getType().cast<ShapedType>().getShape()[0];

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(fusedMatch.binaryOp);
Value dispatched = rewriter.create<xsmm::FusedBrgemmDispatchOp>(
loc, integer64, *dims,
loc, integer64, dims,
xsmm::BinaryKindAttr::get(rewriter.getContext(), fusedMatch.binaryKind),
xsmm::UnaryKindAttr::get(rewriter.getContext(), fusedMatch.unaryKind),
getBrgemmFlags(rewriter, brgemmOp),
Expand Down
Loading

0 comments on commit 6f4b13b

Please sign in to comment.