Skip to content

Commit

Permalink
drafting the load_model method with conversion between spk versions (a…
Browse files Browse the repository at this point in the history
…tomistic-machine-learning#646)

* added: load_model function for automatically converting older models to latest version

* storing spk version in model metadata
  • Loading branch information
jnsLs authored and Maltimore committed Nov 15, 2024
1 parent 2a56d3d commit 860d0dc
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/schnetpack/interfaces/ase_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from schnetpack.data.loader import _atoms_collate_fn
from schnetpack.transform import CastTo32, CastTo64
from schnetpack.units import convert_units
from schnetpack.utils import load_model
from schnetpack.md.utils import activate_model_stress

from typing import Optional, List, Union, Dict
Expand Down Expand Up @@ -261,7 +262,7 @@ def _load_model(self, model_file: str) -> schnetpack.model.AtomisticModel:

log.info("Loading model from {:s}".format(model_file))
# load model and keep it on CPU, device can be changed afterwards
model = torch.load(model_file, map_location="cpu").to(torch.float64)
model = load_model(model_file, device=torch.device("cpu")).to(torch.float64)
model = model.eval()

if self.stress_key is not None:
Expand Down
3 changes: 2 additions & 1 deletion src/schnetpack/md/calculators/schnetpack_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Union, List, Dict, TYPE_CHECKING

import schnetpack.atomistic.response
from schnetpack.utils import load_model

if TYPE_CHECKING:
from schnetpack.md import System
Expand Down Expand Up @@ -94,7 +95,7 @@ def _load_model(self, model_file: str) -> AtomisticModel:

log.info("Loading model from {:s}".format(model_file))
# load model and keep it on CPU, device can be changed afterwards
model = torch.load(model_file, map_location="cpu").to(torch.float64)
model = load_model(model_file, device=torch.device("cpu")).to(torch.float64)
model = model.eval()

if self.stress_key is not None:
Expand Down
2 changes: 2 additions & 0 deletions src/schnetpack/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Dict, Optional, List

import schnetpack as spk
from schnetpack.transform import Transform
import schnetpack.properties as properties
from schnetpack.utils import as_dtype
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
self.postprocessors = nn.ModuleList(postprocessors)
self.required_derivatives: Optional[List[str]] = None
self.model_outputs: Optional[List[str]] = None
self.spk_version = spk.__version__

def collect_derivatives(self) -> List[str]:
self.required_derivatives = None
Expand Down
1 change: 1 addition & 0 deletions src/schnetpack/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .compatibility import *
import importlib
import torch
from typing import Type, Union, List
Expand Down
48 changes: 48 additions & 0 deletions src/schnetpack/utils/compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import warnings


__all__ = ["load_model"]


def load_model(model_path, device="cpu", **kwargs):
"""
Load a SchNetPack model from a Torch file, enabling compatibility with models trained using earlier versions of
SchNetPack. This function imports the old model and automatically updates it to the format used in the current
SchNetPack version. To ensure proper functionality, the Torch model object must include a version tag, such as
spk_version="2.0.4".
Args:
model_path (str): Path to the model file.
device (torch.device): Device on which the model should be loaded.
**kwargs: Additional arguments for the model loading.
Returns:
torch.nn.Module: Loaded model.
"""

def _convert_from_older(model):
model.spk_version = "2.0.4"
return model

def _convert_from_v2_0_4(model):
if not hasattr(model.representation, "electronic_embeddings"):
model.representation.electronic_embeddings = []
model.spk_version = (
"latest" # TODO: replace by latest pypi version once available
)
return model

model = torch.load(model_path, map_location=device, **kwargs)

if not hasattr(model, "spk_version"):
# make warning that model has no version information
warnings.warn(
"Model was saved without version information. Conversion to current version may fail."
)
model = _convert_from_older(model)

if model.spk_version == "2.0.4":
model = _convert_from_v2_0_4(model)

return model

0 comments on commit 860d0dc

Please sign in to comment.