From 376b1115f043db67552a73240dfa4ff1b8cf8174 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Sat, 14 Dec 2024 14:33:38 +0000 Subject: [PATCH] Hotfix: huggingface loading --- src/anemoi/inference/checkpoint.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py index 5a1754d..3e52130 100644 --- a/src/anemoi/inference/checkpoint.py +++ b/src/anemoi/inference/checkpoint.py @@ -48,7 +48,7 @@ class Checkpoint: """Represents an inference checkpoint.""" def __init__(self, path, *, patch_metadata=None): - self.path = path + self._path = path self.patch_metadata = patch_metadata def __repr__(self): @@ -59,17 +59,17 @@ def path(self): import json try: - self._model = json.loads(self._model) - except TypeError: + path = json.loads(self._path) + except Exception: + path = self._path + + if isinstance(path, (Path, str)): + return path + elif isinstance(path, dict): + if "huggingface" in path: + return _download_huggingfacehub(path["huggingface"]) pass - - if isinstance(self._model, str): - return self._model - elif isinstance(self._model, dict): - if "huggingface" in self._model: - return _download_huggingfacehub(self._model["huggingface"]) - pass - raise TypeError(f"Cannot parse model path: {self._model}. It must be a path or dict") + raise TypeError(f"Cannot parse model path: {path}. It must be a path or dict") @cached_property def _metadata(self):