Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine InferShape for recurrent_network_op #3124

Merged
merged 8 commits into from
Aug 3, 2017

Conversation

qingqing01
Copy link
Contributor

The tensor only contains the shape and does not hold memory when inferring shape.

* the tensor only contains shape and does not hold memory when inferring shape.
step_input->Resize(step_dims);
}
}
}

void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<Link>& outlinks,
const size_t seq_len) {
const size_t seq_len,
bool infer_shape) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_infer or infer_mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infer_shape_mode

@@ -97,18 +104,14 @@ void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes,
std::shared_ptr<Scope> scope = scopes[step_id];
std::shared_ptr<Scope> linked_scope = scopes[step_id + offset];
for (auto& attr : memories) {
auto mem = scope->CreateVariable(attr.pre_var)->GetMutable<Tensor>();
auto mem = scope->GetVariable(attr.pre_var)->GetMutable<Tensor>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need an enforce on scope->GetVariable(xxx) != nullptr ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for (size_t i = 0; i < seq_len_; i++) {
if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1);
rnn::LinkMemories(step_scopes, arg_->memories, i, -1, true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true /* infer_mode */

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

step_scopes[0]->GetVariable(outlinks[i].external)->GetMutable<Tensor>();
output->Resize(make_ddim(dims_vec));
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as top

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1);
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the top

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

net->AddOp(
OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {}));

net->AddOp(
OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {}));
OpRegistry::CreateOp("add_two", {"x@alias", "rnn/s"}, {"rnn/h"}, {}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an inline function for the action of adding @alias suffix ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unit test may move to Python, so not add an inline function for adding @alias suffix.

@@ -318,7 +313,7 @@ class RecurrentGradientAlgorithmTest : public ::testing::Test {
scope_->GetVariable("step_scopes")
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
for (int i = 1; i < 10; ++i) {
rnn::LinkMemories(*step_scopes, memories, i, -1);
rnn::LinkMemories(*step_scopes, memories, i, -1, true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the top, `true /* infer_mode */

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@luotao1
Copy link
Contributor

luotao1 commented Jul 31, 2017

recurrentOp里可以调用infershape,也可以不调用infershape?请问分别是什么时候呢

@@ -347,7 +342,7 @@ TEST(RecurrentOp, LinkMemories) {
scope->CreateVariable("pre_h");
auto tensor = scope->CreateVariable("h")->GetMutable<Tensor>();
float* data = tensor->mutable_data<float>(make_ddim({15, 20}), CPUPlace());
for (int i = 0; i < 15 * 20; ++i) {
for (int j = 0; j < 15 * 20; ++j) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data[i] ==> data[j] ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -29,7 +29,8 @@ namespace rnn {

void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<Link>& inlinks,
const size_t seq_len) {
const size_t seq_len,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const size_t&

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const size_t is ok.

step_input->Resize(step_dims);
}
}
}

void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<Link>& outlinks,
const size_t seq_len) {
const size_t seq_len,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const size_t&

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const size_t is ok.

@@ -79,8 +85,9 @@ void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,

void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes,
const std::vector<rnn::MemoryAttr>& memories,
size_t step_id,
int offset) {
const size_t step_id,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const size_t&

}
net->GetMutable<NetOp>()->Run(step_scopes[step_id], dev_ctx);
}
LinkBootMemoryGradients(step_scopes[0]);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_);
LinkBootMemoryGradients(step_scopes[0], false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

false add a comment ?

@@ -362,7 +361,7 @@ TEST(RecurrentOp, LinkMemories) {
memories.push_back(mem_attr);

for (int i = 1; i < len; ++i) {
rnn::LinkMemories(step_scopes, memories, i, -1);
rnn::LinkMemories(step_scopes, memories, i, -1, false /*infer_shape_mode*/);
}
// check
for (int i = 0; i < len - 1; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unify all the array index from int to size_t

Copy link
Contributor

@Superjomn Superjomn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@jacquesqiao jacquesqiao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@qingqing01 qingqing01 merged commit 017182c into PaddlePaddle:develop Aug 3, 2017
@qingqing01 qingqing01 deleted the rnn_infershape branch March 7, 2018 12:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants