Skip to content

Commit

Permalink
compilation optimization for logsumexp_kernel (PaddlePaddle#57817)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianhaodongbd authored and jiahy0825 committed Oct 16, 2023
1 parent 302bd61 commit 698eb42
Showing 1 changed file with 5 additions and 27 deletions.
32 changes: 5 additions & 27 deletions paddle/phi/kernels/gpu/logsumexp_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/reduce_max_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"

namespace phi {

Expand All @@ -42,27 +45,6 @@ struct ComputeType<phi::dtype::bfloat16> {
using type = float;
};

template <typename T>
struct LogCUDAFunctor {
HOSTDEVICE inline T operator()(const T x) const { return std::log(x); }
};

template <>
struct LogCUDAFunctor<float16> {
HOSTDEVICE inline float16 operator()(const float16 x) const {
auto x_ = static_cast<float>(x);
return static_cast<float16>(std::log(x_));
}
};

template <>
struct LogCUDAFunctor<bfloat16> {
HOSTDEVICE inline bfloat16 operator()(const bfloat16 x) const {
auto x_ = static_cast<float>(x);
return static_cast<bfloat16>(std::log(x_));
}
};

template <typename T, typename Context>
void LogsumexpFallbackKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand All @@ -84,18 +66,14 @@ void LogsumexpFallbackKernel(const Context& dev_ctx,
max_x.Resize(outdim);
dev_ctx.template Alloc<T>(&max_x);

phi::funcs::ReduceKernel<T, T, kps::MaxFunctor, kps::IdentityFunctor<T>>(
dev_ctx, *in_x, &max_x, kps::IdentityFunctor<T>(), axis_vec);
phi::MaxKernel<T, Context>(dev_ctx, *in_x, axis_vec, false, &max_x);

max_x.Resize(keeped_outdim);
DenseTensor temp_x = Subtract<T, Context>(dev_ctx, *in_x, max_x);
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::ExpFunctor<T>>(
dev_ctx, temp_x, out_y, kps::ExpFunctor<T>(), axis_vec);

const std::vector<const DenseTensor*> inputs = {out_y};
std::vector<DenseTensor*> outputs = {&temp_x};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, inputs, &outputs, LogCUDAFunctor<T>());
phi::LogKernel<T, Context>(dev_ctx, *out_y, &temp_x);
temp_x.Resize(outdim);
out->Resize(outdim);
phi::AddKernel<T, Context>(dev_ctx, temp_x, max_x, out);
Expand Down

0 comments on commit 698eb42

Please sign in to comment.