From ce94035f185f5c832ff5b7e61f99052831e57ff4 Mon Sep 17 00:00:00 2001 From: xiayanming Date: Fri, 22 Dec 2023 19:20:26 +0800 Subject: [PATCH] support fused svm and slice for bscvrq --- paddle/fluid/framework/device_worker.cc | 10 ++++++++++ .../fluid/operators/fused/fused_seqpool_cvm_op.cc | 1 + .../fluid/operators/fused/fused_seqpool_cvm_op.cu | 14 ++++++++++---- .../operators/fused/fused_seqpool_cvm_op_xpu.cc | 5 ++++- python/paddle/fluid/contrib/layers/nn.py | 6 ++++-- 5 files changed, 29 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/device_worker.cc b/paddle/fluid/framework/device_worker.cc index b6216b9572c6a..d173148718ae8 100644 --- a/paddle/fluid/framework/device_worker.cc +++ b/paddle/fluid/framework/device_worker.cc @@ -183,6 +183,10 @@ std::string PrintLodTensor(const phi::DenseTensor* tensor, proto::VarType::FP64) { out_val = PrintLodTensorType( tensor, start, end, separator, need_leading_separator); + } else if (framework::TransToProtoVarType(tensor->dtype()) == + proto::VarType::BOOL) { + out_val = PrintLodTensorType( + tensor, start, end, separator, need_leading_separator); } else { out_val = "unsupported type"; } @@ -206,6 +210,10 @@ void PrintLodTensor(const phi::DenseTensor* tensor, proto::VarType::FP64) { PrintLodTensorType( tensor, start, end, out_val, separator, need_leading_separator); + } else if (framework::TransToProtoVarType(tensor->dtype()) == + proto::VarType::BOOL) { + PrintLodTensorType( + tensor, start, end, out_val, separator, need_leading_separator); } else { out_val += "unsupported type"; } @@ -512,6 +520,8 @@ void PrintLodTensor(const Tensor* tensor, PrintLodTensorFmtType(tensor, start, end, ":%d", out); } else if (dtype == proto::VarType::INT16) { PrintLodTensorFmtType(tensor, start, end, ":%d", out); + } else if (dtype == proto::VarType::BOOL) { + PrintLodTensorFmtType(tensor, start, end, ":%d", out); } else { out->append("unsupported type"); } diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc index 1dfbc30d06606..474863cad18b9 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc @@ -142,6 +142,7 @@ class FusedSeqpoolCVMOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("embed_thres_size", "(int, default 0)").SetDefault(0); AddAttr("embedx_concate_size", "(int, default 1)").SetDefault(1); AddAttr("embedx_concate_filter", "(bool, default false)").SetDefault(false); + AddAttr("fix_ctr_to_click", "(bool, default false)").SetDefault(false); AddComment(R"DOC( Fuse multiple pairs of Sequence Pool and CVM Operator. diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu index 58d9478a7b77f..d7ba888aa1dd5 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu @@ -233,7 +233,7 @@ __global__ void FusedCVMKernelWithCVM(const size_t N, T **output_values, T **seqpool_output_values, const int batch_size, const int embedding_size, - const int cvm_offset) { + const int cvm_offset, bool fix_ctr_to_click) { CUDA_KERNEL_LOOP(i, N) { int key = i / embedding_size; int offset = i % embedding_size; @@ -246,6 +246,10 @@ __global__ void FusedCVMKernelWithCVM(const size_t N, T **output_values, *out = log(in[0] + 1); } else if (offset == 1) { // ctr = log(click + 1) - log(show + 1) *out = log(in[1] + 1) - log(in[0] + 1); + // fix_ctr_to_click: click += show + if (fix_ctr_to_click) { + *out = log(in[1] + 1); + } } else { *out = in[offset]; } @@ -352,7 +356,7 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, float clk_coeff, float threshold, float embed_threshold, const int quant_ratio, const bool clk_filter, const int embed_thres_size, const int embedx_concate_size, - bool embedx_concate_filter) { + bool embedx_concate_filter, bool fix_ctr_to_click) { auto stream = dynamic_cast( platform::DeviceContextPool::Instance().Get(place)) ->stream(); @@ -433,7 +437,7 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, FusedCVMKernelWithCVM<<>>(N, gpu_output_values, gpu_seqpool_output_values, batch_size, - embedding_size, cvm_offset); + embedding_size, cvm_offset, fix_ctr_to_click); } } else { // not need show click input @@ -690,6 +694,7 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel { const int embed_thres_size = ctx.Attr("embed_thres_size"); const int embedx_concate_size = ctx.Attr("embedx_concate_size"); bool embedx_concate_filter = ctx.Attr("embedx_concate_filter"); + bool fix_ctr_to_click = ctx.Attr("fix_ctr_to_click"); framework::GPULodVector gpu_lods[slot_size]; auto place = ctx.GetPlace(); @@ -737,7 +742,8 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel { embedding_size, padding_value, use_cvm, cvm_offset, need_filter, embed_threshold_filter, show_coeff, clk_coeff, threshold, embed_threshold, quant_ratio, clk_filter, - embed_thres_size, embedx_concate_size, embedx_concate_filter); + embed_thres_size, embedx_concate_size, embedx_concate_filter, + fix_ctr_to_click); } }; diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op_xpu.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_op_xpu.cc index b38b8aa82be9b..b037c8ca63f02 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op_xpu.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op_xpu.cc @@ -79,6 +79,7 @@ class FusedSeqpoolCVMOpXPUKernel : public framework::OpKernel { bool embed_threshold_filter = ctx.Attr("embed_threshold_filter"); float embed_threshold = ctx.Attr("embed_threshold"); int embed_thres_size = ctx.Attr("embed_thres_size"); + bool fix_ctr_to_click = ctx.Attr("fix_ctr_to_click"); auto x0_lod = ins[0]->lod(); auto x0_dims = ins[0]->dims(); @@ -153,7 +154,8 @@ class FusedSeqpoolCVMOpXPUKernel : public framework::OpKernel { cvm_offset, embed_threshold_filter, embed_threshold, - embed_thres_size); + embed_thres_size, + fix_ctr_to_click); PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, platform::errors::External( "The sequence_sum_pool_cvm XPU OP return wrong value[%d %s]", @@ -178,6 +180,7 @@ class FusedSeqpoolCVMGradOpXPUKernel : public framework::OpKernel { bool clk_filter = ctx.Attr("clk_filter"); auto cvm_offset = ctx.Attr("cvm_offset"); int embed_thres_size = ctx.Attr("embed_thres_size"); + int slot_num = dxs.size(); auto xpu_context = ctx.template device_context().x_context(); auto place = ctx.GetPlace(); diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 84b422fe1722e..26b0b4a9f661b 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -1759,7 +1759,8 @@ def fused_seqpool_cvm(input, clk_filter=False, embed_thres_size=0, embedx_concate_size=1, - embedx_concate_filter=False): + embedx_concate_filter=False, + fix_ctr_to_click=False): """ **Notes: The Op only receives List of LoDTensor as input, only support SUM pooling now. :attr:`input`. @@ -1818,7 +1819,8 @@ def fused_seqpool_cvm(input, "clk_filter": clk_filter, "embed_thres_size": embed_thres_size, "embedx_concate_size": embedx_concate_size, - "embedx_concate_filter": embedx_concate_filter + "embedx_concate_filter": embedx_concate_filter, + "fix_ctr_to_click": fix_ctr_to_click }) return outs