Skip to content

Commit

Permalink
code style changed
Browse files Browse the repository at this point in the history
  • Loading branch information
XGZhang11 committed Apr 13, 2021
1 parent 216e946 commit 724e56c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,9 @@ def __init__(self,
'range_abs_max', 'moving_average_abs_max', 'abs_max'
]
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
self._support_algo_type = ['KL', 'hist', 'avg', 'mse', 'abs_max', 'min_max']
self._support_algo_type = [
'KL', 'hist', 'avg', 'mse', 'abs_max', 'min_max'
]
self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = \
list(set(QuantizationTransformPass._supported_quantizable_op_type +
Expand Down Expand Up @@ -389,11 +391,8 @@ def quantize(self):
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break

_logger.info("Finish sampling stage, all batch: " + str(batch_id))

self._reset_activation_persistable()

if self._algo == 'avg':
for var_name in self._quantized_act_var_name:
self._quantized_threshold[var_name] = \
Expand Down Expand Up @@ -573,8 +572,6 @@ def _sample_mse(self):
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value

#Search for the best threshold for activations
_logger.info("MSE searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
Expand All @@ -586,13 +583,15 @@ def _sample_mse(self):
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
bins = 2 ** (self._activation_bits -1) - 1
quant_dequant_var = np.round(np.clip(var_tensor, 0.0, scale) / scale * bins) / bins * scale
mse_loss = ((var_tensor - quant_dequant_var) ** 2).mean()
bins = 2**(self._activation_bits - 1) - 1
quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
mse_loss = ((var_tensor - quant_dequant_var)**2).mean()
if mse_loss <= self._best_mse_loss[var_name]:
self._best_mse_loss[var_name] = mse_loss
self._quantized_threshold[var_name] = scale

def _sample_avg(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
Expand All @@ -611,7 +610,7 @@ def _sample_avg(self):
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value

for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
abs_max_value = float(np.max(np.abs(var_tensor)))
Expand All @@ -621,8 +620,8 @@ def _sample_avg(self):
np.abs(var_tensor.reshape(var_tensor.shape[0], -1)), axis=(1))))
self._quantized_var_avg[var_name].append(abs_avg_value)
continue
def _sample_abs_max(self):

def _sample_abs_max(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
Expand All @@ -640,7 +639,7 @@ def _sample_abs_max(self):
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value

for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
abs_max_value = float(np.max(np.abs(var_tensor)))
Expand Down Expand Up @@ -856,18 +855,16 @@ def analysis_and_save_info(op_node, out_var_name):
out_var_name + " is not the output of the op"
if self._algo == "KL":
# For compatibility, we save output threshold by two methods.
save_info(op_node, out_var_name,
self._quantized_var_threshold, "out_threshold",
"post_kl")
save_info(op_node, out_var_name, self._quantized_var_threshold,
"out_threshold", "post_kl")
save_info(
op_node, out_var_name, self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_kl")
elif self._algo == "hist":
# For compatibility, we save output threshold by two methods.
save_info(op_node, out_var_name,
self._quantized_var_threshold, "out_threshold",
"post_hist")
save_info(op_node, out_var_name, self._quantized_var_threshold,
"out_threshold", "post_hist")
save_info(
op_node, out_var_name, self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
Expand Down Expand Up @@ -926,7 +923,7 @@ def _get_hist_scaling_factor(self, hist, hist_edges):
'''
Using the hist method to get the scaling factor.
'''
threshold_rate = self._hist_percent
threshold_rate = self._hist_percent
hist = hist / float(sum(hist))
hist_sum = 0
hist_index = 0
Expand All @@ -942,7 +939,7 @@ def _get_kl_scaling_factor(self, hist, hist_edeges):
'''
Using the KL-divergenc method to get the more precise scaling factor.
'''
num_quantized_bins = 2 ** (self._activation_bits - 1) - 1
num_quantized_bins = 2**(self._activation_bits - 1) - 1
ending_iter = self._histogram_bins - 1
starting_iter = int(ending_iter * 0.7)
bin_width = hist_edeges[1] - hist_edeges[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,7 @@ def _is_float(self, v):
def _quant(self, x, scale, num_bits, quant_axis):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
bnt = (1 << (num_bits - 1)) - 1

def _clip(x, scale):
x[x > scale] = scale
x[x < -scale] = -scale
Expand Down Expand Up @@ -1412,9 +1413,12 @@ def _bias_correction_w(self, x, x_quant, scale_v, quant_axis):
for i, s in enumerate(scale_v):
x_dequant[:, i] = x_quant[:, i] * s / bnt
quant_bias = x - x_dequant
mean_bias = np.array([quant_bias[:, i].mean() for i in range(quant_bias.shape[1])])
mean_bias = np.array([
quant_bias[:, i].mean() for i in range(quant_bias.shape[1])
])
std_orig = np.array([x[:, i].std() for i in range(x.shape[1])])
std_quant = np.array([x_dequant[:, i].std() for i in range(x_dequant.shape[1])])
std_quant = np.array(
[x_dequant[:, i].std() for i in range(x_dequant.shape[1])])
std_bias = std_orig / (std_quant + eps)
else:
x_dequant = x_quant * scale_v / bnt
Expand All @@ -1425,7 +1429,8 @@ def _bias_correction_w(self, x, x_quant, scale_v, quant_axis):
mean_bias = np.resize(mean_bias, x.shape)

x_dequant = (mean_bias + x_dequant) * std_bias
quantized_param_v = self._quant(x_dequant, scale_v, self._weight_bits, quant_axis)
quantized_param_v = self._quant(x_dequant, scale_v, self._weight_bits,
quant_axis)
return quantized_param_v


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def test_post_training_kl_mobilenetv1(self):
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)


class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_avg_mobilenetv1(self):
model = "MobileNet-V1"
Expand All @@ -348,6 +349,7 @@ def test_post_training_avg_mobilenetv1(self):
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)


class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_hist_mobilenetv1(self):
model = "MobileNet-V1"
Expand All @@ -369,26 +371,6 @@ def test_post_training_hist_mobilenetv1(self):
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)

class TestPostTrainingmseForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_mse_mobilenetv1(self):
model = "MobileNet-V1"
algo = "mse"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)

class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_abs_max_mobilenetv1(self):
Expand Down

1 comment on commit 724e56c

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.