From cdf865ca4c8b4e67b03de744b3c2540c34cd8082 Mon Sep 17 00:00:00 2001 From: Javed Absar Date: Wed, 16 Oct 2024 13:49:59 -0400 Subject: [PATCH 1/7] [mlir][linalg] unfold projected permutation. Identify folded 'projected permutations' (i.e. mixture of transpose and/or broadcast) in linalg generic operands. The 'projected permutations' are unfolded as separate linalg.transpose and linalg.broadcast so that the generic operates on simple identity map which is necessary to replace generic with named op. --- .../Dialect/Linalg/IR/LinalgInterfaces.td | 2 +- .../Dialect/Linalg/Transforms/Transforms.h | 5 + .../Dialect/Linalg/Transforms/CMakeLists.txt | 1 + .../Dialect/Linalg/Transforms/Specialize.cpp | 1 + .../Transforms/UnfoldProjectedPermutation.cpp | 270 ++++++++++++++++++ .../Linalg/unfold_projected_permutation.mlir | 71 +++++ 6 files changed, 349 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp create mode 100644 mlir/test/Dialect/Linalg/unfold_projected_permutation.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 0a404194569c22..b81a4c9c8760cf 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -252,7 +252,7 @@ def LinalgStructuredInterface /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return getNumParallelLoops() == getNumParallelLoops(); + return getNumParallelLoops() == getNumLoops(); }] >, InterfaceMethod< diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 0693e31b4f70af..a110eb88e9f699 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1786,6 +1786,11 @@ void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); /// linalg.fill(%cst, tensor.extract_slice(%init)). void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns); + +/// Add patterns to make explicit broadcasts and transforms in the +/// input operands of a genericOp. +void populateUnfoldProjectedPermutationPatterns(RewritePatternSet &patterns); + /// Patterns to apply `splitReduction` below. void populateSplitReductionPattern( RewritePatternSet &patterns, diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index d7c63cdd8198d7..dfe6d7a54c8f14 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms TilingInterfaceImpl.cpp Transforms.cpp TransposeConv2D.cpp + UnfoldProjectedPermutation.cpp Vectorization.cpp WinogradConv2D.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index dfafffce9d9b60..a911286d5d44b2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -347,6 +347,7 @@ struct LinalgSpecializeGenericOpsPass void LinalgSpecializeGenericOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateLinalgGenericOpsSpecializationPatterns(patterns); + populateUnfoldProjectedPermutationPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp new file mode 100644 index 00000000000000..56d6bd23b2343a --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp @@ -0,0 +1,270 @@ +//===- UnfoldProjectedPermutation.cpp - extract projected projections ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements pattern to decompose the operand of a GenericOp that +// has `transpose+broadcast` juxtaposed via its affine map into separate +// transpose and broadcast ops. +// +//===----------------------------------------------------------------------===// +// +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include +#include +#include + +using namespace mlir; +using namespace mlir::linalg; + +namespace { + +/// Projected permutation are effectively folding in of a mixture of +/// transpose and broadcast into the affine map of the operand. +/// While folding of transpose and broadcast into the affine map of the +/// linalg.generic operand is a very effective optimization, sometimes +/// we may want to unfold that, for instance when recognizing named ops. +/// +/// Example +/// +/// ```mlir +/// +/// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)> +/// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +/// ... +/// %res = linalg.generic +/// { indexing_maps = [#projection, #identity, #identity], +/// iterator_types = ["parallel", "parallel", "parallel", +/// "parallel", "parallel"]} +/// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) +/// outs(%z : tensor<5x9x7x8x10xf32>) { +/// ^bb0(%in: f32, %in_1: f32, %out: f32): +/// %div = arith.divf %in, %in_1 : f32 +/// linalg.yield %div : f32 +/// } -> tensor<5x9x7x8x10xf32> +/// ``` +/// +/// In the above IR operand `%x` map is a projected-permutation. This can be +/// unfolded as: +/// +/// ```mlir +/// ... +/// %transposed = linalg.transpose ins(%x : tensor<7x8x9xf32>) +/// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1] +/// ... +/// %broadcasted = linalg.broadcast ins(%transposed : tensor<9x7x8xf32>) +/// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] +/// %2 = linalg.div +/// ins(%broadcasted, %y : +/// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) +/// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> +/// +/// Note that linalg.generic has been 'specialized' to linalg.div. +/// To unfold it is more effective to transpose first and then do the broadcast. +/// However, if transpose is done first, the permutation map needs to be +/// expressed in terms of reduced dimension (as broadcast hasn't happened yet). +/// Also, the broadcast dimensions in a linalg.generic come from other operands +/// (those not broadcasted along that particular dimension). We work this out +/// by computing the polytope shape of the linalg.gneric from shapes of all the +/// operands (inputs and outputs). + +struct UnfoldProjectedPermutation : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override; +}; + +/// Calculate shape (dimensions) of the iteration space polytope. +/// This is calculated by concatenating the indexing maps of all operands +/// of the generic; inverting the concatenation; concatenating all the +/// shapes of the operands; and then doing `apply map` to those two. +SmallVector getPolytopeDims(GenericOp op) { + assert(op.hasPureTensorSemantics() && "works only on tensors"); + + /// Concat indexing maps of all operands and invert the mapping. + auto maps = op.getIndexingMapsArray(); + auto concat = concatAffineMaps(maps); + auto inverse = inversePermutation(concat); + + /// Concat the size of each dims of all operands. + SmallVector dims; + for (auto &operand : op->getOpOperands()) { + auto rankedType = cast(operand.get().getType()); + for (auto size : rankedType.getShape()) + dims.push_back(size); + } + + /// Match the inverse map with dims to get polytope dimensions. + /// Note that some maybe 'kDynamic'. + return applyPermutationMap(inverse, dims); +} + +/// For the given `map` determine what dimensions are transposed +/// and what dimensions are broadcasted. +/// Returns : +/// `isTransposed, isBroadcast, +/// transpose-permutation, broadcast-dimensions` +/// +std::tuple, SmallVector> +computeTransposeBroadcast(AffineMap &map) { + assert(map.isProjectedPermutation(false) && "not a projection"); + + // Dimensions that don't appear on result are broadcast. + int64_t minorSize = map.getNumResults(); + + // Convert affine expr to int64_t. + SmallVector minorResult; + for (int64_t i = 0; i < minorSize; ++i) { + auto expr = cast(map.getResults()[i]); + minorResult.push_back(expr.getPosition()); + } + + // If dims are not monotonically increasing then transpose is present. + SmallVector sorted(minorResult); + std::sort(sorted.begin(), sorted.end()); + bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(), + sorted.begin(), sorted.end()); + + // Walk the sorted map result to determine which dimensions are broadcasted. + SmallVector broadcast; + for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) { + if (j < minorSize && sorted[j] == i) { + j++; + continue; + } + broadcast.push_back(i); + } + bool hasBroadcast = broadcast.size(); + + /// Consider an operand `x : tensor<7x8x9>` of a genericOp that has + /// affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>` + /// `x`s access is both transposed and brodcast. But when specifying + /// the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be + /// specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of + /// refering to d3, d4. Therefore, re-base the transpose dimensions so + /// that they start from d0. + std::map minorMap; + for (int64_t i = 0; i < minorSize; ++i) + minorMap.insert({sorted[i], i}); + + // Re-map the dimensions. + SmallVector remappedResult(minorSize); + for (int64_t i = 0; i < minorSize; ++i) + remappedResult[i] = minorMap[minorResult[i]]; + + /// Calculate the permutation for the transpose. + SmallVector permutation(minorSize); + for (unsigned i = 0; i < minorSize; ++i) { + permutation[remappedResult[i]] = i; + } + + return {hasTranspose, hasBroadcast, permutation, broadcast}; +} + +LogicalResult +UnfoldProjectedPermutation::matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const { + if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() || + op.isSingleYieldOp() || !op.isAllParallelLoops()) + return failure(); + + // All maps need to be projected permutations. + for (auto &opOperand : op->getOpOperands()) { + auto map = op.getMatchingIndexingMap(&opOperand); + if (!map.isProjectedPermutation(false)) + return failure(); + } + + // Currently we handle only static shapes. + for (auto &operand : op->getOpOperands()) { + auto rankedType = cast(operand.get().getType()); + for (auto size : rankedType.getShape()) + if (size == ShapedType::kDynamic) + return failure(); + } + + // Calculate polytope bounds from affine maps and operand(s) shapes. + auto polytope = getPolytopeDims(op); + + auto loc = op.getLoc(); + bool isChanged = false; + SmallVector newInitValues = op.getDpsInputs(); + SmallVector newMap = op.getIndexingMapsArray(); + + // Walk over each input operand and unfold if it is transposed, broadcast + // or mix of two via operand's affine-map. + for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) { + auto &map = newMap[i]; + auto inputRTType = cast(newInitValues[i].getType()); + auto elType = inputRTType.getElementType(); + + /// Nothing to do if map is already an identity. + if (map.isIdentity()) + continue; + + auto [hasTranspose, hasBroadcast, permutation, broadcastedDims] = + computeTransposeBroadcast(map); + + if (hasTranspose) { + /// linalg.transpose permutes the dimensions of input using + /// rule: dim(result, i) = dim(input, permutation[i]) + SmallVector transposedShape(map.getNumResults()); + for (int64_t i = 0; i < map.getNumResults(); ++i) + transposedShape[i] = inputRTType.getShape()[permutation[i]]; + + Value emptyTensor = + rewriter.create(loc, transposedShape, elType); + + auto transposeOp = rewriter.create(loc, newInitValues[i], + emptyTensor, permutation); + newInitValues[i] = transposeOp->getResult(0); + isChanged = true; + } + + // Does it require broadcast + if (hasBroadcast) { + assert(broadcastedDims.size() && "should have non size broadcast"); + Value emptyTensor = rewriter.create( + loc, polytope, inputRTType.getElementType()); + + auto broadcastOp = rewriter.create( + loc, newInitValues[i], emptyTensor, broadcastedDims); + + newInitValues[i] = broadcastOp->getResult(0); + isChanged = true; + } + newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims()); + } + + if (isChanged) { + SmallVector operands = op->getOperands(); + ValueRange operandsRef(operands); + + auto newOp = rewriter.create( + /*location=*/op.getLoc(), + /*resultTensorTypes=*/op->getResultTypes(), + /*inputs=*/newInitValues, + /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()), + /*indexingMaps=*/newMap, + /*iteratorTypes=*/op.getIteratorTypesArray()); + + newOp.getRegion().takeBody(op->getRegion(0)); + rewriter.replaceOp(op, newOp->getResults()); + } + return success(); +} + +} // namespace + +void mlir::linalg::populateUnfoldProjectedPermutationPatterns( + RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/unfold_projected_permutation.mlir b/mlir/test/Dialect/Linalg/unfold_projected_permutation.mlir new file mode 100644 index 00000000000000..4efa07b2de12e3 --- /dev/null +++ b/mlir/test/Dialect/Linalg/unfold_projected_permutation.mlir @@ -0,0 +1,71 @@ +// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s + +#projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)> +#identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> + +func.func @test_mixed(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> { + %res = linalg.generic + { indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} + ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) outs(%z : tensor<5x9x7x8x10xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %div = arith.divf %in, %in_1 : f32 + linalg.yield %div : f32 + } -> tensor<5x9x7x8x10xf32> + return %res : tensor<5x9x7x8x10xf32> +} + +// CHECK-LABEL: test_mixed +// CHECK-SAME: %[[X:.+]]: tensor<7x8x9xf32>, %[[Y:.+]]: tensor<5x9x7x8x10xf32>, %[[Z:.+]]: tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> { +// CHECK: %[[E0:.+]] = tensor.empty() : tensor<9x7x8xf32> +// CHECK: %[[Transposed:.+]] = linalg.transpose ins(%[[X]] : tensor<7x8x9xf32>) outs(%[[E0]] : tensor<9x7x8xf32>) permutation = [2, 0, 1] +// CHECK: %[[E1:.+]] = tensor.empty() : tensor<5x9x7x8x10xf32> +// CHECK: %[[Broadcasted:.+]] = linalg.broadcast ins(%[[Transposed]] : tensor<9x7x8xf32>) outs(%[[E1]] : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] +// CHECK: {{.*}} = linalg.div ins(%[[Broadcasted]], %[[Y]] : tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) outs(%[[Z]] : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> +// CHECK-NOT: linalg.generic + +// ----- + +#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#transposed = affine_map<(d0, d1, d2) -> (d2, d0, d1)> + +func.func @test_transposed(%x : tensor<32x2x16xf32>, %y: tensor<2x16x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { + %res = linalg.generic + { indexing_maps = [#transposed, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel"]} + ins(%x, %y : tensor<32x2x16xf32>, tensor<2x16x32xf32>) + outs(%z : tensor<2x16x32xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %div = arith.divf %in, %in_1 : f32 + linalg.yield %div : f32 + } -> tensor<2x16x32xf32> + return %res : tensor<2x16x32xf32> +} + +// CHECK-LABEL: test_transposed +// CHECK-SAME: %[[X:.+]]: tensor<32x2x16xf32>, %[[Y:.+]]: tensor<2x16x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { +// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32> +// CHECK: %[[Transposed:.+]] = linalg.transpose ins(%[[X]] : tensor<32x2x16xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) permutation = [1, 2, 0] +// CHECK: {{.*}} = linalg.div ins(%[[Transposed]], %[[Y]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[Z]] : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> +// CHECK-NOT: linalg.generic + +// ----- + +#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#broadcast = affine_map<(d0, d1, d2) -> (d0, d2)> +func.func @test_broadcast(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { + %res = linalg.generic + { indexing_maps = [#identity, #broadcast, #identity], iterator_types = ["parallel", "parallel", "parallel"]} + ins(%x, %y : tensor<2x16x32xf32>, tensor<2x32xf32>) + outs(%z : tensor<2x16x32xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %div = arith.divf %in, %in_1 : f32 + linalg.yield %div : f32 + } -> tensor<2x16x32xf32> + return %res : tensor<2x16x32xf32> +} + +// CHECK-LABEL: test_broadcast +// CHECK-SAME: %[[X:.+]]: tensor<2x16x32xf32>, %[[Y:.+]]: tensor<2x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { +// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32> +// CHECK: %[[Broadcasted:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1] +// CHECK: {{.*}} = linalg.div ins(%[[X]], %[[Broadcasted]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> +// CHECK-NOT: linalg.generic From ce582389abe691a0e6ae078c75dc4b11df7567fd Mon Sep 17 00:00:00 2001 From: Javed Absar Date: Sun, 3 Nov 2024 07:28:53 -0500 Subject: [PATCH 2/7] [mlir][linalg] fix clang format issue --- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index a110eb88e9f699..22af75e473681c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1786,7 +1786,6 @@ void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); /// linalg.fill(%cst, tensor.extract_slice(%init)). void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns); - /// Add patterns to make explicit broadcasts and transforms in the /// input operands of a genericOp. void populateUnfoldProjectedPermutationPatterns(RewritePatternSet &patterns); From b9094dc6959caee3099cee167449aa4aacdcfeb1 Mon Sep 17 00:00:00 2001 From: Javed Absar Date: Tue, 5 Nov 2024 11:29:12 -0500 Subject: [PATCH 3/7] [mlir][linalg] Address review comments --- .../Transforms/UnfoldProjectedPermutation.cpp | 12 +++++---- ...mlir => unfold-projected-permutation.mlir} | 26 +++++++++---------- 2 files changed, 20 insertions(+), 18 deletions(-) rename mlir/test/Dialect/Linalg/{unfold_projected_permutation.mlir => unfold-projected-permutation.mlir} (63%) diff --git a/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp index 56d6bd23b2343a..49425a8443d4d3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp @@ -56,13 +56,15 @@ namespace { /// /// ```mlir /// ... -/// %transposed = linalg.transpose ins(%x : tensor<7x8x9xf32>) -/// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1] +/// %x_trans = linalg.transpose +/// ins(%x : tensor<7x8x9xf32>) +/// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1] /// ... -/// %broadcasted = linalg.broadcast ins(%transposed : tensor<9x7x8xf32>) -/// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] +/// %x_trans_bc = linalg.broadcast +/// ins(%x_trans : tensor<9x7x8xf32>) +/// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] /// %2 = linalg.div -/// ins(%broadcasted, %y : +/// ins(%x_trans_bc, %y : /// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) /// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> /// diff --git a/mlir/test/Dialect/Linalg/unfold_projected_permutation.mlir b/mlir/test/Dialect/Linalg/unfold-projected-permutation.mlir similarity index 63% rename from mlir/test/Dialect/Linalg/unfold_projected_permutation.mlir rename to mlir/test/Dialect/Linalg/unfold-projected-permutation.mlir index 4efa07b2de12e3..c85eab02e792f9 100644 --- a/mlir/test/Dialect/Linalg/unfold_projected_permutation.mlir +++ b/mlir/test/Dialect/Linalg/unfold-projected-permutation.mlir @@ -3,7 +3,7 @@ #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)> #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -func.func @test_mixed(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> { +func.func @transpose_and_broadcast(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> { %res = linalg.generic { indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) outs(%z : tensor<5x9x7x8x10xf32>) { @@ -14,13 +14,13 @@ func.func @test_mixed(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : return %res : tensor<5x9x7x8x10xf32> } -// CHECK-LABEL: test_mixed +// CHECK-LABEL: transpose_and_broadcast // CHECK-SAME: %[[X:.+]]: tensor<7x8x9xf32>, %[[Y:.+]]: tensor<5x9x7x8x10xf32>, %[[Z:.+]]: tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> { // CHECK: %[[E0:.+]] = tensor.empty() : tensor<9x7x8xf32> -// CHECK: %[[Transposed:.+]] = linalg.transpose ins(%[[X]] : tensor<7x8x9xf32>) outs(%[[E0]] : tensor<9x7x8xf32>) permutation = [2, 0, 1] +// CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<7x8x9xf32>) outs(%[[E0]] : tensor<9x7x8xf32>) permutation = [2, 0, 1] // CHECK: %[[E1:.+]] = tensor.empty() : tensor<5x9x7x8x10xf32> -// CHECK: %[[Broadcasted:.+]] = linalg.broadcast ins(%[[Transposed]] : tensor<9x7x8xf32>) outs(%[[E1]] : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] -// CHECK: {{.*}} = linalg.div ins(%[[Broadcasted]], %[[Y]] : tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) outs(%[[Z]] : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> +// CHECK: %[[X_trans_bc:.+]] = linalg.broadcast ins(%[[X_trans]] : tensor<9x7x8xf32>) outs(%[[E1]] : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] +// CHECK: {{.*}} = linalg.div ins(%[[X_trans_bc]], %[[Y]] : tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) outs(%[[Z]] : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> // CHECK-NOT: linalg.generic // ----- @@ -28,7 +28,7 @@ func.func @test_mixed(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : #identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #transposed = affine_map<(d0, d1, d2) -> (d2, d0, d1)> -func.func @test_transposed(%x : tensor<32x2x16xf32>, %y: tensor<2x16x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { +func.func @transpose_only(%x : tensor<32x2x16xf32>, %y: tensor<2x16x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { %res = linalg.generic { indexing_maps = [#transposed, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel"]} ins(%x, %y : tensor<32x2x16xf32>, tensor<2x16x32xf32>) @@ -40,18 +40,18 @@ func.func @test_transposed(%x : tensor<32x2x16xf32>, %y: tensor<2x16x32xf32>, % return %res : tensor<2x16x32xf32> } -// CHECK-LABEL: test_transposed +// CHECK-LABEL: transpose_only // CHECK-SAME: %[[X:.+]]: tensor<32x2x16xf32>, %[[Y:.+]]: tensor<2x16x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { // CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32> -// CHECK: %[[Transposed:.+]] = linalg.transpose ins(%[[X]] : tensor<32x2x16xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) permutation = [1, 2, 0] -// CHECK: {{.*}} = linalg.div ins(%[[Transposed]], %[[Y]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[Z]] : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> +// CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<32x2x16xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) permutation = [1, 2, 0] +// CHECK: {{.*}} = linalg.div ins(%[[X_trans]], %[[Y]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[Z]] : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> // CHECK-NOT: linalg.generic // ----- #identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #broadcast = affine_map<(d0, d1, d2) -> (d0, d2)> -func.func @test_broadcast(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { +func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { %res = linalg.generic { indexing_maps = [#identity, #broadcast, #identity], iterator_types = ["parallel", "parallel", "parallel"]} ins(%x, %y : tensor<2x16x32xf32>, tensor<2x32xf32>) @@ -63,9 +63,9 @@ func.func @test_broadcast(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z : return %res : tensor<2x16x32xf32> } -// CHECK-LABEL: test_broadcast +// CHECK-LABEL: broadcast_only // CHECK-SAME: %[[X:.+]]: tensor<2x16x32xf32>, %[[Y:.+]]: tensor<2x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { // CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32> -// CHECK: %[[Broadcasted:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1] -// CHECK: {{.*}} = linalg.div ins(%[[X]], %[[Broadcasted]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> +// CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1] +// CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> // CHECK-NOT: linalg.generic From 3b238c652dd0f0f89040ace0c224ff0a618ff88a Mon Sep 17 00:00:00 2001 From: Javed Absar Date: Wed, 6 Nov 2024 14:51:50 -0500 Subject: [PATCH 4/7] Revise more based on review comments. --- .../Transforms/UnfoldProjectedPermutation.cpp | 77 ++++++------------- 1 file changed, 24 insertions(+), 53 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp index 49425a8443d4d3..c93a6737ae37ef 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp @@ -6,12 +6,6 @@ // //===----------------------------------------------------------------------===// // -// This file implements pattern to decompose the operand of a GenericOp that -// has `transpose+broadcast` juxtaposed via its affine map into separate -// transpose and broadcast ops. -// -//===----------------------------------------------------------------------===// -// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include @@ -26,11 +20,16 @@ using namespace mlir::linalg; namespace { -/// Projected permutation are effectively folding in of a mixture of -/// transpose and broadcast into the affine map of the operand. -/// While folding of transpose and broadcast into the affine map of the -/// linalg.generic operand is a very effective optimization, sometimes -/// we may want to unfold that, for instance when recognizing named ops. +/// This file implements pattern to decompose the input operand(s) of a +/// linalg.generic that has a `transpose`, `broadcast` or a mixture of two, +/// into explicit transpose and broadcast. Having them folded into the +/// linalg.generic is a good optimization but sometimes we may want to unwrap +/// i.e. `unfold` them as explicit transpose and broadcast. This rewrite +/// pattern helps do it for each input operand. This is useful for instance +/// when trying to recognize named ops. +/// +/// The transpose, broadcast, or mixture of both, are expressed in the affine +/// map of the operand. Technically it is essentially `projected permutation`. /// /// Example /// @@ -69,14 +68,15 @@ namespace { /// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> /// /// Note that linalg.generic has been 'specialized' to linalg.div. +/// /// To unfold it is more effective to transpose first and then do the broadcast. /// However, if transpose is done first, the permutation map needs to be /// expressed in terms of reduced dimension (as broadcast hasn't happened yet). /// Also, the broadcast dimensions in a linalg.generic come from other operands /// (those not broadcasted along that particular dimension). We work this out -/// by computing the polytope shape of the linalg.gneric from shapes of all the -/// operands (inputs and outputs). - +/// by computing the convex-polyhedron shape of the linalg.gneric iteration +/// space from shapes of all the operands (inputs and outputs). +/// struct UnfoldProjectedPermutation : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -84,31 +84,6 @@ struct UnfoldProjectedPermutation : public OpRewritePattern { PatternRewriter &rewriter) const override; }; -/// Calculate shape (dimensions) of the iteration space polytope. -/// This is calculated by concatenating the indexing maps of all operands -/// of the generic; inverting the concatenation; concatenating all the -/// shapes of the operands; and then doing `apply map` to those two. -SmallVector getPolytopeDims(GenericOp op) { - assert(op.hasPureTensorSemantics() && "works only on tensors"); - - /// Concat indexing maps of all operands and invert the mapping. - auto maps = op.getIndexingMapsArray(); - auto concat = concatAffineMaps(maps); - auto inverse = inversePermutation(concat); - - /// Concat the size of each dims of all operands. - SmallVector dims; - for (auto &operand : op->getOpOperands()) { - auto rankedType = cast(operand.get().getType()); - for (auto size : rankedType.getShape()) - dims.push_back(size); - } - - /// Match the inverse map with dims to get polytope dimensions. - /// Note that some maybe 'kDynamic'. - return applyPermutationMap(inverse, dims); -} - /// For the given `map` determine what dimensions are transposed /// and what dimensions are broadcasted. /// Returns : @@ -118,11 +93,8 @@ SmallVector getPolytopeDims(GenericOp op) { std::tuple, SmallVector> computeTransposeBroadcast(AffineMap &map) { assert(map.isProjectedPermutation(false) && "not a projection"); - - // Dimensions that don't appear on result are broadcast. int64_t minorSize = map.getNumResults(); - // Convert affine expr to int64_t. SmallVector minorResult; for (int64_t i = 0; i < minorSize; ++i) { auto expr = cast(map.getResults()[i]); @@ -130,21 +102,21 @@ computeTransposeBroadcast(AffineMap &map) { } // If dims are not monotonically increasing then transpose is present. - SmallVector sorted(minorResult); - std::sort(sorted.begin(), sorted.end()); + SmallVector sortedResMap(minorResult); + std::sort(sortedResMap.begin(), sortedResMap.end()); bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(), - sorted.begin(), sorted.end()); + sortedResMap.begin(), sortedResMap.end()); // Walk the sorted map result to determine which dimensions are broadcasted. SmallVector broadcast; for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) { - if (j < minorSize && sorted[j] == i) { + if (j < minorSize && sortedResMap[j] == i) { j++; continue; } broadcast.push_back(i); } - bool hasBroadcast = broadcast.size(); + bool hasBroadcast = !broadcast.empty(); /// Consider an operand `x : tensor<7x8x9>` of a genericOp that has /// affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>` @@ -155,7 +127,7 @@ computeTransposeBroadcast(AffineMap &map) { /// that they start from d0. std::map minorMap; for (int64_t i = 0; i < minorSize; ++i) - minorMap.insert({sorted[i], i}); + minorMap.insert({sortedResMap[i], i}); // Re-map the dimensions. SmallVector remappedResult(minorSize); @@ -187,14 +159,13 @@ UnfoldProjectedPermutation::matchAndRewrite(GenericOp op, // Currently we handle only static shapes. for (auto &operand : op->getOpOperands()) { - auto rankedType = cast(operand.get().getType()); - for (auto size : rankedType.getShape()) + auto opType = cast(operand.get().getType()); + for (auto size : opType.getShape()) if (size == ShapedType::kDynamic) return failure(); } - // Calculate polytope bounds from affine maps and operand(s) shapes. - auto polytope = getPolytopeDims(op); + auto outputShape = op.getStaticLoopRanges(); auto loc = op.getLoc(); bool isChanged = false; @@ -235,7 +206,7 @@ UnfoldProjectedPermutation::matchAndRewrite(GenericOp op, if (hasBroadcast) { assert(broadcastedDims.size() && "should have non size broadcast"); Value emptyTensor = rewriter.create( - loc, polytope, inputRTType.getElementType()); + loc, outputShape, inputRTType.getElementType()); auto broadcastOp = rewriter.create( loc, newInitValues[i], emptyTensor, broadcastedDims); From e3373b8b87e52c8fba3c5084f77296d4848fd5b3 Mon Sep 17 00:00:00 2001 From: Javed Absar Date: Fri, 8 Nov 2024 07:56:16 -0500 Subject: [PATCH 5/7] Revise based on 3rd round review comments --- .../Dialect/Linalg/Transforms/Transforms.h | 2 +- .../Dialect/Linalg/Transforms/CMakeLists.txt | 2 +- ...ecomposeGenericByUnfoldingPermutation.cpp} | 115 +++++++++--------- .../Dialect/Linalg/Transforms/Specialize.cpp | 2 +- ...c-by-unfolding-projected-permutation.mlir} | 0 5 files changed, 59 insertions(+), 62 deletions(-) rename mlir/lib/Dialect/Linalg/Transforms/{UnfoldProjectedPermutation.cpp => DecomposeGenericByUnfoldingPermutation.cpp} (68%) rename mlir/test/Dialect/Linalg/{unfold-projected-permutation.mlir => decompose-generic-by-unfolding-projected-permutation.mlir} (100%) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 22af75e473681c..043268e7787fd4 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1788,7 +1788,7 @@ void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns); /// Add patterns to make explicit broadcasts and transforms in the /// input operands of a genericOp. -void populateUnfoldProjectedPermutationPatterns(RewritePatternSet &patterns); +void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns); /// Patterns to apply `splitReduction` below. void populateSplitReductionPattern( diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index dfe6d7a54c8f14..3594b084138124 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -38,7 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms TilingInterfaceImpl.cpp Transforms.cpp TransposeConv2D.cpp - UnfoldProjectedPermutation.cpp + DecomposeGenericByUnfoldingPermutation.cpp Vectorization.cpp WinogradConv2D.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp similarity index 68% rename from mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp rename to mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp index c93a6737ae37ef..c41c2d1b3bb155 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp @@ -1,4 +1,4 @@ -//===- UnfoldProjectedPermutation.cpp - extract projected projections ---===// +//===- DecomposeGenericByUnfoldingPermutation.cpp -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,27 +6,25 @@ // //===----------------------------------------------------------------------===// // -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include - #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include #include -#include +#include using namespace mlir; using namespace mlir::linalg; namespace { -/// This file implements pattern to decompose the input operand(s) of a -/// linalg.generic that has a `transpose`, `broadcast` or a mixture of two, -/// into explicit transpose and broadcast. Having them folded into the -/// linalg.generic is a good optimization but sometimes we may want to unwrap -/// i.e. `unfold` them as explicit transpose and broadcast. This rewrite -/// pattern helps do it for each input operand. This is useful for instance -/// when trying to recognize named ops. +/// This pattern decomposes the input operand(s) of a linalg.generic that has +/// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose +/// and broadcast. Having them folded into the linalg.generic is a good +/// optimization but sometimes we may want to unwrap, i.e., `unfold` them as +/// explicit transpose and broadcast. This rewrite pattern helps do it for +/// each input operand. This is useful for instance when trying to recognize +/// named ops. /// /// The transpose, broadcast, or mixture of both, are expressed in the affine /// map of the operand. Technically it is essentially `projected permutation`. @@ -69,28 +67,27 @@ namespace { /// /// Note that linalg.generic has been 'specialized' to linalg.div. /// -/// To unfold it is more effective to transpose first and then do the broadcast. -/// However, if transpose is done first, the permutation map needs to be -/// expressed in terms of reduced dimension (as broadcast hasn't happened yet). -/// Also, the broadcast dimensions in a linalg.generic come from other operands -/// (those not broadcasted along that particular dimension). We work this out -/// by computing the convex-polyhedron shape of the linalg.gneric iteration -/// space from shapes of all the operands (inputs and outputs). +/// To unfold it, it is more optimal to transpose first and then do the +/// broadcast. However, if transpose is done first, the permutation map needs +/// to be expressed in terms of reduced dimension as broadcast hasn't happened +/// yet. Also, the broadcast dimensions in a linalg.generic come from other +/// operands (those not broadcasted along that particular dimension). We work +/// this out by computing the convex-polyhedron shape of the linalg.generic +/// iteration space from shapes of all the operands, both inputs and outputs. /// -struct UnfoldProjectedPermutation : public OpRewritePattern { +struct DecomposeProjectedPermutation : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override; }; -/// For the given `map` determine what dimensions are transposed -/// and what dimensions are broadcasted. +/// For the given `map`, determine what dimensions are transposed and what +/// dimensions are broadcasted. /// Returns : -/// `isTransposed, isBroadcast, -/// transpose-permutation, broadcast-dimensions` +/// transpose-permutation, broadcast-dimensions` (empty if not needed) /// -std::tuple, SmallVector> +std::pair, SmallVector> computeTransposeBroadcast(AffineMap &map) { assert(map.isProjectedPermutation(false) && "not a projection"); int64_t minorSize = map.getNumResults(); @@ -116,36 +113,36 @@ computeTransposeBroadcast(AffineMap &map) { } broadcast.push_back(i); } - bool hasBroadcast = !broadcast.empty(); - - /// Consider an operand `x : tensor<7x8x9>` of a genericOp that has - /// affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>` - /// `x`s access is both transposed and brodcast. But when specifying - /// the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be - /// specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of - /// refering to d3, d4. Therefore, re-base the transpose dimensions so - /// that they start from d0. - std::map minorMap; - for (int64_t i = 0; i < minorSize; ++i) - minorMap.insert({sortedResMap[i], i}); - - // Re-map the dimensions. - SmallVector remappedResult(minorSize); - for (int64_t i = 0; i < minorSize; ++i) - remappedResult[i] = minorMap[minorResult[i]]; - - /// Calculate the permutation for the transpose. - SmallVector permutation(minorSize); - for (unsigned i = 0; i < minorSize; ++i) { - permutation[remappedResult[i]] = i; - } - return {hasTranspose, hasBroadcast, permutation, broadcast}; + SmallVector permutation; + if (hasTranspose) { + /// Consider an operand `x : tensor<7x8x9>` of a genericOp that has + /// affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>` + /// `x`s access is both transposed and brodcast. But when specifying + /// the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be + /// specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of + /// refering to d3, d4. Therefore, re-base the transpose dimensions so + /// that they start from d0. + permutation.resize(minorSize); + std::map minorMap; + for (int64_t i = 0; i < minorSize; ++i) + minorMap.insert({sortedResMap[i], i}); + + // Re-map the dimensions. + SmallVector remappedResult(minorSize); + for (int64_t i = 0; i < minorSize; ++i) + remappedResult[i] = minorMap[minorResult[i]]; + + /// Calculate the permutation for the transpose. + for (unsigned i = 0; i < minorSize; ++i) { + permutation[remappedResult[i]] = i; + } + } + return {permutation, broadcast}; } -LogicalResult -UnfoldProjectedPermutation::matchAndRewrite(GenericOp op, - PatternRewriter &rewriter) const { +LogicalResult DecomposeProjectedPermutation::matchAndRewrite( + GenericOp op, PatternRewriter &rewriter) const { if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() || op.isSingleYieldOp() || !op.isAllParallelLoops()) return failure(); @@ -183,10 +180,10 @@ UnfoldProjectedPermutation::matchAndRewrite(GenericOp op, if (map.isIdentity()) continue; - auto [hasTranspose, hasBroadcast, permutation, broadcastedDims] = - computeTransposeBroadcast(map); + auto [permutation, broadcastedDims] = computeTransposeBroadcast(map); - if (hasTranspose) { + // Does it need transpose? + if (!permutation.empty()) { /// linalg.transpose permutes the dimensions of input using /// rule: dim(result, i) = dim(input, permutation[i]) SmallVector transposedShape(map.getNumResults()); @@ -202,8 +199,8 @@ UnfoldProjectedPermutation::matchAndRewrite(GenericOp op, isChanged = true; } - // Does it require broadcast - if (hasBroadcast) { + // Does it require broadcast? + if (!broadcastedDims.empty()) { assert(broadcastedDims.size() && "should have non size broadcast"); Value emptyTensor = rewriter.create( loc, outputShape, inputRTType.getElementType()); @@ -237,7 +234,7 @@ UnfoldProjectedPermutation::matchAndRewrite(GenericOp op, } // namespace -void mlir::linalg::populateUnfoldProjectedPermutationPatterns( +void mlir::linalg::populateDecomposeProjectedPermutationPatterns( RewritePatternSet &patterns) { - patterns.insert(patterns.getContext()); + patterns.insert(patterns.getContext()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index a911286d5d44b2..748e2a1377930d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -347,7 +347,7 @@ struct LinalgSpecializeGenericOpsPass void LinalgSpecializeGenericOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateLinalgGenericOpsSpecializationPatterns(patterns); - populateUnfoldProjectedPermutationPatterns(patterns); + populateDecomposeProjectedPermutationPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); diff --git a/mlir/test/Dialect/Linalg/unfold-projected-permutation.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir similarity index 100% rename from mlir/test/Dialect/Linalg/unfold-projected-permutation.mlir rename to mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir From 296f805d8046b798fb8ca08de2917523ee9c701c Mon Sep 17 00:00:00 2001 From: Javed Absar Date: Sat, 9 Nov 2024 13:40:17 -0500 Subject: [PATCH 6/7] Revise based on 3rd round review comments --- ...DecomposeGenericByUnfoldingPermutation.cpp | 26 ++++++++++++------- ...ic-by-unfolding-projected-permutation.mlir | 2 +- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp index c41c2d1b3bb155..119ff045ffaa52 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp @@ -90,6 +90,9 @@ struct DecomposeProjectedPermutation : public OpRewritePattern { std::pair, SmallVector> computeTransposeBroadcast(AffineMap &map) { assert(map.isProjectedPermutation(false) && "not a projection"); + + // As the map is a projection it likely operates on a smaller set of + // dimensions as far as the transpose is concerned (rest are broadcast). int64_t minorSize = map.getNumResults(); SmallVector minorResult; @@ -116,13 +119,13 @@ computeTransposeBroadcast(AffineMap &map) { SmallVector permutation; if (hasTranspose) { - /// Consider an operand `x : tensor<7x8x9>` of a genericOp that has - /// affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>` - /// `x`s access is both transposed and brodcast. But when specifying - /// the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be - /// specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of - /// refering to d3, d4. Therefore, re-base the transpose dimensions so - /// that they start from d0. + // Consider an operand `x : tensor<7x8x9>` of a genericOp that has + // affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>` + // `x`s access is both transposed and broadcast. But when specifying + // the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be + // specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of + // refering to d3, d4. Therefore, re-base the transpose dimensions so + // that they start from d0. permutation.resize(minorSize); std::map minorMap; for (int64_t i = 0; i < minorSize; ++i) @@ -147,14 +150,19 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite( op.isSingleYieldOp() || !op.isAllParallelLoops()) return failure(); - // All maps need to be projected permutations. + // If the map of an operand is not a `projected permutation` then + // it cannot be decomposed to mere transpose and broadcast. + // The requirement that all maps be `projected permutation` may be + // over-restrictive but since we need to determine shape of the + // iteration space as well, reject if any map violates assumption. for (auto &opOperand : op->getOpOperands()) { auto map = op.getMatchingIndexingMap(&opOperand); if (!map.isProjectedPermutation(false)) return failure(); } - // Currently we handle only static shapes. + // Decomposing linalg.generic involves creating `tensor.empty` + // which cannot have dnyamic shapes. for (auto &operand : op->getOpOperands()) { auto opType = cast(operand.get().getType()); for (auto size : opType.getShape()) diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir index c85eab02e792f9..38e406a13ec087 100644 --- a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir +++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir @@ -5,7 +5,7 @@ func.func @transpose_and_broadcast(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> { %res = linalg.generic - { indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} + { indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) outs(%z : tensor<5x9x7x8x10xf32>) { ^bb0(%in: f32, %in_1: f32, %out: f32): %div = arith.divf %in, %in_1 : f32 From 6f61f9a9d43acda185a417ade320a1ede99cda07 Mon Sep 17 00:00:00 2001 From: Javed Absar Date: Sat, 9 Nov 2024 14:18:11 -0500 Subject: [PATCH 7/7] Revise based on 5th review comments --- .../DecomposeGenericByUnfoldingPermutation.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp index 119ff045ffaa52..83c4b5bdf10976 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp @@ -162,13 +162,14 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite( } // Decomposing linalg.generic involves creating `tensor.empty` - // which cannot have dnyamic shapes. - for (auto &operand : op->getOpOperands()) { - auto opType = cast(operand.get().getType()); - for (auto size : opType.getShape()) - if (size == ShapedType::kDynamic) - return failure(); - } + // which can have dynamic shapes but then we would have to work + // out which operand can supply that runtime-value (tensor.dim). + // Leaving it as a future TODO. + if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) { + auto opType = cast(oper.get().getType()); + return ShapedType::isDynamicShape(opType.getShape()); + })) + return failure(); auto outputShape = op.getStaticLoopRanges();