diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index c512b6e6865f3..f10b59edf4d77 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -55,6 +55,11 @@ class HeterComm { void split_input_to_shard(KeyType* d_keys, int* d_idx_ptr, size_t len, int* left, int* right, int gpu_num); + void merge_keys(int gpu_num, const KeyType* d_keys, size_t len, + KeyType* d_sorted_keys, + KeyType* d_merged_keys, + uint32_t* d_restore_idx, + size_t & uniq_len); void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len, int& uniq_len); // NOLINT void dynamic_merge_grad(int gpu_num, KeyType* d_keys, float* d_grads, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 972188e1bd651..68f73531928a5 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -635,7 +635,6 @@ void HeterComm::dynamic_merge_grad( auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); - auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1)); uint32_t* d_fea_num_info_ptr = reinterpret_cast(d_fea_num_info->ptr()); @@ -643,15 +642,16 @@ void HeterComm::dynamic_merge_grad( uint32_t* d_idx = (uint32_t*)&d_index[len]; int* d_merged_size = (int*)&d_idx[len]; heter_comm_kernel_->fill_idx(d_idx, len, stream); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_idx, d_index, len, 0, 8 * sizeof(KeyType), stream)); - void* d_buff = NULL; auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr, d_idx, d_index, len, 0, 8 * sizeof(KeyType), stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + temp_storage_bytes = 0; PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode( NULL, temp_storage_bytes, d_merge_keys_ptr, d_keys, d_fea_num_info_ptr, @@ -853,6 +853,73 @@ void HeterComm::split_input_to_shard( sync_stream(stream); } +template +void HeterComm::merge_keys( + int gpu_num, const KeyType* d_keys, size_t len, // input + KeyType* d_sorted_keys, // output + KeyType* d_merged_keys, // output + uint32_t* d_restore_idx, // output + size_t& uniq_len) { // output + int dev_id = resource_->dev_id(gpu_num); + platform::CUDAPlace place = platform::CUDAPlace(dev_id); + platform::CUDADeviceGuard guard(dev_id); + auto stream = resource_->local_stream(gpu_num, 0); + + size_t grad_dim = max_mf_dim_; + size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + + auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 4 + 1)); + uint32_t* d_fea_num_info_ptr = reinterpret_cast(d_fea_num_info->ptr()); + uint32_t* d_idx = (uint32_t*)&d_fea_num_info_ptr[len]; + uint32_t* d_index = (uint32_t*)&d_idx[len]; + uint32_t* d_offset = (uint32_t*)&d_index[len]; + uint32_t* d_merged_size = (uint32_t*)&d_offset[len]; + heter_comm_kernel_->fill_idx(d_idx, len, stream); + + size_t temp_storage_bytes; + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( + NULL, temp_storage_bytes, d_keys, d_sorted_keys, d_idx, d_index, len, + 0, 8 * sizeof(KeyType), stream)); + auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs( + d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_sorted_keys, + d_idx, d_index, len, 0, 8 * sizeof(KeyType), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + temp_storage_bytes = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode( + NULL, temp_storage_bytes, d_sorted_keys, d_merged_keys, d_fea_num_info_ptr, + d_merged_size, len, stream)); + if (d_temp_storage->size() < temp_storage_bytes) { + d_temp_storage = NULL; + d_temp_storage = memory::Alloc(place, temp_storage_bytes); + } + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode( + d_temp_storage->ptr(), temp_storage_bytes, d_sorted_keys, d_merged_keys, + d_fea_num_info_ptr, d_merged_size, len, stream)); + cudaMemcpyAsync((void*)&uniq_len, d_merged_size, sizeof(int), + cudaMemcpyDeviceToHost, stream); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + temp_storage_bytes = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( + NULL, temp_storage_bytes, d_fea_num_info_ptr, d_offset, uniq_len, + stream)); + if (d_temp_storage->size() < temp_storage_bytes) { + d_temp_storage = NULL; + d_temp_storage = memory::Alloc(place, temp_storage_bytes); + } + PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum( + d_temp_storage->ptr(), temp_storage_bytes, d_fea_num_info_ptr, d_offset, uniq_len, + stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + heter_comm_kernel_->fill_restore_idx( + d_index, d_offset, d_fea_num_info_ptr, d_merged_keys, uniq_len, + d_restore_idx, stream); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); +} + template void HeterComm::pull_sparse(int num, KeyType* d_keys, @@ -897,21 +964,33 @@ void HeterComm::pull_sparse(int num, XPUAPIErrorMsg[r2])); #endif - auto d_idx = memory::Alloc(place, len * sizeof(int)); - int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); - size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); VLOG(3) << "pull_sparse len:" << len << " val_type_size: " << val_type_size; + auto d_sorted_keys = memory::Alloc(place, len * sizeof(KeyType)); + auto d_sorted_keys_ptr = reinterpret_cast(d_sorted_keys->ptr()); + auto d_merged_keys = memory::Alloc(place, len * sizeof(KeyType)); + auto d_merged_keys_ptr = reinterpret_cast(d_merged_keys->ptr()); + auto d_restore_idx = memory::Alloc(place, len * sizeof(uint32_t)); + auto d_restore_idx_ptr = reinterpret_cast(d_restore_idx->ptr()); auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); - KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); + auto d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); auto d_shard_vals = memory::Alloc(place, len * val_type_size); - float* d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); - - split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num); - - heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, d_keys, d_idx_ptr, len, - stream); + auto d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); + + size_t uniq_len = 0; + merge_keys(num, d_keys, len, + d_sorted_keys_ptr, + d_merged_keys_ptr, + d_restore_idx_ptr, + uniq_len); + sync_stream(stream); + auto d_idx = memory::Alloc(place, uniq_len * sizeof(int)); + auto d_idx_ptr = reinterpret_cast(d_idx->ptr()); + split_input_to_shard(d_merged_keys_ptr, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, num); + heter_comm_kernel_->fill_shard_key( + d_shard_keys_ptr, d_merged_keys_ptr, d_idx_ptr, uniq_len, + stream); sync_stream(stream); auto dst_place = platform::CPUPlace(); @@ -933,6 +1012,7 @@ void HeterComm::pull_sparse(int num, } walk_to_dest(num, total_device, h_left, h_right, d_shard_keys_ptr, NULL); } + for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1) { continue; @@ -962,6 +1042,7 @@ void HeterComm::pull_sparse(int num, } ptr_tables_[i]->rwlock_->UNLock(); } + if (!FLAGS_gpugraph_enable_gpu_direct_access) { walk_to_src(num, total_device, h_left, h_right, reinterpret_cast(d_shard_vals_ptr), val_type_size); @@ -971,10 +1052,21 @@ void HeterComm::pull_sparse(int num, } } - heter_comm_kernel_->dy_mf_fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len, - val_type_size, stream); + auto d_merged_vals = memory::Alloc(place, uniq_len * val_type_size); + auto d_merged_vals_ptr = reinterpret_cast(d_merged_vals->ptr()); + heter_comm_kernel_->dy_mf_fill_dvals( + d_shard_vals_ptr, d_merged_vals_ptr, + d_idx_ptr, uniq_len, + val_type_size, stream); + sync_stream(stream); + heter_comm_kernel_->unpack_merged_vals( + len, d_keys, + d_merged_vals_ptr, + d_restore_idx_ptr, + d_vals, val_type_size, stream); sync_stream(stream); + if (!FLAGS_gpugraph_enable_gpu_direct_access) { for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 0304ee6812537..abb5cff60f5f8 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -259,6 +259,80 @@ __global__ void shrink_keys_kernel( d_segments_keys[tx] = d_keys[d_segments_offset[tx]]; } +template +__global__ void fill_restore_idx_kernel( + const T *d_sorted_idx, + const T *d_offset, + const T *d_merged_cnts, + const KeyType *d_merged_keys, + T *d_restore_idx, + size_t n) { + const size_t tx = blockIdx.x * blockDim.x + threadIdx.x; + if (tx >= n) { + return; + } + + const KeyType & key = d_merged_keys[tx]; + if (key == 0) { + return; + } + + const T &off = d_offset[tx]; + const T &num = d_merged_cnts[tx]; + for (size_t k = 0; k < num; ++k) { + d_restore_idx[d_sorted_idx[off + k]] = tx; + } +} + +template +__global__ void unpack_merged_vals_kernel( + const KeyType* d_keys, + const float* d_merged_vals, + const uint32_t* d_restored_idx, + float* d_out, size_t val_size, const size_t n, + CommonFeatureValueAccessor feature_value_accessor) { + const size_t tx = blockIdx.x * blockDim.x + threadIdx.x; + if (tx >= n) { + return; + } + + size_t src_val_idx = 0; + const KeyType & key = d_keys[tx]; + if (key != 0) { + src_val_idx = d_restored_idx[tx]; + } + + uint64_t dst_offset = uint64_t(tx) * val_size; + float* dst = (float*)((char*)d_out + dst_offset); + float* src_val = (float*)((char*)d_merged_vals + uint64_t(src_val_idx) * val_size); + int mf_dim = int(src_val[feature_value_accessor.common_feature_value.MfDimIndex()]); + + *(reinterpret_cast(dst + feature_value_accessor.common_feature_value.CpuPtrIndex())) = + *(reinterpret_cast(src_val + feature_value_accessor.common_feature_value.CpuPtrIndex())); + dst[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = + src_val[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; + dst[feature_value_accessor.common_feature_value.ShowIndex()] = + src_val[feature_value_accessor.common_feature_value.ShowIndex()]; + dst[feature_value_accessor.common_feature_value.ClickIndex()] = + src_val[feature_value_accessor.common_feature_value.ClickIndex()]; + dst[feature_value_accessor.common_feature_value.EmbedWIndex()] = + src_val[feature_value_accessor.common_feature_value.EmbedWIndex()]; + for (int i = 0; i < feature_value_accessor.common_feature_value.EmbedDim(); i++) { + dst[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i] = + src_val[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; + } + dst[feature_value_accessor.common_feature_value.SlotIndex()] = + src_val[feature_value_accessor.common_feature_value.SlotIndex()]; + dst[feature_value_accessor.common_feature_value.MfDimIndex()] = mf_dim; + dst[feature_value_accessor.common_feature_value.MfSizeIndex()] = + src_val[feature_value_accessor.common_feature_value.MfSizeIndex()]; + + for (int x = feature_value_accessor.common_feature_value.EmbedxG2SumIndex(); + x < int(feature_value_accessor.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ + dst[x] = src_val[x]; + } +} + template __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, float* d_vals, T* idx, size_t len, size_t val_size, @@ -450,11 +524,31 @@ void HeterCommKernel::expand_segments(const uint32_t* d_fea_num_info, template void HeterCommKernel::shrink_keys(const KeyType* d_keys, const uint32_t* d_segments_offset, KeyType* d_segments_keys, size_t n, const StreamType& stream) { - int grid_size = (n - 1) / block_size_ + 1; + int grid_size = (n - 1) / block_size_ + 1; shrink_keys_kernel<<>>( d_keys, d_segments_offset, d_segments_keys, n); } +template +void HeterCommKernel::fill_restore_idx( + const uint32_t* d_sorted_idx, const uint32_t* d_offset, + const uint32_t* d_merged_cnts, const KeyType* d_merged_keys, + const size_t n, uint32_t *d_restore_idx, const StreamType& stream) { + int grid_size = (n - 1) / block_size_ + 1; + fill_restore_idx_kernel<<>>( + d_sorted_idx, d_offset, d_merged_cnts, d_merged_keys, d_restore_idx, n); +} + +template +void HeterCommKernel::unpack_merged_vals(size_t n, const KeyType* d_keys, + const void* d_merged_vals, const uint32_t* d_restore_idx, + void* d_vals, size_t val_size, const StreamType& stream) { + int grid_size = (n - 1) / block_size_ + 1; + unpack_merged_vals_kernel<<>>( + d_keys, (const float *)d_merged_vals, d_restore_idx, + (float *)d_vals, val_size, n, feature_value_accessor_); +} + template void HeterCommKernel::fill_idx( int* idx, long long len, const cudaStream_t& stream); template void HeterCommKernel::fill_idx( @@ -561,6 +655,26 @@ template void HeterCommKernel::shrink_keys( template void HeterCommKernel::shrink_keys( const uint64_t* d_keys, const uint32_t* d_segments, uint64_t* d_segments_keys, size_t total_segment_num, const cudaStream_t& stream); + +template void HeterCommKernel::fill_restore_idx( + const uint32_t* d_sorted_idx, const uint32_t* d_offset, + const uint32_t* d_merged_cnts, const uint64_t* d_merged_keys, + const size_t n, uint32_t* d_restore_idx, const cudaStream_t& stream); + +template void HeterCommKernel::fill_restore_idx( + const uint32_t* d_sorted_idx, const uint32_t* d_offset, + const uint32_t* d_merged_cnts, const uint32_t* d_merged_keys, + const size_t n, uint32_t* d_restore_idx, const cudaStream_t& stream); + +template void HeterCommKernel::unpack_merged_vals( + size_t n, const uint64_t* d_keys, const void* d_merged_vals, + const uint32_t* d_restore_idx, void* d_vals, size_t val_size, + const cudaStream_t& stream); + +template void HeterCommKernel::unpack_merged_vals( + size_t n, const uint32_t* d_keys, const void* d_merged_vals, + const uint32_t* d_restore_idx, void* d_vals, size_t val_size, + const cudaStream_t& stream); #endif } // namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index 0931c4e1cd34a..1cde86e64b6bc 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -201,10 +201,22 @@ class HeterCommKernel { void shrink_keys(const KeyType* d_keys, const uint32_t* d_segments_offset, KeyType* d_segments_keys, size_t segments_num, const StreamType& stream); - CommonFeatureValueAccessor feature_value_accessor_; + template + void fill_restore_idx(const uint32_t* d_sorted_idx, const uint32_t* d_offset, + const uint32_t* d_merged_cnts, const KeyType* d_merged_keys, + const size_t len, uint32_t* d_restore_idx, const StreamType& stream); + + template + void unpack_merged_vals(size_t n, + const KeyType* d_keys, + const void* d_merged_vals, + const uint32_t* d_restore_idx, + void* d_vals, size_t val_size, + const StreamType& stream); private: int block_size_{256}; + CommonFeatureValueAccessor feature_value_accessor_; }; } // end namespace framework