Skip to content

Commit

Permalink
tests: add test_probe_handles_state_dict_with_integer_keys
Browse files Browse the repository at this point in the history
  • Loading branch information
psychedelicious committed Mar 25, 2024
1 parent 6ad733f commit a6881ec
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/test_model_probe.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from pathlib import Path

import pytest
from torch import tensor

from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
from invokeai.backend.model_manager.config import InvalidModelConfigException
from invokeai.backend.model_manager.probe import (
CkptType,
ModelProbe,
VaeFolderProbe,
get_default_settings_controlnet_t2i_adapter,
get_default_settings_main,
Expand Down Expand Up @@ -52,3 +56,25 @@ def test_default_settings_main():
assert get_default_settings_main(BaseModelType.StableDiffusionXL).height == 1024
assert get_default_settings_main(BaseModelType.StableDiffusionXLRefiner) is None
assert get_default_settings_main(BaseModelType.Any) is None


def test_probe_handles_state_dict_with_integer_keys():
# This structure isn't supported by invoke, but we still need to handle it gracefully. See #6044
state_dict_with_integer_keys: CkptType = {
320: (
{
"linear1.weight": tensor([1.0]),
"linear1.bias": tensor([1.0]),
"linear2.weight": tensor([1.0]),
"linear2.bias": tensor([1.0]),
},
{
"linear1.weight": tensor([1.0]),
"linear1.bias": tensor([1.0]),
"linear2.weight": tensor([1.0]),
"linear2.bias": tensor([1.0]),
},
),
}
with pytest.raises(InvalidModelConfigException):
ModelProbe.get_model_type_from_checkpoint(Path("embedding.pt"), state_dict_with_integer_keys)

0 comments on commit a6881ec

Please sign in to comment.