Skip to content

Commit

Permalink
rename variables
Browse files Browse the repository at this point in the history
  • Loading branch information
seemingwang committed May 25, 2022
1 parent 112cf61 commit dc8dbf2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
21 changes: 12 additions & 9 deletions paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,26 @@ node_list[7]-> node_id:9, neighbor_size:1, neighbor_offset:14
node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15
*/
struct NeighborSampleQuery {
int gpu_id, idx;
int64_t *key;
int sample_size;
int gpu_id;
int table_idx;
int64_t *src_nodes;
int len;
void initialize(int gpu_id, int idx, int64_t key, int sample_size, int len) {
this->idx = idx;
int sample_size;
void initialize(int gpu_id, int table_idx, int64_t src_nodes, int sample_size,
int len) {
this->table_idx = table_idx;
this->gpu_id = gpu_id;
this->key = (int64_t *)key;
this->src_nodes = (int64_t *)src_nodes;
this->sample_size = sample_size;
this->len = len;
}
void display() {
int64_t *sample_keys = new int64_t[len];
VLOG(0) << "device_id " << gpu_id << " sample_size = " << sample_size;
VLOG(0) << "there are " << len << " keys to sample for graph " << idx;
VLOG(0) << "there are " << len << " keys to sample for graph " << table_idx;
std::string key_str;
cudaMemcpy(sample_keys, key, len * sizeof(int64_t), cudaMemcpyDeviceToHost);
cudaMemcpy(sample_keys, src_nodes, len * sizeof(int64_t),
cudaMemcpyDeviceToHost);

for (int i = 0; i < len; i++) {
if (key_str.size() > 0) key_str += ";";
Expand Down Expand Up @@ -213,7 +216,7 @@ struct NeighborSampleResult {
std::vector<int64_t> graph;
int64_t *sample_keys = new int64_t[q.len];
std::string key_str;
cudaMemcpy(sample_keys, q.key, q.len * sizeof(int64_t),
cudaMemcpy(sample_keys, q.src_nodes, q.len * sizeof(int64_t),
cudaMemcpyDeviceToHost);
int64_t *res = new int64_t[sample_size * key_size];
cudaMemcpy(res, val, sample_size * key_size * sizeof(int64_t),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,9 @@ void GpuPsGraphTable::build_graph_from_cpu(

NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v3(
NeighborSampleQuery q, bool cpu_switch) {
return graph_neighbor_sample_v2(global_device_map[q.gpu_id], q.idx, q.key,
q.sample_size, q.len, cpu_switch);
return graph_neighbor_sample_v2(global_device_map[q.gpu_id], q.table_idx,
q.src_nodes, q.sample_size, q.len,
cpu_switch);
}
NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, int idx,
int64_t* key,
Expand Down

0 comments on commit dc8dbf2

Please sign in to comment.