diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 3bdc3a31c..45e784a50 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -10,12 +10,12 @@ from brainpy import math as bm from brainpy._src import connect, initialize as init from brainpy._src.context import share -from brainpy.algorithms import OnlineAlgorithm, OfflineAlgorithm from brainpy.check import is_initializer from brainpy.errors import MathError from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter from brainpy.types import ArrayType, Sharding from brainpy._src.dnn.base import Layer +from brainpy._src.mixin import SupportOnline, SupportOffline __all__ = [ 'Dense', 'Linear', @@ -29,7 +29,7 @@ ] -class Dense(Layer): +class Dense(Layer, SupportOnline, SupportOffline): r"""A linear transformation applied over the last dimension of the input. Mathematically, this node can be defined as: @@ -52,12 +52,6 @@ class Dense(Layer): Enable training this node or not. (default True) """ - online_fit_by: Optional[OnlineAlgorithm] - '''Online fitting method.''' - - offline_fit_by: Optional[OfflineAlgorithm] - '''Offline fitting method.''' - def __init__( self, num_in: int, @@ -95,8 +89,8 @@ def __init__( self.b = b # fitting parameters - self.online_fit_by = None - self.offline_fit_by = None + self.online_fit_by = None # support online training + self.offline_fit_by = None # support offline training self.fit_record = dict() def __repr__(self): diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index e318eee4b..e18ac2a82 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -1,19 +1,19 @@ # -*- coding: utf-8 -*- from brainpy._src.dynsys import Dynamic -from brainpy._src.mixin import AutoDelaySupp, ParamDesc +from brainpy._src.mixin import SupportAutoDelay, ParamDesc __all__ = [ 'NeuDyn', 'SynDyn', 'IonChaDyn', ] -class NeuDyn(Dynamic, AutoDelaySupp): +class NeuDyn(Dynamic, SupportAutoDelay): """Neuronal Dynamics.""" pass -class SynDyn(Dynamic, AutoDelaySupp, ParamDesc): +class SynDyn(Dynamic, SupportAutoDelay, ParamDesc): """Synaptic Dynamics.""" pass diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py index 2dfa2dd14..d0ff37d64 100644 --- a/brainpy/_src/dyn/projections/aligns.py +++ b/brainpy/_src/dyn/projections/aligns.py @@ -4,7 +4,7 @@ from brainpy._src.delay import Delay, DelayAccess, delay_identifier, init_delay_by_return from brainpy._src.dynsys import DynamicalSystem, Projection from brainpy._src.mixin import (JointType, ParamDescInit, ReturnInfo, - AutoDelaySupp, BindCondData, AlignPost) + SupportAutoDelay, BindCondData, AlignPost) __all__ = [ 'VanillaProj', @@ -297,7 +297,7 @@ def update(self, inp): def __init__( self, - pre: JointType[DynamicalSystem, AutoDelaySupp], + pre: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], comm: DynamicalSystem, syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]], @@ -310,7 +310,7 @@ def __init__( super().__init__(name=name, mode=mode) # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp]) + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) check.is_instance(comm, DynamicalSystem) check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AlignPost]]) check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]]) @@ -507,7 +507,7 @@ def update(self, inp): def __init__( self, - pre: JointType[DynamicalSystem, AutoDelaySupp], + pre: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], comm: DynamicalSystem, syn: JointType[DynamicalSystem, AlignPost], @@ -520,7 +520,7 @@ def __init__( super().__init__(name=name, mode=mode) # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp]) + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) check.is_instance(comm, DynamicalSystem) check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) @@ -631,7 +631,7 @@ def update(self, inp): def __init__( self, pre: DynamicalSystem, - syn: ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]], + syn: ParamDescInit[JointType[DynamicalSystem, SupportAutoDelay]], delay: Union[None, int, float], comm: DynamicalSystem, out: JointType[DynamicalSystem, BindCondData], @@ -644,7 +644,7 @@ def __init__( # synaptic models check.is_instance(pre, DynamicalSystem) - check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]]) + check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, SupportAutoDelay]]) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) check.is_instance(post, DynamicalSystem) @@ -654,7 +654,7 @@ def __init__( self._syn_id = f'{syn.identifier} // Delay' if not pre.has_aft_update(self._syn_id): # "syn_cls" needs an instance of "ProjAutoDelay" - syn_cls: AutoDelaySupp = syn() + syn_cls: SupportAutoDelay = syn() delay_cls = init_delay_by_return(syn_cls.return_info()) # add to "after_updates" pre.add_aft_update(self._syn_id, _AlignPre(syn_cls, delay_cls)) @@ -755,7 +755,7 @@ def update(self, inp): def __init__( self, - pre: JointType[DynamicalSystem, AutoDelaySupp], + pre: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], syn: ParamDescInit[DynamicalSystem], comm: DynamicalSystem, @@ -768,7 +768,7 @@ def __init__( super().__init__(name=name, mode=mode) # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp]) + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) check.is_instance(syn, ParamDescInit[DynamicalSystem]) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) @@ -884,7 +884,7 @@ def update(self, inp): def __init__( self, pre: DynamicalSystem, - syn: JointType[DynamicalSystem, AutoDelaySupp], + syn: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], comm: DynamicalSystem, out: JointType[DynamicalSystem, BindCondData], @@ -897,7 +897,7 @@ def __init__( # synaptic models check.is_instance(pre, DynamicalSystem) - check.is_instance(syn, JointType[DynamicalSystem, AutoDelaySupp]) + check.is_instance(syn, JointType[DynamicalSystem, SupportAutoDelay]) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) check.is_instance(post, DynamicalSystem) @@ -1002,7 +1002,7 @@ def update(self, inp): def __init__( self, - pre: JointType[DynamicalSystem, AutoDelaySupp], + pre: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], syn: DynamicalSystem, comm: DynamicalSystem, @@ -1015,7 +1015,7 @@ def __init__( super().__init__(name=name, mode=mode) # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp]) + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) check.is_instance(syn, DynamicalSystem) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py index 145eec585..c212884b7 100644 --- a/brainpy/_src/dynold/synapses/base.py +++ b/brainpy/_src/dynold/synapses/base.py @@ -11,7 +11,7 @@ from brainpy._src.dynsys import DynamicalSystem from brainpy._src.initialize import parameter from brainpy._src.mixin import (ParamDesc, JointType, - AutoDelaySupp, BindCondData, ReturnInfo) + SupportAutoDelay, BindCondData, ReturnInfo) from brainpy.errors import UnsupportedError from brainpy.types import ArrayType @@ -109,7 +109,7 @@ def update(self): pass -class _SynSTP(_SynapseComponent, ParamDesc, AutoDelaySupp): +class _SynSTP(_SynapseComponent, ParamDesc, SupportAutoDelay): """Base class for synaptic short-term plasticity.""" def update(self, pre_spike): diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index 78ea721c7..a7e7d86d9 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -10,7 +10,7 @@ from brainpy import tools, math as bm from brainpy._src.initialize import parameter, variable_ -from brainpy._src.mixin import AutoDelaySupp, Container, ReceiveInputProj, DelayRegister, global_delay_data +from brainpy._src.mixin import SupportAutoDelay, Container, ReceiveInputProj, DelayRegister, global_delay_data from brainpy.errors import NoImplementationError, UnsupportedError from brainpy.types import ArrayType, Shape from brainpy._src.deprecations import _update_deprecate_msg @@ -487,7 +487,7 @@ class Network(DynSysGroup): pass -class Sequential(DynamicalSystem, AutoDelaySupp, Container): +class Sequential(DynamicalSystem, SupportAutoDelay, Container): """A sequential `input-output` module. Modules will be added to it in the order they are passed in the @@ -557,9 +557,9 @@ def update(self, x): def return_info(self): last = self[-1] - if not isinstance(last, AutoDelaySupp): + if not isinstance(last, SupportAutoDelay): raise UnsupportedError(f'Does not support "return_info()" because the last node is ' - f'not instance of {AutoDelaySupp.__name__}') + f'not instance of {SupportAutoDelay.__name__}') return last.return_info() def __getitem__(self, key: Union[int, slice, str]): diff --git a/brainpy/_src/math/surrogate/_one_input.py b/brainpy/_src/math/surrogate/_one_input.py index 055f0fef9..23f151ee0 100644 --- a/brainpy/_src/math/surrogate/_one_input.py +++ b/brainpy/_src/math/surrogate/_one_input.py @@ -38,7 +38,9 @@ class Sigmoid(Surrogate): """Spike function with the sigmoid-shaped surrogate gradient. - Also see :py:class:`~.sigmoid`. + See Also + -------- + sigmoid """ def __init__(self, alpha=4., origin=False): @@ -125,7 +127,9 @@ def grad(dz): class PiecewiseQuadratic(Surrogate): """Judge spiking state with a piecewise quadratic function. - Also see :py:class:`~.piecewise_quadratic`. + See Also + -------- + piecewise_quadratic """ def __init__(self, alpha=1., origin=False): @@ -232,7 +236,9 @@ def grad(dz): class PiecewiseExp(Surrogate): """Judge spiking state with a piecewise exponential function. - Also see :py:class:`~.piecewise_exp`. + See Also + -------- + piecewise_exp """ def __init__(self, alpha=1., origin=False): self.alpha = alpha @@ -324,7 +330,9 @@ def grad(dz): class SoftSign(Surrogate): """Judge spiking state with a soft sign function. - Also see :py:class:`~.soft_sign`. + See Also + -------- + soft_sign """ def __init__(self, alpha=1., origin=False): self.alpha = alpha @@ -411,7 +419,9 @@ def grad(dz): class Arctan(Surrogate): """Judge spiking state with an arctan function. - Also see :py:class:`~.arctan`. + See Also + -------- + arctan """ def __init__(self, alpha=1., origin=False): self.alpha = alpha @@ -497,7 +507,9 @@ def grad(dz): class NonzeroSignLog(Surrogate): """Judge spiking state with a nonzero sign log function. - Also see :py:class:`~.nonzero_sign_log`. + See Also + -------- + nonzero_sign_log """ def __init__(self, alpha=1., origin=False): self.alpha = alpha @@ -596,7 +608,9 @@ def grad(dz): class ERF(Surrogate): """Judge spiking state with an erf function. - Also see :py:class:`~.erf`. + See Also + -------- + erf """ def __init__(self, alpha=1., origin=False): self.alpha = alpha @@ -692,7 +706,9 @@ def grad(dz): class PiecewiseLeakyRelu(Surrogate): """Judge spiking state with a piecewise leaky relu function. - Also see :py:class:`~.piecewise_leaky_relu`. + See Also + -------- + piecewise_leaky_relu """ def __init__(self, c=0.01, w=1., origin=False): self.c = c @@ -807,7 +823,9 @@ def grad(dz): class SquarewaveFourierSeries(Surrogate): """Judge spiking state with a squarewave fourier series. - Also see :py:class:`~.squarewave_fourier_series`. + See Also + -------- + squarewave_fourier_series """ def __init__(self, n=2, t_period=8., origin=False): self.n = n @@ -903,7 +921,9 @@ def grad(dz): class S2NN(Surrogate): """Judge spiking state with the S2NN surrogate spiking function. - Also see :py:class:`~.s2nn`. + See Also + -------- + s2nn """ def __init__(self, alpha=4., beta=1., epsilon=1e-8, origin=False): self.alpha = alpha @@ -1013,7 +1033,9 @@ def grad(dz): class QPseudoSpike(Surrogate): """Judge spiking state with the q-PseudoSpike surrogate function. - Also see :py:class:`~.q_pseudo_spike`. + See Also + -------- + q_pseudo_spike """ def __init__(self, alpha=2., origin=False): self.alpha = alpha @@ -1110,7 +1132,9 @@ def grad(dz): class LeakyRelu(Surrogate): """Judge spiking state with the Leaky ReLU function. - Also see :py:class:`~.leaky_relu`. + See Also + -------- + leaky_relu """ def __init__(self, alpha=0.1, beta=1., origin=False): self.alpha = alpha @@ -1208,7 +1232,9 @@ def grad(dz): class LogTailedRelu(Surrogate): """Judge spiking state with the Log-tailed ReLU function. - Also see :py:class:`~.log_tailed_relu`. + See Also + -------- + log_tailed_relu """ def __init__(self, alpha=0., origin=False): self.alpha = alpha @@ -1316,7 +1342,9 @@ def grad(dz): class ReluGrad(Surrogate): """Judge spiking state with the ReLU gradient function. - Also see :py:class:`~.relu_grad`. + See Also + -------- + relu_grad """ def __init__(self, alpha=0.3, width=1.): self.alpha = alpha @@ -1397,7 +1425,9 @@ def grad(dz): class GaussianGrad(Surrogate): """Judge spiking state with the Gaussian gradient function. - Also see :py:class:`~.gaussian_grad`. + See Also + -------- + gaussian_grad """ def __init__(self, sigma=0.5, alpha=0.5): self.sigma = sigma @@ -1477,7 +1507,9 @@ def grad(dz): class MultiGaussianGrad(Surrogate): """Judge spiking state with the multi-Gaussian gradient function. - Also see :py:class:`~.multi_gaussian_grad`. + See Also + -------- + multi_gaussian_grad """ def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5): self.h = h @@ -1571,7 +1603,9 @@ def grad(dz): class InvSquareGrad(Surrogate): """Judge spiking state with the inverse-square surrogate gradient function. - Also see :py:class:`~.inv_square_grad`. + See Also + -------- + inv_square_grad """ def __init__(self, alpha=100.): self.alpha = alpha @@ -1643,7 +1677,9 @@ def grad(dz): class SlayerGrad(Surrogate): """Judge spiking state with the slayer surrogate gradient function. - Also see :py:class:`~.slayer_grad`. + See Also + -------- + slayer_grad """ def __init__(self, alpha=1.): self.alpha = alpha diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index fce2aca18..124bf3d20 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -28,7 +28,7 @@ 'ParamDesc', 'ParamDescInit', 'AlignPost', - 'AutoDelaySupp', + 'SupportAutoDelay', 'Container', 'TreeNode', 'BindCondData', @@ -207,7 +207,7 @@ def get_data(self): return init -class AutoDelaySupp(MixIn): +class SupportAutoDelay(MixIn): """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`.""" def return_info(self) -> Union[bm.Variable, ReturnInfo]: @@ -347,7 +347,7 @@ def register_delay_at( if delay_identifier is None: from brainpy._src.delay import delay_identifier if DynamicalSystem is None: from brainpy._src.dynsys import DynamicalSystem - assert isinstance(self, AutoDelaySupp), f'self must be an instance of {AutoDelaySupp.__name__}' + assert isinstance(self, SupportAutoDelay), f'self must be an instance of {SupportAutoDelay.__name__}' assert isinstance(self, DynamicalSystem), f'self must be an instance of {DynamicalSystem.__name__}' if not self.has_aft_update(delay_identifier): self.add_aft_update(delay_identifier, init_delay_by_return(self.return_info())) @@ -549,6 +549,27 @@ def get_delay_var(self, name): return global_delay_data[name] +class SupportOnline(MixIn): + """:py:class:`~.MixIn` to support the online training methods.""" + + online_fit_by: Optional # methods for online fitting + + def online_init(self): + raise NotImplementedError + + def online_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): + raise NotImplementedError + + +class SupportOffline(MixIn): + """:py:class:`~.MixIn` to support the offline training methods.""" + + offline_fit_by: Optional # methods for offline fitting + + def offline_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): + raise NotImplementedError + + class BindCondData(MixIn): """Bind temporary conductance data. """ @@ -598,7 +619,7 @@ class UnionType2(MixIn): >>> import brainpy as bp >>> - >>> isinstance(bp.dyn.Expon(1), JointType[bp.DynamicalSystem, bp.mixin.ParamDesc, bp.mixin.AutoDelaySupp]) + >>> isinstance(bp.dyn.Expon(1), JointType[bp.DynamicalSystem, bp.mixin.ParamDesc, bp.mixin.SupportAutoDelay]) """ @classmethod diff --git a/brainpy/mixin.py b/brainpy/mixin.py index a3f17c7aa..82fd9f6ff 100644 --- a/brainpy/mixin.py +++ b/brainpy/mixin.py @@ -1,13 +1,13 @@ from brainpy._src.mixin import ( - MixIn as MixIn, - ReceiveInputProj as ReceiveInputProj, - AlignPost as AlignPost, - AutoDelaySupp as AutoDelaySupp, - ParamDesc as ParamDesc, - ParamDescInit as ParamDescInit, - BindCondData as BindCondData, - Container as Container, - TreeNode as TreeNode, - JointType as JointType, + MixIn as MixIn, + ReceiveInputProj as ReceiveInputProj, + AlignPost as AlignPost, + SupportAutoDelay as AutoDelaySupp, + ParamDesc as ParamDesc, + ParamDescInit as ParamDescInit, + BindCondData as BindCondData, + Container as Container, + TreeNode as TreeNode, + JointType as JointType, )