Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Add AuraFlow #8796

Merged
merged 42 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
0ebbd30
add lavender flow transformer
sayakpaul Jul 4, 2024
939d990
progress.
sayakpaul Jul 4, 2024
f08baf3
progress
sayakpaul Jul 4, 2024
005a8f6
progress
sayakpaul Jul 5, 2024
e238d56
move out the attention processor.
sayakpaul Jul 5, 2024
570c258
finish implementation of pipeline
sayakpaul Jul 5, 2024
b881190
default neg promot
sayakpaul Jul 5, 2024
b8237b2
up
sayakpaul Jul 5, 2024
89eea61
fixes
sayakpaul Jul 5, 2024
2c97d04
up
sayakpaul Jul 5, 2024
a50e1ff
up for pr
sayakpaul Jul 5, 2024
b0d29b2
fix copies
sayakpaul Jul 5, 2024
ae037cf
Merge branch 'main' into lavender-flow
sayakpaul Jul 5, 2024
ad6cb66
move fp32 layer norm to normalization
sayakpaul Jul 5, 2024
8ae6be7
minor fixes
sayakpaul Jul 5, 2024
47ff911
remove boolean flag and resort to norm_type
sayakpaul Jul 5, 2024
10ed96f
eliminate added_qk_norm
sayakpaul Jul 5, 2024
3d9265e
add added_proj_bias
sayakpaul Jul 5, 2024
84708c4
lavender flow -> aura flow
sayakpaul Jul 6, 2024
4bdea0d
Fix the `added_proj_bias` default value (#8800)
sayakpaul Jul 7, 2024
bcbc972
Merge branch 'main' into lavender-flow
sayakpaul Jul 7, 2024
89fad69
remnant aura flow renaming
sayakpaul Jul 7, 2024
e73442f
make it possible to reuse prompt embeds.
sayakpaul Jul 7, 2024
dccc682
rename to auraflow
sayakpaul Jul 7, 2024
f23151b
[lavender-flow] use flow match euler scheduler (#8799)
yiyixuxu Jul 7, 2024
d9a01f4
resolve conflicts
sayakpaul Jul 9, 2024
8984d23
more feedback.
sayakpaul Jul 9, 2024
a281547
context_norm_type fix
sayakpaul Jul 9, 2024
8830bf1
fix circular import
sayakpaul Jul 9, 2024
4334f72
fix conversion
sayakpaul Jul 9, 2024
b1dc5ec
add fast tests for pipeline
sayakpaul Jul 9, 2024
942377d
fix test file name
sayakpaul Jul 9, 2024
f8a08b5
fix test path
sayakpaul Jul 9, 2024
e9832f9
spacxing brtween initialization
sayakpaul Jul 9, 2024
2c87250
style
sayakpaul Jul 9, 2024
1b3e620
add test for the transformer model.
sayakpaul Jul 9, 2024
66ff7f5
Merge branch 'main' into lavender-flow
sayakpaul Jul 9, 2024
ed33913
remove context_norm_type
sayakpaul Jul 9, 2024
6531e54
remove ada continuous.
sayakpaul Jul 10, 2024
0f721ac
address yiyi
sayakpaul Jul 10, 2024
95708dc
Merge branch 'main' into lavender-flow
yiyixuxu Jul 11, 2024
15d3198
style
yiyixuxu Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions scripts/convert_lavender_flow_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import argparse

import torch
from huggingface_hub import hf_hub_download

from diffusers.models.transformers.lavender_transformer_2d import LavenderFlowTransformer2DModel


def load_original_state_dict(args):
model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="model.bin")
state_dict = torch.load(model_pt, map_location="cpu")
return state_dict


def calculate_layers(state_dict_keys, key_prefix):
dit_layers = set()
for k in state_dict_keys:
if key_prefix in k:
dit_layers.add(int(k.split(".")[2]))
print(f"{key_prefix}: {len(dit_layers)}")
return len(dit_layers)


# similar to SD3 but only for the last norm layer
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight


def convert_transformer(state_dict):
converted_state_dict = {}
state_dict_keys = list(state_dict.keys())

converted_state_dict["register_tokens"] = state_dict.pop("model.register_tokens")
converted_state_dict["pos_embed.pos_embed"] = state_dict.pop("model.positional_encoding")
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("model.init_x_linear.weight")
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("model.init_x_linear.bias")

converted_state_dict["time_step_proj.linear_1.weight"] = state_dict.pop("model.t_embedder.mlp.0.weight")
converted_state_dict["time_step_proj.linear_1.bias"] = state_dict.pop("model.t_embedder.mlp.0.bias")
converted_state_dict["time_step_proj.linear_2.weight"] = state_dict.pop("model.t_embedder.mlp.2.weight")
converted_state_dict["time_step_proj.linear_2.bias"] = state_dict.pop("model.t_embedder.mlp.2.bias")

converted_state_dict["context_embedder.weight"] = state_dict.pop("model.cond_seq_linear.weight")

mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")

# MMDiT blocks 🎸.
for i in range(mmdit_layers):
# feed-forward
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
for orig_k, diffuser_k in path_mapping.items():
for k in ["c_fc1", "c_fc2", "c_proj"]:
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{k}.weight"] = state_dict.pop(
f"model.double_layers.{i}.{orig_k}.{k}.weight"
)

# norms
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
for orig_k, diffuser_k in path_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = state_dict.pop(
f"model.double_layers.{i}.{orig_k}.1.weight"
)

# attns
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
for k, v in attn_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
f"model.double_layers.{i}.attn.{k}.weight"
)

# Single-DiT blocks.
for i in range(single_dit_layers):
# feed-forward
for k in ["c_fc1", "c_fc2", "c_proj"]:
converted_state_dict[f"single_transformer_blocks.{i}.ff.{k}.weight"] = state_dict.pop(
f"model.single_layers.{i}.mlp.{k}.weight"
)

# norms
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = state_dict.pop(
f"model.single_layers.{i}.modCX.1.weight"
)

# attns
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
for k, v in x_attn_mapping.items():
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
f"model.single_layers.{i}.attn.{k}.weight"
)

# Final blocks.
converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict.pop("model.modF.1.weight"), dim=None)

return converted_state_dict


@torch.no_grad()
def populate_state_dict(args):
original_state_dict = load_original_state_dict(args)
state_dict_keys = list(original_state_dict.keys())
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")

converted_state_dict = convert_transformer(original_state_dict)
model_diffusers = LavenderFlowTransformer2DModel(
num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers
)
model_diffusers.load_state_dict(converted_state_dict, strict=True)

return model_diffusers


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str)
parser.add_argument("--dump_path", default="lavender-flow", type=str)
parser.add_argument("--hub_id", default=None, type=str)
args = parser.parse_args()

model_diffusers = populate_state_dict(args)
model_diffusers.save_pretrained(args.dump_path)
if args.hub_id is not None:
model_diffusers.push_to_hub(args.hub_id)
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
"HunyuanDiT2DMultiControlNetModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
"LavenderFlowTransformer2DModel",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
Expand Down Expand Up @@ -267,6 +268,7 @@
"KandinskyV22PriorPipeline",
"LatentConsistencyModelImg2ImgPipeline",
"LatentConsistencyModelPipeline",
"LavenderFlowPipeline",
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
Expand Down Expand Up @@ -509,6 +511,7 @@
HunyuanDiT2DMultiControlNetModel,
I2VGenXLUNet,
Kandinsky3UNet,
LavenderFlowTransformer2DModel,
ModelMixin,
MotionAdapter,
MultiAdapter,
Expand Down Expand Up @@ -666,6 +669,7 @@
KandinskyV22PriorPipeline,
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
LavenderFlowPipeline,
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
_import_structure["transformers.lavender_transformer_2d"] = ["LavenderFlowTransformer2DModel"]
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
Expand Down Expand Up @@ -85,6 +86,7 @@
DiTTransformer2DModel,
DualTransformer2DModel,
HunyuanDiT2DModel,
LavenderFlowTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SD3Transformer2DModel,
Expand Down
135 changes: 129 additions & 6 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -102,6 +102,7 @@ def __init__(
cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32,
qk_norm: Optional[str] = None,
added_qk_norm: Optional[str] = None,
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
Expand All @@ -115,8 +116,11 @@ def __init__(
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
context_pre_only=None,
use_fp32_layer_norm=False,
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__()
from .normalization import FP32LayerNorm
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
Expand Down Expand Up @@ -166,8 +170,33 @@ def __init__(
self.norm_q = None
self.norm_k = None
elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
self.norm_q = (
nn.LayerNorm(dim_head, eps=eps)
if not use_fp32_layer_norm
else FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
)
self.norm_k = (
nn.LayerNorm(dim_head, eps=eps)
if not use_fp32_layer_norm
else FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
)
else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")

if added_qk_norm is None:
self.norm_added_q = None
self.norm_added_k = None
elif added_qk_norm == "layer_norm":
self.norm_added_q = (
nn.LayerNorm(dim_head, eps=eps)
if not use_fp32_layer_norm
else FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
)
self.norm_added_k = (
nn.LayerNorm(dim_head, eps=eps)
if not use_fp32_layer_norm
else FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
)
else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")

Expand Down Expand Up @@ -205,10 +234,10 @@ def __init__(
self.to_v = None

if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=out_bias)
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=out_bias)
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=out_bias)

self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
Expand Down Expand Up @@ -1133,6 +1162,100 @@ def __call__(
return hidden_states, encoder_hidden_states


class LavenderFlowAttnProcessor2_0:
"""Attention processor used typically in processing Lavender Flow."""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
raise ImportError(
"LavenderFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
)

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
i=0,
*args,
**kwargs,
) -> torch.FloatTensor:
batch_size = hidden_states.shape[0]

# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

# `context` projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

# Reshape.
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)

# Apply QK norm.
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Concatenate the projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)

if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)

query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

# Attention.
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# Split the attention outputs.
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, encoder_hidden_states.shape[1] :],
hidden_states[:, : encoder_hidden_states.shape[1]],
)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states


class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,18 +386,20 @@ def forward(self, sample, condition=None):


class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale

def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb

Expand Down
Loading
Loading