Skip to content

Commit

Permalink
support nf4 channel wise quant & fix bug when blocksize>512 (PaddlePa…
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 authored and lizexu123 committed Feb 23, 2024
1 parent e18d916 commit 102c7b0
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 37 deletions.
84 changes: 74 additions & 10 deletions csrc/lc/dequantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(const floa
//template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(const float *code, const unsigned char * A, const float * absmax, __nv_bfloat16 *out, int blocksize, int n);



template<typename T, int DATA_TYPE> void dequantize_blockwise(const float *code, const unsigned char *A, const float *absmax, T *out, int blocksize, int n)
{
int num_blocks = n/blocksize;
Expand All @@ -226,6 +225,50 @@ template void dequantize_blockwise<float, NF4>(const float *code, const unsigned
//template void dequantize_blockwise<__nv_bfloat16, FP4>(const float *code, const unsigned char *A, const float *absmax, __nv_bfloat16 *out, int blocksize, int n);
//template void dequantize_blockwise<__nv_bfloat16, NF4>(const float *code, const unsigned char *A, const float *absmax, __nv_bfloat16 *out, int blocksize, int n);

template <typename T, int DATA_TYPE>
__global__ void kDequantizeChannelwise(const unsigned char* A,
const float *absmax,
float *out,
int n,
int cout) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;

int num = n / 2;
//int part_n = num / cout;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
float local_absmax = absmax[i%cout];
int idx = 2*(i/cout)* cout + i%cout;
switch(DATA_TYPE)
{
case FP4:
out[i*2 + i%cout] = dDequantizeFP4Tree(A[i] >> 4, local_absmax);
out[i*2 + cout + i%cout] = dDequantizeFP4Tree(A[i] & 0x0F, local_absmax);
break;
case NF4:
out[idx] = dDequantizeNF4(A[i] >> 4)* local_absmax;
out[idx + cout] = dDequantizeNF4(A[i] & 0x0F)* local_absmax;
break;
}
__syncthreads();
}
}

template<typename T, int DATA_TYPE> void dequantize_channelwise(const unsigned char *A, const float *absmax, T *out, int n, int cout)
{
int max_threads = 1024;
int64_t block_size =
std::min(static_cast<int64_t>(n),
static_cast<int64_t>(max_threads/ 4));

const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (n + block_size - 1) / block_size);

kDequantizeChannelwise<T, DATA_TYPE><<<grid_size, block_size>>>(A, absmax, out, n, cout);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

std::vector<paddle::Tensor> DequantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, const paddle::Tensor& absmax, int blocksize, std::string quant_type) {
int64_t input_numel = input.numel();
int n = input_numel;
Expand All @@ -234,23 +277,44 @@ std::vector<paddle::Tensor> DequantizeBlockwise(const paddle::Tensor& input, con
out_shape = {input_numel * 2, 1};
n = n * 2;
}
if (blocksize == -1) {
out_shape = {input.shape()[0] * 2, input.shape()[1]};
}
auto out = paddle::empty(out_shape, paddle::DataType::FLOAT32, input.place());

if (quant_type == "8bit")
dequantize_blockwise<float, General8bit>(code.data<float>(), input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
else if (quant_type == "nf4")
dequantize_blockwise<float, NF4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
else if (quant_type == "fp4")
dequantize_blockwise<float, FP4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
else
PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. ");
if (blocksize == -1) {
if (quant_type == "8bit")
PD_THROW("blocksize is -1 only support NF4 and FP4.");
else
blocksize = n / absmax.numel() * 2;

int cout = input.shape()[1];
if (quant_type == "nf4")
dequantize_channelwise<float, NF4>(input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), n, cout);
else if (quant_type == "fp4")
dequantize_channelwise<float, FP4>(input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), n, cout);
else
PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. ");
} else {
if (quant_type == "8bit")
dequantize_blockwise<float, General8bit>(code.data<float>(), input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
else if (quant_type == "nf4")
dequantize_blockwise<float, NF4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
else if (quant_type == "fp4")
dequantize_blockwise<float, FP4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
else
PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. ");
}
return {out};
};

std::vector<std::vector<int64_t>> GetDequantizeBlockwiseInferShape(const std::vector<int64_t>& input_shape, const std::vector<int64_t>& code_shape, const std::vector<int64_t>& abs_max_shape, int blocksize, std::string quant_type){
int64_t first_shape = input_shape[0] * input_shape[1] * 2;
if (quant_type != "8bit")
return {{first_shape, 1}};
if (blocksize != -1)
return {{first_shape, 1}};
else
return {{input_shape[0] * 2, input_shape[1]}};
else
return {input_shape};
}
Expand Down
115 changes: 88 additions & 27 deletions csrc/lc/quantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ __global__ void kQuantizeBlockwise(const float * code, const T * __restrict__ A,
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit = 0;
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
Expand Down Expand Up @@ -360,9 +361,39 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, NF4)

template <typename T, int DATA_TYPE>
__global__ void kQuantizeChannelwise(const float *code,
const T* A,
unsigned char* out,
float *absmax,
int n,
int cout) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;

int num = n / 2;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
int idx = 2*(i/cout)* cout + i%cout;
float local_absmax = absmax[i %cout];
float inv_local_absmax = 1.0f/local_absmax;

unsigned char packed_4bit = 0;
switch(DATA_TYPE)
{
case FP4:
packed_4bit |= dQuantizeFP4(((float)A[idx])*inv_local_absmax) << 4;
packed_4bit |= dQuantizeFP4(((float)A[idx+cout])*inv_local_absmax);
out[i] = packed_4bit;
break;
case NF4:
packed_4bit |= dQuantizeNF4(((float)A[idx])*inv_local_absmax) << 4;
packed_4bit |= dQuantizeNF4(((float)A[idx+cout])*inv_local_absmax);
out[i] = packed_4bit;
break;
}
}
}

template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float *code, const paddle::Tensor& A, float *absmax, unsigned char *out, int blocksize, int n)
template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float *code, const paddle::Tensor& A, paddle::Tensor& absmax, unsigned char *out, int blocksize, int n, int channelwise)
{
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
Expand All @@ -372,61 +403,88 @@ template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;

const DataType_* A_data = reinterpret_cast<const DataType_*>(A.data<data_t>());
if(blocksize == 4096)
kQuantizeBlockwise<DataType_, 4096, 4, 0><<<num_blocks, 1024>>>(code, A_data, absmax, out, n);
else if(blocksize == 2048)
kQuantizeBlockwise<DataType_, 2048, 4, DATA_TYPE><<<num_blocks, 512>>>(code, A_data, absmax, out, n);
else if(blocksize == 1024)
kQuantizeBlockwise<DataType_, 1024, 4, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax, out, n);
else if(blocksize == 512)
kQuantizeBlockwise<DataType_, 512, 2, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax, out, n);
else if(blocksize == 256)
kQuantizeBlockwise<DataType_, 256, 2, DATA_TYPE><<<num_blocks, 128>>>(code, A_data, absmax, out, n);
else if(blocksize == 128)
kQuantizeBlockwise<DataType_, 128, 2, DATA_TYPE><<<num_blocks, 64>>>(code, A_data, absmax, out, n);
else if(blocksize == 64)
kQuantizeBlockwise<DataType_, 64, 2, DATA_TYPE><<<num_blocks, 32>>>(code, A_data, absmax, out, n);
else
PD_THROW("only support blocksize is [64, 128, 256, 512, 1024, 2048, 4096].");
if (channelwise == 0) {
if(blocksize == 4096)
kQuantizeBlockwise<DataType_, 4096, 4, 0><<<num_blocks, 1024>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 2048)
kQuantizeBlockwise<DataType_, 2048, 4, DATA_TYPE><<<num_blocks, 512>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 1024)
kQuantizeBlockwise<DataType_, 1024, 4, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 512)
kQuantizeBlockwise<DataType_, 512, 2, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 256)
kQuantizeBlockwise<DataType_, 256, 2, DATA_TYPE><<<num_blocks, 128>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 128)
kQuantizeBlockwise<DataType_, 128, 2, DATA_TYPE><<<num_blocks, 64>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 64)
kQuantizeBlockwise<DataType_, 64, 2, DATA_TYPE><<<num_blocks, 32>>>(code, A_data, absmax.data<float>(), out, n);
}
else {
if (DATA_TYPE == General8bit)
PD_THROW("blocksize is -1 only support NF4 and FP4.");

int cout = A.shape()[1];
int max_threads = 1024;

absmax = A.abs().max({0});

int64_t block_size =
std::min(static_cast<int64_t>(n),
static_cast<int64_t>(max_threads/ 4));

const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (n + block_size - 1) / block_size);

kQuantizeChannelwise<DataType_, DATA_TYPE><<<grid_size, block_size, 0>>>(
code, A_data, out, absmax.data<float>(), n, cout);
}


CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, int blocksize, std::string quant_type) {
int n = input.numel();
int channelwise = 0;
std::vector<int64_t> out_shape = input.shape();
if (quant_type != "8bit") { // 4bit
out_shape = {(n + 1) / 2, 1};
}
if (blocksize == -1){
blocksize = input.shape()[0];
out_shape = {input.shape()[0]/2, input.shape()[1]};
channelwise = 1;
}
auto out = paddle::empty(out_shape, paddle::DataType::UINT8, input.place());
int64_t absmax_shape = n / blocksize;
auto absmax = paddle::empty({absmax_shape}, paddle::DataType::FLOAT32, input.place());
switch(input.type()) {
case paddle::DataType::FLOAT32:
if (quant_type == "8bit")
quantize_blockwise<paddle::DataType::FLOAT32, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT32, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "nf4") {
quantize_blockwise<paddle::DataType::FLOAT32, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT32, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
}
else if (quant_type == "fp4")
quantize_blockwise<paddle::DataType::FLOAT32, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT32, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
return {out, absmax};
case paddle::DataType::FLOAT16:
if (quant_type == "8bit")
quantize_blockwise<paddle::DataType::FLOAT16, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT16, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "nf4")
quantize_blockwise<paddle::DataType::FLOAT16, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT16, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "fp4")
quantize_blockwise<paddle::DataType::FLOAT16, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT16, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
return {out, absmax};
case paddle::DataType::BFLOAT16:
if (quant_type == "8bit")
quantize_blockwise<paddle::DataType::BFLOAT16, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::BFLOAT16, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "nf4")
quantize_blockwise<paddle::DataType::BFLOAT16, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::BFLOAT16, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "fp4")
quantize_blockwise<paddle::DataType::BFLOAT16, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::BFLOAT16, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
return {out, absmax};

default:
Expand All @@ -440,7 +498,10 @@ std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const
std::vector<std::vector<int64_t>> GetQuantizeBlockwiseInferShape(const std::vector<int64_t>& input_shape, const std::vector<int64_t>& code_shape, int blocksize, std::string quant_type){
int64_t first_shape = (input_shape[0] * input_shape[1] + 1) / 2;
if (quant_type != "8bit")
return {{first_shape, 1}};
if (blocksize != -1)
return {{first_shape, 1}};
else
return {{input_shape[0]/2, input_shape[1]}};
else
return {input_shape};
}
Expand Down

0 comments on commit 102c7b0

Please sign in to comment.