diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 93f9c0db8..f8d2ad5db 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -11,7 +11,6 @@ from typing import Callable, Union, Optional, Sequence, Dict, Any, Iterable import jax -from jax._src.sharding_impls import UnspecifiedValue, UNSPECIFIED from jax.sharding import Sharding from brainpy import tools, check @@ -22,6 +21,7 @@ _partial_fun) from .base import BrainPyObject, ObjectTransform from .naming import get_stack_cache, cache_stack +from ..ndarray import Array from .variables import (Variable, VariableStack, outermost_transform, @@ -29,19 +29,47 @@ current_transform_number, new_transform) +RandomState = None + __all__ = [ 'jit', ] +def _is_bp_array(a): + return isinstance(a, Array) + + def _get_sharding(a): - pass + if isinstance(a, Array): + a = a.value + if hasattr(a, 'sharding'): + return a.sharding + return None + + +def get_shardings(args): + return jax.tree_util.tree_map(lambda a: a.sharding, + args, + is_leaf=_is_bp_array) -def _get_sharding_of_dyn_vars(dyn_vars: dict): - leaves, tree = jax.tree_util.tree_flatten(dyn_vars) +def _is_rng(a): + global RandomState + if RandomState is None: + from brainpy.math.random import RandomState + return isinstance(a, RandomState) +def _is_not_rng(a): + global RandomState + if RandomState is None: + from brainpy.math.random import RandomState + return not isinstance(a, RandomState) + + +def _rng_split_key(a): + return a.split_key() def _seq_of_int(static_argnums): @@ -81,8 +109,8 @@ def __init__( keep_unused: bool = False, abstracted_axes: Optional[Any] = None, name: Optional[str] = None, - in_shardings: Union[Sharding, UnspecifiedValue] = UNSPECIFIED, - out_shardings: Union[Sharding, UnspecifiedValue] = UNSPECIFIED, + in_shardings: Any = None, + out_shardings: Any = None, # deprecated dyn_vars: Dict[str, Variable] = None, @@ -110,16 +138,8 @@ def __init__( self._abstracted_axes = abstracted_axes self._in_shardings = in_shardings self._out_shardings = out_shardings - # if isinstance(in_shardings, UnspecifiedValue): - # pass - # else: - # self._in_shardings = (UNSPECIFIED, in_shardings) - # if isinstance(out_shardings, UnspecifiedValue): - # pass - # else: - # self._out_shardings = (AUTO, out_shardings) - - # transformation function + + # OO transformation parameters self._transform = None self._dyn_vars = None @@ -127,43 +147,76 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs): for key, v in self._dyn_vars.items(): v._value = variable_data[key] out = self.fun(*args, **kwargs) - changes = self._dyn_vars.dict_data() + changes = self._dyn_vars.dict_data_of_subset(_is_not_rng) return changes, out + def _get_transform(self, *args, **kwargs): + with new_transform(self): + self._dyn_vars, rets = evaluate_dyn_vars( + self.fun, + *args, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames, + use_eval_shape=current_transform_number() <= 1, + **kwargs + ) + + # in_shardings + if self._in_shardings is None: + in_shardings = None + else: + if isinstance(self._in_shardings, (tuple, list)): + in_shardings = tuple(self._in_shardings) + else: + in_shardings = (self._in_shardings,) + _dyn_vars_sharing = get_shardings(self._dyn_vars) + in_shardings = (_dyn_vars_sharing,) + in_shardings + + # out_shardings + if self._out_shardings is None: + out_shardings = None + else: + if isinstance(self._out_shardings, (tuple, list)): + out_shardings = tuple(self._out_shardings) + else: + out_shardings = (self._out_shardings,) + global RandomState + if RandomState is None: + from brainpy.math.random import RandomState + _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) + out_shardings = (_dyn_vars_sharing,) + out_shardings + + # jit + self._transform = jax.jit( + self._transform_function, + static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), + static_argnames=self._static_argnames, + donate_argnums=self._donate_argnums, + inline=self._inline, + keep_unused=self._keep_unused, + abstracted_axes=self._abstracted_axes, + in_shardings=in_shardings, + out_shardings=out_shardings, + ) + return rets + def __call__(self, *args, **kwargs): if jax.config.jax_disable_jit: # support to disable JIT for debugging return self.fun(*args, **kwargs) if self._transform is None: # initialize the transformation - with new_transform(self): - self._dyn_vars, rets = evaluate_dyn_vars( - self.fun, - *args, - static_argnums=self._static_argnums, - static_argnames=self._static_argnames, - use_eval_shape=current_transform_number() <= 1, - **kwargs - ) - self._transform = jax.jit( - self._transform_function, - static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), - static_argnames=self._static_argnames, - donate_argnums=self._donate_argnums, - inline=self._inline, - keep_unused=self._keep_unused, - abstracted_axes=self._abstracted_axes, - in_shardings=self._in_shardings, - out_shardings=self._out_shardings, - ) - + rets = self._get_transform(*args, **kwargs) # if not the outermost transformation if current_transform_number(): return rets # call the transformed function + rng_keys = self._dyn_vars.call_on_subset(_is_rng, _rng_split_key) changes, out = self._transform(self._dyn_vars.dict_data(), *args, **kwargs) - for key, v in self._dyn_vars.items(): - v._value = changes[key] + for key, v in changes.items(): + self._dyn_vars[key]._value = v + for key, v in rng_keys.items(): + self._dyn_vars[key]._value = v return out def __repr__(self): @@ -174,6 +227,18 @@ def __repr__(self): f'{" " * len(name)} num_of_vars={len(self.vars().unique())})') return format_ref + # def compile(self, *args, **kwargs): + # if self._transform is None: # initialize the transformation + # _ = self._get_transform(*args, **kwargs) + # # call the transformed function + # rng_keys = self._dyn_vars.call_on_subset(_is_rng, _rng_split_key) + # changes, out = self._transform.lower(self._dyn_vars.dict_data(), *args, **kwargs) + # for key, v in changes.items(): + # self._dyn_vars[key]._value = v + # for key, v in rng_keys.items(): + # self._dyn_vars[key]._value = v + # return out + _jit_par = ''' func : BrainPyObject, function, callable @@ -412,7 +477,7 @@ def call_fun(self, *args, **kwargs): cache = get_stack_cache(hash_v) # TODO: better cache mechanism if cache is None: fun2 = partial(fun, self) - + with jax.ensure_compile_time_eval(): if len(static_argnums) or len(static_argnames): fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames)