diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 0a404194569c22..b81a4c9c8760cf 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -252,7 +252,7 @@ def LinalgStructuredInterface /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return getNumParallelLoops() == getNumParallelLoops(); + return getNumParallelLoops() == getNumLoops(); }] >, InterfaceMethod< diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 0693e31b4f70af..043268e7787fd4 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1786,6 +1786,10 @@ void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); /// linalg.fill(%cst, tensor.extract_slice(%init)). void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns); +/// Add patterns to make explicit broadcasts and transforms in the +/// input operands of a genericOp. +void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns); + /// Patterns to apply `splitReduction` below. void populateSplitReductionPattern( RewritePatternSet &patterns, diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index d7c63cdd8198d7..3594b084138124 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms TilingInterfaceImpl.cpp Transforms.cpp TransposeConv2D.cpp + DecomposeGenericByUnfoldingPermutation.cpp Vectorization.cpp WinogradConv2D.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp new file mode 100644 index 00000000000000..83c4b5bdf10976 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp @@ -0,0 +1,249 @@ +//===- DecomposeGenericByUnfoldingPermutation.cpp -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include +#include +#include + +using namespace mlir; +using namespace mlir::linalg; + +namespace { + +/// This pattern decomposes the input operand(s) of a linalg.generic that has +/// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose +/// and broadcast. Having them folded into the linalg.generic is a good +/// optimization but sometimes we may want to unwrap, i.e., `unfold` them as +/// explicit transpose and broadcast. This rewrite pattern helps do it for +/// each input operand. This is useful for instance when trying to recognize +/// named ops. +/// +/// The transpose, broadcast, or mixture of both, are expressed in the affine +/// map of the operand. Technically it is essentially `projected permutation`. +/// +/// Example +/// +/// ```mlir +/// +/// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)> +/// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +/// ... +/// %res = linalg.generic +/// { indexing_maps = [#projection, #identity, #identity], +/// iterator_types = ["parallel", "parallel", "parallel", +/// "parallel", "parallel"]} +/// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) +/// outs(%z : tensor<5x9x7x8x10xf32>) { +/// ^bb0(%in: f32, %in_1: f32, %out: f32): +/// %div = arith.divf %in, %in_1 : f32 +/// linalg.yield %div : f32 +/// } -> tensor<5x9x7x8x10xf32> +/// ``` +/// +/// In the above IR operand `%x` map is a projected-permutation. This can be +/// unfolded as: +/// +/// ```mlir +/// ... +/// %x_trans = linalg.transpose +/// ins(%x : tensor<7x8x9xf32>) +/// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1] +/// ... +/// %x_trans_bc = linalg.broadcast +/// ins(%x_trans : tensor<9x7x8xf32>) +/// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] +/// %2 = linalg.div +/// ins(%x_trans_bc, %y : +/// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) +/// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> +/// +/// Note that linalg.generic has been 'specialized' to linalg.div. +/// +/// To unfold it, it is more optimal to transpose first and then do the +/// broadcast. However, if transpose is done first, the permutation map needs +/// to be expressed in terms of reduced dimension as broadcast hasn't happened +/// yet. Also, the broadcast dimensions in a linalg.generic come from other +/// operands (those not broadcasted along that particular dimension). We work +/// this out by computing the convex-polyhedron shape of the linalg.generic +/// iteration space from shapes of all the operands, both inputs and outputs. +/// +struct DecomposeProjectedPermutation : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override; +}; + +/// For the given `map`, determine what dimensions are transposed and what +/// dimensions are broadcasted. +/// Returns : +/// transpose-permutation, broadcast-dimensions` (empty if not needed) +/// +std::pair, SmallVector> +computeTransposeBroadcast(AffineMap &map) { + assert(map.isProjectedPermutation(false) && "not a projection"); + + // As the map is a projection it likely operates on a smaller set of + // dimensions as far as the transpose is concerned (rest are broadcast). + int64_t minorSize = map.getNumResults(); + + SmallVector minorResult; + for (int64_t i = 0; i < minorSize; ++i) { + auto expr = cast(map.getResults()[i]); + minorResult.push_back(expr.getPosition()); + } + + // If dims are not monotonically increasing then transpose is present. + SmallVector sortedResMap(minorResult); + std::sort(sortedResMap.begin(), sortedResMap.end()); + bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(), + sortedResMap.begin(), sortedResMap.end()); + + // Walk the sorted map result to determine which dimensions are broadcasted. + SmallVector broadcast; + for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) { + if (j < minorSize && sortedResMap[j] == i) { + j++; + continue; + } + broadcast.push_back(i); + } + + SmallVector permutation; + if (hasTranspose) { + // Consider an operand `x : tensor<7x8x9>` of a genericOp that has + // affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>` + // `x`s access is both transposed and broadcast. But when specifying + // the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be + // specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of + // refering to d3, d4. Therefore, re-base the transpose dimensions so + // that they start from d0. + permutation.resize(minorSize); + std::map minorMap; + for (int64_t i = 0; i < minorSize; ++i) + minorMap.insert({sortedResMap[i], i}); + + // Re-map the dimensions. + SmallVector remappedResult(minorSize); + for (int64_t i = 0; i < minorSize; ++i) + remappedResult[i] = minorMap[minorResult[i]]; + + /// Calculate the permutation for the transpose. + for (unsigned i = 0; i < minorSize; ++i) { + permutation[remappedResult[i]] = i; + } + } + return {permutation, broadcast}; +} + +LogicalResult DecomposeProjectedPermutation::matchAndRewrite( + GenericOp op, PatternRewriter &rewriter) const { + if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() || + op.isSingleYieldOp() || !op.isAllParallelLoops()) + return failure(); + + // If the map of an operand is not a `projected permutation` then + // it cannot be decomposed to mere transpose and broadcast. + // The requirement that all maps be `projected permutation` may be + // over-restrictive but since we need to determine shape of the + // iteration space as well, reject if any map violates assumption. + for (auto &opOperand : op->getOpOperands()) { + auto map = op.getMatchingIndexingMap(&opOperand); + if (!map.isProjectedPermutation(false)) + return failure(); + } + + // Decomposing linalg.generic involves creating `tensor.empty` + // which can have dynamic shapes but then we would have to work + // out which operand can supply that runtime-value (tensor.dim). + // Leaving it as a future TODO. + if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) { + auto opType = cast(oper.get().getType()); + return ShapedType::isDynamicShape(opType.getShape()); + })) + return failure(); + + auto outputShape = op.getStaticLoopRanges(); + + auto loc = op.getLoc(); + bool isChanged = false; + SmallVector newInitValues = op.getDpsInputs(); + SmallVector newMap = op.getIndexingMapsArray(); + + // Walk over each input operand and unfold if it is transposed, broadcast + // or mix of two via operand's affine-map. + for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) { + auto &map = newMap[i]; + auto inputRTType = cast(newInitValues[i].getType()); + auto elType = inputRTType.getElementType(); + + /// Nothing to do if map is already an identity. + if (map.isIdentity()) + continue; + + auto [permutation, broadcastedDims] = computeTransposeBroadcast(map); + + // Does it need transpose? + if (!permutation.empty()) { + /// linalg.transpose permutes the dimensions of input using + /// rule: dim(result, i) = dim(input, permutation[i]) + SmallVector transposedShape(map.getNumResults()); + for (int64_t i = 0; i < map.getNumResults(); ++i) + transposedShape[i] = inputRTType.getShape()[permutation[i]]; + + Value emptyTensor = + rewriter.create(loc, transposedShape, elType); + + auto transposeOp = rewriter.create(loc, newInitValues[i], + emptyTensor, permutation); + newInitValues[i] = transposeOp->getResult(0); + isChanged = true; + } + + // Does it require broadcast? + if (!broadcastedDims.empty()) { + assert(broadcastedDims.size() && "should have non size broadcast"); + Value emptyTensor = rewriter.create( + loc, outputShape, inputRTType.getElementType()); + + auto broadcastOp = rewriter.create( + loc, newInitValues[i], emptyTensor, broadcastedDims); + + newInitValues[i] = broadcastOp->getResult(0); + isChanged = true; + } + newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims()); + } + + if (isChanged) { + SmallVector operands = op->getOperands(); + ValueRange operandsRef(operands); + + auto newOp = rewriter.create( + /*location=*/op.getLoc(), + /*resultTensorTypes=*/op->getResultTypes(), + /*inputs=*/newInitValues, + /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()), + /*indexingMaps=*/newMap, + /*iteratorTypes=*/op.getIteratorTypesArray()); + + newOp.getRegion().takeBody(op->getRegion(0)); + rewriter.replaceOp(op, newOp->getResults()); + } + return success(); +} + +} // namespace + +void mlir::linalg::populateDecomposeProjectedPermutationPatterns( + RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index dfafffce9d9b60..748e2a1377930d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -347,6 +347,7 @@ struct LinalgSpecializeGenericOpsPass void LinalgSpecializeGenericOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateLinalgGenericOpsSpecializationPatterns(patterns); + populateDecomposeProjectedPermutationPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); diff --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir new file mode 100644 index 00000000000000..38e406a13ec087 --- /dev/null +++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir @@ -0,0 +1,71 @@ +// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s + +#projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)> +#identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> + +func.func @transpose_and_broadcast(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> { + %res = linalg.generic + { indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} + ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) outs(%z : tensor<5x9x7x8x10xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %div = arith.divf %in, %in_1 : f32 + linalg.yield %div : f32 + } -> tensor<5x9x7x8x10xf32> + return %res : tensor<5x9x7x8x10xf32> +} + +// CHECK-LABEL: transpose_and_broadcast +// CHECK-SAME: %[[X:.+]]: tensor<7x8x9xf32>, %[[Y:.+]]: tensor<5x9x7x8x10xf32>, %[[Z:.+]]: tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> { +// CHECK: %[[E0:.+]] = tensor.empty() : tensor<9x7x8xf32> +// CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<7x8x9xf32>) outs(%[[E0]] : tensor<9x7x8xf32>) permutation = [2, 0, 1] +// CHECK: %[[E1:.+]] = tensor.empty() : tensor<5x9x7x8x10xf32> +// CHECK: %[[X_trans_bc:.+]] = linalg.broadcast ins(%[[X_trans]] : tensor<9x7x8xf32>) outs(%[[E1]] : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] +// CHECK: {{.*}} = linalg.div ins(%[[X_trans_bc]], %[[Y]] : tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) outs(%[[Z]] : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> +// CHECK-NOT: linalg.generic + +// ----- + +#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#transposed = affine_map<(d0, d1, d2) -> (d2, d0, d1)> + +func.func @transpose_only(%x : tensor<32x2x16xf32>, %y: tensor<2x16x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { + %res = linalg.generic + { indexing_maps = [#transposed, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel"]} + ins(%x, %y : tensor<32x2x16xf32>, tensor<2x16x32xf32>) + outs(%z : tensor<2x16x32xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %div = arith.divf %in, %in_1 : f32 + linalg.yield %div : f32 + } -> tensor<2x16x32xf32> + return %res : tensor<2x16x32xf32> +} + +// CHECK-LABEL: transpose_only +// CHECK-SAME: %[[X:.+]]: tensor<32x2x16xf32>, %[[Y:.+]]: tensor<2x16x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { +// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32> +// CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<32x2x16xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) permutation = [1, 2, 0] +// CHECK: {{.*}} = linalg.div ins(%[[X_trans]], %[[Y]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[Z]] : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> +// CHECK-NOT: linalg.generic + +// ----- + +#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#broadcast = affine_map<(d0, d1, d2) -> (d0, d2)> +func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { + %res = linalg.generic + { indexing_maps = [#identity, #broadcast, #identity], iterator_types = ["parallel", "parallel", "parallel"]} + ins(%x, %y : tensor<2x16x32xf32>, tensor<2x32xf32>) + outs(%z : tensor<2x16x32xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %div = arith.divf %in, %in_1 : f32 + linalg.yield %div : f32 + } -> tensor<2x16x32xf32> + return %res : tensor<2x16x32xf32> +} + +// CHECK-LABEL: broadcast_only +// CHECK-SAME: %[[X:.+]]: tensor<2x16x32xf32>, %[[Y:.+]]: tensor<2x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { +// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32> +// CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1] +// CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> +// CHECK-NOT: linalg.generic