Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing the zero key problem in the push phase #40

Merged
merged 1 commit into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#endif

DECLARE_double(gpugraph_hbm_table_load_factor);
DECLARE_bool(gpugraph_enable_gpu_direct_access);

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -682,7 +683,7 @@ void HeterComm<KeyType, ValType, GradType>::dynamic_merge_grad(
uniq_len, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
heter_comm_kernel_->merge_gradient(
d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads,
d_keys, d_offset, d_fea_num_info_ptr, d_index, (char*)d_grads,
(char*)d_merge_grads_ptr, uniq_len, grad_value_size, merger_, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr,
Expand Down Expand Up @@ -802,7 +803,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
memory_copy(dst_place, h_right, src_place, d_right_ptr,
total_device * sizeof(int), stream);

if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
for (int i = 0; i < total_device; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
Expand All @@ -818,12 +819,12 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
continue;
}
auto& node = path_[num][i].nodes_.back();
if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
sync_stream(node.in_stream);
}
AnyDeviceGuard guard(resource_->dev_id(i));
ptr_tables_[i]->rwlock_->RDLock();
if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
ptr_tables_[i]->get(reinterpret_cast<KeyType*>(node.key_storage),
node.val_storage, h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, num));
Expand All @@ -842,7 +843,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
}
ptr_tables_[i]->rwlock_->UNLock();
}
if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
walk_to_src(num, total_device, h_left, h_right,
reinterpret_cast<char*>(d_shard_vals_ptr), val_type_size);
for (int i = 0; i < total_device; ++i) {
Expand All @@ -855,7 +856,7 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
val_type_size, stream);

sync_stream(stream);
if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
Expand Down Expand Up @@ -946,7 +947,7 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
memory_copy(dst_place, h_right, src_place, d_right_ptr,
total_device * sizeof(int), stream);

if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
for (int i = 0; i < total_device; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
Expand All @@ -965,13 +966,13 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
continue;
}
auto& node = path_[dev_num][i].nodes_.back();
if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
sync_stream(node.in_stream);
}

AnyDeviceGuard guard(resource_->dev_id(i));
ptr_tables_[i]->rwlock_->WRLock();
if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
ptr_tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
node.val_storage, h_right[i] - h_left[i] + 1,
sgd, resource_->remote_stream(i, dev_num));
Expand All @@ -995,7 +996,7 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
}
}

if (!direct_access_) {
if (!FLAGS_gpugraph_enable_gpu_direct_access) {
for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
Expand Down
28 changes: 21 additions & 7 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ __global__ void dy_mf_fill_shard_grads_kernel(
}
}

__global__ void merge_gradients_kernel(const uint32_t* offset,
template <typename KeyType>
__global__ void merge_gradients_kernel(const KeyType* d_keys,
const uint32_t* offset,
const uint32_t* fea_num,
const uint32_t* index, const char* input,
char* output, int n,
Expand All @@ -163,10 +165,13 @@ __global__ void merge_gradients_kernel(const uint32_t* offset,
float* in =
(float*)(input + size_t(ori_index) * grad_value_size);
merger_.update_one(out, in, feature_value_accessor);
for (int j = 1; j < num; ++j) {
ori_index = index[start + j];
in = (float*)(input + size_t(ori_index) * grad_value_size);
merger_.merge_one(out, in, feature_value_accessor);
KeyType key = d_keys[i];
if (key != 0) {
for (int j = 1; j < num; ++j) {
ori_index = index[start + j];
in = (float*)(input + size_t(ori_index) * grad_value_size);
merger_.merge_one(out, in, feature_value_accessor);
}
}
}
}
Expand Down Expand Up @@ -316,13 +321,15 @@ void HeterCommKernel::dy_mf_fill_shard_grads(
grad_value_size, feature_value_accessor_);
}

template <typename StreamType>
template <typename KeyType, typename StreamType>
void HeterCommKernel::merge_gradient(
const KeyType* d_keys,
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
DynamicGradMerger& merger_, const StreamType& stream) {
int grid_size = (n - 1) / block_size_ + 1;
merge_gradients_kernel<<<grid_size, block_size_, 0, stream>>>(
d_keys,
offset, fea_num, index, input, output, n, grad_value_size, merger_, feature_value_accessor_);
}

Expand Down Expand Up @@ -407,7 +414,14 @@ template void HeterCommKernel::dy_mf_fill_shard_grads<
float* d_shard_grads, float* d_grads, int* idx, long long len,
size_t grad_value_size, const cudaStream_t& stream);

template void HeterCommKernel::merge_gradient<cudaStream_t>(
template void HeterCommKernel::merge_gradient<uint32_t, cudaStream_t>(
const uint32_t* d_keys,
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
DynamicGradMerger& merger_, const cudaStream_t& stream);

template void HeterCommKernel::merge_gradient<uint64_t, cudaStream_t>(
const uint64_t* d_keys,
const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index,
const char* input, char* output, int n, size_t grad_value_size,
DynamicGradMerger& merger_, const cudaStream_t& stream);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ class HeterCommKernel {
T* idx, long long len, size_t grad_value_size,
const StreamType& stream);

template <typename StreamType>
void merge_gradient(const uint32_t* offset, const uint32_t* fea_num,
template <typename KeyType, typename StreamType>
void merge_gradient(const KeyType* d_shard_keys, const uint32_t* offset, const uint32_t* fea_num,
const uint32_t* index, const char* input, char* output,
int n, size_t grad_value_size, DynamicGradMerger& merger_,
const StreamType& stream);
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,9 @@ PADDLE_DEFINE_EXPORTED_bool(
PADDLE_DEFINE_EXPORTED_double(
gpugraph_hbm_table_load_factor, 0.75,
"the load factor of hbm table, default 0.75");
PADDLE_DEFINE_EXPORTED_bool(
gpugraph_enable_gpu_direct_access, false,
"enable hash collisions stat for hbm table, default false");

/**
* ProcessGroupNCCL related FLAG
Expand Down