diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py index 8f962da6b6766..862dfcdf3d1b4 100644 --- a/python/paddle/nn/quant/quantized_linear.py +++ b/python/paddle/nn/quant/quantized_linear.py @@ -15,7 +15,11 @@ from paddle import _C_ops from paddle.base.data_feeder import check_dtype from paddle.base.framework import convert_np_dtype_to_dtype_ -from paddle.framework import LayerHelper, in_dynamic_mode +from paddle.framework import ( + LayerHelper, + in_dynamic_mode, + in_dynamic_or_pir_mode, +) def weight_quantize(x, algo="weight_only_int8"): @@ -217,7 +221,7 @@ def llm_int8_linear( ... print(out.shape) [1, 2, 32] """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out = _C_ops.llm_int8_linear(x, weight, bias, weight_scale, threshold) return out else: diff --git a/test/quantization/test_llm_int8_linear.py b/test/quantization/test_llm_int8_linear.py index e79b802d37243..5a35b0d512461 100644 --- a/test/quantization/test_llm_int8_linear.py +++ b/test/quantization/test_llm_int8_linear.py @@ -23,6 +23,7 @@ from paddle.base import core from paddle.base.framework import default_main_program from paddle.framework import set_default_dtype +from paddle.pir_utils import test_with_pir_api np.random.seed(123) paddle.seed(123) @@ -86,11 +87,12 @@ def get_llm_int8_linear_out(self): ) return out.numpy() + @test_with_pir_api def get_llm_int8_linear_out_static(self): paddle.enable_static() - main = base.Program() - start = base.Program() - with base.program_guard(main, start): + main = base.static.Program() + start = base.static.Program() + with base.static.program_guard(main, start): x = paddle.static.data("x", self.x.shape, dtype=self.x.dtype) weight = paddle.static.data(