From 9211ad251b66faf7b985fcc8fb452a78e87d8ba2 Mon Sep 17 00:00:00 2001 From: liutiexing Date: Wed, 22 Dec 2021 14:12:58 +0000 Subject: [PATCH] update --- .../new_executor/interpretercore_util.h | 6 ++- .../new_executor/workqueue/workqueue.cc | 42 ++++++++++++++----- .../new_executor/workqueue/workqueue.h | 13 ++++-- .../new_executor/workqueue/workqueue_test.cc | 19 ++++++--- 4 files changed, 59 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index feb4b0c898a0d..14c27c94f8394 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -61,12 +61,14 @@ class AsyncWorkQueue { group_options.emplace_back(/*num_threads*/ host_num_threads, /*allow_spinning*/ true, /*track_task*/ true, - /*queue_empty_waiter*/ waiter); + /*detached*/ true, + /*events_waiter*/ waiter); // for launch device Kernel group_options.emplace_back(/*num_threads*/ 1, /*allow_spinning*/ true, /*track_task*/ true, - /*queue_empty_waiter*/ waiter); + /*detached*/ true, + /*events_waiter*/ waiter); queue_group_ = CreateWorkQueueGroup(group_options); } diff --git a/paddle/fluid/framework/new_executor/workqueue/workqueue.cc b/paddle/fluid/framework/new_executor/workqueue/workqueue.cc index 801acf4b89ddc..3f06f3db23118 100644 --- a/paddle/fluid/framework/new_executor/workqueue/workqueue.cc +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue.cc @@ -18,24 +18,34 @@ using TaskTracker = TaskTracker; class WorkQueueImpl : public WorkQueue { public: explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) { - if (options_.track_task && options.queue_empty_waiter != nullptr) { + if (options_.track_task && options.events_waiter != nullptr) { void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); TaskTracker* tracker = reinterpret_cast(storage); - notifier_ = options.queue_empty_waiter->RegisterEvent( + empty_notifier_ = options.events_waiter->RegisterEvent( kQueueEmptyEvent, [tracker]() { return tracker->PendingTaskNum() == 0; }); - tracker_ = new (storage) TaskTracker(*notifier_.get()); + tracker_ = new (storage) TaskTracker(*empty_notifier_.get()); + } + if (options_.detached == false && options.events_waiter != nullptr) { + destruct_notifier_ = + options.events_waiter->RegisterEvent(kQueueDestructEvent); } queue_ = new NonblockingThreadPool(options_.num_threads, options_.allow_spinning); } virtual ~WorkQueueImpl() { + if (empty_notifier_) { + empty_notifier_->UnregisterEvent(); + } delete queue_; if (tracker_ != nullptr) { tracker_->~TaskTracker(); AlignedFree(tracker_); - notifier_->UnregisterEvent(); + } + if (destruct_notifier_) { + destruct_notifier_->NotifyEvent(); + destruct_notifier_->UnregisterEvent(); } } @@ -60,7 +70,8 @@ class WorkQueueImpl : public WorkQueue { private: NonblockingThreadPool* queue_{nullptr}; TaskTracker* tracker_{nullptr}; - std::shared_ptr notifier_; + std::shared_ptr empty_notifier_; + std::shared_ptr destruct_notifier_; }; class WorkQueueGroupImpl : public WorkQueueGroup { @@ -82,7 +93,8 @@ class WorkQueueGroupImpl : public WorkQueueGroup { std::vector queues_; NonblockingThreadPool* queues_storage_; TaskTracker* tracker_; - std::shared_ptr notifier_; + std::shared_ptr empty_notifier_; + std::shared_ptr destruct_notifier_; }; WorkQueueGroupImpl::WorkQueueGroupImpl( @@ -97,13 +109,17 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( for (size_t idx = 0; idx < num_queues; ++idx) { const auto& options = queues_options_[idx]; if (options.track_task && tracker_ == nullptr && - options.queue_empty_waiter != nullptr) { + options.events_waiter != nullptr) { void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker)); TaskTracker* tracker = reinterpret_cast(storage); - notifier_ = options.queue_empty_waiter->RegisterEvent( + empty_notifier_ = options.events_waiter->RegisterEvent( kQueueEmptyEvent, [tracker]() { return tracker->PendingTaskNum() == 0; }); - tracker_ = new (storage) TaskTracker(*notifier_.get()); + tracker_ = new (storage) TaskTracker(*empty_notifier_.get()); + } + if (options.detached == false && options.events_waiter != nullptr) { + destruct_notifier_ = + options.events_waiter->RegisterEvent(kQueueDestructEvent); } queues_[idx] = new (&queues_storage_[idx]) NonblockingThreadPool(options.num_threads, options.allow_spinning); @@ -111,15 +127,21 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( } WorkQueueGroupImpl::~WorkQueueGroupImpl() { + if (empty_notifier_) { + empty_notifier_->UnregisterEvent(); + } for (auto queue : queues_) { queue->~NonblockingThreadPool(); } if (tracker_ != nullptr) { tracker_->~TaskTracker(); AlignedFree(tracker_); - notifier_->UnregisterEvent(); } free(queues_storage_); + if (destruct_notifier_) { + destruct_notifier_->NotifyEvent(); + destruct_notifier_->UnregisterEvent(); + } } void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function fn) { diff --git a/paddle/fluid/framework/new_executor/workqueue/workqueue.h b/paddle/fluid/framework/new_executor/workqueue/workqueue.h index a299d0aaed7d2..068c54a21a452 100644 --- a/paddle/fluid/framework/new_executor/workqueue/workqueue.h +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue.h @@ -22,6 +22,7 @@ namespace paddle { namespace framework { constexpr const char* kQueueEmptyEvent = "QueueEmpty"; +constexpr const char* kQueueDestructEvent = "QueueDestruct"; class EventsWaiter; @@ -32,20 +33,24 @@ struct WorkQueueOptions { track_task(track_task) {} WorkQueueOptions(size_t num_threads, bool allow_spinning, bool track_task, - EventsWaiter* waiter) + bool detached, EventsWaiter* waiter) : num_threads(num_threads), allow_spinning(allow_spinning), track_task(track_task), - queue_empty_waiter(waiter) {} + detached(detached), + events_waiter(waiter) {} size_t num_threads; bool allow_spinning; // If you need to blocking the calling thread to wait "queue empty", set - // track_task = true and set queue_empty_waiter. EventsWaiter::WaitEvent will + // track_task = true and set events_waiter. EventsWaiter::WaitEvent will // block the calling thread until any of events (including "queue empty") // occured. bool track_task; - EventsWaiter* queue_empty_waiter{nullptr}; // not owned + // If you need to be noticed when a WorkQueue Destruct() , set detached = + // false and set events_waiter. + bool detached{true}; + EventsWaiter* events_waiter{nullptr}; // not owned }; class WorkQueue { diff --git a/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc b/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc index eb47b4cac399f..e06beb623be4c 100644 --- a/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc @@ -45,7 +45,8 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { // CreateSingleThreadedWorkQueue EventsWaiter events_waiter; WorkQueueOptions options(/*num_threads*/ 1, /*allow_spinning*/ true, - /*track_task*/ true, &events_waiter); + /*track_task*/ true, /*detached*/ true, + &events_waiter); auto work_queue = CreateSingleThreadedWorkQueue(options); // NumThreads EXPECT_EQ(work_queue->NumThreads(), 1u); @@ -78,7 +79,8 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { // CreateMultiThreadedWorkQueue EventsWaiter events_waiter; WorkQueueOptions options(/*num_threads*/ 10, /*allow_spinning*/ true, - /*track_task*/ true, &events_waiter); + /*track_task*/ true, /*detached*/ false, + &events_waiter); auto work_queue = CreateMultiThreadedWorkQueue(options); // NumThreads EXPECT_EQ(work_queue->NumThreads(), 10u); @@ -95,11 +97,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { } // WaitQueueEmpty EXPECT_EQ(finished.load(), false); - events_waiter.WaitEvent(); + EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueEmptyEvent); EXPECT_EQ(finished.load(), true); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); // Cancel work_queue->Cancel(); + work_queue.reset(); + EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent); } TEST(WorkQueue, TestWorkQueueGroup) { @@ -114,9 +118,11 @@ TEST(WorkQueue, TestWorkQueueGroup) { // ThreadedWorkQueueGroup EventsWaiter events_waiter; WorkQueueOptions sq_options(/*num_threads*/ 1, /*allow_spinning*/ true, - /*track_task*/ true, &events_waiter); + /*track_task*/ true, /*detached*/ false, + &events_waiter); WorkQueueOptions mq_options(/*num_threads*/ 10, /*allow_spinning*/ true, - /*track_task*/ true, &events_waiter); + /*track_task*/ true, /*detached*/ false, + &events_waiter); auto queue_group = CreateWorkQueueGroup({sq_options, mq_options}); // NumThreads EXPECT_EQ(queue_group->QueueNumThreads(0), 1u); @@ -141,4 +147,7 @@ TEST(WorkQueue, TestWorkQueueGroup) { EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); // Cancel queue_group->Cancel(); + events_waiter.WaitEvent(); + queue_group.reset(); + EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent); }