Skip to content

Commit

Permalink
[MLIR] Improve compose expand(collapse) pattern (#117768)
Browse files Browse the repository at this point in the history
If expand(collapse) has a dimension that gets collapsed and then
expanded to the same shape, the pattern would fail to canonicalize this
to a single collapse shape. Line 341 was changed because the
expand(collapse) could be a reinterpret-cast like sequence where the
shapes differ but the rank is the same. This cannot be represented by a
single `collapse_shape` op.

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 authored and pull[bot] committed Dec 2, 2024
1 parent c71d7ae commit 2019112
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 10 deletions.
22 changes: 12 additions & 10 deletions mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {

int64_t srcRank = srcType.getRank();
int64_t resultRank = resultType.getRank();
if (srcType == resultType)
if (srcRank == resultRank)
return failure();

auto srcReassociation = collapseOp.getReassociationIndices();
Expand Down Expand Up @@ -388,12 +388,14 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
resultShape.slice(resultIndices.front(), resultIndices.size());

if (srcSubShape.size() == resultSubShape.size()) {
if (srcSubShape == resultSubShape &&
llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
composedReassociation.push_back(srcIndices);
} else {
if (srcSubShape != resultSubShape ||
llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
return std::nullopt;
}
for (auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
composedReassociation.emplace_back(1, srcIndices.front() + index);
}
continue;
}

// Find reassociation to collapse `srcSubShape` into `resultSubShape`.
Expand All @@ -403,11 +405,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
return std::nullopt;

// Remap the subshape indices back to the original srcShape.
for (auto &subshape_indices : *subShapeReassociation) {
ReassociationIndices shape_indices;
for (int64_t index : subshape_indices)
shape_indices.push_back(srcIndices.front() + index);
composedReassociation.push_back(shape_indices);
for (auto &subshapeIndices : *subShapeReassociation) {
ReassociationIndices shapeIndices;
for (int64_t index : subshapeIndices)
shapeIndices.push_back(srcIndices.front() + index);
composedReassociation.push_back(shapeIndices);
}
}
return {std::move(composedReassociation)};
Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1382,6 +1382,34 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x

// -----

func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> {
%collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16>
%expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16>
return %expanded : tensor<4x32x10x128xf16>
}

// CHECK-LABEL: func @compose_expand_of_collapse_static
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x32x10x64x2xf16>
// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK-SAME: [0], [1], [2], [3, 4]
// CHECK: return %[[RESULT]]

// -----

func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> {
%collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor<?x10x128xf16>
%expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1, 10, 128] : tensor<?x10x128xf16> into tensor<4x?x10x128xf16>
return %expanded : tensor<4x?x10x128xf16>
}

// CHECK-LABEL: func @compose_expand_of_collapse_dynamic
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x10x64x2xf16>
// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK-SAME: [0], [1], [2], [3, 4]
// CHECK: return %[[RESULT]]

// -----

// CHECK-LABEL: func @zero_rank_reshape_multi
func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: return %arg0
Expand Down

0 comments on commit 2019112

Please sign in to comment.