Skip to content

Commit

Permalink
brainpy.math.jit supports parallelization of all functions in `brai…
Browse files Browse the repository at this point in the history
…npy.math.random` module
  • Loading branch information
chaoming0625 committed Oct 8, 2023
1 parent 0508776 commit fd83b5c
Showing 1 changed file with 106 additions and 41 deletions.
147 changes: 106 additions & 41 deletions brainpy/_src/math/object_transform/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import Callable, Union, Optional, Sequence, Dict, Any, Iterable

import jax
from jax._src.sharding_impls import UnspecifiedValue, UNSPECIFIED
from jax.sharding import Sharding

from brainpy import tools, check
Expand All @@ -22,26 +21,55 @@
_partial_fun)
from .base import BrainPyObject, ObjectTransform
from .naming import get_stack_cache, cache_stack
from ..ndarray import Array
from .variables import (Variable,
VariableStack,
outermost_transform,
transform_stack,
current_transform_number,
new_transform)

RandomState = None

__all__ = [
'jit',
]


def _is_bp_array(a):
return isinstance(a, Array)


def _get_sharding(a):
pass
if isinstance(a, Array):
a = a.value
if hasattr(a, 'sharding'):
return a.sharding
return None


def get_shardings(args):
return jax.tree_util.tree_map(lambda a: a.sharding,
args,
is_leaf=_is_bp_array)


def _get_sharding_of_dyn_vars(dyn_vars: dict):
leaves, tree = jax.tree_util.tree_flatten(dyn_vars)
def _is_rng(a):
global RandomState
if RandomState is None:
from brainpy.math.random import RandomState
return isinstance(a, RandomState)


def _is_not_rng(a):
global RandomState
if RandomState is None:
from brainpy.math.random import RandomState
return not isinstance(a, RandomState)


def _rng_split_key(a):
return a.split_key()


def _seq_of_int(static_argnums):
Expand Down Expand Up @@ -81,8 +109,8 @@ def __init__(
keep_unused: bool = False,
abstracted_axes: Optional[Any] = None,
name: Optional[str] = None,
in_shardings: Union[Sharding, UnspecifiedValue] = UNSPECIFIED,
out_shardings: Union[Sharding, UnspecifiedValue] = UNSPECIFIED,
in_shardings: Any = None,
out_shardings: Any = None,

# deprecated
dyn_vars: Dict[str, Variable] = None,
Expand Down Expand Up @@ -110,60 +138,85 @@ def __init__(
self._abstracted_axes = abstracted_axes
self._in_shardings = in_shardings
self._out_shardings = out_shardings
# if isinstance(in_shardings, UnspecifiedValue):
# pass
# else:
# self._in_shardings = (UNSPECIFIED, in_shardings)
# if isinstance(out_shardings, UnspecifiedValue):
# pass
# else:
# self._out_shardings = (AUTO, out_shardings)

# transformation function

# OO transformation parameters
self._transform = None
self._dyn_vars = None

def _transform_function(self, variable_data: Dict, *args, **kwargs):
for key, v in self._dyn_vars.items():
v._value = variable_data[key]
out = self.fun(*args, **kwargs)
changes = self._dyn_vars.dict_data()
changes = self._dyn_vars.dict_data_of_subset(_is_not_rng)
return changes, out

def _get_transform(self, *args, **kwargs):
with new_transform(self):
self._dyn_vars, rets = evaluate_dyn_vars(
self.fun,
*args,
static_argnums=self._static_argnums,
static_argnames=self._static_argnames,
use_eval_shape=current_transform_number() <= 1,
**kwargs
)

# in_shardings
if self._in_shardings is None:
in_shardings = None
else:
if isinstance(self._in_shardings, (tuple, list)):
in_shardings = tuple(self._in_shardings)
else:
in_shardings = (self._in_shardings,)
_dyn_vars_sharing = get_shardings(self._dyn_vars)
in_shardings = (_dyn_vars_sharing,) + in_shardings

# out_shardings
if self._out_shardings is None:
out_shardings = None
else:
if isinstance(self._out_shardings, (tuple, list)):
out_shardings = tuple(self._out_shardings)
else:
out_shardings = (self._out_shardings,)
global RandomState
if RandomState is None:
from brainpy.math.random import RandomState
_dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState))
out_shardings = (_dyn_vars_sharing,) + out_shardings

# jit
self._transform = jax.jit(
self._transform_function,
static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums),
static_argnames=self._static_argnames,
donate_argnums=self._donate_argnums,
inline=self._inline,
keep_unused=self._keep_unused,
abstracted_axes=self._abstracted_axes,
in_shardings=in_shardings,
out_shardings=out_shardings,
)
return rets

def __call__(self, *args, **kwargs):
if jax.config.jax_disable_jit: # support to disable JIT for debugging
return self.fun(*args, **kwargs)

if self._transform is None: # initialize the transformation
with new_transform(self):
self._dyn_vars, rets = evaluate_dyn_vars(
self.fun,
*args,
static_argnums=self._static_argnums,
static_argnames=self._static_argnames,
use_eval_shape=current_transform_number() <= 1,
**kwargs
)
self._transform = jax.jit(
self._transform_function,
static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums),
static_argnames=self._static_argnames,
donate_argnums=self._donate_argnums,
inline=self._inline,
keep_unused=self._keep_unused,
abstracted_axes=self._abstracted_axes,
in_shardings=self._in_shardings,
out_shardings=self._out_shardings,
)

rets = self._get_transform(*args, **kwargs)
# if not the outermost transformation
if current_transform_number():
return rets

# call the transformed function
rng_keys = self._dyn_vars.call_on_subset(_is_rng, _rng_split_key)
changes, out = self._transform(self._dyn_vars.dict_data(), *args, **kwargs)
for key, v in self._dyn_vars.items():
v._value = changes[key]
for key, v in changes.items():
self._dyn_vars[key]._value = v
for key, v in rng_keys.items():
self._dyn_vars[key]._value = v
return out

def __repr__(self):
Expand All @@ -174,6 +227,18 @@ def __repr__(self):
f'{" " * len(name)} num_of_vars={len(self.vars().unique())})')
return format_ref

# def compile(self, *args, **kwargs):
# if self._transform is None: # initialize the transformation
# _ = self._get_transform(*args, **kwargs)
# # call the transformed function
# rng_keys = self._dyn_vars.call_on_subset(_is_rng, _rng_split_key)
# changes, out = self._transform.lower(self._dyn_vars.dict_data(), *args, **kwargs)
# for key, v in changes.items():
# self._dyn_vars[key]._value = v
# for key, v in rng_keys.items():
# self._dyn_vars[key]._value = v
# return out


_jit_par = '''
func : BrainPyObject, function, callable
Expand Down Expand Up @@ -412,7 +477,7 @@ def call_fun(self, *args, **kwargs):
cache = get_stack_cache(hash_v) # TODO: better cache mechanism
if cache is None:
fun2 = partial(fun, self)

with jax.ensure_compile_time_eval():
if len(static_argnums) or len(static_argnames):
fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames)
Expand Down

0 comments on commit fd83b5c

Please sign in to comment.