Skip to content

Commit

Permalink
Rename the general elementwise and broadcast functions. (#39623)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xreki authored Feb 20, 2022
1 parent 267275d commit 553afc0
Show file tree
Hide file tree
Showing 14 changed files with 612 additions and 669 deletions.
1 change: 0 additions & 1 deletion paddle/fluid/operators/dropout_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/pten/kernels/funcs/cuda_kernel_config.h"

namespace paddle {
namespace operators {
Expand Down
40 changes: 4 additions & 36 deletions paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,13 @@
#pragma once

#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"

// only can include the headers in paddle/top/api dirs
#include "paddle/pten/kernels/gpu/elementwise.h"

namespace paddle {
namespace operators {

namespace kps = paddle::operators::kernel_primitives;

template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
int NumOuts = 1>
void LaunchBroadcastElementwiseCudaKernel(
const KPDevice &ctx, const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
std::vector<const pten::DenseTensor *> pt_inputs;
std::vector<pten::DenseTensor *> pt_outputs;
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
// DenseTensor obj
// generated by MakePtenDenseTensor can be destroyed when exits loop. *_tmp
// can be deleted
// when DenseTensor support copy constructor.
std::vector<std::unique_ptr<pten::DenseTensor>> pt_inputs_tmp;
std::vector<std::unique_ptr<pten::DenseTensor>> pt_outputs_tmp;
for (auto in : ins) {
pt_inputs_tmp.emplace_back(
std::move(paddle::experimental::MakePtenDenseTensor(*in)));
}
for (auto out : *outs) {
pt_outputs_tmp.emplace_back(
std::move(paddle::experimental::MakePtenDenseTensor(*out)));
}
for (int i = 0; i < pt_inputs_tmp.size(); i++) {
pt_inputs.push_back(pt_inputs_tmp[i].get());
}
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get());
}
pten::LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>(
ctx, pt_inputs, &pt_outputs, axis, func);
}

template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
int NumOuts = 1>
void LaunchElementwiseCudaKernel(
Expand Down Expand Up @@ -82,7 +50,7 @@ void LaunchElementwiseCudaKernel(
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get());
}
pten::LaunchElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>(
pten::funcs::BroadcastKernel<ET, InT, OutT, Functor, NumOuts>(
ctx, pt_inputs, &pt_outputs, axis, func);
}

Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License. */

// only can include the headers in paddle/top/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/kernels/gpu/elementwise.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -53,8 +53,8 @@ void LaunchSameDimsElementwiseCudaKernel(
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get());
}
pten::funcs::LaunchSameDimsElementwiseCudaKernel<OutT, Functor, NumOuts>(
ctx, pt_inputs, &pt_outputs, func);
pten::funcs::ElementwiseKernel<OutT, Functor, NumOuts>(ctx, pt_inputs,
&pt_outputs, func);
}

} // namespace operators
Expand Down
Loading

0 comments on commit 553afc0

Please sign in to comment.