diff --git a/brainpy/__init__.py b/brainpy/__init__.py index a9b3b1bda..afbc7bc57 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "2.4.5" +__version__ = "2.4.5.post4" _minimal_brainpylib_version = '0.1.10' # fundamental supporting modules diff --git a/brainpy/_src/initialize/generic.py b/brainpy/_src/initialize/generic.py index 15381a21f..f5a6fe3f3 100644 --- a/brainpy/_src/initialize/generic.py +++ b/brainpy/_src/initialize/generic.py @@ -29,7 +29,7 @@ def _is_scalar(x): def parameter( - param: Union[Callable, Initializer, bm.ndarray, np.ndarray, jnp.ndarray, float, int, bool], + param: Union[Callable, Initializer, bm.Array, np.ndarray, jax.Array, float, int, bool], sizes: Shape, allow_none: bool = True, allow_scalar: bool = True, @@ -74,8 +74,10 @@ def parameter( return param if callable(param): - param = param(sizes) # TODO - # return bm.jit(param, static_argnums=0, out_shardings=bm.sharding.get_sharding(axis_names))(size) + # param = param(sizes) # TODO + return bm.jit(param, + static_argnums=0, + out_shardings=bm.sharding.get_sharding(sharding))(sizes) elif isinstance(param, (np.ndarray, jnp.ndarray)): param = bm.asarray(param) @@ -104,32 +106,9 @@ def variable_( ): """Initialize a :math:`~.Variable` from a callable function or a data. - Parameters - ---------- - init: callable, function, ArrayType - The data to be initialized as a ``Variable``. - batch_or_mode: int, bool, Mode, optional - The batch size, model ``Mode``, boolean state. - This is used to specify the batch size of this variable. - If it is a boolean or an instance of ``Mode``, the batch size will be 1. - If it is None, the variable has no batch axis. - sizes: Shape - The shape of the variable. - batch_axis: int - The batch axis. - axis_names: sequence of str - The name for each axis. These names should match the given ``axes``. - batch_axis_name: str - The name for the batch axis. The name will be used if ``batch_size_or_mode`` is given. - - Returns - ------- - variable: bm.Variable - The target ``Variable`` instance. - See Also -------- - variable, parameter, noise, delay + variable """ return variable(init, @@ -152,10 +131,10 @@ def variable( Parameters ---------- - init: callable, function, ArrayType + init: callable, ArrayType The data to be initialized as a ``Variable``. batch_or_mode: int, bool, Mode, optional - The batch size, model ``Mode``, boolean state. + The batch size, mode ``Mode``, boolean state. This is used to specify the batch size of this variable. If it is a boolean or an instance of ``Mode``, the batch size will be 1. If it is None, the variable has no batch axis. diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index 874f0c2b8..377007847 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -95,9 +95,9 @@ def csrmv( raise ValueError('indices should be a 1D vector with integer type.') if np.ndim(indptr) != 1: raise ValueError('indptr should be a 1D vector with integer type.') - if indices.dtype not in [jnp.int32, jnp.int64]: + if indices.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: raise ValueError('indices should be a 1D vector with int32 or int64 type.') - if indptr.dtype not in [jnp.int32, jnp.int64]: + if indptr.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: raise ValueError('indptr should be a 1D vector with int32 or int64 type.') if np.ndim(events) != 1: raise ValueError('events should be a 1D vector.') @@ -328,36 +328,37 @@ def _event_csr_matvec_transpose_numba_imp1_bool(outs, ins): res_val.fill(0) values, indices, indptr, events, shape, _ = ins if values.shape[0] > 1: # heter - for row_i in range(shape[0]): - if events[row_i]: + for row_i, event in enumerate(events): + if event: for j in range(indptr[row_i], indptr[row_i + 1]): col_i = indices[j] res_val[col_i] += values[j] else: # homo values = values[0] - for row_i in range(shape[0]): - if events[row_i]: + for row_i, event in enumerate(events): + if event: for j in range(indptr[row_i], indptr[row_i + 1]): col_i = indices[j] res_val[col_i] += values + @numba.njit(fastmath=True) def _event_csr_matvec_transpose_numba_imp2(outs, ins): res_val = outs res_val.fill(0) values, indices, indptr, events, shape, _ = ins if values.shape[0] > 1: # heter - for row_i in range(shape[0]): - if events[row_i] > 0.: + for row_i, event in enumerate(events): + if event > 0.: for j in range(indptr[row_i], indptr[row_i + 1]): col_i = indices[j] res_val[col_i] += values[j] else: # homo values = values[0] - for row_i in range(shape[0]): - if events[row_i] > 0.: + for row_i, event in enumerate(events): + if event > 0.: for j in range(indptr[row_i], indptr[row_i + 1]): col_i = indices[j] res_val[col_i] += values diff --git a/brainpy/_src/math/index_tricks.py b/brainpy/_src/math/index_tricks.py index d10b0d0e5..6c71b4b06 100644 --- a/brainpy/_src/math/index_tricks.py +++ b/brainpy/_src/math/index_tricks.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- + import abc from jax import core diff --git a/brainpy/_src/math/object_transform/_tools.py b/brainpy/_src/math/object_transform/_tools.py index c90e631b9..6e126f093 100644 --- a/brainpy/_src/math/object_transform/_tools.py +++ b/brainpy/_src/math/object_transform/_tools.py @@ -1,6 +1,6 @@ import warnings from functools import wraps -from typing import Sequence +from typing import Sequence, Tuple, Any import jax @@ -79,11 +79,12 @@ def abstract(x): def evaluate_dyn_vars( f, *args, + transform: str = None, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), use_eval_shape: bool = True, **kwargs -): +) -> Tuple[VariableStack, Any]: # arguments if len(static_argnums) or len(static_argnames): f2, args, kwargs = _partial_fun(f, args, kwargs, diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index 5f06b4e67..f8dd1d8f8 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -225,7 +225,7 @@ def __call__(self, *args, **kwargs): cache_stack(self.target, stack) self._dyn_vars = stack - self._dyn_vars.remove_var_by_id(*[id(v) for v in self._grad_vars]) + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) self._eval_dyn_vars = True # if not the outermost transformation @@ -233,7 +233,7 @@ def __call__(self, *args, **kwargs): return self._return(rets) else: self._dyn_vars = stack - self._dyn_vars.remove_var_by_id(*[id(v) for v in self._grad_vars]) + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) self._eval_dyn_vars = True rets = self._transform( diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index e3470ef5c..b2b6017c9 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -447,7 +447,7 @@ def __init__( self, seed_or_key: Optional[Union[int, Array, jax.Array, np.ndarray]] = None, seed: Optional[int] = None, - _ready_to_trace: bool = True, + ready_to_trace: bool = True, ): """RandomState constructor. @@ -482,7 +482,7 @@ def __init__( raise ValueError('key must be an array with dtype uint32. ' f'But we got {seed_or_key}') key = seed_or_key - super(RandomState, self).__init__(key, _ready_to_trace=_ready_to_trace) + super(RandomState, self).__init__(key, ready_to_trace=ready_to_trace) def __repr__(self) -> str: print_code = repr(self.value) diff --git a/brainpy/_src/math/sharding.py b/brainpy/_src/math/sharding.py index ac41cb34f..2d95e906d 100644 --- a/brainpy/_src/math/sharding.py +++ b/brainpy/_src/math/sharding.py @@ -1,13 +1,14 @@ +# -*- coding: utf-8 -*- + from functools import partial from typing import Optional, Any, Union, Sequence from contextlib import contextmanager import jax import numpy as np -from jax._src.sharding_impls import UnspecifiedValue, UNSPECIFIED from jax.sharding import PartitionSpec, Mesh, NamedSharding, Sharding -from .ndarray import Array +from .ndarray import Array, ShardedArray __all__ = [ 'device_mesh', @@ -15,6 +16,7 @@ 'partition_by_axname', 'partition_by_sharding', 'partition', + 'keep_constraint', 'NEU_AXIS', 'PRE_AXIS', @@ -39,6 +41,10 @@ _default_mesh: Optional[Mesh] = None +def is_bp_array(x): + return isinstance(x, Array) + + @contextmanager def device_mesh( devices: Any, @@ -61,15 +67,33 @@ def device_mesh( def _device_put(x: Union[Array, jax.Array, np.ndarray], device: Union[None, jax.Device, Sharding] = None): + """Transfers ``x`` to ``device``. + + Note that this function can only transfer ``brainpy.math.Array``, ``jax.Array``, + and ``numpy.ndarray``. Other value will be directly returned. + + Args: + x: The input array. + device: The given device. + + Returns: + A copy of ``x`` that resides on ``device``. + """ if isinstance(x, Array): - x.value = jax.device_put(x, device=device) - return x + x.value = jax.device_put(x.value, device=device) + return x + else: + if isinstance(x, (jax.Array, np.ndarray)): + # wrap the data as brainpy.math.Array is important (experimental) + return ShardedArray(jax.device_put(x, device=device), keep_sharding=True) + else: + return x def get_sharding( axis_names: Optional[Sequence[str]] = None, mesh: Optional[Mesh] = None -) -> Union[UnspecifiedValue, NamedSharding]: +) -> Optional[NamedSharding]: """Get sharding according to the given axes information. Args: @@ -80,11 +104,11 @@ def get_sharding( The instance of NamedSharding. """ if axis_names is None: - return UNSPECIFIED + return None if mesh is None: mesh = _default_mesh if mesh is None: - return UNSPECIFIED + return None else: axis_names = [(name if name in mesh.axis_names else None) for name in axis_names] return NamedSharding(mesh, PartitionSpec(*axis_names)) @@ -108,8 +132,11 @@ def partition_by_axname( if axis_names is None: return x else: - for _leaf in jax.tree_util.tree_leaves(x, is_leaf=lambda a: isinstance(a, Array)): - assert np.ndim(_leaf) == len(axis_names) + for _leaf in jax.tree_util.tree_leaves(x, is_leaf=is_bp_array): + if np.ndim(_leaf) != len(axis_names): + raise ValueError(f'The input array shape is {np.shape(_leaf)}, ' + f'while the given axis names are {axis_names}. ' + f'Dimensions are mismatch.') if mesh is None: if _default_mesh is None: return x @@ -118,41 +145,78 @@ def partition_by_axname( if sharding is None: return x else: - f = partial(_device_put, device=sharding) - return jax.tree_util.tree_map(f, x, is_leaf=lambda a: isinstance(a, Array)) + return jax.tree_util.tree_map(partial(_device_put, device=sharding), + x, is_leaf=is_bp_array) def partition_by_sharding( x: Any, sharding: Optional[Sharding] = None, ): - """Partition inputs with the given sharding strategy.""" + """Partition inputs with the given sharding strategy. + + Args: + x: The input arrays. It can be a pyTree of arrays. + sharding: The `jax.sharding.Sharding` instance. + + Returns: + The sharded ``x``, which has been partitioned by the given sharding stragety. + """ if sharding is None: return x else: - assert isinstance(sharding, Sharding) - if isinstance(x, (Array, jax.Array)): - return _device_put(x, device=sharding) + if not isinstance(sharding, Sharding): + raise TypeError(f'sharding must be instance of jax.sharding.Sharding. While we got {sharding}.') return jax.tree_util.tree_map(partial(_device_put, device=sharding), x, - is_leaf=lambda a: isinstance(a, Array)) + is_leaf=is_bp_array) def partition( x: Any, sharding: Optional[Union[Sequence[str], jax.Device, Sharding]] = None, ): + """Partition the input arrays onto devices by the given sharding strategies. + + Args: + x: Any input arrays. It can also be a PyTree of arrays. + sharding: The sharding strategy. + + Returns: + The partitioned arrays. + Notably, the + """ if sharding is None: return x - if isinstance(sharding, UnspecifiedValue): - return x elif isinstance(sharding, (jax.Device, Sharding)): - if isinstance(x, (Array, jax.Array)): - return _device_put(x, device=sharding) return jax.tree_util.tree_map(partial(_device_put, device=sharding), - x, - is_leaf=lambda a: isinstance(a, Array)) + x, is_leaf=is_bp_array) elif isinstance(sharding, (tuple, list)) and any([isinstance(s, str) for s in sharding]): return partition_by_axname(x, sharding) else: - raise TypeError + raise TypeError('"sharding" only supports jax.sharding.Sharding or a sequence of axis names. \n' + f'But we got {sharding}') + + +def _keep_constraint(x: Any): + if isinstance(x, Array): + x = x.value + if isinstance(x, jax.Array): + if hasattr(x, 'sharding'): + if x.sharding is not None: + return jax.lax.with_sharding_constraint(x, x.sharding) + return x + else: + return x + + +def keep_constraint(x: Any): + """Keep the sharding constraint of the given inputs during computation. + + Args: + x: Any. + + Returns: + constraint_x: Same as ``x``. + """ + return jax.tree_util.tree_map(_keep_constraint, x, is_leaf=is_bp_array) diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 0554429d9..ff97f7303 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -19,8 +19,7 @@ from . import surrogate, event, sparse, jitconn # Variable and Objects for object-oriented JAX transformations -from .object_base import * -from .object_transform import * +from .oo_transform import * # environment settings from .modes import * diff --git a/brainpy/math/object_transform.py b/brainpy/math/object_transform.py deleted file mode 100644 index d281ec740..000000000 --- a/brainpy/math/object_transform.py +++ /dev/null @@ -1,32 +0,0 @@ -# -*- coding: utf-8 -*- - -from brainpy._src.math.object_transform.autograd import ( - grad as grad, - vector_grad as vector_grad, - jacobian as jacobian, - jacrev as jacrev, - jacfwd as jacfwd, - hessian as hessian, -) - -from brainpy._src.math.object_transform.controls import ( - make_loop as make_loop, - make_while as make_while, - make_cond as make_cond, - cond as cond, - ifelse as ifelse, - for_loop as for_loop, - while_loop as while_loop, -) - - -from brainpy._src.math.object_transform.jit import ( - jit as jit, - cls_jit as cls_jit, -) - - -from brainpy._src.math.object_transform.function import ( - to_object as to_object, - function as function, -) diff --git a/brainpy/math/object_base.py b/brainpy/math/oo_transform.py similarity index 66% rename from brainpy/math/object_base.py rename to brainpy/math/oo_transform.py index 1faca0d21..94ab09a9d 100644 --- a/brainpy/math/object_base.py +++ b/brainpy/math/oo_transform.py @@ -16,5 +16,33 @@ var_list as var_list, var_dict as var_dict, ) +from brainpy._src.math.object_transform.autograd import ( + grad as grad, + vector_grad as vector_grad, + jacobian as jacobian, + jacrev as jacrev, + jacfwd as jacfwd, + hessian as hessian, +) +from brainpy._src.math.object_transform.controls import ( + make_loop as make_loop, + make_while as make_while, + make_cond as make_cond, + cond as cond, + ifelse as ifelse, + for_loop as for_loop, + while_loop as while_loop, +) + +from brainpy._src.math.object_transform.jit import ( + jit as jit, + cls_jit as cls_jit, +) + + +from brainpy._src.math.object_transform.function import ( + to_object as to_object, + function as function, +) diff --git a/brainpy/math/sharding.py b/brainpy/math/sharding.py index 328abf6ed..775915672 100644 --- a/brainpy/math/sharding.py +++ b/brainpy/math/sharding.py @@ -5,6 +5,7 @@ partition_by_axname, partition_by_sharding, partition, + keep_constraint, NEU_AXIS, PRE_AXIS,