Skip to content

Commit

Permalink
fix miss key for error dataset (PaddlePaddle#186)
Browse files Browse the repository at this point in the history
* fix miss key for error dataset

* fix miss key for error dataset

Co-authored-by: yangjunchao <yangjunchao@baidu.com>
  • Loading branch information
chao9527 and yangjunchao authored Dec 8, 2022
1 parent 5e845c5 commit 397242a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
13 changes: 8 additions & 5 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ void PSGPUWrapper::resize_gputask(std::shared_ptr<HeterContext> gpu_task) {
}
}

void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task, Dataset* dataset_for_pull) {
VLOG(3) << "PSGPUWrapper::BuildGPUPSTask begin";
platform::Timer timeline;
timeline.Start();
Expand Down Expand Up @@ -341,7 +341,7 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
<< " seconds.";
}
} else {
SlotRecordDataset* dataset = (SlotRecordDataset*)(dataset_);
SlotRecordDataset* dataset = (SlotRecordDataset*)(dataset_for_pull);
const std::vector<uint64_t>& vec_data = dataset->GetGpuGraphTotalKeys();
timeline.Start();
add_key_to_local(vec_data);
Expand Down Expand Up @@ -1308,7 +1308,8 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
gpu_task->Reset();
gpu_task->pass_id_ = (uint16_t)(dataset_->GetPassID());
data_ready_channel_->Put(gpu_task);

data_ready_channel_->Put(std::make_pair(gpu_task, dataset_));
} else if (hbm_sparse_table_initialized_ == false) {
SparseTableToHbm();
}
Expand All @@ -1326,15 +1327,17 @@ void PSGPUWrapper::start_build_thread() {
void PSGPUWrapper::pre_build_thread() {
// prebuild: process load_data
while (running_) {
std::pair<std::shared_ptr<HeterContext>, Dataset*> task = std::make_pair(nullptr, nullptr);
std::shared_ptr<HeterContext> gpu_task = nullptr;
if (!data_ready_channel_->Get(gpu_task)) {
if (!data_ready_channel_->Get(task)) {
continue;
}
gpu_task = task.first;
VLOG(3) << "thread PreBuildTask start.";
platform::Timer timer;
timer.Start();
// build cpu ps data process
PreBuildTask(gpu_task);
PreBuildTask(gpu_task, task.second);
timer.Pause();
VLOG(0) << "thread PreBuildTask end, cost time: " << timer.ElapsedSec()
<< " s";
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */

#include <atomic>
#include <ctime>
#include <utility>
#include <map>
#include <memory>
#include <mutex>
Expand Down Expand Up @@ -202,7 +203,7 @@ class PSGPUWrapper {
void divide_to_device(std::shared_ptr<HeterContext> gpu_task);
void add_slot_feature(std::shared_ptr<HeterContext> gpu_task);
void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task);
void PreBuildTask(std::shared_ptr<HeterContext> gpu_task);
void PreBuildTask(std::shared_ptr<HeterContext> gpu_task, Dataset* dataset_for_pull);
void BuildPull(std::shared_ptr<HeterContext> gpu_task);
void PrepareGPUTask(std::shared_ptr<HeterContext> gpu_task);
void LoadIntoMemory(bool is_shuffle);
Expand Down Expand Up @@ -778,9 +779,9 @@ class PSGPUWrapper {
#endif

std::shared_ptr<
paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
paddle::framework::ChannelObject<std::pair<std::shared_ptr<HeterContext>, Dataset*>>>
data_ready_channel_ =
paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
paddle::framework::MakeChannel<std::pair<std::shared_ptr<HeterContext>, Dataset*>>();
std::shared_ptr<
paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
buildcpu_ready_channel_ =
Expand Down

0 comments on commit 397242a

Please sign in to comment.