forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add ProggramPass and DecomposerPass (PaddlePaddle#464)
* add decomposer pass and relu kernel * optimize decomposer pass and decomposer registry
- Loading branch information
1 parent
53bc5c2
commit 4ec9bb7
Showing
13 changed files
with
405 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
core_gather_headers() | ||
|
||
gather_srcs(cinnapi_src SRCS | ||
decomposer.cc | ||
) | ||
|
||
|
||
cc_test(test_decomposer_pass SRCS decomposer_test.cc DEPS cinncore) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
// Copyright (c) 2021 CINN Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "cinn/frontend/decomposer_registry.h" | ||
#include "cinn/frontend/program_pass.h" | ||
|
||
namespace cinn { | ||
namespace frontend { | ||
namespace pass { | ||
|
||
class DecomposerPass : public ProgramPass { | ||
public: | ||
using ProgramPass::ProgramPass; | ||
|
||
void ApplyImpl(Program* prog, const common::Target& target) const { | ||
// step 1: set the inputs of the origin program to the new program | ||
CinnBuilder builder("decomposer_builder"); | ||
for (auto& var : prog->GetInputs()) { | ||
builder.CreateInput(var); | ||
} | ||
|
||
// step 2: use primitive instructions to build the new program | ||
absl::flat_hash_map<std::string, Variable> var_map; | ||
DecomposerContext context(&builder, &var_map); | ||
for (size_t i = 0; i < prog->size(); i++) { | ||
auto instr = (*prog)[i]; | ||
auto decomposer = InstrDecomposerRegistry::Global()->Find(instr->op_type, target); | ||
if (decomposer) { | ||
decomposer->Run(instr, context); | ||
} else { | ||
builder.AppendInstruction(instr); | ||
} | ||
} | ||
*prog = builder.Build(); | ||
|
||
// step 3: set the origin output to the output of decomposed operator. | ||
for (size_t i = 0; i < prog->size(); i++) { | ||
auto& outputs = (*prog)[i]->outputs; | ||
for (size_t j = 0; j < outputs.size(); j++) { | ||
auto it = var_map.find(outputs[j]->id); | ||
if (it != var_map.end()) { | ||
outputs[j] = it->second; | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace pass | ||
} // namespace frontend | ||
} // namespace cinn | ||
|
||
CINN_REGISTER_HELPER(Decomposer) { | ||
CINN_REGISTER_PROGRAM_PASS(Decomposer, ::cinn::frontend::pass::DecomposerPass); | ||
|
||
return true; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
// Copyright (c) 2021 CINN Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <random> | ||
|
||
#include "cinn/frontend/decomposer/use_decomposer.h" | ||
#include "cinn/frontend/decomposer_registry.h" | ||
#include "cinn/frontend/net_builder.h" | ||
#include "cinn/frontend/pass/use_program_pass.h" | ||
#include "cinn/frontend/program_pass.h" | ||
#include "cinn/hlir/framework/graph.h" | ||
#include "cinn/hlir/framework/graph_compiler.h" | ||
#include "cinn/hlir/framework/pass.h" | ||
#include "cinn/hlir/framework/tensor.h" | ||
#include "cinn/hlir/op/use_ops.h" | ||
#include "cinn/hlir/pass/use_pass.h" | ||
|
||
namespace cinn::frontend { | ||
|
||
Program CreateAddProgram() { | ||
constexpr int M = 32; | ||
constexpr int N = 24; | ||
|
||
NetBuilder builder("net_builder"); | ||
auto a = builder.CreateInput(Float(32), {M, N}); | ||
auto b = builder.CreateInput(Float(32), {M, N}); | ||
auto c = builder.relu(a); | ||
auto d = builder.add(b, c); | ||
auto program = builder.Build(); | ||
|
||
return program; | ||
} | ||
|
||
void SetRandData(hlir::framework::Tensor tensor, Target target) { | ||
auto* data = tensor->mutable_data<float>(target); | ||
std::random_device seed; | ||
std::default_random_engine engine(seed()); | ||
std::uniform_real_distribution<float> dist(0.f, 1.f); | ||
size_t num_ele = tensor->shape().numel(); | ||
std::vector<float> random_data(num_ele); | ||
for (size_t i = 0; i < num_ele; i++) { | ||
random_data[i] = dist(engine); // All random data | ||
} | ||
|
||
#ifdef CINN_WITH_CUDA | ||
cudaMemcpy(data, random_data.data(), num_ele * sizeof(float), cudaMemcpyHostToDevice); | ||
#else | ||
std::copy(random_data.begin(), random_data.end(), data); | ||
#endif | ||
} | ||
|
||
TEST(DecomposePassRegistry, basic) { | ||
ASSERT_NE(cinn::frontend::ProgramPassRegistry::Global()->Find("Decomposer"), nullptr); | ||
ASSERT_EQ(cinn::frontend::ProgramPassRegistry::Global()->Find("Test"), nullptr); | ||
} | ||
|
||
TEST(DecomposePass, basic) { | ||
auto prog = CreateAddProgram(); | ||
for (int i = 0; i < prog.size(); i++) { | ||
LOG(INFO) << "instruction: " << prog[i]; | ||
} | ||
|
||
#ifdef CINN_WITH_CUDA | ||
Target target = common::DefaultNVGPUTarget(); | ||
#else | ||
Target target = common::DefaultHostTarget(); | ||
#endif | ||
|
||
ProgramPass::Apply(&prog, target, {"Decomposer"}); | ||
for (int i = 0; i < prog.size(); i++) { | ||
LOG(INFO) << "new instruction: " << prog[i]; | ||
} | ||
|
||
auto graph = std::make_shared<hlir::framework::Graph>(prog, target); | ||
hlir::framework::ApplyPass(graph.get(), "OpFusion"); | ||
auto scope = BuildScope(target, graph); | ||
hlir::framework::GraphCompiler gc(target, scope, graph); | ||
auto runtime_program = gc.Build(); | ||
|
||
scope->Var<hlir::framework::Tensor>("A"); | ||
scope->Var<hlir::framework::Tensor>("B"); | ||
|
||
auto A = scope->GetTensor("A"); | ||
auto B = scope->GetTensor("B"); | ||
SetRandData(A, target); | ||
SetRandData(B, target); | ||
|
||
runtime_program->Execute(); | ||
} | ||
|
||
} // namespace cinn::frontend |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// Copyright (c) 2021 CINN Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "cinn/common/macros.h" | ||
|
||
CINN_USE_REGISTER(Decomposer) |
Oops, something went wrong.