Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge gpugraph new #79

Merged
merged 3 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2133,8 +2133,9 @@ int GraphDataGenerator::FillInferBuf() {
global_infer_node_type_start[infer_cursor] += total_row_;
infer_node_end_ = global_infer_node_type_start[infer_cursor];
cursor_ = infer_cursor;
return 1;
}
return 1;
return 0;
}

void GraphDataGenerator::ClearSampleState() {
Expand Down
21 changes: 17 additions & 4 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
#define ALIGN_INT64(LEN) (uint64_t((LEN) + 7) & uint64_t(~7))
#define HBMPS_MAX_BUFF 1024 * 1024

DECLARE_bool(enable_neighbor_list_use_uva);

namespace paddle {
namespace framework {
/*
Expand Down Expand Up @@ -895,8 +898,14 @@ void GpuPsGraphTable::build_graph_on_single_gpu(const GpuPsCommGraph& g,
gpu_graph_list_[offset].node_size = 0;
}
if (g.neighbor_size) {
cudaError_t cudaStatus = cudaMalloc(&gpu_graph_list_[offset].neighbor_list,
cudaError_t cudaStatus;
if (!FLAGS_enable_neighbor_list_use_uva) {
cudaStatus = cudaMalloc(&gpu_graph_list_[offset].neighbor_list,
g.neighbor_size * sizeof(uint64_t));
} else {
cudaStatus = cudaMallocManaged(&gpu_graph_list_[offset].neighbor_list,
g.neighbor_size * sizeof(uint64_t));
}
PADDLE_ENFORCE_EQ(cudaStatus,
cudaSuccess,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -964,9 +973,13 @@ void GpuPsGraphTable::build_graph_from_cpu(
gpu_graph_list_[offset].node_size = 0;
}
if (cpu_graph_list[i].neighbor_size) {
CUDA_CHECK(
cudaMalloc(&gpu_graph_list_[offset].neighbor_list,
cpu_graph_list[i].neighbor_size * sizeof(uint64_t)));
if (!FLAGS_enable_neighbor_list_use_uva) {
CUDA_CHECK(cudaMalloc(&gpu_graph_list_[offset].neighbor_list,
cpu_graph_list[i].neighbor_size * sizeof(uint64_t)));
} else {
CUDA_CHECK(cudaMallocManaged(&gpu_graph_list_[offset].neighbor_list,
cpu_graph_list[i].neighbor_size * sizeof(uint64_t)));
}

CUDA_CHECK(
cudaMemcpyAsync(gpu_graph_list_[offset].neighbor_list,
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/core/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,18 @@ PADDLE_DEFINE_EXPORTED_bool(graph_load_in_parallel,
false,
"It controls whether load graph node and edge with "
"mutli threads parallely.");

/**
* Distributed related FLAG
* Name: FLAGS_enable_neighbor_list_use_uva
* Since Version: 2.2.0
* Value Range: bool, default=false
* Example:
* Note: Control whether store neighbor_list with UVA
*/
PADDLE_DEFINE_EXPORTED_bool(enable_neighbor_list_use_uva,
false,
"It controls whether store neighbor_list with UVA");

/**
* Distributed related FLAG
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/distributed/ps/the_one_ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,7 +1630,8 @@ def _save_cache_model(self, dirname, **kwargs):
return feasign_num

def _save_cache_table(self, table_id, pass_id, mem_cache_key_threshold):
if self.role_maker._is_first_worker():
fleet.util.barrier()
if self.context['use_ps_gpu'] or self.role_maker._is_first_worker():
self._worker.save_cache_table(
table_id, pass_id, mem_cache_key_threshold
)
Expand Down