diff --git a/neural_compressor/torch/algorithms/pt2e_quant/utility.py b/neural_compressor/torch/algorithms/pt2e_quant/utility.py index ecf14ec02a7..966baf0e53a 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/utility.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/utility.py @@ -21,7 +21,7 @@ from torch.ao.quantization.quantizer import QuantizationSpec from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer -from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2 +from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5 def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec: @@ -102,8 +102,8 @@ def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86Induct # set global global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic) quantizer.set_global(global_config) - # need torch >= 2.3.2 - if GT_TORCH_VERSION_2_3_2: # pragma: no cover + # need torch >= 2.5 + if GT_OR_EQUAL_TORCH_VERSION_2_5: # pragma: no cover op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() if op_type_config_dict: for op_type, config in op_type_config_dict.items(): diff --git a/neural_compressor/torch/utils/environ.py b/neural_compressor/torch/utils/environ.py index aa6b9affba0..444aaa95f3d 100644 --- a/neural_compressor/torch/utils/environ.py +++ b/neural_compressor/torch/utils/environ.py @@ -104,7 +104,7 @@ def get_torch_version(): return version -GT_TORCH_VERSION_2_3_2 = get_torch_version() > Version("2.3.2") +GT_OR_EQUAL_TORCH_VERSION_2_5 = get_torch_version() >= Version("2.5") def get_accelerator(device_name="auto"): diff --git a/test/3x/torch/quantization/test_pt2e_quant.py b/test/3x/torch/quantization/test_pt2e_quant.py index d55e9004a3a..ab80e991203 100644 --- a/test/3x/torch/quantization/test_pt2e_quant.py +++ b/test/3x/torch/quantization/test_pt2e_quant.py @@ -15,7 +15,7 @@ prepare, quantize, ) -from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2, TORCH_VERSION_2_2_2, get_torch_version +from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5, TORCH_VERSION_2_2_2, get_torch_version torch.manual_seed(0) @@ -131,7 +131,7 @@ def calib_fn(model): logger.warning("out shape is %s", out.shape) assert out is not None - @pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.2") + @pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5") def test_quantize_simple_model_with_set_local(self, force_not_import_ipex): model, example_inputs = self.build_simple_torch_model_and_example_inputs() float_model_output = model(*example_inputs) @@ -243,7 +243,7 @@ def get_node_in_graph(graph_module): nodes_in_graph[n] = 1 return nodes_in_graph - @pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.0") + @pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5") def test_mixed_fp16_and_int8(self, force_not_import_ipex): model, example_inputs = self.build_model_include_conv_and_linear() model = export(model, example_inputs=example_inputs)