Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 committed Dec 19, 2023
1 parent 1896d1f commit 5dc64e1
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions csrc/lc/quantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ __global__ void kQuantizeChannelwise(const float *code,
}
}

template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float *code, const paddle::Tensor& A, paddle::Tensor& absmax, unsigned char *out, int blocksize, int n)
template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float *code, const paddle::Tensor& A, paddle::Tensor& absmax, unsigned char *out, int blocksize, int n, int channelwise)
{
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
Expand All @@ -403,20 +403,22 @@ template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;

const DataType_* A_data = reinterpret_cast<const DataType_*>(A.data<data_t>());
if(blocksize == 4096)
kQuantizeBlockwise<DataType_, 4096, 4, 0><<<num_blocks, 1024>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 2048)
kQuantizeBlockwise<DataType_, 2048, 4, DATA_TYPE><<<num_blocks, 512>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 1024)
kQuantizeBlockwise<DataType_, 1024, 4, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 512)
kQuantizeBlockwise<DataType_, 512, 2, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 256)
kQuantizeBlockwise<DataType_, 256, 2, DATA_TYPE><<<num_blocks, 128>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 128)
kQuantizeBlockwise<DataType_, 128, 2, DATA_TYPE><<<num_blocks, 64>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 64)
kQuantizeBlockwise<DataType_, 64, 2, DATA_TYPE><<<num_blocks, 32>>>(code, A_data, absmax.data<float>(), out, n);
if (channelwise == 0) {
if(blocksize == 4096)
kQuantizeBlockwise<DataType_, 4096, 4, 0><<<num_blocks, 1024>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 2048)
kQuantizeBlockwise<DataType_, 2048, 4, DATA_TYPE><<<num_blocks, 512>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 1024)
kQuantizeBlockwise<DataType_, 1024, 4, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 512)
kQuantizeBlockwise<DataType_, 512, 2, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 256)
kQuantizeBlockwise<DataType_, 256, 2, DATA_TYPE><<<num_blocks, 128>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 128)
kQuantizeBlockwise<DataType_, 128, 2, DATA_TYPE><<<num_blocks, 64>>>(code, A_data, absmax.data<float>(), out, n);
else if(blocksize == 64)
kQuantizeBlockwise<DataType_, 64, 2, DATA_TYPE><<<num_blocks, 32>>>(code, A_data, absmax.data<float>(), out, n);
}
else {
if (DATA_TYPE == General8bit)
PD_THROW("blocksize is -1 only support NF4 and FP4.");
Expand Down Expand Up @@ -445,42 +447,44 @@ template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float

std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, int blocksize, std::string quant_type) {
int n = input.numel();
int channelwise = 0;
std::vector<int64_t> out_shape = input.shape();
if (quant_type != "8bit") { // 4bit
out_shape = {(n + 1) / 2, 1};
}
if (blocksize == -1){
blocksize = input.shape()[0];
out_shape = {input.shape()[0]/2, input.shape()[1]};
channelwise = 1;
}
auto out = paddle::empty(out_shape, paddle::DataType::UINT8, input.place());
int64_t absmax_shape = n / blocksize;
auto absmax = paddle::empty({absmax_shape}, paddle::DataType::FLOAT32, input.place());
switch(input.type()) {
case paddle::DataType::FLOAT32:
if (quant_type == "8bit")
quantize_blockwise<paddle::DataType::FLOAT32, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT32, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "nf4") {
quantize_blockwise<paddle::DataType::FLOAT32, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT32, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
}
else if (quant_type == "fp4")
quantize_blockwise<paddle::DataType::FLOAT32, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT32, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
return {out, absmax};
case paddle::DataType::FLOAT16:
if (quant_type == "8bit")
quantize_blockwise<paddle::DataType::FLOAT16, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT16, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "nf4")
quantize_blockwise<paddle::DataType::FLOAT16, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT16, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "fp4")
quantize_blockwise<paddle::DataType::FLOAT16, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT16, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
return {out, absmax};
case paddle::DataType::BFLOAT16:
if (quant_type == "8bit")
quantize_blockwise<paddle::DataType::BFLOAT16, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::BFLOAT16, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "nf4")
quantize_blockwise<paddle::DataType::BFLOAT16, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::BFLOAT16, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
else if (quant_type == "fp4")
quantize_blockwise<paddle::DataType::BFLOAT16, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::BFLOAT16, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
return {out, absmax};

default:
Expand Down

0 comments on commit 5dc64e1

Please sign in to comment.