diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index b04946ab46b50..11ab20df5e1b7 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -928,6 +928,7 @@ class DataFeed { return ins_content_vec_; } virtual int GetCurBatchSize() { return batch_size_; } + virtual void SetCurBatchSize(const int batch_size) { batch_size_ = batch_size; } virtual void LoadIntoMemory() { PADDLE_THROW(platform::errors::Unimplemented( "This function(LoadIntoMemory) is not implemented.")); diff --git a/paddle/fluid/framework/ps_gpu_worker.cc b/paddle/fluid/framework/ps_gpu_worker.cc index 27f458f475a14..34ef939133ebf 100644 --- a/paddle/fluid/framework/ps_gpu_worker.cc +++ b/paddle/fluid/framework/ps_gpu_worker.cc @@ -458,7 +458,7 @@ void PSGPUWorker::TrainFiles() { if (cur_batch <= 0) { break; } - + device_reader_->SetCurBatchSize(cur_batch); total_ins_num += cur_batch; if (shape_check_flag_.load()) {