Skip to content

Commit

Permalink
update code and doc for functions that keep units
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 23, 2024
1 parent 6289562 commit a0ea781
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 575 deletions.
232 changes: 1 addition & 231 deletions brainunit/math/_fun_accept_unitless.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@
'exprel', 'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'log2',
'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan',
'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan',
'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle',

# math funcs only accept unitless (unary) can return Quantity
'round', 'around', 'round_', 'rint',
'floor', 'ceil', 'trunc', 'fix', 'modf', 'frexp',
'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle', 'frexp',

# math funcs only accept unitless (binary)
'hypot', 'arctan2', 'logaddexp', 'logaddexp2',
Expand Down Expand Up @@ -748,232 +744,6 @@ def frexp(
return _fun_accept_unitless_unary(jnp.frexp, x, unit_to_scale=unit_to_scale)


def _fun_accept_unitless_return_keep_unit(
func: Callable,
x: jax.typing.ArrayLike | Quantity,
*args,
unit_to_scale: Optional[Unit] = None,
**kwargs
):
if isinstance(x, Quantity):
if unit_to_scale is None:
assert x.is_unitless, (f'Input should be unitless for the function "{func}" '
f'when scaling "unit_to_scale" is not provided.')
x = x.mantissa
return func(x, *args, **kwargs)
else:
assert isinstance(unit_to_scale, Unit), f'unit_to_scale should be a Unit instance. Got {unit_to_scale}'
r = func(x.to_decimal(unit_to_scale), *args, **kwargs)
return jax.tree.map(lambda a: a * unit_to_scale, r)
else:
# assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" is not a Quantity.'
return func(x, *args, **kwargs)


@set_module_as('brainunit.math')
def round_(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Round an array to the nearest integer.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.round_, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def around(
x: Union[Quantity, jax.typing.ArrayLike],
decimals: int = 0,
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Round an array to the nearest integer.
Parameters
----------
x : array_like, Quantity
Input array.
decimals : int, optional
Number of decimal places to round to (default is 0).
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.around, x, unit_to_scale=unit_to_scale, decimals=decimals)


@set_module_as('brainunit.math')
def round(
x: Union[Quantity, jax.typing.ArrayLike],
decimals: int = 0,
unit_to_scale: Optional[Unit] = None,
) -> jax.Array | Quantity:
"""
Round an array to the nearest integer.
Parameters
----------
x : array_like, Quantity
Input array.
decimals : int, optional
Number of decimal places to round to (default is 0).
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.round, x, unit_to_scale=unit_to_scale, decimals=decimals)


@set_module_as('brainunit.math')
def rint(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> Union[Quantity, jax.Array]:
"""
Round an array to the nearest integer.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.rint, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def floor(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Return the floor of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.floor, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def ceil(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Return the ceiling of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.ceil, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def trunc(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Return the truncated value of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.trunc, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def fix(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> jax.Array:
"""
Return the nearest integer towards zero.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
out : jax.Array
"""
return _fun_accept_unitless_return_keep_unit(jnp.fix, x, unit_to_scale=unit_to_scale)


@set_module_as('brainunit.math')
def modf(
x: Union[Quantity, jax.typing.ArrayLike],
unit_to_scale: Optional[Unit] = None,
) -> Tuple[jax.Array, jax.Array]:
"""
Return the fractional and integer parts of the array elements.
Parameters
----------
x : array_like, Quantity
Input array.
unit_to_scale : Unit, optional
The unit to scale the ``x``.
Returns
-------
The fractional and integral parts of the input, both with the same dimension.
"""
return _fun_accept_unitless_return_keep_unit(jnp.modf, x, unit_to_scale=unit_to_scale)


# math funcs only accept unitless (binary)
# ----------------------------------------

Expand Down
Loading

0 comments on commit a0ea781

Please sign in to comment.