diff --git a/brainunit/_base.py b/brainunit/_base.py index 71cf3f8..bde716f 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -52,6 +52,10 @@ 'assert_quantity', ] +StaticScalar = Union[ + np.bool_, np.number, # NumPy scalar types + bool, int, float, complex, # Python scalar types +] PyTree = Any _all_slice = slice(None, None, None) @@ -1728,6 +1732,11 @@ def _process_list_with_units(value: List) -> Tuple[jax.typing.ArrayLike, 'Unit'] return values, unit +def _element_not_quantity(x): + assert not isinstance(x, Quantity), f"Expected not a Quantity object, but got {x}" + return x + + @register_pytree_node_class class Quantity: """ @@ -1788,6 +1797,96 @@ def __init__( # dimension self._unit = unit + @property + def at(self): + """ + Helper property for index update functionality. + + The ``at`` property provides a functionally pure equivalent of in-place + array modifications. + + In particular: + + ============================== ================================ + Alternate syntax Equivalent In-place expression + ============================== ================================ + ``x = x.at[idx].set(y)`` ``x[idx] = y`` + ``x = x.at[idx].add(y)`` ``x[idx] += y`` + ``x = x.at[idx].multiply(y)`` ``x[idx] *= y`` + ``x = x.at[idx].divide(y)`` ``x[idx] /= y`` + ``x = x.at[idx].power(y)`` ``x[idx] **= y`` + ``x = x.at[idx].min(y)`` ``x[idx] = minimum(x[idx], y)`` + ``x = x.at[idx].max(y)`` ``x[idx] = maximum(x[idx], y)`` + ``x = x.at[idx].apply(ufunc)`` ``ufunc.at(x, idx)`` + ``x = x.at[idx].get()`` ``x = x[idx]`` + ============================== ================================ + + None of the ``x.at`` expressions modify the original ``x``; instead they return + a modified copy of ``x``. However, inside a :py:func:`~jax.jit` compiled function, + expressions like :code:`x = x.at[idx].set(y)` are guaranteed to be applied in-place. + + Unlike NumPy in-place operations such as :code:`x[idx] += y`, if multiple + indices refer to the same location, all updates will be applied (NumPy would + only apply the last update, rather than applying all updates.) The order + in which conflicting updates are applied is implementation-defined and may be + nondeterministic (e.g., due to concurrency on some hardware platforms). + + By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound + index semantics can be specified via the ``mode`` parameter (see below). + + Arguments + --------- + mode : str + Specify out-of-bound indexing mode. Options are: + + - ``"promise_in_bounds"``: (default) The user promises that indices are in bounds. + No additional checking will be performed. In practice, this means that + out-of-bounds indices in ``get()`` will be clipped, and out-of-bounds indices + in ``set()``, ``add()``, etc. will be dropped. + - ``"clip"``: clamp out of bounds indices into valid range. + - ``"drop"``: ignore out-of-bound indices. + - ``"fill"``: alias for ``"drop"``. For `get()`, the optional ``fill_value`` + argument specifies the value that will be returned. + indices_are_sorted : bool + If True, the implementation will assume that the indices passed to ``at[]`` + are sorted in ascending order, which can lead to more efficient execution + on some backends. + unique_indices : bool + If True, the implementation will assume that the indices passed to ``at[]`` + are unique, which can result in more efficient execution on some backends. + fill_value : Any + Only applies to the ``get()`` method: the fill value to return for out-of-bounds + slices when `mode` is ``'fill'``. Ignored otherwise. Defaults to ``NaN`` for + inexact types, the largest negative value for signed types, the largest positive + value for unsigned types, and ``True`` for booleans. + + Examples + -------- + >>> import brainunit as bu + >>> x = jnp.arange(5.0) * bu.mV + >>> x + Array([0., 1., 2., 3., 4.], dtype=float32) * mvolt + >>> x.at[2].add(10) + brainunit.UnitMismatchError: Cannot convert to a unit with different dimensions. (units are Unit(1.0) and mV). + >>> x.at[2].add(10 * bu.mV) + ArrayImpl([ 0., 1., 12., 3., 4.], dtype=float32) * mvolt + >>> x.at[10].add(10 * bu.mV) # out-of-bounds indices are ignored + ArrayImpl([0., 1., 2., 3., 4.], dtype=float32) * mvolt + >>> x.at[20].add(10 * bu.mV, mode='clip') + ArrayImpl([ 0., 1., 2., 3., 14.], dtype=float32) * mvolt + >>> x.at[2].get() + 2. * mvolt + >>> x.at[20].get() # out-of-bounds indices clipped + 4. * mvolt + >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN + nan * mvolt + >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value + brainunit.UnitMismatchError: Cannot convert to a unit with different dimensions. (units are Unit(1.0) and mV). + >>> x.at[20].get(mode='fill', fill_value=-1 * bu.mV) # custom fill value + -1. * mvolt + """ + return _IndexUpdateHelper(self) + @property def mantissa(self) -> jax.typing.ArrayLike: r""" @@ -2129,30 +2228,224 @@ def __getitem__(self, index) -> 'Quantity': return Quantity(self._mantissa[index], unit=self.unit) def __setitem__(self, index, value: 'Quantity' | jax.typing.ArrayLike): + # check value if not isinstance(value, Quantity): if self.is_unitless: value = Quantity(value) else: raise TypeError(f"Only Quantity can be assigned to Quantity. But got {value}") - fail_for_unit_mismatch(self, value, "Inconsistent units in assignment") value = value.in_unit(self.unit) - # index is a tuple - _assert_not_quantity(index) - if isinstance(index, (tuple, list)): - index = tuple(_assert_not_quantity(x) for x in index) - # index is numpy.ndarray - elif isinstance(index, np.ndarray): - index = jnp.asarray(index) + # check index + index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity)) # update - self_value = self._check_tracer() - if isinstance(self_value, jax.Array): - self_value = self_value.at[index].set(value.mantissa) - else: - self_value[index] = value.mantissa + self_value = jnp.asarray(self._check_tracer()) + self_value = self_value.at[index].set(value.mantissa) self.update_value(self_value) + def scatter_add( + self, + index: jax.typing.ArrayLike, + value: Union['Quantity', jax.typing.ArrayLike] + ) -> 'Quantity': + """ + Scatter-add the given value to the given index. + + Parameters + ---------- + index : int or array_like + The index to scatter-add the value to. + value : Quantity + The value to scatter-add. + + Returns + ------- + out : Quantity + The scatter-added value. + """ + # check value + if not isinstance(value, Quantity): + if self.is_unitless: + value = Quantity(value) + else: + raise TypeError(f"Only Quantity can be assigned to Quantity. But got {value}") + value = value.in_unit(self.unit) + + # check index + index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity)) + + # scatter-add + self_value = jnp.asarray(self._check_tracer()) + self_value = self_value.at[index].add(value.mantissa) + return Quantity(self_value, unit=self.unit) + + def scatter_sub( + self, + index: jax.typing.ArrayLike, + value: Union['Quantity', jax.typing.ArrayLike] + ) -> 'Quantity': + """ + Scatter-sub the given value to the given index. + + Parameters + ---------- + index : int or array_like + The index to scatter-add the value to. + value : Quantity + The value to scatter-add. + + Returns + ------- + out : Quantity + The scatter-subbed value. + """ + return self.scatter_add(index, -value) + + def scatter_mul( + self, + index: jax.typing.ArrayLike, + value: Union['Quantity', jax.typing.ArrayLike] + ) -> 'Quantity': + """ + Scatter-mul the given value to the given index. + + Parameters + ---------- + index : int or array_like + The index to scatter-mul the value to. + value : Quantity + The value to scatter-mul. + + Returns + ------- + out : Quantity + The scatter-multiplied value. + """ + # check value + if not isinstance(value, Quantity): + if self.is_unitless: + value = Quantity(value) + else: + raise TypeError(f"Only Quantity can be assigned to Quantity. But got {value}") + value = value.in_unit(self.unit) + + # check index + index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity)) + + # scatter-mul + self_value = jnp.asarray(self._check_tracer()) + self_value = self_value.at[index].mul(value.mantissa) + return Quantity(self_value, unit=self.unit) + + def scatter_div( + self, + index: jax.typing.ArrayLike, + value: Union['Quantity', jax.typing.ArrayLike] + ) -> 'Quantity': + """ + Scatter-div the given value to the given index. + + Parameters + ---------- + index : int or array_like + The index to scatter-div the value to. + value : Quantity + The value to scatter-div. + + Returns + ------- + out : Quantity + The scatter-divided value. + """ + # check value + if not isinstance(value, Quantity): + if self.is_unitless: + value = Quantity(value) + else: + raise TypeError(f"Only Quantity can be assigned to Quantity. But got {value}") + value = value.in_unit(self.unit) + + # check index + index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity)) + + # scatter-div + self_value = jnp.asarray(self._check_tracer()) + self_value = self_value.at[index].divide(value.mantissa) + return Quantity(self_value, unit=self.unit) + + def scatter_max( + self, + index: jax.typing.ArrayLike, + value: Union['Quantity', jax.typing.ArrayLike] + ) -> 'Quantity': + """ + Scatter-max the given value to the given index. + + Parameters + ---------- + index : int or array_like + The index to scatter-max the value to. + value : Quantity + The value to scatter-max. + + Returns + ------- + out : Quantity + The scatter-maximum value. + """ + # check value + if not isinstance(value, Quantity): + if self.is_unitless: + value = Quantity(value) + else: + raise TypeError(f"Only Quantity can be assigned to Quantity. But got {value}") + value = value.in_unit(self.unit) + + # check index + index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity)) + + # scatter-max + self_value = jnp.asarray(self._check_tracer()) + self_value = self_value.at[index].max(value.mantissa) + return Quantity(self_value, unit=self.unit) + + def scatter_min( + self, + index: jax.typing.ArrayLike, + value: Union['Quantity', jax.typing.ArrayLike] + ) -> 'Quantity': + """ + Scatter-min the given value to the given index. + + Parameters + ---------- + index : int or array_like + The index to scatter-min the value to. + value : Quantity + The value to scatter-min. + + Returns + ------- + out : Quantity + The scatter-minimum value. + """ + # check value + if not isinstance(value, Quantity): + if self.is_unitless: + value = Quantity(value) + else: + raise TypeError(f"Only Quantity can be assigned to Quantity. But got {value}") + value = value.in_unit(self.unit) + + # check index + index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity)) + + # scatter-min + self_value = jnp.asarray(self._check_tracer()) + self_value = self_value.at[index].min(value.mantissa) + return Quantity(self_value, unit=self.unit) + # ---------- # # operations # # ---------- # @@ -3109,6 +3402,276 @@ def double(self) -> 'Quantity': return Quantity(jnp.asarray(self._mantissa, dtype=jnp.float64), unit=self.unit) +class _IndexUpdateHelper: + """ + Helper property for index update functionality. + """ + __slots__ = ("quantity",) + + def __init__(self, quantity: Quantity): + assert isinstance(quantity, Quantity), f"quantity must be a Quantity object, but got {quantity}" + self.quantity = quantity + + def __getitem__(self, index: Any) -> _IndexUpdateRef: + return _IndexUpdateRef(index, self.quantity) + + def __repr__(self): + return f"_IndexUpdateHelper({self.quantity})" + + +class _IndexUpdateRef: + """ + Helper object to call indexed update functions for an (advanced) index. + + This object references a source array and a specific indexer into that array. + Methods on this object return copies of the source array that have been + modified at the positions specified by the indexer. + """ + __slots__ = ("quantity", "index", "mantissa_at", "unit") + + def __init__(self, index, quantity: Quantity): + self.index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity)) + self.quantity = quantity + self.mantissa_at = jnp.asarray(quantity.mantissa).at + self.unit = quantity.unit + + def __repr__(self) -> str: + return f"_IndexUpdateRef({self.quantity}, {self.index!r})" + + def get( + self, + indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | None = None, + fill_value: StaticScalar | None = None + ) -> Quantity: + """Equivalent to ``x[idx]``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexing ` ``x[idx]``. This function differs from + the usual array indexing syntax in that it allows additional keyword + arguments ``indices_are_sorted`` and ``unique_indices`` to be passed. + """ + if fill_value is not None: + fill_value = Quantity(fill_value).in_unit(self.unit).mantissa.item() + return Quantity( + self.mantissa_at[self.index].get( + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode, + fill_value=fill_value + ), + unit=self.unit + ) + + def set( + self, + values: Any, + indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | None = None, + fill_value: StaticScalar | None = None + ) -> Quantity: + """Pure equivalent of ``x[idx] = y``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:`indexed assignment ` ``x[idx] = y``. + """ + values = Quantity(values).in_unit(self.unit).mantissa + if fill_value is not None: + fill_value = Quantity(fill_value).in_unit(self.unit).mantissa.item() + return Quantity( + self.mantissa_at[self.index].set( + values, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode, + fill_value=fill_value + ), + unit=self.unit + ) + + def add( + self, + values: Any, + indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | None = None + ) -> Quantity: + """Pure equivalent of ``x[idx] += y``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] += y``. + + """ + values = Quantity(values).in_unit(self.unit).mantissa + return Quantity( + self.mantissa_at[self.index].add( + values, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode + ), + unit=self.unit + ) + + def multiply( + self, + values: Any, + indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | None = None + ) -> Quantity: + """Pure equivalent of ``x[idx] *= y``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] *= y``. + + """ + values = Quantity(values) + return Quantity( + self.mantissa_at[self.index].multiply( + values.mantissa, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode + ), + unit=self.unit * values.unit + ) + + mul = multiply + + def divide( + self, + values: Any, + indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | None = None + ) -> Quantity: + """Pure equivalent of ``x[idx] /= y``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] /= y``. + + """ + values = Quantity(values) + return Quantity( + self.mantissa_at[self.index].divide( + values, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode + ), + unit=self.unit / values.unit + ) + + div = divide + + def power( + self, + values: Any, + indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | None = None + ) -> Quantity: + """Pure equivalent of ``x[idx] **= y``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` ``x[idx] **= y``. + + """ + assert isinstance(values, int), f"values must be an integer, but got {values}" + return Quantity( + self.mantissa_at[self.index].power( + values, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode + ), + unit=self.unit ** values + ) + + def min( + self, + values: Any, + indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | None = None + ) -> Quantity: + """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` + ``x[idx] = minimum(x[idx], y)``. + + """ + values = Quantity(values).in_unit(self.unit).mantissa + return Quantity( + self.mantissa_at[self.index].min( + values, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode + ), + unit=self.unit + ) + + def max( + self, + values: Any, + indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | None = None + ) -> Quantity: + """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. + + Returns the value of ``x`` that would result from the NumPy-style + :mod:indexed assignment ` + ``x[idx] = maximum(x[idx], y)``. + + """ + values = Quantity(values).in_unit(self.unit).mantissa + return Quantity( + self.mantissa_at[self.index].max( + values, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode + ), + unit=self.unit + ) + + def apply( + self, + mantissa_fun: Callable[[jax.typing.ArrayLike], jax.typing.ArrayLike], + unit_fun: Callable[[Unit], Unit], + indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | None = None + ) -> Quantity: + """Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``. + + Returns the value of ``x`` that would result from applying the unary + function ``func`` to ``x`` at the given indices. This is similar to + ``x.at[idx].set(func(x[idx]))``, but differs in the case of repeated indices: + in ``x.at[idx].apply(func)``, repeated indices result in the function being + applied multiple times. + + Note that in the current implementation, ``scatter_apply`` is not compatible + with automatic differentiation. + + """ + return Quantity( + self.mantissa_at[self.index].apply( + mantissa_fun, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode + ), + unit=unit_fun(self.unit) + ) + + def _replace_with_array(seq, unit): """ Replace all the elements in the list with an equivalent `Array` diff --git a/brainunit/_base_test.py b/brainunit/_base_test.py index c876459..f565197 100644 --- a/brainunit/_base_test.py +++ b/brainunit/_base_test.py @@ -842,6 +842,33 @@ def f(b): f(a) + def test_setiterm(self): + u = bu.Quantity([0, 0, 0.]) + u[jnp.asarray([0, 1, 1])] += jnp.asarray([1., 1., 1.]) + assert_quantity(u, [1., 1., 0.]) + + u = bu.Quantity([0, 0, 0.]) + u = u.scatter_add(jnp.asarray([0, 1, 1]), jnp.asarray([1., 1., 1.])) + assert_quantity(u, [1., 2., 0.]) + + nu = np.asarray([0, 0, 0.]) + nu[np.asarray([0, 1, 1])] += np.asarray([1., 1., 1.]) + self.assertTrue(np.allclose(nu, np.asarray([1., 1., 0.]))) + + def test_at(self): + x = jnp.arange(5.0) * bu.mV + with self.assertRaises(bu.UnitMismatchError): + x.at[2].add(10) + x.at[2].add(10 * bu.mV) + x.at[10].add(10 * bu.mV) # out-of-bounds indices are ignored + x.at[20].add(10 * bu.mV, mode='clip') + x.at[2].get() + x.at[20].get() # out-of-bounds indices clipped + x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN + with self.assertRaises(bu.UnitMismatchError): + x.at[20].get(mode='fill', fill_value=-1) # custom fill value + x.at[20].get(mode='fill', fill_value=-1 * bu.mV) # custom fill value + class TestNumPyFunctions(unittest.TestCase): def test_special_case_numpy_functions(self): diff --git a/docs/mathematical_functions.ipynb b/docs/mathematical_functions.ipynb deleted file mode 100644 index 69dbfbd..0000000 --- a/docs/mathematical_functions.ipynb +++ /dev/null @@ -1,36 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Mathematical Functions" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```{toctree}\n", - ":maxdepth: 2\n", - "\n", - "mathematical_functions/numpy_functions\n", - "mathematical_functions/customize_functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}