forked from atomistic-machine-learning/schnetpack
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
drafting the load_model method with conversion between spk versions (a…
…tomistic-machine-learning#646) * added: load_model function for automatically converting older models to latest version * storing spk version in model metadata
- Loading branch information
Showing
5 changed files
with
55 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .compatibility import * | ||
import importlib | ||
import torch | ||
from typing import Type, Union, List | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch | ||
import warnings | ||
|
||
|
||
__all__ = ["load_model"] | ||
|
||
|
||
def load_model(model_path, device="cpu", **kwargs): | ||
""" | ||
Load a SchNetPack model from a Torch file, enabling compatibility with models trained using earlier versions of | ||
SchNetPack. This function imports the old model and automatically updates it to the format used in the current | ||
SchNetPack version. To ensure proper functionality, the Torch model object must include a version tag, such as | ||
spk_version="2.0.4". | ||
Args: | ||
model_path (str): Path to the model file. | ||
device (torch.device): Device on which the model should be loaded. | ||
**kwargs: Additional arguments for the model loading. | ||
Returns: | ||
torch.nn.Module: Loaded model. | ||
""" | ||
|
||
def _convert_from_older(model): | ||
model.spk_version = "2.0.4" | ||
return model | ||
|
||
def _convert_from_v2_0_4(model): | ||
if not hasattr(model.representation, "electronic_embeddings"): | ||
model.representation.electronic_embeddings = [] | ||
model.spk_version = ( | ||
"latest" # TODO: replace by latest pypi version once available | ||
) | ||
return model | ||
|
||
model = torch.load(model_path, map_location=device, **kwargs) | ||
|
||
if not hasattr(model, "spk_version"): | ||
# make warning that model has no version information | ||
warnings.warn( | ||
"Model was saved without version information. Conversion to current version may fail." | ||
) | ||
model = _convert_from_older(model) | ||
|
||
if model.spk_version == "2.0.4": | ||
model = _convert_from_v2_0_4(model) | ||
|
||
return model |