diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index b0471c416e14..d58324b7946d 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -1590,12 +1590,15 @@ def any(a, axis=None, out=None, keepdims=False): @set_module('mxnet.ndarray.numpy') -def argsort(a, axis=-1, kind=None, order=None): +def argsort(a, axis=-1, descending=False, stable=True): """ - Returns the indices that would sort an array. - Perform an indirect sort along the given axis using the algorithm specified - by the `kind` keyword. It returns an array of indices of the same shape as - `a` that index data along the given axis in sorted order. + Returns the indices that sort an array `x` along a specified axis. + + Notes + ----- + `argsort` is a standard API in + https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#argsort-x-axis-1-descending-false-stable-true + instead of an official NumPy operator. Parameters ---------- @@ -1604,11 +1607,13 @@ def argsort(a, axis=-1, kind=None, order=None): axis : int or None, optional Axis along which to sort. The default is -1 (the last axis). If None, the flattened array is used. - kind : string, optional - This argument can take any string, but it does not have any effect on the - final result. - order : str or list of str, optional - Not supported yet, will raise NotImplementedError if not None. + descending : bool, optional + sort order. If `True`, the returned indices sort x in descending order (by value). + If `False`, the returned indices sort x in ascending order (by value).Default: False. + stable : bool, optional + sort stability. If `True`, the returned indices must maintain the relative order + of x values which compare as equal. If `False`, the returned indices may or may not + maintain the relative order of x values which compare as equal. Default: True. Returns ------- @@ -1659,29 +1664,34 @@ def argsort(a, axis=-1, kind=None, order=None): >>> x[ind] # same as np.sort(x, axis=None) array([0, 2, 2, 3]) """ - if order is not None: - raise NotImplementedError("order not supported here") - - return _api_internal.argsort(a, axis, True, 'int64') + return _api_internal.argsort(a, axis, not descending, 'int64') @set_module('mxnet.ndarray.numpy') -def sort(a, axis=-1, kind=None, order=None): +def sort(a, axis=-1, descending=False, stable=True): """ Return a sorted copy of an array. + Notes + ----- + `sort` is a standard API in + https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#sort-x-axis-1-descending-false-stable-true + instead of an official NumPy operator. + Parameters ---------- a : ndarray - Array to be sorted. + Array to sort. axis : int or None, optional Axis along which to sort. The default is -1 (the last axis). If None, the flattened array is used. - kind : string, optional - This argument can take any string, but it does not have any effect on the - final result. - order : str or list of str, optional - Not supported yet, will raise NotImplementedError if not None. + descending : bool, optional + sort order. If `True`, the returned indices sort x in descending order (by value). + If `False`, the returned indices sort x in ascending order (by value).Default: False. + stable : bool, optional + sort stability. If `True`, the returned indices must maintain the relative order + of x values which compare as equal. If `False`, the returned indices may or may not + maintain the relative order of x values which compare as equal. Default: True. Returns ------- @@ -1704,9 +1714,7 @@ def sort(a, axis=-1, kind=None, order=None): array([[1, 1], [3, 4]]) """ - if order is not None: - raise NotImplementedError("order not supported here") - return _api_internal.sort(a, axis, True) + return _api_internal.sort(a, axis, not descending) @set_module('mxnet.ndarray.numpy') def dot(a, b, out=None): diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index 9d135248e490..0d1a96bcb8a2 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -450,10 +450,16 @@ def svd(a): return tuple(_api_internal.svd(a)) -def cholesky(a): +def cholesky(a, upper=False): r""" Cholesky decomposition. + Notes + ----- + `upper` param is requested by API standardization in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-cholesky-x-upper-false + instead of parameter in official NumPy operator. + Return the Cholesky decomposition, `L * L.T`, of the square matrix `a`, where `L` is lower-triangular and .T is the transpose operator. `a` must be symmetric and positive-definite. Only `L` is actually returned. Complex-valued @@ -463,6 +469,10 @@ def cholesky(a): ---------- a : (..., M, M) ndarray Symmetric, positive-definite input matrix. + upper : bool + If `True`, the result must be the upper-triangular Cholesky factor. + If `False`, the result must be the lower-triangular Cholesky factor. + Default: `False`. Returns ------- @@ -506,7 +516,7 @@ def cholesky(a): array([[16., 4.], [ 4., 10.]]) """ - return _api_internal.cholesky(a, True) + return _api_internal.cholesky(a, not upper) def qr(a, mode='reduced'): diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index 6f96f094478a..65d210f7aa10 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -28,27 +28,28 @@ __all__ += fallback_linalg.__all__ -def matrix_rank(M, tol=None, hermitian=False): +@wrap_data_api_linalg_func +def matrix_rank(M, rtol=None, hermitian=False): r""" Return matrix rank of array using SVD method Rank of the array is the number of singular values of the array that are - greater than `tol`. + greater than `rtol`. Notes ----- - `matrix_rank` is an alias for `matrix_rank`. It is a standard API in + `rtol` param is requested in array-api-standard in https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-matrix-rank-x-rtol-none - instead of an official NumPy operator. + instead of a parameter in official NumPy operator. Parameters ---------- M : {(M,), (..., M, N)} ndarray Input vector or stack of matrices. - tol : (...) ndarray, float, optional - Threshold below which SVD values are considered zero. If `tol` is + rtol : (...) ndarray, float, optional + Threshold below which SVD values are considered zero. If `rtol` is None, and ``S`` is an array with singular values for `M`, and - ``eps`` is the epsilon value for datatype of ``S``, then `tol` is + ``eps`` is the epsilon value for datatype of ``S``, then `rtol` is set to ``S.max() * max(M.shape) * eps``. hermitian : bool, optional If True, `M` is assumed to be Hermitian (symmetric if real-valued), @@ -73,7 +74,7 @@ def matrix_rank(M, tol=None, hermitian=False): >>> np.linalg.matrix_rank(np.zeros((4,))) 0 """ - return _mx_nd_np.linalg.matrix_rank(M, tol, hermitian) + return _mx_nd_np.linalg.matrix_rank(M, rtol, hermitian) def matrix_transpose(a): @@ -502,7 +503,8 @@ def lstsq(a, b, rcond='warn'): return _mx_nd_np.linalg.lstsq(a, b, rcond) -def pinv(a, rcond=1e-15, hermitian=False): +@wrap_data_api_linalg_func +def pinv(a, rtol=None, hermitian=False): r""" Compute the (Moore-Penrose) pseudo-inverse of a matrix. @@ -510,14 +512,20 @@ def pinv(a, rcond=1e-15, hermitian=False): singular-value decomposition (SVD) and including all *large* singular values. + Notes + ----- + `rtol` param is requested in array-api-standard in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-pinv-x-rtol-none + instead of a parameter in official NumPy operator. + Parameters ---------- a : (..., M, N) ndarray Matrix or stack of matrices to be pseudo-inverted. - rcond : (...) {float or ndarray of float}, optional + rtol : (...) {float or ndarray of float}, optional Cutoff for small singular values. Singular values less than or equal to - ``rcond * largest_singular_value`` are set to zero. + ``rtol * largest_singular_value`` are set to zero. Broadcasts against the stack of matrices. hermitian : bool, optional If True, `a` is assumed to be Hermitian (symmetric if real-valued), @@ -567,7 +575,7 @@ def pinv(a, rcond=1e-15, hermitian=False): >>> (pinv_a - np.dot(pinv_a, np.dot(a, pinv_a))).sum() array(0.) """ - return _mx_nd_np.linalg.pinv(a, rcond, hermitian) + return _mx_nd_np.linalg.pinv(a, rtol, hermitian) def norm(x, ord=None, axis=None, keepdims=False): @@ -732,10 +740,16 @@ def svdvals(a): return s -def cholesky(a): +def cholesky(a, upper=False): r""" Cholesky decomposition. + Notes + ----- + `upper` param is requested by API standardization in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-cholesky-x-upper-false + instead of parameter in official NumPy operator. + Return the Cholesky decomposition, `L * L.T`, of the square matrix `a`, where `L` is lower-triangular and .T is the transpose operator. `a` must be symmetric and positive-definite. Only `L` is actually returned. Complex-valued @@ -745,6 +759,10 @@ def cholesky(a): ---------- a : (..., M, M) ndarray Symmetric, positive-definite input matrix. + upper : bool + If `True`, the result must be the upper-triangular Cholesky factor. + If `False`, the result must be the lower-triangular Cholesky factor. + Default: `False`. Returns ------- @@ -788,7 +806,7 @@ def cholesky(a): array([[16., 4.], [ 4., 10.]]) """ - return _mx_nd_np.linalg.cholesky(a) + return _mx_nd_np.linalg.cholesky(a, upper) def qr(a, mode='reduced'): diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 04bace50da27..e03df8fa04d0 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -45,7 +45,8 @@ from ..runtime import Features from ..context import Context from ..util import set_module, wrap_np_unary_func, wrap_np_binary_func,\ - is_np_default_dtype, wrap_data_api_statical_func + is_np_default_dtype, wrap_data_api_statical_func,\ + wrap_sort_functions from ..context import current_context from ..ndarray import numpy as _mx_nd_np from ..ndarray.numpy import _internal as _npi @@ -1908,13 +1909,13 @@ def pick(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute pick') - def sort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ + def sort(self, axis=-1, descending=False, stable=True): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`sort`. The arguments are the same as for :py:func:`sort`, with this array as data. """ - raise sort(self, axis=axis, kind=kind, order=order) + return sort(self, axis=axis, descending=descending, stable=stable) def topk(self, *args, **kwargs): """Convenience fluent method for :py:func:`topk`. @@ -1924,13 +1925,13 @@ def topk(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute topk') - def argsort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ + def argsort(self, axis=-1, descending=False, stable=True): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`argsort`. The arguments are the same as for :py:func:`argsort`, with this array as data. """ - return argsort(self, axis=axis, kind=kind, order=order) + return argsort(self, axis=axis, descending=descending, stable=stable) def argmax_channel(self, *args, **kwargs): """Convenience fluent method for :py:func:`argmax_channel`. @@ -5831,12 +5832,16 @@ def arctanh(x, out=None, **kwargs): @set_module('mxnet.numpy') -def argsort(a, axis=-1, kind=None, order=None): +@wrap_sort_functions +def argsort(a, axis=-1, descending=False, stable=True): """ - Returns the indices that would sort an array. - Perform an indirect sort along the given axis using the algorithm specified - by the `kind` keyword. It returns an array of indices of the same shape as - `a` that index data along the given axis in sorted order. + Returns the indices that sort an array `x` along a specified axis. + + Notes + ----- + `argsort` is a standard API in + https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#argsort-x-axis-1-descending-false-stable-true + instead of an official NumPy operator. Parameters ---------- @@ -5845,11 +5850,13 @@ def argsort(a, axis=-1, kind=None, order=None): axis : int or None, optional Axis along which to sort. The default is -1 (the last axis). If None, the flattened array is used. - kind : string, optional - This argument can take any string, but it does not have any effect on the - final result. - order : str or list of str, optional - Not supported yet, will raise NotImplementedError if not None. + descending : bool, optional + sort order. If `True`, the returned indices sort x in descending order (by value). + If `False`, the returned indices sort x in ascending order (by value).Default: False. + stable : bool, optional + sort stability. If `True`, the returned indices must maintain the relative order + of x values which compare as equal. If `False`, the returned indices may or may not + maintain the relative order of x values which compare as equal. Default: True. Returns ------- @@ -5900,26 +5907,37 @@ def argsort(a, axis=-1, kind=None, order=None): >>> x[ind] # same as np.sort(x, axis=None) array([0, 2, 2, 3]) """ - return _mx_nd_np.argsort(a, axis=axis, kind=kind, order=order) + if stable: + warnings.warn("Currently, MXNet only support quicksort in backend, which is not stable") + return _mx_nd_np.argsort(a, axis=axis, descending=descending) @set_module('mxnet.numpy') -def sort(a, axis=-1, kind=None, order=None): +@wrap_sort_functions +def sort(a, axis=-1, descending=False, stable=True): """ Return a sorted copy of an array. + Notes + ----- + `sort` is a standard API in + https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#sort-x-axis-1-descending-false-stable-true + instead of an official NumPy operator. + Parameters ---------- a : ndarray - Array to be sorted. + Array to sort. axis : int or None, optional Axis along which to sort. The default is -1 (the last axis). If None, the flattened array is used. - kind : string, optional - This argument can take any string, but it does not have any effect on the - final result. - order : str or list of str, optional - Not supported yet, will raise NotImplementedError if not None. + descending : bool, optional + sort order. If `True`, the returned indices sort x in descending order (by value). + If `False`, the returned indices sort x in ascending order (by value).Default: False. + stable : bool, optional + sort stability. If `True`, the returned indices must maintain the relative order + of x values which compare as equal. If `False`, the returned indices may or may not + maintain the relative order of x values which compare as equal. Default: True. Returns ------- @@ -5942,7 +5960,7 @@ def sort(a, axis=-1, kind=None, order=None): array([[1, 1], [3, 4]]) """ - return _mx_nd_np.sort(a, axis=axis, kind=kind, order=order) + return _mx_nd_np.sort(a, axis=axis, descending=descending) @set_module('mxnet.numpy') diff --git a/python/mxnet/util.py b/python/mxnet/util.py index 733d4843a76a..63833d203b3a 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -682,17 +682,49 @@ def wrap_data_api_linalg_func(func): """ @functools.wraps(func) - def _wrap_api_creation_func(*args, **kwargs): + def _wrap_linalg_func(*args, **kwargs): if len(kwargs) != 0: upper = kwargs.pop('UPLO', None) + rcond = kwargs.pop('rcond', None) + tol = kwargs.pop('tol', None) if upper is not None: if upper == 'U': kwargs['upper'] = True else: kwargs['upper'] = False + if rcond is not None: + kwargs['rtol'] = rcond + if tol is not None: + kwargs['rtol'] = tol return func(*args, **kwargs) - return _wrap_api_creation_func + return _wrap_linalg_func + + +def wrap_sort_functions(func): + """A convenience decorator for wrapping sort functions + + Parameters + ---------- + func : a numpy-compatible array creation function to be wrapped for parameter keyword change. + + Returns + ------- + Function + A function wrapped with changed keywords. + """ + @functools.wraps(func) + def _wrap_sort_func(*args, **kwargs): + if len(kwargs) != 0: + kind = kwargs.pop('kind', None) + order = kwargs.pop('order', None) + if kind is not None: + kwargs['stable'] = kind == 'stable' + if order is not None: + raise NotImplementedError("order not supported here") + return func(*args, **kwargs) + return _wrap_sort_func + # pylint: disable=exec-used def numpy_fallback(func): diff --git a/src/api/operator/numpy/np_ordering_op.cc b/src/api/operator/numpy/np_ordering_op.cc index 11c00fbfb71e..627e450892af 100644 --- a/src/api/operator/numpy/np_ordering_op.cc +++ b/src/api/operator/numpy/np_ordering_op.cc @@ -39,7 +39,7 @@ MXNET_REGISTER_API("_npi.sort").set_body([](runtime::MXNetArgs args, runtime::MX } else { param.axis = args[1].operator int(); } - param.is_ascend = true; + param.is_ascend = args[2].operator bool(); attrs.parsed = std::move(param); attrs.op = op; @@ -65,7 +65,7 @@ MXNET_REGISTER_API("_npi.argsort") } else { param.axis = args[1].operator int(); } - param.is_ascend = true; + param.is_ascend = args[2].operator bool(); if (args[3].type_code() == kNull) { param.dtype = mshadow::kFloat32; } else { diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 40d852b83b86..7639fd445987 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -712,7 +712,7 @@ void ThreadedEngine::OnCompleteGPU(Engine* engine, void* sync_info, const dmlc:: ThreadedOpr* threaded_opr = static_cast(info->opr_block)->opr; auto* event_pool = static_cast(info->event_pool); - auto [event, event_pool_idx] = event_pool->GetNextEvent(); + auto [event, event_pool_idx] = event_pool->GetNextEvent(); // NOLINT(*) auto ev = event.lock(); MSHADOW_CUDA_CALL(cudaEventRecord(*ev, worker_stream->stream_)); for (auto* read_var : threaded_opr->const_vars) { diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 6a2e6acf6183..fe9a66f8aaa2 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2175,41 +2175,44 @@ def forward(self, a): @use_np -def test_np_argsort(): +@pytest.mark.parametrize('descending', [True, False]) +@pytest.mark.parametrize('shape', [ + (), + (2, 3), + (1, 0, 2), +]) +@pytest.mark.parametrize('hybrid', [False, True]) +def test_np_argsort(descending, shape, hybrid): class TestArgsort(HybridBlock): - def __init__(self, axis): + def __init__(self, axis, descending): super(TestArgsort, self).__init__() self._axis = axis + self._descending = descending def forward(self, x): - return np.argsort(x, axis=self._axis) - - shapes = [ - (), - (2, 3), - (1, 0, 2), - ] + return np.argsort(x, axis=self._axis, descending=self._descending) - for shape in shapes: - data = np.random.uniform(size=shape) - np_data = data.asnumpy() - - for axis in [None] + [i for i in range(-len(shape), len(shape))]: + data = np.random.uniform(size=shape) + np_data = data.asnumpy() + for axis in [None] + [i for i in range(-len(shape), len(shape))]: + if descending: + np_out = onp.argsort(-1 * np_data, axis) + else: np_out = onp.argsort(np_data, axis) - test_argsort = TestArgsort(axis) - for hybrid in [False, True]: - if hybrid: - test_argsort.hybridize() - mx_out = test_argsort(data) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False) + test_argsort = TestArgsort(axis, descending) + + if hybrid: + test_argsort.hybridize() + mx_out = test_argsort(data) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False) - mx_out = np.argsort(data, axis) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False) + mx_out = np.argsort(data, axis, descending) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False) @use_np -@pytest.mark.parametrize('kind', ['quicksort', 'mergesort', 'heapsort']) +@pytest.mark.parametrize('descending', [True, False]) @pytest.mark.parametrize('shape', [ (), (1,), @@ -2231,32 +2234,35 @@ def forward(self, x): ]) @pytest.mark.parametrize('dtype', [np.int8, np.uint8, np.int32, np.int64, np.float32, np.float64]) @pytest.mark.parametrize('hybridize', [True, False]) -def test_np_sort(kind, shape, dtype, hybridize): +def test_np_sort(shape, dtype, hybridize, descending): class TestSort(HybridBlock): - def __init__(self, axis, kind): + def __init__(self, axis, descending): super(TestSort, self).__init__() self._axis = axis - self._kind = kind + self._descending = descending - def forward(self, x, *args, **kwargs): - return np.sort(x, self._axis, self._kind) + def forward(self, x): + return np.sort(x, self._axis, descending=self._descending) a = np.random.uniform(low=0, high=100, size=shape, dtype='float64').astype(dtype) axis_list = list(range(len(shape))) axis_list.append(None) axis_list.append(-1) for axis in axis_list: - test = TestSort(axis, kind) + test = TestSort(axis, descending) if hybridize: test.hybridize() if axis == -1 and len(shape)==0: continue ret = test(a) - expected_ret = onp.sort(a.asnumpy(), axis, kind) + if descending: + expected_ret = -onp.sort(-1 * a.asnumpy(), axis) + else: + expected_ret = onp.sort(a.asnumpy(), axis) assert_almost_equal(ret.asnumpy(), expected_ret, atol=1e-5, rtol=1e-5, use_broadcast=False) # check imperative again - ret = np.sort(a, axis, kind) + ret = np.sort(a, axis=axis, descending=descending) assert_almost_equal(ret.asnumpy(), expected_ret, atol=1e-5, rtol=1e-5, use_broadcast=False) @@ -6172,18 +6178,36 @@ def check_qr(q, r, a_np): @use_np -def test_np_linalg_cholesky(): +@pytest.mark.parametrize('shape', [ + (0, 0), + (1, 1), + (5, 5), + (6, 6), + (10, 10), + (6, 6, 6), + (1, 0, 0), + (0, 1, 1), + (2, 3, 4, 4), +]) +@pytest.mark.parametrize('dtype', ['float32', 'float64']) +@pytest.mark.parametrize('upper', [True, False]) +@pytest.mark.parametrize('hybridize', [True, False]) +def test_np_linalg_cholesky(shape, dtype, upper, hybridize): class TestCholesky(HybridBlock): - def __init__(self): + def __init__(self, upper=False): super(TestCholesky, self).__init__() + self._upper = upper def forward(self, data): - return np.linalg.cholesky(data) + return np.linalg.cholesky(data, upper=self._upper) - def get_grad(L): + def get_grad(L, upper): # shape of m is [batch, n, n] if 0 in L.shape: return L + + if upper: + L = onp.swapaxes(L, -1, -2) def copyltu(m): eye = onp.array([onp.eye(m.shape[-1]) for i in range(m.shape[0])]) @@ -6202,11 +6226,14 @@ def copyltu(m): dA = 0.5 * onp.matmul(onp.matmul(L_inv_T, copyltu(onp.matmul(L_T, dL))), L_inv) return dA.reshape(shape) - def check_cholesky(L, data_np): + def check_cholesky(L, data_np, upper): assert L.shape == data_np.shape # catch error if numpy throws rank < 2 try: - L_expected = onp.linalg.cholesky(data_np) + if upper: + L_expected = onp.swapaxes(onp.linalg.cholesky(data_np), -1, -2) + else: + L_expected = onp.linalg.cholesky(data_np) except Exception as e: print(data_np) print(data_np.shape) @@ -6231,64 +6258,52 @@ def newSymmetricPositiveDefineMatrix_nD(shape, ran=(0., 10.), max_cond=4): n = int(onp.prod(shape[:-2])) if len(shape) > 2 else 1 return onp.array([newSymmetricPositiveDefineMatrix_2D(shape[-2:], ran, max_cond) for i in range(n)]).reshape(shape) - shapes = [ - (0, 0), - (1, 1), - (5, 5), - (6, 6), - (10, 10), - (6, 6, 6), - (1, 0, 0), - (0, 1, 1), - (2, 3, 4, 4), - ] - dtypes = ['float32', 'float64'] - for hybridize, dtype, shape in itertools.product([True, False], dtypes, shapes): - rtol = 1e-3 - atol = 1e-5 - if dtype == 'float32': - rtol = 1e-2 - atol = 1e-4 - test_cholesky = TestCholesky() - if hybridize: - test_cholesky.hybridize() - - # Numerical issue: - # When backpropagating through Cholesky decomposition, we need to compute the inverse - # of L according to dA = 0.5 * L**(-T) * copyLTU(L**T * dL) * L**(-1) where A = LL^T. - # The inverse is calculated by "trsm" method in CBLAS. When the data type is float32, - # this causes numerical instability. It happens when the matrix is ill-conditioned. - # In this example, the issue occurs frequently if the symmetric positive definite input - # matrix A is constructed by A = LL^T + \epsilon * I. A proper way of testing such - # operators involving numerically unstable operations is to use well-conditioned random - # matrices as input. Here we test Cholesky decomposition for FP32 and FP64 separately. - # See rocBLAS: - # https://github.com/ROCmSoftwarePlatform/rocBLAS/wiki/9.Numerical-Stability-in-TRSM - - # generate symmetric PD matrices - if 0 in shape: - data_np = np.ones(shape) - else: - data_np = newSymmetricPositiveDefineMatrix_nD(shape) + rtol = 1e-3 + atol = 1e-5 + if dtype == 'float32': + rtol = 1e-2 + atol = 1e-4 - # When dtype is np.FP32, truncation from FP64 to FP32 could also be a source of - # instability since the ground-truth gradient is computed using FP64 data. - data = np.array(data_np, dtype=dtype) - data.attach_grad() - with mx.autograd.record(): - L = test_cholesky(data) + test_cholesky = TestCholesky(upper) + if hybridize: + test_cholesky.hybridize() + + # Numerical issue: + # When backpropagating through Cholesky decomposition, we need to compute the inverse + # of L according to dA = 0.5 * L**(-T) * copyLTU(L**T * dL) * L**(-1) where A = LL^T. + # The inverse is calculated by "trsm" method in CBLAS. When the data type is float32, + # this causes numerical instability. It happens when the matrix is ill-conditioned. + # In this example, the issue occurs frequently if the symmetric positive definite input + # matrix A is constructed by A = LL^T + \epsilon * I. A proper way of testing such + # operators involving numerically unstable operations is to use well-conditioned random + # matrices as input. Here we test Cholesky decomposition for FP32 and FP64 separately. + # See rocBLAS: + # https://github.com/ROCmSoftwarePlatform/rocBLAS/wiki/9.Numerical-Stability-in-TRSM + + # generate symmetric PD matrices + if 0 in shape: + data_np = np.ones(shape) + else: + data_np = newSymmetricPositiveDefineMatrix_nD(shape) - # check cholesky validity - check_cholesky(L, data_np) - # check backward. backward does not support empty input - if 0 not in L.shape: - mx.autograd.backward(L) - backward_expected = get_grad(L.asnumpy()) - assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) - # check imperative once again - L = np.linalg.cholesky(data) - check_cholesky(L, data_np) + # When dtype is np.FP32, truncation from FP64 to FP32 could also be a source of + # instability since the ground-truth gradient is computed using FP64 data. + data = np.array(data_np, dtype=dtype) + data.attach_grad() + with mx.autograd.record(): + L = test_cholesky(data) + + # check cholesky validity + check_cholesky(L, data_np, upper) + # check backward. backward does not support empty input + if 0 not in L.shape: + mx.autograd.backward(L) + backward_expected = get_grad(L.asnumpy(), upper) + assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + # check imperative once again + L = np.linalg.cholesky(data, upper=upper) + check_cholesky(L, data_np, upper) @use_np