Skip to content

Commit

Permalink
Add graphsage data process (PaddlePaddle#94)
Browse files Browse the repository at this point in the history
* Add graphsage withoud self-loop

* change next_num_nodes line

* temp commit

* move id tensor position

* change vlog num

* merge gpugraph, add graphsage optimization, refactor reindex, add graphsage infer

* add review

* add graphsage optimize, add geometric

* update fill_dvalue

* add geometric

* v2 version

* v2 version

* delete unused code

* update fill_dvalue
  • Loading branch information
DesmonDay authored and root committed Nov 26, 2022
1 parent ba44351 commit 619cee9
Show file tree
Hide file tree
Showing 19 changed files with 730 additions and 186 deletions.
487 changes: 440 additions & 47 deletions paddle/fluid/framework/data_feed.cu

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,17 @@ class GraphDataGenerator {
type_to_index_[type] = h_device_keys_.size();
h_device_keys_.push_back(device_keys);
}
std::vector<std::shared_ptr<phi::Allocation>> SampleNeighbors(
int64_t* uniq_nodes, int len, int sample_size,
std::vector<int64_t>& edges_split_num, int64_t* neighbor_len);

std::shared_ptr<phi::Allocation> GetReindexResult(
int64_t* reindex_src_data, int64_t* reindex_dst_data,
const int* count, const int64_t* center_nodes,
int* final_nodes_len, int node_len, int64_t neighbor_len);

std::shared_ptr<phi::Allocation> GenerateSampleGraph(
uint64_t* node_ids, int len, int* uniq_len, phi::DenseTensor* inverse);

protected:
int walk_degree_;
Expand All @@ -930,9 +941,13 @@ class GraphDataGenerator {
// point to device_keys_
size_t cursor_;
size_t jump_rows_;
int edge_to_id_len_;
int uniq_instance_;
int64_t* id_tensor_ptr_;
int* index_tensor_ptr_;
int64_t* show_tensor_ptr_;
int64_t* clk_tensor_ptr_;

cudaStream_t stream_;
paddle::platform::Place place_;
std::vector<phi::DenseTensor*> feed_vec_;
Expand Down Expand Up @@ -960,6 +975,10 @@ class GraphDataGenerator {
std::shared_ptr<phi::Allocation> d_pair_num_;
std::shared_ptr<phi::Allocation> d_slot_tensor_ptr_;
std::shared_ptr<phi::Allocation> d_slot_lod_tensor_ptr_;
std::shared_ptr<phi::Allocation> d_reindex_table_key_;
std::shared_ptr<phi::Allocation> d_reindex_table_value_;
std::shared_ptr<phi::Allocation> d_reindex_table_index_;
int64_t reindex_table_size_;
int ins_buf_pair_len_;
// size of a d_walk buf
size_t buf_size_;
Expand All @@ -973,6 +992,8 @@ class GraphDataGenerator {
std::vector<int> first_node_type_;
std::vector<std::vector<int>> meta_path_;
bool gpu_graph_training_;
bool sage_mode_;
std::vector<int> samples_;
};

class DataFeed {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/data_feed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ message GraphConfig {
optional string first_node_type = 8;
optional string meta_path = 9;
optional bool gpu_graph_training = 10 [ default = true ];
optional bool sage_mode = 11 [ default = false ];
optional string samples = 12;
}

message DataFeedDesc {
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,20 @@ struct NeighborSampleResult {
delete[] ac_size;
VLOG(0) << " ------------------";
}
void display2() {
VLOG(0) << "in node sample result display -----";
uint64_t *res = new uint64_t[total_sample_size];
cudaMemcpy(res, actual_val, total_sample_size * sizeof(uint64_t),
cudaMemcpyDeviceToHost);
std::string sample_str;
for (int i = 0; i < total_sample_size; i++) {
if (sample_str.size() > 0) sample_str += ";";
sample_str += std::to_string(res[i]);
}
VLOG(0) << "sample result: " << sample_str;
delete[] res;
}

std::vector<uint64_t> get_sampled_graph(NeighborSampleQuery q) {
std::vector<uint64_t> graph;
int64_t *sample_keys = new int64_t[q.len];
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ class GpuPsGraphTable
const std::vector<GpuPsCommGraphFea> &cpu_node_list, int idx);
NodeQueryResult graph_node_sample(int gpu_id, int sample_size);
NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q,
bool cpu_switch);
bool cpu_switch,
bool compress);
NeighborSampleResult graph_neighbor_sample(int gpu_id,
uint64_t *key,
int sample_size,
Expand All @@ -136,7 +137,8 @@ class GpuPsGraphTable
uint64_t *key,
int sample_size,
int len,
bool cpu_query_switch);
bool cpu_query_switch,
bool compress);

int get_feature_of_nodes(
int gpu_id, uint64_t *d_walk, uint64_t *d_offset, int size, int slot_num);
Expand Down
69 changes: 38 additions & 31 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,10 @@ __global__ void fill_dvalues(uint64_t* d_shard_vals,
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
d_actual_sample_size[idx[i]] = d_shard_actual_sample_size[i];
for (int j = 0; j < sample_size; j++) {
d_vals[idx[i] * sample_size + j] = d_shard_vals[i * sample_size + j];
size_t offset1 = idx[i] * sample_size;
size_t offset2 = i * sample_size;
for (int j = 0; j < d_shard_actual_sample_size[i]; j++) {
d_vals[offset1 + j] = d_shard_vals[offset2 + j];
}
}
}
Expand All @@ -323,8 +325,10 @@ __global__ void fill_actual_vals(uint64_t* vals,
int len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
int offset1 = cumsum_actual_sample_size[i];
int offset2 = sample_size * i;
for (int j = 0; j < actual_sample_size[i]; j++) {
actual_vals[cumsum_actual_sample_size[i] + j] = vals[sample_size * i + j];
actual_vals[offset1 + j] = vals[offset2 + j];
}
}
}
Expand Down Expand Up @@ -655,20 +659,21 @@ void GpuPsGraphTable::build_graph_from_cpu(
}

NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v3(
NeighborSampleQuery q, bool cpu_switch) {
NeighborSampleQuery q, bool cpu_switch, bool compress = true) {
return graph_neighbor_sample_v2(global_device_map[q.gpu_id],
q.table_idx,
q.src_nodes,
q.sample_size,
q.len,
cpu_switch);
cpu_switch,
compress);
}

NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
uint64_t* key,
int sample_size,
int len) {
return graph_neighbor_sample_v2(gpu_id, 0, key, sample_size, len, false);
return graph_neighbor_sample_v2(gpu_id, 0, key, sample_size, len, false, true);
}

NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
Expand All @@ -677,7 +682,9 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
uint64_t* key,
int sample_size,
int len,
bool cpu_query_switch) {
bool cpu_query_switch,
bool compress) {

NeighborSampleResult result;
result.initialize(sample_size, len, resource_->dev_id(gpu_id));

Expand Down Expand Up @@ -812,7 +819,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
d_idx_ptr,
sample_size,
len);
CUDA_CHECK(cudaStreamSynchronize(stream));
if (cpu_query_switch) {
Expand Down Expand Up @@ -912,31 +919,31 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id));
platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id));
thrust::device_vector<int> t_actual_sample_size(len);
thrust::copy(actual_sample_size,
actual_sample_size + len,
t_actual_sample_size.begin());
int total_sample_size = thrust::reduce(t_actual_sample_size.begin(),
t_actual_sample_size.end());
result.actual_val_mem =
memory::AllocShared(place, total_sample_size * sizeof(uint64_t));
result.actual_val = (uint64_t*)(result.actual_val_mem)->ptr();
int total_sample_size = thrust::reduce(
thrust::device_pointer_cast(actual_sample_size),
thrust::device_pointer_cast(actual_sample_size) + len);
result.set_total_sample_size(total_sample_size);
thrust::device_vector<int> cumsum_actual_sample_size(len);
thrust::exclusive_scan(t_actual_sample_size.begin(),
t_actual_sample_size.end(),
cumsum_actual_sample_size.begin(),
0);
fill_actual_vals<<<grid_size, block_size_, 0, stream>>>(
val,
result.actual_val,
actual_sample_size,
thrust::raw_pointer_cast(cumsum_actual_sample_size.data()),
sample_size,
len);
if (compress) {
result.actual_val_mem =
memory::AllocShared(place, total_sample_size * sizeof(uint64_t));
result.actual_val = (uint64_t*)(result.actual_val_mem)->ptr();
thrust::device_vector<int> cumsum_actual_sample_size(len);
thrust::exclusive_scan(thrust::device_pointer_cast(actual_sample_size),
thrust::device_pointer_cast(actual_sample_size) + len,
cumsum_actual_sample_size.begin(),
0);
fill_actual_vals<<<grid_size, block_size_, 0, stream>>>(
val,
result.actual_val,
actual_sample_size,
thrust::raw_pointer_cast(cumsum_actual_sample_size.data()),
sample_size,
len);
}
}
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
if (shard_len == 0) {
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ void GraphGpuWrapper::upload_batch(int type, int slice_num, int slot_num) {
}

NeighborSampleResult GraphGpuWrapper::graph_neighbor_sample_v3(
NeighborSampleQuery q, bool cpu_switch) {
NeighborSampleQuery q, bool cpu_switch, bool compress = true) {
return ((GpuPsGraphTable *)graph_table)
->graph_neighbor_sample_v3(q, cpu_switch);
->graph_neighbor_sample_v3(q, cpu_switch, compress);
}

int GraphGpuWrapper::get_feature_of_nodes(int gpu_id,
Expand Down Expand Up @@ -326,7 +326,7 @@ std::vector<uint64_t> GraphGpuWrapper::graph_neighbor_sample(
auto neighbor_sample_res =
((GpuPsGraphTable *)graph_table)
->graph_neighbor_sample_v2(
gpu_id, idx, cuda_key, sample_size, key.size(), false);
gpu_id, idx, cuda_key, sample_size, key.size(), false, true);
int *actual_sample_size = new int[key.size()];
cudaMemcpy(actual_sample_size,
neighbor_sample_res.actual_sample_size,
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ class GraphGpuWrapper {
int start,
int query_size);
NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q,
bool cpu_switch);
bool cpu_switch,
bool compress);
NeighborSampleResult graph_neighbor_sample(int gpu_id,
uint64_t* device_keys,
int walk_degree,
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ TEST(TEST_FLEET, test_cpu_cache) {
query.initialize(
i, 0, node_query_res.get_val(), 1, node_query_res.get_len());
query.display();
auto c = g.graph_neighbor_sample_v3(query, false);
auto c = g.graph_neighbor_sample_v3(query, false, true);
c.display();
}
}
Expand Down Expand Up @@ -219,7 +219,7 @@ TEST(TEST_FLEET, test_cpu_cache) {
query.initialize(i, 0, node_query_res.get_val(), 4,
node_query_res.get_len());
query.display();
auto c = g.graph_neighbor_sample_v3(query, true);
auto c = g.graph_neighbor_sample_v3(query, true, true);
c.display();
platform::CUDADeviceGuard guard(i);
uint64_t *key;
Expand All @@ -229,7 +229,7 @@ TEST(TEST_FLEET, test_cpu_cache) {
uint64_t t_key = 1;
cudaMemcpy(key, &t_key, sizeof(uint64_t), cudaMemcpyHostToDevice);
q1.initialize(i, 0, (uint64_t)key, 2, 1);
auto d = g.graph_neighbor_sample_v3(q1, true);
auto d = g.graph_neighbor_sample_v3(q1, true, true);
d.display();
cudaFree(key);
g.cpu_graph_table_->set_search_level(1);
Expand Down
58 changes: 58 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,64 @@ void InstanceNormInferMeta(const MetaTensor& x,
}
}

void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& pool_type,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count) {
auto src_index_dims = src_index.dims();
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(),
1,
phi::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}

auto dst_index_dims = dst_index.dims();
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(),
1,
phi::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}

PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
phi::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));

auto dims = x.dims();
std::vector<int64_t> dims_ = phi::vectorize(dims);
dims_[0] = -1;
out->set_dims(phi::make_ddim(dims_));
out->set_dtype(x.dtype());

if (pool_type == "MEAN") {
dst_count->set_dims({-1});
dst_count->set_dtype(DataType::INT32);
}
}

void GroupNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ void InstanceNormInferMeta(const MetaTensor& x,
MetaTensor* saved_variance,
MetaConfig config = MetaConfig());

void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& pool_type,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count);

void GroupNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/gpu/graph_reindex_funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace phi {

template <typename T>
inline __device__ size_t Hash(T id, int64_t size) {
return id % size;
return static_cast<unsigned long long int>(id) % size;
}

template <typename T>
Expand Down Expand Up @@ -169,7 +169,7 @@ __global__ void FillUniqueItems(const T* items,

template <typename T>
__global__ void ReindexSrcOutput(T* src_output,
int num_items,
int64_t num_items,
int64_t size,
const T* keys,
const int* values) {
Expand Down
Loading

0 comments on commit 619cee9

Please sign in to comment.