diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 0241f2c24df6a..501d6c987172a 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1437,15 +1437,15 @@ data_transform : skip_transform : out_size, size_tensor, scale_tensor -- op : llm_int8_matmul - args : (Tensor x, Tensor weight, Tensor weight_scale, float threshold=6.0) +- op : llm_int8_linear + args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, float threshold=6.0) output : Tensor(out) infer_meta : - func : LLMInt8MatmulInferMeta - param : [x, weight] + func : LLMInt8LinearInferMeta kernel : - func : llm_int8_matmul + func : llm_int8_linear data_type : x + optional: bias - op : log args : (Tensor x) @@ -2013,15 +2013,6 @@ func : qr backward : qr_grad -- op : quant_for_compress - args : (Tensor x, int bits = 8, str layout = "weight_only") - output : Tensor(out), Tensor(scale) - infer_meta : - func : QuantForCompressInferMeta - kernel : - func : quant_for_compress - data_type: x - - op : real args : (Tensor x) output : Tensor (out) @@ -2768,14 +2759,24 @@ intermediate: warprnntgrad backward : warprnnt_grad -- op : weight_only_matmul - args : (Tensor x, Tensor weight, Tensor weight_scale) +- op : weight_only_linear + args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype) output : Tensor(out) infer_meta : - func : WeightOnlyMatmulInferMeta + func : WeightOnlyLinearInferMeta kernel : - func : weight_only_matmul + func : weight_only_linear data_type : x + optional: bias + +- op : weight_quantize + args : (Tensor x, str algo = "weight_only_int8") + output : Tensor(out), Tensor(scale) + infer_meta : + func : WeightQuantizeInferMeta + kernel : + func : weight_quantize + data_type: x - op : weighted_sample_neighbors args : (Tensor row, Tensor colptr, Tensor edge_weight, Tensor input_nodes, Tensor eids, int sample_size, bool return_eids) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 481e788a0edf1..31a9296721d82 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2468,6 +2468,53 @@ void LambInferMeta(const MetaTensor& param, } } +void LLMInt8LinearInferMeta(const MetaTensor& x, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& weight_scale, + const float threshold, + MetaTensor* out) { + auto x_dims = x.dims(); + auto w_dims = weight.dims(); + PADDLE_ENFORCE_EQ( + w_dims.size(), + 2UL, + errors::InvalidArgument("The input(weight) must be a 2D Tensor.")); + PADDLE_ENFORCE_EQ( + x_dims[x_dims.size() - 1], + w_dims[1], + errors::InvalidArgument( + "Input(X) dim[-1] and Input(Weight) dim[1] should be euqal." + "But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)", + x_dims[x_dims.size() - 1], + w_dims[1])); + PADDLE_ENFORCE_EQ( + w_dims[0] % 16, + 0, + phi::errors::InvalidArgument( + "The first dimension of input must be divisible by 16, but got[%d]", + w_dims[0])); + PADDLE_ENFORCE_EQ( + w_dims[1] % 16, + 0, + phi::errors::InvalidArgument( + "The second dimension of input must be divisible by 16, but got[%d]", + w_dims[1])); + PADDLE_ENFORCE_EQ( + weight_scale.dims()[0], + w_dims[0], + errors::InvalidArgument( + "Input(weight_scale) dim[0] and Input(Weight) dim[0] should be euqal." + "But received Input(weight_scale) dim[0](%s) != Input(Weight) " + "dim[0](%s)", + weight_scale.dims()[0], + w_dims[0])); + auto out_dims = x_dims; + out_dims[out_dims.size() - 1] = w_dims[0]; + out->set_dims(out_dims); + out->set_dtype(x.dtype()); +} + void LogspaceInferMeta(const MetaTensor& start, const MetaTensor& stop, const MetaTensor& number, @@ -3598,6 +3645,52 @@ void WarprnntInferMeta(const MetaTensor& input, loss->set_dtype(input.dtype()); } +void WeightOnlyLinearInferMeta(const MetaTensor& x, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& weight_scale, + const std::string& weight_dtype, + MetaTensor* out) { + auto x_dims = x.dims(); + auto w_dims = weight.dims(); + auto n = weight_scale.dims()[0]; + PADDLE_ENFORCE( + weight_dtype == "int8" || weight_dtype == "int4", + errors::InvalidArgument("quant_method must be 'int8' or 'int4'.")); + PADDLE_ENFORCE_EQ( + w_dims.size(), + 2UL, + errors::InvalidArgument("The input(weight) must be a 2D Tensor.")); + PADDLE_ENFORCE_EQ( + weight_scale.dims().size(), + 1UL, + errors::InvalidArgument("The input(weight_scale) must be a 1D Tensor.")); + PADDLE_ENFORCE_EQ( + w_dims[0] % 16, + 0, + phi::errors::InvalidArgument( + "The first dimension of input must be divisible by 16, but got[%d]", + w_dims[0])); + PADDLE_ENFORCE_EQ( + w_dims[1] % 16, + 0, + phi::errors::InvalidArgument( + "The second dimension of input must be divisible by 16, but got[%d]", + w_dims[1])); + PADDLE_ENFORCE_EQ( + x_dims[x_dims.size() - 1], + w_dims[1], + errors::InvalidArgument( + "Input(X) dim[-1] and Input(Weight) dim[1] should be euqal." + "But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)", + x_dims[x_dims.size() - 1], + w_dims[1])); + auto out_dims = x_dims; + out_dims[out_dims.size() - 1] = n; + out->set_dims(out_dims); + out->set_dtype(x.dtype()); +} + void WhereInferMeta(const MetaTensor& condition, const MetaTensor& x, const MetaTensor& y, @@ -3931,58 +4024,6 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row, out_count->set_dtype(DataType::INT32); } -void LLMInt8MatmulInferMeta(const MetaTensor& x, - const MetaTensor& weight, - MetaTensor* out) { - auto x_dims = x.dims(); - auto w_dims = weight.dims(); - PADDLE_ENFORCE_EQ( - w_dims.size(), - 2UL, - errors::InvalidArgument("The input(weight) must be a 2D Tensor.")); - PADDLE_ENFORCE_EQ( - x_dims[x_dims.size() - 1], - w_dims[1], - errors::InvalidArgument( - "Input(X) dim[-1] and Input(Weight) dim[1] should be euqal." - "But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)", - x_dims[x_dims.size() - 1], - w_dims[1])); - auto out_dims = x_dims; - out_dims[out_dims.size() - 1] = w_dims[0]; - out->set_dims(out_dims); - out->set_dtype(x.dtype()); -} - -void WeightOnlyMatmulInferMeta(const MetaTensor& x, - const MetaTensor& weight, - const MetaTensor& weight_scale, - MetaTensor* out) { - auto x_dims = x.dims(); - auto w_dims = weight.dims(); - auto n = weight_scale.dims()[0]; - PADDLE_ENFORCE_EQ( - w_dims.size(), - 2UL, - errors::InvalidArgument("The input(weight) must be a 2D Tensor.")); - PADDLE_ENFORCE_EQ( - weight_scale.dims().size(), - 1UL, - errors::InvalidArgument("The input(weight_scale) must be a 1D Tensor.")); - PADDLE_ENFORCE_EQ( - x_dims[x_dims.size() - 1], - w_dims[1], - errors::InvalidArgument( - "Input(X) dim[-1] and Input(Weight) dim[1] should be euqal." - "But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)", - x_dims[x_dims.size() - 1], - w_dims[1])); - auto out_dims = x_dims; - out_dims[out_dims.size() - 1] = n; - out->set_dims(out_dims); - out->set_dtype(x.dtype()); -} - void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, const MetaTensor& cache_kv, const MetaTensor& src_mask, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index e1bd5bd1fe1e8..6f27d27368c69 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -426,6 +426,13 @@ void LambInferMeta(const MetaTensor& param, MetaTensor* beta2_pow_out, MetaTensor* master_param_outs); +void LLMInt8LinearInferMeta(const MetaTensor& x, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& weight_scale, + const float threshold, + MetaTensor* out); + void LogspaceInferMeta(const MetaTensor& start, const MetaTensor& stop, const MetaTensor& number, @@ -656,6 +663,13 @@ void WarprnntInferMeta(const MetaTensor& input, MetaTensor* loss, MetaTensor* warpctcgrad); +void WeightOnlyLinearInferMeta(const MetaTensor& x, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& weight_scale, + const std::string& weight_dtype, + MetaTensor* out); + void WeightedSampleNeighborsInferMeta(const MetaTensor& row, const MetaTensor& col_ptr, const MetaTensor& edge_weight, @@ -755,15 +769,6 @@ void FusedMultiHeadAttentionVariableInferMeta(const MetaTensor& query, bool causal, MetaTensor* out); -void LLMInt8MatmulInferMeta(const MetaTensor& x, - const MetaTensor& weight, - MetaTensor* out); - -void WeightOnlyMatmulInferMeta(const MetaTensor& x, - const MetaTensor& weight, - const MetaTensor& weight_scale, - MetaTensor* out); - void FusedRopeInferMeta(const MetaTensor& q, const MetaTensor& k, const MetaTensor& v, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 1116b9cedf1e6..ed2e8859733c8 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -5030,6 +5030,48 @@ void UnStackInferMeta(const MetaTensor& x, } } +void WeightQuantizeInferMeta(const MetaTensor& x, + const std::string& algo, + MetaTensor* out, + MetaTensor* scale) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2UL, + phi::errors::InvalidArgument( + "The x tensor of quant op must be 2D, but got[%d]", x_dims.size())); + PADDLE_ENFORCE_EQ( + x_dims[0] % 64, + 0, + phi::errors::InvalidArgument( + "The first dimension of input must be divisible by 64, but got[%d]", + x_dims[0])); + PADDLE_ENFORCE_EQ( + x_dims[1] % 16, + 0, + phi::errors::InvalidArgument( + "The second dimension of input must be divisible by 16, but got[%d]", + x_dims[1])); + std::vector dim_scale({x_dims[1]}); + std::vector dim_out; + if (algo == "weight_only_int8" || algo == "llm.int8") { + dim_out = std::vector({x_dims[1], x_dims[0]}); + } else if (algo == "weight_only_int4") { + dim_out = std::vector({x_dims[1] / 2, x_dims[0]}); + } else { + phi::errors::InvalidArgument( + "The algo must be in ['weight_only_int8', 'weight_only_int4', " + "'llm.int8'], but got[%s]", + algo); + } + out->set_dims(phi::make_ddim(dim_out)); + + out->set_dtype(DataType::INT8); + + scale->set_dims(phi::make_ddim(dim_scale)); + scale->set_dtype(DataType::FLOAT32); +} + void ChannelShuffleInferMeta(const MetaTensor& x, int groups, const std::string& data_format, @@ -5090,46 +5132,6 @@ void CheckNumericsInferMeta(const MetaTensor& tensor, values->set_dims(phi::make_ddim({3})); } -void QuantForCompressInferMeta(const MetaTensor& x, - int bits, - const std::string& layout, - MetaTensor* out, - MetaTensor* scale) { - auto x_dims = x.dims(); - PADDLE_ENFORCE_EQ( - x_dims.size(), - 2UL, - phi::errors::InvalidArgument( - "The x tensor of quant op must be 2D, but got[%d]", x_dims.size())); - PADDLE_ENFORCE_GE( - x_dims[0], - 64, - phi::errors::OutOfRange("The first dimension of input is out of range " - "(expected at least 64, but got %ld).", - x_dims[0])); - PADDLE_ENFORCE_EQ( - x_dims[0] % 64, - 0, - phi::errors::InvalidArgument( - "The first dimension of input must be divisible by 64, but got[%d]", - x_dims[0])); - std::vector dim_scale({x_dims[1]}); - std::vector dim_out; - if (bits == 8) { - dim_out = std::vector({x_dims[1], x_dims[0]}); - } else if (bits == 4) { - dim_out = std::vector({x_dims[1] / 2, x_dims[0]}); - } else { - phi::errors::InvalidArgument("The bit must be 8 or 4, but got %d", bits); - } - out->set_dims(phi::make_ddim(dim_out)); - - out->set_dtype(DataType::INT8); - - scale->set_dims(phi::make_ddim(dim_scale)); - scale->set_dtype(DataType::FLOAT32); -} - void StridedUnChangedInferMeta(const MetaTensor& x, MetaTensor* out) { out->share_meta(x); out->set_strides(x.strides()); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index ded1b5b6a9f2c..cd824ebc5737e 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -436,6 +436,11 @@ void QrInferMeta(const MetaTensor& x, MetaTensor* q, MetaTensor* r); +void WeightQuantizeInferMeta(const MetaTensor& x, + const std::string& algo, + MetaTensor* out, + MetaTensor* scale); + void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out); void ReduceInferMeta(const MetaTensor& x, @@ -728,12 +733,6 @@ void UnStackInferMeta(const MetaTensor& x, int num, std::vector outs); -void QuantForCompressInferMeta(const MetaTensor& x, - int bits, - const std::string& layout, - MetaTensor* out, - MetaTensor* scale); - void StridedUnChangedInferMeta(const MetaTensor& x, MetaTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/cpu/quant_for_compress_kernel.cc b/paddle/phi/kernels/cpu/weight_quantize_kernel.cc similarity index 68% rename from paddle/phi/kernels/cpu/quant_for_compress_kernel.cc rename to paddle/phi/kernels/cpu/weight_quantize_kernel.cc index 3d21371f4fd05..2539f37d12197 100644 --- a/paddle/phi/kernels/cpu/quant_for_compress_kernel.cc +++ b/paddle/phi/kernels/cpu/weight_quantize_kernel.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/quant_for_compress_kernel.h" +#include "paddle/phi/kernels/weight_quantize_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/common_shape.h" -#include "paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h" +#include "paddle/phi/kernels/impl/weight_quantize_kernel_impl.h" namespace phi { @@ -27,7 +27,7 @@ void quant_compute(const DeviceContext& dev_ctx, const DenseTensor& x, DenseTensor* out, DenseTensor* scale, - const std::string& layout) { + const std::string& algo) { const auto x_dims = x.dims(); PADDLE_ENFORCE_EQ( x_dims.size(), @@ -56,57 +56,49 @@ void quant_compute(const DeviceContext& dev_ctx, int_processed_2.Resize(out->dims()); dev_ctx.template Alloc(&int_processed_2); D* int_processed_2_data = int_processed_2.data(); - - per_channel_scale(scale_data, x_data, m, n); + per_channel_scale(scale_data, x_data, m, n, bits == 8 ? 127.0f : 7.0f); per_channel_quant(x_int_data, x_data, scale_data, m, n); - if (layout == "weight_only") { + if (algo == "llm.int8") { + std::vector axis = {1, 0}; + funcs::Transpose trans; + trans(dev_ctx, x_int, out, axis); + } else { permute_B_rows_for_mixed_gemm( - int_processed_data, x_int_data, std::vector{m, n}, (int64_t)80); + int_processed_data, x_int_data, std::vector{m, n}); subbyte_transpose_impl( int_processed_2_data, int_processed_data, std::vector{m, n}); interleave_column_major_tensor( out_data, int_processed_2_data, std::vector{m, n}); add_bias_and_interleave_inplace(out_data, num); - } else if (layout == "llm.int8") { - std::vector axis = {1, 0}; - funcs::Transpose trans; - trans(dev_ctx, x_int, out, axis); - } else { - phi::errors::InvalidArgument( - "The layout must be weight_only or llm.int8, but got %s", layout); } } template -void QuantForCompressKernel(const Context& dev_ctx, - const DenseTensor& x, - int bits, - const std::string& layout, - DenseTensor* out, - DenseTensor* scale) { - if (bits == 8) { - dev_ctx.template Alloc(out); - dev_ctx.template Alloc(scale); - quant_compute(dev_ctx, x, out, scale, layout); - } else if (bits == 4 && layout == "weight_only") { - dev_ctx.template Alloc(out); - dev_ctx.template Alloc(scale); - quant_compute(dev_ctx, x, out, scale, layout); +void WeightQuantizeKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::string& algo, + DenseTensor* out, + DenseTensor* scale) { + dev_ctx.template Alloc(out); + dev_ctx.template Alloc(scale); + if (algo == "weight_only_int8" || algo == "llm.int8") { + quant_compute(dev_ctx, x, out, scale, algo); + } else if (algo == "weight_only_int4") { + quant_compute(dev_ctx, x, out, scale, algo); } else { phi::errors::Unimplemented( - "The bits only support 8 or weight_only 4, but got[%s] [%d]", - layout, - bits); + "The algo must be in ['weight_only_int8', 'weight_only_int4', " + "'llm.int8'], but got[%s]", + algo); } - // VLOG(0) << "x: " << x.dtype() << x; - // VLOG(0) << "out: " << out->dtype() << *out; } } // namespace phi -PD_REGISTER_KERNEL(quant_for_compress, +PD_REGISTER_KERNEL(weight_quantize, CPU, ALL_LAYOUT, - phi::QuantForCompressKernel, - phi::dtype::float16) {} + phi::WeightQuantizeKernel, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/funcs/weight_only_gemv.cu b/paddle/phi/kernels/funcs/weight_only_gemv.cu new file mode 100644 index 0000000000000..a1c746bd49ce1 --- /dev/null +++ b/paddle/phi/kernels/funcs/weight_only_gemv.cu @@ -0,0 +1,435 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/funcs/weight_only_gemv.h" + +#include +#include +#include +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/datatype_traits.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +namespace { + +#ifdef PADDLE_WITH_CUDA +constexpr int kWarpSize = 32; +constexpr int kPerBlockWarpNum = 16; + +///////////////////////////////////////////////////////////////////// +template +__device__ inline void fast_cvt_4_packed_signed_i8s_to_2_half2s( + T halves[4], int8_t signed_chars[4]) { + assert(false); +} + +// Specialization for fast cast from FP16 -> int8 +template <> +__device__ inline void fast_cvt_4_packed_signed_i8s_to_2_half2s( + half halves[4], int8_t signed_chars[4]) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + uint32_t* h = reinterpret_cast(halves); + uint32_t i8s = *reinterpret_cast(signed_chars); + + static constexpr uint32_t mask_for_elt_01 = 0x5150; + static constexpr uint32_t mask_for_elt_23 = 0x5352; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); +#endif +} + +// Specialization for fast cast from BF16 -> int8 +#ifdef PADDLE_CUDA_BF16 +template <> +__device__ inline void fast_cvt_4_packed_signed_i8s_to_2_half2s( + __nv_bfloat16 halves[4], int8_t signed_chars[4]) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + uint32_t* bf16_result_ptr = reinterpret_cast(halves); + uint32_t i8s = *reinterpret_cast(signed_chars); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + +// Subtract out fp32_base + 128 to make the unsigned integer signed. +#pragma unroll + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; + } + +// Truncate the fp32 representation and pack up as bfloat16s. +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], + fp32_intermediates_casted[2 * ii + 1], + 0x7632); + } +#else + // Disable this on architectures older than Ampere since they lack hardware + // for bf16 mma. If one wishes to use HMMA on older hardware, they should + // Convert directly to FP16 using FP16 converters. + assert(false); +#endif +} +#endif + +/* Gelu Activation */ + +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__inline__ __device__ float tanh_opt(float x) { +#if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000) + float r; + asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x)); + return r; +#else + const float exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#endif +} + +template +struct GeluActivation { + using return_type = T; + static __device__ __forceinline__ T apply(const T& val) { + if (!EnableFastGelu) return val; + const float cdf = + 0.5f * (1.0f + tanh_opt((0.7978845608028654f * + (val + 0.044715f * val * val * val)))); + return val * cdf; + } +}; + +template +struct ConvertFloatFunc { + ConvertFloatFunc() {} + static __device__ __forceinline__ float apply(const T& val) { + assert(false); + return 0.0f; + } +}; + +template <> +struct ConvertFloatFunc { + static __device__ __forceinline__ float apply(const half& val) { + return __half2float(val); + } +}; + +#ifdef PADDLE_CUDA_BF16 +template <> +struct ConvertFloatFunc<__nv_bfloat16> { + static __device__ __forceinline__ float apply(const __nv_bfloat16& val) { + return __bfloat162float(val); + } +}; +#endif + +template +struct ConvertDstFunc { + static __device__ __forceinline__ T apply(const float& val) { assert(false); } +}; + +template <> +struct ConvertDstFunc { + static __device__ __forceinline__ half apply(const float& val) { + return __float2half_rn(val); + } +}; + +#ifdef PADDLE_CUDA_BF16 +template <> +struct ConvertDstFunc<__nv_bfloat16> { + static __device__ __forceinline__ __nv_bfloat16 apply(const float& val) { + return __float2bfloat16_rn(val); + } +}; +#endif + +template +struct HalfMul { + static __device__ __forceinline__ T apply(const T& x, const T& y) { + return __hmul(x, y); + } +}; + +#ifdef PADDLE_CUDA_BF16 +template <> +struct HalfMul<__nv_bfloat16> { + static __device__ __forceinline__ __nv_bfloat16 + apply(const __nv_bfloat16& x, const __nv_bfloat16& y) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __hmul(x, y); +#else + return __float2bfloat16_rn(__bfloat162float(x) * __bfloat162float(y)); +#endif + } +}; +#endif + +/* +Int8 Weightonly GEMV. +X: 1 x k +Weight(ColMajor): n x k +Each Warp Process: 1 x k matmul 1 x k +*/ +template +__global__ void int8_weight_only_gemv(const T* input, + const int8_t* weight, + const float* scale_list, + const T* bias, + T* output, + const int k, + const int n) { + constexpr int kWarpSize = 32; + constexpr int kVecSize = 16; + T vec_input[kVecSize]; + int8_t vec_weight[kVecSize]; + T vec_weight_f16[kVecSize]; + + const int warp_id = threadIdx.x / kWarpSize; + const int lane_id = threadIdx.x % kWarpSize; + const int tile_id = blockIdx.x * blockDim.x / kWarpSize + warp_id; + const int row_id = tile_id * 2 + ((lane_id % 8) > 3 ? 1 : 0); + weight += tile_id * k * 2; + + float v = 0.f, scale = scale_list[row_id], v_bias; + + if (Bias) { + v_bias = ConvertFloatFunc::apply(bias[row_id]); + } + +#pragma unroll + for (int i = lane_id * kVecSize; i < k * 2; i += kVecSize * kWarpSize) { + *(int4*)vec_weight = *(int4*)(weight + i); // NOLINT + *(float4*)vec_input = // NOLINT + *(float4*)(input + i / 128 * 64 + (i % 64)); // NOLINT + *(float4*)(vec_input + 8) = // NOLINT + *(float4*)(input + i / 128 * 64 + (i % 64) + 8); // NOLINT +#pragma unroll + for (int p = 0; p < kVecSize; p += 4) { + fast_cvt_4_packed_signed_i8s_to_2_half2s(vec_weight_f16 + p, + vec_weight + p); + } +#pragma unroll + for (int p = 0; p < kVecSize; ++p) { + v += ConvertFloatFunc::apply( + HalfMul::apply(vec_input[p], vec_weight_f16[p / 8 + (p % 8) * 2])); + } + } + // Do WarpReduceSum. + v += __shfl_xor_sync(0xffffffff, v, 16); + v += __shfl_xor_sync(0xffffffff, v, 8); + v += __shfl_xor_sync(0xffffffff, v, 2); + v += __shfl_xor_sync(0xffffffff, v, 1); + if (lane_id == 0 || lane_id == 4) { + if (Bias) { + output[row_id] = ConvertDstFunc::apply( + GeluActivation::apply(v * scale + v_bias)); + } else { + output[row_id] = ConvertDstFunc::apply( + GeluActivation::apply(v * scale)); + } + } +} +#endif + +template +void int8_weight_only_gemv_launcher(const T* input, + const int8_t* weight, + const float* scale_list, + const T* bias, + T* output, + const int k, + const int n, + const bool gelu, + gpuStream_t stream) { +#ifdef PADDLE_WITH_CUDA + dim3 block(kWarpSize * kPerBlockWarpNum); // equal to 512; + dim3 grid(n / kPerBlockWarpNum / + 2); // Note(zhengzekang): Since each warp process 2 rows of matrix. + if (bias) { + if (gelu) { + int8_weight_only_gemv<<>>( + input, weight, scale_list, bias, output, k, n); + } else { + int8_weight_only_gemv<<>>( + input, weight, scale_list, bias, output, k, n); + } + } else { + if (gelu) { + int8_weight_only_gemv<<>>( + input, weight, scale_list, bias, output, k, n); + } else { + int8_weight_only_gemv<<>>( + input, weight, scale_list, bias, output, k, n); + } + } +#endif +} + +template <> +void int8_weight_only_gemv_launcher(const float* input, + const int8_t* weight, + const float* scale_list, + const float* bias, + float* output, + const int k, + const int n, + const bool gelu, + gpuStream_t stream) { + // Weightonly GEMV do not support float. + assert(false); +} + +template <> +void int8_weight_only_gemv_launcher(const phi::dtype::bfloat16* input, + const int8_t* weight, + const float* scale_list, + const phi::dtype::bfloat16* bias, + phi::dtype::bfloat16* output, + const int k, + const int n, + const bool gelu, + gpuStream_t stream) { + // Environment do not support bf16. + assert(false); +} + +} // namespace + +template +void GemvWeightonlyInt8Wrapper(const Context& ctx, + const T* x, + const int8_t* weight, + const T* bias, + const float* weight_scale, + const int n, + const int k, + const std::string& act_method, + T* output) { + using DataType = typename PDDataTypeTraits::DataType; + + bool gelu = false; + if (act_method == "gelu") { + gelu = true; + } else if (act_method == "None") { + gelu = false; + } else { + PADDLE_THROW( + errors::InvalidArgument("Currently, Int8 weightonly GEMV act_method " + "only support `gelu`, `None`. ")); + } + + int8_weight_only_gemv_launcher( + reinterpret_cast(x), + weight, + weight_scale, + reinterpret_cast(bias), + reinterpret_cast(output), + k, + n, + gelu, + ctx.stream()); +} + +template +void GemvWeightonlyInt8Kernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight, + const paddle::optional& bias, + const DenseTensor& weight_scale, + const std::string& act_method, + DenseTensor* out) { + const T* x_data = x.data(); + const int8_t* weight_data = + weight.data(); // Actually, we pass the weight datatype is + // uint8_t type. + const T* bias_data = bias ? bias.get().data() : nullptr; + const float* weight_scale_data = weight_scale.data(); + T* out_data = dev_ctx.template Alloc(out); + + int k = x.dims()[1]; + int n = weight.dims()[0]; + GemvWeightonlyInt8Wrapper(dev_ctx, + x_data, + weight_data, + bias_data, + weight_scale_data, + n, + k, + act_method, + out_data); +} + +template void GemvWeightonlyInt8Wrapper(const phi::GPUContext& ctx, + const phi::dtype::float16* x, + const int8_t* weight, + const phi::dtype::float16* bias, + const float* weight_scale, + const int n, + const int k, + const std::string& act_method, + phi::dtype::float16* output); + +template void GemvWeightonlyInt8Wrapper(const phi::GPUContext& ctx, + const phi::dtype::bfloat16* x, + const int8_t* weight, + const phi::dtype::bfloat16* bias, + const float* weight_scale, + const int n, + const int k, + const std::string& act_method, + phi::dtype::bfloat16* output); + +template void GemvWeightonlyInt8Wrapper(const phi::GPUContext& ctx, + const float* x, + const int8_t* weight, + const float* bias, + const float* weight_scale, + const int n, + const int k, + const std::string& act_method, + float* output); + +} // namespace phi diff --git a/paddle/phi/kernels/funcs/weight_only_gemv.h b/paddle/phi/kernels/funcs/weight_only_gemv.h new file mode 100644 index 0000000000000..8f61ab22ba6ea --- /dev/null +++ b/paddle/phi/kernels/funcs/weight_only_gemv.h @@ -0,0 +1,32 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GemvWeightonlyInt8Wrapper(const Context& ctx, + const T* x, + const int8_t* weight, + const T* bias, + const float* weight_scale, + const int n, + const int k, + const std::string& act_method, + T* output); + +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h index 2c96d39b6beab..55c5ac37b11b7 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h @@ -1,10 +1,26 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h index 1726bc9054dda..b023998002cf2 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h @@ -1,10 +1,26 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h index 5fc621b0b8bf6..21dee49a9792d 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -1,4 +1,18 @@ - +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h index b6d880d4de5d1..0ff18f45e1a87 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -102,8 +102,6 @@ struct GemmFpAIntB { /// Parameters structure struct Arguments : UniversalArgumentsBase { - GemmUniversalMode mode = GemmUniversalMode::kGemm; - cutlass::gemm::GemmCoord problem_size; typename Mma::IteratorA::TensorRef ref_A; typename Mma::IteratorB::TensorRef ref_B; @@ -111,9 +109,6 @@ struct GemmFpAIntB { typename Epilogue::OutputTileIterator::TensorRef ref_C; typename Epilogue::OutputTileIterator::TensorRef ref_D; - // Control serial split-k - int batch_count; - typename EpilogueOutputOp::Params output_op; // For gather+scatter operations @@ -121,9 +116,6 @@ struct GemmFpAIntB { int const* gather_B_indices; int const* scatter_D_indices; - // Included so we can use Gemm Universal - int batch_stride_D = 0; - // // Methods // @@ -144,10 +136,13 @@ struct GemmFpAIntB { int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, int const* scatter_D_indices = nullptr) - : UniversalArgumentsBase(mode, - problem_size, - /*serial_split_k_factor=*/1, - /*batch_stride_D=*/0), + : // TODO(wangbojun) hard code here for GemmUniversalMode::kGemm and + // batch_stride_D + UniversalArgumentsBase( + GemmUniversalMode::kGemm, + problem_size, + /*serial_split_k_factor=*/serial_split_k_factor, + /*batch_stride_D=*/0), ref_A(ref_A), ref_B(ref_B), ref_scale(ref_scale), @@ -505,10 +500,9 @@ struct GemmFpAIntB { void operator()(Params const& params, SharedStorage& shared_storage) { // NOLINT #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) - // static constexpr bool compile_needed = platform::is_same::value; KernelRunner::run_kernel(params, - // shared_storage); - CUTLASS_NOT_IMPLEMENTED(); + static constexpr bool compile_needed = + platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) // static constexpr bool compile_needed = platform::is_same::value || #ifdef PADDLE_CUDA_BF16 cutlass::platform::is_same::value || #endif cutlass::platform::is_same::value, "Specialized for bfloat16, half, float"); + static_assert( cutlass::platform::is_same::value || cutlass::platform::is_same::value || @@ -71,13 +88,23 @@ void generic_mixed_gemm_kernelLauncher(const T* A, cutlass::platform::is_same::value, cutlass::half_t, T>::type; - using ElementType = ElementType_; - +#ifdef PADDLE_CUDA_BF16 + using ElementType = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::bfloat16_t, + ElementType_>::type; +#endif using CutlassWeightType_ = typename cutlass::platform::conditional< cutlass::platform::is_same::value, cutlass::half_t, WeightType>::type; - using CutlassWeightType = CutlassWeightType_; + +#ifdef PADDLE_CUDA_BF16 + using CutlassWeightType = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::bfloat16_t, + CutlassWeightType_>::type; +#endif // We need separate config for each architecture since we will target // different tensorcore instructions. For float, we do not target TCs. @@ -156,10 +183,17 @@ void generic_mixed_gemm_kernelLauncher(const T* A, Gemm gemm; if (gemm.get_workspace_size(args) > workspace_bytes) { + // TODO(wangbojun) here to reset the split-k in gemm args, but no work for + // now to run bf16 mixgemm, we have set the split-k factor to 1 VLOG(1) << "Requested split-k but workspace size insufficient. Falling " "back to non-split-k implementation."; + VLOG(1) << "need workspace sizoe of: " << gemm.get_workspace_size(args) + << ", but got " << workspace_bytes; + VLOG(1) << "args.batch_stride_D:" << args.batch_stride_D; + VLOG(1) << "args.batch_count:" << args.batch_count; // If requested split-k factor will require more workspace bytes, revert to // standard gemm. + // args.batch_count = 1; } @@ -209,13 +243,13 @@ struct dispatch_stages { size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { + // VLOG(3)<<__PRETTY_FUNCTION__; std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); throw std::runtime_error("[dispatch_stages::dispatch] " + err_msg); } }; - template CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { + // VLOG(3)<<__PRETTY_FUNCTION__; int device{-1}; check_cuda_error(cudaGetDevice(&device)); sm_ = getSMVersion(); @@ -559,7 +596,9 @@ CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { } template -CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() {} +CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() { + // VLOG(3)<<__PRETTY_FUNCTION__; +} template template @@ -577,20 +616,38 @@ void CutlassFpAIntBGemmRunner::dispatch_to_arch( const size_t workspace_bytes, cudaStream_t stream, int* occupancy) { - // if (sm_ >= 70 && sm_ < 75) { - // dispatch_gemm_to_cutlass( - // A, B, weight_scales, biases, C, m, n, k, workspace_ptr, - // workspace_bytes, gemm_config, stream, occupancy); - // } - // else if (sm_ >= 75 && sm_ < 80) { - // dispatch_gemm_to_cutlass( - // A, B, weight_scales, biases, C, m, n, k, workspace_ptr, - // workspace_bytes, gemm_config, stream, occupancy); - // } - // else - if (sm_ >= 80 && sm_ < 90) { + // VLOG(3)<<__PRETTY_FUNCTION__; + if (sm_ >= 70 && sm_ < 75) { + dispatch_gemm_to_cutlass( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); + } else if (sm_ >= 75 && sm_ < 80) { + dispatch_gemm_to_cutlass( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); + } else if (sm_ >= 80 && sm_ < 90) { dispatch_gemm_to_cutlass( A, B, @@ -607,8 +664,8 @@ void CutlassFpAIntBGemmRunner::dispatch_to_arch( occupancy); } else { throw std::runtime_error( - "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported " - "for CUTLASS mixed type GEMM"); + "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for " + "CUTLASS mixed type GEMM"); } } @@ -626,6 +683,7 @@ void CutlassFpAIntBGemmRunner::run_gemm( char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { + // VLOG(3)<<__PRETTY_FUNCTION__; static constexpr bool is_weight_only = !std::is_same::value; const bool is_weight_only_encoder = m >= 512 ? true : false; std::vector candidate_configs = @@ -690,6 +748,7 @@ void CutlassFpAIntBGemmRunner::gemm_bias_act( char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { + // VLOG(3)<<__PRETTY_FUNCTION__; if (activation_type == "gelu") { run_gemm(A, B, @@ -742,6 +801,7 @@ void CutlassFpAIntBGemmRunner::gemm(const T* A, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { + // VLOG(3)<<__PRETTY_FUNCTION__; run_gemm(A, B, weight_scales, @@ -759,7 +819,8 @@ template int CutlassFpAIntBGemmRunner::getWorkspaceSize(const int m, const int n, const int k) { - // sizes for each config, which would launch the maximum number of blocks + // VLOG(3)<<__PRETTY_FUNCTION__; // These are the min tile sizes for each + // config, which would launch the maximum number of blocks const int max_grid_m = (m + 31) / 32; const int max_grid_n = (n + 127) / 128; // We need 4 bytes per block in the worst case. We launch split_k_limit in z @@ -811,4 +872,14 @@ int CutlassFpAIntBGemmRunner::getWorkspaceSize(const int m, return 0; } +template class CutlassFpAIntBGemmRunner; +template class CutlassFpAIntBGemmRunner; +#ifdef PADDLE_CUDA_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t>; +#endif +template class CutlassFpAIntBGemmRunner; +template class CutlassFpAIntBGemmRunner; +#ifdef PADDLE_CUDA_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t>; +#endif } // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h b/paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h index 6d2fa3ed33743..69e737fa21157 100644 --- a/paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h +++ b/paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h @@ -14,6 +14,20 @@ * limitations under the License. */ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include diff --git a/paddle/phi/kernels/gpu/llm_int8_matmul_kernel.cu b/paddle/phi/kernels/gpu/llm_int8_linear_kernel.cu similarity index 66% rename from paddle/phi/kernels/gpu/llm_int8_matmul_kernel.cu rename to paddle/phi/kernels/gpu/llm_int8_linear_kernel.cu index ed68d2d1e0177..3435a450ffd43 100644 --- a/paddle/phi/kernels/gpu/llm_int8_matmul_kernel.cu +++ b/paddle/phi/kernels/gpu/llm_int8_linear_kernel.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/llm_int8_matmul_kernel.h" +#include "paddle/phi/kernels/llm_int8_linear_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" -#ifndef PADDLE_WITH_HIP +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11020 #include "paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h" #endif @@ -26,12 +28,11 @@ template void llm_int8_compute(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& weight, + const paddle::optional& bias, const DenseTensor& weight_scale, const float threshold, DenseTensor* out) { -#if defined(PADDLE_WITH_HIP) - LOG(ERROR) << "Please compile with cublaslt, ROCM platform isn't support it"; -#else +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11020 DenseTensor cublaslt_workspace; cublaslt_workspace.Resize({{3000000}}); dev_ctx.template Alloc(&cublaslt_workspace); @@ -52,24 +53,35 @@ void llm_int8_compute(const Context& dev_ctx, m, k, n); + if (bias) { + std::vector ins = {out, &(bias.get())}; + std::vector outs = {out}; + phi::funcs::BroadcastKernel( + dev_ctx, ins, &outs, phi::funcs::AddFunctor()); + } +#else + PADDLE_THROW(phi::errors::Unimplemented( + "llm_int8_linear op needs paddle with cuda and cuda version >= 11.2")); #endif } template -void LLMInt8MatmulKernel(const Context& dev_ctx, +void LLMInt8LinearKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& weight, + const paddle::optional& bias, const DenseTensor& weight_scale, const float threshold, DenseTensor* out) { dev_ctx.template Alloc(out); llm_int8_compute( - dev_ctx, x, weight, weight_scale, threshold, out); + dev_ctx, x, weight, bias, weight_scale, threshold, out); } } // namespace phi -PD_REGISTER_KERNEL(llm_int8_matmul, +PD_REGISTER_KERNEL(llm_int8_linear, GPU, ALL_LAYOUT, - phi::LLMInt8MatmulKernel, - phi::dtype::float16) {} + phi::LLMInt8LinearKernel, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu new file mode 100644 index 0000000000000..65af9f9c6c2b5 --- /dev/null +++ b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu @@ -0,0 +1,160 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/phi/kernels/weight_only_linear_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/datatype_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/weight_only_gemv.h" +#if defined(PADDLE_WITH_CUTLASS) +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" +#endif + +namespace phi { + +template +void WeightOnlyLinearKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight, + const paddle::optional& bias, + const DenseTensor& weight_scale, + const std::string& weight_dtype, + DenseTensor* out) { + dev_ctx.template Alloc(out); + const T* x_data = x.data(); + const int8_t* weight_data = weight.data(); + const T* bias_data = bias ? bias.get().data() : nullptr; + const float* weight_scale_data = weight_scale.data(); + T* out_data = out->data(); + const auto x_dims = x.dims(); + const auto w_dims = weight.dims(); + int n = weight_scale.dims()[0]; + int k = w_dims[1]; + int m = x.numel() / k; + + // m > 1: run gemm + if (m > 1 || weight_dtype == "int4") { +#if defined(PADDLE_WITH_CUTLASS) + if (weight_dtype == "int8") { + auto mixed_gemm_runner = + CutlassFpAIntBGemmRunner::DataType, + uint8_t>(); + int mixgemm_max_size = std::max(m, k); + DenseTensor mixgemm_workspace; + int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize( + m, mixgemm_max_size, mixgemm_max_size); + + mixgemm_workspace.Resize({mixgemm_workspace_size_bytes}); + dev_ctx.template Alloc(&mixgemm_workspace); + char* mixgemm_workspace_data = + reinterpret_cast(mixgemm_workspace.data()); + if (bias_data) { + mixed_gemm_runner.gemm_bias_act( + reinterpret_cast::DataType*>( + x_data), + reinterpret_cast(weight_data), + weight_scale_data, + reinterpret_cast::DataType*>( + bias_data), + reinterpret_cast::DataType*>(out_data), + m, + n, + k, + "none", + mixgemm_workspace_data, + mixgemm_workspace_size_bytes, + dev_ctx.stream()); + } else { + mixed_gemm_runner.gemm( + reinterpret_cast::DataType*>( + x_data), + reinterpret_cast(weight_data), + weight_scale_data, + reinterpret_cast::DataType*>(out_data), + m, + n, + k, + mixgemm_workspace_data, + mixgemm_workspace_size_bytes, + dev_ctx.stream()); + } + } else { + auto mixed_gemm_runner = + CutlassFpAIntBGemmRunner::DataType, + cutlass::uint4b_t>(); + int mixgemm_max_size = std::max(m, k); + DenseTensor mixgemm_workspace; + int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize( + m, mixgemm_max_size, mixgemm_max_size); + + mixgemm_workspace.Resize({mixgemm_workspace_size_bytes}); + dev_ctx.template Alloc(&mixgemm_workspace); + char* mixgemm_workspace_data = + reinterpret_cast(mixgemm_workspace.data()); + if (bias_data) { + mixed_gemm_runner.gemm_bias_act( + reinterpret_cast::DataType*>( + x_data), + reinterpret_cast(weight_data), + weight_scale_data, + reinterpret_cast::DataType*>( + bias_data), + reinterpret_cast::DataType*>(out_data), + m, + n, + k, + "none", + mixgemm_workspace_data, + mixgemm_workspace_size_bytes, + dev_ctx.stream()); + } else { + mixed_gemm_runner.gemm( + reinterpret_cast::DataType*>( + x_data), + reinterpret_cast(weight_data), + weight_scale_data, + reinterpret_cast::DataType*>(out_data), + m, + n, + k, + mixgemm_workspace_data, + mixgemm_workspace_size_bytes, + dev_ctx.stream()); + } + } +#else + PADDLE_THROW(phi::errors::Unimplemented( + "Please compile with cutlass to make cutlass available")); +#endif + } else { // m == 1: gemv + if (weight_dtype == "int8") { + GemvWeightonlyInt8Wrapper(dev_ctx, + x_data, + weight_data, + bias_data, + weight_scale_data, + n, + k, + "None", + out->data()); + } // TODO(lizhenyun) support weight_only_gemv for int4. + } +} +} // namespace phi + +PD_REGISTER_KERNEL(weight_only_linear, + GPU, + ALL_LAYOUT, + phi::WeightOnlyLinearKernel, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/weight_only_matmul_kernel.cu b/paddle/phi/kernels/gpu/weight_only_matmul_kernel.cu deleted file mode 100644 index ad88315875bb1..0000000000000 --- a/paddle/phi/kernels/gpu/weight_only_matmul_kernel.cu +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/weight_only_matmul_kernel.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/datatype_traits.h" -#include "paddle/phi/core/kernel_registry.h" -#if defined(PADDLE_WITH_CUTLASS) -#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" -#endif - -namespace phi { - -template -void WeightOnlyMatmulKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& weight, - const DenseTensor& weight_scale, - DenseTensor* out) { -#if defined(PADDLE_WITH_CUTLASS) - dev_ctx.template Alloc(out); - const auto x_dims = x.dims(); - const auto w_dims = weight.dims(); - int n = weight_scale.dims()[0]; - int quant_bit = 0; - if (n % w_dims[0] == 0) { - quant_bit = w_dims[0] * 8 / n; - } else { - errors::InvalidArgument( - "w_dims[0] must be divisible by weight_scale.dims()[0]"); - } - - int k = w_dims[1]; - int m = x.numel() / k; - switch (quant_bit) { - case 8: { - auto mixed_gemm_runner = - CutlassFpAIntBGemmRunner::DataType, - uint8_t>(); - int mixgemm_max_size = std::max(n, k); - DenseTensor mixgemm_workspace; - int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize( - m, mixgemm_max_size, mixgemm_max_size); - - mixgemm_workspace.Resize({mixgemm_workspace_size_bytes}); - dev_ctx.template Alloc(&mixgemm_workspace); - char* mixgemm_workspace_data = - reinterpret_cast(mixgemm_workspace.data()); - mixed_gemm_runner.gemm( - reinterpret_cast::DataType*>( - x.data()), - reinterpret_cast(weight.data()), - reinterpret_cast(weight_scale.data()), - reinterpret_cast::DataType*>( - out->data()), - m, - n, - k, - mixgemm_workspace_data, - mixgemm_workspace_size_bytes, - dev_ctx.stream()); - } break; - case 4: { - auto mixed_gemm_runner = - CutlassFpAIntBGemmRunner::DataType, - cutlass::uint4b_t>(); - int mixgemm_max_size = std::max(n, k); - DenseTensor mixgemm_workspace; - int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize( - m, mixgemm_max_size, mixgemm_max_size); - - mixgemm_workspace.Resize({mixgemm_workspace_size_bytes}); - dev_ctx.template Alloc(&mixgemm_workspace); - char* mixgemm_workspace_data = - reinterpret_cast(mixgemm_workspace.data()); - mixed_gemm_runner.gemm( - reinterpret_cast::DataType*>( - x.data()), - reinterpret_cast(weight.data()), - reinterpret_cast(weight_scale.data()), - reinterpret_cast::DataType*>( - out->data()), - m, - n, - k, - mixgemm_workspace_data, - mixgemm_workspace_size_bytes, - dev_ctx.stream()); - } break; - default: - PADDLE_THROW(errors::Unimplemented( - "Quant_bits (%d) is not supported when gemm ", quant_bit)); - break; - } - -#else - LOG(ERROR) << "Please compile with cutlass to EnableUseCutlass()"; -#endif -} -} // namespace phi - -PD_REGISTER_KERNEL(weight_only_matmul, - GPU, - ALL_LAYOUT, - phi::WeightOnlyMatmulKernel, - phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h b/paddle/phi/kernels/impl/weight_quantize_kernel_impl.h similarity index 93% rename from paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h rename to paddle/phi/kernels/impl/weight_quantize_kernel_impl.h index 096600bac0b7d..500efadd17df7 100644 --- a/paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h +++ b/paddle/phi/kernels/impl/weight_quantize_kernel_impl.h @@ -1,3 +1,19 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,9 +28,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_ -#define PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_ -#include +#pragma once + #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" @@ -29,13 +44,14 @@ inline T xabs(const T x) { } template -void per_channel_scale(float* scale, const T* input, size_t m, size_t n) { +void per_channel_scale( + float* scale, const T* input, size_t m, size_t n, float bound) { for (size_t i = 0; i < n; ++i) { T max = input[i]; for (size_t j = 0; j < m; ++j) { max = xabs(input[j * n + i]) > max ? xabs(input[j * n + i]) : max; } - scale[i] = static_cast(max) / 127.0; + scale[i] = static_cast(max) / bound; } } @@ -144,8 +160,7 @@ void add_bias_and_interleave_inplace(int8_t* tensor_ptr, size_t num_elts) { template void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor, - const std::vector& shape, - const int64_t arch_version) { + const std::vector& shape) { // We only want to run this step for weight only quant. const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; @@ -321,7 +336,7 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const size_t num_vec_rows = num_rows / elts_in_int32; const size_t vec_rows_per_tile = rows_per_tile / elts_in_int32; - const size_t interleave = 2; + const size_t interleave = 128 * 8 / quant_bit / rows_per_tile; for (size_t read_col = 0; read_col < num_cols; ++read_col) { const size_t write_col = read_col / interleave; for (size_t base_vec_row = 0; base_vec_row < num_vec_rows; @@ -345,4 +360,3 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, } } } // namespace phi -#endif // PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_ diff --git a/paddle/phi/kernels/llm_int8_matmul_kernel.h b/paddle/phi/kernels/llm_int8_linear_kernel.h similarity index 83% rename from paddle/phi/kernels/llm_int8_matmul_kernel.h rename to paddle/phi/kernels/llm_int8_linear_kernel.h index 0d6229ea5af54..4e9251cab1f13 100644 --- a/paddle/phi/kernels/llm_int8_matmul_kernel.h +++ b/paddle/phi/kernels/llm_int8_linear_kernel.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -16,9 +16,10 @@ limitations under the License. */ namespace phi { template -void LLMInt8MatmulKernel(const Context& dev_ctx, +void LLMInt8LinearKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& weight, + const paddle::optional& bias, const DenseTensor& weight_scale, const float threshold, DenseTensor* out); diff --git a/paddle/phi/kernels/weight_only_matmul_kernel.h b/paddle/phi/kernels/weight_only_linear_kernel.h similarity index 77% rename from paddle/phi/kernels/weight_only_matmul_kernel.h rename to paddle/phi/kernels/weight_only_linear_kernel.h index f2f20294021e2..19d4d274964b8 100644 --- a/paddle/phi/kernels/weight_only_matmul_kernel.h +++ b/paddle/phi/kernels/weight_only_linear_kernel.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -16,9 +16,11 @@ limitations under the License. */ namespace phi { template -void WeightOnlyMatmulKernel(const Context& dev_ctx, +void WeightOnlyLinearKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& weight, + const paddle::optional& bias, const DenseTensor& weight_scale, + const std::string& weight_dtype, DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/quant_for_compress_kernel.h b/paddle/phi/kernels/weight_quantize_kernel.h similarity index 65% rename from paddle/phi/kernels/quant_for_compress_kernel.h rename to paddle/phi/kernels/weight_quantize_kernel.h index 474589f60234c..ea49b3ffb2dce 100644 --- a/paddle/phi/kernels/quant_for_compress_kernel.h +++ b/paddle/phi/kernels/weight_quantize_kernel.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -16,10 +16,9 @@ limitations under the License. */ namespace phi { template -void QuantForCompressKernel(const Context& dev_ctx, - const DenseTensor& x, - int bits, - const std::string& layout, - DenseTensor* out, - DenseTensor* scale); +void WeightQuantizeKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::string& algo, + DenseTensor* out, + DenseTensor* scale); } // namespace phi diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 477a3b7c57b4e..dbef7079c1bf3 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -71,7 +71,6 @@ from .layer.common import Unfold # noqa: F401 from .layer.common import Fold # noqa: F401 from .layer.common import Unflatten # noqa: F401 -from .layer.common import LinearCompress # noqa: F401 from .layer.pooling import AvgPool1D # noqa: F401 from .layer.pooling import AvgPool2D # noqa: F401 from .layer.pooling import AvgPool3D # noqa: F401 diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index b3e17f5fbd34a..87f2eabba1f59 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -65,8 +65,6 @@ from .common import upsample # noqa: F401 from .common import bilinear # noqa: F401 from .common import class_center_sample # noqa: F401 -from .common import quant_for_compress # noqa: F401 -from .common import linear_compress # noqa: F401 from .conv import conv1d # noqa: F401 from .conv import conv1d_transpose # noqa: F401 from .common import linear # noqa: F401 diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index a111d8384d1a2..e513fb0670ef7 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1877,86 +1877,6 @@ def linear(x, weight, bias=None, name=None): return res -def quant_for_compress(x, bits=8, layout="weight_only"): - return _C_ops.quant_for_compress(x, bits, layout) - - -def linear_compress( - x, - weight, - weight_scale, - bias=None, - bits=8, - algo="llm.int8", - name=None, - config=None, -): - if in_dynamic_mode(): - if algo == "llm.int8": - y = _C_ops.llm_int8_matmul( - x, weight, weight_scale, config['threshold'] - ) - elif algo == "weight_only": - y = _C_ops.weight_only_matmul(x, weight, weight_scale) - else: - raise ValueError( - "Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'.".format( - algo - ) - ) - if bias is not None: - y = paddle.add(y, bias) - return y - else: - helper = LayerHelper('linear_compress', **locals()) - dtype = x.dtype - - check_variable_and_dtype(x, 'x', ['float16'], 'linear_compress') - check_dtype(dtype, 'dtype', ['float16'], 'linear_compress') - - if algo == "llm.int8": - type = "llm_int8_matmul" - inputs = { - 'x': [x], - 'weight': [weight], - 'weight_scale': [weight_scale], - } - attrs = {'algo': algo, 'threshold': config['threshold']} - elif algo == "weight_only": - type = "weight_only_matmul" - inputs = { - 'x': [x], - 'weight': [weight], - 'weight_scale': [weight_scale], - } - attrs = {} - else: - raise ValueError( - "Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'.".format( - algo - ) - ) - tmp = helper.create_variable_for_type_inference(dtype) - - helper.append_op( - type=type, - inputs=inputs, - outputs={'Out': tmp}, - attrs=attrs, - ) - if bias is not None: - res = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type='elementwise_add', - inputs={'X': [tmp], 'Y': [bias]}, - outputs={'Out': [res]}, - attrs={'axis': -1}, - ) - else: - res = tmp - return res - - def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): r""" Label smoothing is a mechanism to regularize the classifier layer and is called diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 539a030ad21e0..c6bca2efa78a7 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -194,173 +194,6 @@ def extra_repr(self): ) -class LinearCompress(Layer): - r""" - - Fully-connected linear transformation layer. For each input :math:`X` , - the equation is: - - .. math:: - - Out = XW + b - - where :math:`W` is the weight and :math:`b` is the bias. - - Linear layer takes only one multi-dimensional tensor as input with the - shape :math:`[batch\_size, *, in\_features]` , where :math:`*` means any - number of additional dimensions. It multiplies input tensor with the weight - (a 2-D tensor of shape :math:`[in\_features, out\_features]` ) and produces - an output tensor of shape :math:`[batch\_size, *, out\_features]` . - If :math:`bias\_attr` is not False, the bias (a 1-D tensor of - shape :math:`[out\_features]` ) will be created and added to the output. - - Parameters: - in_features (int): The number of input units. - out_features (int): The number of output units. - weight_attr (ParamAttr, optional): The attribute for the weight of this layer. - The default value is None. If the Initializer of the - param_attr is not set, the parameter is initialized with Xavier. - For detailed information, please refer to paddle.ParamAttr. - bias_attr (ParamAttr|bool, optional): The attribute for the bias of this layer. - If it is set to False, no bias will be added to the output. - If it is set to None or one kind of ParamAttr, a bias parameter will - be created according to ParamAttr. For detailed information, please refer - to paddle.ParamAttr. The default value is None and the bias will be - initialized to zero. - name (str, optional): Normally there is no need for user to set this parameter. - For detailed information, please refer to :ref:`api_guide_Name` . - bits (int, optional): The attribute to set num of bits in quant during weight_only, - it must be set as 8, default: 8. - algo (str, optional): The attribute to set algorithm of cpmoress, it must be set as 'weight_only' - or 'llm.int8', default: weight_only. - config (dict, optional): The parameter config for algorithm of cpmoress. - For llm.int8, it should be set as {'threshold': 6.0}, default: {'threshold': 6.0}. - - Attribute: - **weight** (Parameter): the learnable weight of this layer. - - **bias** (Parameter): the learnable bias of this layer. - - Shape: - - input: Multi-dimentional tensor with shape :math:`[batch\_size, *, in\_features]` . Its data types are float16. - - output: Multi-dimentional tensor with shape :math:`[batch\_size, *, out\_features]` . The data type is the same as the input . - - Examples: - .. code-block:: python - - >>> import paddle - >>> paddle.seed(100) - - >>> # Define the linear layer. - >>> paddle.set_default_dtype('float16') - >>> weight_attr = paddle.ParamAttr( - ... name="weight", - ... initializer=paddle.nn.initializer.Constant(value=0.5)) - - >>> bias_attr = paddle.ParamAttr( - ... name="bias", - ... initializer=paddle.nn.initializer.Constant(value=1.0)) - - >>> linear = paddle.nn.LinearCompress(128, 64, weight_attr=weight_attr, bias_attr=bias_attr, bits=8, algo='weight_only') - >>> x = paddle.randn((3, 128), dtype="float16") - >>> y = linear(x) - """ - - def __init__( - self, - in_features, - out_features, - weight_attr=None, - bias_attr=None, - name=None, - bits=8, - algo="weight_only", - config={'threshold': 6.0}, - ): - super().__init__() - self._dtype = self._helper.get_default_dtype() - self._weight_attr = weight_attr - self._bias_attr = bias_attr - self.weight = self.create_parameter( - shape=[in_features, out_features], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - self.bias = self.create_parameter( - shape=[out_features], - attr=self._bias_attr, - dtype=self._dtype, - is_bias=True, - ) - self.weight_scale = self.create_parameter( - shape=[out_features], - attr=None, - dtype=self._dtype, - is_bias=False, - ) - self.is_weight_quanted = False - self.name = (name,) - self.bits = bits - self.layout = algo - self.algo = algo - self.config = config - - def forward(self, input): - if in_dynamic_mode(): - if not self.is_weight_quanted: - weight_tensor, weight_scale_tensor = F.quant_for_compress( - self.weight, self.bits, self.layout - ) - weight_attr = paddle.framework.ParamAttr( - initializer=paddle.nn.initializer.Assign(weight_tensor) - ) - weight_shape = ( - [self.weight.shape[1], self.weight.shape[0]] - if self.bits == 8 - else [self.weight.shape[1] / 2, self.weight.shape[0]] - ) - self.weight = self.create_parameter( - shape=weight_shape, - attr=weight_attr, - dtype="int8", - is_bias=False, - ) - weight_scale_attr = paddle.framework.ParamAttr( - initializer=paddle.nn.initializer.Assign( - weight_scale_tensor - ) - ) - self.weight_scale = self.create_parameter( - shape=self.weight_scale.shape, - attr=weight_scale_attr, - dtype="float32", - is_bias=False, - ) - self.is_weight_quanted = True - out = F.linear_compress( - x=input, - weight=self.weight, - weight_scale=self.weight_scale, - bias=self.bias, - bits=self.bits, - algo=self.algo, - name=self.name, - config=self.config, - ) - return out - - def extra_repr(self): - name_str = f', name={self.name}' if self.name else '' - return 'in_features={}, out_features={}, dtype={}{}, algo={}'.format( - self.weight.shape[0], - self.weight.shape[1], - self._dtype, - name_str, - self.algo, - ) - - class Upsample(Layer): """ This op resizes a batch of images. diff --git a/python/paddle/nn/quant/__init__.py b/python/paddle/nn/quant/__init__.py index cd221dd29bcfd..4962aacb4a5bd 100644 --- a/python/paddle/nn/quant/__init__.py +++ b/python/paddle/nn/quant/__init__.py @@ -22,8 +22,11 @@ from .functional_layers import concat # noqa: F401 from .functional_layers import flatten # noqa: F401 from .functional_layers import matmul # noqa: F401 +from .quantized_linear import weight_only_linear # noqa: F401 +from .quantized_linear import llm_int8_linear # noqa: F401 +from .quantized_linear import weight_quantize # noqa: F401 from .quant_layers import QuantStub # noqa: F401 from . import qat from .stub import Stub -__all__ = ["Stub"] +__all__ = ["Stub", "weight_only_linear", "llm_int8_linear", "weight_quantize"] diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py new file mode 100644 index 0000000000000..0d234e7de0c5f --- /dev/null +++ b/python/paddle/nn/quant/quantized_linear.py @@ -0,0 +1,189 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import _C_ops +from paddle.framework import LayerHelper, in_dynamic_mode + + +def weight_quantize(x, algo="weight_only_int8"): + """ + Quantization function for weight_only and llm.int8's weight. + + Args: + x (Tensor): The input Tensor to be quantized, the data type is float16 or bfloat16. + algo (str|None): The algo that is x will be apply, must be one of 'weight_only_int8', + 'weight_only_int4' and 'llm.int8', default: 'weight_only_int8'. + + Returns: + out (Tensor): The Tensor which is the quantitative results, the data type is the same as that of x. + scale (Tensor): The scale Tensor which is the scale of pre-channel, the data type is float32. + Examples: + .. code-block:: python + + import paddle + import numpy as np + from paddle.nn.quant import weight_quantize + + paddle.device.set_device("cpu") + x = np.random.randn(64, 32).astype('float16') + x = paddle.to_tensor(x, dtype=paddle.float16, place=paddle.CPUPlace()) + out, scale = weight_quantize(x, algo='weight_only_int8') + print(out.shape) # [32, 64] + print(scale.shape) # [32] + """ + + if in_dynamic_mode(): + return _C_ops.weight_quantize(x, algo) + else: + type = "weight_quantize" + helper = LayerHelper(type, **locals()) + out = helper.create_variable_for_type_inference('int8') + scale = helper.create_variable_for_type_inference('float') + + helper.append_op( + type=type, + inputs={"x": x}, + outputs={'out': out, "scale": scale}, + attrs={"algo": algo}, + ) + return (out, scale) + + +def weight_only_linear( + x, + weight, + bias=None, + weight_scale=None, + weight_dtype="int8", +): + """ + Applies matrix multiplication of two tensors and then bias addition if provided. + This method requires CUDA version >= 11.2. + + Args: + x (Tensor): The first input Tensor to be multiplied, the data type is float16 or bfloat16. + weight (Tensor): The second input Tensor to be multiplied. Its rank must be 2. + bias (Tensor|None): The input bias Tensor. If it is None, no bias addition would + be performed. Otherwise, The bias is added to the matrix multiplication result. + weight_scale (Tensor|None): The input scale Tensor Provided to weight for dequantization. Its rank must be 1. + weight_dtype(str): The dtype of weight Tensor, must be one of 'int8', 'int4', Defaulted to 'int8'. + Returns: + Tensor: the output Tensor, the data type is the same as that of x. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + from paddle.nn.quant import weight_only_linear + + x = paddle.cast(paddle.randn([1, 2, 64]), dtype='float16') + weight = paddle.cast(paddle.randint(0, 127, [32, 64]), dtype='int8') + scale = paddle.randn([32], dtype='float32') + bias = paddle.cast(paddle.randn([32]), dtype='float16') + if paddle.device.cuda.get_device_capability()[0] >= 8: + out = weight_only_linear(x, weight, bias=bias, weight_scale=scale, weight_dtype='int8') + print(out.shape) # [1, 2, 32] + """ + if in_dynamic_mode(): + out = _C_ops.weight_only_linear( + x, weight, bias, weight_scale, weight_dtype + ) + return out + else: + type = "weight_only_linear" + helper = LayerHelper(type, **locals()) + dtype = x.dtype + + inputs = { + 'x': [x], + 'weight': [weight], + 'bias': [bias], + 'weight_scale': [weight_scale], + } + attrs = {'weight_dtype': weight_dtype} + + out = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type=type, + inputs=inputs, + outputs={'out': out}, + attrs=attrs, + ) + return out + + +def llm_int8_linear( + x, + weight, + bias=None, + weight_scale=None, + threshold=6.0, +): + """ + Applies matrix multiplication of two tensors and then bias addition if provided. + This method requires CUDA version >= 11.2. + + Args: + x (Tensor): the first input Tensor to be multiplied, the data type is float16 or bfloat16. + weight (Tensor): the second input Tensor to be multiplied. Its rank must be 2. + bias (Tensor|None): the input bias Tensor. If it is None, no bias addition would + be performed. Otherwise, the bias is added to the matrix multiplication result. + weight_scale (Tensor|None): the input scale Tensor Provided to weight for dequantization. Its rank must be 1. + threshold(float): The min value of outlier in activation, outlier's channel will be apply multiply with x.dtype. + + Returns: + Tensor: the output Tensor, the data type is the same as that of x. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + from paddle.nn.quant import llm_int8_linear + + x = paddle.cast(paddle.randn([1, 2, 64]), dtype='float16') + weight = paddle.cast(paddle.randint(0, 127, [32, 64]), dtype='int8') + scale = paddle.randn([32], dtype='float32') + bias = paddle.cast(paddle.randn([32]), dtype='float16') + if paddle.device.cuda.get_device_capability()[0] >= 8: + out = llm_int8_linear(x, weight, bias=bias, weight_scale=scale, threshold=6.0) + print(out.shape) # [1, 2, 32] + """ + if in_dynamic_mode(): + out = _C_ops.llm_int8_linear(x, weight, bias, weight_scale, threshold) + return out + else: + type = "llm_int8_linear" + helper = LayerHelper(type, **locals()) + dtype = x.dtype + + inputs = { + 'x': [x], + 'weight': [weight], + 'bias': [bias], + 'weight_scale': [weight_scale], + } + attrs = {'threshold': threshold} + + out = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type=type, + inputs=inputs, + outputs={'out': out}, + attrs=attrs, + ) + return out diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index b5ad632433da1..5f7bea43b307d 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -88,7 +88,6 @@ endif() list(REMOVE_ITEM TEST_OPS test_audio_logmel_feature test_audio_mel_feature) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) -list(REMOVE_ITEM TEST_OPS test_linear_compress) list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op) list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op) list(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass) @@ -159,7 +158,6 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) list(REMOVE_ITEM TEST_OPS test_rms_norm_op) list(REMOVE_ITEM TEST_OPS test_fused_layernorm_op) - list(REMOVE_ITEM TEST_OPS test_linear_compress) list(REMOVE_ITEM TEST_OPS test_matmul_int8_op) list(REMOVE_ITEM TEST_OPS test_variable_length_memory_efficient_attention) endif() diff --git a/test/legacy_test/test_linear_compress.py b/test/legacy_test/test_linear_compress.py deleted file mode 100644 index 438e42a9891df..0000000000000 --- a/test/legacy_test/test_linear_compress.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np - -import paddle -from paddle import fluid -from paddle.fluid.framework import default_main_program -from paddle.framework import set_default_dtype - -np.random.seed(123) -paddle.seed(123) -default_main_program().random_seed = 42 -paddle.disable_static() - - -class LinearTestCase(unittest.TestCase): - def config(self): - self.dtype = 'float16' - self.rtol = 1e-5 - self.atol = 1e-2 - self.bias = True - self.in_features = 64 - self.out_features = 64 - self.algo = "weight_only" - self.bits = 8 - - def setUp(self): - self.config() - input = np.random.random((2, 4, self.in_features)) - self.input = paddle.to_tensor(input, dtype=self.dtype) - if self.bias: - bias_attr = fluid.ParamAttr( - learning_rate=0.8, - trainable=False, - regularizer=None, - initializer=paddle.nn.initializer.Constant(value=1.0), - ) - else: - bias_attr = None - set_default_dtype(self.dtype) - self.linear = paddle.nn.Linear( - self.in_features, self.out_features, bias_attr=bias_attr - ) - if self.algo == "llm.int8": - self.config = {"threshold": 6.0} - else: - self.config = None - self.linear_compress = paddle.nn.LinearCompress( - self.in_features, - self.out_features, - bias_attr=bias_attr, - bits=8, - algo=self.algo, - config=self.config, - ) - self.linear_compress(self.input) - - def get_linear_out(self): - out = self.linear(self.input) - return out.numpy() - - def get_linear_compress_out(self): - out = self.linear_compress(self.input) - return out.numpy() - - def test_linear_compress(self): - out_real = self.get_linear_compress_out() - out_expect = self.get_linear_out() - np.testing.assert_allclose( - out_real, out_expect, rtol=self.rtol, atol=self.atol - ) - - -class LinearTestCase1(LinearTestCase): - def config(self): - super().config() - self.dtype = 'float16' - self.bias = True - self.in_features = 128 - self.out_features = 64 - - -class LinearTestCase2(LinearTestCase): - def config(self): - super().config() - self.dtype = 'float16' - self.bias = False - self.in_features = 64 - self.out_features = 64 - - -class LinearTestCase3(LinearTestCase): - def config(self): - super().config() - self.dtype = 'float16' - self.bias = False - self.in_features = 64 - self.out_features = 64 - self.algo = "llm.int8" - self.atol = 1e-1 - - -class LinearTestCase4(LinearTestCase): - def config(self): - super().config() - self.dtype = 'float16' - self.bias = True - self.in_features = 128 - self.out_features = 64 - self.bits = 4 - - -if __name__ == '__main__': - unittest.main() diff --git a/test/quantization/CMakeLists.txt b/test/quantization/CMakeLists.txt index 8a8596beb868d..310ee93824dfa 100644 --- a/test/quantization/CMakeLists.txt +++ b/test/quantization/CMakeLists.txt @@ -227,6 +227,8 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp) list(REMOVE_ITEM TEST_OPS test_imperative_qat_lsq) list(REMOVE_ITEM TEST_OPS test_imperative_qat_matmul) + list(REMOVE_ITEM TEST_OPS test_weight_only_linear) + list(REMOVE_ITEM TEST_OPS test_llm_int8_linear) list(REMOVE_ITEM TEST_OPS test_quant_aware) list(REMOVE_ITEM TEST_OPS test_quant_post_quant_aware) list(REMOVE_ITEM TEST_OPS test_quant_aware_user_defined) @@ -235,6 +237,11 @@ if(WIN32) endif() +if(NOT WITH_GPU) + list(REMOVE_ITEM TEST_OPS test_weight_only_linear) + list(REMOVE_ITEM TEST_OPS test_llm_int8_linear) +endif() + if(LINUX AND WITH_MKLDNN) #### Image classification dataset: ImageNet (small) diff --git a/test/quantization/test_llm_int8_linear.py b/test/quantization/test_llm_int8_linear.py new file mode 100644 index 0000000000000..b26285c3049f0 --- /dev/null +++ b/test/quantization/test_llm_int8_linear.py @@ -0,0 +1,308 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from test_weight_only_linear import convert_uint16_to_float, get_cuda_version + +import paddle +import paddle.nn.quant as Q +from paddle import fluid +from paddle.fluid import core +from paddle.fluid.framework import default_main_program +from paddle.framework import set_default_dtype + +np.random.seed(123) +paddle.seed(123) +default_main_program().random_seed = 42 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class LLMInt8LinearTestCase(unittest.TestCase): + def config(self): + self.dtype = 'float16' + self.rtol = 1e-5 + self.atol = 1e-1 + self.bias = True + self.batch = 1 + self.token = 32 + self.in_features = 64 + self.out_features = 256 + self.threshold = 6.0 + self.static = False + + def setUp(self): + self.config() + x = np.random.random((self.batch, self.token, self.in_features)) + self.x = paddle.to_tensor(x, dtype=self.dtype) + if self.bias: + bias_attr = fluid.ParamAttr( + trainable=False, + regularizer=None, + initializer=paddle.nn.initializer.Constant(value=1.0), + ) + else: + bias_attr = None + set_default_dtype(self.dtype) + self.linear = paddle.nn.Linear( + self.in_features, self.out_features, bias_attr=bias_attr + ) + + self.bias = self.linear.bias + self.weight = self.linear.weight + self.weight_scale = None + self.weight, self.weight_scale = Q.weight_quantize( + self.weight, algo="llm.int8" + ) + + def get_linear_out(self): + out = self.linear(self.x) + return out.numpy() + + def get_llm_int8_linear_out(self): + out = Q.llm_int8_linear( + self.x, + self.weight, + bias=self.bias, + weight_scale=self.weight_scale, + threshold=self.threshold, + ) + return out.numpy() + + def get_llm_int8_linear_out_static(self): + paddle.enable_static() + main = fluid.Program() + start = fluid.Program() + with fluid.program_guard(main, start): + x = paddle.static.data("x", self.x.shape, dtype=self.x.dtype) + + weight = paddle.static.data( + "weight", self.weight.shape, dtype=self.weight.dtype + ) + bias = paddle.static.data( + "bias", self.bias.shape, dtype=self.bias.dtype + ) + x_np = self.x.numpy() + weight_np = self.weight.numpy() + bias_np = self.bias.numpy() + if self.weight_scale is not None: + weight_scale = paddle.static.data( + "weight_scale", + self.weight_scale.shape, + dtype=self.weight_scale.dtype, + ) + weight_scale_np = self.weight_scale.numpy() + else: + weight_scale = None + weight_scale_np = None + + out = Q.llm_int8_linear( + x, + weight, + bias, + weight_scale, + self.threshold, + ) + feed_dict = { + 'x': x_np, + 'weight': weight_np, + 'bias': bias_np, + "weight_scale": weight_scale_np, + } + exe = fluid.Executor(paddle.CUDAPlace(0)) + exe.run(start) + (out,) = exe.run(main, feed=feed_dict, fetch_list=[out]) + paddle.disable_static() + return out + + def test_llm_int8_linear(self): + out_expect = self.get_linear_out() + if self.static: + out_real = self.get_llm_int8_linear_out_static() + else: + out_real = self.get_llm_int8_linear_out() + + if self.dtype == "bfloat16": + out_real = convert_uint16_to_float(out_real) + out_expect = convert_uint16_to_float(out_expect) + np.testing.assert_allclose( + out_real, out_expect, rtol=self.rtol, atol=self.atol + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class LLMInt8LinearTestCase1(LLMInt8LinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int8" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class LLMInt8LinearTestCase2(LLMInt8LinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.bias = False + self.weight_dtype = "int8" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class LLMInt8LinearTestCase3(LLMInt8LinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int8" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8 + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16", +) +class LLMInt8LinearTestCase4(LLMInt8LinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int4" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class LLMInt8LinearTestCase5(LLMInt8LinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.bias = False + self.weight_dtype = "int4" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8 + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16", +) +class LLMInt8LinearTestCase6(LLMInt8LinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int4" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class LLMInt8LinearTestCase7(LLMInt8LinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int8" + self.batch = 1 + self.token = 1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class LLMInt8LinearTestCase8(LLMInt8LinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int8" + self.bias = False + self.batch = 1 + self.token = 1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class LLMInt8LinearTestCase9(LLMInt8LinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int8" + self.batch = 1 + self.token = 1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class LLMInt8LinearTestCase10(LLMInt8LinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int8" + self.bias = False + self.batch = 1 + self.token = 1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class LLMInt8LinearTestCaseStatic(LLMInt8LinearTestCase): + def config(self): + super().config() + self.static = True + + +if __name__ == '__main__': + unittest.main() diff --git a/test/quantization/test_weight_only_linear.py b/test/quantization/test_weight_only_linear.py new file mode 100644 index 0000000000000..6c30c13ec21d1 --- /dev/null +++ b/test/quantization/test_weight_only_linear.py @@ -0,0 +1,337 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import struct +import unittest + +import numpy as np + +import paddle +import paddle.nn.quant as Q +from paddle import fluid +from paddle.fluid import core +from paddle.fluid.framework import default_main_program +from paddle.framework import set_default_dtype + +np.random.seed(123) +paddle.seed(123) +default_main_program().random_seed = 42 + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +def convert_uint16_to_float(in_list): + in_list = np.asarray(in_list) + out = np.vectorize( + lambda x: struct.unpack( + '= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase(unittest.TestCase): + def config(self): + self.dtype = 'float16' + self.rtol = 1e-5 + self.atol = 1e-2 + self.bias = True + self.batch = 1 + self.token = 32 + self.in_features = 64 + self.out_features = 256 + self.weight_dtype = "int8" + self.static = False + + def setUp(self): + self.config() + if self.dtype == "bfloat16" or self.weight_dtype == "int4": + self.atol = 1e-1 + x = np.random.random((self.batch, self.token, self.in_features)) + self.x = paddle.to_tensor(x, dtype=self.dtype) + if self.bias: + bias_attr = fluid.ParamAttr( + trainable=False, + regularizer=None, + initializer=paddle.nn.initializer.Constant(value=1.0), + ) + else: + bias_attr = None + set_default_dtype(self.dtype) + self.linear = paddle.nn.Linear( + self.in_features, self.out_features, bias_attr=bias_attr + ) + + self.bias = self.linear.bias + self.weight = self.linear.weight + self.weight_scale = None + self.weight, self.weight_scale = Q.weight_quantize( + self.weight, + algo="weight_only_int8" + if self.weight_dtype == "int8" + else "weight_only_int4", + ) + + def get_linear_out(self): + out = self.linear(self.x) + return out.numpy() + + def get_weight_only_linear_out(self): + out = Q.weight_only_linear( + self.x, + self.weight, + bias=self.bias, + weight_scale=self.weight_scale, + weight_dtype=self.weight_dtype, + ) + return out.numpy() + + def get_weight_only_linear_out_static(self): + paddle.enable_static() + main = fluid.Program() + start = fluid.Program() + with fluid.program_guard(main, start): + x = paddle.static.data("x", self.x.shape, dtype=self.x.dtype) + + weight = paddle.static.data( + "weight", self.weight.shape, dtype=self.weight.dtype + ) + bias = paddle.static.data( + "bias", self.bias.shape, dtype=self.bias.dtype + ) + x_np = self.x.numpy() + weight_np = self.weight.numpy() + bias_np = self.bias.numpy() + if self.weight_scale is not None: + weight_scale = paddle.static.data( + "weight_scale", + self.weight_scale.shape, + dtype=self.weight_scale.dtype, + ) + weight_scale_np = self.weight_scale.numpy() + else: + weight_scale = None + weight_scale_np = None + + out = Q.weight_only_linear( + x, + weight, + bias, + weight_scale, + self.weight_dtype, + ) + feed_dict = { + 'x': x_np, + 'weight': weight_np, + 'bias': bias_np, + "weight_scale": weight_scale_np, + } + exe = fluid.Executor(paddle.CUDAPlace(0)) + exe.run(start) + (out,) = exe.run(main, feed=feed_dict, fetch_list=[out]) + paddle.disable_static() + return out + + def test_weight_only_linear(self): + out_expect = self.get_linear_out() + if self.static: + out_real = self.get_weight_only_linear_out_static() + else: + out_real = self.get_weight_only_linear_out() + + if self.dtype == "bfloat16": + out_real = convert_uint16_to_float(out_real) + out_expect = convert_uint16_to_float(out_expect) + np.testing.assert_allclose( + out_real, out_expect, rtol=self.rtol, atol=self.atol + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase1(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int8" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase2(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.bias = False + self.weight_dtype = "int8" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase3(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int8" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8 + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16", +) +class WeightOnlyLinearTestCase4(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int4" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase5(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.bias = False + self.weight_dtype = "int4" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8 + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16", +) +class WeightOnlyLinearTestCase6(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int4" + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase7(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int8" + self.batch = 1 + self.token = 1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase8(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int8" + self.bias = False + self.batch = 1 + self.token = 1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase9(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int8" + self.batch = 1 + self.token = 1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase10(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int8" + self.bias = False + self.batch = 1 + self.token = 1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCaseStatic(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.static = True + + +if __name__ == '__main__': + unittest.main()