Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine device context and fix GetPlace() #3084

Merged
merged 4 commits into from
Jul 28, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/memory/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
add_subdirectory(detail)

cc_library(memory SRCS memory.cc)
cc_library(memcpy SRCS memcpy.cc DEPS device_context)
cc_library(memcpy SRCS memcpy.cc)

cc_library(paddle_memory
DEPS
Expand Down
6 changes: 3 additions & 3 deletions paddle/memory/memcpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
platform::GPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream) {
platform::GPUPlaceGuard g(src_place.device);
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
}

Expand All @@ -45,7 +45,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
platform::CPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream) {
platform::GPUPlaceGuard g(dst_place.device);
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
}

Expand All @@ -56,7 +56,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
const void* src, size_t num,
cudaStream_t stream) {
if (dst_place == src_place) {
platform::GPUPlaceGuard g(src_place.device);
platform::SetDeviceId(src_place.device);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个SetDeviceId是不是要写在if外面

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不需要了 因为GpuMemcpyPeer函数里面的两个参数 就是相应的device id

platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
} else {
platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num,
Expand Down
86 changes: 85 additions & 1 deletion paddle/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,96 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device();
}

CPUDeviceContext::CPUDeviceContext() {
Copy link
Contributor Author

@gangliao gangliao Jul 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move implementation into *.cc to make the definition clearer.

eigen_device_.reset(new Eigen::DefaultDevice());
}

CPUDeviceContext::CPUDeviceContext(CPUPlace place) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep unified cpu and gpu interface

eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
return eigen_device_.get();
}

Place CPUDeviceContext::GetPlace() const { return CPUPlace(); }

#ifndef PADDLE_ONLY_CPU

template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
}
#endif

CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add SetDeviceId(place_.device); in here

Wait();
if (cublas_handle_) {
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
}

if (cudnn_handle_) {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}

if (curand_generator_) {
PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_));
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}

Place CUDADeviceContext::GetPlace() const { return place_; }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix bug in here, return place_


cudaStream_t CUDADeviceContext::stream() const { return stream_; }

void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
}

Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
}

cublasHandle_t CUDADeviceContext::cublas_handle() {
if (!cublas_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
}
return cublas_handle_;
}

cudnnHandle_t CUDADeviceContext::cudnn_handle() {
if (!cudnn_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
}
return cudnn_handle_;
}

curandGenerator_t CUDADeviceContext::curand_generator() {
if (!curand_generator_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_));
}
return curand_generator_;
}

#endif // PADDLE_ONLY_CPU

} // namespace platform
} // namespace paddle
147 changes: 39 additions & 108 deletions paddle/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,134 +39,65 @@ class DeviceContext {

class CPUDeviceContext : public DeviceContext {
public:
CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); }
CPUDeviceContext();
CPUDeviceContext(CPUPlace);
virtual ~CPUDeviceContext() {}

Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); }
Eigen::DefaultDevice* eigen_device() const;

Place GetPlace() const override {
Place retv = CPUPlace();
return retv;
}
Place GetPlace() const override;

private:
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

#ifndef PADDLE_ONLY_CPU

class GPUPlaceGuard {
class CUDADeviceContext : public DeviceContext {
public:
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
if (previous_ != new_place) {
paddle::platform::SetDeviceId(new_place.device);
}
}
explicit CUDADeviceContext(GPUPlace);
virtual ~CUDADeviceContext();

~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); }
/*! \brief Wait for all operations completion in the stream. */
void Wait() const;

private:
GPUPlace previous_;
};
/*! \brief Return CUDA stream in the device context. */
cudaStream_t stream() const;

class CUDADeviceContext : public DeviceContext {
public:
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed");
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

Place GetPlace() const override {
Place retv = GPUPlace();
return retv;
}

void Wait() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_),
"cudaStreamSynchronize failed");
}

cudaStream_t stream() const { return stream_; }

Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); }

cublasHandle_t cublas_handle() {
if (!blas_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_),
"cublasCreate failed");
PADDLE_ENFORCE(
paddle::platform::dynload::cublasSetStream(blas_handle_, stream_),
"cublasSetStream failed");
}
return blas_handle_;
}

cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_),
"cudnnCreate failed");
PADDLE_ENFORCE(
paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_),
"cudnnSetStream failed");
}
return dnn_handle_;
}

curandGenerator_t curand_generator() {
if (!rand_generator_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT),
"curandCreateGenerator failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
rand_generator_, random_seed_),
"curandSetPseudoRandomGeneratorSeed failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetStream(rand_generator_, stream_),
"curandSetStream failed");
}
return rand_generator_;
}

~CUDADeviceContext() {
Wait();
if (blas_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_),
"cublasDestroy failed");
}

if (dnn_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_),
"cudnnDestroy failed");
}

if (rand_generator_) {
PADDLE_ENFORCE(
paddle::platform::dynload::curandDestroyGenerator(rand_generator_),
"curandDestroyGenerator failed");
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed");
}
/*! \brief Return place in the device context. */
Place GetPlace() const override;

/*! \brief Return eigen device in the device context. */
Eigen::GpuDevice* eigen_device() const;

// clang-format off
/*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle ();

/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle ();

/*! \brief Return curand handle in the device context. */
curandGenerator_t curand_generator();
// clang-format on

private:
GPUPlace gpu_place_;
cudaStream_t stream_;
GPUPlace place_;

std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
private:
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;

cublasHandle_t blas_handle_{nullptr};
private:
uint64_t seed_;
Copy link
Contributor Author

@gangliao gangliao Jul 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to set seed to uint64_t and I will also add rand method for CPU in next PR


cudnnHandle_t dnn_handle_{nullptr};
cudaStream_t stream_;

int random_seed_;
curandGenerator_t rand_generator_{nullptr};
// clang-format off
cudnnHandle_t cudnn_handle_ = nullptr;
cublasHandle_t cublas_handle_ = nullptr;
curandGenerator_t curand_generator_ = nullptr;
// clang-format on
};

#endif
Expand Down
10 changes: 5 additions & 5 deletions paddle/platform/enforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ struct EnforceNotMet : public std::exception {
// For more details, please check https://stackoverflow.com/a/43870188/724872.
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)

template <typename T>
inline void throw_on_error(T e) {
throw_on_error(e, "");
}

template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
int stat, const Args&... args) {
Expand Down Expand Up @@ -132,6 +127,11 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(

#endif // PADDLE_ONLY_CPU

template <typename T>
inline void throw_on_error(T e) {
throw_on_error(e, "");
}

#define PADDLE_THROW(...) \
do { \
throw ::paddle::platform::EnforceNotMet( \
Expand Down