From 88f13ec7aa2bd3db85a58a0bdcd307edf3745a94 Mon Sep 17 00:00:00 2001 From: Thunderbrook Date: Tue, 24 May 2022 21:50:43 +0800 Subject: [PATCH 1/2] deepwalk --- paddle/fluid/framework/data_feed.cc | 31 ++- paddle/fluid/framework/data_feed.cu | 326 +++++++++++++++++++---- paddle/fluid/framework/data_feed.h | 163 +++++++++++- paddle/fluid/framework/data_feed.proto | 11 + paddle/fluid/framework/data_set.cc | 13 + paddle/fluid/framework/data_set.h | 2 + paddle/fluid/framework/hogwild_worker.cc | 5 + paddle/fluid/pybind/data_set_py.cc | 2 + python/paddle/fluid/dataset.py | 18 ++ 9 files changed, 504 insertions(+), 67 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index c094b8049c509..0a87aebe8da6a 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -52,17 +52,42 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place& place, device_key_size_ = h_device_keys_->size(); d_device_keys_ = memory::AllocShared(place_, device_key_size_ * sizeof(int64_t)); + for(size_t i = 0; i < h_device_keys_->size(); i++){ + VLOG(2) << "h_device_keys_[" << i << "] = " << (*h_device_keys_)[i]; + } CUDA_CHECK(cudaMemcpyAsync(d_device_keys_->ptr(), h_device_keys_->data(), device_key_size_ * sizeof(int64_t), cudaMemcpyHostToDevice, stream_)); + size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; d_prefix_sum_ = - memory::AllocShared(place_, (sample_key_size_ + 1) * sizeof(int64_t)); + memory::AllocShared(place_, (once_max_sample_keynum + 1) * sizeof(int64_t)); int64_t* d_prefix_sum_ptr = reinterpret_cast(d_prefix_sum_->ptr()); - cudaMemsetAsync(d_prefix_sum_ptr, 0, (sample_key_size_ + 1) * sizeof(int64_t), + cudaMemsetAsync(d_prefix_sum_ptr, 0, (once_max_sample_keynum + 1) * sizeof(int64_t), stream_); cursor_ = 0; + jump_rows_ = 0; device_keys_ = reinterpret_cast(d_device_keys_->ptr()); - ; + VLOG(2) << "device_keys_ = " << (uint64_t)device_keys_; + d_walk_ = memory::AllocShared(place_, buf_size_ * sizeof(int64_t)); + cudaMemsetAsync(d_walk_->ptr(), 0, buf_size_ * sizeof(int64_t), + stream_); + d_sample_keys_ = memory::AllocShared(place_, once_max_sample_keynum * sizeof(int64_t)); + + d_sampleidx2rows_.push_back(memory::AllocShared(place_, once_max_sample_keynum * sizeof(int))); + d_sampleidx2rows_.push_back(memory::AllocShared(place_, once_max_sample_keynum * sizeof(int))); + cur_sampleidx2row_ = 0; + + d_len_per_row_ = memory::AllocShared(place_, once_max_sample_keynum * sizeof(int)); + for (int i = -window_; i < 0; i++) { + window_step_.push_back(i); + } + for (int i = 0; i < window_; i++) { + window_step_.push_back(i + 1); + } + buf_state_.Init(batch_size_, walk_len_, &window_step_); + d_random_row_ = + memory::AllocShared(place_, (once_sample_startid_len_ * walk_degree_ * repeat_time_) * sizeof(int)); + shuffle_seed_ = 0; cudaStreamSynchronize(stream_); } diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index 30da43258dd12..eedb4cee9fc09 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -21,6 +21,10 @@ limitations under the License. */ #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h" #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h" +#include +#include +#include + namespace paddle { namespace framework { @@ -146,29 +150,143 @@ void SlotRecordInMemoryDataFeed::CopyForTensor( cudaStreamSynchronize(stream); } -__global__ void GraphFillIdKernel(int64_t *id_tensor, int *actual_sample_size, - int64_t *prefix_sum, int64_t *device_key, - int64_t *neighbors, int sample_size, - int len) { +__global__ void GraphFillCVMKernel(int64_t *tensor, int len) { + CUDA_KERNEL_LOOP(idx, len) { tensor[idx] = 1; } +} + + +__global__ void GraphFillIdKernel(int64_t *id_tensor, int64_t *walk, int *row, + int central_word, int step, + int len, int col_num) { + CUDA_KERNEL_LOOP(idx, len) { + int dst = idx * 2; + int src = row[idx] * col_num + central_word; + id_tensor[dst] = walk[src]; + id_tensor[dst + 1] = walk[src + step]; + } +} + +int GraphDataGenerator::AcquireInstance(BufState* state) { + // + if (state->GetNextStep()) { + state->Debug(); + return state->len; + } else if (state->GetNextCentrolWord()) { + state->Debug(); + return state->len; + } else if (state->GetNextBatch()) { + state->Debug(); + return state->len; + } + return 0; +} + +int GraphDataGenerator::GenerateBatch() { + platform::CUDADeviceGuard guard(gpuid_); + + int total_instance = AcquireInstance(&buf_state_); + + VLOG(2) << "total_ins: " << total_instance; + buf_state_.Debug(); + + if (total_instance == 0) { + int res = FillWalkBuf(d_walk_); + if (!res) { + return 0; + } else { + total_instance = buf_state_.len; + VLOG(2) << "total_ins: " << total_instance; + buf_state_.Debug(); + if (total_instance == 0) { + return 0; + } + } + } + + total_instance *= 2; + id_tensor_ptr_ = + feed_vec_[0]->mutable_data({total_instance, 1}, this->place_); + show_tensor_ptr_ = + feed_vec_[1]->mutable_data({total_instance}, this->place_); + clk_tensor_ptr_ = + feed_vec_[2]->mutable_data({total_instance}, this->place_); + + int64_t *walk = reinterpret_cast(d_walk_->ptr()); + int *random_row = reinterpret_cast(d_random_row_->ptr()); + int len = buf_state_.len; + GraphFillIdKernel<<>>( + id_tensor_ptr_, walk, random_row + buf_state_.cursor, buf_state_.central_word, window_step_[buf_state_.step], len, walk_len_); + GraphFillCVMKernel<<>>( + show_tensor_ptr_, total_instance); + GraphFillCVMKernel<<>>( + clk_tensor_ptr_, total_instance); + + offset_.clear(); + offset_.push_back(0); + offset_.push_back(total_instance); + LoD lod{offset_}; + feed_vec_[0]->set_lod(lod); + cudaStreamSynchronize(stream_); + return 1; +} + +__global__ void GraphFillSampleKeysKernel(int64_t* neighbors, int64_t* sample_keys, + int64_t *prefix_sum, int* sampleidx2row, int* tmp_sampleidx2row, + int* actual_sample_size, + int cur_degree, int len) { CUDA_KERNEL_LOOP(idx, len) { for (int k = 0; k < actual_sample_size[idx]; k++) { - int offset = (prefix_sum[idx] + k) * 2; - id_tensor[offset] = device_key[idx]; - id_tensor[offset + 1] = neighbors[idx * sample_size + k]; + size_t offset = prefix_sum[idx] + k; + sample_keys[offset] = neighbors[idx * cur_degree + k]; + tmp_sampleidx2row[offset] = sampleidx2row[idx] + k; } } } -__global__ void GraphFillCVMKernel(int64_t *tensor, int len) { - CUDA_KERNEL_LOOP(idx, len) { tensor[idx] = 1; } + +__global__ void GraphDoWalkKernel(int64_t* neighbors, int64_t* walk, + int64_t *d_prefix_sum, int* actual_sample_size, + int cur_degree, int step, int len, + int *id_cnt, int *sampleidx2row, int col_size) { + CUDA_KERNEL_LOOP(i, len) { + for (int k = 0; k < actual_sample_size[i]; k++) { + //int idx = sampleidx2row[i]; + size_t row = sampleidx2row[k + d_prefix_sum[i]]; + //size_t row = idx * cur_degree + k; + size_t col = step; + size_t offset = (row * col_size + col); + walk[offset] = neighbors[i * cur_degree + k]; + id_cnt[row] += 1; + } + } +} + +// Fill keys to the first column of walk +__global__ void GraphFillFirstStepKernel(int64_t* prefix_sum, int* sampleidx2row, int64_t* walk, int64_t* keys, int len, int walk_degree, int col_size, int* actual_sample_size, int64_t* neighbors, int64_t* sample_keys) { + CUDA_KERNEL_LOOP(idx, len) { + for (int k = 0; k < actual_sample_size[idx]; k++) { + size_t row = prefix_sum[idx] + k; + sample_keys[row] = neighbors[idx * walk_degree + k]; + sampleidx2row[row] = row; + + size_t offset = col_size * row; + walk[offset] = keys[idx]; + walk[offset + 1] = neighbors[idx * walk_degree + k]; + } + } } -void GraphDataGenerator::FeedGraphIns(size_t cursor, int len, - NeighborSampleResult &sample_res) { +// Fill sample_res to the stepth column of walk +void GraphDataGenerator::FillOneStep(int64_t* walk, int len, + NeighborSampleResult &sample_res, int cur_degree, int step, int* len_per_row) { size_t temp_storage_bytes = 0; int *d_actual_sample_size = sample_res.actual_sample_size; int64_t *d_neighbors = sample_res.val; int64_t *d_prefix_sum = reinterpret_cast(d_prefix_sum_->ptr()); + int64_t *d_sample_keys = reinterpret_cast(d_sample_keys_->ptr()); + int *d_sampleidx2row = reinterpret_cast(d_sampleidx2rows_[cur_sampleidx2row_]->ptr()); + int *d_tmp_sampleidx2row = reinterpret_cast(d_sampleidx2rows_[1 - cur_sampleidx2row_]->ptr()); + CUDA_CHECK(cub::DeviceScan::InclusiveSum(NULL, temp_storage_bytes, d_actual_sample_size, d_prefix_sum + 1, len, stream_)); @@ -177,56 +295,160 @@ void GraphDataGenerator::FeedGraphIns(size_t cursor, int len, CUDA_CHECK(cub::DeviceScan::InclusiveSum( d_temp_storage->ptr(), temp_storage_bytes, d_actual_sample_size, d_prefix_sum + 1, len, stream_)); - cudaStreamSynchronize(stream_); - int64_t total_ins = 0; - cudaMemcpyAsync(&total_ins, d_prefix_sum + len, sizeof(int64_t), + + int64_t next_start_id_len = 0; + cudaMemcpyAsync(&next_start_id_len, d_prefix_sum + len, sizeof(int64_t), cudaMemcpyDeviceToHost, stream_); + cudaStreamSynchronize(stream_); + + if (step == 1) { + GraphFillFirstStepKernel<<>>(d_prefix_sum, d_tmp_sampleidx2row, walk, device_keys_ + cursor_, len, walk_degree_, walk_len_, d_actual_sample_size, d_neighbors, d_sample_keys); + jump_rows_ = next_start_id_len; - total_ins *= 2; - id_tensor_ptr_ = - feed_vec_[0]->mutable_data({total_ins, 1}, this->place_); - show_tensor_ptr_ = - feed_vec_[1]->mutable_data({total_ins}, this->place_); - clk_tensor_ptr_ = - feed_vec_[2]->mutable_data({total_ins}, this->place_); - - GraphFillIdKernel<<>>( - id_tensor_ptr_, d_actual_sample_size, d_prefix_sum, - device_keys_ + cursor_, d_neighbors, walk_degree_, len); - GraphFillCVMKernel<<>>( - show_tensor_ptr_, total_ins); - GraphFillCVMKernel<<>>( - clk_tensor_ptr_, total_ins); + } else { + GraphFillSampleKeysKernel<<>>( + d_neighbors, d_sample_keys, d_prefix_sum, d_sampleidx2row, d_tmp_sampleidx2row, d_actual_sample_size, cur_degree, len); + + GraphDoWalkKernel<<>>( + d_neighbors, walk, d_prefix_sum, d_actual_sample_size, cur_degree, step, len, len_per_row, d_tmp_sampleidx2row, walk_len_); + } + if (debug_mode_) { + size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; + int64_t* h_prefix_sum = new int64_t[len + 1]; + int* h_actual_size = new int[len]; + int* h_offset2idx = new int[once_max_sample_keynum]; + int64_t* h_sample_keys = new int64_t[once_max_sample_keynum]; + cudaMemcpy(h_offset2idx, d_tmp_sampleidx2row, once_max_sample_keynum * sizeof(int), + cudaMemcpyDeviceToHost); - offset_.clear(); - offset_.push_back(0); - offset_.push_back(total_ins); - LoD lod{offset_}; - feed_vec_[0]->set_lod(lod); - // feed_vec_[1]->set_lod(lod); - // feed_vec_[2]->set_lod(lod); + cudaMemcpy(h_prefix_sum, d_prefix_sum, (len+1) * sizeof(int64_t), + cudaMemcpyDeviceToHost); + for (int xx = 0; xx < once_max_sample_keynum; xx++) { + VLOG(2) << "h_offset2idx[" << xx << "]: " << h_offset2idx[xx]; + } + for (int xx = 0; xx < len+1; xx++) { + VLOG(2) << "h_prefix_sum[" << xx << "]: " << h_prefix_sum[xx]; + } + delete[] h_prefix_sum; + delete[] h_actual_size; + delete[] h_offset2idx; + delete[] h_sample_keys; + } + sample_keys_len_ = next_start_id_len; + cur_sampleidx2row_ = 1 - cur_sampleidx2row_; cudaStreamSynchronize(stream_); } -int GraphDataGenerator::GenerateBatch() { - // GpuPsGraphTable *g = (GpuPsGraphTable *)(gpu_graph_ptr->graph_table); +int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { platform::CUDADeviceGuard guard(gpuid_); + size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; + //////// + int64_t *h_walk; + int64_t* h_sample_keys; + int* h_offset2idx; + int *h_len_per_row; + int64_t* h_prefix_sum; + if (debug_mode_) { + h_walk = new int64_t[buf_size_]; + h_sample_keys = new int64_t[once_max_sample_keynum]; + h_offset2idx = new int[once_max_sample_keynum]; + h_len_per_row = new int[once_max_sample_keynum]; + h_prefix_sum = new int64_t[100]; + } + /////// auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - int tmp_len = cursor_ + sample_key_size_ > device_key_size_ - ? device_key_size_ - cursor_ - : sample_key_size_; - VLOG(3) << "device key size: " << device_key_size_ - << " this batch: " << tmp_len << " cursor: " << cursor_ - << " sample_key_size_: " << sample_key_size_; - if (tmp_len == 0) { - return 0; + int64_t *walk = reinterpret_cast(d_walk->ptr()); + int *len_per_row = reinterpret_cast(d_len_per_row_->ptr()); + int64_t *d_sample_keys = reinterpret_cast(d_sample_keys_->ptr()); + cudaMemsetAsync(walk, 0, buf_size_ * sizeof(sizeof(int64_t)), stream_); + cudaMemsetAsync(len_per_row, 0, once_max_sample_keynum * sizeof(int), stream_); + int i = 0; + int total_row = 0; + while (i < buf_size_) { + int tmp_len = cursor_ + once_sample_startid_len_ > device_key_size_ + ? device_key_size_ - cursor_ + : once_sample_startid_len_; + if (tmp_len == 0) { + break; + } + VLOG(2) << "i = " << i << " buf_size_ = " << buf_size_ << " tmp_len = " << tmp_len << " cursor = " << cursor_ << " once_max_sample_keynum = " << once_max_sample_keynum; + int64_t* cur_walk = walk + i; + len_per_row += once_max_sample_keynum; + + if (debug_mode_) { + cudaMemcpy(h_walk, walk, buf_size_ * sizeof(int64_t), + cudaMemcpyDeviceToHost); + for (int xx = 0; xx < buf_size_; xx++) { + VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx]; + } + } + auto sample_res = gpu_graph_ptr->graph_neighbor_sample( + gpuid_, device_keys_ + cursor_, walk_degree_, tmp_len); + + int step = 1; + jump_rows_ = 0; + FillOneStep(cur_walk, tmp_len, sample_res, walk_degree_, step, len_per_row); + ///////// + if (debug_mode_) { + cudaMemcpy(h_walk, walk, buf_size_ * sizeof(int64_t), + cudaMemcpyDeviceToHost); + cudaMemcpy(h_len_per_row, len_per_row, once_max_sample_keynum * sizeof(int), + cudaMemcpyDeviceToHost); + for (int xx = 0; xx < buf_size_; xx++) { + VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx]; + } + for (int xx = 0; xx < once_max_sample_keynum; xx++) { + VLOG(2) << "h_len_per_row[" << xx << "]: " << h_len_per_row[xx]; + } + + } + ///////// + step++; + for (; step < walk_len_; step ++) { + if (sample_keys_len_ == 0) { + break; + } + sample_res = gpu_graph_ptr->graph_neighbor_sample( + gpuid_, d_sample_keys, 1, sample_keys_len_); + + FillOneStep(cur_walk, sample_keys_len_, sample_res, 1, step, len_per_row); + if (debug_mode_) { + cudaMemcpy(h_walk, walk, buf_size_ * sizeof(int64_t), + cudaMemcpyDeviceToHost); + for (int xx = 0; xx < buf_size_; xx++) { + VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx]; + } + cudaMemcpy(h_len_per_row, len_per_row, once_max_sample_keynum * sizeof(int), + cudaMemcpyDeviceToHost); + for (int xx = 0; xx < once_max_sample_keynum; xx++) { + VLOG(2) << "h_len_per_row[" << xx << "]: " << h_len_per_row[xx]; + } + } + } + cursor_ += tmp_len; + i += jump_rows_ * walk_len_; + total_row += jump_rows_; } - int total_instance = 1; - auto sample_res = gpu_graph_ptr->graph_neighbor_sample( - gpuid_, device_keys_ + cursor_, walk_degree_, tmp_len); - FeedGraphIns(cursor_, tmp_len, sample_res); - cursor_ += tmp_len; - return 1; + buf_state_.Reset(total_row); + int *d_random_row = reinterpret_cast(d_random_row_->ptr()); + + thrust::random::default_random_engine engine(shuffle_seed_); + const auto &exec_policy = thrust::cuda::par.on(stream_); + thrust::counting_iterator cnt_iter(0); + thrust::shuffle_copy(exec_policy, cnt_iter, cnt_iter + total_row, + thrust::device_pointer_cast(d_random_row), engine); + + cudaStreamSynchronize(stream_); + shuffle_seed_ = engine(); + + if (debug_mode_) { + delete[] h_walk; + delete[] h_sample_keys; + delete[] h_offset2idx; + delete[] h_len_per_row; + delete[] h_prefix_sum; + } + return total_row != 0; } } // namespace framework diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index fe370ae115847..8e846691ab659 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -30,6 +30,7 @@ limitations under the License. */ #include #include #include +#include #include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/blocking_queue.h" @@ -775,36 +776,174 @@ class DLManager { std::map handle_map_; }; +struct engine_wrapper_t { + std::default_random_engine engine; + engine_wrapper_t() { + struct timespec tp; + clock_gettime(CLOCK_REALTIME, &tp); + double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9; + static std::atomic x(0); + std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)}; + engine.seed(sseq); + } +}; + +struct BufState { + int left; + int right; + int central_word; + int step; + engine_wrapper_t random_engine_; + + int len; + int cursor; + int row_num; + + int batch_size; + int walk_len; + std::vector* window; + + BufState() {} + ~BufState() {} + + void Init(int graph_batch_size, int graph_walk_len, std::vector* graph_window) { + batch_size = graph_batch_size; + walk_len = graph_walk_len; + window = graph_window; + + left = 0; + right = window->size() - 1; + central_word = -1; + step = -1; + + len = 0; + cursor = 0; + for (size_t i = 0; i < graph_window->size(); i++) { + VLOG(2) << "graph_window[" << i << "] = " << (*graph_window)[i]; + } + } + + void Reset(int total_rows) { + cursor = 0; + row_num = total_rows; + int tmp_len = cursor + batch_size > row_num + ? row_num - cursor + : batch_size; + len = tmp_len; + central_word = -1; + step = -1; + GetNextCentrolWord(); + } + + int GetNextStep() { + step++; + if (step <= right && central_word + (*window)[step] < walk_len) { + return 1; + } + return 0; + } + + void Debug() { + VLOG(2) << "left: " << left << " right: " << right << " central_word: " << central_word << " step: " << step << " cursor: " << cursor << " len: " << len << " row_num: " << row_num; + } + + int GetNextCentrolWord() { + if (++central_word >= walk_len) { + return 0; + } + int window_size = window->size() / 2; + int random_window = random_engine_.engine() % window_size + 1; + left = window_size - random_window; + right = window_size + random_window - 1; + VLOG(2) << "random window: " << random_window << " window[" << left << "] = " << (*window)[left] << " window[" << right << "] = " << (*window)[right]; + + for (step = left; step <= right; step ++) { + if (central_word + (*window)[step] >= 0) { + return 1; + } + } + return 0; + } + + int GetNextBatch() { + int tmp_len = cursor + batch_size > row_num + ? row_num - cursor + : batch_size; + cursor += tmp_len; + len = tmp_len; + central_word = -1; + step = -1; + GetNextCentrolWord(); + return tmp_len != 0; + } + +}; + class GraphDataGenerator { public: GraphDataGenerator() {}; - ~GraphDataGenerator() {}; + virtual ~GraphDataGenerator() {}; void SetConfig(const paddle::framework::DataFeedDesc& data_feed_desc) { - walk_degree_ = 1; - walk_len_ = 1; - sample_key_size_ = 8000; + auto graph_config = data_feed_desc.graph_config(); + walk_degree_ = graph_config.walk_degree(); + walk_len_ = graph_config.walk_len(); + window_ = graph_config.window(); + once_sample_startid_len_ = graph_config.once_sample_startid_len(); + debug_mode_ = graph_config.debug_mode(); + if (debug_mode_) { + batch_size_ = graph_config.batch_size(); + } else { + batch_size_ = once_sample_startid_len_; + } + repeat_time_ = graph_config.sample_times_one_chunk(); + buf_size_ = once_sample_startid_len_ * walk_len_ * walk_degree_ * repeat_time_; + VLOG(2) << "Confirm GraphConfig, walk_degree : " << walk_degree_ << ", walk_len : " << walk_len_ << ", window : " << window_ << ", once_sample_startid_len : " << once_sample_startid_len_ << ", sample_times_one_chunk : " << repeat_time_ << ", batch_size: " << batch_size_; }; void AllocResource(const paddle::platform::Place& place, std::vector feed_vec, std::vector* h_device_keys); - void FeedGraphIns(size_t cursor, int len, NeighborSampleResult& sample_res); + int AcquireInstance(BufState* state); int GenerateBatch(); + int FillWalkBuf(std::shared_ptr d_walk); + void FillOneStep(int64_t* walk, int len, NeighborSampleResult &sample_res, int cur_degree, int step, int *len_per_row); protected: - int walk_degree_ = 1; - int walk_len_ = 1; - int sample_key_size_; + int walk_degree_; + int walk_len_; + int window_; + int once_sample_startid_len_; int gpuid_; + // start ids + int64_t* device_keys_; size_t device_key_size_; + std::vector* h_device_keys_; + // point to device_keys_ size_t cursor_; - int64_t* device_keys_; + size_t jump_rows_; int64_t* id_tensor_ptr_; int64_t* show_tensor_ptr_; int64_t* clk_tensor_ptr_; cudaStream_t stream_; paddle::platform::Place place_; std::vector feed_vec_; - std::vector* h_device_keys_; std::vector offset_; - std::shared_ptr d_prefix_sum_ = nullptr; - std::shared_ptr d_device_keys_ = nullptr; + std::shared_ptr d_prefix_sum_; + std::shared_ptr d_device_keys_; + + std::shared_ptr d_walk_; + std::shared_ptr d_len_per_row_; + std::shared_ptr d_random_row_; + // + std::vector> d_sampleidx2rows_; + int cur_sampleidx2row_; + // record the keys to call graph_neighbor_sample + std::shared_ptr d_sample_keys_; + int sample_keys_len_; + // size of a d_walk buf + size_t buf_size_; + int repeat_time_; + std::vector window_step_; + BufState buf_state_; + int batch_size_; + int shuffle_seed_; + int debug_mode_; }; class DataFeed { diff --git a/paddle/fluid/framework/data_feed.proto b/paddle/fluid/framework/data_feed.proto index 6964446f20946..143bc266b9c1e 100644 --- a/paddle/fluid/framework/data_feed.proto +++ b/paddle/fluid/framework/data_feed.proto @@ -27,6 +27,16 @@ message MultiSlotDesc { optional string uid_slot = 2; } +message GraphConfig { + optional int32 walk_degree = 1 [ default = 1 ]; + optional int32 walk_len = 2 [ default = 20 ]; + optional int32 window = 3 [ default = 5 ]; + optional int32 once_sample_startid_len = 4 [ default = 8000 ]; + optional int32 sample_times_one_chunk = 5 [ default = 10 ]; + optional int32 batch_size = 6 [ default = 1 ]; + optional int32 debug_mode = 7 [ default = 0 ]; +} + message DataFeedDesc { optional string name = 1; optional int32 batch_size = 2 [ default = 32 ]; @@ -37,4 +47,5 @@ message DataFeedDesc { optional int32 pv_batch_size = 7 [ default = 32 ]; optional int32 input_type = 8 [ default = 0 ]; optional string so_parser_name = 9; + optional GraphConfig graph_config = 10; } diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 79dd67f811e9e..f55b73a4cab65 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -410,6 +410,19 @@ void MultiSlotDataset::PrepareTrain() { return; } +template +void DatasetImpl::SetGraphDeviceKeys(const std::vector& h_device_keys) { + + + for (size_t i = 0; i < gpu_graph_device_keys_.size(); i++) { + gpu_graph_device_keys_[i].clear(); + } + size_t device_num = gpu_graph_device_keys_.size(); + for (size_t i = 0; i < h_device_keys.size(); i++) { + int shard = h_device_keys[i] % device_num; + gpu_graph_device_keys_[shard].push_back(h_device_keys[i]); + } +} // load data into memory, Dataset hold this memory, // which will later be fed into readers' channel template diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 8457336560ee6..fe9f0fa32614c 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -158,6 +158,7 @@ class Dataset { virtual void DynamicAdjustReadersNum(int thread_num) = 0; // set fleet send sleep seconds virtual void SetFleetSendSleepSeconds(int seconds) = 0; + virtual void SetGraphDeviceKeys(const std::vector& h_device_keys) = 0; protected: virtual int ReceiveFromClient(int msg_type, int client_id, const std::string& msg) = 0; @@ -237,6 +238,7 @@ class DatasetImpl : public Dataset { int read_thread_num, int consume_thread_num, int shard_num) {} + virtual void SetGraphDeviceKeys(const std::vector& h_device_keys); virtual void ClearLocalTables() {} virtual void CreatePreLoadReaders(); virtual void DestroyPreLoadReaders(); diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index 1d8727c81692f..ed33175a2edc7 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -118,6 +118,11 @@ void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) { void HogwildWorker::TrainFilesWithProfiler() { platform::SetNumThreads(1); +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + platform::SetDeviceId(thread_id_); +#elif defined(PADDLE_WITH_XPU_BKCL) + platform::SetXPUDeviceId(thread_id_); +#endif device_reader_->Start(); std::vector op_total_time; std::vector op_name; diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 5e2274cb65138..5aac6ada05b18 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -298,6 +298,8 @@ void BindDataset(py::module *m) { py::call_guard()) .def("set_preload_thread_num", &framework::Dataset::SetPreLoadThreadNum, py::call_guard()) + .def("set_graph_device_keys", &framework::Dataset::SetGraphDeviceKeys, + py::call_guard()) .def("create_preload_readers", &framework::Dataset::CreatePreLoadReaders, py::call_guard()) .def("destroy_preload_readers", diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 84064669c0dc6..002823bd9c320 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -1041,6 +1041,24 @@ def _set_heter_ps(self, enable_heter_ps=False): user no need to call this function. """ self.dataset.set_heter_ps(enable_heter_ps) + + def set_graph_device_keys(self, device_keys): + """ + Set heter ps mode + user no need to call this function. + """ + self.dataset.set_graph_device_keys(device_keys) + + def set_graph_config(self, config): + """ + """ + self.proto_desc.graph_config.walk_degree = config.get("walk_degree", 1) + self.proto_desc.graph_config.walk_len = config.get("walk_len", 20) + self.proto_desc.graph_config.once_sample_startid_len = config.get("once_sample_startid_len", 8000) + self.proto_desc.graph_config.sample_times_one_chunk = config.get("sample_times_one_chunk", 10) + self.proto_desc.graph_config.batch_size = config.get("batch_size", 1) + self.proto_desc.graph_config.debug_mode = config.get("debug_mode", 0) + class QueueDataset(DatasetBase): From ed569b449f29f0e765587c606e2e50b9c7a87404 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 24 May 2022 15:24:37 +0000 Subject: [PATCH 2/2] format --- paddle/fluid/framework/data_feed.cc | 32 ++-- paddle/fluid/framework/data_feed.cu | 246 +++++++++++++++------------- paddle/fluid/framework/data_feed.h | 70 ++++---- paddle/fluid/framework/data_set.cc | 26 +-- paddle/fluid/framework/data_set.h | 8 +- 5 files changed, 206 insertions(+), 176 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 0a87aebe8da6a..6c74f127c7fd6 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -52,32 +52,35 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place& place, device_key_size_ = h_device_keys_->size(); d_device_keys_ = memory::AllocShared(place_, device_key_size_ * sizeof(int64_t)); - for(size_t i = 0; i < h_device_keys_->size(); i++){ + for (size_t i = 0; i < h_device_keys_->size(); i++) { VLOG(2) << "h_device_keys_[" << i << "] = " << (*h_device_keys_)[i]; } CUDA_CHECK(cudaMemcpyAsync(d_device_keys_->ptr(), h_device_keys_->data(), device_key_size_ * sizeof(int64_t), cudaMemcpyHostToDevice, stream_)); size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; - d_prefix_sum_ = - memory::AllocShared(place_, (once_max_sample_keynum + 1) * sizeof(int64_t)); + d_prefix_sum_ = memory::AllocShared( + place_, (once_max_sample_keynum + 1) * sizeof(int64_t)); int64_t* d_prefix_sum_ptr = reinterpret_cast(d_prefix_sum_->ptr()); - cudaMemsetAsync(d_prefix_sum_ptr, 0, (once_max_sample_keynum + 1) * sizeof(int64_t), - stream_); + cudaMemsetAsync(d_prefix_sum_ptr, 0, + (once_max_sample_keynum + 1) * sizeof(int64_t), stream_); cursor_ = 0; jump_rows_ = 0; device_keys_ = reinterpret_cast(d_device_keys_->ptr()); VLOG(2) << "device_keys_ = " << (uint64_t)device_keys_; d_walk_ = memory::AllocShared(place_, buf_size_ * sizeof(int64_t)); - cudaMemsetAsync(d_walk_->ptr(), 0, buf_size_ * sizeof(int64_t), - stream_); - d_sample_keys_ = memory::AllocShared(place_, once_max_sample_keynum * sizeof(int64_t)); - - d_sampleidx2rows_.push_back(memory::AllocShared(place_, once_max_sample_keynum * sizeof(int))); - d_sampleidx2rows_.push_back(memory::AllocShared(place_, once_max_sample_keynum * sizeof(int))); + cudaMemsetAsync(d_walk_->ptr(), 0, buf_size_ * sizeof(int64_t), stream_); + d_sample_keys_ = + memory::AllocShared(place_, once_max_sample_keynum * sizeof(int64_t)); + + d_sampleidx2rows_.push_back( + memory::AllocShared(place_, once_max_sample_keynum * sizeof(int))); + d_sampleidx2rows_.push_back( + memory::AllocShared(place_, once_max_sample_keynum * sizeof(int))); cur_sampleidx2row_ = 0; - d_len_per_row_ = memory::AllocShared(place_, once_max_sample_keynum * sizeof(int)); + d_len_per_row_ = + memory::AllocShared(place_, once_max_sample_keynum * sizeof(int)); for (int i = -window_; i < 0; i++) { window_step_.push_back(i); } @@ -85,8 +88,9 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place& place, window_step_.push_back(i + 1); } buf_state_.Init(batch_size_, walk_len_, &window_step_); - d_random_row_ = - memory::AllocShared(place_, (once_sample_startid_len_ * walk_degree_ * repeat_time_) * sizeof(int)); + d_random_row_ = memory::AllocShared( + place_, + (once_sample_startid_len_ * walk_degree_ * repeat_time_) * sizeof(int)); shuffle_seed_ = 0; cudaStreamSynchronize(stream_); } diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index eedb4cee9fc09..83af4d6ea1621 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -17,13 +17,13 @@ limitations under the License. */ #endif #if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) +#include +#include +#include #include "cub/cub.cuh" #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h" #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h" -#include -#include -#include namespace paddle { namespace framework { @@ -154,10 +154,9 @@ __global__ void GraphFillCVMKernel(int64_t *tensor, int len) { CUDA_KERNEL_LOOP(idx, len) { tensor[idx] = 1; } } - __global__ void GraphFillIdKernel(int64_t *id_tensor, int64_t *walk, int *row, - int central_word, int step, - int len, int col_num) { + int central_word, int step, int len, + int col_num) { CUDA_KERNEL_LOOP(idx, len) { int dst = idx * 2; int src = row[idx] * col_num + central_word; @@ -166,7 +165,7 @@ __global__ void GraphFillIdKernel(int64_t *id_tensor, int64_t *walk, int *row, } } -int GraphDataGenerator::AcquireInstance(BufState* state) { +int GraphDataGenerator::AcquireInstance(BufState *state) { // if (state->GetNextStep()) { state->Debug(); @@ -183,23 +182,23 @@ int GraphDataGenerator::AcquireInstance(BufState* state) { int GraphDataGenerator::GenerateBatch() { platform::CUDADeviceGuard guard(gpuid_); - + int total_instance = AcquireInstance(&buf_state_); - + VLOG(2) << "total_ins: " << total_instance; buf_state_.Debug(); - + if (total_instance == 0) { int res = FillWalkBuf(d_walk_); if (!res) { - return 0; + return 0; } else { - total_instance = buf_state_.len; - VLOG(2) << "total_ins: " << total_instance; - buf_state_.Debug(); - if (total_instance == 0) { - return 0; - } + total_instance = buf_state_.len; + VLOG(2) << "total_ins: " << total_instance; + buf_state_.Debug(); + if (total_instance == 0) { + return 0; + } } } @@ -210,12 +209,13 @@ int GraphDataGenerator::GenerateBatch() { feed_vec_[1]->mutable_data({total_instance}, this->place_); clk_tensor_ptr_ = feed_vec_[2]->mutable_data({total_instance}, this->place_); - + int64_t *walk = reinterpret_cast(d_walk_->ptr()); int *random_row = reinterpret_cast(d_random_row_->ptr()); int len = buf_state_.len; GraphFillIdKernel<<>>( - id_tensor_ptr_, walk, random_row + buf_state_.cursor, buf_state_.central_word, window_step_[buf_state_.step], len, walk_len_); + id_tensor_ptr_, walk, random_row + buf_state_.cursor, + buf_state_.central_word, window_step_[buf_state_.step], len, walk_len_); GraphFillCVMKernel<<>>( show_tensor_ptr_, total_instance); GraphFillCVMKernel<<>>( @@ -230,10 +230,10 @@ int GraphDataGenerator::GenerateBatch() { return 1; } -__global__ void GraphFillSampleKeysKernel(int64_t* neighbors, int64_t* sample_keys, - int64_t *prefix_sum, int* sampleidx2row, int* tmp_sampleidx2row, - int* actual_sample_size, - int cur_degree, int len) { +__global__ void GraphFillSampleKeysKernel( + int64_t *neighbors, int64_t *sample_keys, int64_t *prefix_sum, + int *sampleidx2row, int *tmp_sampleidx2row, int *actual_sample_size, + int cur_degree, int len) { CUDA_KERNEL_LOOP(idx, len) { for (int k = 0; k < actual_sample_size[idx]; k++) { size_t offset = prefix_sum[idx] + k; @@ -243,18 +243,18 @@ __global__ void GraphFillSampleKeysKernel(int64_t* neighbors, int64_t* sample_ke } } - -__global__ void GraphDoWalkKernel(int64_t* neighbors, int64_t* walk, - int64_t *d_prefix_sum, int* actual_sample_size, - int cur_degree, int step, int len, - int *id_cnt, int *sampleidx2row, int col_size) { +__global__ void GraphDoWalkKernel(int64_t *neighbors, int64_t *walk, + int64_t *d_prefix_sum, + int *actual_sample_size, int cur_degree, + int step, int len, int *id_cnt, + int *sampleidx2row, int col_size) { CUDA_KERNEL_LOOP(i, len) { for (int k = 0; k < actual_sample_size[i]; k++) { - //int idx = sampleidx2row[i]; + // int idx = sampleidx2row[i]; size_t row = sampleidx2row[k + d_prefix_sum[i]]; - //size_t row = idx * cur_degree + k; + // size_t row = idx * cur_degree + k; size_t col = step; - size_t offset = (row * col_size + col); + size_t offset = (row * col_size + col); walk[offset] = neighbors[i * cur_degree + k]; id_cnt[row] += 1; } @@ -262,13 +262,16 @@ __global__ void GraphDoWalkKernel(int64_t* neighbors, int64_t* walk, } // Fill keys to the first column of walk -__global__ void GraphFillFirstStepKernel(int64_t* prefix_sum, int* sampleidx2row, int64_t* walk, int64_t* keys, int len, int walk_degree, int col_size, int* actual_sample_size, int64_t* neighbors, int64_t* sample_keys) { +__global__ void GraphFillFirstStepKernel( + int64_t *prefix_sum, int *sampleidx2row, int64_t *walk, int64_t *keys, + int len, int walk_degree, int col_size, int *actual_sample_size, + int64_t *neighbors, int64_t *sample_keys) { CUDA_KERNEL_LOOP(idx, len) { for (int k = 0; k < actual_sample_size[idx]; k++) { size_t row = prefix_sum[idx] + k; sample_keys[row] = neighbors[idx * walk_degree + k]; sampleidx2row[row] = row; - + size_t offset = col_size * row; walk[offset] = keys[idx]; walk[offset + 1] = neighbors[idx * walk_degree + k]; @@ -277,15 +280,19 @@ __global__ void GraphFillFirstStepKernel(int64_t* prefix_sum, int* sampleidx2row } // Fill sample_res to the stepth column of walk -void GraphDataGenerator::FillOneStep(int64_t* walk, int len, - NeighborSampleResult &sample_res, int cur_degree, int step, int* len_per_row) { +void GraphDataGenerator::FillOneStep(int64_t *walk, int len, + NeighborSampleResult &sample_res, + int cur_degree, int step, + int *len_per_row) { size_t temp_storage_bytes = 0; int *d_actual_sample_size = sample_res.actual_sample_size; int64_t *d_neighbors = sample_res.val; int64_t *d_prefix_sum = reinterpret_cast(d_prefix_sum_->ptr()); int64_t *d_sample_keys = reinterpret_cast(d_sample_keys_->ptr()); - int *d_sampleidx2row = reinterpret_cast(d_sampleidx2rows_[cur_sampleidx2row_]->ptr()); - int *d_tmp_sampleidx2row = reinterpret_cast(d_sampleidx2rows_[1 - cur_sampleidx2row_]->ptr()); + int *d_sampleidx2row = + reinterpret_cast(d_sampleidx2rows_[cur_sampleidx2row_]->ptr()); + int *d_tmp_sampleidx2row = + reinterpret_cast(d_sampleidx2rows_[1 - cur_sampleidx2row_]->ptr()); CUDA_CHECK(cub::DeviceScan::InclusiveSum(NULL, temp_storage_bytes, d_actual_sample_size, @@ -295,44 +302,50 @@ void GraphDataGenerator::FillOneStep(int64_t* walk, int len, CUDA_CHECK(cub::DeviceScan::InclusiveSum( d_temp_storage->ptr(), temp_storage_bytes, d_actual_sample_size, d_prefix_sum + 1, len, stream_)); - - int64_t next_start_id_len = 0; + + int64_t next_start_id_len = 0; cudaMemcpyAsync(&next_start_id_len, d_prefix_sum + len, sizeof(int64_t), cudaMemcpyDeviceToHost, stream_); cudaStreamSynchronize(stream_); - + if (step == 1) { - GraphFillFirstStepKernel<<>>(d_prefix_sum, d_tmp_sampleidx2row, walk, device_keys_ + cursor_, len, walk_degree_, walk_len_, d_actual_sample_size, d_neighbors, d_sample_keys); + GraphFillFirstStepKernel<<>>( + d_prefix_sum, d_tmp_sampleidx2row, walk, device_keys_ + cursor_, len, + walk_degree_, walk_len_, d_actual_sample_size, d_neighbors, + d_sample_keys); jump_rows_ = next_start_id_len; } else { - GraphFillSampleKeysKernel<<>>( - d_neighbors, d_sample_keys, d_prefix_sum, d_sampleidx2row, d_tmp_sampleidx2row, d_actual_sample_size, cur_degree, len); - + GraphFillSampleKeysKernel<<>>( + d_neighbors, d_sample_keys, d_prefix_sum, d_sampleidx2row, + d_tmp_sampleidx2row, d_actual_sample_size, cur_degree, len); + GraphDoWalkKernel<<>>( - d_neighbors, walk, d_prefix_sum, d_actual_sample_size, cur_degree, step, len, len_per_row, d_tmp_sampleidx2row, walk_len_); + d_neighbors, walk, d_prefix_sum, d_actual_sample_size, cur_degree, step, + len, len_per_row, d_tmp_sampleidx2row, walk_len_); } if (debug_mode_) { - size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; - int64_t* h_prefix_sum = new int64_t[len + 1]; - int* h_actual_size = new int[len]; - int* h_offset2idx = new int[once_max_sample_keynum]; - int64_t* h_sample_keys = new int64_t[once_max_sample_keynum]; - cudaMemcpy(h_offset2idx, d_tmp_sampleidx2row, once_max_sample_keynum * sizeof(int), - cudaMemcpyDeviceToHost); - - cudaMemcpy(h_prefix_sum, d_prefix_sum, (len+1) * sizeof(int64_t), + size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; + int64_t *h_prefix_sum = new int64_t[len + 1]; + int *h_actual_size = new int[len]; + int *h_offset2idx = new int[once_max_sample_keynum]; + int64_t *h_sample_keys = new int64_t[once_max_sample_keynum]; + cudaMemcpy(h_offset2idx, d_tmp_sampleidx2row, + once_max_sample_keynum * sizeof(int), cudaMemcpyDeviceToHost); + + cudaMemcpy(h_prefix_sum, d_prefix_sum, (len + 1) * sizeof(int64_t), cudaMemcpyDeviceToHost); - for (int xx = 0; xx < once_max_sample_keynum; xx++) { - VLOG(2) << "h_offset2idx[" << xx << "]: " << h_offset2idx[xx]; - } - for (int xx = 0; xx < len+1; xx++) { - VLOG(2) << "h_prefix_sum[" << xx << "]: " << h_prefix_sum[xx]; - } - delete[] h_prefix_sum; - delete[] h_actual_size; - delete[] h_offset2idx; - delete[] h_sample_keys; + for (int xx = 0; xx < once_max_sample_keynum; xx++) { + VLOG(2) << "h_offset2idx[" << xx << "]: " << h_offset2idx[xx]; + } + for (int xx = 0; xx < len + 1; xx++) { + VLOG(2) << "h_prefix_sum[" << xx << "]: " << h_prefix_sum[xx]; + } + delete[] h_prefix_sum; + delete[] h_actual_size; + delete[] h_offset2idx; + delete[] h_sample_keys; } sample_keys_len_ = next_start_id_len; cur_sampleidx2row_ = 1 - cur_sampleidx2row_; @@ -344,10 +357,10 @@ int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; //////// int64_t *h_walk; - int64_t* h_sample_keys; - int* h_offset2idx; + int64_t *h_sample_keys; + int *h_offset2idx; int *h_len_per_row; - int64_t* h_prefix_sum; + int64_t *h_prefix_sum; if (debug_mode_) { h_walk = new int64_t[buf_size_]; h_sample_keys = new int64_t[once_max_sample_keynum]; @@ -361,92 +374,95 @@ int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { int *len_per_row = reinterpret_cast(d_len_per_row_->ptr()); int64_t *d_sample_keys = reinterpret_cast(d_sample_keys_->ptr()); cudaMemsetAsync(walk, 0, buf_size_ * sizeof(sizeof(int64_t)), stream_); - cudaMemsetAsync(len_per_row, 0, once_max_sample_keynum * sizeof(int), stream_); + cudaMemsetAsync(len_per_row, 0, once_max_sample_keynum * sizeof(int), + stream_); int i = 0; int total_row = 0; while (i < buf_size_) { int tmp_len = cursor_ + once_sample_startid_len_ > device_key_size_ - ? device_key_size_ - cursor_ - : once_sample_startid_len_; + ? device_key_size_ - cursor_ + : once_sample_startid_len_; if (tmp_len == 0) { - break; + break; } - VLOG(2) << "i = " << i << " buf_size_ = " << buf_size_ << " tmp_len = " << tmp_len << " cursor = " << cursor_ << " once_max_sample_keynum = " << once_max_sample_keynum; - int64_t* cur_walk = walk + i; + VLOG(2) << "i = " << i << " buf_size_ = " << buf_size_ + << " tmp_len = " << tmp_len << " cursor = " << cursor_ + << " once_max_sample_keynum = " << once_max_sample_keynum; + int64_t *cur_walk = walk + i; len_per_row += once_max_sample_keynum; - - if (debug_mode_) { + + if (debug_mode_) { cudaMemcpy(h_walk, walk, buf_size_ * sizeof(int64_t), - cudaMemcpyDeviceToHost); + cudaMemcpyDeviceToHost); for (int xx = 0; xx < buf_size_; xx++) { - VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx]; + VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx]; } } auto sample_res = gpu_graph_ptr->graph_neighbor_sample( - gpuid_, device_keys_ + cursor_, walk_degree_, tmp_len); - + gpuid_, device_keys_ + cursor_, walk_degree_, tmp_len); + int step = 1; jump_rows_ = 0; FillOneStep(cur_walk, tmp_len, sample_res, walk_degree_, step, len_per_row); - ///////// - if (debug_mode_) { - cudaMemcpy(h_walk, walk, buf_size_ * sizeof(int64_t), - cudaMemcpyDeviceToHost); - cudaMemcpy(h_len_per_row, len_per_row, once_max_sample_keynum * sizeof(int), - cudaMemcpyDeviceToHost); - for (int xx = 0; xx < buf_size_; xx++) { - VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx]; - } - for (int xx = 0; xx < once_max_sample_keynum; xx++) { - VLOG(2) << "h_len_per_row[" << xx << "]: " << h_len_per_row[xx]; - } - + ///////// + if (debug_mode_) { + cudaMemcpy(h_walk, walk, buf_size_ * sizeof(int64_t), + cudaMemcpyDeviceToHost); + cudaMemcpy(h_len_per_row, len_per_row, + once_max_sample_keynum * sizeof(int), cudaMemcpyDeviceToHost); + for (int xx = 0; xx < buf_size_; xx++) { + VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx]; + } + for (int xx = 0; xx < once_max_sample_keynum; xx++) { + VLOG(2) << "h_len_per_row[" << xx << "]: " << h_len_per_row[xx]; + } } ///////// step++; - for (; step < walk_len_; step ++) { - if (sample_keys_len_ == 0) { - break; + for (; step < walk_len_; step++) { + if (sample_keys_len_ == 0) { + break; + } + sample_res = gpu_graph_ptr->graph_neighbor_sample(gpuid_, d_sample_keys, + 1, sample_keys_len_); + + FillOneStep(cur_walk, sample_keys_len_, sample_res, 1, step, len_per_row); + if (debug_mode_) { + cudaMemcpy(h_walk, walk, buf_size_ * sizeof(int64_t), + cudaMemcpyDeviceToHost); + for (int xx = 0; xx < buf_size_; xx++) { + VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx]; } - sample_res = gpu_graph_ptr->graph_neighbor_sample( - gpuid_, d_sample_keys, 1, sample_keys_len_); - - FillOneStep(cur_walk, sample_keys_len_, sample_res, 1, step, len_per_row); - if (debug_mode_) { - cudaMemcpy(h_walk, walk, buf_size_ * sizeof(int64_t), - cudaMemcpyDeviceToHost); - for (int xx = 0; xx < buf_size_; xx++) { - VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx]; - } - cudaMemcpy(h_len_per_row, len_per_row, once_max_sample_keynum * sizeof(int), - cudaMemcpyDeviceToHost); - for (int xx = 0; xx < once_max_sample_keynum; xx++) { - VLOG(2) << "h_len_per_row[" << xx << "]: " << h_len_per_row[xx]; - } + cudaMemcpy(h_len_per_row, len_per_row, + once_max_sample_keynum * sizeof(int), + cudaMemcpyDeviceToHost); + for (int xx = 0; xx < once_max_sample_keynum; xx++) { + VLOG(2) << "h_len_per_row[" << xx << "]: " << h_len_per_row[xx]; } + } } cursor_ += tmp_len; i += jump_rows_ * walk_len_; - total_row += jump_rows_; + total_row += jump_rows_; } buf_state_.Reset(total_row); int *d_random_row = reinterpret_cast(d_random_row_->ptr()); - + thrust::random::default_random_engine engine(shuffle_seed_); const auto &exec_policy = thrust::cuda::par.on(stream_); thrust::counting_iterator cnt_iter(0); thrust::shuffle_copy(exec_policy, cnt_iter, cnt_iter + total_row, thrust::device_pointer_cast(d_random_row), engine); - + cudaStreamSynchronize(stream_); shuffle_seed_ = engine(); - + if (debug_mode_) { delete[] h_walk; delete[] h_sample_keys; delete[] h_offset2idx; delete[] h_len_per_row; - delete[] h_prefix_sum; + delete[] h_prefix_sum; } return total_row != 0; } diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 8e846691ab659..a3c49b16cfd98 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -23,6 +23,7 @@ limitations under the License. */ #include // NOLINT #include #include // NOLINT +#include #include #include #include // NOLINT @@ -30,7 +31,6 @@ limitations under the License. */ #include #include #include -#include #include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/blocking_queue.h" @@ -794,47 +794,46 @@ struct BufState { int central_word; int step; engine_wrapper_t random_engine_; - + int len; int cursor; int row_num; - + int batch_size; int walk_len; std::vector* window; - + BufState() {} ~BufState() {} - - void Init(int graph_batch_size, int graph_walk_len, std::vector* graph_window) { + + void Init(int graph_batch_size, int graph_walk_len, + std::vector* graph_window) { batch_size = graph_batch_size; walk_len = graph_walk_len; window = graph_window; - + left = 0; right = window->size() - 1; central_word = -1; step = -1; - + len = 0; cursor = 0; - for (size_t i = 0; i < graph_window->size(); i++) { + for (size_t i = 0; i < graph_window->size(); i++) { VLOG(2) << "graph_window[" << i << "] = " << (*graph_window)[i]; } } - + void Reset(int total_rows) { cursor = 0; row_num = total_rows; - int tmp_len = cursor + batch_size > row_num - ? row_num - cursor - : batch_size; + int tmp_len = cursor + batch_size > row_num ? row_num - cursor : batch_size; len = tmp_len; central_word = -1; step = -1; GetNextCentrolWord(); } - + int GetNextStep() { step++; if (step <= right && central_word + (*window)[step] < walk_len) { @@ -844,9 +843,12 @@ struct BufState { } void Debug() { - VLOG(2) << "left: " << left << " right: " << right << " central_word: " << central_word << " step: " << step << " cursor: " << cursor << " len: " << len << " row_num: " << row_num; + VLOG(2) << "left: " << left << " right: " << right + << " central_word: " << central_word << " step: " << step + << " cursor: " << cursor << " len: " << len + << " row_num: " << row_num; } - + int GetNextCentrolWord() { if (++central_word >= walk_len) { return 0; @@ -855,20 +857,20 @@ struct BufState { int random_window = random_engine_.engine() % window_size + 1; left = window_size - random_window; right = window_size + random_window - 1; - VLOG(2) << "random window: " << random_window << " window[" << left << "] = " << (*window)[left] << " window[" << right << "] = " << (*window)[right]; + VLOG(2) << "random window: " << random_window << " window[" << left + << "] = " << (*window)[left] << " window[" << right + << "] = " << (*window)[right]; - for (step = left; step <= right; step ++) { + for (step = left; step <= right; step++) { if (central_word + (*window)[step] >= 0) { return 1; } } return 0; } - + int GetNextBatch() { - int tmp_len = cursor + batch_size > row_num - ? row_num - cursor - : batch_size; + int tmp_len = cursor + batch_size > row_num ? row_num - cursor : batch_size; cursor += tmp_len; len = tmp_len; central_word = -1; @@ -876,13 +878,12 @@ struct BufState { GetNextCentrolWord(); return tmp_len != 0; } - }; class GraphDataGenerator { public: - GraphDataGenerator() {}; - virtual ~GraphDataGenerator() {}; + GraphDataGenerator(){}; + virtual ~GraphDataGenerator(){}; void SetConfig(const paddle::framework::DataFeedDesc& data_feed_desc) { auto graph_config = data_feed_desc.graph_config(); walk_degree_ = graph_config.walk_degree(); @@ -896,14 +897,23 @@ class GraphDataGenerator { batch_size_ = once_sample_startid_len_; } repeat_time_ = graph_config.sample_times_one_chunk(); - buf_size_ = once_sample_startid_len_ * walk_len_ * walk_degree_ * repeat_time_; - VLOG(2) << "Confirm GraphConfig, walk_degree : " << walk_degree_ << ", walk_len : " << walk_len_ << ", window : " << window_ << ", once_sample_startid_len : " << once_sample_startid_len_ << ", sample_times_one_chunk : " << repeat_time_ << ", batch_size: " << batch_size_; + buf_size_ = + once_sample_startid_len_ * walk_len_ * walk_degree_ * repeat_time_; + VLOG(2) << "Confirm GraphConfig, walk_degree : " << walk_degree_ + << ", walk_len : " << walk_len_ << ", window : " << window_ + << ", once_sample_startid_len : " << once_sample_startid_len_ + << ", sample_times_one_chunk : " << repeat_time_ + << ", batch_size: " << batch_size_; }; - void AllocResource(const paddle::platform::Place& place, std::vector feed_vec, std::vector* h_device_keys); + void AllocResource(const paddle::platform::Place& place, + std::vector feed_vec, + std::vector* h_device_keys); int AcquireInstance(BufState* state); int GenerateBatch(); int FillWalkBuf(std::shared_ptr d_walk); - void FillOneStep(int64_t* walk, int len, NeighborSampleResult &sample_res, int cur_degree, int step, int *len_per_row); + void FillOneStep(int64_t* walk, int len, NeighborSampleResult& sample_res, + int cur_degree, int step, int* len_per_row); + protected: int walk_degree_; int walk_len_; @@ -930,7 +940,7 @@ class GraphDataGenerator { std::shared_ptr d_walk_; std::shared_ptr d_len_per_row_; std::shared_ptr d_random_row_; - // + // std::vector> d_sampleidx2rows_; int cur_sampleidx2row_; // record the keys to call graph_neighbor_sample diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index f55b73a4cab65..64e529f609fd3 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -411,17 +411,16 @@ void MultiSlotDataset::PrepareTrain() { } template -void DatasetImpl::SetGraphDeviceKeys(const std::vector& h_device_keys) { - - - for (size_t i = 0; i < gpu_graph_device_keys_.size(); i++) { - gpu_graph_device_keys_[i].clear(); - } - size_t device_num = gpu_graph_device_keys_.size(); - for (size_t i = 0; i < h_device_keys.size(); i++) { - int shard = h_device_keys[i] % device_num; - gpu_graph_device_keys_[shard].push_back(h_device_keys[i]); - } +void DatasetImpl::SetGraphDeviceKeys( + const std::vector& h_device_keys) { + for (size_t i = 0; i < gpu_graph_device_keys_.size(); i++) { + gpu_graph_device_keys_[i].clear(); + } + size_t device_num = gpu_graph_device_keys_.size(); + for (size_t i = 0; i < h_device_keys.size(); i++) { + int shard = h_device_keys[i] % device_num; + gpu_graph_device_keys_[shard].push_back(h_device_keys[i]); + } } // load data into memory, Dataset hold this memory, // which will later be fed into readers' channel @@ -435,9 +434,10 @@ void DatasetImpl::LoadIntoMemory() { VLOG(0) << "in gpu_graph_mode"; auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); gpu_graph_device_keys_ = gpu_graph_ptr->get_all_id(0, 0, thread_num_); - + for (size_t i = 0; i < gpu_graph_device_keys_.size(); i++) { - VLOG(0) << "gpu_graph_device_keys_[" << i << "] = " << gpu_graph_device_keys_[i].size(); + VLOG(0) << "gpu_graph_device_keys_[" << i + << "] = " << gpu_graph_device_keys_[i].size(); for (size_t j = 0; j < gpu_graph_device_keys_[i].size(); j++) { gpu_graph_total_keys_.push_back(gpu_graph_device_keys_[i][j]); } diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index fe9f0fa32614c..fb133ff4895b0 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -158,7 +158,9 @@ class Dataset { virtual void DynamicAdjustReadersNum(int thread_num) = 0; // set fleet send sleep seconds virtual void SetFleetSendSleepSeconds(int seconds) = 0; - virtual void SetGraphDeviceKeys(const std::vector& h_device_keys) = 0; + virtual void SetGraphDeviceKeys( + const std::vector& h_device_keys) = 0; + protected: virtual int ReceiveFromClient(int msg_type, int client_id, const std::string& msg) = 0; @@ -264,9 +266,7 @@ class DatasetImpl : public Dataset { return multi_consume_channel_; } } - std::vector& GetGpuGraphTotalKeys() { - return gpu_graph_total_keys_; - } + std::vector& GetGpuGraphTotalKeys() { return gpu_graph_total_keys_; } Channel& GetInputChannelRef() { return input_channel_; } protected: