diff --git a/lib/Dialect/Transforms/FusionPatterns.cpp b/lib/Dialect/Transforms/FusionPatterns.cpp index 048d7c03..64919052 100644 --- a/lib/Dialect/Transforms/FusionPatterns.cpp +++ b/lib/Dialect/Transforms/FusionPatterns.cpp @@ -12,70 +12,158 @@ #include "mlir-tcp/Dialect/IR/TcpOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/OpDefinition.h" +#include "llvm/Support/Debug.h" + +#ifndef NDEBUG +#define DEBUG_TYPE "tcp-fusion-patterns" +#endif namespace mlir::tcp { + LogicalResult GenericBottomUpFuser::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - Operation *use = op; - bool isChanged = false; - for (auto operand : op->getOperands()) { - if (operand.getDefiningOp()) { - Operation *def = operand.getDefiningOp(); - if (canFuse(def, use)) { - - // Currently we are only fusing ops at the top-level. - // This is to avoid recursing inside a group and ending up with - // nested groups that contain the same ops. - // Since we are iterating bottom up in a block, we only need to check - // if the def op has a func parent. - // - // TODO: Remove this restriction to allow fusing in nested regions. - if (!isa(def->getParentOp())) { - continue; - } - - // We only support fusing def ops that have exactly one use, for now. - if (!def->hasOneUse()) { - continue; - } - - // Fuse the def and use ops into a group. - - // * If both the ops have the same parent region, they must be part - // of the top-level func. So, we need to create a new group. - // * The only other case is when the def op is part of the top-level - // func and the use is already inside a group. - isChanged = true; - if (def->getParentRegion() == use->getParentRegion()) { - auto groupOp = rewriter.create(use->getLoc(), - use->getResultTypes()); - if (postFunc) { - postFunc(groupOp, rewriter); - } - Block *groupBlock = new Block(); - groupOp.getBody().push_back(groupBlock); - for (unsigned num = 0; num < use->getNumResults(); ++num) { - rewriter.replaceAllUsesWith(use->getResult(num), - groupOp->getResult(num)); - } - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(groupBlock); - auto yieldOp = - rewriter.create(use->getLoc(), use->getResults()); - use->moveBefore(yieldOp); - operand.getDefiningOp()->moveBefore(use); - } - } else if (auto groupOp = dyn_cast(use->getParentOp())) { - def->moveBefore(use); - } else { - llvm_unreachable("Unhandled case during fusion"); - } + + // Currently we are only fusing ops at the top-level. + // This is to avoid recursing inside a group and ending up with + // nested groups that contain the same ops. + // Since we are iterating bottom up in a block, we only need to check + // if the def op has a func parent. + // + // TODO: Remove this restriction to allow fusing in nested regions. + if (!isa(op->getParentOp())) + return failure(); + + if (op->use_empty()) + return failure(); + + // We can only fuse a def with multiple uses if all the uses belong to the + // same region and can be fused with the defining op + Region *usesParentRegion = nullptr; + SmallVector uses; + llvm::DenseSet usesSet; + + for (auto &use : op->getUses()) { + auto parentRegion = use.getOwner()->getParentRegion(); + if (usesParentRegion && usesParentRegion != parentRegion) + return failure(); + usesParentRegion = parentRegion; + + if (!canFuse(op, use.getOwner())) + return failure(); + + if (usesSet.insert(use.getOwner()).second) + uses.push_back(use.getOwner()); + } + + // Sorting by dominance ensures that the first element of this vector is + // the first use of the def. Used below when we want to move the op into + // an existing group. + LLVM_DEBUG(llvm::dbgs() << "Processing op: " << *op << " with " << uses.size() + << " uses\n"); + DominanceInfo domInfo; + llvm::stable_sort(uses, [&](Operation *a, Operation *b) { + return domInfo.dominates(a, b); + }); + +#ifndef NDEBUG + for (auto use : uses) { + LLVM_DEBUG(llvm::dbgs() << "Use: " << *use << "\n"); + } +#endif + + if (op->getParentRegion() == usesParentRegion) { + LLVM_DEBUG(llvm::dbgs() << "Creating new group\n"); + // this case can only happen when all ops belong to the function. + SmallVector allResultTypes; + SmallVector allResults; + for (auto use : uses) { + allResultTypes.append(use->getResultTypes().begin(), + use->getResultTypes().end()); + allResults.append(use->getResults().begin(), use->getResults().end()); + } + + auto groupOp = rewriter.create(op->getLoc(), allResultTypes); + if (postFunc) { + postFunc(groupOp, rewriter); + } + Block *groupBlock = new Block(); + groupOp.getBody().push_back(groupBlock); + + // First move all uses into the group in the dominance order + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(groupBlock); + auto yieldOp = rewriter.create(op->getLoc(), allResults); + // This is where we are using the sorted-ness of `uses`. We are + // guaranteed that if the users of the op themselves depend on each + // other, then we'll move them in the correct order. + for (auto use : uses) { + use->moveBefore(yieldOp); } + op->moveBefore(*uses.begin()); } + + // We then replace all uses of the uses which lie outside the group + // with the group's results. We should not replace uses inside the + // group otherwise ops inside the group will end up depending on the + // group's results causing dominance issues. + size_t groupResultNum = 0; + for (auto use : uses) { + for (unsigned num = 0; num < use->getNumResults(); ++num) { + auto useIsOutsideGroup = [&](OpOperand &operand) { + return operand.getOwner()->getParentOp() != groupOp; + }; + rewriter.replaceUsesWithIf(use->getResult(num), + groupOp->getResult(groupResultNum), + useIsOutsideGroup); + groupResultNum++; + } + } + + } else if (auto groupOp = + dyn_cast(usesParentRegion->getParentOp())) { + // Given that we iterate over the funcop in a bottom up manner, when moving + // into an existing group, we would be guaranteed that this op does not use + // any of the ops already in the group. So we can move it to the very + // beginning of the group. This ensures that the order of operands is + // preserved when creating a group. For example, if we start with + // something like: + // + // %0 = op1(%in1) + // %1 = op2(%in2) + // %2 = op3(%0, %1) + // + // we'll first create a %1 and %2 + // + // %0 = op1(%in1) + // %3 = tcp.group { + // %1 = op2(%in2) + // %2 = op3(%0, %1) + // } + // + // if we try to move %0 to right before its use in the group, then we'd + // end up with: + // + // %3 = tcp.group { + // %1 = op2(%in2) + // %0 = op1(%in1) + // %2 = op3(%0, %1) + // } + // + // While this is not incorrect, it is a bit annoying that the MLIR gets + // reordered. + auto &firstOp = *usesParentRegion->getOps().begin(); + op->moveBefore(&firstOp); + } else { + op->emitError("Unhandled case during fusion"); + llvm_unreachable("Unhandled case during fusion"); } - return isChanged ? success() : failure(); + LLVM_DEBUG(llvm::dbgs() << "Function after transformation:\n" + << op->getParentOfType() << "\n"); + return success(); } + } // namespace mlir::tcp diff --git a/test/Dialect/tcp_fusion.mlir b/test/Dialect/tcp_fusion.mlir index bebd4a0c..7e420f9f 100644 --- a/test/Dialect/tcp_fusion.mlir +++ b/test/Dialect/tcp_fusion.mlir @@ -51,3 +51,49 @@ func.func @test_multiple_fusions(%arg0 : tensor, %6 = tcp.sub %arg1, %5 : tensor, tensor -> tensor return %6 : tensor } + +// ----- + +// Fusion with multiple uses where the def with multiple uses moves into an +// already created group. + +// CHECK: func.func @test_multi_use_fusion(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor) -> tensor { +// CHECK: %[[V0:.+]] = tcp.group { +// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor -> tensor +// CHECK: %[[V2:.+]] = tcp.add %[[V1]], %[[ARG1]] : tensor, tensor -> tensor +// CHECK: %[[V3:.+]] = tcp.sub %[[V2]], %[[ARG1]] : tensor, tensor -> tensor +// CHECK: %[[V4:.+]] = tcp.mul %[[V2]], %[[V3]] : tensor, tensor -> tensor +// CHECK: tcp.yield %[[V4]] : tensor +// CHECK: } : tensor +// CHECK: return %[[V0]] : tensor +// CHECK: } +func.func @test_multi_use_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = tcp.tanh %arg0 : tensor -> tensor + %1 = tcp.add %0, %arg1 : tensor, tensor -> tensor + %2 = tcp.sub %1, %arg1 : tensor, tensor -> tensor + %3 = tcp.mul %1, %2 : tensor, tensor -> tensor + return %3 : tensor +} + +// ----- + +// Fusion with multiple uses where the def and the multiple uses create a +// new group. Here we test that the moves use the dominance correctly. + +// CHECK: func.func @test_multi_use_fusion(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor) -> (tensor, tensor) { +// CHECK: %[[V0:.+]]:2 = tcp.group { +// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor -> tensor +// CHECK: %[[V2:.+]] = tcp.add %[[V1]], %[[V1]] : tensor, tensor -> tensor +// CHECK: %[[V3:.+]] = tcp.sub %[[V2]], %[[ARG1]] : tensor, tensor -> tensor +// CHECK: %[[V4:.+]] = tcp.mul %[[V2]], %[[V3]] : tensor, tensor -> tensor +// CHECK: tcp.yield %[[V3]], %[[V4]] : tensor, tensor +// CHECK: } : tensor, tensor +// CHECK: return %[[V0]]#0, %[[V0]]#1 : tensor, tensor +// CHECK: } +func.func @test_multi_use_fusion(%arg0 : tensor, %arg1 : tensor) -> (tensor, tensor) { + %0 = tcp.tanh %arg0 : tensor -> tensor + %1 = tcp.add %0, %0 : tensor, tensor -> tensor + %2 = tcp.sub %1, %arg1 : tensor, tensor -> tensor + %3 = tcp.mul %1, %2 : tensor, tensor -> tensor + "func.return" (%2, %3) : (tensor, tensor) -> () +}