Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize pullsparse and increase keys aggregation #61

Merged
merged 17 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class HeterComm {

void split_input_to_shard(KeyType* d_keys, int* d_idx_ptr, size_t len,
int* left, int* right, int gpu_num);
void merge_keys(int gpu_num, const KeyType* d_keys, size_t len,
KeyType* d_sorted_keys,
KeyType* d_merged_keys,
uint32_t* d_restore_idx,
size_t & uniq_len);
void merge_grad(int gpu_num, KeyType* d_keys, GradType* d_grads, size_t len,
int& uniq_len); // NOLINT
void dynamic_merge_grad(int gpu_num, KeyType* d_keys, float* d_grads,
Expand Down
120 changes: 106 additions & 14 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,23 +635,23 @@ void HeterComm<KeyType, ValType, GradType>::dynamic_merge_grad(

auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_merge_keys_ptr = reinterpret_cast<KeyType*>(d_merge_keys->ptr());

auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1));
uint32_t* d_fea_num_info_ptr =
reinterpret_cast<uint32_t*>(d_fea_num_info->ptr());
uint32_t* d_index = (uint32_t*)&d_fea_num_info_ptr[len];
uint32_t* d_idx = (uint32_t*)&d_index[len];
int* d_merged_size = (int*)&d_idx[len];
heter_comm_kernel_->fill_idx(d_idx, len, stream);

PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
NULL, temp_storage_bytes, d_keys, d_merge_keys_ptr, d_idx, d_index, len,
0, 8 * sizeof(KeyType), stream));
void* d_buff = NULL;
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_merge_keys_ptr,
d_idx, d_index, len, 0, 8 * sizeof(KeyType), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));

temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode(
NULL, temp_storage_bytes, d_merge_keys_ptr, d_keys, d_fea_num_info_ptr,
Expand Down Expand Up @@ -853,6 +853,73 @@ void HeterComm<KeyType, ValType, GradType>::split_input_to_shard(
sync_stream(stream);
}

template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::merge_keys(
int gpu_num, const KeyType* d_keys, size_t len, // input
KeyType* d_sorted_keys, // output
KeyType* d_merged_keys, // output
uint32_t* d_restore_idx, // output
size_t& uniq_len) { // output
int dev_id = resource_->dev_id(gpu_num);
platform::CUDAPlace place = platform::CUDAPlace(dev_id);
platform::CUDADeviceGuard guard(dev_id);
auto stream = resource_->local_stream(gpu_num, 0);

size_t grad_dim = max_mf_dim_;
size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_));

auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 4 + 1));
uint32_t* d_fea_num_info_ptr = reinterpret_cast<uint32_t*>(d_fea_num_info->ptr());
uint32_t* d_idx = (uint32_t*)&d_fea_num_info_ptr[len];
uint32_t* d_index = (uint32_t*)&d_idx[len];
uint32_t* d_offset = (uint32_t*)&d_index[len];
uint32_t* d_merged_size = (uint32_t*)&d_offset[len];
heter_comm_kernel_->fill_idx(d_idx, len, stream);

size_t temp_storage_bytes;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
NULL, temp_storage_bytes, d_keys, d_sorted_keys, d_idx, d_index, len,
0, 8 * sizeof(KeyType), stream));
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRadixSort::SortPairs(
d_temp_storage->ptr(), temp_storage_bytes, d_keys, d_sorted_keys,
d_idx, d_index, len, 0, 8 * sizeof(KeyType), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));

temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode(
NULL, temp_storage_bytes, d_sorted_keys, d_merged_keys, d_fea_num_info_ptr,
d_merged_size, len, stream));
if (d_temp_storage->size() < temp_storage_bytes) {
d_temp_storage = NULL;
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
}
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceRunLengthEncode::Encode(
d_temp_storage->ptr(), temp_storage_bytes, d_sorted_keys, d_merged_keys,
d_fea_num_info_ptr, d_merged_size, len, stream));
cudaMemcpyAsync((void*)&uniq_len, d_merged_size, sizeof(int),
cudaMemcpyDeviceToHost, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));

temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum(
NULL, temp_storage_bytes, d_fea_num_info_ptr, d_offset, uniq_len,
stream));
if (d_temp_storage->size() < temp_storage_bytes) {
d_temp_storage = NULL;
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
}
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceScan::ExclusiveSum(
d_temp_storage->ptr(), temp_storage_bytes, d_fea_num_info_ptr, d_offset, uniq_len,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));

heter_comm_kernel_->fill_restore_idx(
d_index, d_offset, d_fea_num_info_ptr, d_merged_keys, uniq_len,
d_restore_idx, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
}

template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
KeyType* d_keys,
Expand Down Expand Up @@ -897,21 +964,33 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
XPUAPIErrorMsg[r2]));
#endif

auto d_idx = memory::Alloc(place, len * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());

size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_));
VLOG(3) << "pull_sparse len:" << len << " val_type_size: " << val_type_size;
auto d_sorted_keys = memory::Alloc(place, len * sizeof(KeyType));
auto d_sorted_keys_ptr = reinterpret_cast<KeyType*>(d_sorted_keys->ptr());
auto d_merged_keys = memory::Alloc(place, len * sizeof(KeyType));
auto d_merged_keys_ptr = reinterpret_cast<KeyType*>(d_merged_keys->ptr());
auto d_restore_idx = memory::Alloc(place, len * sizeof(uint32_t));
auto d_restore_idx_ptr = reinterpret_cast<uint32_t*>(d_restore_idx->ptr());
auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr());
auto d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr());
auto d_shard_vals = memory::Alloc(place, len * val_type_size);
float* d_shard_vals_ptr = reinterpret_cast<float*>(d_shard_vals->ptr());

split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num);

heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, d_keys, d_idx_ptr, len,
stream);
auto d_shard_vals_ptr = reinterpret_cast<float*>(d_shard_vals->ptr());

size_t uniq_len = 0;
merge_keys(num, d_keys, len,
d_sorted_keys_ptr,
d_merged_keys_ptr,
d_restore_idx_ptr,
uniq_len);
sync_stream(stream);

auto d_idx = memory::Alloc(place, uniq_len * sizeof(int));
auto d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
split_input_to_shard(d_merged_keys_ptr, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, num);
heter_comm_kernel_->fill_shard_key(
d_shard_keys_ptr, d_merged_keys_ptr, d_idx_ptr, uniq_len,
stream);
sync_stream(stream);

auto dst_place = platform::CPUPlace();
Expand All @@ -933,6 +1012,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
}
walk_to_dest(num, total_device, h_left, h_right, d_shard_keys_ptr, NULL);
}

for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1) {
continue;
Expand Down Expand Up @@ -962,6 +1042,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
}
ptr_tables_[i]->rwlock_->UNLock();
}

if (!FLAGS_gpugraph_enable_gpu_direct_access) {
walk_to_src(num, total_device, h_left, h_right,
reinterpret_cast<char*>(d_shard_vals_ptr), val_type_size);
Expand All @@ -971,10 +1052,21 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
}
}

heter_comm_kernel_->dy_mf_fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len,
val_type_size, stream);
auto d_merged_vals = memory::Alloc(place, uniq_len * val_type_size);
auto d_merged_vals_ptr = reinterpret_cast<float*>(d_merged_vals->ptr());
heter_comm_kernel_->dy_mf_fill_dvals(
d_shard_vals_ptr, d_merged_vals_ptr,
d_idx_ptr, uniq_len,
val_type_size, stream);
sync_stream(stream);

heter_comm_kernel_->unpack_merged_vals(
len, d_keys,
d_merged_vals_ptr,
d_restore_idx_ptr,
d_vals, val_type_size, stream);
sync_stream(stream);

if (!FLAGS_gpugraph_enable_gpu_direct_access) {
for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
Expand Down
116 changes: 115 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,80 @@ __global__ void shrink_keys_kernel(
d_segments_keys[tx] = d_keys[d_segments_offset[tx]];
}

template<typename KeyType, typename T>
__global__ void fill_restore_idx_kernel(
const T *d_sorted_idx,
const T *d_offset,
const T *d_merged_cnts,
const KeyType *d_merged_keys,
T *d_restore_idx,
size_t n) {
const size_t tx = blockIdx.x * blockDim.x + threadIdx.x;
if (tx >= n) {
return;
}

const KeyType & key = d_merged_keys[tx];
if (key == 0) {
return;
}

const T &off = d_offset[tx];
const T &num = d_merged_cnts[tx];
for (size_t k = 0; k < num; ++k) {
d_restore_idx[d_sorted_idx[off + k]] = tx;
}
}

template<typename KeyType>
__global__ void unpack_merged_vals_kernel(
const KeyType* d_keys,
const float* d_merged_vals,
const uint32_t* d_restored_idx,
float* d_out, size_t val_size, const size_t n,
CommonFeatureValueAccessor feature_value_accessor) {
const size_t tx = blockIdx.x * blockDim.x + threadIdx.x;
if (tx >= n) {
return;
}

size_t src_val_idx = 0;
const KeyType & key = d_keys[tx];
if (key != 0) {
src_val_idx = d_restored_idx[tx];
}

uint64_t dst_offset = uint64_t(tx) * val_size;
float* dst = (float*)((char*)d_out + dst_offset);
float* src_val = (float*)((char*)d_merged_vals + uint64_t(src_val_idx) * val_size);
int mf_dim = int(src_val[feature_value_accessor.common_feature_value.MfDimIndex()]);

*(reinterpret_cast<uint64_t*>(dst + feature_value_accessor.common_feature_value.CpuPtrIndex())) =
*(reinterpret_cast<uint64_t*>(src_val + feature_value_accessor.common_feature_value.CpuPtrIndex()));
dst[feature_value_accessor.common_feature_value.DeltaScoreIndex()] =
src_val[feature_value_accessor.common_feature_value.DeltaScoreIndex()];
dst[feature_value_accessor.common_feature_value.ShowIndex()] =
src_val[feature_value_accessor.common_feature_value.ShowIndex()];
dst[feature_value_accessor.common_feature_value.ClickIndex()] =
src_val[feature_value_accessor.common_feature_value.ClickIndex()];
dst[feature_value_accessor.common_feature_value.EmbedWIndex()] =
src_val[feature_value_accessor.common_feature_value.EmbedWIndex()];
for (int i = 0; i < feature_value_accessor.common_feature_value.EmbedDim(); i++) {
dst[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i] =
src_val[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i];
}
dst[feature_value_accessor.common_feature_value.SlotIndex()] =
src_val[feature_value_accessor.common_feature_value.SlotIndex()];
dst[feature_value_accessor.common_feature_value.MfDimIndex()] = mf_dim;
dst[feature_value_accessor.common_feature_value.MfSizeIndex()] =
src_val[feature_value_accessor.common_feature_value.MfSizeIndex()];

for (int x = feature_value_accessor.common_feature_value.EmbedxG2SumIndex();
x < int(feature_value_accessor.common_feature_value.Size(mf_dim) / sizeof(float)); x++){
dst[x] = src_val[x];
}
}

template <typename T>
__global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, float* d_vals,
T* idx, size_t len, size_t val_size,
Expand Down Expand Up @@ -450,11 +524,31 @@ void HeterCommKernel::expand_segments(const uint32_t* d_fea_num_info,
template <typename KeyType, typename StreamType>
void HeterCommKernel::shrink_keys(const KeyType* d_keys, const uint32_t* d_segments_offset,
KeyType* d_segments_keys, size_t n, const StreamType& stream) {
int grid_size = (n - 1) / block_size_ + 1;
int grid_size = (n - 1) / block_size_ + 1;
shrink_keys_kernel<<<grid_size, block_size_, 0, stream>>>(
d_keys, d_segments_offset, d_segments_keys, n);
}

template <typename KeyType, typename StreamType>
void HeterCommKernel::fill_restore_idx(
const uint32_t* d_sorted_idx, const uint32_t* d_offset,
const uint32_t* d_merged_cnts, const KeyType* d_merged_keys,
const size_t n, uint32_t *d_restore_idx, const StreamType& stream) {
int grid_size = (n - 1) / block_size_ + 1;
fill_restore_idx_kernel<<<grid_size, block_size_, 0, stream>>>(
d_sorted_idx, d_offset, d_merged_cnts, d_merged_keys, d_restore_idx, n);
}

template <typename KeyType, typename StreamType>
void HeterCommKernel::unpack_merged_vals(size_t n, const KeyType* d_keys,
const void* d_merged_vals, const uint32_t* d_restore_idx,
void* d_vals, size_t val_size, const StreamType& stream) {
int grid_size = (n - 1) / block_size_ + 1;
unpack_merged_vals_kernel<<<grid_size, block_size_, 0, stream>>>(
d_keys, (const float *)d_merged_vals, d_restore_idx,
(float *)d_vals, val_size, n, feature_value_accessor_);
}

template void HeterCommKernel::fill_idx<int, cudaStream_t>(
int* idx, long long len, const cudaStream_t& stream);
template void HeterCommKernel::fill_idx<uint32_t, cudaStream_t>(
Expand Down Expand Up @@ -561,6 +655,26 @@ template void HeterCommKernel::shrink_keys<uint32_t, cudaStream_t>(
template void HeterCommKernel::shrink_keys<uint64_t, cudaStream_t>(
const uint64_t* d_keys, const uint32_t* d_segments,
uint64_t* d_segments_keys, size_t total_segment_num, const cudaStream_t& stream);

template void HeterCommKernel::fill_restore_idx<uint64_t, cudaStream_t>(
const uint32_t* d_sorted_idx, const uint32_t* d_offset,
const uint32_t* d_merged_cnts, const uint64_t* d_merged_keys,
const size_t n, uint32_t* d_restore_idx, const cudaStream_t& stream);

template void HeterCommKernel::fill_restore_idx<uint32_t, cudaStream_t>(
const uint32_t* d_sorted_idx, const uint32_t* d_offset,
const uint32_t* d_merged_cnts, const uint32_t* d_merged_keys,
const size_t n, uint32_t* d_restore_idx, const cudaStream_t& stream);

template void HeterCommKernel::unpack_merged_vals<uint64_t, cudaStream_t>(
size_t n, const uint64_t* d_keys, const void* d_merged_vals,
const uint32_t* d_restore_idx, void* d_vals, size_t val_size,
const cudaStream_t& stream);

template void HeterCommKernel::unpack_merged_vals<uint32_t, cudaStream_t>(
size_t n, const uint32_t* d_keys, const void* d_merged_vals,
const uint32_t* d_restore_idx, void* d_vals, size_t val_size,
const cudaStream_t& stream);
#endif

} // namespace framework
Expand Down
14 changes: 13 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,22 @@ class HeterCommKernel {
void shrink_keys(const KeyType* d_keys, const uint32_t* d_segments_offset,
KeyType* d_segments_keys, size_t segments_num, const StreamType& stream);

CommonFeatureValueAccessor feature_value_accessor_;
template <typename KeyType, typename StreamType>
void fill_restore_idx(const uint32_t* d_sorted_idx, const uint32_t* d_offset,
const uint32_t* d_merged_cnts, const KeyType* d_merged_keys,
const size_t len, uint32_t* d_restore_idx, const StreamType& stream);

template <typename KeyType, typename StreamType>
void unpack_merged_vals(size_t n,
const KeyType* d_keys,
const void* d_merged_vals,
const uint32_t* d_restore_idx,
void* d_vals, size_t val_size,
const StreamType& stream);

private:
int block_size_{256};
CommonFeatureValueAccessor feature_value_accessor_;
};

} // end namespace framework
Expand Down