Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fused_seqpool_cvm_with_conv_op support feedroiq_dxd_fusion #49

Merged
merged 5 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ if (WITH_BOX_PS OR WITH_XPU_KP)
CACHE STRING "" FORCE)
#"https://klx-sdk-release-public.su.bcebos.com/xdnn/release/2.6.0.1/${XPU_XDNN_DIR_NAME}.tar.gz"
set(XPU_XDNN_URL
"https://klx-sdk-release-public.su.bcebos.com/xdnn_train/dev/paddlebox/20240125/${XPU_XDNN_DIR_NAME}.tar.gz"
"https://klx-sdk-release-public.su.bcebos.com/xdnn_train/dev/paddlebox/20240220/${XPU_XDNN_DIR_NAME}.tar.gz"
CACHE STRING "" FORCE)
set(SCALOPUS_URL
"https://klx-sdk-release-public.su.bcebos.com/xdnn_train/dev/paddlebox/20230306/scalopus.tar.gz"
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/fused/fused_concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,15 @@ class FusedConcatOp : public framework::OperatorWithKernel {
"Input length more than zero %d.", length));

auto ins_dims = ctx->GetInputsDim("X");
int batch_size = ins_dims[0][0];
const int input_nums = ins_dims.size();
PADDLE_ENFORCE_GT(input_nums, 1UL,
platform::errors::InvalidArgument(
"Input tensors count should be greater than 0, "
"but received value is %d.",
ins_dims.size()));
std::vector<int64_t> out_dim;
out_dim = {-1, length * input_nums};
out_dim = {batch_size, length * input_nums};
ctx->SetOutputDim("Out", phi::make_ddim(out_dim));
}

Expand Down
17 changes: 12 additions & 5 deletions paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class FusedSeqpoolCVMOpWithConv : public framework::OperatorWithKernel {
outs_dims.resize(num_inputs);
bool use_cvm = ctx->Attrs().Get<bool>("use_cvm");
bool show_filter = ctx->Attrs().Get<bool>("show_filter");
const int embedx_concate_size = ctx->Attrs().Get<int>("embedx_concate_size");

PADDLE_ENFORCE_GT(num_inputs, 0UL,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -65,12 +66,12 @@ class FusedSeqpoolCVMOpWithConv : public framework::OperatorWithKernel {
std::vector<int64_t> out_dim;
if (use_cvm) {
if (show_filter) {
out_dim = {-1, dims[rank - 1] - 1};
out_dim = {-1, (dims[rank - 1] - 1) * embedx_concate_size};
} else {
out_dim = {-1, dims[rank - 1]};
out_dim = {-1, dims[rank - 1] * embedx_concate_size};
}
} else {
out_dim = {-1, dims[rank - 1] - cvm_offset};
out_dim = {-1, (dims[rank - 1] - cvm_offset) * embedx_concate_size};
}
outs_dims[i] = phi::make_ddim(out_dim);
}
Expand Down Expand Up @@ -108,8 +109,13 @@ class FusedSeqpoolCVMOpWithConvMaker : public framework::OpProtoAndCheckerMaker
"(float, default 0.0) The value to pad for empty sequence.")
.SetDefault(0.0);
AddAttr<bool>("use_cvm", "bool, use cvm or not").SetDefault(true);
AddAttr<bool>("need_filter", "(bool, default false)").SetDefault(false);
AddAttr<float>("show_coeff", "(float, default 0.2)").SetDefault(0.2);
AddAttr<float>("clk_coeff", "(float, default 1)").SetDefault(1);
AddAttr<float>("threshold", "(float, default 0.96)").SetDefault(0.96);
AddAttr<int>("cvm_offset", "(int, default 3)").SetDefault(3);
AddAttr<bool>("show_filter", "(bool, default false)").SetDefault(false);
AddAttr<int>("embedx_concate_size", "(int, default 1)").SetDefault(1);

AddComment(R"DOC(
Fuse multiple pairs of Sequence Pool and CVM Operator.
Expand All @@ -129,6 +135,7 @@ class FusedSeqpoolCVMGradOpWithConv : public framework::OperatorWithKernel {
const int cvm_offset = ctx->Attrs().Get<int>("cvm_offset");
bool use_cvm = ctx->Attrs().Get<bool>("use_cvm");
bool show_filter = ctx->Attrs().Get<bool>("show_filter");
const int embedx_concate_size = ctx->Attrs().Get<int>("embedx_concate_size");

PADDLE_ENFORCE_EQ(
cvm_dims.size(), 2,
Expand All @@ -144,7 +151,7 @@ class FusedSeqpoolCVMGradOpWithConv : public framework::OperatorWithKernel {
if (use_cvm) {
auto o_dim = og_dims[i][og_dims[i].size() - 1];
if (show_filter) {
o_dim += 1;
o_dim = o_dim / embedx_concate_size + 1;
}
PADDLE_ENFORCE_EQ(
o_dim, x_dims[i][og_dims[i].size() - 1],
Expand All @@ -157,7 +164,7 @@ class FusedSeqpoolCVMGradOpWithConv : public framework::OperatorWithKernel {
} else {
PADDLE_ENFORCE_EQ(
og_dims[i][og_dims[i].size() - 1],
x_dims[i][og_dims[i].size() - 1] - cvm_offset,
(x_dims[i][og_dims[i].size() - 1] - cvm_offset) * embedx_concate_size,
platform::errors::InvalidArgument(
"The dimension mismatch between Input(OUT@GRAD) and "
"Input(X). Received Input(OUT@GRAD): input rank %u, "
Expand Down
25 changes: 19 additions & 6 deletions paddle/fluid/operators/fused/fused_seqpool_cvm_with_conv_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@ class FusedSeqpoolCVMWithConvOpXPUKernel : public framework::OpKernel<T> {
auto out = ctx.MultiOutput<LoDTensor>("Out");
auto padding_value = ctx.Attr<float>("pad_value");
bool use_cvm = ctx.Attr<bool>("use_cvm");
bool need_filter = ctx.Attr<bool>("need_filter");
float show_coeff = ctx.Attr<float>("show_coeff");
float clk_coeff = ctx.Attr<float>("clk_coeff");
float threshold = ctx.Attr<float>("threshold");
auto cvm_offset = ctx.Attr<int>("cvm_offset");
auto show_filter = ctx.Attr<bool>("show_filter");
const int embedx_concate_size = ctx.Attr<int>("embedx_concate_size");

auto x0_lod = ins[0]->lod();
auto x0_dims = ins[0]->dims();
Expand All @@ -71,13 +76,13 @@ class FusedSeqpoolCVMWithConvOpXPUKernel : public framework::OpKernel<T> {
phi::Place l3_place = ctx.template device_context<DeviceContext>().GetL3Place();
int w = ins[0]->numel() / x0_dims[0];
if(use_cvm) {
if (show_filter) w = w - 1;
PADDLE_ENFORCE_EQ(y_dims[1] % w, 0,
if (show_filter) w = (w - 1);
PADDLE_ENFORCE_EQ(y_dims[1] % (w * embedx_concate_size), 0,
paddle::platform::errors::InvalidArgument(
"The output of dims[1] should be dividable of w"));
}
else{
PADDLE_ENFORCE_EQ(y_dims[1] % (w - cvm_offset), 0,
PADDLE_ENFORCE_EQ(y_dims[1] % ((w - cvm_offset) * embedx_concate_size), 0,
paddle::platform::errors::InvalidArgument(
"The output of dims[1] should be dividable of (w-2)"));
}
Expand Down Expand Up @@ -112,9 +117,14 @@ class FusedSeqpoolCVMWithConvOpXPUKernel : public framework::OpKernel<T> {
x0_dims[1],
slot_num,
use_cvm,
need_filter,
show_coeff,
clk_coeff,
threshold,
show_filter,
padding_value,
cvm_offset);
cvm_offset,
embedx_concate_size);
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The sequence_sum_pool_cvm_with_conv XPU OP return wrong value[%d %s]",
Expand All @@ -137,9 +147,11 @@ class FusedSeqpoolCVMWithConvGradOpXPUKernel : public framework::OpKernel<T> {
auto xs = ctx.MultiInput<LoDTensor>("X");
const framework::Tensor* cvm = ctx.Input<framework::Tensor>("CVM");
auto dxs = ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
auto use_cvm = ctx.Attr<bool>("use_cvm");//TODO:
auto use_cvm = ctx.Attr<bool>("use_cvm");
bool show_filter = ctx.Attr<bool>("show_filter");
auto cvm_offset = ctx.Attr<int>("cvm_offset");
const int embedx_concate_size = ctx.Attr<int>("embedx_concate_size");

int slot_num = dxs.size();
auto xpu_context = ctx.template device_context<DeviceContext>().x_context();
auto place = ctx.GetPlace();
Expand Down Expand Up @@ -194,7 +206,8 @@ class FusedSeqpoolCVMWithConvGradOpXPUKernel : public framework::OpKernel<T> {
show_filter,
item_size,
batch_size,
slot_num);
slot_num,
embedx_concate_size);
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The sequence_sum_pool_cvm_with_conv_grad XPU OP return wrong value[%d %s]",
Expand Down
15 changes: 13 additions & 2 deletions python/paddle/fluid/contrib/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1910,8 +1910,13 @@ def fused_seqpool_cvm_with_conv(input,
cvm,
pad_value=0.0,
use_cvm=True,
need_filter=False,
show_coeff=0.2,
clk_coeff=1.0,
threshold=0.96,
show_filter=False,
cvm_offset=3):
cvm_offset=3,
embedx_concate_size=1):
"""
**Notes: The Op only receives List of LoDTensor as input, only support SUM pooling now.
:attr:`input`.
Expand All @@ -1921,6 +1926,7 @@ def fused_seqpool_cvm_with_conv(input,
cvm(Variable): cvm Variable.
pad_value(float): padding value of sequence pool.
use_cvm(bool): use cvm or not.
embedx_concate_size(uint): is expand slot's feasign into matrix
Returns:
Variable|list of Variable: The tensor variable storing sequence pool and cvm
of input.
Expand Down Expand Up @@ -1955,7 +1961,12 @@ def fused_seqpool_cvm_with_conv(input,
"pad_value": pad_value,
"use_cvm": use_cvm,
"cvm_offset": cvm_offset,
"show_filter": show_filter
"need_filter": need_filter,
"show_coeff": show_coeff,
"clk_coeff": clk_coeff,
"threshold": threshold,
"show_filter": show_filter,
"embedx_concate_size": embedx_concate_size,
})

return outs
Expand Down