From f291fad1db29142d68b6da090e7861bc004470a8 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Thu, 21 Dec 2023 18:33:17 +0100 Subject: [PATCH 1/9] Add support for IPAdapter FaceID --- src/diffusers/loaders/ip_adapter.py | 33 ++- src/diffusers/loaders/unet.py | 80 +++-- src/diffusers/models/attention_processor.py | 280 ++++++++++++++++++ src/diffusers/models/embeddings.py | 13 +- .../pipeline_stable_diffusion.py | 16 +- 5 files changed, 388 insertions(+), 34 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 158bde436374..a8e5fa0cad31 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Dict, Union +from typing import Dict, Optional, Union import torch from huggingface_hub.utils import validate_hf_hub_args @@ -34,6 +34,8 @@ from ..models.attention_processor import ( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, + LoRAIPAdapterAttnProcessor, + LoRAIPAdapterAttnProcessor2_0, ) logger = logging.get_logger(__name__) @@ -46,8 +48,8 @@ class IPAdapterMixin: def load_ip_adapter( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - subfolder: str, weight_name: str, + subfolder: Optional[str] = None, **kwargs, ): """ @@ -135,14 +137,15 @@ def load_ip_adapter( # load CLIP image encoer here if it has not been registered to the pipeline yet if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: if not isinstance(pretrained_model_name_or_path_or_dict, dict): - logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - pretrained_model_name_or_path_or_dict, - subfolder=os.path.join(subfolder, "image_encoder"), - ).to(self.device, dtype=self.dtype) - self.image_encoder = image_encoder - else: - raise ValueError("`image_encoder` cannot be None when using IP Adapters.") + try: + logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + pretrained_model_name_or_path_or_dict, + subfolder=os.path.join(subfolder, "image_encoder"), + ).to(self.device, dtype=self.dtype) + self.image_encoder = image_encoder + except TypeError: + print("IPAdapter: `subfolder` not found, `image_encoder` is None, use image_embeds.") # create feature extractor if it has not been registered to the pipeline yet if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: @@ -153,5 +156,13 @@ def load_ip_adapter( def set_ip_adapter_scale(self, scale): for attn_processor in self.unet.attn_processors.values(): - if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): + if isinstance( + attn_processor, + ( + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, + LoRAIPAdapterAttnProcessor, + LoRAIPAdapterAttnProcessor2_0, + ), + ): attn_processor.scale = scale diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 7dec43571b1c..9e284d89a349 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -684,13 +684,20 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): diffusers_name = key.replace("proj", "image_embeds") updated_state_dict[diffusers_name] = value - elif "proj.3.weight" in state_dict: + elif "proj.0.weight" in state_dict: # IP-Adapter Full - clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] - cross_attention_dim = state_dict["proj.3.weight"].shape[0] + clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] + clip_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] + multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in + norm_layer = "proj.3.weight" if "proj.3.weight" in state_dict else "norm.weight" + cross_attention_dim = state_dict[norm_layer].shape[0] + num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim image_projection = MLPProjection( - cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim_in, + mult=multiplier, + num_tokens=num_tokens, ) for key, value in state_dict.items(): @@ -744,14 +751,24 @@ def _load_ip_adapter_weights(self, state_dict): AttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAIPAdapterAttnProcessor, + LoRAIPAdapterAttnProcessor2_0, ) + use_lora = False if "proj.weight" in state_dict["image_proj"]: # IP-Adapter num_image_text_embeds = 4 - elif "proj.3.weight" in state_dict["image_proj"]: + elif "proj.0.weight" in state_dict["image_proj"]: # IP-Adapter Full Face num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token + for k in state_dict["ip_adapter"].keys(): + if "lora" in k: + num_image_text_embeds = 4 + use_lora = True + break else: # IP-Adapter Plus num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] @@ -774,20 +791,47 @@ def _load_ip_adapter_weights(self, state_dict): block_id = int(name[len("down_blocks.")]) hidden_size = self.config.block_out_channels[block_id] if cross_attention_dim is None or "motion_modules" in name: - attn_processor_class = ( - AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor - ) - attn_procs[name] = attn_processor_class() + if use_lora: + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=128, + ).to(self.device, dtype=self.dtype) + else: + attn_processor_class = ( + AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + ) + attn_procs[name] = attn_processor_class() else: - attn_processor_class = ( - IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor - ) - attn_procs[name] = attn_processor_class( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - scale=1.0, - num_tokens=num_image_text_embeds, - ).to(dtype=self.dtype, device=self.device) + if use_lora: + attn_processor_class = ( + LoRAIPAdapterAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else LoRAIPAdapterAttnProcessor + ) + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + rank=128, + num_tokens=num_image_text_embeds, + ).to(dtype=self.dtype, device=self.device) + + else: + attn_processor_class = ( + IPAdapterAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else IPAdapterAttnProcessor + ) + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, + ).to(dtype=self.dtype, device=self.device) value_dict = {} for k, w in attn_procs[name].state_dict().items(): diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 23a3e2bb3791..03b84776b6b5 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2344,11 +2344,289 @@ def __call__( return hidden_states +class LoRAIPAdapterAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + lora_scale (`float`, defaults to 1.0): + the weight scale of LoRA. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__( + self, + hidden_size, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + scale=1.0, + num_tokens=4, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAIPAdapterAttnProcessor2_0(nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + lora_scale (`float`, defaults to 1.0): + the weight scale of LoRA. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__( + self, + hidden_size, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + scale=1.0, + num_tokens=4, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + LORA_ATTENTION_PROCESSORS = ( LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor, + LoRAIPAdapterAttnProcessor, + LoRAIPAdapterAttnProcessor2_0, ) ADDED_KV_ATTENTION_PROCESSORS = ( @@ -2369,6 +2647,8 @@ def __call__( LoRAXFormersAttnProcessor, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, + LoRAIPAdapterAttnProcessor, + LoRAIPAdapterAttnProcessor2_0, ) AttentionProcessor = Union[ diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index db68591bdb44..39ab7b07a64f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -462,15 +462,22 @@ def forward(self, image_embeds: torch.FloatTensor): class MLPProjection(nn.Module): - def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): super().__init__() from .attention import FeedForward - self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds: torch.FloatTensor): - return self.norm(self.ff(image_embeds)) + if self.num_tokens == 1: + return self.norm(self.ff(image_embeds)) + else: + x = self.ff(image_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return self.norm(x) class CombinedTimestepLabelEmbeddings(nn.Module): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index b05d0b17dd5a..d24900e4e39e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -785,6 +785,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -836,6 +837,8 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated image embeddings. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. @@ -943,7 +946,14 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - if ip_adapter_image is not None: + if image_embeds is not None: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0).to( + device=device, dtype=prompt_embeds.dtype + ) + negative_image_embeds = torch.zeros_like(image_embeds) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + elif ip_adapter_image is not None: output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, num_images_per_prompt, output_hidden_state @@ -971,7 +981,9 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} if ip_adapter_image is not None or image_embeds is not None else None + ) # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None From 40f626d5fb85aeb2d1429add195989cd4873287b Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Tue, 26 Dec 2023 10:15:48 +0100 Subject: [PATCH 2/9] Add docs --- .../en/using-diffusers/loading_adapters.md | 55 ++++++++++++++++++- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index d9d4a675dd37..853a7905d849 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -485,7 +485,7 @@ image.save("sdxl_t2i.png") -You can use the IP-Adapter face model to apply specific faces to your images. It is an effective way to maintain consistent characters in your image generations. +You can use the IP-Adapter face models to apply specific faces to your images. It is an effective way to maintain consistent characters in your image generations. Weights are loaded with the same method used for the other IP-Adapters. ```python @@ -495,7 +495,7 @@ pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-a -It is recommended to use `DDIMScheduler` and `EulerDiscreteScheduler` for face model. +It is recommended to use `DDIMScheduler` and `EulerDiscreteScheduler` for face models. @@ -549,6 +549,57 @@ image = pipeline( +IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by third-party software, so no image encoder needs to be loaded. +You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`. + +``` +import cv2 +from insightface.app import FaceAnalysis +import numpy as np +from PIL import Image +import torch +from diffusers import StableDiffusionPipeline, DDIMScheduler +from diffusers.utils import load_image + +noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1 +) + +pipeline = StableDiffusionPipeline.from_pretrained( + "SG161222/Realistic_Vision_V4.0_noVAE", + torch_dtype=torch.float16, + scheduler=noise_scheduler +).to("cuda") + +generator = torch.Generator(device="cpu").manual_seed(42) +image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png") + +# Extract image embeddings +app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) +app.prepare(ctx_id=0, det_size=(640, 640)) + +image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB) +faces = app.get(image) +image = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) + +# Load IP Adapter weights and run inference +pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", weight_name="ip-adapter-faceid_sd15.bin") +pipeline.set_ip_adapter_scale(0.7) +images = pipeline( + prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower", + image_embeds=image, + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=20, num_images_per_prompt=1, width=512, height=704, + generator=generator +).images[0] +``` + ### LCM-Lora You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights. From db6550a228941b538f340fb5b65ed16c43a21b88 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Tue, 26 Dec 2023 10:43:43 +0100 Subject: [PATCH 3/9] Move subfolder to kwargs --- src/diffusers/loaders/ip_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index a8e5fa0cad31..df9caa9465d7 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -49,7 +49,6 @@ def load_ip_adapter( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], weight_name: str, - subfolder: Optional[str] = None, **kwargs, ): """ @@ -97,6 +96,7 @@ def load_ip_adapter( local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) user_agent = { "file_type": "attn_procs_weights", From 6c29e66eb023f2805e4a4fd697815e9a0d2c6468 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Tue, 26 Dec 2023 10:45:47 +0100 Subject: [PATCH 4/9] Fix quality --- src/diffusers/loaders/ip_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index df9caa9465d7..0c310019f024 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Dict, Optional, Union +from typing import Dict, Union import torch from huggingface_hub.utils import validate_hf_hub_args From f4141acbe6e52e742c1d4c79b462776eb353cc99 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 3 Jan 2024 09:45:48 +0100 Subject: [PATCH 5/9] Fix image encoder loading --- docs/source/en/using-diffusers/loading_adapters.md | 5 +++-- src/diffusers/loaders/ip_adapter.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index 853a7905d849..f0fb3bb19490 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -549,7 +549,8 @@ image = pipeline( -IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by third-party software, so no image encoder needs to be loaded. +IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded. +You need to install `insightface` and all its requirements to use this model. You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`. ``` @@ -589,7 +590,7 @@ faces = app.get(image) image = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) # Load IP Adapter weights and run inference -pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", weight_name="ip-adapter-faceid_sd15.bin") +pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin") pipeline.set_ip_adapter_scale(0.7) images = pipeline( prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower", diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index a85db3455c13..46a3e7dc62b3 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -48,6 +48,7 @@ class IPAdapterMixin: def load_ip_adapter( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + subfolder: str, weight_name: str, **kwargs, ): @@ -96,7 +97,6 @@ def load_ip_adapter( local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) user_agent = { "file_type": "attn_procs_weights", @@ -136,7 +136,7 @@ def load_ip_adapter( # load CLIP image encoder here if it has not been registered to the pipeline yet if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: - if not isinstance(pretrained_model_name_or_path_or_dict, dict): + if not isinstance(pretrained_model_name_or_path_or_dict, dict) and subfolder is not None: logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") image_encoder = CLIPVisionModelWithProjection.from_pretrained( pretrained_model_name_or_path_or_dict, @@ -144,6 +144,12 @@ def load_ip_adapter( ).to(self.device, dtype=self.dtype) self.image_encoder = image_encoder self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"]) + elif subfolder is None: + logger.warning( + "Cannot load an image encoder because `subfolder` is None. " + "If you do not load an image_encoder, you must extract the" + " image embeddings from the input image and pass them as image_embeds to the pipeline." + ) else: raise ValueError("`image_encoder` cannot be None when using IP Adapters.") From f2a952e21f6eff4d85e312be93d88d4a41c6d8a3 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sat, 6 Jan 2024 12:31:14 +0100 Subject: [PATCH 6/9] Fix loading + add test --- src/diffusers/loaders/ip_adapter.py | 5 +- src/diffusers/loaders/unet.py | 76 +++++++++++++++---- src/diffusers/models/unet_2d_condition.py | 9 ++- .../test_ip_adapter_stable_diffusion.py | 19 +++++ 4 files changed, 89 insertions(+), 20 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 46a3e7dc62b3..12d11bbb1a5e 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -146,9 +146,10 @@ def load_ip_adapter( self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"]) elif subfolder is None: logger.warning( - "Cannot load an image encoder because `subfolder` is None. " - "If you do not load an image_encoder, you must extract the" + "Cannot load an image encoder because `subfolder` is None." + " If you do not load an image_encoder, you must extract the" " image embeddings from the input image and pass them as image_embeds to the pipeline." + " This behaviour is intended only for the IP Adapter FaceID model." ) else: raise ValueError("`image_encoder` cannot be None when using IP Adapters.") diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index aa6cb77bb12e..7660436ea9c6 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -774,11 +774,10 @@ def _load_ip_adapter_weights(self, state_dict): AttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, - LoRAAttnProcessor, - LoRAAttnProcessor2_0, LoRAIPAdapterAttnProcessor, LoRAIPAdapterAttnProcessor2_0, ) + from ..models.lora import LoRALinearLayer use_lora = False if "proj.weight" in state_dict["image_proj"]: @@ -802,7 +801,7 @@ def _load_ip_adapter_weights(self, state_dict): # set ip-adapter cross-attention processors & load state_dict attn_procs = {} - key_id = 1 + key_id = 1 if not use_lora else 0 for name in self.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim if name.startswith("mid_block"): @@ -814,22 +813,67 @@ def _load_ip_adapter_weights(self, state_dict): block_id = int(name[len("down_blocks.")]) hidden_size = self.config.block_out_channels[block_id] if cross_attention_dim is None or "motion_modules" in name: + attn_processor_class = ( + AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + ) + attn_procs[name] = attn_processor_class() if use_lora: - attn_processor_class = ( - LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0] + attn_module = self + for n in name.split(".")[:-1]: + attn_module = getattr(attn_module, n) + # Set the `lora_layer` attribute of the attention-related matrices. + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_q.in_features, + out_features=attn_module.to_q.out_features, + rank=rank, + ) ) - attn_procs[name] = attn_processor_class( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=128, - ).to(self.device, dtype=self.dtype) - else: - attn_processor_class = ( - AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_k.in_features, + out_features=attn_module.to_k.out_features, + rank=rank, + ) + ) + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_v.in_features, + out_features=attn_module.to_v.out_features, + rank=rank, + ) ) - attn_procs[name] = attn_processor_class() + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_out[0].in_features, + out_features=attn_module.to_out[0].out_features, + rank=rank, + ) + ) + + value_dict = {} + for k, module in attn_module.named_children(): + index = "." + if not hasattr(module, "set_lora_layer"): + index = ".0." + module = module[0] + lora_layer = getattr(module, "lora_layer") + for lora_name, w in lora_layer.state_dict().items(): + value_dict.update( + { + f"{k}{index}lora_layer.{lora_name}": state_dict["ip_adapter"][ + f"{key_id}.{k}_lora.{lora_name}" + ] + } + ) + + attn_module.load_state_dict(value_dict, strict=False) + attn_module.to(dtype=self.dtype, device=self.device) + key_id += 1 else: if use_lora: + rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0] attn_processor_class = ( LoRAIPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") @@ -839,7 +883,7 @@ def _load_ip_adapter_weights(self, state_dict): hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, - rank=128, + rank=rank, num_tokens=num_image_text_embeds, ).to(dtype=self.dtype, device=self.device) @@ -861,7 +905,7 @@ def _load_ip_adapter_weights(self, state_dict): value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]}) attn_procs[name].load_state_dict(value_dict) - key_id += 2 + key_id += 2 if not use_lora else 1 self.set_attn_processor(attn_procs) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 7b4f9f5594ea..d72eff04e505 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config @@ -28,7 +29,9 @@ Attention, AttentionProcessor, AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, AttnProcessor, + AttnProcessor2_0, ) from .embeddings import ( GaussianFourierProjection, @@ -682,9 +685,11 @@ def set_default_attn_processor(self): Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnAddedKVProcessor() + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() + processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index 289d2b7d6573..e0bb3ad176f5 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -248,6 +248,25 @@ def test_unload(self): ] assert processors == [True] * len(processors) + def test_unload_faceid(self): + pipeline = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, torch_dtype=self.dtype + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin") + pipeline.set_ip_adapter_scale(0.7) + + pipeline.unload_ip_adapter() + pipeline.unload_lora_weights() + + assert getattr(pipeline, "image_encoder") is None + assert getattr(pipeline, "feature_extractor") is None + processors = [ + isinstance(attn_proc, (AttnProcessor, AttnProcessor2_0)) + for name, attn_proc in pipeline.unet.attn_processors.items() + ] + assert processors == [True] * len(processors) + @slow @require_torch_gpu From 529e968d7206413ff477ea42dc13624a68c52b9f Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 10 Jan 2024 20:58:47 +0100 Subject: [PATCH 7/9] Move to community folder --- .../en/using-diffusers/loading_adapters.md | 56 +- examples/community/README.md | 63 +- examples/community/ip_adapter_face_id.py | 1525 +++++++++++++++++ src/diffusers/loaders/ip_adapter.py | 21 +- src/diffusers/loaders/unet.py | 120 +- src/diffusers/models/attention_processor.py | 280 --- src/diffusers/models/embeddings.py | 13 +- src/diffusers/models/unet_2d_condition.py | 9 +- .../pipeline_stable_diffusion.py | 16 +- src/diffusers/utils/constants.py | 3 +- .../test_ip_adapter_stable_diffusion.py | 19 - 11 files changed, 1615 insertions(+), 510 deletions(-) create mode 100644 examples/community/ip_adapter_face_id.py diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index f0fb3bb19490..d9d4a675dd37 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -485,7 +485,7 @@ image.save("sdxl_t2i.png") -You can use the IP-Adapter face models to apply specific faces to your images. It is an effective way to maintain consistent characters in your image generations. +You can use the IP-Adapter face model to apply specific faces to your images. It is an effective way to maintain consistent characters in your image generations. Weights are loaded with the same method used for the other IP-Adapters. ```python @@ -495,7 +495,7 @@ pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-a -It is recommended to use `DDIMScheduler` and `EulerDiscreteScheduler` for face models. +It is recommended to use `DDIMScheduler` and `EulerDiscreteScheduler` for face model. @@ -549,58 +549,6 @@ image = pipeline( -IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded. -You need to install `insightface` and all its requirements to use this model. -You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`. - -``` -import cv2 -from insightface.app import FaceAnalysis -import numpy as np -from PIL import Image -import torch -from diffusers import StableDiffusionPipeline, DDIMScheduler -from diffusers.utils import load_image - -noise_scheduler = DDIMScheduler( - num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, - steps_offset=1 -) - -pipeline = StableDiffusionPipeline.from_pretrained( - "SG161222/Realistic_Vision_V4.0_noVAE", - torch_dtype=torch.float16, - scheduler=noise_scheduler -).to("cuda") - -generator = torch.Generator(device="cpu").manual_seed(42) -image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png") - -# Extract image embeddings -app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) -app.prepare(ctx_id=0, det_size=(640, 640)) - -image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB) -faces = app.get(image) -image = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) - -# Load IP Adapter weights and run inference -pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin") -pipeline.set_ip_adapter_scale(0.7) -images = pipeline( - prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower", - image_embeds=image, - negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", - num_inference_steps=20, num_images_per_prompt=1, width=512, height=704, - generator=generator -).images[0] -``` - ### LCM-Lora You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights. diff --git a/examples/community/README.md b/examples/community/README.md index 3baab2025880..7c19d5fe3f44 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -56,7 +56,7 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap | AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) | | DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#DemoFusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) | | Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) | - +| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. ```py @@ -3186,4 +3186,63 @@ pipeline = NullTextPipeline.from_pretrained(model_path, scheduler = scheduler, t inverted_latent, uncond = pipeline.invert(input_image, invert_prompt, num_inner_steps=10, early_stop_epsilon= 1e-5, num_inference_steps = steps) pipeline(prompt, uncond, inverted_latent, guidance_scale=7.5, num_inference_steps=steps).images[0].save(input_image+".output.jpg") -``` \ No newline at end of file +``` + +### IP Adapter Face ID +IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded. +You need to install `insightface` and all its requirements to use this model. +You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`. +You have to disable PEFT BACKEND in order to load weights. + +```py +import diffusers +diffusers.utils.USE_PEFT_BACKEND = False +import torch +from diffusers.utils import load_image +import cv2 +import numpy as np +from diffusers import DiffusionPipeline, AutoencoderKL, DDIMScheduler +from insightface.app import FaceAnalysis + + +noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, +) +vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch.float16) +pipeline = DiffusionPipeline.from_pretrained( + "SG161222/Realistic_Vision_V4.0_noVAE", + torch_dtype=torch.float16, + scheduler=noise_scheduler, + vae=vae, + custom_pipeline="./forked/diffusers/examples/community/ip_adapter_face_id.py" +) +pipeline.load_ip_adapter_face_id("h94/IP-Adapter-FaceID", "ip-adapter-faceid_sd15.bin") +pipeline.to("cuda") + +generator = torch.Generator(device="cpu").manual_seed(42) +num_images=2 + +image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png") + +app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) +app.prepare(ctx_id=0, det_size=(640, 640)) +image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB) +faces = app.get(image) +image = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) +images = pipeline( + prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower", + image_embeds=image, + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=20, num_images_per_prompt=num_images, width=512, height=704, + generator=generator +).images + +for i in range(num_images): + images[i].save(f"c{i}.png") +``` diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py new file mode 100644 index 000000000000..e3c5a2c84ee0 --- /dev/null +++ b/examples/community/ip_adapter_face_id.py @@ -0,0 +1,1525 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union +from safetensors import safe_open + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import FusedAttnProcessor2_0 +from diffusers.models.lora import adjust_lora_scale_text_encoder, LoRALinearLayer +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + _get_model_file, + USE_PEFT_BACKEND, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LoRAIPAdapterAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + lora_scale (`float`, defaults to 1.0): + the weight scale of LoRA. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__( + self, + hidden_size, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + scale=1.0, + num_tokens=4, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAIPAdapterAttnProcessor2_0(nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + lora_scale (`float`, defaults to 1.0): + the weight scale of LoRA. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__( + self, + hidden_size, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + scale=1.0, + num_tokens=4, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAdapterFullImageProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): + super().__init__() + from diffusers.models.attention import FeedForward + + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.FloatTensor): + x = self.ff(image_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return self.norm(x) + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class IPAdapterFaceIDStableDiffusionPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_name, **kwargs): + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(model_file, map_location="cpu") + self._load_ip_adapter_weights(state_dict) + + def convert_ip_adapter_image_proj_to_diffusers(self, state_dict): + updated_state_dict = {} + clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] + clip_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] + multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in + norm_layer = "norm.weight" + cross_attention_dim = state_dict[norm_layer].shape[0] + num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim + + image_projection = IPAdapterFullImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim_in, + mult=multiplier, + num_tokens=num_tokens, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj.0", "ff.net.0.proj") + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") + updated_state_dict[diffusers_name] = value + + image_projection.load_state_dict(updated_state_dict) + return image_projection + + def _load_ip_adapter_weights(self, state_dict): + from diffusers.models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + ) + + num_image_text_embeds = 4 + + self.unet.encoder_hid_proj = None + + # set ip-adapter cross-attention processors & load state_dict + attn_procs = {} + key_id = 0 + for name in self.unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = self.unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.unet.config.block_out_channels[block_id] + if cross_attention_dim is None or "motion_modules" in name: + attn_processor_class = ( + AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + ) + attn_procs[name] = attn_processor_class() + rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0] + attn_module = self.unet + for n in name.split(".")[:-1]: + attn_module = getattr(attn_module, n) + # Set the `lora_layer` attribute of the attention-related matrices. + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_q.in_features, + out_features=attn_module.to_q.out_features, + rank=rank, + ) + ) + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_k.in_features, + out_features=attn_module.to_k.out_features, + rank=rank, + ) + ) + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_v.in_features, + out_features=attn_module.to_v.out_features, + rank=rank, + ) + ) + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_out[0].in_features, + out_features=attn_module.to_out[0].out_features, + rank=rank, + ) + ) + + value_dict = {} + for k, module in attn_module.named_children(): + index = "." + if not hasattr(module, "set_lora_layer"): + index = ".0." + module = module[0] + lora_layer = getattr(module, "lora_layer") + for lora_name, w in lora_layer.state_dict().items(): + value_dict.update( + { + f"{k}{index}lora_layer.{lora_name}": state_dict["ip_adapter"][ + f"{key_id}.{k}_lora.{lora_name}" + ] + } + ) + + attn_module.load_state_dict(value_dict, strict=False) + attn_module.to(dtype=self.dtype, device=self.device) + key_id += 1 + else: + rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0] + attn_processor_class = ( + LoRAIPAdapterAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else LoRAIPAdapterAttnProcessor + ) + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + rank=rank, + num_tokens=num_image_text_embeds, + ).to(dtype=self.dtype, device=self.device) + + value_dict = {} + for k, w in attn_procs[name].state_dict().items(): + value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]}) + + attn_procs[name].load_state_dict(value_dict) + key_id += 1 + + self.unet.set_attn_processor(attn_procs) + + # convert IP-Adapter Image Projection layers to diffusers + image_projection = self.convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) + + self.unet.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) + self.unet.config.encoder_hid_dim_type = "ip_image_proj" + + def set_ip_adapter_scale(self, scale): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for attn_processor in unet.attn_processors.values(): + if isinstance(attn_processor, (LoRAIPAdapterAttnProcessor, LoRAIPAdapterAttnProcessor2_0)): + attn_processor.scale = scale + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections + def fuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + """ + self.fusing_unet = False + self.fusing_vae = False + + if unet: + self.fusing_unet = True + self.unet.fuse_qkv_projections() + self.unet.set_attn_processor(FusedAttnProcessor2_0()) + + if vae: + if not isinstance(self.vae, AutoencoderKL): + raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.") + + self.fusing_vae = True + self.vae.fuse_qkv_projections() + self.vae.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections + def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """Disable QKV projection fusion if enabled. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + + """ + if unet: + if not self.fusing_unet: + logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.") + else: + self.unet.unfuse_qkv_projections() + self.fusing_unet = False + + if vae: + if not self.fusing_vae: + logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.") + else: + self.vae.unfuse_qkv_projections() + self.fusing_vae = False + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated image embeddings. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if image_embeds is not None: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0).to( + device=device, dtype=prompt_embeds.dtype + ) + negative_image_embeds = torch.zeros_like(image_embeds) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs ={"image_embeds": image_embeds} if image_embeds is not None else None + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 12d11bbb1a5e..039b6b910a57 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -34,8 +34,6 @@ from ..models.attention_processor import ( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, - LoRAIPAdapterAttnProcessor, - LoRAIPAdapterAttnProcessor2_0, ) logger = logging.get_logger(__name__) @@ -136,7 +134,7 @@ def load_ip_adapter( # load CLIP image encoder here if it has not been registered to the pipeline yet if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: - if not isinstance(pretrained_model_name_or_path_or_dict, dict) and subfolder is not None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") image_encoder = CLIPVisionModelWithProjection.from_pretrained( pretrained_model_name_or_path_or_dict, @@ -144,13 +142,6 @@ def load_ip_adapter( ).to(self.device, dtype=self.dtype) self.image_encoder = image_encoder self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"]) - elif subfolder is None: - logger.warning( - "Cannot load an image encoder because `subfolder` is None." - " If you do not load an image_encoder, you must extract the" - " image embeddings from the input image and pass them as image_embeds to the pipeline." - " This behaviour is intended only for the IP Adapter FaceID model." - ) else: raise ValueError("`image_encoder` cannot be None when using IP Adapters.") @@ -166,15 +157,7 @@ def load_ip_adapter( def set_ip_adapter_scale(self, scale): unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet for attn_processor in unet.attn_processors.values(): - if isinstance( - attn_processor, - ( - IPAdapterAttnProcessor, - IPAdapterAttnProcessor2_0, - LoRAIPAdapterAttnProcessor, - LoRAIPAdapterAttnProcessor2_0, - ), - ): + if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): attn_processor.scale = scale def unload_ip_adapter(self): diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 7660436ea9c6..11a32a92aee8 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -707,20 +707,13 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): diffusers_name = key.replace("proj", "image_embeds") updated_state_dict[diffusers_name] = value - elif "proj.0.weight" in state_dict: + elif "proj.3.weight" in state_dict: # IP-Adapter Full - clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] - clip_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] - multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in - norm_layer = "proj.3.weight" if "proj.3.weight" in state_dict else "norm.weight" - cross_attention_dim = state_dict[norm_layer].shape[0] - num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim + clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] + cross_attention_dim = state_dict["proj.3.weight"].shape[0] image_projection = IPAdapterFullImageProjection( - cross_attention_dim=cross_attention_dim, - image_embed_dim=clip_embeddings_dim_in, - mult=multiplier, - num_tokens=num_tokens, + cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim ) for key, value in state_dict.items(): @@ -774,23 +767,14 @@ def _load_ip_adapter_weights(self, state_dict): AttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, - LoRAIPAdapterAttnProcessor, - LoRAIPAdapterAttnProcessor2_0, ) - from ..models.lora import LoRALinearLayer - use_lora = False if "proj.weight" in state_dict["image_proj"]: # IP-Adapter num_image_text_embeds = 4 - elif "proj.0.weight" in state_dict["image_proj"]: + elif "proj.3.weight" in state_dict["image_proj"]: # IP-Adapter Full Face num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token - for k in state_dict["ip_adapter"].keys(): - if "lora" in k: - num_image_text_embeds = 4 - use_lora = True - break else: # IP-Adapter Plus num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] @@ -801,7 +785,7 @@ def _load_ip_adapter_weights(self, state_dict): # set ip-adapter cross-attention processors & load state_dict attn_procs = {} - key_id = 1 if not use_lora else 0 + key_id = 1 for name in self.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim if name.startswith("mid_block"): @@ -817,95 +801,23 @@ def _load_ip_adapter_weights(self, state_dict): AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor ) attn_procs[name] = attn_processor_class() - if use_lora: - rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0] - attn_module = self - for n in name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - # Set the `lora_layer` attribute of the attention-related matrices. - attn_module.to_q.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_q.in_features, - out_features=attn_module.to_q.out_features, - rank=rank, - ) - ) - attn_module.to_k.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_k.in_features, - out_features=attn_module.to_k.out_features, - rank=rank, - ) - ) - attn_module.to_v.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_v.in_features, - out_features=attn_module.to_v.out_features, - rank=rank, - ) - ) - attn_module.to_out[0].set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_out[0].in_features, - out_features=attn_module.to_out[0].out_features, - rank=rank, - ) - ) - - value_dict = {} - for k, module in attn_module.named_children(): - index = "." - if not hasattr(module, "set_lora_layer"): - index = ".0." - module = module[0] - lora_layer = getattr(module, "lora_layer") - for lora_name, w in lora_layer.state_dict().items(): - value_dict.update( - { - f"{k}{index}lora_layer.{lora_name}": state_dict["ip_adapter"][ - f"{key_id}.{k}_lora.{lora_name}" - ] - } - ) - - attn_module.load_state_dict(value_dict, strict=False) - attn_module.to(dtype=self.dtype, device=self.device) - key_id += 1 else: - if use_lora: - rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0] - attn_processor_class = ( - LoRAIPAdapterAttnProcessor2_0 - if hasattr(F, "scaled_dot_product_attention") - else LoRAIPAdapterAttnProcessor - ) - attn_procs[name] = attn_processor_class( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - scale=1.0, - rank=rank, - num_tokens=num_image_text_embeds, - ).to(dtype=self.dtype, device=self.device) - - else: - attn_processor_class = ( - IPAdapterAttnProcessor2_0 - if hasattr(F, "scaled_dot_product_attention") - else IPAdapterAttnProcessor - ) - attn_procs[name] = attn_processor_class( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - scale=1.0, - num_tokens=num_image_text_embeds, - ).to(dtype=self.dtype, device=self.device) + attn_processor_class = ( + IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor + ) + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, + ).to(dtype=self.dtype, device=self.device) value_dict = {} for k, w in attn_procs[name].state_dict().items(): value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]}) attn_procs[name].load_state_dict(value_dict) - key_id += 2 if not use_lora else 1 + key_id += 2 self.set_attn_processor(attn_procs) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 75faf4e1a74d..ac9563e186bb 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2329,289 +2329,11 @@ def __call__( return hidden_states -class LoRAIPAdapterAttnProcessor(nn.Module): - r""" - Attention processor for IP-Adapater. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - rank (`int`, defaults to 4): - The dimension of the LoRA update matrices. - network_alpha (`int`, *optional*): - Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. - lora_scale (`float`, defaults to 1.0): - the weight scale of LoRA. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. - """ - - def __init__( - self, - hidden_size, - cross_attention_dim=None, - rank=4, - network_alpha=None, - lora_scale=1.0, - scale=1.0, - num_tokens=4, - ): - super().__init__() - - self.rank = rank - self.lora_scale = lora_scale - - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - else: - # get encoder_hidden_states, ip_hidden_states - end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - encoder_hidden_states[:, end_pos:, :], - ) - if attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) - - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) - - hidden_states = hidden_states + self.scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class LoRAIPAdapterAttnProcessor2_0(nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - rank (`int`, defaults to 4): - The dimension of the LoRA update matrices. - network_alpha (`int`, *optional*): - Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. - lora_scale (`float`, defaults to 1.0): - the weight scale of LoRA. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. - """ - - def __init__( - self, - hidden_size, - cross_attention_dim=None, - rank=4, - network_alpha=None, - lora_scale=1.0, - scale=1.0, - num_tokens=4, - ): - super().__init__() - - self.rank = rank - self.lora_scale = lora_scale - - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - else: - # get encoder_hidden_states, ip_hidden_states - end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - encoder_hidden_states[:, end_pos:, :], - ) - if attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - hidden_states = hidden_states + self.scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - LORA_ATTENTION_PROCESSORS = ( LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor, - LoRAIPAdapterAttnProcessor, - LoRAIPAdapterAttnProcessor2_0, ) ADDED_KV_ATTENTION_PROCESSORS = ( @@ -2632,8 +2354,6 @@ def __call__( LoRAXFormersAttnProcessor, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, - LoRAIPAdapterAttnProcessor, - LoRAIPAdapterAttnProcessor2_0, ) AttentionProcessor = Union[ diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 99c9ceb72a93..293b751cb67d 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -463,22 +463,15 @@ def forward(self, image_embeds: torch.FloatTensor): class IPAdapterFullImageProjection(nn.Module): - def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): super().__init__() from .attention import FeedForward - self.num_tokens = num_tokens - self.cross_attention_dim = cross_attention_dim - self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") + self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds: torch.FloatTensor): - if self.num_tokens == 1: - return self.norm(self.ff(image_embeds)) - else: - x = self.ff(image_embeds) - x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) - return self.norm(x) + return self.norm(self.ff(image_embeds)) class CombinedTimestepLabelEmbeddings(nn.Module): diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index d72eff04e505..7b4f9f5594ea 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -16,7 +16,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config @@ -29,9 +28,7 @@ Attention, AttentionProcessor, AttnAddedKVProcessor, - AttnAddedKVProcessor2_0, AttnProcessor, - AttnProcessor2_0, ) from .embeddings import ( GaussianFourierProjection, @@ -685,11 +682,9 @@ def set_default_attn_processor(self): Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = ( - AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() - ) + processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() + processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index a467a7c931c3..dc4ad60ce091 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -789,7 +789,6 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - image_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -841,8 +840,6 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - image_embeds (`torch.FloatTensor`, *optional*): - Pre-generated image embeddings. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. @@ -951,14 +948,7 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - if image_embeds is not None: - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0).to( - device=device, dtype=prompt_embeds.dtype - ) - negative_image_embeds = torch.zeros_like(image_embeds) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) - elif ip_adapter_image is not None: + if ip_adapter_image is not None: output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, num_images_per_prompt, output_hidden_state @@ -986,9 +976,7 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter - added_cond_kwargs = ( - {"image_embeds": image_embeds} if ip_adapter_image is not None or image_embeds is not None else None - ) + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 8850da073e95..a83626e32ed8 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -48,7 +48,8 @@ version.parse(importlib.metadata.version("transformers")).base_version ) >= version.parse(MIN_TRANSFORMERS_VERSION) -USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version +_use_peft = _required_peft_version and _required_transformers_version +USE_PEFT_BACKEND = _use_peft if USE_PEFT_BACKEND and _CHECK_PEFT: dep_version_check("peft") diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index e0bb3ad176f5..289d2b7d6573 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -248,25 +248,6 @@ def test_unload(self): ] assert processors == [True] * len(processors) - def test_unload_faceid(self): - pipeline = StableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", safety_checker=None, torch_dtype=self.dtype - ) - pipeline.to(torch_device) - pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin") - pipeline.set_ip_adapter_scale(0.7) - - pipeline.unload_ip_adapter() - pipeline.unload_lora_weights() - - assert getattr(pipeline, "image_encoder") is None - assert getattr(pipeline, "feature_extractor") is None - processors = [ - isinstance(attn_proc, (AttnProcessor, AttnProcessor2_0)) - for name, attn_proc in pipeline.unet.attn_processors.items() - ] - assert processors == [True] * len(processors) - @slow @require_torch_gpu From 217d9d073981605acab5200fc841f20c798c1449 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 10 Jan 2024 21:19:54 +0100 Subject: [PATCH 8/9] Fix style --- examples/community/README.md | 2 +- examples/community/ip_adapter_face_id.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index f205f3b70b15..2fdbdb414c6f 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -3307,7 +3307,7 @@ pipeline = DiffusionPipeline.from_pretrained( torch_dtype=torch.float16, scheduler=noise_scheduler, vae=vae, - custom_pipeline="./forked/diffusers/examples/community/ip_adapter_face_id.py" + custom_pipeline="ip_adapter_face_id" ) pipeline.load_ip_adapter_face_id("h94/IP-Adapter-FaceID", "ip-adapter-faceid_sd15.bin") pipeline.to("cuda") diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py index e3c5a2c84ee0..d9325742cf49 100644 --- a/examples/community/ip_adapter_face_id.py +++ b/examples/community/ip_adapter_face_id.py @@ -14,12 +14,12 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -from safetensors import safe_open import torch import torch.nn as nn import torch.nn.functional as F from packaging import version +from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.configuration_utils import FrozenDict @@ -27,20 +27,20 @@ from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models.attention_processor import FusedAttnProcessor2_0 -from diffusers.models.lora import adjust_lora_scale_text_encoder, LoRALinearLayer +from diffusers.models.lora import LoRALinearLayer, adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( - _get_model_file, USE_PEFT_BACKEND, + _get_model_file, deprecate, logging, scale_lora_layers, unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -555,7 +555,7 @@ def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_ revision=revision, subfolder=subfolder, user_agent=user_agent, - ) + ) if weight_name.endswith(".safetensors"): state_dict = {"image_proj": {}, "ip_adapter": {}} with safe_open(model_file, framework="pt", device="cpu") as f: @@ -1438,7 +1438,7 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter - added_cond_kwargs ={"image_embeds": image_embeds} if image_embeds is not None else None + added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else None # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None From 73922cff800de90bb3368d7bb0cada62a177db3e Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Thu, 11 Jan 2024 18:59:38 +0100 Subject: [PATCH 9/9] Revert constant update --- src/diffusers/utils/constants.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index a83626e32ed8..8850da073e95 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -48,8 +48,7 @@ version.parse(importlib.metadata.version("transformers")).base_version ) >= version.parse(MIN_TRANSFORMERS_VERSION) -_use_peft = _required_peft_version and _required_transformers_version -USE_PEFT_BACKEND = _use_peft +USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version if USE_PEFT_BACKEND and _CHECK_PEFT: dep_version_check("peft")