Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#37 from xymyeah/bscvrq_opt
Browse files Browse the repository at this point in the history
support fused svm and slice for bscvrq
  • Loading branch information
xymyeah committed Dec 26, 2023
2 parents bb54a08 + ce94035 commit 44f6191
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 7 deletions.
10 changes: 10 additions & 0 deletions paddle/fluid/framework/device_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ std::string PrintLodTensor(const phi::DenseTensor* tensor,
proto::VarType::FP64) {
out_val = PrintLodTensorType<double>(
tensor, start, end, separator, need_leading_separator);
} else if (framework::TransToProtoVarType(tensor->dtype()) ==
proto::VarType::BOOL) {
out_val = PrintLodTensorType<bool>(
tensor, start, end, separator, need_leading_separator);
} else {
out_val = "unsupported type";
}
Expand All @@ -206,6 +210,10 @@ void PrintLodTensor(const phi::DenseTensor* tensor,
proto::VarType::FP64) {
PrintLodTensorType<double>(
tensor, start, end, out_val, separator, need_leading_separator);
} else if (framework::TransToProtoVarType(tensor->dtype()) ==
proto::VarType::BOOL) {
PrintLodTensorType<bool>(
tensor, start, end, out_val, separator, need_leading_separator);
} else {
out_val += "unsupported type";
}
Expand Down Expand Up @@ -512,6 +520,8 @@ void PrintLodTensor(const Tensor* tensor,
PrintLodTensorFmtType<int, int>(tensor, start, end, ":%d", out);
} else if (dtype == proto::VarType::INT16) {
PrintLodTensorFmtType<int16_t, int16_t>(tensor, start, end, ":%d", out);
} else if (dtype == proto::VarType::BOOL) {
PrintLodTensorFmtType<bool, bool>(tensor, start, end, ":%d", out);
} else {
out->append("unsupported type");
}
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class FusedSeqpoolCVMOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("embed_thres_size", "(int, default 0)").SetDefault(0);
AddAttr<int>("embedx_concate_size", "(int, default 1)").SetDefault(1);
AddAttr<bool>("embedx_concate_filter", "(bool, default false)").SetDefault(false);
AddAttr<bool>("fix_ctr_to_click", "(bool, default false)").SetDefault(false);

AddComment(R"DOC(
Fuse multiple pairs of Sequence Pool and CVM Operator.
Expand Down
14 changes: 10 additions & 4 deletions paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ __global__ void FusedCVMKernelWithCVM(const size_t N, T **output_values,
T **seqpool_output_values,
const int batch_size,
const int embedding_size,
const int cvm_offset) {
const int cvm_offset, bool fix_ctr_to_click) {
CUDA_KERNEL_LOOP(i, N) {
int key = i / embedding_size;
int offset = i % embedding_size;
Expand All @@ -246,6 +246,10 @@ __global__ void FusedCVMKernelWithCVM(const size_t N, T **output_values,
*out = log(in[0] + 1);
} else if (offset == 1) { // ctr = log(click + 1) - log(show + 1)
*out = log(in[1] + 1) - log(in[0] + 1);
// fix_ctr_to_click: click += show
if (fix_ctr_to_click) {
*out = log(in[1] + 1);
}
} else {
*out = in[offset];
}
Expand Down Expand Up @@ -352,7 +356,7 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place,
float clk_coeff, float threshold, float embed_threshold,
const int quant_ratio, const bool clk_filter,
const int embed_thres_size, const int embedx_concate_size,
bool embedx_concate_filter) {
bool embedx_concate_filter, bool fix_ctr_to_click) {
auto stream = dynamic_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
Expand Down Expand Up @@ -433,7 +437,7 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place,
FusedCVMKernelWithCVM<<<GET_BLOCK(N), PADDLE_CUDA_NUM_THREADS, 0,
stream>>>(N, gpu_output_values,
gpu_seqpool_output_values, batch_size,
embedding_size, cvm_offset);
embedding_size, cvm_offset, fix_ctr_to_click);
}
} else {
// not need show click input
Expand Down Expand Up @@ -690,6 +694,7 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel<T> {
const int embed_thres_size = ctx.Attr<int>("embed_thres_size");
const int embedx_concate_size = ctx.Attr<int>("embedx_concate_size");
bool embedx_concate_filter = ctx.Attr<bool>("embedx_concate_filter");
bool fix_ctr_to_click = ctx.Attr<bool>("fix_ctr_to_click");

framework::GPULodVector gpu_lods[slot_size];
auto place = ctx.GetPlace();
Expand Down Expand Up @@ -737,7 +742,8 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel<T> {
embedding_size, padding_value, use_cvm, cvm_offset,
need_filter, embed_threshold_filter, show_coeff, clk_coeff,
threshold, embed_threshold, quant_ratio, clk_filter,
embed_thres_size, embedx_concate_size, embedx_concate_filter);
embed_thres_size, embedx_concate_size, embedx_concate_filter,
fix_ctr_to_click);
}
};

Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/operators/fused/fused_seqpool_cvm_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class FusedSeqpoolCVMOpXPUKernel : public framework::OpKernel<T> {
bool embed_threshold_filter = ctx.Attr<bool>("embed_threshold_filter");
float embed_threshold = ctx.Attr<float>("embed_threshold");
int embed_thres_size = ctx.Attr<int>("embed_thres_size");
bool fix_ctr_to_click = ctx.Attr<bool>("fix_ctr_to_click");

auto x0_lod = ins[0]->lod();
auto x0_dims = ins[0]->dims();
Expand Down Expand Up @@ -153,7 +154,8 @@ class FusedSeqpoolCVMOpXPUKernel : public framework::OpKernel<T> {
cvm_offset,
embed_threshold_filter,
embed_threshold,
embed_thres_size);
embed_thres_size,
fix_ctr_to_click);
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The sequence_sum_pool_cvm XPU OP return wrong value[%d %s]",
Expand All @@ -178,6 +180,7 @@ class FusedSeqpoolCVMGradOpXPUKernel : public framework::OpKernel<T> {
bool clk_filter = ctx.Attr<bool>("clk_filter");
auto cvm_offset = ctx.Attr<int>("cvm_offset");
int embed_thres_size = ctx.Attr<int>("embed_thres_size");

int slot_num = dxs.size();
auto xpu_context = ctx.template device_context<DeviceContext>().x_context();
auto place = ctx.GetPlace();
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/fluid/contrib/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,7 +1759,8 @@ def fused_seqpool_cvm(input,
clk_filter=False,
embed_thres_size=0,
embedx_concate_size=1,
embedx_concate_filter=False):
embedx_concate_filter=False,
fix_ctr_to_click=False):
"""
**Notes: The Op only receives List of LoDTensor as input, only support SUM pooling now.
:attr:`input`.
Expand Down Expand Up @@ -1818,7 +1819,8 @@ def fused_seqpool_cvm(input,
"clk_filter": clk_filter,
"embed_thres_size": embed_thres_size,
"embedx_concate_size": embedx_concate_size,
"embedx_concate_filter": embedx_concate_filter
"embedx_concate_filter": embedx_concate_filter,
"fix_ctr_to_click": fix_ctr_to_click
})

return outs
Expand Down

0 comments on commit 44f6191

Please sign in to comment.