diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py index 3c607015d92ea..ac493a059a013 100644 --- a/machine-learning/app/main.py +++ b/machine-learning/app/main.py @@ -192,23 +192,18 @@ async def load(model: InferenceModel) -> InferenceModel: return model def _load(model: InferenceModel) -> InferenceModel: + if model.load_attempts > 1: + raise HTTPException(500, f"Failed to load model '{model.model_name}'") with lock: model.load() return model try: - await run(_load, model) - return model + return await run(_load, model) except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile): - log.warning( - ( - f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'." - "Clearing cache and retrying." - ) - ) + log.warning(f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'. Clearing cache.") model.clear_cache() - await run(_load, model) - return model + return await run(_load, model) async def idle_shutdown_task() -> None: diff --git a/machine-learning/app/models/base.py b/machine-learning/app/models/base.py index f64a873010115..4ad6fd6eb7049 100644 --- a/machine-learning/app/models/base.py +++ b/machine-learning/app/models/base.py @@ -31,6 +31,7 @@ def __init__( **model_kwargs: Any, ) -> None: self.loaded = False + self.load_attempts = 0 self.model_name = clean_name(model_name) self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default self.providers = providers if providers is not None else self.providers_default @@ -48,9 +49,11 @@ def download(self) -> None: def load(self) -> None: if self.loaded: return + self.load_attempts += 1 self.download() - log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory") + attempt = f"Attempt #{self.load_attempts + 1} to load" if self.load_attempts else "Loading" + log.info(f"{attempt} {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory") self.session = self._load() self.loaded = True diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py index d9d1455bd1cda..2068c7a4c6856 100644 --- a/machine-learning/app/test_main.py +++ b/machine-learning/app/test_main.py @@ -11,6 +11,7 @@ import numpy as np import onnxruntime as ort import pytest +from fastapi import HTTPException from fastapi.testclient import TestClient from PIL import Image from pytest import MonkeyPatch @@ -627,6 +628,7 @@ class TestLoad: async def test_load(self) -> None: mock_model = mock.Mock(spec=InferenceModel) mock_model.loaded = False + mock_model.load_attempts = 0 res = await load(mock_model) @@ -650,6 +652,7 @@ async def test_load_clears_cache_and_retries_if_os_error(self) -> None: mock_model.model_task = ModelTask.SEARCH mock_model.load.side_effect = [OSError, None] mock_model.loaded = False + mock_model.load_attempts = 0 res = await load(mock_model) @@ -657,6 +660,20 @@ async def test_load_clears_cache_and_retries_if_os_error(self) -> None: mock_model.clear_cache.assert_called_once() assert mock_model.load.call_count == 2 + async def test_load_clears_cache_and_raises_if_os_error_and_already_retried(self) -> None: + mock_model = mock.Mock(spec=InferenceModel) + mock_model.model_name = "test_model_name" + mock_model.model_type = ModelType.VISUAL + mock_model.model_task = ModelTask.SEARCH + mock_model.loaded = False + mock_model.load_attempts = 2 + + with pytest.raises(HTTPException): + await load(mock_model) + + mock_model.clear_cache.assert_not_called() + mock_model.load.assert_not_called() + @pytest.mark.skipif( not settings.test_full,