Skip to content

Commit

Permalink
Make sure to obtain the right DPS init/out/dest operand
Browse files Browse the repository at this point in the history
  • Loading branch information
rolfmorel committed Sep 7, 2024
1 parent e11e15c commit 1983ee8
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions lib/TPP/Transforms/FoldAddIntoDest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
#include "TPP/Passes.h"
#include "TPP/Transforms/Utils/ValueUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;

Expand Down Expand Up @@ -62,14 +64,19 @@ struct FoldAddIntoDestRewrite : public OpRewritePattern<linalg::AddOp> {
// When the dominated op has a single-result and is a contraction, the op
// accumulates on the out argument, starting from the supplied out argument.
// E.g., AddOp is not a contraction and hence ignores its out arg's value.
auto dominatedDestOp =
dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
if (dominatedOp->getNumResults() != 1 ||
!linalg::isaContractionOpInterface(dominatedOp))
!linalg::isaContractionOpInterface(dominatedOp) ||
(!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
return rewriter.notifyMatchFailure(
dominatedOp, "expected dominated op to be single-result contraction");
dominatedOp, "expected dominated op to be single-result "
"destination-passing contraction");

// As the dominated op was already accumulating on its out argument, it is
// only safe to discard its current out arg when it is the additive zero.
auto *dominatedDest = dominatedOp->getOperand(2).getDefiningOp();
auto *dominatedDest =
dominatedDestOp.getDpsInitOperand(0)->get().getDefiningOp();
if (!mlir::utils::isZeroOp(dominatedDest))
return rewriter.notifyMatchFailure(
dominatedOp, "expected dominated op's dest to be additive zero");
Expand Down

0 comments on commit 1983ee8

Please sign in to comment.