diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 53df7af00aee88..4968c4fc463d04 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -774,94 +774,6 @@ struct ConvertIllegalShapeCastOpsToTransposes } }; -/// Returns an iterator over the dims (inc scalability) of a VectorType. -static auto getDims(VectorType vType) { - return llvm::zip_equal(vType.getShape(), vType.getScalableDims()); -} - -/// Helper to drop (fixed-size) unit dims from a VectorType. -static VectorType dropUnitDims(VectorType vType) { - SmallVector scalableFlags; - SmallVector dimSizes; - for (auto dim : getDims(vType)) { - if (dim == std::make_tuple(1, false)) - continue; - auto [size, scalableFlag] = dim; - dimSizes.push_back(size); - scalableFlags.push_back(scalableFlag); - } - return VectorType::get(dimSizes, vType.getElementType(), scalableFlags); -} - -/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the -/// shape_cast only drops unit dimensions. -/// -/// This simplifies the transpose making it possible for other legalization -/// rewrites to handle it. -/// -/// Example: -/// -/// BEFORE: -/// ```mlir -/// %0 = vector.transpose %vector, [3, 0, 1, 2] -/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> -/// %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32> -/// ``` -/// -/// AFTER: -/// ```mlir -/// %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> -/// %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> -/// ``` -struct SwapShapeCastOfTranspose : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, - PatternRewriter &rewriter) const override { - auto transposeOp = - shapeCastOp.getSource().getDefiningOp(); - if (!transposeOp) - return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp"); - - auto resultType = shapeCastOp.getResultVectorType(); - if (resultType.getRank() <= 1) - return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low"); - - if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType())) - return rewriter.notifyMatchFailure( - shapeCastOp, "ShapeCastOp changes non-unit dimension(s)"); - - auto transposeSourceVectorType = transposeOp.getSourceVectorType(); - auto transposeSourceDims = - llvm::to_vector(getDims(transposeSourceVectorType)); - - // Construct a map from dimIdx -> number of dims dropped before dimIdx. - SmallVector droppedDimsBefore(transposeSourceVectorType.getRank()); - int64_t droppedDims = 0; - for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) { - droppedDimsBefore[i] = droppedDims; - if (dim == std::make_tuple(1, false)) - ++droppedDims; - } - - // Drop unit dims from transpose permutation. - auto perm = transposeOp.getPermutation(); - SmallVector newPerm; - for (int64_t idx : perm) { - if (transposeSourceDims[idx] == std::make_tuple(1, false)) - continue; - newPerm.push_back(idx - droppedDimsBefore[idx]); - } - - auto loc = shapeCastOp.getLoc(); - auto newShapeCastOp = rewriter.create( - loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector()); - rewriter.replaceOpWithNewOp(shapeCastOp, - newShapeCastOp, newPerm); - return success(); - } -}; - /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use /// the ZA state. This workaround rewrite to support these transposes when ZA is /// available. @@ -1027,8 +939,7 @@ struct VectorLegalizationPass patterns.add( - context); + LowerIllegalTransposeStoreViaZA>(context); // Note: These two patterns are added with a high benefit to ensure: // - Masked outer products are handled before unmasked ones // - Multi-tile writes are lowered as a store loop (if possible) diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir index adc02adb6e974c..458906a1879829 100644 --- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir +++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir @@ -646,29 +646,3 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref return } - -// ----- - -// CHECK-LABEL: @swap_shape_cast_of_transpose( -// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>) -func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> { - // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> - // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> - // CHECK: return %[[TRANSPOSE]] - %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> - %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32> - return %1 : vector<[4]x4xf32> -} - -// ----- - -// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after( -// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>) -func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> { - // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32> - // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> - // CHECK: return %[[TRANSPOSE]] - %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32> - %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32> - return %1 : vector<[4]x4xf32> -}