-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Work queue group #35470
Work queue group #35470
Changes from 5 commits
256dc2e
5fd3aaf
c6e1fb5
849cd0d
b772cd5
dff2edf
ec70e5c
13234c6
2306737
c0f91fc
57791b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,63 +6,153 @@ | |
|
||
#include "paddle/fluid/framework/new_executor/workqueue.h" | ||
#include "paddle/fluid/framework/new_executor/nonblocking_threadpool.h" | ||
#include "paddle/fluid/framework/new_executor/workqueue_utils.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace { | ||
|
||
class SingleThreadedWorkQueue : public WorkQueue { | ||
class WorkQueueImpl : public WorkQueue { | ||
public: | ||
SingleThreadedWorkQueue() : queue_(1) {} | ||
|
||
SingleThreadedWorkQueue(const SingleThreadedWorkQueue&) = delete; | ||
|
||
SingleThreadedWorkQueue& operator=(const SingleThreadedWorkQueue&) = delete; | ||
explicit WorkQueueImpl(const WorkQueueOptions& options) | ||
: WorkQueue(options), queue_(nullptr), tracker_(nullptr) { | ||
if (options_.track_task) { | ||
tracker_ = new TaskTracker; | ||
} | ||
queue_ = new NonblockingThreadPool(options_.num_threads, | ||
options_.allow_spinning); | ||
} | ||
|
||
virtual ~SingleThreadedWorkQueue() = default; | ||
virtual ~WorkQueueImpl() { | ||
delete tracker_; | ||
delete queue_; | ||
} | ||
|
||
void AddTask(std::function<void()> fn) override { | ||
queue_.AddTask(std::move(fn)); | ||
if (tracker_ != nullptr) { | ||
fn = [ | ||
task = std::move(fn), raii = CounterGuard<TaskTracker>(tracker_) | ||
]() mutable { | ||
task(); | ||
}; | ||
} | ||
queue_->AddTask(std::move(fn)); | ||
} | ||
|
||
void WaitQueueEmpty() override { queue_.WaitQueueEmpty(); } | ||
void WaitQueueEmpty() override { | ||
if (tracker_ == nullptr) { | ||
abort(); | ||
} | ||
tracker_->WaitTaskNumToZero(); | ||
} | ||
|
||
size_t NumThreads() override { return queue_.NumThreads(); } | ||
size_t NumThreads() const override { return queue_->NumThreads(); } | ||
|
||
private: | ||
NonblockingThreadPool queue_; | ||
NonblockingThreadPool* queue_; | ||
TaskTracker* tracker_; | ||
}; | ||
|
||
std::unique_ptr<WorkQueue> CreateSingleThreadedWorkQueue() { | ||
std::unique_ptr<WorkQueue> ptr(new SingleThreadedWorkQueue); | ||
return std::move(ptr); | ||
class WorkQueueGroupImpl : public WorkQueueGroup { | ||
public: | ||
explicit WorkQueueGroupImpl( | ||
const std::vector<WorkQueueOptions>& queue_options); | ||
|
||
~WorkQueueGroupImpl(); | ||
|
||
void AddTask(size_t queue_idx, std::function<void()> fn) override; | ||
|
||
void WaitQueueGroupEmpty() override; | ||
|
||
size_t QueueNumThreads(size_t queue_idx) const override; | ||
|
||
size_t QueueGroupNumThreads() const override; | ||
|
||
private: | ||
std::vector<NonblockingThreadPool*> queues_; | ||
NonblockingThreadPool* queues_storage_; | ||
TaskTracker* tracker_; | ||
}; | ||
|
||
WorkQueueGroupImpl::WorkQueueGroupImpl( | ||
const std::vector<WorkQueueOptions>& queues_options) | ||
: WorkQueueGroup(queues_options), | ||
queues_storage_(nullptr), | ||
tracker_(nullptr) { | ||
size_t num_queues = queues_options_.size(); | ||
queues_.resize(num_queues); | ||
void* buffer = malloc(sizeof(NonblockingThreadPool) * num_queues); | ||
queues_storage_ = reinterpret_cast<NonblockingThreadPool*>(buffer); | ||
for (size_t idx = 0; idx < num_queues; ++idx) { | ||
const auto& options = queues_options_[idx]; | ||
if (options.track_task && tracker_ == nullptr) { | ||
tracker_ = new TaskTracker; | ||
} | ||
queues_[idx] = new (&queues_storage_[idx]) | ||
NonblockingThreadPool(options.num_threads, options.allow_spinning); | ||
} | ||
} | ||
|
||
class MultiThreadedWorkQueue : public WorkQueue { | ||
public: | ||
explicit MultiThreadedWorkQueue(int num_threads) : queue_(num_threads) { | ||
assert(num_threads > 1); | ||
WorkQueueGroupImpl::~WorkQueueGroupImpl() { | ||
for (auto queue : queues_) { | ||
queue->~NonblockingThreadPool(); | ||
} | ||
delete tracker_; | ||
free(queues_storage_); | ||
} | ||
|
||
MultiThreadedWorkQueue(const MultiThreadedWorkQueue&) = delete; | ||
void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function<void()> fn) { | ||
assert(queue_idx < queues_.size()); | ||
if (queues_options_[queue_idx].track_task) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是不是可以直接 if(nullptr == track_task_) ? 这里是支持 queues_options_ 部分track_task 为False, 部分为true? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的 |
||
fn = [ | ||
task = std::move(fn), raii = CounterGuard<TaskTracker>(tracker_) | ||
]() mutable { | ||
task(); | ||
}; | ||
} | ||
queues_[queue_idx]->AddTask(std::move(fn)); | ||
} | ||
|
||
MultiThreadedWorkQueue& operator=(const MultiThreadedWorkQueue&) = delete; | ||
void WorkQueueGroupImpl::WaitQueueGroupEmpty() { | ||
if (nullptr == tracker_) { | ||
abort(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果tracker_为空,这里是不是给一些error信息,表示我们不支持在未指定track_task时,调用WaitQueueGroupEmpty接口,直接abort会让开发者很困惑? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 修改为PADDLE_THROW |
||
} | ||
tracker_->WaitTaskNumToZero(); | ||
} | ||
|
||
virtual ~MultiThreadedWorkQueue() = default; | ||
size_t WorkQueueGroupImpl::QueueNumThreads(size_t queue_idx) const { | ||
assert(queue_idx < queues_.size()); | ||
return queues_[queue_idx]->NumThreads(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Release版本不会执行执行assert,这里会存在越界风险,添加检查或者改为 queues_.at(queue_idx) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
} | ||
|
||
void AddTask(std::function<void()> fn) override { | ||
queue_.AddTask(std::move(fn)); | ||
size_t WorkQueueGroupImpl::QueueGroupNumThreads() const { | ||
size_t total_num = 0; | ||
for (auto queue : queues_) { | ||
total_num += queue->NumThreads(); | ||
} | ||
return total_num; | ||
} | ||
|
||
void WaitQueueEmpty() override { queue_.WaitQueueEmpty(); } | ||
} // namespace | ||
|
||
size_t NumThreads() override { return queue_.NumThreads(); } | ||
std::unique_ptr<WorkQueue> CreateSingleThreadedWorkQueue( | ||
const WorkQueueOptions& options) { | ||
assert(options.num_threads == 1); | ||
std::unique_ptr<WorkQueue> ptr(new WorkQueueImpl(options)); | ||
return std::move(ptr); | ||
} | ||
|
||
private: | ||
NonblockingThreadPool queue_; | ||
}; | ||
std::unique_ptr<WorkQueue> CreateMultiThreadedWorkQueue( | ||
const WorkQueueOptions& options) { | ||
assert(options.num_threads > 1); | ||
std::unique_ptr<WorkQueue> ptr(new WorkQueueImpl(options)); | ||
return std::move(ptr); | ||
} | ||
|
||
std::unique_ptr<WorkQueue> CreateMultiThreadedWorkQueue(int num_threads) { | ||
std::unique_ptr<WorkQueue> ptr(new MultiThreadedWorkQueue(num_threads)); | ||
std::unique_ptr<WorkQueueGroup> CreateWorkQueueGroup( | ||
const std::vector<WorkQueueOptions>& queues_options) { | ||
assert(queues_options.size() > 1); | ||
std::unique_ptr<WorkQueueGroup> ptr(new WorkQueueGroupImpl(queues_options)); | ||
return std::move(ptr); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
options不需要赋值吗?我看默认线程数是0。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修复