Skip to content

Commit

Permalink
Allow creation of TCP groups where an op has multiple uses (cruise-au…
Browse files Browse the repository at this point in the history
…tomation#74)

We previously only allowed an op to have a single use during the
creation a group for it. This PR relaxes that to allow multiple uses as
long as all the uses belong to the same region.

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
  • Loading branch information
srinathava and Srinath Avadhanula authored Jun 24, 2024
1 parent 6fafa7f commit b6ae56c
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 56 deletions.
200 changes: 144 additions & 56 deletions lib/Dialect/Transforms/FusionPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,70 +12,158 @@
#include "mlir-tcp/Dialect/IR/TcpOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/Support/Debug.h"

#ifndef NDEBUG
#define DEBUG_TYPE "tcp-fusion-patterns"
#endif

namespace mlir::tcp {

LogicalResult
GenericBottomUpFuser::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Operation *use = op;
bool isChanged = false;
for (auto operand : op->getOperands()) {
if (operand.getDefiningOp()) {
Operation *def = operand.getDefiningOp();
if (canFuse(def, use)) {

// Currently we are only fusing ops at the top-level.
// This is to avoid recursing inside a group and ending up with
// nested groups that contain the same ops.
// Since we are iterating bottom up in a block, we only need to check
// if the def op has a func parent.
//
// TODO: Remove this restriction to allow fusing in nested regions.
if (!isa<func::FuncOp>(def->getParentOp())) {
continue;
}

// We only support fusing def ops that have exactly one use, for now.
if (!def->hasOneUse()) {
continue;
}

// Fuse the def and use ops into a group.

// * If both the ops have the same parent region, they must be part
// of the top-level func. So, we need to create a new group.
// * The only other case is when the def op is part of the top-level
// func and the use is already inside a group.
isChanged = true;
if (def->getParentRegion() == use->getParentRegion()) {
auto groupOp = rewriter.create<tcp::GroupOp>(use->getLoc(),
use->getResultTypes());
if (postFunc) {
postFunc(groupOp, rewriter);
}
Block *groupBlock = new Block();
groupOp.getBody().push_back(groupBlock);
for (unsigned num = 0; num < use->getNumResults(); ++num) {
rewriter.replaceAllUsesWith(use->getResult(num),
groupOp->getResult(num));
}
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(groupBlock);
auto yieldOp =
rewriter.create<tcp::YieldOp>(use->getLoc(), use->getResults());
use->moveBefore(yieldOp);
operand.getDefiningOp()->moveBefore(use);
}
} else if (auto groupOp = dyn_cast<tcp::GroupOp>(use->getParentOp())) {
def->moveBefore(use);
} else {
llvm_unreachable("Unhandled case during fusion");
}

// Currently we are only fusing ops at the top-level.
// This is to avoid recursing inside a group and ending up with
// nested groups that contain the same ops.
// Since we are iterating bottom up in a block, we only need to check
// if the def op has a func parent.
//
// TODO: Remove this restriction to allow fusing in nested regions.
if (!isa<func::FuncOp>(op->getParentOp()))
return failure();

if (op->use_empty())
return failure();

// We can only fuse a def with multiple uses if all the uses belong to the
// same region and can be fused with the defining op
Region *usesParentRegion = nullptr;
SmallVector<Operation *> uses;
llvm::DenseSet<Operation *> usesSet;

for (auto &use : op->getUses()) {
auto parentRegion = use.getOwner()->getParentRegion();
if (usesParentRegion && usesParentRegion != parentRegion)
return failure();
usesParentRegion = parentRegion;

if (!canFuse(op, use.getOwner()))
return failure();

if (usesSet.insert(use.getOwner()).second)
uses.push_back(use.getOwner());
}

// Sorting by dominance ensures that the first element of this vector is
// the first use of the def. Used below when we want to move the op into
// an existing group.
LLVM_DEBUG(llvm::dbgs() << "Processing op: " << *op << " with " << uses.size()
<< " uses\n");
DominanceInfo domInfo;
llvm::stable_sort(uses, [&](Operation *a, Operation *b) {
return domInfo.dominates(a, b);
});

#ifndef NDEBUG
for (auto use : uses) {
LLVM_DEBUG(llvm::dbgs() << "Use: " << *use << "\n");
}
#endif

if (op->getParentRegion() == usesParentRegion) {
LLVM_DEBUG(llvm::dbgs() << "Creating new group\n");
// this case can only happen when all ops belong to the function.
SmallVector<Type> allResultTypes;
SmallVector<Value> allResults;
for (auto use : uses) {
allResultTypes.append(use->getResultTypes().begin(),
use->getResultTypes().end());
allResults.append(use->getResults().begin(), use->getResults().end());
}

auto groupOp = rewriter.create<tcp::GroupOp>(op->getLoc(), allResultTypes);
if (postFunc) {
postFunc(groupOp, rewriter);
}
Block *groupBlock = new Block();
groupOp.getBody().push_back(groupBlock);

// First move all uses into the group in the dominance order
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(groupBlock);
auto yieldOp = rewriter.create<tcp::YieldOp>(op->getLoc(), allResults);
// This is where we are using the sorted-ness of `uses`. We are
// guaranteed that if the users of the op themselves depend on each
// other, then we'll move them in the correct order.
for (auto use : uses) {
use->moveBefore(yieldOp);
}
op->moveBefore(*uses.begin());
}

// We then replace all uses of the uses which lie outside the group
// with the group's results. We should not replace uses inside the
// group otherwise ops inside the group will end up depending on the
// group's results causing dominance issues.
size_t groupResultNum = 0;
for (auto use : uses) {
for (unsigned num = 0; num < use->getNumResults(); ++num) {
auto useIsOutsideGroup = [&](OpOperand &operand) {
return operand.getOwner()->getParentOp() != groupOp;
};
rewriter.replaceUsesWithIf(use->getResult(num),
groupOp->getResult(groupResultNum),
useIsOutsideGroup);
groupResultNum++;
}
}

} else if (auto groupOp =
dyn_cast<tcp::GroupOp>(usesParentRegion->getParentOp())) {
// Given that we iterate over the funcop in a bottom up manner, when moving
// into an existing group, we would be guaranteed that this op does not use
// any of the ops already in the group. So we can move it to the very
// beginning of the group. This ensures that the order of operands is
// preserved when creating a group. For example, if we start with
// something like:
//
// %0 = op1(%in1)
// %1 = op2(%in2)
// %2 = op3(%0, %1)
//
// we'll first create a %1 and %2
//
// %0 = op1(%in1)
// %3 = tcp.group {
// %1 = op2(%in2)
// %2 = op3(%0, %1)
// }
//
// if we try to move %0 to right before its use in the group, then we'd
// end up with:
//
// %3 = tcp.group {
// %1 = op2(%in2)
// %0 = op1(%in1)
// %2 = op3(%0, %1)
// }
//
// While this is not incorrect, it is a bit annoying that the MLIR gets
// reordered.
auto &firstOp = *usesParentRegion->getOps().begin();
op->moveBefore(&firstOp);
} else {
op->emitError("Unhandled case during fusion");
llvm_unreachable("Unhandled case during fusion");
}
return isChanged ? success() : failure();
LLVM_DEBUG(llvm::dbgs() << "Function after transformation:\n"
<< op->getParentOfType<func::FuncOp>() << "\n");
return success();
}

} // namespace mlir::tcp
46 changes: 46 additions & 0 deletions test/Dialect/tcp_fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,49 @@ func.func @test_multiple_fusions(%arg0 : tensor<?x?xf32>,
%6 = tcp.sub %arg1, %5 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
return %6 : tensor<?x?xf32>
}

// -----

// Fusion with multiple uses where the def with multiple uses moves into an
// already created group.

// CHECK: func.func @test_multi_use_fusion(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[V0:.+]] = tcp.group {
// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V2:.+]] = tcp.add %[[V1]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V3:.+]] = tcp.sub %[[V2]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V4:.+]] = tcp.mul %[[V2]], %[[V3]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V4]] : tensor<?x?xf32>
// CHECK: } : tensor<?x?xf32>
// CHECK: return %[[V0]] : tensor<?x?xf32>
// CHECK: }
func.func @test_multi_use_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = tcp.tanh %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
%1 = tcp.add %0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
%2 = tcp.sub %1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
%3 = tcp.mul %1, %2 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
return %3 : tensor<?x?xf32>
}

// -----

// Fusion with multiple uses where the def and the multiple uses create a
// new group. Here we test that the moves use the dominance correctly.

// CHECK: func.func @test_multi_use_fusion(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
// CHECK: %[[V0:.+]]:2 = tcp.group {
// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V2:.+]] = tcp.add %[[V1]], %[[V1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V3:.+]] = tcp.sub %[[V2]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V4:.+]] = tcp.mul %[[V2]], %[[V3]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V3]], %[[V4]] : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK: } : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK: return %[[V0]]#0, %[[V0]]#1 : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK: }
func.func @test_multi_use_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = tcp.tanh %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
%1 = tcp.add %0, %0 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
%2 = tcp.sub %1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
%3 = tcp.mul %1, %2 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
"func.return" (%2, %3) : (tensor<?x?xf32>, tensor<?x?xf32>) -> ()
}

0 comments on commit b6ae56c

Please sign in to comment.