From 5fee9282fcff455d4bc0283bb40f709984d4f500 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 9 May 2023 09:19:26 +0800 Subject: [PATCH] format code (3) --- paddle/fluid/framework/data_feed.cu | 272 +++++++++++++++------------- paddle/fluid/framework/data_feed.h | 2 +- 2 files changed, 145 insertions(+), 129 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index ffc96dc79f12f..607222fd02143 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -31,9 +31,9 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/heter_ps/hashtable.h" #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" #include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/platform/collective_helper.h" #include "paddle/phi/kernels/gpu/graph_reindex_funcs.h" #include "paddle/phi/kernels/graph_reindex_kernel.h" -#include "paddle/fluid/platform/collective_helper.h" DECLARE_bool(enable_opt_get_features); DECLARE_bool(graph_metapath_split_opt); @@ -343,7 +343,7 @@ __global__ void FillSlotValueOffsetKernel(const int ins_num, struct RandInt { int low, high; - __host__ __device__ RandInt(int low, int high) : low(low), high(high){}; + __host__ __device__ RandInt(int low, int high) : low(low), high(high) {} __host__ __device__ int operator()(const unsigned int n) const { thrust::default_random_engine rng; @@ -851,11 +851,11 @@ int MakeInsPair(const std::shared_ptr &d_walk, // input const GraphDataGeneratorConfig &conf, const std::shared_ptr &d_random_row, const std::shared_ptr &d_random_row_col_shift, - BufState &buf_state, - std::shared_ptr &d_ins_buf, // output - std::shared_ptr &d_pair_label_buf, // output - std::shared_ptr &d_pair_num_ptr, // output - int &ins_buf_pair_len, + BufState *buf_state, + uint64_t *ins_buf, // output + int32_t *pair_label_buf, // output + int *d_pair_num, // output + int *ins_buf_pair_len_ptr, cudaStream_t stream) { uint64_t *walk = reinterpret_cast(d_walk->ptr()); uint8_t *walk_ntype = NULL; @@ -867,20 +867,17 @@ int MakeInsPair(const std::shared_ptr &d_walk, // input excluded_train_pair = reinterpret_cast(conf.d_excluded_train_pair->ptr()); } - uint64_t *ins_buf = reinterpret_cast(d_ins_buf->ptr()); - int32_t *pair_label_buf = NULL; int32_t *pair_label_conf = NULL; if (conf.enable_pair_label) { - pair_label_buf = reinterpret_cast(d_pair_label_buf->ptr()); pair_label_conf = reinterpret_cast(conf.d_pair_label_conf->ptr()); } int *random_row = reinterpret_cast(d_random_row->ptr()); int *random_row_col_shift = reinterpret_cast(d_random_row_col_shift->ptr()); - int *d_pair_num = reinterpret_cast(d_pair_num_ptr->ptr()); cudaMemsetAsync(d_pair_num, 0, sizeof(int), stream); - int len = buf_state.len; + int len = buf_state->len; + int &ins_buf_pair_len = *ins_buf_pair_len_ptr; // make pair GraphFillIdKernel<<>>( @@ -889,10 +886,10 @@ int MakeInsPair(const std::shared_ptr &d_walk, // input d_pair_num, walk, walk_ntype, - random_row + buf_state.cursor, - random_row_col_shift + buf_state.cursor, - buf_state.central_word, - conf.window_step[buf_state.step], + random_row + buf_state->cursor, + random_row_col_shift + buf_state->cursor, + buf_state->central_word, + conf.window_step[buf_state->step], len, conf.walk_len, excluded_train_pair, @@ -926,19 +923,20 @@ int FillInsBuf(const std::shared_ptr &d_walk, // input const GraphDataGeneratorConfig &conf, const std::shared_ptr &d_random_row, const std::shared_ptr &d_random_row_col_shift, - BufState &buf_state, - std::shared_ptr &d_ins_buf, // output - std::shared_ptr &d_pair_label_buf, // output - std::shared_ptr &d_pair_num, // output - int &ins_buf_pair_len, + BufState *buf_state, + uint64_t *ins_buf, // output + int32_t *pair_label_buf, // output + int *pair_num_ptr, // output + int *ins_buf_pair_len_ptr, cudaStream_t stream) { + int &ins_buf_pair_len = *ins_buf_pair_len_ptr; if (ins_buf_pair_len >= conf.batch_size) { return conf.batch_size; } - int total_instance = AcquireInstance(&buf_state); + int total_instance = AcquireInstance(buf_state); VLOG(2) << "total_ins: " << total_instance; - buf_state.Debug(); + buf_state->Debug(); if (total_instance == 0) { return -1; @@ -949,10 +947,10 @@ int FillInsBuf(const std::shared_ptr &d_walk, // input d_random_row, d_random_row_col_shift, buf_state, - d_ins_buf, - d_pair_label_buf, - d_pair_num, - ins_buf_pair_len, + ins_buf, + pair_label_buf, + pair_num_ptr, + ins_buf_pair_len_ptr, stream); } @@ -985,16 +983,21 @@ int GraphDataGenerator::GenerateBatch() { // train if (!conf_.sage_mode) { while (ins_buf_pair_len_ < conf_.batch_size) { + int32_t *pair_label_buf = NULL; + if (d_pair_label_buf_ != NULL) { + pair_label_buf = + reinterpret_cast(d_pair_label_buf_->ptr()); + } res = FillInsBuf(d_walk_, d_walk_ntype_, conf_, d_random_row_, d_random_row_col_shift_, - buf_state_, - d_ins_buf_, - d_pair_label_buf_, - d_pair_num_, - ins_buf_pair_len_, + &buf_state_, + reinterpret_cast(d_ins_buf_->ptr()), + pair_label_buf, + reinterpret_cast(d_pair_num_->ptr()), + &ins_buf_pair_len_, train_stream_); if (res == -1) { if (ins_buf_pair_len_ == 0) { @@ -1260,14 +1263,14 @@ void FillOneStep( uint64_t *walk, uint8_t *walk_ntype, int len, - NeighborSampleResult &sample_res, + NeighborSampleResult *sample_res, int cur_degree, int step, const GraphDataGeneratorConfig &conf, - std::shared_ptr &d_sample_keys_ptr, - std::shared_ptr &d_prefix_sum_ptr, - std::vector> &d_sampleidx2rows, - int &cur_sampleidx2row, + uint64_t *d_sample_keys, + int *d_prefix_sum, + std::vector> *d_sampleidx2rows, + int *cur_sampleidx2row, const paddle::platform::Place &place, cudaStream_t stream) { auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); @@ -1275,15 +1278,12 @@ void FillOneStep( uint8_t edge_src_id = node_id >> 32; uint8_t edge_dst_id = node_id; size_t temp_storage_bytes = 0; - int *d_actual_sample_size = sample_res.actual_sample_size; - uint64_t *d_neighbors = sample_res.val; - int *d_prefix_sum = reinterpret_cast(d_prefix_sum_ptr->ptr()); - uint64_t *d_sample_keys = - reinterpret_cast(d_sample_keys_ptr->ptr()); + int *d_actual_sample_size = sample_res->actual_sample_size; + uint64_t *d_neighbors = sample_res->val; int *d_sampleidx2row = - reinterpret_cast(d_sampleidx2rows[cur_sampleidx2row]->ptr()); - int *d_tmp_sampleidx2row = - reinterpret_cast(d_sampleidx2rows[1 - cur_sampleidx2row]->ptr()); + reinterpret_cast((*d_sampleidx2rows)[*cur_sampleidx2row]->ptr()); + int *d_tmp_sampleidx2row = reinterpret_cast( + (*d_sampleidx2rows)[1 - *cur_sampleidx2row]->ptr()); CUDA_CHECK(cub::DeviceScan::InclusiveSum(NULL, temp_storage_bytes, @@ -1371,7 +1371,7 @@ void FillOneStep( delete[] h_offset2idx; } cudaStreamSynchronize(stream); - cur_sampleidx2row = 1 - cur_sampleidx2row; + *cur_sampleidx2row = 1 - *cur_sampleidx2row; } int GraphDataGenerator::FillSlotFeature(uint64_t *d_walk, size_t key_num) { @@ -1382,14 +1382,13 @@ int GraphDataGenerator::FillSlotFeature(uint64_t *d_walk, size_t key_num) { size_t temp_bytes = (key_num + 1) * sizeof(uint32_t); if (d_feature_size_list_buf_ == NULL || - d_feature_size_list_buf_->size() < temp_bytes) { - d_feature_size_list_buf_ = - memory::AllocShared(this->place_, temp_bytes); + d_feature_size_list_buf_->size() < temp_bytes) { + d_feature_size_list_buf_ = memory::AllocShared(this->place_, temp_bytes); } if (d_feature_size_prefixsum_buf_ == NULL || - d_feature_size_prefixsum_buf_->size() < temp_bytes) { - d_feature_size_prefixsum_buf_ = - memory::AllocShared(this->place_, temp_bytes); + d_feature_size_prefixsum_buf_->size() < temp_bytes) { + d_feature_size_prefixsum_buf_ = + memory::AllocShared(this->place_, temp_bytes); } int fea_num = @@ -1432,9 +1431,9 @@ int GraphDataGenerator::FillSlotFeature(uint64_t *d_walk, size_t key_num) { reinterpret_cast(d_feature_list->ptr()); uint8_t *d_slot_list_ptr = reinterpret_cast(d_slot_list->ptr()); uint32_t *d_feature_size_list_ptr = - reinterpret_cast(d_feature_size_list_buf_->ptr()); + reinterpret_cast(d_feature_size_list_buf_->ptr()); uint32_t *d_feature_size_prefixsum_ptr = - reinterpret_cast(d_feature_size_prefixsum_buf_->ptr()); + reinterpret_cast(d_feature_size_prefixsum_buf_->ptr()); VLOG(2) << "end trans feature list and slot list"; CUDA_CHECK(cudaStreamSynchronize(train_stream_)); @@ -1649,24 +1648,24 @@ uint64_t CopyUniqueNodes( uint64_t copy_unique_len, const paddle::platform::Place &place, const std::shared_ptr &d_uniq_node_num_ptr, - std::vector &host_vec, // output + std::vector *host_vec_ptr, // output cudaStream_t stream); // 对于deepwalk模式,尝试插入table,0表示插入成功,1表示插入失败; // 对于sage模式,尝试插入table,table数量不够则清空table重新插入,返回值无影响。 int InsertTable(const uint64_t *d_keys, // Input uint64_t len, // Input - std::shared_ptr &d_uniq_node_num, + std::shared_ptr *d_uniq_node_num, const GraphDataGeneratorConfig &conf, - uint64_t ©_unique_len, + uint64_t *copy_unique_len_ptr, const paddle::platform::Place &place, HashTable *table, - std::vector &host_vec, // Output + std::vector *host_vec_ptr, // Output cudaStream_t stream) { // Used under NOT WHOLE_HBM. uint64_t h_uniq_node_num = 0; uint64_t *d_uniq_node_num_ptr = - reinterpret_cast(d_uniq_node_num->ptr()); + reinterpret_cast((*d_uniq_node_num)->ptr()); cudaMemcpyAsync(&h_uniq_node_num, d_uniq_node_num_ptr, sizeof(uint64_t), @@ -1682,9 +1681,13 @@ int InsertTable(const uint64_t *d_keys, // Input return 1; } else { // Copy unique nodes first. - uint64_t copy_len = CopyUniqueNodes( - table, copy_unique_len, place, d_uniq_node_num, host_vec, stream); - copy_unique_len += copy_len; + uint64_t copy_len = CopyUniqueNodes(table, + *copy_unique_len_ptr, + place, + *d_uniq_node_num, + host_vec_ptr, + stream); + *copy_unique_len_ptr += copy_len; table->clear(stream); cudaMemsetAsync(d_uniq_node_num_ptr, 0, sizeof(uint64_t), stream); } @@ -1692,9 +1695,13 @@ int InsertTable(const uint64_t *d_keys, // Input } else { // used only for sage_mode. if (h_uniq_node_num + len >= conf.infer_table_cap) { - uint64_t copy_len = CopyUniqueNodes( - table, copy_unique_len, place, d_uniq_node_num, host_vec, stream); - copy_unique_len += copy_len; + uint64_t copy_len = CopyUniqueNodes(table, + *copy_unique_len_ptr, + place, + *d_uniq_node_num, + host_vec_ptr, + stream); + *copy_unique_len_ptr += copy_len; table->clear(stream); cudaMemsetAsync(d_uniq_node_num_ptr, 0, sizeof(uint64_t), stream); } @@ -2096,7 +2103,7 @@ uint64_t CopyUniqueNodes( uint64_t copy_unique_len, const paddle::platform::Place &place, const std::shared_ptr &d_uniq_node_num_ptr, - std::vector &host_vec, // output + std::vector *host_vec_ptr, // output cudaStream_t stream) { if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { uint64_t h_uniq_node_num = 0; @@ -2128,8 +2135,8 @@ uint64_t CopyUniqueNodes( cudaStreamSynchronize(stream); - host_vec.resize(h_uniq_node_num + copy_unique_len); - cudaMemcpyAsync(host_vec.data() + copy_unique_len, + host_vec_ptr->resize(h_uniq_node_num + copy_unique_len); + cudaMemcpyAsync(host_vec_ptr->data() + copy_unique_len, d_uniq_node_ptr, sizeof(uint64_t) * h_uniq_node_num, cudaMemcpyDeviceToHost, @@ -2163,16 +2170,21 @@ void GraphDataGenerator::DoWalkandSage() { while (ins_pair_flag) { int res = 0; while (ins_buf_pair_len_ < conf_.batch_size) { + int32_t *pair_label_buf = NULL; + if (d_pair_label_buf_ != NULL) { + pair_label_buf = + reinterpret_cast(d_pair_label_buf_->ptr()); + } res = FillInsBuf(d_walk_, d_walk_ntype_, conf_, d_random_row_, d_random_row_col_shift_, - buf_state_, - d_ins_buf_, - d_pair_label_buf_, - d_pair_num_, - ins_buf_pair_len_, + &buf_state_, + reinterpret_cast(d_ins_buf_->ptr()), + pair_label_buf, + reinterpret_cast(d_pair_num_->ptr()), + &ins_buf_pair_len_, sample_stream_); if (res == -1) { if (ins_buf_pair_len_ == 0) { @@ -2180,7 +2192,8 @@ void GraphDataGenerator::DoWalkandSage() { sage_pass_end = 1; if (total_row_ != 0) { buf_state_.Reset(total_row_); - VLOG(1) << "reset buf state to make batch num equal in multi node"; + VLOG(1) << "reset buf state to make batch num equal in " + "multi node"; } } else { ins_pair_flag = false; @@ -2248,12 +2261,12 @@ void GraphDataGenerator::DoWalkandSage() { reinterpret_cast(final_sage_nodes->ptr()); InsertTable(final_sage_nodes_ptr, uniq_instance, - d_uniq_node_num_, + &d_uniq_node_num_, conf_, - copy_unique_len_, + ©_unique_len_, place_, table_, - host_vec_, + &host_vec_, sample_stream_); } final_sage_nodes_vec_.emplace_back(final_sage_nodes); @@ -2267,7 +2280,7 @@ void GraphDataGenerator::DoWalkandSage() { copy_unique_len_, place_, d_uniq_node_num_, - host_vec_, + &host_vec_, sample_stream_); VLOG(1) << "train sage_batch_num: " << sage_batch_num_; } @@ -2326,12 +2339,12 @@ void GraphDataGenerator::DoWalkandSage() { reinterpret_cast(final_sage_nodes->ptr()); InsertTable(final_sage_nodes_ptr, uniq_instance, - d_uniq_node_num_, + &d_uniq_node_num_, conf_, - copy_unique_len_, + ©_unique_len_, place_, table_, - host_vec_, + &host_vec_, sample_stream_); } final_sage_nodes_vec_.emplace_back(final_sage_nodes); @@ -2351,7 +2364,7 @@ void GraphDataGenerator::DoWalkandSage() { copy_unique_len_, place_, d_uniq_node_num_, - host_vec_, + &host_vec_, sample_stream_); VLOG(1) << "infer sage_batch_num: " << sage_batch_num_; } @@ -2630,12 +2643,12 @@ int GraphDataGenerator::FillWalkBuf() { if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { if (InsertTable(d_type_keys + start, tmp_len, - d_uniq_node_num_, + &d_uniq_node_num_, conf_, - copy_unique_len_, + ©_unique_len_, place_, table_, - host_vec_, + &host_vec_, sample_stream_) != 0) { VLOG(2) << "gpu:" << conf_.gpuid << " in step 0, insert key stage, table is full"; @@ -2645,12 +2658,12 @@ int GraphDataGenerator::FillWalkBuf() { } if (InsertTable(sample_res.actual_val, sample_res.total_sample_size, - d_uniq_node_num_, + &d_uniq_node_num_, conf_, - copy_unique_len_, + ©_unique_len_, place_, table_, - host_vec_, + &host_vec_, sample_stream_) != 0) { VLOG(2) << "gpu:" << conf_.gpuid << " in step 0, insert sample res, table is full"; @@ -2665,14 +2678,14 @@ int GraphDataGenerator::FillWalkBuf() { cur_walk, cur_walk_ntype, tmp_len, - sample_res, + &sample_res, conf_.walk_degree, step, conf_, - d_sample_keys_, - d_prefix_sum_, - d_sampleidx2rows_, - cur_sampleidx2row_, + reinterpret_cast(d_sample_keys_->ptr()), + reinterpret_cast(d_prefix_sum_->ptr()), + &d_sampleidx2rows_, + &cur_sampleidx2row_, place_, sample_stream_); ///////// @@ -2726,12 +2739,12 @@ int GraphDataGenerator::FillWalkBuf() { if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { if (InsertTable(sample_res.actual_val, sample_res.total_sample_size, - d_uniq_node_num_, + &d_uniq_node_num_, conf_, - copy_unique_len_, + ©_unique_len_, place_, table_, - host_vec_, + &host_vec_, sample_stream_) != 0) { VLOG(0) << "gpu:" << conf_.gpuid << " in step: " << step << ", table is full"; @@ -2746,14 +2759,14 @@ int GraphDataGenerator::FillWalkBuf() { cur_walk, cur_walk_ntype, sample_key_len, - sample_res, + &sample_res, 1, step, conf_, - d_sample_keys_, - d_prefix_sum_, - d_sampleidx2rows_, - cur_sampleidx2row_, + reinterpret_cast(d_sample_keys_->ptr()), + reinterpret_cast(d_prefix_sum_->ptr()), + &d_sampleidx2rows_, + &cur_sampleidx2row_, place_, sample_stream_); if (conf_.debug_mode) { @@ -2831,7 +2844,7 @@ int GraphDataGenerator::FillWalkBuf() { copy_unique_len_, place_, d_uniq_node_num_, - host_vec_, + &host_vec_, sample_stream_); VLOG(1) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_ << ", d_walk_offset:" << i << ", total_rows:" << total_row_ @@ -2943,12 +2956,12 @@ int GraphDataGenerator::FillWalkBufMultiPath() { if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { if (InsertTable(d_type_keys + start, tmp_len, - d_uniq_node_num_, + &d_uniq_node_num_, conf_, - copy_unique_len_, + ©_unique_len_, place_, table_, - host_vec_, + &host_vec_, sample_stream_) != 0) { VLOG(2) << "in step 0, insert key stage, table is full"; update = false; @@ -2956,12 +2969,12 @@ int GraphDataGenerator::FillWalkBufMultiPath() { } if (InsertTable(sample_res.actual_val, sample_res.total_sample_size, - d_uniq_node_num_, + &d_uniq_node_num_, conf_, - copy_unique_len_, + ©_unique_len_, place_, table_, - host_vec_, + &host_vec_, sample_stream_) != 0) { VLOG(2) << "in step 0, insert sample res stage, table is full"; update = false; @@ -2975,14 +2988,14 @@ int GraphDataGenerator::FillWalkBufMultiPath() { cur_walk, cur_walk_ntype, tmp_len, - sample_res, + &sample_res, conf_.walk_degree, step, conf_, - d_sample_keys_, - d_prefix_sum_, - d_sampleidx2rows_, - cur_sampleidx2row_, + reinterpret_cast(d_sample_keys_->ptr()), + reinterpret_cast(d_prefix_sum_->ptr()), + &d_sampleidx2rows_, + &cur_sampleidx2row_, place_, sample_stream_); ///////// @@ -3024,12 +3037,12 @@ int GraphDataGenerator::FillWalkBufMultiPath() { if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) { if (InsertTable(sample_res.actual_val, sample_res.total_sample_size, - d_uniq_node_num_, + &d_uniq_node_num_, conf_, - copy_unique_len_, + ©_unique_len_, place_, table_, - host_vec_, + &host_vec_, sample_stream_) != 0) { VLOG(2) << "in step: " << step << ", table is full"; update = false; @@ -3042,14 +3055,14 @@ int GraphDataGenerator::FillWalkBufMultiPath() { cur_walk, cur_walk_ntype, sample_key_len, - sample_res, + &sample_res, 1, step, conf_, - d_sample_keys_, - d_prefix_sum_, - d_sampleidx2rows_, - cur_sampleidx2row_, + reinterpret_cast(d_sample_keys_->ptr()), + reinterpret_cast(d_prefix_sum_->ptr()), + &d_sampleidx2rows_, + &cur_sampleidx2row_, place_, sample_stream_); if (conf_.debug_mode) { @@ -3121,7 +3134,7 @@ int GraphDataGenerator::FillWalkBufMultiPath() { copy_unique_len_, place_, d_uniq_node_num_, - host_vec_, + &host_vec_, sample_stream_); VLOG(1) << "sample_times:" << sample_times << ", d_walk_size:" << buf_size_ << ", d_walk_offset:" << i << ", total_rows:" << total_row_ @@ -3528,9 +3541,11 @@ int GraphDataGenerator::multi_node_sync_sample(int flag, int GraphDataGenerator::dynamic_adjust_batch_num_for_sage() { int batch_num = (total_row_ + conf_.batch_size - 1) / conf_.batch_size; - auto send_buff = memory::Alloc(place_, 2 * sizeof(int), - phi::Stream(reinterpret_cast(sample_stream_))); - int* send_buff_ptr = reinterpret_cast(send_buff->ptr()); + auto send_buff = memory::Alloc( + place_, + 2 * sizeof(int), + phi::Stream(reinterpret_cast(sample_stream_))); + int *send_buff_ptr = reinterpret_cast(send_buff->ptr()); cudaMemcpyAsync(send_buff_ptr, &batch_num, sizeof(int), @@ -3554,10 +3569,11 @@ int GraphDataGenerator::dynamic_adjust_batch_num_for_sage() { sample_stream_); cudaStreamSynchronize(sample_stream_); - int new_batch_size = (total_row_ + thread_max_batch_num - 1) / thread_max_batch_num; + int new_batch_size = + (total_row_ + thread_max_batch_num - 1) / thread_max_batch_num; VLOG(2) << conf_.gpuid << " dynamic adjust sage batch num " - << " max_batch_num: " << thread_max_batch_num - << " new_batch_size: " << new_batch_size; + << " max_batch_num: " << thread_max_batch_num + << " new_batch_size: " << new_batch_size; return new_batch_size; } diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 8842037fa45a5..d3b56d2c31b58 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -952,7 +952,7 @@ class GraphDataGenerator { int FillSlotFeature(uint64_t* d_walk, size_t key_num); int GetPathNum() { return total_row_; } void ResetPathNum() { total_row_ = 0; } - int GetGraphBatchsize() { return conf_.batch_size; }; + int GetGraphBatchsize() { return conf_.batch_size; } void SetNewBatchsize(int batch_num) { if (!conf_.gpu_graph_training) { conf_.batch_size = (total_row_ + batch_num - 1) / batch_num;