Skip to content

Commit

Permalink
[math] fix bugs in cond and while_loop when variables are used in…
Browse files Browse the repository at this point in the history
… both branches
  • Loading branch information
chaoming0625 committed Sep 12, 2023
1 parent af8d41b commit 751ea08
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 84 deletions.
2 changes: 1 addition & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.4.post3"
__version__ = "2.4.5"

# fundamental supporting modules
from brainpy import errors, check, tools
Expand Down
91 changes: 46 additions & 45 deletions brainpy/_src/dyn/projections/plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,57 +40,58 @@ class STDP_Song2000(Projection):
where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment
of :math:`A_{pre}`, :math:`A_2` is the increment of :math:`A_{post}` produced by a spike.
Example:
>>> import brainpy as bp
>>> import brainpy.math as bm
>>> class STDPNet(bp.DynamicalSystem):
>>> def __init__(self, num_pre, num_post):
>>> super().__init__()
>>> self.pre = bp.dyn.LifRef(num_pre, name='neu1')
>>> self.post = bp.dyn.LifRef(num_post, name='neu2')
>>> self.syn = bp.dyn.STDP_Song2000(
>>> pre=self.pre,
>>> delay=1.,
>>> comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num),
>>> weight=lambda s: bm.Variable(bm.random.rand(*s) * 0.1)),
>>> syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.),
>>> out=bp.dyn.COBA.desc(E=0.),
>>> post=self.post,
>>> tau_s=16.8,
>>> tau_t=33.7,
>>> A1=0.96,
>>> A2=0.53,
>>> )
>>>
>>> def update(self, I_pre, I_post):
>>> self.syn()
>>> self.pre(I_pre)
>>> self.post(I_post)
>>> conductance = self.syn.refs['syn'].g
>>> Apre = self.syn.refs['pre_trace'].g
>>> Apost = self.syn.refs['post_trace'].g
>>> current = self.post.sum_inputs(self.post.V)
>>> return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight
>>> duration = 300.
>>> I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
>>> [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255])
>>> I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
>>> [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250])
>>>
>>> net = STDPNet(1, 1)
>>> def run(i, I_pre, I_post):
>>> pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post)
>>> return pre_spike, post_spike, g, Apre, Apost, current, W
>>>
>>> indices = bm.arange(0, duration, bm.dt)
>>> pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post], jit=True)
Example::
import brainpy as bp
import brainpy.math as bm
class STDPNet(bp.DynamicalSystem):
def __init__(self, num_pre, num_post):
super().__init__()
self.pre = bp.dyn.LifRef(num_pre, name='neu1')
self.post = bp.dyn.LifRef(num_post, name='neu2')
self.syn = bp.dyn.STDP_Song2000(
pre=self.pre,
delay=1.,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num),
weight=bp.init.Uniform(max_val=0.1)),
syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.),
out=bp.dyn.COBA.desc(E=0.),
post=self.post,
tau_s=16.8,
tau_t=33.7,
A1=0.96,
A2=0.53,
)
def update(self, I_pre, I_post):
self.syn()
self.pre(I_pre)
self.post(I_post)
conductance = self.syn.refs['syn'].g
Apre = self.syn.refs['pre_trace'].g
Apost = self.syn.refs['post_trace'].g
current = self.post.sum_inputs(self.post.V)
return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight
duration = 300.
I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
[5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255])
I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
[10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250])
net = STDPNet(1, 1)
def run(i, I_pre, I_post):
pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post)
return pre_spike, post_spike, g, Apre, Apost, current, W
indices = bm.arange(0, duration, bm.dt)
pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post], jit=True)
Args:
tau_s: float, ArrayType, Callable. The time constant of :math:`A_{pre}`.
tau_t: float, ArrayType, Callable. The time constant of :math:`A_{post}`.
A1: float, ArrayType, Callable. The increment of :math:`A_{pre}` produced by a spike.
A2: float, ArrayType, Callable. The increment of :math:`A_{post}` produced by a spike.
%s
"""

def __init__(
Expand Down
41 changes: 16 additions & 25 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,25 +526,21 @@ def cond(
node_deprecation(child_objs)

dyn_vars = get_stack_cache((true_fun, false_fun))
_transform = _get_cond_transform(VariableStack() if dyn_vars is None else dyn_vars,
pred,
true_fun,
false_fun)
if jax.config.jax_disable_jit:
dyn_values, res = _transform(operands)

else:
if not jax.config.jax_disable_jit:
if dyn_vars is None:
with new_transform('cond'):
dyn_vars, rets = evaluate_dyn_vars(
_transform,
operands,
use_eval_shape=current_transform_number() <= 1
dyn_vars1, rets = evaluate_dyn_vars(
true_fun, *operands, use_eval_shape=current_transform_number() <= 1
)
dyn_vars2, rets = evaluate_dyn_vars(
false_fun, *operands, use_eval_shape=current_transform_number() <= 1
)
dyn_vars = dyn_vars1 + dyn_vars2
cache_stack((true_fun, false_fun), dyn_vars)
if current_transform_number() > 0:
return rets[1]
dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands)
dyn_vars = VariableStack() if dyn_vars is None else dyn_vars
dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands)
for k in dyn_values.keys():
dyn_vars[k]._value = dyn_values[k]
return res
Expand Down Expand Up @@ -1009,22 +1005,17 @@ def while_loop(
if not isinstance(operands, (list, tuple)):
operands = (operands,)

if jax.config.jax_disable_jit:
dyn_vars = VariableStack()

else:
dyn_vars = get_stack_cache(body_fun)

dyn_vars = get_stack_cache((body_fun, cond_fun))
if not jax.config.jax_disable_jit:
if dyn_vars is None:
with new_transform('while_loop'):
dyn_vars, rets = evaluate_dyn_vars(
_get_while_transform(cond_fun, body_fun, VariableStack()),
operands
)
cache_stack(body_fun, dyn_vars)
dyn_vars1, _ = evaluate_dyn_vars(cond_fun, *operands, use_eval_shape=current_transform_number() <= 1)
dyn_vars2, rets = evaluate_dyn_vars(body_fun, *operands, use_eval_shape=current_transform_number() <= 1)
dyn_vars = dyn_vars1 + dyn_vars2
cache_stack((body_fun, cond_fun), dyn_vars)
if current_transform_number():
return rets[1]

dyn_vars = VariableStack() if dyn_vars is None else dyn_vars
dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands)
for k, v in dyn_vars.items():
v._value = dyn_values[k]
Expand Down
56 changes: 43 additions & 13 deletions brainpy/_src/math/object_transform/tests/test_controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def update(self):
self.assertTrue(bm.allclose(cls.a, 10.))


class TestCond(unittest.TestCase):
def test1(self):
bm.random.seed(1)
bm.cond(True, lambda: bm.random.random(10), lambda: bm.random.random(10), ())
bm.cond(False, lambda: bm.random.random(10), lambda: bm.random.random(10), ())


class TestIfElse(unittest.TestCase):
def test1(self):
def f(a):
Expand Down Expand Up @@ -221,6 +228,34 @@ def body(x, y):
print()
print(res)


def test2(self):
bm.random.seed()

a = bm.Variable(bm.zeros(1))
b = bm.Variable(bm.ones(1))

def cond(x, y):
return x < 6.

def body(x, y):
a.value += x
b.value *= y
return x + b[0], y + 1.

res = bm.while_loop(body, cond, operands=(1., 1.))
print()
print(res)

with jax.disable_jit():
a = bm.Variable(bm.zeros(1))
b = bm.Variable(bm.ones(1))

res2 = bm.while_loop(body, cond, operands=(1., 1.))
print(res2)
self.assertTrue(bm.array_equal(res2[0], res[0]))
self.assertTrue(bm.array_equal(res2[1], res[1]))

def test3(self):
bm.random.seed()

Expand All @@ -242,32 +277,27 @@ def body(x, y):
print(a)
print(b)

def test2(self):
def test4(self):
bm.random.seed()

a = bm.Variable(bm.zeros(1))
b = bm.Variable(bm.ones(1))

def cond(x, y):
return x < 6.
a.value += 1
return bm.all(a.value < 6.)

def body(x, y):
a.value += x
b.value *= y
return x + b[0], y + 1.

res = bm.while_loop(body, cond, operands=(1., 1.))
print()
self.assertTrue(bm.allclose(a, 5.))
self.assertTrue(bm.allclose(b, 1.))
print(res)

with jax.disable_jit():
a = bm.Variable(bm.zeros(1))
b = bm.Variable(bm.ones(1))

res2 = bm.while_loop(body, cond, operands=(1., 1.))
print(res2)
self.assertTrue(bm.array_equal(res2[0], res[0]))
self.assertTrue(bm.array_equal(res2[1], res[1]))
print(a)
print(b)
print()


class TestDebugAndCompile(parameterized.TestCase):
Expand Down

0 comments on commit 751ea08

Please sign in to comment.