Skip to content

Commit

Permalink
Bring vecdot implementation in line with the one in array-api-compat (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Sep 27, 2024
1 parent b4e94b0 commit d386917
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions cubed/array_api/linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

from cubed.array_api.data_type_functions import result_type
from cubed.array_api.dtypes import _numeric_dtypes
from cubed.array_api.manipulation_functions import expand_dims
from cubed.array_api.manipulation_functions import (
broadcast_arrays,
expand_dims,
moveaxis,
)
from cubed.backend_array_api import namespace as nxp
from cubed.core import blockwise, reduction, squeeze

Expand Down Expand Up @@ -158,12 +162,21 @@ def _tensordot(a, b, axes):


def vecdot(x1, x2, /, *, axis=-1, use_new_impl=True, split_every=None):
# based on the implementation in array-api-compat
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in vecdot")
return tensordot(
x1,
x2,
axes=((axis,), (axis,)),

if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")

x1_ = moveaxis(x1, axis, -1)
x2_ = moveaxis(x2, axis, -1)
x1_, x2_ = broadcast_arrays(x1_, x2_)

res = matmul(
x1_[..., None, :],
x2_[..., None],
use_new_impl=use_new_impl,
split_every=split_every,
)
return res[..., 0, 0]

0 comments on commit d386917

Please sign in to comment.