diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index 8500b4d7..b8c68dbe 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -424,11 +424,13 @@ def test_outer(spec, executor): @pytest.mark.parametrize("axes", [1, (1, 0)]) -def test_tensordot(axes): - x = np.arange(400, dtype=np.float32).reshape((20, 20)) - a = xp.asarray(x, chunks=(5, 4), dtype=xp.float32) - y = np.arange(200, dtype=np.float32).reshape((20, 10)) - b = xp.asarray(y, chunks=(4, 5), dtype=xp.float32) +@pytest.mark.parametrize("dtypes", [(None, None), (np.float32, xp.float32)]) +def test_tensordot(axes, dtypes): + ntype, xtype = dtypes + x = np.arange(400, dtype=ntype).reshape((20, 20)) + a = xp.asarray(x, chunks=(5, 4), dtype=xtype) + y = np.arange(200, dtype=ntype).reshape((20, 10)) + b = xp.asarray(y, chunks=(4, 5), dtype=xtype) assert_array_equal( xp.tensordot(a, b, axes=axes).compute(), np.tensordot(x, y, axes=axes) )