From 0ebbd30d31e18b3bd5903e8a5d955a451c346e28 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 4 Jul 2024 14:34:12 +0530 Subject: [PATCH 01/37] add lavender flow transformer --- scripts/convert_lavender_flow_to_diffusers.py | 1 + src/diffusers/models/attention_processor.py | 14 +- src/diffusers/models/normalization.py | 38 +- .../transformers/lavender_transformer_2d.py | 486 ++++++++++++++++++ .../pipelines/lavender_flow/__init__.py | 0 .../lavender_flow/pipeline_lavender_flow.py | 433 ++++++++++++++++ 6 files changed, 964 insertions(+), 8 deletions(-) create mode 100644 scripts/convert_lavender_flow_to_diffusers.py create mode 100644 src/diffusers/models/transformers/lavender_transformer_2d.py create mode 100644 src/diffusers/pipelines/lavender_flow/__init__.py create mode 100644 src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py diff --git a/scripts/convert_lavender_flow_to_diffusers.py b/scripts/convert_lavender_flow_to_diffusers.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/scripts/convert_lavender_flow_to_diffusers.py @@ -0,0 +1 @@ + diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9d495695e330..bbd19975ac61 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -23,6 +23,7 @@ from ..utils import deprecate, logging from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph +from .normalization import FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -115,6 +116,7 @@ def __init__( processor: Optional["AttnProcessor"] = None, out_dim: int = None, context_pre_only=None, + use_fp32_layer_norm=False, ): super().__init__() self.inner_dim = out_dim if out_dim is not None else dim_head * heads @@ -166,8 +168,16 @@ def __init__( self.norm_q = None self.norm_k = None elif qk_norm == "layer_norm": - self.norm_q = nn.LayerNorm(dim_head, eps=eps) - self.norm_k = nn.LayerNorm(dim_head, eps=eps) + self.norm_q = ( + nn.LayerNorm(dim_head, eps=eps) + if not use_fp32_layer_norm + else FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + ) + self.norm_k = ( + nn.LayerNorm(dim_head, eps=eps) + if not use_fp32_layer_norm + else FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + ) else: raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index a1a7ce91d754..72549b94ed45 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -48,6 +48,15 @@ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: return x +# Copied from diffusers.models.transformers.hunyuan_transformer_2d.FP32LayerNorm +class FP32LayerNorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm( + inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps + ).to(origin_dtype) + + class AdaLayerNormZero(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). @@ -57,7 +66,7 @@ class AdaLayerNormZero(nn.Module): num_embeddings (`int`): The size of the embeddings dictionary. """ - def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None): + def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, use_fp32_layer_norm=False, bias=True): super().__init__() if num_embeddings is not None: self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) @@ -65,8 +74,12 @@ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None): self.emb = None self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) + self.norm = ( + nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + if not use_fp32_layer_norm + else FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) + ) def forward( self, @@ -169,22 +182,35 @@ def __init__( eps=1e-5, bias=True, norm_type="layer_norm", + use_fp32_layer_norm=False, ): super().__init__() + if use_fp32_layer_norm and norm_type != "layer_norm": + raise ValueError("`use_fp32_layer_norm` can only be True when `norm_type` is 'layer_norm'.") + self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) if norm_type == "layer_norm": - self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + self.norm = ( + LayerNorm(embedding_dim, eps, elementwise_affine, bias) + if not use_fp32_layer_norm + else FP32LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine, bias=bias) + ) elif norm_type == "rms_norm": self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + elif norm_type == "no_norm": + self.norm = None else: raise ValueError(f"unknown norm_type {norm_type}") def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: - # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + # convert back to the original dtype in case `conditioning_embedding` is upcasted to float32 (needed for hunyuanDiT) emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) scale, shift = torch.chunk(emb, 2, dim=1) - x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + # lavender flow doesn't have a norm at one place here + if self.norm is not None: + x = self.norm(x) + x = x * (1 + scale)[:, None, :] + shift[:, None, :] return x diff --git a/src/diffusers/models/transformers/lavender_transformer_2d.py b/src/diffusers/models/transformers/lavender_transformer_2d.py new file mode 100644 index 000000000000..6435db39d43b --- /dev/null +++ b/src/diffusers/models/transformers/lavender_transformer_2d.py @@ -0,0 +1,486 @@ +# Copyright 2024 Stability AI, Lavender Flow, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention_processor import Attention +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Taken from the original lavender flow inference code. +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +# Lavender Flow patch embed doesn't use convs for projections. +# Additionally, it uses learned positional embeddings. +class LvenderFlowPatchEmbed(nn.Module): + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + pos_embed_max_size=None, + ): + super().__init__() + + self.num_patches = (height // patch_size) * (width // patch_size) + self.pos_embed_max_size = pos_embed_max_size + + self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim) + self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1) + + self.patch_size = patch_size + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + + def forward(self, latent): + batch_size, num_channels, height, width = latent.size() + latent = latent.view( + batch_size, + num_channels, + height // self.patch_size, + self.patch_size, + width // self.patch_size, + self.patch_size, + ) + latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + latent = self.proj(latent) + return latent + self.pos_embed + + +# Taken from the original lavender flow inference code. +# Our feedforward only has GELU but lavender uses SiLU. +class LavenderFlowFeedForward(nn.Module): + def __init__(self, dim, hidden_dim=None) -> None: + super().__init__() + if hidden_dim is None: + hidden_dim = 4 * dim + + n_hidden = int(2 * hidden_dim / 3) + n_hidden = find_multiple(n_hidden, 256) + + self.c_fc1 = nn.Linear(dim, n_hidden, bias=False) + self.c_fc2 = nn.Linear(dim, n_hidden, bias=False) + self.c_proj = nn.Linear(n_hidden, dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.c_fc1(x)) * self.c_fc2(x) + x = self.c_proj(x) + return x + + +class LavenderFlowAttnProcessor2_0: + """Attention processor used typically in processing Lavender Flow.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): + raise ImportError( + "LavenderFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + if encoder_hidden_states is not None: + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + # attention + if encoder_hidden_states is not None: + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if encoder_hidden_states is not None and not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if encoder_hidden_states is not None and context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class LavenderFlowDiTTransformerBlock(nn.module): + """Similar `LavenderFlowTransformerBlock with a single DiT instead of an MMDiT.""" + + def __init__(self, dim, num_attention_heads, attention_head_dim): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim, bias=False, use_fp32_layer_norm=True) + + processor = LavenderFlowAttnProcessor2_0() + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm", + out_dim=dim, + bias=False, + processor=processor, + ) + + self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) + self.ff = LavenderFlowFeedForward(dim, dim * 4) + + def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor): + # Norm + Projection. + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + # Attention. + attn_output = self.attn(hidden_states=norm_hidden_states) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + return hidden_states + + +@maybe_allow_in_graph +class LavenderFlowTransformerBlock(nn.Module): + r""" + Transformer block for Lavender Flow. Similar to SD3 MMDiT. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + is_last (`bool`): Boolean to determine if this is the last block in the model. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim, is_last=False): + super().__init__() + self.is_last = is_last + context_norm_type = "ada_norm_continous" if is_last else "ada_norm_zero" + + self.norm1 = AdaLayerNormZero(dim, bias=False, use_fp32_layer_norm=True) + + if context_norm_type == "ada_norm_continous": + self.norm1_context = AdaLayerNormContinuous( + dim, dim, elementwise_affine=False, bias=False, use_fp32_layer_norm=True + ) + elif context_norm_type == "ada_norm_zero": + self.norm1_context = AdaLayerNormZero(dim, bias=False, use_fp32_layer_norm=True) + + processor = LavenderFlowAttnProcessor2_0() + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm", + out_dim=dim, + context_pre_only=is_last, + bias=False, + processor=processor, + ) + + self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) + self.ff = LavenderFlowFeedForward(dim, dim * 4) + self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False) + if not is_last: + self.ff_context = LavenderFlowFeedForward(dim, dim * 4) + else: + self.ff_context = None + + def forward( + self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor + ): + # Norm + Projection. + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return encoder_hidden_states, hidden_states + + +class LavenderFlowTransformer2DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 64, + patch_size: int = 2, + in_channels: int = 4, + num_layers: int = 36, + num_single_dit_layers: int = 32, + attention_head_dim: int = 256, + num_attention_heads: int = 12, + joint_attention_dim: int = 2048, + caption_projection_dim: int = 3072, + out_channels: int = 4, + pos_embed_max_size: int = 1024, + ): + super().__init__() + default_out_channels = in_channels + self.out_channels = out_channels if out_channels is not None else default_out_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = LvenderFlowPatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=pos_embed_max_size, + ) + + self.context_embedder = nn.Linear( + self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False + ) + self.time_step_embed = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) + + self.joint_transformer_blocks = nn.ModuleList( + [ + LavenderFlowTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + is_last=i == num_layers - 1, + ) + for i in range(self.config.num_layers) + ] + ) + self.single_transformer_blocks = nn.ModuleList( + [ + LavenderFlowDiTTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for _ in range(self.config.num_single_dit_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, norm_type="no_norm") + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + # https://arxiv.org/abs/2309.16588 + # prevents artifacts in the attention maps + self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + use_register_tokens: bool = True, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + height, width = hidden_states.shape[-2:] + + # Apply patch embedding, timestep embedding, and project the caption embeddings. + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype) + temb = self.time_step_proj(temb) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + # This doesn't apply to the negative prompt embeds. So, we need to keep that in mind. + if use_register_tokens: + encoder_hidden_states = torch.cat( + [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 + ) + + for index_block, block in enumerate(self.joint_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + + if len(self.single_transformer_blocks) > 0: + for index_block, block in enumerate(self.single_transformer_blocks): + encoder_seq_len = encoder_hidden_states.size(1) + combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + combined_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + combined_hidden_states, + temb, + **ckpt_kwargs, + ) + + else: + combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb) + + hidden_states = combined_hidden_states[:, encoder_seq_len:] + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.config.patch_size + out_channels = self.config.out_channels + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/lavender_flow/__init__.py b/src/diffusers/pipelines/lavender_flow/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py new file mode 100644 index 000000000000..766d3250c25e --- /dev/null +++ b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py @@ -0,0 +1,433 @@ +# Copyright 2024 Lavender-Flow Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, List, Optional, Tuple, Union + +import torch +from transformers import T5Tokenizer, UMT5EncoderModel + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + logging, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LavenderFlowPipeline(DiffusionPipeline): + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKL, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + """ + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + text_inputs = {k: v.to(device) for k, v in text_inputs.items()} + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(**text_inputs)[0] + prompt_attention_mask = text_inputs.attention_mask.unsqueeze(-1).expand(prompt_embeds.shape) + prompt_embeds = prompt_embeds * prompt_attention_mask + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + uncond_input = {k: v.to(device) for k, v in uncond_input.items()} + negative_prompt_embeds = self.text_encoder(**uncond_input)[0] + negative_prompt_attention_mask = uncond_input.attention_mask.unsqueeze(-1).expand( + negative_prompt_embeds.shape + ) + negative_prompt_embeds = negative_prompt_embeds * negative_prompt_attention_mask + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "This is watermark, jpeg image white background, web image", + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = 512, + width: Optional[int] = 512, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + max_sequence_length: int = 256, + ) -> Union[ImagePipelineOutput, Tuple]: + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + # timesteps, num_inference_steps = retrieve_timesteps( + # self.scheduler, num_inference_steps, device, timesteps, sigmas + # ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=current_timestep, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) From 939d990cbbdb978c17f6a7132f7d99528875cf4d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 4 Jul 2024 18:42:21 +0530 Subject: [PATCH 02/37] progress. --- scripts/convert_lavender_flow_to_diffusers.py | 128 ++++++++++++++++++ src/diffusers/models/attention_processor.py | 9 +- src/diffusers/models/embeddings.py | 4 +- src/diffusers/models/normalization.py | 6 +- .../transformers/hunyuan_transformer_2d.py | 6 +- .../transformers/lavender_transformer_2d.py | 49 +++---- 6 files changed, 164 insertions(+), 38 deletions(-) diff --git a/scripts/convert_lavender_flow_to_diffusers.py b/scripts/convert_lavender_flow_to_diffusers.py index 8b137891791f..5791bdac9eb5 100644 --- a/scripts/convert_lavender_flow_to_diffusers.py +++ b/scripts/convert_lavender_flow_to_diffusers.py @@ -1 +1,129 @@ +import argparse +import torch +from huggingface_hub import hf_hub_download + +from diffusers.models.transformers.lavender_transformer_2d import LavenderFlowTransformer2DModel + + +def load_original_state_dict(args): + model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="model.bin") + state_dict = torch.load(model_pt, map_location="cpu") + return state_dict + + +def calculate_layers(state_dict_keys, key_prefix): + dit_layers = set() + for k in state_dict_keys: + if key_prefix in k: + dit_layers.add(int(k.split(".")[2])) + print(f"{key_prefix}: {len(dit_layers)}") + return len(dit_layers) + + +def convert_transformer(state_dict): + converted_state_dict = {} + state_dict_keys = list(state_dict.keys()) + + converted_state_dict["register_tokens"] = state_dict.pop("model.register_tokens") + converted_state_dict["pos_embed.pos_embed"] = state_dict.pop("model.positional_encoding") + converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("model.init_x_linear.weight") + converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("model.init_x_linear.bias") + + converted_state_dict["time_step_proj.linear_1.weight"] = state_dict.pop("model.t_embedder.mlp.0.weight") + converted_state_dict["time_step_proj.linear_1.bias"] = state_dict.pop("model.t_embedder.mlp.0.bias") + converted_state_dict["time_step_proj.linear_2.weight"] = state_dict.pop("model.t_embedder.mlp.2.weight") + converted_state_dict["time_step_proj.linear_2.bias"] = state_dict.pop("model.t_embedder.mlp.2.bias") + + converted_state_dict["context_embedder.weight"] = state_dict.pop("model.cond_seq_linear.weight") + + mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") + single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") + + # MMDiT blocks 🎸. + for i in range(mmdit_layers): + # feed-forward + for path in ["mlpX", "mlpC"]: + diffuser_path = "ff" if path == "mlpX" else "ff_context" + for k in ["c_fc1", "c_fc2", "c_proj"]: + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_path}.{k}.weight"] = state_dict.pop( + f"model.double_layers.{i}.{path}.{k}.weight" + ) + + # norms + for path in ["modX", "modC"]: + diffuser_path = "norm1" if path == "modX" else "norm1_context" + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_path}.linear.weight"] = state_dict.pop( + f"model.double_layers.{i}.{path}.1.weight" + ) + + # attns + x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"} + context_attn_mapping = {"w2q": "add_q_proj", "w2k": "add_k_proj", "w2v": "add_v_proj", "w2o": "to_add_out"} + for attn_mapping in [x_attn_mapping, context_attn_mapping]: + for k, v in attn_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop( + f"model.double_layers.{i}.attn.{k}.weight" + ) + + # Single-DiT blocks. + for i in range(single_dit_layers): + # feed-forward + for k in ["c_fc1", "c_fc2", "c_proj"]: + converted_state_dict[f"single_transformer_blocks.{i}.ff.{k}.weight"] = state_dict.pop( + f"model.single_layers.{i}.mlp.{k}.weight" + ) + + # norms + converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = state_dict.pop( + f"model.single_layers.{i}.modCX.1.weight" + ) + + # attns + x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"} + for k, v in x_attn_mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop( + f"model.single_layers.{i}.attn.{k}.weight" + ) + + # Final blocks. + converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight") + converted_state_dict["norm_out.linear.weight"] = state_dict.pop("model.modF.1.weight") + + return converted_state_dict + + +@torch.no_grad() +def populate_state_dict(args): + original_state_dict = load_original_state_dict(args) + state_dict_keys = list(original_state_dict.keys()) + mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") + single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") + + converted_state_dict = convert_transformer(original_state_dict) + + # with init_empty_weights(): + # model_diffusers = LavenderFlowTransformer2DModel(num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers) + + # with torch.no_grad(): + # unexpected_keys = load_model_dict_into_meta(model_diffusers, converted_state_dict) + model_diffusers = LavenderFlowTransformer2DModel( + num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers + ) + model_diffusers.load_state_dict(converted_state_dict, strict=True) + # assert len(unexpected_keys) == 0, "Something wrong." + + return model_diffusers + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str) + parser.add_argument("--dump_path", default="lavender-flow", type=str) + parser.add_argument("--hub_id", default=None, type=str) + args = parser.parse_args() + + model_diffusers = populate_state_dict(args) + model_diffusers.save_pretrained(args.dump_path) + if args.hub_id is not None: + model_diffusers.push_to_hub(args.hub_id) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index bbd19975ac61..1906a75c6283 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -23,7 +23,6 @@ from ..utils import deprecate, logging from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph -from .normalization import FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -119,6 +118,8 @@ def __init__( use_fp32_layer_norm=False, ): super().__init__() + from .normalization import FP32LayerNorm + self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.query_dim = query_dim self.use_bias = bias @@ -215,10 +216,10 @@ def __init__( self.to_v = None if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) - self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=out_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=out_bias) if self.context_pre_only is not None: - self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=out_bias) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index cb6cb065dd32..a42f79b90b6f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -386,11 +386,12 @@ def forward(self, sample, condition=None): class Timesteps(nn.Module): - def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift + self.scale = scale def forward(self, timesteps): t_emb = get_timestep_embedding( @@ -398,6 +399,7 @@ def forward(self, timesteps): self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, ) return t_emb diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 72549b94ed45..820d1310ccfd 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -53,7 +53,11 @@ class FP32LayerNorm(nn.LayerNorm): def forward(self, inputs: torch.Tensor) -> torch.Tensor: origin_dtype = inputs.dtype return F.layer_norm( - inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps + inputs.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, ).to(origin_dtype) diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 8313ffd87a50..f5b6ee122947 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -39,7 +39,11 @@ class FP32LayerNorm(nn.LayerNorm): def forward(self, inputs: torch.Tensor) -> torch.Tensor: origin_dtype = inputs.dtype return F.layer_norm( - inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps + inputs.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, ).to(origin_dtype) diff --git a/src/diffusers/models/transformers/lavender_transformer_2d.py b/src/diffusers/models/transformers/lavender_transformer_2d.py index 6435db39d43b..d1897885b241 100644 --- a/src/diffusers/models/transformers/lavender_transformer_2d.py +++ b/src/diffusers/models/transformers/lavender_transformer_2d.py @@ -41,7 +41,7 @@ def find_multiple(n: int, k: int) -> int: # Lavender Flow patch embed doesn't use convs for projections. # Additionally, it uses learned positional embeddings. -class LvenderFlowPatchEmbed(nn.Module): +class LavenderFlowPatchEmbed(nn.Module): def __init__( self, height=224, @@ -117,19 +117,7 @@ def __call__( **kwargs, ) -> torch.FloatTensor: residual = hidden_states - - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - if encoder_hidden_states is not None: - context_input_ndim = encoder_hidden_states.ndim - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size = encoder_hidden_states.shape[0] + batch_size = hidden_states.shape[0] # `sample` projections. query = attn.to_q(hidden_states) @@ -171,21 +159,16 @@ def __call__( hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) - if encoder_hidden_states is not None and not attn.context_pre_only: + if encoder_hidden_states is not None: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - if encoder_hidden_states is not None and context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - if encoder_hidden_states is not None: return hidden_states, encoder_hidden_states else: return hidden_states -class LavenderFlowDiTTransformerBlock(nn.module): +class LavenderFlowDiTTransformerBlock(nn.Module): """Similar `LavenderFlowTransformerBlock with a single DiT instead of an MMDiT.""" def __init__(self, dim, num_attention_heads, attention_head_dim): @@ -200,8 +183,10 @@ def __init__(self, dim, num_attention_heads, attention_head_dim): dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm", + use_fp32_layer_norm=True, out_dim=dim, bias=False, + out_bias=False, processor=processor, ) @@ -263,10 +248,12 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, is_last=False): dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm", + use_fp32_layer_norm=True, out_dim=dim, - context_pre_only=is_last, bias=False, + out_bias=False, processor=processor, + context_pre_only=False, ) self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) @@ -323,7 +310,7 @@ def __init__( sample_size: int = 64, patch_size: int = 2, in_channels: int = 4, - num_layers: int = 36, + num_mmdit_layers: int = 4, num_single_dit_layers: int = 32, attention_head_dim: int = 256, num_attention_heads: int = 12, @@ -337,7 +324,7 @@ def __init__( self.out_channels = out_channels if out_channels is not None else default_out_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.pos_embed = LvenderFlowPatchEmbed( + self.pos_embed = LavenderFlowPatchEmbed( height=self.config.sample_size, width=self.config.sample_size, patch_size=self.config.patch_size, @@ -349,7 +336,7 @@ def __init__( self.context_embedder = nn.Linear( self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False ) - self.time_step_embed = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.time_step_embed = Timesteps(num_channels=256, downscale_freq_shift=0, scale=1000, flip_sin_to_cos=True) self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) self.joint_transformer_blocks = nn.ModuleList( @@ -358,9 +345,8 @@ def __init__( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, - is_last=i == num_layers - 1, ) - for i in range(self.config.num_layers) + for i in range(self.config.num_mmdit_layers) ] ) self.single_transformer_blocks = nn.ModuleList( @@ -374,7 +360,7 @@ def __init__( ] ) - self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, norm_type="no_norm") + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, norm_type="no_norm", bias=False) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) # https://arxiv.org/abs/2309.16588 @@ -401,6 +387,7 @@ def forward( hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype) temb = self.time_step_proj(temb) + print(f"{temb[0, :4]=}") encoder_hidden_states = self.context_embedder(encoder_hidden_states) # This doesn't apply to the negative prompt embeds. So, we need to keep that in mind. if use_register_tokens: @@ -435,10 +422,10 @@ def custom_forward(*inputs): ) if len(self.single_transformer_blocks) > 0: - for index_block, block in enumerate(self.single_transformer_blocks): - encoder_seq_len = encoder_hidden_states.size(1) - combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + encoder_seq_len = encoder_hidden_states.size(1) + combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + for index_block, block in enumerate(self.single_transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): From f08baf33ae738926ff71ee9e19b9c19a6dae8a02 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 4 Jul 2024 19:29:31 +0530 Subject: [PATCH 03/37] progress --- scripts/convert_lavender_flow_to_diffusers.py | 20 +++++++++---------- .../transformers/lavender_transformer_2d.py | 16 ++++++++------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/scripts/convert_lavender_flow_to_diffusers.py b/scripts/convert_lavender_flow_to_diffusers.py index 5791bdac9eb5..0d4b7b71e32f 100644 --- a/scripts/convert_lavender_flow_to_diffusers.py +++ b/scripts/convert_lavender_flow_to_diffusers.py @@ -43,23 +43,23 @@ def convert_transformer(state_dict): # MMDiT blocks 🎸. for i in range(mmdit_layers): # feed-forward - for path in ["mlpX", "mlpC"]: - diffuser_path = "ff" if path == "mlpX" else "ff_context" + path_mapping = {"mlpX": "ff", "mlpC": "ff_context"} + for orig_k, diffuser_k in path_mapping.items(): for k in ["c_fc1", "c_fc2", "c_proj"]: - converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_path}.{k}.weight"] = state_dict.pop( - f"model.double_layers.{i}.{path}.{k}.weight" + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{k}.weight"] = state_dict.pop( + f"model.double_layers.{i}.{orig_k}.{k}.weight" ) # norms - for path in ["modX", "modC"]: - diffuser_path = "norm1" if path == "modX" else "norm1_context" - converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_path}.linear.weight"] = state_dict.pop( - f"model.double_layers.{i}.{path}.1.weight" + path_mapping = {"modX": "norm1", "modC": "norm1_context"} + for orig_k, diffuser_k in path_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = state_dict.pop( + f"model.double_layers.{i}.{orig_k}.1.weight" ) # attns - x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"} - context_attn_mapping = {"w2q": "add_q_proj", "w2k": "add_k_proj", "w2v": "add_v_proj", "w2o": "to_add_out"} + x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"} + context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"} for attn_mapping in [x_attn_mapping, context_attn_mapping]: for k, v in attn_mapping.items(): converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop( diff --git a/src/diffusers/models/transformers/lavender_transformer_2d.py b/src/diffusers/models/transformers/lavender_transformer_2d.py index d1897885b241..0fe0d31dbf11 100644 --- a/src/diffusers/models/transformers/lavender_transformer_2d.py +++ b/src/diffusers/models/transformers/lavender_transformer_2d.py @@ -116,7 +116,6 @@ def __call__( *args, **kwargs, ) -> torch.FloatTensor: - residual = hidden_states batch_size = hidden_states.shape[0] # `sample` projections. @@ -132,9 +131,9 @@ def __call__( # attention if encoder_hidden_states is not None: - query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) - key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) - value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -151,8 +150,8 @@ def __call__( # Split the attention outputs. if encoder_hidden_states is not None: hidden_states, encoder_hidden_states = ( - hidden_states[:, : residual.shape[1]], - hidden_states[:, residual.shape[1] :], + hidden_states[:, encoder_hidden_states.shape[1] :], + hidden_states[:, : encoder_hidden_states.shape[1]], ) # linear proj @@ -385,15 +384,17 @@ def forward( # Apply patch embedding, timestep embedding, and project the caption embeddings. hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + # print(f"{hidden_states[0, :4, :4]=}") temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype) temb = self.time_step_proj(temb) - print(f"{temb[0, :4]=}") + # print(f"{temb[0, :4]=}") encoder_hidden_states = self.context_embedder(encoder_hidden_states) # This doesn't apply to the negative prompt embeds. So, we need to keep that in mind. if use_register_tokens: encoder_hidden_states = torch.cat( [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 ) + # print(f"{encoder_hidden_states[0, :4, :4]=}") for index_block, block in enumerate(self.joint_transformer_blocks): if self.training and self.gradient_checkpointing: @@ -420,6 +421,7 @@ def custom_forward(*inputs): encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ) + # print(f"{encoder_hidden_states[0, :4, :4]=}") if len(self.single_transformer_blocks) > 0: encoder_seq_len = encoder_hidden_states.size(1) From 005a8f6b2e25e7a86fe217097a3619fa12daf50c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 09:05:48 +0530 Subject: [PATCH 04/37] progress --- scripts/convert_lavender_flow_to_diffusers.py | 16 +-- src/diffusers/models/attention_processor.py | 18 +++ .../transformers/lavender_transformer_2d.py | 116 ++++++++++-------- 3 files changed, 94 insertions(+), 56 deletions(-) diff --git a/scripts/convert_lavender_flow_to_diffusers.py b/scripts/convert_lavender_flow_to_diffusers.py index 0d4b7b71e32f..f6f8a5064002 100644 --- a/scripts/convert_lavender_flow_to_diffusers.py +++ b/scripts/convert_lavender_flow_to_diffusers.py @@ -21,6 +21,13 @@ def calculate_layers(state_dict_keys, key_prefix): return len(dit_layers) +# similar to SD3 but only for the last norm layer +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + def convert_transformer(state_dict): converted_state_dict = {} state_dict_keys = list(state_dict.keys()) @@ -88,7 +95,7 @@ def convert_transformer(state_dict): # Final blocks. converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight") - converted_state_dict["norm_out.linear.weight"] = state_dict.pop("model.modF.1.weight") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict.pop("model.modF.1.weight"), dim=None) return converted_state_dict @@ -101,17 +108,10 @@ def populate_state_dict(args): single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") converted_state_dict = convert_transformer(original_state_dict) - - # with init_empty_weights(): - # model_diffusers = LavenderFlowTransformer2DModel(num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers) - - # with torch.no_grad(): - # unexpected_keys = load_model_dict_into_meta(model_diffusers, converted_state_dict) model_diffusers = LavenderFlowTransformer2DModel( num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers ) model_diffusers.load_state_dict(converted_state_dict, strict=True) - # assert len(unexpected_keys) == 0, "Something wrong." return model_diffusers diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1906a75c6283..cfd73c8e8f19 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -102,6 +102,7 @@ def __init__( cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, qk_norm: Optional[str] = None, + added_qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, @@ -182,6 +183,23 @@ def __init__( else: raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + if added_qk_norm is None: + self.norm_added_q = None + self.norm_added_k = None + elif added_qk_norm == "layer_norm": + self.norm_added_q = ( + nn.LayerNorm(dim_head, eps=eps) + if not use_fp32_layer_norm + else FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + ) + self.norm_added_k = ( + nn.LayerNorm(dim_head, eps=eps) + if not use_fp32_layer_norm + else FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + ) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": diff --git a/src/diffusers/models/transformers/lavender_transformer_2d.py b/src/diffusers/models/transformers/lavender_transformer_2d.py index 0fe0d31dbf11..46190a290edd 100644 --- a/src/diffusers/models/transformers/lavender_transformer_2d.py +++ b/src/diffusers/models/transformers/lavender_transformer_2d.py @@ -113,6 +113,7 @@ def __call__( attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, + i=0, *args, **kwargs, ) -> torch.FloatTensor: @@ -129,18 +130,43 @@ def __call__( encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - # attention + # Reshape. + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + # Apply QK norm. + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Concatenate the projections. if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) + query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + # Attention. hidden_states = F.scaled_dot_product_attention( query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False ) @@ -167,6 +193,7 @@ def __call__( return hidden_states +@maybe_allow_in_graph class LavenderFlowDiTTransformerBlock(nn.Module): """Similar `LavenderFlowTransformerBlock with a single DiT instead of an MMDiT.""" @@ -192,23 +219,21 @@ def __init__(self, dim, num_attention_heads, attention_head_dim): self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) self.ff = LavenderFlowFeedForward(dim, dim * 4) - def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor): + def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9999): + residual = hidden_states + # Norm + Projection. norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) # Attention. - attn_output = self.attn(hidden_states=norm_hidden_states) + attn_output = self.attn(hidden_states=norm_hidden_states, i=i) # Process attention outputs for the `hidden_states`. - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = hidden_states + attn_output - - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = hidden_states + ff_output + hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) + hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(hidden_states) + hidden_states = gate_mlp.unsqueeze(1) * ff_output + hidden_states = residual + hidden_states return hidden_states @@ -216,7 +241,11 @@ def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor): @maybe_allow_in_graph class LavenderFlowTransformerBlock(nn.Module): r""" - Transformer block for Lavender Flow. Similar to SD3 MMDiT. + Transformer block for Lavender Flow. Similar to SD3 MMDiT. Differences (non-exhaustive): + + * QK Norm in the attention blocks + * No bias in the attention blocks + * Most LayerNorms are in FP32 Parameters: dim (`int`): The number of channels in the input and output. @@ -247,6 +276,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, is_last=False): dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm", + added_qk_norm="layer_norm", use_fp32_layer_norm=True, out_dim=dim, bias=False, @@ -264,8 +294,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, is_last=False): self.ff_context = None def forward( - self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor + self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=0 ): + residual = hidden_states + residual_context = encoder_hidden_states + # Norm + Projection. norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( @@ -274,28 +307,20 @@ def forward( # Attention. attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, i=i ) # Process attention outputs for the `hidden_states`. - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = hidden_states + attn_output - - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = hidden_states + ff_output + hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) + hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + hidden_states = gate_mlp.unsqueeze(1) * self.ff(hidden_states) + hidden_states = residual + hidden_states # Process attention outputs for the `encoder_hidden_states`. - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + encoder_hidden_states = self.norm2_context(residual_context + c_gate_msa.unsqueeze(1) * context_attn_output) + encoder_hidden_states = encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + encoder_hidden_states = c_gate_mlp.unsqueeze(1) * self.ff_context(encoder_hidden_states) + encoder_hidden_states = residual_context + encoder_hidden_states return encoder_hidden_states, hidden_states @@ -377,25 +402,20 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, timestep: torch.LongTensor = None, - use_register_tokens: bool = True, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: height, width = hidden_states.shape[-2:] # Apply patch embedding, timestep embedding, and project the caption embeddings. hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. - # print(f"{hidden_states[0, :4, :4]=}") temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype) temb = self.time_step_proj(temb) - # print(f"{temb[0, :4]=}") encoder_hidden_states = self.context_embedder(encoder_hidden_states) - # This doesn't apply to the negative prompt embeds. So, we need to keep that in mind. - if use_register_tokens: - encoder_hidden_states = torch.cat( - [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 - ) - # print(f"{encoder_hidden_states[0, :4, :4]=}") + encoder_hidden_states = torch.cat( + [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 + ) + # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): if self.training and self.gradient_checkpointing: @@ -419,10 +439,10 @@ def custom_forward(*inputs): else: encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, i=index_block ) - # print(f"{encoder_hidden_states[0, :4, :4]=}") + # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text) if len(self.single_transformer_blocks) > 0: encoder_seq_len = encoder_hidden_states.size(1) combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -450,7 +470,7 @@ def custom_forward(*inputs): else: combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb) - hidden_states = combined_hidden_states[:, encoder_seq_len:] + hidden_states = combined_hidden_states[:, encoder_seq_len:] hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From e238d56a97ba7b0720aca0f8e29a97bf88662ef2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 09:07:25 +0530 Subject: [PATCH 05/37] move out the attention processor. --- src/diffusers/models/attention_processor.py | 96 ++++++++++++++++++- .../transformers/lavender_transformer_2d.py | 96 +------------------ 2 files changed, 96 insertions(+), 96 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cfd73c8e8f19..f89222a31721 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -22,7 +22,7 @@ from ..image_processor import IPAdapterMaskProcessor from ..utils import deprecate, logging from ..utils.import_utils import is_torch_npu_available, is_xformers_available -from ..utils.torch_utils import maybe_allow_in_graph +from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1162,6 +1162,100 @@ def __call__( return hidden_states, encoder_hidden_states +class LavenderFlowAttnProcessor2_0: + """Attention processor used typically in processing Lavender Flow.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): + raise ImportError( + "LavenderFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + i=0, + *args, + **kwargs, + ) -> torch.FloatTensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + # Reshape. + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + # Apply QK norm. + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Concatenate the projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) + + query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Attention. + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, encoder_hidden_states.shape[1] :], + hidden_states[:, : encoder_hidden_states.shape[1]], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if encoder_hidden_states is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class XFormersAttnAddedKVProcessor: r""" Processor for implementing memory efficient attention using xFormers. diff --git a/src/diffusers/models/transformers/lavender_transformer_2d.py b/src/diffusers/models/transformers/lavender_transformer_2d.py index 46190a290edd..b50de3035fb6 100644 --- a/src/diffusers/models/transformers/lavender_transformer_2d.py +++ b/src/diffusers/models/transformers/lavender_transformer_2d.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention_processor import Attention +from ..attention_processor import Attention, LavenderFlowAttnProcessor2_0 from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -99,100 +99,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class LavenderFlowAttnProcessor2_0: - """Attention processor used typically in processing Lavender Flow.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): - raise ImportError( - "LavenderFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - i=0, - *args, - **kwargs, - ) -> torch.FloatTensor: - batch_size = hidden_states.shape[0] - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - # `context` projections. - if encoder_hidden_states is not None: - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - # Reshape. - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim) - key = key.view(batch_size, -1, attn.heads, head_dim) - value = value.view(batch_size, -1, attn.heads, head_dim) - - # Apply QK norm. - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Concatenate the projections. - if encoder_hidden_states is not None: - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) - - query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - # Attention. - hidden_states = F.scaled_dot_product_attention( - query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # Split the attention outputs. - if encoder_hidden_states is not None: - hidden_states, encoder_hidden_states = ( - hidden_states[:, encoder_hidden_states.shape[1] :], - hidden_states[:, : encoder_hidden_states.shape[1]], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - if encoder_hidden_states is not None: - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - if encoder_hidden_states is not None: - return hidden_states, encoder_hidden_states - else: - return hidden_states - - @maybe_allow_in_graph class LavenderFlowDiTTransformerBlock(nn.Module): """Similar `LavenderFlowTransformerBlock with a single DiT instead of an MMDiT.""" From 570c25890d9feadd2f3b69624cb117da54863c47 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 09:25:08 +0530 Subject: [PATCH 06/37] finish implementation of pipeline --- .../lavender_flow/pipeline_lavender_flow.py | 110 ++++++++++++------ 1 file changed, 73 insertions(+), 37 deletions(-) diff --git a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py index 766d3250c25e..4a62013deea7 100644 --- a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py +++ b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import Callable, List, Optional, Tuple, Union import torch @@ -29,6 +30,66 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class LavenderFlowPipeline(DiffusionPipeline): _optional_components = ["tokenizer", "text_encoder"] model_cpu_offload_seq = "text_encoder->transformer->vae" @@ -342,12 +403,11 @@ def __call__( ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps - # timesteps, num_inference_steps = retrieve_timesteps( - # self.scheduler, num_inference_steps, device, timesteps, sigmas - # ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels @@ -362,37 +422,20 @@ def __call__( latents, ) - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Denoising loop + # 6. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - current_timestep = t - if not torch.is_tensor(current_timestep): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = latent_model_input.device.type == "mps" - if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) - elif len(current_timestep.shape) == 0: - current_timestep = current_timestep[None].to(latent_model_input.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - current_timestep = current_timestep.expand(latent_model_input.shape[0]) + timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output noise_pred = self.transformer( latent_model_input, encoder_hidden_states=prompt_embeds, - timestep=current_timestep, + timestep=timestep, return_dict=False, )[0] @@ -401,13 +444,8 @@ def __call__( noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # learned sigma - if self.transformer.config.out_channels // 2 == latent_channels: - noise_pred = noise_pred.chunk(2, dim=1)[0] - else: - noise_pred = noise_pred - - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -416,12 +454,10 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) - if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - else: + if output_type == "latent": image = latents - - if not output_type == "latent": + else: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models From b881190cbc6d953f953d1fa26d7140dba7d5f65b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 09:26:14 +0530 Subject: [PATCH 07/37] default neg promot --- src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py index 4a62013deea7..77c459320c7c 100644 --- a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py +++ b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py @@ -182,7 +182,7 @@ def encode_prompt( self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, - negative_prompt: str = "", + negative_prompt: str = "This is watermark, jpeg image white background, web image", num_images_per_prompt: int = 1, device: Optional[torch.device] = None, prompt_embeds: Optional[torch.Tensor] = None, From b8237b2c3882b2ee103d1c062955aca5415534f9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 09:29:25 +0530 Subject: [PATCH 08/37] up --- src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py index 77c459320c7c..987a9c45e251 100644 --- a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py +++ b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py @@ -339,7 +339,6 @@ def __call__( num_images_per_prompt: Optional[int] = 1, height: Optional[int] = 512, width: Optional[int] = 512, - eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, From 89eea61363b542108abdd7943b385f3c85852c5b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 09:42:06 +0530 Subject: [PATCH 09/37] fixes --- .../lavender_flow/pipeline_lavender_flow.py | 54 +++++++++++++++---- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py index 987a9c45e251..ccca71bfacbd 100644 --- a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py +++ b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py @@ -19,10 +19,9 @@ from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, SD3Transformer2DModel +from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - logging, -) +from ...utils import logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -304,29 +303,57 @@ def encode_prompt( return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) + if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma return latents + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + @torch.no_grad() def __call__( self, @@ -456,6 +483,11 @@ def __call__( if output_type == "latent": image = latents else: + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) From 2c97d0419cd9c853f9ca9bab278a6315986d7de5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 13:47:48 +0530 Subject: [PATCH 10/37] up --- .../lavender_flow/pipeline_lavender_flow.py | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py index ccca71bfacbd..c87f610b65df 100644 --- a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py +++ b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py @@ -110,6 +110,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs def check_inputs( self, prompt, @@ -224,9 +225,7 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] - # See Section 3.1. of the paper. max_length = max_sequence_length - if prompt_embeds is None: text_inputs = self.tokenizer( prompt, @@ -236,7 +235,7 @@ def encode_prompt( return_tensors="pt", ) text_inputs = {k: v.to(device) for k, v in text_inputs.items()} - text_input_ids = text_inputs.input_ids + text_input_ids = text_inputs["input_ids"] untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( @@ -249,7 +248,7 @@ def encode_prompt( ) prompt_embeds = self.text_encoder(**text_inputs)[0] - prompt_attention_mask = text_inputs.attention_mask.unsqueeze(-1).expand(prompt_embeds.shape) + prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) prompt_embeds = prompt_embeds * prompt_attention_mask if self.text_encoder is not None: @@ -265,7 +264,7 @@ def encode_prompt( # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.reshape(bs_embed, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) # get unconditional embeddings for classifier free guidance @@ -281,8 +280,8 @@ def encode_prompt( ) uncond_input = {k: v.to(device) for k, v in uncond_input.items()} negative_prompt_embeds = self.text_encoder(**uncond_input)[0] - negative_prompt_attention_mask = uncond_input.attention_mask.unsqueeze(-1).expand( - negative_prompt_embeds.shape + negative_prompt_attention_mask = ( + uncond_input["attention_mask"].unsqueeze(-1).expand(negative_prompt_embeds.shape) ) negative_prompt_embeds = negative_prompt_embeds * negative_prompt_attention_mask @@ -295,7 +294,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.reshape(bs_embed, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) else: negative_prompt_embeds = None @@ -394,7 +393,7 @@ def __call__( negative_prompt_attention_mask, ) - # 2. Default height and width to transformer + # 2. Determine batch size. if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -437,6 +436,7 @@ def __call__( # 5. Prepare latents. latent_channels = self.transformer.config.in_channels + effective_batch_size = batch_size * num_images_per_prompt latents = self.prepare_latents( batch_size * num_images_per_prompt, latent_channels, @@ -450,12 +450,21 @@ def __call__( # 6. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + dt = 1.0 / num_inference_steps + dt = ( + torch.tensor([dt] * effective_batch_size) + .to(self.device) + .view([effective_batch_size, *([1] * len(latents.shape[1:]))]) + ) with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): + for i, t in enumerate(range(num_inference_steps, 0, -1)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) + t = t / num_inference_steps + timestep = ( + torch.tensor([t]).expand(latent_model_input.shape[0]).to(latents.device, dtype=latents.dtype) + ) # predict noise model_output noise_pred = self.transformer( @@ -471,7 +480,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + latents = (latents - dt * noise_pred).to(latents.dtype) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): From a50e1ff88ac9bdb6771aaa96ded5105c3597a78a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 13:53:30 +0530 Subject: [PATCH 11/37] up for pr --- src/diffusers/__init__.py | 4 ++ src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 2 + .../pipelines/lavender_flow/__init__.py | 48 +++++++++++++++++++ .../lavender_flow/pipeline_lavender_flow.py | 4 +- src/diffusers/utils/dummy_pt_objects.py | 15 ++++++ .../dummy_torch_and_transformers_objects.py | 15 ++++++ 8 files changed, 89 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6f80cab0f357..652136732212 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -88,6 +88,7 @@ "HunyuanDiT2DMultiControlNetModel", "I2VGenXLUNet", "Kandinsky3UNet", + "LavenderFlowTransformer2DModel", "ModelMixin", "MotionAdapter", "MultiAdapter", @@ -267,6 +268,7 @@ "KandinskyV22PriorPipeline", "LatentConsistencyModelImg2ImgPipeline", "LatentConsistencyModelPipeline", + "LavenderFlowPipeline", "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", @@ -509,6 +511,7 @@ HunyuanDiT2DMultiControlNetModel, I2VGenXLUNet, Kandinsky3UNet, + LavenderFlowTransformer2DModel, ModelMixin, MotionAdapter, MultiAdapter, @@ -666,6 +669,7 @@ KandinskyV22PriorPipeline, LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, + LavenderFlowPipeline, LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f3fda596aa71..f1f7ff82c35a 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -41,6 +41,7 @@ _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] + _import_structure["transformers.lavender_transformer_2d"] = ["LavenderFlowTransformer2DModel"] _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] @@ -85,6 +86,7 @@ DiTTransformer2DModel, DualTransformer2DModel, HunyuanDiT2DModel, + LavenderFlowTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, SD3Transformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 04bd21b70737..77f5f64321ce 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -5,6 +5,7 @@ from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel from .hunyuan_transformer_2d import HunyuanDiT2DModel + from .lavender_transformer_2d import LavenderFlowTransformer2DModel from .pixart_transformer_2d import PixArtTransformer2DModel from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 4f135c9e43aa..6fa7671f2321 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -242,6 +242,7 @@ "StableDiffusionLDM3DPipeline", ] ) + _import_structure["lavender_flow"] = ["LavenderFlowPipeline"] _import_structure["stable_diffusion_3"] = ["StableDiffusion3Pipeline", "StableDiffusion3Img2ImgPipeline"] _import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] @@ -480,6 +481,7 @@ LatentConsistencyModelPipeline, ) from .latent_diffusion import LDMTextToImagePipeline + from .lavender_flow import LavenderFlowPipeline from .ledits_pp import ( LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput, diff --git a/src/diffusers/pipelines/lavender_flow/__init__.py b/src/diffusers/pipelines/lavender_flow/__init__.py index e69de29bb2d1..c81c38e262c0 100644 --- a/src/diffusers/pipelines/lavender_flow/__init__.py +++ b/src/diffusers/pipelines/lavender_flow/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_lavender_flow"] = ["LavenderFlowPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_lavender_flow import LavenderFlowPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py index c87f610b65df..5c656761c3ef 100644 --- a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py +++ b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py @@ -18,7 +18,7 @@ from transformers import T5Tokenizer, UMT5EncoderModel from ...image_processor import VaeImageProcessor -from ...models import AutoencoderKL, SD3Transformer2DModel +from ...models import AutoencoderKL, LavenderFlowTransformer2DModel from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging @@ -98,7 +98,7 @@ def __init__( tokenizer: T5Tokenizer, text_encoder: UMT5EncoderModel, vae: AutoencoderKL, - transformer: SD3Transformer2DModel, + transformer: LavenderFlowTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 354ce7e0ba34..6224afa076fb 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -122,6 +122,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LavenderFlowTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HunyuanDiT2DControlNetModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a1bb667128df..3beccb913dbd 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -962,6 +962,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LavenderFlowPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionAdapterPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From b0d29b29261d45187b78f36c1e5e0c6697bed9e5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 14:05:39 +0530 Subject: [PATCH 12/37] fix copies --- src/diffusers/utils/dummy_pt_objects.py | 12 ++++---- .../dummy_torch_and_transformers_objects.py | 30 +++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6224afa076fb..9170582911ca 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -122,7 +122,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class LavenderFlowTransformer2DModel(metaclass=DummyObject): +class HunyuanDiT2DControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -137,7 +137,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HunyuanDiT2DControlNetModel(metaclass=DummyObject): +class HunyuanDiT2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -152,7 +152,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HunyuanDiT2DModel(metaclass=DummyObject): +class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -167,7 +167,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject): +class I2VGenXLUNet(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -182,7 +182,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class I2VGenXLUNet(metaclass=DummyObject): +class Kandinsky3UNet(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -197,7 +197,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class Kandinsky3UNet(metaclass=DummyObject): +class LavenderFlowTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3beccb913dbd..25ef41e0f291 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -677,6 +677,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LavenderFlowPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LDMTextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -962,21 +977,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class LavenderFlowPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class StableDiffusionAdapterPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From ad6cb66c3be96b73834e0cb0cc2cf9e826e45379 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 14:13:51 +0530 Subject: [PATCH 13/37] move fp32 layer norm to normalization --- src/diffusers/models/normalization.py | 1 - .../models/transformers/hunyuan_transformer_2d.py | 15 +-------------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 820d1310ccfd..fe3cc07aefa1 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -48,7 +48,6 @@ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: return x -# Copied from diffusers.models.transformers.hunyuan_transformer_2d.FP32LayerNorm class FP32LayerNorm(nn.LayerNorm): def forward(self, inputs: torch.Tensor) -> torch.Tensor: origin_dtype = inputs.dtype diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index f5b6ee122947..cc0dcbd79e9f 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -14,7 +14,6 @@ from typing import Dict, Optional, Union import torch -import torch.nn.functional as F from torch import nn from ...configuration_utils import ConfigMixin, register_to_config @@ -29,24 +28,12 @@ ) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous +from ..normalization import AdaLayerNormContinuous, FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class FP32LayerNorm(nn.LayerNorm): - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - origin_dtype = inputs.dtype - return F.layer_norm( - inputs.float(), - self.normalized_shape, - self.weight.float() if self.weight is not None else None, - self.bias.float() if self.bias is not None else None, - self.eps, - ).to(origin_dtype) - - class AdaLayerNormShift(nn.Module): r""" Norm layer modified to incorporate timestep embeddings. From 8ae6be768ce508ed89eb849bf0b75b689b020bb9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 14:21:16 +0530 Subject: [PATCH 14/37] minor fixes --- .../pipelines/lavender_flow/pipeline_lavender_flow.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py index 5c656761c3ef..05778f1a147c 100644 --- a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py +++ b/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py @@ -199,8 +199,7 @@ def encode_prompt( prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` - instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For - PixArt-Alpha, this should be "". + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not num_images_per_prompt (`int`, *optional*, defaults to 1): @@ -211,9 +210,8 @@ def encode_prompt( Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" - string. - max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + Pre-generated negative text embeddings. + max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. """ if device is None: device = self._execution_device From 47ff911fee5f54d1794db31be4f7bb68db8a51b1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 15:29:38 +0530 Subject: [PATCH 15/37] remove boolean flag and resort to norm_type --- src/diffusers/models/attention_processor.py | 31 ++++++------------- src/diffusers/models/normalization.py | 26 +++++++--------- .../transformers/lavender_transformer_2d.py | 16 +++++----- 3 files changed, 29 insertions(+), 44 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f89222a31721..3007aaa91685 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -116,7 +116,6 @@ def __init__( processor: Optional["AttnProcessor"] = None, out_dim: int = None, context_pre_only=None, - use_fp32_layer_norm=False, ): super().__init__() from .normalization import FP32LayerNorm @@ -170,16 +169,11 @@ def __init__( self.norm_q = None self.norm_k = None elif qk_norm == "layer_norm": - self.norm_q = ( - nn.LayerNorm(dim_head, eps=eps) - if not use_fp32_layer_norm - else FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - ) - self.norm_k = ( - nn.LayerNorm(dim_head, eps=eps) - if not use_fp32_layer_norm - else FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - ) + self.norm_q = nn.LayerNorm(dim_head, eps=eps) + self.norm_k = nn.LayerNorm(dim_head, eps=eps) + elif qk_norm == "fp32_layer_norm": + self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) else: raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") @@ -187,16 +181,11 @@ def __init__( self.norm_added_q = None self.norm_added_k = None elif added_qk_norm == "layer_norm": - self.norm_added_q = ( - nn.LayerNorm(dim_head, eps=eps) - if not use_fp32_layer_norm - else FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - ) - self.norm_added_k = ( - nn.LayerNorm(dim_head, eps=eps) - if not use_fp32_layer_norm - else FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - ) + self.norm_added_q = nn.LayerNorm(dim_head, eps=eps) + self.norm_added_k = nn.LayerNorm(dim_head, eps=eps) + elif added_qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) else: raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index fe3cc07aefa1..3b87eeb822c0 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -69,7 +69,7 @@ class AdaLayerNormZero(nn.Module): num_embeddings (`int`): The size of the embeddings dictionary. """ - def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, use_fp32_layer_norm=False, bias=True): + def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True): super().__init__() if num_embeddings is not None: self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) @@ -78,11 +78,14 @@ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, use self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) - self.norm = ( - nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) - if not use_fp32_layer_norm - else FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) - ) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + elif norm_type == "fp32_layer_norm": + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) def forward( self, @@ -185,20 +188,15 @@ def __init__( eps=1e-5, bias=True, norm_type="layer_norm", - use_fp32_layer_norm=False, ): super().__init__() - if use_fp32_layer_norm and norm_type != "layer_norm": - raise ValueError("`use_fp32_layer_norm` can only be True when `norm_type` is 'layer_norm'.") self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) if norm_type == "layer_norm": - self.norm = ( - LayerNorm(embedding_dim, eps, elementwise_affine, bias) - if not use_fp32_layer_norm - else FP32LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine, bias=bias) - ) + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "fp32_layer_norm": + self.norm = FP32LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine, bias=bias) elif norm_type == "rms_norm": self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) elif norm_type == "no_norm": diff --git a/src/diffusers/models/transformers/lavender_transformer_2d.py b/src/diffusers/models/transformers/lavender_transformer_2d.py index b50de3035fb6..c896431e72ba 100644 --- a/src/diffusers/models/transformers/lavender_transformer_2d.py +++ b/src/diffusers/models/transformers/lavender_transformer_2d.py @@ -106,7 +106,7 @@ class LavenderFlowDiTTransformerBlock(nn.Module): def __init__(self, dim, num_attention_heads, attention_head_dim): super().__init__() - self.norm1 = AdaLayerNormZero(dim, bias=False, use_fp32_layer_norm=True) + self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") processor = LavenderFlowAttnProcessor2_0() self.attn = Attention( @@ -114,8 +114,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim): cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, - qk_norm="layer_norm", - use_fp32_layer_norm=True, + qk_norm="fp32_layer_norm", out_dim=dim, bias=False, out_bias=False, @@ -165,14 +164,14 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, is_last=False): self.is_last = is_last context_norm_type = "ada_norm_continous" if is_last else "ada_norm_zero" - self.norm1 = AdaLayerNormZero(dim, bias=False, use_fp32_layer_norm=True) + self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") if context_norm_type == "ada_norm_continous": self.norm1_context = AdaLayerNormContinuous( - dim, dim, elementwise_affine=False, bias=False, use_fp32_layer_norm=True + dim, dim, elementwise_affine=False, bias=False, norm_type="fp32_layer_norm" ) elif context_norm_type == "ada_norm_zero": - self.norm1_context = AdaLayerNormZero(dim, bias=False, use_fp32_layer_norm=True) + self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") processor = LavenderFlowAttnProcessor2_0() self.attn = Attention( @@ -181,9 +180,8 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, is_last=False): added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, - qk_norm="layer_norm", - added_qk_norm="layer_norm", - use_fp32_layer_norm=True, + qk_norm="fp32_layer_norm", + added_qk_norm="fp32_layer_norm", out_dim=dim, bias=False, out_bias=False, From 10ed96f666a05b05c05ebdc1284afe30ec00235f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 15:34:18 +0530 Subject: [PATCH 16/37] eliminate added_qk_norm --- src/diffusers/models/attention_processor.py | 24 +++++++++---------- .../transformers/lavender_transformer_2d.py | 1 - 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 3007aaa91685..620ed1a41d74 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -102,7 +102,6 @@ def __init__( cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, qk_norm: Optional[str] = None, - added_qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, @@ -177,18 +176,6 @@ def __init__( else: raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") - if added_qk_norm is None: - self.norm_added_q = None - self.norm_added_k = None - elif added_qk_norm == "layer_norm": - self.norm_added_q = nn.LayerNorm(dim_head, eps=eps) - self.norm_added_k = nn.LayerNorm(dim_head, eps=eps) - elif added_qk_norm == "fp32_layer_norm": - self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - else: - raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") - if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": @@ -235,6 +222,17 @@ def __init__( if self.context_pre_only is not None and not self.context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "layer_norm": + self.norm_added_q = nn.LayerNorm(dim_head, eps=eps) + self.norm_added_k = nn.LayerNorm(dim_head, eps=eps) + elif qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + else: + self.norm_added_q = None + self.norm_added_k = None + # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention diff --git a/src/diffusers/models/transformers/lavender_transformer_2d.py b/src/diffusers/models/transformers/lavender_transformer_2d.py index c896431e72ba..e33a2937e0de 100644 --- a/src/diffusers/models/transformers/lavender_transformer_2d.py +++ b/src/diffusers/models/transformers/lavender_transformer_2d.py @@ -181,7 +181,6 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, is_last=False): dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="fp32_layer_norm", - added_qk_norm="fp32_layer_norm", out_dim=dim, bias=False, out_bias=False, From 3d9265e9d67d3d5a40195c7db88d140949c7ab9d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 15:41:55 +0530 Subject: [PATCH 17/37] add added_proj_bias --- src/diffusers/models/attention_processor.py | 7 ++++--- .../models/transformers/lavender_transformer_2d.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 620ed1a41d74..257f393525b0 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -103,6 +103,7 @@ def __init__( cross_attention_norm_num_groups: int = 32, qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = None, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, out_bias: bool = True, @@ -210,10 +211,10 @@ def __init__( self.to_v = None if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=out_bias) - self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=out_bias) + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) if self.context_pre_only is not None: - self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=out_bias) + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) diff --git a/src/diffusers/models/transformers/lavender_transformer_2d.py b/src/diffusers/models/transformers/lavender_transformer_2d.py index e33a2937e0de..896273cdfaa3 100644 --- a/src/diffusers/models/transformers/lavender_transformer_2d.py +++ b/src/diffusers/models/transformers/lavender_transformer_2d.py @@ -178,6 +178,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, is_last=False): query_dim=dim, cross_attention_dim=None, added_kv_proj_dim=dim, + added_proj_bias=False, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="fp32_layer_norm", From 84708c40be5cf55f62657f2ff81587b2d2c9f2d4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 6 Jul 2024 07:51:58 +0530 Subject: [PATCH 18/37] lavender flow -> aura flow --- scripts/convert_lavender_flow_to_diffusers.py | 4 +- src/diffusers/__init__.py | 8 ++-- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/attention_processor.py | 4 +- src/diffusers/models/transformers/__init__.py | 2 +- ...ormer_2d.py => auraflow_transformer_2d.py} | 42 +++++++++---------- src/diffusers/pipelines/__init__.py | 4 +- .../{lavender_flow => aura_flow}/__init__.py | 4 +- .../pipeline_aura_flow.py} | 8 ++-- src/diffusers/utils/dummy_pt_objects.py | 26 ++++++------ .../dummy_torch_and_transformers_objects.py | 30 ++++++------- 11 files changed, 68 insertions(+), 68 deletions(-) rename src/diffusers/models/transformers/{lavender_transformer_2d.py => auraflow_transformer_2d.py} (92%) rename src/diffusers/pipelines/{lavender_flow => aura_flow}/__init__.py (89%) rename src/diffusers/pipelines/{lavender_flow/pipeline_lavender_flow.py => aura_flow/pipeline_aura_flow.py} (98%) diff --git a/scripts/convert_lavender_flow_to_diffusers.py b/scripts/convert_lavender_flow_to_diffusers.py index f6f8a5064002..cf500c0aedcc 100644 --- a/scripts/convert_lavender_flow_to_diffusers.py +++ b/scripts/convert_lavender_flow_to_diffusers.py @@ -3,7 +3,7 @@ import torch from huggingface_hub import hf_hub_download -from diffusers.models.transformers.lavender_transformer_2d import LavenderFlowTransformer2DModel +from diffusers.models.transformers.auraflow_transformer_2d import AuraFlowTransformer2DModel def load_original_state_dict(args): @@ -108,7 +108,7 @@ def populate_state_dict(args): single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") converted_state_dict = convert_transformer(original_state_dict) - model_diffusers = LavenderFlowTransformer2DModel( + model_diffusers = AuraFlowTransformer2DModel( num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers ) model_diffusers.load_state_dict(converted_state_dict, strict=True) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 652136732212..604bca829dca 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -76,6 +76,7 @@ _import_structure["models"].extend( [ "AsymmetricAutoencoderKL", + "AuraFlowTransformer2DModel", "AutoencoderKL", "AutoencoderKLTemporalDecoder", "AutoencoderTiny", @@ -88,7 +89,6 @@ "HunyuanDiT2DMultiControlNetModel", "I2VGenXLUNet", "Kandinsky3UNet", - "LavenderFlowTransformer2DModel", "ModelMixin", "MotionAdapter", "MultiAdapter", @@ -233,6 +233,7 @@ "AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel", "AudioLDMPipeline", + "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", "CLIPImageProjection", @@ -268,7 +269,6 @@ "KandinskyV22PriorPipeline", "LatentConsistencyModelImg2ImgPipeline", "LatentConsistencyModelPipeline", - "LavenderFlowPipeline", "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", @@ -499,6 +499,7 @@ else: from .models import ( AsymmetricAutoencoderKL, + AuraFlowTransformer2DModel, AutoencoderKL, AutoencoderKLTemporalDecoder, AutoencoderTiny, @@ -511,7 +512,6 @@ HunyuanDiT2DMultiControlNetModel, I2VGenXLUNet, Kandinsky3UNet, - LavenderFlowTransformer2DModel, ModelMixin, MotionAdapter, MultiAdapter, @@ -636,6 +636,7 @@ AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, AudioLDMPipeline, + AuraFlowPipeline, CLIPImageProjection, CycleDiffusionPipeline, HunyuanDiTControlNetPipeline, @@ -669,7 +670,6 @@ KandinskyV22PriorPipeline, LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, - LavenderFlowPipeline, LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f1f7ff82c35a..4051a062034c 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -38,10 +38,10 @@ _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] + _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] - _import_structure["transformers.lavender_transformer_2d"] = ["LavenderFlowTransformer2DModel"] _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] @@ -83,10 +83,10 @@ from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( + AuraFlowTransformer2DModel, DiTTransformer2DModel, DualTransformer2DModel, HunyuanDiT2DModel, - LavenderFlowTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, SD3Transformer2DModel, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 257f393525b0..12a1eef34db2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1150,13 +1150,13 @@ def __call__( return hidden_states, encoder_hidden_states -class LavenderFlowAttnProcessor2_0: +class AuraFlowAttnProcessor2_0: """Attention processor used typically in processing Lavender Flow.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): raise ImportError( - "LavenderFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " + "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " ) def __call__( diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 77f5f64321ce..b1b90241a41e 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -2,10 +2,10 @@ if is_torch_available(): + from .auraflow_transformer_2d import AuraFlowTransformer2DModel from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel from .hunyuan_transformer_2d import HunyuanDiT2DModel - from .lavender_transformer_2d import LavenderFlowTransformer2DModel from .pixart_transformer_2d import PixArtTransformer2DModel from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder diff --git a/src/diffusers/models/transformers/lavender_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py similarity index 92% rename from src/diffusers/models/transformers/lavender_transformer_2d.py rename to src/diffusers/models/transformers/auraflow_transformer_2d.py index 896273cdfaa3..feb19910d843 100644 --- a/src/diffusers/models/transformers/lavender_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -1,4 +1,4 @@ -# Copyright 2024 Stability AI, Lavender Flow, The HuggingFace Team. All rights reserved. +# Copyright 2024 AuraFlow Authors, The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention_processor import Attention, LavenderFlowAttnProcessor2_0 +from ..attention_processor import Attention, AuraFlowAttnProcessor2_0 from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -32,16 +32,16 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# Taken from the original lavender flow inference code. +# Taken from the original aura flow inference code. def find_multiple(n: int, k: int) -> int: if n % k == 0: return n return n + k - (n % k) -# Lavender Flow patch embed doesn't use convs for projections. +# Aura Flow patch embed doesn't use convs for projections. # Additionally, it uses learned positional embeddings. -class LavenderFlowPatchEmbed(nn.Module): +class AuraFlowPatchEmbed(nn.Module): def __init__( self, height=224, @@ -78,9 +78,9 @@ def forward(self, latent): return latent + self.pos_embed -# Taken from the original lavender flow inference code. -# Our feedforward only has GELU but lavender uses SiLU. -class LavenderFlowFeedForward(nn.Module): +# Taken from the original Aura flow inference code. +# Our feedforward only has GELU but Aura uses SiLU. +class AuraFlowFeedForward(nn.Module): def __init__(self, dim, hidden_dim=None) -> None: super().__init__() if hidden_dim is None: @@ -100,15 +100,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @maybe_allow_in_graph -class LavenderFlowDiTTransformerBlock(nn.Module): - """Similar `LavenderFlowTransformerBlock with a single DiT instead of an MMDiT.""" +class AuraFlowDiTTransformerBlock(nn.Module): + """Similar `AuraFlowTransformerBlock with a single DiT instead of an MMDiT.""" def __init__(self, dim, num_attention_heads, attention_head_dim): super().__init__() self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") - processor = LavenderFlowAttnProcessor2_0() + processor = AuraFlowAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, @@ -122,7 +122,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim): ) self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) - self.ff = LavenderFlowFeedForward(dim, dim * 4) + self.ff = AuraFlowFeedForward(dim, dim * 4) def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9999): residual = hidden_states @@ -144,9 +144,9 @@ def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9 @maybe_allow_in_graph -class LavenderFlowTransformerBlock(nn.Module): +class AuraFlowTransformerBlock(nn.Module): r""" - Transformer block for Lavender Flow. Similar to SD3 MMDiT. Differences (non-exhaustive): + Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive): * QK Norm in the attention blocks * No bias in the attention blocks @@ -173,7 +173,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, is_last=False): elif context_norm_type == "ada_norm_zero": self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") - processor = LavenderFlowAttnProcessor2_0() + processor = AuraFlowAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, @@ -190,10 +190,10 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, is_last=False): ) self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) - self.ff = LavenderFlowFeedForward(dim, dim * 4) + self.ff = AuraFlowFeedForward(dim, dim * 4) self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False) if not is_last: - self.ff_context = LavenderFlowFeedForward(dim, dim * 4) + self.ff_context = AuraFlowFeedForward(dim, dim * 4) else: self.ff_context = None @@ -229,7 +229,7 @@ def forward( return encoder_hidden_states, hidden_states -class LavenderFlowTransformer2DModel(ModelMixin, ConfigMixin): +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config @@ -252,7 +252,7 @@ def __init__( self.out_channels = out_channels if out_channels is not None else default_out_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.pos_embed = LavenderFlowPatchEmbed( + self.pos_embed = AuraFlowPatchEmbed( height=self.config.sample_size, width=self.config.sample_size, patch_size=self.config.patch_size, @@ -269,7 +269,7 @@ def __init__( self.joint_transformer_blocks = nn.ModuleList( [ - LavenderFlowTransformerBlock( + AuraFlowTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, @@ -279,7 +279,7 @@ def __init__( ) self.single_transformer_blocks = nn.ModuleList( [ - LavenderFlowDiTTransformerBlock( + AuraFlowDiTTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 6fa7671f2321..7f4014019577 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -242,7 +242,7 @@ "StableDiffusionLDM3DPipeline", ] ) - _import_structure["lavender_flow"] = ["LavenderFlowPipeline"] + _import_structure["aura_flow"] = ["AuraFlowPipeline"] _import_structure["stable_diffusion_3"] = ["StableDiffusion3Pipeline", "StableDiffusion3Img2ImgPipeline"] _import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] @@ -407,6 +407,7 @@ AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, ) + from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, @@ -481,7 +482,6 @@ LatentConsistencyModelPipeline, ) from .latent_diffusion import LDMTextToImagePipeline - from .lavender_flow import LavenderFlowPipeline from .ledits_pp import ( LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput, diff --git a/src/diffusers/pipelines/lavender_flow/__init__.py b/src/diffusers/pipelines/aura_flow/__init__.py similarity index 89% rename from src/diffusers/pipelines/lavender_flow/__init__.py rename to src/diffusers/pipelines/aura_flow/__init__.py index c81c38e262c0..e1917baa61e2 100644 --- a/src/diffusers/pipelines/lavender_flow/__init__.py +++ b/src/diffusers/pipelines/aura_flow/__init__.py @@ -22,7 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_lavender_flow"] = ["LavenderFlowPipeline"] + _import_structure["pipeline_aura_flow"] = ["AuraFlowPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -32,7 +32,7 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .pipeline_lavender_flow import LavenderFlowPipeline + from .pipeline_aura_flow import AuraFlowPipeline else: import sys diff --git a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py similarity index 98% rename from src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py rename to src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 05778f1a147c..92d93fb29f80 100644 --- a/src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -1,4 +1,4 @@ -# Copyright 2024 Lavender-Flow Authors and The HuggingFace Team. All rights reserved. +# Copyright 2024 AuraFlow Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ from transformers import T5Tokenizer, UMT5EncoderModel from ...image_processor import VaeImageProcessor -from ...models import AutoencoderKL, LavenderFlowTransformer2DModel +from ...models import AuraFlowTransformer2DModel, AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging @@ -89,7 +89,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class LavenderFlowPipeline(DiffusionPipeline): +class AuraFlowPipeline(DiffusionPipeline): _optional_components = ["tokenizer", "text_encoder"] model_cpu_offload_seq = "text_encoder->transformer->vae" @@ -98,7 +98,7 @@ def __init__( tokenizer: T5Tokenizer, text_encoder: UMT5EncoderModel, vae: AutoencoderKL, - transformer: LavenderFlowTransformer2DModel, + transformer: AuraFlowTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 9170582911ca..57fde90c236d 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -17,7 +17,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class AutoencoderKL(metaclass=DummyObject): +class AuraFlowTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -32,7 +32,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class AutoencoderKLTemporalDecoder(metaclass=DummyObject): +class AutoencoderKL(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -47,7 +47,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class AutoencoderTiny(metaclass=DummyObject): +class AutoencoderKLTemporalDecoder(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -62,7 +62,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ConsistencyDecoderVAE(metaclass=DummyObject): +class AutoencoderTiny(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -77,7 +77,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ControlNetModel(metaclass=DummyObject): +class ConsistencyDecoderVAE(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -92,7 +92,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ControlNetXSAdapter(metaclass=DummyObject): +class ControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -107,7 +107,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DiTTransformer2DModel(metaclass=DummyObject): +class ControlNetXSAdapter(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -122,7 +122,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HunyuanDiT2DControlNetModel(metaclass=DummyObject): +class DiTTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -137,7 +137,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HunyuanDiT2DModel(metaclass=DummyObject): +class HunyuanDiT2DControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -152,7 +152,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject): +class HunyuanDiT2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -167,7 +167,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class I2VGenXLUNet(metaclass=DummyObject): +class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -182,7 +182,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class Kandinsky3UNet(metaclass=DummyObject): +class I2VGenXLUNet(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -197,7 +197,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class LavenderFlowTransformer2DModel(metaclass=DummyObject): +class Kandinsky3UNet(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 25ef41e0f291..457de6039b6f 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -182,6 +182,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class AuraFlowPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class CLIPImageProjection(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -677,21 +692,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class LavenderFlowPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class LDMTextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 4bdea0d224ed05ca92748261e065bef694fd9eba Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sun, 7 Jul 2024 07:29:06 +0530 Subject: [PATCH 19/37] Fix the `added_proj_bias` default value (#8800) checking --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 12a1eef34db2..55096d180ea2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -103,7 +103,7 @@ def __init__( cross_attention_norm_num_groups: int = 32, qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, - added_proj_bias: Optional[bool] = None, + added_proj_bias: Optional[bool] = True, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, out_bias: bool = True, From 89fad69fab0d42eb9f73b75b07ff51f14008801c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 7 Jul 2024 07:55:08 +0530 Subject: [PATCH 20/37] remnant aura flow renaming --- scripts/convert_lavender_flow_to_diffusers.py | 2 +- src/diffusers/models/attention_processor.py | 2 +- src/diffusers/models/normalization.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/convert_lavender_flow_to_diffusers.py b/scripts/convert_lavender_flow_to_diffusers.py index cf500c0aedcc..2792dda0fec8 100644 --- a/scripts/convert_lavender_flow_to_diffusers.py +++ b/scripts/convert_lavender_flow_to_diffusers.py @@ -119,7 +119,7 @@ def populate_state_dict(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str) - parser.add_argument("--dump_path", default="lavender-flow", type=str) + parser.add_argument("--dump_path", default="aura-flow", type=str) parser.add_argument("--hub_id", default=None, type=str) args = parser.parse_args() diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 55096d180ea2..34962ff8ae77 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1151,7 +1151,7 @@ def __call__( class AuraFlowAttnProcessor2_0: - """Attention processor used typically in processing Lavender Flow.""" + """Attention processor used typically in processing Aura Flow.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 3b87eeb822c0..c05b6bb07766 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -208,7 +208,7 @@ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torc # convert back to the original dtype in case `conditioning_embedding` is upcasted to float32 (needed for hunyuanDiT) emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) scale, shift = torch.chunk(emb, 2, dim=1) - # lavender flow doesn't have a norm at one place here + # aura flow doesn't have a norm at one place here if self.norm is not None: x = self.norm(x) x = x * (1 + scale)[:, None, :] + shift[:, None, :] From e73442f1c8392b36838778c7451637832070eac4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 7 Jul 2024 09:11:20 +0530 Subject: [PATCH 21/37] make it possible to reuse prompt embeds. --- src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 92d93fb29f80..97a0e234f33d 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -107,7 +107,9 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs From dccc6827d4d01f46ee95646bdb27b931870682d0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 7 Jul 2024 18:46:36 +0530 Subject: [PATCH 22/37] rename to auraflow --- ...der_flow_to_diffusers.py => convert_aura_flow_to_diffusers.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/{convert_lavender_flow_to_diffusers.py => convert_aura_flow_to_diffusers.py} (100%) diff --git a/scripts/convert_lavender_flow_to_diffusers.py b/scripts/convert_aura_flow_to_diffusers.py similarity index 100% rename from scripts/convert_lavender_flow_to_diffusers.py rename to scripts/convert_aura_flow_to_diffusers.py From f23151b1e8b655baecc793ff5a81d66d77a04c92 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 7 Jul 2024 09:26:55 -1000 Subject: [PATCH 23/37] [lavender-flow] use flow match euler scheduler (#8799) * suppoort custom sigmas * style * apply feedbacks --- .../pipelines/aura_flow/pipeline_aura_flow.py | 21 ++++++---------- .../scheduling_flow_match_euler_discrete.py | 25 ++++++++++++------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 97a0e234f33d..261d8e93dcef 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -430,13 +430,14 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 4. Prepare timesteps + + # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels - effective_batch_size = batch_size * num_images_per_prompt latents = self.prepare_latents( batch_size * num_images_per_prompt, latent_channels, @@ -450,21 +451,15 @@ def __call__( # 6. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - dt = 1.0 / num_inference_steps - dt = ( - torch.tensor([dt] * effective_batch_size) - .to(self.device) - .view([effective_batch_size, *([1] * len(latents.shape[1:]))]) - ) with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(range(num_inference_steps, 0, -1)): + for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - t = t / num_inference_steps - timestep = ( - torch.tensor([t]).expand(latent_model_input.shape[0]).to(latents.device, dtype=latents.dtype) - ) + timestep = torch.tensor([t / 1000]).expand(latent_model_input.shape[0]) + timestep = timestep.to(latents.device, dtype=latents.dtype) # predict noise model_output noise_pred = self.transformer( @@ -480,7 +475,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = (latents - dt * noise_pred).to(latents.dtype) + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 83ce63981abd..779e691f0c27 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -158,7 +158,12 @@ def scale_noise( def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -168,17 +173,19 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps - timesteps = np.linspace( - self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps - ) + if sigmas is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) - sigmas = timesteps / self.config.num_train_timesteps - sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) timesteps = sigmas * self.config.num_train_timesteps + self.timesteps = timesteps.to(device=device) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) From 8984d231cc9dcfa127cd14c12c4da0d1d5aec180 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Jul 2024 05:31:32 +0200 Subject: [PATCH 24/37] more feedback. --- scripts/convert_aura_flow_to_diffusers.py | 4 ++-- src/diffusers/models/attention_processor.py | 7 ++----- .../models/transformers/auraflow_transformer_2d.py | 14 +++++++------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/scripts/convert_aura_flow_to_diffusers.py b/scripts/convert_aura_flow_to_diffusers.py index 2792dda0fec8..6b59737bf0e4 100644 --- a/scripts/convert_aura_flow_to_diffusers.py +++ b/scripts/convert_aura_flow_to_diffusers.py @@ -52,7 +52,7 @@ def convert_transformer(state_dict): # feed-forward path_mapping = {"mlpX": "ff", "mlpC": "ff_context"} for orig_k, diffuser_k in path_mapping.items(): - for k in ["c_fc1", "c_fc2", "c_proj"]: + for k in ["linear_1", "linear_2", "out_projection"]: converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{k}.weight"] = state_dict.pop( f"model.double_layers.{i}.{orig_k}.{k}.weight" ) @@ -76,7 +76,7 @@ def convert_transformer(state_dict): # Single-DiT blocks. for i in range(single_dit_layers): # feed-forward - for k in ["c_fc1", "c_fc2", "c_proj"]: + for k in ["linear_1", "linear_2", "out_projection"]: converted_state_dict[f"single_transformer_blocks.{i}.ff.{k}.weight"] = state_dict.pop( f"model.single_layers.{i}.mlp.{k}.weight" ) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0e86f011cbb5..5c10f05c0320 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -23,6 +23,7 @@ from ..utils import deprecate, logging from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph +from .normalization import FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -119,7 +120,6 @@ def __init__( context_pre_only=None, ): super().__init__() - from .normalization import FP32LayerNorm self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads @@ -230,10 +230,7 @@ def __init__( self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) if qk_norm is not None and added_kv_proj_dim is not None: - if qk_norm == "layer_norm": - self.norm_added_q = nn.LayerNorm(dim_head, eps=eps) - self.norm_added_k = nn.LayerNorm(dim_head, eps=eps) - elif qk_norm == "fp32_layer_norm": + if qk_norm == "fp32_layer_norm": self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) else: diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index feb19910d843..c36bf17bee59 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -86,16 +86,16 @@ def __init__(self, dim, hidden_dim=None) -> None: if hidden_dim is None: hidden_dim = 4 * dim - n_hidden = int(2 * hidden_dim / 3) - n_hidden = find_multiple(n_hidden, 256) + final_hidden_dim = int(2 * hidden_dim / 3) + final_hidden_dim = find_multiple(final_hidden_dim, 256) - self.c_fc1 = nn.Linear(dim, n_hidden, bias=False) - self.c_fc2 = nn.Linear(dim, n_hidden, bias=False) - self.c_proj = nn.Linear(n_hidden, dim, bias=False) + self.linear_1 = nn.Linear(dim, final_hidden_dim, bias=False) + self.linear_2 = nn.Linear(dim, final_hidden_dim, bias=False) + self.out_projection = nn.Linear(final_hidden_dim, dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = F.silu(self.c_fc1(x)) * self.c_fc2(x) - x = self.c_proj(x) + x = F.silu(self.linear_1(x)) * self.linear_2(x) + x = self.out_projection(x) return x From a28154744795ba9f84e3ad382ca902e1b069f5db Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Jul 2024 05:35:00 +0200 Subject: [PATCH 25/37] context_norm_type fix --- .../models/transformers/auraflow_transformer_2d.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index c36bf17bee59..8b15917a86f1 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -159,10 +159,8 @@ class AuraFlowTransformerBlock(nn.Module): is_last (`bool`): Boolean to determine if this is the last block in the model. """ - def __init__(self, dim, num_attention_heads, attention_head_dim, is_last=False): + def __init__(self, dim, num_attention_heads, attention_head_dim, context_norm_type="ada_norm_zero"): super().__init__() - self.is_last = is_last - context_norm_type = "ada_norm_continous" if is_last else "ada_norm_zero" self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") @@ -172,6 +170,10 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, is_last=False): ) elif context_norm_type == "ada_norm_zero": self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") + else: + raise ValueError( + "Invalid norm type provided for `context_norm_type`. Valid values are are: 'ada_norm_continous' and 'ada_norm_zero'." + ) processor = AuraFlowAttnProcessor2_0() self.attn = Attention( @@ -192,10 +194,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, is_last=False): self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) self.ff = AuraFlowFeedForward(dim, dim * 4) self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False) - if not is_last: - self.ff_context = AuraFlowFeedForward(dim, dim * 4) - else: - self.ff_context = None + self.ff_context = AuraFlowFeedForward(dim, dim * 4) def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=0 From 8830bf10dda89ef5b742e8ce3337f3b5ee7cfa72 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Jul 2024 05:39:43 +0200 Subject: [PATCH 26/37] fix circular import --- src/diffusers/models/attention_processor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5c10f05c0320..961bcd2a49e8 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -23,7 +23,6 @@ from ..utils import deprecate, logging from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph -from .normalization import FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -121,6 +120,9 @@ def __init__( ): super().__init__() + # To prevent circular import. + from .normalization import FP32LayerNorm + self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads self.query_dim = query_dim From 4334f72f3645db4484acc4730879b4d6557be2a5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Jul 2024 06:09:03 +0200 Subject: [PATCH 27/37] fix conversion --- scripts/convert_aura_flow_to_diffusers.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/scripts/convert_aura_flow_to_diffusers.py b/scripts/convert_aura_flow_to_diffusers.py index 6b59737bf0e4..74c34f4851ff 100644 --- a/scripts/convert_aura_flow_to_diffusers.py +++ b/scripts/convert_aura_flow_to_diffusers.py @@ -7,7 +7,7 @@ def load_original_state_dict(args): - model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="model.bin") + model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="aura_diffusion_pytorch_model.bin") state_dict = torch.load(model_pt, map_location="cpu") return state_dict @@ -51,9 +51,10 @@ def convert_transformer(state_dict): for i in range(mmdit_layers): # feed-forward path_mapping = {"mlpX": "ff", "mlpC": "ff_context"} + weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} for orig_k, diffuser_k in path_mapping.items(): - for k in ["linear_1", "linear_2", "out_projection"]: - converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{k}.weight"] = state_dict.pop( + for k, v in weight_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = state_dict.pop( f"model.double_layers.{i}.{orig_k}.{k}.weight" ) @@ -76,8 +77,9 @@ def convert_transformer(state_dict): # Single-DiT blocks. for i in range(single_dit_layers): # feed-forward - for k in ["linear_1", "linear_2", "out_projection"]: - converted_state_dict[f"single_transformer_blocks.{i}.ff.{k}.weight"] = state_dict.pop( + mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} + for k, v in mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = state_dict.pop( f"model.single_layers.{i}.mlp.{k}.weight" ) From b1dc5ec748417b5804c98368dbe5eb2606000632 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Jul 2024 06:55:28 +0200 Subject: [PATCH 28/37] add fast tests for pipeline --- .../pipelines/aura_flow/pipeline_aura_flow.py | 11 +- tests/pipelines/aura_flow/__init__.py | 0 .../test_pipeline_stable_diffusion_3.py | 120 ++++++++++++++++++ 3 files changed, 126 insertions(+), 5 deletions(-) create mode 100644 tests/pipelines/aura_flow/__init__.py create mode 100644 tests/pipelines/aura_flow/test_pipeline_stable_diffusion_3.py diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 261d8e93dcef..2ca935189e7c 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -90,7 +90,7 @@ def retrieve_timesteps( class AuraFlowPipeline(DiffusionPipeline): - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" def __init__( @@ -183,8 +183,8 @@ def check_inputs( def encode_prompt( self, prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]] = None, do_classifier_free_guidance: bool = True, - negative_prompt: str = "This is watermark, jpeg image white background, web image", num_images_per_prompt: int = 1, device: Optional[torch.device] = None, prompt_embeds: Optional[torch.Tensor] = None, @@ -269,6 +269,7 @@ def encode_prompt( # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( @@ -357,7 +358,7 @@ def upcast_vae(self): def __call__( self, prompt: Union[str, List[str]] = None, - negative_prompt: str = "This is watermark, jpeg image white background, web image", + negative_prompt: Union[str, List[str]] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, @@ -415,9 +416,9 @@ def __call__( negative_prompt_embeds, negative_prompt_attention_mask, ) = self.encode_prompt( - prompt, - do_classifier_free_guidance, + prompt=prompt, negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, prompt_embeds=prompt_embeds, diff --git a/tests/pipelines/aura_flow/__init__.py b/tests/pipelines/aura_flow/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/aura_flow/test_pipeline_stable_diffusion_3.py b/tests/pipelines/aura_flow/test_pipeline_stable_diffusion_3.py new file mode 100644 index 000000000000..5b0b05256b43 --- /dev/null +++ b/tests/pipelines/aura_flow/test_pipeline_stable_diffusion_3.py @@ -0,0 +1,120 @@ +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler +from diffusers.utils.testing_utils import ( + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = AuraFlowPipeline + params = frozenset( + [ + "prompt", + "height", + "width", + "guidance_scale", + "negative_prompt", + "prompt_embeds", + "negative_prompt_embeds", + ] + ) + batch_params = frozenset(["prompt", "negative_prompt"]) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = AuraFlowTransformer2DModel( + sample_size=32, + patch_size=2, + in_channels=4, + num_mmdit_layers=1, + num_single_dit_layers=1, + attention_head_dim=8, + num_attention_heads=4, + caption_projection_dim=32, + joint_attention_dim=32, + out_channels=4, + pos_embed_max_size=256, + ) + text_encoder = UMT5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-umt5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=32, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + "height": None, + "width": None, + } + return inputs + + def test_aura_flow_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + do_classifier_free_guidance = inputs["guidance_scale"] > 1 + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = pipe.encode_prompt( + prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + device=torch_device, + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_attention_mask=negative_prompt_attention_mask, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 + + def test_attention_slicing_forward_pass(self): + # Attention slicing needs to implemented differently for this because how single DiT and MMDiT + # blocks interfere with each other. + return From 942377d2e75fd7e7d7b761f08374be9eade6a048 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Jul 2024 06:56:34 +0200 Subject: [PATCH 29/37] fix test file name --- ...t_pipeline_stable_diffusion_3.py => test_pipeline_aura_dlow.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/pipelines/aura_flow/test_pipeline_stable_diffusion_3.py => test_pipeline_aura_dlow.py (100%) diff --git a/tests/pipelines/aura_flow/test_pipeline_stable_diffusion_3.py b/test_pipeline_aura_dlow.py similarity index 100% rename from tests/pipelines/aura_flow/test_pipeline_stable_diffusion_3.py rename to test_pipeline_aura_dlow.py From f8a08b541d7ddd16cae440dbeaf7919f8d4f3f36 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Jul 2024 06:57:49 +0200 Subject: [PATCH 30/37] fix test path --- .../pipelines/aura_flow/test_pipeline_aura_dlow.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test_pipeline_aura_dlow.py => tests/pipelines/aura_flow/test_pipeline_aura_dlow.py (100%) diff --git a/test_pipeline_aura_dlow.py b/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py similarity index 100% rename from test_pipeline_aura_dlow.py rename to tests/pipelines/aura_flow/test_pipeline_aura_dlow.py From e9832f9b9be5e2ca808f206bdff10200d1e1f56f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Jul 2024 06:59:15 +0200 Subject: [PATCH 31/37] spacxing brtween initialization --- tests/pipelines/aura_flow/test_pipeline_aura_dlow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py b/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py index 5b0b05256b43..154e47e88b2f 100644 --- a/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py +++ b/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py @@ -42,6 +42,7 @@ def get_dummy_components(self): out_channels=4, pos_embed_max_size=256, ) + text_encoder = UMT5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-umt5") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") From 2c872503e1f02469a15056c2089d63046f88774c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Jul 2024 07:00:40 +0200 Subject: [PATCH 32/37] style --- tests/pipelines/aura_flow/test_pipeline_aura_dlow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py b/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py index 154e47e88b2f..9a2f1846f1d9 100644 --- a/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py +++ b/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py @@ -42,7 +42,7 @@ def get_dummy_components(self): out_channels=4, pos_embed_max_size=256, ) - + text_encoder = UMT5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-umt5") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") From 1b3e620d44fef5cbe6dffb64df97196ee54b9c81 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Jul 2024 08:58:48 +0200 Subject: [PATCH 33/37] add test for the transformer model. --- .../test_models_transformer_aura_flow.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 tests/models/transformers/test_models_transformer_aura_flow.py diff --git a/tests/models/transformers/test_models_transformer_aura_flow.py b/tests/models/transformers/test_models_transformer_aura_flow.py new file mode 100644 index 000000000000..57fac4ba769c --- /dev/null +++ b/tests/models/transformers/test_models_transformer_aura_flow.py @@ -0,0 +1,73 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AuraFlowTransformer2DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class SD3TransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = AuraFlowTransformer2DModel + main_input_name = "hidden_states" + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + height = width = embedding_dim = 32 + sequence_length = 256 + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "sample_size": 32, + "patch_size": 2, + "in_channels": 4, + "num_mmdit_layers": 1, + "num_single_dit_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 4, + "caption_projection_dim": 32, + "joint_attention_dim": 32, + "out_channels": 4, + "pos_embed_max_size": 256, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict From ed33913637f690af94159c15f7889b8bb3be0bdc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Jul 2024 17:55:37 +0200 Subject: [PATCH 34/37] remove context_norm_type --- .../models/transformers/auraflow_transformer_2d.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 8b15917a86f1..c55e9a1aacd5 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -159,21 +159,11 @@ class AuraFlowTransformerBlock(nn.Module): is_last (`bool`): Boolean to determine if this is the last block in the model. """ - def __init__(self, dim, num_attention_heads, attention_head_dim, context_norm_type="ada_norm_zero"): + def __init__(self, dim, num_attention_heads, attention_head_dim): super().__init__() self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") - - if context_norm_type == "ada_norm_continous": - self.norm1_context = AdaLayerNormContinuous( - dim, dim, elementwise_affine=False, bias=False, norm_type="fp32_layer_norm" - ) - elif context_norm_type == "ada_norm_zero": - self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") - else: - raise ValueError( - "Invalid norm type provided for `context_norm_type`. Valid values are are: 'ada_norm_continous' and 'ada_norm_zero'." - ) + self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") processor = AuraFlowAttnProcessor2_0() self.attn = Attention( From 6531e5453b3bcffd3b0f1cfcffcf22974230a4c4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 10 Jul 2024 06:42:46 +0200 Subject: [PATCH 35/37] remove ada continuous. --- src/diffusers/models/normalization.py | 12 ++--------- .../transformers/auraflow_transformer_2d.py | 21 +++++++++++++++++-- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 73ea1acd6ecc..4e532f3fc990 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -224,28 +224,20 @@ def __init__( norm_type="layer_norm", ): super().__init__() - self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) - elif norm_type == "fp32_layer_norm": - self.norm = FP32LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine, bias=bias) elif norm_type == "rms_norm": self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) - elif norm_type == "no_norm": - self.norm = None else: raise ValueError(f"unknown norm_type {norm_type}") def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: - # convert back to the original dtype in case `conditioning_embedding` is upcasted to float32 (needed for hunyuanDiT) + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) scale, shift = torch.chunk(emb, 2, dim=1) - # aura flow doesn't have a norm at one place here - if self.norm is not None: - x = self.norm(x) - x = x * (1 + scale)[:, None, :] + shift[:, None, :] + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] return x diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index c55e9a1aacd5..f62210ba0313 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -26,7 +26,7 @@ from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, FP32LayerNorm +from ..normalization import AdaLayerNormZero, FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -97,6 +97,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.silu(self.linear_1(x)) * self.linear_2(x) x = self.out_projection(x) return x + +class AuraFlowPreFinalBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int + ): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = x * (1 + scale)[:, None, :] + shift[:, None, :] + return x @maybe_allow_in_graph @@ -277,7 +294,7 @@ def __init__( ] ) - self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, norm_type="no_norm", bias=False) + self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) # https://arxiv.org/abs/2309.16588 From 0f721acc9418155fe6abba5264e7ac78fa68ffd4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 10 Jul 2024 06:46:23 +0200 Subject: [PATCH 36/37] address yiyi --- .../transformers/auraflow_transformer_2d.py | 21 ++++++++----------- .../pipelines/aura_flow/pipeline_aura_flow.py | 20 ++---------------- 2 files changed, 11 insertions(+), 30 deletions(-) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index f62210ba0313..eb3b749c88c5 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -97,18 +97,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.silu(self.linear_1(x)) * self.linear_2(x) x = self.out_projection(x) return x - + + class AuraFlowPreFinalBlock(nn.Module): - def __init__( - self, - embedding_dim: int, - conditioning_embedding_dim: int - ): + def __init__(self, embedding_dim: int, conditioning_embedding_dim: int): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False) - + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) scale, shift = torch.chunk(emb, 2, dim=1) @@ -117,8 +114,8 @@ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torc @maybe_allow_in_graph -class AuraFlowDiTTransformerBlock(nn.Module): - """Similar `AuraFlowTransformerBlock with a single DiT instead of an MMDiT.""" +class AuraFlowSingleTransformerBlock(nn.Module): + """Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT.""" def __init__(self, dim, num_attention_heads, attention_head_dim): super().__init__() @@ -161,7 +158,7 @@ def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9 @maybe_allow_in_graph -class AuraFlowTransformerBlock(nn.Module): +class AuraFlowJointTransformerBlock(nn.Module): r""" Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive): @@ -275,7 +272,7 @@ def __init__( self.joint_transformer_blocks = nn.ModuleList( [ - AuraFlowTransformerBlock( + AuraFlowJointTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, @@ -285,7 +282,7 @@ def __init__( ) self.single_transformer_blocks = nn.ModuleList( [ - AuraFlowDiTTransformerBlock( + AuraFlowSingleTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 2ca935189e7c..73b149e853cf 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from transformers import T5Tokenizer, UMT5EncoderModel @@ -112,14 +112,12 @@ def __init__( ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs def check_inputs( self, prompt, height, width, negative_prompt, - callback_steps, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, @@ -128,14 +126,6 @@ def check_inputs( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -372,11 +362,9 @@ def __call__( prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, - callback_steps: int = 1, - max_sequence_length: int = 256, ) -> Union[ImagePipelineOutput, Tuple]: # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor @@ -387,7 +375,6 @@ def __call__( height, width, negative_prompt, - callback_steps, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, @@ -481,9 +468,6 @@ def __call__( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) if output_type == "latent": image = latents From 15d3198b15d1a14b2b21b0a33f37577ab8fd3dba Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 11 Jul 2024 20:35:45 +0200 Subject: [PATCH 37/37] style --- src/diffusers/utils/dummy_torch_and_transformers_objects.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 072a538f4b7f..97fc0ca1b6d0 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -182,7 +182,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) - class AuraFlowPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"]