Skip to content

Commit

Permalink
[dyn] add STDP_Song2020 LTP model
Browse files Browse the repository at this point in the history
  • Loading branch information
ztqakita committed Sep 10, 2023
1 parent 6cfe6f8 commit 4c68b94
Show file tree
Hide file tree
Showing 9 changed files with 640 additions and 25 deletions.
4 changes: 2 additions & 2 deletions brainpy/_src/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def retrieve(self, delay_step, *indices):

if self.method == ROTATE_UPDATE:
i = share.load('i')
delay_idx = bm.as_jax((delay_step - i - 1) % self.max_length)
delay_idx = bm.as_jax((delay_step - i - 1) % self.max_length, dtype=jnp.int32)
delay_idx = jax.lax.stop_gradient(delay_idx)

elif self.method == CONCAT_UPDATE:
Expand Down Expand Up @@ -358,7 +358,7 @@ def update(
# update the delay data at the rotation index
if self.method == ROTATE_UPDATE:
i = share.load('i')
idx = bm.as_jax((-i - 1) % self.max_length)
idx = bm.as_jax((-i - 1) % self.max_length, dtype=jnp.int32)
self.data[idx] = latest_value

# update the delay data at the first position
Expand Down
135 changes: 114 additions & 21 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from brainpy.algorithms import OnlineAlgorithm, OfflineAlgorithm
from brainpy.check import is_initializer
from brainpy.errors import MathError
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter, variable_
from brainpy.types import ArrayType, Sharding
from brainpy._src.dnn.base import Layer
from brainpy._src.mixin import SupportPlasticity
from brainpy._src.connect import mat2coo

__all__ = [
'Dense', 'Linear',
Expand All @@ -29,22 +31,22 @@
]


class Dense(Layer):
class Dense(Layer, SupportPlasticity):
r"""A linear transformation applied over the last dimension of the input.
Mathematically, this node can be defined as:
.. math::
y = x \cdot W + b
y = x \cdot weight + b
Parameters
----------
num_in: int
The number of the input feature. A positive integer.
num_out: int
The number of the output features. A positive integer.
W_initializer: optional, Initializer
weight_initializer: optional, Initializer
The weight initialization.
b_initializer: optional, Initializer
The bias initialization.
Expand All @@ -62,7 +64,7 @@ def __init__(
self,
num_in: int,
num_out: int,
W_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(),
weight_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(),
b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(),
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
Expand All @@ -80,18 +82,18 @@ def __init__(
f'a positive integer. Received: num_out={num_out}')

# weight initializer
self.weight_initializer = W_initializer
self.weight_initializer = weight_initializer
self.bias_initializer = b_initializer
is_initializer(W_initializer, 'weight_initializer')
is_initializer(weight_initializer, 'weight_initializer')
is_initializer(b_initializer, 'bias_initializer', allow_none=True)

# parameter initialization
W = parameter(self.weight_initializer, (num_in, self.num_out))
weight = parameter(self.weight_initializer, (num_in, self.num_out))
b = parameter(self.bias_initializer, (self.num_out,))
if isinstance(self.mode, bm.TrainingMode):
W = bm.TrainVar(W)
weight = bm.TrainVar(weight)
b = None if (b is None) else bm.TrainVar(b)
self.W = W
self.weight = weight
self.b = b

# fitting parameters
Expand All @@ -107,7 +109,7 @@ def __repr__(self):

def update(self, x):
x = bm.as_jax(x)
res = x @ self.W
res = x @ self.weight
if self.b is not None:
res += self.b

Expand Down Expand Up @@ -158,11 +160,11 @@ def online_fit(self,

# assign trained weights
if self.b is None:
self.W += dW
self.weight += dW
else:
db, dW = jnp.split(dW, [1])
self.b += db[0]
self.W += dW
self.weight += dW

def offline_fit(self,
target: ArrayType,
Expand Down Expand Up @@ -198,12 +200,26 @@ def offline_fit(self,

# assign trained weights
if self.b is None:
self.W.value = weights
self.weight.value = weights
else:
bias, Wff = jnp.split(weights, [1])
self.W.value = Wff
self.weight.value = Wff
self.b.value = bias[0]

def plasticity(self, dW, constraints=None):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
if self.weight.shape != dW.shape:
raise ValueError(f'The shape of delta_weight {dW.shape} '
f'should be the same as the shape of weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
self.weight += dW
if constraints is not None:
self.weight.value = constraints(self.weight)


Linear = Dense

Expand All @@ -219,7 +235,7 @@ def update(self, x):
return x


class AllToAll(Layer):
class AllToAll(Layer, SupportPlasticity):
"""Synaptic matrix multiplication with All2All connections.
Args:
Expand Down Expand Up @@ -281,8 +297,23 @@ def update(self, pre_val):
post_val = pre_val @ self.weight
return post_val

def plasticity(self, dW, constraints=None):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
if self.weight.shape != dW.shape:
raise ValueError(f'The shape of delta_weight {dW.shape} '
f'should be the same as the shape of weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
self.weight += dW
if constraints is not None:
self.weight.value = constraints(self.weight)



class OneToOne(Layer):
class OneToOne(Layer, SupportPlasticity):
"""Synaptic matrix multiplication with One2One connection.
Args:
Expand Down Expand Up @@ -315,8 +346,23 @@ def __init__(
def update(self, pre_val):
return pre_val * self.weight


class MaskedLinear(Layer):
def plasticity(self, dW, constraints=None):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
dW = dW.sum(axis=0)
if self.weight.shape != dW.shape:
raise ValueError(f'The shape of delta_weight {dW.shape} '
f'should be the same as the shape of weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
self.weight += dW
if constraints is not None:
self.weight.value = constraints(self.weight)


class MaskedLinear(Layer, SupportPlasticity):
r"""Synaptic matrix multiplication with masked dense computation.
It performs the computation of:
Expand Down Expand Up @@ -369,8 +415,23 @@ def __init__(
def update(self, x):
return x @ self.mask_fun(self.weight * self.mask)

def plasticity(self, dW, constraints=None):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
if self.weight.shape != dW.shape:
raise ValueError(f'The shape of delta_weight {dW.shape} '
f'should be the same as the shape of weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)

self.weight += dW
if constraints is not None:
self.weight.value = constraints(self.weight)


class CSRLinear(Layer):
class CSRLinear(Layer, SupportPlasticity):
r"""Synaptic matrix multiplication with CSR sparse computation.
It performs the computation of:
Expand Down Expand Up @@ -438,6 +499,22 @@ def _batch_csrmv(self, x):
transpose=self.transpose,
method=self.method)

def plasticity(self, dW, constraints=None):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
pre_ids, post_ids = bm.sparse.csr_to_coo(self.indices, self.indptr)
sparse_dW = dW[pre_ids, post_ids]
if self.weight.shape != sparse_dW.shape:
raise ValueError(f'The shape of sparse delta_weight {sparse_dW.shape} '
f'should be the same as the shape of sparse weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
self.weight += sparse_dW
if constraints is not None:
self.weight.value = constraints(self.weight)


class CSCLinear(Layer):
r"""Synaptic matrix multiplication with CSC sparse computation.
Expand Down Expand Up @@ -474,7 +551,7 @@ def __init__(
self.sharding = sharding


class EventCSRLinear(Layer):
class EventCSRLinear(Layer, SupportPlasticity):
r"""Synaptic matrix multiplication with event CSR sparse computation.
It performs the computation of:
Expand Down Expand Up @@ -538,6 +615,22 @@ def _batch_csrmv(self, x):
shape=(self.conn.pre_num, self.conn.post_num),
transpose=self.transpose)

def plasticity(self, dW, constraints=None):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
pre_ids, post_ids = bm.sparse.csr_to_coo(self.indices, self.indptr)
sparse_dW = dW[pre_ids, post_ids]
if self.weight.shape != sparse_dW.shape:
raise ValueError(f'The shape of sparse delta_weight {sparse_dW.shape} '
f'should be the same as the shape of sparse weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
self.weight += sparse_dW
if constraints is not None:
self.weight.value = constraints(self.weight)


class BcsrMM(Layer):
r"""Synaptic matrix multiplication with BCSR sparse computation.
Expand Down
Loading

0 comments on commit 4c68b94

Please sign in to comment.