Skip to content

Commit

Permalink
Merge pull request #489 from chaoming0625/master
Browse files Browse the repository at this point in the history
Decouple Online and Offline training algorithms as ``brainpy.mixin.SupportOnline`` and `brainpy.mixin.SupportOffline`
  • Loading branch information
chaoming0625 authored Sep 11, 2023
2 parents 29f2262 + 531811f commit 7b1faf2
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 65 deletions.
14 changes: 4 additions & 10 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from brainpy import math as bm
from brainpy._src import connect, initialize as init
from brainpy._src.context import share
from brainpy.algorithms import OnlineAlgorithm, OfflineAlgorithm
from brainpy.check import is_initializer
from brainpy.errors import MathError
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.types import ArrayType, Sharding
from brainpy._src.dnn.base import Layer
from brainpy._src.mixin import SupportOnline, SupportOffline

__all__ = [
'Dense', 'Linear',
Expand All @@ -29,7 +29,7 @@
]


class Dense(Layer):
class Dense(Layer, SupportOnline, SupportOffline):
r"""A linear transformation applied over the last dimension of the input.
Mathematically, this node can be defined as:
Expand All @@ -52,12 +52,6 @@ class Dense(Layer):
Enable training this node or not. (default True)
"""

online_fit_by: Optional[OnlineAlgorithm]
'''Online fitting method.'''

offline_fit_by: Optional[OfflineAlgorithm]
'''Offline fitting method.'''

def __init__(
self,
num_in: int,
Expand Down Expand Up @@ -95,8 +89,8 @@ def __init__(
self.b = b

# fitting parameters
self.online_fit_by = None
self.offline_fit_by = None
self.online_fit_by = None # support online training
self.offline_fit_by = None # support offline training
self.fit_record = dict()

def __repr__(self):
Expand Down
6 changes: 3 additions & 3 deletions brainpy/_src/dyn/base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
# -*- coding: utf-8 -*-

from brainpy._src.dynsys import Dynamic
from brainpy._src.mixin import AutoDelaySupp, ParamDesc
from brainpy._src.mixin import SupportAutoDelay, ParamDesc

__all__ = [
'NeuDyn', 'SynDyn', 'IonChaDyn',
]


class NeuDyn(Dynamic, AutoDelaySupp):
class NeuDyn(Dynamic, SupportAutoDelay):
"""Neuronal Dynamics."""
pass


class SynDyn(Dynamic, AutoDelaySupp, ParamDesc):
class SynDyn(Dynamic, SupportAutoDelay, ParamDesc):
"""Synaptic Dynamics."""
pass

Expand Down
28 changes: 14 additions & 14 deletions brainpy/_src/dyn/projections/aligns.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from brainpy._src.delay import Delay, DelayAccess, delay_identifier, init_delay_by_return
from brainpy._src.dynsys import DynamicalSystem, Projection
from brainpy._src.mixin import (JointType, ParamDescInit, ReturnInfo,
AutoDelaySupp, BindCondData, AlignPost)
SupportAutoDelay, BindCondData, AlignPost)

__all__ = [
'VanillaProj',
Expand Down Expand Up @@ -297,7 +297,7 @@ def update(self, inp):

def __init__(
self,
pre: JointType[DynamicalSystem, AutoDelaySupp],
pre: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
comm: DynamicalSystem,
syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]],
Expand All @@ -310,7 +310,7 @@ def __init__(
super().__init__(name=name, mode=mode)

# synaptic models
check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp])
check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(comm, DynamicalSystem)
check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AlignPost]])
check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]])
Expand Down Expand Up @@ -507,7 +507,7 @@ def update(self, inp):

def __init__(
self,
pre: JointType[DynamicalSystem, AutoDelaySupp],
pre: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
comm: DynamicalSystem,
syn: JointType[DynamicalSystem, AlignPost],
Expand All @@ -520,7 +520,7 @@ def __init__(
super().__init__(name=name, mode=mode)

# synaptic models
check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp])
check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(comm, DynamicalSystem)
check.is_instance(syn, JointType[DynamicalSystem, AlignPost])
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
Expand Down Expand Up @@ -631,7 +631,7 @@ def update(self, inp):
def __init__(
self,
pre: DynamicalSystem,
syn: ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]],
syn: ParamDescInit[JointType[DynamicalSystem, SupportAutoDelay]],
delay: Union[None, int, float],
comm: DynamicalSystem,
out: JointType[DynamicalSystem, BindCondData],
Expand All @@ -644,7 +644,7 @@ def __init__(

# synaptic models
check.is_instance(pre, DynamicalSystem)
check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]])
check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, SupportAutoDelay]])
check.is_instance(comm, DynamicalSystem)
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
check.is_instance(post, DynamicalSystem)
Expand All @@ -654,7 +654,7 @@ def __init__(
self._syn_id = f'{syn.identifier} // Delay'
if not pre.has_aft_update(self._syn_id):
# "syn_cls" needs an instance of "ProjAutoDelay"
syn_cls: AutoDelaySupp = syn()
syn_cls: SupportAutoDelay = syn()
delay_cls = init_delay_by_return(syn_cls.return_info())
# add to "after_updates"
pre.add_aft_update(self._syn_id, _AlignPre(syn_cls, delay_cls))
Expand Down Expand Up @@ -755,7 +755,7 @@ def update(self, inp):

def __init__(
self,
pre: JointType[DynamicalSystem, AutoDelaySupp],
pre: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
syn: ParamDescInit[DynamicalSystem],
comm: DynamicalSystem,
Expand All @@ -768,7 +768,7 @@ def __init__(
super().__init__(name=name, mode=mode)

# synaptic models
check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp])
check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(syn, ParamDescInit[DynamicalSystem])
check.is_instance(comm, DynamicalSystem)
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
Expand Down Expand Up @@ -884,7 +884,7 @@ def update(self, inp):
def __init__(
self,
pre: DynamicalSystem,
syn: JointType[DynamicalSystem, AutoDelaySupp],
syn: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
comm: DynamicalSystem,
out: JointType[DynamicalSystem, BindCondData],
Expand All @@ -897,7 +897,7 @@ def __init__(

# synaptic models
check.is_instance(pre, DynamicalSystem)
check.is_instance(syn, JointType[DynamicalSystem, AutoDelaySupp])
check.is_instance(syn, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(comm, DynamicalSystem)
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
check.is_instance(post, DynamicalSystem)
Expand Down Expand Up @@ -1002,7 +1002,7 @@ def update(self, inp):

def __init__(
self,
pre: JointType[DynamicalSystem, AutoDelaySupp],
pre: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
syn: DynamicalSystem,
comm: DynamicalSystem,
Expand All @@ -1015,7 +1015,7 @@ def __init__(
super().__init__(name=name, mode=mode)

# synaptic models
check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp])
check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(syn, DynamicalSystem)
check.is_instance(comm, DynamicalSystem)
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/dynold/synapses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.initialize import parameter
from brainpy._src.mixin import (ParamDesc, JointType,
AutoDelaySupp, BindCondData, ReturnInfo)
SupportAutoDelay, BindCondData, ReturnInfo)
from brainpy.errors import UnsupportedError
from brainpy.types import ArrayType

Expand Down Expand Up @@ -109,7 +109,7 @@ def update(self):
pass


class _SynSTP(_SynapseComponent, ParamDesc, AutoDelaySupp):
class _SynSTP(_SynapseComponent, ParamDesc, SupportAutoDelay):
"""Base class for synaptic short-term plasticity."""

def update(self, pre_spike):
Expand Down
8 changes: 4 additions & 4 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from brainpy import tools, math as bm
from brainpy._src.initialize import parameter, variable_
from brainpy._src.mixin import AutoDelaySupp, Container, ReceiveInputProj, DelayRegister, global_delay_data
from brainpy._src.mixin import SupportAutoDelay, Container, ReceiveInputProj, DelayRegister, global_delay_data
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape
from brainpy._src.deprecations import _update_deprecate_msg
Expand Down Expand Up @@ -487,7 +487,7 @@ class Network(DynSysGroup):
pass


class Sequential(DynamicalSystem, AutoDelaySupp, Container):
class Sequential(DynamicalSystem, SupportAutoDelay, Container):
"""A sequential `input-output` module.
Modules will be added to it in the order they are passed in the
Expand Down Expand Up @@ -557,9 +557,9 @@ def update(self, x):

def return_info(self):
last = self[-1]
if not isinstance(last, AutoDelaySupp):
if not isinstance(last, SupportAutoDelay):
raise UnsupportedError(f'Does not support "return_info()" because the last node is '
f'not instance of {AutoDelaySupp.__name__}')
f'not instance of {SupportAutoDelay.__name__}')
return last.return_info()

def __getitem__(self, key: Union[int, slice, str]):
Expand Down
Loading

0 comments on commit 7b1faf2

Please sign in to comment.