From 2019112165ed6d4a6daa84ae5312cfb0d9b09e61 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Mon, 2 Dec 2024 16:34:26 +0000 Subject: [PATCH] [MLIR] Improve compose expand(collapse) pattern (#117768) If expand(collapse) has a dimension that gets collapsed and then expanded to the same shape, the pattern would fail to canonicalize this to a single collapse shape. Line 341 was changed because the expand(collapse) could be a reinterpret-cast like sequence where the shapes differ but the rank is the same. This cannot be represented by a single `collapse_shape` op. Signed-off-by: Ian Wood --- .../mlir/Dialect/Utils/ReshapeOpsUtils.h | 22 ++++++++------- mlir/test/Dialect/Tensor/canonicalize.mlir | 28 +++++++++++++++++++ 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index 89bc57f09ec8ba3..3fa35bf1851a9c0 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -338,7 +338,7 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern { int64_t srcRank = srcType.getRank(); int64_t resultRank = resultType.getRank(); - if (srcType == resultType) + if (srcRank == resultRank) return failure(); auto srcReassociation = collapseOp.getReassociationIndices(); @@ -388,12 +388,14 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern { resultShape.slice(resultIndices.front(), resultIndices.size()); if (srcSubShape.size() == resultSubShape.size()) { - if (srcSubShape == resultSubShape && - llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) { - composedReassociation.push_back(srcIndices); - } else { + if (srcSubShape != resultSubShape || + llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) { return std::nullopt; } + for (auto index : llvm::seq(0, srcSubShape.size())) { + composedReassociation.emplace_back(1, srcIndices.front() + index); + } + continue; } // Find reassociation to collapse `srcSubShape` into `resultSubShape`. @@ -403,11 +405,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern { return std::nullopt; // Remap the subshape indices back to the original srcShape. - for (auto &subshape_indices : *subShapeReassociation) { - ReassociationIndices shape_indices; - for (int64_t index : subshape_indices) - shape_indices.push_back(srcIndices.front() + index); - composedReassociation.push_back(shape_indices); + for (auto &subshapeIndices : *subShapeReassociation) { + ReassociationIndices shapeIndices; + for (int64_t index : subshapeIndices) + shapeIndices.push_back(srcIndices.front() + index); + composedReassociation.push_back(shapeIndices); } } return {std::move(composedReassociation)}; diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 0b54c207dea84ea..613ec0663372949 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1382,6 +1382,34 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x // ----- +func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> { + %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16> + %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16> + return %expanded : tensor<4x32x10x128xf16> +} + +// CHECK-LABEL: func @compose_expand_of_collapse_static +// CHECK-SAME: %[[ARG0:.+]]: tensor<4x32x10x64x2xf16> +// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK-SAME: [0], [1], [2], [3, 4] +// CHECK: return %[[RESULT]] + +// ----- + +func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> { + %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor + %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1, 10, 128] : tensor into tensor<4x?x10x128xf16> + return %expanded : tensor<4x?x10x128xf16> +} + +// CHECK-LABEL: func @compose_expand_of_collapse_dynamic +// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x10x64x2xf16> +// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK-SAME: [0], [1], [2], [3, 4] +// CHECK: return %[[RESULT]] + +// ----- + // CHECK-LABEL: func @zero_rank_reshape_multi func.func @zero_rank_reshape_multi(%arg0: tensor) -> tensor { // CHECK: return %arg0