Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dist transpiler support prefetch #9714

Merged

Conversation

jacquesqiao
Copy link
Member

@jacquesqiao jacquesqiao commented Apr 7, 2018

project: #9597
task list: #9211
test code: https://github.com/jacquesqiao/models/tree/dist-lookup-table/dist_lookup_table

remain problem:

prefetch block has to be at that last, or RunPreparedContext will fail.

@jacquesqiao jacquesqiao mentioned this pull request Apr 8, 2018
15 tasks
@jacquesqiao jacquesqiao changed the title [WIP]Dist transpiler support prefetch Dist transpiler support prefetch Apr 10, 2018
@@ -33,7 +35,7 @@ class ConcatOp : public framework::OperatorWithKernel {
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
const size_t n = ins.size();

PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
// PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May delete the comment.

}
auto prepared = executor.Prepare(*program, block_list);
auto optimize_prepared = executor.Prepare(*program, block_list);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to prepare all the blocks of the program, so maybe the name prepared is more suitable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimize_prepared is used to be different with prefetch_prepared


rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
// TODO(qiao) set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor);
rpc_service_->SetPrefetchBlkdId(0);
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code L106 have already prepared all the blocks, so we don't need to prepare the prefetch_block again.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -71,7 +71,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddOutput("Out",
"(SelectedRows) result "
"(LoDTensor) result "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type of Output variable is SelectedRows, just because the shape was not a certain value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it should be LoDTensor, because the following op is not certain, most of them can only process LoDTensor, SelectedRows is constructed when backward.

@@ -36,8 +38,8 @@ class SumOp : public framework::OperatorWithKernel {
}

auto x_dims = ctx->GetInputsDim("X");
size_t N = x_dims.size();
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
// size_t N = x_dims.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please delete these comments.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add TODO here, maybe this check need to add back in the future.

@@ -252,12 +315,114 @@ def transpile(self,
outputs={"Out": [orig_param]},
attrs={"axis": 0})

if self.has_distributed_lookup_table:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move these following code into an independent function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@Yancey1989
Copy link
Contributor

Awesome! Thanks for PR and make it work!

@@ -55,7 +55,7 @@ static DDim GetDims(const Scope& scope, const std::string& name) {
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
return var->Get<SelectedRows>().value().dims();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this affect other places like optimization ops?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, will optimize this code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name
all_ops = program.global_block().ops
table_grad_name = framework.grad_var_name(self.table_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad_var_name sometimes may not get the "real" grad var name, for backward may create a different name.

Copy link
Member Author

@jacquesqiao jacquesqiao Apr 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, here the name of the table parameter's gradient will always be table_name@GRAD, the table_name@GRAD@RENAME name will be merged into table_name@GRAD.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks

Copy link
Contributor

@typhoonzero typhoonzero left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@jacquesqiao jacquesqiao merged commit 4c55a60 into PaddlePaddle:develop Apr 12, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants