From 2f7c324265ae58136022159b0fc8f227e3118d46 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Wed, 31 Jul 2024 20:38:40 +0100 Subject: [PATCH] Tensordot testing with and without hardcoded types. --- cubed/tests/test_array_api.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index 8500b4d7..b8c68dbe 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -424,11 +424,13 @@ def test_outer(spec, executor): @pytest.mark.parametrize("axes", [1, (1, 0)]) -def test_tensordot(axes): - x = np.arange(400, dtype=np.float32).reshape((20, 20)) - a = xp.asarray(x, chunks=(5, 4), dtype=xp.float32) - y = np.arange(200, dtype=np.float32).reshape((20, 10)) - b = xp.asarray(y, chunks=(4, 5), dtype=xp.float32) +@pytest.mark.parametrize("dtypes", [(None, None), (np.float32, xp.float32)]) +def test_tensordot(axes, dtypes): + ntype, xtype = dtypes + x = np.arange(400, dtype=ntype).reshape((20, 20)) + a = xp.asarray(x, chunks=(5, 4), dtype=xtype) + y = np.arange(200, dtype=ntype).reshape((20, 10)) + b = xp.asarray(y, chunks=(4, 5), dtype=xtype) assert_array_equal( xp.tensordot(a, b, axes=axes).compute(), np.tensordot(x, y, axes=axes) )