From e8718fce4b74b504ad326032f347c8cc390cd528 Mon Sep 17 00:00:00 2001 From: Kyle Oman Date: Thu, 19 Dec 2024 15:53:05 +0000 Subject: [PATCH] BUG: `apply_over_axes` no longer assumes user-supplied function preserves units (#548) --- unyt/_array_functions.py | 2 +- unyt/tests/test_array_functions.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index c8e92f97..03a4875d 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -850,7 +850,7 @@ def savetxt(fname, X, *args, **kwargs): @implements(np.apply_over_axes) def apply_over_axes(func, a, axes): - res = func(np.asarray(a), axes[0]) * a.units + res = func(a, axes[0]) if len(axes) > 1: # this function is recursive by nature, # here we intentionally do not call the base _implementation diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py index fbabac1d..a8b582c6 100644 --- a/unyt/tests/test_array_functions.py +++ b/unyt/tests/test_array_functions.py @@ -1654,11 +1654,16 @@ def test_apply_along_axis(): assert ret.units == cm**2 -def test_apply_over_axes(): +@pytest.mark.parametrize("axes, expected_units", [((0, 1), cm**4), ((0,), cm**2)]) +def test_apply_over_axes(axes, expected_units): + # the user-supplied function must be trusted to treat units + # sensibly (mainly that it doesn't give a mix of units across + # the resulting array), but we can check that units are + # propagated correctly for well-behaved functions. a = np.eye(3) * cm - ret = np.apply_over_axes(lambda x, axis: x * cm, a, (0, 1)) - assert type(ret) is unyt_array - assert ret.units == cm**3 + ret = np.apply_over_axes(lambda x, axis: x[axis] ** 2, a, axes) + assert isinstance(ret, unyt_array) # could be unyt_quantity + assert ret.units == expected_units def test_array_equal():