From 26ec8b15abae3ae3abd6ccce20c400056f21bd84 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 14 Sep 2021 17:00:37 +0800 Subject: [PATCH] Refactor StreamAnalyzer and EventManager from InterpreterCore (#35711) --- .../framework/new_executor/CMakeLists.txt | 4 +- .../framework/new_executor/event_manager.cc | 58 +++++++ .../framework/new_executor/event_manager.h | 35 ++++ .../framework/new_executor/interpretercore.cc | 157 +----------------- .../framework/new_executor/interpretercore.h | 20 +-- .../new_executor/interpretercore_util.h | 3 - .../new_executor/new_executor_defs.h | 5 + .../framework/new_executor/stream_analyzer.cc | 125 ++++++++++++++ .../framework/new_executor/stream_analyzer.h | 52 ++++++ 9 files changed, 288 insertions(+), 171 deletions(-) create mode 100644 paddle/fluid/framework/new_executor/event_manager.cc create mode 100644 paddle/fluid/framework/new_executor/event_manager.h create mode 100644 paddle/fluid/framework/new_executor/stream_analyzer.cc create mode 100644 paddle/fluid/framework/new_executor/stream_analyzer.h diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index e4a67f191b023..09744bf60032e 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -5,7 +5,9 @@ graph_to_program_pass variable_helper timer monitor) cc_library(workqueue SRCS workqueue.cc DEPS enforce) cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS}) cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS}) -cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_garbage_collector) +cc_library(event_manager SRCS event_manager.cc DEPS ${DEVICE_EVENT_LIBS} glog) +cc_library(stream_analyzer SRCS stream_analyzer.cc DEPS ${DEVICE_EVENT_LIBS} glog device_context) +cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_garbage_collector stream_analyzer event_manager) cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore) cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue) # cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) diff --git a/paddle/fluid/framework/new_executor/event_manager.cc b/paddle/fluid/framework/new_executor/event_manager.cc new file mode 100644 index 0000000000000..a3eb1abaa6127 --- /dev/null +++ b/paddle/fluid/framework/new_executor/event_manager.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/event_manager.h" + +namespace paddle { +namespace framework { + +void EventManager::WaitEvent(const Instruction& instruction, + const platform::Place& place) { + // If InterpreterCore in on CPUPlace, do nothing. + if (platform::is_cpu_place(place)) return; + + VLOG(3) << "Deal StreamWaitEventOrSync for " + << instruction.kernel_func_.operator_base_->Type(); + auto* dev_ctx = instruction.dev_ctx_; + + WaitOrSync(instruction.intput_events_, dev_ctx); +} + +void EventManager::RecordEvent(const Instruction& instruction, + const OpFuncNode& op_func_node, + const platform::Place& place) { + // If InterpreterCore in on CPUPlace, do nothing. + if (platform::is_cpu_place(place)) return; + + for (auto& event : instruction.output_events_) { + VLOG(3) << "Record event in out_var_id: " << event.var_id_; + event.event_->Record(instruction.dev_ctx_); + } +} + +void EventManager::WaitOrSync(const std::vector& events, + const platform::DeviceContext* dev_ctx) { + for (auto& event_iter : events) { + if (event_iter.is_sync_) { + VLOG(3) << "host sync wait in_var_id " << event_iter.var_id_; + event_iter.event_->Wait(platform::kCPU, dev_ctx); + } else { + VLOG(3) << "stream async wait in_var_id " << event_iter.var_id_; + event_iter.event_->Wait(platform::kCUDA, dev_ctx); + } + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/event_manager.h b/paddle/fluid/framework/new_executor/event_manager.h new file mode 100644 index 0000000000000..a2f7b52732ee2 --- /dev/null +++ b/paddle/fluid/framework/new_executor/event_manager.h @@ -0,0 +1,35 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/new_executor/new_executor_defs.h" + +namespace paddle { +namespace framework { + +class EventManager { + public: + void RecordEvent(const Instruction& instruction, + const OpFuncNode& op_func_node, + const platform::Place& place); + + void WaitEvent(const Instruction& instruction, const platform::Place& place); + + private: + void WaitOrSync(const std::vector& events, + const platform::DeviceContext* dev_ctx); +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 864a0a45366f3..d6f305acfb875 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -20,101 +20,6 @@ namespace paddle { namespace framework { -namespace { - -/* - * Parse the var_ids that need to be associated with an event. - * The caller should guarantee front_op and back_op satisfy the - * following conditions: - * 1. kQueueAsync -> kQueueAsync - * 2. kQueueAsync -> kQueueSync - * - * For example: matmul(gpu) -> out_var -> memcpy_d2h - * out_var should be associated with an event. - */ -std::vector ParseEventVarIds(const Instruction& cur_instr, - const Instruction& next_instr) { - std::unordered_set unique_var_ids; - for (auto& item : cur_instr.output_index_) { - unique_var_ids.insert(item.second.begin(), item.second.end()); - } - - std::vector new_event_var_ids; - for (auto& item : next_instr.input_index_) { - for (auto var_id : item.second) { - if (unique_var_ids.count(var_id) > 0) { - new_event_var_ids.push_back(var_id); - } - } - } - return new_event_var_ids; -} - -void AssociateInputWithEvents( - const platform::Place& place, const std::vector& new_event_var_id, - Instruction* next_instr, - std::map>* var_id2event, - bool is_sync) { - for (auto var_id : new_event_var_id) { - if (var_id2event->count(var_id) == 0) { - auto device_event = std::make_shared( - place, platform::GenerateDeviceEventFlag()); - var_id2event->emplace(var_id, std::move(device_event)); - } - // Add events for next_instr.inputs - next_instr->intput_events_.emplace_back(var_id, var_id2event->at(var_id), - is_sync); - } -} - -void ParseDirectAndEventRunOps( - const platform::Place& place, const std::vector& op_func_nodes, - const std::vector& downstream_ops, size_t op_index, - std::map>* var_id2event, - std::vector* instructions) { - auto& op_func_type = op_func_nodes[op_index].type_; - auto& cur_instr = instructions->at(op_index); - auto& next_instruction = cur_instr.next_instruction_; - - if (op_func_type == OpFuncType::kQueueSync) { - // all downstream ops of kQueueSync can directly run, such as CPU -> Any - next_instruction.direct_run_ = downstream_ops; - } else { // kQueueAsync - std::vector event_var_ids; - for (auto next_op_id : downstream_ops) { - auto& next_instr = instructions->at(next_op_id); - // case 1: GPU -> GPU(same stream) - if (cur_instr.dev_ctx_ == next_instr.dev_ctx_) { - next_instruction.direct_run_.emplace_back(next_op_id); - continue; - } - // Always insert events between different stream - auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr); - event_var_ids.insert(event_var_ids.end(), new_event_var_ids.begin(), - new_event_var_ids.end()); - - bool is_sync = - (op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync); - AssociateInputWithEvents(place, new_event_var_ids, &next_instr, - var_id2event, is_sync); - - if (is_sync) { // GPU -> CPU - next_instruction.synchronize_run_.emplace_back(next_op_id); - } else { // GPU -> GPU(different stream) - next_instruction.event_wait_run_.emplace_back(next_op_id); - } - } - // Create events for these cross-stream vars - VLOG(3) << cur_instr.kernel_func_.operator_base_->Type() - << " event_var_ids.size: " << event_var_ids.size(); - for (auto var_id : event_var_ids) { - cur_instr.output_events_.emplace_back(var_id, var_id2event->at(var_id), - false /*not used*/); - } - } -} -} // namespace - InterpreterCore::InterpreterCore(const platform::Place& place, const ProgramDesc& main_prog, VariableScope* global_scope, @@ -123,8 +28,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place, : place_(place), main_program_(main_prog), global_scope_(global_scope), - d2h_ctx_pool_({place}), - h2d_ctx_pool_({place}) { + stream_analyzer_(place) { is_build_ = false; feed_names_ = feed_names; @@ -199,7 +103,7 @@ void InterpreterCore::Convert() { Instruction temp_inst; auto* op_base = op_list_[i]; temp_inst.dev_ctx_ = - ParseDeviceContextForInstruction(vec_func_list_[i], *op_base); + stream_analyzer_.ParseDeviceContext(vec_func_list_[i], *op_base); temp_inst.kernel_func_.compute_func_ = vec_func_list_[i].kernel_func_; temp_inst.kernel_func_.operator_base_ = op_base; temp_inst.input_index_ = vec_func_list_[i].input_index; @@ -270,8 +174,8 @@ void InterpreterCore::Convert() { } } - ParseDirectAndEventRunOps(place_, vec_func_list_, filter_next, i, - &var_id2event_, &vec_instruction_); + stream_analyzer_.Schedule(vec_func_list_, filter_next, i, + &vec_instruction_); for (auto inst_id : filter_next) { dependecy_count_[inst_id]++; @@ -361,7 +265,7 @@ void InterpreterCore::ExecuteInstructionList( working_queue.pop(); auto& instr_node = vec_instr[instr_id]; // step1 : stream_wait (non-block host) or sync (block host) - StreamWaitEventOrSync(instr_node); + event_manager_.WaitEvent(instr_node, place_); // step2: run instruction RunInstruction(instr_node); ++run_op_number; @@ -371,7 +275,7 @@ void InterpreterCore::ExecuteInstructionList( } // step3: insert event for out_vars if needed - RecordEventInstruction(instr_node, vec_func_list_[instr_id]); + event_manager_.RecordEvent(instr_node, vec_func_list_[instr_id], place_); // step4: update working_queue auto& next_instr = instr_node.next_instruction_.all_next_ops_; @@ -450,54 +354,5 @@ const CostInfo& InterpreterCore::DryRun( return dry_run_profiler_.GetCostInfo(); } -platform::DeviceContext* InterpreterCore::ParseDeviceContextForInstruction( - const OpFuncNode& op_func_node, const OperatorBase& op_base) { - auto& op_type = op_base.Type(); - auto* dev_ctx = op_func_node.dev_ctx_; - if (op_type == interpretercore::kMemcpyH2D) { - VLOG(3) << "Get dev_ctx from d2h_context_pool_"; - dev_ctx = d2h_ctx_pool_.Get(place_); - } else if (op_type == interpretercore::kMemcpyD2H) { - VLOG(3) << "Get dev_ctx from h2d_context_pool_"; - dev_ctx = h2d_ctx_pool_.Get(place_); - } - - return dev_ctx; -} - -void InterpreterCore::RecordEventInstruction(const Instruction& instruction, - const OpFuncNode& op_func_node) { - // If InterpreterCore in on CPUPlace, do nothing. - if (platform::is_cpu_place(place_)) return; - - for (auto& event : instruction.output_events_) { - VLOG(3) << "Record event in out_var_id: " << event.var_id_; - event.event_->Record(instruction.dev_ctx_); - } -} - -void InterpreterCore::WaitOrSync(const std::vector& events, - const platform::DeviceContext* dev_ctx) { - for (auto& event_iter : events) { - if (event_iter.is_sync_) { - VLOG(3) << "host sync wait in_var_id " << event_iter.var_id_; - event_iter.event_->Wait(platform::kCPU, dev_ctx); - } else { - VLOG(3) << "stream async wait in_var_id " << event_iter.var_id_; - event_iter.event_->Wait(platform::kCUDA, dev_ctx); - } - } -} - -void InterpreterCore::StreamWaitEventOrSync(const Instruction& instruction) { - // If InterpreterCore in on CPUPlace, do nothing. - if (platform::is_cpu_place(place_)) return; - - VLOG(3) << "Deal StreamWaitEventOrSync for " - << instruction.kernel_func_.operator_base_->Type(); - auto* dev_ctx = instruction.dev_ctx_; - - WaitOrSync(instruction.intput_events_, dev_ctx); -} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index fef2c47bac2e8..76d005aee7e99 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -19,10 +19,12 @@ #include #include +#include "paddle/fluid/framework/new_executor/event_manager.h" #include "paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h" #include "paddle/fluid/framework/new_executor/interpretercore_util.h" #include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/framework/new_executor/profiler.h" +#include "paddle/fluid/framework/new_executor/stream_analyzer.h" #include "paddle/fluid/framework/new_executor/workqueue.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/tensor.h" @@ -64,17 +66,6 @@ class InterpreterCore { const VariableScope& var_scope, const platform::Place& place, std::vector& working_var_ref); // NOLINT - platform::DeviceContext* ParseDeviceContextForInstruction( - const OpFuncNode& op_func_node, const OperatorBase& op_base); - - void RecordEventInstruction(const Instruction& instruction, - const OpFuncNode& op_func_node); - - void WaitOrSync(const std::vector& events, - const platform::DeviceContext* dev_ctx); - - void StreamWaitEventOrSync(const Instruction& instruction); - void AddFetch(const std::vector& fetch_names); bool is_build_; @@ -83,9 +74,6 @@ class InterpreterCore { ProgramDesc main_program_; VariableScope* global_scope_; - platform::DeviceContextPool d2h_ctx_pool_; - platform::DeviceContextPool h2d_ctx_pool_; - std::vector vec_instruction_; InstructionInfo instruction_info_; std::vector dependecy_count_; @@ -99,8 +87,8 @@ class InterpreterCore { std::vector feed_names_; InterpreterProfiler dry_run_profiler_; - - std::map> var_id2event_; + StreamAnalyzer stream_analyzer_; + EventManager event_manager_; InterpreterCoreGarbageCollector gc_; std::vector gc_event_; diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index db7f35fb7ce86..db21a0ebca4da 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -476,9 +476,6 @@ class RuntimeInferShapeContext : public InferShapeContext { namespace interpretercore { -static constexpr char kMemcpyH2D[] = "memcpy_h2d"; -static constexpr char kMemcpyD2H[] = "memcpy_d2h"; - std::string get_memcpy_type(const platform::Place& src_place, const platform::Place& dst_place); diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 0b0148e6baddb..c08104dd95882 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -25,6 +25,11 @@ namespace paddle { namespace framework { +namespace interpretercore { +static constexpr char kMemcpyH2D[] = "memcpy_h2d"; +static constexpr char kMemcpyD2H[] = "memcpy_d2h"; +} // namespace interpretercore + using OpKernelComputeFunc = std::function; using OpKernelMap = std::unordered_map; diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.cc b/paddle/fluid/framework/new_executor/stream_analyzer.cc new file mode 100644 index 0000000000000..13bbda0f31f42 --- /dev/null +++ b/paddle/fluid/framework/new_executor/stream_analyzer.cc @@ -0,0 +1,125 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/stream_analyzer.h" +#include + +namespace paddle { +namespace framework { + +/* + * Parse the var_ids that need to be associated with an event. + * The caller should guarantee front_op and back_op satisfy the + * following conditions: + * 1. kQueueAsync -> kQueueAsync + * 2. kQueueAsync -> kQueueSync + * + * For example: matmul(gpu) -> out_var -> memcpy_d2h + * out_var should be associated with an event. + */ +std::vector StreamAnalyzer::ParseEventVarIds( + const Instruction& cur_instr, const Instruction& next_instr) { + std::unordered_set unique_var_ids; + for (auto& item : cur_instr.output_index_) { + unique_var_ids.insert(item.second.begin(), item.second.end()); + } + + std::vector new_event_var_ids; + for (auto& item : next_instr.input_index_) { + for (auto var_id : item.second) { + if (unique_var_ids.count(var_id) > 0) { + new_event_var_ids.push_back(var_id); + } + } + } + return new_event_var_ids; +} + +void StreamAnalyzer::AssociateInputWithEvents( + const std::vector& new_event_var_id, Instruction* next_instr, + bool is_sync) { + for (auto var_id : new_event_var_id) { + if (var_id2event_.count(var_id) == 0) { + auto device_event = std::make_shared( + place_, platform::GenerateDeviceEventFlag()); + var_id2event_.emplace(var_id, std::move(device_event)); + } + // Add events for next_instr.inputs + next_instr->intput_events_.emplace_back(var_id, var_id2event_.at(var_id), + is_sync); + } +} + +void StreamAnalyzer::Schedule(const std::vector& op_func_nodes, + const std::vector& downstream_ops, + size_t op_index, + std::vector* instructions) { + auto& op_func_type = op_func_nodes[op_index].type_; + auto& cur_instr = instructions->at(op_index); + auto& next_instruction = cur_instr.next_instruction_; + + if (op_func_type == OpFuncType::kQueueSync) { + // all downstream ops of kQueueSync can directly run, such as CPU -> Any + next_instruction.direct_run_ = downstream_ops; + } else { // kQueueAsync + std::vector event_var_ids; + for (auto next_op_id : downstream_ops) { + auto& next_instr = instructions->at(next_op_id); + // case 1: GPU -> GPU(same stream) + if (cur_instr.dev_ctx_ == next_instr.dev_ctx_) { + next_instruction.direct_run_.emplace_back(next_op_id); + continue; + } + // Always insert events between different stream + auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr); + event_var_ids.insert(event_var_ids.end(), new_event_var_ids.begin(), + new_event_var_ids.end()); + + bool is_sync = + (op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync); + AssociateInputWithEvents(new_event_var_ids, &next_instr, is_sync); + + if (is_sync) { // GPU -> CPU + next_instruction.synchronize_run_.emplace_back(next_op_id); + } else { // GPU -> GPU(different stream) + next_instruction.event_wait_run_.emplace_back(next_op_id); + } + } + // Create events for these cross-stream vars + VLOG(3) << cur_instr.kernel_func_.operator_base_->Type() + << " event_var_ids.size: " << event_var_ids.size(); + for (auto var_id : event_var_ids) { + cur_instr.output_events_.emplace_back(var_id, var_id2event_.at(var_id), + false /*not used*/); + } + } +} + +platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( + const OpFuncNode& op_func_node, const OperatorBase& op_base) { + auto& op_type = op_base.Type(); + auto* dev_ctx = op_func_node.dev_ctx_; + if (op_type == interpretercore::kMemcpyH2D) { + VLOG(3) << "Get dev_ctx from d2h_context_pool_"; + dev_ctx = d2h_ctx_pool_.Get(place_); + } else if (op_type == interpretercore::kMemcpyD2H) { + VLOG(3) << "Get dev_ctx from h2d_context_pool_"; + dev_ctx = h2d_ctx_pool_.Get(place_); + } + + return dev_ctx; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.h b/paddle/fluid/framework/new_executor/stream_analyzer.h new file mode 100644 index 0000000000000..ee94c21fc529a --- /dev/null +++ b/paddle/fluid/framework/new_executor/stream_analyzer.h @@ -0,0 +1,52 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/framework/new_executor/new_executor_defs.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/device_event.h" + +namespace paddle { +namespace framework { + +class StreamAnalyzer { + public: + explicit StreamAnalyzer(const platform::Place& place) + : place_(place), d2h_ctx_pool_({place}), h2d_ctx_pool_({place}) {} + + ~StreamAnalyzer() {} + + void Schedule(const std::vector& op_func_nodes, + const std::vector& downstream_ops, size_t op_index, + std::vector* instructions); + + platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node, + const OperatorBase& op_base); + + private: + std::vector ParseEventVarIds(const Instruction& cur_instr, + const Instruction& next_instr); + + void AssociateInputWithEvents(const std::vector& new_event_var_id, + Instruction* next_instr, bool is_sync); + platform::Place place_; + platform::DeviceContextPool d2h_ctx_pool_; + platform::DeviceContextPool h2d_ctx_pool_; + std::map> var_id2event_; +}; + +} // namespace framework +} // namespace paddle