Skip to content

Commit

Permalink
【Hackathon 5th No.103】 fix the bug in moving fc_mkldnn to phi -part (#…
Browse files Browse the repository at this point in the history
…59531)

* try to fix the bug in fc_mkldnn

* fix the missing attr bug

* fix the parameters bug

* remove the paramars in pir

* roback and add attr

* add the scale_out

---------

Co-authored-by: zeroRains <linjunlu@zerorains.com>
  • Loading branch information
zeroRains and zeroRains authored Dec 27, 2023
1 parent 2dfa0f7 commit 5cbf32f
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 190 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@";
/// TODO(luotao): Note that this temporal attribute would be deleted after all
/// ops contain it.
constexpr char kAllKernelsMustComputeRuntimeShape[] =
"@ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE@";
"ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE";

// define some kernel priority
/* Define multiple kernel type fallback order*/
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/compat/fc.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ extra {
type: BOOLEAN
}
attrs {
name: "@ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE@"
name: "ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE"
type: BOOLEAN
}
attrs {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,7 @@ class FcElementwiseLayerNormFusePattern
{
{"in_num_col_dims", pat.Attr("in_num_col_dims")},
{"activation_type", pat.Attr("activation_type")},
{"use_mkldnn", pat.Attr("use_mkldnn")},
{"padding_weights", pat.Attr("padding_weights")},
{"use_quantizer", pat.Attr("use_quantizer")},
{"mkldnn_data_type", pat.Attr("mkldnn_data_type")},
{"scale_in", pat.Attr("scale_in")},
{"scale_weights", pat.Attr("scale_weights")},
{"scale_out", pat.Attr("scale_out")},
{"force_fp32_output", pat.Attr("force_fp32_output")},
});
const auto &add = pat.Op(paddle::dialect::AddOp::name());
const auto &layernorm =
Expand Down Expand Up @@ -104,14 +97,7 @@ class FcElementwiseLayerNormFuse2Pattern
{
{"in_num_col_dims", pat.Attr("in_num_col_dims")},
{"activation_type", pat.Attr("activation_type")},
{"use_mkldnn", pat.Attr("use_mkldnn")},
{"padding_weights", pat.Attr("padding_weights")},
{"use_quantizer", pat.Attr("use_quantizer")},
{"mkldnn_data_type", pat.Attr("mkldnn_data_type")},
{"scale_in", pat.Attr("scale_in")},
{"scale_weights", pat.Attr("scale_weights")},
{"scale_out", pat.Attr("scale_out")},
{"force_fp32_output", pat.Attr("force_fp32_output")},
});
const auto &add = pat.Op(paddle::dialect::AddOp::name());
const auto &layernorm =
Expand Down
49 changes: 9 additions & 40 deletions paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,32 +65,15 @@ class MatmulAddPattern : public pir::drr::DrrPatternBase<MatmulAddPattern> {
const auto &false_attr = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> bool { return false; });

const auto &fc = res.Op(
paddle::dialect::FcOp::name(),
{{
{"in_num_col_dims", in_num_col_dims_attr},
{"activation_type",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::string { return ""; })},
{"use_mkldnn", false_attr},
{"padding_weights", false_attr},
{"use_quantizer", false_attr},
{"mkldnn_data_type",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::string { return "float32"; })},
{"scale_in",
res.Attr([](const pir::drr::MatchContext &match_ctx) -> float {
return 1.0f;
})},
{"scale_weights",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::vector<float> { return {1.0f}; })},
{"scale_out",
res.Attr([](const pir::drr::MatchContext &match_ctx) -> float {
return 1.0f;
})},
{"force_fp32_output", false_attr},
}});
const auto &fc =
res.Op(paddle::dialect::FcOp::name(),
{{
{"in_num_col_dims", in_num_col_dims_attr},
{"activation_type",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::string { return ""; })},
{"padding_weights", false_attr},
}});
fc({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")},
{&res.Tensor("add_out")});
}
Expand All @@ -105,14 +88,7 @@ class FcWithReluPattern : public pir::drr::DrrPatternBase<FcWithReluPattern> {
{{
{"in_num_col_dims", pat.Attr("in_num_col_dims")},
{"activation_type", pat.Attr("activation_type")},
{"use_mkldnn", pat.Attr("use_mkldnn")},
{"padding_weights", pat.Attr("padding_weights")},
{"use_quantizer", pat.Attr("use_quantizer")},
{"mkldnn_data_type", pat.Attr("mkldnn_data_type")},
{"scale_in", pat.Attr("scale_in")},
{"scale_weights", pat.Attr("scale_weights")},
{"scale_out", pat.Attr("scale_out")},
{"force_fp32_output", pat.Attr("force_fp32_output")},
}});
fc({&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("y")},
{&pat.Tensor("fc_out")});
Expand All @@ -133,14 +109,7 @@ class FcWithReluPattern : public pir::drr::DrrPatternBase<FcWithReluPattern> {
{"activation_type",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::string { return "relu"; })},
{"use_mkldnn", pat.Attr("use_mkldnn")},
{"padding_weights", pat.Attr("padding_weights")},
{"use_quantizer", pat.Attr("use_quantizer")},
{"mkldnn_data_type", pat.Attr("mkldnn_data_type")},
{"scale_in", pat.Attr("scale_in")},
{"scale_weights", pat.Attr("scale_weights")},
{"scale_out", pat.Attr("scale_out")},
{"force_fp32_output", pat.Attr("force_fp32_output")},
}});
fc_with_relu({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")},
{&res.Tensor("relu_out")});
Expand Down
105 changes: 27 additions & 78 deletions paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,32 +94,15 @@ class SqueezeFcFusePattern
const auto &false_attr = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> bool { return false; });

const auto &fc = res.Op(
paddle::dialect::FcOp::name(),
{{
{"in_num_col_dims", in_num_col_dims_attr},
{"activation_type",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::string { return ""; })},
{"use_mkldnn", false_attr},
{"padding_weights", false_attr},
{"use_quantizer", false_attr},
{"mkldnn_data_type",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::string { return "float32"; })},
{"scale_in",
res.Attr([](const pir::drr::MatchContext &match_ctx) -> float {
return 1.0f;
})},
{"scale_weights",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::vector<float> { return {1.0f}; })},
{"scale_out",
res.Attr([](const pir::drr::MatchContext &match_ctx) -> float {
return 1.0f;
})},
{"force_fp32_output", false_attr},
}});
const auto &fc =
res.Op(paddle::dialect::FcOp::name(),
{{
{"in_num_col_dims", in_num_col_dims_attr},
{"activation_type",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::string { return ""; })},
{"padding_weights", false_attr},
}});
fc({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")},
{&res.Tensor("add_out")});
}
Expand Down Expand Up @@ -248,32 +231,15 @@ class ReshapeFcFusePattern
const auto &false_attr = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> bool { return false; });

const auto &fc = res.Op(
paddle::dialect::FcOp::name(),
{{
{"in_num_col_dims", in_num_col_dims_attr},
{"activation_type",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::string { return ""; })},
{"use_mkldnn", false_attr},
{"padding_weights", false_attr},
{"use_quantizer", false_attr},
{"mkldnn_data_type",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::string { return "float32"; })},
{"scale_in",
res.Attr([](const pir::drr::MatchContext &match_ctx) -> float {
return 1.0f;
})},
{"scale_weights",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::vector<float> { return {1.0f}; })},
{"scale_out",
res.Attr([](const pir::drr::MatchContext &match_ctx) -> float {
return 1.0f;
})},
{"force_fp32_output", false_attr},
}});
const auto &fc =
res.Op(paddle::dialect::FcOp::name(),
{{
{"in_num_col_dims", in_num_col_dims_attr},
{"activation_type",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::string { return ""; })},
{"padding_weights", false_attr},
}});
fc({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")},
{&res.Tensor("add_out")});
}
Expand Down Expand Up @@ -336,32 +302,15 @@ class FlattenFcFusePattern
const auto &false_attr = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> bool { return false; });

const auto &fc = res.Op(
paddle::dialect::FcOp::name(),
{{
{"in_num_col_dims", in_num_col_dims_attr},
{"activation_type",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::string { return ""; })},
{"use_mkldnn", false_attr},
{"padding_weights", false_attr},
{"use_quantizer", false_attr},
{"mkldnn_data_type",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::string { return "float32"; })},
{"scale_in",
res.Attr([](const pir::drr::MatchContext &match_ctx) -> float {
return 1.0f;
})},
{"scale_weights",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::vector<float> { return {1.0f}; })},
{"scale_out",
res.Attr([](const pir::drr::MatchContext &match_ctx) -> float {
return 1.0f;
})},
{"force_fp32_output", false_attr},
}});
const auto &fc =
res.Op(paddle::dialect::FcOp::name(),
{{
{"in_num_col_dims", in_num_col_dims_attr},
{"activation_type",
res.Attr([](const pir::drr::MatchContext &match_ctx)
-> std::string { return ""; })},
{"padding_weights", false_attr},
}});
fc({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")},
{&res.Tensor("add_out")});
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
data_type : x

- op : fc
args : (Tensor input, Tensor w, Tensor bias, int in_num_col_dims = 1, str activation_type = "", bool use_mkldnn = false, bool padding_weights = false, bool use_quantizer = false, str mkldnn_data_type = "float32", float scale_in = 1.0f, float[] scale_weights = {1.0f}, float scale_out = 1.0f, bool force_fp32_output = false)
args : (Tensor input, Tensor w, Tensor bias, int in_num_col_dims = 1, str activation_type = "", bool padding_weights = false)
output : Tensor(out)
infer_meta :
func : FCInferMeta
Expand Down
6 changes: 1 addition & 5 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1050,12 +1050,8 @@
bias : Bias
outputs :
out : Out
attrs :
scale_in : Scale_in
scale_weights : Scale_weights
scale_out : Scale_out
extra :
[bool @ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE@ = true]
attrs : [bool ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE = true, bool use_mkldnn = false, bool use_quantizer = false, str mkldnn_data_type = "float32", float Scale_in = 1.0f, "float[] Scale_weights = {1.0f}", float Scale_out = 1.0f, bool force_fp32_output = false]

- op : feed
outputs: {out: Out}
Expand Down
29 changes: 1 addition & 28 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3425,14 +3425,7 @@ void FCInferMeta(const MetaTensor& input,
const MetaTensor& bias,
const int in_num_col_dims,
const std::string& activation_type,
const bool use_mkldnn,
const bool padding_weights,
const bool use_quantizer,
const std::string& mkldnn_data_type,
const float scale_in,
const std::vector<float>& sclae_weights,
const float scale_out,
const bool force_fp32_output,
MetaTensor* out) {
PADDLE_ENFORCE_GE(
in_num_col_dims,
Expand All @@ -3441,15 +3434,7 @@ void FCInferMeta(const MetaTensor& input,
"The in_num_col_dims is expected to equal or greater than 1. "
"But received the in_num_col_dims is %d. ",
in_num_col_dims));
std::string mkldnn_data_type_list[] = {"float32", "int8", "bfloat16"};
PADDLE_ENFORCE_EQ(
std::find(std::begin(mkldnn_data_type_list),
std::end(mkldnn_data_type_list),
mkldnn_data_type) != std::end(mkldnn_data_type_list),
true,
phi::errors::InvalidArgument("The mkldnn_data_type shoule be [float32, "
"int8, bfloat16], but found %s.",
mkldnn_data_type.c_str()));

auto w_dims = w.dims();
PADDLE_ENFORCE_EQ(
w_dims.size(),
Expand Down Expand Up @@ -3522,18 +3507,6 @@ void FCInferMeta(const MetaTensor& input,
activation_type.c_str()));
}

if (use_mkldnn) {
PADDLE_ENFORCE_EQ(
in_dims.size() >= 2 && in_dims.size() <= 4,
true,
phi::errors::Unimplemented(
"The Input of fc is expected to be a 2-D, 3-D or 4-D tensor when "
"use_mkldnn is set. But received the number of Input's "
"dimensions is %d, Input's shape is %s.",
in_dims.size(),
in_dims));
}

std::vector<int64_t> output_dims;
phi::funcs::FCOutputSize(
in_dims, w_dims, output_dims, in_num_col_dims, padding_weights);
Expand Down
7 changes: 0 additions & 7 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -807,14 +807,7 @@ void FCInferMeta(const MetaTensor& input,
const MetaTensor& bias,
const int in_num_col_dims,
const std::string& activation_type,
const bool use_mkldnn,
const bool padding_weights,
const bool use_quantizer,
const std::string& mkldnn_data_type,
const float scale_in,
const std::vector<float>& sclae_weights,
const float scale_out,
const bool force_fp32_output,
MetaTensor* out);

void VariableLengthMemoryEfficientAttentionInferMeta(
Expand Down
Loading

0 comments on commit 5cbf32f

Please sign in to comment.