Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] unfold projected permutation. #114704

Merged
merged 7 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def LinalgStructuredInterface
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getNumParallelLoops() == getNumParallelLoops();
return getNumParallelLoops() == getNumLoops();
banach-space marked this conversation as resolved.
Show resolved Hide resolved
}]
>,
InterfaceMethod<
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1786,6 +1786,10 @@ void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
/// linalg.fill(%cst, tensor.extract_slice(%init)).
void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns);

/// Add patterns to make explicit broadcasts and transforms in the
/// input operands of a genericOp.
void populateUnfoldProjectedPermutationPatterns(RewritePatternSet &patterns);

/// Patterns to apply `splitReduction` below.
void populateSplitReductionPattern(
RewritePatternSet &patterns,
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
TilingInterfaceImpl.cpp
Transforms.cpp
TransposeConv2D.cpp
UnfoldProjectedPermutation.cpp
Vectorization.cpp
WinogradConv2D.cpp

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ struct LinalgSpecializeGenericOpsPass
void LinalgSpecializeGenericOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLinalgGenericOpsSpecializationPatterns(patterns);
populateUnfoldProjectedPermutationPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
Expand Down
270 changes: 270 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
//===- UnfoldProjectedPermutation.cpp - extract projected projections ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements pattern to decompose the operand of a GenericOp that
// has `transpose+broadcast` juxtaposed via its affine map into separate
// transpose and broadcast ops.
//
//===----------------------------------------------------------------------===//
//
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include <utility>

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include <map>
#include <optional>
#include <vector>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is vector needed?

You should also move <utility> to this section. I think if you remove the blank line (11) , clang-format would sort the includes in the right order for you :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


using namespace mlir;
using namespace mlir::linalg;

namespace {

/// Projected permutation are effectively folding in of a mixture of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is very insightful, but doesn't really say what the pattern does. Also, I don't believe UnfoldProjectedPermutation is accurate. The pattern (based on this comment and the tests), specialised linalg.generic {op} into linalg.transpose + linalg.transpose + op?

Unfolding the permutation map is just an implementation detail. A very important one, but not the ultimate goal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed.

/// transpose and broadcast into the affine map of the operand.
/// While folding of transpose and broadcast into the affine map of the
/// linalg.generic operand is a very effective optimization, sometimes
/// we may want to unfold that, for instance when recognizing named ops.
///
/// Example
///
/// ```mlir
///
/// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
/// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
/// ...
/// %res = linalg.generic
/// { indexing_maps = [#projection, #identity, #identity],
/// iterator_types = ["parallel", "parallel", "parallel",
/// "parallel", "parallel"]}
/// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>)
/// outs(%z : tensor<5x9x7x8x10xf32>) {
/// ^bb0(%in: f32, %in_1: f32, %out: f32):
/// %div = arith.divf %in, %in_1 : f32
/// linalg.yield %div : f32
/// } -> tensor<5x9x7x8x10xf32>
/// ```
///
/// In the above IR operand `%x` map is a projected-permutation. This can be
banach-space marked this conversation as resolved.
Show resolved Hide resolved
/// unfolded as:
///
/// ```mlir
/// ...
/// %transposed = linalg.transpose ins(%x : tensor<7x8x9xf32>)
/// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
/// ...
/// %broadcasted = linalg.broadcast ins(%transposed : tensor<9x7x8xf32>)
/// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
banach-space marked this conversation as resolved.
Show resolved Hide resolved
/// %2 = linalg.div
/// ins(%broadcasted, %y :
/// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>)
/// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
///
/// Note that linalg.generic has been 'specialized' to linalg.div.
/// To unfold it is more effective to transpose first and then do the broadcast.
banach-space marked this conversation as resolved.
Show resolved Hide resolved
/// However, if transpose is done first, the permutation map needs to be
/// expressed in terms of reduced dimension (as broadcast hasn't happened yet).
/// Also, the broadcast dimensions in a linalg.generic come from other operands
/// (those not broadcasted along that particular dimension). We work this out
/// by computing the polytope shape of the linalg.gneric from shapes of all the
/// operands (inputs and outputs).

struct UnfoldProjectedPermutation : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;

LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override;
};

/// Calculate shape (dimensions) of the iteration space polytope.
/// This is calculated by concatenating the indexing maps of all operands
/// of the generic; inverting the concatenation; concatenating all the
/// shapes of the operands; and then doing `apply map` to those two.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "polytope" a standard term for these things? To me this is just combining all the dims and inverting them using the corresponding maps. I don't see "polytope" being used anywhere in the context of Linalg?

Also, just based on the test, this simply returns the type of the output tensor. Why not use that instead? In what cases grabbing the output tensor would not be sufficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have called it 'convex polyhedron' or nDimRectangleDims - would have been better. But now as @MaheshRavishankar suggested getStaticLoopRanges does same job.

SmallVector<int64_t> getPolytopeDims(GenericOp op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this whole thing should be same as op.getStaticLoopRanges().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome . Yes, does it. Thanks for pointing out. Wish i had seen earlier, but then wouldnt have learnt inversePermutation thingie :)

assert(op.hasPureTensorSemantics() && "works only on tensors");

/// Concat indexing maps of all operands and invert the mapping.
auto maps = op.getIndexingMapsArray();
auto concat = concatAffineMaps(maps);
auto inverse = inversePermutation(concat);
banach-space marked this conversation as resolved.
Show resolved Hide resolved

/// Concat the size of each dims of all operands.
SmallVector<int64_t> dims;
for (auto &operand : op->getOpOperands()) {
auto rankedType = cast<RankedTensorType>(operand.get().getType());
banach-space marked this conversation as resolved.
Show resolved Hide resolved
for (auto size : rankedType.getShape())
dims.push_back(size);
}

/// Match the inverse map with dims to get polytope dimensions.
/// Note that some maybe 'kDynamic'.
return applyPermutationMap<int64_t>(inverse, dims);
}

/// For the given `map` determine what dimensions are transposed
banach-space marked this conversation as resolved.
Show resolved Hide resolved
/// and what dimensions are broadcasted.
/// Returns :
/// `isTransposed, isBroadcast,
/// transpose-permutation, broadcast-dimensions`
///
std::tuple<bool, bool, SmallVector<int64_t>, SmallVector<int64_t>>
banach-space marked this conversation as resolved.
Show resolved Hide resolved
computeTransposeBroadcast(AffineMap &map) {
assert(map.isProjectedPermutation(false) && "not a projection");

// Dimensions that don't appear on result are broadcast.
banach-space marked this conversation as resolved.
Show resolved Hide resolved
int64_t minorSize = map.getNumResults();

// Convert affine expr to int64_t.
banach-space marked this conversation as resolved.
Show resolved Hide resolved
SmallVector<int64_t> minorResult;
for (int64_t i = 0; i < minorSize; ++i) {
auto expr = cast<AffineDimExpr>(map.getResults()[i]);
minorResult.push_back(expr.getPosition());
}

// If dims are not monotonically increasing then transpose is present.
SmallVector<int64_t> sorted(minorResult);
banach-space marked this conversation as resolved.
Show resolved Hide resolved
std::sort(sorted.begin(), sorted.end());
bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(),
sorted.begin(), sorted.end());

// Walk the sorted map result to determine which dimensions are broadcasted.
SmallVector<int64_t> broadcast;
for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) {
if (j < minorSize && sorted[j] == i) {
j++;
continue;
}
broadcast.push_back(i);
}
bool hasBroadcast = broadcast.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bool hasBroadcast = broadcast.size();
bool hasBroadcast = !broadcast.empty();


/// Consider an operand `x : tensor<7x8x9>` of a genericOp that has
/// affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
/// `x`s access is both transposed and brodcast. But when specifying
/// the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
/// specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
/// refering to d3, d4. Therefore, re-base the transpose dimensions so
/// that they start from d0.
std::map<int64_t, int64_t> minorMap;
for (int64_t i = 0; i < minorSize; ++i)
minorMap.insert({sorted[i], i});

// Re-map the dimensions.
SmallVector<int64_t> remappedResult(minorSize);
for (int64_t i = 0; i < minorSize; ++i)
remappedResult[i] = minorMap[minorResult[i]];

/// Calculate the permutation for the transpose.
SmallVector<int64_t> permutation(minorSize);
for (unsigned i = 0; i < minorSize; ++i) {
permutation[remappedResult[i]] = i;
}

return {hasTranspose, hasBroadcast, permutation, broadcast};
}

LogicalResult
UnfoldProjectedPermutation::matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const {
if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
op.isSingleYieldOp() || !op.isAllParallelLoops())
return failure();

// All maps need to be projected permutations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intuitively this makes sense, but ... why? 😅 Which part would break?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping

for (auto &opOperand : op->getOpOperands()) {
auto map = op.getMatchingIndexingMap(&opOperand);
if (!map.isProjectedPermutation(false))
return failure();
}

// Currently we handle only static shapes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this work at all for dynamic shapes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a start this will assert when trying to create tensor.empty with dynamic shape. https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp#L874

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, rather than documenting what the code does, could you add a comment saying "why"? Or what's missing? From what you are saying, we'd need to add logic to compute dynamic sizes of the input tensors for ops like EmptyOp? And probably sth else as well?

for (auto &operand : op->getOpOperands()) {
auto rankedType = cast<RankedTensorType>(operand.get().getType());
for (auto size : rankedType.getShape())
if (size == ShapedType::kDynamic)
return failure();
}

// Calculate polytope bounds from affine maps and operand(s) shapes.
auto polytope = getPolytopeDims(op);

auto loc = op.getLoc();
bool isChanged = false;
SmallVector<Value> newInitValues = op.getDpsInputs();
SmallVector<AffineMap> newMap = op.getIndexingMapsArray();

// Walk over each input operand and unfold if it is transposed, broadcast
// or mix of two via operand's affine-map.
for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
auto &map = newMap[i];
auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType());
auto elType = inputRTType.getElementType();

/// Nothing to do if map is already an identity.
if (map.isIdentity())
continue;

auto [hasTranspose, hasBroadcast, permutation, broadcastedDims] =
computeTransposeBroadcast(map);

if (hasTranspose) {
/// linalg.transpose permutes the dimensions of input using
/// rule: dim(result, i) = dim(input, permutation[i])
SmallVector<int64_t> transposedShape(map.getNumResults());
for (int64_t i = 0; i < map.getNumResults(); ++i)
transposedShape[i] = inputRTType.getShape()[permutation[i]];

Value emptyTensor =
rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType);

auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i],
emptyTensor, permutation);
newInitValues[i] = transposeOp->getResult(0);
isChanged = true;
}

// Does it require broadcast
banach-space marked this conversation as resolved.
Show resolved Hide resolved
if (hasBroadcast) {
assert(broadcastedDims.size() && "should have non size broadcast");
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, polytope, inputRTType.getElementType());

auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
loc, newInitValues[i], emptyTensor, broadcastedDims);

newInitValues[i] = broadcastOp->getResult(0);
isChanged = true;
}
newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
}

if (isChanged) {
SmallVector<Value> operands = op->getOperands();
ValueRange operandsRef(operands);

auto newOp = rewriter.create<linalg::GenericOp>(
/*location=*/op.getLoc(),
/*resultTensorTypes=*/op->getResultTypes(),
/*inputs=*/newInitValues,
/*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
/*indexingMaps=*/newMap,
/*iteratorTypes=*/op.getIteratorTypesArray());

newOp.getRegion().takeBody(op->getRegion(0));
rewriter.replaceOp(op, newOp->getResults());
}
return success();
}

} // namespace

void mlir::linalg::populateUnfoldProjectedPermutationPatterns(
RewritePatternSet &patterns) {
patterns.insert<UnfoldProjectedPermutation>(patterns.getContext());
}
71 changes: 71 additions & 0 deletions mlir/test/Dialect/Linalg/unfold_projected_permutation.mlir
banach-space marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s

#projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
#identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>

func.func @test_mixed(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> {
banach-space marked this conversation as resolved.
Show resolved Hide resolved
%res = linalg.generic
{ indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) outs(%z : tensor<5x9x7x8x10xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%div = arith.divf %in, %in_1 : f32
linalg.yield %div : f32
} -> tensor<5x9x7x8x10xf32>
return %res : tensor<5x9x7x8x10xf32>
}

// CHECK-LABEL: test_mixed
// CHECK-SAME: %[[X:.+]]: tensor<7x8x9xf32>, %[[Y:.+]]: tensor<5x9x7x8x10xf32>, %[[Z:.+]]: tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> {
// CHECK: %[[E0:.+]] = tensor.empty() : tensor<9x7x8xf32>
// CHECK: %[[Transposed:.+]] = linalg.transpose ins(%[[X]] : tensor<7x8x9xf32>) outs(%[[E0]] : tensor<9x7x8xf32>) permutation = [2, 0, 1]
banach-space marked this conversation as resolved.
Show resolved Hide resolved
// CHECK: %[[E1:.+]] = tensor.empty() : tensor<5x9x7x8x10xf32>
// CHECK: %[[Broadcasted:.+]] = linalg.broadcast ins(%[[Transposed]] : tensor<9x7x8xf32>) outs(%[[E1]] : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
// CHECK: {{.*}} = linalg.div ins(%[[Broadcasted]], %[[Y]] : tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) outs(%[[Z]] : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
// CHECK-NOT: linalg.generic

// -----

#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#transposed = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love the descriptive names of the maps, but lets be consistent (broadcast + transpose vs broadcast + transposed?)

Suggested change
#transposed = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
#transpose = affine_map<(d0, d1, d2) -> (d2, d0, d1)>


func.func @test_transposed(%x : tensor<32x2x16xf32>, %y: tensor<2x16x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
%res = linalg.generic
{ indexing_maps = [#transposed, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%x, %y : tensor<32x2x16xf32>, tensor<2x16x32xf32>)
outs(%z : tensor<2x16x32xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%div = arith.divf %in, %in_1 : f32
linalg.yield %div : f32
} -> tensor<2x16x32xf32>
return %res : tensor<2x16x32xf32>
}

// CHECK-LABEL: test_transposed
// CHECK-SAME: %[[X:.+]]: tensor<32x2x16xf32>, %[[Y:.+]]: tensor<2x16x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32>
// CHECK: %[[Transposed:.+]] = linalg.transpose ins(%[[X]] : tensor<32x2x16xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) permutation = [1, 2, 0]
// CHECK: {{.*}} = linalg.div ins(%[[Transposed]], %[[Y]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[Z]] : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
// CHECK-NOT: linalg.generic

// -----

#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#broadcast = affine_map<(d0, d1, d2) -> (d0, d2)>
func.func @test_broadcast(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
%res = linalg.generic
{ indexing_maps = [#identity, #broadcast, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%x, %y : tensor<2x16x32xf32>, tensor<2x32xf32>)
outs(%z : tensor<2x16x32xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%div = arith.divf %in, %in_1 : f32
linalg.yield %div : f32
} -> tensor<2x16x32xf32>
return %res : tensor<2x16x32xf32>
}

// CHECK-LABEL: test_broadcast
// CHECK-SAME: %[[X:.+]]: tensor<2x16x32xf32>, %[[Y:.+]]: tensor<2x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32>
// CHECK: %[[Broadcasted:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1]
// CHECK: {{.*}} = linalg.div ins(%[[X]], %[[Broadcasted]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
// CHECK-NOT: linalg.generic
Loading