-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refine device context and fix GetPlace() #3084
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,12 +20,96 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() | |
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device(); | ||
} | ||
|
||
CPUDeviceContext::CPUDeviceContext() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move implementation into |
||
eigen_device_.reset(new Eigen::DefaultDevice()); | ||
} | ||
|
||
CPUDeviceContext::CPUDeviceContext(CPUPlace place) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep unified cpu and gpu interface |
||
eigen_device_.reset(new Eigen::DefaultDevice()); | ||
} | ||
|
||
Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { | ||
return eigen_device_.get(); | ||
} | ||
|
||
Place CPUDeviceContext::GetPlace() const { return CPUPlace(); } | ||
|
||
#ifndef PADDLE_ONLY_CPU | ||
|
||
template <> | ||
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const { | ||
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device(); | ||
} | ||
#endif | ||
|
||
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { | ||
SetDeviceId(place_.device); | ||
PADDLE_ENFORCE(cudaStreamCreate(&stream_)); | ||
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_)); | ||
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); | ||
} | ||
|
||
CUDADeviceContext::~CUDADeviceContext() { | ||
SetDeviceId(place_.device); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
Wait(); | ||
if (cublas_handle_) { | ||
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); | ||
} | ||
|
||
if (cudnn_handle_) { | ||
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); | ||
} | ||
|
||
if (curand_generator_) { | ||
PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_)); | ||
} | ||
eigen_stream_.reset(); | ||
eigen_device_.reset(); | ||
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); | ||
} | ||
|
||
Place CUDADeviceContext::GetPlace() const { return place_; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix bug in here, return |
||
|
||
cudaStream_t CUDADeviceContext::stream() const { return stream_; } | ||
|
||
void CUDADeviceContext::Wait() const { | ||
PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); | ||
} | ||
|
||
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { | ||
return eigen_device_.get(); | ||
} | ||
|
||
cublasHandle_t CUDADeviceContext::cublas_handle() { | ||
if (!cublas_handle_) { | ||
SetDeviceId(place_.device); | ||
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); | ||
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); | ||
} | ||
return cublas_handle_; | ||
} | ||
|
||
cudnnHandle_t CUDADeviceContext::cudnn_handle() { | ||
if (!cudnn_handle_) { | ||
SetDeviceId(place_.device); | ||
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); | ||
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); | ||
} | ||
return cudnn_handle_; | ||
} | ||
|
||
curandGenerator_t CUDADeviceContext::curand_generator() { | ||
if (!curand_generator_) { | ||
SetDeviceId(place_.device); | ||
PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, | ||
CURAND_RNG_PSEUDO_DEFAULT)); | ||
PADDLE_ENFORCE( | ||
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); | ||
PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); | ||
} | ||
return curand_generator_; | ||
} | ||
|
||
#endif // PADDLE_ONLY_CPU | ||
|
||
} // namespace platform | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,134 +39,65 @@ class DeviceContext { | |
|
||
class CPUDeviceContext : public DeviceContext { | ||
public: | ||
CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } | ||
CPUDeviceContext(); | ||
CPUDeviceContext(CPUPlace); | ||
virtual ~CPUDeviceContext() {} | ||
|
||
Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); } | ||
Eigen::DefaultDevice* eigen_device() const; | ||
|
||
Place GetPlace() const override { | ||
Place retv = CPUPlace(); | ||
return retv; | ||
} | ||
Place GetPlace() const override; | ||
|
||
private: | ||
std::unique_ptr<Eigen::DefaultDevice> eigen_device_; | ||
}; | ||
|
||
#ifndef PADDLE_ONLY_CPU | ||
|
||
class GPUPlaceGuard { | ||
class CUDADeviceContext : public DeviceContext { | ||
public: | ||
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { | ||
if (previous_ != new_place) { | ||
paddle::platform::SetDeviceId(new_place.device); | ||
} | ||
} | ||
explicit CUDADeviceContext(GPUPlace); | ||
virtual ~CUDADeviceContext(); | ||
|
||
~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); } | ||
/*! \brief Wait for all operations completion in the stream. */ | ||
void Wait() const; | ||
|
||
private: | ||
GPUPlace previous_; | ||
}; | ||
/*! \brief Return CUDA stream in the device context. */ | ||
cudaStream_t stream() const; | ||
|
||
class CUDADeviceContext : public DeviceContext { | ||
public: | ||
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { | ||
GPUPlaceGuard guard(gpu_place_); | ||
PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); | ||
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_)); | ||
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); | ||
} | ||
|
||
Place GetPlace() const override { | ||
Place retv = GPUPlace(); | ||
return retv; | ||
} | ||
|
||
void Wait() { | ||
PADDLE_ENFORCE(cudaStreamSynchronize(stream_), | ||
"cudaStreamSynchronize failed"); | ||
} | ||
|
||
cudaStream_t stream() const { return stream_; } | ||
|
||
Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); } | ||
|
||
cublasHandle_t cublas_handle() { | ||
if (!blas_handle_) { | ||
GPUPlaceGuard guard(gpu_place_); | ||
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_), | ||
"cublasCreate failed"); | ||
PADDLE_ENFORCE( | ||
paddle::platform::dynload::cublasSetStream(blas_handle_, stream_), | ||
"cublasSetStream failed"); | ||
} | ||
return blas_handle_; | ||
} | ||
|
||
cudnnHandle_t cudnn_handle() { | ||
if (!dnn_handle_) { | ||
GPUPlaceGuard guard(gpu_place_); | ||
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_), | ||
"cudnnCreate failed"); | ||
PADDLE_ENFORCE( | ||
paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_), | ||
"cudnnSetStream failed"); | ||
} | ||
return dnn_handle_; | ||
} | ||
|
||
curandGenerator_t curand_generator() { | ||
if (!rand_generator_) { | ||
GPUPlaceGuard guard(gpu_place_); | ||
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( | ||
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT), | ||
"curandCreateGenerator failed"); | ||
PADDLE_ENFORCE( | ||
paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed( | ||
rand_generator_, random_seed_), | ||
"curandSetPseudoRandomGeneratorSeed failed"); | ||
PADDLE_ENFORCE( | ||
paddle::platform::dynload::curandSetStream(rand_generator_, stream_), | ||
"curandSetStream failed"); | ||
} | ||
return rand_generator_; | ||
} | ||
|
||
~CUDADeviceContext() { | ||
Wait(); | ||
if (blas_handle_) { | ||
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_), | ||
"cublasDestroy failed"); | ||
} | ||
|
||
if (dnn_handle_) { | ||
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_), | ||
"cudnnDestroy failed"); | ||
} | ||
|
||
if (rand_generator_) { | ||
PADDLE_ENFORCE( | ||
paddle::platform::dynload::curandDestroyGenerator(rand_generator_), | ||
"curandDestroyGenerator failed"); | ||
} | ||
eigen_stream_.reset(); | ||
eigen_device_.reset(); | ||
PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed"); | ||
} | ||
/*! \brief Return place in the device context. */ | ||
Place GetPlace() const override; | ||
|
||
/*! \brief Return eigen device in the device context. */ | ||
Eigen::GpuDevice* eigen_device() const; | ||
|
||
// clang-format off | ||
/*! \brief Return cublas handle in the device context. */ | ||
cublasHandle_t cublas_handle (); | ||
|
||
/*! \brief Return cudnn handle in the device context. */ | ||
cudnnHandle_t cudnn_handle (); | ||
|
||
/*! \brief Return curand handle in the device context. */ | ||
curandGenerator_t curand_generator(); | ||
// clang-format on | ||
|
||
private: | ||
GPUPlace gpu_place_; | ||
cudaStream_t stream_; | ||
GPUPlace place_; | ||
|
||
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_; | ||
private: | ||
std::unique_ptr<Eigen::GpuDevice> eigen_device_; | ||
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_; | ||
|
||
cublasHandle_t blas_handle_{nullptr}; | ||
private: | ||
uint64_t seed_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's better to set seed to |
||
|
||
cudnnHandle_t dnn_handle_{nullptr}; | ||
cudaStream_t stream_; | ||
|
||
int random_seed_; | ||
curandGenerator_t rand_generator_{nullptr}; | ||
// clang-format off | ||
cudnnHandle_t cudnn_handle_ = nullptr; | ||
cublasHandle_t cublas_handle_ = nullptr; | ||
curandGenerator_t curand_generator_ = nullptr; | ||
// clang-format on | ||
}; | ||
|
||
#endif | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个SetDeviceId是不是要写在if外面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个不需要了 因为GpuMemcpyPeer函数里面的两个参数 就是相应的device id