diff --git a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h index 19c355c671a38..06ec1567f4b47 100644 --- a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h +++ b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h @@ -123,21 +123,25 @@ node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15 */ struct NeighborSampleQuery { int gpu_id; - int64_t *key; - int sample_size; + int table_idx; + int64_t *src_nodes; int len; - void initialize(int gpu_id, int64_t key, int sample_size, int len) { + int sample_size; + void initialize(int gpu_id, int table_idx, int64_t src_nodes, int sample_size, + int len) { + this->table_idx = table_idx; this->gpu_id = gpu_id; - this->key = (int64_t *)key; + this->src_nodes = (int64_t *)src_nodes; this->sample_size = sample_size; this->len = len; } void display() { int64_t *sample_keys = new int64_t[len]; VLOG(0) << "device_id " << gpu_id << " sample_size = " << sample_size; - VLOG(0) << "there are " << len << " keys "; + VLOG(0) << "there are " << len << " keys to sample for graph " << table_idx; std::string key_str; - cudaMemcpy(sample_keys, key, len * sizeof(int64_t), cudaMemcpyDeviceToHost); + cudaMemcpy(sample_keys, src_nodes, len * sizeof(int64_t), + cudaMemcpyDeviceToHost); for (int i = 0; i < len; i++) { if (key_str.size() > 0) key_str += ";"; @@ -212,7 +216,7 @@ struct NeighborSampleResult { std::vector graph; int64_t *sample_keys = new int64_t[q.len]; std::string key_str; - cudaMemcpy(sample_keys, q.key, q.len * sizeof(int64_t), + cudaMemcpy(sample_keys, q.src_nodes, q.len * sizeof(int64_t), cudaMemcpyDeviceToHost); int64_t *res = new int64_t[sample_size * key_size]; cudaMemcpy(res, val, sample_size * key_size * sizeof(int64_t), diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h index ae57c2ebe932f..1281c5ed8a9ee 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h @@ -23,19 +23,35 @@ #ifdef PADDLE_WITH_HETERPS namespace paddle { namespace framework { +enum GraphTableType { EDGE_TABLE, FEATURE_TABLE }; class GpuPsGraphTable : public HeterComm { public: - GpuPsGraphTable(std::shared_ptr resource, int topo_aware) + int get_table_offset(int gpu_id, GraphTableType type, int idx) { + int type_id = type; + return gpu_id * (graph_table_num_ + feature_table_num_) + + type_id * graph_table_num_ + idx; + } + GpuPsGraphTable(std::shared_ptr resource, int topo_aware, + int graph_table_num, int feature_table_num) : HeterComm(1, resource) { load_factor_ = 0.25; rw_lock.reset(new pthread_rwlock_t()); + this->graph_table_num_ = graph_table_num; + this->feature_table_num_ = feature_table_num; gpu_num = resource_->total_device(); memset(global_device_map, -1, sizeof(global_device_map)); + for (auto &table : tables_) { + delete table; + table = NULL; + } + tables_ = std::vector( + gpu_num * (graph_table_num + feature_table_num), NULL); + sample_status = std::vector(gpu_num * graph_table_num, NULL); for (int i = 0; i < gpu_num; i++) { - gpu_graph_list.push_back(GpuPsCommGraph()); global_device_map[resource_->dev_id(i)] = i; - sample_status.push_back(NULL); - tables_.push_back(NULL); + for (int j = 0; j < graph_table_num; j++) { + gpu_graph_list_.push_back(GpuPsCommGraph()); + } } cpu_table_status = -1; if (topo_aware) { @@ -89,21 +105,23 @@ class GpuPsGraphTable : public HeterComm { // end_graph_sampling(); // } } - void build_graph_on_single_gpu(GpuPsCommGraph &g, int gpu_id); - void clear_graph_info(int gpu_id); - void build_graph_from_cpu(std::vector &cpu_node_list); + void build_graph_on_single_gpu(GpuPsCommGraph &g, int gpu_id, int idx); + void clear_graph_info(int gpu_id, int index); + void clear_graph_info(int index); + void build_graph_from_cpu(std::vector &cpu_node_list, + int idx); NodeQueryResult graph_node_sample(int gpu_id, int sample_size); NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q, bool cpu_switch); - NeighborSampleResult graph_neighbor_sample(int gpu_id, int64_t *key, + NeighborSampleResult graph_neighbor_sample(int gpu_id, int idx, int64_t *key, int sample_size, int len); - NeighborSampleResult graph_neighbor_sample_v2(int gpu_id, int64_t *key, - int sample_size, int len, - bool cpu_query_switch); + NeighborSampleResult graph_neighbor_sample_v2(int gpu_id, int idx, + int64_t *key, int sample_size, + int len, bool cpu_query_switch); void init_sample_status(); void free_sample_status(); - NodeQueryResult query_node_list(int gpu_id, int start, int query_size); - void clear_graph_info(); + NodeQueryResult query_node_list(int gpu_id, int idx, int start, + int query_size); void display_sample_res(void *key, void *val, int len, int sample_len); void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num, int sample_size, int *h_left, @@ -112,12 +130,13 @@ class GpuPsGraphTable : public HeterComm { int *actual_sample_size); int init_cpu_table(const paddle::distributed::GraphParameter &graph); int gpu_num; - std::vector gpu_graph_list; + int graph_table_num_, feature_table_num_; + std::vector gpu_graph_list_; int global_device_map[32]; std::vector sample_status; const int parallel_sample_size = 1; const int dim_y = 256; - std::shared_ptr cpu_graph_table; + std::shared_ptr cpu_graph_table_; std::shared_ptr rw_lock; mutable std::mutex mutex_; std::condition_variable cv_; diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index 72b9cae41c0fd..703f43a8b3036 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -197,8 +197,8 @@ __global__ void neighbor_sample_example(GpuPsCommGraph graph, int GpuPsGraphTable::init_cpu_table( const paddle::distributed::GraphParameter& graph) { - cpu_graph_table.reset(new paddle::distributed::GraphTable); - cpu_table_status = cpu_graph_table->Initialize(graph); + cpu_graph_table_.reset(new paddle::distributed::GraphTable); + cpu_table_status = cpu_graph_table_->Initialize(graph); // if (cpu_table_status != 0) return cpu_table_status; // std::function&)> callback = // [this](std::vector& res) { @@ -212,17 +212,6 @@ int GpuPsGraphTable::init_cpu_table( return cpu_table_status; } -// int GpuPsGraphTable::load(const std::string& path, const std::string& param) -// { -// int status = cpu_graph_table->load(path, param); -// if (status != 0) { -// return status; -// } -// std::unique_lock lock(mutex_); -// cpu_graph_table->start_graph_sampling(); -// cv_.wait(lock); -// return 0; -// } /* comment 1 gpu i triggers a neighbor_sample task, @@ -445,11 +434,14 @@ __global__ void node_query_example(GpuPsCommGraph graph, int start, int size, } } -void GpuPsGraphTable::clear_graph_info(int gpu_id) { - if (tables_.size() && tables_[gpu_id] != NULL) { - delete tables_[gpu_id]; +void GpuPsGraphTable::clear_graph_info(int gpu_id, int idx) { + if (idx >= graph_table_num_) return; + int offset = get_table_offset(gpu_id, GraphTableType::EDGE_TABLE, idx); + if (offset < tables_.size()) { + delete tables_[offset]; + tables_[offset] = NULL; } - auto& graph = gpu_graph_list[gpu_id]; + auto& graph = gpu_graph_list_[gpu_id * graph_table_num_ + idx]; if (graph.neighbor_list != NULL) { cudaFree(graph.neighbor_list); } @@ -457,20 +449,8 @@ void GpuPsGraphTable::clear_graph_info(int gpu_id) { cudaFree(graph.node_list); } } -void GpuPsGraphTable::clear_graph_info() { - if (tables_.size()) { - for (auto table : tables_) delete table; - } - tables_.clear(); - for (auto graph : gpu_graph_list) { - if (graph.neighbor_list != NULL) { - cudaFree(graph.neighbor_list); - } - if (graph.node_list != NULL) { - cudaFree(graph.node_list); - } - } - gpu_graph_list.clear(); +void GpuPsGraphTable::clear_graph_info(int idx) { + for (int i = 0; i < gpu_num; i++) clear_graph_info(i, idx); } /* the parameter std::vector cpu_graph_list is generated by cpu. @@ -481,34 +461,38 @@ In this function, memory is allocated on each gpu to save the graphs, gpu i saves the ith graph from cpu_graph_list */ -void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i) { - clear_graph_info(i); +void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i, + int idx) { + clear_graph_info(i, idx); platform::CUDADeviceGuard guard(resource_->dev_id(i)); - // platform::CUDADeviceGuard guard(i); - gpu_graph_list[i] = GpuPsCommGraph(); - sample_status[i] = NULL; - tables_[i] = new Table(std::max((int64_t)1, g.node_size) / load_factor_); + int offset = i * graph_table_num_ + idx; + gpu_graph_list_[offset] = GpuPsCommGraph(); + sample_status[offset] = NULL; + int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, idx); + size_t capacity = std::max((int64_t)1, g.node_size) / load_factor_; + tables_[table_offset] = new Table(capacity); if (g.node_size > 0) { std::vector keys; - std::vector offset; - cudaMalloc((void**)&gpu_graph_list[i].node_list, + std::vector offsets; + cudaMalloc((void**)&gpu_graph_list_[offset].node_list, g.node_size * sizeof(GpuPsGraphNode)); - cudaMemcpy(gpu_graph_list[i].node_list, g.node_list, + cudaMemcpy(gpu_graph_list_[offset].node_list, g.node_list, g.node_size * sizeof(GpuPsGraphNode), cudaMemcpyHostToDevice); for (int64_t j = 0; j < g.node_size; j++) { keys.push_back(g.node_list[j].node_id); - offset.push_back(j); + offsets.push_back(j); } - build_ps(i, (uint64_t*)keys.data(), offset.data(), keys.size(), 1024, 8); - gpu_graph_list[i].node_size = g.node_size; + build_ps(i, (uint64_t*)keys.data(), offsets.data(), keys.size(), 1024, 8, + table_offset); + gpu_graph_list_[offset].node_size = g.node_size; } else { - build_ps(i, NULL, NULL, 0, 1024, 8); - gpu_graph_list[i].node_list = NULL; - gpu_graph_list[i].node_size = 0; + build_ps(i, NULL, NULL, 0, 1024, 8, table_offset); + gpu_graph_list_[offset].node_list = NULL; + gpu_graph_list_[offset].node_size = 0; } if (g.neighbor_size) { cudaError_t cudaStatus = - cudaMalloc((void**)&gpu_graph_list[i].neighbor_list, + cudaMalloc((void**)&gpu_graph_list_[offset].neighbor_list, g.neighbor_size * sizeof(int64_t)); PADDLE_ENFORCE_EQ(cudaStatus, cudaSuccess, platform::errors::InvalidArgument( @@ -516,82 +500,92 @@ void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i) { VLOG(0) << "sucessfully allocate " << g.neighbor_size * sizeof(int64_t) << " bytes of memory for graph-edges on gpu " << resource_->dev_id(i); - cudaMemcpy(gpu_graph_list[i].neighbor_list, g.neighbor_list, + cudaMemcpy(gpu_graph_list_[offset].neighbor_list, g.neighbor_list, g.neighbor_size * sizeof(int64_t), cudaMemcpyHostToDevice); - gpu_graph_list[i].neighbor_size = g.neighbor_size; + gpu_graph_list_[offset].neighbor_size = g.neighbor_size; } else { - gpu_graph_list[i].neighbor_list = NULL; - gpu_graph_list[i].neighbor_size = 0; + gpu_graph_list_[offset].neighbor_list = NULL; + gpu_graph_list_[offset].neighbor_size = 0; } } void GpuPsGraphTable::init_sample_status() { for (int i = 0; i < gpu_num; i++) { - if (gpu_graph_list[i].neighbor_size) { - platform::CUDADeviceGuard guard(resource_->dev_id(i)); - int* addr; - cudaMalloc((void**)&addr, gpu_graph_list[i].neighbor_size * sizeof(int)); - cudaMemset(addr, 0, gpu_graph_list[i].neighbor_size * sizeof(int)); - sample_status[i] = addr; + for (int j = 0; j < graph_table_num_; j++) { + int offset = i * graph_table_num_ + j; + if (gpu_graph_list_[offset].neighbor_size) { + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + int* addr; + cudaMalloc((void**)&addr, + gpu_graph_list_[offset].neighbor_size * sizeof(int)); + cudaMemset(addr, 0, + gpu_graph_list_[offset].neighbor_size * sizeof(int)); + sample_status[offset] = addr; + } } } } void GpuPsGraphTable::free_sample_status() { for (int i = 0; i < gpu_num; i++) { - if (sample_status[i] != NULL) { - platform::CUDADeviceGuard guard(resource_->dev_id(i)); - cudaFree(sample_status[i]); + for (int j = 0; j < graph_table_num_; j++) { + int offset = i * graph_table_num_ + j; + if (sample_status[offset] != NULL) { + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + cudaFree(sample_status[offset]); + } } } } void GpuPsGraphTable::build_graph_from_cpu( - std::vector& cpu_graph_list) { + std::vector& cpu_graph_list, int idx) { VLOG(0) << "in build_graph_from_cpu cpu_graph_list size = " << cpu_graph_list.size(); PADDLE_ENFORCE_EQ( cpu_graph_list.size(), resource_->total_device(), platform::errors::InvalidArgument("the cpu node list size doesn't match " "the number of gpu on your machine.")); - clear_graph_info(); + clear_graph_info(idx); for (int i = 0; i < cpu_graph_list.size(); i++) { + int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, idx); + int offset = i * graph_table_num_ + idx; platform::CUDADeviceGuard guard(resource_->dev_id(i)); - gpu_graph_list[i] = GpuPsCommGraph(); - sample_status[i] = NULL; - tables_[i] = new Table(std::max((int64_t)1, cpu_graph_list[i].node_size) / - load_factor_); + gpu_graph_list_[offset] = GpuPsCommGraph(); + sample_status[offset] = NULL; + tables_[table_offset] = new Table( + std::max((int64_t)1, cpu_graph_list[i].node_size) / load_factor_); if (cpu_graph_list[i].node_size > 0) { std::vector keys; - std::vector offset; - cudaMalloc((void**)&gpu_graph_list[i].node_list, + std::vector offsets; + cudaMalloc((void**)&gpu_graph_list_[offset].node_list, cpu_graph_list[i].node_size * sizeof(GpuPsGraphNode)); - cudaMemcpy(gpu_graph_list[i].node_list, cpu_graph_list[i].node_list, + cudaMemcpy(gpu_graph_list_[offset].node_list, cpu_graph_list[i].node_list, cpu_graph_list[i].node_size * sizeof(GpuPsGraphNode), cudaMemcpyHostToDevice); for (int64_t j = 0; j < cpu_graph_list[i].node_size; j++) { keys.push_back(cpu_graph_list[i].node_list[j].node_id); - offset.push_back(j); + offsets.push_back(j); } - build_ps(i, (uint64_t*)(keys.data()), offset.data(), keys.size(), 1024, - 8); - gpu_graph_list[i].node_size = cpu_graph_list[i].node_size; + build_ps(i, (uint64_t*)(keys.data()), offsets.data(), keys.size(), 1024, + 8, table_offset); + gpu_graph_list_[offset].node_size = cpu_graph_list[i].node_size; } else { - build_ps(i, NULL, NULL, 0, 1024, 8); - gpu_graph_list[i].node_list = NULL; - gpu_graph_list[i].node_size = 0; + build_ps(i, NULL, NULL, 0, 1024, 8, table_offset); + gpu_graph_list_[offset].node_list = NULL; + gpu_graph_list_[offset].node_size = 0; } if (cpu_graph_list[i].neighbor_size) { - cudaMalloc((void**)&gpu_graph_list[i].neighbor_list, + cudaMalloc((void**)&gpu_graph_list_[offset].neighbor_list, cpu_graph_list[i].neighbor_size * sizeof(int64_t)); - cudaMemcpy(gpu_graph_list[i].neighbor_list, + cudaMemcpy(gpu_graph_list_[offset].neighbor_list, cpu_graph_list[i].neighbor_list, cpu_graph_list[i].neighbor_size * sizeof(int64_t), cudaMemcpyHostToDevice); - gpu_graph_list[i].neighbor_size = cpu_graph_list[i].neighbor_size; + gpu_graph_list_[offset].neighbor_size = cpu_graph_list[i].neighbor_size; } else { - gpu_graph_list[i].neighbor_list = NULL; - gpu_graph_list[i].neighbor_size = 0; + gpu_graph_list_[offset].neighbor_list = NULL; + gpu_graph_list_[offset].neighbor_size = 0; } } cudaDeviceSynchronize(); @@ -599,10 +593,11 @@ void GpuPsGraphTable::build_graph_from_cpu( NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v3( NeighborSampleQuery q, bool cpu_switch) { - return graph_neighbor_sample_v2(global_device_map[q.gpu_id], q.key, - q.sample_size, q.len, cpu_switch); + return graph_neighbor_sample_v2(global_device_map[q.gpu_id], q.table_idx, + q.src_nodes, q.sample_size, q.len, + cpu_switch); } -NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, +NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, int idx, int64_t* key, int sample_size, int len) { @@ -723,12 +718,14 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, node.in_stream); cudaStreamSynchronize(node.in_stream); platform::CUDADeviceGuard guard(resource_->dev_id(i)); - tables_[i]->get(reinterpret_cast(node.key_storage), - reinterpret_cast(node.val_storage), - h_right[i] - h_left[i] + 1, - resource_->remote_stream(i, gpu_id)); + int offset = i * graph_table_num_ + idx; + int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, idx); + tables_[table_offset]->get(reinterpret_cast(node.key_storage), + reinterpret_cast(node.val_storage), + h_right[i] - h_left[i] + 1, + resource_->remote_stream(i, gpu_id)); // node.in_stream); - auto graph = gpu_graph_list[i]; + auto graph = gpu_graph_list_[offset]; int64_t* id_array = reinterpret_cast(node.val_storage); int* actual_size_array = (int*)(id_array + shard_len); int64_t* sample_array = @@ -739,7 +736,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, neighbor_sample_example<<remote_stream(i, gpu_id)>>>( graph, id_array, actual_size_array, sample_array, sample_size, - sample_status[i], shard_len, gpu_id); + sample_status[offset], shard_len, gpu_id); } for (int i = 0; i < total_gpu; ++i) { @@ -766,7 +763,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, } NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( - int gpu_id, int64_t* key, int sample_size, int len, bool cpu_query_switch) { + int gpu_id, int idx, int64_t* key, int sample_size, int len, + bool cpu_query_switch) { NeighborSampleResult result; result.initialize(sample_size, len, resource_->dev_id(gpu_id)); @@ -844,12 +842,14 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( cudaStreamSynchronize(node.in_stream); platform::CUDADeviceGuard guard(resource_->dev_id(i)); // If not found, val is -1. - tables_[i]->get(reinterpret_cast(node.key_storage), - reinterpret_cast(node.val_storage), - h_right[i] - h_left[i] + 1, - resource_->remote_stream(i, gpu_id)); - - auto graph = gpu_graph_list[i]; + int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, idx); + int offset = i * graph_table_num_ + idx; + tables_[table_offset]->get(reinterpret_cast(node.key_storage), + reinterpret_cast(node.val_storage), + h_right[i] - h_left[i] + 1, + resource_->remote_stream(i, gpu_id)); + + auto graph = gpu_graph_list_[offset]; int64_t* id_array = reinterpret_cast(node.val_storage); int* actual_size_array = (int*)(id_array + shard_len); int64_t* sample_array = @@ -872,7 +872,6 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( } cudaStreamSynchronize(resource_->remote_stream(i, gpu_id)); } - move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, sample_size, h_left, h_right, d_shard_vals_ptr, d_shard_actual_sample_size_ptr); @@ -904,8 +903,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( std::vector> buffers(number_on_cpu); std::vector ac(number_on_cpu); - auto status = cpu_graph_table->random_sample_neighbors( - 0, cpu_keys, sample_size, buffers, ac, false); + auto status = cpu_graph_table_->random_sample_neighbors( + idx, cpu_keys, sample_size, buffers, ac, false); int total_cpu_sample_size = std::accumulate(ac.begin(), ac.end(), 0); total_cpu_sample_size /= sizeof(int64_t); @@ -1001,7 +1000,7 @@ NodeQueryResult GpuPsGraphTable::graph_node_sample(int gpu_id, return NodeQueryResult(); } -NodeQueryResult GpuPsGraphTable::query_node_list(int gpu_id, int start, +NodeQueryResult GpuPsGraphTable::query_node_list(int gpu_id, int idx, int start, int query_size) { NodeQueryResult result; if (query_size <= 0) return result; @@ -1009,24 +1008,8 @@ NodeQueryResult GpuPsGraphTable::query_node_list(int gpu_id, int start, actual_size = 0; // int dev_id = resource_->dev_id(gpu_id); // platform::CUDADeviceGuard guard(dev_id); - std::vector idx, gpu_begin_pos, local_begin_pos; + std::vector gpu_begin_pos, local_begin_pos; int sample_size; - /* - if idx[i] = a, gpu_begin_pos[i] = p1, - gpu_local_begin_pos[i] = p2; - sample_size[i] = s; - then on gpu a, the nodes of positions [p1,p1 + s) should be returned - and saved from the p2 position on the sample_result array - for example: - suppose - gpu 0 saves [0,2,4,6,8], gpu1 saves [1,3,5,7] - start = 3, query_size = 5 - we know [6,8,1,3,5] should be returned; - idx = [0,1] - gpu_begin_pos = [3,0] - local_begin_pos = [0,3] - sample_size = [2,3] - */ std::function range_check = []( int x, int y, int x1, int y1, int& x2, int& y2) { if (y <= x1 || x >= y1) return 0; @@ -1034,7 +1017,7 @@ NodeQueryResult GpuPsGraphTable::query_node_list(int gpu_id, int start, x2 = max(x1, x); return y2 - x2; }; - auto graph = gpu_graph_list[gpu_id]; + auto graph = gpu_graph_list_[gpu_id]; if (graph.node_size == 0) { return result; } @@ -1051,60 +1034,13 @@ NodeQueryResult GpuPsGraphTable::query_node_list(int gpu_id, int start, val = result.val; int dev_id_i = resource_->dev_id(gpu_id); platform::CUDADeviceGuard guard(dev_id_i); - // platform::CUDADeviceGuard guard(i); int grid_size = (len - 1) / block_size_ + 1; + int offset = gpu_id * graph_table_num_ + idx; node_query_example<<remote_stream(gpu_id, gpu_id)>>>( - gpu_graph_list[gpu_id], x2, len, (int64_t*)val); + gpu_graph_list_[offset], x2, len, (int64_t*)val); cudaStreamSynchronize(resource_->remote_stream(gpu_id, gpu_id)); return result; - /* - for (int i = 0; i < gpu_graph_list.size() && query_size != 0; i++) { - auto graph = gpu_graph_list[i]; - if (graph.node_size == 0) { - continue; - } - int x2, y2; - int len = range_check(start, start + query_size, size, - size + graph.node_size, x2, y2); - if (len > 0) { - idx.push_back(i); - gpu_begin_pos.emplace_back(x2 - size); - local_begin_pos.emplace_back(actual_size); - sample_size.push_back(len); - actual_size += len; - create_storage(gpu_id, i, 1, len * sizeof(int64_t)); - } - size += graph.node_size; - } - for (int i = 0; i < idx.size(); i++) { - int dev_id_i = resource_->dev_id(idx[i]); - platform::CUDADeviceGuard guard(dev_id_i); - // platform::CUDADeviceGuard guard(i); - auto& node = path_[gpu_id][idx[i]].nodes_.front(); - int grid_size = (sample_size[i] - 1) / block_size_ + 1; - node_query_example<<remote_stream(idx[i], gpu_id)>>>( - gpu_graph_list[idx[i]], gpu_begin_pos[i], sample_size[i], - (int64_t*)node.val_storage); - } - - for (int i = 0; i < idx.size(); i++) { - cudaStreamSynchronize(resource_->remote_stream(idx[i], gpu_id)); - auto& node = path_[gpu_id][idx[i]].nodes_.front(); - cudaMemcpyAsync(reinterpret_cast(val + local_begin_pos[i]), - node.val_storage, node.val_bytes_len, cudaMemcpyDefault, - node.out_stream); - } - for (int i = 0; i < idx.size(); i++) { - auto& node = path_[gpu_id][idx[i]].nodes_.front(); - cudaStreamSynchronize(node.out_stream); - } - for (auto x : idx) { - destroy_storage(gpu_id, x); - } - return result; - */ } } }; diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu index 3f6602893bd29..070d2e75381db 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h" -#include namespace paddle { namespace framework { #ifdef PADDLE_WITH_HETERPS @@ -29,7 +29,7 @@ void GraphGpuWrapper::set_device(std::vector ids) { std::vector> GraphGpuWrapper::get_all_id(int type, int idx, int slice_num) { return ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table->get_all_id(type, idx, slice_num); + ->cpu_graph_table_->get_all_id(type, idx, slice_num); } void GraphGpuWrapper::set_up_types(std::vector &edge_types, std::vector &node_types) { @@ -52,28 +52,28 @@ void GraphGpuWrapper::set_up_types(std::vector &edge_types, void GraphGpuWrapper::make_partitions(int idx, int64_t byte_size, int device_len) { ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table->make_partitions(idx, byte_size, device_len); + ->cpu_graph_table_->make_partitions(idx, byte_size, device_len); } int32_t GraphGpuWrapper::load_next_partition(int idx) { return ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table->load_next_partition(idx); + ->cpu_graph_table_->load_next_partition(idx); } void GraphGpuWrapper::set_search_level(int level) { - ((GpuPsGraphTable *)graph_table)->cpu_graph_table->set_search_level(level); + ((GpuPsGraphTable *)graph_table)->cpu_graph_table_->set_search_level(level); } std::vector GraphGpuWrapper::get_partition(int idx, int num) { return ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table->get_partition(idx, num); + ->cpu_graph_table_->get_partition(idx, num); } int32_t GraphGpuWrapper::get_partition_num(int idx) { return ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table->get_partition_num(idx); + ->cpu_graph_table_->get_partition_num(idx); } void GraphGpuWrapper::make_complementary_graph(int idx, int64_t byte_size) { ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table->make_complementary_graph(idx, byte_size); + ->cpu_graph_table_->make_complementary_graph(idx, byte_size); } void GraphGpuWrapper::load_edge_file(std::string name, std::string filepath, bool reverse) { @@ -88,7 +88,7 @@ void GraphGpuWrapper::load_edge_file(std::string name, std::string filepath, } if (edge_to_id.find(name) != edge_to_id.end()) { ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table->Load(std::string(filepath), params); + ->cpu_graph_table_->Load(std::string(filepath), params); } } @@ -99,7 +99,7 @@ void GraphGpuWrapper::load_node_file(std::string name, std::string filepath) { if (feature_to_id.find(name) != feature_to_id.end()) { ((GpuPsGraphTable *)graph_table) - ->cpu_graph_table->Load(std::string(filepath), params); + ->cpu_graph_table_->Load(std::string(filepath), params); } } @@ -136,7 +136,7 @@ void GraphGpuWrapper::init_search_level(int level) { search_level = level; } void GraphGpuWrapper::init_service() { table_proto.set_task_pool_size(24); table_proto.set_search_level(search_level); - table_proto.set_table_name("cpu_graph_table"); + table_proto.set_table_name("cpu_graph_table_"); table_proto.set_use_cache(false); for (int i = 0; i < id_to_edge.size(); i++) table_proto.add_edge_types(id_to_edge[i]); @@ -153,7 +153,8 @@ void GraphGpuWrapper::init_service() { std::shared_ptr resource = std::make_shared(device_id_mapping); resource->enable_p2p(); - GpuPsGraphTable *g = new GpuPsGraphTable(resource, 1); + GpuPsGraphTable *g = + new GpuPsGraphTable(resource, 1, id_to_edge.size(), id_to_feature.size()); g->init_cpu_table(table_proto); graph_table = (char *)g; } @@ -161,16 +162,14 @@ void GraphGpuWrapper::init_service() { void GraphGpuWrapper::upload_batch(int idx, std::vector> &ids) { GpuPsGraphTable *g = (GpuPsGraphTable *)graph_table; - // std::vector vec; for (int i = 0; i < ids.size(); i++) { - // vec.push_back(g->cpu_graph_table->make_gpu_ps_graph(idx, ids[i])); GpuPsCommGraph sub_graph = - g->cpu_graph_table->make_gpu_ps_graph(idx, ids[i]); - g->build_graph_on_single_gpu(sub_graph, i); + g->cpu_graph_table_->make_gpu_ps_graph(idx, ids[i]); + // sub_graph.display_on_cpu(); + g->build_graph_on_single_gpu(sub_graph, i, idx); sub_graph.release_on_cpu(); VLOG(0) << "sub graph on gpu " << i << " is built"; } - // g->build_graph_from_cpu(vec); } // void GraphGpuWrapper::test() { @@ -209,34 +208,34 @@ NeighborSampleResult GraphGpuWrapper::graph_neighbor_sample_v3( } NeighborSampleResult GraphGpuWrapper::graph_neighbor_sample( - int gpu_id, int64_t* device_keys, int walk_degree, int len) { + int gpu_id, int64_t *device_keys, int walk_degree, int len) { platform::CUDADeviceGuard guard(gpu_id); auto neighbor_sample_res = ((GpuPsGraphTable *)graph_table) ->graph_neighbor_sample(gpu_id, device_keys, walk_degree, len); - - //int64_t *cpu_keys = new int64_t[len]; - //cudaMemcpy(cpu_keys, device_keys, + + // int64_t *cpu_keys = new int64_t[len]; + // cudaMemcpy(cpu_keys, device_keys, // len * sizeof(int64_t), // cudaMemcpyDeviceToHost); // 3, 1, 3 - //int *actual_sample_size = new int[len]; - //cudaMemcpy(actual_sample_size, neighbor_sample_res.actual_sample_size, + // int *actual_sample_size = new int[len]; + // cudaMemcpy(actual_sample_size, neighbor_sample_res.actual_sample_size, // len * sizeof(int), // cudaMemcpyDeviceToHost); // 3, 1, 3 - //std::stringstream ss; - //ss << len << "\t"; - //for (int i = 0; i < len; i++) { + // std::stringstream ss; + // ss << len << "\t"; + // for (int i = 0; i < len; i++) { // ss << cpu_keys[i] << ":" << actual_sample_size[i] << ","; //} - //VLOG(0) << ss.str(); - //free(actual_sample_size); - //free(cpu_keys); + // VLOG(0) << ss.str(); + // free(actual_sample_size); + // free(cpu_keys); return neighbor_sample_res; } // this function is contributed by Liwb5 std::vector GraphGpuWrapper::graph_neighbor_sample( - int gpu_id, std::vector &key, int sample_size) { + int gpu_id, int idx, std::vector &key, int sample_size) { int64_t *cuda_key; platform::CUDADeviceGuard guard(gpu_id); @@ -246,7 +245,8 @@ std::vector GraphGpuWrapper::graph_neighbor_sample( VLOG(0) << "key_size: " << key.size(); auto neighbor_sample_res = ((GpuPsGraphTable *)graph_table) - ->graph_neighbor_sample(gpu_id, cuda_key, sample_size, key.size()); + ->graph_neighbor_sample(gpu_id, idx, cuda_key, sample_size, + key.size()); int *actual_sample_size = new int[key.size()]; cudaMemcpy(actual_sample_size, neighbor_sample_res.actual_sample_size, key.size() * sizeof(int), @@ -283,10 +283,19 @@ void GraphGpuWrapper::init_sample_status() { void GraphGpuWrapper::free_sample_status() { ((GpuPsGraphTable *)graph_table)->free_sample_status(); } -NodeQueryResult GraphGpuWrapper::query_node_list(int gpu_id, int start, +NodeQueryResult GraphGpuWrapper::query_node_list(int gpu_id, int idx, int start, int query_size) { return ((GpuPsGraphTable *)graph_table) - ->query_node_list(gpu_id, start, query_size); + ->query_node_list(gpu_id, idx, start, query_size); +} +void GraphGpuWrapper::load_node_weight(int type_id, int idx, std::string path) { + return ((GpuPsGraphTable *)graph_table) + ->cpu_graph_table_->load_node_weight(type_id, idx, path); +} + +void GraphGpuWrapper::export_partition_files(int idx, std::string file_path) { + return ((GpuPsGraphTable *)graph_table) + ->cpu_graph_table_->export_partition_files(idx, file_path); } void GraphGpuWrapper::load_node_weight(int type_id, int idx, std::string path) { return ((GpuPsGraphTable *)graph_table) diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h index 2e5aab3f99938..1c0aca12d5ab4 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h @@ -31,7 +31,6 @@ class GraphGpuWrapper { } static std::shared_ptr s_instance_; void initialize(); - void test(); void set_device(std::vector ids); void init_service(); void set_up_types(std::vector& edge_type, @@ -52,12 +51,13 @@ class GraphGpuWrapper { void init_search_level(int level); std::vector> get_all_id(int type, int idx, int slice_num); - NodeQueryResult query_node_list(int gpu_id, int start, int query_size); + NodeQueryResult query_node_list(int gpu_id, int idx, int start, + int query_size); NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q, bool cpu_switch); - NeighborSampleResult graph_neighbor_sample( - int gpu_id, int64_t* device_keys, int walk_degree, int len); - std::vector graph_neighbor_sample(int gpu_id, + NeighborSampleResult graph_neighbor_sample(int gpu_id, int64_t* device_keys, + int walk_degree, int len); + std::vector graph_neighbor_sample(int gpu_id, int idx, std::vector& key, int sample_size); diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 5edc218796ef8..935359164bb02 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -322,6 +322,7 @@ template class HashTable; template class HashTable; template class HashTable; template class HashTable; +template class HashTable; template class HashTable; template class HashTable; template class HashTable; @@ -337,6 +338,8 @@ template void HashTable::get(const long* d_keys, template void HashTable::get( const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream); +template void HashTable::get( + const unsigned long* d_keys, long* d_vals, size_t len, cudaStream_t stream); template void HashTable::get( const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream); template void HashTable::get(const long* d_keys, @@ -366,6 +369,11 @@ template void HashTable::insert(const long* d_keys, template void HashTable::insert( const unsigned long* d_keys, const int* d_vals, size_t len, cudaStream_t stream); + +template void HashTable::insert( + const unsigned long* d_keys, const long* d_vals, size_t len, + cudaStream_t stream); + template void HashTable::insert( const long* d_keys, const unsigned long* d_vals, size_t len, cudaStream_t stream); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index e53a962c5abde..e5384a1bc6787 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -52,7 +52,7 @@ class HeterComm { int& uniq_len); // NOLINT void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len); void build_ps(int num, KeyType* h_keys, ValType* h_vals, size_t len, - size_t chunk_size, int stream_num); + size_t chunk_size, int stream_num, int offset = -1); void dump(); void show_one_table(int gpu_num); int get_index_by_devid(int devid); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index d23719ea9eb77..8a81d10f59326 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -362,7 +362,7 @@ void HeterComm::set_embedx_sgd( template void HeterComm::build_ps( int dev_num, KeyType* h_keys, ValType* h_vals, size_t len, - size_t chunk_size, int stream_num) { + size_t chunk_size, int stream_num, int offset) { if (len <= 0) { return; } @@ -403,8 +403,8 @@ void HeterComm::build_ps( memory_copy( dst_place, reinterpret_cast(d_val_bufs[cur_stream]->ptr()), src_place, h_vals + cur_len, sizeof(ValType) * tmp_len, cur_use_stream); - - tables_[dev_num]->insert( + if (offset == -1) offset = dev_num; + tables_[offset]->insert( reinterpret_cast(d_key_bufs[cur_stream]->ptr()), reinterpret_cast(d_val_bufs[cur_stream]->ptr()), tmp_len, cur_use_stream); diff --git a/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu b/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu index ff3cd9d2d046d..5a1f4cf740441 100644 --- a/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu +++ b/paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu @@ -109,10 +109,10 @@ TEST(TEST_FLEET, test_cpu_cache) { std::make_shared(device_id_mapping); resource->enable_p2p(); int use_nv = 1; - GpuPsGraphTable g(resource, use_nv); + GpuPsGraphTable g(resource, use_nv, 1, 2); g.init_cpu_table(table_proto); - g.cpu_graph_table->Load(node_file_name, "nuser"); - g.cpu_graph_table->Load(node_file_name, "nitem"); + g.cpu_graph_table_->Load(node_file_name, "nuser"); + g.cpu_graph_table_->Load(node_file_name, "nitem"); std::remove(node_file_name); std::vector vec; std::vector node_ids; @@ -123,7 +123,7 @@ TEST(TEST_FLEET, test_cpu_cache) { std::vector feature_names; feature_names.push_back(std::string("c")); feature_names.push_back(std::string("d")); - g.cpu_graph_table->get_node_feat(0, node_ids, feature_names, node_feat); + g.cpu_graph_table_->get_node_feat(0, node_ids, feature_names, node_feat); VLOG(0) << "get_node_feat: " << node_feat[0][0]; VLOG(0) << "get_node_feat: " << node_feat[0][1]; VLOG(0) << "get_node_feat: " << node_feat[1][0]; @@ -131,30 +131,21 @@ TEST(TEST_FLEET, test_cpu_cache) { int n = 10; std::vector ids0, ids1; for (int i = 0; i < n; i++) { - g.cpu_graph_table->add_comm_edge(0, i, (i + 1) % n); - g.cpu_graph_table->add_comm_edge(0, i, (i - 1 + n) % n); + g.cpu_graph_table_->add_comm_edge(0, i, (i + 1) % n); + g.cpu_graph_table_->add_comm_edge(0, i, (i - 1 + n) % n); if (i % 2 == 0) ids0.push_back(i); } - g.cpu_graph_table->build_sampler(0); + g.cpu_graph_table_->build_sampler(0); ids1.push_back(5); ids1.push_back(7); - vec.push_back(g.cpu_graph_table->make_gpu_ps_graph(0, ids0)); - vec.push_back(g.cpu_graph_table->make_gpu_ps_graph(0, ids1)); + vec.push_back(g.cpu_graph_table_->make_gpu_ps_graph(0, ids0)); + vec.push_back(g.cpu_graph_table_->make_gpu_ps_graph(0, ids1)); vec[0].display_on_cpu(); vec[1].display_on_cpu(); // g.build_graph_from_cpu(vec); - g.build_graph_on_single_gpu(vec[0], 0); - g.build_graph_on_single_gpu(vec[1], 1); + g.build_graph_on_single_gpu(vec[0], 0, 0); + g.build_graph_on_single_gpu(vec[1], 1, 0); int64_t cpu_key[3] = {0, 1, 2}; - /* - std::vector> buffers(3); - std::vector actual_sizes(3,0); - g.cpu_graph_table->random_sample_neighbors(cpu_key,2,buffers,actual_sizes,false); - for(int i = 0;i < 3;i++){ - VLOG(0)<<"sample from cpu key->"<set_search_level(2); - // g.cpu_graph_table->Load_to_ssd(edge_file_name,"e>u2u"); - g.cpu_graph_table->Load(edge_file_name, "e>u2u"); - g.cpu_graph_table->make_partitions(0, 64, 2); + g.cpu_graph_table_->clear_graph(0); + g.cpu_graph_table_->set_search_level(2); + g.cpu_graph_table_->Load(edge_file_name, "e>u2u"); + g.cpu_graph_table_->make_partitions(0, 64, 2); int index = 0; - while (g.cpu_graph_table->load_next_partition(0) != -1) { - auto all_ids = g.cpu_graph_table->get_all_id(0, 0, device_len); + while (g.cpu_graph_table_->load_next_partition(0) != -1) { + auto all_ids = g.cpu_graph_table_->get_all_id(0, 0, device_len); for (auto x : all_ids) { for (auto y : x) { VLOG(0) << "part " << index << " " << y; @@ -195,19 +186,19 @@ TEST(TEST_FLEET, test_cpu_cache) { } for (int i = 0; i < all_ids.size(); i++) { GpuPsCommGraph sub_graph = - g.cpu_graph_table->make_gpu_ps_graph(0, all_ids[i]); - g.build_graph_on_single_gpu(sub_graph, i); + g.cpu_graph_table_->make_gpu_ps_graph(0, all_ids[i]); + g.build_graph_on_single_gpu(sub_graph, i, 0); VLOG(2) << "sub graph on gpu " << i << " is built"; } VLOG(0) << "start to iterate gpu graph node"; - g.cpu_graph_table->make_complementary_graph(0, 64); + g.cpu_graph_table_->make_complementary_graph(0, 64); for (int i = 0; i < 2; i++) { // platform::CUDADeviceGuard guard(i); LOG(0) << "query on card " << i; int step = 2; int cur = 0; while (true) { - auto node_query_res = g.query_node_list(i, cur, step); + auto node_query_res = g.query_node_list(i, 0, cur, step); node_query_res.display(); if (node_query_res.get_len() == 0) { VLOG(0) << "no more ids,break"; @@ -215,7 +206,7 @@ TEST(TEST_FLEET, test_cpu_cache) { } cur += node_query_res.get_len(); NeighborSampleQuery query, q1; - query.initialize(i, node_query_res.get_val(), 4, + query.initialize(i, 0, node_query_res.get_val(), 4, node_query_res.get_len()); query.display(); auto c = g.graph_neighbor_sample_v3(query, true); @@ -223,15 +214,15 @@ TEST(TEST_FLEET, test_cpu_cache) { platform::CUDADeviceGuard guard(i); int64_t *key; VLOG(0) << "sample key 1 globally"; - g.cpu_graph_table->set_search_level(2); + g.cpu_graph_table_->set_search_level(2); cudaMalloc((void **)&key, sizeof(int64_t)); int64_t t_key = 1; cudaMemcpy(key, &t_key, sizeof(int64_t), cudaMemcpyHostToDevice); - q1.initialize(i, (int64_t)key, 2, 1); + q1.initialize(i, 0, (int64_t)key, 2, 1); auto d = g.graph_neighbor_sample_v3(q1, true); d.display(); cudaFree(key); - g.cpu_graph_table->set_search_level(1); + g.cpu_graph_table_->set_search_level(1); } } index++;