Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
liutiexing committed Dec 22, 2021
1 parent 91b289c commit 9211ad2
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 21 deletions.
6 changes: 4 additions & 2 deletions paddle/fluid/framework/new_executor/interpretercore_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ class AsyncWorkQueue {
group_options.emplace_back(/*num_threads*/ host_num_threads,
/*allow_spinning*/ true,
/*track_task*/ true,
/*queue_empty_waiter*/ waiter);
/*detached*/ true,
/*events_waiter*/ waiter);
// for launch device Kernel
group_options.emplace_back(/*num_threads*/ 1,
/*allow_spinning*/ true,
/*track_task*/ true,
/*queue_empty_waiter*/ waiter);
/*detached*/ true,
/*events_waiter*/ waiter);
queue_group_ = CreateWorkQueueGroup(group_options);
}

Expand Down
42 changes: 32 additions & 10 deletions paddle/fluid/framework/new_executor/workqueue/workqueue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,34 @@ using TaskTracker = TaskTracker<EventsWaiter::EventNotifier>;
class WorkQueueImpl : public WorkQueue {
public:
explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) {
if (options_.track_task && options.queue_empty_waiter != nullptr) {
if (options_.track_task && options.events_waiter != nullptr) {
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
notifier_ = options.queue_empty_waiter->RegisterEvent(
empty_notifier_ = options.events_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*notifier_.get());
tracker_ = new (storage) TaskTracker(*empty_notifier_.get());
}
if (options_.detached == false && options.events_waiter != nullptr) {
destruct_notifier_ =
options.events_waiter->RegisterEvent(kQueueDestructEvent);
}
queue_ = new NonblockingThreadPool(options_.num_threads,
options_.allow_spinning);
}

virtual ~WorkQueueImpl() {
if (empty_notifier_) {
empty_notifier_->UnregisterEvent();
}
delete queue_;
if (tracker_ != nullptr) {
tracker_->~TaskTracker();
AlignedFree(tracker_);
notifier_->UnregisterEvent();
}
if (destruct_notifier_) {
destruct_notifier_->NotifyEvent();
destruct_notifier_->UnregisterEvent();
}
}

Expand All @@ -60,7 +70,8 @@ class WorkQueueImpl : public WorkQueue {
private:
NonblockingThreadPool* queue_{nullptr};
TaskTracker* tracker_{nullptr};
std::shared_ptr<EventsWaiter::EventNotifier> notifier_;
std::shared_ptr<EventsWaiter::EventNotifier> empty_notifier_;
std::shared_ptr<EventsWaiter::EventNotifier> destruct_notifier_;
};

class WorkQueueGroupImpl : public WorkQueueGroup {
Expand All @@ -82,7 +93,8 @@ class WorkQueueGroupImpl : public WorkQueueGroup {
std::vector<NonblockingThreadPool*> queues_;
NonblockingThreadPool* queues_storage_;
TaskTracker* tracker_;
std::shared_ptr<EventsWaiter::EventNotifier> notifier_;
std::shared_ptr<EventsWaiter::EventNotifier> empty_notifier_;
std::shared_ptr<EventsWaiter::EventNotifier> destruct_notifier_;
};

WorkQueueGroupImpl::WorkQueueGroupImpl(
Expand All @@ -97,29 +109,39 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
for (size_t idx = 0; idx < num_queues; ++idx) {
const auto& options = queues_options_[idx];
if (options.track_task && tracker_ == nullptr &&
options.queue_empty_waiter != nullptr) {
options.events_waiter != nullptr) {
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
notifier_ = options.queue_empty_waiter->RegisterEvent(
empty_notifier_ = options.events_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*notifier_.get());
tracker_ = new (storage) TaskTracker(*empty_notifier_.get());
}
if (options.detached == false && options.events_waiter != nullptr) {
destruct_notifier_ =
options.events_waiter->RegisterEvent(kQueueDestructEvent);
}
queues_[idx] = new (&queues_storage_[idx])
NonblockingThreadPool(options.num_threads, options.allow_spinning);
}
}

WorkQueueGroupImpl::~WorkQueueGroupImpl() {
if (empty_notifier_) {
empty_notifier_->UnregisterEvent();
}
for (auto queue : queues_) {
queue->~NonblockingThreadPool();
}
if (tracker_ != nullptr) {
tracker_->~TaskTracker();
AlignedFree(tracker_);
notifier_->UnregisterEvent();
}
free(queues_storage_);
if (destruct_notifier_) {
destruct_notifier_->NotifyEvent();
destruct_notifier_->UnregisterEvent();
}
}

void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function<void()> fn) {
Expand Down
13 changes: 9 additions & 4 deletions paddle/fluid/framework/new_executor/workqueue/workqueue.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace paddle {
namespace framework {

constexpr const char* kQueueEmptyEvent = "QueueEmpty";
constexpr const char* kQueueDestructEvent = "QueueDestruct";

class EventsWaiter;

Expand All @@ -32,20 +33,24 @@ struct WorkQueueOptions {
track_task(track_task) {}

WorkQueueOptions(size_t num_threads, bool allow_spinning, bool track_task,
EventsWaiter* waiter)
bool detached, EventsWaiter* waiter)
: num_threads(num_threads),
allow_spinning(allow_spinning),
track_task(track_task),
queue_empty_waiter(waiter) {}
detached(detached),
events_waiter(waiter) {}

size_t num_threads;
bool allow_spinning;
// If you need to blocking the calling thread to wait "queue empty", set
// track_task = true and set queue_empty_waiter. EventsWaiter::WaitEvent will
// track_task = true and set events_waiter. EventsWaiter::WaitEvent will
// block the calling thread until any of events (including "queue empty")
// occured.
bool track_task;
EventsWaiter* queue_empty_waiter{nullptr}; // not owned
// If you need to be noticed when a WorkQueue Destruct() , set detached =
// false and set events_waiter.
bool detached{true};
EventsWaiter* events_waiter{nullptr}; // not owned
};

class WorkQueue {
Expand Down
19 changes: 14 additions & 5 deletions paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) {
// CreateSingleThreadedWorkQueue
EventsWaiter events_waiter;
WorkQueueOptions options(/*num_threads*/ 1, /*allow_spinning*/ true,
/*track_task*/ true, &events_waiter);
/*track_task*/ true, /*detached*/ true,
&events_waiter);
auto work_queue = CreateSingleThreadedWorkQueue(options);
// NumThreads
EXPECT_EQ(work_queue->NumThreads(), 1u);
Expand Down Expand Up @@ -78,7 +79,8 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
// CreateMultiThreadedWorkQueue
EventsWaiter events_waiter;
WorkQueueOptions options(/*num_threads*/ 10, /*allow_spinning*/ true,
/*track_task*/ true, &events_waiter);
/*track_task*/ true, /*detached*/ false,
&events_waiter);
auto work_queue = CreateMultiThreadedWorkQueue(options);
// NumThreads
EXPECT_EQ(work_queue->NumThreads(), 10u);
Expand All @@ -95,11 +97,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
}
// WaitQueueEmpty
EXPECT_EQ(finished.load(), false);
events_waiter.WaitEvent();
EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueEmptyEvent);
EXPECT_EQ(finished.load(), true);
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum);
// Cancel
work_queue->Cancel();
work_queue.reset();
EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent);
}

TEST(WorkQueue, TestWorkQueueGroup) {
Expand All @@ -114,9 +118,11 @@ TEST(WorkQueue, TestWorkQueueGroup) {
// ThreadedWorkQueueGroup
EventsWaiter events_waiter;
WorkQueueOptions sq_options(/*num_threads*/ 1, /*allow_spinning*/ true,
/*track_task*/ true, &events_waiter);
/*track_task*/ true, /*detached*/ false,
&events_waiter);
WorkQueueOptions mq_options(/*num_threads*/ 10, /*allow_spinning*/ true,
/*track_task*/ true, &events_waiter);
/*track_task*/ true, /*detached*/ false,
&events_waiter);
auto queue_group = CreateWorkQueueGroup({sq_options, mq_options});
// NumThreads
EXPECT_EQ(queue_group->QueueNumThreads(0), 1u);
Expand All @@ -141,4 +147,7 @@ TEST(WorkQueue, TestWorkQueueGroup) {
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum);
// Cancel
queue_group->Cancel();
events_waiter.WaitEvent();
queue_group.reset();
EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent);
}

1 comment on commit 9211ad2

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.