From a2ad4db5c7ac0af77b73d29ecc3e558432e3d21b Mon Sep 17 00:00:00 2001 From: Zhenghui Jin <69359374+barry-jin@users.noreply.github.com> Date: Thu, 18 Nov 2021 09:25:57 -0800 Subject: [PATCH] [API NEW][LINALG] Add vector_norm, matrix_norm (#20703) * [API] Add vector_norm, matrix_norm * fix lint * fix * fix --- python/mxnet/numpy/linalg.py | 85 ++++++++- tests/python/unittest/test_numpy_op.py | 240 +++++++++++++++++++++++++ 2 files changed, 324 insertions(+), 1 deletion(-) diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index 65d210f7aa10..a94b2535aa5b 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -17,6 +17,8 @@ """Namespace for ops used in imperative programming.""" +from functools import reduce + from ..ndarray import numpy as _mx_nd_np from ..util import wrap_data_api_linalg_func from .fallback_linalg import * # pylint: disable=wildcard-import,unused-wildcard-import @@ -24,7 +26,8 @@ __all__ = ['norm', 'svd', 'cholesky', 'qr', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv', 'eigvals', 'eig', 'eigvalsh', 'eigh', 'lstsq', 'matrix_rank', 'cross', 'diagonal', 'outer', - 'tensordot', 'trace', 'matrix_transpose', 'vecdot', 'svdvals'] + 'tensordot', 'trace', 'matrix_transpose', 'vecdot', 'svdvals', 'vector_norm', 'matrix_norm'] + __all__ += fallback_linalg.__all__ @@ -643,6 +646,86 @@ def norm(x, ord=None, axis=None, keepdims=False): return _mx_nd_np.linalg.norm(x, ord, axis, keepdims) +def vector_norm(x, ord=None, axis=None, keepdims=False): + r""" + Computes the vector norm of a vector (or batch of vectors) `x`. + + Parameters + ---------- + x : ndarray + Input array. Should have a floating-point data type. + ord : {non-zero int, inf, -inf}, optional + Order of the norm. + axis : {int, n-tuple of ints, None}, optional + If `axis` is an integer, it specifies the axis of `x` along which to + compute the vector norms. If `axis` is a n-tuple, it specifies the + axes along which to compute batched vector norms. If `axis` is None, + the norm of the whole ndarray is returned. + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `x`. + + Returns + ------- + n : float or ndarray + Norm of the vector(s). + + Notes + ----- + `vector_norm` is a standard API in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-vector-norm-x-axis-none-keepdims-false-ord-2 + instead of an official NumPy operator. + + """ + if axis is None: + x = x.flatten() + axis = 0 + elif isinstance(axis, tuple): + rest = tuple(i for i in range(x.ndim) if i not in axis) + newshape = axis + rest + x = _mx_nd_np.transpose(x, newshape).\ + reshape((reduce(lambda a, b: a * b, [x.shape[a] for a in axis]),\ + *[x.shape[i] for i in rest])) + axis = 0 + return _mx_nd_np.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord) + + +def matrix_norm(x, ord='fro', axis=(-2, -1), keepdims=False): + r""" + Computes the matrix norm of a matrix (or a stack of matrices) `x`. + + Parameters + ---------- + x : ndarray + Input array. Should have a floating-point data type. + ord : {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional + Order of the norm. + axis : {2-tuple of ints} + a 2-tuple which specifies the axes (dimensions) defining two-dimensional + matrices for which to compute matrix norms. + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `x`. + + Returns + ------- + n : float or ndarray + Norm of the matrix. + + Notes + ----- + `matrix_norm` is a standard API in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-matrix-norm-x-axis-2-1-keepdims-false-ord-fro + instead of an official NumPy operator. + + """ + if isinstance(axis, tuple) and len(axis) == 2: + return _mx_nd_np.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord) + raise ValueError("The axis of matrix_norm must be a 2-tuple of ints") + + def svd(a): r""" Singular Value Decomposition. diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 4d2588ac8cf1..cdb20dff578a 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -29,6 +29,7 @@ import scipy.special as scipy_special import pytest import mxnet.ndarray.numpy._internal as _npi +from functools import reduce from mxnet import np, npx from mxnet.gluon import HybridBlock from mxnet.base import MXNetError @@ -5940,6 +5941,245 @@ def spectral_norm_grad(data): assert_almost_equal(mx_ret.asnumpy(), np_ret, rtol=rtol, atol=atol) +@use_np +@pytest.mark.parametrize('shape,ord,axis', [ + ((2, 3, 4), 2, (1, 2)), + ((2, 3, 4), None, None), + ((3,), None, None), + ((2, 3), 2, 1), + ((2, 3, 4), 1, 1), + ((2, 3, 4), -1, 2), + ((2, 3, 4), 2, 1), + ((2, 3, 4), 4, 1), + ((2, 3, 0, 4), -2, 1), + ((2, 3, 4, 5), 2, (2, 3)), + ((2, 3, 4), 'inf', 1), + ((2, 3, 4), '-inf', (1, 0)), + ((2, 3), None, (0, 1)), + ((3, 2, 3), None, (1, 2)), + ((2, 3), None, None), + ((2, 3, 4), None, (0, 2)), + ((2, 3, 4), -3.2, 2), + ((2, 3, 4), 'inf', (0, 2)), + ((2, 3, 4), '-inf', (0, 2)), + ((2, 3, 4, 5, 7), 2, (2, 3, 1)), +]) +@pytest.mark.parametrize('hybridize', [True, False]) +@pytest.mark.parametrize('itype', [np.float32, np.float64]) +@pytest.mark.parametrize('keepdims', [True, False]) +def test_np_linalg_vector_norm(shape, ord, axis, hybridize, itype, keepdims): + class TestLinalgVectNorm(HybridBlock): + def __init__(self, ord=None, axis=None, keepdims=False): + super(TestLinalgVectNorm, self).__init__() + self._ord = ord + self._axis = axis + self._keepdims = keepdims + + def forward(self, x): + return np.linalg.vector_norm(x, ord=self._ord, axis=self._axis, keepdims=self._keepdims) + + def spectral_norm_grad(data): + with mx.autograd.record(): + UT, S, V = np.linalg.svd(data) + norm = np.max(np.abs(S), axis=-1) + norm.backward() + return data.grad.asnumpy() + + def onp_vector_norm(a, axis=None, keepdims=False, ord=2): + if axis is None: + a = a.flatten() + axis = 0 + elif isinstance(axis, tuple): + # Note: The axis argument supports any number of axes, whereas norm() + # only supports a single axis for vector norm. + rest = tuple(i for i in range(a.ndim) if i not in axis) + newshape = axis + rest + a = onp.transpose(a, newshape).reshape((reduce(lambda x, y: x * y, [a.shape[x] for x in axis]), *[a.shape[i] for i in rest])) + axis = 0 + return onp.linalg.norm(a, axis=axis, keepdims=keepdims, ord=ord) + + # numpy is flaky under float16, also gesvd does not support fp16 + net = TestLinalgVectNorm(ord, axis, keepdims) + rtol = 1e-2 + atol = 1e-2 + if hybridize: + net.hybridize() + a = mx.np.random.uniform(-10.0, 10.0, size=shape, dtype=itype) + a.attach_grad() + with mx.autograd.record(): + mx_ret = net(a) + if ord == 'inf': + np_ret = onp_vector_norm(a.asnumpy(), ord=onp.inf, axis=axis, keepdims=keepdims) + elif ord == '-inf': + np_ret = onp_vector_norm(a.asnumpy(), ord=-onp.inf, axis=axis, keepdims=keepdims) + else: + np_ret = onp_vector_norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims) + + assert np_ret.shape == mx_ret.shape + assert_almost_equal(mx_ret.asnumpy(), np_ret, rtol=rtol, atol=atol) + + mx_ret.backward() + + grad_axis = axis + if axis is None and len(shape) >= 2 and ord is not None: + grad_axis = (len(shape) - 2, len(shape) - 1) + elif axis is None and ord is None: + grad_axis = tuple([i for i in range(len(shape))]) + elif axis is None: + grad_axis = len(shape) - 1 + + if not keepdims and isinstance(grad_axis, tuple): + if len(grad_axis) == 2 and grad_axis[0] > grad_axis[1] and grad_axis[0] > len(np_ret.shape): + grad_axis = (grad_axis[1], grad_axis[0]) + for i in grad_axis: + np_ret = onp.expand_dims(np_ret, axis=i) + elif not keepdims: + np_ret = onp.expand_dims(np_ret, axis=grad_axis) + + if ord == 4: + backward_expected = onp.sign(a.asnumpy()) * onp.power(onp.abs(a.asnumpy()) / np_ret, ord - 1) + assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + + if ord == 2 and not isinstance(grad_axis, tuple): + backward_expected = onp.divide(a.asnumpy(), np_ret) + assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + elif ord == 2 and isinstance(grad_axis, tuple): + backward_expected = spectral_norm_grad(a) + assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + + assert a.grad.shape == a.shape + + # Test imperative once again + if ord == 'inf': + np_ret = onp_vector_norm(a.asnumpy(), ord=onp.inf, axis=axis, keepdims=keepdims) + elif ord == '-inf': + np_ret = onp_vector_norm(a.asnumpy(), ord=-onp.inf, axis=axis, keepdims=keepdims) + else: + np_ret = onp_vector_norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims) + mx_ret = np.linalg.vector_norm(a, ord=ord, axis=axis, keepdims=keepdims) + assert_almost_equal(mx_ret.asnumpy(), np_ret, rtol=rtol, atol=atol) + + +@use_np +@pytest.mark.parametrize('shape,ord,axis', [ + ((2, 3, 4), 1, (2, 1)), + ((2, 3, 4), 2, (1, 2)), + ((2, 3, 4), None, None), + ((3,), None, None), + ((2, 3), 2, 1), + ((2, 3, 4), 1, 1), + ((2, 3, 4), -1, 2), + ((2, 3, 4), 2, 1), + ((2, 3, 4), 4, 1), + ((2, 3, 0, 4), -2, 1), + ((2, 3, 4, 5), 2, (2, 3)), + ((2, 3), -1, None), + ((2, 3, 4), 'inf', 1), + ((2, 3, 4), '-inf', (1, 0)), + ((2, 3), None, (0, 1)), + ((3, 2, 3), None, (1, 2)), + ((2, 3), None, None), + ((2, 3, 4), 'fro', (0, 2)), + ((2, 0, 4), 'fro', (0, 2)), + ((2, 3, 4), None, (0, 2)), + ((2, 3, 4), -3.2, 2), + ((2, 3, 4), -1, (0, 1)), + ((2, 3, 4), 'inf', (0, 2)), + ((2, 3, 4), '-inf', (0, 2)), + ((4, 4, 4, 4), -2, (0, 2)), + ((2, 3, 4), 'nuc', (0, 2)), + ((2, 2), 'nuc', None), +]) +@pytest.mark.parametrize('hybridize', [True, False]) +@pytest.mark.parametrize('itype', [np.float32, np.float64]) +@pytest.mark.parametrize('keepdims', [True, False]) +def test_np_linalg_matrix_norm(shape, ord, axis, hybridize, itype, keepdims): + class TestLinalgMatNorm(HybridBlock): + def __init__(self, ord=None, axis=None, keepdims=False): + super(TestLinalgMatNorm, self).__init__() + self._ord = ord + self._axis = axis + self._keepdims = keepdims + + def forward(self, x): + return np.linalg.matrix_norm(x, ord=self._ord, axis=self._axis, keepdims=self._keepdims) + + def spectral_norm_grad(data): + with mx.autograd.record(): + UT, S, V = np.linalg.svd(data) + norm = np.max(np.abs(S), axis=-1) + norm.backward() + return data.grad.asnumpy() + + # numpy is flaky under float16, also gesvd does not support fp16 + net = TestLinalgMatNorm(ord, axis, keepdims) + rtol = 1e-2 + atol = 1e-2 + if hybridize: + net.hybridize() + a = mx.np.random.uniform(-10.0, 10.0, size=shape, dtype=itype) + if not isinstance(axis, tuple) or not len(axis) == 2: + assertRaises(ValueError, np.linalg.matrix_norm, a, ord, axis, keepdims) + return + a.attach_grad() + with mx.autograd.record(): + mx_ret = net(a) + if ord == 'inf': + np_ret = onp.linalg.norm(a.asnumpy(), ord=onp.inf, axis=axis, keepdims=keepdims) + elif ord == '-inf': + np_ret = onp.linalg.norm(a.asnumpy(), ord=-onp.inf, axis=axis, keepdims=keepdims) + else: + np_ret = onp.linalg.norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims) + + assert np_ret.shape == mx_ret.shape + assert_almost_equal(mx_ret.asnumpy(), np_ret, rtol=rtol, atol=atol) + + mx_ret.backward() + + grad_axis = axis + if axis is None and len(shape) >= 2 and ord is not None: + grad_axis = (len(shape) - 2, len(shape) - 1) + elif axis is None and ord is None: + grad_axis = tuple([i for i in range(len(shape))]) + elif axis is None: + grad_axis = len(shape) - 1 + + if not keepdims and isinstance(grad_axis, tuple): + if len(grad_axis) == 2 and grad_axis[0] > grad_axis[1] and grad_axis[0] > len(np_ret.shape): + grad_axis = (grad_axis[1], grad_axis[0]) + for i in grad_axis: + np_ret = onp.expand_dims(np_ret, axis=i) + elif not keepdims: + np_ret = onp.expand_dims(np_ret, axis=grad_axis) + + if ord == 4: + backward_expected = onp.sign(a.asnumpy()) * onp.power(onp.abs(a.asnumpy()) / np_ret, ord - 1) + assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + + if ord == 2 and not isinstance(grad_axis, tuple): + backward_expected = onp.divide(a.asnumpy(), np_ret) + assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + elif ord == 2 and isinstance(grad_axis, tuple): + backward_expected = spectral_norm_grad(a) + assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + + if ord == 'fro': + backward_expected = onp.divide(a.asnumpy(), np_ret) + assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol) + + assert a.grad.shape == a.shape + + # Test imperative once again + if ord == 'inf': + np_ret = onp.linalg.norm(a.asnumpy(), ord=onp.inf, axis=axis, keepdims=keepdims) + elif ord == '-inf': + np_ret = onp.linalg.norm(a.asnumpy(), ord=-onp.inf, axis=axis, keepdims=keepdims) + else: + np_ret = onp.linalg.norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims) + mx_ret = np.linalg.matrix_norm(a, ord=ord, axis=axis, keepdims=keepdims) + assert_almost_equal(mx_ret.asnumpy(), np_ret, rtol=rtol, atol=atol) + + @use_np @pytest.mark.parametrize('shape', [ (3, 3),