diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 513e720fd099b..766bf0ab0c1c5 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -226,15 +226,15 @@ static bool has_fetch_operators( } void Executor::Run(const ProgramDesc& program, Scope* scope, - std::map& feed_targets, - std::map& fetch_targets, + std::map* feed_targets, + std::map* fetch_targets, bool create_vars, const std::string& feed_holder_name, const std::string& fetch_holder_name) { platform::RecordBlock b(kProgramId); bool has_feed_ops = - has_feed_operators(program.Block(0), feed_targets, feed_holder_name); + has_feed_operators(program.Block(0), *feed_targets, feed_holder_name); bool has_fetch_ops = - has_fetch_operators(program.Block(0), fetch_targets, fetch_holder_name); + has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name); ProgramDesc* copy_program = const_cast(&program); if (!has_feed_ops || !has_fetch_ops) { @@ -250,7 +250,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, feed_holder->SetPersistable(true); int i = 0; - for (auto& feed_target : feed_targets) { + for (auto& feed_target : (*feed_targets)) { std::string var_name = feed_target.first; VLOG(3) << "feed target's name: " << var_name; @@ -273,7 +273,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, fetch_holder->SetPersistable(true); int i = 0; - for (auto& fetch_target : fetch_targets) { + for (auto& fetch_target : (*fetch_targets)) { std::string var_name = fetch_target.first; VLOG(3) << "fetch target's name: " << var_name; @@ -361,16 +361,16 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void Executor::RunPreparedContext( ExecutorPrepareContext* ctx, Scope* scope, - std::map& feed_targets, - std::map& fetch_targets, bool create_vars, + std::map* feed_targets, + std::map* fetch_targets, bool create_vars, const std::string& feed_holder_name, const std::string& fetch_holder_name) { auto& global_block = ctx->prog_.Block(ctx->block_id_); PADDLE_ENFORCE( - has_feed_operators(global_block, feed_targets, feed_holder_name), + has_feed_operators(global_block, *feed_targets, feed_holder_name), "Program in ExecutorPrepareContext should has feed_ops."); PADDLE_ENFORCE( - has_fetch_operators(global_block, fetch_targets, fetch_holder_name), + has_fetch_operators(global_block, *fetch_targets, fetch_holder_name), "Program in the prepared context should has fetch_ops."); // map the data of feed_targets to feed_holder @@ -378,8 +378,8 @@ void Executor::RunPreparedContext( if (op->Type() == kFeedOpType) { std::string feed_target_name = op->Output("Out")[0]; int idx = boost::get(op->GetAttr("col")); - SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, - idx); + SetFeedVariable(scope, *(*feed_targets)[feed_target_name], + feed_holder_name, idx); } } @@ -390,7 +390,7 @@ void Executor::RunPreparedContext( if (op->Type() == kFetchOpType) { std::string fetch_target_name = op->Input("X")[0]; int idx = boost::get(op->GetAttr("col")); - *fetch_targets[fetch_target_name] = + *(*fetch_targets)[fetch_target_name] = GetFetchVariable(*scope, fetch_holder_name, idx); } } diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 43defdacf2a1c..4a3d637e2d79f 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -55,8 +55,8 @@ class Executor { bool create_local_scope = true, bool create_vars = true); void Run(const ProgramDesc& program, Scope* scope, - std::map& feed_targets, - std::map& fetch_targets, + std::map* feed_targets, + std::map* fetch_targets, bool create_vars = true, const std::string& feed_holder_name = "feed", const std::string& fetch_holder_name = "fetch"); @@ -74,8 +74,8 @@ class Executor { bool create_vars = true); void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, - std::map& feed_targets, - std::map& fetch_targets, + std::map* feed_targets, + std::map* fetch_targets, bool create_vars = true, const std::string& feed_holder_name = "feed", const std::string& fetch_holder_name = "fetch"); diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 117472599f7c4..af2a7a5620487 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -178,10 +178,10 @@ void TestInference(const std::string& dirname, std::unique_ptr ctx; if (PrepareContext) { ctx = executor.Prepare(*inference_program, 0); - executor.RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, - CreateVars); + executor.RunPreparedContext(ctx.get(), scope, &feed_targets, + &fetch_targets, CreateVars); } else { - executor.Run(*inference_program, scope, feed_targets, fetch_targets, + executor.Run(*inference_program, scope, &feed_targets, &fetch_targets, CreateVars); } @@ -197,10 +197,10 @@ void TestInference(const std::string& dirname, if (PrepareContext) { // Note: if you change the inference_program, you need to call // executor.Prepare() again to get a new ExecutorPrepareContext. - executor.RunPreparedContext(ctx.get(), scope, feed_targets, - fetch_targets, CreateVars); + executor.RunPreparedContext(ctx.get(), scope, &feed_targets, + &fetch_targets, CreateVars); } else { - executor.Run(*inference_program, scope, feed_targets, fetch_targets, + executor.Run(*inference_program, scope, &feed_targets, &fetch_targets, CreateVars); } }