diff --git a/scripts/convert_aura_flow_to_diffusers.py b/scripts/convert_aura_flow_to_diffusers.py new file mode 100644 index 000000000000..74c34f4851ff --- /dev/null +++ b/scripts/convert_aura_flow_to_diffusers.py @@ -0,0 +1,131 @@ +import argparse + +import torch +from huggingface_hub import hf_hub_download + +from diffusers.models.transformers.auraflow_transformer_2d import AuraFlowTransformer2DModel + + +def load_original_state_dict(args): + model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="aura_diffusion_pytorch_model.bin") + state_dict = torch.load(model_pt, map_location="cpu") + return state_dict + + +def calculate_layers(state_dict_keys, key_prefix): + dit_layers = set() + for k in state_dict_keys: + if key_prefix in k: + dit_layers.add(int(k.split(".")[2])) + print(f"{key_prefix}: {len(dit_layers)}") + return len(dit_layers) + + +# similar to SD3 but only for the last norm layer +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_transformer(state_dict): + converted_state_dict = {} + state_dict_keys = list(state_dict.keys()) + + converted_state_dict["register_tokens"] = state_dict.pop("model.register_tokens") + converted_state_dict["pos_embed.pos_embed"] = state_dict.pop("model.positional_encoding") + converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("model.init_x_linear.weight") + converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("model.init_x_linear.bias") + + converted_state_dict["time_step_proj.linear_1.weight"] = state_dict.pop("model.t_embedder.mlp.0.weight") + converted_state_dict["time_step_proj.linear_1.bias"] = state_dict.pop("model.t_embedder.mlp.0.bias") + converted_state_dict["time_step_proj.linear_2.weight"] = state_dict.pop("model.t_embedder.mlp.2.weight") + converted_state_dict["time_step_proj.linear_2.bias"] = state_dict.pop("model.t_embedder.mlp.2.bias") + + converted_state_dict["context_embedder.weight"] = state_dict.pop("model.cond_seq_linear.weight") + + mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") + single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") + + # MMDiT blocks 🎸. + for i in range(mmdit_layers): + # feed-forward + path_mapping = {"mlpX": "ff", "mlpC": "ff_context"} + weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} + for orig_k, diffuser_k in path_mapping.items(): + for k, v in weight_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = state_dict.pop( + f"model.double_layers.{i}.{orig_k}.{k}.weight" + ) + + # norms + path_mapping = {"modX": "norm1", "modC": "norm1_context"} + for orig_k, diffuser_k in path_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = state_dict.pop( + f"model.double_layers.{i}.{orig_k}.1.weight" + ) + + # attns + x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"} + context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"} + for attn_mapping in [x_attn_mapping, context_attn_mapping]: + for k, v in attn_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop( + f"model.double_layers.{i}.attn.{k}.weight" + ) + + # Single-DiT blocks. + for i in range(single_dit_layers): + # feed-forward + mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} + for k, v in mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = state_dict.pop( + f"model.single_layers.{i}.mlp.{k}.weight" + ) + + # norms + converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = state_dict.pop( + f"model.single_layers.{i}.modCX.1.weight" + ) + + # attns + x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"} + for k, v in x_attn_mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop( + f"model.single_layers.{i}.attn.{k}.weight" + ) + + # Final blocks. + converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict.pop("model.modF.1.weight"), dim=None) + + return converted_state_dict + + +@torch.no_grad() +def populate_state_dict(args): + original_state_dict = load_original_state_dict(args) + state_dict_keys = list(original_state_dict.keys()) + mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") + single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") + + converted_state_dict = convert_transformer(original_state_dict) + model_diffusers = AuraFlowTransformer2DModel( + num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers + ) + model_diffusers.load_state_dict(converted_state_dict, strict=True) + + return model_diffusers + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str) + parser.add_argument("--dump_path", default="aura-flow", type=str) + parser.add_argument("--hub_id", default=None, type=str) + args = parser.parse_args() + + model_diffusers = populate_state_dict(args) + model_diffusers.save_pretrained(args.dump_path) + if args.hub_id is not None: + model_diffusers.push_to_hub(args.hub_id) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 873c30ee4be8..f230a9bcad2b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -76,6 +76,7 @@ _import_structure["models"].extend( [ "AsymmetricAutoencoderKL", + "AuraFlowTransformer2DModel", "AutoencoderKL", "AutoencoderKLTemporalDecoder", "AutoencoderTiny", @@ -235,6 +236,7 @@ "AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel", "AudioLDMPipeline", + "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", "ChatGLMModel", @@ -507,6 +509,7 @@ else: from .models import ( AsymmetricAutoencoderKL, + AuraFlowTransformer2DModel, AutoencoderKL, AutoencoderKLTemporalDecoder, AutoencoderTiny, @@ -646,6 +649,7 @@ AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, AudioLDMPipeline, + AuraFlowPipeline, ChatGLMModel, ChatGLMTokenizer, CLIPImageProjection, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 960d18dc0feb..39dc149ff6d1 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -38,6 +38,7 @@ _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] + _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] @@ -84,6 +85,7 @@ from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( + AuraFlowTransformer2DModel, DiTTransformer2DModel, DualTransformer2DModel, HunyuanDiT2DModel, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ef25d24f9f1a..961bcd2a49e8 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -22,7 +22,7 @@ from ..image_processor import IPAdapterMaskProcessor from ..utils import deprecate, logging from ..utils.import_utils import is_torch_npu_available, is_xformers_available -from ..utils.torch_utils import maybe_allow_in_graph +from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -104,6 +104,7 @@ def __init__( cross_attention_norm_num_groups: int = 32, qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, out_bias: bool = True, @@ -118,6 +119,10 @@ def __init__( context_pre_only=None, ): super().__init__() + + # To prevent circular import. + from .normalization import FP32LayerNorm + self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads self.query_dim = query_dim @@ -170,6 +175,9 @@ def __init__( elif qk_norm == "layer_norm": self.norm_q = nn.LayerNorm(dim_head, eps=eps) self.norm_k = nn.LayerNorm(dim_head, eps=eps) + elif qk_norm == "fp32_layer_norm": + self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) elif qk_norm == "layer_norm_across_heads": # Lumina applys qk norm across all heads self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) @@ -211,10 +219,10 @@ def __init__( self.to_v = None if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim) - self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim) + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) if self.context_pre_only is not None: - self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) @@ -223,6 +231,14 @@ def __init__( if self.context_pre_only is not None and not self.context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + else: + self.norm_added_q = None + self.norm_added_k = None + # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention @@ -1137,6 +1153,100 @@ def __call__( return hidden_states, encoder_hidden_states +class AuraFlowAttnProcessor2_0: + """Attention processor used typically in processing Aura Flow.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): + raise ImportError( + "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + i=0, + *args, + **kwargs, + ) -> torch.FloatTensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + # Reshape. + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + # Apply QK norm. + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Concatenate the projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) + + query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Attention. + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, encoder_hidden_states.shape[1] :], + hidden_states[:, : encoder_hidden_states.shape[1]], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if encoder_hidden_states is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class XFormersAttnAddedKVProcessor: r""" Processor for implementing memory efficient attention using xFormers. diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index ec1c68b86c89..0890842f5775 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -473,11 +473,12 @@ def forward(self, sample, condition=None): class Timesteps(nn.Module): - def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift + self.scale = scale def forward(self, timesteps): t_emb = get_timestep_embedding( @@ -485,6 +486,7 @@ def forward(self, timesteps): self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, ) return t_emb diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 16d76faad0c5..4e532f3fc990 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -51,6 +51,18 @@ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: return x +class FP32LayerNorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm( + inputs.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + class AdaLayerNormZero(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). @@ -60,7 +72,7 @@ class AdaLayerNormZero(nn.Module): num_embeddings (`int`): The size of the embeddings dictionary. """ - def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None): + def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True): super().__init__() if num_embeddings is not None: self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) @@ -68,8 +80,15 @@ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None): self.emb = None self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + elif norm_type == "fp32_layer_norm": + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) def forward( self, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 9c5c9b6dbe16..ae5103160790 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -2,6 +2,7 @@ if is_torch_available(): + from .auraflow_transformer_2d import AuraFlowTransformer2DModel from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel from .hunyuan_transformer_2d import HunyuanDiT2DModel diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py new file mode 100644 index 000000000000..eb3b749c88c5 --- /dev/null +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -0,0 +1,402 @@ +# Copyright 2024 AuraFlow Authors, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention_processor import Attention, AuraFlowAttnProcessor2_0 +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormZero, FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Taken from the original aura flow inference code. +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +# Aura Flow patch embed doesn't use convs for projections. +# Additionally, it uses learned positional embeddings. +class AuraFlowPatchEmbed(nn.Module): + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + pos_embed_max_size=None, + ): + super().__init__() + + self.num_patches = (height // patch_size) * (width // patch_size) + self.pos_embed_max_size = pos_embed_max_size + + self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim) + self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1) + + self.patch_size = patch_size + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + + def forward(self, latent): + batch_size, num_channels, height, width = latent.size() + latent = latent.view( + batch_size, + num_channels, + height // self.patch_size, + self.patch_size, + width // self.patch_size, + self.patch_size, + ) + latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + latent = self.proj(latent) + return latent + self.pos_embed + + +# Taken from the original Aura flow inference code. +# Our feedforward only has GELU but Aura uses SiLU. +class AuraFlowFeedForward(nn.Module): + def __init__(self, dim, hidden_dim=None) -> None: + super().__init__() + if hidden_dim is None: + hidden_dim = 4 * dim + + final_hidden_dim = int(2 * hidden_dim / 3) + final_hidden_dim = find_multiple(final_hidden_dim, 256) + + self.linear_1 = nn.Linear(dim, final_hidden_dim, bias=False) + self.linear_2 = nn.Linear(dim, final_hidden_dim, bias=False) + self.out_projection = nn.Linear(final_hidden_dim, dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.linear_1(x)) * self.linear_2(x) + x = self.out_projection(x) + return x + + +class AuraFlowPreFinalBlock(nn.Module): + def __init__(self, embedding_dim: int, conditioning_embedding_dim: int): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = x * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +@maybe_allow_in_graph +class AuraFlowSingleTransformerBlock(nn.Module): + """Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT.""" + + def __init__(self, dim, num_attention_heads, attention_head_dim): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") + + processor = AuraFlowAttnProcessor2_0() + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="fp32_layer_norm", + out_dim=dim, + bias=False, + out_bias=False, + processor=processor, + ) + + self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) + self.ff = AuraFlowFeedForward(dim, dim * 4) + + def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9999): + residual = hidden_states + + # Norm + Projection. + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + # Attention. + attn_output = self.attn(hidden_states=norm_hidden_states, i=i) + + # Process attention outputs for the `hidden_states`. + hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) + hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(hidden_states) + hidden_states = gate_mlp.unsqueeze(1) * ff_output + hidden_states = residual + hidden_states + + return hidden_states + + +@maybe_allow_in_graph +class AuraFlowJointTransformerBlock(nn.Module): + r""" + Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive): + + * QK Norm in the attention blocks + * No bias in the attention blocks + * Most LayerNorms are in FP32 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + is_last (`bool`): Boolean to determine if this is the last block in the model. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") + self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") + + processor = AuraFlowAttnProcessor2_0() + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + added_proj_bias=False, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="fp32_layer_norm", + out_dim=dim, + bias=False, + out_bias=False, + processor=processor, + context_pre_only=False, + ) + + self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) + self.ff = AuraFlowFeedForward(dim, dim * 4) + self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False) + self.ff_context = AuraFlowFeedForward(dim, dim * 4) + + def forward( + self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=0 + ): + residual = hidden_states + residual_context = encoder_hidden_states + + # Norm + Projection. + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, i=i + ) + + # Process attention outputs for the `hidden_states`. + hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) + hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + hidden_states = gate_mlp.unsqueeze(1) * self.ff(hidden_states) + hidden_states = residual + hidden_states + + # Process attention outputs for the `encoder_hidden_states`. + encoder_hidden_states = self.norm2_context(residual_context + c_gate_msa.unsqueeze(1) * context_attn_output) + encoder_hidden_states = encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + encoder_hidden_states = c_gate_mlp.unsqueeze(1) * self.ff_context(encoder_hidden_states) + encoder_hidden_states = residual_context + encoder_hidden_states + + return encoder_hidden_states, hidden_states + + +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 64, + patch_size: int = 2, + in_channels: int = 4, + num_mmdit_layers: int = 4, + num_single_dit_layers: int = 32, + attention_head_dim: int = 256, + num_attention_heads: int = 12, + joint_attention_dim: int = 2048, + caption_projection_dim: int = 3072, + out_channels: int = 4, + pos_embed_max_size: int = 1024, + ): + super().__init__() + default_out_channels = in_channels + self.out_channels = out_channels if out_channels is not None else default_out_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = AuraFlowPatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=pos_embed_max_size, + ) + + self.context_embedder = nn.Linear( + self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False + ) + self.time_step_embed = Timesteps(num_channels=256, downscale_freq_shift=0, scale=1000, flip_sin_to_cos=True) + self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) + + self.joint_transformer_blocks = nn.ModuleList( + [ + AuraFlowJointTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_mmdit_layers) + ] + ) + self.single_transformer_blocks = nn.ModuleList( + [ + AuraFlowSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for _ in range(self.config.num_single_dit_layers) + ] + ) + + self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + # https://arxiv.org/abs/2309.16588 + # prevents artifacts in the attention maps + self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + height, width = hidden_states.shape[-2:] + + # Apply patch embedding, timestep embedding, and project the caption embeddings. + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype) + temb = self.time_step_proj(temb) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + encoder_hidden_states = torch.cat( + [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 + ) + + # MMDiT blocks. + for index_block, block in enumerate(self.joint_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, i=index_block + ) + + # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text) + if len(self.single_transformer_blocks) > 0: + encoder_seq_len = encoder_hidden_states.size(1) + combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + combined_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + combined_hidden_states, + temb, + **ckpt_kwargs, + ) + + else: + combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb) + + hidden_states = combined_hidden_states[:, encoder_seq_len:] + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.config.patch_size + out_channels = self.config.out_channels + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 8313ffd87a50..cc0dcbd79e9f 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -14,7 +14,6 @@ from typing import Dict, Optional, Union import torch -import torch.nn.functional as F from torch import nn from ...configuration_utils import ConfigMixin, register_to_config @@ -29,20 +28,12 @@ ) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous +from ..normalization import AdaLayerNormContinuous, FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class FP32LayerNorm(nn.LayerNorm): - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - origin_dtype = inputs.dtype - return F.layer_norm( - inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps - ).to(origin_dtype) - - class AdaLayerNormShift(nn.Module): r""" Norm layer modified to incorporate timestep embeddings. diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index f26690cedc34..7be018354d86 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -250,6 +250,7 @@ "StableDiffusionLDM3DPipeline", ] ) + _import_structure["aura_flow"] = ["AuraFlowPipeline"] _import_structure["stable_diffusion_3"] = [ "StableDiffusion3Pipeline", "StableDiffusion3Img2ImgPipeline", @@ -418,6 +419,7 @@ AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, ) + from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, diff --git a/src/diffusers/pipelines/aura_flow/__init__.py b/src/diffusers/pipelines/aura_flow/__init__.py new file mode 100644 index 000000000000..e1917baa61e2 --- /dev/null +++ b/src/diffusers/pipelines/aura_flow/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_aura_flow"] = ["AuraFlowPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_aura_flow import AuraFlowPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py new file mode 100644 index 000000000000..73b149e853cf --- /dev/null +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -0,0 +1,489 @@ +# Copyright 2024 AuraFlow Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import List, Optional, Tuple, Union + +import torch +from transformers import T5Tokenizer, UMT5EncoderModel + +from ...image_processor import VaeImageProcessor +from ...models import AuraFlowTransformer2DModel, AutoencoderKL +from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AuraFlowPipeline(DiffusionPipeline): + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKL, + transformer: AuraFlowTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]] = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + """ + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = max_sequence_length + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + text_inputs = {k: v.to(device) for k, v in text_inputs.items()} + text_input_ids = text_inputs["input_ids"] + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(**text_inputs)[0] + prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) + prompt_embeds = prompt_embeds * prompt_attention_mask + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.reshape(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + uncond_input = {k: v.to(device) for k, v in uncond_input.items()} + negative_prompt_embeds = self.text_encoder(**uncond_input)[0] + negative_prompt_attention_mask = ( + uncond_input["attention_mask"].unsqueeze(-1).expand(negative_prompt_embeds.shape) + ) + negative_prompt_embeds = negative_prompt_embeds * negative_prompt_attention_mask + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.reshape(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = 512, + width: Optional[int] = 512, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Determine batch size. + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + + # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = torch.tensor([t / 1000]).expand(latent_model_input.shape[0]) + timestep = timestep.to(latents.device, dtype=latents.dtype) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 83ce63981abd..779e691f0c27 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -158,7 +158,12 @@ def scale_noise( def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -168,17 +173,19 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps - timesteps = np.linspace( - self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps - ) + if sigmas is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) - sigmas = timesteps / self.config.num_train_timesteps - sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) timesteps = sigmas * self.config.num_train_timesteps + self.timesteps = timesteps.to(device=device) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index f03fe42cdd66..5df0d6d28f53 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -17,6 +17,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AuraFlowTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKL(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index e9aecf0697ab..97fc0ca1b6d0 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -182,6 +182,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class AuraFlowPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class ChatGLMModel(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_aura_flow.py b/tests/models/transformers/test_models_transformer_aura_flow.py new file mode 100644 index 000000000000..57fac4ba769c --- /dev/null +++ b/tests/models/transformers/test_models_transformer_aura_flow.py @@ -0,0 +1,73 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AuraFlowTransformer2DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class SD3TransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = AuraFlowTransformer2DModel + main_input_name = "hidden_states" + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + height = width = embedding_dim = 32 + sequence_length = 256 + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "sample_size": 32, + "patch_size": 2, + "in_channels": 4, + "num_mmdit_layers": 1, + "num_single_dit_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 4, + "caption_projection_dim": 32, + "joint_attention_dim": 32, + "out_channels": 4, + "pos_embed_max_size": 256, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict diff --git a/tests/pipelines/aura_flow/__init__.py b/tests/pipelines/aura_flow/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py b/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py new file mode 100644 index 000000000000..9a2f1846f1d9 --- /dev/null +++ b/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py @@ -0,0 +1,121 @@ +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler +from diffusers.utils.testing_utils import ( + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = AuraFlowPipeline + params = frozenset( + [ + "prompt", + "height", + "width", + "guidance_scale", + "negative_prompt", + "prompt_embeds", + "negative_prompt_embeds", + ] + ) + batch_params = frozenset(["prompt", "negative_prompt"]) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = AuraFlowTransformer2DModel( + sample_size=32, + patch_size=2, + in_channels=4, + num_mmdit_layers=1, + num_single_dit_layers=1, + attention_head_dim=8, + num_attention_heads=4, + caption_projection_dim=32, + joint_attention_dim=32, + out_channels=4, + pos_embed_max_size=256, + ) + + text_encoder = UMT5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-umt5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=32, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + "height": None, + "width": None, + } + return inputs + + def test_aura_flow_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + do_classifier_free_guidance = inputs["guidance_scale"] > 1 + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = pipe.encode_prompt( + prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + device=torch_device, + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_attention_mask=negative_prompt_attention_mask, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 + + def test_attention_slicing_forward_pass(self): + # Attention slicing needs to implemented differently for this because how single DiT and MMDiT + # blocks interfere with each other. + return