Skip to content

Commit

Permalink
[Lang] Sort coo to build correct csr format sparse matrix on GPU (#6050)
Browse files Browse the repository at this point in the history
Related issue = #2906 

When building a coo format sparse matrix, the indices are not in order.
To build a valid csr format sparse matrix. We need to first sort the coo
indices arrays.
  • Loading branch information
FantasyVR authored Sep 17, 2022
1 parent 5c7d0eb commit 33a606a
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 5 deletions.
38 changes: 36 additions & 2 deletions taichi/program/sparse_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,40 @@ void CuSparseMatrix::build_csr_from_coo(void *coo_row_ptr,
void *coo_values_ptr,
int nnz) {
#if defined(TI_WITH_CUDA)
// Step 1: Sort coo first
cusparseHandle_t cusparse_handle = NULL;
CUSPARSEDriver::get_instance().cpCreate(&cusparse_handle);
cusparseSpVecDescr_t vec_permutation;
cusparseDnVecDescr_t vec_values;
void *d_permutation = NULL, *d_values_sorted = NULL;
CUDADriver::get_instance().malloc(&d_permutation, nnz * sizeof(int));
CUDADriver::get_instance().malloc(&d_values_sorted, nnz * sizeof(float));
CUSPARSEDriver::get_instance().cpCreateSpVec(
&vec_permutation, nnz, nnz, d_permutation, d_values_sorted,
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, CUDA_R_32F);
CUSPARSEDriver::get_instance().cpCreateDnVec(&vec_values, nnz, coo_values_ptr,
CUDA_R_32F);
size_t bufferSize = 0;
CUSPARSEDriver::get_instance().cpXcoosort_bufferSizeExt(
cusparse_handle, rows_, cols_, nnz, coo_row_ptr, coo_col_ptr,
&bufferSize);
void *dbuffer = NULL;
if (bufferSize > 0)
CUDADriver::get_instance().malloc(&dbuffer, bufferSize);
// Setup permutation vector to identity
CUSPARSEDriver::get_instance().cpCreateIdentityPermutation(
cusparse_handle, nnz, d_permutation);
CUSPARSEDriver::get_instance().cpXcoosortByRow(cusparse_handle, rows_, cols_,
nnz, coo_row_ptr, coo_col_ptr,
d_permutation, dbuffer);
CUSPARSEDriver::get_instance().cpGather(cusparse_handle, vec_values,
vec_permutation);
CUDADriver::get_instance().memcpy_device_to_device(
coo_values_ptr, d_values_sorted, nnz * sizeof(float));
// Step 2: coo to csr
void *csr_row_offset_ptr = NULL;
CUDADriver::get_instance().malloc(&csr_row_offset_ptr,
sizeof(int) * (rows_ + 1));
cusparseHandle_t cusparse_handle;
CUSPARSEDriver::get_instance().cpCreate(&cusparse_handle);
CUSPARSEDriver::get_instance().cpCoo2Csr(
cusparse_handle, (void *)coo_row_ptr, nnz, rows_,
(void *)csr_row_offset_ptr, CUSPARSE_INDEX_BASE_ZERO);
Expand All @@ -216,9 +245,14 @@ void CuSparseMatrix::build_csr_from_coo(void *coo_row_ptr,
&matrix_, rows_, cols_, nnz, csr_row_offset_ptr, coo_col_ptr,
coo_values_ptr, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO, CUDA_R_32F);
CUSPARSEDriver::get_instance().cpDestroySpVec(vec_permutation);
CUSPARSEDriver::get_instance().cpDestroyDnVec(vec_values);
CUSPARSEDriver::get_instance().cpDestroy(cusparse_handle);
// TODO: free csr_row_offset_ptr
// CUDADriver::get_instance().mem_free(csr_row_offset_ptr);
CUDADriver::get_instance().mem_free(d_values_sorted);
CUDADriver::get_instance().mem_free(d_permutation);
CUDADriver::get_instance().mem_free(dbuffer);
#endif
}

Expand Down
2 changes: 2 additions & 0 deletions taichi/rhi/cuda/cuda_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,10 @@ typedef struct cusparseContext *cusparseHandle_t;
struct cusparseMatDescr;
typedef struct cusparseMatDescr *cusparseMatDescr_t;

struct cusparseSpVecDescr;
struct cusparseDnVecDescr;
struct cusparseSpMatDescr;
typedef struct cusparseSpVecDescr *cusparseSpVecDescr_t;
typedef struct cusparseDnVecDescr *cusparseDnVecDescr_t;
typedef struct cusparseSpMatDescr *cusparseSpMatDescr_t;
typedef enum {
Expand Down
6 changes: 6 additions & 0 deletions taichi/rhi/cuda/cusparse_functions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ PER_CUSPARSE_FUNCTION(cpCreateMatDescr, cusparseCreateMatDescr, cusparseMatDescr
PER_CUSPARSE_FUNCTION(cpSetMatType, cusparseSetMatType, cusparseMatDescr_t, cusparseMatrixType_t);
PER_CUSPARSE_FUNCTION(cpSetMatIndexBase, cusparseSetMatIndexBase, cusparseMatDescr_t, cusparseIndexBase_t);
PER_CUSPARSE_FUNCTION(cpDestroySpMat, cusparseDestroySpMat, cusparseSpMatDescr_t);
PER_CUSPARSE_FUNCTION(cpCreateSpVec, cusparseCreateSpVec, cusparseSpVecDescr_t* ,int ,int,void*,void*,cusparseIndexType_t,cusparseIndexBase_t,cudaDataType);
PER_CUSPARSE_FUNCTION(cpDestroySpVec, cusparseDestroySpVec, cusparseSpVecDescr_t);
PER_CUSPARSE_FUNCTION(cpCreateIdentityPermutation, cusparseCreateIdentityPermutation, cusparseHandle_t, int, void*);
PER_CUSPARSE_FUNCTION(cpXcoosort_bufferSizeExt, cusparseXcoosort_bufferSizeExt, cusparseHandle_t,int ,int,int, void* ,void* ,void*);
PER_CUSPARSE_FUNCTION(cpXcoosortByRow, cusparseXcoosortByRow, cusparseHandle_t,int,int,int,void* ,void* ,void* ,void*);
PER_CUSPARSE_FUNCTION(cpGather, cusparseGather, cusparseHandle_t, cusparseDnVecDescr_t, cusparseSpVecDescr_t);

// cusparse dense vector description
PER_CUSPARSE_FUNCTION(cpCreateDnVec, cusparseCreateDnVec, cusparseDnVecDescr_t*, int, void*, cudaDataType);
Expand Down
6 changes: 3 additions & 3 deletions tests/python/test_sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ def fill(Abuilder: ti.types.sparse_matrix_builder(),

@test_utils.test(arch=ti.cuda)
def test_gpu_sparse_matrix():
h_coo_row = np.asarray([0, 0, 0, 1, 2, 2, 2, 3, 3], dtype=np.int32)
h_coo_col = np.asarray([0, 2, 3, 1, 0, 2, 3, 1, 3], dtype=np.int32)
h_coo_val = np.asarray([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
h_coo_row = np.asarray([1, 0, 0, 0, 2, 2, 2, 3, 3], dtype=np.int32)
h_coo_col = np.asarray([1, 0, 2, 3, 0, 2, 3, 1, 3], dtype=np.int32)
h_coo_val = np.asarray([4.0, 1.0, 2.0, 3.0, 5.0, 6.0, 7.0, 8.0, 9.0],
dtype=np.float32)
h_X = np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
h_Y = np.asarray([19.0, 8.0, 51.0, 52.0], dtype=np.float32)
Expand Down

0 comments on commit 33a606a

Please sign in to comment.