Skip to content

Commit

Permalink
IP-Adapter Safetensor Support (#6041)
Browse files Browse the repository at this point in the history
## Summary

This PR adds support for IP Adapter safetensor files for direct usage
inside InvokeAI.

# TEST

You can download the [Composition
Adapters](https://huggingface.co/ostris/ip-composition-adapter) which
weren't previously supported in Invoke and try them out. Every other IP
Adapter model should work too.

If you pick a Safetensor IP Adapter model, you will also need to set
ViT-H or ViT-G next to it. This is a raw implementation. Can refine it
further based on feedback.

Prompt: `Spiderman holding a bunny` -- Exact same composition as the
adapter image.

![opera_UHlo1IyXPT](https://github.com/invoke-ai/InvokeAI/assets/54517381/00bf9f0b-149f-478d-87ca-3252b68d1054)
  • Loading branch information
blessedcoolant authored Apr 3, 2024
2 parents 132aadc + be574cb commit 7da04b8
Show file tree
Hide file tree
Showing 19 changed files with 390 additions and 143 deletions.
75 changes: 50 additions & 25 deletions invokeai/app/invocations/ip_adapter.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from builtins import float
from typing import List, Union
from typing import List, Literal, Union

from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
ModelType,
)


class IPAdapterField(BaseModel):
Expand Down Expand Up @@ -48,20 +49,27 @@ class IPAdapterOutput(BaseInvocationOutput):
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")


CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}


@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""

# Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).", ui_order=1)
ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.IPAdapterModel,
)

clip_vision_model: Literal["auto", "ViT-H", "ViT-G"] = InputField(
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
default="auto",
ui_order=2,
)
weight: Union[float, List[float]] = InputField(
default=1, description="The weight given to the IP-Adapter", title="Weight"
)
Expand All @@ -86,10 +94,21 @@ def validate_begin_end_step_percent(self) -> Self:
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, IPAdapterConfig)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))

if self.clip_vision_model == "auto":
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
else:
raise RuntimeError(
"You need to set the appropriate CLIP Vision model for checkpoint IP Adapter models."
)
else:
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]

image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)

return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
Expand All @@ -102,19 +121,25 @@ def invoke(self, context: InvocationContext) -> IPAdapterOutput:
)

def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
found = False
while not found:
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)

if not len(image_encoder_models) > 0:
context.logger.warning(
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed. \
Downloading and installing now. This may take a while."
)

installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # Wait for up to 10 minutes
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
found = len(image_encoder_models) > 0
if not found:
context.logger.warning(
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
)
context.logger.warning("Downloading and installing now. This may take a while.")
installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
assert len(image_encoder_models) == 1

if len(image_encoder_models) == 0:
context.logger.error("Error while fetching CLIP Vision Image Encoder")
assert len(image_encoder_models) == 1

return image_encoder_models[0]
13 changes: 2 additions & 11 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,7 @@
WithMetadata,
)
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.primitives import (
DenoiseMaskOutput,
ImageOutput,
LatentsOutput,
)
from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
Expand All @@ -68,12 +64,7 @@
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField

Expand Down
13 changes: 3 additions & 10 deletions invokeai/app/invocations/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,8 @@

from pydantic import BaseModel, ConfigDict, Field

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import (
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Expand Down Expand Up @@ -43,6 +35,7 @@ class IPAdapterMetadataField(BaseModel):

image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
Expand Down
106 changes: 72 additions & 34 deletions invokeai/backend/ip_adapter/ip_adapter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed

from typing import Optional, Union
import pathlib
from typing import List, Optional, TypedDict, Union

import safetensors
import safetensors.torch
import torch
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
Expand All @@ -13,10 +16,17 @@
from .resampler import Resampler


class IPAdapterStateDict(TypedDict):
ip_adapter: dict[str, torch.Tensor]
image_proj: dict[str, torch.Tensor]


class ImageProjModel(torch.nn.Module):
"""Image Projection Model"""

def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
def __init__(
self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4
):
super().__init__()

self.cross_attention_dim = cross_attention_dim
Expand All @@ -25,7 +35,7 @@ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extr
self.norm = torch.nn.LayerNorm(cross_attention_dim)

@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
def from_state_dict(cls, state_dict: dict[str, torch.Tensor], clip_extra_context_tokens: int = 4):
"""Initialize an ImageProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
Expand All @@ -45,7 +55,7 @@ def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_toke
model.load_state_dict(state_dict)
return model

def forward(self, image_embeds):
def forward(self, image_embeds: torch.Tensor):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
Expand All @@ -57,7 +67,7 @@ def forward(self, image_embeds):
class MLPProjModel(torch.nn.Module):
"""SD model with image prompt"""

def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
def __init__(self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024):
super().__init__()

self.proj = torch.nn.Sequential(
Expand All @@ -68,7 +78,7 @@ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
)

@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
"""Initialize an MLPProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
Expand All @@ -87,7 +97,7 @@ def from_state_dict(cls, state_dict: dict[torch.Tensor]):
model.load_state_dict(state_dict)
return model

def forward(self, image_embeds):
def forward(self, image_embeds: torch.Tensor):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens

Expand All @@ -97,7 +107,7 @@ class IPAdapter(RawModel):

def __init__(
self,
state_dict: dict[str, torch.Tensor],
state_dict: IPAdapterStateDict,
device: torch.device,
dtype: torch.dtype = torch.float16,
num_tokens: int = 4,
Expand Down Expand Up @@ -129,24 +139,27 @@ def calc_size(self):

return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)

def _init_image_proj_model(self, state_dict):
def _init_image_proj_model(
self, state_dict: dict[str, torch.Tensor]
) -> Union[ImageProjModel, Resampler, MLPProjModel]:
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)

@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
try:
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
except RuntimeError as e:
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e


class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features"""

def _init_image_proj_model(self, state_dict):
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]) -> Union[Resampler, MLPProjModel]:
return Resampler.from_state_dict(
state_dict=state_dict,
depth=4,
Expand All @@ -157,31 +170,32 @@ def _init_image_proj_model(self, state_dict):
).to(self.device, dtype=self.dtype)

@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=self.dtype)
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
-2
]
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
try:
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
except RuntimeError as e:
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e


class IPAdapterFull(IPAdapterPlus):
"""IP-Adapter Plus with full features."""

def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)


class IPAdapterPlusXL(IPAdapterPlus):
"""IP-Adapter Plus for SDXL."""

def _init_image_proj_model(self, state_dict):
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
return Resampler.from_state_dict(
state_dict=state_dict,
depth=4,
Expand All @@ -192,24 +206,48 @@ def _init_image_proj_model(self, state_dict):
).to(self.device, dtype=self.dtype)


def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> IPAdapterStateDict:
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}

if ip_adapter_ckpt_path.suffix == ".safetensors":
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
for key in model.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
else:
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
else:
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path / "ip_adapter.bin"
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")

return state_dict


def build_ip_adapter(
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus]:
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
ip_adapter_ckpt_path: pathlib.Path, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]:
state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type)

if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
# IPAdapter (with ImageProjModel)
if "proj.weight" in state_dict["image_proj"]:
return IPAdapter(state_dict, device=device, dtype=dtype)
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).

# IPAdaterPlus or IPAdapterPlusXL (with Resampler)
elif "proj_in.weight" in state_dict["image_proj"]:
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
# SD1 IP-Adapter Plus
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
return IPAdapterPlus(state_dict, device=device, dtype=dtype) # SD1 IP-Adapter Plus
elif cross_attention_dim == 2048:
# SDXL IP-Adapter Plus
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) # SDXL IP-Adapter Plus
else:
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).

# IPAdapterFull (with MLPProjModel)
elif "proj.0.weight" in state_dict["image_proj"]:
return IPAdapterFull(state_dict, device=device, dtype=dtype)

# Unrecognized IP Adapter Architectures
else:
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")
Loading

0 comments on commit 7da04b8

Please sign in to comment.