Skip to content

Commit

Permalink
Support multiple target ops in clone_preceding_op_into_dispatch_region
Browse files Browse the repository at this point in the history
The target ops are sorted topoloically before cloning them one-by-one.
This is to ensure that there are no dominance violations. (Same logic as
with the existing dispatch region formation.)
  • Loading branch information
matthias-springer committed Aug 23, 2022
1 parent dc06d95 commit 9da8d2a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Transforms/TopologicalSortUtils.h"

using namespace mlir;
using namespace mlir::iree_compiler;
Expand Down Expand Up @@ -706,8 +707,7 @@ transform_dialect::ClonePrecedingOpIntoDispatchRegionOp::apply(
ArrayRef<Operation *> dispatchRegion =
state.getPayloadOps(getDispatchRegion());

// TODO: Multiple targetOps could be allowed.
if (targetOps.size() != 1 || dispatchRegion.size() != 1)
if (dispatchRegion.size() != 1)
return DiagnosedSilenceableFailure(this->emitOpError(
"requires exactly one target/dispatch region handle"));

Expand All @@ -716,15 +716,25 @@ transform_dialect::ClonePrecedingOpIntoDispatchRegionOp::apply(
return DiagnosedSilenceableFailure(
this->emitOpError("expected 'dispatch.region' operand"));

// We are cloning ops one-by-one, so the order must be inversed (as opposed
// to cloning all ops in one go).
SmallVector<Operation *> targetOpsList(targetOps.begin(), targetOps.end());
bool sortResult = computeTopologicalSorting(
dispatchRegion.front()->getBlock(), targetOpsList);
assert(sortResult && "unable to sort topologically");
SmallVector<Operation *> orderedTargets =
llvm::to_vector(llvm::reverse(targetOps));
IRRewriter rewriter(regionOp->getContext());
auto newRegionOp = clonePrecedingOpIntoDispatchRegion(
rewriter, targetOps.front(), regionOp, getUpdateUsesOutsideOfRegion());
if (failed(newRegionOp))
return DiagnosedSilenceableFailure(
reportUnknownTransformError(targetOps.front()));
for (Operation *target : orderedTargets) {
auto newRegionOp = clonePrecedingOpIntoDispatchRegion(
rewriter, target, regionOp, getUpdateUsesOutsideOfRegion());
if (failed(newRegionOp))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
regionOp = *newRegionOp;
}

transformResults.set(getTransformed().cast<OpResult>(),
newRegionOp->getOperation());
regionOp.getOperation());
return DiagnosedSilenceableFailure(success());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,6 @@ def ClonePrecedingOpIntoDispatchRegionOp : Op<
region. The transform fails if there are uses that appear before the
dispatch region.



TODO: Support multiple payload ops for the `target` handle. In that case,
the targets must be sorted topologically before cloning them.

#### Return modes

This transform consumes both the `target` handle and the `dispatch_region`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,41 @@ transform.with_pdl_patterns {
transform.iree.region_to_workgroups %region_op
}
}

// -----

// CHECK-LABEL: func @move_multiple_preceding
// CHECK-DAG: arith.constant
// CHECK-DAG: arith.constant
// CHECK-DAG: tensor.dim
// CHECK-DAG: tensor.dim
// CHECK-NEXT: "test.dummy_op"
// CHECK-NEXT: "test.third_user"
// CHECK-NEXT: flow.dispatch.region
// CHECK-NEXT: "test.dummy_op"
// CHECK-NEXT: "test.first_user"
// CHECK-NEXT: "test.second_user"
// CHECK-NEXT: "test.merge1"
// CHECK-NEXT: "test.merge2"
func.func @move_multiple_preceding(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %s1: index, %s2: index) -> (tensor<?x?xf32>) {
%0 = "test.dummy_op"(%arg0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
%1 = "test.first_user"(%0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
%2 = "test.second_user"(%0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
%u = "test.third_user"(%0) : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
%3 = "test.merge1"(%1, %2) {__tagged__} : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
%4 = "test.merge2"(%1, %3) {__tagged__} : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
%5 = tensor.insert_slice %4 into %arg1 [5, 16] [%s1, %s2] [1, 1]
: tensor<?x?xf32> into tensor<?x?xf32>
return %5 : tensor<?x?xf32>
}

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
%dispatch_op = transform.iree.wrap_in_dispatch_region %0
%1 = transform.structured.match attributes{"__tagged__"} in %arg1
transform.iree.clone_preceding_op_into_dispatch_region %1 into %dispatch_op
}
}

0 comments on commit 9da8d2a

Please sign in to comment.