-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
share MemOptVarInfos of external variables into cinn_launch subgraph (#…
…39209) * add a graph pass to share MemOptVarInfos of external variables into subgraph * update pass name * fix compile failed * add share_mem_opt_info_to_subgraph_pass test * share_mem_opt_info_to_subgraph_pass_test pass * modify some codes for better style and more robust * update cmake
- Loading branch information
Showing
9 changed files
with
360 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
147 changes: 147 additions & 0 deletions
147
paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <algorithm> | ||
#include "paddle/fluid/framework/details/computation_op_handle.h" | ||
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" | ||
#include "paddle/fluid/framework/ir/graph_helper.h" | ||
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" | ||
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" | ||
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" | ||
#include "paddle/fluid/operators/cinn/cinn_launch_op.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
#include "paddle/fluid/string/string_helper.h" | ||
|
||
namespace paddle::framework::ir { | ||
|
||
using Name2VarInfoMap = | ||
std::unordered_map<std::string, std::shared_ptr<MemOptVarInfo>>; | ||
|
||
static details::EagerDeletionOpHandle* FindFollowedEagerDeletionOp( | ||
details::ComputationOpHandle* compute_op) { | ||
for (details::VarHandleBase* var : compute_op->Outputs()) { | ||
if (!var->Node()->IsCtrlVar()) { | ||
continue; | ||
} | ||
for (details::OpHandleBase* op : var->PendingOps()) { | ||
auto* eager_deletion_op = | ||
dynamic_cast<details::EagerDeletionOpHandle*>(op); | ||
if (eager_deletion_op) { | ||
return eager_deletion_op; | ||
} | ||
} | ||
} | ||
return nullptr; | ||
} | ||
|
||
static void ShareVarInfoToCinnLaunch( | ||
const MemOptVarInfoMapList& varinfo_maps, | ||
details::ComputationOpHandle* cinn_launch_op) { | ||
details::EagerDeletionOpHandle* followed_eager_deletion_op = | ||
FindFollowedEagerDeletionOp(cinn_launch_op); | ||
if (!followed_eager_deletion_op) { | ||
VLOG(4) << "No eager_deletion op found after this cinn_launch op"; | ||
return; | ||
} | ||
|
||
std::vector<std::string> vars_to_delete = | ||
followed_eager_deletion_op->VarsToDelete(); | ||
if (vars_to_delete.empty()) { | ||
VLOG(4) << "No var to be deleted after this cinn_launch op"; | ||
return; | ||
} | ||
VLOG(4) << "Variables would be deleted by the eager_deletion_op" | ||
<< " following the cinn_launch:" | ||
<< paddle::string::join_strings(vars_to_delete, ','); | ||
|
||
const Graph& subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph( | ||
cinn_launch_op->GetOp()->Attr<std::string>(operators::kCompilationKey)); | ||
auto& dst_varinfo_map = | ||
subgraph.Get<Name2VarInfoMap>(paddle2cinn::kMemOptVarInfoFromMainGraph); | ||
const Name2VarInfoMap& src_varinfo_map = | ||
varinfo_maps.at(cinn_launch_op->GetScopeIdx()); | ||
|
||
// collect all MemOptVarInfos of external variables | ||
// that would be eager deleted after the cinn_launch subgraph executed, | ||
// and store them as attribute of the subgraph | ||
for (const auto& var_name : vars_to_delete) { | ||
auto it = src_varinfo_map.find(var_name); | ||
PADDLE_ENFORCE_NE(it, src_varinfo_map.end(), | ||
platform::errors::NotFound( | ||
"MemOptVarInfo of var[%s] not found", var_name)); | ||
dst_varinfo_map.emplace(var_name, it->second); | ||
} | ||
} | ||
|
||
static void TakeVarInfoFromMainGraph( | ||
const Name2VarInfoMap& src_varinfo_map, | ||
const MemOptVarInfoMapList& varinfo_maps, | ||
details::EagerDeletionOpHandle* eager_deletion_op) { | ||
const Name2VarInfoMap& dst_varinfo_map = | ||
varinfo_maps.at(eager_deletion_op->GetScopeIdx()); | ||
for (auto&& var_name : eager_deletion_op->VarsToDelete()) { | ||
auto dst_it = dst_varinfo_map.find(var_name); | ||
PADDLE_ENFORCE_NE(dst_it, dst_varinfo_map.end(), | ||
platform::errors::NotFound( | ||
"MemOptVarInfo of var[%s] not found", var_name)); | ||
auto src_it = src_varinfo_map.find(var_name); | ||
if (src_it != src_varinfo_map.end()) { | ||
VLOG(4) << "MemOptVarInfo of var[" << var_name << "] set parent holder"; | ||
dst_it->second->SetParentHolder(src_it->second); | ||
} | ||
} | ||
} | ||
|
||
// This pass will be applied on both the main graph and all cinn subgraphs, | ||
// and it distinguishs them according to whether the graph has the | ||
// kMemOptVarInfoFromMainGraph attribute or not. | ||
// On the main graph, it finds all cinn_launch ops and shares MemOptVarInfos | ||
// to their subgraphs. | ||
// On a cinn subgraph, it iterates each variable that will be deleted by a | ||
// eager_deletion op, and take the MemOptVarInfo from the main graph | ||
// if such one found. | ||
class ShareMemOptInfoToSubGraphPass : public ir::Pass { | ||
protected: | ||
void ApplyImpl(ir::Graph* graph) const override { | ||
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph); | ||
const auto& varinfo_maps = Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList); | ||
|
||
// the main graph | ||
if (!graph->Has(paddle2cinn::kMemOptVarInfoFromMainGraph)) { | ||
for (details::OpHandleBase* op : all_ops) { | ||
auto compute_op = dynamic_cast<details::ComputationOpHandle*>(op); | ||
if (compute_op && compute_op->Name() == "cinn_launch") { | ||
ShareVarInfoToCinnLaunch(varinfo_maps, compute_op); | ||
} | ||
} | ||
} else { // a cinn subgraph | ||
const auto& parent_varinfo_map = | ||
graph->Get<Name2VarInfoMap>(paddle2cinn::kMemOptVarInfoFromMainGraph); | ||
for (details::OpHandleBase* op : all_ops) { | ||
auto eager_deletion_op = | ||
dynamic_cast<details::EagerDeletionOpHandle*>(op); | ||
if (eager_deletion_op) { | ||
TakeVarInfoFromMainGraph(parent_varinfo_map, varinfo_maps, | ||
eager_deletion_op); | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace paddle::framework::ir | ||
|
||
REGISTER_PASS(share_varinfo_into_cinn_pass, | ||
paddle::framework::ir::ShareMemOptInfoToSubGraphPass) | ||
.RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList); |
142 changes: 142 additions & 0 deletions
142
paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <memory> | ||
#include "gtest/gtest.h" | ||
#include "paddle/fluid/framework/details/computation_op_handle.h" | ||
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" | ||
#include "paddle/fluid/framework/ir/graph.h" | ||
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" | ||
#include "paddle/fluid/framework/ir/pass.h" | ||
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" | ||
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" | ||
#include "paddle/fluid/framework/parallel_executor.h" | ||
#include "paddle/fluid/framework/program_desc.h" | ||
|
||
USE_OP(mul); | ||
USE_OP(cinn_launch); | ||
USE_OP(elementwise_add); | ||
namespace paddle::framework { | ||
|
||
using Name2VarInfoMap = | ||
std::unordered_map<std::string, std::shared_ptr<ir::MemOptVarInfo>>; | ||
|
||
static ProgramDesc BuildProgramInsideCinnLaunchOp() { | ||
ProgramDesc program; | ||
auto* block = program.MutableBlock(0); | ||
block->Var("var1"); | ||
block->Var("var2"); | ||
block->Var("var3"); | ||
block->Var("var4"); | ||
block->Var("var5"); | ||
|
||
auto add_op = std::unique_ptr<OpDesc>( | ||
new OpDesc("elementwise_add", {{"X", {"var1"}}, {"Y", {"var2"}}}, | ||
{{"Out", {"var3"}}}, {})); | ||
block->AppendAllocatedOp(std::move(add_op)); | ||
auto mul_op = std::unique_ptr<OpDesc>(new OpDesc( | ||
"mul", {{"X", {"var3"}}, {"Y", {"var4"}}}, {{"Out", {"var5"}}}, {})); | ||
block->AppendAllocatedOp(std::move(mul_op)); | ||
return program; | ||
} | ||
|
||
static ProgramDesc BuildProgramWithCinnLaunchOp( | ||
const std::string& compilation_key) { | ||
// create a cinn_launch op | ||
ProgramDesc program; | ||
auto* block = program.MutableBlock(0); | ||
block->Var("var1"); | ||
block->Var("var2"); | ||
block->Var("var4"); | ||
block->Var("var5"); | ||
|
||
auto cinn_launch_op = std::unique_ptr<OpDesc>( | ||
new OpDesc("cinn_launch", {{"X", {"var1", "var2", "var4"}}}, | ||
{{"Out", {"var5"}}}, {{"compilation_key", compilation_key}})); | ||
block->AppendAllocatedOp(std::move(cinn_launch_op)); | ||
return program; | ||
} | ||
|
||
struct TestPassContext { | ||
explicit TestPassContext(const ProgramDesc& program) { | ||
graph = std::make_unique<ir::Graph>(program); | ||
details::BuildStrategy build_strategy; | ||
details::ExecutionStrategy exec_strategy; | ||
exec_strategy.use_device_ = paddle::platform::kCUDA; | ||
executor.reset(new ParallelExecutor(platform::CUDAPlace(0), &scope, | ||
exec_strategy, build_strategy, | ||
graph.get())); | ||
} | ||
|
||
Scope scope; | ||
std::unique_ptr<ir::Graph> graph; | ||
std::unique_ptr<ParallelExecutor> executor; | ||
}; | ||
|
||
TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_varinfo) { | ||
// add a subgraph to CinnCompiler | ||
auto subgraph = std::make_unique<ir::Graph>(BuildProgramInsideCinnLaunchOp()); | ||
subgraph->GetOrInit<Name2VarInfoMap>( | ||
paddle2cinn::kMemOptVarInfoFromMainGraph); | ||
std::string compilation_key = | ||
paddle2cinn::CinnCompiler::GetInstance()->AddGraph(std::move(subgraph)); | ||
|
||
// build test data and apply pass | ||
auto context = std::make_unique<TestPassContext>( | ||
BuildProgramWithCinnLaunchOp(compilation_key)); | ||
|
||
// check result | ||
const ir::Graph& result_subgraph = | ||
paddle2cinn::CinnCompiler::GetInstance()->FindGraph(compilation_key); | ||
const auto& dst_varinfo_map = result_subgraph.Get<Name2VarInfoMap>( | ||
paddle2cinn::kMemOptVarInfoFromMainGraph); | ||
ASSERT_EQ(dst_varinfo_map.size(), 4); | ||
EXPECT_EQ(dst_varinfo_map.count("var1"), 1); | ||
EXPECT_EQ(dst_varinfo_map.count("var5"), 1); | ||
EXPECT_EQ(dst_varinfo_map.at("var1").use_count(), 2); | ||
EXPECT_EQ(dst_varinfo_map.at("var5").use_count(), 2); | ||
} | ||
|
||
TEST(ShareMemInfoToSubGraphPassTest, test_subgraph_take_varinfo) { | ||
// build test data and apply pass | ||
auto context = | ||
std::make_unique<TestPassContext>(BuildProgramInsideCinnLaunchOp()); | ||
auto& varinfo_map_shared = context->graph->GetOrInit<Name2VarInfoMap>( | ||
paddle2cinn::kMemOptVarInfoFromMainGraph); | ||
varinfo_map_shared = { | ||
{"var1", std::make_shared<ir::MemOptVarInfo>("var1", 1)}, | ||
{"var2", std::make_shared<ir::MemOptVarInfo>("var2", 2)}, | ||
}; | ||
|
||
ir::MemOptVarInfoMapList varinfo_maps(1); | ||
auto& dst_varinfo_map = varinfo_maps.front(); | ||
dst_varinfo_map = {{"var1", std::make_shared<ir::MemOptVarInfo>("var1", 1)}, | ||
{"var2", std::make_shared<ir::MemOptVarInfo>("var2", 1)}, | ||
{"var3", std::make_shared<ir::MemOptVarInfo>("var3", 1)}, | ||
{"var4", std::make_shared<ir::MemOptVarInfo>("var4", 1)}, | ||
{"var5", std::make_shared<ir::MemOptVarInfo>("var5", 1)}}; | ||
auto share_pass = | ||
ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass"); | ||
share_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &varinfo_maps); | ||
share_pass->Apply(context->graph.get()); | ||
|
||
// check result | ||
ASSERT_NE(dst_varinfo_map.at("var1")->ParentHolder(), nullptr); | ||
ASSERT_NE(dst_varinfo_map.at("var2")->ParentHolder(), nullptr); | ||
ASSERT_EQ(dst_varinfo_map.at("var3")->ParentHolder(), nullptr); | ||
ASSERT_EQ(dst_varinfo_map.at("var4")->ParentHolder(), nullptr); | ||
ASSERT_EQ(dst_varinfo_map.at("var5")->ParentHolder(), nullptr); | ||
} | ||
|
||
} // namespace paddle::framework |
Oops, something went wrong.