-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FlaxCLIPTextModelWithProjection #25254
Add FlaxCLIPTextModelWithProjection #25254
Conversation
This is necessary to support the Flax port of Stable Diffusion XL: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/fb6d705fb518524cabc79c77f13a0e7921bcab3a/text_encoder_2/config.json#L3 Co-authored-by: Martin Müller <martin.muller.me@gmail.com> Co-authored-by: Juan Acevedo <juancevedo@gmail.com>
The documentation is not available anymore as the PR was closed or merged. |
Should we maybe for now just add it in a subfolder of sdxl in diffusers here: https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion_xl instead of having to rely on Would also not force the user to have to install transformers from main :-) |
The PyTorch version of the same model was added 9 months ago, so I assumed it was ok. But sure, we can do that. In that case, how do we deal with it?
Yes, of course, this was meant as the long-term solution. |
Ah yeah good point JAX & PyTorch share the same config - this will become complicated indeed then. Ok let's try to get it merged here. CLIP is important enough to be merged to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this @pcuenca - looking clean! Good with me to add this Flax model if you require it for SD / SD XL - I'll leave it to you and Patrick to decide on whether it's an appropriate addition, but from a Flax transformers side the design is great.
Feel free to add the model to the testing file such that it is run by the CI, just need to add the model class here:
all_model_classes = (FlaxCLIPTextModel,) if is_flax_available() else () |
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Thanks a lot for your in-depth review @sanchit-gandhi! 🙌 I couldn't get back to this PR until today, but I think I addressed all your comments and fixed an error in the example docstring. I was getting some seemingly unrelated CI failures so I just merged the latest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good - thanks for iterating @pcuenca!
cc @patrickvonplaten too, in case we want to consider other alternatives :) Otherwise feel free to merge when appropriate, as I can't do it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
* Add FlaxClipTextModelWithProjection This is necessary to support the Flax port of Stable Diffusion XL: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/fb6d705fb518524cabc79c77f13a0e7921bcab3a/text_encoder_2/config.json#L3 Co-authored-by: Martin Müller <martin.muller.me@gmail.com> Co-authored-by: Juan Acevedo <juancevedo@gmail.com> * Use FlaxCLIPTextModelOutput * make fix-copies again * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Use `return_dict` for consistency with other uses. Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Fix docstring example. * Add new model to FlaxCLIPTextModelTest * Add to IGNORE_NON_AUTO_CONFIGURED list * Fix naming convention. --------- Co-authored-by: Martin Müller <martin.muller.me@gmail.com> Co-authored-by: Juan Acevedo <juancevedo@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Add FlaxClipTextModelWithProjection This is necessary to support the Flax port of Stable Diffusion XL: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/fb6d705fb518524cabc79c77f13a0e7921bcab3a/text_encoder_2/config.json#L3 Co-authored-by: Martin Müller <martin.muller.me@gmail.com> Co-authored-by: Juan Acevedo <juancevedo@gmail.com> * Use FlaxCLIPTextModelOutput * make fix-copies again * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Use `return_dict` for consistency with other uses. Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Fix docstring example. * Add new model to FlaxCLIPTextModelTest * Add to IGNORE_NON_AUTO_CONFIGURED list * Fix naming convention. --------- Co-authored-by: Martin Müller <martin.muller.me@gmail.com> Co-authored-by: Juan Acevedo <juancevedo@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Add FlaxClipTextModelWithProjection This is necessary to support the Flax port of Stable Diffusion XL: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/fb6d705fb518524cabc79c77f13a0e7921bcab3a/text_encoder_2/config.json#L3 Co-authored-by: Martin Müller <martin.muller.me@gmail.com> Co-authored-by: Juan Acevedo <juancevedo@gmail.com> * Use FlaxCLIPTextModelOutput * make fix-copies again * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Use `return_dict` for consistency with other uses. Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Fix docstring example. * Add new model to FlaxCLIPTextModelTest * Add to IGNORE_NON_AUTO_CONFIGURED list * Fix naming convention. --------- Co-authored-by: Martin Müller <martin.muller.me@gmail.com> Co-authored-by: Juan Acevedo <juancevedo@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
What does this PR do?
FlaxCLIPTextModelWithProjection
is necessary to support the Flax port of Stable Diffusion XL: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/fb6d705fb518524cabc79c77f13a0e7921bcab3a/text_encoder_2/config.json#L3I can add some tests, if necessary, after this approach is validated.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@patrickvonplaten @patil-suraj @sanchit-gandhi @younesbelkada