Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mkldnn int8 related passes and config #38643

Closed
wants to merge 39 commits into from
Closed

add mkldnn int8 related passes and config #38643

wants to merge 39 commits into from

Conversation

baoachun
Copy link
Contributor

@baoachun baoachun commented Dec 31, 2021

PR types

New features

PR changes

Others

Describe

推理config新增mkldnn int8配置选项,及量化pass

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot-old
Copy link

Sorry to inform you that f97739b's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@lidanqing-intel
Copy link
Contributor

@wozna Hi please review or continue this PR .

Copy link
Contributor

@wozna wozna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@baoachun You did a lot of great work. I added a few comments to the code.

I have also a question for the next steps. Because the last step of quantization are the three passes cpu_quantize_placement_pass, cpu_quantize_pass, cpu_quantize_squash_pass.
For cpu_quantize_pass, you need scales collected by your passes. To be sure, do you plan to do exactly the same way: save them as attributes to one of the ops and read them in cpu_quantize_pass?

if (op_desc->HasAttr("fuse_relu")) {
const bool fuse_relu =
BOOST_GET_CONST(bool, op_desc->GetAttr("fuse_relu"));
if (fuse_relu) activation = "relu";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use GetAttrIfExists to simplify it

if (op_desc->GetAttrIfExists<bool>("fuse_relu"))
  activation = "relu";

const bool fuse_relu =
BOOST_GET_CONST(bool, op_desc->GetAttr("fuse_relu"));
if (fuse_relu) activation = "relu";
} else if (op_desc->HasAttr("fuse_brelu")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

++iter) {
op_node->Op()->SetAttr(iter->first + "_var_quant_scales", iter->second);
}
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be sure, the first operator is found here and we enter all scales into it as separate attributes?


auto* scale_tensor = var->GetMutable<LoDTensor>();
auto scale_data = scale_tensor->mutable_data<float>(platform::CPUPlace());
float scale = 1.0 / scale_data[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In quant2_int8_mkldnn_pass it was checked if the obtained scales did not have the value of infinity then the scale was set to 0. Maybe it is better to add such a check here?

@@ -0,0 +1,75 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that date should be 2022

passes_.push_back("conv_relu_mkldnn_fuse_pass");
passes_.push_back("conv_relu6_mkldnn_fuse_pass");
// need input params?
/// passes_.push_back("fc_fuse_pass");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass should have set attributes "use_gpu" and "use_fc_padding" to False.

passes_.push_back("repeated_fc_relu_fuse_pass");
// enable or disable?
// passes_.push_back("fc_mkldnn_pass");
// passes_.push_back("fc_act_mkldnn_fuse_pass");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best option here would be to make a parameter that the user can change. These two passes, depending on the model, can significantly accelerate the model, unfortunately for some models, it causes a decrease in performance.

passes_.push_back("mul_gru_fuse_pass");
passes_.push_back("multi_gru_fuse_pass");
passes_.push_back("multi_gru_seq_fuse_pass");
passes_.push_back("seq_concat_fc_fuse_pass");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update passes with recent quan2_int8_mkldnn_pass, because I can see that there where some changes there eg. in this PR #39369

// }
// }
// }
// };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calculation of scales for weights is already implemented in mkldnn_quantzier.cc, it is done for GetMaxChGRUScalingFactor or GetMaxChLSTMScalingFactor so you can use it because the functionality is exactly the same.

AnalysisPredictor::MkldnnQuantizer::GetMaxChGRUScalingFactor(

auto* scope = param_scope();
GetQuantInfo(graph, scope, weight_thresholds, var_quant_scales);

//ComputeWeightScales(graph, scope);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this pass will functionally correspond to this graph = self._compute_weight_scales (graph), do you also plan to add the functionality of graph = self._propagate_scales (graph) here?

@wozna wozna assigned baoachun and lidanqing-intel and unassigned wozna Feb 22, 2022
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "feth")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: "fetch"

}
}

// void RequantMkldnnFusePass::ComputeWeightScales(ir::Graph* graph, Scope* scope,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the commented block of code still needed ?

@@ -0,0 +1,45 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change to 2022

if (fuse_relu) activation = "relu";
} else if (op_desc->HasAttr("fuse_brelu")) {
const bool fuse_brelu =
BOOST_GET_CONST(bool, op_desc->GetAttr("fuse_relu"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please double check what attrib name is being read, I suppose fuse_brelu

@lidanqing-intel lidanqing-intel removed their assignment Mar 9, 2022
std::string output_act_name = fake_quant_out->Var()->Name();
auto outlinks = fake_quant_out->outputs;
for (auto* next_node : outlinks) {
next_node->Op()->RenameInput(output_act_name, input_act_name);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node is op?


void RequantMkldnnFusePass::ComputeWeightScales(
ir::Graph* graph, Scope* scope, StringPairMap& var_quant_scales) const {
auto get_scales = [&](Tensor* tensor, int axis) -> std::vector<float> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change all lambdas to function? It will simplify the testing process.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will do it as soon as possible.

@baoachun
Copy link
Contributor Author

Hi @wozna , Please continue your review~

@@ -0,0 +1,75 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK~

const std::string suffix = "_" + key_suffix + "_" + flag;
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp()) continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this condition be the same as in SaveInfoInTheFirstOp line 32-33?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, do you have any suggestions please?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, there you skip feed and fetch ops so I thought maybe you have to skip it here too

  if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
        op_node->Op()->Type() == "fetch")

Comment on lines 59 to 60
if (fake_name.find(suffix) != std::string::npos) {
size_t pos = fake_name.find(suffix);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (fake_name.find(suffix) != std::string::npos) {
size_t pos = fake_name.find(suffix);
size_t pos = fake_name.find(suffix);
if (pos != std::string::npos) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK~

///
/// \brief Turn on MKLDNN bfloat16.
///
///
void EnableMkldnnBfloat16();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could revert this newline

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (auto* node_input : op_node->inputs) {
for (auto* node_input_input : node_input->inputs) {
if (!node_input_input->IsOp()) continue;
if (op_node->Name().find("quantize_dequantize") ==
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be like this?

Suggested change
if (op_node->Name().find("quantize_dequantize") ==
if (node_input_input->Name().find("quantize_dequantize") ==

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes,you are right!

std::unordered_map<std::string, std::vector<float>>* info_map) const {
for (auto iter = var_quant_scales->begin(); iter != var_quant_scales->end();
iter++) {
auto* data = iter->second.second.mutable_data<float>(platform::CPUPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto* data = iter->second.second.mutable_data<float>(platform::CPUPlace());
auto* data = iter->second.second.data<float>(platform::CPUPlace());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

void ComputePropagateScalesMkldnnPass::ConvertStringPairMap(
StringPairMap* var_quant_scales,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be const?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix it in the next pr.


void ComputePropagateScalesMkldnnPass::PropagateScales(
ir::Graph* graph, StringPairMap* var_quant_scales,
const std::unordered_set<std::string> scale_immutable_ops) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
ir::Graph* graph, StringPairMap* var_quant_scales,
const std::unordered_set<std::string> scale_immutable_ops) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

var, "The input persistable var of %s op is not found.", op_desc->Type());

auto* weight_tensor = var->GetMutable<LoDTensor>();
auto* weight_data = weight_tensor->mutable_data<float>(platform::CPUPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto* weight_data = weight_tensor->mutable_data<float>(platform::CPUPlace());
auto* weight_data = weight_tensor->data<float>(platform::CPUPlace());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@wozna
Copy link
Contributor

wozna commented Mar 31, 2022

@baoachun In this branch https://github.com/wozna/Paddle/tree/mkldnn_int8_test in last commit I added tests for your changes. There is UT for scale calculation for functions from compute_propagate_scales_mkldnn_pass and model test but for now it is performance test only.

if (iter == std::end(passes_)) return -1;
return std::distance(std::begin(passes_), iter);
}

Copy link
Contributor

@lidanqing-intel lidanqing-intel Mar 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Silv3S may consider adding such function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. This is exactly what I needed

@baoachun
Copy link
Contributor Author

baoachun commented Apr 2, 2022

Hi @wozna, what is the test scope of the new single test you added? It takes a long time for me to execute the test_analyzer_quant2_mobilenetv1_mkldnn single test, and it will get stuck, and there is also GPU information. I see that there are 1775 test cases in this single test. Is there a problem?
图片

namespace framework {
namespace ir {

using StringPairMap = std::unordered_map<std::string, std::pair<bool, Tensor>>;
Copy link
Contributor

@lidanqing-intel lidanqing-intel Apr 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还是保持std::unordered_map<std::string, std::pair<bool, Tensor>>是吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

什么意思?要改吗?

@wozna
Copy link
Contributor

wozna commented Apr 5, 2022

Hi @wozna, what is the test scope of the new single test you added? It takes a long time for me to execute the test_analyzer_quant2_mobilenetv1_mkldnn single test, and it will get stuck, and there is also GPU information. I see that there are 1775 test cases in this single test. Is there a problem? 图片

@baoachun It looks like all tests were run. Did you use ctest -R test_analyzer_quant2_mobilenetv1_mkldnn -V ?

@baoachun
Copy link
Contributor Author

baoachun commented Apr 6, 2022

Hi @wozna, what is the test scope of the new single test you added? It takes a long time for me to execute the test_analyzer_quant2_mobilenetv1_mkldnn single test, and it will get stuck, and there is also GPU information. I see that there are 1775 test cases in this single test. Is there a problem? 图片

@baoachun It looks like all tests were run. Did you use ctest -R test_analyzer_quant2_mobilenetv1_mkldnn -V ?

Yes~

}
}

void ComputePropagateScalesMkldnnPass::GetQuantInfo(
Copy link
Contributor

@lidanqing-intel lidanqing-intel Apr 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Achun, there is GetQuantInfo function definition in both compute_propogate_scalse_mkldnn_pass and cpu_quantize_pass? We can consider unifying them in next PR since this PR almost pass all CIs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,需要通过graph把var_quant_scales传给cpu_quantize_pass

graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'],

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


std::vector<float> ComputePropagateScalesMkldnnPass::GetScales(Tensor* tensor,
int axis) const {
PADDLE_ENFORCE_LT(axis, 2, "The input axis is required to be less than 2.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ENFORCE_LT(axis, 2, "The input axis is required to be less than 2.");
auto* data = tensor->data<float>();
const auto dims = tensor->dims();
PADDLE_ENFORCE_EQ(dims.size(), 2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

if (ops.count(op_desc->Type())) {
auto var_name = op_desc->Input(weight_name)[0];
auto* var = scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

Scope* scope, const std::string& wx_var_name,
const std::string& wh_var_name, Tensor* tensor) const {
auto* wx_var = scope->FindVar(wx_var_name);
PADDLE_ENFORCE_NOT_NULL(wx_var, "The input persistable var %s is not found.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above, please refer to the result of approve CI check and modify it in turn

Copy link
Contributor

@sfraczek sfraczek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@baoachun baoachun closed this Apr 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants