From 646454b7df331aba65db223d68031b08f4467cf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Thu, 16 Nov 2023 14:15:17 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PIR=20API=20adaptor=20No.127=E3=80=91l?= =?UTF-8?q?lm=5Fint8=5Flinear=20(#58882)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update quantized_linear.py * Update test_llm_int8_linear.py * Update quantized_linear.py * Update test_llm_int8_linear.py --- python/paddle/nn/quant/quantized_linear.py | 8 ++++++-- test/quantization/test_llm_int8_linear.py | 8 +++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py index 8f962da6b6766..862dfcdf3d1b4 100644 --- a/python/paddle/nn/quant/quantized_linear.py +++ b/python/paddle/nn/quant/quantized_linear.py @@ -15,7 +15,11 @@ from paddle import _C_ops from paddle.base.data_feeder import check_dtype from paddle.base.framework import convert_np_dtype_to_dtype_ -from paddle.framework import LayerHelper, in_dynamic_mode +from paddle.framework import ( + LayerHelper, + in_dynamic_mode, + in_dynamic_or_pir_mode, +) def weight_quantize(x, algo="weight_only_int8"): @@ -217,7 +221,7 @@ def llm_int8_linear( ... print(out.shape) [1, 2, 32] """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out = _C_ops.llm_int8_linear(x, weight, bias, weight_scale, threshold) return out else: diff --git a/test/quantization/test_llm_int8_linear.py b/test/quantization/test_llm_int8_linear.py index e79b802d37243..5a35b0d512461 100644 --- a/test/quantization/test_llm_int8_linear.py +++ b/test/quantization/test_llm_int8_linear.py @@ -23,6 +23,7 @@ from paddle.base import core from paddle.base.framework import default_main_program from paddle.framework import set_default_dtype +from paddle.pir_utils import test_with_pir_api np.random.seed(123) paddle.seed(123) @@ -86,11 +87,12 @@ def get_llm_int8_linear_out(self): ) return out.numpy() + @test_with_pir_api def get_llm_int8_linear_out_static(self): paddle.enable_static() - main = base.Program() - start = base.Program() - with base.program_guard(main, start): + main = base.static.Program() + start = base.static.Program() + with base.static.program_guard(main, start): x = paddle.static.data("x", self.x.shape, dtype=self.x.dtype) weight = paddle.static.data(