diff --git a/cmake/external/box_ps.cmake b/cmake/external/box_ps.cmake index adfc6dba1f083..9747dd85e0875 100644 --- a/cmake/external/box_ps.cmake +++ b/cmake/external/box_ps.cmake @@ -19,7 +19,8 @@ IF((NOT DEFINED BOX_PS_VER) OR (NOT DEFINED BOX_PS_URL)) MESSAGE(STATUS "use pre defined download url") SET(BOX_PS_VER "0.1.1" CACHE STRING "" FORCE) SET(BOX_PS_NAME "box_ps" CACHE STRING "" FORCE) - SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps.tar.gz" CACHE STRING "" FORCE) + #SET(BOX_PS_URL "http://box-ps.gz.bcebos.com/box_ps.tar.gz" CACHE STRING "" FORCE) + SET(BOX_PS_URL "data-im.baidu.com:/home/work/var/CI_DATA/im/static/box_ps.tar.gz/box_ps.tar.gz.11" CACHE STRING "" FORCE) ENDIF() MESSAGE(STATUS "BOX_PS_NAME: ${BOX_PS_NAME}, BOX_PS_URL: ${BOX_PS_URL}") SET(BOX_PS_SOURCE_DIR "${THIRD_PARTY_PATH}/box_ps") diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index c553f2099747a..4f5b60acc2fb7 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -242,21 +242,6 @@ class DatasetImpl : public Dataset { virtual void SetFleetSendSleepSeconds(int seconds); virtual std::vector& GetInputRecord() { return input_records_; } - virtual std::set GetSlotsIdx( - const std::set& str_slots) { - std::set slots_idx; - - auto multi_slot_desc = data_feed_desc_.multi_slot_desc(); - for (int i = 0; i < multi_slot_desc.slots_size(); ++i) { - std::string cur_slot = multi_slot_desc.slots(i).name(); - if (str_slots.find(cur_slot) != str_slots.end()) { - slots_idx.insert(i); - } - } - - return slots_idx; - } - protected: virtual int ReceiveFromClient(int msg_type, int client_id, const std::string& msg); @@ -376,7 +361,7 @@ class PadBoxSlotDataset : public DatasetImpl { virtual int64_t GetShuffleDataSize() { return input_records_.size(); } // merge ins from multiple sources and unroll virtual void UnrollInstance(); - + virtual void ReceiveSuffleData(const int client_id, const char* msg, int len); // pre load virtual void LoadIndexIntoMemory() {} @@ -387,15 +372,30 @@ class PadBoxSlotDataset : public DatasetImpl { // shuffle data virtual void ShuffleData(int thread_num = -1); - public: - virtual void ReceiveSuffleData(const int client_id, const char* msg, int len); - public: void SetPSAgent(boxps::PSAgentBase* agent) { p_agent_ = agent; } boxps::PSAgentBase* GetPSAgent(void) { return p_agent_; } double GetReadInsTime(void) { return max_read_ins_span_; } double GetOtherTime(void) { return other_timer_.ElapsedSec(); } double GetMergeTime(void) { return max_merge_ins_span_; } + // aucrunner + std::set GetSlotsIdx(const std::set& str_slots) { + std::set slots_idx; + uint16_t idx = 0; + auto multi_slot_desc = data_feed_desc_.multi_slot_desc(); + for (int i = 0; i < multi_slot_desc.slots_size(); ++i) { + auto slot = multi_slot_desc.slots(i); + if (!slot.is_used() || slot.type().at(0) != 'u') { + continue; + } + if (str_slots.find(slot.name()) != str_slots.end()) { + slots_idx.insert(idx); + } + ++idx; + } + + return slots_idx; + } protected: void MergeInsKeys(const Channel& in); @@ -437,6 +437,7 @@ class InputTableDataset : public PadBoxSlotDataset { index_filelist_ = filelist; } virtual void LoadIndexIntoMemory(); + private: std::vector index_filelist_; }; diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index a4b5eb7156088..fc7b040bcb88a 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -1119,27 +1119,20 @@ class BoxWrapper { random_ins_pool_list[i].Resize(pool_size); } - std::unordered_set slot_set; + slot_eval_set_.clear(); for (size_t i = 0; i < slot_eval.size(); ++i) { for (const auto& slot : slot_eval[i]) { - slot_set.insert(slot); + slot_eval_set_.insert(slot); } } - for (size_t i = 0; i < slot_list.size(); ++i) { - if (slot_set.find(slot_list[i]) != slot_set.end()) { - slot_index_to_replace_.insert(static_cast(i)); - } - } - for (int i = 0; i < auc_runner_thread_num_; ++i) { - random_ins_pool_list[i].SetReplacedSlots(slot_index_to_replace_); - } + VLOG(0) << "AucRunner configuration: thread number[" << thread_num << "], pool size[" << pool_size << "], runner_group[" << phase_num_ - << "]"; - VLOG(0) << "Slots that need to be evaluated:"; - for (auto e : slot_index_to_replace_) { - VLOG(0) << e << ": " << slot_list[e]; - } + << "], eval size:[" << slot_eval_set_.size() << "]"; + // VLOG(0) << "Slots that need to be evaluated:"; + // for (auto e : slot_index_to_replace_) { + // VLOG(0) << e << ": " << slot_list[e]; + // } } void GetRandomReplace(std::vector* records); void PostUpdate(); @@ -1184,15 +1177,23 @@ class BoxWrapper { void RecordReplaceBack(std::vector* records, const std::set& slots); + // aucrunner + void SetReplacedSlots(const std::set& slot_index_to_replace) { + for (int i = 0; i < auc_runner_thread_num_; ++i) { + random_ins_pool_list[i].SetReplacedSlots(slot_index_to_replace); + } + } + const std::set& GetEvalSlotSet() { return slot_eval_set_; } + private: int mode_ = 0; // 0 means train/test 1 means auc_runner int auc_runner_thread_num_ = 1; bool init_done_ = false; paddle::framework::Channel pass_done_semi_; - std::set slot_index_to_replace_; std::vector random_ins_pool_list; std::mutex mutex4random_pool_; + std::set slot_eval_set_; }; /** * @brief file mgr @@ -1250,7 +1251,23 @@ class BoxHelper { } #endif } - +#ifdef PADDLE_WITH_BOX_PS + void LoadAucRunnerData(PadBoxSlotDataset* dataset, + boxps::PSAgentBase* agent) { + auto box_ptr = BoxWrapper::GetInstance(); + // init random pool slots replace + static bool slot_init = false; + if (!slot_init) { + slot_init = true; + auto slots_set = dataset->GetSlotsIdx(box_ptr->GetEvalSlotSet()); + box_ptr->SetReplacedSlots(slots_set); + } + box_ptr->AddReplaceFeasign(agent, box_ptr->GetFeedpassThreadNum()); + auto& records = dataset->GetInputRecord(); + box_ptr->PushAucRunnerResource(records.size()); + box_ptr->GetRandomReplace(&records); + } +#endif void ReadData2Memory() { platform::Timer timer; VLOG(3) << "Begin ReadData2Memory(), dataset[" << dataset_ << "]"; @@ -1287,10 +1304,7 @@ class BoxHelper { timer.Start(); // auc runner if (box_ptr->Mode() == 1) { - box_ptr->AddReplaceFeasign(agent, box_ptr->GetFeedpassThreadNum()); - auto& records = dataset->GetInputRecord(); - box_ptr->PushAucRunnerResource(records.size()); - box_ptr->GetRandomReplace(&records); + LoadAucRunnerData(dataset, agent); } box_ptr->EndFeedPass(agent); #endif @@ -1350,10 +1364,7 @@ class BoxHelper { auto box_ptr = BoxWrapper::GetInstance(); // auc runner if (box_ptr->Mode() == 1) { - box_ptr->AddReplaceFeasign(agent, box_ptr->GetFeedpassThreadNum()); - auto& records = dataset->GetInputRecord(); - box_ptr->PushAucRunnerResource(records.size()); - box_ptr->GetRandomReplace(&records); + LoadAucRunnerData(dataset, agent); } box_ptr->EndFeedPass(agent); timer.Pause(); diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc index 1215c0a9fb741..a78219dc5a32c 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc @@ -39,6 +39,8 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { const size_t num_inputs = ins_dims.size(); std::vector outs_dims; outs_dims.resize(num_inputs); + bool use_cvm = ctx->Attrs().Get("use_cvm"); + bool clk_filter = ctx->Attrs().Get("clk_filter"); // need filter quant_ratio more than zero if (ctx->Attrs().Get("need_filter")) { @@ -66,7 +68,7 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { for (size_t i = 0; i < num_inputs; ++i) { const auto dims = ins_dims[i]; int rank = dims.size(); - if (ctx->Attrs().Get("use_cvm")) { + if (use_cvm) { PADDLE_ENFORCE_GT( dims[rank - 1], 2, "Shape error in %lu id, the last dimension(embedding) of the " @@ -75,8 +77,12 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { } // input lod is not accessible here std::vector out_dim; - if (ctx->Attrs().Get("use_cvm")) { - out_dim = {-1, dims[rank - 1]}; + if (use_cvm) { + if (clk_filter) { + out_dim = {-1, dims[rank - 1] - 1}; + } else { + out_dim = {-1, dims[rank - 1]}; + } } else { out_dim = {-1, dims[rank - 1] - cvm_offset}; } @@ -122,6 +128,7 @@ class FusedSeqpoolCVMOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("threshold", "(float, default 0.96)").SetDefault(0.96); AddAttr("cvm_offset", "(int, default 2)").SetDefault(2); AddAttr("quant_ratio", "(int, default 128)").SetDefault(0); + AddAttr("clk_filter", "(bool, default false)").SetDefault(false); AddComment(R"DOC( Fuse multiple pairs of Sequence Pool and CVM Operator. @@ -139,6 +146,8 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputsDim("X"); auto cvm_dims = ctx->GetInputDim("CVM"); const int cvm_offset = ctx->Attrs().Get("cvm_offset"); + bool use_cvm = ctx->Attrs().Get("use_cvm"); + bool clk_filter = ctx->Attrs().Get("clk_filter"); PADDLE_ENFORCE_EQ( cvm_dims.size(), 2, @@ -151,9 +160,13 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel { "The rank of output grad must equal to Input(X). But " "received: input rank %u, input shape [%s].", og_dims[i].size(), og_dims[i])); - if (ctx->Attrs().Get("use_cvm")) { + if (use_cvm) { + auto o_dim = og_dims[i][og_dims[i].size() - 1]; + if (clk_filter) { // filter clk need + 1 + o_dim = o_dim + 1; + } PADDLE_ENFORCE_EQ( - og_dims[i][og_dims[i].size() - 1], x_dims[i][og_dims[i].size() - 1], + o_dim, x_dims[i][og_dims[i].size() - 1], platform::errors::InvalidArgument( "The dimension mismatch between Input(OUT@GRAD) and " "Input(X). Received Input(OUT@GRAD): input rank %u, " diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu index 1827b1dfc1195..62db66dd3c278 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu @@ -142,6 +142,27 @@ __global__ void FusedCVMKernelWithCVM(const size_t N, T **output_values, } } } +// join only need show input +template +__global__ void FusedCVMKernelWithShow(const size_t N, T **output_values, + T **seqpool_output_values, + const int batch_size, + const int embedding_size, + const int noclk_embedding_size) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / noclk_embedding_size; + int offset = i % noclk_embedding_size; + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + if (offset == 0) { // show + *(output_values[x] + y * noclk_embedding_size) = + log(*(seqpool_output_values[x] + y * embedding_size) + 1); + } else { // skip click offset + 1 + *(output_values[x] + y * noclk_embedding_size + offset) = + *(seqpool_output_values[x] + y * embedding_size + offset + 1); + } + } +} // update not need show click input template __global__ void FusedCVMKernelNoCVM(const size_t N, T **output_values, @@ -170,7 +191,8 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, const int slot_num, const int embedding_size, const float padding_value, const bool use_cvm, const int cvm_offset, float need_filter, float show_coeff, - float clk_coeff, float threshold, const int quant_ratio) { + float clk_coeff, float threshold, const int quant_ratio, + const bool clk_filter) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( BOOST_GET_CONST(platform::CUDAPlace, place))) @@ -221,9 +243,18 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, } // second log if (use_cvm) { - FusedCVMKernelWithCVM<<>>( - N, gpu_output_values, gpu_seqpool_output_values, batch_size, - embedding_size, cvm_offset); + if (clk_filter) { // skip click + N = static_cast(batch_size * slot_num * (embedding_size - 1)); + FusedCVMKernelWithShow<<>>(N, gpu_output_values, + gpu_seqpool_output_values, batch_size, + embedding_size, embedding_size - 1); + } else { + FusedCVMKernelWithCVM<<>>(N, gpu_output_values, + gpu_seqpool_output_values, batch_size, + embedding_size, cvm_offset); + } } else { // not need show click input N = static_cast(batch_size * slot_num * @@ -256,6 +287,30 @@ __global__ void FusedSeqpoolCVMGradKernelWithCVM( } } } +// join only show not has click +template +__global__ void FusedSeqpoolCVMGradKernelWithShow( + const size_t N, T **out_grads_values, T **in_grads_values, T **cvm_values, + size_t **lods_values, const int batch_size, const int embedding_size, + const int cvm_offset) { + CUDA_KERNEL_LOOP(i, N) { + int key = i / embedding_size; + int offset = i % embedding_size; // embedx offset + int x = key / batch_size; // slot id + int y = key % batch_size; // ins id + + T &val = + (offset < cvm_offset) + ? *(cvm_values[x] + y * cvm_offset + offset) + : *(out_grads_values[x] + y * (embedding_size - 1) + offset - 1); + + auto &start = *(lods_values[x] + y); + auto &end = *(lods_values[x] + y + 1); + for (auto k = start; k < end; ++k) { + *(in_grads_values[x] + k * embedding_size + offset) = val; + } + } +} // update grad template __global__ void FusedSeqpoolCVMGradKernelNoCVM( @@ -288,7 +343,7 @@ void FusedSeqpoolCVMGrad(const paddle::platform::Place &place, const std::vector &lods, const int batch_size, const int slot_num, const int embedding_size, const bool use_cvm, - const int cvm_offset) { + const int cvm_offset, const bool clk_filter) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get( BOOST_GET_CONST(platform::CUDAPlace, place))) @@ -320,11 +375,18 @@ void FusedSeqpoolCVMGrad(const paddle::platform::Place &place, size_t N = static_cast(batch_size * slot_num * embedding_size); if (use_cvm) { - // join grad - FusedSeqpoolCVMGradKernelWithCVM<<>>( - N, gpu_out_grads_values, gpu_in_grads_values, gpu_cvm_values, - lods_values, batch_size, embedding_size, cvm_offset); + if (clk_filter) { + FusedSeqpoolCVMGradKernelWithShow<<>>( + N, gpu_out_grads_values, gpu_in_grads_values, gpu_cvm_values, + lods_values, batch_size, embedding_size, cvm_offset); + } else { + // join grad + FusedSeqpoolCVMGradKernelWithCVM<<>>( + N, gpu_out_grads_values, gpu_in_grads_values, gpu_cvm_values, + lods_values, batch_size, embedding_size, cvm_offset); + } } else { // update grad FusedSeqpoolCVMGradKernelNoCVM<< { float threshold = ctx.Attr("threshold"); const int cvm_offset = ctx.Attr("cvm_offset"); const int quant_ratio = ctx.Attr("quant_ratio"); + bool clk_filter = ctx.Attr("clk_filter"); int embedding_size = inputs[0]->numel() / inputs[0]->dims()[0]; int batch_size = -1; @@ -376,7 +439,11 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel { input_data[i] = reinterpret_cast(input->data()); auto *output = outputs[i]; if (use_cvm) { - output->Resize({batch_size, embedding_size}); + if (clk_filter) { + output->Resize({batch_size, embedding_size - 1}); + } else { + output->Resize({batch_size, embedding_size}); + } } else { output->Resize({batch_size, embedding_size - cvm_offset}); } @@ -391,7 +458,8 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel { FusedSeqpoolCVM(ctx.GetPlace(), input_data, output_data, seqpool_output_data, lods_data, batch_size, slot_size, embedding_size, padding_value, use_cvm, cvm_offset, - need_filter, show_coeff, clk_coeff, threshold, quant_ratio); + need_filter, show_coeff, clk_coeff, threshold, quant_ratio, + clk_filter); } }; @@ -406,6 +474,7 @@ class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel { std::string pooltype = ctx.Attr("pooltype"); auto use_cvm = ctx.Attr("use_cvm"); const int cvm_offset = ctx.Attr("cvm_offset"); + bool clk_filter = ctx.Attr("clk_filter"); const auto slot_size = in_grads.size(); std::vector out_grads_data(slot_size); @@ -438,7 +507,7 @@ class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel { } FusedSeqpoolCVMGrad(ctx.GetPlace(), out_grads_data, in_grads_data, cvm_data, lods_data, batch_size, slot_size, embedding_size, - use_cvm, cvm_offset); + use_cvm, cvm_offset, clk_filter); } }; diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 032c991baf457..78ec471f0da87 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -1531,7 +1531,8 @@ def fused_seqpool_cvm(input, clk_coeff=1.0, threshold=0.96, cvm_offset=2, - quant_ratio=0): + quant_ratio=0, + clk_filter=False): """ **Notes: The Op only receives List of LoDTensor as input, only support SUM pooling now. :attr:`input`. @@ -1583,7 +1584,8 @@ def fused_seqpool_cvm(input, "show_coeff": show_coeff, "clk_coeff": clk_coeff, "threshold": threshold, - "quant_ratio": quant_ratio + "quant_ratio": quant_ratio, + "clk_filter": clk_filter }) return outs