Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[API NEW][LINALG] Add vector_norm, matrix_norm (#20703)
Browse files Browse the repository at this point in the history
* [API] Add vector_norm, matrix_norm

* fix lint

* fix

* fix
  • Loading branch information
barry-jin authored Nov 18, 2021
1 parent e3c4da9 commit a2ad4db
Show file tree
Hide file tree
Showing 2 changed files with 324 additions and 1 deletion.
85 changes: 84 additions & 1 deletion python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

"""Namespace for ops used in imperative programming."""

from functools import reduce

from ..ndarray import numpy as _mx_nd_np
from ..util import wrap_data_api_linalg_func
from .fallback_linalg import * # pylint: disable=wildcard-import,unused-wildcard-import
from . import fallback_linalg

__all__ = ['norm', 'svd', 'cholesky', 'qr', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve',
'pinv', 'eigvals', 'eig', 'eigvalsh', 'eigh', 'lstsq', 'matrix_rank', 'cross', 'diagonal', 'outer',
'tensordot', 'trace', 'matrix_transpose', 'vecdot', 'svdvals']
'tensordot', 'trace', 'matrix_transpose', 'vecdot', 'svdvals', 'vector_norm', 'matrix_norm']

__all__ += fallback_linalg.__all__


Expand Down Expand Up @@ -643,6 +646,86 @@ def norm(x, ord=None, axis=None, keepdims=False):
return _mx_nd_np.linalg.norm(x, ord, axis, keepdims)


def vector_norm(x, ord=None, axis=None, keepdims=False):
r"""
Computes the vector norm of a vector (or batch of vectors) `x`.
Parameters
----------
x : ndarray
Input array. Should have a floating-point data type.
ord : {non-zero int, inf, -inf}, optional
Order of the norm.
axis : {int, n-tuple of ints, None}, optional
If `axis` is an integer, it specifies the axis of `x` along which to
compute the vector norms. If `axis` is a n-tuple, it specifies the
axes along which to compute batched vector norms. If `axis` is None,
the norm of the whole ndarray is returned.
keepdims : bool, optional
If this is set to True, the axes which are normed over are left in the
result as dimensions with size one. With this option the result will
broadcast correctly against the original `x`.
Returns
-------
n : float or ndarray
Norm of the vector(s).
Notes
-----
`vector_norm` is a standard API in
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-vector-norm-x-axis-none-keepdims-false-ord-2
instead of an official NumPy operator.
"""
if axis is None:
x = x.flatten()
axis = 0
elif isinstance(axis, tuple):
rest = tuple(i for i in range(x.ndim) if i not in axis)
newshape = axis + rest
x = _mx_nd_np.transpose(x, newshape).\
reshape((reduce(lambda a, b: a * b, [x.shape[a] for a in axis]),\
*[x.shape[i] for i in rest]))
axis = 0
return _mx_nd_np.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord)


def matrix_norm(x, ord='fro', axis=(-2, -1), keepdims=False):
r"""
Computes the matrix norm of a matrix (or a stack of matrices) `x`.
Parameters
----------
x : ndarray
Input array. Should have a floating-point data type.
ord : {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional
Order of the norm.
axis : {2-tuple of ints}
a 2-tuple which specifies the axes (dimensions) defining two-dimensional
matrices for which to compute matrix norms.
keepdims : bool, optional
If this is set to True, the axes which are normed over are left in the
result as dimensions with size one. With this option the result will
broadcast correctly against the original `x`.
Returns
-------
n : float or ndarray
Norm of the matrix.
Notes
-----
`matrix_norm` is a standard API in
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-matrix-norm-x-axis-2-1-keepdims-false-ord-fro
instead of an official NumPy operator.
"""
if isinstance(axis, tuple) and len(axis) == 2:
return _mx_nd_np.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord)
raise ValueError("The axis of matrix_norm must be a 2-tuple of ints")


def svd(a):
r"""
Singular Value Decomposition.
Expand Down
240 changes: 240 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import scipy.special as scipy_special
import pytest
import mxnet.ndarray.numpy._internal as _npi
from functools import reduce
from mxnet import np, npx
from mxnet.gluon import HybridBlock
from mxnet.base import MXNetError
Expand Down Expand Up @@ -5940,6 +5941,245 @@ def spectral_norm_grad(data):
assert_almost_equal(mx_ret.asnumpy(), np_ret, rtol=rtol, atol=atol)


@use_np
@pytest.mark.parametrize('shape,ord,axis', [
((2, 3, 4), 2, (1, 2)),
((2, 3, 4), None, None),
((3,), None, None),
((2, 3), 2, 1),
((2, 3, 4), 1, 1),
((2, 3, 4), -1, 2),
((2, 3, 4), 2, 1),
((2, 3, 4), 4, 1),
((2, 3, 0, 4), -2, 1),
((2, 3, 4, 5), 2, (2, 3)),
((2, 3, 4), 'inf', 1),
((2, 3, 4), '-inf', (1, 0)),
((2, 3), None, (0, 1)),
((3, 2, 3), None, (1, 2)),
((2, 3), None, None),
((2, 3, 4), None, (0, 2)),
((2, 3, 4), -3.2, 2),
((2, 3, 4), 'inf', (0, 2)),
((2, 3, 4), '-inf', (0, 2)),
((2, 3, 4, 5, 7), 2, (2, 3, 1)),
])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('itype', [np.float32, np.float64])
@pytest.mark.parametrize('keepdims', [True, False])
def test_np_linalg_vector_norm(shape, ord, axis, hybridize, itype, keepdims):
class TestLinalgVectNorm(HybridBlock):
def __init__(self, ord=None, axis=None, keepdims=False):
super(TestLinalgVectNorm, self).__init__()
self._ord = ord
self._axis = axis
self._keepdims = keepdims

def forward(self, x):
return np.linalg.vector_norm(x, ord=self._ord, axis=self._axis, keepdims=self._keepdims)

def spectral_norm_grad(data):
with mx.autograd.record():
UT, S, V = np.linalg.svd(data)
norm = np.max(np.abs(S), axis=-1)
norm.backward()
return data.grad.asnumpy()

def onp_vector_norm(a, axis=None, keepdims=False, ord=2):
if axis is None:
a = a.flatten()
axis = 0
elif isinstance(axis, tuple):
# Note: The axis argument supports any number of axes, whereas norm()
# only supports a single axis for vector norm.
rest = tuple(i for i in range(a.ndim) if i not in axis)
newshape = axis + rest
a = onp.transpose(a, newshape).reshape((reduce(lambda x, y: x * y, [a.shape[x] for x in axis]), *[a.shape[i] for i in rest]))
axis = 0
return onp.linalg.norm(a, axis=axis, keepdims=keepdims, ord=ord)

# numpy is flaky under float16, also gesvd does not support fp16
net = TestLinalgVectNorm(ord, axis, keepdims)
rtol = 1e-2
atol = 1e-2
if hybridize:
net.hybridize()
a = mx.np.random.uniform(-10.0, 10.0, size=shape, dtype=itype)
a.attach_grad()
with mx.autograd.record():
mx_ret = net(a)
if ord == 'inf':
np_ret = onp_vector_norm(a.asnumpy(), ord=onp.inf, axis=axis, keepdims=keepdims)
elif ord == '-inf':
np_ret = onp_vector_norm(a.asnumpy(), ord=-onp.inf, axis=axis, keepdims=keepdims)
else:
np_ret = onp_vector_norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims)

assert np_ret.shape == mx_ret.shape
assert_almost_equal(mx_ret.asnumpy(), np_ret, rtol=rtol, atol=atol)

mx_ret.backward()

grad_axis = axis
if axis is None and len(shape) >= 2 and ord is not None:
grad_axis = (len(shape) - 2, len(shape) - 1)
elif axis is None and ord is None:
grad_axis = tuple([i for i in range(len(shape))])
elif axis is None:
grad_axis = len(shape) - 1

if not keepdims and isinstance(grad_axis, tuple):
if len(grad_axis) == 2 and grad_axis[0] > grad_axis[1] and grad_axis[0] > len(np_ret.shape):
grad_axis = (grad_axis[1], grad_axis[0])
for i in grad_axis:
np_ret = onp.expand_dims(np_ret, axis=i)
elif not keepdims:
np_ret = onp.expand_dims(np_ret, axis=grad_axis)

if ord == 4:
backward_expected = onp.sign(a.asnumpy()) * onp.power(onp.abs(a.asnumpy()) / np_ret, ord - 1)
assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)

if ord == 2 and not isinstance(grad_axis, tuple):
backward_expected = onp.divide(a.asnumpy(), np_ret)
assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)
elif ord == 2 and isinstance(grad_axis, tuple):
backward_expected = spectral_norm_grad(a)
assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)

assert a.grad.shape == a.shape

# Test imperative once again
if ord == 'inf':
np_ret = onp_vector_norm(a.asnumpy(), ord=onp.inf, axis=axis, keepdims=keepdims)
elif ord == '-inf':
np_ret = onp_vector_norm(a.asnumpy(), ord=-onp.inf, axis=axis, keepdims=keepdims)
else:
np_ret = onp_vector_norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims)
mx_ret = np.linalg.vector_norm(a, ord=ord, axis=axis, keepdims=keepdims)
assert_almost_equal(mx_ret.asnumpy(), np_ret, rtol=rtol, atol=atol)


@use_np
@pytest.mark.parametrize('shape,ord,axis', [
((2, 3, 4), 1, (2, 1)),
((2, 3, 4), 2, (1, 2)),
((2, 3, 4), None, None),
((3,), None, None),
((2, 3), 2, 1),
((2, 3, 4), 1, 1),
((2, 3, 4), -1, 2),
((2, 3, 4), 2, 1),
((2, 3, 4), 4, 1),
((2, 3, 0, 4), -2, 1),
((2, 3, 4, 5), 2, (2, 3)),
((2, 3), -1, None),
((2, 3, 4), 'inf', 1),
((2, 3, 4), '-inf', (1, 0)),
((2, 3), None, (0, 1)),
((3, 2, 3), None, (1, 2)),
((2, 3), None, None),
((2, 3, 4), 'fro', (0, 2)),
((2, 0, 4), 'fro', (0, 2)),
((2, 3, 4), None, (0, 2)),
((2, 3, 4), -3.2, 2),
((2, 3, 4), -1, (0, 1)),
((2, 3, 4), 'inf', (0, 2)),
((2, 3, 4), '-inf', (0, 2)),
((4, 4, 4, 4), -2, (0, 2)),
((2, 3, 4), 'nuc', (0, 2)),
((2, 2), 'nuc', None),
])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('itype', [np.float32, np.float64])
@pytest.mark.parametrize('keepdims', [True, False])
def test_np_linalg_matrix_norm(shape, ord, axis, hybridize, itype, keepdims):
class TestLinalgMatNorm(HybridBlock):
def __init__(self, ord=None, axis=None, keepdims=False):
super(TestLinalgMatNorm, self).__init__()
self._ord = ord
self._axis = axis
self._keepdims = keepdims

def forward(self, x):
return np.linalg.matrix_norm(x, ord=self._ord, axis=self._axis, keepdims=self._keepdims)

def spectral_norm_grad(data):
with mx.autograd.record():
UT, S, V = np.linalg.svd(data)
norm = np.max(np.abs(S), axis=-1)
norm.backward()
return data.grad.asnumpy()

# numpy is flaky under float16, also gesvd does not support fp16
net = TestLinalgMatNorm(ord, axis, keepdims)
rtol = 1e-2
atol = 1e-2
if hybridize:
net.hybridize()
a = mx.np.random.uniform(-10.0, 10.0, size=shape, dtype=itype)
if not isinstance(axis, tuple) or not len(axis) == 2:
assertRaises(ValueError, np.linalg.matrix_norm, a, ord, axis, keepdims)
return
a.attach_grad()
with mx.autograd.record():
mx_ret = net(a)
if ord == 'inf':
np_ret = onp.linalg.norm(a.asnumpy(), ord=onp.inf, axis=axis, keepdims=keepdims)
elif ord == '-inf':
np_ret = onp.linalg.norm(a.asnumpy(), ord=-onp.inf, axis=axis, keepdims=keepdims)
else:
np_ret = onp.linalg.norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims)

assert np_ret.shape == mx_ret.shape
assert_almost_equal(mx_ret.asnumpy(), np_ret, rtol=rtol, atol=atol)

mx_ret.backward()

grad_axis = axis
if axis is None and len(shape) >= 2 and ord is not None:
grad_axis = (len(shape) - 2, len(shape) - 1)
elif axis is None and ord is None:
grad_axis = tuple([i for i in range(len(shape))])
elif axis is None:
grad_axis = len(shape) - 1

if not keepdims and isinstance(grad_axis, tuple):
if len(grad_axis) == 2 and grad_axis[0] > grad_axis[1] and grad_axis[0] > len(np_ret.shape):
grad_axis = (grad_axis[1], grad_axis[0])
for i in grad_axis:
np_ret = onp.expand_dims(np_ret, axis=i)
elif not keepdims:
np_ret = onp.expand_dims(np_ret, axis=grad_axis)

if ord == 4:
backward_expected = onp.sign(a.asnumpy()) * onp.power(onp.abs(a.asnumpy()) / np_ret, ord - 1)
assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)

if ord == 2 and not isinstance(grad_axis, tuple):
backward_expected = onp.divide(a.asnumpy(), np_ret)
assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)
elif ord == 2 and isinstance(grad_axis, tuple):
backward_expected = spectral_norm_grad(a)
assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)

if ord == 'fro':
backward_expected = onp.divide(a.asnumpy(), np_ret)
assert_almost_equal(a.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)

assert a.grad.shape == a.shape

# Test imperative once again
if ord == 'inf':
np_ret = onp.linalg.norm(a.asnumpy(), ord=onp.inf, axis=axis, keepdims=keepdims)
elif ord == '-inf':
np_ret = onp.linalg.norm(a.asnumpy(), ord=-onp.inf, axis=axis, keepdims=keepdims)
else:
np_ret = onp.linalg.norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims)
mx_ret = np.linalg.matrix_norm(a, ord=ord, axis=axis, keepdims=keepdims)
assert_almost_equal(mx_ret.asnumpy(), np_ret, rtol=rtol, atol=atol)


@use_np
@pytest.mark.parametrize('shape', [
(3, 3),
Expand Down

0 comments on commit a2ad4db

Please sign in to comment.