Skip to content

Commit

Permalink
Merge ed8e307 into 726738d
Browse files Browse the repository at this point in the history
  • Loading branch information
vlad-perevezentsev authored Mar 26, 2024
2 parents 726738d + ed8e307 commit 241888a
Show file tree
Hide file tree
Showing 9 changed files with 423 additions and 204 deletions.
2 changes: 1 addition & 1 deletion dpnp/backend/extensions/lapack/heevd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ extern std::pair<sycl::event, sycl::event>
const std::int8_t upper_lower,
dpctl::tensor::usm_ndarray eig_vecs,
dpctl::tensor::usm_ndarray eig_vals,
const std::vector<sycl::event> &depends);
const std::vector<sycl::event> &depends = {});

extern void init_heevd_dispatch_table(void);
} // namespace lapack
Expand Down
97 changes: 80 additions & 17 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"eig",
"eigh",
"eigvals",
"eigvalsh",
"inv",
"matrix_power",
"matrix_rank",
Expand Down Expand Up @@ -246,6 +247,19 @@ def eigh(a, UPLO="L"):
For full documentation refer to :obj:`numpy.linalg.eigh`.
Parameters
----------
a : (..., M, M) {dpnp.ndarray, usm_ndarray}
A complex- or real-valued array whose eigenvalues and eigenvectors are to be computed.
UPLO : {"L", "U"}, optional
Specifies the calculation uses either the lower ("L") or upper ("U")
triangular part of the matrix.
Regardless of this choice, only the real parts of the diagonal are
considered to preserve the Hermite matrix property.
It therefore follows that the imaginary part of the diagonal
will always be treated as zero.
Default: "L".
Returns
-------
w : (..., M) dpnp.ndarray
Expand All @@ -255,15 +269,13 @@ def eigh(a, UPLO="L"):
The column ``v[:, i]`` is the normalized eigenvector corresponding
to the eigenvalue ``w[i]``.
Limitations
-----------
Parameter `a` is supported as :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
Input array data types are limited by supported DPNP :ref:`Data types`.
See Also
--------
:obj:`dpnp.eig` : eigenvalues and right eigenvectors for non-symmetric arrays.
:obj:`dpnp.eigvals` : eigenvalues of non-symmetric arrays.
:obj:`dpnp.linalg.eigvalsh` : Compute the eigenvalues of a complex Hermitian or
real symmetric matrix.
:obj:`dpnp.linalg.eig` : Compute the eigenvalues and right eigenvectors of
a square array.
:obj:`dpnp.linalg.eigvals` : Compute the eigenvalues of a general matrix.
Examples
--------
Expand All @@ -281,20 +293,13 @@ def eigh(a, UPLO="L"):
"""

dpnp.check_supported_arrays_type(a)
check_stacked_2d(a)
check_stacked_square(a)

UPLO = UPLO.upper()
if UPLO not in ("L", "U"):
raise ValueError("UPLO argument must be 'L' or 'U'")

if a.ndim < 2:
raise ValueError(
"%d-dimensional array given. Array must be "
"at least two-dimensional" % a.ndim
)

m, n = a.shape[-2:]
if m != n:
raise ValueError("Last 2 dimensions of the array must be square")

return dpnp_eigh(a, UPLO=UPLO)


Expand Down Expand Up @@ -326,6 +331,64 @@ def eigvals(input):
return call_origin(numpy.linalg.eigvals, input)


def eigvalsh(a, UPLO="L"):
"""
eigvalsh(a, UPLO="L")
Compute the eigenvalues of a complex Hermitian or real symmetric matrix.
Main difference from :obj:`dpnp.linalg.eigh`: the eigenvectors are not computed.
For full documentation refer to :obj:`numpy.linalg.eigvalsh`.
Parameters
----------
a : (..., M, M) {dpnp.ndarray, usm_ndarray}
A complex- or real-valued array whose eigenvalues are to be computed.
UPLO : {"L", "U"}, optional
Specifies the calculation uses either the lower ("L") or upper ("U")
triangular part of the matrix.
Regardless of this choice, only the real parts of the diagonal are
considered to preserve the Hermite matrix property.
It therefore follows that the imaginary part of the diagonal
will always be treated as zero.
Default: "L".
Returns
-------
w : (..., M) dpnp.ndarray
The eigenvalues in ascending order, each repeated according to
its multiplicity.
See Also
--------
:obj:`dpnp.linalg.eigh` : Return the eigenvalues and eigenvectors of a complex Hermitian
(conjugate symmetric) or a real symmetric matrix.
:obj:`dpnp.linalg.eigvals` : Compute the eigenvalues of a general matrix.
:obj:`dpnp.linalg.eig` : Compute the eigenvalues and right eigenvectors of
a general matrix.
Examples
--------
>>> import dpnp as np
>>> from dpnp import linalg as LA
>>> a = np.array([[1, -2j], [2j, 5]])
>>> LA.eigvalsh(a)
array([0.17157288, 5.82842712])
"""

dpnp.check_supported_arrays_type(a)
check_stacked_2d(a)
check_stacked_square(a)

UPLO = UPLO.upper()
if UPLO not in ("L", "U"):
raise ValueError("UPLO argument must be 'L' or 'U'")

return dpnp_eigh(a, UPLO=UPLO, eigen_mode="N")


def inv(a):
"""
Compute the (multiplicative) inverse of a matrix.
Expand Down
148 changes: 91 additions & 57 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,72 +857,83 @@ def dpnp_det(a):
return det.reshape(shape)


def dpnp_eigh(a, UPLO):
def dpnp_eigh(a, UPLO, eigen_mode="V"):
"""
dpnp_eigh(a, UPLO)
dpnp_eigh(a, UPLO, eigen_mode="V")
Return the eigenvalues and eigenvectors of a complex Hermitian
(conjugate symmetric) or a real symmetric matrix.
Can return both eigenvalues and eigenvectors (`eigen_mode="V"`) or
only eigenvalues (`eigen_mode="N"`).
The main calculation is done by calling an extension function
for LAPACK library of OneMKL. Depending on input type of `a` array,
it will be either ``heevd`` (for complex types) or ``syevd`` (for others).
"""

a_usm_type = a.usm_type
a_sycl_queue = a.sycl_queue
a_order = "C" if a.flags.c_contiguous else "F"
a_usm_arr = dpnp.get_usm_ndarray(a)
# get resulting type of arrays with eigenvalues and eigenvectors
v_type = _common_type(a)
w_type = _real_type(v_type)

# 'V' means both eigenvectors and eigenvalues will be calculated
jobz = _jobz["V"]
if a.size == 0:
w = dpnp.empty_like(a, shape=a.shape[:-1], dtype=w_type)
if eigen_mode == "V":
v = dpnp.empty_like(a, dtype=v_type)
return w, v
return w

# `eigen_mode` can be either "N" or "V", specifying the computation mode
# for OneMKL LAPACK `syevd` and `heevd` routines.
# "V" (default) means both eigenvectors and eigenvalues will be calculated
# "N" means only eigenvalues will be calculated
jobz = _jobz[eigen_mode]
uplo = _upper_lower[UPLO]

# get resulting type of arrays with eigenvalues and eigenvectors
a_dtype = a.dtype
lapack_func = "_syevd"
if dpnp.issubdtype(a_dtype, dpnp.complexfloating):
lapack_func = "_heevd"
v_type = a_dtype
w_type = dpnp.float64 if a_dtype == dpnp.complex128 else dpnp.float32
elif dpnp.issubdtype(a_dtype, dpnp.floating):
v_type = w_type = a_dtype
elif a_sycl_queue.sycl_device.has_aspect_fp64:
v_type = w_type = dpnp.float64
else:
v_type = w_type = dpnp.float32
# Get LAPACK function (_syevd for real or _heevd for complex data types)
# to compute all eigenvalues and, optionally, all eigenvectors
lapack_func = (
"_heevd" if dpnp.issubdtype(v_type, dpnp.complexfloating) else "_syevd"
)

a_sycl_queue = a.sycl_queue
a_order = "C" if a.flags.c_contiguous else "F"

if a.ndim > 2:
w = dpnp.empty(
a.shape[:-1],
is_cpu_device = a.sycl_device.has_aspect_cpu
orig_shape = a.shape
# get 3d input array by reshape
a = a.reshape(-1, orig_shape[-2], orig_shape[-1])
a_usm_arr = dpnp.get_usm_ndarray(a)

# allocate a memory for dpnp array of eigenvalues
w = dpnp.empty_like(
a,
shape=orig_shape[:-1],
dtype=w_type,
usm_type=a_usm_type,
sycl_queue=a_sycl_queue,
)
w_orig_shape = w.shape
# get 2d dpnp array with eigenvalues by reshape
w = w.reshape(-1, w_orig_shape[-1])

# need to loop over the 1st dimension to get eigenvalues and eigenvectors of 3d matrix A
op_count = a.shape[0]
if op_count == 0:
return w, dpnp.empty_like(a, dtype=v_type)

eig_vecs = [None] * op_count
ht_copy_ev = [None] * op_count
ht_lapack_ev = [None] * op_count
for i in range(op_count):
batch_size = a.shape[0]
eig_vecs = [None] * batch_size
ht_list_ev = [None] * batch_size * 2
for i in range(batch_size):
# oneMKL LAPACK assumes fortran-like array as input, so
# allocate a memory with 'F' order for dpnp array of eigenvectors
eig_vecs[i] = dpnp.empty_like(a[i], order="F", dtype=v_type)

# use DPCTL tensor function to fill the array of eigenvectors with content of input array
ht_copy_ev[i], copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
ht_list_ev[2 * i], copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=a_usm_arr[i],
dst=eig_vecs[i].get_array(),
sycl_queue=a_sycl_queue,
)

# call LAPACK extension function to get eigenvalues and eigenvectors of a portion of matrix A
ht_lapack_ev[i], _ = getattr(li, lapack_func)(
ht_list_ev[2 * i + 1], _ = getattr(li, lapack_func)(
a_sycl_queue,
jobz,
uplo,
Expand All @@ -931,29 +942,52 @@ def dpnp_eigh(a, UPLO):
depends=[copy_ev],
)

for i in range(op_count):
ht_lapack_ev[i].wait()
ht_copy_ev[i].wait()
# TODO: Remove this w/a when MKLD-17201 is solved.
# Waiting for a host task executing an OneMKL LAPACK syevd call
# on CPU causes deadlock due to serialization of all host tasks
# in the queue.
# We need to wait for each host tasks before calling _seyvd again
# to avoid deadlock.
if lapack_func == "_syevd" and is_cpu_device:
ht_list_ev[2 * i + 1].wait()

dpctl.SyclEvent.wait_for(ht_list_ev)

w = w.reshape(w_orig_shape)

if eigen_mode == "V":
# combine the list of eigenvectors into a single array
v = dpnp.array(eig_vecs, order=a_order).reshape(orig_shape)
return w, v
return w

# combine the list of eigenvectors into a single array
v = dpnp.array(eig_vecs, order=a_order)
return w, v
else:
# oneMKL LAPACK assumes fortran-like array as input, so
# allocate a memory with 'F' order for dpnp array of eigenvectors
v = dpnp.empty_like(a, order="F", dtype=v_type)
a_usm_arr = dpnp.get_usm_ndarray(a)
ht_list_ev = []
copy_ev = dpctl.SyclEvent()

# use DPCTL tensor function to fill the array of eigenvectors with content of input array
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=a_usm_arr, dst=v.get_array(), sycl_queue=a_sycl_queue
)
# When `eigen_mode == "N"` (jobz == 0), OneMKL LAPACK does not overwrite the input array.
# If the input array 'a' is already F-contiguous and matches the target data type,
# we can avoid unnecessary memory allocation and data copying.
if eigen_mode == "N" and a_order == "F" and a.dtype == v_type:
v = a

else:
# oneMKL LAPACK assumes fortran-like array as input, so
# allocate a memory with 'F' order for dpnp array of eigenvectors
v = dpnp.empty_like(a, order="F", dtype=v_type)

# use DPCTL tensor function to fill the array of eigenvectors with content of input array
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=a_usm_arr, dst=v.get_array(), sycl_queue=a_sycl_queue
)
ht_list_ev.append(ht_copy_ev)

# allocate a memory for dpnp array of eigenvalues
w = dpnp.empty(
a.shape[:-1],
w = dpnp.empty_like(
a,
shape=a.shape[:-1],
dtype=w_type,
usm_type=a_usm_type,
sycl_queue=a_sycl_queue,
)

# call LAPACK extension function to get eigenvalues and eigenvectors of matrix A
Expand All @@ -965,8 +999,9 @@ def dpnp_eigh(a, UPLO):
w.get_array(),
depends=[copy_ev],
)
ht_list_ev.append(ht_lapack_ev)

if a_order != "F":
if eigen_mode == "V" and a_order != "F":
# need to align order of eigenvectors with one of input matrix A
out_v = dpnp.empty_like(v, order=a_order)
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
Expand All @@ -975,14 +1010,13 @@ def dpnp_eigh(a, UPLO):
sycl_queue=a_sycl_queue,
depends=[lapack_ev],
)
ht_copy_out_ev.wait()
ht_list_ev.append(ht_copy_out_ev)
else:
out_v = v

ht_lapack_ev.wait()
ht_copy_ev.wait()
dpctl.SyclEvent.wait_for(ht_list_ev)

return w, out_v
return (w, out_v) if eigen_mode == "V" else w


def dpnp_inv_batched(a, res_type):
Expand Down
Loading

0 comments on commit 241888a

Please sign in to comment.