Skip to content

Commit

Permalink
[PIR]Migrate einsum_v2 into pir (#58501)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored Oct 31, 2023
1 parent 7d8053c commit 3f6163b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
4 changes: 2 additions & 2 deletions python/paddle/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from paddle import _C_ops

from ..base.data_feeder import check_type, check_variable_and_dtype
from ..base.framework import in_dygraph_mode
from ..base.framework import in_dynamic_or_pir_mode
from ..base.layer_helper import LayerHelper
from .linalg import matmul, transpose
from .manipulation import reshape, squeeze, unsqueeze
Expand Down Expand Up @@ -832,7 +832,7 @@ def gen_einsum_op(equation, *operands):
EinsumOp Python Interface:
"""

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
return _C_ops.einsum(operands, equation)[0]
else:
assert len(operands) <= 2, "Only support two operands in EinsumOp."
Expand Down
7 changes: 5 additions & 2 deletions test/legacy_test/test_einsum_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

os.environ['FLAGS_new_einsum'] = "1"

Expand Down Expand Up @@ -382,7 +383,7 @@ def check_output_equal(self, actual, expect, rtol=1.0e-5, atol=1.0e-8):
rtol=rtol,
atol=atol,
err_msg=error_msg.format(
paddle.get_device(), expect, actual, self.__class__.__name__
self._get_place(False), expect, actual, self.__class__.__name__
),
)

Expand Down Expand Up @@ -465,6 +466,7 @@ def test_sums(self):
self.check_output("i,ij->", y, x)
self.check_output("ij,i->", x, y)

@test_with_pir_api
def test_static_graph(self):
paddle.enable_static()
base = paddle.base
Expand Down Expand Up @@ -523,11 +525,12 @@ def setUp(self):
def tearDown(self):
paddle.disable_static()

@test_with_pir_api
def test_shape(self):
A = paddle.static.data(name='x', shape=[-1])
B = paddle.static.data(name='y', shape=[384])
C = paddle.einsum('i,d->id', A, B)
self.assertEqual(C.shape, (-1, 384))
self.assertEqual(tuple(C.shape), (-1, 384))


@unittest.skipIf(
Expand Down

0 comments on commit 3f6163b

Please sign in to comment.