From a0ea78138c9b4c51e8f17c4527e9e51f3400a4a7 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 23 Aug 2024 10:12:20 +0800 Subject: [PATCH] update code and doc for functions that keep units --- brainunit/math/_fun_accept_unitless.py | 232 +------------------- brainunit/math/_fun_keep_unit.py | 183 ++++++++++++++- docs/apis/brainunit.math.array-creation.rst | 40 ---- docs/apis/brainunit.math.change-unit.rst | 41 ---- docs/apis/brainunit.math.einops.rst | 15 -- docs/apis/brainunit.math.keep-unit.rst | 108 --------- docs/apis/brainunit.math.misc.rst | 36 --- docs/apis/brainunit.math.remove-unit.rst | 42 ---- docs/apis/brainunit.math.unitless.rst | 61 ----- docs/conf.py | 2 + 10 files changed, 185 insertions(+), 575 deletions(-) delete mode 100644 docs/apis/brainunit.math.array-creation.rst delete mode 100644 docs/apis/brainunit.math.change-unit.rst delete mode 100644 docs/apis/brainunit.math.einops.rst delete mode 100644 docs/apis/brainunit.math.keep-unit.rst delete mode 100644 docs/apis/brainunit.math.misc.rst delete mode 100644 docs/apis/brainunit.math.remove-unit.rst delete mode 100644 docs/apis/brainunit.math.unitless.rst diff --git a/brainunit/math/_fun_accept_unitless.py b/brainunit/math/_fun_accept_unitless.py index 18e592d..45b77a2 100644 --- a/brainunit/math/_fun_accept_unitless.py +++ b/brainunit/math/_fun_accept_unitless.py @@ -27,11 +27,7 @@ 'exprel', 'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'log2', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', - 'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle', - - # math funcs only accept unitless (unary) can return Quantity - 'round', 'around', 'round_', 'rint', - 'floor', 'ceil', 'trunc', 'fix', 'modf', 'frexp', + 'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle', 'frexp', # math funcs only accept unitless (binary) 'hypot', 'arctan2', 'logaddexp', 'logaddexp2', @@ -748,232 +744,6 @@ def frexp( return _fun_accept_unitless_unary(jnp.frexp, x, unit_to_scale=unit_to_scale) -def _fun_accept_unitless_return_keep_unit( - func: Callable, - x: jax.typing.ArrayLike | Quantity, - *args, - unit_to_scale: Optional[Unit] = None, - **kwargs -): - if isinstance(x, Quantity): - if unit_to_scale is None: - assert x.is_unitless, (f'Input should be unitless for the function "{func}" ' - f'when scaling "unit_to_scale" is not provided.') - x = x.mantissa - return func(x, *args, **kwargs) - else: - assert isinstance(unit_to_scale, Unit), f'unit_to_scale should be a Unit instance. Got {unit_to_scale}' - r = func(x.to_decimal(unit_to_scale), *args, **kwargs) - return jax.tree.map(lambda a: a * unit_to_scale, r) - else: - # assert unit_to_scale is None, f'Unit should be None for the function "{func}" when "x" is not a Quantity.' - return func(x, *args, **kwargs) - - -@set_module_as('brainunit.math') -def round_( - x: Union[Quantity, jax.typing.ArrayLike], - unit_to_scale: Optional[Unit] = None, -) -> jax.Array: - """ - Round an array to the nearest integer. - - Parameters - ---------- - x : array_like, Quantity - Input array. - unit_to_scale : Unit, optional - The unit to scale the ``x``. - - Returns - ------- - out : jax.Array - """ - return _fun_accept_unitless_return_keep_unit(jnp.round_, x, unit_to_scale=unit_to_scale) - - -@set_module_as('brainunit.math') -def around( - x: Union[Quantity, jax.typing.ArrayLike], - decimals: int = 0, - unit_to_scale: Optional[Unit] = None, -) -> jax.Array: - """ - Round an array to the nearest integer. - - Parameters - ---------- - x : array_like, Quantity - Input array. - decimals : int, optional - Number of decimal places to round to (default is 0). - unit_to_scale : Unit, optional - The unit to scale the ``x``. - - Returns - ------- - out : jax.Array - """ - return _fun_accept_unitless_return_keep_unit(jnp.around, x, unit_to_scale=unit_to_scale, decimals=decimals) - - -@set_module_as('brainunit.math') -def round( - x: Union[Quantity, jax.typing.ArrayLike], - decimals: int = 0, - unit_to_scale: Optional[Unit] = None, -) -> jax.Array | Quantity: - """ - Round an array to the nearest integer. - - Parameters - ---------- - x : array_like, Quantity - Input array. - decimals : int, optional - Number of decimal places to round to (default is 0). - unit_to_scale : Unit, optional - The unit to scale the ``x``. - - Returns - ------- - out : jax.Array - """ - return _fun_accept_unitless_return_keep_unit(jnp.round, x, unit_to_scale=unit_to_scale, decimals=decimals) - - -@set_module_as('brainunit.math') -def rint( - x: Union[Quantity, jax.typing.ArrayLike], - unit_to_scale: Optional[Unit] = None, -) -> Union[Quantity, jax.Array]: - """ - Round an array to the nearest integer. - - Parameters - ---------- - x : array_like, Quantity - Input array. - unit_to_scale : Unit, optional - The unit to scale the ``x``. - - Returns - ------- - out : jax.Array - """ - return _fun_accept_unitless_return_keep_unit(jnp.rint, x, unit_to_scale=unit_to_scale) - - -@set_module_as('brainunit.math') -def floor( - x: Union[Quantity, jax.typing.ArrayLike], - unit_to_scale: Optional[Unit] = None, -) -> jax.Array: - """ - Return the floor of the argument. - - Parameters - ---------- - x : array_like, Quantity - Input array. - unit_to_scale : Unit, optional - The unit to scale the ``x``. - - Returns - ------- - out : jax.Array - """ - return _fun_accept_unitless_return_keep_unit(jnp.floor, x, unit_to_scale=unit_to_scale) - - -@set_module_as('brainunit.math') -def ceil( - x: Union[Quantity, jax.typing.ArrayLike], - unit_to_scale: Optional[Unit] = None, -) -> jax.Array: - """ - Return the ceiling of the argument. - - Parameters - ---------- - x : array_like, Quantity - Input array. - unit_to_scale : Unit, optional - The unit to scale the ``x``. - - Returns - ------- - out : jax.Array - """ - return _fun_accept_unitless_return_keep_unit(jnp.ceil, x, unit_to_scale=unit_to_scale) - - -@set_module_as('brainunit.math') -def trunc( - x: Union[Quantity, jax.typing.ArrayLike], - unit_to_scale: Optional[Unit] = None, -) -> jax.Array: - """ - Return the truncated value of the argument. - - Parameters - ---------- - x : array_like, Quantity - Input array. - unit_to_scale : Unit, optional - The unit to scale the ``x``. - - Returns - ------- - out : jax.Array - """ - return _fun_accept_unitless_return_keep_unit(jnp.trunc, x, unit_to_scale=unit_to_scale) - - -@set_module_as('brainunit.math') -def fix( - x: Union[Quantity, jax.typing.ArrayLike], - unit_to_scale: Optional[Unit] = None, -) -> jax.Array: - """ - Return the nearest integer towards zero. - - Parameters - ---------- - x : array_like, Quantity - Input array. - unit_to_scale : Unit, optional - The unit to scale the ``x``. - - Returns - ------- - out : jax.Array - """ - return _fun_accept_unitless_return_keep_unit(jnp.fix, x, unit_to_scale=unit_to_scale) - - -@set_module_as('brainunit.math') -def modf( - x: Union[Quantity, jax.typing.ArrayLike], - unit_to_scale: Optional[Unit] = None, -) -> Tuple[jax.Array, jax.Array]: - """ - Return the fractional and integer parts of the array elements. - - Parameters - ---------- - x : array_like, Quantity - Input array. - unit_to_scale : Unit, optional - The unit to scale the ``x``. - - Returns - ------- - The fractional and integral parts of the input, both with the same dimension. - """ - return _fun_accept_unitless_return_keep_unit(jnp.modf, x, unit_to_scale=unit_to_scale) - - # math funcs only accept unitless (binary) # ---------------------------------------- diff --git a/brainunit/math/_fun_keep_unit.py b/brainunit/math/_fun_keep_unit.py index a582d2b..b11a8ac 100644 --- a/brainunit/math/_fun_keep_unit.py +++ b/brainunit/math/_fun_keep_unit.py @@ -24,7 +24,8 @@ from ._fun_array_creation import asarray from .._base import (Quantity, - fail_for_dimension_mismatch, get_unit, + fail_for_dimension_mismatch, + get_unit, UNITLESS, unit_scale_align_to_first, remove_unitless) @@ -54,6 +55,9 @@ 'nanmedian', 'nanmean', 'nanstd', 'diff', 'rot90', 'intersect1d', 'nan_to_num', 'percentile', 'nanpercentile', 'quantile', 'nanquantile', + # math funcs only accept unitless (unary) can return Quantity + 'round', 'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'fix', 'modf', + # math funcs keep unit (binary) 'fmod', 'mod', 'copysign', 'remainder', 'maximum', 'minimum', 'fmax', 'fmin', 'lcm', 'gcd', 'trace', @@ -3320,3 +3324,180 @@ def unique( equal_nan=equal_nan, size=size, fill_value=fill_value) + + +@set_module_as('brainunit.math') +def round_( + x: Union[Quantity, jax.typing.ArrayLike], +) -> jax.Array: + """ + Round an array to the nearest integer. + + Parameters + ---------- + x : array_like, Quantity + Input array. + + Returns + ------- + out : jax.Array + """ + return _fun_keep_unit_unary(jnp.round_, x) + + +@set_module_as('brainunit.math') +def around( + x: Union[Quantity, jax.typing.ArrayLike], + decimals: int = 0, +) -> jax.Array: + """ + Round an array to the nearest integer. + + Parameters + ---------- + x : array_like, Quantity + Input array. + decimals : int, optional + Number of decimal places to round to (default is 0). + + Returns + ------- + out : jax.Array + """ + return _fun_keep_unit_unary(jnp.around, x, decimals=decimals) + + +@set_module_as('brainunit.math') +def round( + x: Union[Quantity, jax.typing.ArrayLike], + decimals: int = 0, +) -> jax.Array | Quantity: + """ + Round an array to the nearest integer. + + Parameters + ---------- + x : array_like, Quantity + Input array. + decimals : int, optional + Number of decimal places to round to (default is 0). + + Returns + ------- + out : jax.Array + """ + return _fun_keep_unit_unary(jnp.round, x, decimals=decimals) + + +@set_module_as('brainunit.math') +def rint( + x: Union[Quantity, jax.typing.ArrayLike], +) -> Union[Quantity, jax.Array]: + """ + Round an array to the nearest integer. + + Parameters + ---------- + x : array_like, Quantity + Input array. + + Returns + ------- + out : jax.Array + """ + return _fun_keep_unit_unary(jnp.rint, x) + + +@set_module_as('brainunit.math') +def floor( + x: Union[Quantity, jax.typing.ArrayLike], +) -> jax.Array: + """ + Return the floor of the argument. + + Parameters + ---------- + x : array_like, Quantity + Input array. + + Returns + ------- + out : jax.Array + """ + return _fun_keep_unit_unary(jnp.floor, x) + + +@set_module_as('brainunit.math') +def ceil( + x: Union[Quantity, jax.typing.ArrayLike], +) -> jax.Array: + """ + Return the ceiling of the argument. + + Parameters + ---------- + x : array_like, Quantity + Input array. + + Returns + ------- + out : jax.Array + """ + return _fun_keep_unit_unary(jnp.ceil, x) + + +@set_module_as('brainunit.math') +def trunc( + x: Union[Quantity, jax.typing.ArrayLike], +) -> jax.Array: + """ + Return the truncated value of the argument. + + Parameters + ---------- + x : array_like, Quantity + Input array. + + Returns + ------- + out : jax.Array + """ + return _fun_keep_unit_unary(jnp.trunc, x) + + +@set_module_as('brainunit.math') +def fix( + x: Union[Quantity, jax.typing.ArrayLike], +) -> jax.Array: + """ + Return the nearest integer towards zero. + + Parameters + ---------- + x : array_like, Quantity + Input array. + + Returns + ------- + out : jax.Array + """ + return _fun_keep_unit_unary(jnp.fix, x) + + +@set_module_as('brainunit.math') +def modf( + x: Union[Quantity, jax.typing.ArrayLike], +) -> Tuple[jax.Array, jax.Array]: + """ + Return the fractional and integer parts of the array elements. + + Parameters + ---------- + x : array_like, Quantity + Input array. + + Returns + ------- + The fractional and integral parts of the input, both with the same dimension. + """ + return _fun_keep_unit_unary(jnp.modf, x) diff --git a/docs/apis/brainunit.math.array-creation.rst b/docs/apis/brainunit.math.array-creation.rst deleted file mode 100644 index abff032..0000000 --- a/docs/apis/brainunit.math.array-creation.rst +++ /dev/null @@ -1,40 +0,0 @@ -Array Creation Functions -======================== - -.. currentmodule:: brainunit.math -.. automodule:: brainunit.math - -.. autosummary:: - :template: classtemplate.rst - :toctree: generated/ - - full - eye - identity - tri - empty - ones - zeros - full_like - diag - tril - triu - empty_like - ones_like - zeros_like - fill_diagonal - array - asarray - arange - linspace - logspace - meshgrid - vander - tril_indices - tril_indices_from - triu_indices - triu_indices_from - from_numpy - as_numpy - tree_ones_like - tree_zeros_like diff --git a/docs/apis/brainunit.math.change-unit.rst b/docs/apis/brainunit.math.change-unit.rst deleted file mode 100644 index 00b3e8b..0000000 --- a/docs/apis/brainunit.math.change-unit.rst +++ /dev/null @@ -1,41 +0,0 @@ -Functions that Changing Unit -============================ - -.. currentmodule:: brainunit.math -.. automodule:: brainunit.math - -.. autosummary:: - :template: classtemplate.rst - :toctree: generated/ - - reciprocal - prod - product - nancumprod - nanprod - cumprod - cumproduct - var - nanvar - cbrt - square - sqrt - multiply - divide - power - cross - true_divide - floor_divide - float_power - divmod - convolve - dot - multi_dot - vdot - vecdot - inner - outer - kron - matmul - tensordot - matrix_power diff --git a/docs/apis/brainunit.math.einops.rst b/docs/apis/brainunit.math.einops.rst deleted file mode 100644 index 3405513..0000000 --- a/docs/apis/brainunit.math.einops.rst +++ /dev/null @@ -1,15 +0,0 @@ -Einstein Operations -=================== - -.. currentmodule:: brainunit.math -.. automodule:: brainunit.math - -.. autosummary:: - :template: classtemplate.rst - :toctree: generated/ - - einreduce - einrearrange - einrepeat - einshape - einsum diff --git a/docs/apis/brainunit.math.keep-unit.rst b/docs/apis/brainunit.math.keep-unit.rst deleted file mode 100644 index e334770..0000000 --- a/docs/apis/brainunit.math.keep-unit.rst +++ /dev/null @@ -1,108 +0,0 @@ -Functions that Keeping Unit -=========================== - -.. currentmodule:: brainunit.math -.. automodule:: brainunit.math - -.. autosummary:: - :template: classtemplate.rst - :toctree: generated/ - - row_stack - concatenate - stack - vstack - hstack - dstack - column_stack - block - append - split - array_split - dsplit - hsplit - vsplit - atleast_1d - atleast_2d - atleast_3d - broadcast_arrays - broadcast_to - reshape - moveaxis - transpose - swapaxes - tile - repeat - flip - fliplr - flipud - roll - expand_dims - squeeze - sort - max - min - amax - amin - diagflat - diagonal - choose - ravel - flatten - unflatten - remove_diag - real - imag - conj - conjugate - negative - positive - abs - sum - nancumsum - nansum - cumsum - ediff1d - absolute - fabs - median - nanmin - nanmax - ptp - average - mean - std - nanmedian - nanmean - nanstd - diff - rot90 - intersect1d - nan_to_num - percentile - nanpercentile - quantile - nanquantile - fmod - mod - copysign - remainder - maximum - minimum - fmax - fmin - lcm - gcd - trace - add - subtract - nextafter - interp - clip - histogram - compress - extract - take - select - where - unique diff --git a/docs/apis/brainunit.math.misc.rst b/docs/apis/brainunit.math.misc.rst deleted file mode 100644 index 4e8331e..0000000 --- a/docs/apis/brainunit.math.misc.rst +++ /dev/null @@ -1,36 +0,0 @@ -Other Functions -=============== - -.. currentmodule:: brainunit.math -.. automodule:: brainunit.math - -.. autosummary:: - :template: classtemplate.rst - :toctree: generated/ - - finfo - iinfo - ndim - isreal - isscalar - isfinite - isinf - isnan - shape - size - get_dtype - is_float - is_int - broadcast_shapes - gradient - bartlett - blackman - hamming - hanning - kaiser - dtype - e - pi - inf - nan - euler_gamma diff --git a/docs/apis/brainunit.math.remove-unit.rst b/docs/apis/brainunit.math.remove-unit.rst deleted file mode 100644 index 8892106..0000000 --- a/docs/apis/brainunit.math.remove-unit.rst +++ /dev/null @@ -1,42 +0,0 @@ -Functions that Removing Unit -============================ - -.. currentmodule:: brainunit.math -.. automodule:: brainunit.math - -.. autosummary:: - :template: classtemplate.rst - :toctree: generated/ - - heaviside - signbit - sign - bincount - digitize - all - any - logical_not - equal - not_equal - greater - greater_equal - less - less_equal - array_equal - isclose - allclose - logical_and - logical_or - logical_xor - alltrue - sometrue - argsort - argmax - argmin - nanargmax - nanargmin - argwhere - nonzero - flatnonzero - searchsorted - count_nonzero diff --git a/docs/apis/brainunit.math.unitless.rst b/docs/apis/brainunit.math.unitless.rst deleted file mode 100644 index cc9e81f..0000000 --- a/docs/apis/brainunit.math.unitless.rst +++ /dev/null @@ -1,61 +0,0 @@ -Functions that Accepting Unitless -================================= - -.. currentmodule:: brainunit.math -.. automodule:: brainunit.math - -.. autosummary:: - :template: classtemplate.rst - :toctree: generated/ - - exprel - exp - exp2 - expm1 - log - log10 - log1p - log2 - arccos - arccosh - arcsin - arcsinh - arctan - arctanh - cos - cosh - sin - sinc - sinh - tan - tanh - deg2rad - rad2deg - degrees - radians - angle - round - around - round_ - rint - floor - ceil - trunc - fix - modf - frexp - hypot - arctan2 - logaddexp - logaddexp2 - corrcoef - correlate - cov - ldexp - bitwise_not - invert - bitwise_and - bitwise_or - bitwise_xor - left_shift - right_shift diff --git a/docs/conf.py b/docs/conf.py index cfe6fa1..5300dff 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -33,6 +33,8 @@ sys.path.insert(0, os.path.abspath('../')) import brainunit +import auto_generater +auto_generater.main() os.makedirs('apis/', exist_ok=True)