Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work queue group #35470

Merged
merged 11 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
garbages_.reset(new GarbageQueue());
max_memory_size_ = static_cast<size_t>(GetEagerDeletionThreshold());
cur_memory_size_ = 0;
gc_queue_ = CreateSingleThreadedWorkQueue();
WorkQueueOptions options;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

options不需要赋值吗?我看默认线程数是0。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修复

gc_queue_ = CreateSingleThreadedWorkQueue(options);

feed_names_ = feed_names;

Expand Down
102 changes: 34 additions & 68 deletions paddle/fluid/framework/new_executor/nonblocking_threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,46 @@
namespace paddle {
namespace framework {

class CounterTracker {
class TaskTracker {
public:
explicit CounterTracker(std::atomic<uint64_t>* counter, EventCount* ec)
: counter_(counter), ec_(ec) {
counter_->fetch_add(1, std::memory_order_relaxed);
}
TaskTracker() : wait_empty_cv_(1) {}

~CounterTracker() {
if (counter_ != nullptr) {
if (1 == counter_->fetch_sub(1, std::memory_order_relaxed)) {
ec_->Notify(true);
}
}
}
TaskTracker(const TaskTracker&) = delete;

CounterTracker(CounterTracker&& other)
: counter_(other.counter_), ec_(other.ec_) {
other.counter_ = nullptr;
other.ec_ = nullptr;
}
TaskTracker& operator=(const TaskTracker&) = delete;

CounterTracker& operator=(CounterTracker&& other) {
counter_ = other.counter_;
ec_ = other.ec_;
other.counter_ = nullptr;
other.ec_ = nullptr;
return *this;
}
~TaskTracker() = default;

CounterTracker(const CounterTracker& other)
: counter_(other.counter_), ec_(other.ec_) {
counter_->fetch_add(1, std::memory_order_relaxed);
void AddCounter() { num_tasks_.fetch_add(1, std::memory_order_relaxed); }

void SubCounter() {
if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) {
wait_empty_cv_.Notify(true);
}
}

CounterTracker& operator=(const CounterTracker&) = delete;
// only one user can wait at any time
void WaitTaskNumToZero() {
bool waiting = false;
if (!wait_empty_.compare_exchange_strong(waiting, true,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
abort();
}
EventCount::Waiter* w = wait_empty_cv_.GetWaiter(0);
wait_empty_cv_.Prewait();
if (num_tasks_.load(std::memory_order_relaxed) == 0) {
wait_empty_cv_.CancelWait();
} else {
wait_empty_cv_.CommitWait(w);
}
wait_empty_.store(false);
}

private:
std::atomic<uint64_t>* counter_{nullptr};
EventCount* ec_{nullptr};
std::atomic<uint64_t> num_tasks_{0};
EventCount wait_empty_cv_;
std::atomic<bool> wait_empty_{false};
};

template <typename Environment>
Expand All @@ -66,9 +67,6 @@ class ThreadPoolTempl {
typedef typename Environment::Task Task;
typedef RunQueue<Task, 1024> Queue;

explicit ThreadPoolTempl(int num_threads, Environment env = Environment())
: ThreadPoolTempl(num_threads, true, env) {}

ThreadPoolTempl(int num_threads, bool allow_spinning,
Environment env = Environment())
: env_(env),
Expand All @@ -80,10 +78,7 @@ class ThreadPoolTempl {
spinning_(0),
done_(false),
cancelled_(false),
ec_(num_threads_),
wait_empty_(false),
wait_empty_ec_(1),
num_tasks_(0) {
ec_(num_threads_) {
// Calculate coprimes of all numbers [1, num_threads].
// Coprimes are used for random walks over all threads in Steal
// and NonEmptyQueueIndex. Iteration is based on the fact that if we take
Expand Down Expand Up @@ -146,15 +141,13 @@ class ThreadPoolTempl {
}

void AddTaskWithHint(std::function<void()> fn, int start, int limit) {
Task t = env_.CreateTask([
task = std::move(fn), raii = CounterTracker(&num_tasks_, &wait_empty_ec_)
]() mutable { task(); });
Task t = env_.CreateTask(std::move(fn));
PerThread* pt = GetPerThread();
if (pt->pool == this) {
// Worker thread of this pool, push onto the thread's queue.
Queue& q = thread_data_[pt->thread_id].queue;
t = q.PushFront(std::move(t));
} else if (wait_empty_.load() == false) {
} else {
// A free-standing thread (or worker of another pool), push onto a random
// queue.
assert(start < limit);
Expand All @@ -179,29 +172,6 @@ class ThreadPoolTempl {
}
}

void WaitQueueEmpty() {
bool waiting = wait_empty_.load();
assert(waiting == false);
if (waiting ||
!wait_empty_.compare_exchange_strong(waiting, true,
std::memory_order_acquire)) {
abort();
}
EventCount::Waiter* w = wait_empty_ec_.GetWaiter(0);
wait_empty_ec_.Prewait();
if (num_tasks_.load() == 0) {
wait_empty_ec_.CancelWait();
} else {
wait_empty_ec_.CommitWait(w);
}
waiting = true;
if (!waiting ||
!wait_empty_.compare_exchange_strong(waiting, false,
std::memory_order_acquire)) {
abort();
}
}

void Cancel() {
cancelled_ = true;
done_ = true;
Expand Down Expand Up @@ -300,10 +270,6 @@ class ThreadPoolTempl {
std::atomic<bool> cancelled_;
EventCount ec_;

std::atomic<bool> wait_empty_;
EventCount wait_empty_ec_;
std::atomic<uint64_t> num_tasks_;

// Main worker thread loop.
void WorkerLoop(int thread_id) {
PerThread* pt = GetPerThread();
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/new_executor/run_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class RunQueue {
private:
static const unsigned kMask = kSize - 1;
static const unsigned kMask2 = (kSize << 1) - 1;
struct Elem {
struct alignas(64) Elem {
std::atomic<uint8_t> state;
Work w;
};
Expand All @@ -212,8 +212,8 @@ class RunQueue {
// position, these conditions would be indistinguishable); (2) obtain
// consistent snapshot of front_/back_ for Size operation using the
// modification counters.
std::atomic<unsigned> front_;
std::atomic<unsigned> back_;
alignas(64) std::atomic<unsigned> front_;
alignas(64) std::atomic<unsigned> back_;
Elem array_[kSize];

// SizeOrNotEmpty returns current queue size; if NeedSizeEstimate is false,
Expand Down
150 changes: 120 additions & 30 deletions paddle/fluid/framework/new_executor/workqueue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,63 +6,153 @@

#include "paddle/fluid/framework/new_executor/workqueue.h"
#include "paddle/fluid/framework/new_executor/nonblocking_threadpool.h"
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"

namespace paddle {
namespace framework {
namespace {

class SingleThreadedWorkQueue : public WorkQueue {
class WorkQueueImpl : public WorkQueue {
public:
SingleThreadedWorkQueue() : queue_(1) {}

SingleThreadedWorkQueue(const SingleThreadedWorkQueue&) = delete;

SingleThreadedWorkQueue& operator=(const SingleThreadedWorkQueue&) = delete;
explicit WorkQueueImpl(const WorkQueueOptions& options)
: WorkQueue(options), queue_(nullptr), tracker_(nullptr) {
if (options_.track_task) {
tracker_ = new TaskTracker;
}
queue_ = new NonblockingThreadPool(options_.num_threads,
options_.allow_spinning);
}

virtual ~SingleThreadedWorkQueue() = default;
virtual ~WorkQueueImpl() {
delete tracker_;
delete queue_;
}

void AddTask(std::function<void()> fn) override {
queue_.AddTask(std::move(fn));
if (tracker_ != nullptr) {
fn = [
task = std::move(fn), raii = CounterGuard<TaskTracker>(tracker_)
]() mutable {
task();
};
}
queue_->AddTask(std::move(fn));
}

void WaitQueueEmpty() override { queue_.WaitQueueEmpty(); }
void WaitQueueEmpty() override {
if (tracker_ == nullptr) {
abort();
}
tracker_->WaitTaskNumToZero();
}

size_t NumThreads() override { return queue_.NumThreads(); }
size_t NumThreads() const override { return queue_->NumThreads(); }

private:
NonblockingThreadPool queue_;
NonblockingThreadPool* queue_;
TaskTracker* tracker_;
};

std::unique_ptr<WorkQueue> CreateSingleThreadedWorkQueue() {
std::unique_ptr<WorkQueue> ptr(new SingleThreadedWorkQueue);
return std::move(ptr);
class WorkQueueGroupImpl : public WorkQueueGroup {
public:
explicit WorkQueueGroupImpl(
const std::vector<WorkQueueOptions>& queue_options);

~WorkQueueGroupImpl();

void AddTask(size_t queue_idx, std::function<void()> fn) override;

void WaitQueueGroupEmpty() override;

size_t QueueNumThreads(size_t queue_idx) const override;

size_t QueueGroupNumThreads() const override;

private:
std::vector<NonblockingThreadPool*> queues_;
NonblockingThreadPool* queues_storage_;
TaskTracker* tracker_;
};

WorkQueueGroupImpl::WorkQueueGroupImpl(
const std::vector<WorkQueueOptions>& queues_options)
: WorkQueueGroup(queues_options),
queues_storage_(nullptr),
tracker_(nullptr) {
size_t num_queues = queues_options_.size();
queues_.resize(num_queues);
void* buffer = malloc(sizeof(NonblockingThreadPool) * num_queues);
queues_storage_ = reinterpret_cast<NonblockingThreadPool*>(buffer);
for (size_t idx = 0; idx < num_queues; ++idx) {
const auto& options = queues_options_[idx];
if (options.track_task && tracker_ == nullptr) {
tracker_ = new TaskTracker;
}
queues_[idx] = new (&queues_storage_[idx])
NonblockingThreadPool(options.num_threads, options.allow_spinning);
}
}

class MultiThreadedWorkQueue : public WorkQueue {
public:
explicit MultiThreadedWorkQueue(int num_threads) : queue_(num_threads) {
assert(num_threads > 1);
WorkQueueGroupImpl::~WorkQueueGroupImpl() {
for (auto queue : queues_) {
queue->~NonblockingThreadPool();
}
delete tracker_;
free(queues_storage_);
}

MultiThreadedWorkQueue(const MultiThreadedWorkQueue&) = delete;
void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function<void()> fn) {
assert(queue_idx < queues_.size());
if (queues_options_[queue_idx].track_task) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是可以直接 if(nullptr == track_task_) ?

这里是支持 queues_options_ 部分track_task 为False, 部分为true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的

fn = [
task = std::move(fn), raii = CounterGuard<TaskTracker>(tracker_)
]() mutable {
task();
};
}
queues_[queue_idx]->AddTask(std::move(fn));
}

MultiThreadedWorkQueue& operator=(const MultiThreadedWorkQueue&) = delete;
void WorkQueueGroupImpl::WaitQueueGroupEmpty() {
if (nullptr == tracker_) {
abort();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果tracker_为空,这里是不是给一些error信息,表示我们不支持在未指定track_task时,调用WaitQueueGroupEmpty接口,直接abort会让开发者很困惑?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改为PADDLE_THROW

}
tracker_->WaitTaskNumToZero();
}

virtual ~MultiThreadedWorkQueue() = default;
size_t WorkQueueGroupImpl::QueueNumThreads(size_t queue_idx) const {
assert(queue_idx < queues_.size());
return queues_[queue_idx]->NumThreads();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Release版本不会执行执行assert,这里会存在越界风险,添加检查或者改为 queues_.at(queue_idx)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

}

void AddTask(std::function<void()> fn) override {
queue_.AddTask(std::move(fn));
size_t WorkQueueGroupImpl::QueueGroupNumThreads() const {
size_t total_num = 0;
for (auto queue : queues_) {
total_num += queue->NumThreads();
}
return total_num;
}

void WaitQueueEmpty() override { queue_.WaitQueueEmpty(); }
} // namespace

size_t NumThreads() override { return queue_.NumThreads(); }
std::unique_ptr<WorkQueue> CreateSingleThreadedWorkQueue(
const WorkQueueOptions& options) {
assert(options.num_threads == 1);
std::unique_ptr<WorkQueue> ptr(new WorkQueueImpl(options));
return std::move(ptr);
}

private:
NonblockingThreadPool queue_;
};
std::unique_ptr<WorkQueue> CreateMultiThreadedWorkQueue(
const WorkQueueOptions& options) {
assert(options.num_threads > 1);
std::unique_ptr<WorkQueue> ptr(new WorkQueueImpl(options));
return std::move(ptr);
}

std::unique_ptr<WorkQueue> CreateMultiThreadedWorkQueue(int num_threads) {
std::unique_ptr<WorkQueue> ptr(new MultiThreadedWorkQueue(num_threads));
std::unique_ptr<WorkQueueGroup> CreateWorkQueueGroup(
const std::vector<WorkQueueOptions>& queues_options) {
assert(queues_options.size() > 1);
std::unique_ptr<WorkQueueGroup> ptr(new WorkQueueGroupImpl(queues_options));
return std::move(ptr);
}

Expand Down
Loading