From 14a38f89bef9ff805920e8df106060ccf72a8f0e Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Thu, 6 Jul 2023 14:39:04 +0800 Subject: [PATCH 01/19] add init value for CudaSwishFunctor --- paddle/phi/kernels/funcs/activation_functor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 203f6837d4611..0b4f051bb4599 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -3923,7 +3923,7 @@ template struct CudaSwishFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); - float beta; + float beta = 1.0; typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"beta", &beta}}; From b79a8cb257f82b169c4a1e9ca23030b9c11c9728 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Thu, 6 Jul 2023 15:18:18 +0800 Subject: [PATCH 02/19] add new phi kernel fusedBiasActKernel --- paddle/phi/api/yaml/fused_ops.yaml | 10 + paddle/phi/infermeta/multiary.cc | 114 ++++ paddle/phi/infermeta/multiary.h | 15 + paddle/phi/kernels/funcs/load_store_util.h | 220 +++++++ .../fusion/gpu/fused_bias_act_kernel.cu | 589 ++++++++++++++++++ 5 files changed, 948 insertions(+) create mode 100644 paddle/phi/kernels/funcs/load_store_util.h create mode 100644 paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 64a5d2bb00aae..8c5e06c81e454 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -54,6 +54,16 @@ data_type : x optional : bias, x_max +- op : fused_bias_act + args : (Tensor x, Tensor bias, Tensor dequant_scales, Tensor shift, Tensor smooth, str act_method, int rows, int cols, str compute_type , float quant_scale=-1, int quant_round_type=0, float quant_max_bound=0.0, float quant_min_bound=0.0) + output : Tensor(out) + infer_meta : + func: FusedBiasActInferMeta + kernel : + func : fused_bias_act + data_type : x + optional : bias, dequant_scales, shift, smooth + - op : fused_dropout_add args : (Tensor x, Tensor y, Tensor seed_tensor, Scalar p, bool is_test, str mode, int seed = 0, bool fix_seed = false) optional : seed_tensor diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 71bbfaa333a0a..28e8214670117 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1335,6 +1335,120 @@ void EditDistanceInferMeta(const MetaTensor& hyps, sequencenum->set_dtype(DataType::FLOAT32); } +void FusedBiasActInferMeta(const MetaTensor& x, + const MetaTensor& bias, + const MetaTensor& dequant_scales, + const MetaTensor& shift, + const MetaTensor& smooth, + const std::string& act_method, + const std::string& compute_dtype, + int rows, + int cols, + float quant_scale, + int quant_rount_type, + float quant_max_bound, + float quant_min_bound, + MetaTensor* out) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_EQ(x_dims.size(), + 2, + phi::errors::InvalidArgument( + "The size of Input(x) must be 2: %s", x_dims)); + auto token_num = x_dims[0]; + auto dim = x_dims[1]; + + if (act_method == "geglu" || act_method == "swiglu") { + PADDLE_ENFORCE_EQ( + dim % 2, + 0, + phi::errors::InvalidArgument( + "The seconde dimension of x must be even, but receive %d", dim)); + dim /= 2; + out->set_dims(phi::make_ddim({token_num, dim})); + } else if (act_method == "gelu") { + out->set_dims(phi::make_ddim({token_num, dim})); + } else { + PADDLE_THROW( + errors::InvalidArgument("act_method must be geglu, swiglu or gelu, " + "but get act_method (%s)", + act_method)); + } + + auto FBADtypeCheck = [](const MetaTensor& check_tensor, + const std::string& tensor_name, + const std::string& compute_dtype) { + if (compute_dtype == "bf16") { + PADDLE_ENFORCE_EQ( + check_tensor.dtype(), + phi::DataType::BFLOAT16, + phi::errors::InvalidArgument( + "Input(%s) dtype must be the same with Attr(compute_dtype)", + tensor_name)); + } else if (compute_dtype == "fp16") { + PADDLE_ENFORCE_EQ( + check_tensor.dtype(), + phi::DataType::FLOAT16, + phi::errors::InvalidArgument( + "Input(%s) dtype must be the same with Attr(compute_dtype)", + tensor_name)); + } else if (compute_dtype == "fp32") { + PADDLE_ENFORCE_EQ( + check_tensor.dtype(), + phi::DataType::FLOAT32, + phi::errors::InvalidArgument( + "Input(%s) dtype must be the same with Attr(compute_dtype)", + tensor_name)); + } + }; + + // In the case of quantization enabled, the dtype for computation is + // determined based on compute_type. + if (x.dtype() == phi::DataType::INT32) { + PADDLE_ENFORCE_NE( + compute_dtype, + "default", + phi::errors::InvalidArgument( + "If Input(x) dtype is INT32, Attr(compute_dtype) must be set.")); + + if (bias) { + FBADtypeCheck(bias, "bias", compute_dtype); + } + + if (compute_dtype == "bf16") { + out->set_dtype(phi::DataType::BFLOAT16); + } else if (compute_dtype == "fp16") { + out->set_dtype(phi::DataType::FLOAT16); + } else if (compute_dtype == "fp32") { + out->set_dtype(phi::DataType::FLOAT32); + } else { + PADDLE_THROW( + "In the case of quantization enabled with Input(x) INT32, " + "Attr(compute_dtype) must be set in (bf16, fp16, fp32)"); + } + } else { + // x.dtype() != phi::DataType::INT32 + if (bias) { + if (compute_dtype != "default") { + FBADtypeCheck(bias, "bias", compute_dtype); + FBADtypeCheck(x, "x", compute_dtype); + } else { + PADDLE_ENFORCE_EQ( + x.dtype(), + bias.dtype(), + phi::errors::InvalidArgument("Input(x) and Input(bias) must be the " + "same dtype in this situation")); + } + } else { + // bias not exist + if (compute_dtype != "default") { + FBADtypeCheck(x, "x", compute_dtype); + } + } + out->set_dtype(x.dtype()); + } + out->set_layout(x.layout()); +} + void FusedLinearParamGradAddInferMeta(const MetaTensor& x, const MetaTensor& dout, const MetaTensor& dweight, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 67a39780aa9c2..5e4578c9652f1 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -279,6 +279,21 @@ void EditDistanceInferMeta(const MetaTensor& hyps, MetaTensor* sequencenum, MetaTensor* out); +void FusedBiasActInferMeta(const MetaTensor& x, + const MetaTensor& bias, + const MetaTensor& dequant_scales, + const MetaTensor& shift, + const MetaTensor& smooth, + const std::string& act_method, + const std::string& compute_dtype, + int rows, + int cols, + float quant_scale, + int quant_rount_type, + float quant_max_bound, + float quant_min_bound, + MetaTensor* out); + void FusedLinearParamGradAddInferMeta(const MetaTensor& x, const MetaTensor& dout, const MetaTensor& dweight, diff --git a/paddle/phi/kernels/funcs/load_store_util.h b/paddle/phi/kernels/funcs/load_store_util.h new file mode 100644 index 0000000000000..0fe7a3ce7a348 --- /dev/null +++ b/paddle/phi/kernels/funcs/load_store_util.h @@ -0,0 +1,220 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" + +namespace phi { +namespace funcs { + +template +__device__ __inline__ T ClipFunc(const T v, const T min, const T max) { + if (v > max) return max; + if (v < min) return min; + return v; +} + +template +__forceinline__ __device__ OutType QuantHelperFunc(const InType input, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + float quant_value = max_bound * scale * input; + + if (round_type == 0) { + quant_value = static_cast(rint(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } + return static_cast( + ClipFunc(quant_value, min_bound, max_bound)); +} + +template +struct Load { + explicit Load(const T *src) : src_(src) {} + + template + __device__ void load(phi::AlignedVector *dst, int idx) { + phi::Load(src_ + idx, dst); + } + + const T *src_; +}; + +template +struct Store { + explicit Store(T *dst) : dst_(dst) {} + + template + __device__ void store(phi::AlignedVector &src, int idx) { + phi::Store(src, dst_ + idx); + } + + T *dst_; +}; + +template +struct Store { + Store(T *dst, const T *shift, const T *smooth, const int cols) + : dst_(dst), shift_(shift), smooth_(smooth), cols_(cols) {} + + template + __device__ void store(phi::AlignedVector &src, int idx) { + using Vec = phi::AlignedVector; + Vec shift_vec; + Vec smooth_vec; + + phi::Load(shift_ + idx % cols_, &shift_vec); + phi::Load(smooth_ + idx % cols_, &smooth_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + src[i] = (src[i] + shift_vec[i]) * smooth_vec[i]; + } + phi::Store(src, dst_ + idx); + } + + T *dst_; + const T *shift_; + const T *smooth_; + const int cols_; +}; + +template +struct DequantLoad { + DequantLoad(const int32_t *src, const float *dequant_scales, const int cols) + : src_(src), dequant_scales_(dequant_scales), cols_(cols) {} + + template + __device__ void load(phi::AlignedVector *dst, int idx) { + using SrcVec = phi::AlignedVector; + using DstVec = phi::AlignedVector; + using ScaleVec = phi::AlignedVector; + + SrcVec src_vec; + DstVec dst_vec; + ScaleVec scale_vec; + + phi::Load(src_ + idx, &src_vec); + phi::Load(dequant_scales_ + idx % cols_, &scale_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + dst_vec[i] = + static_cast(static_cast(src_vec[i]) * scale_vec[i]); + } + *dst = dst_vec; + } + + const int32_t *src_; + const float *dequant_scales_; + const int cols_; +}; + +template +struct QuantStore { + QuantStore(int8_t *dst, + const int quant_round_type, + const float quant_scale, + const float quant_max_bound, + const float quant_min_bound) + : dst_(dst), + quant_round_type_(quant_round_type), + quant_scale_(quant_scale), + quant_max_bound_(quant_max_bound), + quant_min_bound_(quant_min_bound) {} + + template + __device__ void store(phi::AlignedVector &src, // NOLINT + int idx) { // NOLINT + using DstVec = phi::AlignedVector; + + DstVec dst_vec; +#pragma unroll + for (int i = 0; i < VecSize; i++) { + dst_vec[i] = QuantHelperFunc(static_cast(src[i]), + quant_scale_, + quant_round_type_, + quant_max_bound_, + quant_min_bound_); + } + + phi::Store(dst_vec, dst_ + idx); + } + + int8_t *dst_; + const int quant_round_type_; + const float quant_scale_; + const float quant_max_bound_; + const float quant_min_bound_; +}; + +template +struct QuantStore { + QuantStore(int8_t *dst, + const T *shift, + const T *smooth, + const int cols, + const int quant_round_type, + const float quant_scale, + const float quant_max_bound, + const float quant_min_bound) + : dst_(dst), + shift_(shift), + smooth_(smooth), + cols_(cols), + quant_round_type_(quant_round_type), + quant_scale_(quant_scale), + quant_max_bound_(quant_max_bound), + quant_min_bound_(quant_min_bound) {} + + template + __device__ void store(phi::AlignedVector &src, // NOLINT + int idx) { // NOLINT + using DstVec = phi::AlignedVector; + using Vec = phi::AlignedVector; + + DstVec dst_vec; + Vec shift_vec; + Vec smooth_vec; + + phi::Load(shift_ + idx % cols_, &shift_vec); + phi::Load(smooth_ + idx % cols_, &smooth_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + src[i] = (src[i] + shift_vec[i]) * smooth_vec[i]; + dst_vec[i] = QuantHelperFunc(static_cast(src[i]), + quant_scale_, + quant_round_type_, + quant_max_bound_, + quant_min_bound_); + } + + phi::Store(dst_vec, dst_ + idx); + } + + int8_t *dst_; + const int quant_round_type_; + const float quant_scale_; + const float quant_max_bound_; + const float quant_min_bound_; + const T *shift_; + const T *smooth_; + const int cols_; +}; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu new file mode 100644 index 0000000000000..3c13634d18382 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu @@ -0,0 +1,589 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "glog/logging.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" +#include "paddle/phi/kernels/funcs/load_store_util.h" +#include "paddle/phi/kernels/gpu/gelu_funcs.h" + +PHI_DECLARE_bool(use_fast_math); + +namespace phi { +namespace fusion { + +using phi::funcs::CudaSwishFunctor; +using phi::funcs::DequantLoad; +using phi::funcs::Load; +using phi::funcs::QuantStore; +using phi::funcs::Store; + +template +using CudnnDataType = phi::backends::gpu::CudnnDataType; +template +using LayerNormParamType = typename CudnnDataType::BatchNormParamType; + +// TODO(lzc): transfer to phi::funcs +template +struct GeluFunctor { + inline __host__ __device__ T operator()(const T x) const { + using U = LayerNormParamType; + const U casted_x = static_cast(x); + const U temp = erf(casted_x * static_cast(M_SQRT1_2)); + const U out = (casted_x * static_cast(0.5) * (static_cast(1) + temp)); + return static_cast(out); + } +}; + +template +struct FastGeluFunctor { + inline __device__ T operator()(const T x) const { + return phi::GeluFwd(x); + } +}; + +inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { + constexpr int kBlockSize = 128; + constexpr int kNumWaves = 16; + + const int device_id = phi::backends::gpu::GetCurrentDeviceId(); + const int sm_count = phi::backends::gpu::GetGPUMultiProcessors(device_id); + const int max_thread_per_multiprocessor = + phi::backends::gpu::GetGPUMultiProcessors(device_id); + + *num_blocks = + std::max(1, + std::min((n + kBlockSize - 1) / kBlockSize, + sm_count * max_thread_per_multiprocessor / + kBlockSize * kNumWaves)); + return cudaSuccess; +} + +template +__global__ void ActFFNGlu(const T *bias, + Functor act_functor, + const int token_num, + const int hid_dim, + const int elem_num, + LoadFunc load_func, + StoreFunc store_func) { + using LoadT = phi::AlignedVector; + LoadT src_vec1; + LoadT src_vec2; + LoadT bias_vec1; + LoadT bias_vec2; + const int global_tid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = global_tid * VecSize; i < elem_num; + i += gridDim.x * blockDim.x * VecSize) { + int bi = i / hid_dim; + int idx = i % hid_dim; + + load_func.template load(&src_vec1, bi * hid_dim * 2 + idx); + load_func.template load(&src_vec2, + bi * hid_dim * 2 + idx + hid_dim); + + if (bias) { + phi::Load(&bias[idx], &bias_vec1); + phi::Load(&bias[idx + hid_dim], &bias_vec2); + } +#pragma unroll + for (int j = 0; j < VecSize; j++) { + if (bias) { + src_vec1[j] += bias_vec1[j]; + src_vec2[j] += bias_vec2[j]; + } + src_vec1[j] = act_functor(src_vec1[j]); + src_vec1[j] *= src_vec2[j]; + } + // phi::Store(src_vec1, &output_this_thread[idx]); + store_func.template store(src_vec1, bi * hid_dim + idx); + } +} + +template +void LaunchActFFNGlu(const Context &dev_ctx, + const T *bias, + const int token_num, + const int hid_dim, + LoadFunc load_func, + StoreFunc store_func) { + constexpr int VecSize = 16; + constexpr int PackSize = VecSize / sizeof(LoadT); + const int elem_cnt = token_num * hid_dim; + const int blocksize = 128; + int grid_size = 1; + Functor functor; + switch (hid_dim % PackSize) { + case 0: + GetNumBlocks(elem_cnt / PackSize, &grid_size); + ActFFNGlu + <<>>(bias, + functor, + token_num, + hid_dim, + elem_cnt, + load_func, + store_func); + break; + default: + GetNumBlocks(elem_cnt, &grid_size); + ActFFNGlu<<>>( + bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func); + break; + } +} + +template +__global__ void BiasAct(const T *bias, + Functor act_functor, + const int rows, + const int cols, + const int elem_num, + LoadFunc load_func, + StoreFunc store_func) { + using LoadT = phi::AlignedVector; + LoadT src_vec; + LoadT bias_vec; + +// Zero Initialize BiasVec. +#pragma unroll + for (int unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) { + bias_vec[unroll_idx] = 0; + } + + const int global_tid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = global_tid * VecSize; i < elem_num; + i += gridDim.x * blockDim.x * VecSize) { + int row_idx = i / cols; + int col_idx = i % cols; + int linear_idx = row_idx * cols + col_idx; + load_func.template load(&src_vec, linear_idx); + if (bias) { + phi::Load(&bias[col_idx], &bias_vec); + } +#pragma unroll + for (int j = 0; j < VecSize; j++) { + if (bias) { + src_vec[j] += bias_vec[j]; + } + src_vec[j] = act_functor(src_vec[j]); + } + store_func.template store(src_vec, linear_idx); + } +} + +template +void LaunchBiasAct(const Context &dev_ctx, + const T *bias, + const int token_num, + const int hid_dim, + LoadFunc load_func, + StoreFunc store_func) { + constexpr int VecSize = 16; + constexpr int PackSize = VecSize / sizeof(LoadT); + const int elem_cnt = token_num * hid_dim; + const int blocksize = 128; + int grid_size = 1; + Functor functor; + switch (hid_dim % PackSize) { + case 0: + GetNumBlocks(elem_cnt / PackSize, &grid_size); + BiasAct + <<>>(bias, + functor, + token_num, + hid_dim, + elem_cnt, + load_func, + store_func); + break; + default: + GetNumBlocks(elem_cnt, &grid_size); + BiasAct<<>>( + bias, functor, token_num, hid_dim, elem_cnt, load_func, store_func); + break; + } +} + +template +void ComputeImpl(const Context &dev_ctx, + const T *bias_data, + const std::string &act_method, + int rows, + int cols, + LoadFunc load_func, + StoreFunc store_func) { + if (act_method == "geglu") { + // Note(Zhengzekang): For GLU structure, we need divide the cols by 2. + VLOG(8) << "Doing geglu"; + LaunchActFFNGlu, LoadFunc, StoreFunc, LoadT>( + dev_ctx, bias_data, rows, cols / 2, load_func, store_func); + } else if (act_method == "swiglu") { + VLOG(8) << "Doing swiglu"; + LaunchActFFNGlu, + LoadFunc, + StoreFunc, + LoadT>( + dev_ctx, bias_data, rows, cols / 2, load_func, store_func); + } else if (act_method == "gelu") { + if (FLAGS_use_fast_math) { + VLOG(8) << "Doing Fast GELU"; + LaunchBiasAct, LoadFunc, StoreFunc, LoadT>( + dev_ctx, bias_data, rows, cols, load_func, store_func); + } else { + VLOG(8) << "Doing GELU"; + LaunchBiasAct, LoadFunc, StoreFunc, LoadT>( + dev_ctx, bias_data, rows, cols, load_func, store_func); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Currently Only Support GeGLU, SwiGLU, GeLU")); + } +} + +template +void DispatchComputeImpl(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor *bias, + const DenseTensor *dequant_scales, + const std::string &act_method, + int rows, + int cols, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + DenseTensor *out) { + const T *bias_data = bias == nullptr ? nullptr : bias->data(); + if (dequant_scales != nullptr && quant_scale > 0) { + DequantLoad load_func( + x.data(), dequant_scales->data(), cols); + QuantStore store_func(out->data(), + quant_round_type, + quant_scale, + quant_max_bound, + quant_min_bound); + ComputeImpl, QuantStore, int32_t>( + dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); + } else if (dequant_scales == nullptr && quant_scale > 0) { + Load load_func(x.data()); + QuantStore store_func(out->data(), + quant_round_type, + quant_scale, + quant_max_bound, + quant_min_bound); + ComputeImpl( + dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); + } else if (dequant_scales != nullptr && quant_scale <= 0) { + DequantLoad load_func( + x.data(), dequant_scales->data(), cols); + Store store_func(out->data()); + ComputeImpl, Store, int32_t>( + dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); + } else { + Load load_func(x.data()); + Store store_func(out->data()); + ComputeImpl( + dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); + } +} + +template +void DispatchComputeImpl(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor *bias, + const DenseTensor *dequant_scales, + const DenseTensor *shift, + const DenseTensor *smooth, + const std::string &act_method, + int rows, + int cols, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + DenseTensor *out) { + bool use_glu = (act_method == "geglu" || act_method == "swiglu"); + const T *bias_data = bias == nullptr ? nullptr : bias->data(); + if (dequant_scales != nullptr && quant_scale > 0) { + DequantLoad load_func( + x.data(), dequant_scales->data(), cols); + QuantStore store_func(out->data(), + shift->data(), + smooth->data(), + use_glu ? cols / 2 : cols, + quant_round_type, + quant_scale, + quant_max_bound, + quant_min_bound); + ComputeImpl, QuantStore, int32_t>( + dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); + } else if (dequant_scales == nullptr && quant_scale > 0) { + Load load_func(x.data()); + QuantStore store_func(out->data(), + shift->data(), + smooth->data(), + use_glu ? cols / 2 : cols, + quant_round_type, + quant_scale, + quant_max_bound, + quant_min_bound); + ComputeImpl( + dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); + } else if (dequant_scales != nullptr && quant_scale <= 0) { + DequantLoad load_func( + x.data(), dequant_scales->data(), cols); + Store store_func(out->data(), + shift->data(), + smooth->data(), + use_glu ? cols / 2 : cols); + ComputeImpl, Store, int32_t>( + dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); + } else { + Load load_func(x.data()); + Store store_func(out->data(), + shift->data(), + smooth->data(), + use_glu ? cols / 2 : cols); + ComputeImpl( + dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); + } +} + +struct NormalVersion {}; +struct UnusedVersion {}; + +template +struct DispatchDtypeTrait { + using FuncVersion = NormalVersion; +}; + +template <> +struct DispatchDtypeTrait { + using FuncVersion = UnusedVersion; +}; + +template +void DispatchWithDtype(const Context &dev_ctx, + const DenseTensor &x, + const paddle::optional &bias, + const paddle::optional &dequant_scales, + const paddle::optional &shift, + const paddle::optional &smooth, + const std::string &act_method, + int rows, + int cols, + float quant_scale, + int quant_round_type, + float quant_max_bound, + float quant_min_bound, + DenseTensor *out, + NormalVersion) { + auto *bias_p = bias.get_ptr(); + auto *dequant_scales_p = dequant_scales.get_ptr(); + auto *shift_p = shift.get_ptr(); + auto *smooth_p = smooth.get_ptr(); + if (dequant_scales_p != nullptr) { + if (shift_p != nullptr) { + DispatchComputeImpl(dev_ctx, + x, + bias_p, + dequant_scales_p, + shift_p, + smooth_p, + act_method, + rows, + cols, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + out); + } else { + DispatchComputeImpl(dev_ctx, + x, + bias_p, + dequant_scales_p, + act_method, + rows, + cols, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + out); + } + } else { + const T *bias_data = bias_p == nullptr ? nullptr : bias_p->data(); + Load load_func(x.data()); + Store store_func(out->data()); + ComputeImpl( + dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); + } +} + +// (not use) only for registering int32_t +template +void DispatchWithDtype(const Context &dev_ctx, + const DenseTensor &x, + const paddle::optional &bias, + const paddle::optional &dequant_scales, + const paddle::optional &shift, + const paddle::optional &smooth, + const std::string &act_method, + int rows, + int cols, + float quant_scale, + int quant_round_type, + float quant_max_bound, + float quant_min_bound, + DenseTensor *out, + UnusedVersion) {} + +template +void FusedBiasActKernel(const Context &dev_ctx, + const DenseTensor &x, + const paddle::optional &bias, + const paddle::optional &dequant_scales, + const paddle::optional &shift, + const paddle::optional &smooth, + const std::string &act_method, + const std::string &compute_dtype, + int rows, + int cols, + float quant_scale, + int quant_round_type, + float quant_max_bound, + float quant_min_bound, + DenseTensor *out) { + if (x.dtype() == phi::DataType::INT32) { + if (compute_dtype == "bf16") { + DispatchWithDtype( + dev_ctx, + x, + bias, + dequant_scales, + shift, + smooth, + act_method, + rows, + cols, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + out, + typename DispatchDtypeTrait::FuncVersion{}); + + } else if (compute_dtype == "fp16") { + DispatchWithDtype( + dev_ctx, + x, + bias, + dequant_scales, + shift, + smooth, + act_method, + rows, + cols, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + out, + typename DispatchDtypeTrait::FuncVersion{}); + } else if (compute_dtype == "fp32") { + DispatchWithDtype( + dev_ctx, + x, + bias, + dequant_scales, + shift, + smooth, + act_method, + rows, + cols, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + out, + typename DispatchDtypeTrait::FuncVersion{}); + + } else { + PADDLE_THROW("Only bf16, fp16 and fp32 are supported. "); + } + } else { + DispatchWithDtype( + dev_ctx, + x, + bias, + dequant_scales, + shift, + smooth, + act_method, + rows, + cols, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + out, + typename DispatchDtypeTrait::FuncVersion{}); + } +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_bias_act, + GPU, + ALL_LAYOUT, + phi::fusion::FusedBiasActKernel, + float, + phi::dtype::bfloat16, + phi::dtype::float16, + int32_t) {} From 5ba65664f1cb91f0076d3ad138a35b757d4f6dc9 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Thu, 6 Jul 2023 15:21:39 +0800 Subject: [PATCH 03/19] fix bug --- paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu index 3c13634d18382..3c0491ac570ec 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu @@ -551,7 +551,7 @@ void FusedBiasActKernel(const Context &dev_ctx, quant_max_bound, quant_min_bound, out, - typename DispatchDtypeTrait::FuncVersion{}); + typename DispatchDtypeTrait::FuncVersion{}); } else { PADDLE_THROW("Only bf16, fp16 and fp32 are supported. "); From f973746aa9a26c52274c15ab156593973c1301a5 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Thu, 6 Jul 2023 16:57:57 +0800 Subject: [PATCH 04/19] fix name bug --- paddle/phi/infermeta/multiary.cc | 2 +- paddle/phi/infermeta/multiary.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 28e8214670117..791e2e2f7d534 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1345,7 +1345,7 @@ void FusedBiasActInferMeta(const MetaTensor& x, int rows, int cols, float quant_scale, - int quant_rount_type, + int quant_round_type, float quant_max_bound, float quant_min_bound, MetaTensor* out) { diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 5e4578c9652f1..468714aa638ae 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -289,7 +289,7 @@ void FusedBiasActInferMeta(const MetaTensor& x, int rows, int cols, float quant_scale, - int quant_rount_type, + int quant_round_type, float quant_max_bound, float quant_min_bound, MetaTensor* out); From c3427b89d924564e53f05f5d9ada59c81ad8d786 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Thu, 6 Jul 2023 19:33:25 +0800 Subject: [PATCH 05/19] fix ci build error(local hasn't) --- paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu index 3c0491ac570ec..1819e7b2d8777 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu @@ -551,7 +551,7 @@ void FusedBiasActKernel(const Context &dev_ctx, quant_max_bound, quant_min_bound, out, - typename DispatchDtypeTrait::FuncVersion{}); + typename DispatchDtypeTrait::FuncVersion{}); } else { PADDLE_THROW("Only bf16, fp16 and fp32 are supported. "); From c00538244704b1b130db72867b99f5bbf4d175a8 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Fri, 7 Jul 2023 19:51:30 +0800 Subject: [PATCH 06/19] convert to phi api --- paddle/phi/api/yaml/fused_ops.yaml | 3 +- paddle/phi/infermeta/multiary.cc | 28 +++++++++++++------ .../fusion/gpu/fused_bias_act_kernel.cu | 20 +++++++------ 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 8c5e06c81e454..c5a2af934207a 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -55,7 +55,7 @@ optional : bias, x_max - op : fused_bias_act - args : (Tensor x, Tensor bias, Tensor dequant_scales, Tensor shift, Tensor smooth, str act_method, int rows, int cols, str compute_type , float quant_scale=-1, int quant_round_type=0, float quant_max_bound=0.0, float quant_min_bound=0.0) + args : (Tensor x, Tensor bias, Tensor dequant_scales, Tensor shift, Tensor smooth, str act_method = "gelu", str compute_type = "default", int rows = -1, int cols = -1, float quant_scale = -1, int quant_round_type = 1, float quant_max_bound = 127.0, float quant_min_bound = -127.0) output : Tensor(out) infer_meta : func: FusedBiasActInferMeta @@ -63,6 +63,7 @@ func : fused_bias_act data_type : x optional : bias, dequant_scales, shift, smooth + support_dygraph_mode : true - op : fused_dropout_add args : (Tensor x, Tensor y, Tensor seed_tensor, Scalar p, bool is_test, str mode, int seed = 0, bool fix_seed = false) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 791e2e2f7d534..67b4f4718360b 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1357,6 +1357,12 @@ void FusedBiasActInferMeta(const MetaTensor& x, auto token_num = x_dims[0]; auto dim = x_dims[1]; + PADDLE_ENFORCE_GT( + rows, 0, phi::errors::InvalidArgument("The size of Attr(rows) must > 0")); + + PADDLE_ENFORCE_GT( + cols, 0, phi::errors::InvalidArgument("The size of Attr(cols) must > 0")); + if (act_method == "geglu" || act_method == "swiglu") { PADDLE_ENFORCE_EQ( dim % 2, @@ -1414,16 +1420,20 @@ void FusedBiasActInferMeta(const MetaTensor& x, FBADtypeCheck(bias, "bias", compute_dtype); } - if (compute_dtype == "bf16") { - out->set_dtype(phi::DataType::BFLOAT16); - } else if (compute_dtype == "fp16") { - out->set_dtype(phi::DataType::FLOAT16); - } else if (compute_dtype == "fp32") { - out->set_dtype(phi::DataType::FLOAT32); + if (quant_scale > 0) { + out->set_dtype(phi::DataType::INT8); } else { - PADDLE_THROW( - "In the case of quantization enabled with Input(x) INT32, " - "Attr(compute_dtype) must be set in (bf16, fp16, fp32)"); + if (compute_dtype == "bf16") { + out->set_dtype(phi::DataType::BFLOAT16); + } else if (compute_dtype == "fp16") { + out->set_dtype(phi::DataType::FLOAT16); + } else if (compute_dtype == "fp32") { + out->set_dtype(phi::DataType::FLOAT32); + } else { + PADDLE_THROW( + "In the case of quantization enabled with Input(x) INT32, " + "Attr(compute_dtype) must be set in (bf16, fp16, fp32)"); + } } } else { // x.dtype() != phi::DataType::INT32 diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu index 1819e7b2d8777..610fe9338885a 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu @@ -301,7 +301,7 @@ void DispatchComputeImpl(const Context &dev_ctx, if (dequant_scales != nullptr && quant_scale > 0) { DequantLoad load_func( x.data(), dequant_scales->data(), cols); - QuantStore store_func(out->data(), + QuantStore store_func(dev_ctx.template Alloc(out), quant_round_type, quant_scale, quant_max_bound, @@ -310,7 +310,7 @@ void DispatchComputeImpl(const Context &dev_ctx, dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } else if (dequant_scales == nullptr && quant_scale > 0) { Load load_func(x.data()); - QuantStore store_func(out->data(), + QuantStore store_func(dev_ctx.template Alloc(out), quant_round_type, quant_scale, quant_max_bound, @@ -320,12 +320,12 @@ void DispatchComputeImpl(const Context &dev_ctx, } else if (dequant_scales != nullptr && quant_scale <= 0) { DequantLoad load_func( x.data(), dequant_scales->data(), cols); - Store store_func(out->data()); + Store store_func(dev_ctx.template Alloc(out)); ComputeImpl, Store, int32_t>( dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } else { Load load_func(x.data()); - Store store_func(out->data()); + Store store_func(dev_ctx.template Alloc(out)); ComputeImpl( dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } @@ -348,10 +348,12 @@ void DispatchComputeImpl(const Context &dev_ctx, DenseTensor *out) { bool use_glu = (act_method == "geglu" || act_method == "swiglu"); const T *bias_data = bias == nullptr ? nullptr : bias->data(); + if (dequant_scales != nullptr && quant_scale > 0) { + int8_t *out_data = dev_ctx.template Alloc(out); DequantLoad load_func( x.data(), dequant_scales->data(), cols); - QuantStore store_func(out->data(), + QuantStore store_func(dev_ctx.template Alloc(out), shift->data(), smooth->data(), use_glu ? cols / 2 : cols, @@ -363,7 +365,7 @@ void DispatchComputeImpl(const Context &dev_ctx, dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } else if (dequant_scales == nullptr && quant_scale > 0) { Load load_func(x.data()); - QuantStore store_func(out->data(), + QuantStore store_func(dev_ctx.template Alloc(out), shift->data(), smooth->data(), use_glu ? cols / 2 : cols, @@ -376,7 +378,7 @@ void DispatchComputeImpl(const Context &dev_ctx, } else if (dequant_scales != nullptr && quant_scale <= 0) { DequantLoad load_func( x.data(), dequant_scales->data(), cols); - Store store_func(out->data(), + Store store_func(dev_ctx.template Alloc(out), shift->data(), smooth->data(), use_glu ? cols / 2 : cols); @@ -384,7 +386,7 @@ void DispatchComputeImpl(const Context &dev_ctx, dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } else { Load load_func(x.data()); - Store store_func(out->data(), + Store store_func(dev_ctx.template Alloc(out), shift->data(), smooth->data(), use_glu ? cols / 2 : cols); @@ -459,7 +461,7 @@ void DispatchWithDtype(const Context &dev_ctx, } else { const T *bias_data = bias_p == nullptr ? nullptr : bias_p->data(); Load load_func(x.data()); - Store store_func(out->data()); + Store store_func(dev_ctx.template Alloc(out)); ComputeImpl( dev_ctx, bias_data, act_method, rows, cols, load_func, store_func); } From ab8d69b3d99145c74e900bfd305d31c970caa041 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Mon, 10 Jul 2023 21:06:55 +0800 Subject: [PATCH 07/19] add unit test for fused_bias_act --- paddle/phi/api/yaml/fused_ops.yaml | 2 +- test/legacy_test/test_fused_bias_act_op.py | 543 +++++++++++++++++++++ 2 files changed, 544 insertions(+), 1 deletion(-) create mode 100644 test/legacy_test/test_fused_bias_act_op.py diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index c5a2af934207a..324f723880eb2 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -55,7 +55,7 @@ optional : bias, x_max - op : fused_bias_act - args : (Tensor x, Tensor bias, Tensor dequant_scales, Tensor shift, Tensor smooth, str act_method = "gelu", str compute_type = "default", int rows = -1, int cols = -1, float quant_scale = -1, int quant_round_type = 1, float quant_max_bound = 127.0, float quant_min_bound = -127.0) + args : (Tensor x, Tensor bias, Tensor dequant_scales, Tensor shift, Tensor smooth, str act_method = "gelu", str compute_dtype = "default", int rows = -1, int cols = -1, float quant_scale = -1, int quant_round_type = 1, float quant_max_bound = 127.0, float quant_min_bound = -127.0) output : Tensor(out) infer_meta : func: FusedBiasActInferMeta diff --git a/test/legacy_test/test_fused_bias_act_op.py b/test/legacy_test/test_fused_bias_act_op.py new file mode 100644 index 0000000000000..caabe8ef17919 --- /dev/null +++ b/test/legacy_test/test_fused_bias_act_op.py @@ -0,0 +1,543 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from eager_op_test import convert_float_to_uint16 +from scipy.special import erf, expit + +import paddle +from paddle.fluid import core +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.layer_helper import LayerHelper + + +def round_type_1_process(val): + dtype = type(val) + if val >= 0: + return dtype(np.floor(val + 0.5)) + return dtype(np.ceil(val - 0.5)) + + +# rounding to nearest ties away from zero +round_type_1 = np.vectorize(round_type_1_process) + +M_SQRT1_2 = 0.70710678118654752440 # /* 1/sqrt(2) */ copy from gelu-kernel.cc + + +def gelu(x): + out = ( + 0.5 * x.astype('float32') * (1.0 + erf(x.astype('float32') * M_SQRT1_2)) + ) + return out.astype(x.dtype) + + +def swish(x): + out = x.astype('float32') * expit(x.astype('float32')) + return out.astype(x.dtype) + + +def fake_dequant(values, dequant_scales): + out = values * dequant_scales.astype('float32') + return out + + +def fake_quant( + values, shift, smooth, quant_sacle, max_bound, min_bound, round_type +): + values_tmp = (values + shift) * smooth + values_tmp = max_bound * quant_sacle * values_tmp + if round_type == 0: + values_tmp = np.rint(values_tmp) + elif round_type == 1: + values_tmp = round_type_1(values_tmp) + return np.clip(values_tmp, min_bound, max_bound).astype(np.int8) + + +def fused_act_bias_wrapper( + x, + bias, + dequant_scales=None, + shift=None, + smooth=None, + act_method='gelu', + compute_dtype='default', + rows=0, + cols=0, + quant_scale=-1, + quant_round_type=0, + quant_max_bound=0, + quant_min_bound=0, +): + if in_dygraph_mode(): + return paddle._C_ops.fused_bias_act( + x, + bias, + dequant_scales, + shift, + smooth, + act_method, + compute_dtype, + rows, + cols, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + ) + + helper = LayerHelper('fused_act_bias', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='fused_act_bias', + inputs={ + 'x': x, + 'bias': bias, + 'dequant_scales': dequant_scales, + 'shift': shift, + 'smooth': smooth, + }, + attrs={ + 'act_method': act_method, + 'compute_dtype': compute_dtype, + 'rows': rows, + 'cols': cols, + 'quant_scale': quant_scale, + 'quant_round_type': quant_round_type, + 'quant_max_bound': quant_max_bound, + 'quant_min_bound': quant_min_bound, + }, + outputs={'out': out}, + ) + return + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestFusedBiasActOp(unittest.TestCase): + def setUp(self): + paddle.seed(2017) + np.random.seed(2017) + + self.op_type = "fused_bias_act" + self.rtol = 1e-5 + self.atol = 1e-3 + + self.rows = 20 + self.cols = 512 + + self.dtype = 'float32' + self.act_method = 'gelu' + + self.use_glu = False + + self.init_test_case() + self.generate_inputs() + + def init_test_case(self): + pass + + def generate_inputs(self): + self.x = (np.random.rand(self.rows, self.cols) * 16).astype(self.dtype) + self.bias = np.random.rand(self.cols).astype(self.dtype) + + def compute_baseline_output(self): + out = gelu(self.x + self.bias).astype(self.dtype) + return out + + def compute_paddle_output(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + x = paddle.to_tensor(self.x) + bias = paddle.to_tensor(self.bias) + + return fused_act_bias_wrapper( + x=x, + bias=bias, + rows=self.rows, + cols=self.cols, + act_method=self.act_method, + ) + + def test_check_output(self): + final_out_ref = self.compute_baseline_output() + final_out = self.compute_paddle_output() + np.testing.assert_allclose( + final_out_ref, final_out, rtol=self.rtol, atol=self.atol + ) + + +class TestBaseFP16(TestFusedBiasActOp): + def init_test_case(self): + self.dtype = np.float16 + self.act_method = 'gelu' + + +class TestGegluFP16(TestFusedBiasActOp): + def init_test_case(self): + self.dtype = np.float16 + self.act_method = 'geglu' + + def compute_baseline_output(self): + res_tmp = (self.x + self.bias).astype(self.dtype) + res_tmp_head = res_tmp[:, : self.cols // 2] + res_tmp_tail = res_tmp[:, self.cols // 2 :] + res_tmp_head_act = gelu(res_tmp_head) + out = res_tmp_head_act * res_tmp_tail + return out + + +class TestSwigluFP16(TestFusedBiasActOp): + def init_test_case(self): + self.dtype = np.float16 + self.act_method = 'swiglu' + + def compute_baseline_output(self): + res_tmp = (self.x + self.bias).astype(self.dtype) + res_tmp_head = res_tmp[:, : self.cols // 2] + res_tmp_tail = res_tmp[:, self.cols // 2 :] + res_tmp_head_act = swish(res_tmp_head) + out = res_tmp_head_act * res_tmp_tail + return out + + +class TestQuantFP32(TestFusedBiasActOp): + def init_test_case(self): + self.atol = 1 + + self.dtype = 'float32' + self.compute_dtype = 'fp32' + self.quant_scale = 0.5 + self.quant_round_type = 1 + self.quant_max_bound = 127.0 + self.quant_min_bound = -127.0 + + def generate_inputs(self): + self.x = np.random.randint( + low=-16, high=16, size=(self.rows, self.cols) + ).astype('int32') + self.bias = np.random.rand(self.cols).astype(self.dtype) + self.dequant_scales = np.random.rand(self.cols).astype('float32') + quant_params_cols = self.cols // 2 if self.use_glu else self.cols + self.shift = np.zeros(quant_params_cols).astype(self.dtype) + self.smooth = np.ones(quant_params_cols).astype(self.dtype) + + def compute_baseline_output(self): + input_dequanted = fake_dequant(self.x, self.dequant_scales) + output_tmp = gelu(input_dequanted + self.bias).astype(self.dtype) + out = fake_quant( + output_tmp, + self.shift, + self.smooth, + self.quant_scale, + self.quant_max_bound, + self.quant_min_bound, + self.quant_round_type, + ) + return out + + def compute_paddle_output(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + x = paddle.to_tensor(self.x) + bias = paddle.to_tensor(self.bias) + dequant_scales = paddle.to_tensor(self.dequant_scales) + shift = paddle.to_tensor(self.shift) + smooth = paddle.to_tensor(self.smooth) + + out = fused_act_bias_wrapper( + x=x, + bias=bias, + dequant_scales=dequant_scales, + shift=shift, + smooth=smooth, + act_method=self.act_method, + compute_dtype=self.compute_dtype, + rows=self.rows, + cols=self.cols, + quant_scale=self.quant_scale, + quant_round_type=self.quant_round_type, + quant_max_bound=self.quant_max_bound, + quant_min_bound=self.quant_min_bound, + ) + + return out + + +class TestQuantFP16(TestQuantFP32): + def init_test_case(self): + self.atol = 1 + + self.dtype = 'float16' + self.compute_dtype = 'fp16' + self.quant_scale = 0.5 + self.quant_round_type = 1 + self.quant_max_bound = 127.0 + self.quant_min_bound = -127.0 + + +class TestQuantGegluFP16(TestQuantFP32): + def init_test_case(self): + self.atol = 1 + + self.dtype = 'float16' + self.compute_dtype = 'fp16' + self.act_method = 'geglu' + self.quant_scale = 0.5 + self.quant_round_type = 1 + self.quant_max_bound = 127.0 + self.quant_min_bound = -127.0 + + self.use_glu = True + + def compute_baseline_output(self): + input_dequanted = fake_dequant(self.x, self.dequant_scales) + tmp = (input_dequanted + self.bias).astype('float32') + tmp_head = tmp[:, : self.cols // 2] + tmp_tail = tmp[:, self.cols // 2 :] + out_tmp = gelu(tmp_head).astype('float32') * tmp_tail + + out = fake_quant( + out_tmp, + self.shift, + self.smooth, + self.quant_scale, + self.quant_max_bound, + self.quant_min_bound, + self.quant_round_type, + ) + return out + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestFusedBiasActOpBF16(unittest.TestCase): + def setUp(self): + paddle.seed(2019) + np.random.seed(2019) + + self.op_type = "fused_bias_act" + self.rtol = 1e-3 + self.atol = 1e-3 + + self.rows = 20 + self.cols = 512 + + self.act_method = 'gelu' + self.compute_dtype = 'default' + + self.init_test_case() + self.generate_inputs() + + def init_test_case(self): + pass + + def generate_inputs(self): + self.x = np.random.rand(self.rows, self.cols).astype('float32') * 16 + self.bias = np.random.rand(self.cols).astype('float32') + + def compute_baseline_output(self): + out = gelu(self.x.astype('float32') + self.bias) + return convert_float_to_uint16(out) + + def compute_paddle_output(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + x = paddle.to_tensor(convert_float_to_uint16(self.x)) + bias = paddle.to_tensor(convert_float_to_uint16(self.bias)) + + out = fused_act_bias_wrapper( + x=x, + bias=bias, + act_method=self.act_method, + compute_dtype=self.compute_dtype, + rows=self.rows, + cols=self.cols, + ) + return out + + def test_check_output(self): + final_out_ref = self.compute_baseline_output() + final_out = self.compute_paddle_output() + np.testing.assert_allclose( + final_out_ref, final_out, rtol=self.rtol, atol=self.atol + ) + + +class TestGegluBF16(TestFusedBiasActOpBF16): + def init_test_case(self): + self.act_method = 'geglu' + self.compute_dtype = 'default' + + def compute_baseline_output(self): + res_tmp = self.x + self.bias + res_tmp_head = res_tmp[:, : self.cols // 2] + res_tmp_tail = res_tmp[:, self.cols // 2 :] + res_tmp_head_act = gelu(res_tmp_head) + out = res_tmp_head_act * res_tmp_tail + return convert_float_to_uint16(out) + + +class TestSwigluBF16(TestFusedBiasActOpBF16): + def init_test_case(self): + self.act_method = 'swiglu' + self.compute_dtype = 'default' + + def compute_baseline_output(self): + res_tmp = self.x + self.bias + res_tmp_head = res_tmp[:, : self.cols // 2] + res_tmp_tail = res_tmp[:, self.cols // 2 :] + res_tmp_head_act = swish(res_tmp_head) + out = res_tmp_head_act * res_tmp_tail + return convert_float_to_uint16(out) + + +class TestQuantBF16(TestFusedBiasActOpBF16): + def init_test_case(self): + self.atol = 1 + + self.compute_dtype = 'bf16' + self.quant_scale = 0.5 + self.quant_round_type = 1 + self.quant_max_bound = 127.0 + self.quant_min_bound = -127.0 + + self.use_glu = False + + def generate_inputs(self): + self.x = np.random.randint( + low=-16, high=16, size=(self.rows, self.cols) + ).astype('int32') + self.bias = np.zeros(self.cols).astype('float32') + self.dequant_scales = np.random.rand(self.cols).astype('float32') + quant_params_cols = self.cols // 2 if self.use_glu else self.cols + self.shift = np.zeros(quant_params_cols).astype('float32') + self.smooth = np.ones(quant_params_cols).astype('float32') + + def compute_baseline_output(self): + input_dequanted = fake_dequant( + self.x.astype('float32'), self.dequant_scales + ) + output_tmp = gelu(input_dequanted + self.bias) + out = fake_quant( + output_tmp, + self.shift, + self.smooth, + self.quant_scale, + self.quant_max_bound, + self.quant_min_bound, + self.quant_round_type, + ) + + return out + + def compute_paddle_output(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + x = paddle.to_tensor(self.x) + bias = paddle.to_tensor(convert_float_to_uint16(self.bias)) + dequant_scales = paddle.to_tensor(self.dequant_scales) + shift = paddle.to_tensor(convert_float_to_uint16(self.shift)) + smooth = paddle.to_tensor(convert_float_to_uint16(self.smooth)) + + out = fused_act_bias_wrapper( + x=x, + bias=bias, + dequant_scales=dequant_scales, + shift=shift, + smooth=smooth, + act_method=self.act_method, + compute_dtype=self.compute_dtype, + rows=self.rows, + cols=self.cols, + quant_scale=self.quant_scale, + quant_round_type=self.quant_round_type, + quant_max_bound=self.quant_max_bound, + quant_min_bound=self.quant_min_bound, + ) + return out + + +class TestQuantGegluBF16(TestQuantBF16): + def init_test_case(self): + self.atol = 1 + + self.compute_dtype = 'bf16' + self.act_method = 'geglu' + self.quant_scale = 0.5 + self.quant_round_type = 1 + self.quant_max_bound = 127.0 + self.quant_min_bound = -127.0 + + self.use_glu = True + + def compute_baseline_output(self): + input_dequanted = fake_dequant( + self.x.astype('float32'), self.dequant_scales + ) + tmp = (input_dequanted + self.bias).astype('float32') + tmp_head = tmp[:, : self.cols // 2] + tmp_tail = tmp[:, self.cols // 2 :] + out_tmp = gelu(tmp_head).astype('float32') * tmp_tail + + out = fake_quant( + out_tmp, + self.shift, + self.smooth, + self.quant_scale, + self.quant_max_bound, + self.quant_min_bound, + self.quant_round_type, + ) + + return out + + +class TestQuantSwigluBF16(TestQuantBF16): + def init_test_case(self): + self.atol = 1 + + self.compute_dtype = 'bf16' + self.act_method = 'swiglu' + self.quant_scale = 0.5 + self.quant_round_type = 1 + self.quant_max_bound = 127.0 + self.quant_min_bound = -127.0 + + self.use_glu = True + + def compute_baseline_output(self): + input_dequanted = fake_dequant( + self.x.astype('float32'), self.dequant_scales + ) + tmp = (input_dequanted + self.bias).astype('float32') + tmp_head = tmp[:, : self.cols // 2] + tmp_tail = tmp[:, self.cols // 2 :] + out_tmp = swish(tmp_head).astype('float32') * tmp_tail + + out = fake_quant( + out_tmp, + self.shift, + self.smooth, + self.quant_scale, + self.quant_max_bound, + self.quant_min_bound, + self.quant_round_type, + ) + + return out + + +if __name__ == '__main__': + unittest.main() From ccb6145d5ba08bf3731e32a58e3db7d9296041b7 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Mon, 10 Jul 2023 22:09:38 +0800 Subject: [PATCH 08/19] fix infershape bug --- paddle/phi/infermeta/multiary.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 67b4f4718360b..5a8c37f59746a 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1408,7 +1408,7 @@ void FusedBiasActInferMeta(const MetaTensor& x, }; // In the case of quantization enabled, the dtype for computation is - // determined based on compute_type. + // determined based on compute_dtype. if (x.dtype() == phi::DataType::INT32) { PADDLE_ENFORCE_NE( compute_dtype, @@ -1454,7 +1454,11 @@ void FusedBiasActInferMeta(const MetaTensor& x, FBADtypeCheck(x, "x", compute_dtype); } } - out->set_dtype(x.dtype()); + if (quant_scale > 0) { + out->set_dtype(phi::DataType::INT8); + } else { + out->set_dtype(x.dtype()); + } } out->set_layout(x.layout()); } From 14924b39be1cbb14a37c3ac73b935a968e87d121 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Mon, 10 Jul 2023 23:18:20 +0800 Subject: [PATCH 09/19] add cudnn header file --- paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu index 610fe9338885a..4fe3eb635c59d 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu @@ -16,6 +16,7 @@ #include "glog/logging.h" +#include "paddle/phi/backends/gpu/cuda/cudnn_helper.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/common/data_type.h" From 88aca3406be31ae58a80ce7e6168bae505d83a5b Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Mon, 10 Jul 2023 23:47:14 +0800 Subject: [PATCH 10/19] try to solve windows build errer --- .../fusion/gpu/fused_bias_act_kernel.cu | 65 +-------- .../kernels/fusion/gpu/fused_bias_act_utils.h | 129 ++++++++++++++++++ 2 files changed, 130 insertions(+), 64 deletions(-) create mode 100644 paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu index 4fe3eb635c59d..8545a380c2209 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu @@ -12,73 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - -#include "glog/logging.h" - -#include "paddle/phi/backends/gpu/cuda/cudnn_helper.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/backends/gpu/gpu_dnn.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/flags.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/activation_functor.h" -#include "paddle/phi/kernels/funcs/load_store_util.h" -#include "paddle/phi/kernels/gpu/gelu_funcs.h" - -PHI_DECLARE_bool(use_fast_math); +#include "paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h" namespace phi { namespace fusion { -using phi::funcs::CudaSwishFunctor; -using phi::funcs::DequantLoad; -using phi::funcs::Load; -using phi::funcs::QuantStore; -using phi::funcs::Store; - -template -using CudnnDataType = phi::backends::gpu::CudnnDataType; -template -using LayerNormParamType = typename CudnnDataType::BatchNormParamType; - -// TODO(lzc): transfer to phi::funcs -template -struct GeluFunctor { - inline __host__ __device__ T operator()(const T x) const { - using U = LayerNormParamType; - const U casted_x = static_cast(x); - const U temp = erf(casted_x * static_cast(M_SQRT1_2)); - const U out = (casted_x * static_cast(0.5) * (static_cast(1) + temp)); - return static_cast(out); - } -}; - -template -struct FastGeluFunctor { - inline __device__ T operator()(const T x) const { - return phi::GeluFwd(x); - } -}; - -inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { - constexpr int kBlockSize = 128; - constexpr int kNumWaves = 16; - - const int device_id = phi::backends::gpu::GetCurrentDeviceId(); - const int sm_count = phi::backends::gpu::GetGPUMultiProcessors(device_id); - const int max_thread_per_multiprocessor = - phi::backends::gpu::GetGPUMultiProcessors(device_id); - - *num_blocks = - std::max(1, - std::min((n + kBlockSize - 1) / kBlockSize, - sm_count * max_thread_per_multiprocessor / - kBlockSize * kNumWaves)); - return cudaSuccess; -} - template (src_vec1, &output_this_thread[idx]); store_func.template store(src_vec1, bi * hid_dim + idx); } } diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h new file mode 100644 index 0000000000000..cb62d39ebd166 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h @@ -0,0 +1,129 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "glog/logging.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/load_store_util.h" +#include "paddle/phi/kernels/gpu/gelu_funcs.h" + +PHI_DECLARE_bool(use_fast_math); + +namespace phi { +namespace fusion { + +template +struct GeluComputeType; + +template <> +struct GeluComputeType { + using Type = float; +}; + +template <> +struct GeluComputeType { + using Type = float; +}; + +template <> +struct GeluComputeType { + using Type = float; +}; + +template +using GeluType = typename GeluComputeType::Type; + +using phi::funcs::DequantLoad; +using phi::funcs::Load; +using phi::funcs::QuantStore; +using phi::funcs::Store; + +template +struct BaseActivationFunctor { + using ELEMENT_TYPE = T; + + using AttrPair = std::vector>; + + AttrPair GetAttrs() { return AttrPair(); } +}; + +// For windows build +template +struct CudaSwishFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float beta = 1.0; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}}; + } + + // swish(x) = x / (1 + exp(-beta * x)) + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + MPType b = static_cast(beta); + return static_cast(x / (one + exp(-b * x))); + } +}; + +// TODO(lzc): transfer to phi::funcs +template +struct GeluFunctor { + inline __host__ __device__ T operator()(const T x) const { + using U = GeluType; + const U casted_x = static_cast(x); + const U temp = erf(casted_x * static_cast(M_SQRT1_2)); + const U out = (casted_x * static_cast(0.5) * (static_cast(1) + temp)); + return static_cast(out); + } +}; + +template +struct FastGeluFunctor { + inline __device__ T operator()(const T x) const { + return phi::GeluFwd(x); + } +}; + +inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { + constexpr int kBlockSize = 128; + constexpr int kNumWaves = 16; + + const int device_id = phi::backends::gpu::GetCurrentDeviceId(); + const int sm_count = phi::backends::gpu::GetGPUMultiProcessors(device_id); + const int max_thread_per_multiprocessor = + phi::backends::gpu::GetGPUMultiProcessors(device_id); + + *num_blocks = + std::max(1, + std::min((n + kBlockSize - 1) / kBlockSize, + sm_count * max_thread_per_multiprocessor / + kBlockSize * kNumWaves)); + return cudaSuccess; +} + +} // namespace fusion +} // namespace phi From 013970ffe9699a8e360878996278ca683332210c Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Tue, 11 Jul 2023 01:54:36 +0800 Subject: [PATCH 11/19] trasnfer macro M_SQRT1_2 --- paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h index cb62d39ebd166..e5e0dc2586125 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h @@ -14,7 +14,6 @@ #pragma once -#include #include #include "glog/logging.h" @@ -30,6 +29,9 @@ #include "paddle/phi/kernels/funcs/load_store_util.h" #include "paddle/phi/kernels/gpu/gelu_funcs.h" +// for windows build +#define M_SQRT1_2 0.70710678118654752440 + PHI_DECLARE_bool(use_fast_math); namespace phi { From bc3355f0c26e805ea29a760ba0db2c7a085cbe06 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Tue, 11 Jul 2023 11:46:37 +0800 Subject: [PATCH 12/19] fix CI bug --- .../fusion/gpu/fused_bias_act_kernel.cu | 2 + .../kernels/fusion/gpu/fused_bias_act_utils.h | 2 + test/legacy_test/test_fused_bias_act_op.py | 69 +++++++++++++++++-- 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu index 8545a380c2209..033cd0402174c 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef PADDLE_WITH_HIP #include "paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h" namespace phi { @@ -527,3 +528,4 @@ PD_REGISTER_KERNEL(fused_bias_act, phi::dtype::bfloat16, phi::dtype::float16, int32_t) {} +#endif diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h index e5e0dc2586125..05dfab7071d49 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef PADDLE_WITH_HIP #pragma once #include @@ -129,3 +130,4 @@ inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { } // namespace fusion } // namespace phi +#endif diff --git a/test/legacy_test/test_fused_bias_act_op.py b/test/legacy_test/test_fused_bias_act_op.py index caabe8ef17919..e536fc2a7b05e 100644 --- a/test/legacy_test/test_fused_bias_act_op.py +++ b/test/legacy_test/test_fused_bias_act_op.py @@ -19,6 +19,7 @@ from scipy.special import erf, expit import paddle +import paddle.nn.functional as F from paddle.fluid import core from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.layer_helper import LayerHelper @@ -34,7 +35,7 @@ def round_type_1_process(val): # rounding to nearest ties away from zero round_type_1 = np.vectorize(round_type_1_process) -M_SQRT1_2 = 0.70710678118654752440 # /* 1/sqrt(2) */ copy from gelu-kernel.cc +M_SQRT1_2 = 0.70710678118654752440 def gelu(x): @@ -186,6 +187,37 @@ def init_test_case(self): self.act_method = 'gelu' +class TestFastGeluFP16(TestFusedBiasActOp): + def use_fast_math(self, enabled): + paddle.set_flags({'FLAGS_use_fast_math': enabled}) + + def init_test_case(self): + self.dtype = np.float16 + self.act_method = 'gelu' + + def compute_baseline_output(self): + out = F.gelu( + paddle.to_tensor(self.x) + paddle.to_tensor(self.bias), + approximate=True, + ) + return out + + def compute_paddle_output(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + x = paddle.to_tensor(self.x) + bias = paddle.to_tensor(self.bias) + self.use_fast_math(True) + out = fused_act_bias_wrapper( + x=x, + bias=bias, + rows=self.rows, + cols=self.cols, + act_method=self.act_method, + ) + self.use_fast_math(False) + return out + + class TestGegluFP16(TestFusedBiasActOp): def init_test_case(self): self.dtype = np.float16 @@ -322,7 +354,9 @@ def compute_baseline_output(self): @unittest.skipIf( - not core.is_compiled_with_cuda(), "core is not compiled with CUDA" + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", ) class TestFusedBiasActOpBF16(unittest.TestCase): def setUp(self): @@ -376,6 +410,11 @@ def test_check_output(self): ) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) class TestGegluBF16(TestFusedBiasActOpBF16): def init_test_case(self): self.act_method = 'geglu' @@ -390,6 +429,11 @@ def compute_baseline_output(self): return convert_float_to_uint16(out) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) class TestSwigluBF16(TestFusedBiasActOpBF16): def init_test_case(self): self.act_method = 'swiglu' @@ -404,11 +448,17 @@ def compute_baseline_output(self): return convert_float_to_uint16(out) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) class TestQuantBF16(TestFusedBiasActOpBF16): def init_test_case(self): self.atol = 1 self.compute_dtype = 'bf16' + self.act_method = 'gelu' self.quant_scale = 0.5 self.quant_round_type = 1 self.quant_max_bound = 127.0 @@ -418,10 +468,11 @@ def init_test_case(self): def generate_inputs(self): self.x = np.random.randint( - low=-16, high=16, size=(self.rows, self.cols) + low=-1000, high=1000, size=(self.rows, self.cols) ).astype('int32') self.bias = np.zeros(self.cols).astype('float32') - self.dequant_scales = np.random.rand(self.cols).astype('float32') + self.dequant_scales = np.ones(self.cols).astype('float32') + quant_params_cols = self.cols // 2 if self.use_glu else self.cols self.shift = np.zeros(quant_params_cols).astype('float32') self.smooth = np.ones(quant_params_cols).astype('float32') @@ -469,6 +520,11 @@ def compute_paddle_output(self): return out +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) class TestQuantGegluBF16(TestQuantBF16): def init_test_case(self): self.atol = 1 @@ -504,6 +560,11 @@ def compute_baseline_output(self): return out +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) class TestQuantSwigluBF16(TestQuantBF16): def init_test_case(self): self.atol = 1 From ba98d65ab1b20f3dce8b0103728f3070515e971b Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Tue, 11 Jul 2023 15:05:05 +0800 Subject: [PATCH 13/19] fix rocm build bug && add some unit test --- .../fusion/gpu/fused_bias_act_kernel.cu | 6 +- .../kernels/fusion/gpu/fused_bias_act_utils.h | 6 +- test/legacy_test/test_fused_bias_act_op.py | 159 +++++++++++++----- 3 files changed, 123 insertions(+), 48 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu index 033cd0402174c..a14d9f43b4d86 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef PADDLE_WITH_HIP #include "paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h" namespace phi { @@ -193,6 +192,7 @@ void ComputeImpl(const Context &dev_ctx, int cols, LoadFunc load_func, StoreFunc store_func) { +#ifndef PADDLE_WITH_HIP if (act_method == "geglu") { // Note(Zhengzekang): For GLU structure, we need divide the cols by 2. VLOG(8) << "Doing geglu"; @@ -221,6 +221,7 @@ void ComputeImpl(const Context &dev_ctx, PADDLE_THROW(phi::errors::Unimplemented( "Currently Only Support GeGLU, SwiGLU, GeLU")); } +#endif } template @@ -440,6 +441,7 @@ void FusedBiasActKernel(const Context &dev_ctx, float quant_max_bound, float quant_min_bound, DenseTensor *out) { +#ifndef PADDLE_WITH_HIP if (x.dtype() == phi::DataType::INT32) { if (compute_dtype == "bf16") { DispatchWithDtype( @@ -515,6 +517,7 @@ void FusedBiasActKernel(const Context &dev_ctx, out, typename DispatchDtypeTrait::FuncVersion{}); } +#endif } } // namespace fusion @@ -528,4 +531,3 @@ PD_REGISTER_KERNEL(fused_bias_act, phi::dtype::bfloat16, phi::dtype::float16, int32_t) {} -#endif diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h index 05dfab7071d49..a5a326f6de599 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef PADDLE_WITH_HIP #pragma once #include @@ -28,7 +27,9 @@ #include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/load_store_util.h" +#ifndef PADDLE_WITH_HIP #include "paddle/phi/kernels/gpu/gelu_funcs.h" +#endif // for windows build #define M_SQRT1_2 0.70710678118654752440 @@ -104,12 +105,14 @@ struct GeluFunctor { } }; +#ifndef PADDLE_WITH_HIP template struct FastGeluFunctor { inline __device__ T operator()(const T x) const { return phi::GeluFwd(x); } }; +#endif inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { constexpr int kBlockSize = 128; @@ -130,4 +133,3 @@ inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { } // namespace fusion } // namespace phi -#endif diff --git a/test/legacy_test/test_fused_bias_act_op.py b/test/legacy_test/test_fused_bias_act_op.py index e536fc2a7b05e..f56bfa5e93fe3 100644 --- a/test/legacy_test/test_fused_bias_act_op.py +++ b/test/legacy_test/test_fused_bias_act_op.py @@ -21,8 +21,6 @@ import paddle import paddle.nn.functional as F from paddle.fluid import core -from paddle.fluid.framework import in_dygraph_mode -from paddle.fluid.layer_helper import LayerHelper def round_type_1_process(val): @@ -69,7 +67,7 @@ def fake_quant( def fused_act_bias_wrapper( x, - bias, + bias=None, dequant_scales=None, shift=None, smooth=None, @@ -82,48 +80,21 @@ def fused_act_bias_wrapper( quant_max_bound=0, quant_min_bound=0, ): - if in_dygraph_mode(): - return paddle._C_ops.fused_bias_act( - x, - bias, - dequant_scales, - shift, - smooth, - act_method, - compute_dtype, - rows, - cols, - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound, - ) - - helper = LayerHelper('fused_act_bias', **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) - - helper.append_op( - type='fused_act_bias', - inputs={ - 'x': x, - 'bias': bias, - 'dequant_scales': dequant_scales, - 'shift': shift, - 'smooth': smooth, - }, - attrs={ - 'act_method': act_method, - 'compute_dtype': compute_dtype, - 'rows': rows, - 'cols': cols, - 'quant_scale': quant_scale, - 'quant_round_type': quant_round_type, - 'quant_max_bound': quant_max_bound, - 'quant_min_bound': quant_min_bound, - }, - outputs={'out': out}, + return paddle._C_ops.fused_bias_act( + x, + bias, + dequant_scales, + shift, + smooth, + act_method, + compute_dtype, + rows, + cols, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, ) - return @unittest.skipIf( @@ -600,5 +571,105 @@ def compute_baseline_output(self): return out +class TestAssert(unittest.TestCase): + def setUp(self): + self.rows = 20 + self.cols = 512 + + self.dtype = 'float32' + self.act_method = 'gelu' + + def test_assert_case1(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + x = np.random.randint( + low=-16, high=16, size=(self.rows, self.cols) + ).astype('int32') + + bias = np.random.rand(self.cols).astype(self.dtype) + + try: + out = fused_act_bias_wrapper( + x=paddle.to_tensor(x), + bias=paddle.to_tensor(bias), + rows=self.rows, + cols=self.cols, + ) + except ValueError as e: + print(e) + + def test_assert_case2(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + x = np.random.randint( + low=-16, high=16, size=(self.rows, self.cols) + ).astype('int32') + + bias = np.random.rand(self.cols).astype(self.dtype) + + try: + out = fused_act_bias_wrapper( + x=paddle.to_tensor(x), + bias=paddle.to_tensor(bias), + rows=self.rows, + cols=self.cols, + compute_dtype='fp16', + ) + except ValueError as e: + print(e) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestWithoutBias(unittest.TestCase): + def setUp(self): + paddle.seed(2017) + np.random.seed(2017) + + self.op_type = "fused_bias_act" + self.rtol = 1e-5 + self.atol = 1e-3 + + self.rows = 20 + self.cols = 512 + + self.dtype = 'float32' + self.act_method = 'gelu' + + self.use_glu = False + + self.init_test_case() + self.generate_inputs() + + def init_test_case(self): + pass + + def generate_inputs(self): + self.x = (np.random.rand(self.rows, self.cols) * 16).astype(self.dtype) + # self.bias = np.random.rand(self.cols).astype(self.dtype) + + def compute_baseline_output(self): + out = gelu(self.x).astype(self.dtype) + return out + + def compute_paddle_output(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + x = paddle.to_tensor(self.x) + + return fused_act_bias_wrapper( + x=x, + bias=None, + rows=self.rows, + cols=self.cols, + act_method=self.act_method, + ) + + def test_check_output(self): + final_out_ref = self.compute_baseline_output() + final_out = self.compute_paddle_output() + np.testing.assert_allclose( + final_out_ref, final_out, rtol=self.rtol, atol=self.atol + ) + + if __name__ == '__main__': unittest.main() From 76b9a182366aca93de4e21a895fd5403428e5aa1 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Tue, 11 Jul 2023 16:00:22 +0800 Subject: [PATCH 14/19] add cpu skip in unit test --- test/legacy_test/test_fused_bias_act_op.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/legacy_test/test_fused_bias_act_op.py b/test/legacy_test/test_fused_bias_act_op.py index f56bfa5e93fe3..624a554458d27 100644 --- a/test/legacy_test/test_fused_bias_act_op.py +++ b/test/legacy_test/test_fused_bias_act_op.py @@ -571,6 +571,9 @@ def compute_baseline_output(self): return out +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) class TestAssert(unittest.TestCase): def setUp(self): self.rows = 20 From c5cdc4b9c4f29c5c6270083a48f9f93ea3d8f1d9 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Tue, 11 Jul 2023 20:42:03 +0800 Subject: [PATCH 15/19] fix ci problem --- paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu | 4 ++-- paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h | 6 ++---- test/legacy_test/test_fused_bias_act_op.py | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu index a14d9f43b4d86..3dab19057cb85 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu @@ -17,6 +17,7 @@ namespace phi { namespace fusion { +#ifndef PADDLE_WITH_HIP template @@ -424,6 +423,7 @@ void DispatchWithDtype(const Context &dev_ctx, float quant_min_bound, DenseTensor *out, UnusedVersion) {} +#endif template void FusedBiasActKernel(const Context &dev_ctx, diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h index a5a326f6de599..823ced9baffbc 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h @@ -18,6 +18,7 @@ #include "glog/logging.h" +#ifndef PADDLE_WITH_HIP #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_dnn.h" @@ -27,9 +28,7 @@ #include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/load_store_util.h" -#ifndef PADDLE_WITH_HIP #include "paddle/phi/kernels/gpu/gelu_funcs.h" -#endif // for windows build #define M_SQRT1_2 0.70710678118654752440 @@ -105,14 +104,12 @@ struct GeluFunctor { } }; -#ifndef PADDLE_WITH_HIP template struct FastGeluFunctor { inline __device__ T operator()(const T x) const { return phi::GeluFwd(x); } }; -#endif inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { constexpr int kBlockSize = 128; @@ -133,3 +130,4 @@ inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { } // namespace fusion } // namespace phi +#endif diff --git a/test/legacy_test/test_fused_bias_act_op.py b/test/legacy_test/test_fused_bias_act_op.py index 624a554458d27..215a55d959934 100644 --- a/test/legacy_test/test_fused_bias_act_op.py +++ b/test/legacy_test/test_fused_bias_act_op.py @@ -598,7 +598,7 @@ def test_assert_case1(self): cols=self.cols, ) except ValueError as e: - print(e) + pass def test_assert_case2(self): paddle.disable_static(place=paddle.CUDAPlace(0)) @@ -617,7 +617,7 @@ def test_assert_case2(self): compute_dtype='fp16', ) except ValueError as e: - print(e) + pass @unittest.skipIf( From 48469c19d04ce0f3cc9b6a54492cf17afc2dde41 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Tue, 11 Jul 2023 21:41:26 +0800 Subject: [PATCH 16/19] all trigger ci --- test/legacy_test/test_fused_bias_act_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_fused_bias_act_op.py b/test/legacy_test/test_fused_bias_act_op.py index 215a55d959934..2800b9f505f28 100644 --- a/test/legacy_test/test_fused_bias_act_op.py +++ b/test/legacy_test/test_fused_bias_act_op.py @@ -403,7 +403,7 @@ def compute_baseline_output(self): @unittest.skipIf( not core.is_compiled_with_cuda() or not core.is_bfloat16_supported(core.CUDAPlace(0)), - "core is not complied with CUDA and not support the bfloat16", + "core is not complied with CUDA and not support the bfloat16 ", ) class TestSwigluBF16(TestFusedBiasActOpBF16): def init_test_case(self): From 171ac090f0ad1720e63a1710cfa35e42c501cb4f Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Wed, 12 Jul 2023 10:51:57 +0800 Subject: [PATCH 17/19] add unit_test for ci coverage test --- test/legacy_test/test_fused_bias_act_op.py | 104 +++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/test/legacy_test/test_fused_bias_act_op.py b/test/legacy_test/test_fused_bias_act_op.py index 2800b9f505f28..7d90fc36bbb31 100644 --- a/test/legacy_test/test_fused_bias_act_op.py +++ b/test/legacy_test/test_fused_bias_act_op.py @@ -114,6 +114,7 @@ def setUp(self): self.dtype = 'float32' self.act_method = 'gelu' + self.compute_dtype = 'default' self.use_glu = False @@ -142,6 +143,7 @@ def compute_paddle_output(self): rows=self.rows, cols=self.cols, act_method=self.act_method, + compute_dtype=self.compute_dtype, ) def test_check_output(self): @@ -158,6 +160,20 @@ def init_test_case(self): self.act_method = 'gelu' +class TestWithComTypeFP32(TestFusedBiasActOp): + def init_test_case(self): + self.dtype = 'float32' + self.act_method = 'gelu' + self.compute_dtype = 'fp32' + + +class TestWithComTypeFP16(TestFusedBiasActOp): + def init_test_case(self): + self.dtype = 'float16' + self.act_method = 'gelu' + self.compute_dtype = 'fp16' + + class TestFastGeluFP16(TestFusedBiasActOp): def use_fast_math(self, enabled): paddle.set_flags({'FLAGS_use_fast_math': enabled}) @@ -279,6 +295,49 @@ def compute_paddle_output(self): return out +class TestDequantFP32(TestQuantFP32): + def init_test_case(self): + self.rows = 10 + self.cols = 10 + self.atol = 1 + + self.dtype = 'float32' + self.compute_dtype = 'fp32' + self.quant_scale = 0.5 + self.quant_round_type = 1 + self.quant_max_bound = 127.0 + self.quant_min_bound = -127.0 + + def generate_inputs(self): + self.x = np.random.randint( + low=-16, high=16, size=(self.rows, self.cols) + ).astype('int32') + self.bias = np.random.rand(self.cols).astype(self.dtype) + self.dequant_scales = np.ones(self.cols).astype('float32') + + def compute_baseline_output(self): + input_dequanted = fake_dequant(self.x, self.dequant_scales) + out = gelu(input_dequanted + self.bias).astype(self.dtype) + return out + + def compute_paddle_output(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + x = paddle.to_tensor(self.x) + bias = paddle.to_tensor(self.bias) + dequant_scales = paddle.to_tensor(self.dequant_scales) + + out = fused_act_bias_wrapper( + x=x, + bias=bias, + dequant_scales=dequant_scales, + act_method=self.act_method, + compute_dtype=self.compute_dtype, + rows=self.rows, + cols=self.cols, + ) + return out + + class TestQuantFP16(TestQuantFP32): def init_test_case(self): self.atol = 1 @@ -291,6 +350,20 @@ def init_test_case(self): self.quant_min_bound = -127.0 +class TestDequantFP16(TestDequantFP32): + def init_test_case(self): + self.rows = 10 + self.cols = 10 + self.atol = 1 + + self.dtype = 'float16' + self.compute_dtype = 'fp16' + self.quant_scale = 0.5 + self.quant_round_type = 1 + self.quant_max_bound = 127.0 + self.quant_min_bound = -127.0 + + class TestQuantGegluFP16(TestQuantFP32): def init_test_case(self): self.atol = 1 @@ -381,6 +454,17 @@ def test_check_output(self): ) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestWithComTypeBF16(unittest.TestCase): + def init_test_case(self): + self.act_method = 'geglu' + self.compute_dtype = 'bf16' + + @unittest.skipIf( not core.is_compiled_with_cuda() or not core.is_bfloat16_supported(core.CUDAPlace(0)), @@ -619,6 +703,26 @@ def test_assert_case2(self): except ValueError as e: pass + def test_assert_case3(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + x = np.random.randint( + low=-16, high=16, size=(self.rows, self.cols) + ).astype('int32') + + bias = np.random.rand(self.cols).astype(self.dtype) + act_method = "error_type" + try: + out = fused_act_bias_wrapper( + x=paddle.to_tensor(x), + bias=paddle.to_tensor(bias), + rows=self.rows, + cols=self.cols, + compute_dtype='fp16', + act_method=act_method, + ) + except ValueError as e: + pass + @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" From 69e79c2a640eef3ab611d3c6f6b4d0018365cd66 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Wed, 12 Jul 2023 10:55:13 +0800 Subject: [PATCH 18/19] for ci rocm build --- paddle/phi/kernels/funcs/load_store_util.h | 3 ++- paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/funcs/load_store_util.h b/paddle/phi/kernels/funcs/load_store_util.h index 0fe7a3ce7a348..848f7d0b40bd8 100644 --- a/paddle/phi/kernels/funcs/load_store_util.h +++ b/paddle/phi/kernels/funcs/load_store_util.h @@ -20,6 +20,7 @@ namespace phi { namespace funcs { +#ifndef PADDLE_WITH_HIP template __device__ __inline__ T ClipFunc(const T v, const T min, const T max) { if (v > max) return max; @@ -215,6 +216,6 @@ struct QuantStore { const T *smooth_; const int cols_; }; - +#endif } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h index 823ced9baffbc..9457a7b23dbb5 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h @@ -18,7 +18,6 @@ #include "glog/logging.h" -#ifndef PADDLE_WITH_HIP #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_dnn.h" @@ -27,9 +26,10 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" +#ifndef PADDLE_WITH_HIP #include "paddle/phi/kernels/funcs/load_store_util.h" #include "paddle/phi/kernels/gpu/gelu_funcs.h" - +#endif // for windows build #define M_SQRT1_2 0.70710678118654752440 @@ -38,6 +38,7 @@ PHI_DECLARE_bool(use_fast_math); namespace phi { namespace fusion { +#ifndef PADDLE_WITH_HIP template struct GeluComputeType; @@ -127,7 +128,7 @@ inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { kBlockSize * kNumWaves)); return cudaSuccess; } +#endif } // namespace fusion } // namespace phi -#endif From 88ce8d8c8a8916ae542a615919bc4d041f8b639d Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Wed, 12 Jul 2023 19:27:14 +0800 Subject: [PATCH 19/19] fix CI-APPROVAL problem --- paddle/phi/infermeta/multiary.cc | 6 ++++-- .../phi/kernels/fusion/gpu/fused_bias_act_kernel.cu | 11 +++++++++-- paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h | 5 ----- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index abc3b01c3e932..28ccfbb7fc30c 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1430,9 +1430,11 @@ void FusedBiasActInferMeta(const MetaTensor& x, } else if (compute_dtype == "fp32") { out->set_dtype(phi::DataType::FLOAT32); } else { - PADDLE_THROW( + PADDLE_THROW(phi::errors::InvalidArgument( "In the case of quantization enabled with Input(x) INT32, " - "Attr(compute_dtype) must be set in (bf16, fp16, fp32)"); + "Attr(compute_dtype) must be set in (bf16, fp16, fp32), " + "but get compute_dtype (%s)", + compute_dtype)); } } } else { diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu index 3dab19057cb85..ff722a0dfdff4 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu @@ -12,8 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "glog/logging.h" +#include "paddle/phi/core/flags.h" #include "paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h" +PHI_DECLARE_bool(use_fast_math); + namespace phi { namespace fusion { @@ -495,9 +499,12 @@ void FusedBiasActKernel(const Context &dev_ctx, quant_min_bound, out, typename DispatchDtypeTrait::FuncVersion{}); - } else { - PADDLE_THROW("Only bf16, fp16 and fp32 are supported. "); + PADDLE_THROW(phi::errors::InvalidArgument( + "In the case of quantization enabled with Input(x) INT32, " + "Attr(compute_dtype) must be set in (bf16, fp16, fp32), " + "but get compute_dtype (%s)", + compute_dtype)); } } else { DispatchWithDtype( diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h index 9457a7b23dbb5..93ed50ec4e0df 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h @@ -16,15 +16,12 @@ #include -#include "glog/logging.h" - #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #ifndef PADDLE_WITH_HIP #include "paddle/phi/kernels/funcs/load_store_util.h" @@ -33,8 +30,6 @@ // for windows build #define M_SQRT1_2 0.70710678118654752440 -PHI_DECLARE_bool(use_fast_math); - namespace phi { namespace fusion {