Skip to content

Commit

Permalink
support nf4 channel wise quant
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 committed Dec 13, 2023
1 parent 9c9b6a6 commit b5be6ca
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 3 deletions.
6 changes: 5 additions & 1 deletion csrc/lc/dequantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(const floa
//template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(const float *code, const unsigned char * A, const float * absmax, __nv_bfloat16 *out, int blocksize, int n);



template<typename T, int DATA_TYPE> void dequantize_blockwise(const float *code, const unsigned char *A, const float *absmax, T *out, int blocksize, int n)
{
int num_blocks = n/blocksize;
Expand Down Expand Up @@ -229,6 +228,11 @@ template void dequantize_blockwise<float, NF4>(const float *code, const unsigned
std::vector<paddle::Tensor> DequantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, const paddle::Tensor& absmax, int blocksize, std::string quant_type) {
int64_t input_numel = input.numel();
int n = input_numel;
if (blocksize == -1)
if (quant_type == "8bit")
PD_THROW("blocksize is -1 only support NF4 and FP4.");
else
blocksize = n / absmax.numel() * 2;
std::vector<int64_t> out_shape = input.shape();
if (quant_type != "8bit") { // 4bit
out_shape = {input_numel * 2, 1};
Expand Down
87 changes: 85 additions & 2 deletions csrc/lc/quantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ __global__ void kQuantizeBlockwise(const float * code, const T * __restrict__ A,
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit = 0;
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
Expand Down Expand Up @@ -360,7 +361,63 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, NF4)

template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis0(const T *in,
int n,
int c,
float *out) {
int tid = threadIdx.x;
int channel_size = n / c;
const T *in_c = in + blockIdx.x * channel_size;
extern __shared__ char *shared_max_data_tmp[];
auto shared_max_data = reinterpret_cast<float *>(shared_max_data_tmp);
float local_max_data = (float)(0);
for (int i = tid; i < channel_size; i += blockDim.x) {
local_max_data = fmaxf(local_max_data, fabsf((float)in[i]));
}
shared_max_data[tid] = local_max_data;
__syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
shared_max_data[tid] = shared_max_data[tid + i];
}
__syncthreads();
}
if (tid == 0) {
out[blockIdx.x] = shared_max_data[0];
}
}

template <typename T, int DATA_TYPE>
__global__ void kQuantizeChannelwise(const float *code,
const T* A,
unsigned char* out,
float *absmax,
int n,
int cout) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;

int num = n / 2;
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
float local_absmax = absmax[(i * 2) % cout];
float inv_local_absmax = 1.0f/local_absmax;

unsigned char packed_4bit = 0;
switch(DATA_TYPE)
{
case FP4:
packed_4bit |= dQuantizeFP4(((float)A[2*i])*inv_local_absmax) << 4;
packed_4bit |= dQuantizeFP4(((float)A[2*i+1])*inv_local_absmax);
out[i] = packed_4bit;
break;
case NF4:
packed_4bit |= dQuantizeNF4(((float)A[2*i])*inv_local_absmax) << 4;
packed_4bit |= dQuantizeNF4(((float)A[2*i+1])*inv_local_absmax);
out[i] = packed_4bit;
break;
}
}
}

template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float *code, const paddle::Tensor& A, float *absmax, unsigned char *out, int blocksize, int n)
{
Expand All @@ -386,8 +443,31 @@ template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float
kQuantizeBlockwise<DataType_, 128, 2, DATA_TYPE><<<num_blocks, 64>>>(code, A_data, absmax, out, n);
else if(blocksize == 64)
kQuantizeBlockwise<DataType_, 64, 2, DATA_TYPE><<<num_blocks, 32>>>(code, A_data, absmax, out, n);
else
PD_THROW("only support blocksize is [64, 128, 256, 512, 1024, 2048, 4096].");
else {
if (DATA_TYPE == General8bit)
PD_THROW("blocksize is -1 only support NF4 and FP4.");

int cout = A.shape()[1];
int grid = cout;
int64_t max_threads = 1024;
int block = max_threads;
cudaMemset(absmax, 0, sizeof(DataType_) * cout);

FindChannelAbsMaxKernelQuantAxis0<DataType_>
<<<grid, block, block * sizeof(DataType_)>>>(
A_data, n, cout, absmax);
int64_t block_size =
std::min(static_cast<int64_t>(n),
static_cast<int64_t>(max_threads/ 4));

const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (n + block_size - 1) / block_size);

kQuantizeChannelwise<DataType_, DATA_TYPE><<<grid_size, block_size, 0>>>(
code, A_data, out, absmax, n, cout);
}


CUDA_CHECK_RETURN(cudaPeekAtLastError());
Expand All @@ -399,6 +479,9 @@ std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const
if (quant_type != "8bit") { // 4bit
out_shape = {(n + 1) / 2, 1};
}
if (blocksize == -1){
blocksize = out_shape[0];
}
auto out = paddle::empty(out_shape, paddle::DataType::UINT8, input.place());
int64_t absmax_shape = n / blocksize;
auto absmax = paddle::empty({absmax_shape}, paddle::DataType::FLOAT32, input.place());
Expand Down

0 comments on commit b5be6ca

Please sign in to comment.