From 8787cf9fbb9a169bab76fc724a5ef0386276a7d8 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Fri, 22 Oct 2021 16:25:55 -0700 Subject: [PATCH 1/5] [API] Standardize sort & linalg operators --- python/mxnet/ndarray/numpy/_op.py | 56 ++++--- python/mxnet/ndarray/numpy/linalg.py | 14 +- python/mxnet/numpy/linalg.py | 46 ++++-- python/mxnet/numpy/multiarray.py | 58 ++++--- python/mxnet/util.py | 36 ++++- src/api/operator/numpy/np_ordering_op.cc | 4 +- tests/python/unittest/test_numpy_op.py | 198 ++++++++++++----------- 7 files changed, 256 insertions(+), 156 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index cf1bd52c2228..c12cd143be11 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -1551,12 +1551,15 @@ def any(a, axis=None, out=None, keepdims=False): @set_module('mxnet.ndarray.numpy') -def argsort(a, axis=-1, kind=None, order=None): +def argsort(a, axis=-1, descending=False, kind=None, order=None): """ - Returns the indices that would sort an array. - Perform an indirect sort along the given axis using the algorithm specified - by the `kind` keyword. It returns an array of indices of the same shape as - `a` that index data along the given axis in sorted order. + Returns the indices that sort an array `x` along a specified axis. + + Notes + ----- + `argsort` is a standard API in + https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#argsort-x-axis-1-descending-false-stable-true + instead of an official NumPy operator. Parameters ---------- @@ -1565,11 +1568,13 @@ def argsort(a, axis=-1, kind=None, order=None): axis : int or None, optional Axis along which to sort. The default is -1 (the last axis). If None, the flattened array is used. - kind : string, optional - This argument can take any string, but it does not have any effect on the - final result. - order : str or list of str, optional - Not supported yet, will raise NotImplementedError if not None. + descending : bool, optional + sort order. If `True`, the returned indices sort x in descending order (by value). + If `False`, the returned indices sort x in ascending order (by value).Default: False. + stable : bool, optional + sort stability. If `True`, the returned indices must maintain the relative order + of x values which compare as equal. If `False`, the returned indices may or may not + maintain the relative order of x values which compare as equal. Default: True. Returns ------- @@ -1620,29 +1625,34 @@ def argsort(a, axis=-1, kind=None, order=None): >>> x[ind] # same as np.sort(x, axis=None) array([0, 2, 2, 3]) """ - if order is not None: - raise NotImplementedError("order not supported here") - - return _api_internal.argsort(a, axis, True, 'int64') + return _api_internal.argsort(a, axis, not descending, 'int64') @set_module('mxnet.ndarray.numpy') -def sort(a, axis=-1, kind=None, order=None): +def sort(a, axis=-1, descending=False, stable=True): """ Return a sorted copy of an array. + Notes + ----- + `sort` is a standard API in + https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#sort-x-axis-1-descending-false-stable-true + instead of an official NumPy operator. + Parameters ---------- a : ndarray - Array to be sorted. + Array to sort. axis : int or None, optional Axis along which to sort. The default is -1 (the last axis). If None, the flattened array is used. - kind : string, optional - This argument can take any string, but it does not have any effect on the - final result. - order : str or list of str, optional - Not supported yet, will raise NotImplementedError if not None. + descending : bool, optional + sort order. If `True`, the returned indices sort x in descending order (by value). + If `False`, the returned indices sort x in ascending order (by value).Default: False. + stable : bool, optional + sort stability. If `True`, the returned indices must maintain the relative order + of x values which compare as equal. If `False`, the returned indices may or may not + maintain the relative order of x values which compare as equal. Default: True. Returns ------- @@ -1665,9 +1675,7 @@ def sort(a, axis=-1, kind=None, order=None): array([[1, 1], [3, 4]]) """ - if order is not None: - raise NotImplementedError("order not supported here") - return _api_internal.sort(a, axis, True) + return _api_internal.sort(a, axis, not descending) @set_module('mxnet.ndarray.numpy') def dot(a, b, out=None): diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index 9d135248e490..76fa152c4bd4 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -450,10 +450,16 @@ def svd(a): return tuple(_api_internal.svd(a)) -def cholesky(a): +def cholesky(a, upper=False): r""" Cholesky decomposition. + Notes + ----- + `upper` param is requested by API standardization in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-cholesky-x-upper-false + instead of parameter in official NumPy operator. + Return the Cholesky decomposition, `L * L.T`, of the square matrix `a`, where `L` is lower-triangular and .T is the transpose operator. `a` must be symmetric and positive-definite. Only `L` is actually returned. Complex-valued @@ -463,6 +469,10 @@ def cholesky(a): ---------- a : (..., M, M) ndarray Symmetric, positive-definite input matrix. + upper : bool + If `True`, the result must be the upper-triangular Cholesky factor. + If `False`, the result must be the lower-triangular Cholesky factor. + Default: `False`. Returns ------- @@ -506,7 +516,7 @@ def cholesky(a): array([[16., 4.], [ 4., 10.]]) """ - return _api_internal.cholesky(a, True) + return _api_internal.cholesky(a, not upper) def qr(a, mode='reduced'): diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index a9c0f9b38313..cef98184ea9b 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -28,27 +28,28 @@ __all__ += fallback_linalg.__all__ -def matrix_rank(M, tol=None, hermitian=False): +@wrap_data_api_linalg_func +def matrix_rank(M, rtol=None, hermitian=False): r""" Return matrix rank of array using SVD method Rank of the array is the number of singular values of the array that are - greater than `tol`. + greater than `rtol`. Notes ----- - `matrix_rank` is an alias for `matrix_rank`. It is a standard API in + `rtol` param is requested in array-api-standard in https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-matrix-rank-x-rtol-none - instead of an official NumPy operator. + instead of a parameter in official NumPy operator. Parameters ---------- M : {(M,), (..., M, N)} ndarray Input vector or stack of matrices. - tol : (...) ndarray, float, optional - Threshold below which SVD values are considered zero. If `tol` is + rtol : (...) ndarray, float, optional + Threshold below which SVD values are considered zero. If `rtol` is None, and ``S`` is an array with singular values for `M`, and - ``eps`` is the epsilon value for datatype of ``S``, then `tol` is + ``eps`` is the epsilon value for datatype of ``S``, then `rtol` is set to ``S.max() * max(M.shape) * eps``. hermitian : bool, optional If True, `M` is assumed to be Hermitian (symmetric if real-valued), @@ -73,7 +74,7 @@ def matrix_rank(M, tol=None, hermitian=False): >>> np.linalg.matrix_rank(np.zeros((4,))) 0 """ - return _mx_nd_np.linalg.matrix_rank(M, tol, hermitian) + return _mx_nd_np.linalg.matrix_rank(M, rtol, hermitian) def matrix_transpose(a): @@ -498,7 +499,8 @@ def lstsq(a, b, rcond='warn'): return _mx_nd_np.linalg.lstsq(a, b, rcond) -def pinv(a, rcond=1e-15, hermitian=False): +@wrap_data_api_linalg_func +def pinv(a, rtol=None, hermitian=False): r""" Compute the (Moore-Penrose) pseudo-inverse of a matrix. @@ -506,14 +508,20 @@ def pinv(a, rcond=1e-15, hermitian=False): singular-value decomposition (SVD) and including all *large* singular values. + Notes + ----- + `rtol` param is requested in array-api-standard in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-pinv-x-rtol-none + instead of a parameter in official NumPy operator. + Parameters ---------- a : (..., M, N) ndarray Matrix or stack of matrices to be pseudo-inverted. - rcond : (...) {float or ndarray of float}, optional + rtol : (...) {float or ndarray of float}, optional Cutoff for small singular values. Singular values less than or equal to - ``rcond * largest_singular_value`` are set to zero. + ``rtol * largest_singular_value`` are set to zero. Broadcasts against the stack of matrices. hermitian : bool, optional If True, `a` is assumed to be Hermitian (symmetric if real-valued), @@ -563,7 +571,7 @@ def pinv(a, rcond=1e-15, hermitian=False): >>> (pinv_a - np.dot(pinv_a, np.dot(a, pinv_a))).sum() array(0.) """ - return _mx_nd_np.linalg.pinv(a, rcond, hermitian) + return _mx_nd_np.linalg.pinv(a, rtol, hermitian) def norm(x, ord=None, axis=None, keepdims=False): @@ -703,10 +711,16 @@ def svd(a): return _mx_nd_np.linalg.svd(a) -def cholesky(a): +def cholesky(a, upper=False): r""" Cholesky decomposition. + Notes + ----- + `upper` param is requested by API standardization in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-cholesky-x-upper-false + instead of parameter in official NumPy operator. + Return the Cholesky decomposition, `L * L.T`, of the square matrix `a`, where `L` is lower-triangular and .T is the transpose operator. `a` must be symmetric and positive-definite. Only `L` is actually returned. Complex-valued @@ -716,6 +730,10 @@ def cholesky(a): ---------- a : (..., M, M) ndarray Symmetric, positive-definite input matrix. + upper : bool + If `True`, the result must be the upper-triangular Cholesky factor. + If `False`, the result must be the lower-triangular Cholesky factor. + Default: `False`. Returns ------- @@ -759,7 +777,7 @@ def cholesky(a): array([[16., 4.], [ 4., 10.]]) """ - return _mx_nd_np.linalg.cholesky(a) + return _mx_nd_np.linalg.cholesky(a, upper) def qr(a, mode='reduced'): diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index a58f1faf5587..2affa21f2427 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -45,7 +45,8 @@ from ..runtime import Features from ..context import Context from ..util import set_module, wrap_np_unary_func, wrap_np_binary_func,\ - is_np_default_dtype, wrap_data_api_statical_func + is_np_default_dtype, wrap_data_api_statical_func,\ + wrap_sort_functions from ..context import current_context from ..ndarray import numpy as _mx_nd_np from ..ndarray.numpy import _internal as _npi @@ -5756,12 +5757,16 @@ def arctanh(x, out=None, **kwargs): @set_module('mxnet.numpy') -def argsort(a, axis=-1, kind=None, order=None): +@wrap_sort_functions +def argsort(a, axis=-1, descending=False, stable=True): """ - Returns the indices that would sort an array. - Perform an indirect sort along the given axis using the algorithm specified - by the `kind` keyword. It returns an array of indices of the same shape as - `a` that index data along the given axis in sorted order. + Returns the indices that sort an array `x` along a specified axis. + + Notes + ----- + `argsort` is a standard API in + https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#argsort-x-axis-1-descending-false-stable-true + instead of an official NumPy operator. Parameters ---------- @@ -5770,11 +5775,13 @@ def argsort(a, axis=-1, kind=None, order=None): axis : int or None, optional Axis along which to sort. The default is -1 (the last axis). If None, the flattened array is used. - kind : string, optional - This argument can take any string, but it does not have any effect on the - final result. - order : str or list of str, optional - Not supported yet, will raise NotImplementedError if not None. + descending : bool, optional + sort order. If `True`, the returned indices sort x in descending order (by value). + If `False`, the returned indices sort x in ascending order (by value).Default: False. + stable : bool, optional + sort stability. If `True`, the returned indices must maintain the relative order + of x values which compare as equal. If `False`, the returned indices may or may not + maintain the relative order of x values which compare as equal. Default: True. Returns ------- @@ -5825,26 +5832,37 @@ def argsort(a, axis=-1, kind=None, order=None): >>> x[ind] # same as np.sort(x, axis=None) array([0, 2, 2, 3]) """ - return _mx_nd_np.argsort(a, axis=axis, kind=kind, order=order) + if stable: + warnings.warn("Currently, MXNet only support quicksort in backend, which is not stable") + return _mx_nd_np.argsort(a, axis=axis, descending=descending) @set_module('mxnet.numpy') -def sort(a, axis=-1, kind=None, order=None): +@wrap_sort_functions +def sort(a, axis=-1, descending=False, stable=True): """ Return a sorted copy of an array. + Notes + ----- + `sort` is a standard API in + https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#sort-x-axis-1-descending-false-stable-true + instead of an official NumPy operator. + Parameters ---------- a : ndarray - Array to be sorted. + Array to sort. axis : int or None, optional Axis along which to sort. The default is -1 (the last axis). If None, the flattened array is used. - kind : string, optional - This argument can take any string, but it does not have any effect on the - final result. - order : str or list of str, optional - Not supported yet, will raise NotImplementedError if not None. + descending : bool, optional + sort order. If `True`, the returned indices sort x in descending order (by value). + If `False`, the returned indices sort x in ascending order (by value).Default: False. + stable : bool, optional + sort stability. If `True`, the returned indices must maintain the relative order + of x values which compare as equal. If `False`, the returned indices may or may not + maintain the relative order of x values which compare as equal. Default: True. Returns ------- @@ -5867,7 +5885,7 @@ def sort(a, axis=-1, kind=None, order=None): array([[1, 1], [3, 4]]) """ - return _mx_nd_np.sort(a, axis=axis, kind=kind, order=order) + return _mx_nd_np.sort(a, axis=axis, descending=descending) @set_module('mxnet.numpy') diff --git a/python/mxnet/util.py b/python/mxnet/util.py index 733d4843a76a..ed9affebdf23 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -682,17 +682,49 @@ def wrap_data_api_linalg_func(func): """ @functools.wraps(func) - def _wrap_api_creation_func(*args, **kwargs): + def _wrap_linalg_func(*args, **kwargs): if len(kwargs) != 0: upper = kwargs.pop('UPLO', None) + rcond = kwargs.pop('rcond', None) + tol = kwargs.pop('tol', None) if upper is not None: if upper == 'U': kwargs['upper'] = True else: kwargs['upper'] = False + if rcond is not None: + kwargs['rtol'] = rcond + if tol is not None: + kwargs['rtol'] = tol return func(*args, **kwargs) - return _wrap_api_creation_func + return _wrap_linalg_func + + +def wrap_sort_functions(func): + """A convenience decorator for wrapping sort functions + + Parameters + ---------- + func : a numpy-compatible array creation function to be wrapped for parameter keyword change. + + Returns + ------- + Function + A function wrapped with changed keywords. + """ + @functools.wraps(func) + def _wrap_sort_func(*args, **kwargs): + if len(kwargs) != 0: + kind = kwargs.pop('kind', None) + order = kwargs.pop('order', None) + if kind is not None: + kwargs['stable'] = True if kind == 'stable' else False + if order is not None: + raise NotImplementedError("order not supported here") + return func(*args, **kwargs) + return _wrap_sort_func + # pylint: disable=exec-used def numpy_fallback(func): diff --git a/src/api/operator/numpy/np_ordering_op.cc b/src/api/operator/numpy/np_ordering_op.cc index 11c00fbfb71e..627e450892af 100644 --- a/src/api/operator/numpy/np_ordering_op.cc +++ b/src/api/operator/numpy/np_ordering_op.cc @@ -39,7 +39,7 @@ MXNET_REGISTER_API("_npi.sort").set_body([](runtime::MXNetArgs args, runtime::MX } else { param.axis = args[1].operator int(); } - param.is_ascend = true; + param.is_ascend = args[2].operator bool(); attrs.parsed = std::move(param); attrs.op = op; @@ -65,7 +65,7 @@ MXNET_REGISTER_API("_npi.argsort") } else { param.axis = args[1].operator int(); } - param.is_ascend = true; + param.is_ascend = args[2].operator bool(); if (args[3].type_code() == kNull) { param.dtype = mshadow::kFloat32; } else { diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 1010475c605d..82171bfd31c4 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2175,41 +2175,43 @@ def forward(self, a): @use_np -def test_np_argsort(): +@pytest.mark.parametrize('descending', [True, False]) +@pytest.mark.parametrize('shape', [ + (), + (2, 3), + (1, 0, 2), +]) +def test_np_argsort(descending, shape): class TestArgsort(HybridBlock): - def __init__(self, axis): + def __init__(self, axis, descending): super(TestArgsort, self).__init__() self._axis = axis + self._descending = descending def forward(self, x): - return np.argsort(x, axis=self._axis) - - shapes = [ - (), - (2, 3), - (1, 0, 2), - ] + return np.argsort(x, axis=self._axis, descending=self._descending) - for shape in shapes: - data = np.random.uniform(size=shape) - np_data = data.asnumpy() - - for axis in [None] + [i for i in range(-len(shape), len(shape))]: + data = np.random.uniform(size=shape) + np_data = data.asnumpy() + for axis in [None] + [i for i in range(-len(shape), len(shape))]: + if descending: + np_out = onp.argsort(-np_data, axis) + else: np_out = onp.argsort(np_data, axis) - test_argsort = TestArgsort(axis) - for hybrid in [False, True]: - if hybrid: - test_argsort.hybridize() - mx_out = test_argsort(data) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False) - - mx_out = np.argsort(data, axis) + test_argsort = TestArgsort(axis, descending) + for hybrid in [False, True]: + if hybrid: + test_argsort.hybridize() + mx_out = test_argsort(data) assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False) + mx_out = np.argsort(data, axis, descending) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False) + @use_np -@pytest.mark.parametrize('kind', ['quicksort', 'mergesort', 'heapsort']) +@pytest.mark.parametrize('descending', [True, False]) @pytest.mark.parametrize('shape', [ (), (1,), @@ -2231,32 +2233,35 @@ def forward(self, x): ]) @pytest.mark.parametrize('dtype', [np.int8, np.uint8, np.int32, np.int64, np.float32, np.float64]) @pytest.mark.parametrize('hybridize', [True, False]) -def test_np_sort(kind, shape, dtype, hybridize): +def test_np_sort(shape, dtype, hybridize, descending): class TestSort(HybridBlock): - def __init__(self, axis, kind): + def __init__(self, axis, descending): super(TestSort, self).__init__() self._axis = axis - self._kind = kind + self._descending = descending - def forward(self, x, *args, **kwargs): - return np.sort(x, self._axis, self._kind) + def forward(self, x): + return np.sort(x, self._axis, descending=self._descending) a = np.random.uniform(low=0, high=100, size=shape, dtype='float64').astype(dtype) axis_list = list(range(len(shape))) axis_list.append(None) axis_list.append(-1) for axis in axis_list: - test = TestSort(axis, kind) + test = TestSort(axis, descending) if hybridize: test.hybridize() if axis == -1 and len(shape)==0: continue ret = test(a) - expected_ret = onp.sort(a.asnumpy(), axis, kind) + if descending: + expected_ret = -onp.sort(-a.asnumpy(), axis) + else: + expected_ret = onp.sort(a.asnumpy(), axis) assert_almost_equal(ret.asnumpy(), expected_ret, atol=1e-5, rtol=1e-5, use_broadcast=False) # check imperative again - ret = np.sort(a, axis, kind) + ret = np.sort(a, axis=axis, descending=descending) assert_almost_equal(ret.asnumpy(), expected_ret, atol=1e-5, rtol=1e-5, use_broadcast=False) @@ -6068,18 +6073,36 @@ def check_qr(q, r, a_np): @use_np -def test_np_linalg_cholesky(): +@pytest.mark.parametrize('shape', [ + (0, 0), + (1, 1), + (5, 5), + (6, 6), + (10, 10), + (6, 6, 6), + (1, 0, 0), + (0, 1, 1), + (2, 3, 4, 4), +]) +@pytest.mark.parametrize('dtype', ['float32', 'float64']) +@pytest.mark.parametrize('upper', [True, False]) +@pytest.mark.parametrize('hybridize', [True, False]) +def test_np_linalg_cholesky(shape, dtype, upper, hybridize): class TestCholesky(HybridBlock): - def __init__(self): + def __init__(self, upper=False): super(TestCholesky, self).__init__() + self._upper = upper def forward(self, data): - return np.linalg.cholesky(data) + return np.linalg.cholesky(data, upper=self._upper) - def get_grad(L): + def get_grad(L, upper): # shape of m is [batch, n, n] if 0 in L.shape: return L + + if upper: + L = onp.swapaxes(L, -1, -2) def copyltu(m): eye = onp.array([onp.eye(m.shape[-1]) for i in range(m.shape[0])]) @@ -6098,11 +6121,14 @@ def copyltu(m): dA = 0.5 * onp.matmul(onp.matmul(L_inv_T, copyltu(onp.matmul(L_T, dL))), L_inv) return dA.reshape(shape) - def check_cholesky(L, data_np): + def check_cholesky(L, data_np, upper): assert L.shape == data_np.shape # catch error if numpy throws rank < 2 try: - L_expected = onp.linalg.cholesky(data_np) + if upper: + L_expected = onp.swapaxes(onp.linalg.cholesky(data_np), -1, -2) + else: + L_expected = onp.linalg.cholesky(data_np) except Exception as e: print(data_np) print(data_np.shape) @@ -6127,64 +6153,52 @@ def newSymmetricPositiveDefineMatrix_nD(shape, ran=(0., 10.), max_cond=4): n = int(onp.prod(shape[:-2])) if len(shape) > 2 else 1 return onp.array([newSymmetricPositiveDefineMatrix_2D(shape[-2:], ran, max_cond) for i in range(n)]).reshape(shape) - shapes = [ - (0, 0), - (1, 1), - (5, 5), - (6, 6), - (10, 10), - (6, 6, 6), - (1, 0, 0), - (0, 1, 1), - (2, 3, 4, 4), - ] - dtypes = ['float32', 'float64'] - for hybridize, dtype, shape in itertools.product([True, False], dtypes, shapes): - rtol = 1e-3 - atol = 1e-5 - if dtype == 'float32': - rtol = 1e-2 - atol = 1e-4 - test_cholesky = TestCholesky() - if hybridize: - test_cholesky.hybridize() - - # Numerical issue: - # When backpropagating through Cholesky decomposition, we need to compute the inverse - # of L according to dA = 0.5 * L**(-T) * copyLTU(L**T * dL) * L**(-1) where A = LL^T. - # The inverse is calculated by "trsm" method in CBLAS. When the data type is float32, - # this causes numerical instability. It happens when the matrix is ill-conditioned. - # In this example, the issue occurs frequently if the symmetric positive definite input - # matrix A is constructed by A = LL^T + \epsilon * I. A proper way of testing such - # operators involving numerically unstable operations is to use well-conditioned random - # matrices as input. Here we test Cholesky decomposition for FP32 and FP64 separately. - # See rocBLAS: - # https://github.com/ROCmSoftwarePlatform/rocBLAS/wiki/9.Numerical-Stability-in-TRSM - - # generate symmetric PD matrices - if 0 in shape: - data_np = np.ones(shape) - else: - data_np = newSymmetricPositiveDefineMatrix_nD(shape) + rtol = 1e-3 + atol = 1e-5 + if dtype == 'float32': + rtol = 1e-2 + atol = 1e-4 - # When dtype is np.FP32, truncation from FP64 to FP32 could also be a source of - # instability since the ground-truth gradient is computed using FP64 data. - data = np.array(data_np, dtype=dtype) - data.attach_grad() - with mx.autograd.record(): - L = test_cholesky(data) + test_cholesky = TestCholesky(upper) + if hybridize: + test_cholesky.hybridize() + + # Numerical issue: + # When backpropagating through Cholesky decomposition, we need to compute the inverse + # of L according to dA = 0.5 * L**(-T) * copyLTU(L**T * dL) * L**(-1) where A = LL^T. + # The inverse is calculated by "trsm" method in CBLAS. When the data type is float32, + # this causes numerical instability. It happens when the matrix is ill-conditioned. + # In this example, the issue occurs frequently if the symmetric positive definite input + # matrix A is constructed by A = LL^T + \epsilon * I. A proper way of testing such + # operators involving numerically unstable operations is to use well-conditioned random + # matrices as input. Here we test Cholesky decomposition for FP32 and FP64 separately. + # See rocBLAS: + # https://github.com/ROCmSoftwarePlatform/rocBLAS/wiki/9.Numerical-Stability-in-TRSM + + # generate symmetric PD matrices + if 0 in shape: + data_np = np.ones(shape) + else: + data_np = newSymmetricPositiveDefineMatrix_nD(shape) - # check cholesky validity - check_cholesky(L, data_np) - # check backward. backward does not support empty input - if 0 not in L.shape: - mx.autograd.backward(L) - backward_expected = get_grad(L.asnumpy()) - assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) - # check imperative once again - L = np.linalg.cholesky(data) - check_cholesky(L, data_np) + # When dtype is np.FP32, truncation from FP64 to FP32 could also be a source of + # instability since the ground-truth gradient is computed using FP64 data. + data = np.array(data_np, dtype=dtype) + data.attach_grad() + with mx.autograd.record(): + L = test_cholesky(data) + + # check cholesky validity + check_cholesky(L, data_np, upper) + # check backward. backward does not support empty input + if 0 not in L.shape: + mx.autograd.backward(L) + backward_expected = get_grad(L.asnumpy(), upper) + assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + # check imperative once again + L = np.linalg.cholesky(data, upper=upper) + check_cholesky(L, data_np, upper) @use_np From c0c6334a3cd8744f48a6f8d4cecc875b722e5da7 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Fri, 22 Oct 2021 16:29:12 -0700 Subject: [PATCH 2/5] fix --- python/mxnet/ndarray/numpy/_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index c12cd143be11..cddc661d5f3f 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -1551,7 +1551,7 @@ def any(a, axis=None, out=None, keepdims=False): @set_module('mxnet.ndarray.numpy') -def argsort(a, axis=-1, descending=False, kind=None, order=None): +def argsort(a, axis=-1, descending=False, stable=True): """ Returns the indices that sort an array `x` along a specified axis. From 8042ece0c9dc1cd371b18ef7fdc572e03ea559e0 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Sun, 24 Oct 2021 22:31:21 -0700 Subject: [PATCH 3/5] fix tests --- python/mxnet/ndarray/numpy/_op.py | 16 ++++++++-------- python/mxnet/ndarray/numpy/linalg.py | 8 ++++---- python/mxnet/numpy/linalg.py | 8 ++++---- python/mxnet/numpy/multiarray.py | 24 ++++++++++++------------ python/mxnet/util.py | 2 +- tests/python/unittest/test_numpy_op.py | 4 ++-- 6 files changed, 31 insertions(+), 31 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index cddc661d5f3f..871d8a529c0b 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -1555,11 +1555,11 @@ def argsort(a, axis=-1, descending=False, stable=True): """ Returns the indices that sort an array `x` along a specified axis. - Notes - ----- - `argsort` is a standard API in + Notes + ----- + `argsort` is a standard API in https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#argsort-x-axis-1-descending-false-stable-true - instead of an official NumPy operator. + instead of an official NumPy operator. Parameters ---------- @@ -1633,11 +1633,11 @@ def sort(a, axis=-1, descending=False, stable=True): """ Return a sorted copy of an array. - Notes - ----- - `sort` is a standard API in + Notes + ----- + `sort` is a standard API in https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#sort-x-axis-1-descending-false-stable-true - instead of an official NumPy operator. + instead of an official NumPy operator. Parameters ---------- diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index 76fa152c4bd4..0d1a96bcb8a2 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -454,11 +454,11 @@ def cholesky(a, upper=False): r""" Cholesky decomposition. - Notes - ----- - `upper` param is requested by API standardization in + Notes + ----- + `upper` param is requested by API standardization in https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-cholesky-x-upper-false - instead of parameter in official NumPy operator. + instead of parameter in official NumPy operator. Return the Cholesky decomposition, `L * L.T`, of the square matrix `a`, where `L` is lower-triangular and .T is the transpose operator. `a` must be diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index cef98184ea9b..ea4bdfef6359 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -715,11 +715,11 @@ def cholesky(a, upper=False): r""" Cholesky decomposition. - Notes - ----- - `upper` param is requested by API standardization in + Notes + ----- + `upper` param is requested by API standardization in https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-cholesky-x-upper-false - instead of parameter in official NumPy operator. + instead of parameter in official NumPy operator. Return the Cholesky decomposition, `L * L.T`, of the square matrix `a`, where `L` is lower-triangular and .T is the transpose operator. `a` must be diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 2affa21f2427..a7f9bfeec6e5 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -1878,13 +1878,13 @@ def pick(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute pick') - def sort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ + def sort(self, axis=-1, descending=False, stable=True): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`sort`. The arguments are the same as for :py:func:`sort`, with this array as data. """ - raise sort(self, axis=axis, kind=kind, order=order) + return sort(self, axis=axis, descending=descending, stable=stable) def topk(self, *args, **kwargs): """Convenience fluent method for :py:func:`topk`. @@ -1894,13 +1894,13 @@ def topk(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute topk') - def argsort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ + def argsort(self, axis=-1, descending=False, stable=True): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`argsort`. The arguments are the same as for :py:func:`argsort`, with this array as data. """ - return argsort(self, axis=axis, kind=kind, order=order) + return argsort(self, axis=axis, descending=descending, stable=stable) def argmax_channel(self, *args, **kwargs): """Convenience fluent method for :py:func:`argmax_channel`. @@ -5762,11 +5762,11 @@ def argsort(a, axis=-1, descending=False, stable=True): """ Returns the indices that sort an array `x` along a specified axis. - Notes - ----- - `argsort` is a standard API in + Notes + ----- + `argsort` is a standard API in https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#argsort-x-axis-1-descending-false-stable-true - instead of an official NumPy operator. + instead of an official NumPy operator. Parameters ---------- @@ -5843,11 +5843,11 @@ def sort(a, axis=-1, descending=False, stable=True): """ Return a sorted copy of an array. - Notes - ----- - `sort` is a standard API in + Notes + ----- + `sort` is a standard API in https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#sort-x-axis-1-descending-false-stable-true - instead of an official NumPy operator. + instead of an official NumPy operator. Parameters ---------- diff --git a/python/mxnet/util.py b/python/mxnet/util.py index ed9affebdf23..63833d203b3a 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -719,7 +719,7 @@ def _wrap_sort_func(*args, **kwargs): kind = kwargs.pop('kind', None) order = kwargs.pop('order', None) if kind is not None: - kwargs['stable'] = True if kind == 'stable' else False + kwargs['stable'] = kind == 'stable' if order is not None: raise NotImplementedError("order not supported here") return func(*args, **kwargs) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 82171bfd31c4..c37964fdb169 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2195,7 +2195,7 @@ def forward(self, x): np_data = data.asnumpy() for axis in [None] + [i for i in range(-len(shape), len(shape))]: if descending: - np_out = onp.argsort(-np_data, axis) + np_out = onp.argsort(-1 * np_data, axis) else: np_out = onp.argsort(np_data, axis) @@ -2255,7 +2255,7 @@ def forward(self, x): continue ret = test(a) if descending: - expected_ret = -onp.sort(-a.asnumpy(), axis) + expected_ret = -onp.sort(-1 * a.asnumpy(), axis) else: expected_ret = onp.sort(a.asnumpy(), axis) assert_almost_equal(ret.asnumpy(), expected_ret, atol=1e-5, rtol=1e-5, use_broadcast=False) From 30ce6d2f2c506e665ae0c6f45565424c41f65c78 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Sat, 30 Oct 2021 17:30:35 -0700 Subject: [PATCH 4/5] update tests --- tests/python/unittest/test_numpy_op.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index c37964fdb169..fa3c950117ee 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2181,7 +2181,8 @@ def forward(self, a): (2, 3), (1, 0, 2), ]) -def test_np_argsort(descending, shape): +@pytest.mark.parametrize('hybrid', [False, True]) +def test_np_argsort(descending, shape, hybrid): class TestArgsort(HybridBlock): def __init__(self, axis, descending): super(TestArgsort, self).__init__() @@ -2200,11 +2201,11 @@ def forward(self, x): np_out = onp.argsort(np_data, axis) test_argsort = TestArgsort(axis, descending) - for hybrid in [False, True]: - if hybrid: - test_argsort.hybridize() - mx_out = test_argsort(data) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False) + + if hybrid: + test_argsort.hybridize() + mx_out = test_argsort(data) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False) mx_out = np.argsort(data, axis, descending) assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False) From aef1112d0ecd591c84076dc39cc6b76914750bd5 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Sun, 31 Oct 2021 11:04:53 -0700 Subject: [PATCH 5/5] fix lint --- src/engine/threaded_engine.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 40d852b83b86..7639fd445987 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -712,7 +712,7 @@ void ThreadedEngine::OnCompleteGPU(Engine* engine, void* sync_info, const dmlc:: ThreadedOpr* threaded_opr = static_cast(info->opr_block)->opr; auto* event_pool = static_cast(info->event_pool); - auto [event, event_pool_idx] = event_pool->GetNextEvent(); + auto [event, event_pool_idx] = event_pool->GetNextEvent(); // NOLINT(*) auto ev = event.lock(); MSHADOW_CUDA_CALL(cudaEventRecord(*ev, worker_stream->stream_)); for (auto* read_var : threaded_opr->const_vars) {