Skip to content

Commit

Permalink
Add a fold-add-into-dest pattern and pass (#966)
Browse files Browse the repository at this point in the history
Fold a linalg.add into the dest/out of the contraction op of one of
its operands. The sum with the `other` operand will be performed by
having the contraction accumulate on top of `other`. Note that this
changes the result of the contraction - hence the linalg.add must be
its single use.

Main value is likely to come from that a buffer allocation can be 
elided, in case the contraction op needed a buffer to be allocated 
for its result.
  • Loading branch information
rolfmorel authored Sep 12, 2024
1 parent 706e710 commit f4068c1
Show file tree
Hide file tree
Showing 7 changed files with 348 additions and 2 deletions.
12 changes: 12 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,18 @@ def ConstantFoldPack : Pass<"constant-fold-pack", "ModuleOp"> {
"arith::ArithDialect"];
}

def FoldAddIntoDest : Pass<"fold-add-into-dest", "ModuleOp"> {
let summary = "Fold linalg.add into dest of contraction op";
let description = [{
Replace a linalg.add with one operand the single user of a contraction,
which has a zero-filled, "identity-mapped" destination and is dominated by
the `other` operand, by the contraction with `other` as its dest.
}];
let dependentDialects = ["linalg::LinalgDialect",
"tensor::TensorDialect",
"arith::ArithDialect"];
}

def ElementWiseFusion : Pass<"element-wise-fusion", "func::FuncOp"> {
let summary = "Run linalg element-wise fusion";
}
Expand Down
4 changes: 4 additions & 0 deletions include/TPP/Transforms/Utils/ValueUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
namespace mlir {
class Value;
class OpBuilder;
class Operation;
namespace utils {

// Returns true if the value is a constant float or integer.
Expand All @@ -20,6 +21,9 @@ bool isValConstZero(Value val);
// Returns true if the op defining `val` represents a zero filled tensor.
bool isZeroTensor(Value val);

// Returns true if the operation represents a zero filled tensor.
bool isZeroOp(Operation *);

// Returns the strides of `val`. The method returns something usefull
// only if the `val` type is a strided memref and the strides are statically
// known.
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ struct DefaultTppPasses

private:
void constructPipeline() override {
pm.addPass(createFoldAddIntoDest());
if (linalgToLoops) {
// Lower linalg directly to loops.
// Skip all TPP transformations.
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_mlir_library(TPPTransforms
LinalgConvertCompareSelectToMaximumfPass.cpp
ConvertLinalgToInplace.cpp
FoldIntoEltwise.cpp
FoldAddIntoDest.cpp
Vectorization.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
129 changes: 129 additions & 0 deletions lib/TPP/Transforms/FoldAddIntoDest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TPP/Passes.h"
#include "TPP/Transforms/Utils/ValueUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_FOLDADDINTODEST
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

namespace {

/// Replace a linalg.add with one operand the single user of a contraction,
/// which has a zero-filled, "identity-mapped" destination and is dominated by
/// the `other` operand, by the contraction with `other` as its dest.
struct FoldAddIntoDestRewrite : public OpRewritePattern<linalg::AddOp> {
using OpRewritePattern<linalg::AddOp>::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::AddOp addOp,
PatternRewriter &rewriter) const override {
Value dominatingOperand = nullptr;
linalg::LinalgOp dominatedOp = nullptr;
{
auto firstOperand = addOp.getOperand(0);
auto secondOperand = addOp.getOperand(1);

// Can only put one of addOp's operands in the dest/out arg of the other's
// defining op based on suitable dominance.
if (auto secondOp = secondOperand.getDefiningOp<linalg::LinalgOp>()) {
DominanceInfo domInfo(secondOp);
if (domInfo.properlyDominates(firstOperand, secondOp)) {
dominatingOperand = firstOperand;
dominatedOp = secondOp;
}
}
if (auto firstOp = firstOperand.getDefiningOp<linalg::LinalgOp>()) {
DominanceInfo domInfo(firstOp);
if (domInfo.properlyDominates(secondOperand, firstOp)) {
dominatingOperand = secondOperand;
dominatedOp = firstOp;
}
}
if (!dominatingOperand || !dominatedOp)
return failure();
// NB: As linalg.add's generalisation ignores the out argument in its
// region there is no need to perform checks on addOp's out argument.
}

// Dominated op must be a contraction for it to accumulate on its out arg.
// E.g., AddOp is not a contraction and hence ignores its out arg's value.
auto dominatedDestOp =
dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
if (dominatedOp->getNumResults() != 1 ||
!linalg::isaContractionOpInterface(dominatedOp) ||
(!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
return rewriter.notifyMatchFailure(
dominatedOp, "expected dominated op to be single-result "
"destination-passing contraction");

// To change the contraction's result, `addOp` must be its only user.
if (!dominatedOp->getResult(0).hasOneUse())
return rewriter.notifyMatchFailure(
dominatedOp,
"expected linalg.add to be single user of contraction's result");

// As `dominatedOp` was already accumulating on its out argument, it is only
// safe to no longer use its current out arg when it is the additive zero.
auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
if (!mlir::utils::isZeroOp(destOperand->get().getDefiningOp()))
return rewriter.notifyMatchFailure(
dominatedOp, "expected dominated op's dest to be additive zero");
// TODO: If the other op is a contraction and has additive zero as dest, we
// can swap the dests and achieve the proper sum, given suitable dominance.

// As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
// Hence, we can only substitute `dominatingOperand` for the dest of the
// contraction when dest's index_map corresponds to an identity map w.r.t.
// just the dimensions of dest, i.e. is an ordered projection.
SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
int prevDimPos = -1;
for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
auto dim = dyn_cast<AffineDimExpr>(expr);
if (!dim || prevDimPos >= (int)dim.getPosition())
return rewriter.notifyMatchFailure(
dominatedOp, "expected index_map for contraction's dest to be an "
"ordered projection");
prevDimPos = dim.getPosition();
}

// Replace the additive-zero out argument of the dominated op by the
// dominating summand. This makes the dominated op's result the sum of both
// of addOp's arguments - therefore we replace addOp and it uses by it.
rewriter.modifyOpInPlace(
dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
return success();
}
};

/// Replace linalg.add when destination passing suffices for achieving the sum.
struct FoldAddIntoDest
: public tpp::impl::FoldAddIntoDestBase<FoldAddIntoDest> {

void runOnOperation() override {
auto *ctx = &getContext();

RewritePatternSet patterns(ctx);
patterns.add<FoldAddIntoDestRewrite>(ctx);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace
4 changes: 2 additions & 2 deletions lib/TPP/Transforms/Utils/ValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static bool isZeroAttr(Attribute attribute) {
}

// Prototypes
static bool isZeroOp(Operation *);
bool isZeroOp(Operation *);

// Returns true if the value represents a zero filled tensor.
// Recurse into isZeroOp for defining ops if not immediately obvious
Expand Down Expand Up @@ -70,7 +70,7 @@ bool isZeroTensor(Value val) {

// Returns true if the operation represents a zero filled tensor
// Recurses into isZeroTensor for operands and isZeroAttr for attributes
static bool isZeroOp(Operation *defOp) {
bool isZeroOp(Operation *defOp) {
if (!defOp)
return false;

Expand Down
Loading

0 comments on commit f4068c1

Please sign in to comment.