Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CPU/GPU Memcpy in memory folder #2970

Merged
merged 9 commits into from
Jul 21, 2017
Merged

Conversation

gangliao
Copy link
Contributor

@gangliao gangliao commented Jul 19, 2017

No description provided.

Copy link
Contributor Author

@gangliao gangliao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments

@@ -43,10 +43,26 @@ namespace platform {
// For more details, please check https://stackoverflow.com/a/43870188/724872.
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)

template <typename T>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix special case of PADDLE_ENFORCE

PADDLE_ENFORCE(condition, "hello world"); // OK, if using old implementation
PADDLE_ENFORCE(condition) // Failed, if using old implementation. But, it's addressed.

@gangliao gangliao changed the title CPU/GPU Memcpy in memory folder Add CPU/GPU Memcpy in memory folder Jul 20, 2017
@gangliao gangliao requested a review from wangkuiyi July 20, 2017 00:39
const void* src, size_t num,
cudaStream_t stream) {
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, use cudaMemcpyDeviceToHost and cudaMemcpyHostToDevice, cudaMemcpyDeviceToDevice. But the cudaMemcpyKind of cudaMemcpyDefault:

cudaMemcpyDefault = 4
Direction of the transfer is inferred from the pointer values. Requires unified virtual addressing

Can we use cudaMemcpyDefault to simplify code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. But, how to specialize one function to support both cases: (CPUPlace, GPUPlace),
(GPUPlace, CPUPlace).

One way may achieve that is std::enable_if, but it will dump too many annoying code.

@@ -31,7 +32,7 @@ int GetCurrentDeviceId();
void SetDeviceId(int device_id);

//!Get the memory usage of current GPU device.
void GpuMemoryUsage(size_t& available, size_t& total);
void GpuMemoryUsage(size_t &available, size_t &total);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we should unify the code style

size_t& available, size_t& total;

& and * should close to type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't change this. Clang-format takes this job

platform::CPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream) {
platform::SetDeviceId(dst_place.device);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should use platform::GPUPlaceGuard here

Copy link
Contributor Author

@gangliao gangliao Jul 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's unnecessary to use the guard to implicitly roll back the device id.
For GPU device, it's better to explicitly set device id.

Copy link
Collaborator

@wangkuiyi wangkuiyi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

#ifndef PADDLE_ONLY_CPU
template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
cudaStream_t stream);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to add a comment telling when would users call this second form of Copy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for replying too late due to we are on duty yesterday. Yeah, I will annotate this function. Thanks.

#ifndef PADDLE_ONLY_CPU

template <typename... Args>
inline void throw_on_error(cudaError_t e, const Args&... args) {
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

@@ -42,6 +43,18 @@ size_t GpuMinChunkSize();
//! Get the maximum chunk size for GPU buddy allocator.
size_t GpuMaxChunkSize();

//! Copy memory from address src to dst asynchronously.
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move these copying functions into a new source file, say, copy.{h,cc}? I am not sure. Just mention it.

@wangkuiyi wangkuiyi merged commit e1140f2 into PaddlePaddle:develop Jul 21, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants