Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update code and doc for functions that keep units #42

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 1 addition & 231 deletions brainunit/math/_fun_accept_unitless.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@
'exprel', 'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'log2',
'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan',
'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan',
'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle',

# math funcs only accept unitless (unary) can return Quantity
'round', 'around', 'round_', 'rint',
'floor', 'ceil', 'trunc', 'fix', 'modf', 'frexp',
'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle', 'frexp',

# math funcs only accept unitless (binary)
'hypot', 'arctan2', 'logaddexp', 'logaddexp2',
Expand Down Expand Up @@ -748,232 +744,6 @@ def frexp(
return _fun_accept_unitless_unary(jnp.frexp, x, unit_to_scale=unit_to_scale)


def _fun_accept_unitless_return_keep_unit(
func: Callable,
x: jax.typing.ArrayLike | Quantity,
*args,
unit_to_scale: Optional[Unit] = None,
**kwargs
):
if isinstance(x, Quantity):
if unit_to_scale is None:
assert x.is_unitless, (f'Input should be unitless for the function "{func}" '
f'when scaling "unit_to_scale" is not provided.')
x = x.mantissa
return func(x, *args, **kwargs)
else:
assert isinstance(unit_to_scale, Unit), f'unit_to_scale should be a Unit instance. Got {unit_to_scale}'
r = func(x.to_decimal(unit_to_scale), *args, **kwargs)
return jax.tree.map(lambda a: a * unit_to_scale, r)
else:
# assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" is not a Quantity.'
return func(x, *args, **kwargs)


@set_module_as('brainunit.math')
def round_(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Round an array to the nearest integer.

Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.

Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.round_, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def around(
x: Union[Quantity, jax.typing.ArrayLike],
decimals: int = 0,
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Round an array to the nearest integer.

Parameters
----------
x : array_like, Quantity
Input array.
decimals : int, optional
Number of decimal places to round to (default is 0).
unit_to_scale : Unit, optional
The unit to scale the ``x``.

Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.around, x, unit_to_scale=unit_to_scale, decimals=decimals)


@set_module_as('brainunit.math')
def round(
x: Union[Quantity, jax.typing.ArrayLike],
decimals: int = 0,
unit_to_scale: Optional[Unit] = None,
) -> jax.Array | Quantity:
"""
Round an array to the nearest integer.

Parameters
----------
x : array_like, Quantity
Input array.
decimals : int, optional
Number of decimal places to round to (default is 0).
unit_to_scale : Unit, optional
The unit to scale the ``x``.

Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.round, x, unit_to_scale=unit_to_scale, decimals=decimals)


@set_module_as('brainunit.math')
def rint(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> Union[Quantity, jax.Array]:
"""
Round an array to the nearest integer.

Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.

Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.rint, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def floor(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Return the floor of the argument.

Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.

Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.floor, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def ceil(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Return the ceiling of the argument.

Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.

Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.ceil, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def trunc(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Return the truncated value of the argument.

Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.

Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.trunc, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def fix(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Return the nearest integer towards zero.

Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.

Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.fix, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def modf(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> Tuple[jax.Array, jax.Array]:
"""
Return the fractional and integer parts of the array elements.

Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.

Returns
-------
The fractional and integral parts of the input, both with the same dimension.
"""
return _fun_accept_unitless_return_keep_unit(jnp.modf, x, unit_to_scale=unit_to_scale)


# math funcs only accept unitless (binary)
# ----------------------------------------

Expand Down
64 changes: 2 additions & 62 deletions brainunit/math/_fun_accept_unitless_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,15 @@
'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan',
'tanh',
]
fun_accept_unitless_unary_2_results = [
'modf', 'frexp',
]

fun_accept_unitless_binary = [
'hypot', 'arctan2', 'logaddexp', 'logaddexp2',
'corrcoef', 'correlate', 'cov',
]
fun_accept_unitless_binary_ldexp = [
'ldexp',
]
fun_accept_unitless_unary_can_return_quantity = [
'round', 'around', 'round_', 'rint',
'floor', 'ceil', 'trunc', 'fix',
]

fun_elementwise_bit_operation_unary = [
'bitwise_not', 'invert',
]
Expand Down Expand Up @@ -70,36 +65,6 @@ def test_fun_accept_unitless_unary_1(self, value):
with pytest.raises(bu.UnitMismatchError):
result = fun(q, unit_to_scale=bu.nS)

@parameterized.product(
value=[(1.123, 2.567, 3.891), (1.23, 2.34, 3.45)]
)
def test_fun_accept_unitless_binary_2_results(self, value):
bm_fun_list = [getattr(bm, fun) for fun in fun_accept_unitless_unary_2_results]
jnp_fun_list = [getattr(jnp, fun) for fun in fun_accept_unitless_unary_2_results]

for bm_fun, jnp_fun in zip(bm_fun_list, jnp_fun_list):
print(f'fun: {bm_fun.__name__}')
result1, result2 = bm_fun(jnp.array(value))
expected1, expected2 = jnp_fun(jnp.array(value))
assert_quantity(result1, expected1)
assert_quantity(result2, expected2)

q = value * meter
result1, result2 = bm_fun(q, unit_to_scale=meter)
expected1, expected2 = jnp_fun(jnp.array(value))
if bm_fun.__name__ == 'modf':
assert_quantity(result1, expected1, meter)
assert_quantity(result2, expected2, meter)
else:
assert_quantity(result1, expected1)
assert_quantity(result2, expected2)

with pytest.raises(AssertionError):
result1, result2 = bm_fun(q)

with pytest.raises(bu.UnitMismatchError):
result1, result2 = bm_fun(q, unit_to_scale=bu.second)

@parameterized.product(
value=[[(1.0, 2.0), (3.0, 4.0), ],
[(1.23, 2.34, 3.45), (4.56, 5.67, 6.78)]]
Expand Down Expand Up @@ -153,31 +118,6 @@ def test_func_accept_unitless_binary_ldexp(self, value):
with pytest.raises(AssertionError):
result = bm_fun(q1, q2)

@parameterized.product(
value=[(1.123, 2.567, 3.891), (1.23, 2.34, 3.45)]
)
def test_fun_accept_unitless_unary_can_return_quantity(self, value):
bm_fun_list = [getattr(bm, fun) for fun in fun_accept_unitless_unary_can_return_quantity]
jnp_fun_list = [getattr(jnp, fun) for fun in fun_accept_unitless_unary_can_return_quantity]

for bm_fun, jnp_fun in zip(bm_fun_list, jnp_fun_list):
print(f'fun: {bm_fun.__name__}')

result = bm_fun(jnp.array(value))
expected = jnp_fun(jnp.array(value))
assert_quantity(result, expected)

q = value * meter
result = bm_fun(q, unit_to_scale=meter)
expected = jnp_fun(jnp.array(value))
assert_quantity(result, expected, meter)

with pytest.raises(AssertionError):
result = bm_fun(q)

with pytest.raises(bu.UnitMismatchError):
result = bm_fun(q, unit_to_scale=bu.second)

@parameterized.product(
value=[(1, 2), (1, 2, 3)]
)
Expand Down
Loading
Loading