Skip to content

Commit

Permalink
add model is None in the input options (#120)
Browse files Browse the repository at this point in the history
* add model is None in the input options

---------

Co-authored-by: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com>
  • Loading branch information
federicazanca and ElliottKasoar authored May 23, 2024
1 parent 3ffb47d commit 32478fe
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
4 changes: 4 additions & 0 deletions aiida_mlip/calculations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,16 @@ def prepare_for_submission(

model_path = None
if "model" in self.inputs:
# Raise error if model is None
if self.inputs.model is None:
raise ValueError("Model cannot be None")
model_path = self.inputs.model.filepath
else:
if "config" in self.inputs and "model" in self.inputs.config:
model_path = None
else:
if "arch" in self.inputs:
# if model is not given (which is different than it being None)
model_path = ModelData.download(
"https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model", # pylint: disable=line-too-long
architecture,
Expand Down
12 changes: 9 additions & 3 deletions aiida_mlip/helpers/help_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ def load_model(
Load a model from a file path or URL.
If the string represents a file path, the model will be loaded from that path.
Otherwise, the model will be downloaded from the specified location.
If it's a URL, the model will be downloaded from the specified location.
If the input model is None it returns a default model corresponding to the
default used in the Calcjobs.
Parameters
----------
model : Optional[Union[str, Path]]
Model file path or a URL for downloading the model.
Model file path or a URL for downloading the model or None to use the default.
architecture : str
The architecture of the model.
cache_dir : Optional[Union[str, Path]]
Expand All @@ -40,7 +42,11 @@ def load_model(
The loaded model.
"""
if model is None:
loaded_model = None
loaded_model = ModelData.download(
"https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model", # pylint: disable=line-too-long
architecture,
cache_dir=cache_dir,
)
elif (file_path := Path(model)).is_file():
loaded_model = ModelData.local_file(file_path, architecture=architecture)
else:
Expand Down

0 comments on commit 32478fe

Please sign in to comment.