Skip to content

Commit

Permalink
Support multiple target ops in clone_succeeding_op_into_dispatch_region
Browse files Browse the repository at this point in the history
The target ops are sorted topoloically before cloning them one-by-one.
This is to ensure that there are no dominance violations.
  • Loading branch information
matthias-springer committed Aug 23, 2022
1 parent b5bf9d5 commit 9796b9e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -834,8 +834,7 @@ transform_dialect::CloneSucceedingOpIntoDispatchRegionOp::apply(
ArrayRef<Operation *> dispatchRegion =
state.getPayloadOps(getDispatchRegion());

// TODO: Multiple targetOps could be allowed.
if (targetOps.size() != 1 || dispatchRegion.size() != 1)
if (dispatchRegion.size() != 1)
return DiagnosedSilenceableFailure(this->emitOpError(
"requires exactly one target/dispatch region handle"));

Expand All @@ -844,15 +843,22 @@ transform_dialect::CloneSucceedingOpIntoDispatchRegionOp::apply(
return DiagnosedSilenceableFailure(
this->emitOpError("expected 'dispatch.region' operand"));

SmallVector<Operation *> orderedTargets(targetOps.begin(), targetOps.end());
bool sortResult = computeTopologicalSorting(
dispatchRegion.front()->getBlock(), orderedTargets);
(void)sortResult;
assert(sortResult && "unable to sort topologically");
IRRewriter rewriter(regionOp->getContext());
auto newRegionOp = cloneSucceedingOpIntoDispatchRegion(
rewriter, targetOps.front(), regionOp, getUpdateUsesOutsideOfRegion());
if (failed(newRegionOp))
return DiagnosedSilenceableFailure(
reportUnknownTransformError(targetOps.front()));
for (Operation *target : orderedTargets) {
auto newRegionOp = cloneSucceedingOpIntoDispatchRegion(
rewriter, target, regionOp, getUpdateUsesOutsideOfRegion());
if (failed(newRegionOp))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
regionOp = *newRegionOp;
}

transformResults.set(getTransformed().cast<OpResult>(),
newRegionOp->getOperation());
regionOp.getOperation());
return DiagnosedSilenceableFailure(success());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,6 @@ def CloneSucceedingOpIntoDispatchRegionOp : Op<
are yielded from the dispatch region and used instead of results of the
original target op.

TODO: Support multiple payload ops for the `target` handle. In that case,
the targets must be sorted topologically before cloning them.

#### Return modes

This transform consumes both the `target` handle and the `dispatch_region`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,39 @@ transform.with_pdl_patterns {
transform.iree.clone_succeeding_op_into_dispatch_region %1 into %dispatch_op
}
}

// -----

// CHECK-LABEL: func @move_multiple_succeeding
// CHECK-NEXT: flow.dispatch.region -> (tensor<50x90xf32>, tensor<50x90xf32>, tensor<50x90xf32>, tensor<50x90xf32>, tensor<50x90xf32>, tensor<600x700xf32>)
// CHECK-NEXT: "test.dummy_op"
// CHECK-NEXT: "test.first_user"
// CHECK-NEXT: "test.second_user"
// CHECK-NEXT: "test.merge1"
// CHECK-NEXT: "test.merge2"
// CHECK-NEXT: tensor.insert_slice
// CHECK-NEXT: flow.return
// CHECK-NEXT: }
// CHECK-NEXT: "test.third_user"
func.func @move_multiple_succeeding(%arg0: tensor<50x90xf32>, %arg1: tensor<600x700xf32>) -> (tensor<600x700xf32>, tensor<50x90xf32>) {
%0 = "test.dummy_op"(%arg0) : (tensor<50x90xf32>) -> (tensor<50x90xf32>)
%1 = "test.first_user"(%0) {__tagged__} : (tensor<50x90xf32>) -> (tensor<50x90xf32>)
%2 = "test.second_user"(%0) {__tagged__} : (tensor<50x90xf32>) -> (tensor<50x90xf32>)
%u = "test.third_user"(%0) : (tensor<50x90xf32>) -> (tensor<50x90xf32>)
%3 = "test.merge1"(%1, %2) {__tagged__} : (tensor<50x90xf32>, tensor<50x90xf32>) -> (tensor<50x90xf32>)
%4 = "test.merge2"(%1, %3) {__tagged__} : (tensor<50x90xf32>, tensor<50x90xf32>) -> (tensor<50x90xf32>)
%5 = tensor.insert_slice %4 into %arg1 [5, 16] [50, 90] [1, 1] {__tagged__}
: tensor<50x90xf32> into tensor<600x700xf32>
return %5, %u : tensor<600x700xf32>, tensor<50x90xf32>
}

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["test.dummy_op"]} in %arg1
%dispatch_op = transform.iree.wrap_in_dispatch_region %0
%1 = transform.structured.match attributes{"__tagged__"} in %arg1
transform.iree.clone_succeeding_op_into_dispatch_region %1 into %dispatch_op
}
}

0 comments on commit 9796b9e

Please sign in to comment.