Skip to content

Commit

Permalink
update code and doc for functions that keep units (#42)
Browse files Browse the repository at this point in the history
* update code and doc for functions that keep units

* fix bugs
  • Loading branch information
chaoming0625 committed Aug 23, 2024
1 parent 6289562 commit 8608358
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 638 deletions.
232 changes: 1 addition & 231 deletions brainunit/math/_fun_accept_unitless.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@
'exprel', 'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'log2',
'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan',
'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan',
'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle',

# math funcs only accept unitless (unary) can return Quantity
'round', 'around', 'round_', 'rint',
'floor', 'ceil', 'trunc', 'fix', 'modf', 'frexp',
'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle', 'frexp',

# math funcs only accept unitless (binary)
'hypot', 'arctan2', 'logaddexp', 'logaddexp2',
Expand Down Expand Up @@ -748,232 +744,6 @@ def frexp(
return _fun_accept_unitless_unary(jnp.frexp, x, unit_to_scale=unit_to_scale)


def _fun_accept_unitless_return_keep_unit(
func: Callable,
x: jax.typing.ArrayLike | Quantity,
*args,
unit_to_scale: Optional[Unit] = None,
**kwargs
):
if isinstance(x, Quantity):
if unit_to_scale is None:
assert x.is_unitless, (f'Input should be unitless for the function "{func}" '
f'when scaling "unit_to_scale" is not provided.')
x = x.mantissa
return func(x, *args, **kwargs)
else:
assert isinstance(unit_to_scale, Unit), f'unit_to_scale should be a Unit instance. Got {unit_to_scale}'
r = func(x.to_decimal(unit_to_scale), *args, **kwargs)
return jax.tree.map(lambda a: a * unit_to_scale, r)
else:
# assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" is not a Quantity.'
return func(x, *args, **kwargs)


@set_module_as('brainunit.math')
def round_(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Round an array to the nearest integer.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.round_, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def around(
x: Union[Quantity, jax.typing.ArrayLike],
decimals: int = 0,
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Round an array to the nearest integer.
Parameters
----------
x : array_like, Quantity
Input array.
decimals : int, optional
Number of decimal places to round to (default is 0).
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.around, x, unit_to_scale=unit_to_scale, decimals=decimals)


@set_module_as('brainunit.math')
def round(
x: Union[Quantity, jax.typing.ArrayLike],
decimals: int = 0,
unit_to_scale: Optional[Unit] = None,
) -> jax.Array | Quantity:
"""
Round an array to the nearest integer.
Parameters
----------
x : array_like, Quantity
Input array.
decimals : int, optional
Number of decimal places to round to (default is 0).
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.round, x, unit_to_scale=unit_to_scale, decimals=decimals)


@set_module_as('brainunit.math')
def rint(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> Union[Quantity, jax.Array]:
"""
Round an array to the nearest integer.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.rint, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def floor(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Return the floor of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.floor, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def ceil(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Return the ceiling of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.ceil, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def trunc(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Return the truncated value of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.trunc, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def fix(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Return the nearest integer towards zero.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.fix, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def modf(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> Tuple[jax.Array, jax.Array]:
"""
Return the fractional and integer parts of the array elements.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
The fractional and integral parts of the input, both with the same dimension.
"""
return _fun_accept_unitless_return_keep_unit(jnp.modf, x, unit_to_scale=unit_to_scale)


# math funcs only accept unitless (binary)
# ----------------------------------------

Expand Down
64 changes: 2 additions & 62 deletions brainunit/math/_fun_accept_unitless_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,15 @@
'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan',
'tanh',
]
fun_accept_unitless_unary_2_results = [
'modf', 'frexp',
]

fun_accept_unitless_binary = [
'hypot', 'arctan2', 'logaddexp', 'logaddexp2',
'corrcoef', 'correlate', 'cov',
]
fun_accept_unitless_binary_ldexp = [
'ldexp',
]
fun_accept_unitless_unary_can_return_quantity = [
'round', 'around', 'round_', 'rint',
'floor', 'ceil', 'trunc', 'fix',
]

fun_elementwise_bit_operation_unary = [
'bitwise_not', 'invert',
]
Expand Down Expand Up @@ -70,36 +65,6 @@ def test_fun_accept_unitless_unary_1(self, value):
with pytest.raises(bu.UnitMismatchError):
result = fun(q, unit_to_scale=bu.nS)

@parameterized.product(
value=[(1.123, 2.567, 3.891), (1.23, 2.34, 3.45)]
)
def test_fun_accept_unitless_binary_2_results(self, value):
bm_fun_list = [getattr(bm, fun) for fun in fun_accept_unitless_unary_2_results]
jnp_fun_list = [getattr(jnp, fun) for fun in fun_accept_unitless_unary_2_results]

for bm_fun, jnp_fun in zip(bm_fun_list, jnp_fun_list):
print(f'fun: {bm_fun.__name__}')
result1, result2 = bm_fun(jnp.array(value))
expected1, expected2 = jnp_fun(jnp.array(value))
assert_quantity(result1, expected1)
assert_quantity(result2, expected2)

q = value * meter
result1, result2 = bm_fun(q, unit_to_scale=meter)
expected1, expected2 = jnp_fun(jnp.array(value))
if bm_fun.__name__ == 'modf':
assert_quantity(result1, expected1, meter)
assert_quantity(result2, expected2, meter)
else:
assert_quantity(result1, expected1)
assert_quantity(result2, expected2)

with pytest.raises(AssertionError):
result1, result2 = bm_fun(q)

with pytest.raises(bu.UnitMismatchError):
result1, result2 = bm_fun(q, unit_to_scale=bu.second)

@parameterized.product(
value=[[(1.0, 2.0), (3.0, 4.0), ],
[(1.23, 2.34, 3.45), (4.56, 5.67, 6.78)]]
Expand Down Expand Up @@ -153,31 +118,6 @@ def test_func_accept_unitless_binary_ldexp(self, value):
with pytest.raises(AssertionError):
result = bm_fun(q1, q2)

@parameterized.product(
value=[(1.123, 2.567, 3.891), (1.23, 2.34, 3.45)]
)
def test_fun_accept_unitless_unary_can_return_quantity(self, value):
bm_fun_list = [getattr(bm, fun) for fun in fun_accept_unitless_unary_can_return_quantity]
jnp_fun_list = [getattr(jnp, fun) for fun in fun_accept_unitless_unary_can_return_quantity]

for bm_fun, jnp_fun in zip(bm_fun_list, jnp_fun_list):
print(f'fun: {bm_fun.__name__}')

result = bm_fun(jnp.array(value))
expected = jnp_fun(jnp.array(value))
assert_quantity(result, expected)

q = value * meter
result = bm_fun(q, unit_to_scale=meter)
expected = jnp_fun(jnp.array(value))
assert_quantity(result, expected, meter)

with pytest.raises(AssertionError):
result = bm_fun(q)

with pytest.raises(bu.UnitMismatchError):
result = bm_fun(q, unit_to_scale=bu.second)

@parameterized.product(
value=[(1, 2), (1, 2, 3)]
)
Expand Down
Loading

0 comments on commit 8608358

Please sign in to comment.