-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature/parallel_do #6730
feature/parallel_do #6730
Changes from 7 commits
9d2c77e
d34ed7b
973aec2
aea5ccc
b2ee919
2f56d4b
f899150
f879ef2
cb0b81f
6004a2e
9313233
8ee17e9
7411df3
fccbc2f
97dc451
60e27d1
8496b2e
0156066
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -179,10 +179,13 @@ static const Tensor* GetTensorFromVar(const Variable* var) { | |
const Tensor* t = nullptr; | ||
if (var->IsType<LoDTensor>()) { | ||
t = &(var->Get<LoDTensor>()); | ||
} else if (var->IsType<Tensor>()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A little confused about it: do the ParallelDo operator need it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a leftover from debugging. Sorry for the confusion. |
||
t = &(var->Get<Tensor>()); | ||
} else if (var->IsType<SelectedRows>()) { | ||
t = &(var->Get<SelectedRows>().value()); | ||
} else { | ||
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); | ||
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. expect LoDTensor/SelectedRows => expect LoDTensor/SelectedRows/Tensor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a leftover from debugging. Sorry for the confusion. |
||
var->Type().name()); | ||
} | ||
return t; | ||
} | ||
|
@@ -191,10 +194,13 @@ static Tensor* GetMutableTensorFromVar(Variable* var) { | |
Tensor* t = nullptr; | ||
if (var->IsType<LoDTensor>()) { | ||
t = var->GetMutable<LoDTensor>(); | ||
} else if (var->IsType<Tensor>()) { | ||
t = var->GetMutable<Tensor>(); | ||
} else if (var->IsType<SelectedRows>()) { | ||
t = var->GetMutable<SelectedRows>()->mutable_value(); | ||
} else { | ||
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); | ||
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. expect LoDTensor/SelectedRows => expect LoDTensor/SelectedRows/Tensor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a leftover from debugging. Sorry for the confusion. |
||
var->Type().name()); | ||
} | ||
return t; | ||
} | ||
|
@@ -359,21 +365,27 @@ class RuntimeInferShapeContext : public InferShapeContext { | |
Variable* var = scope_.FindVar(name); | ||
if (var->IsType<LoDTensor>()) { | ||
return var->Get<LoDTensor>().dims(); | ||
} else if (var->IsType<Tensor>()) { | ||
return var->Get<Tensor>().dims(); | ||
} else if (var->IsType<SelectedRows>()) { | ||
return var->Get<SelectedRows>().GetCompleteDims(); | ||
} else { | ||
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); | ||
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a leftover from debugging. Sorry for the confusion. |
||
name, var->Type().name()); | ||
} | ||
} | ||
|
||
void SetDim(const std::string& name, const DDim& dim) override { | ||
Variable* var = scope_.FindVar(name); | ||
if (var->IsType<LoDTensor>()) { | ||
var->GetMutable<LoDTensor>()->Resize(dim); | ||
} else if (var->IsType<Tensor>()) { | ||
var->GetMutable<Tensor>()->Resize(dim); | ||
} else if (var->IsType<SelectedRows>()) { | ||
var->GetMutable<SelectedRows>()->set_height(dim[0]); | ||
} else { | ||
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); | ||
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a leftover from debugging. Sorry for the confusion. |
||
name, var->Type().name()); | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The license format is not correct, please refer to #7022 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
|
||
#include <vector> | ||
#include "chunk_eval_op.h" | ||
#include "paddle/framework/executor.h" | ||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/framework/operator.h" | ||
#include "paddle/platform/place.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
constexpr char kInputs[] = "inputs"; | ||
constexpr char kParameters[] = "parameters"; | ||
constexpr char kPlaces[] = "places"; | ||
|
||
constexpr char kOutputs[] = "outputs"; | ||
constexpr char kParallelScopes[] = "parallel_scopes"; | ||
|
||
constexpr char kParallelBlock[] = "sub_block"; | ||
|
||
using ParallelScopeVar = std::vector<framework::Scope *>; | ||
using OperatorBase = framework::OperatorBase; | ||
|
||
class ParallelDoOp : public OperatorBase { | ||
public: | ||
ParallelDoOp(const std::string &type, | ||
const framework::VariableNameMap &inputs, | ||
const framework::VariableNameMap &outputs, | ||
const framework::AttributeMap &attrs) | ||
: OperatorBase(type, inputs, outputs, attrs) {} | ||
|
||
void Run(const framework::Scope &scope, | ||
const platform::DeviceContext &dev_ctx) const override { | ||
auto *block = Attr<framework::BlockDescBind *>(kParallelBlock); | ||
auto *program = block->Program(); | ||
|
||
// TODO(tonyyang-svail): get places from input | ||
std::vector<platform::Place> places; | ||
places.emplace_back(platform::CPUPlace()); | ||
places.emplace_back(platform::CPUPlace()); | ||
|
||
std::vector<framework::Scope *> sub_scopes; | ||
for (int place_idx = 0; place_idx < places.size(); ++place_idx) { | ||
VLOG(3) << "Run " << place_idx; | ||
|
||
sub_scopes.push_back(&scope.NewScope()); | ||
|
||
auto &place = places[place_idx]; | ||
auto *cur_scope = sub_scopes[place_idx]; | ||
|
||
// copy parameter | ||
if (dev_ctx.GetPlace() != place) { | ||
PADDLE_THROW("Not Implemented"); | ||
} | ||
|
||
// feed input | ||
for (auto &argu : Inputs(kInputs)) { | ||
auto *var = scope.FindVar(argu); | ||
const auto &tensor = var->Get<LoDTensor>(); | ||
if (!tensor.lod().empty()) { | ||
PADDLE_THROW("Disable parallel lod for now"); | ||
} else { | ||
PADDLE_ENFORCE(tensor.dims()[0] % places.size() == 0, | ||
"Batch size should be divided by places size"); | ||
int begin = place_idx * tensor.dims()[0] / places.size(); | ||
int end = (place_idx + 1) * tensor.dims()[0] / places.size(); | ||
auto feed_tensor = tensor.Slice(begin, end); | ||
feed_tensor.switch_place(place); | ||
|
||
auto *cur_var = cur_scope->Var(argu); | ||
auto *cur_tensor = cur_var->GetMutable<Tensor>(); | ||
*cur_tensor = feed_tensor; | ||
} | ||
} | ||
|
||
// execute | ||
auto executor = framework::Executor(place); | ||
executor.Run(*program, cur_scope, block->ID(), | ||
false /*create_local_scope*/); | ||
} | ||
|
||
// merge output | ||
for (auto &o_name : Outputs(kOutputs)) { | ||
std::vector<const framework::LoDTensor *> lod_tensors; | ||
for (auto *sub_scope : sub_scopes) { | ||
lod_tensors.push_back(&sub_scope->FindVar(o_name)->Get<LoDTensor>()); | ||
} | ||
|
||
auto *lod_tensor_to_be_merged = | ||
scope.FindVar(o_name)->GetMutable<LoDTensor>(); | ||
lod_tensor_to_be_merged->MergeLoDTensor(lod_tensors, dev_ctx.GetPlace()); | ||
} | ||
} | ||
}; | ||
|
||
class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
ParallelDoOpProtoMaker(framework::OpProto *proto, | ||
framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput(kInputs, "").AsDuplicable(); | ||
AddInput(kParameters, "").AsDuplicable(); | ||
AddInput(kPlaces, ""); | ||
AddOutput(kOutputs, "").AsDuplicable(); | ||
AddOutput(kParallelScopes, ""); | ||
AddAttr<framework::BlockDescBind *>(kParallelBlock, ""); | ||
AddComment(R"DOC( | ||
ParallelDo Operator. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
class ParallelDoGradOp : public OperatorBase { | ||
public: | ||
ParallelDoGradOp(const std::string &type, | ||
const framework::VariableNameMap &inputs, | ||
const framework::VariableNameMap &outputs, | ||
const framework::AttributeMap &attrs) | ||
: OperatorBase(type, inputs, outputs, attrs) {} | ||
|
||
void Run(const framework::Scope &scope, | ||
const platform::DeviceContext &dev_ctx) const override {} | ||
}; | ||
|
||
class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker { | ||
public: | ||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; | ||
|
||
protected: | ||
virtual std::unique_ptr<framework::OpDescBind> Apply() const { | ||
auto *grad = new framework::OpDescBind(); | ||
grad->SetType("parallel_do_grad"); | ||
for (auto &input_param : this->InputNames()) { | ||
LOG(INFO) << input_param; | ||
grad->SetInput(input_param, this->Input(input_param)); | ||
grad->SetOutput(framework::GradVarName(input_param), | ||
this->InputGrad(input_param)); | ||
} | ||
|
||
for (auto &output_param : this->OutputNames()) { | ||
if (output_param == kParallelScopes) { | ||
grad->SetInput(output_param, this->Output(output_param)); | ||
grad->SetInput(framework::GradVarName(output_param), | ||
this->Output(output_param)); | ||
} else { | ||
grad->SetInput(output_param, this->Output(output_param)); | ||
grad->SetInput(framework::GradVarName(output_param), | ||
this->OutputGrad(output_param)); | ||
} | ||
} | ||
grad->SetAttrMap(this->Attrs()); | ||
grad->SetBlockAttr(kParallelBlock, *grad_block_[0]); | ||
|
||
return std::unique_ptr<framework::OpDescBind>(grad); | ||
} | ||
}; | ||
|
||
class ParallelDoGradOpShapeInference : public framework::InferShapeBase { | ||
public: | ||
void operator()(framework::InferShapeContext *ctx) const override { | ||
std::vector<std::string> input{kParameters, kInputs}; | ||
std::vector<std::string> output{kOutputs}; | ||
for (auto &s : input) { | ||
PADDLE_ENFORCE(ctx->HasInputs(s)); | ||
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(s)), | ||
"Cannot find the gradient variable %s", | ||
framework::GradVarName(s)); | ||
} | ||
for (auto &s : output) { | ||
PADDLE_ENFORCE(ctx->HasInputs(s)); | ||
} | ||
for (auto &s : input) { | ||
ctx->SetOutputsDim(framework::GradVarName(s), ctx->GetInputsDim(s)); | ||
} | ||
if (ctx->HasInputs(kParameters)) { | ||
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kParameters))); | ||
ctx->SetOutputsDim(framework::GradVarName(kParameters), | ||
ctx->GetInputsDim(kParameters)); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
REGISTER_OPERATOR(parallel_do, paddle::operators::ParallelDoOp, | ||
paddle::operators::ParallelDoOpProtoMaker, | ||
paddle::operators::ParallelDoGradOpDescMaker); | ||
REGISTER_OPERATOR(parallel_do_grad, paddle::operators::ParallelDoGradOp, | ||
paddle::operators::ParallelDoGradOpShapeInference); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add comments about this function?