Skip to content

Commit

Permalink
Support multiple target ops in clone_succeeding_op_into_dispatch_region
Browse files Browse the repository at this point in the history
The target ops are sorted topoloically before cloning them one-by-one.
This is to ensure that there are no dominance violations.
  • Loading branch information
matthias-springer committed Aug 23, 2022
1 parent b5bf9d5 commit ed33a5b
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -834,8 +834,7 @@ transform_dialect::CloneSucceedingOpIntoDispatchRegionOp::apply(
ArrayRef<Operation *> dispatchRegion =
state.getPayloadOps(getDispatchRegion());

// TODO: Multiple targetOps could be allowed.
if (targetOps.size() != 1 || dispatchRegion.size() != 1)
if (dispatchRegion.size() != 1)
return DiagnosedSilenceableFailure(this->emitOpError(
"requires exactly one target/dispatch region handle"));

Expand All @@ -844,15 +843,18 @@ transform_dialect::CloneSucceedingOpIntoDispatchRegionOp::apply(
return DiagnosedSilenceableFailure(
this->emitOpError("expected 'dispatch.region' operand"));

SmallVector<Operation *> orderedTargets = Flow::orderOperations(targetOps);
IRRewriter rewriter(regionOp->getContext());
auto newRegionOp = cloneSucceedingOpIntoDispatchRegion(
rewriter, targetOps.front(), regionOp, getUpdateUsesOutsideOfRegion());
if (failed(newRegionOp))
return DiagnosedSilenceableFailure(
reportUnknownTransformError(targetOps.front()));
for (Operation *target : orderedTargets) {
auto newRegionOp = cloneSucceedingOpIntoDispatchRegion(
rewriter, target, regionOp, getUpdateUsesOutsideOfRegion());
if (failed(newRegionOp))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
regionOp = *newRegionOp;
}

transformResults.set(getTransformed().cast<OpResult>(),
newRegionOp->getOperation());
regionOp.getOperation());
return DiagnosedSilenceableFailure(success());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,6 @@ def CloneSucceedingOpIntoDispatchRegionOp : Op<
are yielded from the dispatch region and used instead of results of the
original target op.

TODO: Support multiple payload ops for the `target` handle. In that case,
the targets must be sorted topologically before cloning them.

#### Return modes

This transform consumes both the `target` handle and the `dispatch_region`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func.func @single_op(%arg0: tensor<?x?xf32>, %s1: index, %s2: index) -> tensor<?

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 failures(propagate) {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["tensor.extract_slice"]} in %arg1
transform.iree.wrap_in_dispatch_region %0
Expand Down Expand Up @@ -46,7 +46,7 @@ func.func @clone_preceding(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %s1:

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 failures(propagate) {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
%dispatch_op = transform.iree.wrap_in_dispatch_region %0
Expand Down Expand Up @@ -79,7 +79,7 @@ func.func @move_preceding(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %s1: i

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 failures(propagate) {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
%dispatch_op = transform.iree.wrap_in_dispatch_region %0
Expand All @@ -90,70 +90,6 @@ transform.with_pdl_patterns {

// -----

// CHECK-LABEL: func @create_region_and_convert_to_workgroups
// CHECK: linalg.init_tensor
// CHECK: flow.dispatch.workgroups
// CHECK: linalg.matmul
// CHECK: flow.return
func.func @create_region_and_convert_to_workgroups(
%A: tensor<5x3xf32>, %B: tensor<3x5xf32>) -> tensor<5x5xf32> {
%init = linalg.init_tensor [5, 5] : tensor<5x5xf32>
%matmul = linalg.matmul
ins(%A, %B : tensor<5x3xf32>, tensor<3x5xf32>)
outs(%init : tensor<5x5xf32>) -> tensor<5x5xf32>
return %matmul : tensor<5x5xf32>
}

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%region_op = transform.iree.wrap_in_dispatch_region %0
transform.iree.region_to_workgroups %region_op
}
}

// -----

// CHECK-LABEL: func @move_multiple_preceding
// CHECK-DAG: arith.constant
// CHECK-DAG: arith.constant
// CHECK-DAG: tensor.dim
// CHECK-DAG: tensor.dim
// CHECK-NEXT: "test.dummy_op"
// CHECK-NEXT: "test.third_user"
// CHECK-NEXT: flow.dispatch.region
// CHECK-NEXT: "test.dummy_op"
// CHECK-NEXT: "test.first_user"
// CHECK-NEXT: "test.second_user"
// CHECK-NEXT: "test.merge1"
// CHECK-NEXT: "test.merge2"
func.func @move_multiple_preceding(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %s1: index, %s2: index) -> (tensor<?x?xf32>) {
%0 = "test.dummy_op"(%arg0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
%1 = "test.first_user"(%0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
%2 = "test.second_user"(%0) {__tagged__} : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
%u = "test.third_user"(%0) : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
%3 = "test.merge1"(%1, %2) {__tagged__} : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
%4 = "test.merge2"(%1, %3) {__tagged__} : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)
%5 = tensor.insert_slice %4 into %arg1 [5, 16] [%s1, %s2] [1, 1]
: tensor<?x?xf32> into tensor<?x?xf32>
return %5 : tensor<?x?xf32>
}

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
%dispatch_op = transform.iree.wrap_in_dispatch_region %0
%1 = transform.structured.match attributes{"__tagged__"} in %arg1
transform.iree.clone_preceding_op_into_dispatch_region %1 into %dispatch_op
}
}

// -----

// CHECK-LABEL: func @move_succeeding(
// CHECK-SAME: %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
func.func @move_succeeding(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %s1: index, %s2: index) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
Expand All @@ -176,11 +112,47 @@ func.func @move_succeeding(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %s1:

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 failures(propagate) {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["tensor.extract_slice"]} in %arg1
%dispatch_op = transform.iree.wrap_in_dispatch_region %0
%1 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
transform.iree.clone_succeeding_op_into_dispatch_region %1 into %dispatch_op
}
}

// -----

// CHECK-LABEL: func @move_multiple_succeeding
// CHECK-NEXT: flow.dispatch.region -> (tensor<50x90xf32>, tensor<50x90xf32>, tensor<50x90xf32>, tensor<50x90xf32>, tensor<50x90xf32>, tensor<600x700xf32>)
// CHECK-NEXT: "test.dummy_op"
// CHECK-NEXT: "test.first_user"
// CHECK-NEXT: "test.second_user"
// CHECK-NEXT: "test.merge1"
// CHECK-NEXT: "test.merge2"
// CHECK-NEXT: tensor.insert_slice
// CHECK-NEXT: flow.return
// CHECK-NEXT: }
// CHECK-NEXT: "test.third_user"
func.func @move_multiple_succeeding(%arg0: tensor<50x90xf32>, %arg1: tensor<600x700xf32>) -> (tensor<600x700xf32>, tensor<50x90xf32>) {
%0 = "test.dummy_op"(%arg0) : (tensor<50x90xf32>) -> (tensor<50x90xf32>)
%1 = "test.first_user"(%0) {__tagged__} : (tensor<50x90xf32>) -> (tensor<50x90xf32>)
%2 = "test.second_user"(%0) {__tagged__} : (tensor<50x90xf32>) -> (tensor<50x90xf32>)
%u = "test.third_user"(%0) : (tensor<50x90xf32>) -> (tensor<50x90xf32>)
%3 = "test.merge1"(%1, %2) {__tagged__} : (tensor<50x90xf32>, tensor<50x90xf32>) -> (tensor<50x90xf32>)
%4 = "test.merge2"(%1, %3) {__tagged__} : (tensor<50x90xf32>, tensor<50x90xf32>) -> (tensor<50x90xf32>)
%5 = tensor.insert_slice %4 into %arg1 [5, 16] [50, 90] [1, 1] {__tagged__}
: tensor<50x90xf32> into tensor<600x700xf32>
return %5, %u : tensor<600x700xf32>, tensor<50x90xf32>
}

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["test.dummy_op"]} in %arg1
%dispatch_op = transform.iree.wrap_in_dispatch_region %0
%1 = transform.structured.match attribute{"__tagged__"} in %arg1
transform.iree.clone_succeeding_op_into_dispatch_region %1 into %dispatch_op
}
}

0 comments on commit ed33a5b

Please sign in to comment.