-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature/parallel_do #6730
feature/parallel_do #6730
Changes from 13 commits
9d2c77e
d34ed7b
973aec2
aea5ccc
b2ee919
2f56d4b
f899150
f879ef2
cb0b81f
6004a2e
9313233
8ee17e9
7411df3
fccbc2f
97dc451
60e27d1
8496b2e
0156066
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -182,10 +182,13 @@ static const Tensor* GetTensorFromVar(const Variable* var) { | |
const Tensor* t = nullptr; | ||
if (var->IsType<LoDTensor>()) { | ||
t = &(var->Get<LoDTensor>()); | ||
} else if (var->IsType<Tensor>()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A little confused about it: do the ParallelDo operator need it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a leftover from debugging. Sorry for the confusion. |
||
t = &(var->Get<Tensor>()); | ||
} else if (var->IsType<SelectedRows>()) { | ||
t = &(var->Get<SelectedRows>().value()); | ||
} else { | ||
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); | ||
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. expect LoDTensor/SelectedRows => expect LoDTensor/SelectedRows/Tensor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a leftover from debugging. Sorry for the confusion. |
||
var->Type().name()); | ||
} | ||
return t; | ||
} | ||
|
@@ -194,10 +197,13 @@ static Tensor* GetMutableTensorFromVar(Variable* var) { | |
Tensor* t = nullptr; | ||
if (var->IsType<LoDTensor>()) { | ||
t = var->GetMutable<LoDTensor>(); | ||
} else if (var->IsType<Tensor>()) { | ||
t = var->GetMutable<Tensor>(); | ||
} else if (var->IsType<SelectedRows>()) { | ||
t = var->GetMutable<SelectedRows>()->mutable_value(); | ||
} else { | ||
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); | ||
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. expect LoDTensor/SelectedRows => expect LoDTensor/SelectedRows/Tensor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a leftover from debugging. Sorry for the confusion. |
||
var->Type().name()); | ||
} | ||
return t; | ||
} | ||
|
@@ -356,21 +362,27 @@ class RuntimeInferShapeContext : public InferShapeContext { | |
Variable* var = scope_.FindVar(name); | ||
if (var->IsType<LoDTensor>()) { | ||
return var->Get<LoDTensor>().dims(); | ||
} else if (var->IsType<Tensor>()) { | ||
return var->Get<Tensor>().dims(); | ||
} else if (var->IsType<SelectedRows>()) { | ||
return var->Get<SelectedRows>().GetCompleteDims(); | ||
} else { | ||
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); | ||
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a leftover from debugging. Sorry for the confusion. |
||
name, var->Type().name()); | ||
} | ||
} | ||
|
||
void SetDim(const std::string& name, const DDim& dim) override { | ||
Variable* var = scope_.FindVar(name); | ||
if (var->IsType<LoDTensor>()) { | ||
var->GetMutable<LoDTensor>()->Resize(dim); | ||
} else if (var->IsType<Tensor>()) { | ||
var->GetMutable<Tensor>()->Resize(dim); | ||
} else if (var->IsType<SelectedRows>()) { | ||
var->GetMutable<SelectedRows>()->set_height(dim[0]); | ||
} else { | ||
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); | ||
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a leftover from debugging. Sorry for the confusion. |
||
name, var->Type().name()); | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add comments about this function?