-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Add AuraFlow #8796
[Core] Add AuraFlow #8796
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -0,0 +1,401 @@ | |||
# Copyright 2024 Stability AI, Lavender Flow, The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New model because it differs from the SD3 one (non-exhaustive list):
- Uses register tokens
- Mixes MMDiT and another kind of simple DiT block (that uses a concatenated
encoder_hidden_states
andhidden_states
as its inputs) - The final layer norm is different
- Position embeddings are different (uses learned positional embeddings)
- The feedforward is different. We only support GeLU and its variants in the feedforward. It uses SwiGLU.
- No pooled projections.
def _set_gradient_checkpointing(self, module, value=False): | ||
if hasattr(module, "gradient_checkpointing"): | ||
module.gradient_checkpointing = value | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have deliberately kept additional methods like feedforward chunking, QKV fusion, etc. out of this class because it helps with the initial review.
src/diffusers/pipelines/lavender_flow/pipeline_lavender_flow.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice!
left some comments mainly on attention processor
Looking into the test failures 👀 |
i've been testing a fork of this with the LoRA support and it works without any changes to just add the peft adapter to the Transformer2D model and the SD3 LoRA loader mixin to the pipeline. |
@yiyixuxu @DN6 I have addressed the comments. Here are some extended comments from my end:
@bghira, I will add LoRA support in an immediate future PR once this PR is merged to keep the reviewing scope concrete and manageable. It's not just about adding those classes like you mentioned. We need to scale and unscale the layers appropriately for dealing with |
I have also added fast tests and decided to make the default value of negative prompt to be None instead of "This is watermark, jpeg image white background, web image". I think this is better aligned with our other pipelines. Will include this negative prompt in the docs once I start adding them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice!!
@@ -158,7 +158,12 @@ def scale_noise( | |||
def _sigma_to_t(self, sigma): | |||
return sigma * self.config.num_train_timesteps | |||
|
|||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): | |||
def set_timesteps( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
umm I don't think these changed are introduced in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right, I was confused 😅
padding="max_length", | ||
return_tensors="pt", | ||
) | ||
text_inputs = {k: v.to(device) for k, v in text_inputs.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious what else is in text_inputs
other than the text_input_ids
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attention_mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
* add lavender flow transformer --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
What does this PR do?
Adds Aura Flow from Fal.
Test code:
Warning
To download the model you must be a member of the AuraDiffusion org. Follow this (internal) Slack message.
Gives:
TODOS
Because of the last point above, the noise scheduling code is taken from the original codebase. But I think this PR is still ready for a first review.