Skip to content

Commit

Permalink
Wint8 gemm and gemv opt (#59291)
Browse files Browse the repository at this point in the history
* fpAintB split-k

* workspace

* fix error

* just_for_llama13b_bsz64-128

* llama13 opt

* fix scale type of weight ony quant

* draft gemv batched

* accuracy fix

* m size dispatch for gemv and gemm

* fit dispatch

* refine gemv

* remove useless kernel

* refine

* fix bug for split-k-limit

* fix bug for half scale

* weight quant kernel fit for half scale

* fix bf16 compile

* fix sm70 autogen

* fix sm70 compile error

* fix code style

* update

* update

* code-style

* code-style

* windows compile fix

* code-style

* fix merge bug

---------

Co-authored-by: wwbitejotunn <wwbitejotunn@outlook.com>
  • Loading branch information
2 people authored and pull[bot] committed Jun 18, 2024
1 parent 238d99d commit 6489006
Show file tree
Hide file tree
Showing 25 changed files with 2,331 additions and 638 deletions.
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5221,8 +5221,8 @@ void WeightQuantizeInferMeta(const MetaTensor& x,

out->set_dtype(DataType::INT8);

scale->set_dims(common::make_ddim(dim_scale));
scale->set_dtype(DataType::FLOAT32);
scale->set_dims(phi::make_ddim(dim_scale));
scale->set_dtype(x.dtype());
}

void ChannelShuffleInferMeta(const MetaTensor& x,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ if(WITH_CUTLASS)
)

execute_process(
COMMAND
${CMAKE_COMMAND} -E remove_directory
"${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen"
COMMAND
${CMAKE_COMMAND} -E make_directory
"${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen"
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cpu/weight_quantize_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void quant_compute(const DeviceContext& dev_ctx,
DDim dims = {num};
const T* x_data = x.data<T>();
D* out_data = out->data<D>();
float* scale_data = scale->data<float>();
T* scale_data = scale->data<T>();

DenseTensor x_int(out->type());

Expand Down Expand Up @@ -108,7 +108,7 @@ void WeightQuantizeKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* scale) {
dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<float>(scale);
dev_ctx.template Alloc<T>(scale);
if (algo == "weight_only_int8" || algo == "llm.int8") {
quant_compute<Context, T, int8_t, 8>(dev_ctx, x, out, scale, algo, arch);
} else if (algo == "weight_only_int4") {
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/kernels/funcs/weight_dequant_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ struct FastWeightOnlyHalfConverter<__nv_bfloat16, 4> {

template <typename T>
__global__ void int8_weight_only_dequant(const uint8_t* weight,
const float* scale_list,
const T* scale_list,
T* output,
const int n,
const int k) {
Expand All @@ -145,7 +145,7 @@ __global__ void int8_weight_only_dequant(const uint8_t* weight,
int row_id = tile_id * 2 + ((lane_id % 8) > 3 ? 1 : 0);
weight += tile_id * k * 2;
output += row_id * k;
float scale = scale_list[row_id];
float scale = static_cast<float>(scale_list[row_id]);
#pragma unroll
for (int i = lane_id * 16; i < k * 2; i += 16 * 32) {
Load<uint8_t, 16>(&weight[i], &vec_weight);
Expand Down Expand Up @@ -175,7 +175,7 @@ __global__ void int8_weight_only_dequant(const uint8_t* weight,

template <typename T>
__global__ void int4_weight_only_dequant(const uint8_t* weight,
const float* scale_list,
const T* scale_list,
T* output,
const int n,
const int k) {
Expand All @@ -201,7 +201,7 @@ __global__ void int4_weight_only_dequant(const uint8_t* weight,
int row_id = tile_id * 4 + ((lane_id % 8) / 2);
weight += tile_id * k / 2 * 4;
output += row_id * k;
float scale = scale_list[row_id];
float scale = static_cast<float>(scale_list[row_id]);
#pragma unroll
for (int i = lane_id * 32; i < k * 4; i += 32 * 32) {
Load<uint8_t, 16>(&weight[i / 2], &vec_weight);
Expand Down Expand Up @@ -249,15 +249,15 @@ void WeightDequantize(const Context& dev_ctx,
if (algo == "weight_only_int8") {
int8_weight_only_dequant<DataType><<<grid, block, 0, stream>>>(
reinterpret_cast<const uint8_t*>(x.data<int8_t>()),
scale.data<float>(),
reinterpret_cast<const DataType*>(scale.data<T>()),
reinterpret_cast<DataType*>(out->data<T>()),
n,
k);
} else if (algo == "weight_only_int4") {
grid.x /= 2;
int4_weight_only_dequant<DataType><<<grid, block, 0, stream>>>(
reinterpret_cast<const uint8_t*>(x.data<int8_t>()),
scale.data<float>(),
reinterpret_cast<const DataType*>(scale.data<T>()),
reinterpret_cast<DataType*>(out->data<T>()),
n,
k);
Expand Down
Loading

0 comments on commit 6489006

Please sign in to comment.