Skip to content

Commit

Permalink
Add support for downloading models from huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
HCookie committed Dec 6, 2024
1 parent b1d2560 commit e34a404
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Keep it human-readable, your future self will thank you!
- Add support for unstructured grids
- Add CONTRIBUTORS.md file (#36)
- Add sanetise command
- Add support for huggingface

### Changed
- Change `write_initial_state` default value to `true`
Expand Down
11 changes: 9 additions & 2 deletions docs/configs/top-level.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@ The following options control the inference process:
checkpoint:
===========

The only compulsory option is ``checkpoint``, which specifies the path
to the checkpoint file.
The only compulsory option is ``checkpoint``, which specifies the checkpoint file.
It can be a path to a local file, or a huggingface config.

.. code:: yaml
checkpoint: /path/to/checkpoint.ckpt
.. code:: yaml
checkpoint:
huggingface:
repo_id: "ecmwf/aifs-single"
filename: "aifs_single_v0.2.1.ckpt"
device:
=======

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ optional-dependencies.docs = [
]

optional-dependencies.plugin = [ "ai-models>=0.7", "tqdm" ]
optional-dependencies.huggingface = [ "huggingface_hub" ]


optional-dependencies.tests = [ "anemoi-datasets[all]", "hypothesis", "pytest" ]

Expand Down
31 changes: 29 additions & 2 deletions src/anemoi/inference/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,42 @@

LOG = logging.getLogger(__name__)

def _download_huggingfacehub(huggingface_config):
"""Download model from huggingface"""
try:
from huggingface_hub import hf_hub_download
except ImportError as e:
raise ImportError("Could not import `huggingface_hub`, please run `pip install huggingface_hub`.") from e

config_path = hf_hub_download(**huggingface_config)
return config_path

class Checkpoint:
"""Represents an inference checkpoint."""

def __init__(self, path):
self.path = path
def __init__(self, model):
self._model = model

def __repr__(self):
return f"Checkpoint({self.path})"

@cached_property
def path(self):
import json

try:
self._model = json.loads(self._model)
except TypeError:
pass

if isinstance(self._model, str):
return self._model
elif isinstance(self._model, dict):
if 'huggingface' in self._model:
return _download_huggingfacehub(self._model['huggingface'])
pass
raise TypeError(f"Cannot parse model path: {self._model}. It must be a path or dict")


@cached_property
def _metadata(self):
Expand Down
6 changes: 3 additions & 3 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import datetime
import logging
import os
from typing import Dict
from typing import Dict, Literal, Any

import yaml
from pydantic import BaseModel
Expand All @@ -27,8 +27,8 @@ class Config:

description: str | None = None

checkpoint: str
"""A path an Anemoi checkpoint file."""
checkpoint: str | Dict[Literal['huggingface'], Dict[str, Any]]
"""A path to an Anemoi checkpoint file."""

date: str | int | datetime.datetime | None = None
"""The starting date for the forecast. If not provided, the date will depend on the selected Input object. If a string, it is parsed by :func:`anemoi.utils.dates.as_datetime`.
Expand Down

0 comments on commit e34a404

Please sign in to comment.