From abc44ab1b0898360c9d74921cf8cee3d40aec08d Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sat, 12 Oct 2024 17:16:33 +0530 Subject: [PATCH] Switch to numpy assert_allclose --- tests/tensor/rewriting/test_math.py | 34 ++++++++++++++--------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 2a10279d62..6f5afe7779 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -2511,23 +2511,23 @@ def test_local_sum_prod_all_to_none(self): # test sum f = function([a], a.sum(), mode=self.mode) assert len(f.maker.fgraph.apply_nodes) == 1 - utt.assert_allclose(f(input), input.sum()) + np.testing.assert_allclose(f(input), input.sum()) # test prod f = function([a], a.prod(), mode=self.mode) assert len(f.maker.fgraph.apply_nodes) == 1 - utt.assert_allclose(f(input), input.prod()) + np.testing.assert_allclose(f(input), input.prod()) # test sum f = function([a], a.sum([0, 1, 2]), mode=self.mode) assert len(f.maker.fgraph.apply_nodes) == 1 - utt.assert_allclose(f(input), input.sum()) + np.testing.assert_allclose(f(input), input.sum()) # test prod f = function([a], a.prod([0, 1, 2]), mode=self.mode) assert len(f.maker.fgraph.apply_nodes) == 1 - utt.assert_allclose(f(input), input.prod()) + np.testing.assert_allclose(f(input), input.prod()) f = function([a], a.sum(0).sum(0).sum(0), mode=self.mode) assert len(f.maker.fgraph.apply_nodes) == 1 - utt.assert_allclose(f(input), input.sum()) + np.testing.assert_allclose(f(input), input.sum()) def test_local_sum_sum_prod_prod(self): a = tensor3() @@ -2582,54 +2582,54 @@ def my_sum_prod(data, d, dd): for d, dd in dims: expected = my_sum(input, d, dd) f = function([a], a.sum(d).sum(dd), mode=self.mode) - utt.assert_allclose(f(input), expected) + np.testing.assert_allclose(f(input), expected) assert len(f.maker.fgraph.apply_nodes) == 1 for d, dd in dims[:6]: f = function([a], a.sum(d).sum(dd).sum(0), mode=self.mode) - utt.assert_allclose(f(input), input.sum(d).sum(dd).sum(0)) + np.testing.assert_allclose(f(input), input.sum(d).sum(dd).sum(0)) assert len(f.maker.fgraph.apply_nodes) == 1 for d in [0, 1, 2]: f = function([a], a.sum(d).sum(None), mode=self.mode) - utt.assert_allclose(f(input), input.sum(d).sum()) + np.testing.assert_allclose(f(input), input.sum(d).sum()) assert len(f.maker.fgraph.apply_nodes) == 1 f = function([a], a.sum(None).sum(), mode=self.mode) - utt.assert_allclose(f(input), input.sum()) + np.testing.assert_allclose(f(input), input.sum()) assert len(f.maker.fgraph.apply_nodes) == 1 # test prod for d, dd in dims: expected = my_prod(input, d, dd) f = function([a], a.prod(d).prod(dd), mode=self.mode) - utt.assert_allclose(f(input), expected) + np.testing.assert_allclose(f(input), expected) assert len(f.maker.fgraph.apply_nodes) == 1 for d, dd in dims[:6]: f = function([a], a.prod(d).prod(dd).prod(0), mode=self.mode) - utt.assert_allclose(f(input), input.prod(d).prod(dd).prod(0)) + np.testing.assert_allclose(f(input), input.prod(d).prod(dd).prod(0)) assert len(f.maker.fgraph.apply_nodes) == 1 for d in [0, 1, 2]: f = function([a], a.prod(d).prod(None), mode=self.mode) - utt.assert_allclose(f(input), input.prod(d).prod()) + np.testing.assert_allclose(f(input), input.prod(d).prod()) assert len(f.maker.fgraph.apply_nodes) == 1 f = function([a], a.prod(None).prod(), mode=self.mode) - utt.assert_allclose(f(input), input.prod()) + np.testing.assert_allclose(f(input), input.prod()) assert len(f.maker.fgraph.apply_nodes) == 1 # Test that sum prod didn't get rewritten. for d, dd in dims: expected = my_sum_prod(input, d, dd) f = function([a], a.sum(d).prod(dd), mode=self.mode) - utt.assert_allclose(f(input), expected) + np.testing.assert_allclose(f(input), expected) assert len(f.maker.fgraph.apply_nodes) == 2 for d, dd in dims[:6]: f = function([a], a.sum(d).prod(dd).prod(0), mode=self.mode) - utt.assert_allclose(f(input), input.sum(d).prod(dd).prod(0)) + np.testing.assert_allclose(f(input), input.sum(d).prod(dd).prod(0)) assert len(f.maker.fgraph.apply_nodes) == 2 for d in [0, 1, 2]: f = function([a], a.sum(d).prod(None), mode=self.mode) - utt.assert_allclose(f(input), input.sum(d).prod()) + np.testing.assert_allclose(f(input), input.sum(d).prod()) assert len(f.maker.fgraph.apply_nodes) == 2 f = function([a], a.sum(None).prod(), mode=self.mode) - utt.assert_allclose(f(input), input.sum()) + np.testing.assert_allclose(f(input), input.sum()) assert len(f.maker.fgraph.apply_nodes) == 1 def test_local_sum_sum_int8(self):