Skip to content

Commit

Permalink
BUG: apply_over_axes no longer assumes user-supplied function prese…
Browse files Browse the repository at this point in the history
…rves units (#548)
  • Loading branch information
kyleaoman authored Dec 19, 2024
1 parent 3fb5a44 commit e8718fc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion unyt/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ def savetxt(fname, X, *args, **kwargs):

@implements(np.apply_over_axes)
def apply_over_axes(func, a, axes):
res = func(np.asarray(a), axes[0]) * a.units
res = func(a, axes[0])
if len(axes) > 1:
# this function is recursive by nature,
# here we intentionally do not call the base _implementation
Expand Down
13 changes: 9 additions & 4 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,11 +1654,16 @@ def test_apply_along_axis():
assert ret.units == cm**2


def test_apply_over_axes():
@pytest.mark.parametrize("axes, expected_units", [((0, 1), cm**4), ((0,), cm**2)])
def test_apply_over_axes(axes, expected_units):
# the user-supplied function must be trusted to treat units
# sensibly (mainly that it doesn't give a mix of units across
# the resulting array), but we can check that units are
# propagated correctly for well-behaved functions.
a = np.eye(3) * cm
ret = np.apply_over_axes(lambda x, axis: x * cm, a, (0, 1))
assert type(ret) is unyt_array
assert ret.units == cm**3
ret = np.apply_over_axes(lambda x, axis: x[axis] ** 2, a, axes)
assert isinstance(ret, unyt_array) # could be unyt_quantity
assert ret.units == expected_units


def test_array_equal():
Expand Down

0 comments on commit e8718fc

Please sign in to comment.