Skip to content

Commit

Permalink
Add dpnp.linalg.tensorsolve() implementation (#1753)
Browse files Browse the repository at this point in the history
* add dpnp.linalg.tensorsolve impl

* Add tests for tensorsolve

* Add test_tensorsolve_axes

* Address remarks

* Address remarks for #1752
  • Loading branch information
vlad-perevezentsev authored Mar 23, 2024
1 parent e5d3127 commit 7ca1aff
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 2 deletions.
76 changes: 74 additions & 2 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"svd",
"slogdet",
"tensorinv",
"tensorsolve",
]


Expand Down Expand Up @@ -935,7 +936,7 @@ def slogdet(a):

def tensorinv(a, ind=2):
"""
Compute the `inverse` of a tensor.
Compute the 'inverse' of an N-dimensional array.
For full documentation refer to :obj:`numpy.linalg.tensorinv`.
Expand All @@ -944,7 +945,7 @@ def tensorinv(a, ind=2):
a : {dpnp.ndarray, usm_ndarray}
Tensor to `invert`. Its shape must be 'square', i. e.,
``prod(a.shape[:ind]) == prod(a.shape[ind:])``.
ind : int
ind : int, optional
Number of first indices that are involved in the inverse sum.
Must be a positive integer.
Default: 2.
Expand Down Expand Up @@ -989,3 +990,74 @@ def tensorinv(a, ind=2):
a_inv = inv(a)

return a_inv.reshape(*inv_shape)


def tensorsolve(a, b, axes=None):
"""
Solve the tensor equation ``a x = b`` for x.
For full documentation refer to :obj:`numpy.linalg.tensorsolve`.
Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
Coefficient tensor, of shape ``b.shape + Q``. `Q`, a tuple, equals
the shape of that sub-tensor of `a` consisting of the appropriate
number of its rightmost indices, and must be such that
``prod(Q) == prod(b.shape)`` (in which sense `a` is said to be
'square').
b : {dpnp.ndarray, usm_ndarray}
Right-hand tensor, which can be of any shape.
axes : tuple of ints, optional
Axes in `a` to reorder to the right, before inversion.
If ``None`` , no reordering is done.
Default: ``None``.
Returns
-------
out : dpnp.ndarray
The tensor with shape ``Q`` such that ``b.shape + Q == a.shape``.
See Also
--------
:obj:`dpnp.linalg.tensordot` : Compute tensor dot product along specified axes.
:obj:`dpnp.linalg.tensorinv` : Compute the 'inverse' of an N-dimensional array.
:obj:`dpnp.einsum` : Evaluates the Einstein summation convention on the operands.
Examples
--------
>>> import dpnp as np
>>> a = np.eye(2*3*4)
>>> a.shape = (2*3, 4, 2, 3, 4)
>>> b = np.random.randn(2*3, 4)
>>> x = np.linalg.tensorsolve(a, b)
>>> x.shape
(2, 3, 4)
>>> np.allclose(np.tensordot(a, x, axes=3), b)
array([ True])
"""

dpnp.check_supported_arrays_type(a, b)
a_ndim = a.ndim

if axes is not None:
all_axes = list(range(a_ndim))
for k in axes:
all_axes.remove(k)
all_axes.insert(a_ndim, k)
a = a.transpose(tuple(all_axes))

old_shape = a.shape[-(a_ndim - b.ndim) :]
prod = numpy.prod(old_shape)

if a.size != prod**2:
raise dpnp.linalg.LinAlgError(
"Input arrays must satisfy the requirement \
prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
)

a = a.reshape(-1, prod)
b = b.ravel()
res = solve(a, b)
return res.reshape(old_shape)
44 changes: 44 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,3 +1498,47 @@ def test_test_tensorinv_errors(self):

# non-square
assert_raises(inp.linalg.LinAlgError, inp.linalg.tensorinv, a_dp, 1)


class TestTensorsolve:
@pytest.mark.parametrize("dtype", get_all_dtypes())
@pytest.mark.parametrize(
"axes",
[None, (1,), (2,)],
ids=[
"None",
"(1,)",
"(2,)",
],
)
def test_tensorsolve_axes(self, dtype, axes):
a = numpy.eye(12).reshape(12, 3, 4).astype(dtype)
b = numpy.ones(a.shape[0], dtype=dtype)

a_dp = inp.array(a)
b_dp = inp.array(b)

res_np = numpy.linalg.tensorsolve(a, b, axes=axes)
res_dp = inp.linalg.tensorsolve(a_dp, b_dp, axes=axes)

assert res_np.shape == res_dp.shape
assert_dtype_allclose(res_dp, res_np)

def test_tensorsolve_errors(self):
a_dp = inp.eye(24, dtype="float32").reshape(4, 6, 8, 3)
b_dp = inp.ones(a_dp.shape[:2], dtype="float32")

# unsupported type `a` and `b`
a_np = inp.asnumpy(a_dp)
b_np = inp.asnumpy(b_dp)
assert_raises(TypeError, inp.linalg.tensorsolve, a_np, b_dp)
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, b_np)

# unsupported type `axes`
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, 2.0)
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, -2)

# incorrect axes
assert_raises(
inp.linalg.LinAlgError, inp.linalg.tensorsolve, a_dp, b_dp, (1,)
)
21 changes: 21 additions & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,3 +1891,24 @@ def test_tensorinv(device):
result_queue = result.sycl_queue

assert_sycl_queue_equal(result_queue, a_dp.sycl_queue)


@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_tensorsolve(device):
a_np = numpy.random.randn(3, 2, 6).astype(dpnp.default_float_type())
b_np = numpy.ones(a_np.shape[:2], dtype=a_np.dtype)

a_dp = dpnp.array(a_np, device=device)
b_dp = dpnp.array(b_np, device=device)

result = dpnp.linalg.tensorsolve(a_dp, b_dp)
expected = numpy.linalg.tensorsolve(a_np, b_np)
assert_dtype_allclose(result, expected)

result_queue = result.sycl_queue

assert_sycl_queue_equal(result_queue, a_dp.sycl_queue)
14 changes: 14 additions & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,3 +1035,17 @@ def test_tensorinv(usm_type):
ainv = dp.linalg.tensorinv(a, ind=1)

assert a.usm_type == ainv.usm_type


@pytest.mark.parametrize("usm_type_a", list_of_usm_types, ids=list_of_usm_types)
@pytest.mark.parametrize("usm_type_b", list_of_usm_types, ids=list_of_usm_types)
def test_tensorsolve(usm_type_a, usm_type_b):
data = numpy.random.randn(3, 2, 6)
a = dp.array(data, usm_type=usm_type_a)
b = dp.ones(a.shape[:2], dtype=a.dtype, usm_type=usm_type_b)

result = dp.linalg.tensorsolve(a, b)

assert a.usm_type == usm_type_a
assert b.usm_type == usm_type_b
assert result.usm_type == du.get_coerced_usm_type([usm_type_a, usm_type_b])
20 changes: 20 additions & 0 deletions tests/third_party/cupy/linalg_tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,26 @@ def test_invalid_shape(self):
self.check_shape((0, 3, 4), (3,), linalg_errors)


@testing.parameterize(
*testing.product(
{
"a_shape": [(2, 3, 6), (3, 4, 4, 3)],
"axes": [None, (0, 2)],
}
)
)
@testing.fix_random()
class TestTensorSolve(unittest.TestCase):
@testing.for_dtypes("ifdFD")
@testing.numpy_cupy_allclose(atol=0.02, type_check=has_support_aspect64())
def test_tensorsolve(self, xp, dtype):
a_shape = self.a_shape
b_shape = self.a_shape[:2]
a = testing.shaped_random(a_shape, xp, dtype=dtype, seed=0)
b = testing.shaped_random(b_shape, xp, dtype=dtype, seed=1)
return xp.linalg.tensorsolve(a, b, axes=self.axes)


@testing.parameterize(
*testing.product(
{
Expand Down

0 comments on commit 7ca1aff

Please sign in to comment.