Skip to content

Commit

Permalink
add registry
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweibo committed Mar 23, 2023
1 parent 5edbbd9 commit 50747e6
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 7 deletions.
15 changes: 11 additions & 4 deletions paddle/infra/Pass/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

/// The code and design is mainly from mlir, very thanks to the great project.

#include <algorithm>
#include <cassert>
#include <optional>

Expand Down Expand Up @@ -65,14 +66,16 @@ struct PassInfo {
class Pass {
public:
virtual ~Pass() = default;

PassInfo GetPassInfo() const { return info_; }

protected:
explicit Pass(const std::string& name,
int opt_level,
const std::vector<std::string>& dependents = {})
: info_(name, opt_level, dependents) {}

PassInfo GetPassInfo() const { return info_; }

std::unique_ptr<Pass> Clone() const { return ClonePass(); }

protected:
virtual void Run(mlir::Operation* op) = 0;

virtual inline bool CanScheduleOn(mlir::Operation* op) const {
Expand All @@ -83,6 +86,10 @@ class Pass {
return mlir::success();
}

// TODO(wilber): need to consider pure virtual.
/// A clone method to create a copy of this pass.
virtual std::unique_ptr<Pass> ClonePass() const { return nullptr; }

AnalysisManager GetAnalysisManager() { return pass_state_->am; }

void SignalPassFailure() { pass_state_->pass_failed = true; }
Expand Down
30 changes: 30 additions & 0 deletions paddle/infra/Pass/PassRegistry.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "Pass/PassRegistry.h"
#include <memory>
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/ManagedStatic.h"

namespace infra {

static llvm::ManagedStatic<llvm::StringMap<PassAllocatorFunction>> PassRegistry;

void RegisterPass(const PassAllocatorFunction& func) {
std::unique_ptr<Pass> pass = func();

PassRegistry->try_emplace(pass->GetPassInfo().name, func);
}

} // namespace infra
36 changes: 36 additions & 0 deletions paddle/infra/Pass/PassRegistry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>
#include <memory>
#include "Pass.h"

namespace infra {

using PassAllocatorFunction = std::function<std::unique_ptr<Pass>()>;

void RegisterPass(const PassAllocatorFunction& func);

template <typename ConcretePass>
struct PassRegistration {
explicit PassRegistration(const PassAllocatorFunction& func) {
RegisterPass(func);
}
PassRegistration()
: PassRegistration([] { return std::make_unique<ConcretePass>(); }) {}
};

} // namespace infra
7 changes: 4 additions & 3 deletions paddle/infra/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ git submodule update --init

## 1. compile llvm

cmake > 3.19

llvm commit id: 10939d1d580b9d3c9c2f3539c6bdb39f408179c0
```
cd llvm-project
mkdir build && cd build
cmake -GNinja \
"-H$LLVM_SRC_DIR/llvm" \
"-B$build_dir" \
-DLLVM_INSTALL_UTILS=ON \
-DLLVM_ENABLE_LLD=ON \
-DLLVM_ENABLE_PROJECTS=mlir \
Expand All @@ -25,7 +25,8 @@ cmake -GNinja \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DPython3_EXECUTABLE=`which python` \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DLLVM_ENABLE_ASSERTIONS=On
-DLLVM_ENABLE_ASSERTIONS=On \
../llvm
ninja
```

Expand Down
71 changes: 71 additions & 0 deletions paddle/infra/python/pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <pybind11/detail/common.h>
#include "mlir/IR/Operation.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"

#include "Pass/Pass.h"
#include "Pass/PassManager.h"

namespace py = pybind11;

namespace infra {
namespace {

class PassTrampoline : public Pass {
public:
using Pass::Pass;

inline bool CanScheduleOn(mlir::Operation* op) const override {
PYBIND11_OVERRIDE(bool, Pass, CanScheduleOn, op);
}
void Run(mlir::Operation* op) override {
PYBIND11_OVERRIDE_PURE(void, Pass, Run, op);
}
};

class PassPublicist : public Pass {
public:
using Pass::CanScheduleOn;
using Pass::Pass;
using Pass::Run;
};
} // namespace
PYBIND11_MODULE(pass_python, m) {
py::class_<PassInfo>(m, "PassInfo")
.def(py::init<const std::string&, int, const std::vector<std::string>&>())
.def_readwrite("name", &PassInfo::name, "...")
.def_readwrite("opt_level", &PassInfo::opt_level, "...")
.def_readwrite("dependents", &PassInfo::dependents, "...");

py::class_<Pass, PassTrampoline>(m, "Pass")
.def(py::init<const std::string&, int, const std::vector<std::string>&>())
.def("can_schedule_on", &PassPublicist::CanScheduleOn, "...")
.def("run", &PassPublicist::Run, "...")
.def("get_pass_info", &Pass::GetPassInfo, "...");

py::class_<PassManager>(m, "PassManager")
.def(py::init<mlir::MLIRContext*, int>())
.def("run", &PassManager::Run)

// pybind not support `RetType func(unique_ptr)` signature.
// https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#std-unique-ptr
.def("add_pass", [](PassManager& pm, Pass* pass) {
auto cp = pass->Clone();
pm.addPass(std::move(cp));
});
}
} // namespace infra

0 comments on commit 50747e6

Please sign in to comment.