Skip to content

Commit

Permalink
fix(nodes): ip adapter uses valid ModelIdentifierField for image en…
Browse files Browse the repository at this point in the history
…coder model

- Add class method to `ModelIdentifierField` to construct the field from a model config
- Use this to construct a valid IP adapter model field
  • Loading branch information
psychedelicious authored and hipsterusername committed Mar 10, 2024
1 parent 145bb45 commit 8c2ff79
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
2 changes: 1 addition & 1 deletion invokeai/app/invocations/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def invoke(self, context: InvocationContext) -> IPAdapterOutput:
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelIdentifierField(key=image_encoder_models[0].key),
image_encoder_model=ModelIdentifierField.from_config(image_encoder_models[0]),
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
Expand Down
15 changes: 14 additions & 1 deletion invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType

from .baseinvocation import (
BaseInvocation,
Expand All @@ -26,6 +26,19 @@ class ModelIdentifierField(BaseModel):
description="The submodel to load, if this is a main model", default=None
)

@classmethod
def from_config(
cls, config: "AnyModelConfig", submodel_type: Optional[SubModelType] = None
) -> "ModelIdentifierField":
return cls(
key=config.key,
hash=config.hash,
name=config.name,
base=config.base,
type=config.type,
submodel_type=submodel_type,
)


class LoRAField(BaseModel):
lora: ModelIdentifierField = Field(description="Info to load lora model")
Expand Down

0 comments on commit 8c2ff79

Please sign in to comment.