diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index e7e2e369722d..c01a5d270f0f 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -1159,35 +1159,7 @@ def norm(x: ArrayLike, ord: int | str | None = None, num_axes = len(axis) if num_axes == 1: - if ord is None or ord == 2: - return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis, - keepdims=keepdims)) - elif ord == jnp.inf: - return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims) - elif ord == -jnp.inf: - return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims) - elif ord == 0: - return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, - axis=axis, keepdims=keepdims) - elif ord == 1: - # Numpy has a special case for ord == 1 as an optimization. We don't - # really need the optimization (XLA could do it for us), but the Numpy - # code has slightly different type promotion semantics, so we need a - # special case too. - return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims) - elif isinstance(ord, str): - msg = f"Invalid order '{ord}' for vector norm." - if ord == "inf": - msg += "Use 'jax.numpy.inf' instead." - if ord == "-inf": - msg += "Use '-jax.numpy.inf' instead." - raise ValueError(msg) - else: - abs_x = ufuncs.abs(x) - ord_arr = lax_internal._const(abs_x, ord) - ord_inv = lax_internal._const(abs_x, 1. / ord_arr) - out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims) - return ufuncs.power(out, ord_inv) + return vector_norm(x, ord=2 if ord is None else ord, axis=axis, keepdims=keepdims) elif num_axes == 2: row_axis, col_axis = axis # pytype: disable=bad-unpacking @@ -1632,7 +1604,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: @export -def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False, +def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ord: int | str = 2) -> Array: """Compute the vector norm of a vector or batch of vectors. @@ -1668,13 +1640,35 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa Array([3.7416575, 9.486833 ], dtype=float32) """ check_arraylike('jnp.linalg.vector_norm', x) - if axis is None: - result = norm(jnp.ravel(x), ord=ord) - if keepdims: - result = lax.expand_dims(result, range(jnp.ndim(x))) - return result - return norm(x, axis=axis, keepdims=keepdims, ord=ord) - + if ord is None or ord == 2: + return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis, + keepdims=keepdims)) + elif ord == jnp.inf: + return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims) + elif ord == -jnp.inf: + return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims) + elif ord == 0: + return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, + axis=axis, keepdims=keepdims) + elif ord == 1: + # Numpy has a special case for ord == 1 as an optimization. We don't + # really need the optimization (XLA could do it for us), but the Numpy + # code has slightly different type promotion semantics, so we need a + # special case too. + return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims) + elif isinstance(ord, str): + msg = f"Invalid order '{ord}' for vector norm." + if ord == "inf": + msg += "Use 'jax.numpy.inf' instead." + if ord == "-inf": + msg += "Use '-jax.numpy.inf' instead." + raise ValueError(msg) + else: + abs_x = ufuncs.abs(x) + ord_arr = lax_internal._const(abs_x, ord) + ord_inv = lax_internal._const(abs_x, 1. / ord_arr) + out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims) + return ufuncs.power(out, ord_inv) @export def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 7c135b4ffeca..0da09e232deb 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -16,6 +16,8 @@ from functools import partial import itertools +from typing import Iterator +from unittest import skipIf import numpy as np import scipy @@ -54,6 +56,20 @@ def _is_required_cuda_version_satisfied(cuda_version): return int(version.split()[-1]) >= cuda_version +def _axis_for_ndim(ndim: int) -> Iterator[None | int | tuple[int, ...]]: + """ + Generate a range of valid axis arguments for a reduction over + an array with a given number of dimensions. + """ + yield from (None, ()) + if ndim > 0: + yield from (0, (-1,)) + if ndim > 1: + yield from (1, (0, 1), (-1, 0)) + if ndim > 2: + yield (-1, 0, 1) + + def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray: """scipy.linalg.toeplitz with v1.17+ batching semantics.""" if scipy_version >= (1, 17, 0): @@ -707,29 +723,25 @@ def testMatrixNorm(self, shape, dtype, keepdims, ord): self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3) self._CompileAndCheck(jnp_fn, args_maker) + @skipIf(jtu.numpy_version() < (2, 0, 0), "np.linalg.vector_norm requires NumPy 2.0") @jtu.sample_product( - shape=[(3,), (3, 4), (2, 3, 4, 5)], + [ + dict(shape=shape, axis=axis) + for shape in [(3,), (3, 4), (2, 3, 4, 5)] + for axis in _axis_for_ndim(len(shape)) + ], dtype=float_types + complex_types, keepdims=[True, False], - axis=[0, None], ord=[1, -1, 2, -2, np.inf, -np.inf], ) def testVectorNorm(self, shape, dtype, keepdims, axis, ord): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - if jtu.numpy_version() < (2, 0, 0): - def np_fn(x, *, ord, keepdims, axis): - x = np.asarray(x) - if axis is None: - result = np_fn(x.ravel(), ord=ord, keepdims=False, axis=0) - return np.reshape(result, (1,) * x.ndim) if keepdims else result - return np.linalg.norm(x, ord=ord, keepdims=keepdims, axis=axis) - else: - np_fn = np.linalg.vector_norm - np_fn = partial(np_fn, ord=ord, keepdims=keepdims, axis=axis) + np_fn = partial(np.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis) jnp_fn = partial(jnp.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis) - self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3) - self._CompileAndCheck(jnp_fn, args_maker) + tol = 1E-3 if jtu.test_device_matches(['tpu']) else None + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + self._CompileAndCheck(jnp_fn, args_maker, tol=tol) # jnp.linalg.vecdot is an alias of jnp.vecdot; do a minimal test here. @jtu.sample_product(