Skip to content

Commit

Permalink
new implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgsavage committed Jun 23, 2023
1 parent 50c5ed6 commit 2ab9f0f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 28 deletions.
56 changes: 30 additions & 26 deletions pint/facets/numpy/numpy_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,36 +171,31 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None):
"""
all_args = all_args or []

def multiply_units(args):
product = first_input_units._REGISTRY.parse_units("")
for x in args:
if hasattr(x, "units"):
product *= x.units
elif _is_sequence_with_quantity_elements(x):
product *= x[0].units
return product

def get_numerator_unit(arg):
if _is_sequence_with_quantity_elements(arg):
return getattr(arg[0], "units", first_input_units._REGISTRY.parse_units(""))
else:
return getattr(arg, "units", first_input_units._REGISTRY.parse_units(""))

if unit_op == "sum":
result_unit = (1 * first_input_units + 1 * first_input_units).units
elif unit_op == "mul":
result_unit = multiply_units(all_args)
product = first_input_units._REGISTRY.parse_units("")
for x in all_args:
if hasattr(x, "units"):
product *= x.units
result_unit = product
elif unit_op == "delta":
result_unit = (1 * first_input_units - 1 * first_input_units).units
elif unit_op == "delta,div":
numerator_unit = (1 * first_input_units - 1 * first_input_units).units
denominator_unit = multiply_units(all_args[1:])
result_unit = numerator_unit / denominator_unit
product = (1 * first_input_units - 1 * first_input_units).units
for x in all_args[1:]:
if hasattr(x, "units"):
product /= x.units
result_unit = product
elif unit_op == "div":
# Start with first arg in numerator, all others in denominator
numerator_unit = get_numerator_unit(all_args[0])
denominator_unit = multiply_units(all_args[1:])
result_unit = numerator_unit / denominator_unit
product = getattr(
all_args[0], "units", first_input_units._REGISTRY.parse_units("")
)
for x in all_args[1:]:
if hasattr(x, "units"):
product /= x.units
result_unit = product
elif unit_op == "variance":
result_unit = ((1 * first_input_units + 1 * first_input_units) ** 2).units
elif unit_op == "square":
Expand All @@ -217,9 +212,13 @@ def get_numerator_unit(arg):
result_unit = first_input_units**size
elif unit_op == "invdiv":
# Start with first arg in numerator, all others in denominator
numerator_unit = get_numerator_unit(all_args[0])
denominator_unit = multiply_units(all_args[1:])
result_unit = (numerator_unit / denominator_unit) ** -1
product = getattr(
all_args[0], "units", first_input_units._REGISTRY.parse_units("")
)
for x in all_args[1:]:
if hasattr(x, "units"):
product /= x.units
result_unit = product**-1
else:
raise ValueError("Output unit method {} not understood".format(unit_op))

Expand Down Expand Up @@ -285,6 +284,12 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None):

@implements(func_str, func_type)
def implementation(*args, **kwargs):
if (func_str in ["multiply", "true_divide", "divide", "floor_divide"] and
any([_is_sequence_with_quantity_elements(arg) and not _is_quantity(arg) for arg in args])):
# the sequence may contain different units, so fall back to element-wise
print(func_str,args, kwargs)
return np.array([func(args[0][i],args[1][i]) for i in range(len(args[0]))],dtype=object)

first_input_units = _get_first_input_units(args, kwargs)
if input_units == "all_consistent":
# Match all input args/kwargs to same units
Expand Down Expand Up @@ -422,7 +427,6 @@ def implementation(*args, **kwargs):
"nextafter",
"trunc",
"absolute",
"positive",
"negative",
"maximum",
"minimum",
Expand Down
14 changes: 12 additions & 2 deletions pint/testsuite/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,8 +892,18 @@ def test_issue1674(self, module_registry):
arr_of_q = np.array([Q_(2, "m"), Q_(4, "m")], dtype="object")
q_arr = Q_(np.array([1, 2]), "m")

helpers.assert_quantity_almost_equal(arr_of_q * q_arr, Q_([2, 8], "m^2"))
helpers.assert_quantity_almost_equal(arr_of_q / q_arr, Q_([2, 2], ""))
helpers.assert_quantity_equal(arr_of_q * q_arr, np.array([Q_(2, "m^2"), Q_(8, "m^2")], dtype="object"))
helpers.assert_quantity_equal(arr_of_q / q_arr, np.array([Q_(2, ""), Q_(2, "")], dtype="object"))


arr_of_q = np.array([Q_(2, "m"), Q_(4, "s")], dtype="object")
q_arr = Q_(np.array([1, 2]), "m")

helpers.assert_quantity_equal(
arr_of_q * q_arr,
np.array([Q_(2, "m^2"), Q_(8, "m s")], dtype="object")
)



if np is not None:
Expand Down

0 comments on commit 2ab9f0f

Please sign in to comment.