Skip to content

Commit

Permalink
Add matmul_int8 op (PaddlePaddle#55228)
Browse files Browse the repository at this point in the history
* add matmul int8
  • Loading branch information
RichardWooSJTU authored and wz1qqx committed Jul 31, 2023
1 parent 5268f2e commit 8996f03
Show file tree
Hide file tree
Showing 13 changed files with 542 additions and 134 deletions.
68 changes: 34 additions & 34 deletions paddle/fluid/operators/fused/attn_gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,29 +57,29 @@ class AttnMatmulINT8 {
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());
LaunchQuantKernel<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());

helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream());

dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>());
LaunchDequantKernel<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>());

if (compute_bias_) {
// bias_out = output + bias
Expand Down Expand Up @@ -126,14 +126,14 @@ class AttnMatmulINT8 {
output_tmp->data<int32_t>(),
dev_ctx_.stream());

dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>());
LaunchDequantKernel<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>());

if (compute_bias_) {
// bias_out = output + bias
Expand Down Expand Up @@ -162,15 +162,15 @@ class AttnMatmulINT8 {
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());
LaunchQuantKernel<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());

helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
Expand Down
80 changes: 40 additions & 40 deletions paddle/fluid/operators/fused/quant_dequant_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ __forceinline__ __device__ int8_t quant_helper(const T input,
}

template <typename T>
__global__ void quantize_kernel(const T* input,
char4* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound) {
__global__ void QuantKernel(const T* input,
char4* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound) {
int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2;
int m_id = blockIdx.y * blockDim.y + threadIdx.y;

Expand All @@ -74,36 +74,36 @@ __global__ void quantize_kernel(const T* input,
}

template <typename T>
void quantize_kernel_launcher(const T* input,
int8_t* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound,
gpuStream_t stream) {
void LaunchQuantKernel(const T* input,
int8_t* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound,
gpuStream_t stream) {
// TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1
dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32);
dim3 block(32, 32);

quantize_kernel<<<grid, block, 0, stream>>>(input,
(char4*)output, // NOLINT
scale,
m,
n,
round_type,
max_bound,
min_bound);
QuantKernel<<<grid, block, 0, stream>>>(input,
(char4*)output, // NOLINT
scale,
m,
n,
round_type,
max_bound,
min_bound);
}

template <typename T, int VecSize>
__global__ void dequantize_kernel(T* output,
const int32_t* input,
const int m, // batch size
const int n, // hidden
const float quant_in_scale,
const float* dequant_out_scale_data) {
__global__ void DequantKernel(T* output,
const int32_t* input,
const int m, // batch size
const int n, // hidden
const float quant_in_scale,
const float* dequant_out_scale_data) {
int numel = m * n;
int stride = blockDim.x * gridDim.x * VecSize;
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
Expand All @@ -128,15 +128,15 @@ __global__ void dequantize_kernel(T* output,
}

template <typename T>
void dequantize_kernel_launcher(const int32_t* input,
T* output,
const int m, // m
const int n, // n
gpuStream_t stream,
GpuLaunchConfig* gpu_config,
const float quant_in_scale,
const float* dequant_out_scale_data) {
dequantize_kernel<T, DequantKernelVecSize>
void LaunchDequantKernel(const int32_t* input,
T* output,
const int m, // m
const int n, // n
gpuStream_t stream,
GpuLaunchConfig* gpu_config,
const float quant_in_scale,
const float* dequant_out_scale_data) {
DequantKernel<T, DequantKernelVecSize>
<<<gpu_config->block_per_grid, gpu_config->thread_per_block, 0, stream>>>(
output, input, m, n, quant_in_scale, dequant_out_scale_data);
}
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,14 @@
func : matmul
backward : matmul_grad

- op : matmul_int8
args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false)
output : Tensor
infer_meta :
func : MatmulInt8InferMeta
kernel :
func : matmul_int8

- op : matrix_rank
args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false)
output : Tensor(out)
Expand Down
70 changes: 70 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2096,6 +2096,76 @@ void MatmulInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}

void MatmulInt8InferMeta(const MetaTensor& x,
const MetaTensor& y,
bool trans_x,
bool trans_y,
MetaTensor* out) {
std::vector<int64_t> dims_x = phi::vectorize(x.dims());
std::vector<int64_t> dims_y = phi::vectorize(y.dims());
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x,
0UL,
phi::errors::InvalidArgument(
"The Input(x) dims size must be greater than 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y,
0UL,
phi::errors::InvalidArgument(
"The Input(y) dims size must be greater than 0,"
" but reviced dims size is 0. "));

bool x_broadcasted = false, y_broadcasted = false;
if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1);
ndims_x = 2;
x_broadcasted = true;
}

if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}

size_t M, N;
if (trans_x) {
M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (trans_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}

std::vector<int64_t> new_dims;
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}

auto ddim_out = phi::make_ddim(new_dims);

out->set_dims(ddim_out);
out->set_dtype(phi::DataType::INT32);
out->set_layout(x.layout());
}

void MatmulWithFlattenInferMeta(const MetaTensor& x,
const MetaTensor& y,
int x_num_col_dims,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ void MatmulInferMeta(const MetaTensor& x,
bool trans_y,
MetaTensor* out);

void MatmulInt8InferMeta(const MetaTensor& x,
const MetaTensor& y,
bool trans_x,
bool trans_y,
MetaTensor* out);

void MatmulWithFlattenInferMeta(const MetaTensor& x,
const MetaTensor& y,
int x_num_col_dims,
Expand Down
29 changes: 10 additions & 19 deletions paddle/phi/kernels/funcs/cublaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,16 @@ const std::map<std::tuple<int, int, int>, CublasLtAlgoParam> AlgoParamCache{};

class CublasLtHelper {
public:
CublasLtHelper(int m, int k, int n)
: alpha_(1), beta_(0), m_(m), k_(k), n_(n) {
CublasLtHelper(int m, int k, int n, cublasLtHandle_t handle)
: handle_(handle), alpha_(1), beta_(0), m_(m), k_(k), n_(n) {
cublasStatus_t status;
// handle and matmul desc
status = dyl::cublasLtCreate(&handle_);
#if CUBLAS_VER_MAJOR < 11
cudaDataType_t cudaComputeType = CUDA_R_32I;
#else
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
#endif

PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));

// matmul desc
#if CUBLAS_VER_MAJOR < 11
status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType);
#else
Expand Down Expand Up @@ -179,7 +170,7 @@ class CublasLtHelper {
}
~CublasLtHelper() {}

void GEMM(int8_t* A_dev,
void GEMM(const int8_t* A_dev,
const int8_t* B_dev,
int32_t* C_dev,
cudaStream_t stream,
Expand Down Expand Up @@ -226,14 +217,14 @@ class CublasLtHelper {

cublasLtMatmulAlgo_t algo_;

int32_t alpha_;
int32_t beta_;
int32_t alpha_ = 1;
int32_t beta_ = 0;

int m_;
int k_;
int n_;
int m_ = 0;
int k_ = 0;
int n_ = 0;

size_t workspace_size_;
size_t workspace_size_ = 0;
};

} // namespace phi
Loading

0 comments on commit 8996f03

Please sign in to comment.