Skip to content

Commit

Permalink
Tile-and-fuse - tiling factor option
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed Feb 19, 2024
1 parent 0c9f715 commit 49781ad
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 20 deletions.
4 changes: 3 additions & 1 deletion include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def TileConsumerAndFuseProducers : Pass<"tile-consumer-and-fuse-producers",
"Get producers till maxDepth">,
Option<"numIters", "num-iters", "int64_t", "3",
"Run fusion for the given number of iterations">,
Option<"useForAll", "use-for-all", "bool", "true", "Use parallel forAll">
Option<"useForAll", "use-for-all", "bool", "true", "Use parallel forAll">,
Option<"minTileFactor", "min-tile-factor", "int64_t", "2",
"Minimum factor between dimension size and a tile size">
];
let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect",
"tensor::TensorDialect"];
Expand Down
6 changes: 5 additions & 1 deletion include/TPP/Transforms/Utils/TransformUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ isContraction(linalg::LinalgOp linalgOp);
// Specific dims can be passed using 'dims'. If dims is empty the validation
// will start from the outermost dimension, moving to innermost ones up to the
// number of tiles.
// Tiling application can restricted based on the workload dimension size.
// The tiling is applied only to when all dimensions fulfill the predicate:
// '(dimSize[i] / tiles[i]) >= minTileFactor'.
bool validateFullTilesOnDims(TilingInterface tileOp,
ArrayRef<OpFoldResult> tiles,
ArrayRef<size_t> dims = {});
ArrayRef<size_t> dims = {},
int64_t minTileFactor = 2);

// Rewrite scf.for to scf.forall. Assumes the loop to be parallel and
// marked with `kLoopId`.
Expand Down
4 changes: 3 additions & 1 deletion lib/TPP/GPU/GpuPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase<GpuPipeline>,
// Tile to split the kernel into threads and blocks.
// Use default tiling to handle both packed and unpacked ops.
pm.addPass(createCleanup());
pm.addPass(createTileConsumerAndFuseProducers());
TileConsumerAndFuseProducersOptions tilingOptions;
tilingOptions.minTileFactor = 1;
pm.addPass(createTileConsumerAndFuseProducers(tilingOptions));
pm.addPass(createCleanup());

// Preprocess and bufferize as further conversion requires memref
Expand Down
21 changes: 13 additions & 8 deletions lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ static bool isConvolutionLike(Operation *op) {
// Return true if `op` can be tiled using `tileSizes`. Require to statically
// know the range and the tile factor. The tile must be full.
static bool canBeTiledWithCurrentSpec(Operation *op,
ArrayRef<OpFoldResult> tileSizes) {
ArrayRef<OpFoldResult> tileSizes,
int64_t minTileFactor) {
assert(isa<TilingInterface>(op) &&
"expect an op implementing the tiling interface");
assert(!tileSizes.empty() && "expect tile sizes to be non-empty");
Expand All @@ -105,8 +106,8 @@ static bool canBeTiledWithCurrentSpec(Operation *op,
}

LLVM_DEBUG(llvm::dbgs() << "Running tile validations ----\n");
if (!linalgx::utils::validateFullTilesOnDims(cast<TilingInterface>(op),
tileSizes)) {
if (!linalgx::utils::validateFullTilesOnDims(
cast<TilingInterface>(op), tileSizes, /*dim=*/{}, minTileFactor)) {
LLVM_DEBUG(llvm::dbgs() << "FAILED\n");
return false;
}
Expand Down Expand Up @@ -382,7 +383,8 @@ static llvm::SmallDenseSet<Operation *> collectFusableProducers(
static FailureOr<scf::SCFTileAndFuseResult> fuseWithEltwise(
RewriterBase &rewriter, TilingInterface consumer,
llvm::DenseMap<Operation *, SmallVector<OpFoldResult>> &tileSizes,
llvm::SmallDenseSet<Operation *> &alreadyFusedOps, int64_t maxDepth) {
llvm::SmallDenseSet<Operation *> &alreadyFusedOps, int64_t maxDepth,
int64_t minTileFactor) {
// Step 0. Early exit if tileSizes are empty.
if (tileSizes.empty() || !tileSizes.count(consumer)) {
LLVM_DEBUG(llvm::dbgs() << "EMPTY TILE SIZES\n");
Expand All @@ -397,7 +399,8 @@ static FailureOr<scf::SCFTileAndFuseResult> fuseWithEltwise(
}

// Step 2. Check if the tile configuration fits the consumer.
if (!canBeTiledWithCurrentSpec(consumer, tileSizes.at(consumer))) {
if (!canBeTiledWithCurrentSpec(consumer, tileSizes.at(consumer),
minTileFactor)) {
LLVM_DEBUG(llvm::dbgs() << "CONSUMER: " << consumer
<< "\nCANNOT BE TILED WITH CURRENT CONFIG\n");
return failure();
Expand Down Expand Up @@ -616,7 +619,8 @@ static Operation *getLastFusableEltWiseConsumer(

// Run `fuseWithEltwise` on contraction-like operations.
static void doFusion(RewriterBase &rewriter, func::FuncOp func,
ArrayRef<int64_t> tileSizes, int64_t maxDepth) {
ArrayRef<int64_t> tileSizes, int64_t maxDepth,
int64_t minTileFactor) {
// Set to keep track of fused ops.
llvm::SmallDenseSet<Operation *> fusedOps;

Expand Down Expand Up @@ -673,7 +677,7 @@ static void doFusion(RewriterBase &rewriter, func::FuncOp func,
LLVM_DEBUG(llvm::dbgs() << "\n\n");
FailureOr<scf::SCFTileAndFuseResult> fuseAndTileResult =
fuseWithEltwise(rewriter, cast<TilingInterface>(linalgOp),
defaultTiles, fusedOps, maxDepth);
defaultTiles, fusedOps, maxDepth, minTileFactor);
LLVM_DEBUG(llvm::dbgs() << "\n\n");
if (succeeded(fuseAndTileResult)) {
rewriter.replaceOp(
Expand Down Expand Up @@ -703,7 +707,8 @@ struct TileConsumerAndFuseProducers
do {
func::FuncOp func = getOperation();
IRRewriter rewriter(&getContext());
doFusion(rewriter, func, this->tileSizes, this->maxDepth);
doFusion(rewriter, func, this->tileSizes, this->maxDepth,
this->minTileFactor);

{
RewritePatternSet patterns(&ctx);
Expand Down
19 changes: 11 additions & 8 deletions lib/TPP/Transforms/TransformUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,36 +291,38 @@ static std::optional<int64_t> getConstantRange(const Range &range) {
}

static bool validateFullTilesOnDim(TilingInterface tileOp,
const OpFoldResult &tile, size_t dim) {
const OpFoldResult &tile, size_t dim,
int64_t minTileFactor) {
OpBuilder builder(tileOp);
OpBuilder::InsertionGuard guard(builder);
SmallVector<Range> iterationDomain =
cast<TilingInterface>(tileOp.getOperation()).getIterationDomain(builder);
if (dim >= iterationDomain.size())
return false;

auto tileFactor = getConstantIntValue(tile);
auto tileSize = getConstantIntValue(tile);
auto rangeOnDim = getConstantRange(iterationDomain[dim]);

// If the tile factor or the range are non-constant, the tile size is
// considered to be valid.
if (!tileFactor || !rangeOnDim)
if (!tileSize || !rangeOnDim)
return true;

// Corner case: Tiling with '0' along 'dim' is valid - no tiling.
if (*tileFactor == 0)
if (*tileSize == 0)
return true;

// Corner case: Tiling '1' with '1' is valid.
if (*tileFactor == 1 && *rangeOnDim == 1)
if (*tileSize == 1 && *rangeOnDim == 1)
return true;

return (*rangeOnDim % *tileFactor == 0);
return (*rangeOnDim % *tileSize == 0) &&
(*rangeOnDim / *tileSize >= minTileFactor);
}

bool validateFullTilesOnDims(TilingInterface tileOp,
ArrayRef<OpFoldResult> tiles,
ArrayRef<size_t> dims) {
ArrayRef<size_t> dims, int64_t minTileFactor) {
if (!dims.empty() && dims.size() != tiles.size())
return false;

Expand All @@ -333,7 +335,8 @@ bool validateFullTilesOnDims(TilingInterface tileOp,
assert(dimsToCheck.size() == tiles.size());

for (auto dim : llvm::enumerate(dimsToCheck)) {
if (!validateFullTilesOnDim(tileOp, tiles[dim.index()], dim.value()))
if (!validateFullTilesOnDim(tileOp, tiles[dim.index()], dim.value(),
minTileFactor))
return false;
}
return true;
Expand Down
5 changes: 4 additions & 1 deletion test/Passes/DefaultPipeline/default-tpp-passes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,11 @@ func.func @batch_matmul_rewrite(%arg0: tensor<512x32x64xf32>, %arg1: tensor<512x
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i64
// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : i64
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i64
// CHECK-DAG: %[[C0_i:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1_i:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C512_i:.+]] = arith.constant 512 : index
// CHECK: %{{.+}} = call @xsmm_gemm_dispatch(%[[C1]], %[[C32]], %[[C32]], %[[C64]], %[[C64]], %[[C32]], %[[C32]], %[[C0]])
// CHECK: scf.parallel
// CHECK: scf.parallel{{.*}}(%[[C0_i]]) to (%[[C512_i]]) step (%[[C1_i]])
// CHECK: xsmm_gemm_invoke
%1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x32x64xf32>, tensor<512x64x32xf32>)
outs(%0 : tensor<512x32x32xf32>) -> tensor<512x32x32xf32>
Expand Down

0 comments on commit 49781ad

Please sign in to comment.