Skip to content

Commit

Permalink
Switch to numpy assert_allclose
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Oct 12, 2024
1 parent 035139e commit abc44ab
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2511,23 +2511,23 @@ def test_local_sum_prod_all_to_none(self):
# test sum
f = function([a], a.sum(), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.sum())
np.testing.assert_allclose(f(input), input.sum())
# test prod
f = function([a], a.prod(), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.prod())
np.testing.assert_allclose(f(input), input.prod())
# test sum
f = function([a], a.sum([0, 1, 2]), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.sum())
np.testing.assert_allclose(f(input), input.sum())
# test prod
f = function([a], a.prod([0, 1, 2]), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.prod())
np.testing.assert_allclose(f(input), input.prod())

f = function([a], a.sum(0).sum(0).sum(0), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.sum())
np.testing.assert_allclose(f(input), input.sum())

def test_local_sum_sum_prod_prod(self):
a = tensor3()
Expand Down Expand Up @@ -2582,54 +2582,54 @@ def my_sum_prod(data, d, dd):
for d, dd in dims:
expected = my_sum(input, d, dd)
f = function([a], a.sum(d).sum(dd), mode=self.mode)
utt.assert_allclose(f(input), expected)
np.testing.assert_allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims[:6]:
f = function([a], a.sum(d).sum(dd).sum(0), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).sum(dd).sum(0))
np.testing.assert_allclose(f(input), input.sum(d).sum(dd).sum(0))
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = function([a], a.sum(d).sum(None), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).sum())
np.testing.assert_allclose(f(input), input.sum(d).sum())
assert len(f.maker.fgraph.apply_nodes) == 1
f = function([a], a.sum(None).sum(), mode=self.mode)
utt.assert_allclose(f(input), input.sum())
np.testing.assert_allclose(f(input), input.sum())
assert len(f.maker.fgraph.apply_nodes) == 1

# test prod
for d, dd in dims:
expected = my_prod(input, d, dd)
f = function([a], a.prod(d).prod(dd), mode=self.mode)
utt.assert_allclose(f(input), expected)
np.testing.assert_allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims[:6]:
f = function([a], a.prod(d).prod(dd).prod(0), mode=self.mode)
utt.assert_allclose(f(input), input.prod(d).prod(dd).prod(0))
np.testing.assert_allclose(f(input), input.prod(d).prod(dd).prod(0))
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = function([a], a.prod(d).prod(None), mode=self.mode)
utt.assert_allclose(f(input), input.prod(d).prod())
np.testing.assert_allclose(f(input), input.prod(d).prod())
assert len(f.maker.fgraph.apply_nodes) == 1
f = function([a], a.prod(None).prod(), mode=self.mode)
utt.assert_allclose(f(input), input.prod())
np.testing.assert_allclose(f(input), input.prod())
assert len(f.maker.fgraph.apply_nodes) == 1

# Test that sum prod didn't get rewritten.
for d, dd in dims:
expected = my_sum_prod(input, d, dd)
f = function([a], a.sum(d).prod(dd), mode=self.mode)
utt.assert_allclose(f(input), expected)
np.testing.assert_allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 2
for d, dd in dims[:6]:
f = function([a], a.sum(d).prod(dd).prod(0), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).prod(dd).prod(0))
np.testing.assert_allclose(f(input), input.sum(d).prod(dd).prod(0))
assert len(f.maker.fgraph.apply_nodes) == 2
for d in [0, 1, 2]:
f = function([a], a.sum(d).prod(None), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).prod())
np.testing.assert_allclose(f(input), input.sum(d).prod())
assert len(f.maker.fgraph.apply_nodes) == 2
f = function([a], a.sum(None).prod(), mode=self.mode)
utt.assert_allclose(f(input), input.sum())
np.testing.assert_allclose(f(input), input.sum())
assert len(f.maker.fgraph.apply_nodes) == 1

def test_local_sum_sum_int8(self):
Expand Down

0 comments on commit abc44ab

Please sign in to comment.