diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index ec33c6377ff22..a586aa0ce5f70 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -1955,6 +1955,20 @@ void PaddleBoxDataFeed::PutToFeedVec(const std::vector& ins_vec) { #endif } +SlotRecordInMemoryDataFeed::~SlotRecordInMemoryDataFeed() { +#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) + stop_token_.store(true); + for (auto& thread : pack_threads_) { + if (thread.joinable()) { + thread.join(); + } + } + for (auto* pack : pack_vec_) { + pack->set_use_flag(false); + } +#endif +} + template class InMemoryDataFeed; void SlotRecordInMemoryDataFeed::Init(const DataFeedDesc& data_feed_desc) { finish_init_ = false; @@ -2128,6 +2142,9 @@ void SlotRecordInMemoryDataFeed::LoadIntoMemoryByFile(void) { pull_record_func, lines); } else { int err_no = 0; + if (idx == 0) { + filename = filename.substr(7); + } this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_); CHECK(this->fp_ != nullptr); @@ -2162,7 +2179,7 @@ void SlotRecordInMemoryDataFeed::LoadIntoMemoryByLine(void) { BufferedLineFileReader::LineFunc line_func = nullptr; while (this->PickOneFile(&filename)) { - VLOG(3) << "PickOneFile, filename=" << filename + VLOG(0) << "PickOneFile, filename=" << filename << ", thread_id=" << thread_id_; std::vector record_vec; platform::Timer timeline; @@ -2418,8 +2435,13 @@ bool SlotRecordInMemoryDataFeed::ParseOneInstance(const std::string& line, void SlotRecordInMemoryDataFeed::AssignFeedVar(const Scope& scope) { CheckInit(); + if (scpoe_feed_vec_.count(&scope) > 0) { + return; + } + auto& feed_vec = scpoe_feed_vec_[&scope]; + feed_vec.resize(used_slots_info_.size()); for (int i = 0; i < use_slot_size_; ++i) { - feed_vec_[i] = + feed_vec[i] = scope.FindVar(used_slots_info_[i].slot)->GetMutable(); } } @@ -2427,15 +2449,14 @@ void SlotRecordInMemoryDataFeed::AssignFeedVar(const Scope& scope) { void SlotRecordInMemoryDataFeed::PutToFeedVec(const SlotRecord* ins_vec, int num) { #if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) - paddle::platform::SetDeviceId(place_.GetDeviceId()); - pack_->pack_instance(ins_vec, num); - BuildSlotBatchGPU(pack_->ins_num()); + // nothing to do #else for (int j = 0; j < use_slot_size_; ++j) { auto& feed = feed_vec_[j]; if (feed == nullptr) { continue; } + auto& slot_offset = offset_[j]; slot_offset.clear(); slot_offset.reserve(num + 1); @@ -2585,80 +2606,126 @@ bool SlotRecordInMemoryDataFeed::Start() { this->finish_start_ = true; #if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) CHECK(paddle::platform::is_gpu_place(this->place_)); - pack_ = BatchGpuPackMgr().get(this->GetPlace(), used_slots_info_); + + for (int i = 0; i < pack_thread_num_ + 1; i++) { + auto pack = BatchGpuPackMgr().get(this->GetPlace(), used_slots_info_); + pack_vec_.push_back(pack); + free_pack_queue_.Push(pack); + } + + pack_offset_index_.store(0); + pack_is_end_.store(false); + thread_count_.store(pack_thread_num_); + pack_threads_.reserve(pack_thread_num_); + for (int i = 0; i < pack_thread_num_; i++) { + pack_threads_.emplace_back(std::thread([this]() -> void { + while (!stop_token_.load()) { + uint64_t offset_index = pack_offset_index_.fetch_add(1); + if (offset_index >= batch_offsets_.size()) { + int thread_num = thread_count_.fetch_sub(1); + if (thread_num == 1) { + pack_is_end_.store(true); + } + return; + } + auto* pack = free_pack_queue_.Pop(); + + auto& batch = batch_offsets_[offset_index]; + auto offset = batch.first; + auto batch_size = batch.second; + + paddle::platform::SetDeviceId(place_.GetDeviceId()); + pack->pack_instance(&records_[offset], batch_size); + this->BuildSlotBatchGPU(batch_size, pack); + using_pack_queue_.Push(pack); + } + })); + } + + #endif return true; } int SlotRecordInMemoryDataFeed::Next() { -#ifdef _LINUX - this->CheckStart(); - - VLOG(3) << "enable heter next: " << offset_index_ - << " batch_offsets: " << batch_offsets_.size(); - if (offset_index_ >= batch_offsets_.size()) { - VLOG(3) << "offset_index: " << offset_index_ - << " batch_offsets: " << batch_offsets_.size(); - return 0; - } - auto& batch = batch_offsets_[offset_index_++]; - this->batch_size_ = batch.second; - VLOG(3) << "batch_size_=" << this->batch_size_ - << ", thread_id=" << thread_id_; - if (this->batch_size_ != 0) { - PutToFeedVec(&records_[batch.first], this->batch_size_); - } else { - VLOG(3) << "finish reading for heterps, batch size zero, thread_id=" - << thread_id_; + while (true) { + if (last_pack_ != nullptr) { + free_pack_queue_.Push(last_pack_); + last_pack_ = nullptr; + } + if (using_pack_queue_.Size() != 0) { + auto* pack = using_pack_queue_.Pop(); + PackToScope(pack); + last_pack_ = pack; + return pack->ins_num(); + } + bool is_end = pack_is_end_.load(); + if (is_end) { + if (using_pack_queue_.Size() == 0) { + return 0; + } + } + std::this_thread::sleep_for( + std::chrono::microseconds(200)); } - VLOG(3) << "enable heter next: " << offset_index_ - << " batch_offsets: " << batch_offsets_.size() - << " baych_size: " << this->batch_size_; - - return this->batch_size_; -#else - return 0; -#endif } #if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) -void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) { +void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num, MiniBatchGpuPack* pack) { int offset_cols_size = (ins_num + 1); + // slot总数 = 样本数 * 每个样本的slot数 size_t slot_total_num = (use_slot_size_ * offset_cols_size); - pack_->resize_gpu_slot_offsets(slot_total_num * sizeof(size_t)); + // 创建保存全部slot-offset的bytes buffer + pack->resize_gpu_slot_offsets(slot_total_num * sizeof(size_t)); - auto& value = pack_->value(); + // gpu value + auto& value = pack->value(); const UsedSlotGpuType* used_slot_gpu_types = - static_cast(pack_->get_gpu_slots()); + static_cast(pack->get_gpu_slots()); + + // 填充gpu_slot_offset FillSlotValueOffset(ins_num, use_slot_size_, - reinterpret_cast(pack_->gpu_slot_offsets()), + reinterpret_cast(pack->gpu_slot_offsets()), value.d_uint64_offset.data(), uint64_use_slot_size_, value.d_float_offset.data(), float_use_slot_size_, - used_slot_gpu_types); - size_t* d_slot_offsets = reinterpret_cast(pack_->gpu_slot_offsets()); + used_slot_gpu_types, + pack->get_stream()); + + size_t* d_slot_offsets = reinterpret_cast(pack->gpu_slot_offsets()); - HostBuffer& offsets = pack_->offsets(); + HostBuffer& offsets = pack->offsets(); offsets.resize(slot_total_num); - HostBuffer& h_tensor_ptrs = pack_->h_tensor_ptrs(); + HostBuffer& h_tensor_ptrs = pack->h_tensor_ptrs(); h_tensor_ptrs.resize(use_slot_size_); // alloc gpu memory - pack_->resize_tensor(); + pack->resize_tensor(); - LoDTensor& float_tensor = pack_->float_tensor(); - LoDTensor& uint64_tensor = pack_->uint64_tensor(); + LoDTensor& float_tensor = pack->float_tensor(); + LoDTensor& uint64_tensor = pack->uint64_tensor(); + + // copy index + CUDA_CHECK(cudaMemcpyAsync(offsets.data(), d_slot_offsets, + slot_total_num * sizeof(size_t), + cudaMemcpyDeviceToHost, pack->get_stream())); + + cudaStreamSynchronize(pack->get_stream()); int64_t float_offset = 0; int64_t uint64_offset = 0; + size_t float_zero_slot_index = 0; + size_t uint64_zero_slot_index = 0; - // copy index - CUDA_CHECK(cudaMemcpy(offsets.data(), d_slot_offsets, - slot_total_num * sizeof(size_t), - cudaMemcpyDeviceToHost)); for (int j = 0; j < use_slot_size_; ++j) { - auto& feed = feed_vec_[j]; - if (feed == nullptr) { - h_tensor_ptrs[j] = nullptr; - continue; + if (scpoe_feed_vec_.size() > 0) { + if (scpoe_feed_vec_.begin()->second[j] == nullptr) { + h_tensor_ptrs[j] = nullptr; + continue; + } + } else { + if (feed_vec_[j] == nullptr) { + h_tensor_ptrs[j] = nullptr; + continue; + } } size_t* off_start_ptr = &offsets[j * offset_cols_size]; @@ -2668,6 +2735,69 @@ void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) { << ", total instance:" << total_instance; auto& info = used_slots_info_[j]; + // fill slot value with default value 0 + if (info.type[0] == 'f') { // float + if (total_instance > 0) { + h_tensor_ptrs[j] = float_tensor.data() + float_offset; + float_offset += total_instance; + } else { + h_tensor_ptrs[j] = pack->float_tensor_vec()[float_zero_slot_index].mutable_data({total_instance, 1}, this->place_); + float_zero_slot_index++; + } + } else if (info.type[0] == 'u') { // uint64 + if (total_instance > 0) { + h_tensor_ptrs[j] = uint64_tensor.data() + uint64_offset; + uint64_offset += total_instance; + } else { + h_tensor_ptrs[j] = pack->uint64_tensor_vec()[uint64_zero_slot_index].mutable_data({total_instance, 1}, this->place_); + uint64_zero_slot_index++; + } + } + } + void** dest_gpu_p = reinterpret_cast(pack->slot_buf_ptr()); + CUDA_CHECK(cudaMemcpyAsync(dest_gpu_p, h_tensor_ptrs.data(), + use_slot_size_ * sizeof(void*), + cudaMemcpyHostToDevice, pack->get_stream())); + + CopyForTensor(ins_num, use_slot_size_, dest_gpu_p, + (const size_t*)pack->gpu_slot_offsets(), + (const uint64_t*)value.d_uint64_keys.data(), + (const int*)value.d_uint64_offset.data(), + (const int*)value.d_uint64_lens.data(), uint64_use_slot_size_, + (const float*)value.d_float_keys.data(), + (const int*)value.d_float_offset.data(), + (const int*)value.d_float_lens.data(), float_use_slot_size_, + used_slot_gpu_types, pack->get_stream()); +} + +void SlotRecordInMemoryDataFeed::PackToScope(MiniBatchGpuPack* pack, const Scope* scope) { + int64_t float_offset = 0; + int64_t uint64_offset = 0; + size_t float_zero_slot_index = 0; + size_t uint64_zero_slot_index = 0; + + int offset_cols_size = (pack->ins_num() + 1); + HostBuffer& offsets = pack->offsets(); + LoDTensor& float_tensor = pack->float_tensor(); + LoDTensor& uint64_tensor = pack->uint64_tensor(); + + auto* feed_vec = &feed_vec_; + if (scope) { + CHECK(scpoe_feed_vec_.count(scope) > 0) << "scope not found."; + feed_vec = &scpoe_feed_vec_[scope]; + } + + CHECK(feed_vec != nullptr) << "feed_vec nullptr."; + + for (int j = 0; j < use_slot_size_; ++j) { + auto& feed = (*feed_vec)[j]; + if (feed == nullptr) { + continue; + } + size_t* off_start_ptr = &offsets[j * offset_cols_size]; + int total_instance = static_cast(off_start_ptr[offset_cols_size - 1]); + auto& info = used_slots_info_[j]; + // fill slot value with default value 0 if (info.type[0] == 'f') { // float if (total_instance > 0) { @@ -2676,10 +2806,9 @@ void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) { static_cast(float_offset + total_instance))); feed->Resize({total_instance, 1}); float_offset += total_instance; - h_tensor_ptrs[j] = feed->mutable_data(this->place_); } else { - h_tensor_ptrs[j] = - feed->mutable_data({total_instance, 1}, this->place_); + feed->ShareDataWith(pack->float_tensor_vec()[float_zero_slot_index++]); + feed->Resize({total_instance, 1}); } } else if (info.type[0] == 'u') { // uint64 if (total_instance > 0) { @@ -2688,10 +2817,9 @@ void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) { static_cast(uint64_offset + total_instance))); feed->Resize({total_instance, 1}); uint64_offset += total_instance; - h_tensor_ptrs[j] = feed->mutable_data(this->place_); } else { - h_tensor_ptrs[j] = - feed->mutable_data({total_instance, 1}, this->place_); + feed->ShareDataWith(pack->uint64_tensor_vec()[uint64_zero_slot_index++]); + feed->Resize({total_instance, 1}); } } @@ -2710,28 +2838,37 @@ void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) { offset_cols_size * sizeof(size_t)); } } - void** dest_gpu_p = reinterpret_cast(pack_->slot_buf_ptr()); - CUDA_CHECK(cudaMemcpy(dest_gpu_p, h_tensor_ptrs.data(), - use_slot_size_ * sizeof(void*), - cudaMemcpyHostToDevice)); +} - CopyForTensor(ins_num, use_slot_size_, dest_gpu_p, - (const size_t*)pack_->gpu_slot_offsets(), - (const uint64_t*)value.d_uint64_keys.data(), - (const int*)value.d_uint64_offset.data(), - (const int*)value.d_uint64_lens.data(), uint64_use_slot_size_, - (const float*)value.d_float_keys.data(), - (const int*)value.d_float_offset.data(), - (const int*)value.d_float_lens.data(), float_use_slot_size_, - used_slot_gpu_types); +MiniBatchGpuPack* SlotRecordInMemoryDataFeed::get_pack(MiniBatchGpuPack* last_pack) { + if (last_pack != nullptr) { + free_pack_queue_.Push(last_pack); + return nullptr; + } + + std::unique_lock lock(pack_mutex_); + while (true) { + if (using_pack_queue_.Size() != 0) { + auto* pack = using_pack_queue_.Pop(); + return pack; + } + bool is_end = pack_is_end_.load(); + if (is_end) { + if (using_pack_queue_.Size() == 0) { + return nullptr; + } + } + std::this_thread::sleep_for( + std::chrono::microseconds(200)); + } } + MiniBatchGpuPack::MiniBatchGpuPack(const paddle::platform::Place& place, const std::vector& infos) { place_ = place; - stream_ = dynamic_cast( - platform::DeviceContextPool::Instance().Get(place)) - ->stream(); + stream_holder_.reset(new platform::stream::CUDAStream(place)); + stream_ = stream_holder_->raw_stream(); ins_num_ = 0; pv_num_ = 0; @@ -2757,15 +2894,25 @@ MiniBatchGpuPack::MiniBatchGpuPack(const paddle::platform::Place& place, VLOG(3) << "begin get batch pack device id: " << device_id; // sync CUDA_CHECK(cudaStreamSynchronize(stream_)); + float_tensor_vec_.resize(used_slot_size_); + uint64_tensor_vec_.resize(used_slot_size_); } MiniBatchGpuPack::~MiniBatchGpuPack() {} + +bool MiniBatchGpuPack::is_use() { + return is_using_; +} + +void MiniBatchGpuPack::set_use_flag(bool is_use) { + is_using_ = is_use; +} + void MiniBatchGpuPack::reset(const paddle::platform::Place& place) { place_ = place; - stream_ = dynamic_cast( - platform::DeviceContextPool::Instance().Get(place)) - ->stream(); + stream_holder_.reset(new platform::stream::CUDAStream(place)); + stream_ = stream_holder_->raw_stream(); ins_num_ = 0; pv_num_ = 0; } @@ -2904,6 +3051,7 @@ void MiniBatchGpuPack::pack_float_data(const SlotRecord* ins_vec, int num) { } void MiniBatchGpuPack::pack_instance(const SlotRecord* ins_vec, int num) { + // VLOG(0) << "pack_instance, slot_record:" << ins_vec << ", num: " << num; ins_num_ = num; batch_ins_ = ins_vec; CHECK(used_uint64_num_ > 0 || used_float_num_ > 0); diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index f9435ec2a32d8..aaf1f1a009402 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -71,11 +71,8 @@ void SlotRecordInMemoryDataFeed::FillSlotValueOffset( const int ins_num, const int used_slot_num, size_t *slot_value_offsets, const int *uint64_offsets, const int uint64_slot_size, const int *float_offsets, const int float_slot_size, - const UsedSlotGpuType *used_slots) { - auto stream = - dynamic_cast( - paddle::platform::DeviceContextPool::Instance().Get(this->place_)) - ->stream(); + const UsedSlotGpuType *used_slots, + cudaStream_t stream) { FillSlotValueOffsetKernel<<>>( ins_num, used_slot_num, slot_value_offsets, uint64_offsets, @@ -130,12 +127,8 @@ void SlotRecordInMemoryDataFeed::CopyForTensor( const int *uint64_offsets, const int *uint64_ins_lens, const int uint64_slot_size, const float *float_feas, const int *float_offsets, const int *float_ins_lens, - const int float_slot_size, const UsedSlotGpuType *used_slots) { - auto stream = - dynamic_cast( - paddle::platform::DeviceContextPool::Instance().Get(this->place_)) - ->stream(); - + const int float_slot_size, const UsedSlotGpuType *used_slots, + cudaStream_t stream) { CopyForTensorKernel<<>>( used_slot_num, ins_num, dest, slot_value_offsets, uint64_feas, diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 6f7f1dac52804..31b733f3a7530 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -530,6 +530,9 @@ class MiniBatchGpuPack { MiniBatchGpuPack(const paddle::platform::Place& place, const std::vector& infos); ~MiniBatchGpuPack(); + + bool is_use(); + void set_use_flag(bool is_use); void reset(const paddle::platform::Place& place); void pack_instance(const SlotRecord* ins_vec, int num); int ins_num() { return ins_num_; } @@ -559,6 +562,8 @@ class MiniBatchGpuPack { } LoDTensor& float_tensor(void) { return float_tensor_; } LoDTensor& uint64_tensor(void) { return uint64_tensor_; } + std::vector& float_tensor_vec(void) { return float_tensor_vec_; } + std::vector& uint64_tensor_vec(void) { return uint64_tensor_vec_; } HostBuffer& offsets(void) { return offsets_; } HostBuffer& h_tensor_ptrs(void) { return h_tensor_ptrs_; } @@ -583,6 +588,10 @@ class MiniBatchGpuPack { return batch_ins_[idx]->ins_id_; } + cudaStream_t get_stream() { + return stream_; + } + private: void transfer_to_gpu(void); void pack_all_data(const SlotRecord* ins_vec, int num); @@ -605,7 +614,9 @@ class MiniBatchGpuPack { } private: + bool is_using_ = false; paddle::platform::Place place_; + std::unique_ptr stream_holder_; cudaStream_t stream_; BatchGPUValue value_; BatchCPUValue buf_; @@ -624,8 +635,10 @@ class MiniBatchGpuPack { // uint64 tensor LoDTensor uint64_tensor_; + std::vector uint64_tensor_vec_; // float tensor LoDTensor float_tensor_; + std::vector float_tensor_vec_; // batch HostBuffer offsets_; HostBuffer h_tensor_ptrs_; @@ -638,33 +651,42 @@ class MiniBatchGpuPackMgr { public: MiniBatchGpuPackMgr() { + pack_list_.resize(MAX_DEIVCE_NUM); for (int i = 0; i < MAX_DEIVCE_NUM; ++i) { - pack_list_[i] = nullptr; + pack_list_[i].clear(); } } ~MiniBatchGpuPackMgr() { for (int i = 0; i < MAX_DEIVCE_NUM; ++i) { - if (pack_list_[i] == nullptr) { - continue; + for (size_t j = 0; j < pack_list_[i].size(); j++) { + if (pack_list_[i][j] == nullptr) { + continue; + } + delete pack_list_[i][j]; + pack_list_[i][j] = nullptr; } - delete pack_list_[i]; - pack_list_[i] = nullptr; } } - // one device one thread + + // thread unsafe MiniBatchGpuPack* get(const paddle::platform::Place& place, const std::vector& infos) { int device_id = place.GetDeviceId(); - if (pack_list_[device_id] == nullptr) { - pack_list_[device_id] = new MiniBatchGpuPack(place, infos); - } else { - pack_list_[device_id]->reset(place); + for (size_t i = 0; i < pack_list_[device_id].size(); i++) { + if (!pack_list_[device_id][i]->is_use()) { + pack_list_[device_id][i]->set_use_flag(true); + pack_list_[device_id][i]->reset(place); + return pack_list_[device_id][i]; + } } - return pack_list_[device_id]; + auto* pack = new MiniBatchGpuPack(place, infos); + pack->set_use_flag(true); + pack_list_[device_id].push_back(pack); + return pack; } private: - MiniBatchGpuPack* pack_list_[MAX_DEIVCE_NUM]; + std::vector> pack_list_; }; // global mgr inline MiniBatchGpuPackMgr& BatchGpuPackMgr() { @@ -744,6 +766,19 @@ class DLManager { if (it != handle_map_.end()) { return it->second.parser; } + // load so symbol + // 导出libps、core_avx符号给parser共享 + const std::vector packages {"libps.so", "core_avx.so"}; + for (auto& package : packages) { + if (handle_map_.count(package) == 0) { + DLHandle handle_ps; + handle_ps.module = dlopen(package.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle_ps.module == nullptr) { + VLOG(0) << "Create so of " << package << " fail, " << dlerror(); + } + handle_map_.insert({package, handle_ps}); + } + } handle.module = dlopen(name.c_str(), RTLD_NOW); if (handle.module == nullptr) { VLOG(0) << "Create so of " << name << " fail"; @@ -812,6 +847,10 @@ class DataFeed { // This function is used for binding feed_vec memory in a given scope virtual void AssignFeedVar(const Scope& scope); + virtual std::vector GetInputVarNames() { + return std::vector(); + } + // This function will do nothing at default virtual void SetInputPvChannel(void* channel) {} // This function will do nothing at default @@ -858,6 +897,14 @@ class DataFeed { } virtual const paddle::platform::Place& GetPlace() const { return place_; } +#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) + virtual MiniBatchGpuPack* get_pack(MiniBatchGpuPack* last_pack) { return nullptr; } + virtual void PackToScope(MiniBatchGpuPack* pack, const Scope* scope) { + PADDLE_THROW(platform::errors::Unimplemented( + "This function(PackToScope) is not implemented.")); + } +#endif + protected: // The following three functions are used to check if it is executed in this // order: @@ -1390,13 +1437,8 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed { class SlotRecordInMemoryDataFeed : public InMemoryDataFeed { public: SlotRecordInMemoryDataFeed() {} - virtual ~SlotRecordInMemoryDataFeed() { -#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) - if (pack_ != nullptr) { - pack_ = nullptr; - } -#endif - } + virtual ~SlotRecordInMemoryDataFeed(); + virtual void Init(const DataFeedDesc& data_feed_desc); virtual void LoadIntoMemory(); void ExpandSlotRecord(SlotRecord* ins); @@ -1420,21 +1462,37 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed { bool ParseOneInstance(const std::string& line, SlotRecord* rec); virtual void PutToFeedVec(const SlotRecord* ins_vec, int num); virtual void AssignFeedVar(const Scope& scope); + virtual std::vector GetInputVarNames() { + std::vector var_names; + for (int i = 0; i < use_slot_size_; ++i) { + var_names.push_back(used_slots_info_[i].slot); + } + return var_names; + } + #if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) - void BuildSlotBatchGPU(const int ins_num); + void BuildSlotBatchGPU(const int ins_num, MiniBatchGpuPack* pack); + + // async infershape + virtual MiniBatchGpuPack* get_pack(MiniBatchGpuPack* last_pack); + virtual void PackToScope(MiniBatchGpuPack* pack, const Scope* scope = nullptr); + void FillSlotValueOffset(const int ins_num, const int used_slot_num, size_t* slot_value_offsets, const int* uint64_offsets, const int uint64_slot_size, const int* float_offsets, const int float_slot_size, - const UsedSlotGpuType* used_slots); + const UsedSlotGpuType* used_slots, + cudaStream_t stream + ); void CopyForTensor(const int ins_num, const int used_slot_num, void** dest, const size_t* slot_value_offsets, const uint64_t* uint64_feas, const int* uint64_offsets, const int* uint64_ins_lens, const int uint64_slot_size, const float* float_feas, const int* float_offsets, const int* float_ins_lens, const int float_slot_size, - const UsedSlotGpuType* used_slots); + const UsedSlotGpuType* used_slots, + cudaStream_t stream); #endif float sample_rate_ = 1.0f; int use_slot_size_ = 0; @@ -1446,7 +1504,21 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed { std::vector float_total_dims_without_inductives_; #if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) - MiniBatchGpuPack* pack_ = nullptr; + int pack_thread_num_ {5}; + std::vector pack_threads_; + std::vector pack_vec_; + BlockingQueue free_pack_queue_; + BlockingQueue using_pack_queue_; + std::atomic pack_is_end_ {false}; + std::atomic pack_offset_index_ {0}; + MiniBatchGpuPack* last_pack_ {nullptr}; + std::atomic stop_token_ {false}; + std::atomic thread_count_ {0}; + std::mutex pack_mutex_; + + // async infershape + std::map > scpoe_feed_vec_; + #endif }; diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 718a59a6f72cd..b377dfde6b162 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -528,7 +528,7 @@ class HeterCpuWorker : public HogwildWorker { class PSGPUWorker : public HogwildWorker { public: PSGPUWorker() {} - virtual ~PSGPUWorker() {} + virtual ~PSGPUWorker(); virtual void Initialize(const TrainerDesc& desc); virtual void TrainFiles(); virtual void TrainFilesWithProfiler(); @@ -542,6 +542,10 @@ class PSGPUWorker : public HogwildWorker { virtual void SetEvent(const gpuEvent_t event) { event_ = event; } void ResetStat(); + // async infershape + virtual void CreateDeviceResource(const ProgramDesc& main_prog); + virtual void BindingDataFeedMemory(); + protected: void PushGradients(); void CopySparseTable(); @@ -549,6 +553,16 @@ class PSGPUWorker : public HogwildWorker { void CopyDenseVars(); void PrepareCudaGraph(); + struct InferShapeCheckData { + std::vector> pre_dims; + std::vector> pre_lods; + std::vector> after_dims; + std::vector> after_lods; + }; + + int OpRunAndShapeCheck(OperatorBase& op, + const Scope& scope, + const platform::Place& place); private: int mpi_rank_; std::mutex mutex_; @@ -618,6 +632,28 @@ class PSGPUWorker : public HogwildWorker { double gpu_2_cpu_time_; double cpu_2_gpu_time_; uint64_t total_inst_; + + // async infershape + int task_threads_num_ {6}; + int scope_num_ {task_threads_num_ + 1}; + std::atomic thread_count_ {0}; + std::atomic stop_token_ {false}; + std::atomic pack_is_end_ {false}; + std::vector task_threads_; + std::vector thread_scope_vec_; + std::map> need_reuse_var_vec_; + std::vector need_reuse_var_; + + struct TaskData { + int ins_num; + Scope* scope; + MiniBatchGpuPack* pack; + }; + paddle::framework::BlockingQueue free_task_queue_; + paddle::framework::BlockingQueue using_task_queue_; + + static std::atomic shape_check_count_; + static std::atomic shape_check_flag_; }; #endif diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index e6577f662ae7b..587cfca069b75 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -661,444 +661,460 @@ bool OpSupportGPU(const std::string& op_type) { return false; } -class RuntimeInferShapeContext : public InferShapeContext { - public: - RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx) - : op_(op), ctx_(ctx) {} - - bool HasInput(const std::string& name) const override { - // has only one input - const auto& ins = ctx_.inputs; - auto it = ins.find(name); - if (it == ins.end()) { - return false; - } - const auto& in = it->second; - if (in.size() == 0) return false; - PADDLE_ENFORCE_EQ( - in.size(), 1UL, - platform::errors::InvalidArgument( - "Input %s should not contain more than one inputs.", name)); - return in[0] != nullptr; - } - - bool HasOutput(const std::string& name) const override { - // has only one output - const auto& outs = ctx_.outputs; - auto it = outs.find(name); - if (it == outs.end()) { - return false; - } - const auto& out = it->second; - if (out.size() == 0) { - return false; - } - PADDLE_ENFORCE_EQ( - out.size(), 1UL, - platform::errors::InvalidArgument( - "Output %s should not contain more than one outputs.", name)); - return out[0] != nullptr; +bool RuntimeInferShapeContext::HasInput(const std::string& name) const { + // has only one input + const auto& ins = ctx_.inputs; + auto it = ins.find(name); + if (it == ins.end()) { + return false; } + const auto& in = it->second; + if (in.size() == 0) return false; + PADDLE_ENFORCE_EQ( + in.size(), 1UL, + platform::errors::InvalidArgument( + "Input %s should not contain more than one inputs.", name)); + return in[0] != nullptr; +} - bool HasAttr(const std::string& name) const override { - return op_.HasAttr(name); +bool RuntimeInferShapeContext::HasOutput(const std::string& name) const { + // has only one output + const auto& outs = ctx_.outputs; + auto it = outs.find(name); + if (it == outs.end()) { + return false; } + const auto& out = it->second; + if (out.size() == 0) { + return false; + } + PADDLE_ENFORCE_EQ( + out.size(), 1UL, + platform::errors::InvalidArgument( + "Output %s should not contain more than one outputs.", name)); + return out[0] != nullptr; +} + +bool RuntimeInferShapeContext::HasAttr(const std::string& name) const { + return op_.HasAttr(name); +} - bool HasInputs(const std::string& name) const override { - const auto& ins = ctx_.inputs; - auto it = ins.find(name); - if (it == ins.end() || it->second.empty()) { +bool RuntimeInferShapeContext::HasInputs(const std::string& name) const { + const auto& ins = ctx_.inputs; + auto it = ins.find(name); + if (it == ins.end() || it->second.empty()) { + return false; + } + for (auto& input : it->second) { + if (input == nullptr) { return false; } - for (auto& input : it->second) { - if (input == nullptr) { - return false; - } - } - return true; } + return true; +} - bool HasOutputs(const std::string& name) const override { - const auto& outs = ctx_.outputs; - auto it = outs.find(name); - if (it == outs.end() || it->second.empty()) { +bool RuntimeInferShapeContext::HasOutputs(const std::string& name) const { + const auto& outs = ctx_.outputs; + auto it = outs.find(name); + if (it == outs.end() || it->second.empty()) { + return false; + } + for (auto& output : it->second) { + if (output == nullptr) { return false; } - for (auto& output : it->second) { - if (output == nullptr) { - return false; - } - } - return true; } + return true; +} - AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } - - std::vector Inputs(const std::string& name) const override { - return op_.Inputs(name); - } - - std::vector Outputs(const std::string& name) const override { - return op_.Outputs(name); - } - - std::string GetInputNameByIdx(size_t idx) const override { - auto& op_proto = - paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; - PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(), - platform::errors::OutOfRange( - "The index should be less than the size of inputs of " - "operator %s, but got index is %d and size is %d", - op_.Type(), idx, op_proto->inputs().size())); - return op_proto->inputs()[idx].name(); - } - - std::string GetOutputNameByIdx(size_t idx) const override { - auto& op_proto = - paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; - PADDLE_ENFORCE_LT( - idx, op_proto->outputs().size(), - platform::errors::OutOfRange( - "The index should be less than the size of outputs of " - "operator %s, but got index is %d and size is %d", - op_.Type(), idx, op_proto->outputs().size())); - return op_proto->outputs()[idx].name(); - } - - void ShareDim(const std::string& in, const std::string& out, size_t i = 0, - size_t j = 0) override { - auto in_it = ctx_.inputs.find(in); - auto out_it = ctx_.outputs.find(out); - PADDLE_ENFORCE_NE( - in_it, ctx_.inputs.end(), - platform::errors::NotFound("Input %s does not exist.", in)); - PADDLE_ENFORCE_NE( - out_it, ctx_.outputs.end(), - platform::errors::NotFound("Output %s does not exist.", out)); - PADDLE_ENFORCE_LT(i, in_it->second.size(), - platform::errors::InvalidArgument( - "The index of input dimension is out of range, " - "excepted index less than %zu, but received %zu.", - in_it->second.size(), i)); - PADDLE_ENFORCE_LT(j, out_it->second.size(), - platform::errors::InvalidArgument( - "The index of output dimension is out of range, " - "excepted index less than %zu, but received %zu.", - out_it->second.size(), j)); - - Variable* in_var = in_it->second[i]; - Variable* out_var = out_it->second[j]; - - PADDLE_ENFORCE_EQ( - in_var->Type(), out_var->Type(), - platform::errors::InvalidArgument( - "The type of input (%s) and output (%s) are inconsistent.", in, - out)); - - if (in_var->IsType()) { - auto& in_sele_rows = in_var->Get(); - auto out_sele_rows = out_var->GetMutable(); - out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims()); - out_sele_rows->set_rows(in_sele_rows.rows()); - out_sele_rows->set_height(in_sele_rows.height()); - } else if (in_var->IsType()) { - auto& in_lod_tensor = in_var->Get(); - auto* out_lod_tensor = out_var->GetMutable(); - out_lod_tensor->Resize(in_lod_tensor.dims()); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Currently, the input type of ShareDim only can be LoDTensor " - "or SelectedRows.")); - } - } +AttrReader RuntimeInferShapeContext::Attrs() const { return AttrReader(op_.Attrs()); } - void ShareAllLoD(const std::string& in, - const std::string& out) const override { - auto in_it = ctx_.inputs.find(in); - auto out_it = ctx_.outputs.find(out); - PADDLE_ENFORCE_NE(in_it, ctx_.inputs.end(), - platform::errors::NotFound( - "Input [%s] found error in Op [%s]", in, op_.Type())); - PADDLE_ENFORCE_NE( - out_it, ctx_.outputs.end(), - platform::errors::NotFound("Output [%s] found error in Op [%s]", out, - op_.Type())); +std::vector RuntimeInferShapeContext::Inputs(const std::string& name) const { + return op_.Inputs(name); +} - auto& in_var_list = in_it->second; - auto& out_var_list = out_it->second; +std::vector RuntimeInferShapeContext::Outputs(const std::string& name) const { + return op_.Outputs(name); +} - PADDLE_ENFORCE_EQ( - in_var_list.size(), out_var_list.size(), - platform::errors::PreconditionNotMet( - "Op [%s]: Input var size should be equal with output var size", - op_.Type())); +std::string RuntimeInferShapeContext::GetInputNameByIdx(size_t idx) const { + auto& op_proto = + paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; + PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(), + platform::errors::OutOfRange( + "The index should be less than the size of inputs of " + "operator %s, but got index is %d and size is %d", + op_.Type(), idx, op_proto->inputs().size())); + return op_proto->inputs()[idx].name(); +} - auto& out_var_names = op_.Outputs(out); +std::string RuntimeInferShapeContext::GetOutputNameByIdx(size_t idx) const { + auto& op_proto = + paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; + PADDLE_ENFORCE_LT( + idx, op_proto->outputs().size(), + platform::errors::OutOfRange( + "The index should be less than the size of outputs of " + "operator %s, but got index is %d and size is %d", + op_.Type(), idx, op_proto->outputs().size())); + return op_proto->outputs()[idx].name(); +} - for (size_t i = 0; i < in_var_list.size(); ++i) { - if (out_var_names[i] == framework::kEmptyVarName) { - continue; - } +void RuntimeInferShapeContext::ShareDim(const std::string& in, const std::string& out, size_t i, + size_t j) { + auto in_it = ctx_.inputs.find(in); + auto out_it = ctx_.outputs.find(out); + PADDLE_ENFORCE_NE( + in_it, ctx_.inputs.end(), + platform::errors::NotFound("Input %s does not exist.", in)); + PADDLE_ENFORCE_NE( + out_it, ctx_.outputs.end(), + platform::errors::NotFound("Output %s does not exist.", out)); + PADDLE_ENFORCE_LT(i, in_it->second.size(), + platform::errors::InvalidArgument( + "The index of input dimension is out of range, " + "excepted index less than %zu, but received %zu.", + in_it->second.size(), i)); + PADDLE_ENFORCE_LT(j, out_it->second.size(), + platform::errors::InvalidArgument( + "The index of output dimension is out of range, " + "excepted index less than %zu, but received %zu.", + out_it->second.size(), j)); - Variable* in_var = in_var_list[i]; - if (!in_var->IsType()) return; - Variable* out_var = out_var_list[i]; - PADDLE_ENFORCE_EQ(out_var->IsType(), true, - platform::errors::PreconditionNotMet( - "The %d-th output of Output(%s) must be LoDTensor.", - i, out_var_names[i])); - auto& in_tensor = in_var->Get(); - auto* out_tensor = out_var->GetMutable(); - out_tensor->set_lod(in_tensor.lod()); -#ifdef PADDLE_WITH_MKLDNN - if (in_tensor.layout() != DataLayout::kMKLDNN) -#endif - out_tensor->set_layout(in_tensor.layout()); - } + Variable* in_var = in_it->second[i]; + Variable* out_var = out_it->second[j]; + + PADDLE_ENFORCE_EQ( + in_var->Type(), out_var->Type(), + platform::errors::InvalidArgument( + "The type of input (%s) and output (%s) are inconsistent.", in, + out)); + + if (in_var->IsType()) { + auto& in_sele_rows = in_var->Get(); + auto out_sele_rows = out_var->GetMutable(); + out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims()); + out_sele_rows->set_rows(in_sele_rows.rows()); + out_sele_rows->set_height(in_sele_rows.height()); + } else if (in_var->IsType()) { + auto& in_lod_tensor = in_var->Get(); + auto* out_lod_tensor = out_var->GetMutable(); + out_lod_tensor->Resize(in_lod_tensor.dims()); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently, the input type of ShareDim only can be LoDTensor " + "or SelectedRows.")); } +} + +void RuntimeInferShapeContext::ShareAllLoD(const std::string& in, + const std::string& out) const { + auto in_it = ctx_.inputs.find(in); + auto out_it = ctx_.outputs.find(out); + PADDLE_ENFORCE_NE(in_it, ctx_.inputs.end(), + platform::errors::NotFound( + "Input [%s] found error in Op [%s]", in, op_.Type())); + PADDLE_ENFORCE_NE( + out_it, ctx_.outputs.end(), + platform::errors::NotFound("Output [%s] found error in Op [%s]", out, + op_.Type())); + + auto& in_var_list = in_it->second; + auto& out_var_list = out_it->second; + + PADDLE_ENFORCE_EQ( + in_var_list.size(), out_var_list.size(), + platform::errors::PreconditionNotMet( + "Op [%s]: Input var size should be equal with output var size", + op_.Type())); + + auto& out_var_names = op_.Outputs(out); + + for (size_t i = 0; i < in_var_list.size(); ++i) { + if (out_var_names[i] == framework::kEmptyVarName) { + continue; + } - void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, - size_t j = 0) const override { - auto in_it = ctx_.inputs.find(in); - auto out_it = ctx_.outputs.find(out); - PADDLE_ENFORCE_NE( - in_it, ctx_.inputs.end(), - platform::errors::NotFound("Input %s does not exist.", in)); - PADDLE_ENFORCE_NE( - out_it, ctx_.outputs.end(), - platform::errors::NotFound("Output %s does not exist.", out)); - PADDLE_ENFORCE_LT(i, in_it->second.size(), - platform::errors::InvalidArgument( - "The index of input dimension is out of range, " - "excepted index less than %zu, but received %zu.", - in_it->second.size(), i)); - PADDLE_ENFORCE_LT(j, out_it->second.size(), - platform::errors::InvalidArgument( - "The index of output dimension is out of range, " - "excepted index less than %zu, but received %zu.", - out_it->second.size(), j)); - - Variable* in_var = in_it->second.at(i); + Variable* in_var = in_var_list[i]; if (!in_var->IsType()) return; - Variable* out_var = out_it->second.at(j); - PADDLE_ENFORCE_EQ( - out_var->IsType(), true, - platform::errors::InvalidArgument( - "The %zu-th output of Output(%s) must be LoDTensor.", j, out)); + Variable* out_var = out_var_list[i]; + PADDLE_ENFORCE_EQ(out_var->IsType(), true, + platform::errors::PreconditionNotMet( + "The %d-th output of Output(%s) must be LoDTensor.", + i, out_var_names[i])); auto& in_tensor = in_var->Get(); auto* out_tensor = out_var->GetMutable(); out_tensor->set_lod(in_tensor.lod()); - -// TODO(dzhwinter) : reuse ShareLoD in most operators. -// Need to call ShareLayout explicitly in sequence related ops. -// Shall we have a better method to shared info between in/out Tensor? #ifdef PADDLE_WITH_MKLDNN - // Fix me: ugly workaround below - // Correct solution: - // set_layout() should NOT be called here (i.e. ShareLoD). Instead, - // layout of output tensor should be set "manually" in Compute() - // of each OPKernel. The reason layout should NOT be shared between - // input and output "automatically" (now by InferShape()->ShareLoD()) - // is that layout transform may occur after InferShape(). - // Workaround: - // Skip set_layout() when input layout is kMKLDNN - // This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN - // OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called - // in Compute() if (in_tensor.layout() != DataLayout::kMKLDNN) #endif out_tensor->set_layout(in_tensor.layout()); } +} - int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "GetLoDLevel is only used in compile time. The calculation of " - "output's actual lod is different among operators so that should be " - "set in the runtime kernel.")); - } +// async infershape +std::vector RuntimeInferShapeContext::GetOutputsLod(const std::string& out) const { + auto out_it = ctx_.outputs.find(out); + auto& out_var_list = out_it->second; - void SetLoDLevel(const std::string& out, int32_t lod_level, - size_t j = 0) const override { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "SetLoDLevel is only used in compile time. The calculation of " - "output's actual lod is different among operators so that should be " - "set in the runtime kernel.")); + std::vector ret; + for (size_t i = 0; i < out_var_list.size(); ++i) { + Variable* out_var = out_var_list[i]; + if (out_var != nullptr) { + auto* out_tensor = out_var->GetMutable(); + ret.push_back(out_tensor->lod()); + } } + return ret; +} - bool IsRuntime() const override { return true; } +void RuntimeInferShapeContext::ShareLoD(const std::string& in, const std::string& out, + size_t i, size_t j) const { + auto in_it = ctx_.inputs.find(in); + auto out_it = ctx_.outputs.find(out); + PADDLE_ENFORCE_NE( + in_it, ctx_.inputs.end(), + platform::errors::NotFound("Input %s does not exist.", in)); + PADDLE_ENFORCE_NE( + out_it, ctx_.outputs.end(), + platform::errors::NotFound("Output %s does not exist.", out)); + PADDLE_ENFORCE_LT(i, in_it->second.size(), + platform::errors::InvalidArgument( + "The index of input dimension is out of range, " + "excepted index less than %zu, but received %zu.", + in_it->second.size(), i)); + PADDLE_ENFORCE_LT(j, out_it->second.size(), + platform::errors::InvalidArgument( + "The index of output dimension is out of range, " + "excepted index less than %zu, but received %zu.", + out_it->second.size(), j)); + + Variable* in_var = in_it->second.at(i); + if (!in_var->IsType()) return; + Variable* out_var = out_it->second.at(j); + PADDLE_ENFORCE_EQ( + out_var->IsType(), true, + platform::errors::InvalidArgument( + "The %zu-th output of Output(%s) must be LoDTensor.", j, out)); + auto& in_tensor = in_var->Get(); + auto* out_tensor = out_var->GetMutable(); + out_tensor->set_lod(in_tensor.lod()); - bool IsRunMKLDNNKernel() const override { - try { - auto& op_with_kernel = dynamic_cast(op_); - return ((op_with_kernel.kernel_type()) && - (op_with_kernel.kernel_type()->data_layout_ == - framework::DataLayout::kMKLDNN)); - } catch (std::bad_cast exp) { - return false; - } - } +// TODO(dzhwinter) : reuse ShareLoD in most operators. +// Need to call ShareLayout explicitly in sequence related ops. +// Shall we have a better method to shared info between in/out Tensor? +#ifdef PADDLE_WITH_MKLDNN + // Fix me: ugly workaround below + // Correct solution: + // set_layout() should NOT be called here (i.e. ShareLoD). Instead, + // layout of output tensor should be set "manually" in Compute() + // of each OPKernel. The reason layout should NOT be shared between + // input and output "automatically" (now by InferShape()->ShareLoD()) + // is that layout transform may occur after InferShape(). + // Workaround: + // Skip set_layout() when input layout is kMKLDNN + // This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN + // OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called + // in Compute() + if (in_tensor.layout() != DataLayout::kMKLDNN) +#endif + out_tensor->set_layout(in_tensor.layout()); +} - // TODO(paddle-dev): Can this be template? - std::vector GetInputVarPtrs( - const std::string& name) const override { - const std::vector& vars = InputVars(name); - std::vector res; - res.reserve(vars.size()); - res.insert(res.begin(), vars.begin(), vars.end()); - return res; - } +int32_t RuntimeInferShapeContext::GetLoDLevel(const std::string& in, size_t i) const { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "GetLoDLevel is only used in compile time. The calculation of " + "output's actual lod is different among operators so that should be " + "set in the runtime kernel.")); +} - std::vector GetOutputVarPtrs( - const std::string& name) const override { - const std::vector& vars = OutputVars(name); - std::vector res; - res.reserve(vars.size()); - res.insert(res.begin(), vars.begin(), vars.end()); - return res; - } +void RuntimeInferShapeContext::SetLoDLevel(const std::string& out, int32_t lod_level, + size_t j) const { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "SetLoDLevel is only used in compile time. The calculation of " + "output's actual lod is different among operators so that should be " + "set in the runtime kernel.")); +} - DDim GetInputDim(const std::string& name) const override { - const std::vector& vars = InputVars(name); - PADDLE_ENFORCE_EQ( - vars.size(), 1UL, - platform::errors::InvalidArgument( - "Input(%s) should hold one element, but now it holds %zu elements.", - name, vars.size())); - return this->GetDim(vars[0]); - } +bool RuntimeInferShapeContext::IsRuntime() const { return true; } - std::vector GetInputsDim(const std::string& name) const override { - const std::vector& vars = InputVars(name); - return GetDims(vars); +bool RuntimeInferShapeContext::IsRunMKLDNNKernel() const { + try { + auto& op_with_kernel = dynamic_cast(op_); + return ((op_with_kernel.kernel_type()) && + (op_with_kernel.kernel_type()->data_layout_ == + framework::DataLayout::kMKLDNN)); + } catch (std::bad_cast exp) { + return false; } +} - std::vector GetInputsVarType( - const std::string& name) const override { - return GetVarTypes(InputVars(name)); - } +// TODO(paddle-dev): Can this be template? +std::vector RuntimeInferShapeContext::GetInputVarPtrs( + const std::string& name) const { + const std::vector& vars = InputVars(name); + std::vector res; + res.reserve(vars.size()); + res.insert(res.begin(), vars.begin(), vars.end()); + return res; +} - std::vector GetOutputsVarType( - const std::string& name) const override { - return GetVarTypes(OutputVars(name)); - } +std::vector RuntimeInferShapeContext::GetOutputVarPtrs( + const std::string& name) const { + const std::vector& vars = OutputVars(name); + std::vector res; + res.reserve(vars.size()); + res.insert(res.begin(), vars.begin(), vars.end()); + return res; +} - void SetOutputDim(const std::string& name, const DDim& dim) override { - auto& vars = OutputVars(name); - PADDLE_ENFORCE_EQ( - vars.size(), 1UL, - platform::errors::InvalidArgument("Output(%s) should hold one element, " - "but now it holds %zu elements.", - name, vars.size())); - SetDim(vars[0], dim); - } +DDim RuntimeInferShapeContext::GetInputDim(const std::string& name) const { + const std::vector& vars = InputVars(name); + PADDLE_ENFORCE_EQ( + vars.size(), 1UL, + platform::errors::InvalidArgument( + "Input(%s) should hold one element, but now it holds %zu elements.", + name, vars.size())); + return this->GetDim(vars[0]); +} - void SetOutputsDim(const std::string& name, - const std::vector& dims) override { - auto& vars = OutputVars(name); - SetDims(vars, dims); - } +std::vector RuntimeInferShapeContext::GetInputsDim(const std::string& name) const { + const std::vector& vars = InputVars(name); + return GetDims(vars); +} - protected: - DDim GetDim(Variable* var) const { - PADDLE_ENFORCE_NOT_NULL( - var, platform::errors::InvalidArgument("Input variable is nullptr.")); - if (var->IsType()) { - return var->Get().dims(); - } else if (var->IsType()) { - return var->Get().GetCompleteDims(); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Only LoDTensor or SelectedRows support 'GetDim', but input " - "Variable's type is %s.", - ToTypeName(var->Type()))); - } +std::vector RuntimeInferShapeContext::GetOutputsDim(const std::string& name) const { + const std::vector& vars = OutputVars(name); + std::vector vars_res; + for (auto var : vars) { + if (var != nullptr) { + vars_res.push_back(var); + } } + return GetDims(vars_res); +} - std::vector GetDims(const std::vector& vars) const { - std::vector ret; - ret.reserve(vars.size()); - std::transform(vars.begin(), vars.end(), std::back_inserter(ret), - [this](Variable* var) { return this->GetDim(var); }); - return ret; - } +std::vector RuntimeInferShapeContext::GetInputsVarType( + const std::string& name) const { + return GetVarTypes(InputVars(name)); +} - std::vector GetRepeatedDims(const std::string& name) const override { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "GetRepeatedDims method only ban be used in compile time.")); - } +std::vector RuntimeInferShapeContext::GetOutputsVarType( + const std::string& name) const { + return GetVarTypes(OutputVars(name)); +} - void SetDim(Variable* var, const DDim& dim) { - if (var->IsType()) { - var->GetMutable()->Resize(dim); - } else if (var->IsType()) { - var->GetMutable()->set_height(dim[0]); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Variable type error, expect LoDTensor or SelectedRows, but received " - "(%s).", - ToTypeName(var->Type()))); - } - } +void RuntimeInferShapeContext::SetOutputDim(const std::string& name, const DDim& dim) { + auto& vars = OutputVars(name); + PADDLE_ENFORCE_EQ( + vars.size(), 1UL, + platform::errors::InvalidArgument("Output(%s) should hold one element, " + "but now it holds %zu elements.", + name, vars.size())); + SetDim(vars[0], dim); +} - void SetDims(const std::vector& vars, - const std::vector& dims) { - size_t length = vars.size(); - PADDLE_ENFORCE_EQ(length, dims.size(), - platform::errors::InvalidArgument( - "The number of input variables do not match the " - "number of input dimensions, the number of variables " - "is %zu, the number of dimensions is %zu.", - length, dims.size())); - for (size_t i = 0; i < length; ++i) { - if (vars[i] == nullptr) { - continue; - } - SetDim(vars[i], dims[i]); - } - } +void RuntimeInferShapeContext::SetOutputsDim(const std::string& name, + const std::vector& dims) { + auto& vars = OutputVars(name); + SetDims(vars, dims); +} - void SetRepeatedDims(const std::string& name, - const std::vector& dims) override { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "SetRepeatedDims method only can be used in compile time.")); +DDim RuntimeInferShapeContext::GetDim(Variable* var) const { + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::InvalidArgument("Input variable is nullptr.")); + if (var->IsType()) { + return var->Get().dims(); + } else if (var->IsType()) { + return var->Get().GetCompleteDims(); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only LoDTensor or SelectedRows support 'GetDim', but input " + "Variable's type is %s.", + ToTypeName(var->Type()))); } +} - std::vector GetVarTypes( - const std::vector& vars) const { - std::vector retv; - retv.resize(vars.size()); - std::transform(vars.begin(), vars.end(), retv.begin(), - std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType), - this, std::placeholders::_1)); - return retv; - } +std::vector RuntimeInferShapeContext::GetDims(const std::vector& vars) const { + std::vector ret; + ret.reserve(vars.size()); + std::transform(vars.begin(), vars.end(), std::back_inserter(ret), + [this](Variable* var) { return this->GetDim(var); }); + return ret; +} - proto::VarType::Type GetVarType(Variable* var) const { - return ToVarType(var->Type()); - } +std::vector RuntimeInferShapeContext::GetRepeatedDims(const std::string& name) const { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "GetRepeatedDims method only ban be used in compile time.")); +} - private: - const std::vector& InputVars(const std::string& name) const { - auto it = ctx_.inputs.find(name); - PADDLE_ENFORCE_NE( - it, ctx_.inputs.end(), - platform::errors::NotFound( - "Operator (%s) does not have the input (%s).", op_.Type(), name)); - return it->second; +void RuntimeInferShapeContext::SetDim(Variable* var, const DDim& dim) { + if (var->IsType()) { + var->GetMutable()->Resize(dim); + } else if (var->IsType()) { + var->GetMutable()->set_height(dim[0]); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Variable type error, expect LoDTensor or SelectedRows, but received " + "(%s).", + ToTypeName(var->Type()))); } +} - const std::vector& OutputVars(const std::string& name) const { - auto it = ctx_.outputs.find(name); - PADDLE_ENFORCE_NE( - it, ctx_.outputs.end(), - platform::errors::NotFound( - "Operator (%s) does not have the outputs (%s).", op_.Type(), name)); - return it->second; +void RuntimeInferShapeContext::SetDims(const std::vector& vars, + const std::vector& dims) { + size_t length = vars.size(); + PADDLE_ENFORCE_EQ(length, dims.size(), + platform::errors::InvalidArgument( + "The number of input variables do not match the " + "number of input dimensions, the number of variables " + "is %zu, the number of dimensions is %zu.", + length, dims.size())); + for (size_t i = 0; i < length; ++i) { + if (vars[i] == nullptr) { + continue; + } + SetDim(vars[i], dims[i]); } +} - const OperatorBase& op_; - const RuntimeContext& ctx_; -}; +void RuntimeInferShapeContext::SetRepeatedDims(const std::string& name, + const std::vector& dims) { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "SetRepeatedDims method only can be used in compile time.")); +} + +std::vector RuntimeInferShapeContext::GetVarTypes( + const std::vector& vars) const { + std::vector retv; + retv.resize(vars.size()); + std::transform(vars.begin(), vars.end(), retv.begin(), + std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType), + this, std::placeholders::_1)); + return retv; +} + +proto::VarType::Type RuntimeInferShapeContext::GetVarType(Variable* var) const { + return ToVarType(var->Type()); +} + +const std::vector& RuntimeInferShapeContext::InputVars(const std::string& name) const { + auto it = ctx_.inputs.find(name); + PADDLE_ENFORCE_NE( + it, ctx_.inputs.end(), + platform::errors::NotFound( + "Operator (%s) does not have the input (%s).", op_.Type(), name)); + return it->second; +} + +const std::vector& RuntimeInferShapeContext::OutputVars(const std::string& name) const { + auto it = ctx_.outputs.find(name); + PADDLE_ENFORCE_NE( + it, ctx_.outputs.end(), + platform::errors::NotFound( + "Operator (%s) does not have the outputs (%s).", op_.Type(), name)); + return it->second; +} static void CheckTensorNANOrInf(const std::string& op_type, const std::string& name, @@ -1212,6 +1228,12 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, this->Info().infer_shape_(&infer_shape_ctx); } +void OperatorWithKernel::RuntimeInferShape(const Scope& scope) const { + RuntimeContext ctx(Inputs(), Outputs(), scope); + RuntimeInferShapeContext infer_shape_ctx(*this, ctx); + this->Info().infer_shape_(&infer_shape_ctx); +} + void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { // To reduce the elapsed time of HasAttr, we use bool variable to record the @@ -1434,6 +1456,18 @@ void OperatorWithKernel::RunImpl(const Scope& scope, platform::RecordEvent record_event("compute", platform::TracerEventType::OperatorInner, 1, platform::EventRole::kInnerOp); + + + // infershape check + // RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx); + // std::vector> pre_dims; + // std::vector> pre_lod; + // auto outnames = Outputs(); + // for (auto& var_name_item : outnames) { + // pre_dims.push_back(infer_shape_ctx.GetOutputsDim(var_name_item.first)); + // pre_lod.push_back(infer_shape_ctx.GetOutputsLod(var_name_item.first)); + // } + if (run_phi_kernel_) { phi::KernelContext pt_kernel_context; // Do data transform before building KernelContext @@ -1446,6 +1480,55 @@ void OperatorWithKernel::RunImpl(const Scope& scope, (*kernel_func_)( ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); } + + // if (all_kernels_must_compute_runtime_shape_) { + // std::vector> after_dims; + // std::vector> after_lod; + // for (auto& var_name_item : outnames) { + // after_dims.push_back(infer_shape_ctx.GetOutputsDim(var_name_item.first)); + // after_lod.push_back(infer_shape_ctx.GetOutputsLod(var_name_item.first)); + // } + // if (pre_dims.size() != after_dims.size()) { + // CHECK(false) << "dims error: " << Info().Proto().type(); + // } + // for (size_t i = 0; i < pre_dims.size(); i++) { + // if (pre_dims[i].size() != after_dims[i].size()) { + // CHECK(false) << "dims error: " << Info().Proto().type(); + // } + // for (size_t j = 0; j < pre_dims[i].size(); j++) { + // if (pre_dims[i][j] != after_dims[i][j]) { + // CHECK(false) << "dims error: " << Info().Proto().type(); + // } + // } + // } + // if (pre_lod.size() != after_lod.size()) { + // CHECK(false) << "lods error: " << Info().Proto().type(); + // } + // for (size_t i = 0; i < pre_lod.size(); i++) { + // if (pre_lod[i].size() != after_lod[i].size()) { + // CHECK(false) << "lods error: " << Info().Proto().type(); + // } + // for (size_t j = 0; j < pre_lod[i].size(); j++) { + // auto& a = pre_lod[i][j]; + // auto& b = after_lod[i][j]; + // if (a.size() != b.size()) { + // CHECK(false) << "lods error: " << Info().Proto().type(); + // } + // for (size_t i = 0; i < a.size(); i++) { + // const auto &a_level = a[i]; + // const auto &b_level = b[i]; + // if (a_level.size() != b_level.size()) { + // CHECK(false) << "lods error: " << Info().Proto().type(); + // } + // for (size_t j = 0; j < a_level.size(); j++) { + // if (a_level[j] != b_level[j]) { + // CHECK(false) << "lods error: " << Info().Proto().type(); + // } + // } + // } + // } + // } + // } } if (!transfered_inplace_vars.empty()) { diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index f7fc83f1d6d30..31a085661c5de 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -31,6 +31,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/phi_utils.h" +#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/tensor.h" @@ -218,6 +219,9 @@ class OperatorBase { void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; } + virtual void SetIsRuntimeInferShape(bool x) {} + virtual void RuntimeInferShape(const Scope& scope) const {} + virtual void RuntimeInferShape(const Scope& scope, const platform::Place& place, const RuntimeContext& ctx) const {} @@ -579,9 +583,15 @@ class OperatorWithKernel : public OperatorBase { virtual void InferShape(InferShapeContext* ctx) const; + void SetIsRuntimeInferShape(bool x) override { + all_kernels_must_compute_runtime_shape_ = x; + } + void RuntimeInferShape(const Scope& scope, const platform::Place& place, const RuntimeContext& ctx) const override; + void RuntimeInferShape(const Scope& scope) const override; + proto::VarType::Type IndicateVarDataType(const ExecutionContext& ctx, const std::string& name) const; @@ -703,6 +713,84 @@ class OperatorWithKernel : public OperatorBase { mutable std::unique_ptr pt_kernel_; }; +class RuntimeInferShapeContext : public InferShapeContext { + public: + RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx) + : op_(op), ctx_(ctx) {} + + bool HasInput(const std::string &name) const override; + bool HasOutput(const std::string &name) const override; + bool HasAttr(const std::string &name) const override; + + std::vector GetInputsVarType( + const std::string &name) const override; + std::vector GetOutputsVarType( + const std::string &name) const override; + + bool HasInputs(const std::string &name) const override; + bool HasOutputs(const std::string &name) const override; + + DDim GetInputDim(const std::string &name) const override; + std::vector GetInputsDim(const std::string &name) const override; + + void SetOutputDim(const std::string &name, const DDim &dim) override; + void SetOutputsDim(const std::string &name, + const std::vector &dims) override; + std::string GetInputNameByIdx(size_t idx) const override; + std::string GetOutputNameByIdx(size_t idx) const override; + AttrReader Attrs() const override; + std::vector Inputs(const std::string &name) const override; + std::vector Outputs(const std::string &name) const override; + + void ShareDim(const std::string &in, const std::string &out, + size_t i = 0, size_t j = 0) override; + + void ShareLoD(const std::string &in, const std::string &out, + size_t i = 0, size_t j = 0) const override; + + void ShareAllLoD(const std::string &in, + const std::string &out) const override; + + int32_t GetLoDLevel(const std::string &in, size_t i = 0) const override; + + void SetLoDLevel(const std::string &out, int32_t lod_level, + size_t j = 0) const override; + + bool IsRuntime() const override; + + bool IsRunMKLDNNKernel() const override; + + std::vector GetInputVarPtrs( + const std::string &name) const override; + std::vector GetOutputVarPtrs( + const std::string &name) const override; + + std::vector GetOutputsLod(const std::string& out) const; + + std::vector GetOutputsDim(const std::string& name) const; + + +protected: + std::vector GetRepeatedDims(const std::string &name) const override; + void SetRepeatedDims(const std::string &name, + const std::vector &dims) override; + + DDim GetDim(Variable* var) const; + std::vector GetDims(const std::vector& vars) const; + void SetDim(Variable* var, const DDim& dim); + void SetDims(const std::vector& vars, + const std::vector& dims); + + proto::VarType::Type GetVarType(Variable* var) const; + std::vector GetVarTypes( + const std::vector& vars) const; +private: + const std::vector& InputVars(const std::string& name) const; + const std::vector& OutputVars(const std::string& name) const; + const OperatorBase& op_; + const RuntimeContext& ctx_; +}; + extern bool OpSupportGPU(const std::string& op_type); } // namespace framework diff --git a/paddle/fluid/framework/ps_gpu_trainer.cc b/paddle/fluid/framework/ps_gpu_trainer.cc index e0cf860e5bc7b..9a373be92b636 100644 --- a/paddle/fluid/framework/ps_gpu_trainer.cc +++ b/paddle/fluid/framework/ps_gpu_trainer.cc @@ -280,12 +280,13 @@ void PSGPUTrainer::InitTrainerEnv(const ProgramDesc& main_program, if (var->Persistable()) { auto name = var->Name(); Variable* root_var = root_scope_->FindVar(name); + + auto* ptr = scope->Var(name); + InitializeVariable(ptr, proto::VarType::LOD_TENSOR); if (!root_var) { continue; } LoDTensor* root_tensor = root_var->GetMutable(); - auto* ptr = scope->Var(name); - InitializeVariable(ptr, proto::VarType::LOD_TENSOR); LoDTensor* thread_tensor = ptr->GetMutable(); TensorCopy(*root_tensor, place, thread_tensor); } @@ -301,6 +302,19 @@ void PSGPUTrainer::InitTrainerEnv(const ProgramDesc& main_program, } } } + + for (size_t num = 0; num < places_.size(); ++num) { + Scope* scope = workers_[num]->GetThreadScope(); + for (size_t i = 0; i < need_merge_var_names_.size(); i++) { + Variable* thread_var = scope->FindVar(need_merge_var_names_[i]); + if (thread_var != nullptr) { + continue; + } + auto* ptr = scope->Var(need_merge_var_names_[i]); + InitializeVariable(ptr, proto::VarType::LOD_TENSOR); + } + } + place_ = place; return; } diff --git a/paddle/fluid/framework/ps_gpu_worker.cc b/paddle/fluid/framework/ps_gpu_worker.cc index 726f3b14c1f01..83eefc57a6dea 100644 --- a/paddle/fluid/framework/ps_gpu_worker.cc +++ b/paddle/fluid/framework/ps_gpu_worker.cc @@ -34,6 +34,42 @@ limitations under the License. */ namespace paddle { namespace framework { +std::atomic PSGPUWorker::shape_check_count_(16); +std::atomic PSGPUWorker::shape_check_flag_(false); + +void PSGPUWorker::CreateDeviceResource(const ProgramDesc& main_prog) { + this->HogwildWorker::CreateDeviceResource(main_prog); + if (scope_num_ != 1) { + auto& block = main_prog.Block(0); + for (int i = 0; i < scope_num_; i++) { + auto thread_tmp = &thread_scope_->NewScope(); + thread_scope_vec_.push_back(thread_tmp); + } + for (auto& scope : thread_scope_vec_) { + for (auto& var : block.AllVars()) { + std::string name = var->Name(); + if (!var->Persistable()) { + auto* ptr = scope->Var(var->Name()); + InitializeVariable(ptr, var->GetType()); + } + } + } + for (auto& op : ops_) { + op->SetIsRuntimeInferShape(true); + } + } +} + +void PSGPUWorker::BindingDataFeedMemory() { + if (scope_num_ == 1) { + this->HogwildWorker::BindingDataFeedMemory(); + } else { + for (auto& scope : thread_scope_vec_) { + device_reader_->AssignFeedVar(*scope); + } + } +} + void PSGPUWorker::Initialize(const TrainerDesc& desc) { param_ = desc.downpour_param(); dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); @@ -163,28 +199,29 @@ void PSGPUWorker::PrepareCudaGraph() { } if (!need_skip) { bool need_capture = false; - if (op_blacklist.find(op->Type()) == op_blacklist.end()) { - if (op->HasAttr(enable_cuda_graph_capture_attr_name) && op->Attr(enable_cuda_graph_capture_attr_name)) { - need_capture = true; - } - if (!need_capture) { - need_capture = true; - for (auto& input : op->InputVars()) { - if (var_whitelist.find(input) == var_whitelist.end()) { - need_capture = false; - break; - } - } - if (need_capture) { - for (auto& output : op->OutputVars(true)) { - if (var_whitelist.find(output) == var_whitelist.end()) { - need_capture = false; - break; - } - } - } - } - } + // if (op_blacklist.find(op->Type()) == op_blacklist.end()) { + // if (op->HasAttr(enable_cuda_graph_capture_attr_name) && op->Attr(enable_cuda_graph_capture_attr_name)) { + // need_capture = true; + // } + // if (!need_capture) { + // need_capture = true; + // for (auto& input : op->InputVars()) { + // if (var_whitelist.find(input) == var_whitelist.end()) { + // need_capture = false; + // break; + // } + // } + // if (need_capture) { + // for (auto& output : op->OutputVars(true)) { + // if (var_whitelist.find(output) == var_whitelist.end()) { + // need_capture = false; + // break; + // } + // } + // } + // } + // } + if (op_or_cudagraphs_.empty() || op_or_cudagraphs_.back().need_capture != need_capture) { op_or_cudagraphs_.emplace_back(); op_or_cudagraphs_.back().need_capture = need_capture; @@ -203,8 +240,89 @@ void PSGPUWorker::PrepareCudaGraph() { } } +PSGPUWorker::~PSGPUWorker() { + stop_token_.store(true); + for (auto& thread : task_threads_) { + if (thread.joinable()) { + thread.join(); + } + } +} + +int PSGPUWorker::OpRunAndShapeCheck(OperatorBase& op, + const Scope& scope, + const platform::Place& place) { + if (shape_check_flag_.load()) { + VLOG(0) << "Begin OpRunAndShapeCheck... " + << shape_check_count_.load(); + if (shape_check_count_.fetch_sub(1) <= 0) { + // shape_check_flag_ = false; + } + // before op run + InferShapeCheckData check_data; + auto& pre_dims = check_data.pre_dims; + auto& pre_lods = check_data.pre_lods; + auto& after_dims = check_data.after_dims; + auto& after_lods = check_data.after_lods; + RuntimeContext ctx(op.Inputs(), op.Outputs(), scope); + RuntimeInferShapeContext infer_shape_ctx(op, ctx); + auto outnames = op.Outputs(); + for (auto& var_name_item : outnames) { + pre_dims.push_back(infer_shape_ctx.GetOutputsDim(var_name_item.first)); + pre_lods.push_back(infer_shape_ctx.GetOutputsLod(var_name_item.first)); + } + + // op run + op.Run(scope, place); + + // after op run + for (auto& var_name_item : outnames) { + after_dims.push_back(infer_shape_ctx.GetOutputsDim(var_name_item.first)); + after_lods.push_back(infer_shape_ctx.GetOutputsLod(var_name_item.first)); + } + // auto& op_name = op.Info().Proto().type(); + CHECK(pre_dims.size() == after_dims.size()) + << "dims error, op name:" << op.Info().Proto().type(); + for (size_t i = 0; i < pre_dims.size(); i++) { + CHECK(pre_dims[i].size() == after_dims[i].size()) + << "dims error, op name:" << op.Info().Proto().type(); + for (size_t j = 0; j < pre_dims[i].size(); j++) { + CHECK(pre_dims[i][j] == after_dims[i][j]) + << "dims error, op name:" << op.Info().Proto().type(); + } + } + + CHECK(pre_lods.size() == after_lods.size()) + << "lods error, op name:" << op.Info().Proto().type(); + for (size_t i = 0; i < pre_lods.size(); i++) { + CHECK(pre_lods[i].size() == after_lods[i].size()) + << "lods error, op name:" << op.Info().Proto().type(); + for (size_t j = 0; j < pre_lods[i].size(); j++) { + auto& x = pre_lods[i][j]; + auto& y = after_lods[i][j]; + CHECK(x.size() == y.size()) + << "lods error, op name:" << op.Info().Proto().type(); + for (size_t i = 0; i < x.size(); i++) { + const auto &x_level = x[i]; + const auto &y_level = y[i]; + CHECK(x_level.size() == y_level.size()) + << "lods error, op name:" << op.Info().Proto().type(); + for (size_t j = 0; j < x_level.size(); j++) { + CHECK(x_level[j] == y_level[j]) + << "lods error, op name:" << op.Info().Proto().type(); + } + } + } + } + } else { + op.Run(scope, place); + } + return 0; +} + + void PSGPUWorker::TrainFiles() { - VLOG(3) << "Begin to train files"; + VLOG(0) << "Begin to train files"; platform::SetNumThreads(1); platform::Timer timeline; timeline.Start(); @@ -219,7 +337,81 @@ void PSGPUWorker::TrainFiles() { int graph_batch_size = 0; platform::SetDeviceId(place_.GetDeviceId()); - while ((cur_batch = device_reader_->Next()) > 0) { + + // async infershape + pack_is_end_.store(false); + if (scope_num_ != 1) { + for (size_t i = 0; i < thread_scope_vec_.size(); i++) { + TaskData task; + task.scope = thread_scope_vec_[i]; + free_task_queue_.Push(task); + } + // std::atomic* thread_run = new std::atomic(task_threads_); + thread_count_.store(task_threads_num_); + task_threads_.reserve(task_threads_num_); + for (int i = 0; i < task_threads_num_; i++) { + task_threads_.emplace_back(std::thread([this]() -> void { + while (true) { + auto pack = device_reader_->get_pack(nullptr); + if (pack == nullptr) { + int thread_num = thread_count_.fetch_sub(1); + if (thread_num == 1) { + pack_is_end_.store(true); + } + return; + } + auto task = free_task_queue_.Pop(); + task.pack = pack; + task.ins_num = pack->ins_num(); + device_reader_->PackToScope(task.pack, task.scope); + for (size_t ii = 0; ii < ops_.size(); ii++) { + auto& op = ops_[ii]; + bool need_skip = false; + for (auto t = 0u; t < skip_ops_.size(); ++t) { + if (op->Type().find(skip_ops_[t]) != std::string::npos) { + need_skip = true; + break; + } + } + if (!need_skip) { + op->RuntimeInferShape(*task.scope); + } + } + using_task_queue_.Push(task); + } + })); + } + } + + while (true) { + auto thread_scope = thread_scope_; + TaskData cur_task; + if (scope_num_ == 1) { + cur_batch = device_reader_->Next(); + } else { + while (true) { + if (using_task_queue_.Size() != 0) { + cur_task = using_task_queue_.Pop(); + cur_batch = cur_task.ins_num; + break; + } + bool is_end = pack_is_end_.load(); + if (is_end) { + if (using_task_queue_.Size() == 0) { + cur_batch = 0; + break; + } + } + std::this_thread::sleep_for( + std::chrono::microseconds(200)); + } + thread_scope = cur_task.scope; + } + + if (cur_batch <= 0) { + break; + } + total_ins_num += cur_batch; if (op_or_cudagraphs_.empty()) { @@ -233,7 +425,8 @@ void PSGPUWorker::TrainFiles() { } } if (!need_skip) { - op->Run(*thread_scope_, place_); + OpRunAndShapeCheck(*op, *thread_scope, place_); + // op->Run(*thread_scope, place_); } } graph_batch_size = cur_batch; @@ -249,7 +442,8 @@ void PSGPUWorker::TrainFiles() { } } if (!need_skip) { - op->Run(*thread_scope_, place_); + OpRunAndShapeCheck(*op, *thread_scope, place_); + // op->Run(*thread_scope, place_); } } } else { @@ -261,7 +455,8 @@ void PSGPUWorker::TrainFiles() { std::lock_guard lock(_capture_mutex); platform::BeginCUDAGraphCapture(place_, cudaStreamCaptureModeThreadLocal); for (auto& op : op_or_cuda_graph.ops) { - op->Run(*thread_scope_, place_); + OpRunAndShapeCheck(*op, *thread_scope, place_); + // op->Run(*thread_scope, place_); } op_or_cuda_graph.cudagraph = platform::EndCUDAGraphCapture(); } @@ -271,20 +466,21 @@ void PSGPUWorker::TrainFiles() { op_or_cuda_graph.cudagraph->Replay(); } else { for (auto& op : op_or_cuda_graph.ops) { - op->Run(*thread_scope_, place_); + OpRunAndShapeCheck(*op, *thread_scope, place_); + // op->Run(*thread_scope, place_); } } } } if (need_dump_field_) { - DumpField(*thread_scope_, dump_mode_, dump_interval_); + DumpField(*thread_scope, dump_mode_, dump_interval_); } if (need_dump_param_ && thread_id_ == 0) { - DumpParam(*thread_scope_, batch_cnt); + DumpParam(*thread_scope, batch_cnt); } for (std::string& var_name : check_nan_var_names_) { - Variable* var = thread_scope_->FindVar(var_name); + Variable* var = thread_scope->FindVar(var_name); if (var == nullptr) { continue; } @@ -299,11 +495,11 @@ void PSGPUWorker::TrainFiles() { std::lock_guard lock(mutex); VLOG(0) << "worker " << thread_id_ << ": " << var_name << " cantains inf or nan"; - auto all_vars = thread_scope_->LocalVarNames(); + auto all_vars = thread_scope->LocalVarNames(); std::stringstream ss; ss << "====== worker " << thread_id_ << "======\n"; for (auto& local_var : all_vars) { - platform::PrintVar(thread_scope_, local_var, local_var, &ss); + platform::PrintVar(thread_scope, local_var, local_var, &ss); ss << "\n"; } std::cout << ss.str() << std::endl; @@ -316,8 +512,13 @@ void PSGPUWorker::TrainFiles() { dev_ctx_->Wait(); PrintFetchVars(); - thread_scope_->DropKids(); + thread_scope->DropKids(); ++batch_cnt; + + if (scope_num_ != 1) { + device_reader_->get_pack(cur_task.pack); + free_task_queue_.Push(cur_task); + } } if (need_dump_field_ || need_dump_param_) { writer_.Flush(); @@ -330,7 +531,7 @@ void PSGPUWorker::TrainFiles() { void PSGPUWorker::TrainFilesWithProfiler() { platform::SetNumThreads(1); - VLOG(3) << "Begin to train files with profiler"; + VLOG(0) << "Begin to train files with profiler"; device_reader_->Start(); std::vector op_total_time; std::vector op_name; diff --git a/paddle/fluid/operators/coalesce_tensor_op.cc b/paddle/fluid/operators/coalesce_tensor_op.cc index 900fd4d8d292e..2a6d4fce1c798 100644 --- a/paddle/fluid/operators/coalesce_tensor_op.cc +++ b/paddle/fluid/operators/coalesce_tensor_op.cc @@ -186,6 +186,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel { if (size_of_dtype == -1) { size_of_dtype = framework::SizeOfType(dtype); } + + if (use_align && align_size <= 0) { + align_size = size_of_dtype; + } + GetMemSizeAndDtype(in_tensors, in_var_names, &numel, size_of_dtype, context.GetPlace(), use_align, align_size); @@ -317,9 +322,12 @@ class CoalesceTensorOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - if (ctx->IsRuntime()) { - return; - } + + // TODO trick to be fixed + // if (ctx->IsRuntime()) { + // return; + // } + auto use_align = ctx->Attrs().Get("use_align"); auto align_size = ctx->Attrs().Get("align_size"); auto size_of_dtype = ctx->Attrs().Get("user_defined_size_of_dtype"); @@ -330,6 +338,10 @@ class CoalesceTensorOp : public framework::OperatorWithKernel { size_of_dtype = framework::SizeOfType(dtype); } + if (use_align && align_size <= 0) { + align_size = size_of_dtype; + } + auto alignment = [](size_t size, size_t align_size) { size_t remaining = size % align_size; auto aligned_size = diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index adc0842fb3882..63ea3d9410ce6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -319,7 +319,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { auto y_grad_name = framework::GradVarName("Y"); if (ctx->HasOutput(x_grad_name)) { ctx->ShareDim("X", /*->*/ x_grad_name); - ctx->ShareLoD("X", /*->*/ x_grad_name); + ctx->ShareLoD(out_grad_name, /*->*/ x_grad_name); } if (ctx->HasOutput(y_grad_name)) { ctx->ShareDim("Y", /*->*/ y_grad_name); diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc index 4585103538877..e7abf4689b1fc 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc @@ -59,27 +59,63 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { "but received value is %d.", ins_dims[0].size())); - for (size_t i = 0; i < num_inputs; ++i) { - const auto dims = ins_dims[i]; - int rank = dims.size(); - if (use_cvm) { - PADDLE_ENFORCE_GT( - dims[rank - 1], 2, - platform::errors::InvalidArgument( - "Shape error in %lu id, the last dimension(embedding) of the " - "'X' tensor must be larger than 2.", - i)); + if (ctx->IsRuntime()) { + int batch_size = -1; + auto inputs_tensor = ctx->GetInputVarPtrs("X"); + for (size_t i = 0; i < num_inputs; ++i) { + const auto dims = ins_dims[i]; + int rank = dims.size(); + int cur_batch_size = 0; + framework::Variable* x_var = + BOOST_GET(framework::Variable*, inputs_tensor[i]); + const auto& x_tensor = x_var->Get(); + const auto& x_lod = x_tensor.lod(); + if (x_lod.size() > 0) { + cur_batch_size = x_lod[0].size() - 1; + } else { + cur_batch_size = x_tensor.dims()[0]; + } + if (batch_size == -1) { + batch_size = cur_batch_size; + } else { + PADDLE_ENFORCE_EQ(batch_size, cur_batch_size, + platform::errors::PreconditionNotMet( + "The batch size of all input should be same, " + "please check, last batch_size is %d, current " + "batch_size is %d", + batch_size, cur_batch_size)); + } + std::vector out_dim; + if (use_cvm) { + out_dim = {batch_size, dims[rank - 1]}; + } else { + out_dim = {batch_size, dims[rank - 1] - cvm_offset}; + } + outs_dims[i] = phi::make_ddim(out_dim); } - // input lod is not accessible here - std::vector out_dim; - if (use_cvm) { - out_dim = {-1, dims[rank - 1]}; - } else { - out_dim = {-1, dims[rank - 1] - cvm_offset}; + } else { + for (size_t i = 0; i < num_inputs; ++i) { + const auto dims = ins_dims[i]; + int rank = dims.size(); + if (use_cvm) { + PADDLE_ENFORCE_GT( + dims[rank - 1], 2, + platform::errors::InvalidArgument( + "Shape error in %lu id, the last dimension(embedding) of the " + "'X' tensor must be larger than 2.", + i)); + } + std::vector out_dim; + if (use_cvm) { + out_dim = {-1, dims[rank - 1]}; + } else { + out_dim = {-1, dims[rank - 1] - cvm_offset}; + } + outs_dims[i] = phi::make_ddim(out_dim); } - outs_dims[i] = phi::make_ddim(out_dim); } ctx->SetOutputsDim("Out", outs_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } protected: diff --git a/paddle/fluid/operators/pull_box_sparse_op.cc b/paddle/fluid/operators/pull_box_sparse_op.cc index 90e4fc9da0d61..9bf89c95bf222 100644 --- a/paddle/fluid/operators/pull_box_sparse_op.cc +++ b/paddle/fluid/operators/pull_box_sparse_op.cc @@ -47,6 +47,7 @@ class PullBoxSparseOp : public framework::OperatorWithKernel { outs_dims[i] = phi::make_ddim(out_dim); } ctx->SetOutputsDim("Out", outs_dims); + ctx->ShareAllLoD("Ids", "Out"); for (size_t i = 0; i < n_ids; ++i) { ctx->ShareLoD("Ids", "Out", i, i); } diff --git a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc index f452d9ffb7e89..60b016063caf3 100644 --- a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc @@ -129,7 +129,7 @@ void MultiplyGradKernel(const Context& dev_ctx, int axis, DenseTensor* dx, DenseTensor* dy) { - funcs::ElementwiseGradPreProcess(dout, dx); + // funcs::ElementwiseGradPreProcess(dout, dx); auto* out = &dout; // out is not necessary phi::funcs::ElemwiseGradCompute, MulGradDY>( dev_ctx, x, y, *out, dout, axis, dx, dy, MulGradDX(), MulGradDY()); diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index fae7978d3d2ea..e390b84378fcf 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -144,7 +144,7 @@ void MultiplyGradKernel(const Context& dev_ctx, int axis, DenseTensor* dx, DenseTensor* dy) { - funcs::ElementwiseGradPreProcess(dout, dx); + // funcs::ElementwiseGradPreProcess(dout, dx); ElementwiseMulGrad(dev_ctx, x, y, dout, dx, dy, axis); } diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index aba4a5f5fbd43..e290608db5631 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -32,7 +32,7 @@ void AddGradImpl(const Context& dev_ctx, DenseTensor* x_grad, DenseTensor* y_grad, GradFunc grad_func) { - phi::funcs::ElementwiseGradPreProcess(out_grad, x_grad); + // phi::funcs::ElementwiseGradPreProcess(out_grad, x_grad); auto* out = &out_grad; // Special case when y_grad is not needed and x_grad doesn't reduce if (x_grad != nullptr && y_grad == nullptr &&