Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Add Sub Mul Max Min Pow binary functors in elementwise system #33050

Merged
merged 27 commits into from
Jun 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e4d29f3
First Commit.
JamesLim-sy May 21, 2021
3f243ed
First Commit.
JamesLim-sy May 21, 2021
bc12b1b
Debuging the multipler bugs
JamesLim-sy May 23, 2021
8290346
Adding Max_Min operators
JamesLim-sy May 23, 2021
49c221b
Fixs multipler bugs and supporting Max\Min OPs
JamesLim-sy May 23, 2021
dd24d12
Fixs multipler bugs and supporting Max\Min OPs
JamesLim-sy May 23, 2021
173ee57
Delete the useless codes in elementwise_mul_op.cu
JamesLim-sy May 23, 2021
0a7bfef
Delete the useless codes in elementwise_mul_op.cu
JamesLim-sy May 23, 2021
07b3797
Delete the useless codes in elementwise_mul_op.cu
JamesLim-sy May 23, 2021
a16ba39
Merge branch 'Adding_binary_functor_support' of https://github.com/Ja…
JamesLim-sy May 23, 2021
9bca0af
Merge broadcast update with OutType template argument.
JamesLim-sy May 24, 2021
b5182f1
Adjust elementwise-functor location
JamesLim-sy May 24, 2021
9d46543
Fisrt commit
JamesLim-sy May 25, 2021
74e4179
Trigger of rerun
JamesLim-sy May 25, 2021
656ac99
To avoid spartial specification bugs which happened in PR-CI-ROCM
JamesLim-sy May 26, 2021
585566f
Avoid kUnary instantiation of LaunchElementwiseCudaKernel at compile …
JamesLim-sy May 30, 2021
b9c5ea5
refine the warpper of binary ops
JamesLim-sy May 30, 2021
25d290e
refine the warpper of binary ops
JamesLim-sy May 30, 2021
0e4a011
Fix bugs
JamesLim-sy May 31, 2021
d9c70ec
refine warpper of broadcast and add cuda op
JamesLim-sy May 31, 2021
ce5a717
fix bus
JamesLim-sy May 31, 2021
950965b
adding pow
JamesLim-sy Jun 1, 2021
1f72b51
adding pow
JamesLim-sy Jun 1, 2021
90a0b29
Merge branch 'Adding_binary_functor_support' of https://github.com/Ja…
JamesLim-sy Jun 1, 2021
5b146cd
Fix header quote sort
JamesLim-sy Jun 1, 2021
f5a2ce7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JamesLim-sy Jun 2, 2021
cd40092
refine warpper
JamesLim-sy Jun 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions paddle/fluid/operators/controlflow/compare_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,21 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {

#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(Func, op) \
#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(func, op) \
template <typename T, typename Enable = void> \
struct Func##Functor { \
struct func { \
using ELEMENT_TYPE = T; \
inline HOSTDEVICE bool operator()(const T* args) const { \
return args[0] op args[1]; \
} \
};

DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThan, <)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqual, <=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThan, >)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqual, >=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqual, ==)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqual, !=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThanFunctor, <)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqualFunctor, <=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThanFunctor, >)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqualFunctor, >=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqualFunctor, ==)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqualFunctor, !=)
#undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT

template <typename T>
Expand Down Expand Up @@ -67,10 +67,12 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
auto functor = Functor();
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
int axis = PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
ctx, ins, &outs, functor);
cuda_ctx, ins, &outs, axis, functor);
}
};

Expand All @@ -79,19 +81,16 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>

#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \
REGISTER_OP_CUDA_KERNEL( \
op_type, ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func##Functor<float>, \
void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<double>, void>);
op_type, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<double>, void>);

REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqual)
REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqual)
REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThan)
REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqual)
REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThan)
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqual)
REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqualFunctor)
#undef REGISTER_CUDA_COMPARE_KERNEL
11 changes: 7 additions & 4 deletions paddle/fluid/operators/elementwise/elementwise_add_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ namespace operators {
1. For Unary Op, the length of input array is 1,
e.g. Relu: return args[0] > 0 ? args[0] : 0;
2. For Binary Op, the length of input array is 2,
e.g. Add: return args[0] + args[1];
e.g. Add: return args[0] expr args[1];
*/
template <typename T>
struct CudaAddFunctor {
__device__ __forceinline__ T operator()(const T* args) const {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] + args[1];
}
};
Expand All @@ -44,9 +44,12 @@ class ElementwiseAddKernel<platform::CUDADeviceContext, T>
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
PackTensorsIntoVector<T>(ctx, &ins, &outs);
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx, ins, &outs, CudaAddFunctor<T>());
cuda_ctx, ins, &outs, axis, CudaAddFunctor<T>());
}
};

Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/operators/elementwise/elementwise_add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,10 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
if (x->dims() == y->dims()) {
SameDimsElemwiseAdd<platform::CPUDeviceContext, T>
LaunchElementwiseCpuKernel;
SameDimsElemwiseAdd<DeviceContext, T> LaunchElementwiseCpuKernel;
LaunchElementwiseCpuKernel(ctx, x, y, z);
} else {
LaunchBroadcastElementwiseCpuKernel<platform::CPUDeviceContext, T>(ctx, x,
y, z);
LaunchBroadcastElementwiseCpuKernel<DeviceContext, T>(ctx, x, y, z);
}
}
};
Expand Down
31 changes: 31 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_max_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"

namespace ops = paddle::operators;

namespace paddle {
namespace operators {

template <typename T>
struct CudaMaxFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return (args[0] > args[1] ? args[0] : args[1]);
}
};

template <typename T>
class ElementwiseMaxKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaMaxFunctor<T>());
}
};

} // namespace operators
} // namespace paddle

REGISTER_OP_CUDA_KERNEL(
elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, float>,
Expand Down
31 changes: 31 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_min_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"

namespace ops = paddle::operators;

namespace paddle {
namespace operators {

template <typename T>
struct CudaMinFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return (args[0] > args[1] ? args[1] : args[0]);
}
};

template <typename T>
class ElementwiseMinKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaMinFunctor<T>());
}
};

} // namespace operators
} // namespace paddle

REGISTER_OP_CUDA_KERNEL(
elementwise_min,
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, float>,
Expand Down
85 changes: 57 additions & 28 deletions paddle/fluid/operators/elementwise/elementwise_mul_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
Expand All @@ -24,37 +25,65 @@ namespace paddle {
namespace operators {

template <typename T>
struct SameDimsElemwiseMul<platform::CUDADeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
MulRangeFunctor<T> functor(x->data<T>(), y->data<T>(), z->data<T>());
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
x->numel());
for_range(functor);
struct CudaMulFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] * args[1];
}
};

template <>
struct SameDimsElemwiseMul<platform::CUDADeviceContext, platform::float16> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto size = x->numel();
dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
const half* x2 =
reinterpret_cast<const half*>(x->data<platform::float16>());
const half* y2 =
reinterpret_cast<const half*>(y->data<platform::float16>());
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
SameDimsElemwiseMulCUDAKernel<<<
grid_size, block_size, 0,
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
x2, y2, z2, size);
template <typename T>
class ElementwiseMulKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int axis = -1;
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_NOT_NULL(
x_var, platform::errors::InvalidArgument(
"Cannot get input Variable X, Variable name = %s.",
ctx.InputName("X")));
auto* y = ctx.Input<framework::LoDTensor>("Y");

framework::Tensor x, *z;
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

if (x_var->IsType<framework::LoDTensor>()) {
x = x_var->Get<framework::LoDTensor>();
z = ctx.Output<framework::LoDTensor>("Out");
axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
} else if (x_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true,
platform::errors::InvalidArgument(
"For elementwise_op, if X is Sparse, Y must be "
"scalar. But reveived the size of Y = %s.",
y->dims().size()));
auto& x_sele = x_var->Get<framework::SelectedRows>();
auto out_sele = ctx.Output<framework::SelectedRows>("Out");
x = x_sele.value();
out_sele->set_rows(x_sele.rows());
out_sele->set_height(x_sele.height());
out_sele->mutable_value()->Resize(x_sele.value().dims());
out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type());
z = ctx.Output<framework::SelectedRows>("Out")->mutable_value();
z->mutable_data<T>(ctx.GetPlace());
outs.emplace_back(z);
ins.emplace_back(&x);
ins.emplace_back(y);

axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
axis = axis == -1 ? std::abs(y->dims().size() - x.dims().size()) : axis;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"X's type[%s] is not supported by elementwise_op. X's type should be "
"LoDTensor or SelectedRows.",
framework::ToTypeName(x_var->Type())));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要复制粘贴大段代码,L41 - L70写个函数封装一下。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L41 - L70也可以封装到PackTensorsIntoVector函数里面。


LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaMulFunctor<T>());
}
};

Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/operators/elementwise/elementwise_mul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
}
}
};

template <typename T>
struct MulGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; }
Expand Down
24 changes: 10 additions & 14 deletions paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,11 @@ void LaunchBroadcastElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
static_assert(ET == (ElementwiseType)2, "Only Support binary calculation.");
PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary,
platform::errors::InvalidArgument(
"Currently, only Support binary calculation, "
"but received %d input tensors.\n",
static_cast<int>(ET)));
int in_vec_size = 4;
framework::Tensor *out = (*outs)[0];
for (auto *in : ins) {
Expand Down Expand Up @@ -502,26 +506,18 @@ void LaunchBroadcastElementwiseCudaKernel(

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchElementwiseCudaKernel(
const framework::ExecutionContext &ctx,
const platform::CUDADeviceContext &cuda_ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, Functor func) {
std::vector<int> dims_size;
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag = ins[0]->dims() == in->dims();
dims_size.emplace_back(in->dims().size());
}
const auto &cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

if (no_broadcast_flag) {
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
cuda_ctx, ins, outs, func);
LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
func);
} else {
int axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
axis, func);
}
Expand Down
16 changes: 10 additions & 6 deletions paddle/fluid/operators/elementwise/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,24 @@ namespace operators {
* To pack the input and output tnesors into vector for
* LaunchElementwiseCudaKernel
*/
template <typename T>
void PackTensorsIntoVector(const framework::ExecutionContext &ctx,
std::vector<const framework::Tensor *> *ins,
std::vector<framework::Tensor *> *outs) {
template <typename OutT>
int PackTensorsIntoVector(const framework::ExecutionContext &ctx,
std::vector<const framework::Tensor *> *ins,
std::vector<framework::Tensor *> *outs) {
int axis = -1;
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
ins->emplace_back(x);
z->mutable_data<OutT>(ctx.GetPlace());
outs->emplace_back(z);
ins->emplace_back(x);

if (y != nullptr) {
ins->emplace_back(y);
axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
axis = axis == -1 ? std::abs(y->dims().size() - x->dims().size()) : axis;
}
return axis;
}

/*
Expand Down
Loading