Skip to content

Commit

Permalink
Use t_resource.gpu_mem
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhao4ever committed Oct 31, 2016
1 parent 897797d commit 86f6cbe
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions paddle/cuda/src/hl_cuda_cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,45 +175,42 @@ void hl_matrix_inverse(real *A_d, real *C_d, int dimN, int lda, int ldc) {
CHECK_NOTNULL(C_d);

/* Step 1: Compute the LU decomposition of matrix A */
real **inout_h = (real **)hl_malloc_host(sizeof(real *));
inout_h[0] = A_d;

real **inout_h = &A_d;
real **inout_d = (real **)hl_malloc_device(sizeof(real *));
hl_memcpy(inout_d, inout_h, sizeof(real *));
hl_free_mem_host(inout_h);

int *pivot_d = (int *)hl_malloc_device(dimN*sizeof(int));
int *info_h = (int *)hl_malloc_host(sizeof(int));
int *info_d = (int *)hl_malloc_device(sizeof(int));
int *info_d = (int *)t_resource.gpu_mem;

/* Note: cublasSgetrfBatched is used to calculate a number of
small-sized matrices. There may be a better way to reconstruct
the API for better performance.
*/
CHECK_CUBLAS(CUBLAS_GETRF(t_resource.handle,
dimN, inout_d, lda, pivot_d,
info_d, 1));

hl_memcpy(info_h, info_d, sizeof(int));
if (info_h[0] != 0) {
int info_h;
hl_memcpy(&info_h, info_d, sizeof(int));
if (info_h != 0) {
LOG(FATAL) << "Factorization of matrix failed: matrix may be singular.\n";
}

/* Step 2: Compute the inverse of the matrix given its LU decomposition */
real **out_h = (real **)hl_malloc_host(sizeof(real *));
out_h[0] = C_d;

real **out_h = &C_d;
real **out_d = (real **)hl_malloc_device(sizeof(real *));
hl_memcpy(out_d, out_h, sizeof(real *));
hl_free_mem_host(out_h);

CHECK_CUBLAS(CUBLAS_GETRI(t_resource.handle,
dimN, (const real **)inout_d, lda, pivot_d,
out_d, ldc, info_d, 1));

hl_memcpy(info_h, info_d, sizeof(int));
if (info_h[0] != 0) {
hl_memcpy(&info_h, info_d, sizeof(int));
if (info_h != 0) {
LOG(FATAL) << "Inversion of matrix failed: matrix may be singular.\n";
}

hl_free_mem_device(inout_d);
hl_free_mem_device(info_d);
hl_free_mem_device(pivot_d);
hl_free_mem_device(out_d);

Expand Down

0 comments on commit 86f6cbe

Please sign in to comment.