Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work queue group #35470

Merged
merged 11 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/new_executor/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cc_library(workqueue SRCS workqueue.cc)
cc_library(workqueue SRCS workqueue.cc DEPS enforce)
cc_library(interpretercore SRCS interpretercore.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
garbages_.reset(new GarbageQueue());
max_memory_size_ = static_cast<size_t>(GetEagerDeletionThreshold());
cur_memory_size_ = 0;
gc_queue_ = CreateSingleThreadedWorkQueue();
WorkQueueOptions options;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

options不需要赋值吗?我看默认线程数是0。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修复

options.num_threads = 1;
gc_queue_ = CreateSingleThreadedWorkQueue(options);

feed_names_ = feed_names;

Expand Down
102 changes: 34 additions & 68 deletions paddle/fluid/framework/new_executor/nonblocking_threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,46 @@
namespace paddle {
namespace framework {

class CounterTracker {
class TaskTracker {
public:
explicit CounterTracker(std::atomic<uint64_t>* counter, EventCount* ec)
: counter_(counter), ec_(ec) {
counter_->fetch_add(1, std::memory_order_relaxed);
}
TaskTracker() : wait_empty_cv_(1) {}

~CounterTracker() {
if (counter_ != nullptr) {
if (1 == counter_->fetch_sub(1, std::memory_order_relaxed)) {
ec_->Notify(true);
}
}
}
TaskTracker(const TaskTracker&) = delete;

CounterTracker(CounterTracker&& other)
: counter_(other.counter_), ec_(other.ec_) {
other.counter_ = nullptr;
other.ec_ = nullptr;
}
TaskTracker& operator=(const TaskTracker&) = delete;

CounterTracker& operator=(CounterTracker&& other) {
counter_ = other.counter_;
ec_ = other.ec_;
other.counter_ = nullptr;
other.ec_ = nullptr;
return *this;
}
~TaskTracker() = default;

CounterTracker(const CounterTracker& other)
: counter_(other.counter_), ec_(other.ec_) {
counter_->fetch_add(1, std::memory_order_relaxed);
void AddCounter() { num_tasks_.fetch_add(1, std::memory_order_relaxed); }

void SubCounter() {
if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) {
wait_empty_cv_.Notify(true);
}
}

CounterTracker& operator=(const CounterTracker&) = delete;
// only one user can wait at any time
void WaitTaskNumToZero() {
bool waiting = false;
if (!wait_empty_.compare_exchange_strong(waiting, true,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
abort();
}
EventCount::Waiter* w = wait_empty_cv_.GetWaiter(0);
wait_empty_cv_.Prewait();
if (num_tasks_.load(std::memory_order_relaxed) == 0) {
wait_empty_cv_.CancelWait();
} else {
wait_empty_cv_.CommitWait(w);
}
wait_empty_.store(false);
}

private:
std::atomic<uint64_t>* counter_{nullptr};
EventCount* ec_{nullptr};
std::atomic<uint64_t> num_tasks_{0};
EventCount wait_empty_cv_;
std::atomic<bool> wait_empty_{false};
};

template <typename Environment>
Expand All @@ -66,9 +67,6 @@ class ThreadPoolTempl {
typedef typename Environment::Task Task;
typedef RunQueue<Task, 1024> Queue;

explicit ThreadPoolTempl(int num_threads, Environment env = Environment())
: ThreadPoolTempl(num_threads, true, env) {}

ThreadPoolTempl(int num_threads, bool allow_spinning,
Environment env = Environment())
: env_(env),
Expand All @@ -80,10 +78,7 @@ class ThreadPoolTempl {
spinning_(0),
done_(false),
cancelled_(false),
ec_(num_threads_),
wait_empty_(false),
wait_empty_ec_(1),
num_tasks_(0) {
ec_(num_threads_) {
// Calculate coprimes of all numbers [1, num_threads].
// Coprimes are used for random walks over all threads in Steal
// and NonEmptyQueueIndex. Iteration is based on the fact that if we take
Expand Down Expand Up @@ -146,15 +141,13 @@ class ThreadPoolTempl {
}

void AddTaskWithHint(std::function<void()> fn, int start, int limit) {
Task t = env_.CreateTask([
task = std::move(fn), raii = CounterTracker(&num_tasks_, &wait_empty_ec_)
]() mutable { task(); });
Task t = env_.CreateTask(std::move(fn));
PerThread* pt = GetPerThread();
if (pt->pool == this) {
// Worker thread of this pool, push onto the thread's queue.
Queue& q = thread_data_[pt->thread_id].queue;
t = q.PushFront(std::move(t));
} else if (wait_empty_.load() == false) {
} else {
// A free-standing thread (or worker of another pool), push onto a random
// queue.
assert(start < limit);
Expand All @@ -179,29 +172,6 @@ class ThreadPoolTempl {
}
}

void WaitQueueEmpty() {
bool waiting = wait_empty_.load();
assert(waiting == false);
if (waiting ||
!wait_empty_.compare_exchange_strong(waiting, true,
std::memory_order_acquire)) {
abort();
}
EventCount::Waiter* w = wait_empty_ec_.GetWaiter(0);
wait_empty_ec_.Prewait();
if (num_tasks_.load() == 0) {
wait_empty_ec_.CancelWait();
} else {
wait_empty_ec_.CommitWait(w);
}
waiting = true;
if (!waiting ||
!wait_empty_.compare_exchange_strong(waiting, false,
std::memory_order_acquire)) {
abort();
}
}

void Cancel() {
cancelled_ = true;
done_ = true;
Expand Down Expand Up @@ -300,10 +270,6 @@ class ThreadPoolTempl {
std::atomic<bool> cancelled_;
EventCount ec_;

std::atomic<bool> wait_empty_;
EventCount wait_empty_ec_;
std::atomic<uint64_t> num_tasks_;

// Main worker thread loop.
void WorkerLoop(int thread_id) {
PerThread* pt = GetPerThread();
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/new_executor/run_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class RunQueue {
private:
static const unsigned kMask = kSize - 1;
static const unsigned kMask2 = (kSize << 1) - 1;
struct Elem {
struct alignas(64) Elem {
std::atomic<uint8_t> state;
Work w;
};
Expand All @@ -212,8 +212,8 @@ class RunQueue {
// position, these conditions would be indistinguishable); (2) obtain
// consistent snapshot of front_/back_ for Size operation using the
// modification counters.
std::atomic<unsigned> front_;
std::atomic<unsigned> back_;
alignas(64) std::atomic<unsigned> front_;
alignas(64) std::atomic<unsigned> back_;
Elem array_[kSize];

// SizeOrNotEmpty returns current queue size; if NeedSizeEstimate is false,
Expand Down
165 changes: 135 additions & 30 deletions paddle/fluid/framework/new_executor/workqueue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,63 +6,168 @@

#include "paddle/fluid/framework/new_executor/workqueue.h"
#include "paddle/fluid/framework/new_executor/nonblocking_threadpool.h"
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace {

class SingleThreadedWorkQueue : public WorkQueue {
class WorkQueueImpl : public WorkQueue {
public:
SingleThreadedWorkQueue() : queue_(1) {}

SingleThreadedWorkQueue(const SingleThreadedWorkQueue&) = delete;

SingleThreadedWorkQueue& operator=(const SingleThreadedWorkQueue&) = delete;
explicit WorkQueueImpl(const WorkQueueOptions& options)
: WorkQueue(options), queue_(nullptr), tracker_(nullptr) {
if (options_.track_task) {
tracker_ = new TaskTracker;
}
queue_ = new NonblockingThreadPool(options_.num_threads,
options_.allow_spinning);
}

virtual ~SingleThreadedWorkQueue() = default;
virtual ~WorkQueueImpl() {
delete tracker_;
delete queue_;
}

void AddTask(std::function<void()> fn) override {
queue_.AddTask(std::move(fn));
if (tracker_ != nullptr) {
fn = [
task = std::move(fn), raii = CounterGuard<TaskTracker>(tracker_)
]() mutable {
task();
};
}
queue_->AddTask(std::move(fn));
}

void WaitQueueEmpty() override { queue_.WaitQueueEmpty(); }
void WaitQueueEmpty() override {
if (tracker_ == nullptr) {
PADDLE_THROW(
platform::errors::Unavailable("set WorkQueueOptions.track_task = "
"true before call this interface."));
}
tracker_->WaitTaskNumToZero();
}

size_t NumThreads() override { return queue_.NumThreads(); }
size_t NumThreads() const override { return queue_->NumThreads(); }

private:
NonblockingThreadPool queue_;
NonblockingThreadPool* queue_;
TaskTracker* tracker_;
};

std::unique_ptr<WorkQueue> CreateSingleThreadedWorkQueue() {
std::unique_ptr<WorkQueue> ptr(new SingleThreadedWorkQueue);
return std::move(ptr);
class WorkQueueGroupImpl : public WorkQueueGroup {
public:
explicit WorkQueueGroupImpl(
const std::vector<WorkQueueOptions>& queue_options);

~WorkQueueGroupImpl();

void AddTask(size_t queue_idx, std::function<void()> fn) override;

void WaitQueueGroupEmpty() override;

size_t QueueNumThreads(size_t queue_idx) const override;

size_t QueueGroupNumThreads() const override;

private:
std::vector<NonblockingThreadPool*> queues_;
NonblockingThreadPool* queues_storage_;
TaskTracker* tracker_;
};

WorkQueueGroupImpl::WorkQueueGroupImpl(
const std::vector<WorkQueueOptions>& queues_options)
: WorkQueueGroup(queues_options),
queues_storage_(nullptr),
tracker_(nullptr) {
size_t num_queues = queues_options_.size();
queues_.resize(num_queues);
void* buffer = malloc(sizeof(NonblockingThreadPool) * num_queues);
queues_storage_ = reinterpret_cast<NonblockingThreadPool*>(buffer);
for (size_t idx = 0; idx < num_queues; ++idx) {
const auto& options = queues_options_[idx];
if (options.track_task && tracker_ == nullptr) {
tracker_ = new TaskTracker;
}
queues_[idx] = new (&queues_storage_[idx])
NonblockingThreadPool(options.num_threads, options.allow_spinning);
}
}

class MultiThreadedWorkQueue : public WorkQueue {
public:
explicit MultiThreadedWorkQueue(int num_threads) : queue_(num_threads) {
assert(num_threads > 1);
WorkQueueGroupImpl::~WorkQueueGroupImpl() {
for (auto queue : queues_) {
queue->~NonblockingThreadPool();
}
delete tracker_;
free(queues_storage_);
}

MultiThreadedWorkQueue(const MultiThreadedWorkQueue&) = delete;
void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function<void()> fn) {
assert(queue_idx < queues_.size());
if (queues_options_.at(queue_idx).track_task) {
fn = [
task = std::move(fn), raii = CounterGuard<TaskTracker>(tracker_)
]() mutable {
task();
};
}
queues_[queue_idx]->AddTask(std::move(fn));
}

MultiThreadedWorkQueue& operator=(const MultiThreadedWorkQueue&) = delete;
void WorkQueueGroupImpl::WaitQueueGroupEmpty() {
if (nullptr == tracker_) {
PADDLE_THROW(platform::errors::Unavailable(
"set WorkQueueOptions.track_task = true for at least one of queues "
"before call this interface."));
}
tracker_->WaitTaskNumToZero();
}

virtual ~MultiThreadedWorkQueue() = default;
size_t WorkQueueGroupImpl::QueueNumThreads(size_t queue_idx) const {
assert(queue_idx < queues_.size());
return queues_.at(queue_idx)->NumThreads();
}

void AddTask(std::function<void()> fn) override {
queue_.AddTask(std::move(fn));
size_t WorkQueueGroupImpl::QueueGroupNumThreads() const {
size_t total_num = 0;
for (auto queue : queues_) {
total_num += queue->NumThreads();
}
return total_num;
}

void WaitQueueEmpty() override { queue_.WaitQueueEmpty(); }
} // namespace

size_t NumThreads() override { return queue_.NumThreads(); }
std::unique_ptr<WorkQueue> CreateSingleThreadedWorkQueue(
const WorkQueueOptions& options) {
PADDLE_ENFORCE_EQ(options.num_threads, 1u,
platform::errors::InvalidArgument(
"For a SingleThreadedWorkQueue, "
"WorkQueueOptions.num_threads must equals to 1."));
std::unique_ptr<WorkQueue> ptr(new WorkQueueImpl(options));
return std::move(ptr);
}

private:
NonblockingThreadPool queue_;
};
std::unique_ptr<WorkQueue> CreateMultiThreadedWorkQueue(
const WorkQueueOptions& options) {
PADDLE_ENFORCE_GT(
options.num_threads, 1u,
platform::errors::InvalidArgument("For a MultiThreadedWorkQueue, "
"WorkQueueOptions.num_threads must be "
"greater than 1."));
std::unique_ptr<WorkQueue> ptr(new WorkQueueImpl(options));
return std::move(ptr);
}

std::unique_ptr<WorkQueue> CreateMultiThreadedWorkQueue(int num_threads) {
std::unique_ptr<WorkQueue> ptr(new MultiThreadedWorkQueue(num_threads));
std::unique_ptr<WorkQueueGroup> CreateWorkQueueGroup(
const std::vector<WorkQueueOptions>& queues_options) {
PADDLE_ENFORCE_GT(queues_options.size(), 1u,
platform::errors::InvalidArgument(
"For a WorkQueueGroup, the number of WorkQueueOptions "
"must be greater than 1."));
std::unique_ptr<WorkQueueGroup> ptr(new WorkQueueGroupImpl(queues_options));
return std::move(ptr);
}

Expand Down
Loading