Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dpnp.linalg.tensorinv() implementation #1752

Merged
merged 4 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"solve",
"svd",
"slogdet",
"tensorinv",
]


Expand Down Expand Up @@ -897,3 +898,61 @@ def slogdet(a):
check_stacked_square(a)

return dpnp_slogdet(a)


def tensorinv(a, ind=2):
"""
Compute the `inverse` of a tensor.
antonwolfy marked this conversation as resolved.
Show resolved Hide resolved

For full documentation refer to :obj:`numpy.linalg.tensorinv`.

Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
Tensor to `invert`. Its shape must be 'square', i. e.,
``prod(a.shape[:ind]) == prod(a.shape[ind:])``.
ind : int
antonwolfy marked this conversation as resolved.
Show resolved Hide resolved
Number of first indices that are involved in the inverse sum.
Must be a positive integer.
Default: 2.

Returns
-------
out : dpnp.ndarray
The inverse of a tensor whose shape is equivalent to
``a.shape[ind:] + a.shape[:ind]``.

See Also
--------
:obj:`dpnp.linalg.tensordot` : Compute tensor dot product along specified axes.
:obj:`dpnp.linalg.tensorsolve` : Solve the tensor equation ``a x = b`` for x.

Examples
--------
>>> import dpnp as np
>>> a = np.eye(4*6)
>>> a.shape = (4, 6, 8, 3)
>>> ainv = np.linalg.tensorinv(a, ind=2)
>>> ainv.shape
(8, 3, 4, 6)

>>> a = np.eye(4*6)
>>> a.shape = (24, 8, 3)
>>> ainv = np.linalg.tensorinv(a, ind=1)
>>> ainv.shape
(8, 3, 24)

"""

dpnp.check_supported_arrays_type(a)

if ind <= 0:
raise ValueError("Invalid ind argument")

old_shape = a.shape
inv_shape = old_shape[ind:] + old_shape[:ind]
prod = numpy.prod(old_shape[ind:])
a = a.reshape(prod, -1)
a_inv = inv(a)

return a_inv.reshape(*inv_shape)
39 changes: 39 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,3 +1409,42 @@ def test_pinv_errors(self):
a_dp_q = inp.array(a_dp, sycl_queue=a_queue)
rcond_dp_q = inp.array([0.5], dtype="float32", sycl_queue=rcond_queue)
assert_raises(ValueError, inp.linalg.pinv, a_dp_q, rcond_dp_q)


class TestTensorinv:
@pytest.mark.parametrize("dtype", get_all_dtypes())
@pytest.mark.parametrize(
"shape, ind",
[
((4, 6, 8, 3), 2),
((24, 8, 3), 1),
],
ids=[
"(4, 6, 8, 3)",
"(24, 8, 3)",
],
)
def test_tensorinv(self, dtype, shape, ind):
a = numpy.eye(24, dtype=dtype).reshape(shape)
a_dp = inp.array(a)

ainv = numpy.linalg.tensorinv(a, ind=ind)
ainv_dp = inp.linalg.tensorinv(a_dp, ind=ind)

assert ainv.shape == ainv_dp.shape
assert_dtype_allclose(ainv_dp, ainv)

def test_test_tensorinv_errors(self):
a_dp = inp.eye(24, dtype="float32").reshape(4, 6, 8, 3)

# unsupported type `a`
a_np = inp.asnumpy(a_dp)
assert_raises(TypeError, inp.linalg.pinv, a_np)

# unsupported type `ind`
assert_raises(TypeError, inp.linalg.tensorinv, a_dp, 2.0)
assert_raises(TypeError, inp.linalg.tensorinv, a_dp, [2.0])
assert_raises(ValueError, inp.linalg.tensorinv, a_dp, -1)

# non-square
assert_raises(inp.linalg.LinAlgError, inp.linalg.tensorinv, a_dp, 1)
18 changes: 18 additions & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,3 +1849,21 @@ def test_pinv(shape, hermitian, rcond_as_array, device):
B_queue = B_result.sycl_queue

assert_sycl_queue_equal(B_queue, a_dp.sycl_queue)


@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_tensorinv(device):
a_np = numpy.eye(12).reshape(12, 4, 3)
a_dp = dpnp.array(a_np, device=device)

result = dpnp.linalg.tensorinv(a_dp, ind=1)
expected = numpy.linalg.tensorinv(a_np, ind=1)
assert_dtype_allclose(result, expected)

result_queue = result.sycl_queue

assert_sycl_queue_equal(result_queue, a_dp.sycl_queue)
8 changes: 8 additions & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,3 +1014,11 @@ def test_qr(shape, mode, usm_type):

assert a.usm_type == dp_q.usm_type
assert a.usm_type == dp_r.usm_type


@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_tensorinv(usm_type):
a = dp.eye(12, usm_type=usm_type).reshape(12, 4, 3)
ainv = dp.linalg.tensorinv(a, ind=1)

assert a.usm_type == ainv.usm_type
44 changes: 44 additions & 0 deletions tests/third_party/cupy/linalg_tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,47 @@ def test_pinv_size_0(self):
self.check_x((0, 0), rcond=1e-15)
self.check_x((0, 2, 3), rcond=1e-15)
self.check_x((2, 0, 3), rcond=1e-15)


class TestTensorInv(unittest.TestCase):
@testing.for_dtypes("ifdFD")
@_condition.retry(10)
def check_x(self, a_shape, ind, dtype):
a_cpu = numpy.random.randint(0, 10, size=a_shape).astype(dtype)
a_gpu = cupy.asarray(a_cpu)
a_gpu_copy = a_gpu.copy()
result_cpu = numpy.linalg.tensorinv(a_cpu, ind=ind)
result_gpu = cupy.linalg.tensorinv(a_gpu, ind=ind)
assert_dtype_allclose(result_gpu, result_cpu)
testing.assert_array_equal(a_gpu_copy, a_gpu)

def check_shape(self, a_shape, ind):
a = cupy.random.rand(*a_shape)
with self.assertRaises(
(numpy.linalg.LinAlgError, cupy.linalg.LinAlgError)
):
cupy.linalg.tensorinv(a, ind=ind)

def check_ind(self, a_shape, ind):
a = cupy.random.rand(*a_shape)
with self.assertRaises(ValueError):
cupy.linalg.tensorinv(a, ind=ind)

def test_tensorinv(self):
self.check_x((12, 3, 4), ind=1)
self.check_x((3, 8, 24), ind=2)
self.check_x((18, 3, 3, 2), ind=1)
self.check_x((1, 4, 2, 2), ind=2)
self.check_x((2, 3, 5, 30), ind=3)
self.check_x((24, 2, 2, 3, 2), ind=1)
self.check_x((3, 4, 2, 3, 2), ind=2)
self.check_x((1, 2, 3, 2, 3), ind=3)
self.check_x((3, 2, 1, 2, 12), ind=4)

def test_invalid_shape(self):
self.check_shape((2, 3, 4), ind=1)
self.check_shape((1, 2, 3, 4), ind=3)

def test_invalid_index(self):
self.check_ind((12, 3, 4), ind=-1)
self.check_ind((18, 3, 3, 2), ind=0)
Loading