From 466f98238b64c4fb1490afbf6d18abc900d82e7d Mon Sep 17 00:00:00 2001 From: co63oc Date: Wed, 24 Apr 2024 20:54:51 +0800 Subject: [PATCH] Fix --- paddle/fluid/operators/cudnn_rnn_cache.h | 2 +- .../fluid/operators/detection/bbox_util.cu.h | 2 +- .../fluid/operators/fused/attn_bias_add.cu.h | 366 --------- .../fluid/operators/fused/attn_feed_forward.h | 125 --- .../fused/cudnn_bn_stats_finalize.cu.h | 9 +- .../operators/fused/cudnn_norm_conv.cu.h | 9 +- .../fused/cudnn_scale_bias_add_relu.cu.h | 11 +- paddle/fluid/operators/fused/fmha_ref.h | 750 ------------------ .../fused/fused_multi_transformer_int8_op.cu | 10 +- .../fused/fused_multi_transformer_op.cu | 6 +- .../fused/fused_multi_transformer_op.cu.h | 33 +- .../fluid/operators/fused/resnet_unit_op.cu | 4 +- .../fused/xpu_fused_common_function.h | 225 ------ .../operators/grid_sampler_cudnn_op.cu.cc | 2 +- paddle/fluid/operators/math/prelu.h | 2 +- paddle/fluid/operators/miopen_rnn_cache.h | 2 +- .../sequence_ops/sequence_softmax_op.cc | 2 +- 17 files changed, 49 insertions(+), 1511 deletions(-) delete mode 100644 paddle/fluid/operators/fused/attn_bias_add.cu.h delete mode 100644 paddle/fluid/operators/fused/attn_feed_forward.h delete mode 100644 paddle/fluid/operators/fused/fmha_ref.h delete mode 100644 paddle/fluid/operators/fused/xpu_fused_common_function.h diff --git a/paddle/fluid/operators/cudnn_rnn_cache.h b/paddle/fluid/operators/cudnn_rnn_cache.h index eaca6842d350ce..a7fd8ad384608b 100644 --- a/paddle/fluid/operators/cudnn_rnn_cache.h +++ b/paddle/fluid/operators/cudnn_rnn_cache.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/detection/bbox_util.cu.h b/paddle/fluid/operators/detection/bbox_util.cu.h index adb60a8a8d0642..abd34c3c2025a2 100644 --- a/paddle/fluid/operators/detection/bbox_util.cu.h +++ b/paddle/fluid/operators/detection/bbox_util.cu.h @@ -23,8 +23,8 @@ limitations under the License. */ #include namespace cub = hipcub; #endif -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h deleted file mode 100644 index b5eab8f4145503..00000000000000 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ /dev/null @@ -1,366 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#ifdef __NVCC__ -#include "cub/cub.cuh" -#endif -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif - -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" - -#ifdef __HIPCC__ -#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim) -#else -#define LAUNCH_BOUNDS(BlockDim) -#endif - -#include "paddle/fluid/operators/elementwise/elementwise_functor.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/phi/kernels/funcs/fast_divmod.h" -#include "paddle/phi/kernels/funcs/reduce_function.h" -#include "paddle/phi/kernels/primitive/kernel_primitives.h" - -namespace paddle { -namespace operators { - -#define MAX_INPUT_NUM 2 - -template -using CudnnDataType = platform::CudnnDataType; -template -using ReduceParamType = typename CudnnDataType::BatchNormParamType; - -template -__global__ void BroadcastKernelBinary( - const InT* __restrict__ in0, - const InT* __restrict__ in1, - OutT* out, - phi::Array use_broadcast, - uint32_t numel, - phi::Array configlists, - int main_tid, - int tail_tid, - Functor func) { - int fix = blockIdx.x * blockDim.x * VecSize; - int num = tail_tid; - InT arg0[VecSize * DATA_PER_THREAD]; - InT arg1[VecSize * DATA_PER_THREAD]; - OutT result[VecSize * DATA_PER_THREAD]; - if (blockIdx.x < main_tid) { - num = blockDim.x * VecSize; // blockIdx.x < main_tid - } - - // load in0 - if (use_broadcast[0]) { - phi::kps::ReadDataBc( - arg0, in0, fix, configlists[0], numel); - } else { - phi::kps::ReadData(arg0, in0 + fix, num); - } - // load in1 - if (use_broadcast[1]) { - phi::kps::ReadDataBc( - arg1, in1, fix, configlists[1], numel); - } else { - phi::kps::ReadData(arg1, in1 + fix, num); - } - // compute - phi::kps::ElementwiseBinary( - result, arg0, arg1, func); - // store - phi::kps::WriteData(out + fix, result, num); -} - -// bias add forward impl for "[m, n] + [n] = [m, n]" -template -void LaunchBiasAddFwKernel(const phi::GPUContext& ctx, - int m, - int n, - const T* in0, - const T* in1, - T* out) { - uint64_t addr = - (reinterpret_cast(in0) | reinterpret_cast(in1) | - reinterpret_cast(out)); - int vec_size = phi::GetVectorizedSize(reinterpret_cast(addr)); - int numel = m * n; - const int threads = 256; - const int data_per_thread = 1; - int blocks = - ((numel + vec_size * data_per_thread - 1) / (vec_size * data_per_thread) + - threads - 1) / - threads; - int main_tid = numel / (data_per_thread * vec_size * threads); - int tail_tid = numel % (data_per_thread * vec_size * threads); - - phi::Array configlists; - phi::Array use_broadcast; - - use_broadcast[0] = false; - use_broadcast[1] = false; - if (m != 1) { - use_broadcast[1] = true; - } - // Here, dims are transposed due to the logic in BroadcastConfig. - std::vector input1_dims = {n, 1}; - std::vector out_dims = {n, m}; - configlists[1] = kps::details::BroadcastConfig(out_dims, input1_dims, 2); - - auto func = AddFunctor(); - auto stream = ctx.stream(); - switch (vec_size) { - case 4: { - BroadcastKernelBinary - <<>>(in0, - in1, - out, - use_broadcast, - numel, - configlists, - main_tid, - tail_tid, - func); - break; - } - case 2: { - BroadcastKernelBinary - <<>>(in0, - in1, - out, - use_broadcast, - numel, - configlists, - main_tid, - tail_tid, - func); - break; - } - case 1: { - BroadcastKernelBinary - <<>>(in0, - in1, - out, - use_broadcast, - numel, - configlists, - main_tid, - tail_tid, - func); - break; - } - default: { - PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported vectorized size: %d !", vec_size)); - break; - } - } -} - -template -__global__ void LAUNCH_BOUNDS(BlockDim) Compute1DColumnReduceKernel( - const int reduce_num, const int left_num, const T* in, T* out) { - typedef cub::BlockReduce, BlockDim> BlockReduce; - __shared__ typename BlockReduce::TempStorage mean_storage; - - for (int i = blockIdx.x; i < left_num; i += gridDim.x) { - ReduceParamType x_sum = static_cast>(0); - for (int j = threadIdx.x; j < reduce_num; j += blockDim.x) { - const int index = j * left_num + i; - ReduceParamType x_i = static_cast>(in[index]); - x_sum += x_i; - } - x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum()); - if (threadIdx.x == 0) { - out[i] = static_cast(x_sum); - } - } -} - -template -void Launch1DColumnReduce(gpuStream_t stream, - const int max_threads, - const int reduce_num, - const int left_num, - const T* d_out, - T* d_bias) { - const int block = 256; - const int max_blocks = std::max(max_threads / block, 1); - const int grid = std::min(left_num, max_blocks); - Compute1DColumnReduceKernel - <<>>(reduce_num, left_num, d_out, d_bias); -} - -void SetConfigForColumnReduce(const int max_threads, - const int reduce_num, - const int left_num, - int* blocking_size, - bool* should_reduce_again, - dim3* block_dim, - dim3* grid_dim) { - block_dim->z = 1; - grid_dim->z = 1; - *should_reduce_again = false; - - int num_block = (max_threads / left_num); - if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) { - *blocking_size = phi::funcs::details::GetLastPow2(reduce_num / num_block); - if (*blocking_size <= 1) { - *blocking_size = phi::funcs::details::GetLastPow2(sqrt(reduce_num)); - } else if (*blocking_size * 2 < reduce_num) { - *blocking_size *= 2; - } - *should_reduce_again = true; - block_dim->x = 32; - block_dim->y = 1; - grid_dim->x = (left_num + block_dim->x - 1) / block_dim->x; - grid_dim->y = (reduce_num + *blocking_size - 1) / *blocking_size; - } else { - block_dim->x = 32; - *blocking_size = reduce_num; - grid_dim->x = (left_num + block_dim->x - 1) / block_dim->x; - grid_dim->y = 1; - } -} - -template -__global__ void BiasAddBwSinglePassKernel(const T* in, - int reduce_num, - int left_num, - T* out) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - ReduceParamType x_sum = static_cast>(0); - if (idx < left_num) { - for (int iy = 0; iy < reduce_num; iy++) { - int id = iy * left_num + idx; - ReduceParamType x_val = static_cast>(in[id]); - x_sum += x_val; - } - out[idx] = static_cast(x_sum); - } -} - -template -__global__ void BiasAddBw2DReduceKernel(const T* x, - int reduce_num, - int left_num, - int workload_per_thread, - ReduceParamType* temp_x_sum) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int idy = blockIdx.y * workload_per_thread; - - T x_val; - ReduceParamType x_sum = static_cast>(0); - if (idx < left_num) { - int loop = reduce_num - idy; - loop = loop > workload_per_thread ? workload_per_thread : loop; - for (int iy = 0; iy < loop; iy++) { - int id = (idy + iy) * left_num + idx; - ReduceParamType x_val = static_cast>(x[id]); - x_sum += x_val; - } - temp_x_sum[idx + blockIdx.y * left_num] = x_sum; - } -} - -template -__global__ void BiasAddBw1DReduceKernel(const ReduceParamType* temp_sum, - int workload_per_thread, - int left_num, - T* out) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - ReduceParamType x_sum = static_cast>(0); - if (idx < left_num) { - for (int iy = 0; iy < workload_per_thread; iy++) { - int id = iy * left_num + idx; - x_sum += temp_sum[id]; - } - out[idx] = static_cast(x_sum); - } -} - -template -void Launch2DColumnReduce(const phi::GPUContext& dev_ctx, - const int max_threads, - const int reduce_num, - const int left_num, - const T* d_out, - T* d_bias) { - dim3 block; - dim3 grid; - bool should_reduce_again = false; - int blocking_size = 1; - SetConfigForColumnReduce(max_threads, - reduce_num, - left_num, - &blocking_size, - &should_reduce_again, - &block, - &grid); - const auto& stream = dev_ctx.stream(); - - if (!should_reduce_again) { - BiasAddBwSinglePassKernel - <<>>(d_out, reduce_num, left_num, d_bias); - } else { - phi::DenseTensor tmp_sum; - tmp_sum.Resize({grid.y, left_num}); - dev_ctx.template Alloc>( - &tmp_sum, tmp_sum.numel() * sizeof(ReduceParamType)); - - BiasAddBw2DReduceKernel<<>>( - d_out, - reduce_num, - left_num, - blocking_size, - tmp_sum.template data>()); - - BiasAddBw1DReduceKernel<<>>( - tmp_sum.template data>(), grid.y, left_num, d_bias); - } -} - -// bias add backward impl whose pattern are column-reduce with d_out[m, n] as -// input -// and d_bias[n] as output. -template -void LaunchBiasAddBwKernel( - const phi::GPUContext& dev_ctx, int m, int n, const T* d_out, T* d_bias) { - int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); - int reduce_num = m; - int left_num = n; - bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) || - (left_num > REDUCE_SPLIT_BOUNDARY); - if (!is_large_enough) { - Launch1DColumnReduce( - dev_ctx.stream(), max_threads, reduce_num, left_num, d_out, d_bias); - } else { - Launch2DColumnReduce( - dev_ctx, max_threads, reduce_num, left_num, d_out, d_bias); - } -} - -#undef MAX_INPUT_NUM - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/attn_feed_forward.h b/paddle/fluid/operators/fused/attn_feed_forward.h deleted file mode 100644 index 25ba1cc13ead2a..00000000000000 --- a/paddle/fluid/operators/fused/attn_feed_forward.h +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/operators/fused/attn_bias_add.cu.h" -#include "paddle/phi/common/float16.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" - -namespace paddle { -namespace operators { - -template -class FeedForward { - public: - FeedForward(const phi::GPUContext& dev_ctx, - int bsz_seq, - int output_size, - int input_size, - bool compute_bias) - : dev_ctx_(dev_ctx), - bsz_seq_(bsz_seq), - output_size_(output_size), - input_size_(input_size), - compute_bias_(compute_bias) {} - - ~FeedForward() {} - - void ComputeForward(const T* weight_data, - const T* input_data, - const T* bias_data, - T* output_data, - T* bias_out_data) { - // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. - // To convert to col-major expression, transa<->transb, A<->B, m<->n. - - // column-major: gemm-tn. - CBLAS_TRANSPOSE transA = CblasNoTrans; - CBLAS_TRANSPOSE transB = CblasTrans; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - - // column-major: (m,n,k) = output_size,bsz_seq,input_size (weight*input=out) - // here: (m,n,k) = bsz_seq,output_size,input_size (input*weight=out) - auto blas = phi::funcs::GetBlas(dev_ctx_); - blas.GEMM(transA, - transB, - bsz_seq_, - output_size_, - input_size_, - alpha, - input_data, - weight_data, - beta, - output_data); - if (compute_bias_) { - LaunchBiasAddFwKernel(dev_ctx_, - bsz_seq_, - output_size_, - output_data, - bias_data, - bias_out_data); - } - } - - void ComputeBackward( - T* input, T* weight, T* d_output, T* d_input, T* d_weight, T* d_bias) { - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - auto blas = phi::funcs::GetBlas(dev_ctx_); - - // column-major: gemm-nt, get d_weight. - CBLAS_TRANSPOSE transA = CblasTrans; - CBLAS_TRANSPOSE transB = CblasNoTrans; - // column-major: (m,n,k): input_size,output_size,bsz (input*dout=dweight) - // here: (m,n,k): output_size,input_size,bsz (dout*input=dweight) - blas.GEMM(transA, - transB, - output_size_, - input_size_, - bsz_seq_, - alpha, - d_output, - input, - beta, - d_weight); - - // column-major: gemm-nn: get d_input. - transA = CblasNoTrans; - // column-major: (m,n,k): input_size,bsz,output_size (weight*dout=dinput) - // here: (m, n, k): bsz, input_size, output_size, (dout*weight=dinput) - blas.GEMM(transA, - transB, - bsz_seq_, - input_size_, - output_size_, - alpha, - d_output, - weight, - beta, - d_input); - if (compute_bias_) { - LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias); - } - } - - private: - const phi::GPUContext& dev_ctx_; - int bsz_seq_, output_size_, input_size_; - bool compute_bias_; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h index 5fb6f38b4c682f..22ed80c50e57ea 100644 --- a/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h +++ b/paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" namespace paddle { namespace operators { @@ -23,15 +23,16 @@ namespace operators { namespace dynload = platform::dynload; template using BatchNormParamType = - typename platform::CudnnDataType::BatchNormParamType; + typename phi::backends::gpu::CudnnDataType::BatchNormParamType; #if CUDNN_VERSION >= 8000 template struct BNStatsFinalizeArgs { BNStatsFinalizeArgs() { - dtype = platform::CudnnDataType::type; - param_dtype = platform::CudnnDataType>::type; + dtype = phi::backends::gpu::CudnnDataType::type; + param_dtype = + phi::backends::gpu::CudnnDataType>::type; format = CUDNN_TENSOR_NHWC; } diff --git a/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h b/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h index 5d0e6c44c4e63f..f41307aab5e2eb 100644 --- a/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h +++ b/paddle/fluid/operators/fused/cudnn_norm_conv.cu.h @@ -15,14 +15,15 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" namespace paddle { namespace operators { namespace dynload = platform::dynload; template -using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; +using ScalingParamType = + typename phi::backends::gpu::CudnnDataType::ScalingParamType; #if CUDNN_VERSION >= 8000 @@ -31,9 +32,9 @@ static size_t RoundUp(int64_t a, int64_t b) { return (a + b - 1) / b * b; } template struct NormConvolutionArgs { NormConvolutionArgs() { - dtype = platform::CudnnDataType::type; + dtype = phi::backends::gpu::CudnnDataType::type; format = CUDNN_TENSOR_NHWC; - compute_type = platform::CudnnDataType::type; + compute_type = phi::backends::gpu::CudnnDataType::type; } void Set(const phi::GPUContext &ctx, diff --git a/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h index 7f47ea40e6cea0..5f62e3fa69de38 100644 --- a/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h +++ b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h @@ -15,24 +15,25 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" namespace paddle { namespace operators { template -using CudnnDataType = platform::CudnnDataType; +using CudnnDataType = phi::backends::gpu::CudnnDataType; namespace dynload = platform::dynload; template using BatchNormParamType = - typename platform::CudnnDataType::BatchNormParamType; + typename phi::backends::gpu::CudnnDataType::BatchNormParamType; #if CUDNN_VERSION >= 8000 template struct ScaleBiasAddReluArgs { ScaleBiasAddReluArgs() { - dtype = platform::CudnnDataType::type; - param_dtype = platform::CudnnDataType>::type; + dtype = phi::backends::gpu::CudnnDataType::type; + param_dtype = + phi::backends::gpu::CudnnDataType>::type; format = CUDNN_TENSOR_NHWC; } diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h deleted file mode 100644 index 2a43eea07535ab..00000000000000 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ /dev/null @@ -1,750 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/broadcast_function.h" -#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" -#include "paddle/phi/kernels/funcs/dropout_impl.cu.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" -#include "paddle/phi/kernels/funcs/elementwise_functor.h" -#include "paddle/phi/kernels/funcs/functors.h" -#include "paddle/phi/kernels/funcs/transpose_function.cu.h" -#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" - -namespace paddle { -namespace operators { - -class AttnDropoutParam { - public: - AttnDropoutParam() { - is_test_ = false; - dropout_implementation_ = "downgrade_in_infer"; - dropout_prob_ = 0.5; - is_upscale_in_train_ = false; - is_fix_seed_ = false; - seed_val_ = 0; - seed_ = nullptr; - } - AttnDropoutParam(bool is_test, - const std::string dropout_implementation, - float dropout_prob, - bool is_upscale_in_train, - bool is_fix_seed, - int seed_val, - const phi::DenseTensor* seed) { - is_test_ = is_test; - dropout_implementation_ = dropout_implementation; - dropout_prob_ = dropout_prob; - is_upscale_in_train_ = is_upscale_in_train; - is_fix_seed_ = is_fix_seed; - seed_val_ = seed_val; - seed_ = seed; - } - bool is_test_; - std::string dropout_implementation_; - float dropout_prob_; - bool is_upscale_in_train_; - bool is_fix_seed_; - int seed_val_; - const phi::DenseTensor* seed_; -}; - -template -__global__ void TransposeRemovingPadding(const T* input_data, - T* output_data, - const int batch_size, - const int num_head, - const int seq_len, - const int head_dim, - const int token_num, - const int elem_cnt, - const int* padding_offset) { - // transpose and remove padding - // [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head, - // head_dim] - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - const int dim_embed = num_head * head_dim; - using LoadT = phi::AlignedVector; - LoadT src_vec; - - for (int32_t linear_index = idx * VecSize, - step = gridDim.x * blockDim.x * VecSize; - linear_index < elem_cnt; - linear_index += step) { - const int token_idx = linear_index / dim_embed; - const int ori_token_idx = - token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); - const int ori_batch_id = ori_token_idx / seq_len; - const int ori_seq_id = ori_token_idx % seq_len; - const int ori_head_id = (linear_index % dim_embed) / head_dim; - const int ori_head_lane = (linear_index % dim_embed) % head_dim; - const int ori_idx = ori_batch_id * num_head * seq_len * head_dim + - ori_head_id * seq_len * head_dim + - ori_seq_id * head_dim + ori_head_lane; - phi::Load(&input_data[ori_idx], &src_vec); - phi::Store(src_vec, &output_data[linear_index]); - } -} - -template -void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx, - const T* input_data, - T* output_data, - const int batch_size, - const int num_head, - const int seq_len, - const int head_dim, - const int token_num, - const int* padding_offset) { - // [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head, - // head_dim] - constexpr int VEC_16B = 16; - const int elem_cnt = token_num * num_head * head_dim; - constexpr int PackSize = VEC_16B / sizeof(T); - PADDLE_ENFORCE_EQ( - head_dim % PackSize, - 0, - phi::errors::PreconditionNotMet( - "dim_head=%d must be divisible by vec_size=%d", head_dim, PackSize)); - const int32_t pack_num = elem_cnt / PackSize; - const int32_t block_size = 128; - int32_t grid_size = (pack_num + block_size - 1) / block_size; - TransposeRemovingPadding - <<>>(input_data, - output_data, - batch_size, - num_head, - seq_len, - head_dim, - token_num, - elem_cnt, - padding_offset); -} - -template -class FMHARef { - public: - FMHARef(const phi::GPUContext& dev_ctx, - int64_t batch_size, - int64_t seq_len, - int64_t num_head, - int64_t head_dim, - AttnDropoutParam param) - : dev_ctx_(dev_ctx), - batch_size_(batch_size), - seq_len_(seq_len), - num_head_(num_head), - head_dim_(head_dim), - dropout_param_(param) {} - - ~FMHARef() {} - - void ComputeForward(const phi::DenseTensor& qkv_input_tensor, - const phi::DenseTensor* cache_kv_tensor, - const phi::DenseTensor* src_mask_tensor, - phi::DenseTensor* transpose_2_out_tensor, - phi::DenseTensor* cache_kv_out_tensor, - phi::DenseTensor* qk_out_tensor, - phi::DenseTensor* src_mask_out_tensor, - phi::DenseTensor* softmax_out_tensor, - phi::DenseTensor* dropout_mask_out_tensor, - phi::DenseTensor* dropout_out_tensor, - phi::DenseTensor* qktv_out_tensor, - phi::DenseTensor* fmha_out_tensor) { - // input shape: [bs, seq_len, 3, num_head, head_dim] - // transpose with perm [2, 0, 3, 1, 4], - // output_shape: [3, bs, num_head, seq_len, head_dim] - std::vector perm_1 = {2, 0, 3, 1, 4}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, qkv_input_tensor, perm_1, transpose_2_out_tensor); - T* qkv_data = transpose_2_out_tensor->data(); - T* qk_out_data = qk_out_tensor->data(); - T* qktv_out_data = qktv_out_tensor->data(); - T* softmax_out_data = softmax_out_tensor->data(); - T* fmha_out_data = fmha_out_tensor->data(); - - auto out_seq_len = seq_len_; - if (cache_kv_tensor) { - // kv [2, bs, num_head, seq_len, head_dim] - auto kv_tensor = transpose_2_out_tensor->Slice(1, 3); - phi::funcs::ConcatFunctor concat; - // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] - concat(dev_ctx_, {*cache_kv_tensor, kv_tensor}, 3, cache_kv_out_tensor); - out_seq_len = cache_kv_out_tensor->dims()[3]; - } - - int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; - T* q_ptr = qkv_data; - T* k_ptr = nullptr; - T* v_ptr = nullptr; - - if (cache_kv_tensor) { - int64_t k_size = cache_kv_out_tensor->numel() / 2; - k_ptr = cache_kv_out_tensor->data(); - v_ptr = k_ptr + k_size; - } else { - int64_t k_size = q_size; - k_ptr = q_ptr + q_size; - v_ptr = k_ptr + k_size; - } - - { - // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for - // float16 calculation, INF may appear in QK^T if we do not scale before. - float alpha = 1.0 / sqrt(head_dim_); - auto q_tensor = transpose_2_out_tensor->Slice(0, 1); - auto functor = phi::funcs::ScaleFunctor(alpha); - std::vector ins = {&q_tensor}; - std::vector outs = {&q_tensor}; - phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); - } - - // q*k^t, batched_gemm - CBLAS_TRANSPOSE transA = CblasNoTrans; - CBLAS_TRANSPOSE transB = CblasTrans; - auto blas = phi::funcs::GetBlas(dev_ctx_); - int gemm_batch_size = batch_size_ * num_head_; - int gemm_m = seq_len_; - int gemm_n = out_seq_len; - int gemm_k = head_dim_; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - int64_t stride_a = gemm_m * gemm_k; - int64_t stride_b = gemm_k * gemm_n; - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - q_ptr, - k_ptr, - beta, - qk_out_data, - gemm_batch_size, - stride_a, - stride_b); - int softmax_axis = -1; - if (src_mask_tensor != nullptr) { - if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) { - LaunchFusedSoftmaxMaskKernel(qk_out_data, - src_mask_tensor->data(), - softmax_out_data, - batch_size_, - num_head_, - seq_len_, - dev_ctx_.stream()); - } else { - std::vector ins; - std::vector outs; - ins.emplace_back(qk_out_tensor); - ins.emplace_back(src_mask_tensor); - outs.emplace_back(src_mask_out_tensor); - int elewise_add_axis = -1; - phi::funcs::BroadcastKernel(dev_ctx_, - ins, - &outs, - phi::funcs::AddFunctor(), - elewise_add_axis); - - phi::SoftmaxForwardCUDAKernelDriver( - dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); - } - } else { - phi::SoftmaxForwardCUDAKernelDriver( - dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor); - } - - transB = CblasNoTrans; - gemm_m = seq_len_; - gemm_n = head_dim_; - gemm_k = out_seq_len; - alpha = static_cast(1.0); - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - - if (dropout_param_.dropout_prob_) { - phi::funcs::DropoutFwGPUKernelDriver( - static_cast(dev_ctx_), - dropout_param_.is_test_, - dropout_param_.dropout_prob_, - dropout_param_.is_upscale_in_train_, - dropout_param_.is_fix_seed_, - dropout_param_.seed_val_, - static_cast(*softmax_out_tensor), - dropout_param_.seed_, - dropout_mask_out_tensor, - dropout_out_tensor, - false); - T* dropout_out_data = dropout_out_tensor->data(); - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - dropout_out_data, - v_ptr, - beta, - qktv_out_data, - gemm_batch_size, - stride_a, - stride_b); - } else { - // softmax_out * v, batched_gemm - // output shape: [batch_size, num_heads, seq_len, head_dim] - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - softmax_out_data, - v_ptr, - beta, - qktv_out_data, - gemm_batch_size, - stride_a, - stride_b); - } - // transpose: [0, 2, 1, 3] - // output shape: [batch_size, seq_len, num_heads, head_dim] - std::vector perm_3 = {0, 2, 1, 3}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); - } - - void ComputeForwardWithoutTranspose( - const phi::DenseTensor* cache_kv_tensor, - const phi::DenseTensor* src_mask_tensor, - const phi::DenseTensor* padding_offset_tensor, - phi::DenseTensor* q_transpose_out_tensor, - phi::DenseTensor* kv_transpose_out_tensor, - phi::DenseTensor* cache_kv_out_tensor, - phi::DenseTensor* qk_out_tensor, - phi::DenseTensor* src_mask_out_tensor, - phi::DenseTensor* softmax_out_tensor, - phi::DenseTensor* dropout_mask_out_tensor, - phi::DenseTensor* dropout_out_tensor, - phi::DenseTensor* qktv_out_tensor, - phi::DenseTensor* fmha_out_tensor, - const int token_num) { - // input shape: [bs, seq_len, 3, num_head, head_dim] - // transpose with perm [2, 0, 3, 1, 4], - // output_shape: [3, bs, num_head, seq_len, head_dim] - T* qk_out_data = qk_out_tensor->data(); - T* qktv_out_data = qktv_out_tensor->data(); - T* softmax_out_data = softmax_out_tensor->data(); - T* dropout_out_data = dropout_out_tensor->data(); - T* fmha_out_data = fmha_out_tensor->data(); - - auto out_seq_len = seq_len_; - if (cache_kv_tensor) { - // kv [2, bs, num_head, seq_len, head_dim] - phi::funcs::ConcatFunctor concat; - // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] - concat(dev_ctx_, - {*cache_kv_tensor, *kv_transpose_out_tensor}, - 3, - cache_kv_out_tensor); - out_seq_len = cache_kv_out_tensor->dims()[3]; - } - - int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; - T* q_ptr = q_transpose_out_tensor->data(); - T* k_ptr = nullptr; - T* v_ptr = nullptr; - - if (cache_kv_tensor) { - int64_t k_size = cache_kv_out_tensor->numel() / 2; - k_ptr = cache_kv_out_tensor->data(); - v_ptr = k_ptr + k_size; - } else { - int64_t k_size = q_size; - k_ptr = kv_transpose_out_tensor->data(); - v_ptr = k_ptr + k_size; - } - - { - // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for - // float16 calculation, INF may appear in QK^T if we do not scale before. - float alpha = 1.0 / sqrt(head_dim_); - auto functor = phi::funcs::ScaleFunctor(alpha); - std::vector ins = {q_transpose_out_tensor}; - std::vector outs = {q_transpose_out_tensor}; - phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); - } - - // q*k^t, batched_gemm - CBLAS_TRANSPOSE transA = CblasNoTrans; - CBLAS_TRANSPOSE transB = CblasTrans; - auto blas = phi::funcs::GetBlas(dev_ctx_); - int gemm_batch_size = batch_size_ * num_head_; - int gemm_m = seq_len_; - int gemm_n = out_seq_len; - int gemm_k = head_dim_; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - int64_t stride_a = gemm_m * gemm_k; - int64_t stride_b = gemm_k * gemm_n; - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - q_ptr, - k_ptr, - beta, - qk_out_data, - gemm_batch_size, - stride_a, - stride_b); - int softmax_axis = -1; - if (src_mask_tensor != nullptr) { - if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) { - LaunchFusedSoftmaxMaskKernel(qk_out_data, - src_mask_tensor->data(), - softmax_out_data, - batch_size_, - num_head_, - seq_len_, - dev_ctx_.stream()); - } else { - std::vector ins; - std::vector outs; - ins.emplace_back(qk_out_tensor); - ins.emplace_back(src_mask_tensor); - outs.emplace_back(src_mask_out_tensor); - int elewise_add_axis = -1; - phi::funcs::BroadcastKernel(dev_ctx_, - ins, - &outs, - phi::funcs::AddFunctor(), - elewise_add_axis); - - phi::SoftmaxForwardCUDAKernelDriver( - dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); - } - } else { - phi::SoftmaxForwardCUDAKernelDriver( - dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor); - } - - transB = CblasNoTrans; - gemm_m = seq_len_; - gemm_n = head_dim_; - gemm_k = out_seq_len; - alpha = static_cast(1.0); - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - - if (dropout_param_.dropout_prob_) { - phi::funcs::DropoutFwGPUKernelDriver( - static_cast(dev_ctx_), - dropout_param_.is_test_, - dropout_param_.dropout_prob_, - dropout_param_.is_upscale_in_train_, - dropout_param_.is_fix_seed_, - dropout_param_.seed_val_, - static_cast(*softmax_out_tensor), - dropout_param_.seed_, - dropout_mask_out_tensor, - dropout_out_tensor, - false); - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - dropout_out_data, - v_ptr, - beta, - qktv_out_data, - gemm_batch_size, - stride_a, - stride_b); - } else { - // softmax_out * v, batched_gemm - // output shape: [batch_size, num_heads, seq_len, head_dim] - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - softmax_out_data, - v_ptr, - beta, - qktv_out_data, - gemm_batch_size, - stride_a, - stride_b); - } - // transpose: [0, 2, 1, 3] - // output shape: [batch_size, seq_len, num_heads, head_dim] - if (!padding_offset_tensor) { - std::vector perm_3 = {0, 2, 1, 3}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); - } else { - InvokeTransposeRemovePadding(dev_ctx_, - qktv_out_data, - fmha_out_data, - batch_size_, - num_head_, - seq_len_, - head_dim_, - token_num, - padding_offset_tensor->data()); - } - } - - void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor, - const phi::DenseTensor* src_mask_tensor, - const phi::DenseTensor& softmax_out_tensor, - const phi::DenseTensor& dropout_mask_out_tensor, - const phi::DenseTensor& dropout_out_tensor, - const phi::DenseTensor& qk_out_tensor, - const phi::DenseTensor& src_mask_out_tensor, - const phi::DenseTensor& fmha_out_grad_tensor, - phi::DenseTensor* qktv_out_grad_tensor, - phi::DenseTensor* dropout_out_grad_tensor, - phi::DenseTensor* softmax_out_grad_tensor, - phi::DenseTensor* src_mask_out_grad_tensor, - phi::DenseTensor* qk_out_grad_tensor, - phi::DenseTensor* transpose_2_out_grad_tensor, - phi::DenseTensor* src_mask_grad_tensor, - phi::DenseTensor* qkv_input_grad_tensor) { - auto blas = phi::funcs::GetBlas(dev_ctx_); - int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; - int k_size = q_size; - int softmax_axis = -1; - - T* qkv_grad_data = transpose_2_out_grad_tensor->data(); - T* q_grad_ptr = qkv_grad_data; - T* k_grad_ptr = q_grad_ptr + q_size; - T* v_grad_ptr = k_grad_ptr + k_size; - const T* qkv_data = transpose_2_out_tensor.data(); - const T* q_ptr = qkv_data; - const T* k_ptr = q_ptr + q_size; - const T* v_ptr = k_ptr + k_size; - - const T* softmax_out_data = softmax_out_tensor.data(); - T* softmax_out_grad_data = softmax_out_grad_tensor->data(); - T* qktv_out_grad_data = qktv_out_grad_tensor->data(); - - // transpose bw - std::vector perm_3 = {0, 2, 1, 3}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, fmha_out_grad_tensor, perm_3, qktv_out_grad_tensor); - - // recall batchedgemm(nn) fw: softmax_out_data(x) * v_ptr(y) = - // qktv_out_data(out) - CBLAS_TRANSPOSE transA = CblasTrans; - CBLAS_TRANSPOSE transB = CblasNoTrans; - int gemm_batch_size = batch_size_ * num_head_; - int gemm_m = seq_len_; - int gemm_n = head_dim_; - int gemm_k = seq_len_; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - int64_t stride_a = gemm_m * gemm_k; - int64_t stride_b = gemm_k * gemm_n; - // bw: dy = x^t * dout - if (dropout_param_.dropout_prob_) { - const T* dropout_out_data = dropout_out_tensor.data(); - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - dropout_out_data, - qktv_out_grad_data, - beta, - v_grad_ptr, - gemm_batch_size, - stride_a, - stride_b); - } else { - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - softmax_out_data, - qktv_out_grad_data, - beta, - v_grad_ptr, - gemm_batch_size, - stride_a, - stride_b); - } - // bw: dx = dout * y^t - transA = CblasNoTrans; - transB = CblasTrans; - gemm_m = seq_len_; - gemm_n = seq_len_; - gemm_k = head_dim_; - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - if (dropout_param_.dropout_prob_) { - T* dropout_out_grad_data = dropout_out_grad_tensor->data(); - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - qktv_out_grad_data, - v_ptr, - beta, - dropout_out_grad_data, - gemm_batch_size, - stride_a, - stride_b); - } else { - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - qktv_out_grad_data, - v_ptr, - beta, - softmax_out_grad_data, - gemm_batch_size, - stride_a, - stride_b); - } - // dropout bw - if (dropout_param_.dropout_prob_) { - phi::funcs::DropoutGradGPUKernelDriver( - static_cast(dev_ctx_), - false, - dropout_param_.dropout_prob_, - dropout_param_.is_upscale_in_train_, - static_cast(*dropout_out_grad_tensor), - dropout_mask_out_tensor, - softmax_out_grad_tensor, - false); - } - - if (src_mask_tensor != nullptr) { - phi::SoftmaxBackwardCUDAKernelDriver(dev_ctx_, - softmax_out_tensor, - *softmax_out_grad_tensor, - softmax_axis, - src_mask_out_grad_tensor); - // recall LaunchElementwiseCudaKernel fw: src_mask_out = qk_out + - // src_mask - // Special case when dy is not needed and dx doesn't reduce - if (qk_out_grad_tensor != nullptr && src_mask_grad_tensor == nullptr && - qk_out_tensor.dims() == src_mask_out_tensor.dims()) { - VLOG(4) << "Special case when dy is not needed and dx doesn't " - "reduce"; - framework::TensorCopy(*src_mask_out_grad_tensor, - dev_ctx_.GetPlace(), - dev_ctx_, - qk_out_grad_tensor); - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "Only used for the backward elementwise_add op when" - "dy is not needed and dx is not reduce")); - return; - } - - } else { - phi::SoftmaxBackwardCUDAKernelDriver(dev_ctx_, - softmax_out_tensor, - *softmax_out_grad_tensor, - softmax_axis, - qk_out_grad_tensor); - } - - T* qk_out_grad_data = qk_out_grad_tensor->data(); - // NOTE(wangxi): For we scale Q with 1/sqrt(Dh) in forward, so we set - // alpha = 1.0 in backward. - alpha = static_cast(1.0); - // recall batchedgemm(nt) fw: q_ptr * (k_ptr)^t = qk_out - // bw: dy (seq_len * head_dim) = (dout)^t * x - transA = CblasTrans; - transB = CblasNoTrans; - gemm_m = seq_len_; - gemm_n = head_dim_; - gemm_k = seq_len_; - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - qk_out_grad_data, - q_ptr, - beta, - k_grad_ptr, - gemm_batch_size, - stride_a, - stride_b); - // dx (seq_len * head_dim) = dout * y - alpha = static_cast(1.0 / sqrt(head_dim_)); - transA = CblasNoTrans; - transB = CblasNoTrans; - gemm_m = seq_len_; - gemm_n = head_dim_; - gemm_k = seq_len_; - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - qk_out_grad_data, - k_ptr, - beta, - q_grad_ptr, - gemm_batch_size, - stride_a, - stride_b); - - // transpose bw - std::vector perm_1 = {1, 3, 0, 2, 4}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, *transpose_2_out_grad_tensor, perm_1, qkv_input_grad_tensor); - } - - private: - const phi::GPUContext& dev_ctx_; - - int64_t batch_size_; - int64_t seq_len_; - int64_t num_head_; - int64_t head_dim_; - - AttnDropoutParam dropout_param_; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu index b696a183170c33..11614d70165d3a 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu @@ -61,8 +61,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { auto ln_scales = ctx.MultiInput("LnScale"); auto ln_biases = ctx.MultiInput("LnBias"); - auto ln_compute = - AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); + auto ln_compute = phi::fusion::AttnLayerNorm( + dev_ctx, epsilon, bsz_seq, dim_embed); phi::DenseTensor ln_mean, ln_var; ln_mean.Resize({{bsz_seq}}); auto *ln_mean_data = @@ -93,10 +93,10 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); // 3. fmha - AttnDropoutParam attn_param( + phi::fusion::AttnDropoutParam attn_param( true, "upscale_in_train", 0.0, true, true, 0, nullptr); - auto fmha_compute = - FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); + auto fmha_compute = phi::fusion::FMHARef( + dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); auto *src_mask = ctx.Input("SrcMask"); auto cache_kvs = ctx.MultiInput("CacheKV"); auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index 75a4c7b275a8a5..b3718dfe1f7d51 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -125,7 +125,8 @@ void FusedMultiTransformerKernel( auto *padding_offset_data = encoder_remove_padding ? padding_offset_tensor.data() : nullptr; - auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); + auto ln_compute = + phi::fusion::AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); phi::DenseTensor ln_mean, ln_var; ln_mean.Resize({token_num}); auto *ln_mean_data = @@ -800,7 +801,8 @@ void FusedMultiTransformerKernel( // 1. layer norm - auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); + auto ln_compute = + phi::fusion::AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); phi::DenseTensor ln_mean, ln_var; ln_mean.Resize({token_num}); auto *ln_mean_data = diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index 4bf467e9caf8fa..0a57fb9e873414 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -27,17 +27,16 @@ limitations under the License. */ #include "paddle/common/flags.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/fused/attention_layer_norm.h" -#include "paddle/fluid/operators/fused/fmha_ref.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/backends/dynload/cublasLt.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/fusion/gpu/attn_gemm.h" +#include "paddle/phi/kernels/fusion/gpu/fmha_ref.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/process_group.h" @@ -711,13 +710,13 @@ struct Qk_dot { } }; -template +template inline __device__ float block_sum(float *red_smem, float sum) { - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; + int warp = threadIdx.x / WARP_SIZE_T; + int lane = threadIdx.x % WARP_SIZE_T; #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + for (int mask = WARP_SIZE_T / 2; mask >= 1; mask /= 2) { sum += __shfl_xor_sync(uint32_t(-1), sum, mask); } @@ -789,8 +788,8 @@ __global__ void masked_multihead_attention_kernel( static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); - constexpr int WARP_SIZE = 32; - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + constexpr int WARP_SIZE_TMP = 32; + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE_TMP; extern __shared__ char smem_[]; @@ -824,7 +823,7 @@ __global__ void masked_multihead_attention_kernel( constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); // Use block reduction if needed - // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE_TMP, ""); constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; // cache_k, [B, num_head, head_dim / x, max_seq_len, x] @@ -944,16 +943,16 @@ __global__ void masked_multihead_attention_kernel( qk = dot(q, k); - if (QK_VECS_PER_WARP <= WARP_SIZE) { + if (QK_VECS_PER_WARP <= WARP_SIZE_TMP) { #pragma unroll for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); } } } - if (QK_VECS_PER_WARP > WARP_SIZE) { + if (QK_VECS_PER_WARP > WARP_SIZE_TMP) { constexpr int WARPS_PER_RED = - (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + (QK_VECS_PER_WARP + WARP_SIZE_TMP - 1) / WARP_SIZE_TMP; qk = block_sum(&red_smem[WARPS_PER_RED], qk); } if (tid == 0) { @@ -994,7 +993,7 @@ __global__ void masked_multihead_attention_kernel( } constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + constexpr int K_PER_WARP = WARP_SIZE_TMP / THREADS_PER_KEY; T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP; @@ -1031,12 +1030,12 @@ __global__ void masked_multihead_attention_kernel( } #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + for (int mask = WARP_SIZE_TMP / 2; mask >= THREADS_PER_KEY; mask /= 2) { qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); } - const int warp = tid / WARP_SIZE; - const int lane = tid % WARP_SIZE; + const int warp = tid / WARP_SIZE_TMP; + const int lane = tid % WARP_SIZE_TMP; if (lane == 0) { red_smem[warp] = qk_max; diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cu b/paddle/fluid/operators/fused/resnet_unit_op.cu index 2955fd3b453b4d..f715bda6906951 100644 --- a/paddle/fluid/operators/fused/resnet_unit_op.cu +++ b/paddle/fluid/operators/fused/resnet_unit_op.cu @@ -31,7 +31,7 @@ class ResNetUnitKernel : public framework::OpKernel { platform::is_gpu_place(ctx.GetPlace()), true, phi::errors::PreconditionNotMet("It must use CUDAPlace.")); - PADDLE_ENFORCE_EQ(platform::CudnnDataType::type, + PADDLE_ENFORCE_EQ(phi::backends::gpu::CudnnDataType::type, CUDNN_DATA_HALF, phi::errors::Unavailable( "ResNetUnitOp only supports float16 for now.")); @@ -231,7 +231,7 @@ class ResNetUnitGradKernel : public framework::OpKernel { platform::is_gpu_place(ctx.GetPlace()), true, phi::errors::PreconditionNotMet("It must use CUDAPlace.")); - PADDLE_ENFORCE_EQ(platform::CudnnDataType::type, + PADDLE_ENFORCE_EQ(phi::backends::gpu::CudnnDataType::type, CUDNN_DATA_HALF, phi::errors::Unavailable( "ResNetUnitOp only supports float16 for now.")); diff --git a/paddle/fluid/operators/fused/xpu_fused_common_function.h b/paddle/fluid/operators/fused/xpu_fused_common_function.h deleted file mode 100644 index 63a22838e8c35e..00000000000000 --- a/paddle/fluid/operators/fused/xpu_fused_common_function.h +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#ifdef PADDLE_WITH_XPU -#include "paddle/fluid/platform/device/device_wrapper.h" - -namespace paddle { -namespace operators { - -struct XPUDropoutParam { - float dropout_prob; - bool is_upscale_in_train; - bool is_test; - bool fix_seed; - const phi::DenseTensor *tensor_seed; - int seed_val; - - XPUDropoutParam() { - fix_seed = false; - is_test = false; - is_upscale_in_train = false; - dropout_prob = 0.5; - tensor_seed = nullptr; - seed_val = 0; - } - - XPUDropoutParam(const framework::ExecutionContext &context, - const int dropout_index) { - std::string pre_fix = "dropout"; - std::string str_index = std::to_string(dropout_index); - if (dropout_index > 0) { - pre_fix = pre_fix + str_index + "_"; - } else { - pre_fix = pre_fix + "_"; - } - dropout_prob = context.Attr(pre_fix + "rate"); - auto &dropout_implementation = - context.Attr(pre_fix + "implementation"); - is_upscale_in_train = (dropout_implementation == "upscale_in_train"); - is_test = context.Attr("is_test"); - fix_seed = context.Attr(pre_fix + "fix_seed"); - - std::string str_seed = "Dropout"; - if (dropout_index > 0) { - str_seed = str_seed + str_index + "Seed"; - } else { - str_seed = str_seed + "Seed"; - } - - tensor_seed = context.HasInput(str_seed) - ? context.Input(str_seed) - : nullptr; - if (tensor_seed) { - seed_val = *(tensor_seed->data()); - } else { - seed_val = fix_seed ? context.Attr(pre_fix + "seed") : 0; - } - } - - void initXPUDropoutParam(float dropout_prob_, - bool is_upscale_in_train_, - bool is_test_, - bool fix_seed_, - const phi::DenseTensor *tensor_seed, - int seed_val_) { - dropout_prob = dropout_prob_; - is_upscale_in_train = is_upscale_in_train_; - is_test = is_test_; - fix_seed = fix_seed_; - if (tensor_seed) { - seed_val = *(tensor_seed->data()); - } else { - seed_val = fix_seed ? seed_val_ : 0; - } - } - - void initXPUDropoutParam(const framework::ExecutionContext &context, - int dropout_index) { - std::string pre_fix = "dropout"; - std::string str_index = std::to_string(dropout_index); - if (dropout_index > 0) { - pre_fix = pre_fix + str_index + "_"; - } else { - pre_fix = pre_fix + "_"; - } - dropout_prob = context.Attr(pre_fix + "rate"); - auto &dropout_implementation = - context.Attr(pre_fix + "implementation"); - is_upscale_in_train = (dropout_implementation == "upscale_in_train"); - is_test = context.Attr("is_test"); - fix_seed = context.Attr(pre_fix + "fix_seed"); - std::string str_seed = "Dropout"; - if (dropout_index > 0) { - str_seed = str_seed + str_index + "Seed"; - } else { - str_seed = str_seed + "Seed"; - } - tensor_seed = context.HasInput(str_seed) - ? context.Input(str_seed) - : nullptr; - - if (tensor_seed) { - seed_val = *(tensor_seed->data()); - } else { - seed_val = fix_seed ? context.Attr(pre_fix + "seed") : 0; - } - } -}; - -/****************** - * check is l3 - *******************/ - -static bool is_in_l3(const void *addr) { - int64_t addr_int = (int64_t)addr; - int addr_int_high = addr_int >> 32; - return (addr_int_high == 0); -} - -/************************* - * dropout - *************************/ - -template -void Dropout(xpu::Context *xpu_ctx, - const T *x, - T *mask, - T *y, - const XPUDropoutParam ¶m, - int len) { - using XPUType = typename XPUTypeTrait::Type; - int r = XPU_SUCCESS; - if (param.dropout_prob == 0.0f) { - r = xpu::copy(xpu_ctx, - reinterpret_cast(x), - reinterpret_cast(y), - len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); - return; - } - if (!param.is_test) { - if (param.dropout_prob == 1.0f) { - r = xpu::constant( - xpu_ctx, reinterpret_cast(y), len, XPUType(0)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); - r = xpu::constant( - xpu_ctx, reinterpret_cast(mask), len, XPUType(0)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); - } else { - r = xpu::dropout(xpu_ctx, - reinterpret_cast(x), - reinterpret_cast(y), - reinterpret_cast(mask), - param.seed_val, - len, - param.is_upscale_in_train, - param.dropout_prob); - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout"); - } - } else { - float scale = (param.is_upscale_in_train) - ? (1.0) - : (static_cast(1.0f - param.dropout_prob)); - r = xpu::scale(xpu_ctx, - reinterpret_cast(x), - reinterpret_cast(y), - len, - false, - scale, - 0.0f); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); - } -} - -template -void DropoutGrad(xpu::Context *xpu_ctx, - const T *dy, - const T *mask, - T *dx, - const XPUDropoutParam ¶m, - int len) { - using XPUType = typename XPUTypeTrait::Type; - if (param.dropout_prob == 0.0f) { - int r = xpu::copy(xpu_ctx, - reinterpret_cast(dy), - reinterpret_cast(dx), - len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); - return; - } - if (!param.is_upscale_in_train) { - int r = xpu::mul(xpu_ctx, - reinterpret_cast(dy), - reinterpret_cast(mask), - reinterpret_cast(dx), - len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); - } else { - int r = xpu::dropout_grad(xpu_ctx, - reinterpret_cast(mask), - reinterpret_cast(dy), - reinterpret_cast(dx), - param.dropout_prob, - len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad"); - } -} - -} // namespace operators -} // namespace paddle -#endif diff --git a/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc index 6fdd6d380a7fe2..6000516cab7aa1 100644 --- a/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc +++ b/paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc @@ -16,7 +16,7 @@ limitations under the License. */ // HIP not support cudnnSpatialTfGridGeneratorForward #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" namespace phi { class DenseTensor; diff --git a/paddle/fluid/operators/math/prelu.h b/paddle/fluid/operators/math/prelu.h index 00ff1fbcbc38db..d809c71f437426 100644 --- a/paddle/fluid/operators/math/prelu.h +++ b/paddle/fluid/operators/math/prelu.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/miopen_rnn_cache.h b/paddle/fluid/operators/miopen_rnn_cache.h index 2a8b38d38d5776..5f633c6f6fd1ee 100644 --- a/paddle/fluid/operators/miopen_rnn_cache.h +++ b/paddle/fluid/operators/miopen_rnn_cache.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index 5fbbd49a885210..1d58154d36064b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" #endif namespace paddle {