diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index d102923c2cd..17a37ffc949 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -21,10 +21,11 @@ from invokeai.app.services.model_install import ModelInstallJob from invokeai.app.services.model_records import ( + DuplicateModelException, InvalidModelException, + ModelRecordChanges, UnknownModelException, ) -from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges from invokeai.app.util.suppress_output import SuppressOutput from invokeai.backend.model_manager.config import ( AnyModelConfig, @@ -309,8 +310,10 @@ async def update_model_record( """Update a model's config.""" logger = ApiDependencies.invoker.services.logger record_store = ApiDependencies.invoker.services.model_manager.store + installer = ApiDependencies.invoker.services.model_manager.install try: - model_response: AnyModelConfig = record_store.update_model(key, changes=changes) + record_store.update_model(key, changes=changes) + model_response: AnyModelConfig = installer.sync_model_path(key) logger.info(f"Updated model: {key}") except UnknownModelException as e: raise HTTPException(status_code=404, detail=str(e)) diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 0c44571eeed..c44d2c3bcbc 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -468,6 +468,19 @@ def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: def sync_to_config(self) -> None: """Synchronize models on disk to those in the model record database.""" + @abstractmethod + def sync_model_path(self, key: str) -> AnyModelConfig: + """ + Move model into the location indicated by its basetype, type and name. + + Call this after updating a model's attributes in order to move + the model's path into the location indicated by its basetype, type and + name. Applies only to models whose paths are within the root `models_dir` + directory. + + May raise an UnknownModelException. + """ + @abstractmethod def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path: """ diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 6d71caf28b7..a7e1fab680d 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -526,7 +526,7 @@ def _scan_models_directory(self) -> None: installed.update(self.scan_directory(models_dir)) self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered") - def _sync_model_path(self, key: str) -> AnyModelConfig: + def sync_model_path(self, key: str) -> AnyModelConfig: """ Move model into the location indicated by its basetype, type and name. @@ -538,16 +538,13 @@ def _sync_model_path(self, key: str) -> AnyModelConfig: May raise an UnknownModelException. """ model = self.record_store.get_model(key) - old_path = Path(model.path) - models_dir = self.app_config.models_path + old_path = Path(model.path).resolve() + models_dir = self.app_config.models_path.resolve() - try: - old_path.relative_to(models_dir) + if not old_path.is_relative_to(models_dir): return model - except ValueError: - pass - new_path = models_dir / model.base.value / model.type.value / old_path.name + new_path = (models_dir / model.base.value / model.type.value / model.name).with_suffix(old_path.suffix) if old_path == new_path or new_path.exists() and old_path == new_path.resolve(): return model @@ -559,11 +556,11 @@ def _sync_model_path(self, key: str) -> AnyModelConfig: return model def _scan_register(self, model: Path) -> bool: - if model in self._cached_model_paths: + if model.resolve() in self._cached_model_paths: return True try: id = self.register_path(model) - self._sync_model_path(id) # possibly move it to right place in `models` + self.sync_model_path(id) # possibly move it to right place in `models` self._logger.info(f"Registered {model.name} with id {id}") self._models_installed.add(id) except DuplicateModelException: diff --git a/invokeai/app/services/model_records/__init__.py b/invokeai/app/services/model_records/__init__.py index 7f888cf1f31..4fee477466d 100644 --- a/invokeai/app/services/model_records/__init__.py +++ b/invokeai/app/services/model_records/__init__.py @@ -6,6 +6,7 @@ ModelRecordServiceBase, UnknownModelException, ModelSummary, + ModelRecordChanges, ModelRecordOrderBy, ) from .model_records_sql import ModelRecordServiceSQL # noqa F401 @@ -17,5 +18,6 @@ "InvalidModelException", "UnknownModelException", "ModelSummary", + "ModelRecordChanges", "ModelRecordOrderBy", ] diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index ad9b2bb7a86..032c1fd6a21 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -20,7 +20,7 @@ ModelInstallServiceBase, URLModelSource, ) -from invokeai.app.services.model_records import UnknownModelException +from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 @@ -82,6 +82,18 @@ def test_install( assert model_record.source == embedding_file.as_posix() +def test_rename( + mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig +) -> None: + store = mm2_installer.record_store + key = mm2_installer.install_path(embedding_file) + model_record = store.get_model(key) + assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors") + store.update_model(key, ModelRecordChanges(name="new_name.safetensors", base=BaseModelType("sd-2"))) + new_model_record = mm2_installer.sync_model_path(key) + assert new_model_record.path.endswith("sd-2/embedding/new_name.safetensors") + + @pytest.mark.parametrize( "fixture_name,size,destination", [