diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu index 0307fbb1dd350..fff112d095b42 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu @@ -24,9 +24,9 @@ #include "paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.h" #include "paddle/fluid/operators/fused/fused_dropout_common.h" #include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h" -#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h" #include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" namespace paddle { diff --git a/paddle/fluid/operators/fused/attention_layer_norm.h b/paddle/fluid/operators/fused/attention_layer_norm.h index e54bca8a89368..92cbc37059eb1 100644 --- a/paddle/fluid/operators/fused/attention_layer_norm.h +++ b/paddle/fluid/operators/fused/attention_layer_norm.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/layer_norm_kernel.cu.h" +#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" namespace paddle { namespace operators { @@ -35,11 +35,11 @@ class AttnLayerNorm { ~AttnLayerNorm() {} void ComputeForward(const InType* x_data, - const LayerNormParamType* scale_data, - const LayerNormParamType* bias_data, + const phi::funcs::LayerNormParamType* scale_data, + const phi::funcs::LayerNormParamType* bias_data, OutType* y_data, - LayerNormParamType* mean_data, - LayerNormParamType* var_data, + phi::funcs::LayerNormParamType* mean_data, + phi::funcs::LayerNormParamType* var_data, const float* dequant_out_scale_data = nullptr, const int quant_out_scale_offset = 0, const float quant_in_scale = 1.0, @@ -48,14 +48,14 @@ class AttnLayerNorm { const float quant_min_bound = -127.0) { auto stream = dev_ctx_.stream(); - switch (GetDesiredBlockDim(feature_size_)) { + switch (phi::funcs::GetDesiredBlockDim(feature_size_)) { FIXED_BLOCK_DIM_CASE( - LayerNormForward, - kBlockDim, - false, - InType, - OutType> + phi::funcs::LayerNormForward, + kBlockDim, + false, + InType, + OutType> <<>>(x_data, scale_data, bias_data, @@ -71,32 +71,33 @@ class AttnLayerNorm { quant_max_bound, quant_min_bound)); default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Feature_size must be larger than 1")); + PADDLE_THROW( + phi::errors::InvalidArgument("Feature_size must be larger than 1")); break; } } void ComputeBackward(const T* x_data, const T* d_y_data, - const LayerNormParamType* scale_data, - const LayerNormParamType* mean_data, - const LayerNormParamType* var_data, + const phi::funcs::LayerNormParamType* scale_data, + const phi::funcs::LayerNormParamType* mean_data, + const phi::funcs::LayerNormParamType* var_data, T* d_x_data, - LayerNormParamType* d_scale_data, - LayerNormParamType* d_bias_data) { - LayerNormBackward>(x_data, - d_y_data, - scale_data, - mean_data, - var_data, - d_x_data, - d_scale_data, - d_bias_data, - epsilon_, - batch_size_, - feature_size_, - dev_ctx_); + phi::funcs::LayerNormParamType* d_scale_data, + phi::funcs::LayerNormParamType* d_bias_data) { + phi::funcs::LayerNormBackward>( + x_data, + d_y_data, + scale_data, + mean_data, + var_data, + d_x_data, + d_scale_data, + d_bias_data, + epsilon_, + batch_size_, + feature_size_, + dev_ctx_); } private: diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h index 1156d04b8f557..cce4fe911bd9c 100644 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias.h +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -26,7 +26,7 @@ namespace operators { template struct GeluFunctor { inline __host__ __device__ T operator()(const T x) const { - using U = LayerNormParamType; + using U = phi::funcs::LayerNormParamType; const U casted_x = static_cast(x); const U temp = erf(casted_x * static_cast(M_SQRT1_2)); const U out = (casted_x * static_cast(0.5) * (static_cast(1) + temp)); @@ -47,7 +47,7 @@ struct FastGeluFunctor { template struct GeluGradFunctor { inline __host__ __device__ T UseOut(const T x) const { - using U = LayerNormParamType; + using U = phi::funcs::LayerNormParamType; auto casted_x = static_cast(x); auto first = diff --git a/paddle/fluid/operators/fused/fused_dropout_common.h b/paddle/fluid/operators/fused/fused_dropout_common.h index 0fbc14436e914..06606f4a1adad 100644 --- a/paddle/fluid/operators/fused/fused_dropout_common.h +++ b/paddle/fluid/operators/fused/fused_dropout_common.h @@ -21,13 +21,13 @@ limitations under the License. */ #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/fused/quant_dequant_kernel.h" -#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/functors.h" +#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" namespace paddle { namespace operators { @@ -138,7 +138,7 @@ inline __device__ void CalculateDBias(const T *tmp_sum, int reduce_num_pre_thread = (BlockSizeX * VecSize + 31) / 32; // reduce 32 to 1 for (int i = 0; i < reduce_num_pre_thread; i++) { - sum[i] = WarpReduceSum(sum[i]); + sum[i] = phi::funcs::WarpReduceSum(sum[i]); } // save sum to dbias diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index 1e6be41315c61..e530ae6b40ac5 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -418,18 +418,18 @@ class FusedDropoutLayerNormHelper LayerNormParamType* d_scale, LayerNormParamType* d_bias) { using U = LayerNormParamType; - LayerNormBackward(src, - dout, - gamma, - mean, - variance, - d_src, - d_scale, - d_bias, - epsilon_, - this->rows_, - this->cols_, - ctx); + phi::funcs::LayerNormBackward(src, + dout, + gamma, + mean, + variance, + d_src, + d_scale, + d_bias, + epsilon_, + this->rows_, + this->cols_, + ctx); } // out = layernorm(residual + dropout(src + bias)) @@ -457,7 +457,7 @@ class FusedDropoutLayerNormHelper if (this->cols_ % vec_size != 0) { vec_size = 1; } - int threads = GetDesiredBlockDim(this->cols_ / vec_size); + int threads = phi::funcs::GetDesiredBlockDim(this->cols_ / vec_size); int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size; increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment); LaunchLayernormResidualDropoutBias(layernorm_src, - d_out, - gamma, - mean, - variance, - d_layernorm_src, - d_scale, - d_layernorm_bias, - epsilon_, - this->rows_, - this->cols_, - ctx); + phi::funcs::LayerNormBackward(layernorm_src, + d_out, + gamma, + mean, + variance, + d_layernorm_src, + d_scale, + d_layernorm_bias, + epsilon_, + this->rows_, + this->cols_, + ctx); this->ResidualDropoutBiasGrad( ctx, d_layernorm_src, mask, d_dropout_src, d_residual, d_bias); } diff --git a/paddle/fluid/operators/fused/fused_dropout_test.h b/paddle/fluid/operators/fused/fused_dropout_test.h index a985d23b483a7..cb3f56302b89f 100644 --- a/paddle/fluid/operators/fused/fused_dropout_test.h +++ b/paddle/fluid/operators/fused/fused_dropout_test.h @@ -23,9 +23,9 @@ limitations under the License. */ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/layer_norm_kernel.h" @@ -37,7 +37,7 @@ USE_OP_ITSELF(dropout); USE_OP_ITSELF(layer_norm); template -using CudnnDataType = platform::CudnnDataType; +using CudnnDataType = phi::backends::gpu::CudnnDataType; template using LayerNormParamType = typename CudnnDataType::BatchNormParamType; diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cu b/paddle/fluid/operators/fused/fused_feedforward_op.cu index 925ec7d2060a4..2058d9448cdfc 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cu +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cu @@ -15,12 +15,12 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" -#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/operators/matmul_v2_op.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/process_group_nccl.h" @@ -120,7 +120,7 @@ class FusedFeedForwardKernel : public framework::OpKernel { FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx, bsz_seq, d_model, dropout_param2, epsilon2); - using U = LayerNormParamType; + using U = phi::funcs::LayerNormParamType; const phi::DenseTensor* in = &x; const U* ln1_scale_ptr = @@ -238,7 +238,7 @@ class FusedFeedForwardKernel : public framework::OpKernel { DropoutParam dropout_param1(context, 1); DropoutParam dropout_param2(context, 2); - using U = LayerNormParamType; + using U = phi::funcs::LayerNormParamType; dev_ctx.Alloc(out, out->numel() * sizeof(T)); dev_ctx.Alloc(dropout1_mask, dropout1_mask->numel() * sizeof(uint8_t)); @@ -369,7 +369,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx, bsz_seq, d_model, dropout_param2, epsilon2); - using U = LayerNormParamType; + using U = phi::funcs::LayerNormParamType; const U* ln1_gamma_ptr = ln1_gamma == nullptr ? nullptr : ln1_gamma->data(); const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data(); @@ -485,7 +485,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { } void Compute(const framework::ExecutionContext& context) const override { - using U = LayerNormParamType; + using U = phi::funcs::LayerNormParamType; auto& dev_ctx = context.template device_context(); auto d_out = *context.Input(framework::GradVarName("Out")); diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index 0c4e10fa156f9..a6bd467dc1992 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -22,7 +22,7 @@ namespace operators { #define LN_NUM_COLS 1024 template -using CudnnDataType = platform::CudnnDataType; +using CudnnDataType = phi::backends::gpu::CudnnDataType; template using LayerNormParamType = typename CudnnDataType::BatchNormParamType; @@ -174,8 +174,8 @@ __global__ void FusedLayernormResidualDropoutBias( relu); } - mean_val = BlockReduceSum(mean_val, shared_mean); - var_val = BlockReduceSum(var_val, shared_var); + mean_val = phi::funcs::BlockReduceSum(mean_val, shared_mean); + var_val = phi::funcs::BlockReduceSum(var_val, shared_var); if (threadIdx.x == 0) { auto scale = static_cast>( static_cast(1.) / static_cast(cols)); @@ -189,7 +189,7 @@ __global__ void FusedLayernormResidualDropoutBias( __syncthreads(); mean_val = mean_share; - U invvar = rsqrt_(var_share + static_cast(epsilon)); + U invvar = phi::funcs::rsqrt_(var_share + static_cast(epsilon)); // calculate layernorm_dst CalcLayernormY(scale, @@ -358,8 +358,8 @@ __global__ void FusedLayernormResidualDropoutBiasInfer( relu); } - mean_val = BlockReduceSum(mean_val, shared_mean); - var_val = BlockReduceSum(var_val, shared_var); + mean_val = phi::funcs::BlockReduceSum(mean_val, shared_mean); + var_val = phi::funcs::BlockReduceSum(var_val, shared_var); if (threadIdx.x == 0) { auto scale = static_cast>( static_cast(1.) / static_cast(cols)); @@ -372,7 +372,7 @@ __global__ void FusedLayernormResidualDropoutBiasInfer( __syncthreads(); mean_val = mean_share; - U invvar = rsqrt_(var_share + static_cast(epsilon)); + U invvar = phi::funcs::rsqrt_(var_share + static_cast(epsilon)); // calculate layernorm_dst CalcLayernormY(scale, @@ -412,7 +412,7 @@ struct FusedLayernormResidualDropoutBiasFunctor { LayerNormParamType *mean, LayerNormParamType *var, cudaStream_t stream) { - int blockDim = GetDesiredBlockDim(cols / VecSize); + int blockDim = phi::funcs::GetDesiredBlockDim(cols / VecSize); if (mean != nullptr && var != nullptr) { LaunchFusedLayernormResidualDropoutBiasCUDAKernel + phi::funcs::LayerNormForward <<>>( dst, scale, @@ -1005,7 +1005,7 @@ void LaunchLayernormResidualDropoutBias( const int VecSize = MAX_CACHE_BYTES / sizeof(T); if (cols % VecSize != 0) { - int blockDim = GetDesiredBlockDim(cols); + int blockDim = phi::funcs::GetDesiredBlockDim(cols); LaunchFusedLayernormResidualDropoutBiasCUDAKernel(1.0f); } - ln_bwd_fast_kernel_driver, - MaskType>(dev_ctx, - rows, - cols, - epsilon, - layernorm_src, - scale, - mean, - var, - d_out, - d_residual, - d_scale, - d_layernorm_bias, - mask_data, - factor, - d_dropout_src); + phi::funcs::ln_bwd_fast_kernel_driver< + T, + U, + LayerNormScaleBiasT, + MaskType>(dev_ctx, + rows, + cols, + epsilon, + layernorm_src, + scale, + mean, + var, + d_out, + d_residual, + d_scale, + d_layernorm_bias, + mask_data, + factor, + d_dropout_src); } } // namespace operators diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu index f4353dd9bd4d7..e57f39ec01e6d 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu @@ -235,7 +235,7 @@ struct TestFusedLayernormResidualDropoutBias { if (cols % 4 != 0) { VecSize = 1; } - int threads = paddle::operators::GetDesiredBlockDim(cols / VecSize); + int threads = phi::funcs::GetDesiredBlockDim(cols / VecSize); const int increment = ((cols - 1) / (threads * VecSize) + 1) * VecSize; T *bias_ptr = nullptr; diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h similarity index 97% rename from paddle/fluid/operators/layer_norm_kernel.cu.h rename to paddle/phi/kernels/funcs/layer_norm_impl.cu.h index f771a93530c90..157f507e68d9e 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h @@ -24,17 +24,17 @@ namespace cub = hipcub; #include -#include "paddle/fluid/operators/fused/quant_dequant_kernel.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/fluid/memory/malloc.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" -namespace paddle { -namespace operators { +namespace phi { +namespace funcs { template -using CudnnDataType = platform::CudnnDataType; +using CudnnDataType = phi::backends::gpu::CudnnDataType; template using LayerNormParamType = typename CudnnDataType::BatchNormParamType; @@ -331,6 +331,38 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( } #endif +template +inline HOSTDEVICE T roundWithTiesToEven(T x) { + T xLower = floor(x); + T xUpper = ceil(x); + // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to + // even. + T dLower = x - xLower; + T dUpper = xUpper - x; + return static_cast( + (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) + ? xLower + : xUpper); +} + +template +__forceinline__ __device__ int8_t quant_helper(const T input, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + float quant_value = max_bound * scale * static_cast(input); + + if (round_type == 0) { + quant_value = static_cast(roundWithTiesToEven(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + return static_cast(quant_value); +} + template using LayerNormScaleBiasT = typename std::conditional::type; @@ -947,17 +979,17 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx, // get temp space for dscale and dbias. phi::DenseTensor dscale_temp; dscale_temp.Resize({gridx, cols}); - dscale_temp.mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(&dscale_temp); U *dscale_temp_ptr = dscale_temp.data(); phi::DenseTensor dbias_temp; dbias_temp.Resize({gridx, cols}); - dbias_temp.mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(&dbias_temp); U *dbias_temp_ptr = dbias_temp.data(); if (mask_ptr != nullptr) { if (d_dropout_src_ptr == nullptr) { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "To compute fused_dropout_residual_ln grad, d_dropout_src_ptr " "can't be null")); } @@ -1069,8 +1101,8 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx, // #blocks: 32,#threads_per_block: 512 // Note: it is not supported for double type. if (sizeof(U) > 4) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Only support float and fp16 type")); + PADDLE_THROW( + phi::errors::InvalidArgument("Only support float and fp16 type")); } else { int gridx_2 = 0; @@ -1103,7 +1135,7 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx, #undef LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL } } else { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "Fast layer_norm kernel is only used when feature_size is 1024")); } } @@ -1891,11 +1923,11 @@ static void LayerNormBackward( constexpr int part_size = BDIMY2 * VPT; const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1); - auto part_grad_gamma_ptr = memory::Alloc( + auto part_grad_gamma_ptr = paddle::memory::Alloc( dev_ctx.GetPlace(), part_size * feature_size * sizeof(U), phi::Stream(reinterpret_cast(dev_ctx.stream()))); - auto part_grad_beta_ptr = memory::Alloc( + auto part_grad_beta_ptr = paddle::memory::Alloc( dev_ctx.GetPlace(), part_size * feature_size * sizeof(U), phi::Stream(reinterpret_cast(dev_ctx.stream()))); @@ -1959,5 +1991,5 @@ static void LayerNormBackward( } } -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/layer_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_grad_kernel.cu index 0ec43eab3785c..ba731e700e8ea 100644 --- a/paddle/phi/kernels/gpu/layer_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/layer_norm_grad_kernel.cu @@ -14,9 +14,9 @@ #include "paddle/phi/kernels/layer_norm_grad_kernel.h" -#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" #include "paddle/phi/kernels/funcs/layer_norm_util.h" namespace phi { @@ -34,7 +34,7 @@ void LayerNormGradKernel(const Context &dev_ctx, DenseTensor *x_grad, DenseTensor *scale_grad, DenseTensor *bias_grad) { - using U = paddle::operators::LayerNormParamType; + using U = phi::funcs::LayerNormParamType; // d_x, d_scale, d_bias may be nullptr auto *d_x = x_grad; auto *d_scale = scale_grad; @@ -84,7 +84,7 @@ void LayerNormGradKernel(const Context &dev_ctx, : dev_ctx.template Alloc(d_bias)); \ auto *d_x_data = \ (d_x == nullptr ? nullptr : dev_ctx.template Alloc(d_x)); \ - paddle::operators::LayerNormBackward( \ + phi::funcs::LayerNormBackward( \ x_data, \ d_y_data, \ scale_data, \ diff --git a/paddle/phi/kernels/gpu/layer_norm_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_kernel.cu index 1350cb2209c31..cccf93f944640 100644 --- a/paddle/phi/kernels/gpu/layer_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/layer_norm_kernel.cu @@ -14,9 +14,9 @@ #include "paddle/phi/kernels/layer_norm_kernel.h" -#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" #include "paddle/phi/kernels/funcs/layer_norm_util.h" namespace phi { @@ -36,9 +36,9 @@ void LayerNormDirectCUDAFunctor::operator()(gpuStream_t stream, auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); int64_t batch_size = static_cast(matrix_dim[0]); int64_t feature_size = static_cast(matrix_dim[1]); - switch (paddle::operators::GetDesiredBlockDim(feature_size)) { + switch (phi::funcs::GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( - paddle::operators::LayerNormForward + phi::funcs::LayerNormForward <<>>( input, scale, bias, output, mean, variance, eps, feature_size)); default: @@ -65,7 +65,7 @@ void LayerNormKernel(const Context &dev_ctx, DenseTensor *y, DenseTensor *mean, DenseTensor *var) { - using U = paddle::operators::LayerNormParamType; + using U = phi::funcs::LayerNormParamType; auto *scale = scale_opt.get_ptr(); auto *bias = bias_opt.get_ptr(); @@ -109,9 +109,9 @@ void LayerNormKernel(const Context &dev_ctx, #define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ do { \ - switch (paddle::operators::GetDesiredBlockDim(feature_size)) { \ + switch (phi::funcs::GetDesiredBlockDim(feature_size)) { \ FIXED_BLOCK_DIM_CASE( \ - paddle::operators:: \ + phi::funcs:: \ LayerNormForward \ <<>>( \ x_data, \ @@ -140,13 +140,13 @@ void LayerNormKernel(const Context &dev_ctx, const int ROWS_PER_CTA = WARPS_M; \ const int grid = static_cast( \ std::ceil(batch_size / static_cast(ROWS_PER_CTA))); \ - paddle::operators::fast_ln_fwd_kernel \ + phi::funcs::fast_ln_fwd_kernel \ <<>>( \ batch_size, \ feature_size, \