Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:ecmwf/anemoi-inference into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Dec 7, 2024
2 parents c2fa83c + 425bf7f commit 8564242
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Keep it human-readable, your future self will thank you!
- Add support for unstructured grids
- Add CONTRIBUTORS.md file (#36)
- Add sanetise command
- Add support for huggingface

### Changed
- Change `write_initial_state` default value to `true`
Expand Down
12 changes: 10 additions & 2 deletions docs/configs/top-level.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,21 @@ The following options control the inference process:
checkpoint:
===========

The only compulsory option is ``checkpoint``, which specifies the path
to the checkpoint file.
The only compulsory option is ``checkpoint``, which specifies the
checkpoint file. It can be a path to a local file, or a huggingface
config.

.. code:: yaml
checkpoint: /path/to/checkpoint.ckpt
.. code:: yaml
checkpoint:
huggingface:
repo_id: "ecmwf/aifs-single"
filename: "aifs_single_v0.2.1.ckpt"
device:
=======

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ optional-dependencies.docs = [
"sphinx-rtd-theme",
]

optional-dependencies.plugin = [ "ai-models>=0.7", "tqdm" ]
optional-dependencies.huggingface = [ "huggingface-hub" ]

optional-dependencies.plugin = [ "ai-models>=0.7", "tqdm" ]
optional-dependencies.tests = [ "anemoi-datasets[all]", "hypothesis", "pytest" ]

urls.Documentation = "https://anemoi-inference.readthedocs.io/"
Expand Down
30 changes: 29 additions & 1 deletion src/anemoi/inference/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
LOG = logging.getLogger(__name__)


def _download_huggingfacehub(huggingface_config):
"""Download model from huggingface"""
try:
from huggingface_hub import hf_hub_download
except ImportError as e:
raise ImportError("Could not import `huggingface_hub`, please run `pip install huggingface_hub`.") from e

config_path = hf_hub_download(**huggingface_config)
return config_path


class Checkpoint:
"""Represents an inference checkpoint."""

Expand All @@ -31,13 +42,30 @@ def __init__(self, path, *, patch_metadata=None):
def __repr__(self):
return f"Checkpoint({self.path})"

@cached_property
def path(self):
import json

try:
self._model = json.loads(self._model)
except TypeError:
pass

if isinstance(self._model, str):
return self._model
elif isinstance(self._model, dict):
if "huggingface" in self._model:
return _download_huggingfacehub(self._model["huggingface"])
pass
raise TypeError(f"Cannot parse model path: {self._model}. It must be a path or dict")

@cached_property
def _metadata(self):
try:
result = Metadata(*load_metadata(self.path, supporting_arrays=True))
except Exception as e:
LOG.warning("Version for not support `supporting_arrays` (%s)", e)
result= Metadata(load_metadata(self.path))
result = Metadata(load_metadata(self.path))

if self.patch_metadata:
LOG.warning("Patching metadata with %r", self.patch_metadata)
Expand Down
3 changes: 2 additions & 1 deletion src/anemoi/inference/commands/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class RetrieveCmd(Command):
def add_arguments(self, command_parser):
command_parser.description = self.__doc__
command_parser.add_argument("config", type=str, help="Path to checkpoint")
command_parser.add_argument("--defaults", action="append", help="Sources of default values.")
command_parser.add_argument("--date", type=str, help="Date")
command_parser.add_argument("--output", type=str, help="Output file")
command_parser.add_argument("--staging-dates", type=str, help="Path to a file with staging dates")
Expand All @@ -34,7 +35,7 @@ def add_arguments(self, command_parser):

def run(self, args):

config = load_config(args.config, args.overrides)
config = load_config(args.config, args.overrides, defaults=args.defaults)

runner = DefaultRunner(config)

Expand Down
6 changes: 4 additions & 2 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import datetime
import logging
import os
from typing import Any
from typing import Dict
from typing import Literal

import yaml
from pydantic import BaseModel
Expand All @@ -27,8 +29,8 @@ class Config:

description: str | None = None

checkpoint: str
"""A path an Anemoi checkpoint file."""
checkpoint: str | Dict[Literal["huggingface"], Dict[str, Any]]
"""A path to an Anemoi checkpoint file."""

date: str | int | datetime.datetime | None = None
"""The starting date for the forecast. If not provided, the date will depend on the selected Input object. If a string, it is parsed by :func:`anemoi.utils.dates.as_datetime`.
Expand Down
18 changes: 1 addition & 17 deletions src/anemoi/inference/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import warnings
from collections import defaultdict
from functools import cached_property
from types import MappingProxyType as frozendict
from typing import Literal

import numpy as np
Expand All @@ -37,15 +38,6 @@ def _remove_full_paths(x):
return x


class frozendict(dict):
def __deepcopy__(self, memo):
# As this is a frozendict, we can return the same object
return self

def __setitem__(self, key, value):
raise TypeError("frozendict is immutable")


class Metadata(PatchMixin, LegacyMixin):
"""An object that holds metadata of a checkpoint."""

Expand Down Expand Up @@ -835,7 +827,6 @@ def print_variable_categories(self):
for name, categories in sorted(self.variable_categories().items()):
LOG.info(f" {name:{length}} => {', '.join(categories)}")


###########################################################################

def patch(self, patch):
Expand All @@ -854,13 +845,6 @@ def merge(main, patch):
merge(self._metadata, patch)









class SourceMetadata(Metadata):
"""An object that holds metadata of a source. It is only the `dataset` and `supporting_arrays` parts of the metadata.
The rest is forwarded to the parent metadata object.
Expand Down
3 changes: 2 additions & 1 deletion src/anemoi/inference/outputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def close(self):
path = self.archive_requests["path"]
extra = self.archive_requests.get("extra", {})
patch = self.archive_requests.get("patch", {})
indent = self.archive_requests.get("indent", None)

def _patch(r):
if self.context.config.use_grib_paramid:
Expand Down Expand Up @@ -147,4 +148,4 @@ def _patch(r):
request.update(extra)
requests.append(request)

json.dump(requests, f, indent=4)
json.dump(requests, f, indent=indent)
4 changes: 4 additions & 0 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):

yield result

# No need to prepare next input tensor if we are at the last step
if s == steps - 1:
continue

# Update tensor for next iteration

check[:] = reset
Expand Down

0 comments on commit 8564242

Please sign in to comment.