Skip to content

Commit

Permalink
Merge pull request #9000 from reyoung/feature/extract_prepare_from_ex…
Browse files Browse the repository at this point in the history
…ecutor_run

Extract Prepare from Executor
  • Loading branch information
reyoung authored Mar 12, 2018
2 parents b628744 + 43d09a1 commit b52ad9d
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 70 deletions.
155 changes: 88 additions & 67 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ DEFINE_bool(check_nan_inf, false,
namespace paddle {
namespace framework {

struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {}

framework::ProgramDesc prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
};

Executor::Executor(const platform::Place& place) : place_(place) {}

static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
Expand Down Expand Up @@ -85,73 +94,9 @@ static void CheckTensorNANOrInf(const std::string& name,

void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), pdesc.Size());
auto& block = pdesc.Block(block_id);

Scope* local_scope = scope;
if (create_vars) {
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}

if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
} // if (create_local_scope)
} // if (create_vars)

for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);

VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);

if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_);
}
if (FLAGS_check_nan_inf) {
for (auto& vname : op->OutputVars(true)) {
auto* var = local_scope->FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
}
}
}
}
if (create_vars && create_local_scope) {
scope->DeleteScope(local_scope);
}
if (FLAGS_benchmark) {
VLOG(2) << "-------------------------------------------------------";
VLOG(2) << "Memory used after deleting local scope: "
<< memory::memory_usage(place_);
VLOG(2) << "-------------------------------------------------------";
}
auto* ctx = Prepare(pdesc, block_id);
RunPreparedContext(ctx, scope, create_local_scope, create_vars);
delete ctx;
}

// Check whether the block already has feed operators and feed_holder.
Expand Down Expand Up @@ -313,5 +258,81 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
delete copy_program;
}

ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
int block_id) {
auto* ctx = new ExecutorPrepareContext(program, block_id);
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
return ctx;
}

void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars) {
auto& block = ctx->prog_.Block(ctx->block_id_);

Scope* local_scope = scope;
if (create_vars) {
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}

if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
} // if (create_local_scope)
} // if (create_vars)

for (auto& op : ctx->ops_) {
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);

if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_);
}
if (FLAGS_check_nan_inf) {
for (auto& vname : op->OutputVars(true)) {
auto* var = local_scope->FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
}
}
}
}
if (create_vars && create_local_scope) {
scope->DeleteScope(local_scope);
}
if (FLAGS_benchmark) {
VLOG(2) << "-------------------------------------------------------";
VLOG(2) << "Memory used after deleting local scope: "
<< memory::memory_usage(place_);
VLOG(2) << "-------------------------------------------------------";
}
}

} // namespace framework
} // namespace paddle
13 changes: 10 additions & 3 deletions paddle/fluid/framework/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License. */

namespace paddle {
namespace framework {

struct ExecutorPrepareContext;
class Executor {
public:
// TODO(dzhwinter) : Do not rely on this function, it will be removed
Expand All @@ -38,15 +38,22 @@ class Executor {
* ProgramDesc
* Scope
*/
void Run(const ProgramDesc&, Scope*, int, bool create_local_scope = true,
bool create_vars = true);
void Run(const ProgramDesc& prog, Scope* scope, int block_id,
bool create_local_scope = true, bool create_vars = true);

void Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");

static ExecutorPrepareContext* Prepare(const ProgramDesc& program,
int block_id);

void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true,
bool create_vars = true);

private:
const platform::Place place_;
};
Expand Down

0 comments on commit b52ad9d

Please sign in to comment.