diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d85b47..0fbdcae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ Keep it human-readable, your future self will thank you! - Add support for unstructured grids - Add CONTRIBUTORS.md file (#36) - Add sanetise command +- Add support for huggingface ### Changed - Change `write_initial_state` default value to `true` diff --git a/docs/configs/top-level.rst b/docs/configs/top-level.rst index 0e3a709..2168825 100644 --- a/docs/configs/top-level.rst +++ b/docs/configs/top-level.rst @@ -11,13 +11,21 @@ The following options control the inference process: checkpoint: =========== -The only compulsory option is ``checkpoint``, which specifies the path -to the checkpoint file. +The only compulsory option is ``checkpoint``, which specifies the +checkpoint file. It can be a path to a local file, or a huggingface +config. .. code:: yaml checkpoint: /path/to/checkpoint.ckpt +.. code:: yaml + + checkpoint: + huggingface: + repo_id: "ecmwf/aifs-single" + filename: "aifs_single_v0.2.1.ckpt" + device: ======= diff --git a/pyproject.toml b/pyproject.toml index 8794e02..a3d2376 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,8 +68,9 @@ optional-dependencies.docs = [ "sphinx-rtd-theme", ] -optional-dependencies.plugin = [ "ai-models>=0.7", "tqdm" ] +optional-dependencies.huggingface = [ "huggingface-hub" ] +optional-dependencies.plugin = [ "ai-models>=0.7", "tqdm" ] optional-dependencies.tests = [ "anemoi-datasets[all]", "hypothesis", "pytest" ] urls.Documentation = "https://anemoi-inference.readthedocs.io/" diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py index fa789de..9410aca 100644 --- a/src/anemoi/inference/checkpoint.py +++ b/src/anemoi/inference/checkpoint.py @@ -21,6 +21,17 @@ LOG = logging.getLogger(__name__) +def _download_huggingfacehub(huggingface_config): + """Download model from huggingface""" + try: + from huggingface_hub import hf_hub_download + except ImportError as e: + raise ImportError("Could not import `huggingface_hub`, please run `pip install huggingface_hub`.") from e + + config_path = hf_hub_download(**huggingface_config) + return config_path + + class Checkpoint: """Represents an inference checkpoint.""" @@ -31,13 +42,30 @@ def __init__(self, path, *, patch_metadata=None): def __repr__(self): return f"Checkpoint({self.path})" + @cached_property + def path(self): + import json + + try: + self._model = json.loads(self._model) + except TypeError: + pass + + if isinstance(self._model, str): + return self._model + elif isinstance(self._model, dict): + if "huggingface" in self._model: + return _download_huggingfacehub(self._model["huggingface"]) + pass + raise TypeError(f"Cannot parse model path: {self._model}. It must be a path or dict") + @cached_property def _metadata(self): try: result = Metadata(*load_metadata(self.path, supporting_arrays=True)) except Exception as e: LOG.warning("Version for not support `supporting_arrays` (%s)", e) - result= Metadata(load_metadata(self.path)) + result = Metadata(load_metadata(self.path)) if self.patch_metadata: LOG.warning("Patching metadata with %r", self.patch_metadata) diff --git a/src/anemoi/inference/commands/retrieve.py b/src/anemoi/inference/commands/retrieve.py index 9537a69..86b8542 100644 --- a/src/anemoi/inference/commands/retrieve.py +++ b/src/anemoi/inference/commands/retrieve.py @@ -26,6 +26,7 @@ class RetrieveCmd(Command): def add_arguments(self, command_parser): command_parser.description = self.__doc__ command_parser.add_argument("config", type=str, help="Path to checkpoint") + command_parser.add_argument("--defaults", action="append", help="Sources of default values.") command_parser.add_argument("--date", type=str, help="Date") command_parser.add_argument("--output", type=str, help="Output file") command_parser.add_argument("--staging-dates", type=str, help="Path to a file with staging dates") @@ -34,7 +35,7 @@ def add_arguments(self, command_parser): def run(self, args): - config = load_config(args.config, args.overrides) + config = load_config(args.config, args.overrides, defaults=args.defaults) runner = DefaultRunner(config) diff --git a/src/anemoi/inference/config.py b/src/anemoi/inference/config.py index 3b0cfff..738d2f6 100644 --- a/src/anemoi/inference/config.py +++ b/src/anemoi/inference/config.py @@ -12,7 +12,9 @@ import datetime import logging import os +from typing import Any from typing import Dict +from typing import Literal import yaml from pydantic import BaseModel @@ -27,8 +29,8 @@ class Config: description: str | None = None - checkpoint: str - """A path an Anemoi checkpoint file.""" + checkpoint: str | Dict[Literal["huggingface"], Dict[str, Any]] + """A path to an Anemoi checkpoint file.""" date: str | int | datetime.datetime | None = None """The starting date for the forecast. If not provided, the date will depend on the selected Input object. If a string, it is parsed by :func:`anemoi.utils.dates.as_datetime`. diff --git a/src/anemoi/inference/metadata.py b/src/anemoi/inference/metadata.py index aef45e3..58d9096 100644 --- a/src/anemoi/inference/metadata.py +++ b/src/anemoi/inference/metadata.py @@ -13,6 +13,7 @@ import warnings from collections import defaultdict from functools import cached_property +from types import MappingProxyType as frozendict from typing import Literal import numpy as np @@ -37,15 +38,6 @@ def _remove_full_paths(x): return x -class frozendict(dict): - def __deepcopy__(self, memo): - # As this is a frozendict, we can return the same object - return self - - def __setitem__(self, key, value): - raise TypeError("frozendict is immutable") - - class Metadata(PatchMixin, LegacyMixin): """An object that holds metadata of a checkpoint.""" @@ -835,7 +827,6 @@ def print_variable_categories(self): for name, categories in sorted(self.variable_categories().items()): LOG.info(f" {name:{length}} => {', '.join(categories)}") - ########################################################################### def patch(self, patch): @@ -854,13 +845,6 @@ def merge(main, patch): merge(self._metadata, patch) - - - - - - - class SourceMetadata(Metadata): """An object that holds metadata of a source. It is only the `dataset` and `supporting_arrays` parts of the metadata. The rest is forwarded to the parent metadata object. diff --git a/src/anemoi/inference/outputs/gribfile.py b/src/anemoi/inference/outputs/gribfile.py index 0d12066..ff99421 100644 --- a/src/anemoi/inference/outputs/gribfile.py +++ b/src/anemoi/inference/outputs/gribfile.py @@ -117,6 +117,7 @@ def close(self): path = self.archive_requests["path"] extra = self.archive_requests.get("extra", {}) patch = self.archive_requests.get("patch", {}) + indent = self.archive_requests.get("indent", None) def _patch(r): if self.context.config.use_grib_paramid: @@ -147,4 +148,4 @@ def _patch(r): request.update(extra) requests.append(request) - json.dump(requests, f, indent=4) + json.dump(requests, f, indent=indent) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 9ff0d2a..da60915 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -297,6 +297,10 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): yield result + # No need to prepare next input tensor if we are at the last step + if s == steps - 1: + continue + # Update tensor for next iteration check[:] = reset