Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize before propagation #951

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -517,4 +517,9 @@ def FoldIntoEltwise : Pass<"fold-into-eltwise", "ModuleOp"> {
"affine::AffineDialect"];
}

def GeneralizeNamedOps : Pass<"generalize-named-ops"> {
let summary = "Selectively convert named ops into generic ops";
let dependentDialects = ["linalg::LinalgDialect"];
}

#endif // TPP_DIALECT_TPP_PASSES
3 changes: 3 additions & 0 deletions include/TPP/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ namespace tpp {
void populateLinalgToXsmmPatterns(RewritePatternSet &patterns);
void populateSimplifyPacking(RewritePatternSet &patterns);
void populateSinkPackPatterns(RewritePatternSet &patterns);
using ControlGeneralizationFn = std::function<bool(linalg::LinalgOp)>;
void populateGeneralizeNamedOpsPatterns(RewritePatternSet &patterns,
ControlGeneralizationFn = nullptr);
} // namespace tpp
namespace linalg {
void populateLinalgDeGeneralizationPatterns(RewritePatternSet &patterns);
Expand Down
7 changes: 6 additions & 1 deletion lib/TPP/PassBundles/TppMapping.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"

#include "TPP/PassUtils.h"

Expand Down Expand Up @@ -63,6 +64,11 @@ struct TppMapping : public tpp::impl::TppMappingBase<TppMapping>,
pm.addPass(createPackMatmul());
pm.addPass(createPackVNNI());

// TODO: Remove when layout propagation and tile-and-fuse have better
// support for named ops.
pm.addNestedPass<func::FuncOp>(createGeneralizeNamedOps());
pm.addNestedPass<func::FuncOp>(createConvertLinalgToInplace());

// Postprocess packing.
// Run only canonicalizer at this stage as full cleanup (mostly CSE) can
// mess up tensor producer-consumer chains used for analysis in the
Expand All @@ -71,7 +77,6 @@ struct TppMapping : public tpp::impl::TppMappingBase<TppMapping>,
pm.addPass(createConstantFoldPack());
pm.addPass(createSimplifyAndCanonicalizePack());

pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass());
pm.addPass(createCleanup());
pm.addNestedPass<func::FuncOp>(
createLinalgConvertCompareSelectToMaximumfPass());
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_mlir_library(TPPTransforms
LinalgConvertCompareSelectToMaximumfPass.cpp
ConvertLinalgToInplace.cpp
FoldIntoEltwise.cpp
GeneralizeNamedOps.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
76 changes: 76 additions & 0 deletions lib/TPP/Transforms/GeneralizeNamedOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//===- GeneralizeNamedOps.cpp ------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TPP/Passes.h"
#include "TPP/Transforms/Transforms.h"
#include "TPP/Transforms/Utils/TransformUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace tpp;

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_GENERALIZENAMEDOPS
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

namespace {

struct LinalgGeneralizationPattern
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
LinalgGeneralizationPattern(MLIRContext *context, ControlGeneralizationFn fun,
PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<linalg::LinalgOp>(context, benefit),
controlFn(std::move(fun)) {}

/// `matchAndRewrite` implementation that returns the significant
/// transformed pieces of IR.
FailureOr<linalg::GenericOp>
returningMatchAndRewrite(linalg::LinalgOp op,
PatternRewriter &rewriter) const {
return linalg::generalizeNamedOp(rewriter, op);
}

LogicalResult matchAndRewrite(linalg::LinalgOp op,
PatternRewriter &rewriter) const override {
if (controlFn && !controlFn(op))
return failure();

return returningMatchAndRewrite(op, rewriter);
}

private:
ControlGeneralizationFn controlFn;
};

struct GeneralizeNamedOps
: tpp::impl::GeneralizeNamedOpsBase<GeneralizeNamedOps> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());

ControlGeneralizationFn controlFn = [](linalg::LinalgOp op) -> bool {
return !(isa<linalg::FillOp>(op));
};

tpp::populateGeneralizeNamedOpsPatterns(patterns, controlFn);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace

// TODO: Add control function to Linalg generalization patterns upstream.
void tpp::populateGeneralizeNamedOpsPatterns(
RewritePatternSet &patterns, ControlGeneralizationFn controlFn) {
patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), controlFn);
}
34 changes: 34 additions & 0 deletions test/Passes/pass-generalize-named-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: tpp-opt %s -generalize-named-ops -split-input-file | FileCheck %s

func.func @generalize_matmul(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C: memref<16x32xf32>) {
linalg.matmul ins(%A, %B: memref<16x8xf32>, memref<8x32xf32>)
outs(%C: memref<16x32xf32>)
return
}

// CHECK-LABEL @generalize_matmul(
// CHECK-NOT: linalg.matmul
// CHECK: linalg.generic

// -----

func.func @generalize_add(%A : memref<32x32xf32>, %B: memref<32x32xf32>, %C: memref<32x32xf32>) {
linalg.add ins(%A, %B: memref<32x32xf32>, memref<32x32xf32>)
outs(%C: memref<32x32xf32>)
return
}

// CHECK-LABEL @generalize_add(
// CHECK-NOT: linalg.add
// CHECK: linalg.generic

// -----

func.func @dont_generalize_fill(%arg0 : memref<32x32xf32>, %val: f32) {
linalg.fill ins(%val : f32) outs(%arg0 : memref<32x32xf32>)
return
}

// CHECK-LABEL @generalize_add(
// CHECK: linalg.fill
// CHECK-NOT: linalg.generic
22 changes: 22 additions & 0 deletions test/Passes/tpp-mapping.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,25 @@ func.func @tile_and_fuse(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>,
// CHECK-SAME:{{.*}}outs(%{{.+}} : tensor<32x32xf32>)
// CHECK: linalg.generic{{.*}}outs(%{{.+}} : tensor<32x32xf32>)
// CHECK: arith.maximumf

// -----

func.func @tile_and_fuse_named(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>,
%arg2: tensor<64x64xf32>, %arg3: tensor<64x64xf32>) -> tensor<64x64xf32> {
%e = tensor.empty() : tensor<64x64xf32>
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<64x64xf32>, tensor<64x64xf32>)
outs(%arg2 : tensor<64x64xf32>) -> tensor<64x64xf32>
%1 = linalg.add ins(%0, %arg3 : tensor<64x64xf32>, tensor<64x64xf32>)
outs(%e : tensor<64x64xf32>) -> tensor<64x64xf32>
return %1 : tensor<64x64xf32>
}

// CHECK-LABEL: tile_and_fuse_named(
// CHECK-COUNT-3: tensor.pack
// Fused matmul and relu
// CHECK: scf.forall
// CHECK: linalg.batch_reduce_matmul{{.*}}ins(%{{.+}}, %{{.+}} : tensor<2x32x32xf32>, tensor<2x32x32xf32>)
// CHECK-SAME:{{.*}}outs(%{{.+}} : tensor<32x32xf32>)
// CHECK: linalg.generic{{.*}}outs(%{{.+}} : tensor<32x32xf32>)
// CHECK: arith.addf
// CHECK-NOT: tensor.unpack