Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix minor bugs involving model manager handling of model paths #6024

Merged
merged 5 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@

from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
ModelRecordChanges,
UnknownModelException,
)
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager.config import (
AnyModelConfig,
Expand Down Expand Up @@ -309,8 +310,10 @@ async def update_model_record(
"""Update a model's config."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install
try:
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
record_store.update_model(key, changes=changes)
model_response: AnyModelConfig = installer.sync_model_path(key)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
Expand Down
13 changes: 13 additions & 0 deletions invokeai/app/services/model_install/model_install_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,19 @@ def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
def sync_to_config(self) -> None:
"""Synchronize models on disk to those in the model record database."""

@abstractmethod
def sync_model_path(self, key: str) -> AnyModelConfig:
"""
Move model into the location indicated by its basetype, type and name.

Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.

May raise an UnknownModelException.
"""

@abstractmethod
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
"""
Expand Down
17 changes: 7 additions & 10 deletions invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def _scan_models_directory(self) -> None:
installed.update(self.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")

def _sync_model_path(self, key: str) -> AnyModelConfig:
def sync_model_path(self, key: str) -> AnyModelConfig:
"""
Move model into the location indicated by its basetype, type and name.

Expand All @@ -538,16 +538,13 @@ def _sync_model_path(self, key: str) -> AnyModelConfig:
May raise an UnknownModelException.
"""
model = self.record_store.get_model(key)
old_path = Path(model.path)
models_dir = self.app_config.models_path
old_path = Path(model.path).resolve()
models_dir = self.app_config.models_path.resolve()

try:
old_path.relative_to(models_dir)
if not old_path.is_relative_to(models_dir):
psychedelicious marked this conversation as resolved.
Show resolved Hide resolved
return model
except ValueError:
pass

new_path = models_dir / model.base.value / model.type.value / old_path.name
new_path = (models_dir / model.base.value / model.type.value / model.name).with_suffix(old_path.suffix)

if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
return model
Expand All @@ -559,11 +556,11 @@ def _sync_model_path(self, key: str) -> AnyModelConfig:
return model

def _scan_register(self, model: Path) -> bool:
if model in self._cached_model_paths:
if model.resolve() in self._cached_model_paths:
psychedelicious marked this conversation as resolved.
Show resolved Hide resolved
return True
try:
id = self.register_path(model)
self._sync_model_path(id) # possibly move it to right place in `models`
self.sync_model_path(id) # possibly move it to right place in `models`
self._logger.info(f"Registered {model.name} with id {id}")
self._models_installed.add(id)
except DuplicateModelException:
Expand Down
2 changes: 2 additions & 0 deletions invokeai/app/services/model_records/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ModelRecordServiceBase,
UnknownModelException,
ModelSummary,
ModelRecordChanges,
ModelRecordOrderBy,
)
from .model_records_sql import ModelRecordServiceSQL # noqa F401
Expand All @@ -17,5 +18,6 @@
"InvalidModelException",
"UnknownModelException",
"ModelSummary",
"ModelRecordChanges",
"ModelRecordOrderBy",
]
14 changes: 13 additions & 1 deletion tests/app/services/model_install/test_model_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ModelInstallServiceBase,
URLModelSource,
)
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403

Expand Down Expand Up @@ -82,6 +82,18 @@ def test_install(
assert model_record.source == embedding_file.as_posix()


def test_rename(
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
) -> None:
store = mm2_installer.record_store
key = mm2_installer.install_path(embedding_file)
model_record = store.get_model(key)
assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors")
store.update_model(key, ModelRecordChanges(name="new_name.safetensors", base=BaseModelType("sd-2")))
new_model_record = mm2_installer.sync_model_path(key)
assert new_model_record.path.endswith("sd-2/embedding/new_name.safetensors")


@pytest.mark.parametrize(
"fixture_name,size,destination",
[
Expand Down
Loading