From 39b86590d4ed10fc85a9740d66bff15d10c4e3df Mon Sep 17 00:00:00 2001 From: root Date: Tue, 21 Jun 2022 19:34:07 +0800 Subject: [PATCH 1/6] Optimizing the zero key problem in the push phase --- .../framework/fleet/heter_ps/heter_comm_inl.h | 21 +++++++------- .../fleet/heter_ps/heter_comm_kernel.cu | 28 ++++++++++++++----- .../fleet/heter_ps/heter_comm_kernel.h | 4 +-- paddle/fluid/platform/flags.cc | 3 ++ 4 files changed, 37 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 6364a3f7bf4af..f211e15b13e28 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -22,6 +22,7 @@ limitations under the License. */ #endif DECLARE_double(gpugraph_hbm_table_load_factor); +DECLARE_bool(gpugraph_enable_gpu_direct_access); namespace paddle { namespace framework { @@ -682,7 +683,7 @@ void HeterComm::dynamic_merge_grad( uniq_len, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); heter_comm_kernel_->merge_gradient( - d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, + d_keys, d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, (char*)d_merge_grads_ptr, uniq_len, grad_value_size, merger_, stream); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr, @@ -802,7 +803,7 @@ void HeterComm::pull_sparse(int num, memory_copy(dst_place, h_right, src_place, d_right_ptr, total_device * sizeof(int), stream); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { for (int i = 0; i < total_device; ++i) { int shard_len = h_right[i] - h_left[i] + 1; if (h_left[i] == -1 || h_right[i] == -1) { @@ -818,12 +819,12 @@ void HeterComm::pull_sparse(int num, continue; } auto& node = path_[num][i].nodes_.back(); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { sync_stream(node.in_stream); } AnyDeviceGuard guard(resource_->dev_id(i)); ptr_tables_[i]->rwlock_->RDLock(); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { ptr_tables_[i]->get(reinterpret_cast(node.key_storage), node.val_storage, h_right[i] - h_left[i] + 1, resource_->remote_stream(i, num)); @@ -842,7 +843,7 @@ void HeterComm::pull_sparse(int num, } ptr_tables_[i]->rwlock_->UNLock(); } - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { walk_to_src(num, total_device, h_left, h_right, reinterpret_cast(d_shard_vals_ptr), val_type_size); for (int i = 0; i < total_device; ++i) { @@ -855,7 +856,7 @@ void HeterComm::pull_sparse(int num, val_type_size, stream); sync_stream(stream); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { continue; @@ -946,7 +947,7 @@ void HeterComm::push_sparse(int dev_num, memory_copy(dst_place, h_right, src_place, d_right_ptr, total_device * sizeof(int), stream); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { for (int i = 0; i < total_device; ++i) { int shard_len = h_right[i] - h_left[i] + 1; if (h_left[i] == -1 || h_right[i] == -1) { @@ -965,13 +966,13 @@ void HeterComm::push_sparse(int dev_num, continue; } auto& node = path_[dev_num][i].nodes_.back(); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { sync_stream(node.in_stream); } AnyDeviceGuard guard(resource_->dev_id(i)); ptr_tables_[i]->rwlock_->WRLock(); - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { ptr_tables_[i]->update(reinterpret_cast(node.key_storage), node.val_storage, h_right[i] - h_left[i] + 1, sgd, resource_->remote_stream(i, dev_num)); @@ -995,7 +996,7 @@ void HeterComm::push_sparse(int dev_num, } } - if (!direct_access_) { + if (!FLAGS_gpugraph_enable_gpu_direct_access) { for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { continue; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 39a1469d9073c..38566da3990cc 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -146,7 +146,9 @@ __global__ void dy_mf_fill_shard_grads_kernel( } } -__global__ void merge_gradients_kernel(const uint32_t* offset, +template +__global__ void merge_gradients_kernel(const KeyType* d_keys, + const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, const char* input, char* output, int n, @@ -163,10 +165,13 @@ __global__ void merge_gradients_kernel(const uint32_t* offset, float* in = (float*)(input + size_t(ori_index) * grad_value_size); merger_.update_one(out, in, feature_value_accessor); - for (int j = 1; j < num; ++j) { - ori_index = index[start + j]; - in = (float*)(input + size_t(ori_index) * grad_value_size); - merger_.merge_one(out, in, feature_value_accessor); + KeyType key = d_keys[i]; + if (key != 0) { + for (int j = 1; j < num; ++j) { + ori_index = index[start + j]; + in = (float*)(input + size_t(ori_index) * grad_value_size); + merger_.merge_one(out, in, feature_value_accessor); + } } } } @@ -316,13 +321,15 @@ void HeterCommKernel::dy_mf_fill_shard_grads( grad_value_size, feature_value_accessor_); } -template +template void HeterCommKernel::merge_gradient( + const KeyType* d_keys, const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, const char* input, char* output, int n, size_t grad_value_size, DynamicGradMerger& merger_, const StreamType& stream) { int grid_size = (n - 1) / block_size_ + 1; merge_gradients_kernel<<>>( + d_keys, offset, fea_num, index, input, output, n, grad_value_size, merger_, feature_value_accessor_); } @@ -407,7 +414,14 @@ template void HeterCommKernel::dy_mf_fill_shard_grads< float* d_shard_grads, float* d_grads, int* idx, long long len, size_t grad_value_size, const cudaStream_t& stream); -template void HeterCommKernel::merge_gradient( +template void HeterCommKernel::merge_gradient( + const uint32_t* d_keys, + const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, + const char* input, char* output, int n, size_t grad_value_size, + DynamicGradMerger& merger_, const cudaStream_t& stream); + +template void HeterCommKernel::merge_gradient( + const uint64_t* d_keys, const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, const char* input, char* output, int n, size_t grad_value_size, DynamicGradMerger& merger_, const cudaStream_t& stream); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index 4bf77aaac202d..d02031f9e7e28 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -135,8 +135,8 @@ class HeterCommKernel { T* idx, long long len, size_t grad_value_size, const StreamType& stream); - template - void merge_gradient(const uint32_t* offset, const uint32_t* fea_num, + template + void merge_gradient(const KeyType* d_shard_keys, const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, const char* input, char* output, int n, size_t grad_value_size, DynamicGradMerger& merger_, const StreamType& stream); diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 71daee503e49d..b165a678f8f93 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -857,6 +857,9 @@ PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_double( gpugraph_hbm_table_load_factor, 0.75, "the load factor of hbm table, default 0.75"); +PADDLE_DEFINE_EXPORTED_bool( + gpugraph_enable_gpu_direct_access, false, + "enable hash collisions stat for hbm table, default false"); /** * ProcessGroupNCCL related FLAG From bc32d9dd0355f3464c01d60b9189f389c54b1295 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 22 Jun 2022 19:14:43 +0800 Subject: [PATCH 2/6] Optimize CUDA thread parallelism in MergeGrad phase --- .../framework/fleet/heter_ps/feature_value.h | 7 +++ .../framework/fleet/heter_ps/heter_comm.h | 1 - .../framework/fleet/heter_ps/heter_comm_inl.h | 3 +- .../fleet/heter_ps/heter_comm_kernel.cu | 58 +++++++++++++++---- .../fleet/heter_ps/heter_comm_kernel.h | 35 ++++++++++- 5 files changed, 90 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 039c4563cccf3..470c2065ae3b9 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -196,6 +196,13 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { _accessor_info.update_size = _accessor_info.update_dim * sizeof(float); _accessor_info.mf_size = (embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float); + + printf("dim:%lu size:%lu update_dim:%lu update_size:%lu mf_size:%lu\n", + _accessor_info.dim, + _accessor_info.size, + _accessor_info.update_dim, + _accessor_info.update_size, + _accessor_info.mf_size); } __host__ __device__ std::string ParseToString(const float* v, int param_size) { diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index cef917c98f854..45519d37165d2 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -247,7 +247,6 @@ class HeterComm { std::vector> path_; float load_factor_{0.75}; int block_size_{256}; - int direct_access_ = 1; std::unique_ptr heter_comm_kernel_; private: diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 9531d74adbaa4..4818996245214 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -628,6 +628,7 @@ void HeterComm::dynamic_merge_grad( size_t temp_storage_bytes; + size_t grad_dim = feature_value_accessor_.GetAccessorInfo().update_dim; size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); @@ -686,7 +687,7 @@ void HeterComm::dynamic_merge_grad( PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); heter_comm_kernel_->merge_gradient( d_keys, d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, - (char*)d_merge_grads_ptr, uniq_len, grad_value_size, merger_, stream); + (char*)d_merge_grads_ptr, uniq_len, grad_dim, grad_value_size, merger_, stream); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr, grad_value_size * uniq_len, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 38566da3990cc..512e54f899d94 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -147,13 +147,13 @@ __global__ void dy_mf_fill_shard_grads_kernel( } template -__global__ void merge_gradients_kernel(const KeyType* d_keys, +__global__ void merge_gradients_basic_kernel(const KeyType* d_keys, const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, const char* input, char* output, int n, size_t grad_value_size, - DynamicGradMerger& merger_, + DynamicGradMerger& merger, CommonFeatureValueAccessor& feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; @@ -164,18 +164,48 @@ __global__ void merge_gradients_kernel(const KeyType* d_keys, float* out = (float*)(output + i * grad_value_size); float* in = (float*)(input + size_t(ori_index) * grad_value_size); - merger_.update_one(out, in, feature_value_accessor); + merger.update_basic(out, in, feature_value_accessor); KeyType key = d_keys[i]; if (key != 0) { for (int j = 1; j < num; ++j) { ori_index = index[start + j]; in = (float*)(input + size_t(ori_index) * grad_value_size); - merger_.merge_one(out, in, feature_value_accessor); + merger.merge_basic(out, in, feature_value_accessor); } } } } +template +__global__ void merge_gradients_embedx_kernel(const KeyType* d_keys, + const uint32_t* offset, + const uint32_t* fea_num, + const uint32_t* index, const char* input, + char* output, int n, + size_t grad_dim, + size_t grad_value_size, + DynamicGradMerger& merger, + CommonFeatureValueAccessor& feature_value_accessor) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i < n) { + size_t value_idx = i / grad_dim; + size_t field_idx = i % grad_dim; + uint32_t start = offset[value_idx]; + uint32_t num = fea_num[value_idx]; + float* out = (float*)(output + value_idx * grad_value_size); + KeyType key = d_keys[value_idx]; + if (key != 0) { + for (int j = 0; j < num; ++j) { + int ori_index = index[start + j]; + float* in = (float*)(input + size_t(ori_index) * grad_value_size); + merger.merge_embedx(out, in, field_idx, feature_value_accessor); + } + } + //printf("merge kernel, i=%lu num=%u key=%lu\n", value_idx, num, key); + } +} + template __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, float* d_vals, T* idx, size_t len, size_t val_size, @@ -325,12 +355,18 @@ template void HeterCommKernel::merge_gradient( const KeyType* d_keys, const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, - const char* input, char* output, int n, size_t grad_value_size, - DynamicGradMerger& merger_, const StreamType& stream) { - int grid_size = (n - 1) / block_size_ + 1; - merge_gradients_kernel<<>>( + const char* input, char* output, int n, size_t grad_dim, size_t grad_value_size, + DynamicGradMerger& merger, const StreamType& stream) { + int grid_size1 = (n - 1) / block_size_ + 1; + merge_gradients_basic_kernel<<>>( d_keys, - offset, fea_num, index, input, output, n, grad_value_size, merger_, feature_value_accessor_); + offset, fea_num, index, input, output, n, grad_value_size, merger, feature_value_accessor_); + if (grad_dim > 0) { + int grid_size2 = (n * grad_dim - 1) / block_size_ + 1; + merge_gradients_embedx_kernel<<>>( + d_keys, + offset, fea_num, index, input, output, n, grad_dim, grad_value_size, merger, feature_value_accessor_); + } } template @@ -417,13 +453,13 @@ template void HeterCommKernel::dy_mf_fill_shard_grads< template void HeterCommKernel::merge_gradient( const uint32_t* d_keys, const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, - const char* input, char* output, int n, size_t grad_value_size, + const char* input, char* output, int n, size_t grad_dim, size_t grad_value_size, DynamicGradMerger& merger_, const cudaStream_t& stream); template void HeterCommKernel::merge_gradient( const uint64_t* d_keys, const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, - const char* input, char* output, int n, size_t grad_value_size, + const char* input, char* output, int n, size_t grad_dim, size_t grad_value_size, DynamicGradMerger& merger_, const cudaStream_t& stream); template void HeterCommKernel::dy_mf_fill_dvals( diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index d02031f9e7e28..b49964839c824 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -72,6 +72,39 @@ struct DynamicGradMerger { input[feature_value_accessor.common_push_value.EmbedxGIndex() + j]; } } + + __device__ __forceinline__ void update_basic(float* output, const float* input, + CommonFeatureValueAccessor& feature_value_accessor) { + output[feature_value_accessor.common_push_value.SlotIndex()] = + input[feature_value_accessor.common_push_value.SlotIndex()]; + output[feature_value_accessor.common_push_value.ShowIndex()] = + input[feature_value_accessor.common_push_value.ShowIndex()]; + output[feature_value_accessor.common_push_value.ClickIndex()] = + input[feature_value_accessor.common_push_value.ClickIndex()]; + output[feature_value_accessor.common_push_value.MfDimIndex()] = + input[feature_value_accessor.common_push_value.MfDimIndex()]; + output[feature_value_accessor.common_push_value.EmbedGIndex()] = + input[feature_value_accessor.common_push_value.EmbedGIndex()]; + } + + __device__ __forceinline__ void merge_basic(float* output, const float* input, + CommonFeatureValueAccessor& feature_value_accessor) { + output[feature_value_accessor.common_push_value.ShowIndex()] += + input[feature_value_accessor.common_push_value.ShowIndex()]; + output[feature_value_accessor.common_push_value.ClickIndex()] += + input[feature_value_accessor.common_push_value.ClickIndex()]; + output[feature_value_accessor.common_push_value.EmbedGIndex()] += + input[feature_value_accessor.common_push_value.EmbedGIndex()]; + } + + + __device__ __forceinline__ void merge_embedx(float* output, const float* input, size_t embedx_idx, + CommonFeatureValueAccessor& feature_value_accessor) { + if (embedx_idx < output[feature_value_accessor.common_push_value.MfDimIndex()]) { + output[feature_value_accessor.common_push_value.EmbedxGIndex() + embedx_idx] = + input[feature_value_accessor.common_push_value.EmbedxGIndex() + embedx_idx]; + } + } }; class HeterCommKernel { @@ -138,7 +171,7 @@ class HeterCommKernel { template void merge_gradient(const KeyType* d_shard_keys, const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, const char* input, char* output, - int n, size_t grad_value_size, DynamicGradMerger& merger_, + int n, size_t grad_dim, size_t grad_value_size, DynamicGradMerger& merger, const StreamType& stream); template From 960469b2cab477ce25411e3039b532539267906c Mon Sep 17 00:00:00 2001 From: root Date: Wed, 22 Jun 2022 20:53:04 +0800 Subject: [PATCH 3/6] Optimize CUDA thread parallelism in MergeGrad phase --- .../framework/fleet/heter_ps/feature_value.h | 3 ++ .../framework/fleet/heter_ps/heter_comm_inl.h | 2 +- .../fleet/heter_ps/heter_comm_kernel.cu | 7 ++- .../fleet/heter_ps/heter_comm_kernel.h | 51 +++++++++++-------- 4 files changed, 38 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 470c2065ae3b9..b488b10e815f9 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -32,6 +32,8 @@ struct GpuAccessorInfo { size_t dim; // value各个维度的size size_t size; + // embedx维度 + size_t embedx_dim; // push value维度 size_t update_dim; // push value各个维度的size @@ -192,6 +194,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { ? 8 : int(_config["embedx_dim"]); // VLOG(0) << "feature value InitAccessorInfo embedx_dim:" << embedx_dim; + _accessor_info.embedx_dim = embedx_dim; _accessor_info.update_dim = 5 + embedx_dim; _accessor_info.update_size = _accessor_info.update_dim * sizeof(float); _accessor_info.mf_size = diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 4818996245214..e4b5ca8323ecc 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -628,7 +628,7 @@ void HeterComm::dynamic_merge_grad( size_t temp_storage_bytes; - size_t grad_dim = feature_value_accessor_.GetAccessorInfo().update_dim; + size_t grad_dim = feature_value_accessor_.GetAccessorInfo().embedx_dim; size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 512e54f899d94..82b290559a346 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -173,6 +173,7 @@ __global__ void merge_gradients_basic_kernel(const KeyType* d_keys, merger.merge_basic(out, in, feature_value_accessor); } } + printf("merge kernel, i=%lu num=%u key=%lu\n", i, num, key); } } @@ -193,16 +194,18 @@ __global__ void merge_gradients_embedx_kernel(const KeyType* d_keys, size_t field_idx = i % grad_dim; uint32_t start = offset[value_idx]; uint32_t num = fea_num[value_idx]; + int ori_index = index[start]; + float* in = (float*)(input + size_t(ori_index) * grad_value_size); float* out = (float*)(output + value_idx * grad_value_size); + merger.update_embedx(out, in, field_idx, feature_value_accessor); KeyType key = d_keys[value_idx]; if (key != 0) { - for (int j = 0; j < num; ++j) { + for (int j = 1; j < num; ++j) { int ori_index = index[start + j]; float* in = (float*)(input + size_t(ori_index) * grad_value_size); merger.merge_embedx(out, in, field_idx, feature_value_accessor); } } - //printf("merge kernel, i=%lu num=%u key=%lu\n", value_idx, num, key); } } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index b49964839c824..930dafc944371 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -74,35 +74,42 @@ struct DynamicGradMerger { } __device__ __forceinline__ void update_basic(float* output, const float* input, - CommonFeatureValueAccessor& feature_value_accessor) { - output[feature_value_accessor.common_push_value.SlotIndex()] = - input[feature_value_accessor.common_push_value.SlotIndex()]; - output[feature_value_accessor.common_push_value.ShowIndex()] = - input[feature_value_accessor.common_push_value.ShowIndex()]; - output[feature_value_accessor.common_push_value.ClickIndex()] = - input[feature_value_accessor.common_push_value.ClickIndex()]; - output[feature_value_accessor.common_push_value.MfDimIndex()] = - input[feature_value_accessor.common_push_value.MfDimIndex()]; - output[feature_value_accessor.common_push_value.EmbedGIndex()] = - input[feature_value_accessor.common_push_value.EmbedGIndex()]; + CommonFeatureValueAccessor& fv_accessor) { + output[fv_accessor.common_push_value.SlotIndex()] = + input[fv_accessor.common_push_value.SlotIndex()]; + output[fv_accessor.common_push_value.ShowIndex()] = + input[fv_accessor.common_push_value.ShowIndex()]; + output[fv_accessor.common_push_value.ClickIndex()] = + input[fv_accessor.common_push_value.ClickIndex()]; + output[fv_accessor.common_push_value.MfDimIndex()] = + input[fv_accessor.common_push_value.MfDimIndex()]; + output[fv_accessor.common_push_value.EmbedGIndex()] = + input[fv_accessor.common_push_value.EmbedGIndex()]; } __device__ __forceinline__ void merge_basic(float* output, const float* input, - CommonFeatureValueAccessor& feature_value_accessor) { - output[feature_value_accessor.common_push_value.ShowIndex()] += - input[feature_value_accessor.common_push_value.ShowIndex()]; - output[feature_value_accessor.common_push_value.ClickIndex()] += - input[feature_value_accessor.common_push_value.ClickIndex()]; - output[feature_value_accessor.common_push_value.EmbedGIndex()] += - input[feature_value_accessor.common_push_value.EmbedGIndex()]; + CommonFeatureValueAccessor& fv_accessor) { + output[fv_accessor.common_push_value.ShowIndex()] += + input[fv_accessor.common_push_value.ShowIndex()]; + output[fv_accessor.common_push_value.ClickIndex()] += + input[fv_accessor.common_push_value.ClickIndex()]; + output[fv_accessor.common_push_value.EmbedGIndex()] += + input[fv_accessor.common_push_value.EmbedGIndex()]; } + __device__ __forceinline__ void update_embedx(float* output, const float* input, size_t embedx_idx, + CommonFeatureValueAccessor& fv_accessor) { + if (embedx_idx < output[fv_accessor.common_push_value.MfDimIndex()]) { + output[fv_accessor.common_push_value.EmbedxGIndex() + embedx_idx] = + input[fv_accessor.common_push_value.EmbedxGIndex() + embedx_idx]; + } + } __device__ __forceinline__ void merge_embedx(float* output, const float* input, size_t embedx_idx, - CommonFeatureValueAccessor& feature_value_accessor) { - if (embedx_idx < output[feature_value_accessor.common_push_value.MfDimIndex()]) { - output[feature_value_accessor.common_push_value.EmbedxGIndex() + embedx_idx] = - input[feature_value_accessor.common_push_value.EmbedxGIndex() + embedx_idx]; + CommonFeatureValueAccessor& fv_accessor) { + if (embedx_idx < output[fv_accessor.common_push_value.MfDimIndex()]) { + output[fv_accessor.common_push_value.EmbedxGIndex() + embedx_idx] += + input[fv_accessor.common_push_value.EmbedxGIndex() + embedx_idx]; } } }; From b19d610471be1971c01d4420c4f2c2b21bcfd75f Mon Sep 17 00:00:00 2001 From: root Date: Wed, 29 Jun 2022 22:30:18 +0800 Subject: [PATCH 4/6] Performance optimization, segment gradient merging --- .../framework/fleet/heter_ps/heter_comm.h | 18 +- .../framework/fleet/heter_ps/heter_comm_inl.h | 183 ++++++++++++++++-- .../fleet/heter_ps/heter_comm_kernel.cu | 99 ++++++++++ .../fleet/heter_ps/heter_comm_kernel.h | 16 ++ paddle/fluid/framework/hogwild_worker.cc | 6 +- paddle/fluid/platform/flags.cc | 6 + 6 files changed, 297 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index 45519d37165d2..b2c7a8eca7c0e 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/platform/timer.h" #include "thrust/pair.h" #elif defined(PADDLE_WITH_XPU_KP) -// #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h" #include #include "paddle/fluid/platform/device/xpu/enforce_xpu.h" #endif @@ -55,16 +54,20 @@ class HeterComm { HeterComm& operator=(const HeterComm&) = delete; void split_input_to_shard(KeyType* d_keys, int* d_idx_ptr, size_t len, - int* left, int* right, int gpu_num); + int* left, int* right, int gpu_num); void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, - int& uniq_len); // NOLINT + int& uniq_len); // NOLINT void dynamic_merge_grad(int gpu_num, KeyType* d_keys, float* d_grads, - size_t len, int& uniq_len); + size_t len, int& uniq_len, size_t& segment_len, bool enable_segment_merge_grad); + void segment_merge_grad(int gpu_num, KeyType* d_keys, float* d_grads, + const uint32_t* d_index, size_t len, + const uint32_t* d_fea_num_info, const uint32_t* d_offset, + size_t uniq_len, size_t& segment_len); void pull_sparse(int num, KeyType* d_keys, float* d_vals, size_t len); void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len, - size_t chunk_size, int stream_num, int offset = -1); + size_t chunk_size, int stream_num, int offset = -1); void build_ps(int num, KeyType* h_keys, char* pool, size_t len, - size_t feature_value_size, size_t chunk_size, int stream_num); + size_t feature_value_size, size_t chunk_size, int stream_num); void dump(); void show_one_table(int gpu_num); void show_table_collisions(); @@ -237,7 +240,6 @@ class HeterComm { char* src_val, size_t val_size); - CommonFeatureValueAccessor feature_value_accessor_; protected: using Table = HashTable; using PtrTable = HashTable; @@ -249,6 +251,8 @@ class HeterComm { int block_size_{256}; std::unique_ptr heter_comm_kernel_; + CommonFeatureValueAccessor feature_value_accessor_; + private: int topo_aware_{0}; std::vector storage_; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 21b85acef9e14..fe08f966e2206 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -24,6 +24,8 @@ limitations under the License. */ DECLARE_double(gpugraph_hbm_table_load_factor); DECLARE_bool(gpugraph_enable_gpu_direct_access); +DECLARE_bool(gpugraph_enable_segment_merge_grads); +DECLARE_uint64(gpugraph_merge_grads_segment_size); namespace paddle { namespace framework { @@ -621,7 +623,7 @@ void HeterComm::merge_grad( template void HeterComm::dynamic_merge_grad( int gpu_num, KeyType* d_keys, float* d_grads, size_t len, - int& uniq_len) { + int& uniq_len, size_t& segment_len, bool enable_segment_merge_grad) { int dev_id = resource_->dev_id(gpu_num); platform::CUDAPlace place = platform::CUDAPlace(dev_id); platform::CUDADeviceGuard guard(dev_id); @@ -635,17 +637,12 @@ void HeterComm::dynamic_merge_grad( auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); - auto d_merge_grads = memory::Alloc(place, len * grad_value_size); - float* d_merge_grads_ptr = - reinterpret_cast(d_merge_grads->ptr()); - auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1)); uint32_t* d_fea_num_info_ptr = reinterpret_cast(d_fea_num_info->ptr()); uint32_t* d_index = (uint32_t*)&d_fea_num_info_ptr[len]; uint32_t* d_idx = (uint32_t*)&d_index[len]; int* d_merged_size = (int*)&d_idx[len]; - int grid_size = (len - 1) / block_size_ + 1; heter_comm_kernel_->fill_idx(d_idx, len, stream); PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_idx, d_index, len, @@ -686,13 +683,151 @@ void HeterComm::dynamic_merge_grad( d_temp_storage->ptr(), temp_storage_bytes, d_fea_num_info_ptr, d_offset, uniq_len, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + if (enable_segment_merge_grad) { + platform::Timer timeline; + timeline.Start(); + segment_merge_grad( + gpu_num, + d_keys, d_grads, d_index, len, + d_fea_num_info_ptr, + d_offset, uniq_len, + segment_len); + timeline.Pause(); + VLOG(0) << "card:" << dev_id << ", segment_merge_grad cost " << + timeline.ElapsedSec() << "seconds"; + } else { + platform::Timer timeline; + timeline.Start(); + + auto d_merge_grads = memory::Alloc(place, len * grad_value_size); + float* d_merge_grads_ptr = reinterpret_cast(d_merge_grads->ptr()); + + heter_comm_kernel_->merge_gradient( + d_keys, d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, + (char*)d_merge_grads_ptr, uniq_len, grad_dim, grad_value_size, merger_, stream); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr, + grad_value_size * uniq_len, + cudaMemcpyDeviceToDevice, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + timeline.Pause(); + VLOG(0) << "card:" << dev_id << ", raw merge_grad cost " << + timeline.ElapsedSec() << "seconds"; + } +} + +template +void HeterComm::segment_merge_grad( + int gpu_num, // the device number + KeyType* d_keys, // the sorted keys list, which will be modified after merged + float* d_grads, // the raw grads list, which will be modified after merged + const uint32_t* d_index, // the storage position of d_keys, its length is len. + size_t len, // the number of raw input keys + const uint32_t* d_fea_num_info, // prefix sum array, its length is uniq_len+1 + const uint32_t* d_offset, // prefix sum array, its length is uniq_len+1 + size_t uniq_len, // the number of unique keys + size_t& segments_num) { // the number of segment merged keys + + int dev_id = resource_->dev_id(gpu_num); + platform::CUDAPlace place = platform::CUDAPlace(dev_id); + platform::CUDADeviceGuard guard(dev_id); + auto stream = resource_->local_stream(gpu_num, 0); + auto grad_dim = max_mf_dim_; + auto grad_value_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + + auto d_buffer1 = memory::Alloc(place, sizeof(uint32_t) * len); + auto d_segments = reinterpret_cast(d_buffer1->ptr()); + auto d_buffer2 = memory::Alloc(place, sizeof(uint32_t) * len); + auto d_segments_offset = reinterpret_cast(d_buffer2->ptr()); + auto d_buffer3 = memory::Alloc(place, sizeof(uint32_t) * len); + auto d_segments_fea_num_info = reinterpret_cast(d_buffer3->ptr()); + auto d_buffer4 = memory::Alloc(place, sizeof(uint32_t)); + auto d_segments_num = reinterpret_cast(d_buffer4->ptr()); + CUDA_CHECK(cudaMemsetAsync(d_segments_num, 0, sizeof(uint32_t), stream)); + + uint32_t segment_size = FLAGS_gpugraph_merge_grads_segment_size; + heter_comm_kernel_->split_segments( + d_fea_num_info, uniq_len, + d_segments, + d_segments_num, + segment_size, stream); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + +#if 0 + size_t temp_storage_bytes = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Sum( + NULL, temp_storage_bytes, d_segments, d_segments_num, + uniq_len, stream)); + auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Sum( + d_temp_storage->ptr(), temp_storage_bytes, d_segments, d_segments_num, + uniq_len, stream)); +#else + CUDA_CHECK(cudaMemcpyAsync(&segments_num, d_segments_num, sizeof(uint32_t), + cudaMemcpyDeviceToHost, stream)); +#endif + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + size_t temp_storage_bytes = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( + NULL, temp_storage_bytes, d_segments, d_segments_offset, + uniq_len, stream)); + auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( + d_temp_storage->ptr(), temp_storage_bytes, d_segments, d_segments_offset, + uniq_len, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + heter_comm_kernel_->expand_segments( + d_fea_num_info, + d_segments_offset, uniq_len, + d_segments_fea_num_info, segment_size, stream); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( + NULL, temp_storage_bytes, d_segments_fea_num_info, d_segments_offset, + segments_num, stream)); + if (d_temp_storage->size() < temp_storage_bytes) { + d_temp_storage = NULL; + d_temp_storage = memory::Alloc(place, temp_storage_bytes); + } + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( + d_temp_storage->ptr(), temp_storage_bytes, d_segments_fea_num_info, d_segments_offset, + segments_num, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + platform::Timer timeline; + timeline.Start(); + + auto d_segment_grads = memory::Alloc(place, segments_num * grad_value_size); + float* d_segment_grads_ptr = reinterpret_cast(d_segment_grads->ptr()); heter_comm_kernel_->merge_gradient( - d_keys, d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads, - (char*)d_merge_grads_ptr, uniq_len, grad_dim, grad_value_size, merger_, stream); + d_keys, d_segments_offset, d_segments_fea_num_info, d_index, + (char*)d_grads, (char*)d_segment_grads_ptr, segments_num, + grad_dim, grad_value_size, merger_, stream); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr, - grad_value_size * uniq_len, - cudaMemcpyDeviceToDevice, stream)); + + timeline.Pause(); + VLOG(0) << "card:" << dev_id << ", segment merge_grad cost " << + timeline.ElapsedSec() << "seconds" << ", len=" << len << + ", uniq_len=" << uniq_len << ", segments_num=" << segments_num; + + auto d_segments_keys = memory::Alloc(place, sizeof(KeyType) * len); + auto d_segments_keys_ptr = reinterpret_cast(d_segments_keys->ptr()); + heter_comm_kernel_->shrink_keys( + d_keys, d_segments_offset, + d_segments_keys_ptr, segments_num, + stream); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_keys, d_segments_keys_ptr, + sizeof(KeyType) * segments_num, + cudaMemcpyDeviceToDevice, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_segment_grads_ptr, + grad_value_size * segments_num, + cudaMemcpyDeviceToDevice, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); } @@ -715,21 +850,17 @@ void HeterComm::split_input_to_shard( auto d_shard_index_tmp = memory::Alloc(place, len * sizeof(int)); int* d_shard_index_tmp_ptr = reinterpret_cast(d_shard_index_tmp->ptr()); - // int grid_size = (len - 1) / block_size_ + 1; - heter_comm_kernel_->fill_idx(d_idx_tmp_ptr, len, stream); heter_comm_kernel_->calc_shard_index(d_keys, len, d_shard_index_tmp_ptr, total_device, stream); size_t temp_storage_bytes; const int num_bits = 1 + log2i(total_device); - heter_comm_kernel_->sort_pairs( NULL, temp_storage_bytes, d_shard_index_tmp_ptr, d_shard_index_ptr, d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream); auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); - heter_comm_kernel_->sort_pairs( d_temp_storage->ptr(), temp_storage_bytes, d_shard_index_tmp_ptr, d_shard_index_ptr, d_idx_tmp_ptr, d_idx_ptr, len, 0, num_bits, stream); @@ -856,8 +987,10 @@ void HeterComm::pull_sparse(int num, sync_stream(node.out_stream); } } + heter_comm_kernel_->dy_mf_fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len, val_type_size, stream); + sync_stream(stream); if (!FLAGS_gpugraph_enable_gpu_direct_access) { for (int i = 0; i < total_device; ++i) { @@ -930,9 +1063,20 @@ void HeterComm::push_sparse(int dev_num, d_shard_grads_ptr = reinterpret_cast(d_shard_grads->ptr()); int uniq_len = len; - dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len); - - int grid_size = (uniq_len - 1) / block_size_ + 1; + size_t segment_len = 0; + if (FLAGS_gpugraph_enable_segment_merge_grads) { + // do two gradient merge + // 1st. do segmented gradient merge + // 2nd. do global gradient merge + dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len, segment_len, true); + len = segment_len; + uniq_len = 0; + segment_len = 0; + dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len, segment_len, false); + } else { + // Perform gradient merge only once + dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len, segment_len, false); + } split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, dev_num); @@ -1067,8 +1211,6 @@ void HeterComm::push_sparse(int dev_num, int uniq_len = len; merge_grad(dev_num, d_keys, d_grads, len, uniq_len); - // int grid_size = (uniq_len - 1) / block_size_ + 1; - split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, dev_num); @@ -1242,7 +1384,6 @@ int HeterComm::gather_one_node_grad( cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost); - // int grid_size = (h_node_len[i] - 1) / block_size_ + 1; heter_comm_kernel_->fill_shard_grads( storage.local_keys + merge_num, storage.all_keys + index, storage.local_grads + merge_num, storage.all_grads + index, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 415865ebba8dd..0e734a4c58946 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -208,6 +208,58 @@ __global__ void merge_gradients_embedx_kernel(const KeyType* d_keys, } } +__global__ void split_segments_kernel( + const uint32_t* d_fea_num_info, size_t n, + uint32_t* d_segments, uint32_t* d_segments_num, + uint32_t segment_size) { + const size_t tx = blockIdx.x * blockDim.x + threadIdx.x; + if (tx >= n) { + return; + } + + auto fea_num = d_fea_num_info[tx]; + auto seg_num = (uint32_t)((fea_num - 1) / segment_size + 1); + d_segments[tx] = seg_num; + atomicAdd(d_segments_num, seg_num); +} + +__global__ void expand_segments_kernel( + const uint32_t* d_fea_num_info, + const uint32_t* d_segments_offset, size_t n, + uint32_t* d_segments_fea_num_info, uint32_t segment_size) { + const size_t tx = blockIdx.x * blockDim.x + threadIdx.x; + if (tx >= n) { + return; + } + + auto fea_num = d_fea_num_info[tx]; + auto seg_num = (uint32_t)((fea_num - 1) / segment_size + 1); + auto start_pos = d_segments_offset[tx]; + auto remains = fea_num; + int cur_seg_size = 0; + for (size_t i = 0; i < seg_num; ++i) { + if (remains >= segment_size) { + cur_seg_size = segment_size; + } else { + cur_seg_size = remains; + } + d_segments_fea_num_info[start_pos + i] = cur_seg_size; + remains -= cur_seg_size; + } +} + +template +__global__ void shrink_keys_kernel( + const KeyType* d_keys, const uint32_t* d_segments_offset, + KeyType* d_segments_keys, size_t n) { + const size_t tx = blockIdx.x * blockDim.x + threadIdx.x; + if (tx >= n) { + return; + } + + d_segments_keys[tx] = d_keys[d_segments_offset[tx]]; +} + template __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, float* d_vals, T* idx, size_t len, size_t val_size, @@ -376,6 +428,34 @@ void HeterCommKernel::dy_mf_fill_dvals(float* d_shard_vals, float* d_vals, d_shard_vals, d_vals, idx, c_len, val_size, feature_value_accessor_); } +template +void HeterCommKernel::split_segments(const uint32_t* d_fea_num_info, size_t n, + uint32_t* d_segments, uint32_t* d_segments_num, size_t segment_size, const StreamType& stream) { + int grid_size = (n - 1) / block_size_ + 1; + split_segments_kernel<<>>( + d_fea_num_info, n, d_segments, d_segments_num, segment_size); +} + +template +void HeterCommKernel::expand_segments(const uint32_t* d_fea_num_info, + const uint32_t* d_segments_offset, size_t n, + uint32_t* d_segments_fea_num_info, uint32_t segment_size, + const StreamType& stream) { + int grid_size = (n - 1) / block_size_ + 1; + expand_segments_kernel<<>>( + d_fea_num_info, + d_segments_offset, n, + d_segments_fea_num_info, segment_size); +} + +template +void HeterCommKernel::shrink_keys(const KeyType* d_keys, const uint32_t* d_segments_offset, + KeyType* d_segments_keys, size_t n, const StreamType& stream) { + int grid_size = (n - 1) / block_size_ + 1; + shrink_keys_kernel<<>>( + d_keys, d_segments_offset, d_segments_keys, n); +} + template void HeterCommKernel::fill_idx( int* idx, long long len, const cudaStream_t& stream); template void HeterCommKernel::fill_idx( @@ -463,6 +543,25 @@ template void HeterCommKernel::dy_mf_fill_dvals( float* d_shard_vals, float* d_vals, int* idx, long long len, size_t val_size, const cudaStream_t& stream); + +template void HeterCommKernel::split_segments( + const uint32_t* d_fea_num_info, size_t n, + uint32_t* d_segment, uint32_t* d_segments_num, size_t segment_size, + const cudaStream_t& stream); + +template void HeterCommKernel::expand_segments( + const uint32_t* d_fea_num_info, + const uint32_t* d_segments_offset, size_t n, + uint32_t* d_segments_fea_num_info, uint32_t segment_size, + const cudaStream_t& stream); + +template void HeterCommKernel::shrink_keys( + const uint32_t* d_keys, const uint32_t* d_segments_offset, + uint32_t* d_segments_keys, size_t segment_num, const cudaStream_t& stream); + +template void HeterCommKernel::shrink_keys( + const uint64_t* d_keys, const uint32_t* d_segments, + uint64_t* d_segments_keys, size_t total_segment_num, const cudaStream_t& stream); #endif } // namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index 930dafc944371..0931c4e1cd34a 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -186,7 +186,23 @@ class HeterCommKernel { long long len, size_t val_size, const StreamType& stream); + template + void split_segments(const uint32_t* d_fea_num_info, + size_t len, uint32_t* d_segments, uint32_t* d_segments_num, + size_t segment_size, const StreamType& stream); + + template + void expand_segments(const uint32_t* d_fea_num_info, + const uint32_t* d_segments_offset, size_t segments_num, + uint32_t* d_segments_fea_num_info, uint32_t segment_size, + const StreamType& stream); + + template + void shrink_keys(const KeyType* d_keys, const uint32_t* d_segments_offset, + KeyType* d_segments_keys, size_t segments_num, const StreamType& stream); + CommonFeatureValueAccessor feature_value_accessor_; + private: int block_size_{256}; }; diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index cee122e540f7e..84bf12ed31a66 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -225,7 +225,7 @@ void HogwildWorker::TrainFiles() { platform::SetXPUDeviceId(thread_id_); #endif - int total_ins_num = 0; + int total_batch_num = 0; // how to accumulate fetched values here device_reader_->Start(); int cur_batch; @@ -255,7 +255,7 @@ void HogwildWorker::TrainFiles() { DumpParam(*thread_scope_, batch_cnt); } - total_ins_num += cur_batch; + total_batch_num += cur_batch; ++batch_cnt; PrintFetchVars(); thread_scope_->DropKids(); @@ -265,7 +265,7 @@ void HogwildWorker::TrainFiles() { } timeline.Pause(); VLOG(0) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec() - << " seconds, ins_num: " << total_ins_num; + << " seconds, batch_num: " << total_batch_num; if (need_dump_field_ || need_dump_param_) { writer_.Flush(); diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index e482020cf97db..33198c11cc2af 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -860,6 +860,12 @@ PADDLE_DEFINE_EXPORTED_double( PADDLE_DEFINE_EXPORTED_bool( gpugraph_enable_gpu_direct_access, false, "enable direct access bwtween multi gpu cards, default false"); +PADDLE_DEFINE_EXPORTED_bool( + gpugraph_enable_segment_merge_grads, false, + "enable segment merge gradients while push sparse, default false"); +PADDLE_DEFINE_EXPORTED_uint64( + gpugraph_merge_grads_segment_size, 128, + "segment size with segment gradient merge, default 128"); /** * ProcessGroupNCCL related FLAG From b277cfc5d2c201206098765d06acc0a63a23c35e Mon Sep 17 00:00:00 2001 From: root Date: Thu, 30 Jun 2022 20:05:58 +0800 Subject: [PATCH 5/6] Performance optimization, segment gradient merging --- .../framework/fleet/heter_ps/heter_comm_inl.h | 59 ++++++++++++++----- .../fleet/heter_ps/heter_comm_kernel.cu | 1 - 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index fe08f966e2206..b36c527ef6065 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -29,6 +29,25 @@ DECLARE_uint64(gpugraph_merge_grads_segment_size); namespace paddle { namespace framework { +template +void show_list(int gpu_id, const T* d_ids, int len, const char* desc) { + static int i = 0; + if (gpu_id != 0) return; + if (++i > 4) return; + + T* h_nodes = nullptr; + size_t size = sizeof(T) * len; + cudaMallocHost((void**)&h_nodes, size); + cudaMemcpy(h_nodes, d_ids, size, cudaMemcpyDeviceToHost); + for (size_t idx = 0; idx < len; ++idx) { + VLOG(0) << "device:" << gpu_id << + ", " << "list[" << idx << "]=" << h_nodes[idx] << + ", desc=" << desc; + } + cudaFree(&h_nodes); + h_nodes = nullptr; +} + template HeterComm::HeterComm( size_t capacity, std::shared_ptr resource) { @@ -630,7 +649,6 @@ void HeterComm::dynamic_merge_grad( auto stream = resource_->local_stream(gpu_num, 0); size_t temp_storage_bytes; - size_t grad_dim = max_mf_dim_; size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); @@ -689,7 +707,8 @@ void HeterComm::dynamic_merge_grad( timeline.Start(); segment_merge_grad( gpu_num, - d_keys, d_grads, d_index, len, + //d_keys, d_grads, d_index, len, + d_merge_keys_ptr, d_grads, d_index, len, d_fea_num_info_ptr, d_offset, uniq_len, segment_len); @@ -713,8 +732,9 @@ void HeterComm::dynamic_merge_grad( PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); timeline.Pause(); - VLOG(0) << "card:" << dev_id << ", raw merge_grad cost " << - timeline.ElapsedSec() << "seconds"; + VLOG(0) << "card:" << dev_id << ", merge_grad cost " << + timeline.ElapsedSec() << "seconds" << ", len=" << len << + ", uniq_len=" << uniq_len << ", segments_num=" << segment_len; } } @@ -736,13 +756,8 @@ void HeterComm::segment_merge_grad( auto stream = resource_->local_stream(gpu_num, 0); auto grad_dim = max_mf_dim_; auto grad_value_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); - auto d_buffer1 = memory::Alloc(place, sizeof(uint32_t) * len); auto d_segments = reinterpret_cast(d_buffer1->ptr()); - auto d_buffer2 = memory::Alloc(place, sizeof(uint32_t) * len); - auto d_segments_offset = reinterpret_cast(d_buffer2->ptr()); - auto d_buffer3 = memory::Alloc(place, sizeof(uint32_t) * len); - auto d_segments_fea_num_info = reinterpret_cast(d_buffer3->ptr()); auto d_buffer4 = memory::Alloc(place, sizeof(uint32_t)); auto d_segments_num = reinterpret_cast(d_buffer4->ptr()); CUDA_CHECK(cudaMemsetAsync(d_segments_num, 0, sizeof(uint32_t), stream)); @@ -755,7 +770,6 @@ void HeterComm::segment_merge_grad( segment_size, stream); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); -#if 0 size_t temp_storage_bytes = 0; PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Sum( NULL, temp_storage_bytes, d_segments, d_segments_num, @@ -764,17 +778,23 @@ void HeterComm::segment_merge_grad( PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Sum( d_temp_storage->ptr(), temp_storage_bytes, d_segments, d_segments_num, uniq_len, stream)); -#else CUDA_CHECK(cudaMemcpyAsync(&segments_num, d_segments_num, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream)); -#endif PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); - size_t temp_storage_bytes = 0; + auto d_buffer2 = memory::Alloc(place, sizeof(uint32_t) * segments_num); + auto d_segments_offset = reinterpret_cast(d_buffer2->ptr()); + auto d_buffer3 = memory::Alloc(place, sizeof(uint32_t) * segments_num); + auto d_segments_fea_num_info = reinterpret_cast(d_buffer3->ptr()); + + temp_storage_bytes = 0; PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( NULL, temp_storage_bytes, d_segments, d_segments_offset, uniq_len, stream)); - auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); + if (d_temp_storage->size() < temp_storage_bytes) { + d_temp_storage = NULL; + d_temp_storage = memory::Alloc(place, temp_storage_bytes); + } PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( d_temp_storage->ptr(), temp_storage_bytes, d_segments, d_segments_offset, uniq_len, stream)); @@ -786,6 +806,7 @@ void HeterComm::segment_merge_grad( d_segments_fea_num_info, segment_size, stream); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + // reuse d_segments_offset PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( NULL, temp_storage_bytes, d_segments_fea_num_info, d_segments_offset, segments_num, stream)); @@ -810,11 +831,11 @@ void HeterComm::segment_merge_grad( PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); timeline.Pause(); - VLOG(0) << "card:" << dev_id << ", segment merge_grad cost " << + VLOG(0) << "card:" << dev_id << ", merge_grad cost " << timeline.ElapsedSec() << "seconds" << ", len=" << len << ", uniq_len=" << uniq_len << ", segments_num=" << segments_num; - auto d_segments_keys = memory::Alloc(place, sizeof(KeyType) * len); + auto d_segments_keys = memory::Alloc(place, sizeof(KeyType) * segments_num); auto d_segments_keys_ptr = reinterpret_cast(d_segments_keys->ptr()); heter_comm_kernel_->shrink_keys( d_keys, d_segments_offset, @@ -829,6 +850,11 @@ void HeterComm::segment_merge_grad( grad_value_size * segments_num, cudaMemcpyDeviceToDevice, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + //show_list(dev_id, d_keys, len, "d_keys"); + //show_list(dev_id, d_segments_fea_num_info, segments_num, "d_segments_fea_num_info"); + //show_list(dev_id, d_segments_offset, segments_num, "d_segments_offset"); + //show_list(dev_id, d_segments_keys_ptr, segments_num, "d_segments_keys"); } template @@ -1068,6 +1094,7 @@ void HeterComm::push_sparse(int dev_num, // do two gradient merge // 1st. do segmented gradient merge // 2nd. do global gradient merge + //len = 1000; dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len, segment_len, true); len = segment_len; uniq_len = 0; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 0e734a4c58946..0304ee6812537 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -220,7 +220,6 @@ __global__ void split_segments_kernel( auto fea_num = d_fea_num_info[tx]; auto seg_num = (uint32_t)((fea_num - 1) / segment_size + 1); d_segments[tx] = seg_num; - atomicAdd(d_segments_num, seg_num); } __global__ void expand_segments_kernel( From b7ba6a9225d59f6e9033188ca9d625a31978cbb3 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 11 Jul 2022 10:13:14 +0800 Subject: [PATCH 6/6] Optimize pullsparse and increase keys aggregation --- .../framework/fleet/heter_ps/heter_comm.h | 5 + .../framework/fleet/heter_ps/heter_comm_inl.h | 120 ++++++++++++++++-- .../fleet/heter_ps/heter_comm_kernel.cu | 116 ++++++++++++++++- .../fleet/heter_ps/heter_comm_kernel.h | 14 +- 4 files changed, 239 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index c512b6e6865f3..f10b59edf4d77 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -55,6 +55,11 @@ class HeterComm { void split_input_to_shard(KeyType* d_keys, int* d_idx_ptr, size_t len, int* left, int* right, int gpu_num); + void merge_keys(int gpu_num, const KeyType* d_keys, size_t len, + KeyType* d_sorted_keys, + KeyType* d_merged_keys, + uint32_t* d_restore_idx, + size_t & uniq_len); void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, int& uniq_len); // NOLINT void dynamic_merge_grad(int gpu_num, KeyType* d_keys, float* d_grads, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 972188e1bd651..68f73531928a5 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -635,7 +635,6 @@ void HeterComm::dynamic_merge_grad( auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); - auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1)); uint32_t* d_fea_num_info_ptr = reinterpret_cast(d_fea_num_info->ptr()); @@ -643,15 +642,16 @@ void HeterComm::dynamic_merge_grad( uint32_t* d_idx = (uint32_t*)&d_index[len]; int* d_merged_size = (int*)&d_idx[len]; heter_comm_kernel_->fill_idx(d_idx, len, stream); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_idx, d_index, len, 0, 8 * sizeof(KeyType), stream)); - void* d_buff = NULL; auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr, d_idx, d_index, len, 0, 8 * sizeof(KeyType), stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + temp_storage_bytes = 0; PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode( NULL, temp_storage_bytes, d_merge_keys_ptr, d_keys, d_fea_num_info_ptr, @@ -853,6 +853,73 @@ void HeterComm::split_input_to_shard( sync_stream(stream); } +template +void HeterComm::merge_keys( + int gpu_num, const KeyType* d_keys, size_t len, // input + KeyType* d_sorted_keys, // output + KeyType* d_merged_keys, // output + uint32_t* d_restore_idx, // output + size_t& uniq_len) { // output + int dev_id = resource_->dev_id(gpu_num); + platform::CUDAPlace place = platform::CUDAPlace(dev_id); + platform::CUDADeviceGuard guard(dev_id); + auto stream = resource_->local_stream(gpu_num, 0); + + size_t grad_dim = max_mf_dim_; + size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + + auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 4 + 1)); + uint32_t* d_fea_num_info_ptr = reinterpret_cast(d_fea_num_info->ptr()); + uint32_t* d_idx = (uint32_t*)&d_fea_num_info_ptr[len]; + uint32_t* d_index = (uint32_t*)&d_idx[len]; + uint32_t* d_offset = (uint32_t*)&d_index[len]; + uint32_t* d_merged_size = (uint32_t*)&d_offset[len]; + heter_comm_kernel_->fill_idx(d_idx, len, stream); + + size_t temp_storage_bytes; + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( + NULL, temp_storage_bytes, d_keys, d_sorted_keys, d_idx, d_index, len, + 0, 8 * sizeof(KeyType), stream)); + auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( + d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_sorted_keys, + d_idx, d_index, len, 0, 8 * sizeof(KeyType), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + temp_storage_bytes = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode( + NULL, temp_storage_bytes, d_sorted_keys, d_merged_keys, d_fea_num_info_ptr, + d_merged_size, len, stream)); + if (d_temp_storage->size() < temp_storage_bytes) { + d_temp_storage = NULL; + d_temp_storage = memory::Alloc(place, temp_storage_bytes); + } + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode( + d_temp_storage->ptr(), temp_storage_bytes, d_sorted_keys, d_merged_keys, + d_fea_num_info_ptr, d_merged_size, len, stream)); + cudaMemcpyAsync((void*)&uniq_len, d_merged_size, sizeof(int), + cudaMemcpyDeviceToHost, stream); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + temp_storage_bytes = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( + NULL, temp_storage_bytes, d_fea_num_info_ptr, d_offset, uniq_len, + stream)); + if (d_temp_storage->size() < temp_storage_bytes) { + d_temp_storage = NULL; + d_temp_storage = memory::Alloc(place, temp_storage_bytes); + } + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( + d_temp_storage->ptr(), temp_storage_bytes, d_fea_num_info_ptr, d_offset, uniq_len, + stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + heter_comm_kernel_->fill_restore_idx( + d_index, d_offset, d_fea_num_info_ptr, d_merged_keys, uniq_len, + d_restore_idx, stream); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); +} + template void HeterComm::pull_sparse(int num, KeyType* d_keys, @@ -897,21 +964,33 @@ void HeterComm::pull_sparse(int num, XPUAPIErrorMsg[r2])); #endif - auto d_idx = memory::Alloc(place, len * sizeof(int)); - int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); - size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); VLOG(3) << "pull_sparse len:" << len << " val_type_size: " << val_type_size; + auto d_sorted_keys = memory::Alloc(place, len * sizeof(KeyType)); + auto d_sorted_keys_ptr = reinterpret_cast(d_sorted_keys->ptr()); + auto d_merged_keys = memory::Alloc(place, len * sizeof(KeyType)); + auto d_merged_keys_ptr = reinterpret_cast(d_merged_keys->ptr()); + auto d_restore_idx = memory::Alloc(place, len * sizeof(uint32_t)); + auto d_restore_idx_ptr = reinterpret_cast(d_restore_idx->ptr()); auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); - KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); + auto d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); auto d_shard_vals = memory::Alloc(place, len * val_type_size); - float* d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); - - split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num); - - heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, d_keys, d_idx_ptr, len, - stream); + auto d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); + + size_t uniq_len = 0; + merge_keys(num, d_keys, len, + d_sorted_keys_ptr, + d_merged_keys_ptr, + d_restore_idx_ptr, + uniq_len); + sync_stream(stream); + auto d_idx = memory::Alloc(place, uniq_len * sizeof(int)); + auto d_idx_ptr = reinterpret_cast(d_idx->ptr()); + split_input_to_shard(d_merged_keys_ptr, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, num); + heter_comm_kernel_->fill_shard_key( + d_shard_keys_ptr, d_merged_keys_ptr, d_idx_ptr, uniq_len, + stream); sync_stream(stream); auto dst_place = platform::CPUPlace(); @@ -933,6 +1012,7 @@ void HeterComm::pull_sparse(int num, } walk_to_dest(num, total_device, h_left, h_right, d_shard_keys_ptr, NULL); } + for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1) { continue; @@ -962,6 +1042,7 @@ void HeterComm::pull_sparse(int num, } ptr_tables_[i]->rwlock_->UNLock(); } + if (!FLAGS_gpugraph_enable_gpu_direct_access) { walk_to_src(num, total_device, h_left, h_right, reinterpret_cast(d_shard_vals_ptr), val_type_size); @@ -971,10 +1052,21 @@ void HeterComm::pull_sparse(int num, } } - heter_comm_kernel_->dy_mf_fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len, - val_type_size, stream); + auto d_merged_vals = memory::Alloc(place, uniq_len * val_type_size); + auto d_merged_vals_ptr = reinterpret_cast(d_merged_vals->ptr()); + heter_comm_kernel_->dy_mf_fill_dvals( + d_shard_vals_ptr, d_merged_vals_ptr, + d_idx_ptr, uniq_len, + val_type_size, stream); + sync_stream(stream); + heter_comm_kernel_->unpack_merged_vals( + len, d_keys, + d_merged_vals_ptr, + d_restore_idx_ptr, + d_vals, val_type_size, stream); sync_stream(stream); + if (!FLAGS_gpugraph_enable_gpu_direct_access) { for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 0304ee6812537..abb5cff60f5f8 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -259,6 +259,80 @@ __global__ void shrink_keys_kernel( d_segments_keys[tx] = d_keys[d_segments_offset[tx]]; } +template +__global__ void fill_restore_idx_kernel( + const T *d_sorted_idx, + const T *d_offset, + const T *d_merged_cnts, + const KeyType *d_merged_keys, + T *d_restore_idx, + size_t n) { + const size_t tx = blockIdx.x * blockDim.x + threadIdx.x; + if (tx >= n) { + return; + } + + const KeyType & key = d_merged_keys[tx]; + if (key == 0) { + return; + } + + const T &off = d_offset[tx]; + const T &num = d_merged_cnts[tx]; + for (size_t k = 0; k < num; ++k) { + d_restore_idx[d_sorted_idx[off + k]] = tx; + } +} + +template +__global__ void unpack_merged_vals_kernel( + const KeyType* d_keys, + const float* d_merged_vals, + const uint32_t* d_restored_idx, + float* d_out, size_t val_size, const size_t n, + CommonFeatureValueAccessor feature_value_accessor) { + const size_t tx = blockIdx.x * blockDim.x + threadIdx.x; + if (tx >= n) { + return; + } + + size_t src_val_idx = 0; + const KeyType & key = d_keys[tx]; + if (key != 0) { + src_val_idx = d_restored_idx[tx]; + } + + uint64_t dst_offset = uint64_t(tx) * val_size; + float* dst = (float*)((char*)d_out + dst_offset); + float* src_val = (float*)((char*)d_merged_vals + uint64_t(src_val_idx) * val_size); + int mf_dim = int(src_val[feature_value_accessor.common_feature_value.MfDimIndex()]); + + *(reinterpret_cast(dst + feature_value_accessor.common_feature_value.CpuPtrIndex())) = + *(reinterpret_cast(src_val + feature_value_accessor.common_feature_value.CpuPtrIndex())); + dst[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = + src_val[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; + dst[feature_value_accessor.common_feature_value.ShowIndex()] = + src_val[feature_value_accessor.common_feature_value.ShowIndex()]; + dst[feature_value_accessor.common_feature_value.ClickIndex()] = + src_val[feature_value_accessor.common_feature_value.ClickIndex()]; + dst[feature_value_accessor.common_feature_value.EmbedWIndex()] = + src_val[feature_value_accessor.common_feature_value.EmbedWIndex()]; + for (int i = 0; i < feature_value_accessor.common_feature_value.EmbedDim(); i++) { + dst[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i] = + src_val[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; + } + dst[feature_value_accessor.common_feature_value.SlotIndex()] = + src_val[feature_value_accessor.common_feature_value.SlotIndex()]; + dst[feature_value_accessor.common_feature_value.MfDimIndex()] = mf_dim; + dst[feature_value_accessor.common_feature_value.MfSizeIndex()] = + src_val[feature_value_accessor.common_feature_value.MfSizeIndex()]; + + for (int x = feature_value_accessor.common_feature_value.EmbedxG2SumIndex(); + x < int(feature_value_accessor.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ + dst[x] = src_val[x]; + } +} + template __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, float* d_vals, T* idx, size_t len, size_t val_size, @@ -450,11 +524,31 @@ void HeterCommKernel::expand_segments(const uint32_t* d_fea_num_info, template void HeterCommKernel::shrink_keys(const KeyType* d_keys, const uint32_t* d_segments_offset, KeyType* d_segments_keys, size_t n, const StreamType& stream) { - int grid_size = (n - 1) / block_size_ + 1; + int grid_size = (n - 1) / block_size_ + 1; shrink_keys_kernel<<>>( d_keys, d_segments_offset, d_segments_keys, n); } +template +void HeterCommKernel::fill_restore_idx( + const uint32_t* d_sorted_idx, const uint32_t* d_offset, + const uint32_t* d_merged_cnts, const KeyType* d_merged_keys, + const size_t n, uint32_t *d_restore_idx, const StreamType& stream) { + int grid_size = (n - 1) / block_size_ + 1; + fill_restore_idx_kernel<<>>( + d_sorted_idx, d_offset, d_merged_cnts, d_merged_keys, d_restore_idx, n); +} + +template +void HeterCommKernel::unpack_merged_vals(size_t n, const KeyType* d_keys, + const void* d_merged_vals, const uint32_t* d_restore_idx, + void* d_vals, size_t val_size, const StreamType& stream) { + int grid_size = (n - 1) / block_size_ + 1; + unpack_merged_vals_kernel<<>>( + d_keys, (const float *)d_merged_vals, d_restore_idx, + (float *)d_vals, val_size, n, feature_value_accessor_); +} + template void HeterCommKernel::fill_idx( int* idx, long long len, const cudaStream_t& stream); template void HeterCommKernel::fill_idx( @@ -561,6 +655,26 @@ template void HeterCommKernel::shrink_keys( template void HeterCommKernel::shrink_keys( const uint64_t* d_keys, const uint32_t* d_segments, uint64_t* d_segments_keys, size_t total_segment_num, const cudaStream_t& stream); + +template void HeterCommKernel::fill_restore_idx( + const uint32_t* d_sorted_idx, const uint32_t* d_offset, + const uint32_t* d_merged_cnts, const uint64_t* d_merged_keys, + const size_t n, uint32_t* d_restore_idx, const cudaStream_t& stream); + +template void HeterCommKernel::fill_restore_idx( + const uint32_t* d_sorted_idx, const uint32_t* d_offset, + const uint32_t* d_merged_cnts, const uint32_t* d_merged_keys, + const size_t n, uint32_t* d_restore_idx, const cudaStream_t& stream); + +template void HeterCommKernel::unpack_merged_vals( + size_t n, const uint64_t* d_keys, const void* d_merged_vals, + const uint32_t* d_restore_idx, void* d_vals, size_t val_size, + const cudaStream_t& stream); + +template void HeterCommKernel::unpack_merged_vals( + size_t n, const uint32_t* d_keys, const void* d_merged_vals, + const uint32_t* d_restore_idx, void* d_vals, size_t val_size, + const cudaStream_t& stream); #endif } // namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index 0931c4e1cd34a..1cde86e64b6bc 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -201,10 +201,22 @@ class HeterCommKernel { void shrink_keys(const KeyType* d_keys, const uint32_t* d_segments_offset, KeyType* d_segments_keys, size_t segments_num, const StreamType& stream); - CommonFeatureValueAccessor feature_value_accessor_; + template + void fill_restore_idx(const uint32_t* d_sorted_idx, const uint32_t* d_offset, + const uint32_t* d_merged_cnts, const KeyType* d_merged_keys, + const size_t len, uint32_t* d_restore_idx, const StreamType& stream); + + template + void unpack_merged_vals(size_t n, + const KeyType* d_keys, + const void* d_merged_vals, + const uint32_t* d_restore_idx, + void* d_vals, size_t val_size, + const StreamType& stream); private: int block_size_{256}; + CommonFeatureValueAccessor feature_value_accessor_; }; } // end namespace framework