Skip to content

Commit

Permalink
Compilation optimization for dist_kernel (PaddlePaddle#57541)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianhaodongbd authored and jiahy0825 committed Oct 16, 2023
1 parent 7fe7acc commit 4506991
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions paddle/phi/kernels/gpu/dist_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/legacy/reduce_max_kernel.h"
#include "paddle/phi/kernels/p_norm_kernel.h"
#include "paddle/phi/kernels/reduce_min_kernel.h"

namespace phi {

Expand Down Expand Up @@ -149,16 +151,16 @@ void DistKernel(const Context& dev_ctx,
ReduceMaxWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n);
phi::funcs::ReduceKernel<T, T, kps::MaxFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);
phi::MaxRawKernel<T, Context>(
dev_ctx, intermediate, reduce_axis, true, true, out);

} else if (p == -INFINITY) {
ReduceMinWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n);

phi::funcs::ReduceKernel<T, T, kps::MinFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);
phi::MinRawKernel<T, Context>(
dev_ctx, intermediate, reduce_axis, true, true, out);

} else {
MT p_order = static_cast<MT>(p);
Expand Down

0 comments on commit 4506991

Please sign in to comment.