diff --git a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h index 6e7d11c370362..06ec1567f4b47 100644 --- a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h +++ b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h @@ -122,23 +122,26 @@ node_list[7]-> node_id:9, neighbor_size:1, neighbor_offset:14 node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15 */ struct NeighborSampleQuery { - int gpu_id, idx; - int64_t *key; - int sample_size; + int gpu_id; + int table_idx; + int64_t *src_nodes; int len; - void initialize(int gpu_id, int idx, int64_t key, int sample_size, int len) { - this->idx = idx; + int sample_size; + void initialize(int gpu_id, int table_idx, int64_t src_nodes, int sample_size, + int len) { + this->table_idx = table_idx; this->gpu_id = gpu_id; - this->key = (int64_t *)key; + this->src_nodes = (int64_t *)src_nodes; this->sample_size = sample_size; this->len = len; } void display() { int64_t *sample_keys = new int64_t[len]; VLOG(0) << "device_id " << gpu_id << " sample_size = " << sample_size; - VLOG(0) << "there are " << len << " keys to sample for graph " << idx; + VLOG(0) << "there are " << len << " keys to sample for graph " << table_idx; std::string key_str; - cudaMemcpy(sample_keys, key, len * sizeof(int64_t), cudaMemcpyDeviceToHost); + cudaMemcpy(sample_keys, src_nodes, len * sizeof(int64_t), + cudaMemcpyDeviceToHost); for (int i = 0; i < len; i++) { if (key_str.size() > 0) key_str += ";"; @@ -213,7 +216,7 @@ struct NeighborSampleResult { std::vector graph; int64_t *sample_keys = new int64_t[q.len]; std::string key_str; - cudaMemcpy(sample_keys, q.key, q.len * sizeof(int64_t), + cudaMemcpy(sample_keys, q.src_nodes, q.len * sizeof(int64_t), cudaMemcpyDeviceToHost); int64_t *res = new int64_t[sample_size * key_size]; cudaMemcpy(res, val, sample_size * key_size * sizeof(int64_t), diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index b68bb1e1e60ce..215594f515fab 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -595,8 +595,9 @@ void GpuPsGraphTable::build_graph_from_cpu( NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v3( NeighborSampleQuery q, bool cpu_switch) { - return graph_neighbor_sample_v2(global_device_map[q.gpu_id], q.idx, q.key, - q.sample_size, q.len, cpu_switch); + return graph_neighbor_sample_v2(global_device_map[q.gpu_id], q.table_idx, + q.src_nodes, q.sample_size, q.len, + cpu_switch); } NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, int idx, int64_t* key,