Skip to content

Commit

Permalink
add ProggramPass and DecomposerPass (PaddlePaddle#464)
Browse files Browse the repository at this point in the history
* add decomposer pass and relu kernel

* optimize decomposer pass and decomposer registry
  • Loading branch information
zhangting2020 authored Oct 14, 2021
1 parent 53bc5c2 commit 4ec9bb7
Show file tree
Hide file tree
Showing 13 changed files with 405 additions and 16 deletions.
4 changes: 3 additions & 1 deletion cinn/frontend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ gather_srcs(cinnapi_src SRCS
base_builder.cc
net_builder.cc
cinn_builder.cc
paddle_model_to_netbuilder.cc)
paddle_model_to_netbuilder.cc
program_pass.cc)

if(NOT WITH_CUDA)
cc_test(test_frontend_syntax
Expand Down Expand Up @@ -42,5 +43,6 @@ cc_test(test_decomposer_registry
add_subdirectory(paddle)
add_subdirectory(decomposer)
add_subdirectory(op_mappers)
add_subdirectory(pass)

cc_test(test_op_mapper_registry SRCS op_mapper_registry_test.cc DEPS cinncore)
7 changes: 7 additions & 0 deletions cinn/frontend/base_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ Placeholder BaseBuilder::CreateInput(const Type& type, const std::vector<int>& s
return Placeholder(var);
}

Placeholder BaseBuilder::CreateInput(const Variable& var) {
CHECK(!var->shape.empty()) << "The input's shape is not set yet";
CHECK(!var->type.is_unk()) << "The input's type is not set yet";
inputs_.push_back(var);
return Placeholder(var);
}

void BaseBuilder::InferShape(Instruction instr) const {
using shape_func_t = std::function<std::vector<shape_t>(const std::vector<shape_t>&, const AttrMapType&)>;
using type_func_t = std::function<std::vector<Type>(const std::vector<Type>&, const AttrMapType&)>;
Expand Down
3 changes: 2 additions & 1 deletion cinn/frontend/base_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ class BaseBuilder {
Program Build();

Placeholder CreateInput(const common::Type& type, const std::vector<int>& shape, const std::string& id_hint = "");
Placeholder CreateInput(const Variable& input);

// name of this builder
const std::string& name() { return name_; }

virtual ~BaseBuilder() {}

protected:
void AppendInstruction(const Instruction& instr) { instrs_.push_back(instr); }

protected:
void InferShape(Instruction instr) const;

std::string name_;
Expand Down
18 changes: 16 additions & 2 deletions cinn/frontend/decomposer/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,33 @@
// limitations under the License.

#include "cinn/frontend/decomposer_registry.h"
#include "cinn/frontend/syntax.h"

namespace cinn {
namespace frontend {
namespace decomposer {

void relu(const Instruction& instr, const DecomposerContext& context) { LOG(FATAL) << "not implemented"; }
void relu(const Instruction& instr, const DecomposerContext& context) {
CHECK_EQ(instr->inputs.size(), 1UL) << " 1 input tensor for " << instr->op_type;
CHECK_EQ(instr->outputs.size(), 1UL) << "1 output tensor for " << instr->op_type;
auto x = instr->inputs[0];
auto output = instr->outputs[0];
auto* builder = context.builder_;

auto zero_var = builder->ConstScalar<float>(0.f, common::UniqName("zero"));
auto bcast_zero = builder->BroadcastTo(zero_var, x->shape, {0});
auto out = builder->Max(x, bcast_zero);

// map the the output of decomposed operator to the original.
context.MapVarToOrigin(out, output);
}

} // namespace decomposer
} // namespace frontend
} // namespace cinn

CINN_REGISTER_HELPER(activation) {
CINN_DECOMPOSER_REGISTER(relu, ::cinn::common::DefaultHostTarget()).set_body(cinn::frontend::decomposer::relu);
CINN_DECOMPOSER_REGISTER(relu, cinn::frontend::decomposer::relu);

return true;
}
54 changes: 43 additions & 11 deletions cinn/frontend/decomposer_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <unordered_map>

#include "cinn/common/target.h"
#include "cinn/frontend/cinn_builder.h"
#include "cinn/frontend/syntax.h"

namespace cinn {
Expand All @@ -28,9 +29,16 @@ class Decomposer;

class DecomposerContext {
public:
explicit DecomposerContext(Program* prog) : program(prog) {}
explicit DecomposerContext(CinnBuilder* builder, absl::flat_hash_map<std::string, Variable>* var_map)
: builder_(builder), var_map_(var_map) {}

Program* program{nullptr};
CinnBuilder* builder_{nullptr};

// Map the new var to the original var.
void MapVarToOrigin(const Variable& new_var, const Variable& ori_var) const { (*var_map_)[new_var->id] = ori_var; }

private:
absl::flat_hash_map<std::string, Variable>* var_map_{nullptr};
};

class InstrDecomposerRegistry : public Registry<Decomposer> {
Expand All @@ -50,10 +58,6 @@ class InstrDecomposerRegistry : public Registry<Decomposer> {
return Registry<Decomposer>::Find(name + "_" + target.arch_str());
}

inline Decomposer& __REGISTER__(const std::string& name, const common::Target& target) {
return Registry<Decomposer>::__REGISTER__(name + "_" + target.arch_str());
}

private:
InstrDecomposerRegistry() = default;
CINN_DISALLOW_COPY_AND_ASSIGN(InstrDecomposerRegistry);
Expand All @@ -63,22 +67,50 @@ class Decomposer {
public:
using DecomposerKernel = std::function<void(const Instruction& instr, const DecomposerContext&)>;

Decomposer& set_body(const DecomposerKernel& kernel) {
Decomposer& SetBody(const DecomposerKernel& kernel) {
kernel_ = kernel;
return *this;
}

void Run(const Instruction& instr, const DecomposerContext& context) { kernel_(instr, context); }
void Run(const Instruction& instr, const DecomposerContext& context) const { kernel_(instr, context); }

std::string name;

private:
DecomposerKernel kernel_;
};

#define CINN_DECOMPOSER_REGISTER(name, target) \
static ::cinn::frontend::Decomposer& CINN_STR_CONCAT(__make_Decomposer_name, __COUNTER__) = \
::cinn::frontend::InstrDecomposerRegistry::Global()->__REGISTER__(#name, target)
#define CINN_DECOMPOSER_REGISTER_CORE(name, target, kernel) \
::cinn::frontend::InstrDecomposerRegistry::Global() \
->__REGISTER__(std::string(#name) + "_" + target.arch_str()) \
.SetBody(kernel)

#define CINN_DECOMPOSER_REGISTER_ALL(name, kernel) \
static std::vector<::cinn::common::Target> all_targets = {::cinn::common::DefaultHostTarget(), \
::cinn::common::DefaultNVGPUTarget()}; \
for (auto& target : all_targets) { \
::cinn::frontend::InstrDecomposerRegistry::Global() \
->__REGISTER__(std::string(#name) + "_" + target.arch_str()) \
.SetBody(kernel); \
}

/**
* @def CINN_DECOMPOSER_REGISTER
* \brief Register a decomposer kernel
*
* Register a decomposer on the specific target:
* \code
* CINN_DECOMPOSER_REGISTER(name, target, kernel);
* \endcode
*
* Register a decomposer on all default targets:
* \code
* CINN_DECOMPOSER_REGISTER(name, kernel);
* \endcode
*/
#define GET_MACRO(_0, _1, _2, FUNC, ...) FUNC
#define CINN_DECOMPOSER_REGISTER(...) \
GET_MACRO(__VA_ARGS__, CINN_DECOMPOSER_REGISTER_CORE, CINN_DECOMPOSER_REGISTER_ALL)(__VA_ARGS__)

} // namespace frontend
} // namespace cinn
8 changes: 8 additions & 0 deletions cinn/frontend/pass/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
core_gather_headers()

gather_srcs(cinnapi_src SRCS
decomposer.cc
)


cc_test(test_decomposer_pass SRCS decomposer_test.cc DEPS cinncore)
68 changes: 68 additions & 0 deletions cinn/frontend/pass/decomposer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/frontend/decomposer_registry.h"
#include "cinn/frontend/program_pass.h"

namespace cinn {
namespace frontend {
namespace pass {

class DecomposerPass : public ProgramPass {
public:
using ProgramPass::ProgramPass;

void ApplyImpl(Program* prog, const common::Target& target) const {
// step 1: set the inputs of the origin program to the new program
CinnBuilder builder("decomposer_builder");
for (auto& var : prog->GetInputs()) {
builder.CreateInput(var);
}

// step 2: use primitive instructions to build the new program
absl::flat_hash_map<std::string, Variable> var_map;
DecomposerContext context(&builder, &var_map);
for (size_t i = 0; i < prog->size(); i++) {
auto instr = (*prog)[i];
auto decomposer = InstrDecomposerRegistry::Global()->Find(instr->op_type, target);
if (decomposer) {
decomposer->Run(instr, context);
} else {
builder.AppendInstruction(instr);
}
}
*prog = builder.Build();

// step 3: set the origin output to the output of decomposed operator.
for (size_t i = 0; i < prog->size(); i++) {
auto& outputs = (*prog)[i]->outputs;
for (size_t j = 0; j < outputs.size(); j++) {
auto it = var_map.find(outputs[j]->id);
if (it != var_map.end()) {
outputs[j] = it->second;
}
}
}
}
};

} // namespace pass
} // namespace frontend
} // namespace cinn

CINN_REGISTER_HELPER(Decomposer) {
CINN_REGISTER_PROGRAM_PASS(Decomposer, ::cinn::frontend::pass::DecomposerPass);

return true;
}
104 changes: 104 additions & 0 deletions cinn/frontend/pass/decomposer_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>

#include <random>

#include "cinn/frontend/decomposer/use_decomposer.h"
#include "cinn/frontend/decomposer_registry.h"
#include "cinn/frontend/net_builder.h"
#include "cinn/frontend/pass/use_program_pass.h"
#include "cinn/frontend/program_pass.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/framework/tensor.h"
#include "cinn/hlir/op/use_ops.h"
#include "cinn/hlir/pass/use_pass.h"

namespace cinn::frontend {

Program CreateAddProgram() {
constexpr int M = 32;
constexpr int N = 24;

NetBuilder builder("net_builder");
auto a = builder.CreateInput(Float(32), {M, N});
auto b = builder.CreateInput(Float(32), {M, N});
auto c = builder.relu(a);
auto d = builder.add(b, c);
auto program = builder.Build();

return program;
}

void SetRandData(hlir::framework::Tensor tensor, Target target) {
auto* data = tensor->mutable_data<float>(target);
std::random_device seed;
std::default_random_engine engine(seed());
std::uniform_real_distribution<float> dist(0.f, 1.f);
size_t num_ele = tensor->shape().numel();
std::vector<float> random_data(num_ele);
for (size_t i = 0; i < num_ele; i++) {
random_data[i] = dist(engine); // All random data
}

#ifdef CINN_WITH_CUDA
cudaMemcpy(data, random_data.data(), num_ele * sizeof(float), cudaMemcpyHostToDevice);
#else
std::copy(random_data.begin(), random_data.end(), data);
#endif
}

TEST(DecomposePassRegistry, basic) {
ASSERT_NE(cinn::frontend::ProgramPassRegistry::Global()->Find("Decomposer"), nullptr);
ASSERT_EQ(cinn::frontend::ProgramPassRegistry::Global()->Find("Test"), nullptr);
}

TEST(DecomposePass, basic) {
auto prog = CreateAddProgram();
for (int i = 0; i < prog.size(); i++) {
LOG(INFO) << "instruction: " << prog[i];
}

#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif

ProgramPass::Apply(&prog, target, {"Decomposer"});
for (int i = 0; i < prog.size(); i++) {
LOG(INFO) << "new instruction: " << prog[i];
}

auto graph = std::make_shared<hlir::framework::Graph>(prog, target);
hlir::framework::ApplyPass(graph.get(), "OpFusion");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>("A");
scope->Var<hlir::framework::Tensor>("B");

auto A = scope->GetTensor("A");
auto B = scope->GetTensor("B");
SetRandData(A, target);
SetRandData(B, target);

runtime_program->Execute();
}

} // namespace cinn::frontend
19 changes: 19 additions & 0 deletions cinn/frontend/pass/use_program_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "cinn/common/macros.h"

CINN_USE_REGISTER(Decomposer)
Loading

0 comments on commit 4ec9bb7

Please sign in to comment.