From 198d11bec8115efb37363010684476f27a818c51 Mon Sep 17 00:00:00 2001 From: liutiexing <74819124+liutiexing@users.noreply.github.com> Date: Thu, 23 Dec 2021 16:18:32 +0800 Subject: [PATCH] Upgrade work queue (#38335) * add align for WorkQueue * add spinlock * merge develop * merge * Add EventsWaiter * Revert "Add EventsWaiter" This reverts commit e206173aa9be7401b83a53581627bfaf557c8fb2. * update EventsWater * fix * split workqueue files * add more tests * fix * bugfix * bugfix * update Co-authored-by: liutiexing --- .../framework/new_executor/CMakeLists.txt | 4 +- .../framework/new_executor/interpretercore.cc | 3 +- .../framework/new_executor/interpretercore.h | 2 - .../interpretercore_garbage_collector.h | 2 +- .../new_executor/interpretercore_util.h | 10 +- .../new_executor/workqueue/CMakeLists.txt | 2 + .../{ => workqueue}/event_count.h | 4 + .../new_executor/workqueue/events_waiter.cc | 147 ++++++++++++++++++ .../new_executor/workqueue/events_waiter.h | 111 +++++++++++++ .../{ => workqueue}/nonblocking_threadpool.h | 37 +---- .../new_executor/{ => workqueue}/run_queue.h | 7 +- .../{ => workqueue}/thread_environment.h | 0 .../new_executor/{ => workqueue}/workqueue.cc | 46 ++++-- .../new_executor/{ => workqueue}/workqueue.h | 13 +- .../{ => workqueue}/workqueue_test.cc | 38 ++++- .../{ => workqueue}/workqueue_utils.cc | 59 +------ .../{ => workqueue}/workqueue_utils.h | 60 +++---- 17 files changed, 380 insertions(+), 165 deletions(-) create mode 100644 paddle/fluid/framework/new_executor/workqueue/CMakeLists.txt rename paddle/fluid/framework/new_executor/{ => workqueue}/event_count.h (98%) create mode 100644 paddle/fluid/framework/new_executor/workqueue/events_waiter.cc create mode 100644 paddle/fluid/framework/new_executor/workqueue/events_waiter.h rename paddle/fluid/framework/new_executor/{ => workqueue}/nonblocking_threadpool.h (94%) rename paddle/fluid/framework/new_executor/{ => workqueue}/run_queue.h (97%) rename paddle/fluid/framework/new_executor/{ => workqueue}/thread_environment.h (100%) rename paddle/fluid/framework/new_executor/{ => workqueue}/workqueue.cc (77%) rename paddle/fluid/framework/new_executor/{ => workqueue}/workqueue.h (87%) rename paddle/fluid/framework/new_executor/{ => workqueue}/workqueue_test.cc (73%) rename paddle/fluid/framework/new_executor/{ => workqueue}/workqueue_utils.cc (50%) rename paddle/fluid/framework/new_executor/{ => workqueue}/workqueue_utils.h (59%) diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index e21588da7fdd8..3a1ce59fba995 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -2,8 +2,9 @@ set(INTERPRETERCORE_DEPS op_registry device_context scope framework_proto data_f lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper timer monitor nan_inf_utils) +add_subdirectory(workqueue) + cc_library(data_transfer SRCS data_transfer.cc DEPS enforce scope glog) -cc_library(workqueue SRCS workqueue.cc workqueue_utils.cc DEPS enforce) cc_library(new_executor_defs SRCS new_executor_defs.cc DEPS enforce glog scope) cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS} executor_gc_helper) cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS} workqueue new_executor_defs data_transfer) @@ -11,7 +12,6 @@ cc_library(event_manager SRCS event_manager.cc DEPS ${DEVICE_EVENT_LIBS} glog ne cc_library(stream_analyzer SRCS stream_analyzer.cc DEPS ${DEVICE_EVENT_LIBS} glog device_context new_executor_defs) cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_garbage_collector stream_analyzer event_manager) cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore) -cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue) # skip win32 since wget is not installed by default on windows machine. # skip COVERAGE_CI since the test runs slowly because of instrumentation. diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index dcbdd12f88fb7..5a4caf6af441c 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -48,8 +48,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place, new interpreter::AsyncWorkQueue(kHostNumThreads, &main_thread_blocker_)); gc_.reset(new InterpreterCoreGarbageCollector()); - exception_notifier_ = main_thread_blocker_.RegisterEvent( - kExceptionCaught, [this]() { return exception_holder_.IsCaught(); }); + exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught); create_local_scope_ = FLAGS_new_executor_use_local_scope; if (FLAGS_new_executor_use_local_scope) { diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 656262d6381f6..93ac7c0294349 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -26,8 +26,6 @@ #include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/framework/new_executor/profiler.h" #include "paddle/fluid/framework/new_executor/stream_analyzer.h" -#include "paddle/fluid/framework/new_executor/workqueue.h" -#include "paddle/fluid/framework/new_executor/workqueue_utils.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable.h" diff --git a/paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h b/paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h index 166139a73c8f9..ffb22092701b8 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h +++ b/paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h @@ -23,7 +23,7 @@ #include #include -#include "paddle/fluid/framework/new_executor/workqueue.h" +#include "paddle/fluid/framework/new_executor/workqueue/workqueue.h" #include "paddle/fluid/memory/allocation/spin_lock.h" #include "paddle/fluid/platform/device_event.h" diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index 8f27c7e1811fb..14c27c94f8394 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -32,8 +32,8 @@ #include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/new_executor/new_executor_defs.h" -#include "paddle/fluid/framework/new_executor/workqueue.h" -#include "paddle/fluid/framework/new_executor/workqueue_utils.h" +#include "paddle/fluid/framework/new_executor/workqueue/workqueue.h" +#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" @@ -61,12 +61,14 @@ class AsyncWorkQueue { group_options.emplace_back(/*num_threads*/ host_num_threads, /*allow_spinning*/ true, /*track_task*/ true, - /*queue_empty_waiter*/ waiter); + /*detached*/ true, + /*events_waiter*/ waiter); // for launch device Kernel group_options.emplace_back(/*num_threads*/ 1, /*allow_spinning*/ true, /*track_task*/ true, - /*queue_empty_waiter*/ waiter); + /*detached*/ true, + /*events_waiter*/ waiter); queue_group_ = CreateWorkQueueGroup(group_options); } diff --git a/paddle/fluid/framework/new_executor/workqueue/CMakeLists.txt b/paddle/fluid/framework/new_executor/workqueue/CMakeLists.txt new file mode 100644 index 0000000000000..77130102d52e5 --- /dev/null +++ b/paddle/fluid/framework/new_executor/workqueue/CMakeLists.txt @@ -0,0 +1,2 @@ +cc_library(workqueue SRCS workqueue.cc workqueue_utils.cc events_waiter.cc DEPS enforce glog) +cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue) diff --git a/paddle/fluid/framework/new_executor/event_count.h b/paddle/fluid/framework/new_executor/workqueue/event_count.h similarity index 98% rename from paddle/fluid/framework/new_executor/event_count.h rename to paddle/fluid/framework/new_executor/workqueue/event_count.h index 7f1e3670056fc..893c6d2d54ac7 100644 --- a/paddle/fluid/framework/new_executor/event_count.h +++ b/paddle/fluid/framework/new_executor/workqueue/event_count.h @@ -41,6 +41,10 @@ // and won't block, or notifying thread will see state_ change and will unblock // the waiter, or both. But it can't happen that both threads don't see each // other changes, which would lead to deadlock. +// +// What changed by PaddlePaddle +// 1. Allocate aligned storage for Waiters to get better performance. +// 2. Replace Eigen utils with std utils. #pragma once diff --git a/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc b/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc new file mode 100644 index 0000000000000..ac45e7b5fdfe9 --- /dev/null +++ b/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc @@ -0,0 +1,147 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/workqueue/events_waiter.h" +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { + +EventsWaiter::EventsWaiter() + : trigger_event_(nullptr), counter_(0), waiting_(false), cv_(1) {} + +std::shared_ptr EventsWaiter::RegisterEvent( + const std::string& name, EventChecker checker) { + auto counter = counter_.fetch_add(1); + auto id = std::hash()(name + std::to_string(counter)); + VLOG(10) << "Register event id:" << id << " name:" << name; + auto notifier = std::shared_ptr(new EventNotifier(id, this)); + EventInfo evt{id, name, TriggerType::LevelTriggered, std::move(checker)}; + std::lock_guard guard(events_lock_); + events_[id] = std::move(evt); + return notifier; +} + +std::shared_ptr EventsWaiter::RegisterEvent( + const std::string& name) { + auto counter = counter_.fetch_add(1); + auto id = std::hash()(name + std::to_string(counter)); + VLOG(10) << "Register event id:" << id << " name:" << name; + auto notifier = std::shared_ptr(new EventNotifier(id, this)); + EventInfo evt{id, name, TriggerType::EdgeTriggered, []() { return false; }}; + std::lock_guard guard(events_lock_); + events_[id] = std::move(evt); + return notifier; +} + +void EventsWaiter::UnregisterEvent(const EventId& id) { + VLOG(10) << "Unregister event id:" << id; + std::lock_guard guard(events_lock_); + events_.erase(id); +} + +std::string EventsWaiter::WaitEvent() { + // only one user can wait at any time + bool waiting = false; + if (!waiting_.compare_exchange_strong(waiting, true, + std::memory_order_seq_cst, + std::memory_order_relaxed)) { + PADDLE_THROW( + platform::errors::ResourceExhausted("Another thread is waiting.")); + } + auto w = cv_.GetWaiter(0); + cv_.Prewait(); + std::string* triggered = trigger_event_; + if (triggered == nullptr) { + // checkers + { + std::lock_guard guard(events_lock_); + for (auto& kv : events_) { + auto& evt = kv.second; + if (TriggerType::LevelTriggered == evt.type && evt.checker()) { + triggered = new std::string(evt.name); + break; + } + } + } + if (triggered != nullptr) { + std::string* prev = nullptr; + if (!trigger_event_.compare_exchange_strong(prev, triggered, + std::memory_order_seq_cst, + std::memory_order_relaxed)) { + delete triggered; + triggered = prev; + } + } + } + if (triggered) { + cv_.CancelWait(); + } else { + cv_.CommitWait(w); + triggered = trigger_event_; + } + trigger_event_.store(nullptr, std::memory_order_relaxed); + waiting_.store(false); + auto trigger_event = *triggered; + delete triggered; + return trigger_event; +} + +int EventsWaiter::Clear() { + bool waiting = false; + if (!waiting_.compare_exchange_strong(waiting, true, + std::memory_order_seq_cst, + std::memory_order_relaxed)) { + return -1; + } + trigger_event_.store(nullptr, std::memory_order_relaxed); + waiting_.store(false); + return 0; +} + +void EventsWaiter::TriggerEvent(const EventId& id) { + VLOG(10) << "Try to trigger event id:" << id; + std::string* trigger_event = new std::string; + { + std::lock_guard guard(events_lock_); + auto iter = events_.find(id); + if (iter == events_.end()) { + delete trigger_event; + return; + } + *trigger_event = iter->second.name; + } + std::string* prev = nullptr; + if (!trigger_event_.compare_exchange_strong(prev, trigger_event, + std::memory_order_seq_cst, + std::memory_order_relaxed)) { + delete trigger_event; + return; + } + VLOG(10) << "Triggered event id:" << id << " name:" << *trigger_event; + cv_.Notify(true); +} + +std::string EventsWaiter::GetEventName(const EventId& id) { + std::lock_guard guard(events_lock_); + auto iter = events_.find(id); + if (iter == events_.end()) { + return "Unregistered"; + } + return iter->second.name; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/workqueue/events_waiter.h b/paddle/fluid/framework/new_executor/workqueue/events_waiter.h new file mode 100644 index 0000000000000..5ffed15155d59 --- /dev/null +++ b/paddle/fluid/framework/new_executor/workqueue/events_waiter.h @@ -0,0 +1,111 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/new_executor/workqueue/event_count.h" +#include "paddle/fluid/memory/allocation/spin_lock.h" + +namespace paddle { +namespace framework { + +// A multiplexing waiter, be able to wait multiple kinds of events +// simultaneously. +// Muti-Producer single-consumer single-slot message-queue. +class EventsWaiter { + public: + using EventId = std::size_t; + + using EventChecker = std::function; + + // Make sure EventsWaiter has a longer lifetime than EventNotifier. + class EventNotifier { + public: + void NotifyEvent() { waiter_.TriggerEvent(id_); } + + void UnregisterEvent() { waiter_.UnregisterEvent(id_); } + + EventId GetEventId() { return id_; } + + // return "Unregistered" if the corresponding event was unregistered. + std::string GetEventName() { return waiter_.GetEventName(id_); } + + private: + friend EventsWaiter; + EventNotifier(EventId id, EventsWaiter* waiter) + : id_(id), waiter_(*waiter) {} + EventNotifier(const EventNotifier&) = delete; + void operator=(const EventNotifier&) = delete; + + EventId id_; + EventsWaiter& waiter_; + }; + + EventsWaiter(); + EventsWaiter(const EventsWaiter&) = delete; + EventsWaiter& operator=(const EventsWaiter&) = delete; + + // Register a level-triggered event. If the checker returns true or + // EventNotifier::NotifyEvent is called, the corresponding event will be + // distributed. + std::shared_ptr RegisterEvent(const std::string& name, + EventChecker checker); + + // Register an edge-triggered event. The corresponding event will be + // distributed when EventNotifier::NotifyEvent is called. + std::shared_ptr RegisterEvent(const std::string& name); + + void UnregisterEvent(const EventId& id); + + // Blocking the calling thread to wait any of the registered events. + std::string WaitEvent(); + + // Nonblocking. + // Clear the slot, no matter whether there is an event. + // Return value: + // -1 : another thread is waiting. + // 0 : succ. + int Clear(); + + private: + friend EventNotifier; + + enum class TriggerType { LevelTriggered, EdgeTriggered }; + + struct EventInfo { + EventId id; + std::string name; + TriggerType type; + EventChecker checker; + }; + + void TriggerEvent(const EventId& id); + + std::string GetEventName(const EventId& id); + + std::unordered_map events_; + paddle::memory::SpinLock events_lock_; + std::atomic trigger_event_; + std::atomic counter_; + std::atomic waiting_; + EventCount cv_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/nonblocking_threadpool.h b/paddle/fluid/framework/new_executor/workqueue/nonblocking_threadpool.h similarity index 94% rename from paddle/fluid/framework/new_executor/nonblocking_threadpool.h rename to paddle/fluid/framework/new_executor/workqueue/nonblocking_threadpool.h index cdcdbbb445185..37044d3c19b35 100644 --- a/paddle/fluid/framework/new_executor/nonblocking_threadpool.h +++ b/paddle/fluid/framework/new_executor/workqueue/nonblocking_threadpool.h @@ -12,43 +12,14 @@ #include #include #include -#include "paddle/fluid/framework/new_executor/event_count.h" -#include "paddle/fluid/framework/new_executor/run_queue.h" -#include "paddle/fluid/framework/new_executor/thread_environment.h" +#include "paddle/fluid/framework/new_executor/workqueue/event_count.h" +#include "paddle/fluid/framework/new_executor/workqueue/run_queue.h" +#include "paddle/fluid/framework/new_executor/workqueue/thread_environment.h" +#include "paddle/fluid/platform/os_info.h" namespace paddle { namespace framework { -template -class TaskTracker { - public: - TaskTracker() = default; - - explicit TaskTracker(Notifier& notifier) : notifier_(¬ifier) {} - - TaskTracker(const TaskTracker&) = delete; - - TaskTracker& operator=(const TaskTracker&) = delete; - - ~TaskTracker() = default; - - void AddCounter() { num_tasks_.fetch_add(1, std::memory_order_relaxed); } - - void SubCounter() { - if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) { - if (notifier_ != nullptr) { - notifier_->NotifyEvent(); - } - } - } - - uint64_t PendingTaskNum() { return num_tasks_.load(); } - - private: - alignas(64) std::atomic num_tasks_{0}; - Notifier* notifier_{nullptr}; -}; - template class ThreadPoolTempl { public: diff --git a/paddle/fluid/framework/new_executor/run_queue.h b/paddle/fluid/framework/new_executor/workqueue/run_queue.h similarity index 97% rename from paddle/fluid/framework/new_executor/run_queue.h rename to paddle/fluid/framework/new_executor/workqueue/run_queue.h index e457b20a3c35d..2fc42cf308ab8 100644 --- a/paddle/fluid/framework/new_executor/run_queue.h +++ b/paddle/fluid/framework/new_executor/workqueue/run_queue.h @@ -29,6 +29,11 @@ // separate state variable as null/non-null pointer value would serve as state, // but that would require malloc/free per operation for large, complex values // (and this is designed to store std::function<()>). +// +// What changed by PaddlePaddle +// 1. Use paddle::memory::SpinLock instead of std::mutex to protect back_. +// 2. Make front_/back_ aligned to get better performance. +// 3. Replace Eigen utils with std utils. #pragma once @@ -37,7 +42,7 @@ #include #include #include -#include "paddle/fluid/framework/new_executor/workqueue_utils.h" +#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h" #include "paddle/fluid/memory/allocation/spin_lock.h" namespace paddle { diff --git a/paddle/fluid/framework/new_executor/thread_environment.h b/paddle/fluid/framework/new_executor/workqueue/thread_environment.h similarity index 100% rename from paddle/fluid/framework/new_executor/thread_environment.h rename to paddle/fluid/framework/new_executor/workqueue/thread_environment.h diff --git a/paddle/fluid/framework/new_executor/workqueue.cc b/paddle/fluid/framework/new_executor/workqueue/workqueue.cc similarity index 77% rename from paddle/fluid/framework/new_executor/workqueue.cc rename to paddle/fluid/framework/new_executor/workqueue/workqueue.cc index 7607b3a297f84..3f06f3db23118 100644 --- a/paddle/fluid/framework/new_executor/workqueue.cc +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue.cc @@ -4,9 +4,9 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. -#include "paddle/fluid/framework/new_executor/workqueue.h" -#include "paddle/fluid/framework/new_executor/nonblocking_threadpool.h" -#include "paddle/fluid/framework/new_executor/workqueue_utils.h" +#include "paddle/fluid/framework/new_executor/workqueue/workqueue.h" +#include "paddle/fluid/framework/new_executor/workqueue/nonblocking_threadpool.h" +#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -18,24 +18,35 @@ using TaskTracker = TaskTracker; class WorkQueueImpl : public WorkQueue { public: explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) { - if (options_.track_task && options.queue_empty_waiter != nullptr) { + if (options_.track_task && options.events_waiter != nullptr) { void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); TaskTracker* tracker = reinterpret_cast(storage); - auto notifier = options.queue_empty_waiter->RegisterEvent( + empty_notifier_ = options.events_waiter->RegisterEvent( kQueueEmptyEvent, [tracker]() { return tracker->PendingTaskNum() == 0; }); - tracker_ = new (storage) TaskTracker(*notifier.get()); + tracker_ = new (storage) TaskTracker(*empty_notifier_.get()); + } + if (options_.detached == false && options.events_waiter != nullptr) { + destruct_notifier_ = + options.events_waiter->RegisterEvent(kQueueDestructEvent); } queue_ = new NonblockingThreadPool(options_.num_threads, options_.allow_spinning); } virtual ~WorkQueueImpl() { + if (empty_notifier_) { + empty_notifier_->UnregisterEvent(); + } + delete queue_; if (tracker_ != nullptr) { tracker_->~TaskTracker(); AlignedFree(tracker_); } - delete queue_; + if (destruct_notifier_) { + destruct_notifier_->NotifyEvent(); + destruct_notifier_->UnregisterEvent(); + } } void AddTask(std::function fn) override { @@ -59,6 +70,8 @@ class WorkQueueImpl : public WorkQueue { private: NonblockingThreadPool* queue_{nullptr}; TaskTracker* tracker_{nullptr}; + std::shared_ptr empty_notifier_; + std::shared_ptr destruct_notifier_; }; class WorkQueueGroupImpl : public WorkQueueGroup { @@ -80,6 +93,8 @@ class WorkQueueGroupImpl : public WorkQueueGroup { std::vector queues_; NonblockingThreadPool* queues_storage_; TaskTracker* tracker_; + std::shared_ptr empty_notifier_; + std::shared_ptr destruct_notifier_; }; WorkQueueGroupImpl::WorkQueueGroupImpl( @@ -94,13 +109,17 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( for (size_t idx = 0; idx < num_queues; ++idx) { const auto& options = queues_options_[idx]; if (options.track_task && tracker_ == nullptr && - options.queue_empty_waiter != nullptr) { + options.events_waiter != nullptr) { void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); TaskTracker* tracker = reinterpret_cast(storage); - auto notifier = options.queue_empty_waiter->RegisterEvent( + empty_notifier_ = options.events_waiter->RegisterEvent( kQueueEmptyEvent, [tracker]() { return tracker->PendingTaskNum() == 0; }); - tracker_ = new (storage) TaskTracker(*notifier.get()); + tracker_ = new (storage) TaskTracker(*empty_notifier_.get()); + } + if (options.detached == false && options.events_waiter != nullptr) { + destruct_notifier_ = + options.events_waiter->RegisterEvent(kQueueDestructEvent); } queues_[idx] = new (&queues_storage_[idx]) NonblockingThreadPool(options.num_threads, options.allow_spinning); @@ -108,6 +127,9 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( } WorkQueueGroupImpl::~WorkQueueGroupImpl() { + if (empty_notifier_) { + empty_notifier_->UnregisterEvent(); + } for (auto queue : queues_) { queue->~NonblockingThreadPool(); } @@ -116,6 +138,10 @@ WorkQueueGroupImpl::~WorkQueueGroupImpl() { AlignedFree(tracker_); } free(queues_storage_); + if (destruct_notifier_) { + destruct_notifier_->NotifyEvent(); + destruct_notifier_->UnregisterEvent(); + } } void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function fn) { diff --git a/paddle/fluid/framework/new_executor/workqueue.h b/paddle/fluid/framework/new_executor/workqueue/workqueue.h similarity index 87% rename from paddle/fluid/framework/new_executor/workqueue.h rename to paddle/fluid/framework/new_executor/workqueue/workqueue.h index a299d0aaed7d2..068c54a21a452 100644 --- a/paddle/fluid/framework/new_executor/workqueue.h +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue.h @@ -22,6 +22,7 @@ namespace paddle { namespace framework { constexpr const char* kQueueEmptyEvent = "QueueEmpty"; +constexpr const char* kQueueDestructEvent = "QueueDestruct"; class EventsWaiter; @@ -32,20 +33,24 @@ struct WorkQueueOptions { track_task(track_task) {} WorkQueueOptions(size_t num_threads, bool allow_spinning, bool track_task, - EventsWaiter* waiter) + bool detached, EventsWaiter* waiter) : num_threads(num_threads), allow_spinning(allow_spinning), track_task(track_task), - queue_empty_waiter(waiter) {} + detached(detached), + events_waiter(waiter) {} size_t num_threads; bool allow_spinning; // If you need to blocking the calling thread to wait "queue empty", set - // track_task = true and set queue_empty_waiter. EventsWaiter::WaitEvent will + // track_task = true and set events_waiter. EventsWaiter::WaitEvent will // block the calling thread until any of events (including "queue empty") // occured. bool track_task; - EventsWaiter* queue_empty_waiter{nullptr}; // not owned + // If you need to be noticed when a WorkQueue Destruct() , set detached = + // false and set events_waiter. + bool detached{true}; + EventsWaiter* events_waiter{nullptr}; // not owned }; class WorkQueue { diff --git a/paddle/fluid/framework/new_executor/workqueue_test.cc b/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc similarity index 73% rename from paddle/fluid/framework/new_executor/workqueue_test.cc rename to paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc index 3ea0096b631e8..e06beb623be4c 100644 --- a/paddle/fluid/framework/new_executor/workqueue_test.cc +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc @@ -12,11 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/workqueue.h" +#include "paddle/fluid/framework/new_executor/workqueue/workqueue.h" #include #include "glog/logging.h" #include "gtest/gtest.h" -#include "paddle/fluid/framework/new_executor/workqueue_utils.h" +#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h" + +TEST(WorkQueueUtils, TestEventsWaiter) { + using paddle::framework::EventsWaiter; + EventsWaiter events_waiter; + auto notifier = + events_waiter.RegisterEvent("test_register_lt", []() { return true; }); + EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt"); + EXPECT_EQ(notifier->GetEventName(), "test_register_lt"); + EXPECT_EQ(events_waiter.WaitEvent(), "test_register_lt"); + notifier->UnregisterEvent(); + EXPECT_EQ(notifier->GetEventName(), "Unregistered"); + notifier = events_waiter.RegisterEvent("test_register_et"); + notifier->NotifyEvent(); + EXPECT_EQ(events_waiter.WaitEvent(), "test_register_et"); +} TEST(WorkQueue, TestSingleThreadedWorkQueue) { VLOG(1) << "In Test"; @@ -30,7 +45,8 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { // CreateSingleThreadedWorkQueue EventsWaiter events_waiter; WorkQueueOptions options(/*num_threads*/ 1, /*allow_spinning*/ true, - /*track_task*/ true, &events_waiter); + /*track_task*/ true, /*detached*/ true, + &events_waiter); auto work_queue = CreateSingleThreadedWorkQueue(options); // NumThreads EXPECT_EQ(work_queue->NumThreads(), 1u); @@ -63,7 +79,8 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { // CreateMultiThreadedWorkQueue EventsWaiter events_waiter; WorkQueueOptions options(/*num_threads*/ 10, /*allow_spinning*/ true, - /*track_task*/ true, &events_waiter); + /*track_task*/ true, /*detached*/ false, + &events_waiter); auto work_queue = CreateMultiThreadedWorkQueue(options); // NumThreads EXPECT_EQ(work_queue->NumThreads(), 10u); @@ -80,11 +97,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { } // WaitQueueEmpty EXPECT_EQ(finished.load(), false); - events_waiter.WaitEvent(); + EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueEmptyEvent); EXPECT_EQ(finished.load(), true); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); // Cancel work_queue->Cancel(); + work_queue.reset(); + EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent); } TEST(WorkQueue, TestWorkQueueGroup) { @@ -99,9 +118,11 @@ TEST(WorkQueue, TestWorkQueueGroup) { // ThreadedWorkQueueGroup EventsWaiter events_waiter; WorkQueueOptions sq_options(/*num_threads*/ 1, /*allow_spinning*/ true, - /*track_task*/ true, &events_waiter); + /*track_task*/ true, /*detached*/ false, + &events_waiter); WorkQueueOptions mq_options(/*num_threads*/ 10, /*allow_spinning*/ true, - /*track_task*/ true, &events_waiter); + /*track_task*/ true, /*detached*/ false, + &events_waiter); auto queue_group = CreateWorkQueueGroup({sq_options, mq_options}); // NumThreads EXPECT_EQ(queue_group->QueueNumThreads(0), 1u); @@ -126,4 +147,7 @@ TEST(WorkQueue, TestWorkQueueGroup) { EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); // Cancel queue_group->Cancel(); + events_waiter.WaitEvent(); + queue_group.reset(); + EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent); } diff --git a/paddle/fluid/framework/new_executor/workqueue_utils.cc b/paddle/fluid/framework/new_executor/workqueue/workqueue_utils.cc similarity index 50% rename from paddle/fluid/framework/new_executor/workqueue_utils.cc rename to paddle/fluid/framework/new_executor/workqueue/workqueue_utils.cc index 2c81cffb49d82..82dcbbd509dd5 100644 --- a/paddle/fluid/framework/new_executor/workqueue_utils.cc +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue_utils.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/workqueue_utils.h" +#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h" #include #include @@ -55,62 +55,5 @@ void AlignedFree(void* mem_ptr) { #endif } -constexpr EventsWaiter::EventId kEmptyEventId = -1; - -EventsWaiter::EventsWaiter() - : trigger_event_(kEmptyEventId), waiting_(false), cv_(1) {} - -std::shared_ptr EventsWaiter::RegisterEvent( - const std::string& name, EventChecker checker) { - names_.emplace_back(name); - checkers_.emplace_back(std::move(checker)); - EventId id = checkers_.size() - 1; - auto notifier = std::shared_ptr(new EventNotifier(id, this)); - notifiers_.emplace_back(notifier); - return notifier; -} - -std::string EventsWaiter::WaitEvent() { - // only one user can wait at any time - bool waiting = false; - if (!waiting_.compare_exchange_strong(waiting, true, - std::memory_order_seq_cst, - std::memory_order_relaxed)) { - PADDLE_THROW( - platform::errors::ResourceExhausted("Another thread is waiting.")); - } - EventId id = kEmptyEventId; - auto w = cv_.GetWaiter(0); - cv_.Prewait(); - int64_t event_num = checkers_.size(); - for (int64_t i = 0; id == kEmptyEventId && i < event_num; ++i) { - if (checkers_[i]()) { - id = i; - } - } - if (id != kEmptyEventId) { - cv_.CancelWait(); - } else { - cv_.CommitWait(w); - id = trigger_event_.load(std::memory_order_relaxed); - } - trigger_event_.store(kEmptyEventId, std::memory_order_relaxed); - waiting_.store(false); - return names_.at(id); -} - -void EventsWaiter::SetTriggerEvent(const EventId& id) { - trigger_event_.store(id, std::memory_order_relaxed); - cv_.Notify(true); -} - -std::string EventsWaiter::EventNotifier::GetEventName() { - return waiter_.names_.at(id_); -} - -void EventsWaiter::EventNotifier::NotifyEvent() { - waiter_.SetTriggerEvent(id_); -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/workqueue_utils.h b/paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h similarity index 59% rename from paddle/fluid/framework/new_executor/workqueue_utils.h rename to paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h index a06d9f319dfee..eee64df285dcb 100644 --- a/paddle/fluid/framework/new_executor/workqueue_utils.h +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h @@ -21,8 +21,7 @@ #include #include #include -#include -#include "paddle/fluid/framework/new_executor/event_count.h" +#include "paddle/fluid/framework/new_executor/workqueue/events_waiter.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -69,55 +68,34 @@ void* AlignedMalloc(size_t size, size_t alignment); void AlignedFree(void* memory_ptr); -// A multiplexing waiter, be able to wait multi events simultaneously. -// Blocking the calling thread to wait any of the registered events. -// Non-thread-safe. -class EventsWaiter { +template +class TaskTracker { public: - using EventId = int64_t; + TaskTracker() = default; - using EventChecker = std::function; + explicit TaskTracker(Notifier& notifier) : notifier_(¬ifier) {} - class EventNotifier { - public: - void NotifyEvent(); + TaskTracker(const TaskTracker&) = delete; - EventId GetEventId() { return id_; } + TaskTracker& operator=(const TaskTracker&) = delete; - std::string GetEventName(); + ~TaskTracker() = default; - private: - friend EventsWaiter; - EventNotifier(EventId id, EventsWaiter* waiter) - : id_(id), waiter_(*waiter) {} + void AddCounter() { num_tasks_.fetch_add(1, std::memory_order_relaxed); } - EventId id_; - EventsWaiter& waiter_; - }; - - EventsWaiter(); - - EventsWaiter(const EventsWaiter&) = delete; - - EventsWaiter& operator=(const EventsWaiter&) = delete; - - // All the RegisterEvent functions must be called before any WaitEvent - std::shared_ptr RegisterEvent(const std::string& name, - EventChecker checker); + void SubCounter() { + if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) { + if (notifier_ != nullptr) { + notifier_->NotifyEvent(); + } + } + } - // Wait any of the registered events - std::string WaitEvent(); + uint64_t PendingTaskNum() { return num_tasks_.load(); } private: - friend EventNotifier; - void SetTriggerEvent(const EventId& id); - - std::vector names_; - std::vector checkers_; - std::vector> notifiers_; - std::atomic trigger_event_; - std::atomic waiting_; - EventCount cv_; + alignas(64) std::atomic num_tasks_{0}; + Notifier* notifier_{nullptr}; }; } // namespace framework