Skip to content

Commit

Permalink
Generalize named ops pass
Browse files Browse the repository at this point in the history
Adds a custom wrapper around Linalg named ops generalization.

The custom pattern exposes control over which ops should be
generalized.
By default, the pass filters out linalg.fill ops as layout propagation
and fusion behave better when the named version is present.
This improves overall IR quality after bufferization.
  • Loading branch information
adam-smnk committed Aug 6, 2024
1 parent 93b822b commit 351098c
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 0 deletions.
5 changes: 5 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -504,4 +504,9 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
];
}

def GeneralizeNamedOps : Pass<"generalize-named-ops"> {
let summary = "Selectively convert named ops into generic ops";
let dependentDialects = ["linalg::LinalgDialect"];
}

#endif // TPP_DIALECT_TPP_PASSES
3 changes: 3 additions & 0 deletions include/TPP/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ namespace tpp {
void populateLinalgToXsmmPatterns(RewritePatternSet &patterns);
void populateSimplifyPacking(RewritePatternSet &patterns);
void populateSinkPackPatterns(RewritePatternSet &patterns);
using ControlGeneralizationFn = std::function<bool(linalg::LinalgOp)>;
void populateGeneralizeNamedOpsPatterns(RewritePatternSet &patterns,
ControlGeneralizationFn = nullptr);
} // namespace tpp
namespace linalg {
void populateLinalgDeGeneralizationPatterns(RewritePatternSet &patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ add_mlir_library(TPPTransforms
IntelAMXTileConfigHoisting.cpp
LinalgConvertCompareSelectToMaximumfPass.cpp
ConvertAddInplacePass.cpp
GeneralizeNamedOps.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
76 changes: 76 additions & 0 deletions lib/TPP/Transforms/GeneralizeNamedOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//===- GeneralizeNamedOps.cpp ------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TPP/Passes.h"
#include "TPP/Transforms/Transforms.h"
#include "TPP/Transforms/Utils/TransformUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace tpp;

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_GENERALIZENAMEDOPS
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

namespace {

struct LinalgGeneralizationPattern
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
LinalgGeneralizationPattern(MLIRContext *context, ControlGeneralizationFn fun,
PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<linalg::LinalgOp>(context, benefit),
controlFn(std::move(fun)) {}

/// `matchAndRewrite` implementation that returns the significant
/// transformed pieces of IR.
FailureOr<linalg::GenericOp>
returningMatchAndRewrite(linalg::LinalgOp op,
PatternRewriter &rewriter) const {
return linalg::generalizeNamedOp(rewriter, op);
}

LogicalResult matchAndRewrite(linalg::LinalgOp op,
PatternRewriter &rewriter) const override {
if (controlFn && !controlFn(op))
return failure();

return returningMatchAndRewrite(op, rewriter);
}

private:
ControlGeneralizationFn controlFn;
};

struct GeneralizeNamedOps
: tpp::impl::GeneralizeNamedOpsBase<GeneralizeNamedOps> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());

ControlGeneralizationFn controlFn = [](linalg::LinalgOp op) -> bool {
return !(isa<linalg::FillOp>(op));
};

tpp::populateGeneralizeNamedOpsPatterns(patterns, controlFn);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace

// TODO: Add control function to Linalg generalization patterns upstream.
void tpp::populateGeneralizeNamedOpsPatterns(
RewritePatternSet &patterns, ControlGeneralizationFn controlFn) {
patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), controlFn);
}
34 changes: 34 additions & 0 deletions test/Passes/pass-generalize-named-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: tpp-opt %s -generalize-named-ops -split-input-file | FileCheck %s

func.func @generalize_matmul(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C: memref<16x32xf32>) {
linalg.matmul ins(%A, %B: memref<16x8xf32>, memref<8x32xf32>)
outs(%C: memref<16x32xf32>)
return
}

// CHECK-LABEL @generalize_matmul(
// CHECK-NOT: linalg.matmul
// CHECK: linalg.generic

// -----

func.func @generalize_add(%A : memref<32x32xf32>, %B: memref<32x32xf32>, %C: memref<32x32xf32>) {
linalg.add ins(%A, %B: memref<32x32xf32>, memref<32x32xf32>)
outs(%C: memref<32x32xf32>)
return
}

// CHECK-LABEL @generalize_add(
// CHECK-NOT: linalg.add
// CHECK: linalg.generic

// -----

func.func @dont_generalize_fill(%arg0 : memref<32x32xf32>, %val: f32) {
linalg.fill ins(%val : f32) outs(%arg0 : memref<32x32xf32>)
return
}

// CHECK-LABEL @generalize_add(
// CHECK: linalg.fill
// CHECK-NOT: linalg.generic

0 comments on commit 351098c

Please sign in to comment.