Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mm): handle integer state dict keys in probe #6051

Merged
merged 2 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from .util.model_util import lora_token_vector_length, read_checkpoint_meta

CkptType = Dict[str, Any]
CkptType = Dict[str | int, Any]

LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
BaseModelType.StableDiffusion1: {
Expand Down Expand Up @@ -219,7 +219,7 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
ckpt = ckpt.get("state_dict", ckpt)

for key in ckpt.keys():
for key in [str(k) for k in ckpt.keys()]:
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_model_probe.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from pathlib import Path

import pytest
from torch import tensor

from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
from invokeai.backend.model_manager.config import InvalidModelConfigException
from invokeai.backend.model_manager.probe import (
CkptType,
ModelProbe,
VaeFolderProbe,
get_default_settings_controlnet_t2i_adapter,
get_default_settings_main,
Expand Down Expand Up @@ -52,3 +56,25 @@ def test_default_settings_main():
assert get_default_settings_main(BaseModelType.StableDiffusionXL).height == 1024
assert get_default_settings_main(BaseModelType.StableDiffusionXLRefiner) is None
assert get_default_settings_main(BaseModelType.Any) is None


def test_probe_handles_state_dict_with_integer_keys():
# This structure isn't supported by invoke, but we still need to handle it gracefully. See #6044
state_dict_with_integer_keys: CkptType = {
320: (
{
"linear1.weight": tensor([1.0]),
"linear1.bias": tensor([1.0]),
"linear2.weight": tensor([1.0]),
"linear2.bias": tensor([1.0]),
},
{
"linear1.weight": tensor([1.0]),
"linear1.bias": tensor([1.0]),
"linear2.weight": tensor([1.0]),
"linear2.bias": tensor([1.0]),
},
),
}
with pytest.raises(InvalidModelConfigException):
ModelProbe.get_model_type_from_checkpoint(Path("embedding.pt"), state_dict_with_integer_keys)
Loading