Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[API STD][LINALG] Standardize sort & linalg operators #20694

Merged
merged 9 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 32 additions & 24 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,12 +1590,15 @@ def any(a, axis=None, out=None, keepdims=False):


@set_module('mxnet.ndarray.numpy')
def argsort(a, axis=-1, kind=None, order=None):
def argsort(a, axis=-1, descending=False, stable=True):
"""
Returns the indices that would sort an array.
Perform an indirect sort along the given axis using the algorithm specified
by the `kind` keyword. It returns an array of indices of the same shape as
`a` that index data along the given axis in sorted order.
Returns the indices that sort an array `x` along a specified axis.

Notes
-----
`argsort` is a standard API in
https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#argsort-x-axis-1-descending-false-stable-true
instead of an official NumPy operator.

Parameters
----------
Expand All @@ -1604,11 +1607,13 @@ def argsort(a, axis=-1, kind=None, order=None):
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
descending : bool, optional
sort order. If `True`, the returned indices sort x in descending order (by value).
If `False`, the returned indices sort x in ascending order (by value).Default: False.
stable : bool, optional
sort stability. If `True`, the returned indices must maintain the relative order
of x values which compare as equal. If `False`, the returned indices may or may not
maintain the relative order of x values which compare as equal. Default: True.

Returns
-------
Expand Down Expand Up @@ -1659,29 +1664,34 @@ def argsort(a, axis=-1, kind=None, order=None):
>>> x[ind] # same as np.sort(x, axis=None)
array([0, 2, 2, 3])
"""
if order is not None:
raise NotImplementedError("order not supported here")

return _api_internal.argsort(a, axis, True, 'int64')
return _api_internal.argsort(a, axis, not descending, 'int64')


@set_module('mxnet.ndarray.numpy')
def sort(a, axis=-1, kind=None, order=None):
def sort(a, axis=-1, descending=False, stable=True):
"""
Return a sorted copy of an array.

Notes
-----
`sort` is a standard API in
https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#sort-x-axis-1-descending-false-stable-true
instead of an official NumPy operator.

Parameters
----------
a : ndarray
Array to be sorted.
Array to sort.
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
descending : bool, optional
sort order. If `True`, the returned indices sort x in descending order (by value).
If `False`, the returned indices sort x in ascending order (by value).Default: False.
stable : bool, optional
sort stability. If `True`, the returned indices must maintain the relative order
of x values which compare as equal. If `False`, the returned indices may or may not
maintain the relative order of x values which compare as equal. Default: True.

Returns
-------
Expand All @@ -1704,9 +1714,7 @@ def sort(a, axis=-1, kind=None, order=None):
array([[1, 1],
[3, 4]])
"""
if order is not None:
raise NotImplementedError("order not supported here")
return _api_internal.sort(a, axis, True)
return _api_internal.sort(a, axis, not descending)

@set_module('mxnet.ndarray.numpy')
def dot(a, b, out=None):
Expand Down
14 changes: 12 additions & 2 deletions python/mxnet/ndarray/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,16 @@ def svd(a):
return tuple(_api_internal.svd(a))


def cholesky(a):
def cholesky(a, upper=False):
r"""
Cholesky decomposition.

Notes
-----
`upper` param is requested by API standardization in
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-cholesky-x-upper-false
instead of parameter in official NumPy operator.

Return the Cholesky decomposition, `L * L.T`, of the square matrix `a`,
where `L` is lower-triangular and .T is the transpose operator. `a` must be
symmetric and positive-definite. Only `L` is actually returned. Complex-valued
Expand All @@ -463,6 +469,10 @@ def cholesky(a):
----------
a : (..., M, M) ndarray
Symmetric, positive-definite input matrix.
upper : bool
If `True`, the result must be the upper-triangular Cholesky factor.
If `False`, the result must be the lower-triangular Cholesky factor.
Default: `False`.

Returns
-------
Expand Down Expand Up @@ -506,7 +516,7 @@ def cholesky(a):
array([[16., 4.],
[ 4., 10.]])
"""
return _api_internal.cholesky(a, True)
return _api_internal.cholesky(a, not upper)


def qr(a, mode='reduced'):
Expand Down
46 changes: 32 additions & 14 deletions python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,28 @@
__all__ += fallback_linalg.__all__


def matrix_rank(M, tol=None, hermitian=False):
@wrap_data_api_linalg_func
def matrix_rank(M, rtol=None, hermitian=False):
r"""
Return matrix rank of array using SVD method

Rank of the array is the number of singular values of the array that are
greater than `tol`.
greater than `rtol`.

Notes
-----
`matrix_rank` is an alias for `matrix_rank`. It is a standard API in
`rtol` param is requested in array-api-standard in
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-matrix-rank-x-rtol-none
instead of an official NumPy operator.
instead of a parameter in official NumPy operator.

Parameters
----------
M : {(M,), (..., M, N)} ndarray
Input vector or stack of matrices.
tol : (...) ndarray, float, optional
Threshold below which SVD values are considered zero. If `tol` is
rtol : (...) ndarray, float, optional
Threshold below which SVD values are considered zero. If `rtol` is
None, and ``S`` is an array with singular values for `M`, and
``eps`` is the epsilon value for datatype of ``S``, then `tol` is
``eps`` is the epsilon value for datatype of ``S``, then `rtol` is
set to ``S.max() * max(M.shape) * eps``.
hermitian : bool, optional
If True, `M` is assumed to be Hermitian (symmetric if real-valued),
Expand All @@ -73,7 +74,7 @@ def matrix_rank(M, tol=None, hermitian=False):
>>> np.linalg.matrix_rank(np.zeros((4,)))
0
"""
return _mx_nd_np.linalg.matrix_rank(M, tol, hermitian)
return _mx_nd_np.linalg.matrix_rank(M, rtol, hermitian)


def matrix_transpose(a):
Expand Down Expand Up @@ -502,22 +503,29 @@ def lstsq(a, b, rcond='warn'):
return _mx_nd_np.linalg.lstsq(a, b, rcond)


def pinv(a, rcond=1e-15, hermitian=False):
@wrap_data_api_linalg_func
def pinv(a, rtol=None, hermitian=False):
r"""
Compute the (Moore-Penrose) pseudo-inverse of a matrix.

Calculate the generalized inverse of a matrix using its
singular-value decomposition (SVD) and including all
*large* singular values.

Notes
-----
`rtol` param is requested in array-api-standard in
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-pinv-x-rtol-none
instead of a parameter in official NumPy operator.

Parameters
----------
a : (..., M, N) ndarray
Matrix or stack of matrices to be pseudo-inverted.
rcond : (...) {float or ndarray of float}, optional
rtol : (...) {float or ndarray of float}, optional
Cutoff for small singular values.
Singular values less than or equal to
``rcond * largest_singular_value`` are set to zero.
``rtol * largest_singular_value`` are set to zero.
Broadcasts against the stack of matrices.
hermitian : bool, optional
If True, `a` is assumed to be Hermitian (symmetric if real-valued),
Expand Down Expand Up @@ -567,7 +575,7 @@ def pinv(a, rcond=1e-15, hermitian=False):
>>> (pinv_a - np.dot(pinv_a, np.dot(a, pinv_a))).sum()
array(0.)
"""
return _mx_nd_np.linalg.pinv(a, rcond, hermitian)
return _mx_nd_np.linalg.pinv(a, rtol, hermitian)


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -732,10 +740,16 @@ def svdvals(a):
return s


def cholesky(a):
def cholesky(a, upper=False):
r"""
Cholesky decomposition.

Notes
-----
`upper` param is requested by API standardization in
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-cholesky-x-upper-false
instead of parameter in official NumPy operator.

Return the Cholesky decomposition, `L * L.T`, of the square matrix `a`,
where `L` is lower-triangular and .T is the transpose operator. `a` must be
symmetric and positive-definite. Only `L` is actually returned. Complex-valued
Expand All @@ -745,6 +759,10 @@ def cholesky(a):
----------
a : (..., M, M) ndarray
Symmetric, positive-definite input matrix.
upper : bool
If `True`, the result must be the upper-triangular Cholesky factor.
If `False`, the result must be the lower-triangular Cholesky factor.
Default: `False`.

Returns
-------
Expand Down Expand Up @@ -788,7 +806,7 @@ def cholesky(a):
array([[16., 4.],
[ 4., 10.]])
"""
return _mx_nd_np.linalg.cholesky(a)
return _mx_nd_np.linalg.cholesky(a, upper)


def qr(a, mode='reduced'):
Expand Down
66 changes: 42 additions & 24 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
from ..runtime import Features
from ..context import Context
from ..util import set_module, wrap_np_unary_func, wrap_np_binary_func,\
is_np_default_dtype, wrap_data_api_statical_func
is_np_default_dtype, wrap_data_api_statical_func,\
wrap_sort_functions
from ..context import current_context
from ..ndarray import numpy as _mx_nd_np
from ..ndarray.numpy import _internal as _npi
Expand Down Expand Up @@ -1908,13 +1909,13 @@ def pick(self, *args, **kwargs):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute pick')

def sort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ
def sort(self, axis=-1, descending=False, stable=True): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`sort`.

The arguments are the same as for :py:func:`sort`, with
this array as data.
"""
raise sort(self, axis=axis, kind=kind, order=order)
return sort(self, axis=axis, descending=descending, stable=stable)

def topk(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`topk`.
Expand All @@ -1924,13 +1925,13 @@ def topk(self, *args, **kwargs):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute topk')

def argsort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ
def argsort(self, axis=-1, descending=False, stable=True): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`argsort`.

The arguments are the same as for :py:func:`argsort`, with
this array as data.
"""
return argsort(self, axis=axis, kind=kind, order=order)
return argsort(self, axis=axis, descending=descending, stable=stable)

def argmax_channel(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`argmax_channel`.
Expand Down Expand Up @@ -5831,12 +5832,16 @@ def arctanh(x, out=None, **kwargs):


@set_module('mxnet.numpy')
def argsort(a, axis=-1, kind=None, order=None):
@wrap_sort_functions
def argsort(a, axis=-1, descending=False, stable=True):
"""
Returns the indices that would sort an array.
Perform an indirect sort along the given axis using the algorithm specified
by the `kind` keyword. It returns an array of indices of the same shape as
`a` that index data along the given axis in sorted order.
Returns the indices that sort an array `x` along a specified axis.

Notes
-----
`argsort` is a standard API in
https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#argsort-x-axis-1-descending-false-stable-true
instead of an official NumPy operator.

Parameters
----------
Expand All @@ -5845,11 +5850,13 @@ def argsort(a, axis=-1, kind=None, order=None):
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
descending : bool, optional
sort order. If `True`, the returned indices sort x in descending order (by value).
If `False`, the returned indices sort x in ascending order (by value).Default: False.
stable : bool, optional
sort stability. If `True`, the returned indices must maintain the relative order
of x values which compare as equal. If `False`, the returned indices may or may not
maintain the relative order of x values which compare as equal. Default: True.

Returns
-------
Expand Down Expand Up @@ -5900,26 +5907,37 @@ def argsort(a, axis=-1, kind=None, order=None):
>>> x[ind] # same as np.sort(x, axis=None)
array([0, 2, 2, 3])
"""
return _mx_nd_np.argsort(a, axis=axis, kind=kind, order=order)
if stable:
warnings.warn("Currently, MXNet only support quicksort in backend, which is not stable")
return _mx_nd_np.argsort(a, axis=axis, descending=descending)


@set_module('mxnet.numpy')
def sort(a, axis=-1, kind=None, order=None):
@wrap_sort_functions
def sort(a, axis=-1, descending=False, stable=True):
"""
Return a sorted copy of an array.

Notes
-----
`sort` is a standard API in
https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#sort-x-axis-1-descending-false-stable-true
instead of an official NumPy operator.

Parameters
----------
a : ndarray
Array to be sorted.
Array to sort.
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
descending : bool, optional
sort order. If `True`, the returned indices sort x in descending order (by value).
If `False`, the returned indices sort x in ascending order (by value).Default: False.
stable : bool, optional
sort stability. If `True`, the returned indices must maintain the relative order
of x values which compare as equal. If `False`, the returned indices may or may not
maintain the relative order of x values which compare as equal. Default: True.

Returns
-------
Expand All @@ -5942,7 +5960,7 @@ def sort(a, axis=-1, kind=None, order=None):
array([[1, 1],
[3, 4]])
"""
return _mx_nd_np.sort(a, axis=axis, kind=kind, order=order)
return _mx_nd_np.sort(a, axis=axis, descending=descending)


@set_module('mxnet.numpy')
Expand Down
Loading