Skip to content

Commit

Permalink
【PIR API adaptor No.127】llm_int8_linear (PaddlePaddle#58882)
Browse files Browse the repository at this point in the history
* Update quantized_linear.py

* Update test_llm_int8_linear.py

* Update quantized_linear.py

* Update test_llm_int8_linear.py
  • Loading branch information
Liyulingyue authored and SecretXV committed Nov 28, 2023
1 parent 0e325ae commit 646454b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
8 changes: 6 additions & 2 deletions python/paddle/nn/quant/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from paddle import _C_ops
from paddle.base.data_feeder import check_dtype
from paddle.base.framework import convert_np_dtype_to_dtype_
from paddle.framework import LayerHelper, in_dynamic_mode
from paddle.framework import (
LayerHelper,
in_dynamic_mode,
in_dynamic_or_pir_mode,
)


def weight_quantize(x, algo="weight_only_int8"):
Expand Down Expand Up @@ -217,7 +221,7 @@ def llm_int8_linear(
... print(out.shape)
[1, 2, 32]
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
out = _C_ops.llm_int8_linear(x, weight, bias, weight_scale, threshold)
return out
else:
Expand Down
8 changes: 5 additions & 3 deletions test/quantization/test_llm_int8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle.base import core
from paddle.base.framework import default_main_program
from paddle.framework import set_default_dtype
from paddle.pir_utils import test_with_pir_api

np.random.seed(123)
paddle.seed(123)
Expand Down Expand Up @@ -86,11 +87,12 @@ def get_llm_int8_linear_out(self):
)
return out.numpy()

@test_with_pir_api
def get_llm_int8_linear_out_static(self):
paddle.enable_static()
main = base.Program()
start = base.Program()
with base.program_guard(main, start):
main = base.static.Program()
start = base.static.Program()
with base.static.program_guard(main, start):
x = paddle.static.data("x", self.x.shape, dtype=self.x.dtype)

weight = paddle.static.data(
Expand Down

0 comments on commit 646454b

Please sign in to comment.