Skip to content

Commit

Permalink
Add a fold-add-into-dest pattern and pass
Browse files Browse the repository at this point in the history
Achieves the sum of linalg.add by relying on the accumulation semantics
of contraction ops w.r.t. their out/dest argument. Main value is likely
to come from that a buffer allocation can be elided, in case the
contraction op had its own buffer.
  • Loading branch information
rolfmorel committed Sep 7, 2024
1 parent 9c53a05 commit 93f8831
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 2 deletions.
10 changes: 10 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,16 @@ def ConstantFoldPack : Pass<"constant-fold-pack", "ModuleOp"> {
"arith::ArithDialect"];
}

def FoldAddIntoDest : Pass<"fold-add-into-dest", "ModuleOp"> {
let summary = "TODO";
let description = [{
TODO.
}];
let dependentDialects = ["linalg::LinalgDialect",
"tensor::TensorDialect",
"arith::ArithDialect"];
}

def ElementWiseFusion : Pass<"element-wise-fusion", "func::FuncOp"> {
let summary = "Run linalg element-wise fusion";
}
Expand Down
4 changes: 4 additions & 0 deletions include/TPP/Transforms/Utils/ValueUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
namespace mlir {
class Value;
class OpBuilder;
class Operation;
namespace utils {

// Returns true if the value is a constant float or integer.
Expand All @@ -20,6 +21,9 @@ bool isValConstZero(Value val);
// Returns true if the op defining `val` represents a zero filled tensor.
bool isZeroTensor(Value val);

// Returns true if the operation represents a zero filled tensor.
bool isZeroOp(Operation *);

// Returns the strides of `val`. The method returns something usefull
// only if the `val` type is a strided memref and the strides are statically
// known.
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_mlir_library(TPPTransforms
LinalgConvertCompareSelectToMaximumfPass.cpp
ConvertLinalgToInplace.cpp
FoldIntoEltwise.cpp
FoldAddIntoDest.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
101 changes: 101 additions & 0 deletions lib/TPP/Transforms/FoldAddIntoDest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TPP/Passes.h"
#include "TPP/Transforms/Utils/ValueUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_FOLDADDINTODEST
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

namespace {

/// Replace a linalg.add where its linalg-contraction operand - with a
/// zero-filled destination - is dominated by the `other` linalg operand,
/// by passing `other` as the destination of the contraction.
struct FoldAddIntoDestRewrite : public OpRewritePattern<linalg::AddOp> {
using OpRewritePattern<linalg::AddOp>::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::AddOp addOp,
PatternRewriter &rewriter) const override {
Value dominatingOperand = nullptr;
linalg::LinalgOp dominatedOp = nullptr;
{
auto firstOperand = addOp.getOperand(0);
auto secondOperand = addOp.getOperand(1);

// Can only put one of addOp's operands in the dest/out arg of the other's
// defining op based on suitable dominance.
if (auto secondOp = secondOperand.getDefiningOp<linalg::LinalgOp>()) {
DominanceInfo domInfo(secondOp);
if (domInfo.properlyDominates(firstOperand, secondOp)) {
dominatingOperand = firstOperand;
dominatedOp = secondOp;
}
}
if (auto firstOp = firstOperand.getDefiningOp<linalg::LinalgOp>()) {
DominanceInfo domInfo(firstOp);
if (domInfo.properlyDominates(secondOperand, firstOp)) {
dominatingOperand = secondOperand;
dominatedOp = firstOp;
}
}
if (!dominatingOperand || !dominatedOp)
return failure();
// NB: As linalg.add's generalisation ignores the out argument in its
// region there is no need to perform checks on addOp's out argument.
}

// When the dominated op has a single-result and is a contraction, the op
// accumulates on the out argument, starting from the supplied out argument.
// E.g., AddOp is not a contraction and hence ignores its out arg's value.
if (dominatedOp->getNumResults() != 1 ||
!linalg::isaContractionOpInterface(dominatedOp))
return rewriter.notifyMatchFailure(
dominatedOp, "expected dominated op to be single-result contraction");

// As the dominated op was already accumulating on its out argument, it is
// only safe to discard its current out arg when it is the additive zero.
auto *dominatedDest = dominatedOp->getOperand(2).getDefiningOp();
if (!mlir::utils::isZeroOp(dominatedDest))
return rewriter.notifyMatchFailure(
dominatedOp, "expected dominated op's dest to be additive zero");

// Replace the additive-zero out argument of the dominated op by the
// dominating summand. This makes the dominated op's result the sum of both
// of addOp's arguments - therefore we replace addOp and it uses by it.
rewriter.modifyOpInPlace(
dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
return success();
}
};

/// Replace linalg.add when destination passing suffices for achieving the sum.
struct FoldAddIntoDest
: public tpp::impl::FoldAddIntoDestBase<FoldAddIntoDest> {

void runOnOperation() override {
auto *ctx = &getContext();

RewritePatternSet patterns(ctx);
patterns.add<FoldAddIntoDestRewrite>(ctx);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace
4 changes: 2 additions & 2 deletions lib/TPP/Transforms/Utils/ValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static bool isZeroAttr(Attribute attribute) {
}

// Prototypes
static bool isZeroOp(Operation *);
bool isZeroOp(Operation *);

// Returns true if the value represents a zero filled tensor.
// Recurse into isZeroOp for defining ops if not immediately obvious
Expand Down Expand Up @@ -70,7 +70,7 @@ bool isZeroTensor(Value val) {

// Returns true if the operation represents a zero filled tensor
// Recurses into isZeroTensor for operands and isZeroAttr for attributes
static bool isZeroOp(Operation *defOp) {
bool isZeroOp(Operation *defOp) {
if (!defOp)
return false;

Expand Down

0 comments on commit 93f8831

Please sign in to comment.