Skip to content

Commit

Permalink
Merge pull request #3 from seemingwang/gpu_graph_engine2
Browse files Browse the repository at this point in the history
split graph table
  • Loading branch information
xuewujiao authored May 25, 2022
2 parents 63e501d + a9b5445 commit fc28b23
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 268 deletions.
18 changes: 11 additions & 7 deletions paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,25 @@ node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15
*/
struct NeighborSampleQuery {
int gpu_id;
int64_t *key;
int sample_size;
int table_idx;
int64_t *src_nodes;
int len;
void initialize(int gpu_id, int64_t key, int sample_size, int len) {
int sample_size;
void initialize(int gpu_id, int table_idx, int64_t src_nodes, int sample_size,
int len) {
this->table_idx = table_idx;
this->gpu_id = gpu_id;
this->key = (int64_t *)key;
this->src_nodes = (int64_t *)src_nodes;
this->sample_size = sample_size;
this->len = len;
}
void display() {
int64_t *sample_keys = new int64_t[len];
VLOG(0) << "device_id " << gpu_id << " sample_size = " << sample_size;
VLOG(0) << "there are " << len << " keys ";
VLOG(0) << "there are " << len << " keys to sample for graph " << table_idx;
std::string key_str;
cudaMemcpy(sample_keys, key, len * sizeof(int64_t), cudaMemcpyDeviceToHost);
cudaMemcpy(sample_keys, src_nodes, len * sizeof(int64_t),
cudaMemcpyDeviceToHost);

for (int i = 0; i < len; i++) {
if (key_str.size() > 0) key_str += ";";
Expand Down Expand Up @@ -212,7 +216,7 @@ struct NeighborSampleResult {
std::vector<int64_t> graph;
int64_t *sample_keys = new int64_t[q.len];
std::string key_str;
cudaMemcpy(sample_keys, q.key, q.len * sizeof(int64_t),
cudaMemcpy(sample_keys, q.src_nodes, q.len * sizeof(int64_t),
cudaMemcpyDeviceToHost);
int64_t *res = new int64_t[sample_size * key_size];
cudaMemcpy(res, val, sample_size * key_size * sizeof(int64_t),
Expand Down
49 changes: 34 additions & 15 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,35 @@
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace framework {
enum GraphTableType { EDGE_TABLE, FEATURE_TABLE };
class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
public:
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource, int topo_aware)
int get_table_offset(int gpu_id, GraphTableType type, int idx) {
int type_id = type;
return gpu_id * (graph_table_num_ + feature_table_num_) +
type_id * graph_table_num_ + idx;
}
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource, int topo_aware,
int graph_table_num, int feature_table_num)
: HeterComm<uint64_t, int64_t, int>(1, resource) {
load_factor_ = 0.25;
rw_lock.reset(new pthread_rwlock_t());
this->graph_table_num_ = graph_table_num;
this->feature_table_num_ = feature_table_num;
gpu_num = resource_->total_device();
memset(global_device_map, -1, sizeof(global_device_map));
for (auto &table : tables_) {
delete table;
table = NULL;
}
tables_ = std::vector<Table *>(
gpu_num * (graph_table_num + feature_table_num), NULL);
sample_status = std::vector<int *>(gpu_num * graph_table_num, NULL);
for (int i = 0; i < gpu_num; i++) {
gpu_graph_list.push_back(GpuPsCommGraph());
global_device_map[resource_->dev_id(i)] = i;
sample_status.push_back(NULL);
tables_.push_back(NULL);
for (int j = 0; j < graph_table_num; j++) {
gpu_graph_list_.push_back(GpuPsCommGraph());
}
}
cpu_table_status = -1;
if (topo_aware) {
Expand Down Expand Up @@ -89,21 +105,23 @@ class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
// end_graph_sampling();
// }
}
void build_graph_on_single_gpu(GpuPsCommGraph &g, int gpu_id);
void clear_graph_info(int gpu_id);
void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list);
void build_graph_on_single_gpu(GpuPsCommGraph &g, int gpu_id, int idx);
void clear_graph_info(int gpu_id, int index);
void clear_graph_info(int index);
void build_graph_from_cpu(std::vector<GpuPsCommGraph> &cpu_node_list,
int idx);
NodeQueryResult graph_node_sample(int gpu_id, int sample_size);
NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q,
bool cpu_switch);
NeighborSampleResult graph_neighbor_sample(int gpu_id, int64_t *key,
NeighborSampleResult graph_neighbor_sample(int gpu_id, int idx, int64_t *key,
int sample_size, int len);
NeighborSampleResult graph_neighbor_sample_v2(int gpu_id, int64_t *key,
int sample_size, int len,
bool cpu_query_switch);
NeighborSampleResult graph_neighbor_sample_v2(int gpu_id, int idx,
int64_t *key, int sample_size,
int len, bool cpu_query_switch);
void init_sample_status();
void free_sample_status();
NodeQueryResult query_node_list(int gpu_id, int start, int query_size);
void clear_graph_info();
NodeQueryResult query_node_list(int gpu_id, int idx, int start,
int query_size);
void display_sample_res(void *key, void *val, int len, int sample_len);
void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num,
int sample_size, int *h_left,
Expand All @@ -112,12 +130,13 @@ class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
int *actual_sample_size);
int init_cpu_table(const paddle::distributed::GraphParameter &graph);
int gpu_num;
std::vector<GpuPsCommGraph> gpu_graph_list;
int graph_table_num_, feature_table_num_;
std::vector<GpuPsCommGraph> gpu_graph_list_;
int global_device_map[32];
std::vector<int *> sample_status;
const int parallel_sample_size = 1;
const int dim_y = 256;
std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table;
std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table_;
std::shared_ptr<pthread_rwlock_t> rw_lock;
mutable std::mutex mutex_;
std::condition_variable cv_;
Expand Down
Loading

0 comments on commit fc28b23

Please sign in to comment.