From db6e376da1cbeb1170823383ae2c4fc79b03b0f8 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 12 Sep 2023 11:20:58 +0800 Subject: [PATCH] [math] fix bugs in `cond` and `while_loop` when variables are used in both branches --- brainpy/__init__.py | 2 +- brainpy/_src/math/object_transform/controls.py | 12 ++++-------- .../math/object_transform/tests/test_controls.py | 1 - 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index b02b1d426..97f5aa304 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "2.4.5" +__version__ = "2.4.4.post4" # fundamental supporting modules from brainpy import errors, check, tools diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 81cb9e440..61c7b7f0d 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -529,16 +529,12 @@ def cond( if not jax.config.jax_disable_jit: if dyn_vars is None: with new_transform('cond'): - dyn_vars1, rets = evaluate_dyn_vars( - true_fun, *operands, use_eval_shape=current_transform_number() <= 1 - ) - dyn_vars2, rets = evaluate_dyn_vars( - false_fun, *operands, use_eval_shape=current_transform_number() <= 1 - ) + dyn_vars1, rets = evaluate_dyn_vars(true_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars2, rets = evaluate_dyn_vars(false_fun, *operands, use_eval_shape=current_transform_number() <= 1) dyn_vars = dyn_vars1 + dyn_vars2 cache_stack((true_fun, false_fun), dyn_vars) if current_transform_number() > 0: - return rets[1] + return rets dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands) for k in dyn_values.keys(): @@ -1014,7 +1010,7 @@ def while_loop( dyn_vars = dyn_vars1 + dyn_vars2 cache_stack((body_fun, cond_fun), dyn_vars) if current_transform_number(): - return rets[1] + return rets dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands) for k, v in dyn_vars.items(): diff --git a/brainpy/_src/math/object_transform/tests/test_controls.py b/brainpy/_src/math/object_transform/tests/test_controls.py index 96edefcb2..5295d80db 100644 --- a/brainpy/_src/math/object_transform/tests/test_controls.py +++ b/brainpy/_src/math/object_transform/tests/test_controls.py @@ -228,7 +228,6 @@ def body(x, y): print() print(res) - def test2(self): bm.random.seed()