diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index aba6005f0cfdf..bc2e2dc9b6562 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -55,7 +55,7 @@ def _set_variable_data(scope, place, var_name, np_value): Set the value of var node by name, if the node exits, ''' assert isinstance(np_value, np.ndarray), \ - 'The type of value should be numpy array.' + 'The type of value should be numpy array.' var_node = scope.find_var(var_name) if var_node != None: tensor = var_node.get_tensor() @@ -138,8 +138,10 @@ def __init__(self, batch_size=10, batch_nums=None, algo="KL", + hist_percent=0.99999, quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], is_full_quantize=False, + bias_correction=False, activation_bits=8, weight_bits=8, activation_quantize_type='range_abs_max', @@ -180,7 +182,13 @@ def __init__(self, get the KL threshold for quantized activations and get the abs_max value for quantized weights. If algo='abs_max', get the abs max value for activations and weights. If algo= 'min_max', get the min - and max value for quantized activations and weights. Default is KL. + and max value for quantized activations and weights. If algo='avg', + get the average value among the max values for activations. If + algo= 'hist', get the value of 'hist_percent' quantile as the threshold. + If algo='mse', get the value which makes the quantization mse loss + minimal. Default is KL. + hist_percent(float, optional): The threshold of algo 'hist' for activations. + Default is 0.99999. quantizable_op_type(list[str], optional): List the type of ops that will be quantized. Default is ["conv2d", "depthwise_conv2d", "mul"]. @@ -188,6 +196,8 @@ def __init__(self, apply quantization to all supported quantizable op type. If set is_full_quantized as False, only apply quantization to the op type according to the input quantizable_op_type. + bias_correction(bool, optional): If set as True, use the bias correction + method of https://arxiv.org/abs/1810.05723. Default is False. activation_bits(int): quantization bit number for activation. weight_bits(int, optional): quantization bit number for weights. activation_quantize_type(str): quantization type for activation, @@ -255,7 +265,9 @@ def __init__(self, 'range_abs_max', 'moving_average_abs_max', 'abs_max' ] self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max'] - self._support_algo_type = ['KL', 'abs_max', 'min_max'] + self._support_algo_type = [ + 'KL', 'hist', 'avg', 'mse', 'abs_max', 'min_max' + ] self._dynamic_quantize_op_type = ['lstm'] self._support_quantize_op_type = \ list(set(QuantizationTransformPass._supported_quantizable_op_type + @@ -270,7 +282,7 @@ def __init__(self, "cannot be None in the same time." assert batch_size > 0, "The batch_size should be greater than 0." assert algo in self._support_algo_type, \ - "The algo should be KL, abs_max or min_max." + "The algo should be KL, hist, mse, avg, abs_max or min_max." assert activation_quantize_type in self._support_activation_quantize_type, \ "The activation_quantize_type ({}) should in ({}).".format( activation_quantize_type, self._support_activation_quantize_type) @@ -279,6 +291,7 @@ def __init__(self, weight_quantize_type, self._support_weight_quantize_type) # Save input params + self._bias_correction = bias_correction self._executor = executor self._scope = global_scope() if scope == None else scope self._model_dir = model_dir @@ -289,6 +302,7 @@ def __init__(self, self._batch_size = batch_size self._batch_nums = batch_nums self._algo = algo + self._hist_percent = hist_percent self._activation_bits = activation_bits self._weight_bits = weight_bits self._activation_quantize_type = activation_quantize_type @@ -314,17 +328,21 @@ def __init__(self, self._quantized_weight_var_name = set() self._quantized_act_var_name = set() self._weight_op_pairs = {} - # The vars for alog = KL + # The vars for alog = KL or hist self._sampling_act_abs_min_max = {} self._sampling_act_histogram = {} self._sampling_data = {} - self._quantized_var_kl_threshold = {} + self._quantized_var_threshold = {} self._histogram_bins = 2048 # The vars for algo = min_max self._quantized_var_min = {} self._quantized_var_max = {} - # The vars for algo = abs_max - self._quantized_var_abs_max = {} + # The vars for algo = avg + self._quantized_var_avg = {} + # The best loss of algo = mse + self._best_mse_loss = {} + # The threshold for algo = abs_max, mse or avg + self._quantized_threshold = {} def quantize(self): ''' @@ -341,7 +359,7 @@ def quantize(self): self._collect_target_varnames() self._set_activation_persistable() - if self._algo == "KL": + if self._algo in ["KL", "hist"]: _logger.info("Preparation stage ...") batch_id = 0 for data in self._data_loader(): @@ -374,13 +392,14 @@ def quantize(self): if self._batch_nums and batch_id >= self._batch_nums: break _logger.info("Finish sampling stage, all batch: " + str(batch_id)) - self._reset_activation_persistable() - - if self._algo == "KL": - self._calculate_kl_threshold() - - if self._algo in ["KL", "abs_max"]: + if self._algo == 'avg': + for var_name in self._quantized_act_var_name: + self._quantized_threshold[var_name] = \ + np.array(self._quantized_var_avg[var_name]).mean() + if self._algo in ["KL", "hist"]: + self._calculate_kl_hist_threshold() + if self._algo in ["KL", "abs_max", "hist", "avg", "mse"]: self._update_program() else: self._save_input_threhold() @@ -526,14 +545,84 @@ def _sampling(self): ''' if self._algo == "abs_max": self._sample_abs_max() + elif self._algo == "avg": + self._sample_avg() elif self._algo == "min_max": self._sample_min_max() - elif self._algo == "KL": + elif self._algo == "mse": + self._sample_mse() + elif self._algo in ["KL", "hist"]: self._sample_histogram() + def _sample_mse(self): + if self._quantized_threshold == {}: + for var_name in self._quantized_weight_var_name: + var_tensor = _load_variable_data(self._scope, var_name) + if self._weight_quantize_type == "abs_max": + abs_max_value = float(np.max(np.abs(var_tensor))) + elif self._weight_quantize_type == "channel_wise_abs_max": + abs_max_value = [] + if self._weight_op_pairs[ + var_name] in _channelwise_quant_axis1_ops: + for i in range(var_tensor.shape[1]): + abs_max_value.append( + float(np.max(np.abs(var_tensor[:, i])))) + else: + for i in range(var_tensor.shape[0]): + abs_max_value.append( + float(np.max(np.abs(var_tensor[i])))) + self._quantized_threshold[var_name] = abs_max_value + _logger.info("MSE searching stage ...") + for var_name in self._quantized_act_var_name: + var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = var_tensor.flatten() + abs_max_value = float(np.max(np.abs(var_tensor))) + s = 0.3 + if var_name not in self._best_mse_loss: + self._best_mse_loss[var_name] = float('inf') + while s <= 1.0: + scale = s * abs_max_value + s += 0.02 + bins = 2**(self._activation_bits - 1) - 1 + quant_dequant_var = np.round( + np.clip(var_tensor, 0.0, scale) / scale * + bins) / bins * scale + mse_loss = ((var_tensor - quant_dequant_var)**2).mean() + if mse_loss <= self._best_mse_loss[var_name]: + self._best_mse_loss[var_name] = mse_loss + self._quantized_threshold[var_name] = scale + + def _sample_avg(self): + if self._quantized_threshold == {}: + for var_name in self._quantized_weight_var_name: + var_tensor = _load_variable_data(self._scope, var_name) + if self._weight_quantize_type == "abs_max": + abs_max_value = float(np.max(np.abs(var_tensor))) + elif self._weight_quantize_type == "channel_wise_abs_max": + abs_max_value = [] + if self._weight_op_pairs[ + var_name] in _channelwise_quant_axis1_ops: + for i in range(var_tensor.shape[1]): + abs_max_value.append( + float(np.max(np.abs(var_tensor[:, i])))) + else: + for i in range(var_tensor.shape[0]): + abs_max_value.append( + float(np.max(np.abs(var_tensor[i])))) + self._quantized_threshold[var_name] = abs_max_value + + for var_name in self._quantized_act_var_name: + var_tensor = _load_variable_data(self._scope, var_name) + abs_max_value = float(np.max(np.abs(var_tensor))) + if (var_name not in self._quantized_var_avg): + self._quantized_var_avg[var_name] = [] + abs_avg_value = float(np.mean(np.max( \ + np.abs(var_tensor.reshape(var_tensor.shape[0], -1)), axis=(1)))) + self._quantized_var_avg[var_name].append(abs_avg_value) + continue + def _sample_abs_max(self): - # Only calculate abs_max value for weight for once - if self._quantized_var_abs_max == {}: + if self._quantized_threshold == {}: for var_name in self._quantized_weight_var_name: var_tensor = _load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": @@ -549,14 +638,14 @@ def _sample_abs_max(self): for i in range(var_tensor.shape[0]): abs_max_value.append( float(np.max(np.abs(var_tensor[i])))) - self._quantized_var_abs_max[var_name] = abs_max_value + self._quantized_threshold[var_name] = abs_max_value for var_name in self._quantized_act_var_name: var_tensor = _load_variable_data(self._scope, var_name) abs_max_value = float(np.max(np.abs(var_tensor))) - if (var_name not in self._quantized_var_abs_max) or \ - (abs_max_value > self._quantized_var_abs_max[var_name]): - self._quantized_var_abs_max[var_name] = abs_max_value + if (var_name not in self._quantized_threshold) or \ + (abs_max_value > self._quantized_threshold[var_name]): + self._quantized_threshold[var_name] = abs_max_value def _sample_min_max(self): if self._quantized_var_min == {} and self._quantized_var_max == {}: @@ -646,12 +735,12 @@ def _init_sampling_act_histogram(self): [], bins=self._histogram_bins, range=(min_val, max_val)) self._sampling_act_histogram[var_name] = [hist, hist_edeges] - def _calculate_kl_threshold(self): + def _calculate_kl_hist_threshold(self): ''' - Calculate the KL threshold of quantized variables. + Calculate the KL or hist threshold of quantized variables. ''' - _logger.info("Calculate KL threshold ...") - assert self._algo == "KL", "The algo should be KL to calculate kl threshold." + _logger.info("Calculate {} threshold ...".format(self._algo)) + assert self._algo in ["KL", "hist"], "The algo should be KL or hist." # Abs_max threshold for weights for var_name in self._quantized_weight_var_name: @@ -669,18 +758,22 @@ def _calculate_kl_threshold(self): for i in range(weight_data.shape[0]): weight_threshold.append( float(np.max(np.abs(weight_data[i])))) - self._quantized_var_kl_threshold[var_name] = weight_threshold + self._quantized_var_threshold[var_name] = weight_threshold for var_name in self._quantized_act_var_name: hist, hist_edeges = self._sampling_act_histogram[var_name] - self._quantized_var_kl_threshold[var_name] = \ - self._get_kl_scaling_factor(hist, hist_edeges) + if self._algo == "KL": + self._quantized_var_threshold[var_name] = \ + self._get_kl_scaling_factor(hist, hist_edeges) + elif self._algo == "hist": + self._quantized_var_threshold[var_name] = \ + self._get_hist_scaling_factor(hist, hist_edeges) def _update_program(self): ''' Use QuantizationTransformPass and AddQuantDequantPass to insert fake_quantize, fake_dequantize and fake_quant_dequant op. - Besides, save all kl threshold to the scale var node. + Besides, save all threshold to the scale var node. ''' _logger.info("Update the program ...") graph = IrGraph(core.Graph(self._program.desc), for_test=True) @@ -711,11 +804,11 @@ def _update_program(self): quantizable_op_type=minor_quantizable_op_types) add_quant_dequant_pass.apply(graph) - # save abs_max or KL threshold to scale var node - if self._algo == "KL": - scale_dict = self._quantized_var_kl_threshold + # save threshold to scale var node + if self._algo in ["KL", "hist"]: + scale_dict = self._quantized_var_threshold else: - scale_dict = self._quantized_var_abs_max + scale_dict = self._quantized_threshold for key, val in scale_dict.items(): _set_variable_data( self._scope, @@ -734,6 +827,7 @@ def _update_program(self): freeze_pass = QuantizationFreezePass( scope=self._scope, place=self._place, + bias_correction=self._bias_correction, weight_bits=self._weight_bits, activation_bits=self._activation_bits, weight_quantize_type=self._weight_quantize_type, @@ -761,20 +855,28 @@ def analysis_and_save_info(op_node, out_var_name): out_var_name + " is not the output of the op" if self._algo == "KL": # For compatibility, we save output threshold by two methods. - save_info(op_node, out_var_name, - self._quantized_var_kl_threshold, "out_threshold", - "post_kl") + save_info(op_node, out_var_name, self._quantized_var_threshold, + "out_threshold", "post_kl") save_info( - op_node, out_var_name, self._quantized_var_kl_threshold, + op_node, out_var_name, self._quantized_var_threshold, argname_index[0] + str(argname_index[1]) + "_threshold", "post_kl") - elif self._algo == "abs_max": - save_info(op_node, out_var_name, self._quantized_var_abs_max, - "out_threshold", "post_abs_max") + elif self._algo == "hist": + # For compatibility, we save output threshold by two methods. + save_info(op_node, out_var_name, self._quantized_var_threshold, + "out_threshold", "post_hist") save_info( - op_node, out_var_name, self._quantized_var_abs_max, + op_node, out_var_name, self._quantized_var_threshold, argname_index[0] + str(argname_index[1]) + "_threshold", - "post_kl") + "post_hist") + + elif self._algo in ["avg", "abs_max", "mse"]: + save_info(op_node, out_var_name, self._quantized_threshold, + "out_threshold", "post_" + str(self._algo)) + save_info( + op_node, out_var_name, self._quantized_threshold, + argname_index[0] + str(argname_index[1]) + "_threshold", + "post_" + str(self._algo)) elif self._algo == "min_max": save_info(op_node, out_var_name, self._quantized_var_min, "out_min", "post_min_max") @@ -817,10 +919,27 @@ def _collect_dynamic_quantize_op_threshold(self, target_ops_type): op._set_attr("quantization_type", quantization_type) op._set_attr("bit_length", self._weight_bits) - def _get_kl_scaling_factor(self, hist, hist_edeges, num_quantized_bins=255): + def _get_hist_scaling_factor(self, hist, hist_edges): + ''' + Using the hist method to get the scaling factor. + ''' + threshold_rate = self._hist_percent + hist = hist / float(sum(hist)) + hist_sum = 0 + hist_index = 0 + for i in range(len(hist)): + hist_sum += hist[i] + if hist_sum >= threshold_rate: + hist_index = i + 1 + break + bin_width = hist_edges[1] - hist_edges[0] + return (hist_index - 0.5) * bin_width + + def _get_kl_scaling_factor(self, hist, hist_edeges): ''' Using the KL-divergenc method to get the more precise scaling factor. ''' + num_quantized_bins = 2**(self._activation_bits - 1) - 1 ending_iter = self._histogram_bins - 1 starting_iter = int(ending_iter * 0.7) bin_width = hist_edeges[1] - hist_edeges[0] diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 3f9ff7295dd6b..79aad8c8bc53d 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1070,6 +1070,7 @@ class QuantizationFreezePass(object): def __init__(self, scope, place, + bias_correction=False, weight_bits=8, activation_bits=8, weight_quantize_type='abs_max', @@ -1085,6 +1086,8 @@ def __init__(self, scope(fluid.Scope): scope is used to get the weight tensor values. place(fluid.CPUPlace|fluid.CUDAPlace|str): place is used to restore the weight tensors. If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs. + bias_correction(bool): whether use bias correction for post-training quantization. + https://arxiv.org/abs/1810.05723. weight_bits(int): quantization bit number for weights. activation_bits(int): quantization bit number for activation. weight_quantize_type(str): quantization type for weights, support 'abs_max' and @@ -1098,6 +1101,7 @@ def __init__(self, assert place is not None, \ 'The place cannot be set None.' self._scope = scope + self._bias_correction = bias_correction self._place = _get_paddle_place(place) self._weight_bits = weight_bits self._activation_bits = activation_bits @@ -1154,7 +1158,10 @@ def apply(self, graph): else: quant_axis = 0 quantized_param_v = self._quant( - param_v, scale_v, self._weight_bits, quant_axis) + param_v.copy(), scale_v, self._weight_bits, quant_axis) + if self._bias_correction == True: + quantized_param_v = self._bias_correction_w( + param_v, quantized_param_v, scale_v, quant_axis) self._restore_var(input_arg_name, quantized_param_v) self._remove_fake_quant_and_dequant_op(graph, op_node) @@ -1373,6 +1380,8 @@ def _clip(x, scale): if isinstance(scale, list): for i, s in enumerate(scale): + if s == 0.0: + s = 1e-8 if quant_axis == 0: x[i] = _clip(x[i], s) x[i] = np.round(x[i] / s * bnt) @@ -1384,6 +1393,46 @@ def _clip(x, scale): x = np.round(x / scale * bnt) return x + def _bias_correction_w(self, x, x_quant, scale_v, quant_axis): + ''' + Bias correction for weight + ''' + eps = 1e-8 + bnt = (1 << (self._weight_bits - 1)) - 1 + x_dequant = x_quant.copy() + if isinstance(scale_v, list): + if quant_axis == 0: + for i, s in enumerate(scale_v): + x_dequant[i] = x_dequant[i] * s / bnt + quant_bias = x - x_dequant + mean_bias = quant_bias.reshape(quant_bias.shape[0], -1).mean(-1) + std_orig = x.reshape(x.shape[0], -1).std(-1) + std_quant = x_dequant.reshape(x_dequant.shape[0], -1).std(-1) + std_bias = std_orig / (std_quant + eps) + else: + for i, s in enumerate(scale_v): + x_dequant[:, i] = x_quant[:, i] * s / bnt + quant_bias = x - x_dequant + mean_bias = np.array([ + quant_bias[:, i].mean() for i in range(quant_bias.shape[1]) + ]) + std_orig = np.array([x[:, i].std() for i in range(x.shape[1])]) + std_quant = np.array( + [x_dequant[:, i].std() for i in range(x_dequant.shape[1])]) + std_bias = std_orig / (std_quant + eps) + else: + x_dequant = x_quant * scale_v / bnt + mean_bias = (x - x_dequant).mean() + std_bias = x.std() / (x_dequant.std() + eps) + if mean_bias.ndim == 1: + std_bias = np.resize(std_bias, x.shape) + mean_bias = np.resize(mean_bias, x.shape) + + x_dequant = (mean_bias + x_dequant) * std_bias + quantized_param_v = self._quant(x_dequant, scale_v, self._weight_bits, + quant_axis) + return quantized_param_v + class ConvertToInt8Pass(object): def __init__(self, scope, place, quantizable_op_type=None): diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py index 3ea1c84f976a8..da5c5d6dc9441 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py @@ -204,6 +204,66 @@ def test_post_training_kl(self): quant_iterations) +class TestPostTraininghistForMnist(TestPostTrainingQuantization): + def test_post_training_hist(self): + model_name = "mnist_model" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" + data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" + algo = "hist" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold, batch_size, infer_iterations, + quant_iterations) + + +class TestPostTrainingmseForMnist(TestPostTrainingQuantization): + def test_post_training_mse(self): + model_name = "mnist_model" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" + data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" + algo = "mse" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold, batch_size, infer_iterations, + quant_iterations) + + +class TestPostTrainingavgForMnist(TestPostTrainingQuantization): + def test_post_training_avg(self): + model_name = "mnist_model" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" + data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" + algo = "avg" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold, batch_size, infer_iterations, + quant_iterations) + + class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): def test_post_training_abs_max(self): model_name = "mnist_model" diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 18389d9433b9a..7161104861006 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -328,6 +328,50 @@ def test_post_training_kl_mobilenetv1(self): diff_threshold) +class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization): + def test_post_training_avg_mobilenetv1(self): + model = "MobileNet-V1" + algo = "avg" + data_urls = [ + 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' + ] + data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] + quantizable_op_type = [ + "conv2d", + "depthwise_conv2d", + "mul", + ] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.025 + self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold) + + +class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization): + def test_post_training_hist_mobilenetv1(self): + model = "MobileNet-V1" + algo = "hist" + data_urls = [ + 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' + ] + data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] + quantizable_op_type = [ + "conv2d", + "depthwise_conv2d", + "mul", + ] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.025 + self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold) + + class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): def test_post_training_abs_max_mobilenetv1(self): model = "MobileNet-V1" diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py index 768a9ba7cfc3e..790213d4b0292 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py @@ -257,6 +257,7 @@ def freeze_graph(self, use_cuda, seed, activation_quant_type, + bias_correction=False, weight_quant_type='abs_max', for_ci=True, quant_skip_pattern='skip_quant'): @@ -355,7 +356,8 @@ def build_program(main, startup, is_test): # Freeze graph for inference, but the weight of fc/conv is still float type. freeze_pass = QuantizationFreezePass( - scope=scope, place=place, weight_quantize_type=weight_quant_type) + scope=scope, place=place, bias_correction=bias_correction, \ + weight_quantize_type=weight_quant_type) freeze_pass.apply(test_graph) if not for_ci: marked_nodes = set() @@ -472,6 +474,13 @@ def test_freeze_graph_cpu_dynamic(self): def test_freeze_graph_cuda_static(self): if fluid.core.is_compiled_with_cuda(): with fluid.unique_name.guard(): + self.freeze_graph( + True, + seed=1, + activation_quant_type='range_abs_max', + bias_correction=True, + weight_quant_type='abs_max', + for_ci=True) self.freeze_graph( True, seed=1, @@ -496,6 +505,13 @@ def test_freeze_graph_cuda_static(self): activation_quant_type='moving_average_abs_max', weight_quant_type='channel_wise_abs_max', for_ci=True) + self.freeze_graph( + True, + seed=1, + activation_quant_type='moving_average_abs_max', + bias_correction=True, + weight_quant_type='channel_wise_abs_max', + for_ci=True) def test_freeze_graph_cpu_static(self): with fluid.unique_name.guard():