Skip to content

Commit

Permalink
[PIR]Migrate any into pir (PaddlePaddle#58211)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored and jiahy0825 committed Oct 26, 2023
1 parent efc39c6 commit 263067b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
2 changes: 1 addition & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4833,7 +4833,7 @@ def any(x, axis=None, keepdim=False, name=None):
[True]])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.any(x, axis, keepdim)
else:
reduce_all, axis = _get_reduce_axis(axis, x)
Expand Down
37 changes: 22 additions & 15 deletions test/legacy_test/test_reduce_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def setUp(self):
self.attrs = {'reduce_all': True}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class TestAnyFloatOp(OpTest):
Expand All @@ -977,7 +977,7 @@ def setUp(self):
self.attrs = {'reduce_all': True}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class TestAnyIntOp(OpTest):
Expand All @@ -989,7 +989,7 @@ def setUp(self):
self.attrs = {'reduce_all': True}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class TestAnyOp_ZeroDim(OpTest):
Expand All @@ -1001,7 +1001,7 @@ def setUp(self):
self.attrs = {'dim': []}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class TestAny8DOp(OpTest):
Expand All @@ -1017,7 +1017,7 @@ def setUp(self):
self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class TestAnyOpWithDim(OpTest):
Expand All @@ -1029,7 +1029,7 @@ def setUp(self):
self.outputs = {'Out': self.inputs['X'].any(axis=1)}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class TestAny8DOpWithDim(OpTest):
Expand All @@ -1045,7 +1045,7 @@ def setUp(self):
self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class TestAnyOpWithKeepDim(OpTest):
Expand All @@ -1061,7 +1061,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class TestAny8DOpWithKeepDim(OpTest):
Expand All @@ -1081,7 +1081,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class TestAnyOpError(unittest.TestCase):
Expand Down Expand Up @@ -1814,21 +1814,25 @@ def setUp(self):
self.places.append(base.CUDAPlace(0))

def check_static_result(self, place):
with base.program_guard(base.Program(), base.Program()):
main = paddle.static.Program()
startup = paddle.static.Program()
with base.program_guard(main, startup):
input = paddle.static.data(name="input", shape=[4, 4], dtype="bool")
result = paddle.any(x=input)
input_np = np.random.randint(0, 2, [4, 4]).astype("bool")

exe = base.Executor(place)
fetches = exe.run(
base.default_main_program(),
main,
feed={"input": input_np},
fetch_list=[result],
)
self.assertTrue((fetches[0] == np.any(input_np)).all())

def check_static_float_result(self, place):
with base.program_guard(base.Program(), base.Program()):
main = paddle.static.Program()
startup = paddle.static.Program()
with base.program_guard(main, startup):
input = paddle.static.data(
name="input", shape=[4, 4], dtype="float"
)
Expand All @@ -1837,26 +1841,29 @@ def check_static_float_result(self, place):

exe = base.Executor(place)
fetches = exe.run(
base.default_main_program(),
main,
feed={"input": input_np},
fetch_list=[result],
)
self.assertTrue((fetches[0] == np.any(input_np)).all())

def check_static_int_result(self, place):
with base.program_guard(base.Program(), base.Program()):
main = paddle.static.Program()
startup = paddle.static.Program()
with base.program_guard(main, startup):
input = paddle.static.data(name="input", shape=[4, 4], dtype="int")
result = paddle.any(x=input)
input_np = np.random.randint(0, 2, [4, 4]).astype("int")

exe = base.Executor(place)
fetches = exe.run(
base.default_main_program(),
main,
feed={"input": input_np},
fetch_list=[result],
)
self.assertTrue((fetches[0] == np.any(input_np)).all())

@test_with_pir_api
def test_static(self):
for place in self.places:
self.check_static_result(place=place)
Expand Down

0 comments on commit 263067b

Please sign in to comment.