Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpFunctor and replace cast, scale,clip, bce_loss and abs_grad with elementwise_no_broadcast #38500

Merged
merged 1 commit into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 19 additions & 28 deletions paddle/fluid/operators/bce_loss_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/operators/bce_loss_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
Expand All @@ -23,6 +24,17 @@ namespace operators {

using Tensor = framework::Tensor;

template <typename T>
struct BCELossGradFunctor {
T one = static_cast<T>(1.0f);
T eps = static_cast<T>(1e-12);
__device__ __forceinline__ T operator()(const T& x, const T& label,
const T& dout) const {
T term1 = max((one - x) * x, eps);
return (dout * (x - label) / term1);
}
};

template <typename T>
__global__ void GPUBCELossForward(const T* x_data, const T* label_data,
T* out_data, const int in_numel) {
Expand All @@ -44,23 +56,6 @@ __global__ void GPUBCELossForward(const T* x_data, const T* label_data,
}
}

template <typename T>
__global__ void GPUBCELossBackward(const T* x_data, const T* label_data,
const T* dout_data, T* dx_data,
const int in_numel) {
CUDA_KERNEL_LOOP(i, in_numel) {
T x = x_data[i];
T label = label_data[i];
T dout = dout_data[i];
T one = static_cast<T>(1.);
T eps = static_cast<T>(1e-12);

T term1 = max((one - x) * x, eps);

dx_data[i] = dout * (x - label) / term1;
}
}

template <typename DeviceContext, typename T>
class BCELossCUDAKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -91,17 +86,13 @@ class BCELossGradCUDAKernel : public framework::OpKernel<T> {
auto* labels = ctx.Input<Tensor>("Label");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));

int x_numel = x->numel();
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());

auto& dev_ctx = ctx.cuda_device_context();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(dev_ctx, x_numel);

GPUBCELossBackward<T><<<config.block_per_grid, config.thread_per_block, 0,
dev_ctx.stream()>>>(
x->data<T>(), labels->data<T>(), dout->data<T>(), dx_data, x_numel);
dx->mutable_data<T>(ctx.GetPlace());
std::vector<const framework::Tensor*> ins = {x, labels, dout};
std::vector<framework::Tensor*> outs = {dx};
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto functor = BCELossGradFunctor<T>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kTernary, T, T>(
dev_ctx, ins, &outs, functor);
}
};

Expand Down
26 changes: 9 additions & 17 deletions paddle/fluid/operators/clip_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/transform.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#endif

namespace paddle {
namespace operators {

using framework::Tensor;
using platform::Transform;

#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T, typename UnaryOperation>
__global__ void ClipCudaKernel(const T* input, T* out, int num,
UnaryOperation op) {
int idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx < num) {
out[idx] = op(input[idx]);
}
}
#endif

template <typename T>
class ClipFunctor {
public:
Expand Down Expand Up @@ -106,12 +98,12 @@ class ClipKernel : public framework::OpKernel<T> {
int64_t numel = x->numel();
if (platform::is_gpu_place(context.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
int threads = 256;
int blocks = (numel + threads - 1) / threads;
ClipCudaKernel<T, ClipFunctor<T>><<<
blocks, threads, 0,
context.template device_context<platform::CUDADeviceContext>()
.stream()>>>(x_data, out_data, numel, ClipFunctor<T>(min, max));
std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {out};
auto functor = ClipFunctor<T>(min, max);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
context.template device_context<platform::CUDADeviceContext>(), ins,
&outs, functor);
#endif
} else {
Transform<DeviceContext> trans;
Expand Down
68 changes: 42 additions & 26 deletions paddle/fluid/operators/label_smooth_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,39 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/label_smooth_op.h"
namespace paddle {
namespace operators {

template <typename T>
__global__ void LabelSmoothRunOriginKernel(const int N, const float epsilon,
const int label_dim, const T* src,
T* dst) {
CUDA_KERNEL_LOOP(idx, N) {
dst[idx] = static_cast<T>(1 - epsilon) * src[idx] +
static_cast<T>(epsilon / label_dim);
struct LabelSmoothFunctor {
T epsilon;
T label_dim;

__forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) {
epsilon = static_cast<T>(epsilon_data);
label_dim = static_cast<T>(label_dim_data);
}
}

__device__ __forceinline__ T operator()(const T& x) const {
return (static_cast<T>(1 - epsilon) * x +
static_cast<T>(epsilon / label_dim));
}
};

template <typename T>
struct LabelSmoothGradFunctor {
T epsilon;

__forceinline__ LabelSmoothGradFunctor(float epsilon_data) {
epsilon = static_cast<T>(epsilon_data);
}

__device__ __forceinline__ T operator()(const T& x) const {
return static_cast<T>(1 - epsilon) * x;
}
};

template <typename T>
__global__ void LabelSmoothRunDistKernel(const int N, const float epsilon,
Expand All @@ -38,14 +58,6 @@ __global__ void LabelSmoothRunDistKernel(const int N, const float epsilon,
}
}

template <typename T>
__global__ void LabelSmoothGradRunKernel(const int N, const float epsilon,
const T* src, T* dst) {
CUDA_KERNEL_LOOP(idx, N) {
dst[idx] = static_cast<T>(1 - epsilon) * src[idx];
}
}

template <typename DeviceContext, typename T>
class LabelSmoothGPUKernel : public framework::OpKernel<T> {
public:
Expand All @@ -69,8 +81,14 @@ class LabelSmoothGPUKernel : public framework::OpKernel<T> {
size_prob, epsilon, dist_numel, in_data, dist_data, out_data);

} else {
LabelSmoothRunOriginKernel<T><<<grid, threads, 0, stream>>>(
size_prob, epsilon, label_dim, in_data, out_data);
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();

std::vector<const framework::Tensor*> ins = {in_t};
std::vector<framework::Tensor*> outs = {out_t};
auto functor = LabelSmoothFunctor<T>(epsilon, label_dim);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, functor);
}
}
};
Expand All @@ -84,15 +102,13 @@ class LabelSmoothGradGPUKernel : public framework::OpKernel<T> {
d_in_t->mutable_data<T>(ctx.GetPlace());

auto epsilon = ctx.Attr<float>("epsilon");
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
const T* in_data = d_out_t->data<T>();
auto size_prob = d_out_t->numel();
T* out_data = d_in_t->mutable_data<T>(ctx.GetPlace());
int threads = 512;
int grid = (size_prob + threads - 1) / threads;
auto stream = ctx.cuda_device_context().stream();
LabelSmoothGradRunKernel<T><<<grid, threads, 0, stream>>>(
size_prob, epsilon, in_data, out_data);
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();

std::vector<const framework::Tensor*> ins = {d_out_t};
std::vector<framework::Tensor*> outs = {d_in_t};
auto functor = LabelSmoothGradFunctor<T>(epsilon);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, functor);
}
};
} // namespace operators
Expand Down
61 changes: 12 additions & 49 deletions paddle/pten/kernels/gpu/cast_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/pten/core/kernel_registry.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"
Expand All @@ -27,62 +28,24 @@

namespace pten {

template <typename InT, typename OutT, int VecSize>
__global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
using LoadT = paddle::platform::AlignedVector<InT, VecSize>;
using StoreT = paddle::platform::AlignedVector<OutT, VecSize>;

int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = idx * VecSize; i < N;
i += blockDim.x * gridDim.x * VecSize) {
LoadT in_val;
paddle::platform::Load<InT, VecSize>(&in[i], &in_val);

StoreT out_val;
#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_val[j] = static_cast<OutT>(in_val[j]);
}

paddle::platform::Store<OutT, VecSize>(out_val, &out[i]);
}
}

template <typename InT, typename OutT>
__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast<OutT>(in[index]); }
}

template <typename InT, typename OutT>
void CastCUDAKernelImplWithPtr(const GPUContext& dev_ctx,
const InT* in_data,
OutT* out_data,
int64_t size) {
paddle::platform::GpuLaunchConfig config =
paddle::platform::GetGpuLaunchConfig1D(dev_ctx, size);
int vec_size = paddle::platform::GetVectorizedSize<OutT>(out_data);
if (!std::is_same<InT, OutT>::value && vec_size == 4 && size % 4 == 0) {
VecCastCUDAKernel<InT, OutT, 4><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(
in_data, size, out_data);
} else {
CastCUDAKernel<InT, OutT><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(in_data, size, out_data);
struct CastFuctor {
__device__ __forceinline__ OutT operator()(const InT& x) const {
return static_cast<OutT>(x);
}
}
};

template <typename InT, typename OutT>
void CastCUDAKernelImpl(const GPUContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto* in_data = x.data<InT>();
auto size = x.numel();
auto* out_data = out->mutable_data<OutT>();
CastCUDAKernelImplWithPtr(dev_ctx, in_data, out_data, size);
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
outputs.emplace_back(out);
out->mutable_data<OutT>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, InT, OutT>(
dev_ctx, inputs, &outputs, CastFuctor<InT, OutT>());
}

template <typename T, typename Context>
Expand Down
47 changes: 45 additions & 2 deletions paddle/pten/kernels/gpu/scale_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,54 @@ limitations under the License. */

#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/scale_kernel_impl.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/float16.h"

namespace pten {

template <typename InT>
struct ScaleFunctor {
InT bias;
InT scale;
bool bias_after_scale;

ScaleFunctor(InT scale_data, InT bias_data, bool is_bias_after_sacle) {
scale = scale_data;
bias = bias_data;
bias_after_scale = is_bias_after_sacle;
}

__device__ __forceinline__ InT operator()(const InT& x) const {
if (bias_after_scale) {
return scale * x + bias;
} else {
return scale * (x + bias);
}
}
};

template <typename T, typename ContextT>
void Scale(const ContextT& dev_ctx,
const DenseTensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
outputs.emplace_back(out);
out->mutable_data<T>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
dev_ctx,
inputs,
&outputs,
ScaleFunctor<T>(scale.to<T>(), static_cast<T>(bias), bias_after_scale));
}

} // namespace pten

PT_REGISTER_CTX_KERNEL(scale,
GPU,
ALL_LAYOUT,
Expand Down