-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refine InferShape for recurrent_network_op #3124
Conversation
* the tensor only contains shape and does not hold memory when inferring shape.
… rnn_infershape
step_input->Resize(step_dims); | ||
} | ||
} | ||
} | ||
|
||
void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes, | ||
const std::vector<Link>& outlinks, | ||
const size_t seq_len) { | ||
const size_t seq_len, | ||
bool infer_shape) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_infer or infer_mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
infer_shape_mode
@@ -97,18 +104,14 @@ void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes, | |||
std::shared_ptr<Scope> scope = scopes[step_id]; | |||
std::shared_ptr<Scope> linked_scope = scopes[step_id + offset]; | |||
for (auto& attr : memories) { | |||
auto mem = scope->CreateVariable(attr.pre_var)->GetMutable<Tensor>(); | |||
auto mem = scope->GetVariable(attr.pre_var)->GetMutable<Tensor>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need an enforce on scope->GetVariable(xxx) != nullptr
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
for (size_t i = 0; i < seq_len_; i++) { | ||
if (i > 0) { | ||
rnn::LinkMemories(step_scopes, arg_->memories, i, -1); | ||
rnn::LinkMemories(step_scopes, arg_->memories, i, -1, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true /* infer_mode */
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
step_scopes[0]->GetVariable(outlinks[i].external)->GetMutable<Tensor>(); | ||
output->Resize(make_ddim(dims_vec)); | ||
} | ||
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as top
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if (step_id > 0) { | ||
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1); | ||
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as the top
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
net->AddOp( | ||
OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {})); | ||
|
||
net->AddOp( | ||
OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {})); | ||
OpRegistry::CreateOp("add_two", {"x@alias", "rnn/s"}, {"rnn/h"}, {})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
an inline function for the action of adding @alias
suffix ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This unit test may move to Python, so not add an inline function for adding @alias suffix.
@@ -318,7 +313,7 @@ class RecurrentGradientAlgorithmTest : public ::testing::Test { | |||
scope_->GetVariable("step_scopes") | |||
->GetMutable<std::vector<std::shared_ptr<Scope>>>(); | |||
for (int i = 1; i < 10; ++i) { | |||
rnn::LinkMemories(*step_scopes, memories, i, -1); | |||
rnn::LinkMemories(*step_scopes, memories, i, -1, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as the top, `true /* infer_mode */
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
recurrentOp里可以调用infershape,也可以不调用infershape?请问分别是什么时候呢 |
@@ -347,7 +342,7 @@ TEST(RecurrentOp, LinkMemories) { | |||
scope->CreateVariable("pre_h"); | |||
auto tensor = scope->CreateVariable("h")->GetMutable<Tensor>(); | |||
float* data = tensor->mutable_data<float>(make_ddim({15, 20}), CPUPlace()); | |||
for (int i = 0; i < 15 * 20; ++i) { | |||
for (int j = 0; j < 15 * 20; ++j) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data[i] ==> data[j] ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -29,7 +29,8 @@ namespace rnn { | |||
|
|||
void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes, | |||
const std::vector<Link>& inlinks, | |||
const size_t seq_len) { | |||
const size_t seq_len, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const size_t&
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const size_t
is ok.
step_input->Resize(step_dims); | ||
} | ||
} | ||
} | ||
|
||
void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes, | ||
const std::vector<Link>& outlinks, | ||
const size_t seq_len) { | ||
const size_t seq_len, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const size_t&
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const size_t
is ok.
@@ -79,8 +85,9 @@ void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes, | |||
|
|||
void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes, | |||
const std::vector<rnn::MemoryAttr>& memories, | |||
size_t step_id, | |||
int offset) { | |||
const size_t step_id, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const size_t&
paddle/operators/recurrent_op.cc
Outdated
} | ||
net->GetMutable<NetOp>()->Run(step_scopes[step_id], dev_ctx); | ||
} | ||
LinkBootMemoryGradients(step_scopes[0]); | ||
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); | ||
LinkBootMemoryGradients(step_scopes[0], false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
false add a comment ?
@@ -362,7 +361,7 @@ TEST(RecurrentOp, LinkMemories) { | |||
memories.push_back(mem_attr); | |||
|
|||
for (int i = 1; i < len; ++i) { | |||
rnn::LinkMemories(step_scopes, memories, i, -1); | |||
rnn::LinkMemories(step_scopes, memories, i, -1, false /*infer_shape_mode*/); | |||
} | |||
// check | |||
for (int i = 0; i < len - 1; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unify all the array index from int
to size_t
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
The tensor only contains the shape and does not hold memory when inferring shape.