Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fleet Executor] Modified python cache strategy to support multi carriers #38839

Merged
merged 3 commits into from
Jan 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"

namespace paddle {
namespace distributed {
Expand All @@ -43,18 +45,24 @@ void Carrier::Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
const framework::ProgramDesc& program, framework::Scope* scope,
int64_t num_micro_batches, const platform::Place& place) {
rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_id_to_node_ = interceptor_id_to_node;
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
place_ = place;
root_scope_ = root_scope;
root_scope_ = scope;
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);

PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument(
"root_scope can not be nullptr"));
minibatch_scope_ = &root_scope_->NewScope();
microbatch_scopes_.resize(num_micro_batches);
for (int i = 0; i < num_micro_batches; ++i) {
microbatch_scopes_[i] = &minibatch_scope_->NewScope();
CopyParameters(i, program);
}

// TODO(fleet_exe dev): thread pool
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
Expand All @@ -64,10 +72,33 @@ void Carrier::Init(
is_init_ = true;
}

void Carrier::Release() {}
void Carrier::Release() {
if (root_scope_) {
root_scope_->DropKids();
}
}

Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }

void Carrier::CopyParameters(int microbatch_id,
const framework::ProgramDesc& program) {
auto& global_block = program.Block(0);

for (auto& var : global_block.AllVars()) {
if (var->Persistable() && microbatch_id == 0) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(5) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr;
} else if (!var->Persistable()) {
auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr << ".";
InitializeVariable(ptr, var->GetType());
}
}
}

bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) {
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -116,6 +147,15 @@ void Carrier::Start() {
// TODO(wangxi): async step
Wait();
dev_ctx_->Wait();
for (auto* micro_scope : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
// But when while_op also create a local executor to run it's sub block,
// the sub scopes it created should not be dropped immediately, because
// while_grad_op will use some variables created during while_op run, so
// we need to keep the kids and wait for the outer executor to drop them.
micro_scope->DropKids();
}
}

bool Carrier::IsInit() const { return is_init_; }
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
namespace paddle {
namespace framework {
class Scope;
class ProgramDesc;
}

namespace distributed {
Expand All @@ -55,9 +56,10 @@ class Carrier final {
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);
const framework::ProgramDesc& program, framework::Scope* scope,
int64_t num_micro_batches, const platform::Place& place);

void CopyParameters(int microbatch_id, const framework::ProgramDesc& program);

void Release();
void Wait();
Expand Down
55 changes: 8 additions & 47 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"

namespace paddle {
namespace distributed {
Expand All @@ -38,7 +36,6 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
}

FleetExecutor::~FleetExecutor() {
root_scope_->DropKids();
for (const auto& carrier_id : carrier_ids_) {
GlobalMap<std::string, Carrier>::Get(carrier_id)->Release();
}
Expand All @@ -47,7 +44,7 @@ FleetExecutor::~FleetExecutor() {
void FleetExecutor::Init(
const std::string& carrier_id, const framework::ProgramDesc& program_desc,
framework::Scope* scope, const platform::Place& place,
const std::vector<TaskNode*>& task_nodes,
int64_t num_micro_batches, const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank) {
PADDLE_ENFORCE_GT(task_nodes.size(), 0,
platform::errors::InvalidArgument(
Expand All @@ -72,31 +69,23 @@ void FleetExecutor::Init(
for (auto& unique_op : ops) {
unique_op.release();
}
root_scope_ = scope;
place_ = place;
PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument(
"root_scope_ can not be nullptr"));
minibatch_scope_ = &root_scope_->NewScope();
int64_t num_micro_batches = exe_desc_.num_micro_batches();
microbatch_scopes_.resize(num_micro_batches);
for (int i = 0; i < num_micro_batches; ++i) {
microbatch_scopes_[i] = &minibatch_scope_->NewScope();
CopyParameters(i, program_desc);
}
VLOG(5) << runtime_graph_->DebugString();
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier_ids_.insert(carrier_id);
// Set current running carrier
GlobalVal<std::string>::Set(new std::string(carrier_id));
InitCarrier(carrier);
InitCarrier(carrier, scope, place, num_micro_batches, program_desc);
GlobalVal<MessageBus>::Get()->Barrier();
}

void FleetExecutor::InitCarrier(Carrier* carrier) {
void FleetExecutor::InitCarrier(Carrier* carrier, framework::Scope* scope,
const platform::Place& place,
int64_t num_micro_batches,
const framework::ProgramDesc& program_desc) {
carrier->Init(exe_desc_.cur_rank(), runtime_graph_->interceptor_id_to_rank(),
runtime_graph_->interceptor_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
runtime_graph_->interceptor_id_to_node(), program_desc, scope,
num_micro_batches, place);
}

void FleetExecutor::InitMessageBus() {
Expand Down Expand Up @@ -140,34 +129,6 @@ void FleetExecutor::Run(const std::string& carrier_id) {
GlobalVal<MessageBus>::Get()->Barrier();
}
carrier->Start();
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
// But when while_op also create a local executor to run it's sub block,
// the sub scopes it created should not be dropped immediately, because
// while_grad_op will use some variables created during while_op run, so
// we need to keep the kids and wait for the outer executor to drop them.
micro_scop->DropKids();
}
}

void FleetExecutor::CopyParameters(int microbatch_id,
const framework::ProgramDesc& program) {
auto& global_block = program.Block(0);

for (auto& var : global_block.AllVars()) {
if (var->Persistable() && microbatch_id == 0) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(5) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr;
} else if (!var->Persistable()) {
auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr << ".";
InitializeVariable(ptr, var->GetType());
}
}
}

} // namespace distributed
Expand Down
11 changes: 4 additions & 7 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,19 @@ class FleetExecutor final {
~FleetExecutor();
void Init(const std::string& carrier_id,
const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place,
const platform::Place& place, int64_t num_micro_batches,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run(const std::string& carrier_id);

private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
void InitMessageBus();
void InitCarrier(Carrier* carrier);
void CopyParameters(int microbatch_id, const framework::ProgramDesc& program);
void InitCarrier(Carrier* carrier, framework::Scope* scope,
const platform::Place& place, int64_t num_micro_batches,
const framework::ProgramDesc& program_desc);
FleetExecutorDesc exe_desc_;
std::shared_ptr<RuntimeGraph> runtime_graph_;
framework::Scope* root_scope_;
framework::Scope* minibatch_scope_;
platform::Place place_;
std::vector<framework::Scope*> microbatch_scopes_;
std::unordered_set<std::string> carrier_ids_;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,4 @@ message RankInfo {
message FleetExecutorDesc {
optional int64 cur_rank = 1 [ default = 0 ]; // Rank id of current processor
repeated RankInfo cluster_info = 2;
optional int64 num_micro_batches = 3 [ default = 1 ];
}
70 changes: 39 additions & 31 deletions python/paddle/fluid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,23 @@ def _is_enable_standalone_executor():
return flag


def _prepare_fleet_executor():
from ..distributed.fleet.proto import fleet_executor_desc_pb2
trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS", "")
trainer_endpoints = trainer_endpoints_str.split(',')
fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc()
cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
fleet_exe_desc.cur_rank = cur_rank
nrank = len(trainer_endpoints)
for rank, endpoint in enumerate(trainer_endpoints):
rank_info = fleet_executor_desc_pb2.RankInfo()
rank_info.rank = rank
rank_info.ip_port = endpoint
fleet_exe_desc.cluster_info.append(rank_info)
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
return fleet_exe


def _get_strong_program_cache_key(program, feed, fetch_list):
# NOTE(xiongkun) id(proram) may be duplicate. So add addition var_name as cache key.
def _get_varname_from_block(block):
Expand Down Expand Up @@ -692,6 +709,8 @@ def __init__(self, place=None):
self._enable_interpreter_core = _is_enable_standalone_executor()
self._executor_cache = _ExecutorCache(self.place)

self._fleet_executor = None

def _get_scope_cache(self, program_cache_key):
return self.scope_caches.get(program_cache_key, None)

Expand Down Expand Up @@ -1281,6 +1300,9 @@ def _run_impl(self, program, feed, fetch_list, feed_var_name,

if isinstance(program, Program) and program._pipeline_opt:
if "fleet_opt" in program._pipeline_opt:
# Move prepare here for port conflict with nccl in startup program
if self._fleet_executor is None:
self._fleet_executor = _prepare_fleet_executor()
return self._run_using_fleet_executor(
program=program, feed=feed, fetch_list=fetch_list)
if "startup_program" in program._pipeline_opt:
Expand Down Expand Up @@ -1960,27 +1982,16 @@ def _get_real_program_fetch_list():

return ctx

def _prepare_fleet_executor(self,
carrier_id="",
program=None,
scope=None,
fleet_opt=None):
from ..distributed.fleet.proto import fleet_executor_desc_pb2
assert program, "Program for fleet executor should not be None"
assert fleet_opt, "Configurations for fleet executor should not be None"
trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS", "")
trainer_endpoints = trainer_endpoints_str.split(',')
fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc()
def _prepare_fleet_executor_carrier(self,
carrier_id="",
program=None,
scope=None,
fleet_opt=None):
num_micro_batches = fleet_opt[
"num_micro_batches"] if "num_micro_batches" in fleet_opt else 1
cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
fleet_exe_desc.cur_rank = cur_rank
trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',')
nrank = len(trainer_endpoints)
for rank, endpoint in enumerate(trainer_endpoints):
rank_info = fleet_executor_desc_pb2.RankInfo()
rank_info.rank = rank
rank_info.ip_port = endpoint
fleet_exe_desc.cluster_info.append(rank_info)
if "num_micro_batches" in fleet_opt:
fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"]

assert 'scheduler' in fleet_opt or 'tasks' in fleet_opt, \
"Fleet executor need configuration for scheduler, you can choose from 1F1B or Origin. " \
Expand Down Expand Up @@ -2019,29 +2030,26 @@ def _prepare_fleet_executor(self,
# NOTE: have to hold these vars, otherwise will be destructed
fleet_opt['tasks'] = tasks
fleet_opt['task_id_to_rank'] = task_id_to_rank
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
place = core.Place()
place.set_place(self.place)
fleet_exe.init(carrier_id, program.desc, scope, place, tasks,
task_id_to_rank)
return fleet_exe
self._fleet_executor.init(carrier_id, program.desc, scope, place,
num_micro_batches, tasks, task_id_to_rank)

def _run_using_fleet_executor(self,
program=None,
feed=None,
feed_var_name="feed",
fetch_var_name="fetch",
fetch_list=None):
# TODO(liyurui): Change cache strategy for multi carriers
cache_key = _get_strong_program_cache_key(program, feed, fetch_list)
cached_ctx = self._get_ctx_cache(cache_key)
cached_scope = self._get_scope_cache(cache_key)
cached_program = self._get_program_cache(cache_key)
real_feed = [] if feed is None else feed
cached_scope = self._get_scope_cache(cache_key)
if cached_scope is None:
cached_scope = global_scope()
self._add_scope_cache(cache_key, cached_scope)
if cached_program is None:
assert program._pipeline_opt, "program should have _pipeline_opt to start carrier"
real_feed = [] if feed is None else feed
real_program = program
if "section_program" in program._pipeline_opt:
real_program = program._pipeline_opt["section_program"]
Expand All @@ -2060,7 +2068,6 @@ def _run_using_fleet_executor(self,
'op_role',
core.op_proto_and_checker_maker.OpRole.Optimize)
self._add_program_cache(cache_key, cached_program)
if cached_ctx is None:
fleet_opt = program._pipeline_opt["fleet_opt"]
if 'tasks' in fleet_opt:
# Insert feed/fetch op for cloned program in each task node,
Expand Down Expand Up @@ -2097,12 +2104,12 @@ def _run_using_fleet_executor(self,
core.op_proto_and_checker_maker.OpRole.Optimize)
fetch_task.set_program(fetch_program)

cached_ctx = self._prepare_fleet_executor(
self._prepare_fleet_executor_carrier(
cache_key,
program=cached_program,
scope=cached_scope,
fleet_opt=fleet_opt)
self._add_ctx_cache(cache_key, cached_ctx)

if feed:
# NOTE: don't have to traverse programs in task nodes,
# since they all sub program of cached program and
Expand All @@ -2120,7 +2127,8 @@ def _run_using_fleet_executor(self,
lr_sheduler._var_name)
tensor.set(data, self.place)

cached_ctx.run(cache_key)
self._fleet_executor.run(cache_key)

if fetch_list:
arr = cached_scope.find_var(fetch_var_name).get_fetch_list()
tensors = arr._move_to_list()
Expand Down