Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retire downstream constant packing folder #921

Merged
merged 3 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ def ConstantFoldPack : Pass<"constant-fold-pack", "ModuleOp"> {
let description = [{
Reduce pack overhead by folding tensor.pack into constant tensors.
}];
let dependentDialects = ["linalg::LinalgDialect",
"tensor::TensorDialect",
"arith::ArithDialect"];
}

def ElementWiseFusion : Pass<"element-wise-fusion", "func::FuncOp"> {
Expand Down
210 changes: 50 additions & 160 deletions lib/TPP/Transforms/ConstantFoldPack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Threading.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

using namespace mlir;

Expand All @@ -25,179 +25,69 @@ namespace tpp {
} // namespace tpp
} // namespace mlir

#define DEBUG_TYPE "fold-pack-into-cst"

namespace {

struct ConstantFoldPack
: public tpp::impl::ConstantFoldPackBase<ConstantFoldPack> {
// Helper pattern - lower tensor.pack operations that pack constants.
struct LowerConstantPacking : public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;

// Collect a packed constantOp and its attribute if any.
static FailureOr<std::pair<arith::ConstantOp, DenseElementsAttr>>
getDenseAttributeAndConstant(tensor::PackOp packOp) {
if (packOp.getPaddingValue())
return failure();
Value sourcePack = packOp.getSource();
auto cstOp = sourcePack.getDefiningOp<arith::ConstantOp>();
if (!cstOp)
LogicalResult matchAndRewrite(tensor::PackOp packOp,
PatternRewriter &rewriter) const override {
auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
if (!constOp)
return failure();
auto cst = cstOp.getValue();
if (!isa<DenseElementsAttr>(cst))
// Must be a dense constant.
auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
if (!denseAttr)
return failure();
auto oldDense = cast<DenseElementsAttr>(cst);
return std::make_pair(cstOp, oldDense);
}

static bool areStaticValues(ArrayRef<int64_t> tilesSizes) {
return !llvm::is_contained(tilesSizes, ShapedType::kDynamic);
}

void foldPackIntoCst(RewriterBase &rewriter, tensor::PackOp packOp) {
// Bail out if the user uses pack as a writable operation
// (i.e., the destination is not a tensor.empty).
// Bail out if the pack is used as a writing operation i.e., the destination
// is not a tensor.empty.
if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
return;
OpBuilder::InsertionGuard guard(rewriter);
auto cstAndAttribute = getDenseAttributeAndConstant(packOp);
if (failed(cstAndAttribute))
return;
auto [cstOp, oldDense] = *(cstAndAttribute);
// Happy path, splat constant.
if (oldDense.isSplat()) {
auto newDense = oldDense.reshape(packOp.getDestType());
rewriter.setInsertionPoint(cstOp);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, newDense);
return;
}
LLVM_DEBUG(llvm::dbgs()
<< "NUM ELEMENT: " << oldDense.getNumElements() << "\n");
const int64_t bytes =
oldDense.getRawData().size() / oldDense.getNumElements();

// The original buffer.
ArrayRef<char> rawData = oldDense.getRawData();
// The new buffer.
SmallVector<char> destRawData(rawData.size());

int64_t numberOfElements = oldDense.getNumElements();
SmallVector<int64_t> strides =
computeStrides(packOp.getDestType().getShape());
LLVM_DEBUG(llvm::dbgs() << "#STRIDES: " << strides.size() << "\n";
for (int64_t stride
: strides) llvm::dbgs()
<< stride << " ";
llvm::dbgs() << "\n";);

parallelFor(
packOp.getContext(), 0, numberOfElements,
[&](size_t destLinearizedIdx) {
// Step1. De-linearize destination index.
// f(lin) = tmp[A][B][C]
SmallVector<int64_t> delDestIndexes =
delinearize(destLinearizedIdx, strides);
assert(delDestIndexes.size() ==
static_cast<size_t>(packOp.getDestType().getRank()));

// Step2. Arrange the indexes based on the packing
// information. Step 2.1: Compute inverse of outerDimsPerm to
// bring the loops into the canonical form tmp[A][B][a][b].
if (!packOp.getOuterDimsPerm().empty()) {
SmallVector<int64_t> inversePermutation =
invertPermutationVector(packOp.getOuterDimsPerm());
SmallVector<int64_t> tileLoops;
for (auto i = 0; i < packOp.getSourceType().getRank(); i++)
tileLoops.push_back(delDestIndexes[i]);
applyPermutationToVector(tileLoops, inversePermutation);
SmallVector<int64_t> pointLoops;
for (size_t i = packOp.getSourceType().getRank();
i < delDestIndexes.size(); i++) {
pointLoops.push_back(delDestIndexes[i]);
}
delDestIndexes = tileLoops;
delDestIndexes.append(pointLoops.begin(), pointLoops.end());
assert(delDestIndexes.size() ==
static_cast<size_t>(packOp.getDestType().getRank()));
}
// Step 2.2
// After interchanging the outermost tiled loop we end up in
// the canonical form tmp[A][B][a][b]. Squash the point loops
// with the tiled ones.
llvm::DenseSet<int64_t> tiledLoops(packOp.getInnerDimsPos().begin(),
packOp.getInnerDimsPos().end());
llvm::DenseMap<int64_t, int64_t> mappingTileToPointLoops;
// Map the position of the tiled loops with the point one. Example:
// [A][B] -> [A][B][a][b]
// entry: [A : 0] [a : 2]
// entry: [B : 1] [b : 3]
// [A][B] -> [A][B][b]
// entry: [B : 1] [b : 2]
for (auto tileLoop : llvm::enumerate(packOp.getInnerDimsPos()))
mappingTileToPointLoops[tileLoop.value()] = tileLoop.index();

SmallVector<int64_t> delSourceIndexes;
size_t tilePosIdx = 0;
SmallVector<int64_t> tilesSizes = packOp.getStaticTiles();
if (!areStaticValues(tilesSizes))
return;
int numberOfTileLoops = packOp.getSourceType().getRank();
for (int i = 0; i < numberOfTileLoops; i++) {
// Loop is not tiled.
if (!tiledLoops.count(i)) {
delSourceIndexes.push_back(delDestIndexes[i]);
// Loop is tiled, the point loop is at distance:
// numberOfTileLoops + mappingTileToPointLoops[i].
} else {
delSourceIndexes.push_back(
delDestIndexes[i] * tilesSizes[tilePosIdx] +
delDestIndexes[numberOfTileLoops +
mappingTileToPointLoops[i]]);
tilePosIdx++;
}
}
assert(delSourceIndexes.size() ==
static_cast<size_t>(packOp.getSourceType().getRank()));
int64_t sourceLinearizedIdx =
linearize(delSourceIndexes,
computeStrides(packOp.getSourceType().getShape()));
assert(sourceLinearizedIdx < numberOfElements);
LLVM_DEBUG(llvm::dbgs() << "dest index: " << destLinearizedIdx
<< " map to source index: "
<< sourceLinearizedIdx << "\n");
return rewriter.notifyMatchFailure(packOp,
"expects empty tensor destination");
// Pack destination must have static shape.
if (!packOp.getDestType().hasStaticShape())
return rewriter.notifyMatchFailure(
packOp, "expects destination with static shape");

// Pack with padding is not supported currently.
// TODO: Add tensor.pad folder pattern when available and lower the pack.
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp,
"NYI, expects no padding value");

// Step3. Do the packing.
for (int j = 0; j < bytes; j++) {
destRawData[destLinearizedIdx * bytes + j] =
rawData[sourceLinearizedIdx * bytes + j];
}
});
// If it is a splat constant, skip and let tensor.pack folder to handle this
// case.
if (denseAttr.isSplat())
return rewriter.notifyMatchFailure(
packOp, "skip pack - existing folder covers constant splats");

[[maybe_unused]] bool detectSpalt = false;
assert(DenseElementsAttr::isValidRawBuffer(packOp.getDestType(),
destRawData, detectSpalt));
auto newDense =
DenseElementsAttr::getFromRawBuffer(packOp.getDestType(), destRawData);
rewriter.setInsertionPoint(cstOp);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, newDense);
return linalg::lowerPack(rewriter, packOp);
}
};

void foldPackIntoFill(RewriterBase &rewriter, tensor::PackOp packOp) {
OpBuilder::InsertionGuard guard(rewriter);
Value sourcePack = packOp.getSource();
auto fillOp = sourcePack.getDefiningOp<linalg::FillOp>();
if (!fillOp)
return;
rewriter.setInsertionPoint(packOp);
rewriter.replaceOpWithNewOp<linalg::FillOp>(packOp, fillOp.getInputs()[0],
packOp.getDest());
}
// Rewrite constant packing operation as a compile-time packed constant.
struct ConstantFoldPack
: public tpp::impl::ConstantFoldPackBase<ConstantFoldPack> {

void runOnOperation() override {
auto module = getOperation();
IRRewriter rewriter(&getContext());
module->walk(
[&](tensor::PackOp packOp) { foldPackIntoFill(rewriter, packOp); });
module->walk(
[&](tensor::PackOp packOp) { foldPackIntoCst(rewriter, packOp); });
auto *ctx = &getContext();

// TODO: Add tensor.pad folder pattern when available.
RewritePatternSet patterns(ctx);
// Temporarily lower constant packing operation to allow other existing
// patterns to fold the operation completely.
patterns.add<LowerConstantPacking>(ctx);
// Apply canonicalization to fold trivial cases and linalg constant folders
// to cleanup lowered packs.
linalg::FillOp::getCanonicalizationPatterns(patterns, ctx);
tensor::PackOp::getCanonicalizationPatterns(patterns, ctx);
linalg::populateConstantFoldLinalgOperations(
patterns, [](OpOperand *) -> bool { return true; });

(void)applyPatternsAndFoldGreedily(module, std::move(patterns));
}
};

Expand Down
2 changes: 1 addition & 1 deletion test/Passes/fold-pack-and-constant.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tpp-opt %s -constant-fold-pack -canonicalize -cse -split-input-file | FileCheck %s
// RUN: tpp-opt %s -constant-fold-pack -cse -split-input-file | FileCheck %s

func.func @expect_to_fold_cst() -> tensor<8x2x1x1x32x32xi64> {
%cst = arith.constant dense<1> : tensor<1x1x64x256xi64>
Expand Down
Loading