Skip to content

Commit

Permalink
change arg of FillWalkBuf to be vector (PaddlePaddle#286)
Browse files Browse the repository at this point in the history
Co-authored-by: root <root@yq01-inf-hic-k8s-a100-ab2-0009.yq01.baidu.com>
  • Loading branch information
2 people authored and danleifeng committed Sep 12, 2023
1 parent 35ce03f commit 6eee86d
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 113 deletions.
206 changes: 112 additions & 94 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ int GraphDataGenerator::FillIdShowClkTensor(int total_instance,
} else {
// infer
uint64_t *d_type_keys =
reinterpret_cast<uint64_t *>(d_device_keys_[infer_cursor_]->ptr());
reinterpret_cast<uint64_t *>(d_device_keys_[0][infer_cursor_]->ptr());
d_type_keys += infer_node_start_;
infer_node_start_ += total_instance / 2;
CopyDuplicateKeys<<<GET_BLOCKS(total_instance / 2),
Expand Down Expand Up @@ -3256,68 +3256,71 @@ void GraphDataGenerator::DoWalkandSage() {
}
}
if (conf_.gpu_graph_training || conf_.sage_mode) {
CopyUniqueNodes(table_, copy_unique_len_, place_, d_uniq_node_num_,
CopyUniqueNodes(table_, copy_unique_len_[0], place_, d_uniq_node_num_[0],
&host_vec_, sample_stream_);
}
}

bool GraphDataGenerator::DoWalkForTrain() {
bool train_flag;
uint8_t *walk_ntype = NULL;
if (conf_.need_walk_ntype) {
walk_ntype = reinterpret_cast<uint8_t *>(d_walk_ntype_[0]->ptr());
}

auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
if (FLAGS_graph_metapath_split_opt) {
train_flag = FillWalkBufMultiPath(
h_device_keys_len_,
d_device_keys_,
gpu_graph_ptr->meta_path_,
conf_,
&epoch_finish_,
&copy_unique_len_,
place_,
gpu_graph_ptr->first_node_type_,
&(gpu_graph_ptr->node_type_start_[conf_.gpuid]),
reinterpret_cast<uint64_t *>(d_walk_[0]->ptr()),
walk_ntype,
&d_uniq_node_num_,
reinterpret_cast<int *>(d_random_row_[0]->ptr()),
reinterpret_cast<int *>(d_random_row_col_shift_[0]->ptr()),
&host_vec_,
&total_row_[0],
&jump_rows_,
&shuffle_seed_,
reinterpret_cast<uint64_t *>(d_train_metapath_keys_->ptr()),
&h_train_metapath_keys_len_,
table_,
&buf_state_[0],
sample_stream_);
} else {
train_flag = FillWalkBuf(h_device_keys_len_,
d_device_keys_,
gpu_graph_ptr->meta_path_,
for (int tensor_pair_idx = 0; tensor_pair_idx < conf_.tensor_pair_num;
++tensor_pair_idx) {
uint8_t *walk_ntype = NULL;
if (conf_.need_walk_ntype) {
walk_ntype = reinterpret_cast<uint8_t *>(d_walk_ntype_[tensor_pair_idx]->ptr());
}

auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
if (FLAGS_graph_metapath_split_opt) {
train_flag = FillWalkBufMultiPath(
h_device_keys_len_[tensor_pair_idx],
d_device_keys_[tensor_pair_idx],
gpu_graph_ptr->meta_path_[tensor_pair_idx],
conf_,
&epoch_finish_,
&copy_unique_len_,
&copy_unique_len_[tensor_pair_idx],
place_,
gpu_graph_ptr->first_node_type_,
&(gpu_graph_ptr->node_type_start_[conf_.gpuid]),
&(gpu_graph_ptr->finish_node_type_[conf_.gpuid]),
reinterpret_cast<uint64_t *>(d_walk_[0]->ptr()),
gpu_graph_ptr->first_node_type_[tensor_pair_idx],
&(gpu_graph_ptr->node_type_start_[tensor_pair_idx][conf_.gpuid]),
reinterpret_cast<uint64_t *>(d_walk_[tensor_pair_idx]->ptr()),
walk_ntype,
&d_uniq_node_num_,
reinterpret_cast<int *>(d_random_row_[0]->ptr()),
reinterpret_cast<int *>(d_random_row_col_shift_[0]->ptr()),
&multi_node_sync_stat_,
&d_uniq_node_num_[tensor_pair_idx],
reinterpret_cast<int *>(d_random_row_[tensor_pair_idx]->ptr()),
reinterpret_cast<int *>(d_random_row_col_shift_[tensor_pair_idx]->ptr()),
&host_vec_,
&total_row_[0],
&jump_rows_,
&shuffle_seed_,
&total_row_[tensor_pair_idx],
&jump_rows_[tensor_pair_idx],
&shuffle_seed_[tensor_pair_idx],
reinterpret_cast<uint64_t *>(d_train_metapath_keys_[tensor_pair_idx]->ptr()),
&h_train_metapath_keys_len_[tensor_pair_idx],
table_,
&buf_state_[0],
&buf_state_[tensor_pair_idx],
sample_stream_);
} else {
train_flag = FillWalkBuf(h_device_keys_len_[tensor_pair_idx],
d_device_keys_[tensor_pair_idx],
gpu_graph_ptr->meta_path_[tensor_pair_idx],
conf_,
&epoch_finish_,
&copy_unique_len_[tensor_pair_idx],
place_,
gpu_graph_ptr->first_node_type_[tensor_pair_idx],
&(gpu_graph_ptr->node_type_start_[tensor_pair_idx][conf_.gpuid]),
&(gpu_graph_ptr->finish_node_type_[tensor_pair_idx][conf_.gpuid]),
reinterpret_cast<uint64_t *>(d_walk_[tensor_pair_idx]->ptr()),
walk_ntype,
&d_uniq_node_num_[tensor_pair_idx],
reinterpret_cast<int *>(d_random_row_[tensor_pair_idx]->ptr()),
reinterpret_cast<int *>(d_random_row_col_shift_[tensor_pair_idx]->ptr()),
&multi_node_sync_stat_,
&host_vec_,
&total_row_[tensor_pair_idx],
&jump_rows_[tensor_pair_idx],
&shuffle_seed_[tensor_pair_idx],
table_,
&buf_state_[tensor_pair_idx],
sample_stream_);
}
}

return train_flag;
Expand Down Expand Up @@ -3430,9 +3433,9 @@ void GraphDataGenerator::DoSageForTrain() {
cudaStreamSynchronize(sample_stream_);
InsertTable(reinterpret_cast<uint64_t *>(final_sage_nodes->ptr()),
uniq_instance,
&d_uniq_node_num_,
&d_uniq_node_num_[0],
conf_,
&copy_unique_len_,
&copy_unique_len_[0],
place_,
table_,
&host_vec_,
Expand All @@ -3458,7 +3461,7 @@ void GraphDataGenerator::DoSageForInfer() {
total_instance *= 2;
while (total_instance != 0) {
uint64_t *d_type_keys = reinterpret_cast<uint64_t *>(
d_device_keys_[infer_cursor_]->ptr());
d_device_keys_[0][infer_cursor_]->ptr());
d_type_keys += infer_node_start_;
infer_node_start_ += total_instance / 2;
auto node_buf = memory::AllocShared(
Expand Down Expand Up @@ -3496,9 +3499,9 @@ void GraphDataGenerator::DoSageForInfer() {
cudaStreamSynchronize(sample_stream_);
InsertTable(reinterpret_cast<uint64_t *>(final_sage_nodes->ptr()),
uniq_instance,
&d_uniq_node_num_,
&d_uniq_node_num_[0],
conf_,
&copy_unique_len_,
&copy_unique_len_[0],
place_,
table_,
&host_vec_,
Expand Down Expand Up @@ -3526,16 +3529,16 @@ int GraphDataGenerator::FillInferBuf() {
gpu_graph_ptr->global_infer_node_type_start_[conf_.gpuid];
auto &infer_cursor = gpu_graph_ptr->infer_cursor_[conf_.thread_id];
total_row_[0] = 0;
if (infer_cursor < h_device_keys_len_.size()) {
if (infer_cursor < h_device_keys_len_[0].size()) {
while (global_infer_node_type_start[infer_cursor] >=
h_device_keys_len_[infer_cursor]) {
h_device_keys_len_[0][infer_cursor]) {
infer_cursor++;
if (infer_cursor >= h_device_keys_len_.size()) {
if (infer_cursor >= h_device_keys_len_[0].size()) {
return 0;
}
}
if (!infer_node_type_index_set_.empty()) {
while (infer_cursor < h_device_keys_len_.size()) {
while (infer_cursor < h_device_keys_len_[0].size()) {
if (infer_node_type_index_set_.find(infer_cursor) ==
infer_node_type_index_set_.end()) {
VLOG(2) << "Skip cursor[" << infer_cursor << "]";
Expand All @@ -3546,20 +3549,20 @@ int GraphDataGenerator::FillInferBuf() {
break;
}
}
if (infer_cursor >= h_device_keys_len_.size()) {
if (infer_cursor >= h_device_keys_len_[0].size()) {
return 0;
}
}

size_t device_key_size = h_device_keys_len_[infer_cursor];
size_t device_key_size = h_device_keys_len_[0][infer_cursor];
total_row_[0] =
(global_infer_node_type_start[infer_cursor] + conf_.buf_size <=
device_key_size)
? conf_.buf_size
: device_key_size - global_infer_node_type_start[infer_cursor];

uint64_t *d_type_keys =
reinterpret_cast<uint64_t *>(d_device_keys_[infer_cursor]->ptr());
reinterpret_cast<uint64_t *>(d_device_keys_[0][infer_cursor]->ptr());
if (!conf_.sage_mode) {
host_vec_.resize(total_row_[0]);
cudaMemcpyAsync(host_vec_.data(),
Expand All @@ -3583,12 +3586,15 @@ int GraphDataGenerator::FillInferBuf() {

void GraphDataGenerator::ClearSampleState() {
auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
auto &finish_node_type = gpu_graph_ptr->finish_node_type_[conf_.gpuid];
auto &node_type_start = gpu_graph_ptr->node_type_start_[conf_.gpuid];
finish_node_type.clear();
for (auto iter = node_type_start.begin(); iter != node_type_start.end();
iter++) {
iter->second = 0;
for (int tensor_pair_idx = 0; tensor_pair_idx < conf_.tensor_pair_num;
++tensor_pair_idx) {
auto &finish_node_type = gpu_graph_ptr->finish_node_type_[tensor_pair_idx][conf_.gpuid];
auto &node_type_start = gpu_graph_ptr->node_type_start_[tensor_pair_idx][conf_.gpuid];
finish_node_type.clear();
for (auto iter = node_type_start.begin(); iter != node_type_start.end();
iter++) {
iter->second = 0;
}
}
}

Expand Down Expand Up @@ -3666,33 +3672,47 @@ void GraphDataGenerator::AllocResource(
// stream_));
// }
if (conf_.gpu_graph_training && FLAGS_graph_metapath_split_opt) {
d_train_metapath_keys_ =
gpu_graph_ptr->d_node_iter_graph_metapath_keys_[thread_id];
h_train_metapath_keys_len_ =
gpu_graph_ptr->h_node_iter_graph_metapath_keys_len_[thread_id];
VLOG(2) << "h train metapaths key len: " << h_train_metapath_keys_len_;
d_train_metapath_keys_.resize(conf_.tensor_pair_num);
h_train_metapath_keys_len_.resize(conf_.tensor_pair_num);
for (int tensor_pair_idx = 0; tensor_pair_idx < conf_.tensor_pair_num;
++tensor_pair_idx) {
d_train_metapath_keys_[tensor_pair_idx] =
gpu_graph_ptr->d_node_iter_graph_metapath_keys_[thread_id];
h_train_metapath_keys_len_[tensor_pair_idx] =
gpu_graph_ptr->h_node_iter_graph_metapath_keys_len_[thread_id];
VLOG(2) << "h train metapaths key len: " << h_train_metapath_keys_len_[tensor_pair_idx];
}
} else {
auto &d_graph_all_type_keys =
gpu_graph_ptr->d_node_iter_graph_all_type_keys_;
auto &h_graph_all_type_keys_len =
gpu_graph_ptr->h_node_iter_graph_all_type_keys_len_;

for (size_t i = 0; i < d_graph_all_type_keys.size(); i++) {
d_device_keys_.push_back(d_graph_all_type_keys[i][thread_id]);
h_device_keys_len_.push_back(h_graph_all_type_keys_len[i][thread_id]);
d_device_keys_.resize(conf_.tensor_pair_num);
h_device_keys_len_.resize(conf_.tensor_pair_num);
for (int tensor_pair_idx = 0; tensor_pair_idx < conf_.tensor_pair_num;
++tensor_pair_idx) {
auto &d_graph_all_type_keys =
gpu_graph_ptr->d_node_iter_graph_all_type_keys_;
auto &h_graph_all_type_keys_len =
gpu_graph_ptr->h_node_iter_graph_all_type_keys_len_;

for (size_t i = 0; i < d_graph_all_type_keys.size(); i++) {
d_device_keys_[tensor_pair_idx].push_back(d_graph_all_type_keys[i][thread_id]);
h_device_keys_len_[tensor_pair_idx].push_back(h_graph_all_type_keys_len[i][thread_id]);
}
VLOG(2) << "h_device_keys size: " << h_device_keys_len_[tensor_pair_idx].size();
}
VLOG(2) << "h_device_keys size: " << h_device_keys_len_.size();
}

infer_cursor_ = 0;
jump_rows_ = 0;
d_uniq_node_num_ = memory::AllocShared(
place_,
sizeof(uint64_t),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
cudaMemsetAsync(d_uniq_node_num_->ptr(), 0, sizeof(uint64_t), sample_stream_);
jump_rows_.assign(conf_.tensor_pair_num, 0);
d_uniq_node_num_.resize(conf_.tensor_pair_num);
for (int tensor_pair_idx = 0; tensor_pair_idx < conf_.tensor_pair_num;
++tensor_pair_idx) {
d_uniq_node_num_[tensor_pair_idx] = memory::AllocShared(
place_,
sizeof(uint64_t),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
cudaMemsetAsync(d_uniq_node_num_[tensor_pair_idx]->ptr(), 0, sizeof(uint64_t), sample_stream_);
}

total_row_.resize(conf_.tensor_pair_num);
total_row_.assign(conf_.tensor_pair_num, 0);
d_walk_.resize(conf_.tensor_pair_num);
for (int tensor_pair_idx = 0; tensor_pair_idx < conf_.tensor_pair_num;
++tensor_pair_idx) {
Expand Down Expand Up @@ -3742,14 +3762,12 @@ void GraphDataGenerator::AllocResource(
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));
}

shuffle_seed_ = 0;

ins_buf_pair_len_.resize(conf_.tensor_pair_num);
shuffle_seed_.assign(conf_.tensor_pair_num, 0);
ins_buf_pair_len_.assign(conf_.tensor_pair_num, 0);
d_ins_buf_.resize(conf_.tensor_pair_num);
d_pair_num_.resize(conf_.tensor_pair_num);
for (int tensor_pair_idx = 0; tensor_pair_idx < conf_.tensor_pair_num;
++tensor_pair_idx) {
ins_buf_pair_len_[tensor_pair_idx] = 0;
d_ins_buf_[tensor_pair_idx] = memory::AllocShared(
place_,
(conf_.batch_size * 2 * 2) * sizeof(uint64_t),
Expand Down Expand Up @@ -3931,7 +3949,7 @@ void GraphDataGenerator::SetConfig(
int sample_size = std::stoi(samples[i]);
conf_.samples.emplace_back(sample_size);
}
copy_unique_len_ = 0;
copy_unique_len_.assign(conf_.tensor_pair_num, 0);

if (!conf_.gpu_graph_training) {
infer_node_type_ = graph_config.infer_node_type();
Expand Down
16 changes: 8 additions & 8 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ class GraphDataGenerator {
HashTable<uint64_t, uint64_t>* table_;
GraphDataGeneratorConfig conf_;
size_t infer_cursor_;
size_t jump_rows_;
std::vector<size_t> jump_rows_;
int64_t* id_tensor_ptr_;
int* index_tensor_ptr_;
int64_t* show_tensor_ptr_;
Expand All @@ -1041,16 +1041,16 @@ class GraphDataGenerator {
std::vector<phi::DenseTensor*> feed_vec_;
std::vector<UsedSlotInfo>* feed_info_; // adapt for float feature
std::vector<size_t> offset_;
std::vector<std::shared_ptr<phi::Allocation>> d_device_keys_;
std::shared_ptr<phi::Allocation> d_train_metapath_keys_;
std::vector<std::vector<std::shared_ptr<phi::Allocation>>> d_device_keys_;
std::vector<std::shared_ptr<phi::Allocation>> d_train_metapath_keys_;

std::vector<std::shared_ptr<phi::Allocation>> d_walk_;
std::vector<std::shared_ptr<phi::Allocation>> d_walk_ntype_;
std::shared_ptr<phi::Allocation> d_feature_list_;
std::shared_ptr<phi::Allocation> d_feature_;
std::vector<std::shared_ptr<phi::Allocation>> d_random_row_;
std::vector<std::shared_ptr<phi::Allocation>> d_random_row_col_shift_;
std::shared_ptr<phi::Allocation> d_uniq_node_num_;
std::vector<std::shared_ptr<phi::Allocation>> d_uniq_node_num_;
std::shared_ptr<phi::Allocation> d_slot_feature_num_map_;
std::shared_ptr<phi::Allocation> d_actual_slot_id_map_;
std::shared_ptr<phi::Allocation> d_fea_offset_map_;
Expand Down Expand Up @@ -1086,13 +1086,13 @@ class GraphDataGenerator {
int uint_slot_num_ = 0; // uint slot num
std::vector<int> h_slot_feature_num_map_;
int fea_num_per_node_;
int shuffle_seed_;
std::vector<int> shuffle_seed_;
bool epoch_finish_;
int pass_end_ = 0;
std::vector<uint64_t> host_vec_;
std::vector<uint64_t> h_device_keys_len_;
uint64_t h_train_metapath_keys_len_;
uint64_t copy_unique_len_;
std::vector<std::vector<uint64_t>> h_device_keys_len_;
std::vector<uint64_t> h_train_metapath_keys_len_;
std::vector<uint64_t> copy_unique_len_;
std::vector<int> total_row_;
size_t infer_node_start_;
size_t infer_node_end_;
Expand Down
Loading

0 comments on commit 6eee86d

Please sign in to comment.