Skip to content

Commit

Permalink
[reset] update logics of state reset in DynamicalSystem
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Sep 14, 2023
1 parent db6e376 commit 69df9da
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
4 changes: 2 additions & 2 deletions brainpy/_src/analysis/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ def update(self):
self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=share['dt']))

def __getattr__(self, item):
child_vars = super(TrajectModel, self).__getattribute__('implicit_vars')
child_vars = super().__getattribute__('implicit_vars')
if item in child_vars:
return child_vars[item]
else:
return super(TrajectModel, self).__getattribute__(item)
return super().__getattribute__(item)

def run(self, duration):
self.runner.run(duration)
Expand Down
11 changes: 5 additions & 6 deletions brainpy/_src/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,6 @@ def reset_state(self, batch_size: int = None):
# initialize delay data
if self.data is not None:
self._init_data(self.max_length, batch_size)
for cls in self.before_updates.values():
cls.reset_state(batch_size)
for cls in self.after_updates.values():
cls.reset_state(batch_size)

def _init_data(self, length: int, batch_size: int = None):
if batch_size is not None:
Expand Down Expand Up @@ -468,13 +464,16 @@ def __init__(
*indices
):
super().__init__(mode=delay.mode)
self.delay = delay
self.refs = {'delay': delay}
assert isinstance(delay, Delay)
delay.register_entry(self.name, time)
self.indices = indices

def update(self):
return self.delay.at(self.name, *self.indices)
return self.refs['delay'].at(self.name, *self.indices)

def reset_state(self, *args, **kwargs):
pass


def init_delay_by_return(info: Union[bm.Variable, ReturnInfo]) -> Delay:
Expand Down
12 changes: 10 additions & 2 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,14 @@ def has_aft_update(self, key: Any):
def reset_bef_updates(self, *args, **kwargs):
"""Reset all before updates."""
for node in self.before_updates.values():
node.reset_state(*args, **kwargs)
if isinstance(node, DynamicalSystem):
node.reset(*args, **kwargs)

def reset_aft_updates(self, *args, **kwargs):
"""Reset all after updates."""
for node in self.after_updates.values():
node.reset_state(*args, **kwargs)
if isinstance(node, DynamicalSystem):
node.reset(*args, **kwargs)

def update(self, *args, **kwargs):
"""The function to specify the updating rule.
Expand Down Expand Up @@ -349,6 +351,12 @@ def _compatible_update(self, *args, **kwargs):
return ret
return update_fun(*args, **kwargs)

# def __getattr__(self, item):
# if item == 'update':
# return self._compatible_update # update function compatible with previous ``update()`` function
# else:
# return object.__getattribute__(self, item)

def __getattribute__(self, item):
if item == 'update':
return self._compatible_update # update function compatible with previous ``update()`` function
Expand Down

0 comments on commit 69df9da

Please sign in to comment.