Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hoist vector transfers outside reduction and k loop #977

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def VectorizationPass : Pass<"vectorization-pass",
let dependentDialects = [ "memref::MemRefDialect", "linalg::LinalgDialect", "vector::VectorDialect" ];
}



def HoistVectorTransfers : Pass<"hoist-vector-transfer"> {
let summary = "Hoist vector transfer operation outside of reduction and k loop";
let description = [{
Hoists the vector transfer read and write operations of the resultant matrix outside the reduction and k loop for a brgemm operation. This pass should be applied after the BrgemmLinalgTiling Pass.
}];
let dependentDialects = [ "vector::VectorDialect", "scf::SCFDialect" ];
}



def VectorContractToOuterproduct : Pass<
"vector-contract-to-outerproduct"> {
let summary = "Perform outerproduct lowering of vector contraction ops";
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_mlir_library(TPPTransforms
Vectorization.cpp
SplitReductionDim.cpp
VectorContractToOuterproduct.cpp
HoistVectorTransfers.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
164 changes: 164 additions & 0 deletions lib/TPP/Transforms/HoistVectorTransfers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
//===-HoistVectorTransfers.cpp -----------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements tile configuration hoisting on parallel loops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include <iostream>
namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_HOISTVECTORTRANSFERS
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

using namespace mlir;
using namespace vector;

namespace mlir {
namespace tpp {

struct HoistVectorTransferOp : OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {

// Check whether the linalg tiling + vector contract pattern matches
auto retriveVectorReadOp = contractOp.getAcc().getDefiningOp<mlir::vector::TransferReadOp>();
if (retriveVectorReadOp == NULL)
return rewriter.notifyMatchFailure(contractOp, "Not a linalg tile + vector contract operation");

auto subviewOp = retriveVectorReadOp.getOperand(0).getDefiningOp<memref::SubViewOp>();
if (subviewOp == NULL)
return rewriter.notifyMatchFailure(contractOp, "Not a linalg tile + vector contract operation");

auto ReductionForOp = llvm::dyn_cast<mlir::scf::ForOp>(subviewOp->getNextNode());
if (ReductionForOp == NULL)
return rewriter.notifyMatchFailure(contractOp, "Not a linalg tile + vector contract operation");

auto KForOp = llvm::dyn_cast<mlir::scf::ForOp>(ReductionForOp.getBody()->front());
if (KForOp == NULL)
return rewriter.notifyMatchFailure(contractOp, "Not a linalg tile + vector contract operation");

// Move the vector transfer read before the reduction and k loop
rewriter.setInsertionPointAfter(subviewOp);
auto *cloneVectorReadOp = rewriter.clone(*retriveVectorReadOp);
retriveVectorReadOp.replaceAllUsesWith(cloneVectorReadOp);

// Code to re-create the reduction and k loop with iter args
auto *nextOp = (*cloneVectorReadOp).getNextNode();
auto oldReductionForOp = llvm::dyn_cast<mlir::scf::ForOp>(*nextOp);
auto oldKForOp = llvm::dyn_cast<mlir::scf::ForOp>(oldReductionForOp.getBody()->front());

auto vectorReadOpValue = (*cloneVectorReadOp).getResult(0);
rewriter.setInsertionPoint(oldReductionForOp);

auto newReductionForOp = rewriter.create<scf::ForOp>(
oldReductionForOp.getLoc(), oldReductionForOp.getLowerBound(), oldReductionForOp.getUpperBound(),
oldReductionForOp.getStep(),ValueRange{vectorReadOpValue},
[&](OpBuilder &rewriterNewReductionForOp, Location locNewReductionForOp, Value ivNewReductionForOp,
ValueRange iterArgsNewReductionForOp) {

auto newKForOp = rewriter.create<scf::ForOp>(
oldKForOp.getLoc(), oldKForOp.getLowerBound(), oldKForOp.getUpperBound(),
oldKForOp.getStep(), iterArgsNewReductionForOp,
[&](OpBuilder &rewriterNewKForOp, Location locNewKForOp, Value ivNewKForOp,
ValueRange iterArgsNewKForOp) {

mlir::IRMapping mapper;
mapper.map(oldReductionForOp.getInductionVar(), ivNewReductionForOp);
mapper.map(oldKForOp.getInductionVar(), ivNewKForOp);

for (auto [origArgReduction, newArgReduction] :
llvm::zip(oldReductionForOp.getRegionIterArgs(), iterArgsNewReductionForOp)) {
mapper.map(origArgReduction, newArgReduction);
}

for (auto [origArgK, newArgK] :
llvm::zip(oldKForOp.getRegionIterArgs(), iterArgsNewKForOp)) {
mapper.map(origArgK, newArgK);
}

for (auto &op : oldKForOp.getBody()->without_terminator()) {
rewriterNewKForOp.clone(op, mapper);
}

rewriterNewKForOp.create<scf::YieldOp>(locNewKForOp, iterArgsNewKForOp);

});
rewriterNewReductionForOp.create<scf::YieldOp>(locNewReductionForOp, newKForOp.getResult(0));
});

//Code to hoist vector transfer write after reduction loop and also to update the yield of k loop
auto newKForOp = llvm::dyn_cast<mlir::scf::ForOp>(newReductionForOp.getBody()->front());
Value newcontractOpValue;
mlir::vector::TransferWriteOp vectorWriteOperation;
mlir::Block *bodyBlock = newKForOp.getBody();
for (auto &op : bodyBlock->getOperations()) {
if (auto vectorContractOp = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
vectorContractOp.setOperand(vectorContractOp.getNumOperands()-1, newKForOp.getRegionIterArgs()[0]);
newcontractOpValue = vectorContractOp.getResult();
}
if (auto yieldOp = llvm::dyn_cast<mlir::scf::YieldOp>(op)) {
if ( newcontractOpValue != NULL)
yieldOp.setOperand(0, newcontractOpValue);
}
if (auto vectorWriteOp = llvm::dyn_cast<mlir::vector::TransferWriteOp>(op)) {
vectorWriteOperation = vectorWriteOp;
}
}

if (vectorWriteOperation != NULL) {
vectorWriteOperation.setOperand(0,newReductionForOp.getResult(0));
vectorWriteOperation->moveBefore(oldReductionForOp);
}

// Erase the old vector contract operation
for (auto result : contractOp->getResults()) {
for (auto *userOp : result.getUsers()) {
userOp->erase();
}
}
contractOp.erase();

return success();
}
};


void populateHoistVectorTransferPatterns(RewritePatternSet &patterns) {
patterns.add<HoistVectorTransferOp>(patterns.getContext());
}

struct HoistVectorTransfers
: public impl::HoistVectorTransfersBase<HoistVectorTransfers> {
using HoistVectorTransfersBase::HoistVectorTransfersBase;

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateHoistVectorTransferPatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
}
};
} // namespace tpp
} // namespace mlir
59 changes: 59 additions & 0 deletions test/Integration/hoist-vector-transfer-brgemm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// RUN: tpp-opt %s | tpp-run -e entry --entry-point-result=void -print > %t.1
// RUN: tpp-opt %s --loop-invariant-code-motion --vectorization-pass --loop-invariant-code-motion --hoist-vector-transfer | tpp-run -e entry --entry-point-result=void -print > %t.2
// RUN: diff %t.1 %t.2

memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
func.func @entry(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
%c1 = arith.constant 1 : index
%c24 = arith.constant 24 : index
%c64 = arith.constant 64 : index
%c4 = arith.constant 4 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
scf.forall (%arg1, %arg2) in (8, 24) {
%subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
linalg.fill ins(%cst : f32) outs(%subview : memref<32x64xf32, strided<[64, 1], offset: ?>>)
%subview_0 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
scf.for %arg3 = %c0 to %c32 step %c4 {
scf.for %arg4 = %c0 to %c64 step %c64 {
%subview_1 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
scf.for %arg5 = %c0 to %c24 step %c1 {
scf.for %arg6 = %c0 to %c64 step %c1 {
%subview_2 = memref.subview %subview_0[%arg5, %arg3, %arg6] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
%subview_3 = memref.subview %0[%arg5, %arg6, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
linalg.batch_reduce_matmul ins(%subview_2, %subview_3 : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>) outs(%subview_1 : memref<4x64xf32, strided<[64, 1], offset: ?>>)
}
}
}
}
}
return %alloc : memref<8x24x32x64xf32>
}

// -----

// RUN: tpp-opt %s | tpp-run -e nomatch --entry-point-result=void -seed 123 -print > %t.1
// RUN: tpp-opt %s --hoist-vector-transfer | tpp-run -e nomatch --entry-point-result=void -seed 123 -print > %t.2
// RUN: diff %t.1 %t.2

#permA0 = affine_map<(d0, d1, d2) -> (d2, d0)>
#permA1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#permA2 = affine_map<(d0, d1, d2) -> (d0, d1)>

func.func @nomatch(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%3 = vector.contract {indexing_maps = [#permA0, #permA1, #permA2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %0, %1, %2
: vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
%4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
return %4 : tensor<4x4xf32>
}

Loading