From d7007fcb949409482daf371626fde009ed9a2159 Mon Sep 17 00:00:00 2001 From: Zeyuan Tang Date: Wed, 17 Mar 2021 11:04:49 +0100 Subject: [PATCH] Specifying devices when loading system state Making sure all tensors are on the same device during a MD simulation --- src/schnetpack/md/parsers/md_setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/schnetpack/md/parsers/md_setup.py b/src/schnetpack/md/parsers/md_setup.py index 9246182ad..f4fb1df3c 100644 --- a/src/schnetpack/md/parsers/md_setup.py +++ b/src/schnetpack/md/parsers/md_setup.py @@ -86,11 +86,11 @@ def _build_simulator(self): # If requested, read restart data if self.restart and (self.restart is not None): - state_dict = torch.load(self.restart) + state_dict = torch.load(self.restart, map_location=self.device) simulator.restart_simulation(state_dict, soft=False) logging.info(f"Restarting simulation from {self.restart}...") elif self.load_system_state: - state_dict = torch.load(self.load_system_state) + state_dict = torch.load(self.load_system_state, map_location=self.device) simulator.load_system_state(state_dict) logging.info(f"Loaded system state from {self.load_system_state}...")