Skip to content

Commit

Permalink
[math] fix bugs in cond and while_loop when variables are used in…
Browse files Browse the repository at this point in the history
… both branches
  • Loading branch information
chaoming0625 committed Sep 12, 2023
1 parent 751ea08 commit db6e376
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 10 deletions.
2 changes: 1 addition & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.5"
__version__ = "2.4.4.post4"

# fundamental supporting modules
from brainpy import errors, check, tools
Expand Down
12 changes: 4 additions & 8 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,16 +529,12 @@ def cond(
if not jax.config.jax_disable_jit:
if dyn_vars is None:
with new_transform('cond'):
dyn_vars1, rets = evaluate_dyn_vars(
true_fun, *operands, use_eval_shape=current_transform_number() <= 1
)
dyn_vars2, rets = evaluate_dyn_vars(
false_fun, *operands, use_eval_shape=current_transform_number() <= 1
)
dyn_vars1, rets = evaluate_dyn_vars(true_fun, *operands, use_eval_shape=current_transform_number() <= 1)
dyn_vars2, rets = evaluate_dyn_vars(false_fun, *operands, use_eval_shape=current_transform_number() <= 1)
dyn_vars = dyn_vars1 + dyn_vars2
cache_stack((true_fun, false_fun), dyn_vars)
if current_transform_number() > 0:
return rets[1]
return rets
dyn_vars = VariableStack() if dyn_vars is None else dyn_vars
dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands)
for k in dyn_values.keys():
Expand Down Expand Up @@ -1014,7 +1010,7 @@ def while_loop(
dyn_vars = dyn_vars1 + dyn_vars2
cache_stack((body_fun, cond_fun), dyn_vars)
if current_transform_number():
return rets[1]
return rets
dyn_vars = VariableStack() if dyn_vars is None else dyn_vars
dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands)
for k, v in dyn_vars.items():
Expand Down
1 change: 0 additions & 1 deletion brainpy/_src/math/object_transform/tests/test_controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ def body(x, y):
print()
print(res)


def test2(self):
bm.random.seed()

Expand Down

0 comments on commit db6e376

Please sign in to comment.