Skip to content

Commit

Permalink
【Hackathon 5th No.110】为 Paddle 增强 sparse.matmul API (#59890)
Browse files Browse the repository at this point in the history
  • Loading branch information
MayYouBeProsperous authored Jan 12, 2024
1 parent c526bbb commit dab5512
Show file tree
Hide file tree
Showing 10 changed files with 793 additions and 89 deletions.
47 changes: 27 additions & 20 deletions paddle/fluid/platform/dynload/cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,33 @@ namespace dynload {

#if defined(PADDLE_WITH_CUDA)
#if CUDA_VERSION >= 11000
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseCreateDnVec); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseDestroyDnVec); \
__macro(cusparseSpMV_bufferSize); \
__macro(cusparseSpMV);
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseCreateDnVec); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseDestroyDnVec); \
__macro(cusparseSpMV_bufferSize); \
__macro(cusparseSpMV); \
__macro(cusparseSpMatGetSize); \
__macro(cusparseCsrSetPointers); \
__macro(cusparseSpGEMM_createDescr); \
__macro(cusparseSpGEMM_compute); \
__macro(cusparseSpGEMM_workEstimation); \
__macro(cusparseSpGEMM_copy); \
__macro(cusparseSpGEMM_destroyDescr);

CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#endif
Expand Down
47 changes: 27 additions & 20 deletions paddle/phi/backends/dynload/cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,33 @@ extern void *cusparse_dso_handle;

#if defined(PADDLE_WITH_CUDA)
#if CUDA_VERSION >= 11000
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseCreateDnVec); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseDestroyDnVec); \
__macro(cusparseSpMV_bufferSize); \
__macro(cusparseSpMV);
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseCreateDnVec); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseDestroyDnVec); \
__macro(cusparseSpMV_bufferSize); \
__macro(cusparseSpMV); \
__macro(cusparseSpMatGetSize); \
__macro(cusparseCsrSetPointers); \
__macro(cusparseSpGEMM_createDescr); \
__macro(cusparseSpGEMM_compute); \
__macro(cusparseSpGEMM_workEstimation); \
__macro(cusparseSpGEMM_copy); \
__macro(cusparseSpGEMM_destroyDescr);

CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#endif
Expand Down
15 changes: 14 additions & 1 deletion paddle/phi/kernels/funcs/sparse/sparse_blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
namespace phi {
namespace funcs {
namespace sparse {

template <typename DeviceContext>
class SparseBlas {
public:
Expand Down Expand Up @@ -54,6 +53,15 @@ class SparseBlas {
T beta,
TensorType* mat_out) const;

template <typename T>
void SPGEMM(bool transa,
bool transb,
T alpha,
const SparseCsrTensor& mat_a,
const SparseCsrTensor& mat_b,
T beta,
SparseCsrTensor* mat_out) const;

private:
const DeviceContext& dev_ctx_;
};
Expand All @@ -78,6 +86,11 @@ class SparseBlasT : private SparseBlas<DeviceContext> {
Base()->template SDDMM<T>(args...);
}

template <typename... ARGS>
void SPGEMM(ARGS... args) const {
Base()->template SPGEMM<T>(args...);
}

private:
const SparseBlas<DeviceContext>* Base() const {
return static_cast<const SparseBlas<DeviceContext>*>(this);
Expand Down
Loading

0 comments on commit dab5512

Please sign in to comment.