Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Community Pipeline] IPAdapter FaceID #6276

Merged
merged 16 commits into from
Jan 15, 2024

Conversation

fabiorigano
Copy link
Contributor

What does this PR do?

Add support for IPAdapter FaceID
Fixes #6243

Who can review?

@patrickvonplaten, @sayakpaul, @yiyixuxu

@fabiorigano
Copy link
Contributor Author

import cv2
from insightface.app import FaceAnalysis
import numpy as np
from PIL import Image
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers.utils import load_image

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1
)

pipeline = StableDiffusionPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V4.0_noVAE",
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    feature_extractor=None,
    safety_checker=None
).to("cuda")

generator = torch.Generator(device="cpu").manual_seed(42)
num_images=4
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")

app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))

image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)
faces = app.get(image)
image = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)

pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", weight_name="ip-adapter-faceid_sd15.bin")
pipeline.set_ip_adapter_scale(0.7)
images = pipeline(
  prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
  image_embeds=image,
  negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
  num_inference_steps=20, num_images_per_prompt=num_images, width=512, height=704, 
  generator=generator,
  #output_type= "np"
).images
image_grid(images, 1, 4)

##Output
image

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul requested a review from yiyixuxu December 22, 2023 01:41
@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Dec 22, 2023

Thank you so much for adding this @fabiorigano! Was trying to work on this as well but great to see the quick addition. Sharing some more results from the PR with EpicRealism-v4.

Results
Original Image
Generated Images

Will be interesting to see if multiple reference images could be provided as input to improve on the quality in the future.

@sayakpaul
Copy link
Member

@fabiorigano thanks so much for this.

Let's try to get the CI green.

weight_name: str,
subfolder: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can pass this to the kwargs right? I don't think there's a need to expose this specifically. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can pass this to the kwargs right? I don't think there's a need to expose this specifically. WDYT?

Yes, we can pass it as keyword argument. The other IP Adapter models need the image encoder, stored in a subfolder of H94/IP-Adapter, while FaceID model doesn't require an image encoder, so I made it Optional. I will remove it in the next update

@@ -684,13 +684,20 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
diffusers_name = key.replace("proj", "image_embeds")
updated_state_dict[diffusers_name] = value

elif "proj.3.weight" in state_dict:
elif "proj.0.weight" in state_dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is using this key a better option? Can we use a more resilient condition here to avoid side-effects?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed to proj.0.weight because both IPAdapter Full and FaceID state dicts have it, while the IPAdapter Full proj.3.weight key is named norm.weight in the FaceID model

@sayakpaul
Copy link
Member

Would be also nice to have some documentation about it here:
https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters#ip-adapter

WDYT?

hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
rank=128,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make rank an argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for now, but perhaps new IP Adapter will be released in the future, using different LoRA ranks. Do you think it is better to remove it for now?

@fabiorigano
Copy link
Contributor Author

Would be also nice to have some documentation about it here: https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters#ip-adapter

WDYT?

Sure, I can add the snippet on top of this PR

@patrickvonplaten
Copy link
Contributor

@yiyixuxu could you take a look here?

@@ -46,7 +48,6 @@ class IPAdapterMixin:
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
subfolder: str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change, can we leave as it was here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done as you suggested!

Comment on lines 140 to 148
try:
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=os.path.join(subfolder, "image_encoder"),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
except TypeError:
print("IPAdapter: `subfolder` not found, `image_encoder` is None, use image_embeds.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try:
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=os.path.join(subfolder, "image_encoder"),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
except TypeError:
print("IPAdapter: `subfolder` not found, `image_encoder` is None, use image_embeds.")
try:
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=os.path.join(subfolder, "image_encoder"),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
except TypeError:
print("IPAdapter: `subfolder` not found, `image_encoder` is None, use image_embeds.")

Let's try to not use try...except here please

@patrickvonplaten
Copy link
Contributor

@yiyixuxu can you take a look here?

@a-r-r-o-w
Copy link
Member

Btw, FaceID v2 and SDXL versions were made public last week. Maybe we could look into supporting and testing both in this PR too?

@sayakpaul
Copy link
Member

One at a time please.

@fabiorigano
Copy link
Contributor Author

fabiorigano commented Jan 6, 2024

I just fixed the loading of the LoRA weights, because some of them were missing.
With this update, output images are more aligned with the input image, thanks to both the LoRA layers and the IP Adapter.

_p0
_p1

I updated set_default_attn_processor in order to restore processors with torch sdpa when using unload_ip_adapter.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 9, 2024

@patrickvonplaten
This is still an experimental feature - does creating a community pipeline for this make more sense?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 9, 2024

@sayakpaul, I completely agree that this can enable real use cases and have huge potential.

However, this model was released as an experimental checkpoint, so I think it's too early for us to integrate. I think we have two options here:

  1. add this a community pipeline
  2. wait a little bit to see if they release a stable checkpoint and if this model gains more popularity

I prefer the option 1 - It will require a little bit of effort to move all the files to the community folder but will allow users to try this pipeline out at mealtime . What do you think? @fabiorigano it will be completely up to you:)

@fabiorigano
Copy link
Contributor Author

hi @yiyixuxu, no problem for me, I can move everything to the community folder. I would leave some lines in common with IPAdapterFull unchanged, because in my opinion the current PR implementation is more readable and general.

I also think that all classes copying set_default_attn_processor from UNet2DConditionModel should be updated to support pytorch's sdpa, as it is done in this PR.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 9, 2024

@fabiorigano

I would leave some lines in common with IPAdapterFull unchanged, because in my opinion the current PR implementation is more readable and general.
I also think that all classes copying set_default_attn_processor from UNet2DConditionModel should be updated to support pytorch's sdpa, as it is done in this PR.

ok by me - let's review it again once move it to the community folder :)

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 9, 2024

@fabiorigano
it will be easier if you separate it into two PRs:

  • one PR just for the community pipeline, we can merge that quickly
  • and the other PR for the other changes you mentioned - it will require more time for us to review

YiYi

@fabiorigano
Copy link
Contributor Author

@fabiorigano
it will be easier if you separate it into two PRs:

  • one PR just for the community pipeline, we can merge that quickly
  • and the other PR for the other changes you mentioned - it will require more time for us to review

YiYi

Ok perfect! Will do it tomorrow

@sayakpaul
Copy link
Member

However, this model was released as an experimental checkpoint, so I think it's too early for us to integrate

Makes sense, yeah!

@okaris
Copy link

okaris commented Jan 10, 2024

@fabiorigano Thank you for your contribution. I took your branch on a test run but haven't had success loading the FaceID SDXL weights.

ip-adapter-faceid_sdxl.bin is being tried to load as a LoRA and throws an error:

self.ip_pipe.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sdxl.bin")
File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn
return fn(*args, **kwargs)
File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/diffusers/loaders/ip_adapter.py", line 164, in load_ip_adapter
unet._load_ip_adapter_weights(state_dict)
File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/diffusers/loaders/unet.py", line 826, in _load_ip_adapter_weights
attn_module.to_q.set_lora_layer(
File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'Linear' object has no attribute 'set_lora_layer'
Traceback (most recent call last):
  File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/cog/server/runner.py", line 296, in setup
    for event in worker.setup():
  File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/cog/server/worker.py", line 126, in _wait
    raise FatalWorkerException(raise_on_error + ": " + done.error_detail)
cog.server.exceptions.FatalWorkerException: Predictor errored during setup: 'Linear' object has no attribute 'set_lora_layer'

ip-adapter-faceid_sdxl_lora.safetensors throws:

File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn
return fn(*args, **kwargs)
File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/diffusers/loaders/ip_adapter.py", line 164, in load_ip_adapter
unet._load_ip_adapter_weights(state_dict)
File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/diffusers/loaders/unet.py", line 796, in _load_ip_adapter_weights
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
KeyError: 'latents'
Traceback (most recent call last):
  File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/cog/server/runner.py", line 296, in setup
    for event in worker.setup():
  File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/cog/server/worker.py", line 126, in _wait
    raise FatalWorkerException(raise_on_error + ": " + done.error_detail)
cog.server.exceptions.FatalWorkerException: Predictor errored during setup: 'latents'

@fabiorigano
Copy link
Contributor Author

fabiorigano commented Jan 10, 2024

hi @okaris, when diffusers.utils.USE_PEFT_BACKEND is True, this model can't be loaded. To bypass this constant, I made a little change in the definition, so you can put it to False (@yiyixuxu @sayakpaul I would like to know if it is acceptable for you). In any case, this implementation only support SD, I am quite sure something must be updated for SDXL

I moved everything to community folder

@fabiorigano fabiorigano changed the title Add support for IPAdapter FaceID [Community Pipeline] IPAdapter FaceID Jan 10, 2024
@okaris
Copy link

okaris commented Jan 11, 2024

@fabiorigano do you mind sharing an example how to use that to disable it?

@yiyixuxu
Copy link
Collaborator

hi @fabiorigano

when diffusers.utils.USE_PEFT_BACKEND is True, this model can't be loaded.

do you know why it won't work when this constant is True? i.e. which line of code is causing the problem?

@fabiorigano
Copy link
Contributor Author

@okaris

you have the full example in the README

import diffusers
diffusers.utils.USE_PEFT_BACKEND = False

@fabiorigano
Copy link
Contributor Author

@yiyixuxu

do you know why it won't work when this constant is True? i.e. which line of code is causing the problem?

here, it is used toch.nn.Linear instead of LoRACompatibleLinear, and so I cannot call set_lora_layer

@yiyixuxu
Copy link
Collaborator

@fabiorigano
ohh I saw you reverted it - does that mean it's working now?

cc @sayakpaul here for his expertise
@fabiorigano needs to set diffusers.utils.USE_PEFT_BACKEND = False - what's the best way to go about this? see comments #6276 (comment)

@fabiorigano
Copy link
Contributor Author

fabiorigano commented Jan 11, 2024

@yiyixuxu yes, at first I was probably setting it after creating the pipeline, but it works if set before, as I did in the README example. Sorry for not checking before committing 😅

@yiyixuxu
Copy link
Collaborator

Alright! great! I think it is good to merge then
I will wait @sayakpaul to take a final look :)

@okaris
Copy link

okaris commented Jan 13, 2024

Hi @fabiorigano how can I test the community pipeline version before it's merged?

huggingface_hub.utils._errors.HfHubHTTPError: 404 Client Error: Not Found for url: https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/ip_adapter_face_id.py

@fabiorigano
Copy link
Contributor Author

Hi @okaris, pass the path to the python file as custom_pipeline:

pipeline = DiffusionPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V4.0_noVAE",
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    vae=vae,
    custom_pipeline="./path/to/ip_adapter_face_id.py"
)

@patrickvonplaten
Copy link
Contributor

Great job @fabiorigano!

@patrickvonplaten patrickvonplaten merged commit f825221 into huggingface:main Jan 15, 2024
14 checks passed
@okaris
Copy link

okaris commented Feb 7, 2024

Is there a way to make this work with the PEFT backend enabled @yiyixuxu @fabiorigano ?

@yiyixuxu yiyixuxu mentioned this pull request Feb 14, 2024
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Add support for IPAdapter FaceID

* Add docs

* Move subfolder to kwargs

* Fix quality

* Fix image encoder loading

* Fix loading + add test

* Move to community folder

* Fix style

* Revert constant update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

IP-Adapter Face Id
7 participants