Skip to content

Commit

Permalink
Merge pull request #9922 from chengduoZH/feature/refine_gather_reduce
Browse files Browse the repository at this point in the history
Refine gather and broadcast
  • Loading branch information
chengduo authored Apr 20, 2018
2 parents d114d2b + 88f8183 commit 23a21c8
Show file tree
Hide file tree
Showing 9 changed files with 262 additions and 108 deletions.
6 changes: 4 additions & 2 deletions paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framewor
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)

cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)

cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base variable_visitor scope ddim memory)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope variable_visitor ddim memory)

cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle)
Expand Down
107 changes: 42 additions & 65 deletions paddle/fluid/framework/details/broadcast_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,95 +13,72 @@
// limitations under the License.

#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"

namespace paddle {
namespace framework {
namespace details {

Tensor *GetTensorFromVar(Variable *in_var) {
if (in_var->IsType<LoDTensor>()) {
return in_var->GetMutable<LoDTensor>();
} else if (in_var->IsType<SelectedRows>()) {
return in_var->GetMutable<SelectedRows>()->mutable_value();
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
}
return nullptr;
}

BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}

void BroadcastOpHandle::RunImpl() {
// the input may have dummy var.
std::vector<VarHandle *> in_var_handle;
for (auto *in : inputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(in);
if (out_handle) {
in_var_handle.push_back(out_handle);
}
}
PADDLE_ENFORCE_EQ(in_var_handle.size(), 1,
"The number of input should be one.");
// the input and output may have dummy var.
VarHandle *in_var_handle;

// the output may have dummy var.
std::vector<VarHandle *> out_var_handles;
for (auto *out : outputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(out);
if (out_handle) {
out_var_handles.push_back(out_handle);
}
{
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1,
"The number of input should be one.");
in_var_handle = in_var_handles[0];
}

auto out_var_handles = DynamicCast<VarHandle>(outputs_);

PADDLE_ENFORCE_EQ(
out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");

// Wait input done, this Wait is asynchronous operation
auto &in_place = in_var_handle[0]->place_;
if (in_var_handle[0]->generated_op_) {
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
}
}
// Wait input done, this Wait is asynchronous operation platform::Place
// &in_place;
WaitInputVarGenerated(*in_var_handle);

//
auto in_scope_idx = in_var_handle[0]->scope_idx_;
auto in_var =
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_);
Tensor *in_tensor = GetTensorFromVar(in_var);
auto *in_var = local_scopes_.at(in_var_handle->scope_idx_)
->FindVar(in_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(in_var);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);

for (auto *out : out_var_handles) {
if (*out == *in_var_handle) {
continue;
}

auto &out_p = out->place_;
auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);
auto *out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);

PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(),
PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(),
"Places must be all on CPU or all on CUDA.");

if (in_var->IsType<framework::SelectedRows>()) {
auto &in_sr = in_var->Get<framework::SelectedRows>();
auto out_sr = out_var->GetMutable<framework::SelectedRows>();
if (&in_sr == out_sr) continue;
out_sr->set_height(in_sr.height());
out_sr->set_rows(in_sr.rows());
out_sr->mutable_value()->Resize(in_sr.value().dims());
out_sr->mutable_value()->mutable_data(out_p, in_sr.value().type());
} else if (in_var->IsType<framework::LoDTensor>()) {
auto in_lod = in_var->Get<framework::LoDTensor>();
auto out_lod = out_var->GetMutable<framework::LoDTensor>();
if (&in_lod == out_lod) continue;
out_lod->set_lod(in_lod.lod());
out_lod->Resize(in_lod.dims());
out_lod->mutable_data(out_p, in_lod.type());
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
}
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var)
.Resize(in_tensor.dims())
.mutable_data(out_p, in_tensor.type());

Tensor *out_tensor = GetTensorFromVar(out_var);
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]),
out_tensor);
auto dev_ctx = dev_ctxes_[out_p];
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
paddle::framework::TensorCopy(
in_tensor, out_p, *(dev_ctx),
&VariableVisitor::GetMutableTensor(out_var));
});
}
}

void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
if (in_var.generated_op_) {
for (auto &pair : dev_ctxes_) {
in_var.generated_op_->Wait(pair.second);
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/broadcast_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase {

protected:
void RunImpl() override;
void WaitInputVarGenerated(const VarHandle &in_var);

private:
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
};

} // namespace details
} // namespace framework
} // namespace paddle
40 changes: 40 additions & 0 deletions paddle/fluid/framework/details/container_cast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <type_traits>
#include <vector>

namespace paddle {
namespace framework {
namespace details {

template <typename ResultType, typename ElemType>
std::vector<ResultType*> DynamicCast(const std::vector<ElemType*>& container) {
static_assert(std::is_base_of<ElemType, ResultType>::value,
"ElementType must be a base class of ResultType");
std::vector<ResultType*> res;
for (auto* ptr : container) {
auto* derived = dynamic_cast<ResultType*>(ptr);
if (derived) {
res.emplace_back(derived);
}
}
return res;
}

} // namespace details
} // namespace framework
} // namespace paddle
83 changes: 43 additions & 40 deletions paddle/fluid/framework/details/gather_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// limitations under the License.

#include "paddle/fluid/framework/details/gather_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"

namespace paddle {
namespace framework {
Expand All @@ -23,81 +25,68 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
: local_scopes_(local_scopes), places_(places) {}

void GatherOpHandle::RunImpl() {
// the input may have dummy var.
std::vector<VarHandle *> in_var_handles;
for (auto *in : inputs_) {
auto *in_handle = dynamic_cast<VarHandle *>(in);
if (in_handle) {
in_var_handles.push_back(in_handle);
}
}
// the input and output may have dummy var.
auto in_var_handles = DynamicCast<VarHandle>(inputs_);

PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");

// the output may have dummy var.
std::vector<VarHandle *> out_var_handles;
for (auto *out : outputs_) {
auto *out_handle = dynamic_cast<VarHandle *>(out);
if (out_handle) {
out_var_handles.push_back(out_handle);
}
VarHandle *out_var_handle;
{
auto out_var_handles = DynamicCast<VarHandle>(outputs_);

PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one.");
out_var_handle = out_var_handles.front();
}
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one.");

auto in_0_handle = static_cast<VarHandle *>(in_var_handles[0]);
auto in_0_handle = in_var_handles[0];
auto pre_in_var =
local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_);
auto pre_place = in_0_handle->place_;

PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(),
"Currently, gather_op only can gather SelectedRows.");

PADDLE_ENFORCE_EQ(out_var_handles[0]->place_.which(), pre_place.which(),
PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(),
"The place of input and output should be the same.");

// Wait input done, this Wait is asynchronous operation
for (auto *in : in_var_handles) {
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[in->place_]);
}
}
WaitInputVarGenerated(in_var_handles);

std::vector<int64_t> out_rows;
std::vector<Tensor> in_tensors;
std::vector<platform::Place> in_places;

auto &pre_in = pre_in_var->Get<framework::SelectedRows>();
// gather the inputs
for (auto *in : in_var_handles) {
auto in_handle = static_cast<VarHandle *>(in);
for (auto *in_handle : in_var_handles) {
auto in_p = in_handle->place_;
in_places.push_back(in_p);
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"Places must be all on CPU or all on CUDA.");
auto in_var =
auto *in_var =
local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
auto &in_sr = in_var->Get<framework::SelectedRows>();

PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(),
"The type of input is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(),
"The height of inputs is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), ,
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(),
"The dims of inputs is not consistent.");

auto in_sr_rows = in_sr.rows();
auto &in_sr_rows = in_sr.rows();
out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end());

in_tensors.emplace_back(in_sr.value());
}

// write the output
auto &out_place = out_var_handles[0]->place_;
auto out_scope_idx = out_var_handles[0]->scope_idx_;
auto out_var =
local_scopes_[out_scope_idx]->FindVar(out_var_handles[0]->name_);
auto &out_place = out_var_handle->place_;
auto out_scope_idx = out_var_handle->scope_idx_;
auto out_var = local_scopes_[out_scope_idx]->FindVar(out_var_handle->name_);

auto out = out_var->GetMutable<framework::SelectedRows>();
out->set_height(pre_in.height());
Expand All @@ -110,13 +99,27 @@ void GatherOpHandle::RunImpl() {
Tensor *out_tensor = out->mutable_value();

// copy
int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors[j], out_place,
*(dev_ctxes_[in_places[j]]), &sub_out);
s = e;
auto dev_ctx = dev_ctxes_[out_place];
RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] {
int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors[j], out_place, *(dev_ctx),
&sub_out);
s = e;
}
});
}

void GatherOpHandle::WaitInputVarGenerated(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
for (auto pair : dev_ctxes_) {
in->generated_op_->Wait(pair.second);
}
}
}
}

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/details/gather_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct GatherOpHandle : public OpHandleBase {

protected:
void RunImpl() override;
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);

private:
const std::vector<Scope *> &local_scopes_;
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/details/var_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ struct VarHandle : public VarHandleBase {
size_t scope_idx_;
std::string name_;
platform::Place place_;

bool operator==(const VarHandle& o) const {
return o.generated_op_ == generated_op_ && o.name_ == name_ &&
o.scope_idx_ == scope_idx_;
}
};

// Dummy Variable. It is used to represent dependencies between operators
Expand Down
Loading

0 comments on commit 23a21c8

Please sign in to comment.