Skip to content

Commit

Permalink
Added on_step_finalize in MD simulation hooks (#390)
Browse files Browse the repository at this point in the history
* added `on_step_finalize` hook for logging etc.

* switched to new hook for logging
  • Loading branch information
mgastegger authored Feb 21, 2022
1 parent 83f3fef commit cb554ef
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
5 changes: 4 additions & 1 deletion src/schnetpack/md/simulation_hooks/basic_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def on_step_middle(self, simulator: Simulator):
def on_step_end(self, simulator: Simulator):
pass

def on_step_finalize(self, simulator: Simulator):
pass

def on_step_failed(self, simulator: Simulator):
pass

Expand All @@ -51,7 +54,7 @@ def __init__(self, every_n_steps: int, remove_rotation: bool, wrap_positions: bo
self.remove_rotation = remove_rotation
self.wrap_positions = wrap_positions

def on_step_end(self, simulator: Simulator):
def on_step_finalize(self, simulator: Simulator):
if simulator.step % self.every_n_steps == 0:
simulator.system.remove_center_of_mass()
simulator.system.remove_translation()
Expand Down
8 changes: 4 additions & 4 deletions src/schnetpack/md/simulation_hooks/callback_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, checkpoint_file: str, every_n_steps: int):
self.every_n_steps = every_n_steps
self.checkpoint_file = checkpoint_file

def on_step_end(self, simulator: Simulator):
def on_step_finalize(self, simulator: Simulator):
"""
Store state_dict at specified intervals.
Expand Down Expand Up @@ -526,7 +526,7 @@ def on_simulation_start(self, simulator: Simulator):
# Enable single writer, multiple reader flag
self.file.swmr_mode = True

def on_step_end(self, simulator: Simulator):
def on_step_finalize(self, simulator: Simulator):
"""
Update the buffer of each stream after each specified interval and flush the buffer to the main file if full.
Expand Down Expand Up @@ -604,7 +604,7 @@ def on_simulation_start(self, simulator):
self.n_replicas = simulator.system.n_replicas
self.n_molecules = simulator.system.n_molecules

def on_step_end(self, simulator):
def on_step_finalize(self, simulator: Simulator):
"""
Routine for collecting and storing scalar properties of replicas and molecules during the simulation. Needs to
be adapted based on the properties.
Expand Down Expand Up @@ -678,7 +678,7 @@ def __init__(self, log_file: str, properties: List, every_n_steps: int = 100):

self.properties = properties

def on_step_end(self, simulator):
def on_step_finalize(self, simulator: Simulator):
"""
Log the systems properties the given intervals.
Expand Down
4 changes: 4 additions & 0 deletions src/schnetpack/md/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def simulate(self, n_steps: int):
for hook in self.simulator_hooks[::-1]:
hook.on_step_end(self)

# Logging hooks etc
for hook in self.simulator_hooks:
hook.on_step_finalize(self)

self.step += 1
self.effective_steps += 1

Expand Down

0 comments on commit cb554ef

Please sign in to comment.