diff --git a/csrc/generation/tune_cublaslt_gemm.cu b/csrc/generation/tune_cublaslt_gemm.cu index 0aa64640f11e..74d5a8acea64 100644 --- a/csrc/generation/tune_cublaslt_gemm.cu +++ b/csrc/generation/tune_cublaslt_gemm.cu @@ -15,12 +15,14 @@ limitations under the License. */ #include #include #include + #include #include #include #include #include #include + #include "helper.h" template @@ -172,7 +174,7 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle, } template -void FindAlgo(cublasLtHandle_t ltHandle, +void FindAlgo(const cublasLtHandle_t& ltHandle, int m, int n, int k, @@ -466,15 +468,14 @@ class DevContext {}; class CPUContext : public DevContext {}; class CUBLASLTContext : public DevContext { - public: - CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle_)); } +public: + CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle)); } - private: - cublasLtHandle_t handle_; + cublasLtHandle_t handle; }; template -void GEMMInt8(DevContext dev_ctx, +void GEMMInt8(const DevContext& dev_ctx, const std::vector& A, const std::vector& B, std::vector& C, @@ -488,7 +489,7 @@ void GEMMInt8(DevContext dev_ctx, } template <> -void GEMMInt8(CPUContext dev_ctx, +void GEMMInt8(const CPUContext& dev_ctx, const std::vector& A, const std::vector& B, std::vector& C, @@ -502,7 +503,7 @@ void GEMMInt8(CPUContext dev_ctx, } template <> -void GEMMInt8(CUBLASLTContext dev_ctx, +void GEMMInt8(const CUBLASLTContext& dev_ctx, const std::vector& AVec, const std::vector& BVec, std::vector& CVec, @@ -528,24 +529,24 @@ void GEMMInt8(CUBLASLTContext dev_ctx, // init data structure - cublasLtMatmulDesc_t matmul_desc_; - cublasLtMatrixLayout_t A_desc_; - cublasLtMatrixLayout_t B_desc_; - cublasLtMatrixLayout_t C_desc_; - int32_t alpha_ = 1; - int32_t beta_ = 0; + cublasLtMatmulDesc_t matmul_desc; + cublasLtMatrixLayout_t A_desc; + cublasLtMatrixLayout_t B_desc; + cublasLtMatrixLayout_t C_desc; + int32_t alpha = 1; + int32_t beta = 0; cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; CUDA_CHECK( - cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType, CUDA_R_32I)); + cublasLtMatmulDescCreate(&matmul_desc, cudaComputeType, CUDA_R_32I)); cublasOperation_t op_transpose = CUBLAS_OP_T; - CUDA_CHECK(cublasLtMatmulDescSetAttribute(matmul_desc_, + CUDA_CHECK(cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSA, &op_transpose, sizeof(op_transpose))); - CUDA_CHECK(cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k)); - CUDA_CHECK(cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k)); - CUDA_CHECK(cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n)); + CUDA_CHECK(cublasLtMatrixLayoutCreate(&B_desc, CUDA_R_8I, k, n, k)); + CUDA_CHECK(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, k, m, k)); + CUDA_CHECK(cublasLtMatrixLayoutCreate(&C_desc, CUDA_R_32I, n, m, n)); cublasLtMatmulAlgo_t algo; int algoId; @@ -574,17 +575,17 @@ void GEMMInt8(CUBLASLTContext dev_ctx, if (is_test) { std::vector algos; // Select // - FindAlgo(dev_ctx.handle_, + FindAlgo(dev_ctx.handle, m, n, k, B_dev, A_dev, C_dev, - matmul_desc_, - B_desc_, - A_desc_, - C_desc_, + matmul_desc, + B_desc, + A_desc, + C_desc, CUBLAS_COMPUTE_32I, CUDA_R_32I, CUDA_R_8I, @@ -643,7 +644,7 @@ void GEMMInt8(CUBLASLTContext dev_ctx, paddle::DataType::UINT8, paddle::GPUPlace()); void* workspace_ptr = workspace.data(); - CUDA_CHECK(cublasLtMatmulAlgoInit(dev_ctx.handle_, + CUDA_CHECK(cublasLtMatmulAlgoInit(dev_ctx.handle, cudaComputeType, CUDA_R_32I, CUDA_R_8I, @@ -677,18 +678,18 @@ void GEMMInt8(CUBLASLTContext dev_ctx, auto start = std::chrono::high_resolution_clock::now(); const int repeats = 10; for (int loop = 0; loop < repeats; loop++) { - CUDA_CHECK(cublasLtMatmul(dev_ctx.handle_, - matmul_desc_, - &alpha_, + CUDA_CHECK(cublasLtMatmul(dev_ctx.handle, + matmul_desc, + &alpha, B_dev, - B_desc_, + B_desc, A_dev, - A_desc_, - &beta_, + A_desc, + &beta, C_dev, - C_desc_, + C_desc, C_dev, - C_desc_, + C_desc, &algo, // nullptr, workspace_ptr, @@ -711,8 +712,8 @@ void TuneCublasltGemm(const paddle::Tensor& M, bool is_test, bool is_read_from_file, const std::string& path) { - - // Ensure that M, K, and N are all one-dimensional Tensors. is_test != is_read_from_file + // Ensure that M, K, and N are all one-dimensional Tensors. is_test != + // is_read_from_file assert(M.dims().size() == 1 && K.dims().size() == 1 && N.dims().size() == 1); assert(is_test != is_read_from_file); @@ -730,22 +731,34 @@ void TuneCublasltGemm(const paddle::Tensor& M, int m_data = (int)M_data[0]; assert(m_data > 0 && 4 <= 8192); - + std::vector mm; int m = 1, step = 1; - while (m <= m_data) { + while (m <= m_data) { mm.push_back(m); m += step; // update step switch (m) { - case 4: step = 4; break; - case 16: step = 16; break; - case 64: step = 32; break; - case 256: step = 64; break; - case 512: step = 128; break; - case 1024: step = 1024; break; + case 4: + step = 4; + break; + case 16: + step = 16; + break; + case 64: + step = 32; + break; + case 256: + step = 64; + break; + case 512: + step = 128; + break; + case 1024: + step = 1024; + break; } } @@ -761,15 +774,15 @@ void TuneCublasltGemm(const paddle::Tensor& M, if (dtype == "int8") { CUBLASLTContext dev_ctx; GEMMInt8(dev_ctx, - A, - B, - C, - m, - k, - n, - is_test, /*is_test*/ - is_read_from_file, /*is_read_from_file*/ - path); + A, + B, + C, + m, + k, + n, + is_test, /*is_test*/ + is_read_from_file, /*is_read_from_file*/ + path); } else { // other dtype std::cout << "Not currently supported" << std::endl;