Skip to content

Commit

Permalink
fix test bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Oct 8, 2023
1 parent 4de1acd commit 44c3e4b
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 10 deletions.
3 changes: 2 additions & 1 deletion brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,8 @@ def update_STDP(self, dW, constraints=None):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
pre_ids, post_ids = bm.sparse.csr_to_coo(self.indices, self.indptr)
with jax.ensure_compile_time_eval():
pre_ids, post_ids = bm.sparse.csr_to_coo(self.indices, self.indptr)
sparse_dW = dW[pre_ids, post_ids]
if self.weight.shape != sparse_dW.shape:
raise ValueError(f'The shape of sparse delta_weight {sparse_dW.shape} '
Expand Down
9 changes: 7 additions & 2 deletions brainpy/_src/dyn/projections/tests/test_STDP.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# -*- coding: utf-8 -*-


import os
os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


class Test_STDP(parameterized.TestCase):
def test_STDP(self):
bm.random.seed()

class STDPNet(bp.DynamicalSystem):
def __init__(self, num_pre, num_post):
super().__init__()
Expand Down Expand Up @@ -45,10 +48,12 @@ def update(self, I_pre, I_post):
[10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250])

net = STDPNet(1, 1)

def run(i, I_pre, I_post):
pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post)
return pre_spike, post_spike, g, Apre, Apost, current, W

indices = bm.arange(0, duration, bm.dt)
bm.for_loop(run, [indices, I_pre, I_post], jit=True)
bm.clear_buffer_memory()
bm.clear_buffer_memory()

14 changes: 10 additions & 4 deletions brainpy/_src/initialize/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def _is_scalar(x):
return isinstance(x, (float, int, bool, complex))


def _check_var(x):
if isinstance(x, bm.Variable):
x.ready_to_trace = True
return x


def parameter(
param: Union[Callable, Initializer, bm.Array, np.ndarray, jax.Array, float, int, bool],
sizes: Shape,
Expand Down Expand Up @@ -74,10 +80,10 @@ def parameter(
return param

if callable(param):
# param = param(sizes) # TODO
return bm.jit(param,
static_argnums=0,
out_shardings=bm.sharding.get_sharding(sharding))(sizes)
v = bm.jit(param,
static_argnums=0,
out_shardings=bm.sharding.get_sharding(sharding))(sizes)
return _check_var(v) # TODO: checking the Variable need to be traced

elif isinstance(param, (np.ndarray, jnp.ndarray)):
param = bm.asarray(param)
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def fun(self):
from brainpy.initialize import variable_
with jax.ensure_compile_time_eval():
value = variable_(init, shape, batch_or_mode, batch_axis, axis_names, batch_axis_name)
value._ready_to_trace = True
value.ready_to_trace = True
self.setattr(name, value)
return value

Expand Down
4 changes: 2 additions & 2 deletions tests/training/test_ESN.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class ESN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden, num_out):
super(ESN, self).__init__()
self.r = bp.dnn.Reservoir(num_in,
self.r = bp.dyn.Reservoir(num_in,
num_hidden,
Win_initializer=bp.init.Uniform(-0.1, 0.1),
Wrec_initializer=bp.init.Normal(scale=0.1),
Expand All @@ -26,7 +26,7 @@ class NGRC(bp.DynamicalSystem):
def __init__(self, num_in, num_out):
super(NGRC, self).__init__()

self.r = bp.dnn.NVAR(num_in, delay=2, order=2)
self.r = bp.dyn.NVAR(num_in, delay=2, order=2)
self.o = bp.dnn.Dense(self.r.num_out, num_out,
W_initializer=bp.init.Normal(0.1),
mode=bm.training_mode)
Expand Down

0 comments on commit 44c3e4b

Please sign in to comment.