Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

drafting the load_model method with conversion between spk versions #646

Merged
merged 4 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/schnetpack/interfaces/ase_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from schnetpack.data.loader import _atoms_collate_fn
from schnetpack.transform import CastTo32, CastTo64
from schnetpack.units import convert_units
from schnetpack.utils import load_model
from schnetpack.md.utils import activate_model_stress

from typing import Optional, List, Union, Dict
Expand Down Expand Up @@ -262,7 +263,7 @@ def _load_model(self, model_file: str) -> schnetpack.model.AtomisticModel:

log.info("Loading model from {:s}".format(model_file))
# load model and keep it on CPU, device can be changed afterwards
model = torch.load(model_file, map_location="cpu").to(torch.float64)
model = load_model(model_file, device=torch.device("cpu")).to(torch.float64)
model = model.eval()

if self.stress_key is not None:
Expand Down
3 changes: 2 additions & 1 deletion src/schnetpack/md/calculators/schnetpack_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Union, List, Dict, TYPE_CHECKING

import schnetpack.atomistic.response
from schnetpack.utils import load_model

if TYPE_CHECKING:
from schnetpack.md import System
Expand Down Expand Up @@ -94,7 +95,7 @@ def _load_model(self, model_file: str) -> AtomisticModel:

log.info("Loading model from {:s}".format(model_file))
# load model and keep it on CPU, device can be changed afterwards
model = torch.load(model_file, map_location="cpu").to(torch.float64)
model = load_model(model_file, device=torch.device("cpu")).to(torch.float64)
model = model.eval()

if self.stress_key is not None:
Expand Down
2 changes: 2 additions & 0 deletions src/schnetpack/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Dict, Optional, List

import schnetpack as spk
from schnetpack.transform import Transform
import schnetpack.properties as properties
from schnetpack.utils import as_dtype
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
self.postprocessors = nn.ModuleList(postprocessors)
self.required_derivatives: Optional[List[str]] = None
self.model_outputs: Optional[List[str]] = None
self.spk_version = spk.__version__

def collect_derivatives(self) -> List[str]:
self.required_derivatives = None
Expand Down
1 change: 1 addition & 0 deletions src/schnetpack/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .compatibility import *
import importlib
import torch
from typing import Type, Union, List
Expand Down
48 changes: 48 additions & 0 deletions src/schnetpack/utils/compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import warnings


__all__ = ["load_model"]


def load_model(model_path, device="cpu", **kwargs):
"""
Load a SchNetPack model from a Torch file, enabling compatibility with models trained using earlier versions of
SchNetPack. This function imports the old model and automatically updates it to the format used in the current
SchNetPack version. To ensure proper functionality, the Torch model object must include a version tag, such as
spk_version="2.0.4".

Args:
model_path (str): Path to the model file.
device (torch.device): Device on which the model should be loaded.
**kwargs: Additional arguments for the model loading.

Returns:
torch.nn.Module: Loaded model.
"""

def _convert_from_older(model):
model.spk_version = "2.0.4"
return model

def _convert_from_v2_0_4(model):
if not hasattr(model.representation, "electronic_embeddings"):
model.representation.electronic_embeddings = []
model.spk_version = (
"latest" # TODO: replace by latest pypi version once available
)
return model

model = torch.load(model_path, map_location=device, **kwargs)

if not hasattr(model, "spk_version"):
# make warning that model has no version information
warnings.warn(
"Model was saved without version information. Conversion to current version may fail."
)
model = _convert_from_older(model)

if model.spk_version == "2.0.4":
model = _convert_from_v2_0_4(model)

return model
Loading