Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NFC: Convert util transforms to declarative registration. #10143

Merged
merged 1 commit into from
Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Patterns.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand All @@ -20,15 +21,8 @@ namespace iree_compiler {
namespace IREE {
namespace Util {

class ApplyPatternsPass
: public PassWrapper<ApplyPatternsPass, OperationPass<void>> {
class ApplyPatternsPass : public ApplyPatternsBase<ApplyPatternsPass> {
public:
StringRef getArgument() const override { return "iree-util-apply-patterns"; }

StringRef getDescription() const override {
return "Applies some risky/IREE-specific canonicalization patterns.";
}

void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<BuiltinDialect, func::FuncDialect, IREE::Util::UtilDialect>();
Expand Down Expand Up @@ -60,8 +54,6 @@ std::unique_ptr<OperationPass<void>> createApplyPatternsPass() {
return std::make_unique<ApplyPatternsPass>();
}

static PassRegistration<ApplyPatternsPass> pass;

} // namespace Util
} // namespace IREE
} // namespace iree_compiler
Expand Down
22 changes: 21 additions & 1 deletion compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
load("//build_tools/bazel:iree_tablegen.bzl", "iree_gentbl_cc_library")

package(
default_visibility = ["//visibility:public"],
Expand All @@ -23,18 +24,22 @@ iree_compiler_cc_library(
"FoldGlobals.cpp",
"FuseGlobals.cpp",
"HoistIntoGlobals.cpp",
"PassDetail.h",
"Passes.cpp",
"Patterns.cpp",
"PropagateSubrange.cpp",
"PropagateSubranges.cpp",
"SimplifyGlobalAccesses.cpp",
"StripDebugOps.cpp",
"TestConversion.cpp",
"TestFloatRangeAnalysis.cpp",
],
hdrs = [
"Passes.h",
"Passes.h.inc",
"Patterns.h",
],
deps = [
":PassesIncGen",
"//compiler/src/iree/compiler/Dialect/Util/Analysis",
"//compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes",
"//compiler/src/iree/compiler/Dialect/Util/Analysis/Constant",
Expand All @@ -59,3 +64,18 @@ iree_compiler_cc_library(
"@llvm-project//mlir:Transforms",
],
)

iree_gentbl_cc_library(
name = "PassesIncGen",
tbl_outs = [
(
["--gen-pass-decls"],
"Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_cc_library(
Transforms
HDRS
"Passes.h"
"Passes.h.inc"
"Patterns.h"
SRCS
"ApplyPatterns.cpp"
Expand All @@ -25,13 +26,16 @@ iree_cc_library(
"FoldGlobals.cpp"
"FuseGlobals.cpp"
"HoistIntoGlobals.cpp"
"PassDetail.h"
"Passes.cpp"
"Patterns.cpp"
"PropagateSubrange.cpp"
"PropagateSubranges.cpp"
"SimplifyGlobalAccesses.cpp"
"StripDebugOps.cpp"
"TestConversion.cpp"
"TestFloatRangeAnalysis.cpp"
DEPS
::PassesIncGen
LLVMSupport
MLIRAffineDialect
MLIRAnalysis
Expand All @@ -57,4 +61,13 @@ iree_cc_library(
PUBLIC
)

iree_tablegen_library(
NAME
PassesIncGen
TD_FILE
"Passes.td"
OUTS
--gen-pass-decls Passes.h.inc
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTraits.h"
#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
Expand All @@ -30,19 +31,8 @@ namespace Util {
namespace {

class CombineInitializersPass
: public PassWrapper<CombineInitializersPass,
OperationPass<mlir::ModuleOp>> {
: public CombineInitializersBase<CombineInitializersPass> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CombineInitializersPass)

StringRef getArgument() const override {
return "iree-util-combine-initializers";
}

StringRef getDescription() const override {
return "Combines global initializers into one.";
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Util::UtilDialect>();
}
Expand Down Expand Up @@ -92,8 +82,6 @@ std::unique_ptr<OperationPass<mlir::ModuleOp>> createCombineInitializersPass() {
return std::make_unique<CombineInitializersPass>();
}

static PassRegistration<CombineInitializersPass> pass;

} // namespace Util
} // namespace IREE
} // namespace iree_compiler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -215,8 +216,8 @@ struct ConvertTypeSensitiveArithCastOp : public OpConversionPattern<OpTy> {
}
};

template <typename T, typename Converter>
struct ConvertTypesPass : public PassWrapper<T, OperationPass<mlir::ModuleOp>> {
template <typename Base, typename Converter>
struct ConvertTypesPass : public Base {
void runOnOperation() override {
MLIRContext *context = &this->getContext();
RewritePatternSet patterns(context);
Expand Down Expand Up @@ -285,22 +286,13 @@ struct DemoteI64ToI32Converter
}
};
struct DemoteI64ToI32Pass
: public ConvertTypesPass<DemoteI64ToI32Pass, DemoteI64ToI32Converter> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DemoteI64ToI32Pass)

StringRef getArgument() const override {
return "iree-util-demote-i64-to-i32";
}
StringRef getDescription() const override {
return "Demotes i64 types to i32 types.";
}
};
: public ConvertTypesPass<DemoteI64ToI32Base<DemoteI64ToI32Pass>,
DemoteI64ToI32Converter> {};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteI64ToI32Pass() {
return std::make_unique<DemoteI64ToI32Pass>();
}
static PassRegistration<DemoteI64ToI32Pass> demoteI64ToI32Pass;

namespace {
struct DemoteF32ToF16Converter
Expand All @@ -310,22 +302,13 @@ struct DemoteF32ToF16Converter
}
};
struct DemoteF32ToF16Pass
: public ConvertTypesPass<DemoteF32ToF16Pass, DemoteF32ToF16Converter> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DemoteF32ToF16Pass)

StringRef getArgument() const override {
return "iree-util-demote-f32-to-f16";
}
StringRef getDescription() const override {
return "Demotes f32 types to f16 types.";
}
};
: public ConvertTypesPass<DemoteF32ToF16Base<DemoteF32ToF16Pass>,
DemoteF32ToF16Converter> {};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteF32ToF16Pass() {
return std::make_unique<DemoteF32ToF16Pass>();
}
static PassRegistration<DemoteF32ToF16Pass> demoteF32ToF16Pass;

namespace {
struct PromoteF16ToF32Converter
Expand All @@ -335,22 +318,13 @@ struct PromoteF16ToF32Converter
}
};
struct PromoteF16ToF32Pass
: public ConvertTypesPass<PromoteF16ToF32Pass, PromoteF16ToF32Converter> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PromoteF16ToF32Pass)

StringRef getArgument() const override {
return "iree-util-promote-f16-to-f32";
}
StringRef getDescription() const override {
return "Promotes f16 types to f32 types.";
}
};
: public ConvertTypesPass<PromoteF16ToF32Base<PromoteF16ToF32Pass>,
PromoteF16ToF32Converter> {};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createPromoteF16ToF32Pass() {
return std::make_unique<PromoteF16ToF32Pass>();
}
static PassRegistration<PromoteF16ToF32Pass> promoteF16ToF32Pass;

namespace {
struct DemoteF64ToF32Converter
Expand All @@ -360,22 +334,13 @@ struct DemoteF64ToF32Converter
}
};
struct DemoteF64ToF32Pass
: public ConvertTypesPass<DemoteF64ToF32Pass, DemoteF64ToF32Converter> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DemoteF64ToF32Pass)

StringRef getArgument() const override {
return "iree-util-demote-f64-to-f32";
}
StringRef getDescription() const override {
return "Demotes f64 types to f32 types.";
}
};
: public ConvertTypesPass<DemoteF64ToF32Base<DemoteF64ToF32Pass>,
DemoteF64ToF32Converter> {};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteF64ToF32Pass() {
return std::make_unique<DemoteF64ToF32Pass>();
}
static PassRegistration<DemoteF64ToF32Pass> demoteF64ToF32Pass;

} // namespace Util
} // namespace IREE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <utility>

#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"

Expand All @@ -16,18 +17,8 @@ namespace IREE {
namespace Util {

class DropCompilerHintsPass
: public PassWrapper<DropCompilerHintsPass, OperationPass<void>> {
: public DropCompilerHintsBase<DropCompilerHintsPass> {
public:
StringRef getArgument() const override {
return "iree-util-drop-compiler-hints";
}

StringRef getDescription() const override {
return "Deletes operations that have no runtime equivalent and are only "
"used in the compiler. This should be performed after all other "
"compiler passes.";
}

void runOnOperation() override {
// We can't use patterns and applyPatternsAndFoldGreedily because that
// automatically does canonicalization.
Expand All @@ -42,8 +33,6 @@ std::unique_ptr<OperationPass<void>> createDropCompilerHintsPass() {
return std::make_unique<DropCompilerHintsPass>();
}

static PassRegistration<DropCompilerHintsPass> pass;

} // namespace Util
} // namespace IREE
} // namespace iree_compiler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
Expand All @@ -24,25 +26,16 @@ namespace {
// iteration terminates. If a sub-pass removes it, then iteration will
// continue.
class FixedPointIteratorPass
: public PassWrapper<FixedPointIteratorPass, OperationPass<void>> {
: public FixedPointIteratorBase<FixedPointIteratorPass> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FixedPointIteratorPass)

FixedPointIteratorPass() = default;
FixedPointIteratorPass(const FixedPointIteratorPass &other)
: PassWrapper(other) {}
: FixedPointIteratorBase<FixedPointIteratorPass>(other) {}
FixedPointIteratorPass(OpPassManager pipeline);

private:
LogicalResult initializeOptions(StringRef options) override;
void getDependentDialects(DialectRegistry &registry) const override;
StringRef getArgument() const override {
return "iree-util-fixed-point-iterator";
}
StringRef getDescription() const override {
return "Iterates a sub-pipeline to a fixed point";
}

void runOnOperation() override;

Optional<OpPassManager> pipeline;
Expand Down Expand Up @@ -125,7 +118,6 @@ void FixedPointIteratorPass::runOnOperation() {

std::unique_ptr<OperationPass<void>> createFixedPointIteratorPass(
OpPassManager pipeline) {
static PassRegistration<FixedPointIteratorPass> pass;
return std::make_unique<FixedPointIteratorPass>(std::move(pipeline));
}

Expand Down
Loading