Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tune_cublaslt_gemm compile bug #8844

Merged
merged 3 commits into from
Jul 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 65 additions & 52 deletions csrc/generation/tune_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ limitations under the License. */
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <sys/time.h>

#include <algorithm>
#include <fstream>
#include <iostream>
#include <limits>
#include <list>
#include <vector>

#include "helper.h"

template <typename T>
Expand Down Expand Up @@ -172,7 +174,7 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
}

template <typename InT, typename OutT, typename ScaleT = OutT>
void FindAlgo(cublasLtHandle_t ltHandle,
void FindAlgo(const cublasLtHandle_t& ltHandle,
int m,
int n,
int k,
Expand Down Expand Up @@ -466,15 +468,14 @@ class DevContext {};
class CPUContext : public DevContext {};

class CUBLASLTContext : public DevContext {
public:
CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle_)); }
public:
CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle)); }

private:
cublasLtHandle_t handle_;
cublasLtHandle_t handle;
};

template <typename InT, typename OutT, typename DevContext>
void GEMMInt8(DevContext dev_ctx,
void GEMMInt8(const DevContext& dev_ctx,
const std::vector<InT>& A,
const std::vector<InT>& B,
std::vector<OutT>& C,
Expand All @@ -488,7 +489,7 @@ void GEMMInt8(DevContext dev_ctx,
}

template <>
void GEMMInt8<int8_t, int32_t, CPUContext>(CPUContext dev_ctx,
void GEMMInt8<int8_t, int32_t, CPUContext>(const CPUContext& dev_ctx,
const std::vector<int8_t>& A,
const std::vector<int8_t>& B,
std::vector<int32_t>& C,
Expand All @@ -502,7 +503,7 @@ void GEMMInt8<int8_t, int32_t, CPUContext>(CPUContext dev_ctx,
}

template <>
void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
const std::vector<int8_t>& AVec,
const std::vector<int8_t>& BVec,
std::vector<int32_t>& CVec,
Expand All @@ -528,24 +529,24 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,

// init data structure

cublasLtMatmulDesc_t matmul_desc_;
cublasLtMatrixLayout_t A_desc_;
cublasLtMatrixLayout_t B_desc_;
cublasLtMatrixLayout_t C_desc_;
int32_t alpha_ = 1;
int32_t beta_ = 0;
cublasLtMatmulDesc_t matmul_desc;
cublasLtMatrixLayout_t A_desc;
cublasLtMatrixLayout_t B_desc;
cublasLtMatrixLayout_t C_desc;
int32_t alpha = 1;
int32_t beta = 0;

cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
CUDA_CHECK(
cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType, CUDA_R_32I));
cublasLtMatmulDescCreate(&matmul_desc, cudaComputeType, CUDA_R_32I));
cublasOperation_t op_transpose = CUBLAS_OP_T;
CUDA_CHECK(cublasLtMatmulDescSetAttribute(matmul_desc_,
CUDA_CHECK(cublasLtMatmulDescSetAttribute(matmul_desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&op_transpose,
sizeof(op_transpose)));
CUDA_CHECK(cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k));
CUDA_CHECK(cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k));
CUDA_CHECK(cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n));
CUDA_CHECK(cublasLtMatrixLayoutCreate(&B_desc, CUDA_R_8I, k, n, k));
CUDA_CHECK(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, k, m, k));
CUDA_CHECK(cublasLtMatrixLayoutCreate(&C_desc, CUDA_R_32I, n, m, n));

cublasLtMatmulAlgo_t algo;
int algoId;
Expand Down Expand Up @@ -574,17 +575,17 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
if (is_test) {
std::vector<algoSelect_t> algos;
// Select //
FindAlgo(dev_ctx.handle_,
FindAlgo(dev_ctx.handle,
m,
n,
k,
B_dev,
A_dev,
C_dev,
matmul_desc_,
B_desc_,
A_desc_,
C_desc_,
matmul_desc,
B_desc,
A_desc,
C_desc,
CUBLAS_COMPUTE_32I,
CUDA_R_32I,
CUDA_R_8I,
Expand Down Expand Up @@ -643,7 +644,7 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
paddle::DataType::UINT8,
paddle::GPUPlace());
void* workspace_ptr = workspace.data<uint8_t>();
CUDA_CHECK(cublasLtMatmulAlgoInit(dev_ctx.handle_,
CUDA_CHECK(cublasLtMatmulAlgoInit(dev_ctx.handle,
cudaComputeType,
CUDA_R_32I,
CUDA_R_8I,
Expand Down Expand Up @@ -677,18 +678,18 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(CUBLASLTContext dev_ctx,
auto start = std::chrono::high_resolution_clock::now();
const int repeats = 10;
for (int loop = 0; loop < repeats; loop++) {
CUDA_CHECK(cublasLtMatmul(dev_ctx.handle_,
matmul_desc_,
&alpha_,
CUDA_CHECK(cublasLtMatmul(dev_ctx.handle,
matmul_desc,
&alpha,
B_dev,
B_desc_,
B_desc,
A_dev,
A_desc_,
&beta_,
A_desc,
&beta,
C_dev,
C_desc_,
C_desc,
C_dev,
C_desc_,
C_desc,
&algo,
// nullptr,
workspace_ptr,
Expand All @@ -711,8 +712,8 @@ void TuneCublasltGemm(const paddle::Tensor& M,
bool is_test,
bool is_read_from_file,
const std::string& path) {

// Ensure that M, K, and N are all one-dimensional Tensors. is_test != is_read_from_file
// Ensure that M, K, and N are all one-dimensional Tensors. is_test !=
// is_read_from_file
assert(M.dims().size() == 1 && K.dims().size() == 1 && N.dims().size() == 1);
assert(is_test != is_read_from_file);

Expand All @@ -730,22 +731,34 @@ void TuneCublasltGemm(const paddle::Tensor& M,

int m_data = (int)M_data[0];
assert(m_data > 0 && 4 <= 8192);

std::vector<int> mm;

int m = 1, step = 1;
while (m <= m_data) {
while (m <= m_data) {
mm.push_back(m);
m += step;

// update step
switch (m) {
case 4: step = 4; break;
case 16: step = 16; break;
case 64: step = 32; break;
case 256: step = 64; break;
case 512: step = 128; break;
case 1024: step = 1024; break;
case 4:
step = 4;
break;
case 16:
step = 16;
break;
case 64:
step = 32;
break;
case 256:
step = 64;
break;
case 512:
step = 128;
break;
case 1024:
step = 1024;
break;
}
}

Expand All @@ -761,15 +774,15 @@ void TuneCublasltGemm(const paddle::Tensor& M,
if (dtype == "int8") {
CUBLASLTContext dev_ctx;
GEMMInt8(dev_ctx,
A,
B,
C,
m,
k,
n,
is_test, /*is_test*/
is_read_from_file, /*is_read_from_file*/
path);
A,
B,
C,
m,
k,
n,
is_test, /*is_test*/
is_read_from_file, /*is_read_from_file*/
path);
} else {
// other dtype
std::cout << "Not currently supported" << std::endl;
Expand Down
Loading