Skip to content

Commit

Permalink
Tensordot testing with and without hardcoded types.
Browse files Browse the repository at this point in the history
  • Loading branch information
alxmrs committed Jul 31, 2024
1 parent a8e9998 commit e78cd21
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,13 @@ def test_outer(spec, executor):


@pytest.mark.parametrize("axes", [1, (1, 0)])
def test_tensordot(axes):
x = np.arange(400, dtype=np.float32).reshape((20, 20))
a = xp.asarray(x, chunks=(5, 4), dtype=xp.float32)
y = np.arange(200, dtype=np.float32).reshape((20, 10))
b = xp.asarray(y, chunks=(4, 5), dtype=xp.float32)
@pytest.mark.parametrize("dtypes", [(None, None), (np.float32, xp.float32)])
def test_tensordot(axes, dtypes):
ntype, xtype = dtypes
x = np.arange(400, dtype=ntype).reshape((20, 20))
a = xp.asarray(x, chunks=(5, 4), dtype=xtype)
y = np.arange(200, dtype=ntype).reshape((20, 10))
b = xp.asarray(y, chunks=(4, 5), dtype=xtype)
assert_array_equal(
xp.tensordot(a, b, axes=axes).compute(), np.tensordot(x, y, axes=axes)
)
Expand Down

0 comments on commit e78cd21

Please sign in to comment.