Skip to content

Commit

Permalink
[mlir] Fix lower_unpack when dynamic dimensions are involved (llvm#…
Browse files Browse the repository at this point in the history
…68423)

When lowering `tensor.unpack`, we need to use the sizes of the
destination tensor in the final `tensor.extract_slice` operation. Prior
to this patch, when the destination tensor had dynamic dimensions, we
would compute them from the result of the `tensor.unpack` operation
instead of its destination argument.

This would produce invalid IR because the `tensor.dim` operations would
need to appear before the `tensor.extract_slice` operation, but the
input of the `tensor.dim` operations would consume the final result of
the lowering of `tensor.unpack`, which happens after the
`tensor.extract_slice` operation. In other words, the definition
wouldn't dominate its uses.

I.e., we were generating:
```
%dynDim = tensor.dim %defLater, ... <-- %defLater defined below
%res = tensor.extract_slice ..., %dynDim, ...
%defLater = linalg.copy (ins %res)
```

Note: I checked the implementation of `lower_pack` and the code is
correct as far as I can tell.
  • Loading branch information
qcolombet authored Oct 6, 2023
1 parent 5009d24 commit 7050ff4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
loc, destTensorType, reshapeOp->getResult(0),
SmallVector<OpFoldResult>(destRank, zero),
tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
SmallVector<OpFoldResult>(destRank, one));

// 7. Inject a copy to preserve DPS.
Expand Down
39 changes: 38 additions & 1 deletion mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16
// CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
// CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
// CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
// CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
// CHECK-SAME: outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
%pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
: tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
Expand Down Expand Up @@ -397,3 +397,40 @@ transform.sequence failures(propagate) {
transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
-> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
}

// -----

// Check that we can lower unpack with dynamic dimensions in the destination.
// CHECK-LABEL: func.func @unpack_with_dynamic_dest(
// CHECK-SAME: %[[ARG0:.*]]: tensor<32x2x49x16x16xf32>, %[[ARG1:.*]]: tensor<32x?x?xf32>)
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32x2x16x49x16xf32>
// CHECK: %[[TRAN:.*]] = linalg.transpose
// CHECK-SAME: ins(%[[ARG0]] : tensor<32x2x49x16x16xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x2x16x49x16xf32>)
// CHECK-SAME: permutation = [0, 1, 3, 2, 4]
// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0], [1, 2], [3, 4]]
// CHECK-SAME: : tensor<32x2x16x49x16xf32> into tensor<32x32x784xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<32x?x?xf32>
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<32x?x?xf32>
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0] [32, %[[DIM1]], %[[DIM2]]] [1, 1, 1]
// CHECK-SAME: : tensor<32x32x784xf32> to tensor<32x?x?xf32>
// CHECK: linalg.copy ins(%[[SLICE]] : tensor<32x?x?xf32>)
// CHECK-SAME: outs(%[[ARG1]] : tensor<32x?x?xf32>)
func.func @unpack_with_dynamic_dest(%arg0: tensor<32x2x49x16x16xf32>, %arg1: tensor<32x?x?xf32>) -> tensor<32x?x?xf32> {
%pack = tensor.unpack %arg0 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %arg1
: tensor<32x2x49x16x16xf32> -> tensor<32x?x?xf32>
return %pack : tensor<32x?x?xf32>
}

transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
: (!transform.any_op) -> !transform.op<"tensor.unpack">
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
!transform.op<"tensor.extract_slice">)
}

0 comments on commit 7050ff4

Please sign in to comment.