Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement numpy.nanprod #1369

Merged
merged 11 commits into from
Feb 2, 2022
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Pint Changelog
- Fix casting error when using to_reduced_units with array of int.
(Issue #1184)
- Use default numpy `np.printoptions` available since numpy 1.15.
- Implement `numpy.nanprod` (Issue #1369)
- Fix default_format ignored for measurement (Issue #1456)


Expand Down
65 changes: 40 additions & 25 deletions pint/numpy_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,34 +679,49 @@ def _all(a, *args, **kwargs):
raise ValueError("Boolean value of Quantity with offset unit is ambiguous.")


@implements("prod", "function")
def _prod(a, *args, **kwargs):
arg_names = ("axis", "dtype", "out", "keepdims", "initial", "where")
all_kwargs = dict(**dict(zip(arg_names, args)), **kwargs)
axis = all_kwargs.get("axis", None)
where = all_kwargs.get("where", None)

registry = a.units._REGISTRY

if axis is not None and where is not None:
_, where_ = np.broadcast_arrays(a._magnitude, where)
exponents = np.unique(np.sum(where_, axis=axis))
if len(exponents) == 1 or (len(exponents) == 2 and 0 in exponents):
units = a.units ** np.max(exponents)
def implement_prod_func(name):
if np is None:
return

func = getattr(np, name, None)
if func is None:
return

@implements(name, "function")
def _prod(a, *args, **kwargs):
arg_names = ("axis", "dtype", "out", "keepdims", "initial", "where")
all_kwargs = dict(**dict(zip(arg_names, args)), **kwargs)
axis = all_kwargs.get("axis", None)
where = all_kwargs.get("where", None)

registry = a.units._REGISTRY

if axis is not None and where is not None:
_, where_ = np.broadcast_arrays(a._magnitude, where)
exponents = np.unique(np.sum(where_, axis=axis))
if len(exponents) == 1 or (len(exponents) == 2 and 0 in exponents):
units = a.units ** np.max(exponents)
else:
units = registry.dimensionless
a = a.to(units)
elif axis is not None:
units = a.units ** a.shape[axis]
elif where is not None:
exponent = np.sum(where)
units = a.units ** exponent
else:
units = registry.dimensionless
a = a.to(units)
elif axis is not None:
units = a.units ** a.shape[axis]
elif where is not None:
exponent = np.sum(where)
units = a.units ** exponent
else:
units = a.units ** a.size
exponent = (
np.sum(np.logical_not(np.isnan(a))) if name == "nanprod" else a.size
)
units = a.units ** exponent

result = func(a._magnitude, *args, **kwargs)

return registry.Quantity(result, units)

result = np.prod(a._magnitude, *args, **kwargs)

return registry.Quantity(result, units)
for name in ["prod", "nanprod"]:
implement_prod_func(name)


# Implement simple matching-unit or stripped-unit functions based on signature
Expand Down
10 changes: 10 additions & 0 deletions pint/testsuite/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,16 @@ def test_prod_numpy_func(self):
np.prod(self.q, axis=axis, where=[True, False]), [3, 1] * self.ureg.m ** 2
)

@helpers.requires_array_function_protocol()
def test_nanprod_numpy_func(self):
helpers.assert_quantity_equal(np.nanprod(self.q_nan), 6 * self.ureg.m ** 3)
helpers.assert_quantity_equal(
np.nanprod(self.q_nan, axis=0), [3, 2] * self.ureg.m ** 2
)
helpers.assert_quantity_equal(
np.nanprod(self.q_nan, axis=1), [2, 3] * self.ureg.m ** 2
)

def test_sum(self):
assert self.q.sum() == 10 * self.ureg.m
helpers.assert_quantity_equal(self.q.sum(0), [4, 6] * self.ureg.m)
Expand Down