From a6881ec885869a9dddbab57cae94ee1fa2e25bf0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 26 Mar 2024 08:56:12 +1100 Subject: [PATCH] tests: add test_probe_handles_state_dict_with_integer_keys --- tests/test_model_probe.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 53720585ef7..8be7089cf52 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -1,9 +1,13 @@ from pathlib import Path import pytest +from torch import tensor from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant +from invokeai.backend.model_manager.config import InvalidModelConfigException from invokeai.backend.model_manager.probe import ( + CkptType, + ModelProbe, VaeFolderProbe, get_default_settings_controlnet_t2i_adapter, get_default_settings_main, @@ -52,3 +56,25 @@ def test_default_settings_main(): assert get_default_settings_main(BaseModelType.StableDiffusionXL).height == 1024 assert get_default_settings_main(BaseModelType.StableDiffusionXLRefiner) is None assert get_default_settings_main(BaseModelType.Any) is None + + +def test_probe_handles_state_dict_with_integer_keys(): + # This structure isn't supported by invoke, but we still need to handle it gracefully. See #6044 + state_dict_with_integer_keys: CkptType = { + 320: ( + { + "linear1.weight": tensor([1.0]), + "linear1.bias": tensor([1.0]), + "linear2.weight": tensor([1.0]), + "linear2.bias": tensor([1.0]), + }, + { + "linear1.weight": tensor([1.0]), + "linear1.bias": tensor([1.0]), + "linear2.weight": tensor([1.0]), + "linear2.bias": tensor([1.0]), + }, + ), + } + with pytest.raises(InvalidModelConfigException): + ModelProbe.get_model_type_from_checkpoint(Path("embedding.pt"), state_dict_with_integer_keys)