diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index 8ec2bef5e..ad3aa8da2 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -504,4 +504,9 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> { ]; } +def GeneralizeNamedOps : Pass<"generalize-named-ops"> { + let summary = "Selectively convert named ops into generic ops"; + let dependentDialects = ["linalg::LinalgDialect"]; +} + #endif // TPP_DIALECT_TPP_PASSES diff --git a/include/TPP/Transforms/Transforms.h b/include/TPP/Transforms/Transforms.h index 93ed2e176..f6536b97d 100644 --- a/include/TPP/Transforms/Transforms.h +++ b/include/TPP/Transforms/Transforms.h @@ -70,6 +70,9 @@ namespace tpp { void populateLinalgToXsmmPatterns(RewritePatternSet &patterns); void populateSimplifyPacking(RewritePatternSet &patterns); void populateSinkPackPatterns(RewritePatternSet &patterns); +using ControlGeneralizationFn = std::function; +void populateGeneralizeNamedOpsPatterns(RewritePatternSet &patterns, + ControlGeneralizationFn = nullptr); } // namespace tpp namespace linalg { void populateLinalgDeGeneralizationPatterns(RewritePatternSet &patterns); diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index fb7385b5c..f1599e454 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_library(TPPTransforms IntelAMXTileConfigHoisting.cpp LinalgConvertCompareSelectToMaximumfPass.cpp ConvertAddInplacePass.cpp + GeneralizeNamedOps.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/TPP diff --git a/lib/TPP/Transforms/GeneralizeNamedOps.cpp b/lib/TPP/Transforms/GeneralizeNamedOps.cpp new file mode 100644 index 000000000..7aed6ce0e --- /dev/null +++ b/lib/TPP/Transforms/GeneralizeNamedOps.cpp @@ -0,0 +1,76 @@ +//===- GeneralizeNamedOps.cpp ------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TPP/Passes.h" +#include "TPP/Transforms/Transforms.h" +#include "TPP/Transforms/Utils/TransformUtils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace tpp; + +namespace mlir { +namespace tpp { +#define GEN_PASS_DEF_GENERALIZENAMEDOPS +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +namespace { + +struct LinalgGeneralizationPattern + : public OpInterfaceRewritePattern { + LinalgGeneralizationPattern(MLIRContext *context, ControlGeneralizationFn fun, + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + controlFn(std::move(fun)) {} + + /// `matchAndRewrite` implementation that returns the significant + /// transformed pieces of IR. + FailureOr + returningMatchAndRewrite(linalg::LinalgOp op, + PatternRewriter &rewriter) const { + return linalg::generalizeNamedOp(rewriter, op); + } + + LogicalResult matchAndRewrite(linalg::LinalgOp op, + PatternRewriter &rewriter) const override { + if (controlFn && !controlFn(op)) + return failure(); + + return returningMatchAndRewrite(op, rewriter); + } + +private: + ControlGeneralizationFn controlFn; +}; + +struct GeneralizeNamedOps + : tpp::impl::GeneralizeNamedOpsBase { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + + ControlGeneralizationFn controlFn = [](linalg::LinalgOp op) -> bool { + return !(isa(op)); + }; + + tpp::populateGeneralizeNamedOpsPatterns(patterns, controlFn); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace + +// TODO: Add control function to Linalg generalization patterns upstream. +void tpp::populateGeneralizeNamedOpsPatterns( + RewritePatternSet &patterns, ControlGeneralizationFn controlFn) { + patterns.add(patterns.getContext(), controlFn); +} diff --git a/test/Passes/pass-generalize-named-ops.mlir b/test/Passes/pass-generalize-named-ops.mlir new file mode 100644 index 000000000..9861607e5 --- /dev/null +++ b/test/Passes/pass-generalize-named-ops.mlir @@ -0,0 +1,34 @@ +// RUN: tpp-opt %s -generalize-named-ops -split-input-file | FileCheck %s + +func.func @generalize_matmul(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C: memref<16x32xf32>) { + linalg.matmul ins(%A, %B: memref<16x8xf32>, memref<8x32xf32>) + outs(%C: memref<16x32xf32>) + return +} + +// CHECK-LABEL @generalize_matmul( +// CHECK-NOT: linalg.matmul +// CHECK: linalg.generic + +// ----- + +func.func @generalize_add(%A : memref<32x32xf32>, %B: memref<32x32xf32>, %C: memref<32x32xf32>) { + linalg.add ins(%A, %B: memref<32x32xf32>, memref<32x32xf32>) + outs(%C: memref<32x32xf32>) + return +} + +// CHECK-LABEL @generalize_add( +// CHECK-NOT: linalg.add +// CHECK: linalg.generic + +// ----- + +func.func @dont_generalize_fill(%arg0 : memref<32x32xf32>, %val: f32) { + linalg.fill ins(%val : f32) outs(%arg0 : memref<32x32xf32>) + return +} + +// CHECK-LABEL @generalize_add( +// CHECK: linalg.fill +// CHECK-NOT: linalg.generic