Skip to content

Commit

Permalink
Add cublas_handle() to expose cublas_handle to ops (PaddlePaddle#31157)
Browse files Browse the repository at this point in the history
* add get_cublas_handle() api

* update format

* add unittests

* alter function name
  • Loading branch information
FrostML committed Feb 24, 2021
1 parent b0ec6e8 commit 2d305b0
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 1 deletion.
4 changes: 3 additions & 1 deletion paddle/fluid/platform/cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,14 @@ class CublasHandleHolder {
#endif // CUDA_VERSION >= 9000
}

const cublasHandle_t& GetCublasHandle() const { return handle_; }

~CublasHandleHolder() PADDLE_MAY_THROW {
PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasDestroy(handle_));
}

template <typename Callback>
inline void Call(Callback &&callback) const {
inline void Call(Callback&& callback) const {
std::lock_guard<std::mutex> guard(mtx_);
callback(handle_);
}
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return context()->CudnnHandle();
}

cublasHandle_t CUDADeviceContext::cublas_handle() const {
return context()->CublasHandle()->GetCublasHandle();
}

CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const;

/*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle() const;

/*! \brief Return a cudnn workspace handle to call multiple cudnn
* functions without interrupting by other threads.
* Once the first cudnn function is called by the handle, a lock
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/platform/device_context_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
delete device_context;
}
}
Expand Down

0 comments on commit 2d305b0

Please sign in to comment.