Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 6, 2024
1 parent e34a404 commit bfe73f0
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
7 changes: 4 additions & 3 deletions docs/configs/top-level.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@ The following options control the inference process:
checkpoint:
===========

The only compulsory option is ``checkpoint``, which specifies the checkpoint file.
It can be a path to a local file, or a huggingface config.
The only compulsory option is ``checkpoint``, which specifies the
checkpoint file. It can be a path to a local file, or a huggingface
config.

.. code:: yaml
checkpoint: /path/to/checkpoint.ckpt
.. code:: yaml
checkpoint:
checkpoint:
huggingface:
repo_id: "ecmwf/aifs-single"
filename: "aifs_single_v0.2.1.ckpt"
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,9 @@ optional-dependencies.docs = [
"sphinx-rtd-theme",
]

optional-dependencies.plugin = [ "ai-models>=0.7", "tqdm" ]
optional-dependencies.huggingface = [ "huggingface_hub" ]

optional-dependencies.huggingface = [ "huggingface-hub" ]

optional-dependencies.plugin = [ "ai-models>=0.7", "tqdm" ]
optional-dependencies.tests = [ "anemoi-datasets[all]", "hypothesis", "pytest" ]

urls.Documentation = "https://anemoi-inference.readthedocs.io/"
Expand Down
9 changes: 5 additions & 4 deletions src/anemoi/inference/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

LOG = logging.getLogger(__name__)


def _download_huggingfacehub(huggingface_config):
"""Download model from huggingface"""
try:
Expand All @@ -30,6 +31,7 @@ def _download_huggingfacehub(huggingface_config):
config_path = hf_hub_download(**huggingface_config)
return config_path


class Checkpoint:
"""Represents an inference checkpoint."""

Expand All @@ -38,7 +40,7 @@ def __init__(self, model):

def __repr__(self):
return f"Checkpoint({self.path})"

@cached_property
def path(self):
import json
Expand All @@ -51,12 +53,11 @@ def path(self):
if isinstance(self._model, str):
return self._model
elif isinstance(self._model, dict):
if 'huggingface' in self._model:
return _download_huggingfacehub(self._model['huggingface'])
if "huggingface" in self._model:
return _download_huggingfacehub(self._model["huggingface"])
pass
raise TypeError(f"Cannot parse model path: {self._model}. It must be a path or dict")


@cached_property
def _metadata(self):
try:
Expand Down
6 changes: 4 additions & 2 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import datetime
import logging
import os
from typing import Dict, Literal, Any
from typing import Any
from typing import Dict
from typing import Literal

import yaml
from pydantic import BaseModel
Expand All @@ -27,7 +29,7 @@ class Config:

description: str | None = None

checkpoint: str | Dict[Literal['huggingface'], Dict[str, Any]]
checkpoint: str | Dict[Literal["huggingface"], Dict[str, Any]]
"""A path to an Anemoi checkpoint file."""

date: str | int | datetime.datetime | None = None
Expand Down

0 comments on commit bfe73f0

Please sign in to comment.