-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] C++ implementation of parallel executor #9035
Closed
tonyyang-svail
wants to merge
29
commits into
PaddlePaddle:develop
from
tonyyang-svail:cpp_parallel_executor
Closed
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
0621c32
init commit
e67325c
update readme
8f061e4
delete param name
a62d142
better name
65c7ed5
pass run
7fa64ee
switch to larger network
003b165
use all device
b2f4c5a
fix compile
07ee125
add broadcast
c3d6b86
add resnet
0760aaf
Shrink batch_norm_grad's inputs
reyoung 8699e4e
add vgg_bn_drop
4506d39
merge 9299
f07c25e
hook up data reader for multi-gpu executor example
helinwang b343ce3
add embedding
0a6a552
merge helin
5f1127c
take the device context improvement from reyoung; turn on thread safe…
helinwang 8b9884b
run multi gpu with recordio reader
bb07417
add share comment
9e5d957
change name
33ada99
scope: add replicas, used for multi gpu executor
helinwang 27d17e0
add wait on executor
7aad021
Add test parallel executor and transformer model from reyoung's PR
helinwang 069b726
test multi gpu executor in test_parallel_executor.py
helinwang 75d89f6
Merge remote-tracking branch 'pr/9035' into cpp_parallel_executor
924cada
make bn inplace in img_conv_group by default
799446b
add inplace attr to bn
e67070b
merge append act inplace
41f1a87
add in place
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# ParallelExecutor Design Doc | ||
|
||
## Introduction | ||
|
||
We introduce `ParallelExecutor` to run multi-GPU training in PaddlePaddle Fluid. It supports | ||
1. keeping a copy of the parameters on each GPU | ||
1. allreduce on a separate stream allowing computation and communication overlap | ||
|
||
An example of switching single GPU training to multiple GPUs: | ||
```python | ||
cost = your_neural_network() | ||
opt = fluid.optimizer.SGDOptimizer() | ||
opt.minimize(avg_cost) | ||
|
||
# change Executor -> ParallelExecutor | ||
exe = fluid.ParallelExecutor(gpu_list=[0, 1]) | ||
|
||
for iter in xranges(iter_num): | ||
exe.run() | ||
``` | ||
|
||
## Design | ||
|
||
In the constructor, a list of parameter, whose gradients need to be allreduced, is given. | ||
|
||
During the runtime, `ParallelExecutor` starts `#gpu` threads to run each `Executor`. For every | ||
operator run on each GPU, it will automatically sync with different streams when necessary. | ||
|
||
```c++ | ||
// if op's input is params' grad: | ||
// sync with allreduce stream | ||
// e.g. sgd should wait for allreduce to be finished | ||
CallBack->BeforeOp(op); | ||
|
||
op->Run(*local_scope, place_); | ||
|
||
// if op's output is params' grad: | ||
// sync with computation stream | ||
// e.g. allreduce shoudl wait for fc_grad to be finished. | ||
CallBack->AfterOp(op); | ||
``` | ||
|
||
And the `Callback` object can be implemented as the following | ||
|
||
```c++ | ||
struct AllReduceCallBack { | ||
void BeforeOp(framework::OperatorBase* op); | ||
void AfterOp(framework::OperatorBase* op); | ||
|
||
std::unordered_set<std::string> reduced_param_grad_names; | ||
std::unordered_set<std::string> param_grad_names_; | ||
|
||
platform::DeviceContext* computation_dev_ctx; // computation device context | ||
platform::DeviceContext* communication_dev_ctx; // communication device context | ||
|
||
framework::Scope* scope; | ||
platform::NCCL::Communicator* nccl_com; | ||
}; | ||
|
||
AllReduceCallBack::BeforeOp(framework::OperatorBase* op) { | ||
if (op->Input() in reduced_param_grad_names) { | ||
communication_dev_ctx->Wait(); | ||
reduced_param_grad_names.erase(op->Input()) | ||
} | ||
} | ||
|
||
AllReduceCallBack::AfterOp(framework::OperatorBase* op) { | ||
if (op->Output() in param_grad_names) { | ||
computation_dev_ctx->Wait(); | ||
reduced_param_grad_names.insert(op->Output()); | ||
ncclAllreduce(scope, op->Output(), communication_dev_ctx); | ||
} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/framework/multi_gpu_executor.h" | ||
|
||
#include <thread> | ||
#include "paddle/fluid/framework/operator.h" | ||
#include "paddle/fluid/framework/reader.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
ExecutorWithAllReduce::ExecutorWithAllReduce( | ||
const platform::Place& place, std::unordered_set<std::string>* param_grads, | ||
NCCLContext* nccl_context) | ||
: Executor(place), param_grads_(param_grads) { | ||
int device_id = boost::get<platform::CUDAPlace>(place).device; | ||
comm_ = &nccl_context->comms_[device_id]; | ||
io_ctx_ = nccl_context->ctxs_[device_id]; | ||
} | ||
|
||
// TODO(yy): Move this function somewhere | ||
ncclDataType_t ToNCCLDataType(std::type_index type) { | ||
if (type == typeid(float)) { // NOLINT | ||
return ncclFloat; | ||
} else if (type == typeid(double)) { // NOLINT | ||
return ncclDouble; | ||
} else if (type == typeid(int)) { // NOLINT | ||
return ncclInt; | ||
} else { | ||
PADDLE_THROW("Not supported"); | ||
} | ||
} | ||
|
||
void ExecutorWithAllReduce::RunOperators(const ExecutorPrepareContext* ctx, | ||
const Scope* local_scope) const { | ||
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device); | ||
|
||
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); | ||
auto* dev_ctx = pool.Get(place_); | ||
cudaStream_t computation_stream = | ||
reinterpret_cast<const platform::CUDADeviceContext*>(dev_ctx)->stream(); | ||
cudaStream_t all_reduce_stream = io_ctx_->stream(); | ||
|
||
std::unordered_map<std::string, cudaEvent_t> computation_event; | ||
std::unordered_map<std::string, cudaEvent_t> all_reduce_event; | ||
for (auto& argu : *param_grads_) { | ||
PADDLE_ENFORCE(cudaEventCreateWithFlags(&computation_event[argu], | ||
cudaEventDisableTiming)); | ||
PADDLE_ENFORCE(cudaEventCreateWithFlags(&all_reduce_event[argu], | ||
cudaEventDisableTiming)); | ||
} | ||
|
||
for (auto& op : ctx->ops_) { | ||
// sgd should wait for allreduce finished | ||
for (auto& param2argu : op->Inputs()) { | ||
for (auto& argu : param2argu.second) { | ||
if (param_grads_->count(argu) != 0) { | ||
VLOG(5) << place_ << " " << op->Type() << param2argu.first << " " | ||
<< argu; | ||
PADDLE_ENFORCE(cudaStreamWaitEvent(computation_stream, | ||
all_reduce_event[argu], 0)); | ||
} | ||
} | ||
} | ||
|
||
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope); | ||
op->Run(*local_scope, place_); | ||
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); | ||
|
||
for (auto& param2argu : op->Outputs()) { | ||
for (auto& argu : param2argu.second) { | ||
if (param_grads_->count(argu) != 0) { | ||
VLOG(5) << place_ << " " << op->Type() << " Launch allreduce on " | ||
<< argu; | ||
|
||
PADDLE_ENFORCE( | ||
cudaEventRecord(computation_event[argu], computation_stream)); | ||
PADDLE_ENFORCE(cudaStreamWaitEvent(all_reduce_stream, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to block the next computation op. We should profile and see how much it hurts |
||
computation_event[argu], 0)); | ||
|
||
auto& tensor = local_scope->FindVar(argu)->Get<LoDTensor>(); | ||
void* data = const_cast<void*>(tensor.data<void>()); | ||
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( | ||
data, data, tensor.numel(), ToNCCLDataType(tensor.type()), | ||
ncclSum, *comm_, all_reduce_stream)); | ||
|
||
PADDLE_ENFORCE( | ||
cudaEventRecord(all_reduce_event[argu], all_reduce_stream)); | ||
} | ||
} | ||
} | ||
} | ||
|
||
cudaStreamSynchronize(computation_stream); | ||
cudaStreamSynchronize(all_reduce_stream); | ||
for (auto& argu : *param_grads_) { | ||
PADDLE_ENFORCE(cudaEventDestroy(computation_event[argu])); | ||
PADDLE_ENFORCE(cudaEventDestroy(all_reduce_event[argu])); | ||
} | ||
} | ||
|
||
MultiGPUExecutor::MultiGPUExecutor( | ||
const std::vector<platform::Place>& places, | ||
const std::unordered_set<std::string>& params) | ||
: nccl_ctx_(places), params_(params) { | ||
for (auto& param : params) { | ||
param_grads_.insert(GradVarName(param)); | ||
} | ||
for (size_t i = 0; i < places.size(); ++i) { | ||
auto& place = places[i]; | ||
exes_.push_back( | ||
framework::ExecutorWithAllReduce(place, ¶m_grads_, &nccl_ctx_)); | ||
} | ||
} | ||
|
||
void MultiGPUExecutor::Init(const ProgramDesc& prog, Scope* scope, int block_id, | ||
bool create_local_scope, bool create_vars) { | ||
// init parameters on one device | ||
auto replicas = scope->replicas(exes_.size() - 1); | ||
exes_[0].Run(prog, scope, block_id, create_local_scope, create_vars); | ||
|
||
for (auto* var_desc : prog.Block(0).AllVars()) { | ||
if (var_desc->GetType() == proto::VarType::LOD_TENSOR) { | ||
auto& main_tensor = scope->FindVar(var_desc->Name())->Get<LoDTensor>(); | ||
ncclDataType_t data_type = ToNCCLDataType(main_tensor.type()); | ||
auto& dims = main_tensor.dims(); | ||
size_t numel = main_tensor.numel(); | ||
|
||
platform::dynload::ncclGroupStart(); | ||
for (size_t i = 0; i < exes_.size(); ++i) { | ||
void* buffer; | ||
if (i == 0) { | ||
buffer = const_cast<void*>(main_tensor.data<void>()); | ||
} else { | ||
auto local_scope = replicas[i - 1]; | ||
auto* t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>(); | ||
t->Resize(dims); | ||
buffer = t->mutable_data(exes_[i].place_, main_tensor.type()); | ||
} | ||
|
||
platform::dynload::ncclBcast(buffer, numel, data_type, 0, | ||
nccl_ctx_.comms_[i], | ||
nccl_ctx_.ctxs_[i]->stream()); | ||
} | ||
platform::dynload::ncclGroupEnd(); | ||
} else if (var_desc->GetType() == proto::VarType::READER) { | ||
VLOG(4) << "Copy reader " << var_desc->Name(); | ||
auto& reader = | ||
scope->FindVar(var_desc->Name())->Get<framework::ReaderHolder>(); | ||
for (size_t i = 0; i < replicas.size(); ++i) { | ||
auto* reader_dup = replicas[i] | ||
->Var(var_desc->Name()) | ||
->GetMutable<framework::ReaderHolder>(); | ||
*reader_dup = reader; | ||
} | ||
} | ||
} | ||
} | ||
|
||
void MultiGPUExecutor::Run(const ProgramDesc& prog, Scope* scope, int block_id, | ||
bool create_local_scope, bool create_vars) { | ||
// prepare prog in a single thread to avoid race | ||
auto context = exes_[0].Prepare(prog, block_id); | ||
auto replicas = scope->replicas(exes_.size() - 1); | ||
std::vector<std::thread> threads; | ||
for (size_t i = 0; i < exes_.size(); ++i) { | ||
threads.push_back(std::thread([&, i] { | ||
Scope* cur_scope; | ||
if (i == 0) { | ||
cur_scope = scope; | ||
} else { | ||
cur_scope = replicas[i - 1].get(); | ||
} | ||
exes_[i].RunPreparedContext(context.get(), cur_scope, create_local_scope, | ||
create_vars); | ||
})); | ||
} | ||
|
||
for (auto& t : threads) { | ||
t.join(); | ||
} | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of looping every input every time, perhaps the op can cache the param inputs and only wait for them