From 818397a727ce37f9aa2f90349a95af793b807683 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 21 Feb 2024 09:42:55 +0800 Subject: [PATCH 01/16] test improvement --- brainpy/_src/dnn/conv.py | 11 ++++++++++- brainpy/_src/dnn/tests/test_activation.py | 3 ++- brainpy/_src/dnn/tests/test_conv_layers.py | 11 ++++++----- brainpy/_src/dnn/tests/test_function.py | 6 ++---- brainpy/_src/dnn/tests/test_linear.py | 5 +++-- brainpy/_src/dnn/tests/test_mode.py | 5 +++-- brainpy/_src/dnn/tests/test_normalization.py | 5 +++-- brainpy/_src/dnn/tests/test_pooling_layers.py | 2 +- examples/dynamics_simulation/ei_nets.py | 2 +- 9 files changed, 31 insertions(+), 19 deletions(-) diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py index deead1f3b..e4b6e25d2 100644 --- a/brainpy/_src/dnn/conv.py +++ b/brainpy/_src/dnn/conv.py @@ -160,7 +160,7 @@ def update(self, x): nonbatching = False if x.ndim == self.num_spatial_dims + 1: nonbatching = True - x = x.unsqueeze(0) + x = bm.unsqueeze(x, 0) w = self.w.value if self.mask is not None: try: @@ -190,6 +190,9 @@ def __repr__(self): class Conv1d(_GeneralConv): """One-dimensional convolution. + The input should a 2d array with the shape of ``[H, C]``, or + a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size. + Parameters ---------- in_channels: int @@ -282,6 +285,9 @@ def _check_input_dim(self, x): class Conv2d(_GeneralConv): """Two-dimensional convolution. + The input should a 3d array with the shape of ``[H, W, C]``, or + a 4d array with the shape of ``[B, H, W, C]``. + Parameters ---------- in_channels: int @@ -375,6 +381,9 @@ def _check_input_dim(self, x): class Conv3d(_GeneralConv): """Three-dimensional convolution. + The input should a 3d array with the shape of ``[H, W, D, C]``, or + a 4d array with the shape of ``[B, H, W, D, C]``. + Parameters ---------- in_channels: int diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py index ba2a49efd..17054667d 100644 --- a/brainpy/_src/dnn/tests/test_activation.py +++ b/brainpy/_src/dnn/tests/test_activation.py @@ -1,5 +1,6 @@ -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py index 3c9fdfa87..05f523622 100644 --- a/brainpy/_src/dnn/tests/test_conv_layers.py +++ b/brainpy/_src/dnn/tests/test_conv_layers.py @@ -1,17 +1,15 @@ # -*- coding: utf-8 -*- -from unittest import TestCase -from absl.testing import absltest import jax.numpy as jnp -import brainpy.math as bm +from absl.testing import absltest from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm class TestConv(parameterized.TestCase): def test_Conv2D_img(self): - bm.random.seed() img = jnp.zeros((2, 200, 198, 4)) for k in range(4): x = 30 + 60 * k @@ -24,6 +22,7 @@ def test_Conv2D_img(self): strides=(2, 1), padding='VALID', groups=4) out = net(img) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 99, 196, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(img)[0, :, :, 0]) @@ -31,7 +30,6 @@ def test_Conv2D_img(self): bm.clear_buffer_memory() def test_conv1D(self): - bm.random.seed() with bp.math.training_environment(): model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,)) @@ -39,6 +37,7 @@ def test_conv1D(self): out = model(input) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :]) @@ -54,6 +53,7 @@ def test_conv2D(self): out = model(input) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :, 31]) @@ -67,6 +67,7 @@ def test_conv3D(self): input = bp.math.ones((2, 5, 5, 5, 3)) out = model(input) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 5, 5, 32)) bm.clear_buffer_memory() diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py index 269fec441..9ad15938d 100644 --- a/brainpy/_src/dnn/tests/test_function.py +++ b/brainpy/_src/dnn/tests/test_function.py @@ -1,12 +1,10 @@ # -*- coding: utf-8 -*- -from unittest import TestCase - -import jax.numpy as jnp -import brainpy.math as bm from absl.testing import absltest from absl.testing import parameterized + import brainpy as bp +import brainpy.math as bm class TestFunction(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 7fc89526c..df5293ab9 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -1,6 +1,7 @@ -import brainpy as bp -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized + +import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py index 0d754976f..3cf923d7b 100644 --- a/brainpy/_src/dnn/tests/test_mode.py +++ b/brainpy/_src/dnn/tests/test_mode.py @@ -1,7 +1,8 @@ -import brainpy.math as bm -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized + import brainpy as bp +import brainpy.math as bm class Test_Conv(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_normalization.py b/brainpy/_src/dnn/tests/test_normalization.py index fdc5b34e3..de2c9765b 100644 --- a/brainpy/_src/dnn/tests/test_normalization.py +++ b/brainpy/_src/dnn/tests/test_normalization.py @@ -1,7 +1,8 @@ -import brainpy.math as bm -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized + import brainpy as bp +import brainpy.math as bm class Test_Normalization(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py index 34f8f5cd5..5748edd8b 100644 --- a/brainpy/_src/dnn/tests/test_pooling_layers.py +++ b/brainpy/_src/dnn/tests/test_pooling_layers.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp import numpy as np -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized import brainpy as bp import brainpy.math as bm diff --git a/examples/dynamics_simulation/ei_nets.py b/examples/dynamics_simulation/ei_nets.py index f98527458..9c7daff55 100644 --- a/examples/dynamics_simulation/ei_nets.py +++ b/examples/dynamics_simulation/ei_nets.py @@ -228,7 +228,7 @@ def __init__(self): ) def update(self, input): - spk = self.delay.at('I') + spk = self.delay.at('delay') self.E(self.syn1(spk[:3200])) self.I(self.syn2(spk[3200:])) self.delay(self.N(input)) From d1ad4e9b6fa54198129311920e4e98938f996b13 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 21 Feb 2024 09:43:04 +0800 Subject: [PATCH 02/16] remove pytorch add --- brainpy/math/compat_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index e4570f6fd..3b0c3f517 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -12,7 +12,7 @@ arccos as arccos, acosh as acosh, arccosh as arccosh, - add as add, + # add as add, addcdiv as addcdiv, addcmul as addcmul, angle as angle, From d2c6d7858f1273bcc687a49eed1a294a2a772164 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 21 Feb 2024 10:02:01 +0800 Subject: [PATCH 03/16] variable evaluation using `brainpy.math.eval_shape` --- .../_src/math/object_transform/autograd.py | 22 +++----- .../_src/math/object_transform/controls.py | 51 +++++++------------ brainpy/_src/math/object_transform/jit.py | 42 ++++++--------- brainpy/_src/math/object_transform/tools.py | 42 ++++++++++++--- 4 files changed, 73 insertions(+), 84 deletions(-) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index f5e091675..b868b8076 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -32,6 +32,7 @@ VariableStack, current_transform_number, new_transform) +from .tools import eval_shape __all__ = [ 'grad', # gradient of scalar function @@ -204,22 +205,11 @@ def __call__(self, *args, **kwargs): stack = get_stack_cache(self.target) if stack is None: with new_transform(self): - with VariableStack() as stack: - if current_transform_number() > 1: - rets = self._transform( - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs - ) - else: - rets = jax.eval_shape( - self._transform, - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs - ) + stack, rets = eval_shape(self._transform, + [v.value for v in self._grad_vars], # variables for gradients + self._dyn_vars.dict_data(), # dynamical variables + *args, + **kwargs) cache_stack(self.target, stack) self._dyn_vars = stack diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 032a0fab6..e38a541e7 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -21,6 +21,8 @@ cache_stack ) from .tools import ( + eval_shape, + eval_shape_of_multi_funcs, evaluate_dyn_vars, dynvar_deprecation, node_deprecation, @@ -545,12 +547,10 @@ def cond( if not jax.config.jax_disable_jit: if dyn_vars is None: with new_transform('cond'): - dyn_vars1, rets = evaluate_dyn_vars(true_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars2, rets = evaluate_dyn_vars(false_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars = dyn_vars1 + dyn_vars2 + dyn_vars, rets = eval_shape_of_multi_funcs([true_fun, false_fun], *operands) cache_stack((true_fun, false_fun), dyn_vars) if current_transform_number() > 0: - return rets + return rets[0] dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands) for k in dyn_values.keys(): @@ -682,17 +682,13 @@ def ifelse( dyn_vars = get_stack_cache(tuple(branches)) if dyn_vars is None: with new_transform('ifelse'): - with VariableStack() as dyn_vars: - if current_transform_number() > 1: - rets = [branch(*operands) for branch in branches] - else: - rets = [jax.eval_shape(branch, *operands) for branch in branches] - trees = [jax.tree_util.tree_structure(ret) for ret in rets] - if not _all_equal(trees): - msg = 'All returns in branches should have the same tree structure. But we got:\n' - for tree in trees: - msg += f'- {tree}\n' - raise TypeError(msg) + dyn_vars, rets = eval_shape_of_multi_funcs(branches, *operands) + trees = [jax.tree_util.tree_structure(ret) for ret in rets] + if not _all_equal(trees): + msg = 'All returns in branches should have the same tree structure. But we got:\n' + for tree in trees: + msg += f'- {tree}\n' + raise TypeError(msg) cache_stack(tuple(branches), dyn_vars) if current_transform_number(): return rets[0] @@ -885,14 +881,9 @@ def for_loop( if dyn_vars is None: # TODO: better cache mechanism? with new_transform('for_loop'): - with VariableStack() as dyn_vars: - transform = _get_for_loop_transform(body_fun, VariableStack(), bar, - progress_bar, remat, reverse, unroll, - unroll_kwargs) - if current_transform_number() > 1: - rets = transform(operands) - else: - rets = jax.eval_shape(transform, operands) + transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar, + remat, reverse, unroll, unroll_kwargs) + dyn_vars, rets = eval_shape(transform, operands) cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache if current_transform_number(): return rets[1] @@ -1015,12 +1006,8 @@ def scan( if not jax.config.jax_disable_jit: if dyn_vars is None: with new_transform('scan'): - with VariableStack() as dyn_vars: - transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) - if current_transform_number() > 1: - rets = transform(init, operands) - else: - rets = jax.eval_shape(transform, init, operands) + transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) + dyn_vars, rets = eval_shape(transform, init, operands) cache_stack(body_fun, dyn_vars) # cache if current_transform_number(): return rets[0][1], rets[1] @@ -1141,12 +1128,10 @@ def while_loop( if not jax.config.jax_disable_jit: if dyn_vars is None: with new_transform('while_loop'): - dyn_vars1, _ = evaluate_dyn_vars(cond_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars2, rets = evaluate_dyn_vars(body_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars = dyn_vars1 + dyn_vars2 + dyn_vars, rets = eval_shape_of_multi_funcs([body_fun, cond_fun], *operands) cache_stack((body_fun, cond_fun), dyn_vars) if current_transform_number(): - return rets + return rets[1] dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands) for k, v in dyn_vars.items(): diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 7bb36f4e2..f8cd721dc 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -11,23 +11,18 @@ from typing import Callable, Union, Optional, Sequence, Dict, Any, Iterable import jax -from jax.sharding import Sharding from brainpy import tools, check +from .base import BrainPyObject, ObjectTransform +from .naming import get_stack_cache, cache_stack from .tools import (dynvar_deprecation, node_deprecation, - evaluate_dyn_vars_with_cache, - evaluate_dyn_vars, + eval_shape, _partial_fun) -from .base import BrainPyObject, ObjectTransform -from .naming import get_stack_cache, cache_stack -from ..ndarray import Array from .variables import (Variable, - VariableStack, - outermost_transform, - transform_stack, current_transform_number, new_transform) +from ..ndarray import Array RandomState = None @@ -152,15 +147,11 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs): def _get_transform(self, *args, **kwargs): with new_transform(self): - self._dyn_vars, rets = evaluate_dyn_vars( - self.fun, - *args, - static_argnums=self._static_argnums, - static_argnames=self._static_argnames, - use_eval_shape=current_transform_number() <= 1, - **kwargs - ) - + self._dyn_vars, rets = eval_shape(self.fun, + *args, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames, + **kwargs) # in_shardings if self._in_shardings is None: in_shardings = None @@ -477,15 +468,12 @@ def call_fun(self, *args, **kwargs): cache = get_stack_cache(hash_v) # TODO: better cache mechanism if cache is None: fun2 = partial(fun, self) - - with jax.ensure_compile_time_eval(): - if len(static_argnums) or len(static_argnames): - fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames) - else: - args_, kwargs_, fun3 = args, kwargs, fun2 - with VariableStack() as stack: - _ = jax.eval_shape(fun3, *args_, **kwargs_) - del args_, kwargs_ + if len(static_argnums) or len(static_argnames): + fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames) + else: + args_, kwargs_, fun3 = args, kwargs, fun2 + stack, _ = eval_shape(fun3, *args_, **kwargs_) + del args_, kwargs_ _transform = jax.jit( _make_transform(fun2, stack), static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums), diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 7b519590a..300650f58 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -143,8 +143,8 @@ def eval_shape( Args: fun: The callable function. - *args: - **kwargs: + *args: The positional arguments. + **kwargs: The keyword arguments. static_argnums: The static argument indices. static_argnames: The static argument names. @@ -162,12 +162,38 @@ def eval_shape( # evaluate the function fun_in_eval_shape.append(fun) try: - with jax.ensure_compile_time_eval(): - with VariableStack() as stack: - if len(fun_in_eval_shape) > 1: - returns = fun(*args, **kwargs) - else: - returns = jax.eval_shape(fun, *args, **kwargs) + with VariableStack() as stack: + if len(fun_in_eval_shape) > 1: + returns = fun(*args, **kwargs) + else: + returns = jax.eval_shape(fun, *args, **kwargs) finally: fun_in_eval_shape.pop() return stack, returns + + +def eval_shape_of_multi_funcs( + funs: Sequence[Callable], + *args, + static_argnums: Sequence[int] = (), + static_argnames: Sequence[str] = (), + **kwargs +): + """Compute the shape/dtype of ``funs`` without any FLOPs. + + Args: + fun: The callable function. + *args: The positional arguments. + **kwargs: The keyword arguments. + static_argnums: The static argument indices. + static_argnames: The static argument names. + + Returns: + The variable stack and the functional returns. + """ + stack, returns = VariableStack(), [] + for fun in funs: + st, ret = eval_shape(fun, *args, static_argnums=static_argnums, static_argnames=static_argnames, **kwargs) + stack += st + returns.append(ret) + return stack, returns From d336694555ba671f4c151130e61067d52086ea92 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 21 Feb 2024 10:42:04 +0800 Subject: [PATCH 04/16] fix bugs --- .../_src/math/object_transform/controls.py | 1 - brainpy/_src/math/object_transform/jit.py | 10 +--- brainpy/_src/math/object_transform/tools.py | 54 +++++++++++++++++-- 3 files changed, 51 insertions(+), 14 deletions(-) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index e38a541e7..2f43ec421 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -23,7 +23,6 @@ from .tools import ( eval_shape, eval_shape_of_multi_funcs, - evaluate_dyn_vars, dynvar_deprecation, node_deprecation, abstract diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index f8cd721dc..4ad2e2507 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -17,8 +17,7 @@ from .naming import get_stack_cache, cache_stack from .tools import (dynvar_deprecation, node_deprecation, - eval_shape, - _partial_fun) + eval_shape) from .variables import (Variable, current_transform_number, new_transform) @@ -468,12 +467,7 @@ def call_fun(self, *args, **kwargs): cache = get_stack_cache(hash_v) # TODO: better cache mechanism if cache is None: fun2 = partial(fun, self) - if len(static_argnums) or len(static_argnames): - fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames) - else: - args_, kwargs_, fun3 = args, kwargs, fun2 - stack, _ = eval_shape(fun3, *args_, **kwargs_) - del args_, kwargs_ + stack, _ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) _transform = jax.jit( _make_transform(fun2, stack), static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums), diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 300650f58..e5b137350 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -132,6 +132,50 @@ def evaluate_dyn_vars_with_cache( return stack +def _partial_fun2( + fun: Callable, + args: tuple, + kwargs: dict, + static_argnums: Sequence[int] = (), + static_argnames: Sequence[str] = () +): + num_args = len(args) + + # arguments + static_args = dict() + dyn_args = [] + dyn_arg_ids = dict() + static_argnums = list(static_argnums) + dyn_i = 0 + for i in range(num_args): + if i in static_argnums: + static_argnums.remove(i) + static_args[i] = args[i] + else: + dyn_args.append(args[i]) + dyn_arg_ids[i] = dyn_i + dyn_i += 1 + if len(static_argnums) > 0: + raise ValueError(f"Invalid static_argnums: {static_argnums}") + + # keyword arguments + static_kwargs, dyn_kwargs = {}, {} + for k, arg in kwargs.items(): + if k in static_argnames: + static_kwargs[k] = arg + else: + dyn_kwargs[k] = arg + del args, kwargs, static_argnums, static_argnames + + @wraps(fun) + def new_fun(*dynargs, **dynkwargs): + return fun(*[dynargs[dyn_arg_ids[i]] if i in dyn_arg_ids else static_args[i] + for i in range(num_args)], + **static_kwargs, **dynkwargs) + + return new_fun, dyn_args, dyn_kwargs + + def eval_shape( fun: Callable, *args, @@ -153,9 +197,9 @@ def eval_shape( """ # reorganize the function if len(static_argnums) or len(static_argnames): - f2, args, kwargs = _partial_fun(fun, args, kwargs, - static_argnums=static_argnums, - static_argnames=static_argnames) + f2, args, kwargs = _partial_fun2(fun, args, kwargs, + static_argnums=static_argnums, + static_argnames=static_argnames) else: f2, args, kwargs = fun, args, kwargs @@ -164,9 +208,9 @@ def eval_shape( try: with VariableStack() as stack: if len(fun_in_eval_shape) > 1: - returns = fun(*args, **kwargs) + returns = f2(*args, **kwargs) else: - returns = jax.eval_shape(fun, *args, **kwargs) + returns = jax.eval_shape(f2, *args, **kwargs) finally: fun_in_eval_shape.pop() return stack, returns From d7035120d1df44a5b6d4958a75377b0800d903bb Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 21 Feb 2024 19:17:12 +0800 Subject: [PATCH 05/16] update transformations --- .../_src/math/object_transform/autograd.py | 11 +++--- .../_src/math/object_transform/controls.py | 6 ++-- brainpy/_src/math/object_transform/jit.py | 15 ++++---- brainpy/_src/math/object_transform/tools.py | 34 +++++++++---------- 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index b868b8076..baa1f7606 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -205,11 +205,12 @@ def __call__(self, *args, **kwargs): stack = get_stack_cache(self.target) if stack is None: with new_transform(self): - stack, rets = eval_shape(self._transform, - [v.value for v in self._grad_vars], # variables for gradients - self._dyn_vars.dict_data(), # dynamical variables - *args, - **kwargs) + with VariableStack() as stack: + rets = eval_shape(self._transform, + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs) cache_stack(self.target, stack) self._dyn_vars = stack diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 2f43ec421..6ce4210b4 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -882,7 +882,8 @@ def for_loop( with new_transform('for_loop'): transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll, unroll_kwargs) - dyn_vars, rets = eval_shape(transform, operands) + with VariableStack() as dyn_vars: + rets = eval_shape(transform, operands) cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache if current_transform_number(): return rets[1] @@ -1006,7 +1007,8 @@ def scan( if dyn_vars is None: with new_transform('scan'): transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) - dyn_vars, rets = eval_shape(transform, init, operands) + with VariableStack() as dyn_vars: + rets = eval_shape(transform, init, operands) cache_stack(body_fun, dyn_vars) # cache if current_transform_number(): return rets[0][1], rets[1] diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 4ad2e2507..3965c1a71 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -19,6 +19,7 @@ node_deprecation, eval_shape) from .variables import (Variable, + VariableStack, current_transform_number, new_transform) from ..ndarray import Array @@ -146,11 +147,12 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs): def _get_transform(self, *args, **kwargs): with new_transform(self): - self._dyn_vars, rets = eval_shape(self.fun, - *args, - static_argnums=self._static_argnums, - static_argnames=self._static_argnames, - **kwargs) + with VariableStack() as self._dyn_vars: + rets = eval_shape(self.fun, + *args, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames, + **kwargs) # in_shardings if self._in_shardings is None: in_shardings = None @@ -467,7 +469,8 @@ def call_fun(self, *args, **kwargs): cache = get_stack_cache(hash_v) # TODO: better cache mechanism if cache is None: fun2 = partial(fun, self) - stack, _ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) + with VariableStack() as stack: + _ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) _transform = jax.jit( _make_transform(fun2, stack), static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums), diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index e5b137350..6010fac74 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -169,9 +169,9 @@ def _partial_fun2( @wraps(fun) def new_fun(*dynargs, **dynkwargs): - return fun(*[dynargs[dyn_arg_ids[i]] if i in dyn_arg_ids else static_args[i] - for i in range(num_args)], - **static_kwargs, **dynkwargs) + return fun(*[dynargs[dyn_arg_ids[id_]] if id_ in dyn_arg_ids else static_args[id_] for id_ in range(num_args)], + **static_kwargs, + **dynkwargs) return new_fun, dyn_args, dyn_kwargs @@ -197,23 +197,21 @@ def eval_shape( """ # reorganize the function if len(static_argnums) or len(static_argnames): - f2, args, kwargs = _partial_fun2(fun, args, kwargs, - static_argnums=static_argnums, - static_argnames=static_argnames) + f2, args, kwargs = _partial_fun2(fun, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames) else: - f2, args, kwargs = fun, args, kwargs + f2 = fun # evaluate the function fun_in_eval_shape.append(fun) try: - with VariableStack() as stack: - if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) - else: - returns = jax.eval_shape(f2, *args, **kwargs) + if len(fun_in_eval_shape) > 1: + returns = f2(*args, **kwargs) + else: + returns = jax.eval_shape(f2, *args, **kwargs) + pass finally: fun_in_eval_shape.pop() - return stack, returns + return returns def eval_shape_of_multi_funcs( @@ -226,7 +224,7 @@ def eval_shape_of_multi_funcs( """Compute the shape/dtype of ``funs`` without any FLOPs. Args: - fun: The callable function. + funs: A set of callable functions. *args: The positional arguments. **kwargs: The keyword arguments. static_argnums: The static argument indices. @@ -235,9 +233,9 @@ def eval_shape_of_multi_funcs( Returns: The variable stack and the functional returns. """ - stack, returns = VariableStack(), [] - for fun in funs: - st, ret = eval_shape(fun, *args, static_argnums=static_argnums, static_argnames=static_argnames, **kwargs) - stack += st + returns = [] + with VariableStack() as stack: + for fun in funs: + ret = eval_shape(fun, *args, static_argnums=static_argnums, static_argnames=static_argnames, **kwargs) returns.append(ret) return stack, returns From bbe5da9504d2966e398d711d68274fcf254639fd Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 21 Feb 2024 19:32:59 +0800 Subject: [PATCH 06/16] remove `new_transform` API --- .../_src/math/object_transform/autograd.py | 34 ++++------ .../_src/math/object_transform/controls.py | 47 ++++++-------- brainpy/_src/math/object_transform/jit.py | 64 +++++++++---------- brainpy/_src/math/object_transform/tools.py | 7 +- .../_src/math/object_transform/variables.py | 45 +++---------- 5 files changed, 74 insertions(+), 123 deletions(-) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index baa1f7606..ad8a5ccf6 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -28,10 +28,7 @@ get_stack_cache, cache_stack) from .base import (BrainPyObject, ObjectTransform) -from .variables import (Variable, - VariableStack, - current_transform_number, - new_transform) +from .variables import (Variable, VariableStack) from .tools import eval_shape __all__ = [ @@ -204,26 +201,21 @@ def __call__(self, *args, **kwargs): elif not self._eval_dyn_vars: # evaluate dynamical variables stack = get_stack_cache(self.target) if stack is None: - with new_transform(self): - with VariableStack() as stack: - rets = eval_shape(self._transform, - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs) + with VariableStack() as stack: + rets = eval_shape(self._transform, + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs) cache_stack(self.target, stack) - self._dyn_vars = stack - self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True + self._dyn_vars = stack + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True - # if not the outermost transformation - if current_transform_number(): - return self._return(rets) - else: - self._dyn_vars = stack - self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True + # if not the outermost transformation + if not stack.is_first_stack(): + return self._return(rets) rets = self._transform( [v.value for v in self._grad_vars], # variables for gradients diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 6ce4210b4..e43a27808 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -27,12 +27,7 @@ node_deprecation, abstract ) -from .variables import ( - Variable, - VariableStack, - new_transform, - current_transform_number, -) +from .variables import (Variable, VariableStack) __all__ = [ 'make_loop', @@ -545,10 +540,10 @@ def cond( dyn_vars = get_stack_cache((true_fun, false_fun)) if not jax.config.jax_disable_jit: if dyn_vars is None: - with new_transform('cond'): - dyn_vars, rets = eval_shape_of_multi_funcs([true_fun, false_fun], *operands) + with VariableStack() as dyn_vars: + rets = eval_shape_of_multi_funcs([true_fun, false_fun], *operands) cache_stack((true_fun, false_fun), dyn_vars) - if current_transform_number() > 0: + if not dyn_vars.is_first_stack(): return rets[0] dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands) @@ -680,8 +675,8 @@ def ifelse( else: dyn_vars = get_stack_cache(tuple(branches)) if dyn_vars is None: - with new_transform('ifelse'): - dyn_vars, rets = eval_shape_of_multi_funcs(branches, *operands) + with VariableStack() as dyn_vars: + rets = eval_shape_of_multi_funcs(branches, *operands) trees = [jax.tree_util.tree_structure(ret) for ret in rets] if not _all_equal(trees): msg = 'All returns in branches should have the same tree structure. But we got:\n' @@ -689,7 +684,7 @@ def ifelse( msg += f'- {tree}\n' raise TypeError(msg) cache_stack(tuple(branches), dyn_vars) - if current_transform_number(): + if not dyn_vars.is_first_stack(): return rets[0] branches = [_cond_transform_fun(fun, dyn_vars) for fun in branches] @@ -878,14 +873,13 @@ def for_loop( dyn_vars = get_stack_cache((body_fun, unroll_kwargs)) if jit: if dyn_vars is None: + transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar, + remat, reverse, unroll, unroll_kwargs) # TODO: better cache mechanism? - with new_transform('for_loop'): - transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar, - remat, reverse, unroll, unroll_kwargs) - with VariableStack() as dyn_vars: - rets = eval_shape(transform, operands) - cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache - if current_transform_number(): + with VariableStack() as dyn_vars: + rets = eval_shape(transform, operands) + cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache + if not dyn_vars.is_first_stack(): return rets[1] del rets else: @@ -1005,12 +999,11 @@ def scan( dyn_vars = get_stack_cache(body_fun) if not jax.config.jax_disable_jit: if dyn_vars is None: - with new_transform('scan'): - transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) - with VariableStack() as dyn_vars: - rets = eval_shape(transform, init, operands) + transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) + with VariableStack() as dyn_vars: + rets = eval_shape(transform, init, operands) cache_stack(body_fun, dyn_vars) # cache - if current_transform_number(): + if not dyn_vars.is_first_stack(): return rets[0][1], rets[1] del rets @@ -1128,10 +1121,10 @@ def while_loop( dyn_vars = get_stack_cache((body_fun, cond_fun)) if not jax.config.jax_disable_jit: if dyn_vars is None: - with new_transform('while_loop'): - dyn_vars, rets = eval_shape_of_multi_funcs([body_fun, cond_fun], *operands) + with VariableStack() as dyn_vars: + rets = eval_shape_of_multi_funcs([body_fun, cond_fun], *operands) cache_stack((body_fun, cond_fun), dyn_vars) - if current_transform_number(): + if not dyn_vars.is_first_stack(): return rets[1] dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands) diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 3965c1a71..394a3dd37 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -18,10 +18,7 @@ from .tools import (dynvar_deprecation, node_deprecation, eval_shape) -from .variables import (Variable, - VariableStack, - current_transform_number, - new_transform) +from .variables import (Variable, VariableStack) from ..ndarray import Array RandomState = None @@ -146,37 +143,36 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs): return changes, out def _get_transform(self, *args, **kwargs): - with new_transform(self): - with VariableStack() as self._dyn_vars: - rets = eval_shape(self.fun, - *args, - static_argnums=self._static_argnums, - static_argnames=self._static_argnames, - **kwargs) - # in_shardings - if self._in_shardings is None: - in_shardings = None + with VariableStack() as self._dyn_vars: + rets = eval_shape(self.fun, + *args, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames, + **kwargs) + # in_shardings + if self._in_shardings is None: + in_shardings = None + else: + if isinstance(self._in_shardings, (tuple, list)): + in_shardings = tuple(self._in_shardings) else: - if isinstance(self._in_shardings, (tuple, list)): - in_shardings = tuple(self._in_shardings) - else: - in_shardings = (self._in_shardings,) - _dyn_vars_sharing = get_shardings(self._dyn_vars) - in_shardings = (_dyn_vars_sharing,) + in_shardings - - # out_shardings - if self._out_shardings is None: - out_shardings = None + in_shardings = (self._in_shardings,) + _dyn_vars_sharing = get_shardings(self._dyn_vars) + in_shardings = (_dyn_vars_sharing,) + in_shardings + + # out_shardings + if self._out_shardings is None: + out_shardings = None + else: + if isinstance(self._out_shardings, (tuple, list)): + out_shardings = tuple(self._out_shardings) else: - if isinstance(self._out_shardings, (tuple, list)): - out_shardings = tuple(self._out_shardings) - else: - out_shardings = (self._out_shardings,) - global RandomState - if RandomState is None: - from brainpy.math.random import RandomState - _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) - out_shardings = (_dyn_vars_sharing,) + out_shardings + out_shardings = (self._out_shardings,) + global RandomState + if RandomState is None: + from brainpy.math.random import RandomState + _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) + out_shardings = (_dyn_vars_sharing,) + out_shardings # jit self._transform = jax.jit( @@ -199,7 +195,7 @@ def __call__(self, *args, **kwargs): if self._transform is None: # initialize the transformation rets = self._get_transform(*args, **kwargs) # if not the outermost transformation - if current_transform_number(): + if not self._dyn_vars.is_first_stack(): return rets # call the transformed function diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 6010fac74..173733cc7 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -234,8 +234,7 @@ def eval_shape_of_multi_funcs( The variable stack and the functional returns. """ returns = [] - with VariableStack() as stack: - for fun in funs: - ret = eval_shape(fun, *args, static_argnums=static_argnums, static_argnames=static_argnames, **kwargs) + for fun in funs: + ret = eval_shape(fun, *args, static_argnums=static_argnums, static_argnames=static_argnames, **kwargs) returns.append(ret) - return stack, returns + return returns diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index 5014da0bf..b7babae8d 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from typing import Optional, Any, List, Callable, Sequence, Union, Dict, Tuple import jax @@ -190,6 +189,14 @@ def remove_by_id(self, *ids, error_when_absent=False): remove_var_by_id = remove_by_id + @classmethod + def num_of_stack(self): + return len(var_stack_list) + + @classmethod + def is_first_stack(self): + return len(var_stack_list) == 0 + def __enter__(self) -> 'VariableStack': self.collect_values() # recollect the original value of each variable var_stack_list.append(self) @@ -210,42 +217,6 @@ def __add__(self, other: dict): var_stack_list: List[VariableStack] = [] -transform_stack: List[Callable] = [] - - -@contextmanager -def new_transform(transform: Any): - transform_stack.append(transform) - try: - yield - finally: - transform_stack.pop() - - -def outermost_stack(): - if len(var_stack_list): - return var_stack_list[0] - else: - return None - - -def outermost_transform(): - if len(transform_stack): - return transform_stack[0] - else: - return None - - -def current_transform_number(): - return len(transform_stack) - - -def _stack_add_read(var: 'Variable'): - pass - - -def _stack_add_write(var: 'Variable'): - pass @register_pytree_node_class From 9771bde90760327e700f2b2670da36613030f3d0 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 21 Feb 2024 19:34:36 +0800 Subject: [PATCH 07/16] update --- brainpy/_src/math/object_transform/base.py | 4 +- .../_src/math/object_transform/parallels.py | 460 ------------------ 2 files changed, 1 insertion(+), 463 deletions(-) delete mode 100644 brainpy/_src/math/object_transform/parallels.py diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index aaf053ae7..c52845a06 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -6,7 +6,6 @@ """ import numbers -import os import warnings from collections import namedtuple from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional @@ -14,14 +13,13 @@ import jax import numpy as np -from brainpy import errors +from brainpy._src.math.modes import Mode from brainpy._src.math.ndarray import (Array, ) from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector) from brainpy._src.math.object_transform.naming import (get_unique_name, check_name_uniqueness) from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar, VarList, VarDict) -from brainpy._src.math.modes import Mode from brainpy._src.math.sharding import BATCH_AXIS variable_ = None diff --git a/brainpy/_src/math/object_transform/parallels.py b/brainpy/_src/math/object_transform/parallels.py deleted file mode 100644 index 1eddce048..000000000 --- a/brainpy/_src/math/object_transform/parallels.py +++ /dev/null @@ -1,460 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -The parallel compilation tools for JAX backend. - -1. Vectorize compilation is implemented by the 'vmap()' function -2. Parallel compilation is implemented by the 'pmap()' function - -""" - - -import functools - -import jax -import jax.numpy as jnp -import numpy as np -from jax.interpreters.partial_eval import DynamicJaxprTracer -from jax.interpreters.partial_eval import JaxprTracer -from jax.interpreters.pxla import ShardedDeviceArray - -try: - from jax.errors import UnexpectedTracerError -except ImportError: - from jax.core import UnexpectedTracerError - -from brainpy import errors -from brainpy._src.math.random import RandomState -from brainpy._src.math.ndarray import Array -from brainpy.tools import change_func_name -from .base import BrainPyObject, ArrayCollector - -__all__ = [ - 'vmap', - 'pmap', -] - - -def _make_vmap(func, nonbatched_vars, batched_vars, in_axes, out_axes, - batch_idx, axis_name, f_name=None): - @functools.partial(jax.vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) - def vmapped_func(nonbatched_data, batched_data, *args, **kwargs): - nonbatched_vars.assign(nonbatched_data) - batched_vars.assign(batched_data) - out = func(*args, **kwargs) - nonbatched_changes = nonbatched_vars.dict() - batched_changes = batched_vars.dict() - return nonbatched_changes, batched_changes, out - - def call(*args, **kwargs): - n = args[batch_idx[0]].shape[batch_idx[1]] - nonbatched_data = nonbatched_vars.dict() - batched_data = {key: val.split_keys(n) for key, val in batched_vars.items()} - try: - out, dyn_changes, rand_changes = vmapped_func(nonbatched_data, batched_data, *args, **kwargs) - except UnexpectedTracerError as e: - nonbatched_vars.assign(nonbatched_data) - batched_vars.assign(batched_data) - raise errors.JaxTracerError() from e - # for key, v in dyn_changes.items(): - # dyn_vars[key] = reduce_func(v) - # for key, v in rand_changes.items(): - # rand_vars[key] = reduce_func(v) - return out - - return change_func_name(name=f_name, f=call) if f_name else call - - -def vmap(func, dyn_vars=None, batched_vars=None, - in_axes=0, out_axes=0, axis_name=None, - reduce_func=None, auto_infer=False): - """Vectorization compilation for class objects. - - Vectorized compile a function or a module to run in parallel on a single device. - - Examples - -------- - - Parameters - ---------- - func : BrainPyObject, function, callable - The function or the module to compile. - dyn_vars : dict, sequence - batched_vars : dict - in_axes : optional, int, sequence of int - Specify which input array axes to map over. If each positional argument to - ``obj_or_func`` is an array, then ``in_axes`` can be an integer, a None, - or a tuple of integers and Nones with length equal to the number of - positional arguments to ``obj_or_func``. An integer or ``None`` - indicates which array axis to map over for all arguments (with ``None`` - indicating not to map any axis), and a tuple indicates which axis to map - for each corresponding positional argument. Axis integers must be in the - range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of - dimensions (axes) of the corresponding input array. - - If the positional arguments to ``obj_or_func`` are container types, the - corresponding element of ``in_axes`` can itself be a matching container, - so that distinct array axes can be mapped for different container - elements. ``in_axes`` must be a container tree prefix of the positional - argument tuple passed to ``obj_or_func``. - - At least one positional argument must have ``in_axes`` not None. The sizes - of the mapped input axes for all mapped positional arguments must all be - equal. - - Arguments passed as keywords are always mapped over their leading axis - (i.e. axis index 0). - out_axes : optional, int, tuple/list/dict - Indicate where the mapped axis should appear in the output. All outputs - with a mapped axis must have a non-None ``out_axes`` specification. Axis - integers must be in the range ``[-ndim, ndim)`` for each output array, - where ``ndim`` is the number of dimensions (axes) of the array returned - by the :func:`vmap`-ed function, which is one more than the number of - dimensions (axes) of the corresponding array returned by ``obj_or_func``. - axis_name : optional - - Returns - ------- - obj_or_func : Any - Batched/vectorized version of ``obj_or_func`` with arguments that correspond to - those of ``obj_or_func``, but with extra array axes at positions indicated by - ``in_axes``, and a return value that corresponds to that of ``obj_or_func``, but - with extra array axes at positions indicated by ``out_axes``. - - """ - # if isinstance(func, DynamicalSystem): - # if len(func.steps): # DynamicalSystem has step functions - # - # # dynamical variables - # dyn_vars = (dyn_vars or func.vars().unique()) - # dyn_vars, rand_vars = ArrayCollector(), ArrayCollector() - # for key, val in dyn_vars.items(): - # if isinstance(val, RandomState): - # rand_vars[key] = val - # else: - # dyn_vars[key] = val - # - # # in axes - # if in_axes is None: - # in_axes = {key: (None, 0) for key in func.steps.keys()} - # elif isinstance(in_axes, int): - # in_axes = {key: (None, 0, in_axes) for key in func.steps.keys()} - # elif isinstance(in_axes, (tuple, list)): - # in_axes = {key: (None, 0) + tuple(in_axes) for key in func.steps.keys()} - # elif isinstance(in_axes, dict): - # keys = list(func.steps.keys()) - # if keys[0] not in in_axes: - # in_axes = {key: (None, 0, in_axes) for key in keys} - # else: - # in_axes = {key: (None, 0) + tuple(in_axes[key]) for key in keys} - # assert isinstance(in_axes, dict) - # - # # batch size index - # batch_idx = {} - # for key, axes in in_axes.items(): - # for i, axis in enumerate(axes[2:]): - # if axis is not None: - # batch_idx[key] = (i, axis) - # break - # else: - # raise ValueError(f'Found no batch axis: {axes}.') - # - # # out axes - # if out_axes is None: - # out_axes = {key: 0 for key in func.steps.keys()} - # elif isinstance(out_axes, int): - # out_axes = {key: out_axes for key in func.steps.keys()} - # elif isinstance(out_axes, (tuple, list)): - # out_axes = {key: tuple(out_axes) + (0, 0) for key in func.steps.keys()} - # elif isinstance(out_axes, dict): - # keys = list(func.steps.keys()) - # if keys[0] not in out_axes: - # out_axes = {key: (out_axes, 0, 0) for key in keys} - # else: - # out_axes = {key: tuple(out_axes[key]) + (0, 0) for key in keys} - # assert isinstance(out_axes, dict) - # - # # reduce_func - # if reduce_func is None: - # reduce_func = lambda x: x.mean(axis=0) - # - # # vectorized map functions - # for key in func.steps.keys(): - # func.steps[key] = _make_vmap(func=func.steps[key], - # dyn_vars=dyn_vars, - # rand_vars=rand_vars, - # in_axes=in_axes[key], - # out_axes=out_axes[key], - # axis_name=axis_name, - # batch_idx=batch_idx[key], - # reduce_func=reduce_func, - # f_name=key) - # - # return func - - if callable(func): - if auto_infer: - if dyn_vars is not None: - dyn_vars = dyn_vars - elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation - dyn_vars = func.vars().unique() - elif hasattr(func, '__self__'): - if isinstance(func.__self__, BrainPyObject): - dyn_vars = func.__self__.vars().unique() - - if dyn_vars is None: - return jax.vmap(func, - in_axes=in_axes, - out_axes=out_axes, - axis_name=axis_name) - - else: - if isinstance(dyn_vars, Array): - dyn_vars = [dyn_vars] - if isinstance(dyn_vars, (tuple, list)): - dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)} - assert isinstance(dyn_vars, dict) - - # dynamical variables - _dyn_vars, _rand_vars = ArrayCollector(), ArrayCollector() - for key, val in dyn_vars.items(): - if isinstance(val, RandomState): - _rand_vars[key] = val - else: - _dyn_vars[key] = val - - # in axes - if in_axes is None: - in_axes = (None, 0) - elif isinstance(in_axes, (int, dict)): - in_axes = (None, 0, in_axes) - elif isinstance(in_axes, (tuple, list)): - in_axes = (None, 0) + tuple(in_axes) - assert isinstance(in_axes, (tuple, list)) - - # batch size index - batch_idx = {} - for key, axes in batch_idx.items(): - for i, axis in enumerate(axes[2:]): - if axis is not None: - batch_idx[key] = (i, axis) - break - else: - raise ValueError(f'Found no batch axis: {axes}.') - - # out axes - if out_axes is None: - out_axes = 0 - elif isinstance(out_axes, (int, dict)): - out_axes = (out_axes, 0, 0) - elif isinstance(out_axes, (tuple, list)): - out_axes = tuple(out_axes) + (0, 0) - assert isinstance(out_axes, (list, tuple)) - - # reduce_func - if reduce_func is None: - reduce_func = lambda x: x.mean(axis=0) - - # jit function - return _make_vmap(func=func, - nonbatched_vars=_dyn_vars, - batched_vars=_rand_vars, - in_axes=in_axes, - out_axes=out_axes, - axis_name=axis_name, - batch_idx=batch_idx) - - else: - raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable ' - f'function, but we got {type(func)}.') - - -def _device_reshape(x): - """Reshape an input array in order to broadcast to multiple devices.""" - num_device = jax.local_device_count() - - if not hasattr(x, 'ndim'): - raise errors.BrainPyError(f'Expected Array, got {type(x)}. If you are trying to pass a scalar to ' - f'parallel, first convert it to a Array, for example np.float(0.5)') - if x.ndim == 0: - return np.broadcast_to(x, [num_device]) - if x.shape[0] % num_device != 0: - raise errors.BrainPyError(f'Must be able to equally divide batch {x.shape} among ' - f'{num_device} devices, but does not go equally.') - return x.reshape((num_device, x.shape[0] // num_device) + x.shape[1:]) - - -def _make_pmap(func, dyn_vars, rand_vars, reduce_func, axis_name=None, in_axes=0, - out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, - axis_size=None, donate_argnums=(), global_arg_shapes=None, f_name=None): - @functools.partial(jax.pmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, - static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, - backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes) - def pmapped_func(dyn_data, rand_data, *args, **kwargs): - dyn_vars.assign(dyn_data) - rand_vars.assign(rand_data) - out = func(*args, **kwargs) - dyn_changes = dyn_vars.dict() - rand_changes = rand_vars.dict() - return out, dyn_changes, rand_changes - - def call(*args): - un_replicated = [k for k, v in dyn_vars.items() - if not isinstance(v.value, (ShardedDeviceArray, JaxprTracer, DynamicJaxprTracer))] - if len(un_replicated): - raise errors.BrainPyError(f'Some variables were not replicated: {un_replicated}.' - f'did you forget to call xx.replicate() on them?') - _args = [] - for i, x in enumerate(args): - if i + 2 in static_broadcasted_argnums: - _args.append(x) - else: - _args.append(jax.tree_map(_device_reshape, [x])[0]) - dyn_data = dyn_vars.dict() - rand_data = rand_vars.dict() - output, dyn_changes, rand_changes = pmapped_func(dyn_data, rand_data, *_args) - dyn_vars.assign(dyn_changes) - rand_vars.assign(rand_changes) - return jax.tree_map(reduce_func, output) - - return change_func_name(name=f_name, f=call) if f_name else call - - -def pmap(func, dyn_vars=None, axis_name=None, in_axes=0, out_axes=0, static_broadcasted_argnums=(), - devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None, - reduce_func=None): - """Parallel compilation for class objects. - - Parallel compile a function or a module to run on multiple devices in parallel. - - Parameters - ---------- - func - axis_name - in_axes - out_axes - static_broadcasted_argnums - devices - backend - axis_size - donate_argnums - global_arg_shapes - - Returns - ------- - - - Examples - -------- - - - """ - - # if isinstance(func, DynamicalSystem): - # if len(func.steps): # DynamicalSystem has step functions - # - # # dynamical variables - # all_vars = (dyn_vars or func.vars().unique()) - # dyn_vars = ArrayCollector() - # rand_vars = ArrayCollector() - # for key, val in all_vars.items(): - # if isinstance(val, RandomState): - # rand_vars[key] = val - # else: - # dyn_vars[key] = val - # - # # reduce function - # if reduce_func is None: - # reduce_func = jnp.concatenate - # - # # static broadcast-ed arguments - # if static_broadcasted_argnums is None: - # static_broadcasted_argnums = () - # elif isinstance(static_broadcasted_argnums, int): - # static_broadcasted_argnums = (static_broadcasted_argnums + 2,) - # elif isinstance(static_broadcasted_argnums, (tuple, list)): - # static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) - # assert isinstance(static_broadcasted_argnums, (tuple, list)) - # - # # jit functions - # for key in func.steps.keys(): - # step = func.steps[key] - # func.steps[key] = _make_pmap(dyn_vars=dyn_vars, - # rand_vars=rand_vars, - # func=step, - # axis_name=axis_name, - # in_axes=in_axes, - # out_axes=out_axes, - # static_broadcasted_argnums=static_broadcasted_argnums, - # devices=devices, - # backend=backend, - # axis_size=axis_size, - # donate_argnums=donate_argnums, - # global_arg_shapes=global_arg_shapes, - # reduce_func=reduce_func, - # f_name=key) - # return func - - if callable(func): - if dyn_vars is not None: - dyn_vars = dyn_vars - elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation - dyn_vars = func.vars().unique() - elif hasattr(func, '__self__'): - if isinstance(func.__self__, BrainPyObject): - dyn_vars = func.__self__.vars().unique() - - if dyn_vars is None: - return jax.pmap(func, - axis_name=axis_name, - in_axes=in_axes, - out_axes=out_axes, - static_broadcasted_argnums=static_broadcasted_argnums, - devices=devices, - backend=backend, - axis_size=axis_size, - donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes) - else: - # dynamical variables - dyn_vars = ArrayCollector() - rand_vars = ArrayCollector() - for key, val in dyn_vars.items(): - if isinstance(val, RandomState): - rand_vars[key] = val - else: - dyn_vars[key] = val - - # static broadcast-ed arguments - if static_broadcasted_argnums is None: - static_broadcasted_argnums = () - elif isinstance(static_broadcasted_argnums, int): - static_broadcasted_argnums = (static_broadcasted_argnums + 2,) - elif isinstance(static_broadcasted_argnums, (tuple, list)): - static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) - assert isinstance(static_broadcasted_argnums, (tuple, list)) - - # reduce function - if reduce_func is None: - reduce_func = jnp.concatenate - - # jit function - func.__call__ = _make_pmap(dyn_vars=dyn_vars, - rand_vars=rand_vars, - func=func, - axis_name=axis_name, - in_axes=in_axes, - out_axes=out_axes, - static_broadcasted_argnums=static_broadcasted_argnums, - devices=devices, - backend=backend, - axis_size=axis_size, - donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes, - reduce_func=reduce_func) - return func - - else: - raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable function, ' - f'but we got {type(func)}.') From da8807ca5c4cb69c8b27ed7a333f4c43887f9d51 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 21 Feb 2024 19:38:03 +0800 Subject: [PATCH 08/16] update --- .../_src/math/object_transform/controls.py | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index e43a27808..bfe605130 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -870,23 +870,23 @@ def for_loop( if jit is None: # jax disable jit jit = not jax.config.jax_disable_jit - dyn_vars = get_stack_cache((body_fun, unroll_kwargs)) + stack = get_stack_cache((body_fun, unroll_kwargs)) if jit: - if dyn_vars is None: + if stack is None: transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll, unroll_kwargs) # TODO: better cache mechanism? - with VariableStack() as dyn_vars: + with VariableStack() as stack: rets = eval_shape(transform, operands) - cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache - if not dyn_vars.is_first_stack(): + cache_stack((body_fun, unroll_kwargs), stack) # cache + if not stack.is_first_stack(): return rets[1] del rets else: - dyn_vars = VariableStack() + stack = VariableStack() # TODO: cache mechanism? - transform = _get_for_loop_transform(body_fun, dyn_vars, bar, + transform = _get_for_loop_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll, unroll_kwargs) if jit: @@ -894,11 +894,11 @@ def for_loop( else: with jax.disable_jit(): dyn_vals, out_vals = transform(operands) - for key in dyn_vars.keys(): - dyn_vars[key]._value = dyn_vals[key] + for key in stack.keys(): + stack[key]._value = dyn_vals[key] if progress_bar: bar.close() - del dyn_vals, dyn_vars + del dyn_vals, stack return out_vals @@ -996,22 +996,22 @@ def scan( num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]]) bar = tqdm(total=num_total) - dyn_vars = get_stack_cache(body_fun) + stack = get_stack_cache(body_fun) if not jax.config.jax_disable_jit: - if dyn_vars is None: + if stack is None: transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) - with VariableStack() as dyn_vars: + with VariableStack() as stack: rets = eval_shape(transform, init, operands) - cache_stack(body_fun, dyn_vars) # cache - if not dyn_vars.is_first_stack(): + cache_stack(body_fun, stack) # cache + if not stack.is_first_stack(): return rets[0][1], rets[1] del rets - dyn_vars = VariableStack() if dyn_vars is None else dyn_vars - transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll) + stack = VariableStack() if stack is None else stack + transform = _get_scan_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll) (dyn_vals, carry), out_vals = transform(init, operands) - for key in dyn_vars.keys(): - dyn_vars[key]._value = dyn_vals[key] + for key in stack.keys(): + stack[key]._value = dyn_vals[key] if progress_bar: bar.close() return carry, out_vals @@ -1118,16 +1118,16 @@ def while_loop( if not isinstance(operands, (list, tuple)): operands = (operands,) - dyn_vars = get_stack_cache((body_fun, cond_fun)) + stack = get_stack_cache((body_fun, cond_fun)) if not jax.config.jax_disable_jit: - if dyn_vars is None: - with VariableStack() as dyn_vars: + if stack is None: + with VariableStack() as stack: rets = eval_shape_of_multi_funcs([body_fun, cond_fun], *operands) - cache_stack((body_fun, cond_fun), dyn_vars) - if not dyn_vars.is_first_stack(): + cache_stack((body_fun, cond_fun), stack) + if not stack.is_first_stack(): return rets[1] - dyn_vars = VariableStack() if dyn_vars is None else dyn_vars - dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands) - for k, v in dyn_vars.items(): + stack = VariableStack() if stack is None else stack + dyn_values, out = _get_while_transform(cond_fun, body_fun, stack)(operands) + for k, v in stack.items(): v._value = dyn_values[k] return out From 05784b9678c9302aba9002555a614f7f22077be7 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 21 Feb 2024 19:40:46 +0800 Subject: [PATCH 09/16] fix --- brainpy/_src/math/object_transform/jit.py | 28 +++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 394a3dd37..6c729e1d4 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -146,9 +146,9 @@ def _get_transform(self, *args, **kwargs): with VariableStack() as self._dyn_vars: rets = eval_shape(self.fun, *args, + **kwargs, static_argnums=self._static_argnums, - static_argnames=self._static_argnames, - **kwargs) + static_argnames=self._static_argnames,) # in_shardings if self._in_shardings is None: in_shardings = None @@ -174,18 +174,18 @@ def _get_transform(self, *args, **kwargs): _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) out_shardings = (_dyn_vars_sharing,) + out_shardings - # jit - self._transform = jax.jit( - self._transform_function, - static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), - static_argnames=self._static_argnames, - donate_argnums=self._donate_argnums, - inline=self._inline, - keep_unused=self._keep_unused, - abstracted_axes=self._abstracted_axes, - in_shardings=in_shardings, - out_shardings=out_shardings, - ) + # jit + self._transform = jax.jit( + self._transform_function, + static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), + static_argnames=self._static_argnames, + donate_argnums=self._donate_argnums, + inline=self._inline, + keep_unused=self._keep_unused, + abstracted_axes=self._abstracted_axes, + in_shardings=in_shardings, + out_shardings=out_shardings, + ) return rets def __call__(self, *args, **kwargs): From b96c5ad4b7f0b20371bdd258add5211a8c08849c Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 21 Feb 2024 19:51:29 +0800 Subject: [PATCH 10/16] fix --- brainpy/_src/math/object_transform/tools.py | 1 - 1 file changed, 1 deletion(-) diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 173733cc7..e487bfce6 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -208,7 +208,6 @@ def eval_shape( returns = f2(*args, **kwargs) else: returns = jax.eval_shape(f2, *args, **kwargs) - pass finally: fun_in_eval_shape.pop() return returns From d51fc98d8df76e8becf77968afa29cde4571da0d Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 21 Feb 2024 23:40:32 +0800 Subject: [PATCH 11/16] fix bugs --- .../_src/math/object_transform/controls.py | 12 +++++----- brainpy/_src/math/object_transform/tools.py | 22 +++++++++++-------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index bfe605130..7fe6664ff 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -22,7 +22,7 @@ ) from .tools import ( eval_shape, - eval_shape_of_multi_funcs, + eval_shape_with_context, dynvar_deprecation, node_deprecation, abstract @@ -541,7 +541,8 @@ def cond( if not jax.config.jax_disable_jit: if dyn_vars is None: with VariableStack() as dyn_vars: - rets = eval_shape_of_multi_funcs([true_fun, false_fun], *operands) + rets = eval_shape_with_context(true_fun, *operands) + _ = eval_shape_with_context(false_fun, *operands) cache_stack((true_fun, false_fun), dyn_vars) if not dyn_vars.is_first_stack(): return rets[0] @@ -676,7 +677,7 @@ def ifelse( dyn_vars = get_stack_cache(tuple(branches)) if dyn_vars is None: with VariableStack() as dyn_vars: - rets = eval_shape_of_multi_funcs(branches, *operands) + rets = [eval_shape_with_context(fun, *operands) for fun in branches] trees = [jax.tree_util.tree_structure(ret) for ret in rets] if not _all_equal(trees): msg = 'All returns in branches should have the same tree structure. But we got:\n' @@ -1122,10 +1123,11 @@ def while_loop( if not jax.config.jax_disable_jit: if stack is None: with VariableStack() as stack: - rets = eval_shape_of_multi_funcs([body_fun, cond_fun], *operands) + _ = eval_shape_with_context(cond_fun, *operands) + rets = eval_shape_with_context(body_fun, *operands) cache_stack((body_fun, cond_fun), stack) if not stack.is_first_stack(): - return rets[1] + return rets stack = VariableStack() if stack is None else stack dyn_values, out = _get_while_transform(cond_fun, body_fun, stack)(operands) for k, v in stack.items(): diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index e487bfce6..a7494a01e 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -213,27 +213,31 @@ def eval_shape( return returns -def eval_shape_of_multi_funcs( - funs: Sequence[Callable], +def eval_shape_with_context( + fun: Callable, *args, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), + return_context: bool = False, **kwargs ): - """Compute the shape/dtype of ``funs`` without any FLOPs. + """Compute the shape/dtype of ``fun`` without any FLOPs. Args: - funs: A set of callable functions. + fun: The callable function. *args: The positional arguments. **kwargs: The keyword arguments. static_argnums: The static argument indices. static_argnames: The static argument names. + return_context: Whether to return the variable stack. Returns: The variable stack and the functional returns. """ - returns = [] - for fun in funs: - ret = eval_shape(fun, *args, static_argnums=static_argnums, static_argnames=static_argnames, **kwargs) - returns.append(ret) - return returns + with VariableStack() as stack: + returns = eval_shape(fun, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) + if return_context: + return stack, returns + else: + return returns + From bd049b1739594e87e3aa35bd6638b7d38b4b1e11 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 22 Feb 2024 10:14:15 +0800 Subject: [PATCH 12/16] fix bugs --- brainpy/_src/math/object_transform/controls.py | 2 +- brainpy/_src/math/object_transform/tools.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 7fe6664ff..3b5b5a8ac 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -545,7 +545,7 @@ def cond( _ = eval_shape_with_context(false_fun, *operands) cache_stack((true_fun, false_fun), dyn_vars) if not dyn_vars.is_first_stack(): - return rets[0] + return rets dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands) for k in dyn_values.keys(): diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index a7494a01e..7057d047e 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -210,6 +210,7 @@ def eval_shape( returns = jax.eval_shape(f2, *args, **kwargs) finally: fun_in_eval_shape.pop() + del f2 return returns From 007bae6ddfd9d46d744633c5c08c72ffa6e39314 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 22 Feb 2024 10:18:32 +0800 Subject: [PATCH 13/16] updates --- .../_src/math/object_transform/controls.py | 47 +++++++++---------- brainpy/_src/math/object_transform/naming.py | 3 +- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 3b5b5a8ac..286a78919 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -538,14 +538,13 @@ def cond( node_deprecation(child_objs) dyn_vars = get_stack_cache((true_fun, false_fun)) - if not jax.config.jax_disable_jit: - if dyn_vars is None: - with VariableStack() as dyn_vars: - rets = eval_shape_with_context(true_fun, *operands) - _ = eval_shape_with_context(false_fun, *operands) - cache_stack((true_fun, false_fun), dyn_vars) - if not dyn_vars.is_first_stack(): - return rets + if not jax.config.jax_disable_jit and dyn_vars is None: + with VariableStack() as dyn_vars: + rets = eval_shape_with_context(true_fun, *operands) + _ = eval_shape_with_context(false_fun, *operands) + cache_stack((true_fun, false_fun), dyn_vars) + if not dyn_vars.is_first_stack(): + return rets dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands) for k in dyn_values.keys(): @@ -998,15 +997,14 @@ def scan( bar = tqdm(total=num_total) stack = get_stack_cache(body_fun) - if not jax.config.jax_disable_jit: - if stack is None: - transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) - with VariableStack() as stack: - rets = eval_shape(transform, init, operands) - cache_stack(body_fun, stack) # cache - if not stack.is_first_stack(): - return rets[0][1], rets[1] - del rets + if not jax.config.jax_disable_jit and stack is None: + transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) + with VariableStack() as stack: + rets = eval_shape(transform, init, operands) + cache_stack(body_fun, stack) # cache + if not stack.is_first_stack(): + return rets[0][1], rets[1] + del rets stack = VariableStack() if stack is None else stack transform = _get_scan_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll) @@ -1120,14 +1118,13 @@ def while_loop( operands = (operands,) stack = get_stack_cache((body_fun, cond_fun)) - if not jax.config.jax_disable_jit: - if stack is None: - with VariableStack() as stack: - _ = eval_shape_with_context(cond_fun, *operands) - rets = eval_shape_with_context(body_fun, *operands) - cache_stack((body_fun, cond_fun), stack) - if not stack.is_first_stack(): - return rets + if not jax.config.jax_disable_jit and stack is None: + with VariableStack() as stack: + _ = eval_shape_with_context(cond_fun, *operands) + rets = eval_shape_with_context(body_fun, *operands) + cache_stack((body_fun, cond_fun), stack) + if not stack.is_first_stack(): + return rets stack = VariableStack() if stack is None else stack dyn_values, out = _get_while_transform(cond_fun, body_fun, stack)(operands) for k, v in stack.items(): diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py index 1c8ca6ef9..1181e003b 100644 --- a/brainpy/_src/math/object_transform/naming.py +++ b/brainpy/_src/math/object_transform/naming.py @@ -41,7 +41,7 @@ def get_unique_name(type_: str): return name -def clear_name_cache(ignore_warn=False): +def clear_name_cache(ignore_warn=True): """Clear the cached names.""" _name2id.clear() _typed_names.clear() @@ -57,6 +57,7 @@ def cache_stack(func, stack): def clear_stack_cache(): + """Clear the cached stack.""" for k in tuple(_fun2stack.keys()): del _fun2stack[k] From 9d9cd0123e22637eb4045889570c563cbc34326a Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 22 Feb 2024 10:20:24 +0800 Subject: [PATCH 14/16] updates --- brainpy/_src/math/object_transform/jit.py | 46 +++++++++++------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 6c729e1d4..73eab2f91 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -148,31 +148,31 @@ def _get_transform(self, *args, **kwargs): *args, **kwargs, static_argnums=self._static_argnums, - static_argnames=self._static_argnames,) - # in_shardings - if self._in_shardings is None: - in_shardings = None - else: - if isinstance(self._in_shardings, (tuple, list)): - in_shardings = tuple(self._in_shardings) + static_argnames=self._static_argnames) + # in_shardings + if self._in_shardings is None: + in_shardings = None else: - in_shardings = (self._in_shardings,) - _dyn_vars_sharing = get_shardings(self._dyn_vars) - in_shardings = (_dyn_vars_sharing,) + in_shardings - - # out_shardings - if self._out_shardings is None: - out_shardings = None - else: - if isinstance(self._out_shardings, (tuple, list)): - out_shardings = tuple(self._out_shardings) + if isinstance(self._in_shardings, (tuple, list)): + in_shardings = tuple(self._in_shardings) + else: + in_shardings = (self._in_shardings,) + _dyn_vars_sharing = get_shardings(self._dyn_vars) + in_shardings = (_dyn_vars_sharing,) + in_shardings + + # out_shardings + if self._out_shardings is None: + out_shardings = None else: - out_shardings = (self._out_shardings,) - global RandomState - if RandomState is None: - from brainpy.math.random import RandomState - _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) - out_shardings = (_dyn_vars_sharing,) + out_shardings + if isinstance(self._out_shardings, (tuple, list)): + out_shardings = tuple(self._out_shardings) + else: + out_shardings = (self._out_shardings,) + global RandomState + if RandomState is None: + from brainpy.math.random import RandomState + _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) + out_shardings = (_dyn_vars_sharing,) + out_shardings # jit self._transform = jax.jit( From ad47ce8f977e1167e067009127463e12f5b431ad Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 22 Feb 2024 13:05:58 +0800 Subject: [PATCH 15/16] upgrade --- .../_src/math/object_transform/controls.py | 12 +++-- brainpy/_src/math/object_transform/tools.py | 44 ++++++------------- 2 files changed, 19 insertions(+), 37 deletions(-) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 286a78919..3edeb08e8 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -22,7 +22,6 @@ ) from .tools import ( eval_shape, - eval_shape_with_context, dynvar_deprecation, node_deprecation, abstract @@ -540,8 +539,8 @@ def cond( dyn_vars = get_stack_cache((true_fun, false_fun)) if not jax.config.jax_disable_jit and dyn_vars is None: with VariableStack() as dyn_vars: - rets = eval_shape_with_context(true_fun, *operands) - _ = eval_shape_with_context(false_fun, *operands) + rets = eval_shape(true_fun, *operands, with_stack=True)[1] + _ = eval_shape(false_fun, *operands, with_stack=True) cache_stack((true_fun, false_fun), dyn_vars) if not dyn_vars.is_first_stack(): return rets @@ -676,7 +675,7 @@ def ifelse( dyn_vars = get_stack_cache(tuple(branches)) if dyn_vars is None: with VariableStack() as dyn_vars: - rets = [eval_shape_with_context(fun, *operands) for fun in branches] + rets = [eval_shape(fun, *operands, with_stack=True)[1] for fun in branches] trees = [jax.tree_util.tree_structure(ret) for ret in rets] if not _all_equal(trees): msg = 'All returns in branches should have the same tree structure. But we got:\n' @@ -1109,7 +1108,6 @@ def while_loop( No longer need to provide ``child_objs``. This function is capable of automatically collecting the children objects used in the target ``func``. - """ dynvar_deprecation(dyn_vars) node_deprecation(child_objs) @@ -1120,8 +1118,8 @@ def while_loop( stack = get_stack_cache((body_fun, cond_fun)) if not jax.config.jax_disable_jit and stack is None: with VariableStack() as stack: - _ = eval_shape_with_context(cond_fun, *operands) - rets = eval_shape_with_context(body_fun, *operands) + _ = eval_shape(cond_fun, *operands, with_stack=True) + rets = eval_shape(body_fun, *operands, with_stack=True)[1] cache_stack((body_fun, cond_fun), stack) if not stack.is_first_stack(): return rets diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 7057d047e..632c6d79e 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -181,6 +181,7 @@ def eval_shape( *args, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), + with_stack: bool = False, **kwargs ): """Compute the shape/dtype of ``fun`` without any FLOPs. @@ -189,6 +190,7 @@ def eval_shape( fun: The callable function. *args: The positional arguments. **kwargs: The keyword arguments. + with_stack: Whether evaluate the function within a local variable stack. static_argnums: The static argument indices. static_argnames: The static argument names. @@ -204,40 +206,22 @@ def eval_shape( # evaluate the function fun_in_eval_shape.append(fun) try: - if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) + if with_stack: + with VariableStack() as stack: + if len(fun_in_eval_shape) > 1: + returns = f2(*args, **kwargs) + else: + returns = jax.eval_shape(f2, *args, **kwargs) else: - returns = jax.eval_shape(f2, *args, **kwargs) + stack = None + if len(fun_in_eval_shape) > 1: + returns = f2(*args, **kwargs) + else: + returns = jax.eval_shape(f2, *args, **kwargs) finally: fun_in_eval_shape.pop() del f2 - return returns - - -def eval_shape_with_context( - fun: Callable, - *args, - static_argnums: Sequence[int] = (), - static_argnames: Sequence[str] = (), - return_context: bool = False, - **kwargs -): - """Compute the shape/dtype of ``fun`` without any FLOPs. - - Args: - fun: The callable function. - *args: The positional arguments. - **kwargs: The keyword arguments. - static_argnums: The static argument indices. - static_argnames: The static argument names. - return_context: Whether to return the variable stack. - - Returns: - The variable stack and the functional returns. - """ - with VariableStack() as stack: - returns = eval_shape(fun, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) - if return_context: + if with_stack: return stack, returns else: return returns From 51eab01fac3ffccea0ba3629ae868ad5ab3a2313 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 22 Feb 2024 13:08:21 +0800 Subject: [PATCH 16/16] add `brainpy.math.VariableStack` --- brainpy/math/oo_transform.py | 4 ++++ docs/apis/brainpy.math.oo_transform.rst | 1 + 2 files changed, 5 insertions(+) diff --git a/brainpy/math/oo_transform.py b/brainpy/math/oo_transform.py index 548a987d0..7654731d8 100644 --- a/brainpy/math/oo_transform.py +++ b/brainpy/math/oo_transform.py @@ -59,3 +59,7 @@ eval_shape as eval_shape, ) +from brainpy._src.math.object_transform.variables import ( + VariableStack as VariableStack, +) + diff --git a/docs/apis/brainpy.math.oo_transform.rst b/docs/apis/brainpy.math.oo_transform.rst index 754e0d81d..9ed9cf46a 100644 --- a/docs/apis/brainpy.math.oo_transform.rst +++ b/docs/apis/brainpy.math.oo_transform.rst @@ -77,4 +77,5 @@ Helpers for Object-oriented Transformations :template: classtemplate.rst eval_shape + VariableStack