Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow creation of TCP groups where an op has multiple uses #74

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 127 additions & 52 deletions lib/Dialect/Transforms/FusionPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,70 +12,145 @@
#include "mlir-tcp/Dialect/IR/TcpOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/Support/Debug.h"

#ifndef NDEBUG
#define DEBUG_TYPE "tcp-fusion-patterns"
#endif

namespace mlir::tcp {

LogicalResult
GenericBottomUpFuser::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Operation *use = op;
bool isChanged = false;
for (auto operand : op->getOperands()) {
if (operand.getDefiningOp()) {
Operation *def = operand.getDefiningOp();
if (canFuse(def, use)) {

// Currently we are only fusing ops at the top-level.
// This is to avoid recursing inside a group and ending up with
// nested groups that contain the same ops.
// Since we are iterating bottom up in a block, we only need to check
// if the def op has a func parent.
//
// TODO: Remove this restriction to allow fusing in nested regions.
if (!isa<func::FuncOp>(def->getParentOp())) {
continue;
}
// Currently we are only fusing ops at the top-level.
// This is to avoid recursing inside a group and ending up with
// nested groups that contain the same ops.
// Since we are iterating bottom up in a block, we only need to check
// if the def op has a func parent.
//
// TODO: Remove this restriction to allow fusing in nested regions.
if (!isa<func::FuncOp>(op->getParentOp())) {
return failure();
}
srinathava marked this conversation as resolved.
Show resolved Hide resolved

if (op->use_empty())
return failure();

// We can only fuse a def with multiple uses if all the uses belong to the
// same region and can be fused with the defining op
Region *usesParentRegion = nullptr;
SmallVector<Operation *> uses;
llvm::DenseSet<Operation *> usesSet;

for (auto &use : op->getUses()) {
auto parentRegion = use.getOwner()->getParentRegion();
if (usesParentRegion && usesParentRegion != parentRegion) {
return failure();
}
srinathava marked this conversation as resolved.
Show resolved Hide resolved
usesParentRegion = parentRegion;

if (!canFuse(op, use.getOwner())) {
return failure();
}
if (usesSet.insert(use.getOwner()).second) {
uses.push_back(use.getOwner());
}
srinathava marked this conversation as resolved.
Show resolved Hide resolved
}

// Sorting by dominance ensures that the first element of this vector is
// the first use of the def. Used below when we want to move the op into
// an existing group.
LLVM_DEBUG(llvm::dbgs() << "Processing op: " << *op << " with " << uses.size()
<< " uses\n");
DominanceInfo domInfo;
llvm::stable_sort(uses, [&](Operation *a, Operation *b) {
return !domInfo.dominates(a, b);
});

// We only support fusing def ops that have exactly one use, for now.
if (!def->hasOneUse()) {
continue;
}
if (op->getParentRegion() == usesParentRegion) {
LLVM_DEBUG(llvm::dbgs() << "Creating new group\n");
// this case can only happen when all ops belong to the function.
SmallVector<Type> allResultTypes;
SmallVector<Value> allResults;
for (auto use : uses) {
allResultTypes.append(use->getResultTypes().begin(),
use->getResultTypes().end());
allResults.append(use->getResults().begin(), use->getResults().end());
}

// Fuse the def and use ops into a group.
auto groupOp = rewriter.create<tcp::GroupOp>(op->getLoc(), allResultTypes);
if (postFunc) {
postFunc(groupOp, rewriter);
}
Block *groupBlock = new Block();
groupOp.getBody().push_back(groupBlock);

// * If both the ops have the same parent region, they must be part
// of the top-level func. So, we need to create a new group.
// * The only other case is when the def op is part of the top-level
// func and the use is already inside a group.
isChanged = true;
if (def->getParentRegion() == use->getParentRegion()) {
auto groupOp = rewriter.create<tcp::GroupOp>(use->getLoc(),
use->getResultTypes());
if (postFunc) {
postFunc(groupOp, rewriter);
}
Block *groupBlock = new Block();
groupOp.getBody().push_back(groupBlock);
for (unsigned num = 0; num < use->getNumResults(); ++num) {
rewriter.replaceAllUsesWith(use->getResult(num),
groupOp->getResult(num));
}
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(groupBlock);
auto yieldOp =
rewriter.create<tcp::YieldOp>(use->getLoc(), use->getResults());
use->moveBefore(yieldOp);
operand.getDefiningOp()->moveBefore(use);
}
} else if (auto groupOp = dyn_cast<tcp::GroupOp>(use->getParentOp())) {
def->moveBefore(use);
} else {
llvm_unreachable("Unhandled case during fusion");
}
size_t groupResultNum = 0;
for (auto use : uses) {
for (unsigned num = 0; num < use->getNumResults(); ++num) {
rewriter.replaceAllUsesWith(use->getResult(num),
groupOp->getResult(groupResultNum));
groupResultNum++;
}
}

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(groupBlock);
auto yieldOp = rewriter.create<tcp::YieldOp>(op->getLoc(), allResults);
// This is where we are using the sorted-ness of `uses`. We are
// guaranteed that if the users of the op themselves depend on each
// other, then we'll move them in the correct order.
for (auto use : uses) {
use->moveBefore(yieldOp);
}
op->moveBefore(*uses.begin());
}
} else if (auto groupOp =
dyn_cast<tcp::GroupOp>(usesParentRegion->getParentOp())) {
// Given that we iterate over the funcop in a bottom up manner, when moving
// into an existing group, we would be guaranteed that this op does not use
// any of the ops already in the group. So we can move it to the very
// beginning of the group. This ensures that the order of operands is
// preserved when creating a gruop. Otherwise, if we start with
srinathava marked this conversation as resolved.
Show resolved Hide resolved
srinathava marked this conversation as resolved.
Show resolved Hide resolved
// something like:
//
// %0 = op1(%in1)
// %1 = op2(%in2)
// %2 = op3(%0, %1)
//
// we'll first create a %1 and %2
//
// %0 = op1(%in1)
// %3 = tcp.group {
// %1 = op2(%in2)
// %2 = op3(%0, %1)
// }
//
// if we try to move %0 to right before its use in the group, then we'd
// end up with:
//
// %3 = tcp.group {
// %1 = op2(%in2)
// %0 = op1(%in1)
// %2 = op3(%0, %1)
// }
//
// While this is not incorrect, it is a bit annoying that the MLIR gets
// reordered.
auto &firstOp = *usesParentRegion->getOps().begin();
op->moveBefore(&firstOp);
} else {
op->emitError("Unhandled case during fusion");
llvm_unreachable("Unhandled case during fusion");
}
return isChanged ? success() : failure();
LLVM_DEBUG(llvm::dbgs() << "Function after transformation:\n"
<< op->getParentOfType<func::FuncOp>() << "\n");
return success();
}

} // namespace mlir::tcp
40 changes: 40 additions & 0 deletions test/Dialect/tcp_fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,43 @@ func.func @test_multiple_fusions(%arg0 : tensor<?x?xf32>,
%6 = tcp.sub %arg1, %5 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
return %6 : tensor<?x?xf32>
}

// -----

// CHECK: func.func @test_multi_use_fusion(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[V0:.+]] = tcp.group {
// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V2:.+]] = tcp.add %[[V1]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V3:.+]] = tcp.sub %[[V2]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V4:.+]] = tcp.mul %[[V2]], %[[V3]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V4]] : tensor<?x?xf32>
// CHECK: } : tensor<?x?xf32>
// CHECK: return %[[V0]] : tensor<?x?xf32>
// CHECK: }
func.func @test_multi_use_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: call this something different from below (is this single use fusion?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comments explaining what the test intent is. These are both for testing fusion with multiple uses. The difference is in whether the multiple uses are already part of a tcp group or whether we are creating a new group from the multiple uses for the first time.

%0 = tcp.tanh %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
%1 = tcp.add %0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
%2 = tcp.sub %1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
%3 = tcp.mul %1, %2 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
return %3 : tensor<?x?xf32>
}

// -----

// CHECK: func.func @test_multi_use_fusion(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
// CHECK: %[[V0:.+]]:2 = tcp.group {
// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V2:.+]] = tcp.add %[[V1]], %[[V1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V3:.+]] = tcp.mul %[[V2]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[V4:.+]] = tcp.sub %[[V2]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: tcp.yield %[[V3]], %[[V4]] : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK: } : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK: return %[[V0]]#1, %[[V0]]#0 : tensor<?x?xf32>, tensor<?x?xf32>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea why the order is swapped here?
Should it be return %[[V0]]#0, %[[V0]]#1 : tensor<?x?xf32>, tensor<?x?xf32> instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the close eyes! This actually uncovered a bug in the algorithm... I fixed it and added a more robust test which was what the intent of the second test was all along.

// CHECK: }
func.func @test_multi_use_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = tcp.tanh %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
%1 = tcp.add %0, %0 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
%2 = tcp.sub %1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
%3 = tcp.mul %1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
"func.return" (%2, %3) : (tensor<?x?xf32>, tensor<?x?xf32>) -> ()
}