Skip to content

Commit

Permalink
fix issue when build with hipblasLt on rocm6.1 (#22553)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

hipblasLt library is released with rocm6.x, and current onnxruntime's
code need some modifications to match new hipblasLt API.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
kailums authored Oct 28, 2024
1 parent 7ad7873 commit dd28f09
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 17 deletions.
5 changes: 0 additions & 5 deletions onnxruntime/core/providers/rocm/rocm_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,4 @@ template Status RocmCall<ncclResult_t, false>(ncclResult_t retCode, const char*
template void RocmCall<ncclResult_t, true>(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line);
#endif

#ifdef USE_HIPBLASLT
template Status RocmCall<hipblasStatus_t, false>(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line);
template void RocmCall<hipblasStatus_t, true>(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line);
#endif

} // namespace onnxruntime
24 changes: 12 additions & 12 deletions onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,26 @@ enum ActivationType {
};

template <typename T>
constexpr hipblasltDatatype_t HipBlasDataTypeFor();
constexpr hipDataType HipBlasDataTypeFor();

template <>
constexpr hipblasltDatatype_t HipBlasDataTypeFor<float>() {
return HIPBLASLT_R_32F;
constexpr hipDataType HipBlasDataTypeFor<float>() {
return HIP_R_32F;
}

template <>
constexpr hipblasltDatatype_t HipBlasDataTypeFor<half>() {
return HIPBLASLT_R_16F;
constexpr hipDataType HipBlasDataTypeFor<half>() {
return HIP_R_16F;
}

template <>
constexpr hipblasltDatatype_t HipBlasDataTypeFor<BFloat16>() {
return HIPBLASLT_R_16B;
constexpr hipDataType HipBlasDataTypeFor<BFloat16>() {
return HIP_R_16BF;
}

template <>
constexpr hipblasltDatatype_t HipBlasDataTypeFor<double>() {
return HIPBLASLT_R_64F;
constexpr hipDataType HipBlasDataTypeFor<double>() {
return HIP_R_64F;
}

template <BlasOp Op>
Expand Down Expand Up @@ -108,7 +108,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp

hipblasOperation_t trans_a = MapBlasOpToHipBlasLt<OpB>();
hipblasOperation_t trans_b = MapBlasOpToHipBlasLt<OpA>();
hipblasltDatatype_t in_out_datatype = HipBlasDataTypeFor<T>();
hipDataType in_out_datatype = HipBlasDataTypeFor<T>();
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;

HIPBLASLT_CALL_THROW(hipblaslt_ext::getAllAlgos(handle,
Expand All @@ -119,7 +119,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLASLT_COMPUTE_F32,
HIPBLAS_COMPUTE_32F,
heuristic_result));
HIPBLASLT_CALL_THROW(hipblasLtDestroy(handle));

Expand Down Expand Up @@ -161,7 +161,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLASLT_R_32F));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));

int batch = GetBatchCountFromParams<T>(params);
if (batch > 1) {
Expand Down
4 changes: 4 additions & 0 deletions tools/ci_build/amd_hipify.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path):
s = s.replace("kCudaStreamCopyIn", "kHipStreamCopyIn")
s = s.replace("kCudaStreamCopyOut", "kHipStreamCopyOut")
s = s.replace("kTotalCudaStreams", "kTotalHipStreams")

# in rocm 6.0, hipify-perl, the -roc option also maps __half -> rocblas_half which we don't want
s = s.replace("rocblas_half", "__half")

# these should be "hip" but it's easier to just use rocm to avoid complicated file renaming
s = s.replace("CudaGraph", "RocmGraph")
s = s.replace("CUDAGraph", "ROCMGraph")
Expand Down

0 comments on commit dd28f09

Please sign in to comment.