diff --git a/docs/configs/top-level.rst b/docs/configs/top-level.rst index be43610..2168825 100644 --- a/docs/configs/top-level.rst +++ b/docs/configs/top-level.rst @@ -11,8 +11,9 @@ The following options control the inference process: checkpoint: =========== -The only compulsory option is ``checkpoint``, which specifies the checkpoint file. -It can be a path to a local file, or a huggingface config. +The only compulsory option is ``checkpoint``, which specifies the +checkpoint file. It can be a path to a local file, or a huggingface +config. .. code:: yaml @@ -20,7 +21,7 @@ It can be a path to a local file, or a huggingface config. .. code:: yaml - checkpoint: + checkpoint: huggingface: repo_id: "ecmwf/aifs-single" filename: "aifs_single_v0.2.1.ckpt" diff --git a/pyproject.toml b/pyproject.toml index be0a2f9..a3d2376 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,10 +68,9 @@ optional-dependencies.docs = [ "sphinx-rtd-theme", ] -optional-dependencies.plugin = [ "ai-models>=0.7", "tqdm" ] -optional-dependencies.huggingface = [ "huggingface_hub" ] - +optional-dependencies.huggingface = [ "huggingface-hub" ] +optional-dependencies.plugin = [ "ai-models>=0.7", "tqdm" ] optional-dependencies.tests = [ "anemoi-datasets[all]", "hypothesis", "pytest" ] urls.Documentation = "https://anemoi-inference.readthedocs.io/" diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py index 4b767d9..13b64f1 100644 --- a/src/anemoi/inference/checkpoint.py +++ b/src/anemoi/inference/checkpoint.py @@ -20,6 +20,7 @@ LOG = logging.getLogger(__name__) + def _download_huggingfacehub(huggingface_config): """Download model from huggingface""" try: @@ -30,6 +31,7 @@ def _download_huggingfacehub(huggingface_config): config_path = hf_hub_download(**huggingface_config) return config_path + class Checkpoint: """Represents an inference checkpoint.""" @@ -38,7 +40,7 @@ def __init__(self, model): def __repr__(self): return f"Checkpoint({self.path})" - + @cached_property def path(self): import json @@ -51,12 +53,11 @@ def path(self): if isinstance(self._model, str): return self._model elif isinstance(self._model, dict): - if 'huggingface' in self._model: - return _download_huggingfacehub(self._model['huggingface']) + if "huggingface" in self._model: + return _download_huggingfacehub(self._model["huggingface"]) pass raise TypeError(f"Cannot parse model path: {self._model}. It must be a path or dict") - @cached_property def _metadata(self): try: diff --git a/src/anemoi/inference/config.py b/src/anemoi/inference/config.py index 11c9198..cb8f53e 100644 --- a/src/anemoi/inference/config.py +++ b/src/anemoi/inference/config.py @@ -12,7 +12,9 @@ import datetime import logging import os -from typing import Dict, Literal, Any +from typing import Any +from typing import Dict +from typing import Literal import yaml from pydantic import BaseModel @@ -27,7 +29,7 @@ class Config: description: str | None = None - checkpoint: str | Dict[Literal['huggingface'], Dict[str, Any]] + checkpoint: str | Dict[Literal["huggingface"], Dict[str, Any]] """A path to an Anemoi checkpoint file.""" date: str | int | datetime.datetime | None = None