Skip to content

Commit

Permalink
NFC: Convert util transforms to declarative registration.
Browse files Browse the repository at this point in the history
  • Loading branch information
Stella Laurenzo committed Aug 19, 2022
1 parent 979d6ea commit 288af43
Show file tree
Hide file tree
Showing 19 changed files with 273 additions and 224 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Patterns.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand All @@ -20,15 +21,8 @@ namespace iree_compiler {
namespace IREE {
namespace Util {

class ApplyPatternsPass
: public PassWrapper<ApplyPatternsPass, OperationPass<void>> {
class ApplyPatternsPass : public ApplyPatternsBase<ApplyPatternsPass> {
public:
StringRef getArgument() const override { return "iree-util-apply-patterns"; }

StringRef getDescription() const override {
return "Applies some risky/IREE-specific canonicalization patterns.";
}

void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<BuiltinDialect, func::FuncDialect, IREE::Util::UtilDialect>();
Expand Down Expand Up @@ -60,8 +54,6 @@ std::unique_ptr<OperationPass<void>> createApplyPatternsPass() {
return std::make_unique<ApplyPatternsPass>();
}

static PassRegistration<ApplyPatternsPass> pass;

} // namespace Util
} // namespace IREE
} // namespace iree_compiler
Expand Down
21 changes: 20 additions & 1 deletion compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
load("//build_tools/bazel:iree_tablegen.bzl", "iree_gentbl_cc_library")

package(
default_visibility = ["//visibility:public"],
Expand All @@ -23,18 +24,21 @@ iree_compiler_cc_library(
"FoldGlobals.cpp",
"FuseGlobals.cpp",
"HoistIntoGlobals.cpp",
"Passes.cpp",
"Patterns.cpp",
"PropagateSubrange.cpp",
"PropagateSubranges.cpp",
"SimplifyGlobalAccesses.cpp",
"StripDebugOps.cpp",
"TestConversion.cpp",
"TestFloatRangeAnalysis.cpp",
],
hdrs = [
"Passes.h",
"Passes.h.inc",
"Patterns.h",
],
deps = [
":PassesIncGen",
"//compiler/src/iree/compiler/Dialect/Util/Analysis",
"//compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes",
"//compiler/src/iree/compiler/Dialect/Util/Analysis/Constant",
Expand All @@ -59,3 +63,18 @@ iree_compiler_cc_library(
"@llvm-project//mlir:Transforms",
],
)

iree_gentbl_cc_library(
name = "PassesIncGen",
tbl_outs = [
(
["--gen-pass-decls"],
"Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_cc_library(
Transforms
HDRS
"Passes.h"
"Passes.h.inc"
"Patterns.h"
SRCS
"ApplyPatterns.cpp"
Expand All @@ -25,13 +26,15 @@ iree_cc_library(
"FoldGlobals.cpp"
"FuseGlobals.cpp"
"HoistIntoGlobals.cpp"
"Passes.cpp"
"Patterns.cpp"
"PropagateSubrange.cpp"
"PropagateSubranges.cpp"
"SimplifyGlobalAccesses.cpp"
"StripDebugOps.cpp"
"TestConversion.cpp"
"TestFloatRangeAnalysis.cpp"
DEPS
::PassesIncGen
LLVMSupport
MLIRAffineDialect
MLIRAnalysis
Expand All @@ -57,4 +60,13 @@ iree_cc_library(
PUBLIC
)

iree_tablegen_library(
NAME
PassesIncGen
TD_FILE
"Passes.td"
OUTS
--gen-pass-decls Passes.h.inc
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTraits.h"
#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
Expand All @@ -30,19 +31,8 @@ namespace Util {
namespace {

class CombineInitializersPass
: public PassWrapper<CombineInitializersPass,
OperationPass<mlir::ModuleOp>> {
: public CombineInitializersBase<CombineInitializersPass> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CombineInitializersPass)

StringRef getArgument() const override {
return "iree-util-combine-initializers";
}

StringRef getDescription() const override {
return "Combines global initializers into one.";
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Util::UtilDialect>();
}
Expand Down Expand Up @@ -92,8 +82,6 @@ std::unique_ptr<OperationPass<mlir::ModuleOp>> createCombineInitializersPass() {
return std::make_unique<CombineInitializersPass>();
}

static PassRegistration<CombineInitializersPass> pass;

} // namespace Util
} // namespace IREE
} // namespace iree_compiler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -215,8 +216,8 @@ struct ConvertTypeSensitiveArithCastOp : public OpConversionPattern<OpTy> {
}
};

template <typename T, typename Converter>
struct ConvertTypesPass : public PassWrapper<T, OperationPass<mlir::ModuleOp>> {
template <typename Base, typename Converter>
struct ConvertTypesPass : public Base {
void runOnOperation() override {
MLIRContext *context = &this->getContext();
RewritePatternSet patterns(context);
Expand Down Expand Up @@ -285,22 +286,13 @@ struct DemoteI64ToI32Converter
}
};
struct DemoteI64ToI32Pass
: public ConvertTypesPass<DemoteI64ToI32Pass, DemoteI64ToI32Converter> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DemoteI64ToI32Pass)

StringRef getArgument() const override {
return "iree-util-demote-i64-to-i32";
}
StringRef getDescription() const override {
return "Demotes i64 types to i32 types.";
}
};
: public ConvertTypesPass<DemoteI64ToI32Base<DemoteI64ToI32Pass>,
DemoteI64ToI32Converter> {};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteI64ToI32Pass() {
return std::make_unique<DemoteI64ToI32Pass>();
}
static PassRegistration<DemoteI64ToI32Pass> demoteI64ToI32Pass;

namespace {
struct DemoteF32ToF16Converter
Expand All @@ -310,22 +302,13 @@ struct DemoteF32ToF16Converter
}
};
struct DemoteF32ToF16Pass
: public ConvertTypesPass<DemoteF32ToF16Pass, DemoteF32ToF16Converter> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DemoteF32ToF16Pass)

StringRef getArgument() const override {
return "iree-util-demote-f32-to-f16";
}
StringRef getDescription() const override {
return "Demotes f32 types to f16 types.";
}
};
: public ConvertTypesPass<DemoteF32ToF16Base<DemoteF32ToF16Pass>,
DemoteF32ToF16Converter> {};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteF32ToF16Pass() {
return std::make_unique<DemoteF32ToF16Pass>();
}
static PassRegistration<DemoteF32ToF16Pass> demoteF32ToF16Pass;

namespace {
struct PromoteF16ToF32Converter
Expand All @@ -335,22 +318,13 @@ struct PromoteF16ToF32Converter
}
};
struct PromoteF16ToF32Pass
: public ConvertTypesPass<PromoteF16ToF32Pass, PromoteF16ToF32Converter> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PromoteF16ToF32Pass)

StringRef getArgument() const override {
return "iree-util-promote-f16-to-f32";
}
StringRef getDescription() const override {
return "Promotes f16 types to f32 types.";
}
};
: public ConvertTypesPass<PromoteF16ToF32Base<PromoteF16ToF32Pass>,
PromoteF16ToF32Converter> {};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createPromoteF16ToF32Pass() {
return std::make_unique<PromoteF16ToF32Pass>();
}
static PassRegistration<PromoteF16ToF32Pass> promoteF16ToF32Pass;

namespace {
struct DemoteF64ToF32Converter
Expand All @@ -360,22 +334,13 @@ struct DemoteF64ToF32Converter
}
};
struct DemoteF64ToF32Pass
: public ConvertTypesPass<DemoteF64ToF32Pass, DemoteF64ToF32Converter> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DemoteF64ToF32Pass)

StringRef getArgument() const override {
return "iree-util-demote-f64-to-f32";
}
StringRef getDescription() const override {
return "Demotes f64 types to f32 types.";
}
};
: public ConvertTypesPass<DemoteF64ToF32Base<DemoteF64ToF32Pass>,
DemoteF64ToF32Converter> {};
} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteF64ToF32Pass() {
return std::make_unique<DemoteF64ToF32Pass>();
}
static PassRegistration<DemoteF64ToF32Pass> demoteF64ToF32Pass;

} // namespace Util
} // namespace IREE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <utility>

#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"

Expand All @@ -16,18 +17,8 @@ namespace IREE {
namespace Util {

class DropCompilerHintsPass
: public PassWrapper<DropCompilerHintsPass, OperationPass<void>> {
: public DropCompilerHintsBase<DropCompilerHintsPass> {
public:
StringRef getArgument() const override {
return "iree-util-drop-compiler-hints";
}

StringRef getDescription() const override {
return "Deletes operations that have no runtime equivalent and are only "
"used in the compiler. This should be performed after all other "
"compiler passes.";
}

void runOnOperation() override {
// We can't use patterns and applyPatternsAndFoldGreedily because that
// automatically does canonicalization.
Expand All @@ -42,8 +33,6 @@ std::unique_ptr<OperationPass<void>> createDropCompilerHintsPass() {
return std::make_unique<DropCompilerHintsPass>();
}

static PassRegistration<DropCompilerHintsPass> pass;

} // namespace Util
} // namespace IREE
} // namespace iree_compiler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
Expand All @@ -24,25 +26,16 @@ namespace {
// iteration terminates. If a sub-pass removes it, then iteration will
// continue.
class FixedPointIteratorPass
: public PassWrapper<FixedPointIteratorPass, OperationPass<void>> {
: public FixedPointIteratorBase<FixedPointIteratorPass> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FixedPointIteratorPass)

FixedPointIteratorPass() = default;
FixedPointIteratorPass(const FixedPointIteratorPass &other)
: PassWrapper(other) {}
: FixedPointIteratorBase<FixedPointIteratorPass>(other) {}
FixedPointIteratorPass(OpPassManager pipeline);

private:
LogicalResult initializeOptions(StringRef options) override;
void getDependentDialects(DialectRegistry &registry) const override;
StringRef getArgument() const override {
return "iree-util-fixed-point-iterator";
}
StringRef getDescription() const override {
return "Iterates a sub-pipeline to a fixed point";
}

void runOnOperation() override;

Optional<OpPassManager> pipeline;
Expand Down Expand Up @@ -125,7 +118,6 @@ void FixedPointIteratorPass::runOnOperation() {

std::unique_ptr<OperationPass<void>> createFixedPointIteratorPass(
OpPassManager pipeline) {
static PassRegistration<FixedPointIteratorPass> pass;
return std::make_unique<FixedPointIteratorPass>(std::move(pipeline));
}

Expand Down
Loading

0 comments on commit 288af43

Please sign in to comment.