Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tune_cublaslt_gemm operator by cublaslt gemm algorithm and generate algo cache file #8799

Merged
merged 8 commits into from
Jul 31, 2024

Conversation

Hanyonggong
Copy link
Contributor

@Hanyonggong Hanyonggong commented Jul 24, 2024

PR types

New features

PR changes

Description

Add tune_cublaslt_gemm operator by cublaslt gemm algorithm and generate algo cache file

Usage

tune_cublaslt_gemm(M, K, N, dtype, is_test, is_read_form_file, path)
parameter:

  • M,K,N:数据类型:paddle::Tensor,指定矩阵shape[M, K]*[K,N] = [M,N]
  • dtype:用于指定矩阵元素数据类型
  • is_test:是否通过搜索生成matmul配置
  • is_read_from_file:通过离线文件加载matmul配置
  • Path:生成文件或者离线文件的路径
#example
M_tensor = paddle.to_tensor([1024])
K_tensor = paddle.to_tensor([2048])
N_tensor = paddle.to_tensor([11008])
Dtype = "int8"
Path = "./search.csv"

tune_cublaslt_gemm(M_tensor, K_tensor, N_tensor, Dtype, True, False, Path)

cache file format

format(each line): M,K,N,algoId,swizzle,customOption,tile,splitK_val,reductionScheme,stages,workspaceSize,time
example: 2048,2048,4096,21,0,0,24,8,3,15,0,0.0701338
parameter:

  • M,K,N:矩阵shape[M, K]*[K,N] = [M,N]
  • algoId:指定初始化的算法ID
  • swizzle:是否开启 CTA swizzling
  • customOption:自定义选项值
  • tile:tile_id,用于设置tile size(rows * columns)
  • splitk_val:K 的分割数
  • reductionScheme:splitK_val > 1 时使用的缩减方案
  • stages id:用于配置暂存输入元素的共享内存缓冲区的大小和数量
  • workspaceSize:所需workspace内存大小
  • time:运行时间

Copy link

paddle-bot bot commented Jul 24, 2024

Thanks for your contribution!

@CLAassistant
Copy link

CLAassistant commented Jul 24, 2024

CLA assistant check
All committers have signed the CLA.

Copy link

codecov bot commented Jul 24, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 55.50%. Comparing base (da1eb9c) to head (ca515f2).
Report is 238 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8799      +/-   ##
===========================================
+ Coverage    55.37%   55.50%   +0.13%     
===========================================
  Files          631      631              
  Lines        99707    98544    -1163     
===========================================
- Hits         55211    54701     -510     
+ Misses       44496    43843     -653     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@yuanlehome yuanlehome left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

统一梳理下关键位置的LOG打印,争取算子执行过程中,LOG输出文本可读性高

@@ -49,6 +49,7 @@ def get_gencode_flags():


gencode_flags = get_gencode_flags()
library_path = os.environ.get('LD_LIBRARY_PATH')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果找不到这个环境变量呢?

Comment on lines 22 to 28
M.extend(range(1, 4, 1))
M.extend(range(4, 16, 4))
M.extend(range(16, 64, 16))
M.extend(range(64, 256, 32))
M.extend(range(256, 512, 64))
M.extend(range(512, 1024, 128))
M.extend(range(1024, 8193, 1024))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

M看起来不用传list进去,传一个整数值就可以

Dtype = "int8"
Path = "./search.csv"

Tune_gemm(M, K, N, Dtype, Path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

算子名先改成tune_cublaslt_gemm吧,后面再商议个合适的名字,首字母不用大写

Comment on lines 546 to 548
cudaMalloc((void**)&A_dev, A.size() * sizeof(int8_t));
cudaMalloc((void**)&B_dev, B.size() * sizeof(int8_t));
cudaMalloc((void**)&C_dev, m * n * sizeof(int32_t));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

内部的cudaMalloc/cudaFree等显存的使用,用paddle::Tensor吧

See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.cu文件无需这个

Comment on lines 27 to 29
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
#include <cuda_fp8.h>
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8不支持先删去

Comment on lines 744 to 748
void TestBench(const std::vector<int64_t>& M,
const std::vector<int64_t>& K,
const std::vector<int64_t>& N,
const std::string dtype,
const std::string path) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数名与算子名保持一致,并且这些函数参数都改为算子的输入,而不是属性

Comment on lines 786 to 790
.Attrs({"M :std::vector<int64_t>",
"K :std::vector<int64_t>",
"N :std::vector<int64_t>",
"dtype: std::string",
"path: std::string",})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

算子输入而非属性,M, N, K均是paddle::Tensor对象

Comment on lines 629 to 652
int m_tmp, k_tmp, n_tmp;
FILE* fp;
fp = fopen(path.c_str(), "r");
if (!fp) {
using_default_config();
} else {
bool match = false;
int find_cnt = 0;
while (1) {
fscanf(fp,
"%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%f",
&m_tmp,
&k_tmp,
&n_tmp,
&algoId,
&swizzle,
&customOption,
&tile,
&splitK_val,
&reductionScheme,
&stages,
&work_space_size,
&time_ref);
if (feof(fp)) break;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成c++语法,可参考我之前发你的pr里实现

Comment on lines 251 to 267
cublasLtMatmulAlgoCapGetAttribute(&algo,
CUBLASLT_ALGO_CAP_SPLITK_SUPPORT,
&splitkSupport,
sizeof(splitkSupport),
&sizeWritten);
std::clog << "splitkSupport " << splitkSupport;
cublasLtMatmulAlgoCapGetAttribute(&algo,
CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK,
&redMask,
sizeof(redMask),
&sizeWritten);
cublasLtMatmulAlgoCapGetAttribute(&algo,
CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT,
&swizzlingMax,
sizeof(swizzlingMax),
&sizeWritten);
cublasLtMatmulAlgoCapGetAttribute(&algo,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所有CUDA的API调用都用宏检验下返回值,类似于CUDA_CHECK

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#define CUDA_CHECK(call)                           \
  do {                                             \
    const cudaError_t error_code = call;           \
    if (error_code != cudaSuccess) {               \
      std::printf("at %s:%d - %s.\n",              \
                  __FILE__,                        \
                  __LINE__,                        \
                  cudaGetErrorString(error_code)); \
      exit(1);                                     \
    }                                              \
  } while (0)

@Hanyonggong Hanyonggong changed the title add tune_gemm feature Add tune_cublaslt_gemm operator by cublaslt gemm algorithm and generate algo cache file Jul 26, 2024
N_tensor = paddle.to_tensor([4096, 8192])

Dtype = "int8"
Path = "./search.csv"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件只是个示例么,实际上不跑?看起来没有 search.csv这个文件

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是输出path,跑完之后会生成

public:
CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle_)); }

cublasLtHandle_t handle_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是私有变量吗?是不是写到下面,private那里是空的

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Hanyonggong 改一下吧~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更改

Copy link
Contributor

@DesmonDay DesmonDay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@DesmonDay DesmonDay merged commit 99d9bcb into PaddlePaddle:develop Jul 31, 2024
11 of 12 checks passed
DrownFish19 pushed a commit to DrownFish19/PaddleNLP that referenced this pull request Aug 2, 2024
DrownFish19 pushed a commit to DrownFish19/PaddleNLP that referenced this pull request Aug 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants