Skip to content

Commit

Permalink
Refine events waiter (#40876)
Browse files Browse the repository at this point in the history
* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* Add EventsWaiter

* update

* Revert "Add EventsWaiter"

This reverts commit e206173.

* update

* update Error MSG

* update EventsWaiter

* update

Co-authored-by: liutiexing <liutiexing@google.com>
  • Loading branch information
liutiexing and liutiexing authored Mar 24, 2022
1 parent 2e8f988 commit 36ee6dd
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 88 deletions.
3 changes: 0 additions & 3 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ InterpreterCore::~InterpreterCore() {
// cancle gc's thread
gc_.reset(nullptr);

exception_notifier_->UnregisterEvent();
completion_notifier_->UnregisterEvent();

async_work_queue_.reset(nullptr);
}

Expand Down
173 changes: 117 additions & 56 deletions paddle/fluid/framework/new_executor/workqueue/events_waiter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,79 @@
namespace paddle {
namespace framework {

constexpr EventsWaiter::EventId kEmptyEventId = 0;

EventsWaiter::EventsWaiter()
: trigger_event_(nullptr), counter_(0), waiting_(false), cv_(1) {}
: trigger_event_(kEmptyEventId),
counter_(0),
eof_(true),
waiting_(false),
cv_(1) {}

std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent(
const std::string& name, EventChecker checker) {
auto counter = counter_.fetch_add(1);
auto id = std::hash<std::string>()(name + std::to_string(counter));
EventId id = kEmptyEventId;
EventInfo* evt = nullptr;
do {
auto counter = counter_.fetch_add(1);
id = std::hash<std::string>()(name + std::to_string(counter));
if (id == kEmptyEventId) {
continue;
}
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (events_.count(id) > 0) {
continue;
}
evt = &(events_[id]);
} while (evt == nullptr);
evt->id = id;
evt->name = name;
evt->type = TriggerType::LevelTriggered;
evt->checker = std::move(checker);
eof_.store(false, std::memory_order_relaxed);
VLOG(10) << "Register event id:" << id << " name:" << name;
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this));
EventInfo evt{id, name, TriggerType::LevelTriggered, std::move(checker)};
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_[id] = std::move(evt);
return notifier;
}

std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent(
const std::string& name) {
auto counter = counter_.fetch_add(1);
auto id = std::hash<std::string>()(name + std::to_string(counter));
EventId id = kEmptyEventId;
EventInfo* evt = nullptr;
do {
auto counter = counter_.fetch_add(1);
id = std::hash<std::string>()(name + std::to_string(counter));
if (id == kEmptyEventId) {
continue;
}
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (events_.count(id) > 0) {
continue;
}
evt = &(events_[id]);
} while (evt == nullptr);
evt->id = id;
evt->name = name;
evt->type = TriggerType::EdgeTriggered;
evt->checker = []() { return false; };
eof_.store(false, std::memory_order_relaxed);
VLOG(10) << "Register event id:" << id << " name:" << name;
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this));
EventInfo evt{id, name, TriggerType::EdgeTriggered, []() { return false; }};
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_[id] = std::move(evt);
return notifier;
}

void EventsWaiter::UnregisterEvent(const EventId& id) {
VLOG(10) << "Unregister event id:" << id;
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
events_.erase(id);
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
deleted_events_.insert(id);
if (deleted_events_.size() == events_.size()) {
eof_.store(true, std::memory_order_relaxed);
}
}
if (eof_.load(std::memory_order_relaxed)) {
cv_.Notify(true);
}
}

std::string EventsWaiter::WaitEvent() {
Expand All @@ -61,42 +103,60 @@ std::string EventsWaiter::WaitEvent() {
PADDLE_THROW(
platform::errors::ResourceExhausted("Another thread is waiting."));
}

auto w = cv_.GetWaiter(0);
cv_.Prewait();
std::string* triggered = trigger_event_;
if (triggered == nullptr) {
EventId triggered = trigger_event_;
while (triggered == kEmptyEventId && !eof_) {
cv_.Prewait();

// double check
triggered = trigger_event_;
// checkers
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
for (auto& kv : events_) {
auto& evt = kv.second;
if (TriggerType::LevelTriggered == evt.type && evt.checker()) {
triggered = new std::string(evt.name);
break;
if (triggered == kEmptyEventId) {
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
for (auto& kv : events_) {
auto& evt = kv.second;
if (TriggerType::LevelTriggered == evt.type && evt.checker()) {
triggered = evt.id;
break;
}
}
}
}
if (triggered != nullptr) {
std::string* prev = nullptr;
if (!trigger_event_.compare_exchange_strong(prev, triggered,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
delete triggered;
triggered = prev;
if (triggered != kEmptyEventId) {
EventId prev = kEmptyEventId;
if (!trigger_event_.compare_exchange_strong(
prev, triggered, std::memory_order_seq_cst,
std::memory_order_relaxed)) {
triggered = prev;
}
}
}

if (triggered != kEmptyEventId || eof_) {
cv_.CancelWait();
} else {
cv_.CommitWait(w);
triggered = trigger_event_;
}
}
if (triggered) {
cv_.CancelWait();
} else {
cv_.CommitWait(w);
triggered = trigger_event_;

trigger_event_.store(kEmptyEventId, std::memory_order_relaxed);
waiting_.store(false, std::memory_order_relaxed);
std::string evt_name =
triggered == kEmptyEventId ? "NoEventNotifier" : GetEventName(triggered);
VLOG(10) << "Consume event id:" << triggered << ", name:" << evt_name;
// lazy deletion
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (deleted_events_.size() > 0) {
for (auto evt : deleted_events_) {
events_.erase(evt);
}
deleted_events_.clear();
}
}
trigger_event_.store(nullptr, std::memory_order_relaxed);
waiting_.store(false);
auto trigger_event = *triggered;
delete triggered;
return trigger_event;
return evt_name;
}

int EventsWaiter::Clear() {
Expand All @@ -106,32 +166,33 @@ int EventsWaiter::Clear() {
std::memory_order_relaxed)) {
return -1;
}
trigger_event_.store(nullptr, std::memory_order_relaxed);
trigger_event_.store(kEmptyEventId, std::memory_order_relaxed);
waiting_.store(false);
return 0;
}

void EventsWaiter::TriggerEvent(const EventId& id) {
VLOG(10) << "Try to trigger event id:" << id;
std::string* trigger_event = new std::string;
{
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
auto iter = events_.find(id);
if (iter == events_.end()) {
delete trigger_event;
return;
}
*trigger_event = iter->second.name;
EventId prev = kEmptyEventId;
if (!trigger_event_.compare_exchange_strong(
prev, id, std::memory_order_seq_cst, std::memory_order_relaxed)) {
VLOG(10) << "Event id:" << prev << " is pending";
return;
}
std::string* prev = nullptr;
if (!trigger_event_.compare_exchange_strong(prev, trigger_event,
VLOG(10) << "Triggered event id:" << id;
cv_.Notify(true);
}

void EventsWaiter::CancelEvent(const EventId& id) {
VLOG(10) << "Try to cancel event id:" << id;
EventId prev = id;
if (!trigger_event_.compare_exchange_strong(prev, kEmptyEventId,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
delete trigger_event;
VLOG(10) << "Event id:" << prev << " is pending";
return;
}
VLOG(10) << "Triggered event id:" << id << " name:" << *trigger_event;
cv_.Notify(true);
VLOG(10) << "Cancelled event id:" << id;
}

std::string EventsWaiter::GetEventName(const EventId& id) {
Expand Down
14 changes: 9 additions & 5 deletions paddle/fluid/framework/new_executor/workqueue/events_waiter.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/new_executor/workqueue/event_count.h"
#include "paddle/fluid/memory/allocation/spin_lock.h"

Expand All @@ -37,13 +38,12 @@ class EventsWaiter {
// Make sure EventsWaiter has a longer lifetime than EventNotifier.
class EventNotifier {
public:
void NotifyEvent() { waiter_.TriggerEvent(id_); }
~EventNotifier() { waiter_.UnregisterEvent(id_); }

void UnregisterEvent() { waiter_.UnregisterEvent(id_); }
void NotifyEvent() { waiter_.TriggerEvent(id_); }

EventId GetEventId() { return id_; }
void CancelEvent() { waiter_.CancelEvent(id_); }

// return "Unregistered" if the corresponding event was unregistered.
std::string GetEventName() { return waiter_.GetEventName(id_); }

private:
Expand Down Expand Up @@ -97,12 +97,16 @@ class EventsWaiter {

void TriggerEvent(const EventId& id);

void CancelEvent(const EventId& id);

std::string GetEventName(const EventId& id);

std::unordered_map<EventId, EventInfo> events_;
std::unordered_set<EventId> deleted_events_;
paddle::memory::SpinLock events_lock_;
std::atomic<std::string*> trigger_event_;
std::atomic<EventId> trigger_event_;
std::atomic<uint64_t> counter_;
std::atomic<bool> eof_;
std::atomic<bool> waiting_;
EventCount cv_;
};
Expand Down
21 changes: 4 additions & 17 deletions paddle/fluid/framework/new_executor/workqueue/workqueue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@ class WorkQueueImpl : public WorkQueue {
public:
explicit WorkQueueImpl(const WorkQueueOptions& options) : WorkQueue(options) {
if (options_.track_task && options.events_waiter != nullptr) {
empty_notifier_ = options.events_waiter->RegisterEvent(kQueueEmptyEvent);
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
empty_notifier_ = options.events_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*empty_notifier_.get());
}
if (options_.detached == false && options.events_waiter != nullptr) {
Expand All @@ -47,17 +44,13 @@ class WorkQueueImpl : public WorkQueue {
}

virtual ~WorkQueueImpl() {
if (empty_notifier_) {
empty_notifier_->UnregisterEvent();
}
delete queue_;
if (tracker_ != nullptr) {
tracker_->~TaskTracker();
AlignedFree(tracker_);
}
if (destruct_notifier_) {
destruct_notifier_->NotifyEvent();
destruct_notifier_->UnregisterEvent();
}
}

Expand Down Expand Up @@ -124,14 +117,12 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
const auto& options = queues_options_[idx];
if (options.track_task && tracker_ == nullptr &&
options.events_waiter != nullptr) {
empty_notifier_ = options.events_waiter->RegisterEvent(kQueueEmptyEvent);
void* storage = AlignedMalloc(sizeof(TaskTracker), alignof(TaskTracker));
TaskTracker* tracker = reinterpret_cast<TaskTracker*>(storage);
empty_notifier_ = options.events_waiter->RegisterEvent(
kQueueEmptyEvent,
[tracker]() { return tracker->PendingTaskNum() == 0; });
tracker_ = new (storage) TaskTracker(*empty_notifier_.get());
}
if (options.detached == false && options.events_waiter != nullptr) {
if (options.detached == false && options.events_waiter != nullptr &&
!destruct_notifier_) {
destruct_notifier_ =
options.events_waiter->RegisterEvent(kQueueDestructEvent);
}
Expand All @@ -141,9 +132,6 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
}

WorkQueueGroupImpl::~WorkQueueGroupImpl() {
if (empty_notifier_) {
empty_notifier_->UnregisterEvent();
}
for (auto queue : queues_) {
queue->~NonblockingThreadPool();
}
Expand All @@ -154,7 +142,6 @@ WorkQueueGroupImpl::~WorkQueueGroupImpl() {
free(queues_storage_);
if (destruct_notifier_) {
destruct_notifier_->NotifyEvent();
destruct_notifier_->UnregisterEvent();
}
}

Expand Down
Loading

0 comments on commit 36ee6dd

Please sign in to comment.