Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/brainpy/BrainPy
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Sep 22, 2023
2 parents 474ad2b + e6373e8 commit 572312f
Show file tree
Hide file tree
Showing 9 changed files with 434 additions and 39 deletions.
14 changes: 6 additions & 8 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,9 +915,8 @@ def _valid_jaxtype(arg):

def _check_output_dtype_revderiv(name, holomorphic, x):
aval = core.get_aval(x)
if core.is_opaque_dtype(aval.dtype):
raise TypeError(
f"{name} with output element type {aval.dtype.name}")
# if jnp.issubdtype(aval.dtype, dtypes.extended):
# raise TypeError(f"{name} with output element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, "
Expand All @@ -938,9 +937,8 @@ def _check_output_dtype_revderiv(name, holomorphic, x):
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
_check_arg(x)
aval = core.get_aval(x)
if core.is_opaque_dtype(aval.dtype):
raise TypeError(
f"{name} with input element type {aval.dtype.name}")
# if jnp.issubdtype(aval.dtype, dtypes.extended):
# raise TypeError(f"{name} with input element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
Expand Down Expand Up @@ -972,8 +970,8 @@ def _check_output_dtype_jacfwd(holomorphic, x):
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
_check_arg(x)
aval = core.get_aval(x)
if core.is_opaque_dtype(aval.dtype):
raise TypeError(f"jacfwd with input element type {aval.dtype.name}")
# if jnp.issubdtype(aval.dtype, dtypes.extended):
# raise TypeError(f"jacfwd with input element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError("jacfwd with holomorphic=True requires inputs with complex "
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/op_registers/numba_approach/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class XLACustomOp(BrainPyObject):
cpu_func: callable
The function defines the computation on CPU backend. Same as ``con_compute``.
gpu_func: callable
The function defines the computation on GPU backend. Currently, this function is not supportted.
The function defines the computation on GPU backend. Currently, this function is not supported.
apply_cpu_func_to_gpu: bool
Whether allows to apply CPU function on GPU backend. If True, the GPU data will move to CPU,
and after calculation, the returned outputs on CPU backend will move to GPU.
Expand Down
39 changes: 15 additions & 24 deletions brainpy/_src/math/op_registers/tests/test_ei_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
from jax.core import ShapedArray


bm.set_platform('cpu')


def abs_eval(events, indices, indptr, *, weight, post_num):
return [ShapedArray((post_num,), bm.float32), ]

Expand All @@ -25,7 +22,7 @@ def con_compute(outs, ins):
event_sum = bm.XLACustomOp(eval_shape=abs_eval, cpu_func=con_compute)


class ExponentialV2(bp.TwoEndConn):
class ExponentialV2(bp.synapses.TwoEndConn):
"""Exponential synapse model using customized operator written in C++."""

def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.):
Expand All @@ -46,8 +43,8 @@ def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.):
# function
self.integral = bp.odeint(lambda g, t: -g / self.tau, method='exp_auto')

def update(self, tdi):
self.g.value = self.integral(self.g, tdi.t, tdi.dt)
def update(self):
self.g.value = self.integral(self.g, bp.share['t'])
self.g += event_sum(self.pre.spike,
self.pre2post[0],
self.pre2post[1],
Expand All @@ -56,31 +53,25 @@ def update(self, tdi):
self.post.input += self.g * (self.E - self.post.V)


class EINet(bp.Network):
class EINet(bp.DynSysGroup):
def __init__(self, scale):
super().__init__()
# neurons
bm.random.seed()
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
E = bp.neurons.LIF(int(3200 * scale), **pars, method='exp_auto')
I = bp.neurons.LIF(int(800 * scale), **pars, method='exp_auto')
V_initializer=bp.init.Normal(-55., 2.), method='exp_auto')
self.E = bp.neurons.LIF(int(3200 * scale), **pars)
self.I = bp.neurons.LIF(int(800 * scale), **pars)

# synapses
E2E = ExponentialV2(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.)
E2I = ExponentialV2(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.)
I2E = ExponentialV2(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.)
I2I = ExponentialV2(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.)

super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
self.E2E = ExponentialV2(self.E, self.E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.)
self.E2I = ExponentialV2(self.E, self.I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.)
self.I2E = ExponentialV2(self.I, self.E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.)
self.I2I = ExponentialV2(self.I, self.I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.)


def test1():
bm.random.seed()
bm.set_platform('cpu')
net2 = EINet(scale=0.1)
runner2 = bp.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)])
r = runner2.predict(100., eval_time=True)
runner = bp.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)])
r = runner.predict(100., eval_time=True)
bm.clear_buffer_memory()




2 changes: 2 additions & 0 deletions docs/tutorial_advanced/3_dedicated_operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ Brain Dynamics Dedicated Operators

.. toctree::
:maxdepth: 1

low-level_operator_customization.ipynb
Loading

0 comments on commit 572312f

Please sign in to comment.