From 37190f14a73c6fa93d1b46735b5d39fdfb2f104b Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 6 Aug 2024 13:04:09 +0200 Subject: [PATCH 1/6] Generalize linalg ops earlier Allows better pack-unpack propagation by first generalizing linalg named ops. This is mainly due to limited support for named ops in layout propagation patterns. --- lib/TPP/PassBundles/TppMapping.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lib/TPP/PassBundles/TppMapping.cpp b/lib/TPP/PassBundles/TppMapping.cpp index d39b62530..1c8684b98 100644 --- a/lib/TPP/PassBundles/TppMapping.cpp +++ b/lib/TPP/PassBundles/TppMapping.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" #include "TPP/PassUtils.h" @@ -63,6 +64,11 @@ struct TppMapping : public tpp::impl::TppMappingBase, pm.addPass(createPackMatmul()); pm.addPass(createPackVNNI()); + // TODO: Remove when layout propagation and tile-and-fuse have better + // support for named ops. + pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); + pm.addPass(createCanonicalizerPass()); + // Postprocess packing. // Run only canonicalizer at this stage as full cleanup (mostly CSE) can // mess up tensor producer-consumer chains used for analysis in the @@ -71,7 +77,6 @@ struct TppMapping : public tpp::impl::TppMappingBase, pm.addPass(createConstantFoldPack()); pm.addPass(createSimplifyAndCanonicalizePack()); - pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); pm.addPass(createCleanup()); pm.addNestedPass( createLinalgConvertCompareSelectToMaximumfPass()); From 2c909e29fdf8c6704bd160da3bb35d5c9a26a951 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 6 Aug 2024 16:46:19 +0200 Subject: [PATCH 2/6] Generalize named ops pass Adds a custom wrapper around Linalg named ops generalization. The custom pattern exposes control over which ops should be generalized. By default, the pass filters out linalg.fill ops as layout propagation and fusion behave better when the named version is present. This improves overall IR quality after bufferization. --- include/TPP/Passes.td | 5 ++ include/TPP/Transforms/Transforms.h | 3 + lib/TPP/Transforms/CMakeLists.txt | 1 + lib/TPP/Transforms/GeneralizeNamedOps.cpp | 76 ++++++++++++++++++++++ test/Passes/pass-generalize-named-ops.mlir | 34 ++++++++++ 5 files changed, 119 insertions(+) create mode 100644 lib/TPP/Transforms/GeneralizeNamedOps.cpp create mode 100644 test/Passes/pass-generalize-named-ops.mlir diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index a650a8e7f..a59eade4e 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -517,4 +517,9 @@ def FoldIntoEltwise : Pass<"fold-into-eltwise", "ModuleOp"> { "affine::AffineDialect"]; } +def GeneralizeNamedOps : Pass<"generalize-named-ops"> { + let summary = "Selectively convert named ops into generic ops"; + let dependentDialects = ["linalg::LinalgDialect"]; +} + #endif // TPP_DIALECT_TPP_PASSES diff --git a/include/TPP/Transforms/Transforms.h b/include/TPP/Transforms/Transforms.h index 93ed2e176..f6536b97d 100644 --- a/include/TPP/Transforms/Transforms.h +++ b/include/TPP/Transforms/Transforms.h @@ -70,6 +70,9 @@ namespace tpp { void populateLinalgToXsmmPatterns(RewritePatternSet &patterns); void populateSimplifyPacking(RewritePatternSet &patterns); void populateSinkPackPatterns(RewritePatternSet &patterns); +using ControlGeneralizationFn = std::function; +void populateGeneralizeNamedOpsPatterns(RewritePatternSet &patterns, + ControlGeneralizationFn = nullptr); } // namespace tpp namespace linalg { void populateLinalgDeGeneralizationPatterns(RewritePatternSet &patterns); diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index 708ffb5eb..d6ff5d171 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -22,6 +22,7 @@ add_mlir_library(TPPTransforms LinalgConvertCompareSelectToMaximumfPass.cpp ConvertLinalgToInplace.cpp FoldIntoEltwise.cpp + GeneralizeNamedOps.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/TPP diff --git a/lib/TPP/Transforms/GeneralizeNamedOps.cpp b/lib/TPP/Transforms/GeneralizeNamedOps.cpp new file mode 100644 index 000000000..7aed6ce0e --- /dev/null +++ b/lib/TPP/Transforms/GeneralizeNamedOps.cpp @@ -0,0 +1,76 @@ +//===- GeneralizeNamedOps.cpp ------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TPP/Passes.h" +#include "TPP/Transforms/Transforms.h" +#include "TPP/Transforms/Utils/TransformUtils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace tpp; + +namespace mlir { +namespace tpp { +#define GEN_PASS_DEF_GENERALIZENAMEDOPS +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +namespace { + +struct LinalgGeneralizationPattern + : public OpInterfaceRewritePattern { + LinalgGeneralizationPattern(MLIRContext *context, ControlGeneralizationFn fun, + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + controlFn(std::move(fun)) {} + + /// `matchAndRewrite` implementation that returns the significant + /// transformed pieces of IR. + FailureOr + returningMatchAndRewrite(linalg::LinalgOp op, + PatternRewriter &rewriter) const { + return linalg::generalizeNamedOp(rewriter, op); + } + + LogicalResult matchAndRewrite(linalg::LinalgOp op, + PatternRewriter &rewriter) const override { + if (controlFn && !controlFn(op)) + return failure(); + + return returningMatchAndRewrite(op, rewriter); + } + +private: + ControlGeneralizationFn controlFn; +}; + +struct GeneralizeNamedOps + : tpp::impl::GeneralizeNamedOpsBase { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + + ControlGeneralizationFn controlFn = [](linalg::LinalgOp op) -> bool { + return !(isa(op)); + }; + + tpp::populateGeneralizeNamedOpsPatterns(patterns, controlFn); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace + +// TODO: Add control function to Linalg generalization patterns upstream. +void tpp::populateGeneralizeNamedOpsPatterns( + RewritePatternSet &patterns, ControlGeneralizationFn controlFn) { + patterns.add(patterns.getContext(), controlFn); +} diff --git a/test/Passes/pass-generalize-named-ops.mlir b/test/Passes/pass-generalize-named-ops.mlir new file mode 100644 index 000000000..9861607e5 --- /dev/null +++ b/test/Passes/pass-generalize-named-ops.mlir @@ -0,0 +1,34 @@ +// RUN: tpp-opt %s -generalize-named-ops -split-input-file | FileCheck %s + +func.func @generalize_matmul(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C: memref<16x32xf32>) { + linalg.matmul ins(%A, %B: memref<16x8xf32>, memref<8x32xf32>) + outs(%C: memref<16x32xf32>) + return +} + +// CHECK-LABEL @generalize_matmul( +// CHECK-NOT: linalg.matmul +// CHECK: linalg.generic + +// ----- + +func.func @generalize_add(%A : memref<32x32xf32>, %B: memref<32x32xf32>, %C: memref<32x32xf32>) { + linalg.add ins(%A, %B: memref<32x32xf32>, memref<32x32xf32>) + outs(%C: memref<32x32xf32>) + return +} + +// CHECK-LABEL @generalize_add( +// CHECK-NOT: linalg.add +// CHECK: linalg.generic + +// ----- + +func.func @dont_generalize_fill(%arg0 : memref<32x32xf32>, %val: f32) { + linalg.fill ins(%val : f32) outs(%arg0 : memref<32x32xf32>) + return +} + +// CHECK-LABEL @generalize_add( +// CHECK: linalg.fill +// CHECK-NOT: linalg.generic From 2a74fe5a8db083e205c7bb595adf0ee70fcaf013 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 6 Aug 2024 16:50:52 +0200 Subject: [PATCH 3/6] Use custom generalization in mapping + test case --- lib/TPP/PassBundles/TppMapping.cpp | 2 +- test/Passes/tpp-mapping.mlir | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/lib/TPP/PassBundles/TppMapping.cpp b/lib/TPP/PassBundles/TppMapping.cpp index 1c8684b98..43d363175 100644 --- a/lib/TPP/PassBundles/TppMapping.cpp +++ b/lib/TPP/PassBundles/TppMapping.cpp @@ -66,7 +66,7 @@ struct TppMapping : public tpp::impl::TppMappingBase, // TODO: Remove when layout propagation and tile-and-fuse have better // support for named ops. - pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); + pm.addNestedPass(createGeneralizeNamedOps()); pm.addPass(createCanonicalizerPass()); // Postprocess packing. diff --git a/test/Passes/tpp-mapping.mlir b/test/Passes/tpp-mapping.mlir index e50c297bc..ed2f19aa3 100644 --- a/test/Passes/tpp-mapping.mlir +++ b/test/Passes/tpp-mapping.mlir @@ -207,3 +207,25 @@ func.func @tile_and_fuse(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, // CHECK-SAME:{{.*}}outs(%{{.+}} : tensor<32x32xf32>) // CHECK: linalg.generic{{.*}}outs(%{{.+}} : tensor<32x32xf32>) // CHECK: arith.maximumf + +// ----- + +func.func @tile_and_fuse_named(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, + %arg2: tensor<64x64xf32>, %arg3: tensor<64x64xf32>) -> tensor<64x64xf32> { + %e = tensor.empty() : tensor<64x64xf32> + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<64x64xf32>, tensor<64x64xf32>) + outs(%arg2 : tensor<64x64xf32>) -> tensor<64x64xf32> + %1 = linalg.add ins(%0, %arg3 : tensor<64x64xf32>, tensor<64x64xf32>) + outs(%e : tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> +} + +// CHECK-LABEL: tile_and_fuse_named( +// CHECK-COUNT-3: tensor.pack +// Fused matmul and relu +// CHECK: scf.forall +// CHECK: linalg.batch_reduce_matmul{{.*}}ins(%{{.+}}, %{{.+}} : tensor<2x32x32xf32>, tensor<2x32x32xf32>) +// CHECK-SAME:{{.*}}outs(%{{.+}} : tensor<32x32xf32>) +// CHECK: linalg.generic{{.*}}outs(%{{.+}} : tensor<32x32xf32>) +// CHECK: arith.addf +// CHECK: tensor.unpack From 001a8b3cc9440c871456cbb983ba6b53e8fd8a3c Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 6 Aug 2024 16:55:10 +0200 Subject: [PATCH 4/6] Remove extra canonicalization --- lib/TPP/PassBundles/TppMapping.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/TPP/PassBundles/TppMapping.cpp b/lib/TPP/PassBundles/TppMapping.cpp index 43d363175..edb93d39b 100644 --- a/lib/TPP/PassBundles/TppMapping.cpp +++ b/lib/TPP/PassBundles/TppMapping.cpp @@ -67,7 +67,6 @@ struct TppMapping : public tpp::impl::TppMappingBase, // TODO: Remove when layout propagation and tile-and-fuse have better // support for named ops. pm.addNestedPass(createGeneralizeNamedOps()); - pm.addPass(createCanonicalizerPass()); // Postprocess packing. // Run only canonicalizer at this stage as full cleanup (mostly CSE) can From 5d40f545dfa436c3a365ccbc73aa5910bb088fb7 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 7 Aug 2024 14:11:23 +0200 Subject: [PATCH 5/6] Cleanup generalized adds --- lib/TPP/PassBundles/TppMapping.cpp | 1 + test/Passes/tpp-mapping.mlir | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/TPP/PassBundles/TppMapping.cpp b/lib/TPP/PassBundles/TppMapping.cpp index edb93d39b..a7f3b60f9 100644 --- a/lib/TPP/PassBundles/TppMapping.cpp +++ b/lib/TPP/PassBundles/TppMapping.cpp @@ -67,6 +67,7 @@ struct TppMapping : public tpp::impl::TppMappingBase, // TODO: Remove when layout propagation and tile-and-fuse have better // support for named ops. pm.addNestedPass(createGeneralizeNamedOps()); + pm.addNestedPass(createConvertAddInplacePass()); // Postprocess packing. // Run only canonicalizer at this stage as full cleanup (mostly CSE) can diff --git a/test/Passes/tpp-mapping.mlir b/test/Passes/tpp-mapping.mlir index ed2f19aa3..1b05b13ed 100644 --- a/test/Passes/tpp-mapping.mlir +++ b/test/Passes/tpp-mapping.mlir @@ -228,4 +228,4 @@ func.func @tile_and_fuse_named(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32 // CHECK-SAME:{{.*}}outs(%{{.+}} : tensor<32x32xf32>) // CHECK: linalg.generic{{.*}}outs(%{{.+}} : tensor<32x32xf32>) // CHECK: arith.addf -// CHECK: tensor.unpack +// CHECK-NOT: tensor.unpack From 799d1ccec4247d15e67c1d659991ea19cecc8686 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 9 Aug 2024 14:51:53 +0200 Subject: [PATCH 6/6] Fix after rebase --- lib/TPP/PassBundles/TppMapping.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/TPP/PassBundles/TppMapping.cpp b/lib/TPP/PassBundles/TppMapping.cpp index a7f3b60f9..98a86996d 100644 --- a/lib/TPP/PassBundles/TppMapping.cpp +++ b/lib/TPP/PassBundles/TppMapping.cpp @@ -67,7 +67,7 @@ struct TppMapping : public tpp::impl::TppMappingBase, // TODO: Remove when layout propagation and tile-and-fuse have better // support for named ops. pm.addNestedPass(createGeneralizeNamedOps()); - pm.addNestedPass(createConvertAddInplacePass()); + pm.addNestedPass(createConvertLinalgToInplace()); // Postprocess packing. // Run only canonicalizer at this stage as full cleanup (mostly CSE) can