Skip to content

Commit

Permalink
Revert tensordot change.
Browse files Browse the repository at this point in the history
  • Loading branch information
alxmrs committed Jul 31, 2024
1 parent e189c09 commit 9a7a419
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,13 +424,11 @@ def test_outer(spec, executor):


@pytest.mark.parametrize("axes", [1, (1, 0)])
@pytest.mark.parametrize("dtypes", [(None, None), (np.float32, xp.float32)])
def test_tensordot(axes, dtypes):
ntype, xtype = dtypes
x = np.arange(400, dtype=ntype).reshape((20, 20))
a = xp.asarray(x, chunks=(5, 4), dtype=xtype)
y = np.arange(200, dtype=ntype).reshape((20, 10))
b = xp.asarray(y, chunks=(4, 5), dtype=xtype)
def test_tensordot(axes):
x = np.arange(400).reshape((20, 20))
a = xp.asarray(x, chunks=(5, 4))
y = np.arange(200).reshape((20, 10))
b = xp.asarray(y, chunks=(4, 5))
assert_array_equal(
xp.tensordot(a, b, axes=axes).compute(), np.tensordot(x, y, axes=axes)
)
Expand Down

0 comments on commit 9a7a419

Please sign in to comment.