Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Phi decouple] move layer_norm_kernel.cu.h to phi #50506

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
#include "paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.h"
#include "paddle/fluid/operators/fused/fused_dropout_common.h"
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"

namespace paddle {
Expand Down
63 changes: 32 additions & 31 deletions paddle/fluid/operators/fused/attention_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"

namespace paddle {
namespace operators {
Expand All @@ -35,11 +35,11 @@ class AttnLayerNorm {
~AttnLayerNorm() {}

void ComputeForward(const InType* x_data,
const LayerNormParamType<T>* scale_data,
const LayerNormParamType<T>* bias_data,
const phi::funcs::LayerNormParamType<T>* scale_data,
const phi::funcs::LayerNormParamType<T>* bias_data,
OutType* y_data,
LayerNormParamType<T>* mean_data,
LayerNormParamType<T>* var_data,
phi::funcs::LayerNormParamType<T>* mean_data,
phi::funcs::LayerNormParamType<T>* var_data,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_in_scale = 1.0,
Expand All @@ -48,14 +48,14 @@ class AttnLayerNorm {
const float quant_min_bound = -127.0) {
auto stream = dev_ctx_.stream();

switch (GetDesiredBlockDim(feature_size_)) {
switch (phi::funcs::GetDesiredBlockDim(feature_size_)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T,
LayerNormParamType<T>,
kBlockDim,
false,
InType,
OutType>
phi::funcs::LayerNormForward<T,
phi::funcs::LayerNormParamType<T>,
kBlockDim,
false,
InType,
OutType>
<<<batch_size_, kBlockDim, 0, stream>>>(x_data,
scale_data,
bias_data,
Expand All @@ -71,32 +71,33 @@ class AttnLayerNorm {
quant_max_bound,
quant_min_bound));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Feature_size must be larger than 1"));
PADDLE_THROW(
phi::errors::InvalidArgument("Feature_size must be larger than 1"));
break;
}
}

void ComputeBackward(const T* x_data,
const T* d_y_data,
const LayerNormParamType<T>* scale_data,
const LayerNormParamType<T>* mean_data,
const LayerNormParamType<T>* var_data,
const phi::funcs::LayerNormParamType<T>* scale_data,
const phi::funcs::LayerNormParamType<T>* mean_data,
const phi::funcs::LayerNormParamType<T>* var_data,
T* d_x_data,
LayerNormParamType<T>* d_scale_data,
LayerNormParamType<T>* d_bias_data) {
LayerNormBackward<T, LayerNormParamType<T>>(x_data,
d_y_data,
scale_data,
mean_data,
var_data,
d_x_data,
d_scale_data,
d_bias_data,
epsilon_,
batch_size_,
feature_size_,
dev_ctx_);
phi::funcs::LayerNormParamType<T>* d_scale_data,
phi::funcs::LayerNormParamType<T>* d_bias_data) {
phi::funcs::LayerNormBackward<T, phi::funcs::LayerNormParamType<T>>(
x_data,
d_y_data,
scale_data,
mean_data,
var_data,
d_x_data,
d_scale_data,
d_bias_data,
epsilon_,
batch_size_,
feature_size_,
dev_ctx_);
}

private:
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fused/fused_dropout_act_bias.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace operators {
template <typename T>
struct GeluFunctor {
inline __host__ __device__ T operator()(const T x) const {
using U = LayerNormParamType<T>;
using U = phi::funcs::LayerNormParamType<T>;
const U casted_x = static_cast<U>(x);
const U temp = erf(casted_x * static_cast<U>(M_SQRT1_2));
const U out = (casted_x * static_cast<U>(0.5) * (static_cast<U>(1) + temp));
Expand All @@ -47,7 +47,7 @@ struct FastGeluFunctor {
template <typename T>
struct GeluGradFunctor {
inline __host__ __device__ T UseOut(const T x) const {
using U = LayerNormParamType<T>;
using U = phi::funcs::LayerNormParamType<T>;
auto casted_x = static_cast<U>(x);

auto first =
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fused/fused_dropout_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -138,7 +138,7 @@ inline __device__ void CalculateDBias(const T *tmp_sum,
int reduce_num_pre_thread = (BlockSizeX * VecSize + 31) / 32;
// reduce 32 to 1
for (int i = 0; i < reduce_num_pre_thread; i++) {
sum[i] = WarpReduceSum(sum[i]);
sum[i] = phi::funcs::WarpReduceSum(sum[i]);
}

// save sum to dbias
Expand Down
50 changes: 25 additions & 25 deletions paddle/fluid/operators/fused/fused_dropout_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,18 +418,18 @@ class FusedDropoutLayerNormHelper
LayerNormParamType<T>* d_scale,
LayerNormParamType<T>* d_bias) {
using U = LayerNormParamType<T>;
LayerNormBackward<T, U>(src,
dout,
gamma,
mean,
variance,
d_src,
d_scale,
d_bias,
epsilon_,
this->rows_,
this->cols_,
ctx);
phi::funcs::LayerNormBackward<T, U>(src,
dout,
gamma,
mean,
variance,
d_src,
d_scale,
d_bias,
epsilon_,
this->rows_,
this->cols_,
ctx);
}

// out = layernorm(residual + dropout(src + bias))
Expand Down Expand Up @@ -457,7 +457,7 @@ class FusedDropoutLayerNormHelper
if (this->cols_ % vec_size != 0) {
vec_size = 1;
}
int threads = GetDesiredBlockDim(this->cols_ / vec_size);
int threads = phi::funcs::GetDesiredBlockDim(this->cols_ / vec_size);
int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size;
increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment);
LaunchLayernormResidualDropoutBias<T,
Expand Down Expand Up @@ -537,18 +537,18 @@ class FusedDropoutLayerNormHelper
d_residual,
d_dropout_src);
} else {
LayerNormBackward<T, U, is_same_type>(layernorm_src,
d_out,
gamma,
mean,
variance,
d_layernorm_src,
d_scale,
d_layernorm_bias,
epsilon_,
this->rows_,
this->cols_,
ctx);
phi::funcs::LayerNormBackward<T, U, is_same_type>(layernorm_src,
d_out,
gamma,
mean,
variance,
d_layernorm_src,
d_scale,
d_layernorm_bias,
epsilon_,
this->rows_,
this->cols_,
ctx);
this->ResidualDropoutBiasGrad(
ctx, d_layernorm_src, mask, d_dropout_src, d_residual, d_bias);
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fused/fused_dropout_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"

Expand All @@ -37,7 +37,7 @@ USE_OP_ITSELF(dropout);
USE_OP_ITSELF(layer_norm);

template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
using CudnnDataType = phi::backends::gpu::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;

Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/operators/fused/fused_feedforward_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
Expand Down Expand Up @@ -120,7 +120,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx, bsz_seq, d_model, dropout_param2, epsilon2);

using U = LayerNormParamType<T>;
using U = phi::funcs::LayerNormParamType<T>;
const phi::DenseTensor* in = &x;

const U* ln1_scale_ptr =
Expand Down Expand Up @@ -238,7 +238,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2);

using U = LayerNormParamType<T>;
using U = phi::funcs::LayerNormParamType<T>;
dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
dev_ctx.Alloc<uint8_t>(dropout1_mask,
dropout1_mask->numel() * sizeof(uint8_t));
Expand Down Expand Up @@ -369,7 +369,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx, bsz_seq, d_model, dropout_param2, epsilon2);

using U = LayerNormParamType<T>;
using U = phi::funcs::LayerNormParamType<T>;
const U* ln1_gamma_ptr =
ln1_gamma == nullptr ? nullptr : ln1_gamma->data<U>();
const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data<U>();
Expand Down Expand Up @@ -485,7 +485,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
}

void Compute(const framework::ExecutionContext& context) const override {
using U = LayerNormParamType<T>;
using U = phi::funcs::LayerNormParamType<T>;
auto& dev_ctx = context.template device_context<phi::GPUContext>();
auto d_out =
*context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
Expand Down
Loading