Skip to content

Commit

Permalink
jnp.linalg.vector_norm: properly support multiple axes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 5, 2024
1 parent 3f5f3e1 commit aaaee63
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 51 deletions.
68 changes: 31 additions & 37 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,35 +1159,7 @@ def norm(x: ArrayLike, ord: int | str | None = None,

num_axes = len(axis)
if num_axes == 1:
if ord is None or ord == 2:
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
keepdims=keepdims))
elif ord == jnp.inf:
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif ord == -jnp.inf:
return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0:
return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
axis=axis, keepdims=keepdims)
elif ord == 1:
# Numpy has a special case for ord == 1 as an optimization. We don't
# really need the optimization (XLA could do it for us), but the Numpy
# code has slightly different type promotion semantics, so we need a
# special case too.
return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif isinstance(ord, str):
msg = f"Invalid order '{ord}' for vector norm."
if ord == "inf":
msg += "Use 'jax.numpy.inf' instead."
if ord == "-inf":
msg += "Use '-jax.numpy.inf' instead."
raise ValueError(msg)
else:
abs_x = ufuncs.abs(x)
ord_arr = lax_internal._const(abs_x, ord)
ord_inv = lax_internal._const(abs_x, 1. / ord_arr)
out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims)
return ufuncs.power(out, ord_inv)
return vector_norm(x, ord=2 if ord is None else ord, axis=axis, keepdims=keepdims)

elif num_axes == 2:
row_axis, col_axis = axis # pytype: disable=bad-unpacking
Expand Down Expand Up @@ -1632,7 +1604,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:


@export
def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False,
def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False,
ord: int | str = 2) -> Array:
"""Compute the vector norm of a vector or batch of vectors.
Expand Down Expand Up @@ -1668,13 +1640,35 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa
Array([3.7416575, 9.486833 ], dtype=float32)
"""
check_arraylike('jnp.linalg.vector_norm', x)
if axis is None:
result = norm(jnp.ravel(x), ord=ord)
if keepdims:
result = lax.expand_dims(result, range(jnp.ndim(x)))
return result
return norm(x, axis=axis, keepdims=keepdims, ord=ord)

if ord is None or ord == 2:
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
keepdims=keepdims))
elif ord == jnp.inf:
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif ord == -jnp.inf:
return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0:
return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
axis=axis, keepdims=keepdims)
elif ord == 1:
# Numpy has a special case for ord == 1 as an optimization. We don't
# really need the optimization (XLA could do it for us), but the Numpy
# code has slightly different type promotion semantics, so we need a
# special case too.
return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims)
elif isinstance(ord, str):
msg = f"Invalid order '{ord}' for vector norm."
if ord == "inf":
msg += "Use 'jax.numpy.inf' instead."
if ord == "-inf":
msg += "Use '-jax.numpy.inf' instead."
raise ValueError(msg)
else:
abs_x = ufuncs.abs(x)
ord_arr = lax_internal._const(abs_x, ord)
ord_inv = lax_internal._const(abs_x, 1. / ord_arr)
out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims)
return ufuncs.power(out, ord_inv)

@export
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
Expand Down
40 changes: 26 additions & 14 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from functools import partial
import itertools
from typing import Iterator
from unittest import skipIf

import numpy as np
import scipy
Expand Down Expand Up @@ -54,6 +56,20 @@ def _is_required_cuda_version_satisfied(cuda_version):
return int(version.split()[-1]) >= cuda_version


def _axis_for_ndim(ndim: int) -> Iterator[None | int | tuple[int, ...]]:
"""
Generate a range of valid axis arguments for a reduction over
an array with a given number of dimensions.
"""
yield from (None, ())
if ndim > 0:
yield from (0, (-1,))
if ndim > 1:
yield from (1, (0, 1), (-1, 0))
if ndim > 2:
yield (-1, 0, 1)


def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray:
"""scipy.linalg.toeplitz with v1.17+ batching semantics."""
if scipy_version >= (1, 17, 0):
Expand Down Expand Up @@ -707,29 +723,25 @@ def testMatrixNorm(self, shape, dtype, keepdims, ord):
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)

@skipIf(jtu.numpy_version() < (2, 0, 0), "np.linalg.vector_norm requires NumPy 2.0")
@jtu.sample_product(
shape=[(3,), (3, 4), (2, 3, 4, 5)],
[
dict(shape=shape, axis=axis)
for shape in [(3,), (3, 4), (2, 3, 4, 5)]
for axis in _axis_for_ndim(len(shape))
],
dtype=float_types + complex_types,
keepdims=[True, False],
axis=[0, None],
ord=[1, -1, 2, -2, np.inf, -np.inf],
)
def testVectorNorm(self, shape, dtype, keepdims, axis, ord):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fn(x, *, ord, keepdims, axis):
x = np.asarray(x)
if axis is None:
result = np_fn(x.ravel(), ord=ord, keepdims=False, axis=0)
return np.reshape(result, (1,) * x.ndim) if keepdims else result
return np.linalg.norm(x, ord=ord, keepdims=keepdims, axis=axis)
else:
np_fn = np.linalg.vector_norm
np_fn = partial(np_fn, ord=ord, keepdims=keepdims, axis=axis)
np_fn = partial(np.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis)
jnp_fn = partial(jnp.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)
tol = 1E-3 if jtu.test_device_matches(['tpu']) else None
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

# jnp.linalg.vecdot is an alias of jnp.vecdot; do a minimal test here.
@jtu.sample_product(
Expand Down

0 comments on commit aaaee63

Please sign in to comment.