Skip to content

Commit

Permalink
upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Oct 8, 2023
1 parent fd83b5c commit 4de1acd
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 103 deletions.
2 changes: 1 addition & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.5"
__version__ = "2.4.5.post4"
_minimal_brainpylib_version = '0.1.10'

# fundamental supporting modules
Expand Down
37 changes: 8 additions & 29 deletions brainpy/_src/initialize/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _is_scalar(x):


def parameter(
param: Union[Callable, Initializer, bm.ndarray, np.ndarray, jnp.ndarray, float, int, bool],
param: Union[Callable, Initializer, bm.Array, np.ndarray, jax.Array, float, int, bool],
sizes: Shape,
allow_none: bool = True,
allow_scalar: bool = True,
Expand Down Expand Up @@ -74,8 +74,10 @@ def parameter(
return param

if callable(param):
param = param(sizes) # TODO
# return bm.jit(param, static_argnums=0, out_shardings=bm.sharding.get_sharding(axis_names))(size)
# param = param(sizes) # TODO
return bm.jit(param,
static_argnums=0,
out_shardings=bm.sharding.get_sharding(sharding))(sizes)

elif isinstance(param, (np.ndarray, jnp.ndarray)):
param = bm.asarray(param)
Expand Down Expand Up @@ -104,32 +106,9 @@ def variable_(
):
"""Initialize a :math:`~.Variable` from a callable function or a data.
Parameters
----------
init: callable, function, ArrayType
The data to be initialized as a ``Variable``.
batch_or_mode: int, bool, Mode, optional
The batch size, model ``Mode``, boolean state.
This is used to specify the batch size of this variable.
If it is a boolean or an instance of ``Mode``, the batch size will be 1.
If it is None, the variable has no batch axis.
sizes: Shape
The shape of the variable.
batch_axis: int
The batch axis.
axis_names: sequence of str
The name for each axis. These names should match the given ``axes``.
batch_axis_name: str
The name for the batch axis. The name will be used if ``batch_size_or_mode`` is given.
Returns
-------
variable: bm.Variable
The target ``Variable`` instance.
See Also
--------
variable, parameter, noise, delay
variable
"""
return variable(init,
Expand All @@ -152,10 +131,10 @@ def variable(
Parameters
----------
init: callable, function, ArrayType
init: callable, ArrayType
The data to be initialized as a ``Variable``.
batch_or_mode: int, bool, Mode, optional
The batch size, model ``Mode``, boolean state.
The batch size, mode ``Mode``, boolean state.
This is used to specify the batch size of this variable.
If it is a boolean or an instance of ``Mode``, the batch size will be 1.
If it is None, the variable has no batch axis.
Expand Down
21 changes: 11 additions & 10 deletions brainpy/_src/math/event/_csr_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def csrmv(
raise ValueError('indices should be a 1D vector with integer type.')
if np.ndim(indptr) != 1:
raise ValueError('indptr should be a 1D vector with integer type.')
if indices.dtype not in [jnp.int32, jnp.int64]:
if indices.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]:
raise ValueError('indices should be a 1D vector with int32 or int64 type.')
if indptr.dtype not in [jnp.int32, jnp.int64]:
if indptr.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]:
raise ValueError('indptr should be a 1D vector with int32 or int64 type.')
if np.ndim(events) != 1:
raise ValueError('events should be a 1D vector.')
Expand Down Expand Up @@ -328,36 +328,37 @@ def _event_csr_matvec_transpose_numba_imp1_bool(outs, ins):
res_val.fill(0)
values, indices, indptr, events, shape, _ = ins
if values.shape[0] > 1: # heter
for row_i in range(shape[0]):
if events[row_i]:
for row_i, event in enumerate(events):
if event:
for j in range(indptr[row_i], indptr[row_i + 1]):
col_i = indices[j]
res_val[col_i] += values[j]

else: # homo
values = values[0]
for row_i in range(shape[0]):
if events[row_i]:
for row_i, event in enumerate(events):
if event:
for j in range(indptr[row_i], indptr[row_i + 1]):
col_i = indices[j]
res_val[col_i] += values


@numba.njit(fastmath=True)
def _event_csr_matvec_transpose_numba_imp2(outs, ins):
res_val = outs
res_val.fill(0)
values, indices, indptr, events, shape, _ = ins
if values.shape[0] > 1: # heter
for row_i in range(shape[0]):
if events[row_i] > 0.:
for row_i, event in enumerate(events):
if event > 0.:
for j in range(indptr[row_i], indptr[row_i + 1]):
col_i = indices[j]
res_val[col_i] += values[j]

else: # homo
values = values[0]
for row_i in range(shape[0]):
if events[row_i] > 0.:
for row_i, event in enumerate(events):
if event > 0.:
for j in range(indptr[row_i], indptr[row_i + 1]):
col_i = indices[j]
res_val[col_i] += values
Expand Down
2 changes: 2 additions & 0 deletions brainpy/_src/math/index_tricks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-

import abc

from jax import core
Expand Down
5 changes: 3 additions & 2 deletions brainpy/_src/math/object_transform/_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from functools import wraps
from typing import Sequence
from typing import Sequence, Tuple, Any

import jax

Expand Down Expand Up @@ -79,11 +79,12 @@ def abstract(x):
def evaluate_dyn_vars(
f,
*args,
transform: str = None,
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = (),
use_eval_shape: bool = True,
**kwargs
):
) -> Tuple[VariableStack, Any]:
# arguments
if len(static_argnums) or len(static_argnames):
f2, args, kwargs = _partial_fun(f, args, kwargs,
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,15 @@ def __call__(self, *args, **kwargs):
cache_stack(self.target, stack)

self._dyn_vars = stack
self._dyn_vars.remove_var_by_id(*[id(v) for v in self._grad_vars])
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
self._eval_dyn_vars = True

# if not the outermost transformation
if current_transform_number():
return self._return(rets)
else:
self._dyn_vars = stack
self._dyn_vars.remove_var_by_id(*[id(v) for v in self._grad_vars])
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
self._eval_dyn_vars = True

rets = self._transform(
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def __init__(
self,
seed_or_key: Optional[Union[int, Array, jax.Array, np.ndarray]] = None,
seed: Optional[int] = None,
_ready_to_trace: bool = True,
ready_to_trace: bool = True,
):
"""RandomState constructor.
Expand Down Expand Up @@ -482,7 +482,7 @@ def __init__(
raise ValueError('key must be an array with dtype uint32. '
f'But we got {seed_or_key}')
key = seed_or_key
super(RandomState, self).__init__(key, _ready_to_trace=_ready_to_trace)
super(RandomState, self).__init__(key, ready_to_trace=ready_to_trace)

def __repr__(self) -> str:
print_code = repr(self.value)
Expand Down
110 changes: 87 additions & 23 deletions brainpy/_src/math/sharding.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
# -*- coding: utf-8 -*-

from functools import partial
from typing import Optional, Any, Union, Sequence
from contextlib import contextmanager

import jax
import numpy as np
from jax._src.sharding_impls import UnspecifiedValue, UNSPECIFIED
from jax.sharding import PartitionSpec, Mesh, NamedSharding, Sharding

from .ndarray import Array
from .ndarray import Array, ShardedArray

__all__ = [
'device_mesh',
'get_sharding',
'partition_by_axname',
'partition_by_sharding',
'partition',
'keep_constraint',

'NEU_AXIS',
'PRE_AXIS',
Expand All @@ -39,6 +41,10 @@
_default_mesh: Optional[Mesh] = None


def is_bp_array(x):
return isinstance(x, Array)


@contextmanager
def device_mesh(
devices: Any,
Expand All @@ -61,15 +67,33 @@ def device_mesh(

def _device_put(x: Union[Array, jax.Array, np.ndarray],
device: Union[None, jax.Device, Sharding] = None):
"""Transfers ``x`` to ``device``.
Note that this function can only transfer ``brainpy.math.Array``, ``jax.Array``,
and ``numpy.ndarray``. Other value will be directly returned.
Args:
x: The input array.
device: The given device.
Returns:
A copy of ``x`` that resides on ``device``.
"""
if isinstance(x, Array):
x.value = jax.device_put(x, device=device)
return x
x.value = jax.device_put(x.value, device=device)
return x
else:
if isinstance(x, (jax.Array, np.ndarray)):
# wrap the data as brainpy.math.Array is important (experimental)
return ShardedArray(jax.device_put(x, device=device), keep_sharding=True)
else:
return x


def get_sharding(
axis_names: Optional[Sequence[str]] = None,
mesh: Optional[Mesh] = None
) -> Union[UnspecifiedValue, NamedSharding]:
) -> Optional[NamedSharding]:
"""Get sharding according to the given axes information.
Args:
Expand All @@ -80,11 +104,11 @@ def get_sharding(
The instance of NamedSharding.
"""
if axis_names is None:
return UNSPECIFIED
return None
if mesh is None:
mesh = _default_mesh
if mesh is None:
return UNSPECIFIED
return None
else:
axis_names = [(name if name in mesh.axis_names else None) for name in axis_names]
return NamedSharding(mesh, PartitionSpec(*axis_names))
Expand All @@ -108,8 +132,11 @@ def partition_by_axname(
if axis_names is None:
return x
else:
for _leaf in jax.tree_util.tree_leaves(x, is_leaf=lambda a: isinstance(a, Array)):
assert np.ndim(_leaf) == len(axis_names)
for _leaf in jax.tree_util.tree_leaves(x, is_leaf=is_bp_array):
if np.ndim(_leaf) != len(axis_names):
raise ValueError(f'The input array shape is {np.shape(_leaf)}, '
f'while the given axis names are {axis_names}. '
f'Dimensions are mismatch.')
if mesh is None:
if _default_mesh is None:
return x
Expand All @@ -118,41 +145,78 @@ def partition_by_axname(
if sharding is None:
return x
else:
f = partial(_device_put, device=sharding)
return jax.tree_util.tree_map(f, x, is_leaf=lambda a: isinstance(a, Array))
return jax.tree_util.tree_map(partial(_device_put, device=sharding),
x, is_leaf=is_bp_array)


def partition_by_sharding(
x: Any,
sharding: Optional[Sharding] = None,
):
"""Partition inputs with the given sharding strategy."""
"""Partition inputs with the given sharding strategy.
Args:
x: The input arrays. It can be a pyTree of arrays.
sharding: The `jax.sharding.Sharding` instance.
Returns:
The sharded ``x``, which has been partitioned by the given sharding stragety.
"""
if sharding is None:
return x
else:
assert isinstance(sharding, Sharding)
if isinstance(x, (Array, jax.Array)):
return _device_put(x, device=sharding)
if not isinstance(sharding, Sharding):
raise TypeError(f'sharding must be instance of jax.sharding.Sharding. While we got {sharding}.')
return jax.tree_util.tree_map(partial(_device_put, device=sharding),
x,
is_leaf=lambda a: isinstance(a, Array))
is_leaf=is_bp_array)


def partition(
x: Any,
sharding: Optional[Union[Sequence[str], jax.Device, Sharding]] = None,
):
"""Partition the input arrays onto devices by the given sharding strategies.
Args:
x: Any input arrays. It can also be a PyTree of arrays.
sharding: The sharding strategy.
Returns:
The partitioned arrays.
Notably, the
"""
if sharding is None:
return x
if isinstance(sharding, UnspecifiedValue):
return x
elif isinstance(sharding, (jax.Device, Sharding)):
if isinstance(x, (Array, jax.Array)):
return _device_put(x, device=sharding)
return jax.tree_util.tree_map(partial(_device_put, device=sharding),
x,
is_leaf=lambda a: isinstance(a, Array))
x, is_leaf=is_bp_array)
elif isinstance(sharding, (tuple, list)) and any([isinstance(s, str) for s in sharding]):
return partition_by_axname(x, sharding)
else:
raise TypeError
raise TypeError('"sharding" only supports jax.sharding.Sharding or a sequence of axis names. \n'
f'But we got {sharding}')


def _keep_constraint(x: Any):
if isinstance(x, Array):
x = x.value
if isinstance(x, jax.Array):
if hasattr(x, 'sharding'):
if x.sharding is not None:
return jax.lax.with_sharding_constraint(x, x.sharding)
return x
else:
return x


def keep_constraint(x: Any):
"""Keep the sharding constraint of the given inputs during computation.
Args:
x: Any.
Returns:
constraint_x: Same as ``x``.
"""
return jax.tree_util.tree_map(_keep_constraint, x, is_leaf=is_bp_array)
Loading

0 comments on commit 4de1acd

Please sign in to comment.