-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add mkldnn int8 related passes and config #38643
Conversation
Thanks for your contribution! |
Sorry to inform you that f97739b's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
@wozna Hi please review or continue this PR . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@baoachun You did a lot of great work. I added a few comments to the code.
I have also a question for the next steps. Because the last step of quantization are the three passes cpu_quantize_placement_pass
, cpu_quantize_pass
, cpu_quantize_squash_pass
.
For cpu_quantize_pass
, you need scales collected by your passes. To be sure, do you plan to do exactly the same way: save them as attributes to one of the ops and read them in cpu_quantize_pass
?
if (op_desc->HasAttr("fuse_relu")) { | ||
const bool fuse_relu = | ||
BOOST_GET_CONST(bool, op_desc->GetAttr("fuse_relu")); | ||
if (fuse_relu) activation = "relu"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can use GetAttrIfExists
to simplify it
if (op_desc->GetAttrIfExists<bool>("fuse_relu"))
activation = "relu";
const bool fuse_relu = | ||
BOOST_GET_CONST(bool, op_desc->GetAttr("fuse_relu")); | ||
if (fuse_relu) activation = "relu"; | ||
} else if (op_desc->HasAttr("fuse_brelu")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
++iter) { | ||
op_node->Op()->SetAttr(iter->first + "_var_quant_scales", iter->second); | ||
} | ||
break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be sure, the first operator is found here and we enter all scales into it as separate attributes?
|
||
auto* scale_tensor = var->GetMutable<LoDTensor>(); | ||
auto scale_data = scale_tensor->mutable_data<float>(platform::CPUPlace()); | ||
float scale = 1.0 / scale_data[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In quant2_int8_mkldnn_pass it was checked if the obtained scales did not have the value of infinity then the scale was set to 0. Maybe it is better to add such a check here?
@@ -0,0 +1,75 @@ | |||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that date should be 2022
passes_.push_back("conv_relu_mkldnn_fuse_pass"); | ||
passes_.push_back("conv_relu6_mkldnn_fuse_pass"); | ||
// need input params? | ||
/// passes_.push_back("fc_fuse_pass"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass should have set attributes "use_gpu" and "use_fc_padding" to False.
passes_.push_back("repeated_fc_relu_fuse_pass"); | ||
// enable or disable? | ||
// passes_.push_back("fc_mkldnn_pass"); | ||
// passes_.push_back("fc_act_mkldnn_fuse_pass"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The best option here would be to make a parameter that the user can change. These two passes, depending on the model, can significantly accelerate the model, unfortunately for some models, it causes a decrease in performance.
passes_.push_back("mul_gru_fuse_pass"); | ||
passes_.push_back("multi_gru_fuse_pass"); | ||
passes_.push_back("multi_gru_seq_fuse_pass"); | ||
passes_.push_back("seq_concat_fc_fuse_pass"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update passes with recent quan2_int8_mkldnn_pass, because I can see that there where some changes there eg. in this PR #39369
// } | ||
// } | ||
// } | ||
// }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calculation of scales for weights is already implemented in mkldnn_quantzier.cc, it is done for GetMaxChGRUScalingFactor or GetMaxChLSTMScalingFactor so you can use it because the functionality is exactly the same.
AnalysisPredictor::MkldnnQuantizer::GetMaxChGRUScalingFactor( |
auto* scope = param_scope(); | ||
GetQuantInfo(graph, scope, weight_thresholds, var_quant_scales); | ||
|
||
//ComputeWeightScales(graph, scope); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that this pass will functionally correspond to this graph = self._compute_weight_scales (graph)
, do you also plan to add the functionality of graph = self._propagate_scales (graph)
here?
for (auto* op_node : | ||
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) { | ||
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || | ||
op_node->Op()->Type() == "feth") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: "fetch"
} | ||
} | ||
|
||
// void RequantMkldnnFusePass::ComputeWeightScales(ir::Graph* graph, Scope* scope, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the commented block of code still needed ?
@@ -0,0 +1,45 @@ | |||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change to 2022
if (fuse_relu) activation = "relu"; | ||
} else if (op_desc->HasAttr("fuse_brelu")) { | ||
const bool fuse_brelu = | ||
BOOST_GET_CONST(bool, op_desc->GetAttr("fuse_relu")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please double check what attrib name is being read, I suppose fuse_brelu
std::string output_act_name = fake_quant_out->Var()->Name(); | ||
auto outlinks = fake_quant_out->outputs; | ||
for (auto* next_node : outlinks) { | ||
next_node->Op()->RenameInput(output_act_name, input_act_name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
node is op?
|
||
void RequantMkldnnFusePass::ComputeWeightScales( | ||
ir::Graph* graph, Scope* scope, StringPairMap& var_quant_scales) const { | ||
auto get_scales = [&](Tensor* tensor, int axis) -> std::vector<float> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change all lambdas to function? It will simplify the testing process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will do it as soon as possible.
Hi @wozna , Please continue your review~ |
@@ -0,0 +1,75 @@ | |||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2022
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK~
const std::string suffix = "_" + key_suffix + "_" + flag; | ||
for (auto* op_node : | ||
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) { | ||
if (!op_node->IsOp()) continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this condition be the same as in SaveInfoInTheFirstOp line 32-33?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, do you have any suggestions please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know, there you skip feed and fetch ops so I thought maybe you have to skip it here too
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch")
if (fake_name.find(suffix) != std::string::npos) { | ||
size_t pos = fake_name.find(suffix); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (fake_name.find(suffix) != std::string::npos) { | |
size_t pos = fake_name.find(suffix); | |
size_t pos = fake_name.find(suffix); | |
if (pos != std::string::npos) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK~
/// | ||
/// \brief Turn on MKLDNN bfloat16. | ||
/// | ||
/// | ||
void EnableMkldnnBfloat16(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you could revert this newline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
for (auto* node_input : op_node->inputs) { | ||
for (auto* node_input_input : node_input->inputs) { | ||
if (!node_input_input->IsOp()) continue; | ||
if (op_node->Name().find("quantize_dequantize") == |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be like this?
if (op_node->Name().find("quantize_dequantize") == | |
if (node_input_input->Name().find("quantize_dequantize") == |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes,you are right!
std::unordered_map<std::string, std::vector<float>>* info_map) const { | ||
for (auto iter = var_quant_scales->begin(); iter != var_quant_scales->end(); | ||
iter++) { | ||
auto* data = iter->second.second.mutable_data<float>(platform::CPUPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto* data = iter->second.second.mutable_data<float>(platform::CPUPlace()); | |
auto* data = iter->second.second.data<float>(platform::CPUPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
|
||
void ComputePropagateScalesMkldnnPass::ConvertStringPairMap( | ||
StringPairMap* var_quant_scales, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be const?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will fix it in the next pr.
|
||
void ComputePropagateScalesMkldnnPass::PropagateScales( | ||
ir::Graph* graph, StringPairMap* var_quant_scales, | ||
const std::unordered_set<std::string> scale_immutable_ops) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could this be reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales( | ||
ir::Graph* graph, StringPairMap* var_quant_scales, | ||
const std::unordered_set<std::string> scale_immutable_ops) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could this be reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
var, "The input persistable var of %s op is not found.", op_desc->Type()); | ||
|
||
auto* weight_tensor = var->GetMutable<LoDTensor>(); | ||
auto* weight_data = weight_tensor->mutable_data<float>(platform::CPUPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto* weight_data = weight_tensor->mutable_data<float>(platform::CPUPlace()); | |
auto* weight_data = weight_tensor->data<float>(platform::CPUPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@baoachun In this branch https://github.com/wozna/Paddle/tree/mkldnn_int8_test in last commit I added tests for your changes. There is UT for scale calculation for functions from |
if (iter == std::end(passes_)) return -1; | ||
return std::distance(std::begin(passes_), iter); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Silv3S may consider adding such function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. This is exactly what I needed
Hi @wozna, what is the test scope of the new single test you added? It takes a long time for me to execute the |
namespace framework { | ||
namespace ir { | ||
|
||
using StringPairMap = std::unordered_map<std::string, std::pair<bool, Tensor>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里还是保持std::unordered_map<std::string, std::pair<bool, Tensor>>是吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
什么意思?要改吗?
@baoachun It looks like all tests were run. Did you use |
Yes~ |
} | ||
} | ||
|
||
void ComputePropagateScalesMkldnnPass::GetQuantInfo( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Achun, there is GetQuantInfo
function definition in both compute_propogate_scalse_mkldnn_pass and cpu_quantize_pass? We can consider unifying them in next PR since this PR almost pass all CIs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,需要通过graph把var_quant_scales传给cpu_quantize_pass
graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
||
std::vector<float> ComputePropagateScalesMkldnnPass::GetScales(Tensor* tensor, | ||
int axis) const { | ||
PADDLE_ENFORCE_LT(axis, 2, "The input axis is required to be less than 2."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
heer need to specify the error type, please see https://github.com/PaddlePaddle/Paddle/wiki/Paddle-Error-Message-Writing-Specification
PADDLE_ENFORCE_LT(axis, 2, "The input axis is required to be less than 2."); | ||
auto* data = tensor->data<float>(); | ||
const auto dims = tensor->dims(); | ||
PADDLE_ENFORCE_EQ(dims.size(), 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above
if (ops.count(op_desc->Type())) { | ||
auto var_name = op_desc->Input(weight_name)[0]; | ||
auto* var = scope->FindVar(var_name); | ||
PADDLE_ENFORCE_NOT_NULL( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above
Scope* scope, const std::string& wx_var_name, | ||
const std::string& wh_var_name, Tensor* tensor) const { | ||
auto* wx_var = scope->FindVar(wx_var_name); | ||
PADDLE_ENFORCE_NOT_NULL(wx_var, "The input persistable var %s is not found.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above, please refer to the result of approve CI check and modify it in turn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Others
Describe
推理config新增mkldnn int8配置选项,及量化pass