-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CPU/GPU Memcpy in memory folder #2970
Changes from all commits
028f3dc
e53a48b
b058864
527c859
ca89bfa
00500ee
0897d18
b3115fb
6cae35b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,20 +14,30 @@ limitations under the License. */ | |
|
||
#pragma once | ||
|
||
#include "paddle/platform/gpu_info.h" | ||
#include "paddle/platform/place.h" | ||
|
||
namespace paddle { | ||
namespace memory { | ||
|
||
template <class Place> | ||
template <typename Place> | ||
void* Alloc(Place, size_t); | ||
|
||
template <class Place> | ||
template <typename Place> | ||
void Free(Place, void*); | ||
|
||
template <class Place> | ||
template <typename Place> | ||
size_t Used(Place); | ||
|
||
template <typename DstPlace, typename SrcPlace> | ||
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); | ||
|
||
#ifndef PADDLE_ONLY_CPU | ||
template <typename DstPlace, typename SrcPlace> | ||
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, | ||
cudaStream_t stream); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be great to add a comment telling when would users call this second form of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for replying too late due to we are on duty yesterday. Yeah, I will annotate this function. Thanks. |
||
#endif // PADDLE_ONLY_CPU | ||
|
||
template <typename T, /* must be POD types */ | ||
typename Place /* platform::GPUPlace or platform::CPUPlace */, | ||
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,10 +43,26 @@ namespace platform { | |
// For more details, please check https://stackoverflow.com/a/43870188/724872. | ||
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) | ||
|
||
template <typename T> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix special case of PADDLE_ENFORCE PADDLE_ENFORCE(condition, "hello world"); // OK, if using old implementation
PADDLE_ENFORCE(condition) // Failed, if using old implementation. But, it's addressed. |
||
inline void throw_on_error(T e) { | ||
throw_on_error(e, ""); | ||
} | ||
|
||
template <typename... Args> | ||
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( | ||
int stat, const Args&... args) { | ||
if (UNLIKELY(!(stat))) { | ||
throw std::runtime_error( | ||
string::Sprintf(args...) + | ||
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); | ||
} | ||
} | ||
|
||
#ifndef PADDLE_ONLY_CPU | ||
|
||
template <typename... Args> | ||
inline void throw_on_error(cudaError_t e, const Args&... args) { | ||
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome! |
||
cudaError_t e, const Args&... args) { | ||
if (UNLIKELY(e)) { | ||
// clang-format off | ||
throw thrust::system_error( | ||
|
@@ -58,7 +74,8 @@ inline void throw_on_error(cudaError_t e, const Args&... args) { | |
} | ||
|
||
template <typename... Args> | ||
inline void throw_on_error(curandStatus_t stat, const Args&... args) { | ||
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( | ||
curandStatus_t stat, const Args&... args) { | ||
if (stat != CURAND_STATUS_SUCCESS) { | ||
// clang-format off | ||
throw thrust::system_error( | ||
|
@@ -70,7 +87,8 @@ inline void throw_on_error(curandStatus_t stat, const Args&... args) { | |
} | ||
|
||
template <typename... Args> | ||
inline void throw_on_error(cudnnStatus_t stat, const Args&... args) { | ||
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( | ||
cudnnStatus_t stat, const Args&... args) { | ||
if (stat == CUDNN_STATUS_SUCCESS) { | ||
return; | ||
} else { | ||
|
@@ -84,7 +102,8 @@ inline void throw_on_error(cudnnStatus_t stat, const Args&... args) { | |
} | ||
|
||
template <typename... Args> | ||
inline void throw_on_error(cublasStatus_t stat, const Args&... args) { | ||
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( | ||
cublasStatus_t stat, const Args&... args) { | ||
std::string err; | ||
if (stat == CUBLAS_STATUS_SUCCESS) { | ||
return; | ||
|
@@ -113,28 +132,16 @@ inline void throw_on_error(cublasStatus_t stat, const Args&... args) { | |
|
||
#endif // PADDLE_ONLY_CPU | ||
|
||
template <typename... Args> | ||
inline void throw_on_error(int stat, const Args&... args) { | ||
if (UNLIKELY(!(stat))) { | ||
throw std::runtime_error( | ||
string::Sprintf(args...) + | ||
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); | ||
} | ||
} | ||
|
||
#define PADDLE_THROW(...) \ | ||
do { \ | ||
throw std::runtime_error( \ | ||
string::Sprintf(__VA_ARGS__) + \ | ||
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \ | ||
} while (0) | ||
|
||
/** | ||
* @brief Enforce a condition, otherwise throw an EnforceNotMet | ||
*/ | ||
#define PADDLE_ENFORCE(condition, ...) \ | ||
do { \ | ||
::paddle::platform::throw_on_error(condition, __VA_ARGS__); \ | ||
#define PADDLE_ENFORCE(...) \ | ||
do { \ | ||
::paddle::platform::throw_on_error(__VA_ARGS__); \ | ||
} while (0) | ||
|
||
} // namespace platform | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ limitations under the License. */ | |
|
||
#ifndef PADDLE_ONLY_CPU | ||
|
||
#include <cuda_runtime.h> | ||
#include <stddef.h> | ||
|
||
namespace paddle { | ||
|
@@ -31,7 +32,7 @@ int GetCurrentDeviceId(); | |
void SetDeviceId(int device_id); | ||
|
||
//!Get the memory usage of current GPU device. | ||
void GpuMemoryUsage(size_t& available, size_t& total); | ||
void GpuMemoryUsage(size_t &available, size_t &total); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, we should unify the code style
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't change this. Clang-format takes this job |
||
|
||
//! Get the maximum allocation size of current GPU device. | ||
size_t GpuMaxAllocSize(); | ||
|
@@ -42,6 +43,18 @@ size_t GpuMinChunkSize(); | |
//! Get the maximum chunk size for GPU buddy allocator. | ||
size_t GpuMaxChunkSize(); | ||
|
||
//! Copy memory from address src to dst asynchronously. | ||
void GpuMemcpyAsync(void *dst, const void *src, size_t count, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we move these copying functions into a new source file, say, |
||
enum cudaMemcpyKind kind, cudaStream_t stream); | ||
|
||
//! Copy memory from address src to dst synchronously. | ||
void GpuMemcpySync(void *dst, const void *src, size_t count, | ||
enum cudaMemcpyKind kind); | ||
|
||
//! Copy memory from one device to another device. | ||
void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, | ||
size_t count, cudaStream_t stream); | ||
|
||
} // namespace platform | ||
} // namespace paddle | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should use platform::GPUPlaceGuard here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's unnecessary to use the guard to implicitly roll back the device id.
For GPU device, it's better to explicitly set device id.