Skip to content

Commit

Permalink
share MemOptVarInfos of external variables into cinn_launch subgraph (#…
Browse files Browse the repository at this point in the history
…39209)

* add a graph pass to share MemOptVarInfos of external variables into subgraph

* update pass name

* fix compile failed

* add share_mem_opt_info_to_subgraph_pass test

* share_mem_opt_info_to_subgraph_pass_test pass

* modify some codes for better style and more robust

* update cmake
  • Loading branch information
CtfGo authored Feb 10, 2022
1 parent 29d3160 commit 35b03e1
Show file tree
Hide file tree
Showing 9 changed files with 360 additions and 14 deletions.
18 changes: 16 additions & 2 deletions paddle/fluid/framework/details/eager_deletion_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,21 @@ void EagerDeletionOpHandle::CallOnce() {

std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }

static bool CanBeErased(ir::MemOptVarInfo *var_info) {
if (var_info->IsSkippedAllMemoryOptimization() ||
!var_info->DecreaseRefCnt()) {
return false;
}
#ifdef PADDLE_WITH_CINN
// if parent_holder exists, it should meet deletion condition too.
std::shared_ptr<ir::MemOptVarInfo> parent_holder = var_info->ParentHolder();
if (parent_holder && !CanBeErased(parent_holder.get())) {
return false;
}
#endif
return true;
}

void EagerDeletionOpHandle::RunImpl() {
if (vars_.size() != var_infos_.size() || is_variant_scope_) {
vars_.clear();
Expand All @@ -117,8 +132,7 @@ void EagerDeletionOpHandle::RunImpl() {
std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (size_t i = 0; i < var_infos_.size(); ++i) {
auto *var_info = var_infos_[i];
if (var_info->IsSkippedAllMemoryOptimization() ||
!var_info->DecreaseRefCnt()) {
if (!CanBeErased(var_info)) {
VLOG(4) << "skip memory optimization with var: " << var_info->Name();
continue;
}
Expand Down
14 changes: 10 additions & 4 deletions paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@ cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pas
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)

cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle
eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper)
SET(EAGER_DELETETION_PASS_DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper)
if (WITH_CINN)
cc_library(share_varinfo_into_cinn_pass SRCS share_varinfo_into_cinn_pass.cc DEPS pass enforce graph_helper computation_op_handle eager_deletion_op_handle cinn_compiler)
cc_test(share_varinfo_into_cinn_pass_test SRCS share_varinfo_into_cinn_pass_test.cc DEPS share_varinfo_into_cinn_pass parallel_executor cinn_compiler elementwise_add_op mul_op cinn_launch_op)
list(APPEND EAGER_DELETETION_PASS_DEPS share_varinfo_into_cinn_pass)
endif()

cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS ${EAGER_DELETETION_PASS_DEPS})

cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper)

cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass executor_gc_helper)
cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass)
cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass)

cc_library(inplace_addto_op_pass SRCS inplace_addto_op_pass.cc DEPS memory_reuse_pass)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,13 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto recurrent_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("recurrent_op_eager_deletion_pass");
recurrent_op_eager_deletion_pass->Apply(graph);

#ifdef PADDLE_WITH_CINN
auto share_varinfo_into_cinn_pass =
ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass");
share_varinfo_into_cinn_pass->SetNotOwned(kMemOptVarInfoMapList, &var_infos);
share_varinfo_into_cinn_pass->Apply(graph);
#endif
}

} // namespace ir
Expand All @@ -300,3 +307,6 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass)
USE_PASS(conditional_block_op_eager_deletion_pass);
USE_PASS(while_op_eager_deletion_pass);
USE_PASS(recurrent_op_eager_deletion_pass);
#ifdef PADDLE_WITH_CINN
USE_PASS(share_varinfo_into_cinn_pass);
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ class MemOptVarInfo {
return skip_memory_reuse_ || skip_all_memory_optimization_;
}

void SetParentHolder(std::shared_ptr<MemOptVarInfo> parent) {
parent_holder_ = parent;
}

std::shared_ptr<MemOptVarInfo> ParentHolder() const { return parent_holder_; }

const std::string &Name() const { return name_; }

private:
Expand All @@ -88,6 +94,9 @@ class MemOptVarInfo {
std::atomic<size_t> runtime_ref_cnt_;
bool skip_memory_reuse_{false};
bool skip_all_memory_optimization_{false};
// point to var info of the same variable in the main graph,
// used in external(input/output) variables of a subgraph
std::shared_ptr<MemOptVarInfo> parent_holder_{nullptr};
};

using MemOptVarInfoMapList = std::vector<
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn/cinn_launch_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"

namespace paddle::framework::ir {

using Name2VarInfoMap =
std::unordered_map<std::string, std::shared_ptr<MemOptVarInfo>>;

static details::EagerDeletionOpHandle* FindFollowedEagerDeletionOp(
details::ComputationOpHandle* compute_op) {
for (details::VarHandleBase* var : compute_op->Outputs()) {
if (!var->Node()->IsCtrlVar()) {
continue;
}
for (details::OpHandleBase* op : var->PendingOps()) {
auto* eager_deletion_op =
dynamic_cast<details::EagerDeletionOpHandle*>(op);
if (eager_deletion_op) {
return eager_deletion_op;
}
}
}
return nullptr;
}

static void ShareVarInfoToCinnLaunch(
const MemOptVarInfoMapList& varinfo_maps,
details::ComputationOpHandle* cinn_launch_op) {
details::EagerDeletionOpHandle* followed_eager_deletion_op =
FindFollowedEagerDeletionOp(cinn_launch_op);
if (!followed_eager_deletion_op) {
VLOG(4) << "No eager_deletion op found after this cinn_launch op";
return;
}

std::vector<std::string> vars_to_delete =
followed_eager_deletion_op->VarsToDelete();
if (vars_to_delete.empty()) {
VLOG(4) << "No var to be deleted after this cinn_launch op";
return;
}
VLOG(4) << "Variables would be deleted by the eager_deletion_op"
<< " following the cinn_launch:"
<< paddle::string::join_strings(vars_to_delete, ',');

const Graph& subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph(
cinn_launch_op->GetOp()->Attr<std::string>(operators::kCompilationKey));
auto& dst_varinfo_map =
subgraph.Get<Name2VarInfoMap>(paddle2cinn::kMemOptVarInfoFromMainGraph);
const Name2VarInfoMap& src_varinfo_map =
varinfo_maps.at(cinn_launch_op->GetScopeIdx());

// collect all MemOptVarInfos of external variables
// that would be eager deleted after the cinn_launch subgraph executed,
// and store them as attribute of the subgraph
for (const auto& var_name : vars_to_delete) {
auto it = src_varinfo_map.find(var_name);
PADDLE_ENFORCE_NE(it, src_varinfo_map.end(),
platform::errors::NotFound(
"MemOptVarInfo of var[%s] not found", var_name));
dst_varinfo_map.emplace(var_name, it->second);
}
}

static void TakeVarInfoFromMainGraph(
const Name2VarInfoMap& src_varinfo_map,
const MemOptVarInfoMapList& varinfo_maps,
details::EagerDeletionOpHandle* eager_deletion_op) {
const Name2VarInfoMap& dst_varinfo_map =
varinfo_maps.at(eager_deletion_op->GetScopeIdx());
for (auto&& var_name : eager_deletion_op->VarsToDelete()) {
auto dst_it = dst_varinfo_map.find(var_name);
PADDLE_ENFORCE_NE(dst_it, dst_varinfo_map.end(),
platform::errors::NotFound(
"MemOptVarInfo of var[%s] not found", var_name));
auto src_it = src_varinfo_map.find(var_name);
if (src_it != src_varinfo_map.end()) {
VLOG(4) << "MemOptVarInfo of var[" << var_name << "] set parent holder";
dst_it->second->SetParentHolder(src_it->second);
}
}
}

// This pass will be applied on both the main graph and all cinn subgraphs,
// and it distinguishs them according to whether the graph has the
// kMemOptVarInfoFromMainGraph attribute or not.
// On the main graph, it finds all cinn_launch ops and shares MemOptVarInfos
// to their subgraphs.
// On a cinn subgraph, it iterates each variable that will be deleted by a
// eager_deletion op, and take the MemOptVarInfo from the main graph
// if such one found.
class ShareMemOptInfoToSubGraphPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
const auto& varinfo_maps = Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList);

// the main graph
if (!graph->Has(paddle2cinn::kMemOptVarInfoFromMainGraph)) {
for (details::OpHandleBase* op : all_ops) {
auto compute_op = dynamic_cast<details::ComputationOpHandle*>(op);
if (compute_op && compute_op->Name() == "cinn_launch") {
ShareVarInfoToCinnLaunch(varinfo_maps, compute_op);
}
}
} else { // a cinn subgraph
const auto& parent_varinfo_map =
graph->Get<Name2VarInfoMap>(paddle2cinn::kMemOptVarInfoFromMainGraph);
for (details::OpHandleBase* op : all_ops) {
auto eager_deletion_op =
dynamic_cast<details::EagerDeletionOpHandle*>(op);
if (eager_deletion_op) {
TakeVarInfoFromMainGraph(parent_varinfo_map, varinfo_maps,
eager_deletion_op);
}
}
}
}
};

} // namespace paddle::framework::ir

REGISTER_PASS(share_varinfo_into_cinn_pass,
paddle::framework::ir::ShareMemOptInfoToSubGraphPass)
.RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList);
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/program_desc.h"

USE_OP(mul);
USE_OP(cinn_launch);
USE_OP(elementwise_add);
namespace paddle::framework {

using Name2VarInfoMap =
std::unordered_map<std::string, std::shared_ptr<ir::MemOptVarInfo>>;

static ProgramDesc BuildProgramInsideCinnLaunchOp() {
ProgramDesc program;
auto* block = program.MutableBlock(0);
block->Var("var1");
block->Var("var2");
block->Var("var3");
block->Var("var4");
block->Var("var5");

auto add_op = std::unique_ptr<OpDesc>(
new OpDesc("elementwise_add", {{"X", {"var1"}}, {"Y", {"var2"}}},
{{"Out", {"var3"}}}, {}));
block->AppendAllocatedOp(std::move(add_op));
auto mul_op = std::unique_ptr<OpDesc>(new OpDesc(
"mul", {{"X", {"var3"}}, {"Y", {"var4"}}}, {{"Out", {"var5"}}}, {}));
block->AppendAllocatedOp(std::move(mul_op));
return program;
}

static ProgramDesc BuildProgramWithCinnLaunchOp(
const std::string& compilation_key) {
// create a cinn_launch op
ProgramDesc program;
auto* block = program.MutableBlock(0);
block->Var("var1");
block->Var("var2");
block->Var("var4");
block->Var("var5");

auto cinn_launch_op = std::unique_ptr<OpDesc>(
new OpDesc("cinn_launch", {{"X", {"var1", "var2", "var4"}}},
{{"Out", {"var5"}}}, {{"compilation_key", compilation_key}}));
block->AppendAllocatedOp(std::move(cinn_launch_op));
return program;
}

struct TestPassContext {
explicit TestPassContext(const ProgramDesc& program) {
graph = std::make_unique<ir::Graph>(program);
details::BuildStrategy build_strategy;
details::ExecutionStrategy exec_strategy;
exec_strategy.use_device_ = paddle::platform::kCUDA;
executor.reset(new ParallelExecutor(platform::CUDAPlace(0), &scope,
exec_strategy, build_strategy,
graph.get()));
}

Scope scope;
std::unique_ptr<ir::Graph> graph;
std::unique_ptr<ParallelExecutor> executor;
};

TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_varinfo) {
// add a subgraph to CinnCompiler
auto subgraph = std::make_unique<ir::Graph>(BuildProgramInsideCinnLaunchOp());
subgraph->GetOrInit<Name2VarInfoMap>(
paddle2cinn::kMemOptVarInfoFromMainGraph);
std::string compilation_key =
paddle2cinn::CinnCompiler::GetInstance()->AddGraph(std::move(subgraph));

// build test data and apply pass
auto context = std::make_unique<TestPassContext>(
BuildProgramWithCinnLaunchOp(compilation_key));

// check result
const ir::Graph& result_subgraph =
paddle2cinn::CinnCompiler::GetInstance()->FindGraph(compilation_key);
const auto& dst_varinfo_map = result_subgraph.Get<Name2VarInfoMap>(
paddle2cinn::kMemOptVarInfoFromMainGraph);
ASSERT_EQ(dst_varinfo_map.size(), 4);
EXPECT_EQ(dst_varinfo_map.count("var1"), 1);
EXPECT_EQ(dst_varinfo_map.count("var5"), 1);
EXPECT_EQ(dst_varinfo_map.at("var1").use_count(), 2);
EXPECT_EQ(dst_varinfo_map.at("var5").use_count(), 2);
}

TEST(ShareMemInfoToSubGraphPassTest, test_subgraph_take_varinfo) {
// build test data and apply pass
auto context =
std::make_unique<TestPassContext>(BuildProgramInsideCinnLaunchOp());
auto& varinfo_map_shared = context->graph->GetOrInit<Name2VarInfoMap>(
paddle2cinn::kMemOptVarInfoFromMainGraph);
varinfo_map_shared = {
{"var1", std::make_shared<ir::MemOptVarInfo>("var1", 1)},
{"var2", std::make_shared<ir::MemOptVarInfo>("var2", 2)},
};

ir::MemOptVarInfoMapList varinfo_maps(1);
auto& dst_varinfo_map = varinfo_maps.front();
dst_varinfo_map = {{"var1", std::make_shared<ir::MemOptVarInfo>("var1", 1)},
{"var2", std::make_shared<ir::MemOptVarInfo>("var2", 1)},
{"var3", std::make_shared<ir::MemOptVarInfo>("var3", 1)},
{"var4", std::make_shared<ir::MemOptVarInfo>("var4", 1)},
{"var5", std::make_shared<ir::MemOptVarInfo>("var5", 1)}};
auto share_pass =
ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass");
share_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &varinfo_maps);
share_pass->Apply(context->graph.get());

// check result
ASSERT_NE(dst_varinfo_map.at("var1")->ParentHolder(), nullptr);
ASSERT_NE(dst_varinfo_map.at("var2")->ParentHolder(), nullptr);
ASSERT_EQ(dst_varinfo_map.at("var3")->ParentHolder(), nullptr);
ASSERT_EQ(dst_varinfo_map.at("var4")->ParentHolder(), nullptr);
ASSERT_EQ(dst_varinfo_map.at("var5")->ParentHolder(), nullptr);
}

} // namespace paddle::framework
Loading

0 comments on commit 35b03e1

Please sign in to comment.