Skip to content

Commit

Permalink
code refine
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Apr 17, 2018
1 parent 2aaa75e commit 4abef50
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 50 deletions.
58 changes: 31 additions & 27 deletions paddle/fluid/framework/details/broadcast_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,40 +34,21 @@ BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
: local_scopes_(local_scopes), places_(places) {}

void BroadcastOpHandle::RunImpl() {
// the input may have dummy var.
std::vector<VarHandle *> in_var_handle;
for (auto *in : inputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(in);
if (out_handle) {
in_var_handle.push_back(out_handle);
}
}
// the input and output may have dummy var.
std::vector<VarHandle *> in_var_handle = GetValidVarHandles(inputs_);
std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);

PADDLE_ENFORCE_EQ(in_var_handle.size(), 1,
"The number of input should be one.");

// the output may have dummy var.
std::vector<VarHandle *> out_var_handles;
for (auto *out : outputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(out);
if (out_handle) {
out_var_handles.push_back(out_handle);
}
}

PADDLE_ENFORCE_EQ(
out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");

// Wait input done, this Wait is asynchronous operation
auto &in_place = in_var_handle[0]->place_;
if (in_var_handle[0]->generated_op_) {
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
}
}
// Wait input done, this Wait is asynchronous operationplatform::Place
// &in_place;
WaitEvents(out_var_handles, in_var_handle);

//
auto in_place = in_var_handle[0]->place_;
auto in_scope_idx = in_var_handle[0]->scope_idx_;
auto in_var =
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_);
Expand Down Expand Up @@ -107,6 +88,29 @@ void BroadcastOpHandle::RunImpl() {
}
}

void BroadcastOpHandle::WaitEvents(
const std::vector<VarHandle *> &out_var_handles,
const std::vector<VarHandle *> &in_var_handle) {
if (in_var_handle[0]->generated_op_) {
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
}
}
}

std::vector<VarHandle *> BroadcastOpHandle::GetValidVarHandles(
const std::vector<VarHandleBase *> &inputs) {
std::vector<VarHandle *> in_var_handle;
for (auto *in : inputs) {
auto *out_handle = dynamic_cast<VarHandle *>(in);
if (out_handle) {
in_var_handle.push_back(out_handle);
}
}
return in_var_handle;
}

std::string BroadcastOpHandle::Name() const { return "broadcast"; }
} // namespace details
} // namespace framework
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/details/broadcast_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ struct BroadcastOpHandle : public OpHandleBase {

protected:
void RunImpl() override;

std::vector<VarHandle *> GetValidVarHandles(
const std::vector<VarHandleBase *> &inputs);

void WaitEvents(const std::vector<VarHandle *> &out_var_handles,
const std::vector<VarHandle *> &in_var_handle);
};

} // namespace details
Expand Down
50 changes: 27 additions & 23 deletions paddle/fluid/framework/details/gather_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,13 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
: local_scopes_(local_scopes), places_(places) {}

void GatherOpHandle::RunImpl() {
// the input may have dummy var.
std::vector<VarHandle *> in_var_handles;
for (auto *in : inputs_) {
auto *in_handle = dynamic_cast<VarHandle *>(in);
if (in_handle) {
in_var_handles.push_back(in_handle);
}
}
// the input and output may have dummy var.
std::vector<VarHandle *> in_var_handles = GetValidVarHandles(inputs_);
std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);

PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");

// the output may have dummy var.
std::vector<VarHandle *> out_var_handles;
for (auto *out : outputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(out);
if (out_handle) {
out_var_handles.push_back(out_handle);
}
}
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one.");

Expand All @@ -58,11 +45,7 @@ void GatherOpHandle::RunImpl() {
"The place of input and output should be the same.");

// Wait input done, this Wait is asynchronous operation
for (auto *in : in_var_handles) {
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[in->place_]);
}
}
WaitEvents(in_var_handles);

std::vector<int64_t> out_rows;
std::vector<Tensor> in_tensors;
Expand Down Expand Up @@ -111,7 +94,7 @@ void GatherOpHandle::RunImpl() {

// copy
auto dev_ctx = dev_ctxes_[out_place];
RunAndRecordEvent(out_place, [in_tensors, out_var, dev_ctx, out_place] {
RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] {
int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0];
Expand All @@ -123,6 +106,27 @@ void GatherOpHandle::RunImpl() {
});
}

void GatherOpHandle::WaitEvents(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[in->place_]);
}
}
}

std::vector<VarHandle *> GatherOpHandle::GetValidVarHandles(
const std::vector<VarHandleBase *> &inputs) {
std::vector<VarHandle *> in_var_handles;
for (auto *in : inputs) {
auto *in_handle = dynamic_cast<VarHandle *>(in);
if (in_handle) {
in_var_handles.push_back(in_handle);
}
}
return in_var_handles;
}

std::string GatherOpHandle::Name() const { return "gather"; }
} // namespace details
} // namespace framework
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/details/gather_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ struct GatherOpHandle : public OpHandleBase {

protected:
void RunImpl() override;

std::vector<VarHandle *> GetValidVarHandles(
const std::vector<VarHandleBase *> &);

void WaitEvents(const std::vector<VarHandle *> &in_var_handles);
};

} // namespace details
Expand Down

0 comments on commit 4abef50

Please sign in to comment.