From 0e04d151d0eebb1b165c6a11f1cbfda3b079ba5d Mon Sep 17 00:00:00 2001 From: miaoli06 <106585574+miaoli06@users.noreply.github.com> Date: Mon, 29 Aug 2022 19:26:51 +0800 Subject: [PATCH] split generate batch into multi stage (#92) * split generate batch into multi stage * fix conflict Co-authored-by: root --- paddle/fluid/framework/data_feed.cu | 428 +++++++++++----------------- paddle/fluid/framework/data_feed.h | 3 + paddle/fluid/framework/data_set.h | 1 + 3 files changed, 173 insertions(+), 259 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index de7fe49c442f5..dc4e6d1f07654 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -342,66 +342,150 @@ __global__ void GraphFillSlotLodKernel(int64_t *id_tensor, int len) { CUDA_KERNEL_LOOP(idx, len) { id_tensor[idx] = idx; } } -int GraphDataGenerator::FillInsBuf() { - if (ins_buf_pair_len_ >= batch_size_) { - return batch_size_; +int GraphDataGenerator::FillIdShowClkTensor(int total_instance, bool gpu_graph_training, size_t cursor) { + id_tensor_ptr_ = + feed_vec_[0]->mutable_data({total_instance, 1}, this->place_); + show_tensor_ptr_ = + feed_vec_[1]->mutable_data({total_instance}, this->place_); + clk_tensor_ptr_ = + feed_vec_[2]->mutable_data({total_instance}, this->place_); + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + if (gpu_graph_training) { + uint64_t *ins_cursor, *ins_buf; + ins_buf = reinterpret_cast(d_ins_buf_->ptr()); + ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance; + cudaMemcpyAsync(id_tensor_ptr_, + ins_cursor, + sizeof(uint64_t) * total_instance, + cudaMemcpyDeviceToDevice, + train_stream_); + } else { + uint64_t *d_type_keys = + reinterpret_cast(d_device_keys_[cursor]->ptr()); + auto &infer_node_type_start = gpu_graph_ptr->infer_node_type_start_[gpuid_]; + d_type_keys += infer_node_type_start[cursor]; + infer_node_type_start[cursor] += total_instance / 2; + CopyDuplicateKeys<<>>( + id_tensor_ptr_, d_type_keys, total_instance / 2); } - int total_instance = AcquireInstance(&buf_state_); - - VLOG(2) << "total_ins: " << total_instance; - buf_state_.Debug(); - if (total_instance == 0) { - if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { - return -1; - } - int res = FillWalkBuf(); - if (!res) { - // graph iterate complete - return -1; - } else { - total_instance = buf_state_.len; - VLOG(2) << "total_ins: " << total_instance; - buf_state_.Debug(); - // if (total_instance == 0) { - // return -1; - //} - } + GraphFillCVMKernel<<>>(show_tensor_ptr_, total_instance); + GraphFillCVMKernel<<>>(clk_tensor_ptr_, total_instance); + return 0; +} - if (!FLAGS_enable_opt_get_features && slot_num_ > 0) { - FillFeatureBuf(d_walk_, d_feature_); - if (debug_mode_) { - int len = buf_size_ > 5000 ? 5000 : buf_size_; - uint64_t h_walk[len]; - cudaMemcpy(h_walk, - d_walk_->ptr(), - len * sizeof(uint64_t), +int GraphDataGenerator::FillGraphSlotFeature(int total_instance, bool gpu_graph_training) { + int64_t *slot_tensor_ptr_[slot_num_]; + int64_t *slot_lod_tensor_ptr_[slot_num_]; + for (int i = 0; i < slot_num_; ++i) { + slot_tensor_ptr_[i] = feed_vec_[3 + 2 * i]->mutable_data( + {total_instance, 1}, this->place_); + slot_lod_tensor_ptr_[i] = feed_vec_[3 + 2 * i + 1]->mutable_data( + {total_instance + 1}, this->place_); + } + uint64_t *ins_cursor, *ins_buf; + if (gpu_graph_training) { + ins_buf = reinterpret_cast(d_ins_buf_->ptr()); + ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance; + } else { + id_tensor_ptr_ = + feed_vec_[0]->mutable_data({total_instance, 1}, this->place_); + ins_cursor = (uint64_t *)id_tensor_ptr_; + } + + cudaMemcpyAsync(d_slot_tensor_ptr_->ptr(), + slot_tensor_ptr_, + sizeof(uint64_t *) * slot_num_, + cudaMemcpyHostToDevice, + train_stream_); + cudaMemcpyAsync(d_slot_lod_tensor_ptr_->ptr(), + slot_lod_tensor_ptr_, + sizeof(uint64_t *) * slot_num_, + cudaMemcpyHostToDevice, + train_stream_); + uint64_t *feature_buf = reinterpret_cast(d_feature_buf_->ptr()); + FillFeatureBuf(ins_cursor, feature_buf, total_instance); + GraphFillSlotKernel<<>>((uint64_t *)d_slot_tensor_ptr_->ptr(), + feature_buf, + total_instance * slot_num_, + total_instance, + slot_num_); + GraphFillSlotLodKernelOpt<<>>( + (uint64_t *)d_slot_lod_tensor_ptr_->ptr(), + (total_instance + 1) * slot_num_, + total_instance + 1); + if (debug_mode_) { + uint64_t h_walk[total_instance]; + cudaMemcpy(h_walk, + ins_cursor, + total_instance * sizeof(uint64_t), + cudaMemcpyDeviceToHost); + uint64_t h_feature[total_instance * slot_num_]; + cudaMemcpy(h_feature, + feature_buf, + total_instance * slot_num_ * sizeof(uint64_t), + cudaMemcpyDeviceToHost); + for (int i = 0; i < total_instance; ++i) { + std::stringstream ss; + for (int j = 0; j < slot_num_; ++j) { + ss << h_feature[i * slot_num_ + j] << " "; + } + VLOG(2) << "aft FillFeatureBuf, gpu[" << gpuid_ << "] walk[" << i + << "] = " << (uint64_t)h_walk[i] << " feature[" + << i * slot_num_ << ".." << (i + 1) * slot_num_ + << "] = " << ss.str(); + } + + uint64_t h_slot_tensor[slot_num_][total_instance]; + uint64_t h_slot_lod_tensor[slot_num_][total_instance + 1]; + for (int i = 0; i < slot_num_; ++i) { + cudaMemcpy(h_slot_tensor[i], + slot_tensor_ptr_[i], + total_instance * sizeof(uint64_t), cudaMemcpyDeviceToHost); - uint64_t h_feature[len * slot_num_]; - cudaMemcpy(h_feature, - d_feature_->ptr(), - len * slot_num_ * sizeof(uint64_t), + int len = total_instance > 5000 ? 5000 : total_instance; + for (int j = 0; j < len; ++j) { + VLOG(2) << "gpu[" << gpuid_ << "] slot_tensor[" << i << "][" << j + << "] = " << h_slot_tensor[i][j]; + } + + cudaMemcpy(h_slot_lod_tensor[i], + slot_lod_tensor_ptr_[i], + (total_instance + 1) * sizeof(uint64_t), cudaMemcpyDeviceToHost); - for (int i = 0; i < len; ++i) { - std::stringstream ss; - for (int j = 0; j < slot_num_; ++j) { - ss << h_feature[i * slot_num_ + j] << " "; - } - VLOG(2) << "aft FillFeatureBuf, gpu[" << gpuid_ << "] walk[" << i - << "] = " << (uint64_t)h_walk[i] << " feature[" - << i * slot_num_ << ".." << (i + 1) * slot_num_ - << "] = " << ss.str(); - } - } - } - } + len = total_instance + 1 > 5000 ? 5000 : total_instance + 1; + for (int j = 0; j < len; ++j) { + VLOG(2) << "gpu[" << gpuid_ << "] slot_lod_tensor[" << i << "][" << j + << "] = " << h_slot_lod_tensor[i][j]; + } + } + } + return 0; +} +int GraphDataGenerator::MakeInsPair() { uint64_t *walk = reinterpret_cast(d_walk_->ptr()); uint64_t *ins_buf = reinterpret_cast(d_ins_buf_->ptr()); int *random_row = reinterpret_cast(d_random_row_->ptr()); int *d_pair_num = reinterpret_cast(d_pair_num_->ptr()); cudaMemsetAsync(d_pair_num, 0, sizeof(int), train_stream_); int len = buf_state_.len; + //make pair GraphFillIdKernel<<>>( ins_buf + ins_buf_pair_len_ * 2, d_pair_num, @@ -417,29 +501,6 @@ int GraphDataGenerator::FillInsBuf() { sizeof(int), cudaMemcpyDeviceToHost, train_stream_); - if (!FLAGS_enable_opt_get_features && slot_num_ > 0) { - uint64_t *feature_buf = reinterpret_cast(d_feature_buf_->ptr()); - uint64_t *feature = reinterpret_cast(d_feature_->ptr()); - cudaMemsetAsync(d_pair_num, 0, sizeof(int), train_stream_); - int len = buf_state_.len; - VLOG(2) << "feature_buf start[" << ins_buf_pair_len_ * 2 * slot_num_ - << "] len[" << len << "]"; - GraphFillFeatureKernel<<>>( - feature_buf + ins_buf_pair_len_ * 2 * slot_num_, - d_pair_num, - walk, - feature, - random_row + buf_state_.cursor, - buf_state_.central_word, - window_step_[buf_state_.step], - len, - walk_len_, - slot_num_); - } - cudaStreamSynchronize(train_stream_); ins_buf_pair_len_ += h_pair_num; @@ -455,21 +516,37 @@ int GraphDataGenerator::FillInsBuf() { VLOG(2) << "h_ins_buf[" << xx << "]: " << h_ins_buf[xx]; } delete[] h_ins_buf; + } + return ins_buf_pair_len_; +} - if (!FLAGS_enable_opt_get_features && slot_num_ > 0) { - uint64_t *feature_buf = - reinterpret_cast(d_feature_buf_->ptr()); - uint64_t h_feature_buf[(batch_size_ * 2 * 2) * slot_num_]; - cudaMemcpy(h_feature_buf, - feature_buf, - (batch_size_ * 2 * 2) * slot_num_ * sizeof(uint64_t), - cudaMemcpyDeviceToHost); - for (int xx = 0; xx < (batch_size_ * 2 * 2) * slot_num_; xx++) { - VLOG(2) << "h_feature_buf[" << xx << "]: " << h_feature_buf[xx]; - } +int GraphDataGenerator::FillInsBuf() { + if (ins_buf_pair_len_ >= batch_size_) { + return batch_size_; + } + int total_instance = AcquireInstance(&buf_state_); + + VLOG(2) << "total_ins: " << total_instance; + buf_state_.Debug(); + + if (total_instance == 0) { + if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { + return -1; + } + int res = FillWalkBuf(); + if (!res) { + // graph iterate complete + return -1; + } else { + total_instance = buf_state_.len; + VLOG(2) << "total_ins: " << total_instance; + buf_state_.Debug(); + // if (total_instance == 0) { + // return -1; + //} } } - return ins_buf_pair_len_; + return MakeInsPair(); } int GraphDataGenerator::GenerateBatch() { @@ -489,32 +566,10 @@ int GraphDataGenerator::GenerateBatch() { (infer_node_type_start[cursor_] + batch_size_ <= device_key_size) ? batch_size_ : device_key_size - infer_node_type_start[cursor_]; - uint64_t *d_type_keys = - reinterpret_cast(d_device_keys_[cursor_]->ptr()); - d_type_keys += infer_node_type_start[cursor_]; - infer_node_type_start[cursor_] += total_instance; VLOG(1) << "in graph_data generator:batch_size = " << batch_size_ << " instance = " << total_instance; total_instance *= 2; - id_tensor_ptr_ = feed_vec_[0]->mutable_data({total_instance, 1}, - this->place_); - show_tensor_ptr_ = - feed_vec_[1]->mutable_data({total_instance}, this->place_); - clk_tensor_ptr_ = - feed_vec_[2]->mutable_data({total_instance}, this->place_); - CopyDuplicateKeys<<>>( - id_tensor_ptr_, d_type_keys, total_instance / 2); - GraphFillCVMKernel<<>>(show_tensor_ptr_, total_instance); - GraphFillCVMKernel<<>>(clk_tensor_ptr_, total_instance); + FillIdShowClkTensor(total_instance, gpu_graph_training_, cursor_); break; } if (total_instance == 0) { @@ -533,133 +588,14 @@ int GraphDataGenerator::GenerateBatch() { } total_instance = ins_buf_pair_len_ < batch_size_ ? ins_buf_pair_len_ : batch_size_; - total_instance *= 2; - VLOG(2) << "total ins: " << total_instance << " gpuid: " << gpuid_ - << " feed_vec: " << feed_vec_[0]; - id_tensor_ptr_ = - feed_vec_[0]->mutable_data({total_instance, 1}, this->place_); - show_tensor_ptr_ = - feed_vec_[1]->mutable_data({total_instance}, this->place_); - clk_tensor_ptr_ = - feed_vec_[2]->mutable_data({total_instance}, this->place_); - } - - int64_t *slot_tensor_ptr_[slot_num_]; - int64_t *slot_lod_tensor_ptr_[slot_num_]; - if (slot_num_ > 0) { - for (int i = 0; i < slot_num_; ++i) { - slot_tensor_ptr_[i] = feed_vec_[3 + 2 * i]->mutable_data( - {total_instance, 1}, this->place_); - slot_lod_tensor_ptr_[i] = feed_vec_[3 + 2 * i + 1]->mutable_data( - {total_instance + 1}, this->place_); - } - if (FLAGS_enable_opt_get_features || !gpu_graph_training_) { - cudaMemcpyAsync(d_slot_tensor_ptr_->ptr(), - slot_tensor_ptr_, - sizeof(uint64_t *) * slot_num_, - cudaMemcpyHostToDevice, - train_stream_); - cudaMemcpyAsync(d_slot_lod_tensor_ptr_->ptr(), - slot_lod_tensor_ptr_, - sizeof(uint64_t *) * slot_num_, - cudaMemcpyHostToDevice, - train_stream_); - } - } - - uint64_t *ins_cursor, *ins_buf; - if (gpu_graph_training_) { VLOG(2) << "total_instance: " << total_instance - << ", ins_buf_pair_len = " << ins_buf_pair_len_; - // uint64_t *ins_buf = reinterpret_cast(d_ins_buf_->ptr()); - // uint64_t *ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance; - ins_buf = reinterpret_cast(d_ins_buf_->ptr()); - ins_cursor = ins_buf + ins_buf_pair_len_ * 2 - total_instance; - cudaMemcpyAsync(id_tensor_ptr_, - ins_cursor, - sizeof(uint64_t) * total_instance, - cudaMemcpyDeviceToDevice, - train_stream_); - - GraphFillCVMKernel<<>>(show_tensor_ptr_, total_instance); - GraphFillCVMKernel<<>>(clk_tensor_ptr_, total_instance); - } else { - ins_cursor = (uint64_t *)id_tensor_ptr_; + << ", ins_buf_pair_len = " << ins_buf_pair_len_; + FillIdShowClkTensor(total_instance, gpu_graph_training_); } if (slot_num_ > 0) { - uint64_t *feature_buf = reinterpret_cast(d_feature_buf_->ptr()); - if (FLAGS_enable_opt_get_features || !gpu_graph_training_) { - FillFeatureBuf(ins_cursor, feature_buf, total_instance); - // FillFeatureBuf(id_tensor_ptr_, feature_buf, total_instance); - if (debug_mode_) { - uint64_t h_walk[total_instance]; - cudaMemcpy(h_walk, - ins_cursor, - total_instance * sizeof(uint64_t), - cudaMemcpyDeviceToHost); - uint64_t h_feature[total_instance * slot_num_]; - cudaMemcpy(h_feature, - feature_buf, - total_instance * slot_num_ * sizeof(uint64_t), - cudaMemcpyDeviceToHost); - for (int i = 0; i < total_instance; ++i) { - std::stringstream ss; - for (int j = 0; j < slot_num_; ++j) { - ss << h_feature[i * slot_num_ + j] << " "; - } - VLOG(2) << "aft FillFeatureBuf, gpu[" << gpuid_ << "] walk[" << i - << "] = " << (uint64_t)h_walk[i] << " feature[" - << i * slot_num_ << ".." << (i + 1) * slot_num_ - << "] = " << ss.str(); - } - } - - GraphFillSlotKernel<<>>( - (uint64_t *)d_slot_tensor_ptr_->ptr(), - feature_buf, - total_instance * slot_num_, - total_instance, - slot_num_); - GraphFillSlotLodKernelOpt<<>>( - (uint64_t *)d_slot_lod_tensor_ptr_->ptr(), - (total_instance + 1) * slot_num_, - total_instance + 1); - } else { - for (int i = 0; i < slot_num_; ++i) { - int feature_buf_offset = - (ins_buf_pair_len_ * 2 - total_instance) * slot_num_ + i * 2; - for (int j = 0; j < total_instance; j += 2) { - VLOG(2) << "slot_tensor[" << i << "][" << j << "] <- feature_buf[" - << feature_buf_offset + j * slot_num_ << "]"; - VLOG(2) << "slot_tensor[" << i << "][" << j + 1 << "] <- feature_buf[" - << feature_buf_offset + j * slot_num_ + 1 << "]"; - cudaMemcpyAsync(slot_tensor_ptr_[i] + j, - &feature_buf[feature_buf_offset + j * slot_num_], - sizeof(uint64_t) * 2, - cudaMemcpyDeviceToDevice, - train_stream_); - } - GraphFillSlotLodKernel<<>>(slot_lod_tensor_ptr_[i], - total_instance + 1); - } - } + FillGraphSlotFeature(total_instance, gpu_graph_training_); } offset_.clear(); @@ -676,31 +612,6 @@ int GraphDataGenerator::GenerateBatch() { cudaStreamSynchronize(train_stream_); if (!gpu_graph_training_) return 1; ins_buf_pair_len_ -= total_instance / 2; - if (debug_mode_) { - uint64_t h_slot_tensor[slot_num_][total_instance]; - uint64_t h_slot_lod_tensor[slot_num_][total_instance + 1]; - for (int i = 0; i < slot_num_; ++i) { - cudaMemcpy(h_slot_tensor[i], - slot_tensor_ptr_[i], - total_instance * sizeof(uint64_t), - cudaMemcpyDeviceToHost); - int len = total_instance > 5000 ? 5000 : total_instance; - for (int j = 0; j < len; ++j) { - VLOG(2) << "gpu[" << gpuid_ << "] slot_tensor[" << i << "][" << j - << "] = " << h_slot_tensor[i][j]; - } - - cudaMemcpy(h_slot_lod_tensor[i], - slot_lod_tensor_ptr_[i], - (total_instance + 1) * sizeof(uint64_t), - cudaMemcpyDeviceToHost); - len = total_instance + 1 > 5000 ? 5000 : total_instance + 1; - for (int j = 0; j < len; ++j) { - VLOG(2) << "gpu[" << gpuid_ << "] slot_lod_tensor[" << i << "][" << j - << "] = " << h_slot_lod_tensor[i][j]; - } - } - } return 1; } @@ -1245,12 +1156,11 @@ void GraphDataGenerator::AllocResource(int thread_id, place_, (batch_size_ * 2 * 2) * slot_num_ * sizeof(uint64_t)); } d_pair_num_ = memory::AllocShared(place_, sizeof(int)); - if (FLAGS_enable_opt_get_features && slot_num_ > 0) { - d_slot_tensor_ptr_ = - memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *)); - d_slot_lod_tensor_ptr_ = - memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *)); - } + + d_slot_tensor_ptr_ = + memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *)); + d_slot_lod_tensor_ptr_ = + memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *)); cudaStreamSynchronize(sample_stream_); } diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 32b06a1ceffa1..7aa429cd22450 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -913,6 +913,9 @@ class GraphDataGenerator { int step, int* len_per_row); int FillInsBuf(); + int FillIdShowClkTensor(int total_instance, bool gpu_graph_training, size_t cursor = 0); + int FillGraphSlotFeature(int total_instance, bool gpu_graph_training); + int MakeInsPair(); int GetPathNum() { return total_row_; } void SetDeviceKeys(std::vector* device_keys, int type) { type_to_index_[type] = h_device_keys_.size(); diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 458a08f4f0a3d..3c736b56dfb27 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -338,6 +338,7 @@ class DatasetImpl : public Dataset { bool enable_heterps_ = false; int gpu_graph_mode_ = 0; std::vector>> gpu_graph_type_keys_; + std::vector gpu_graph_total_keys_; }; // use std::vector or Record as data type