From fa7a57619156aa21fe38a493d88a270ba1f538e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Tolga=20Cang=C3=B6z?= <46008593+standardAI@users.noreply.github.com> Date: Mon, 13 Mar 2023 18:41:28 +0300 Subject: [PATCH 1/3] Update schedulers.mdx (#2647) Fix typos --- docs/source/en/using-diffusers/schedulers.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/using-diffusers/schedulers.mdx b/docs/source/en/using-diffusers/schedulers.mdx index 9283df3aabde4..e17d826c7dab1 100644 --- a/docs/source/en/using-diffusers/schedulers.mdx +++ b/docs/source/en/using-diffusers/schedulers.mdx @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. # Schedulers Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize -a pipeline to one's use case. The best example of this are the [Schedulers](../api/schedulers/overview.mdx). +a pipeline to one's use case. The best example of this is the [Schedulers](../api/schedulers/overview.mdx). Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample, schedulers define the whole denoising process, *i.e.*: @@ -24,7 +24,7 @@ schedulers define the whole denoising process, *i.e.*: They can be quite complex and often define a trade-off between **denoising speed** and **denoising quality**. It is extremely difficult to measure quantitatively which scheduler works best for a given diffusion pipeline, so it is often recommended to simply try out which works best. -The following paragraphs shows how to do so with the 🧨 Diffusers library. +The following paragraphs show how to do so with the 🧨 Diffusers library. ## Load pipeline From 4ae54b37898cf2c97518038d5028645bdb816dca Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 13 Mar 2023 19:10:33 +0100 Subject: [PATCH 2/3] [attention] Fix attention (#2656) * [attention] Fix attention * fix * correct --- src/diffusers/models/attention.py | 7 +++++-- tests/pipelines/stable_diffusion/test_stable_diffusion.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b476e762675c1..6da318e65593d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -271,9 +271,10 @@ def __init__( def forward( self, hidden_states, + attention_mask=None, encoder_hidden_states=None, + encoder_attention_mask=None, timestep=None, - attention_mask=None, cross_attention_kwargs=None, class_labels=None, ): @@ -302,12 +303,14 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here # 2. Cross-Attention attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index dfb14617f8853..d4fd304583733 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -737,7 +737,7 @@ def test_stable_diffusion_vae_tiling(self): # make sure that more than 4 GB is allocated mem_bytes = torch.cuda.max_memory_allocated() - assert mem_bytes > 4e9 + assert mem_bytes > 5e9 assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2 def test_stable_diffusion_fp16_vs_autocast(self): From d9b8adc4cac287a10d0641533bf6741cd9e5c80a Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 14 Mar 2023 05:18:30 +0900 Subject: [PATCH 3/3] Add support for Multi-ControlNet to StableDiffusionControlNetPipeline (#2627) * support for List[ControlNetModel] on init() * Add to support for multiple ControlNetCondition * rename conditioning_scale to scale * scaling bugfix * Manually merge `MultiControlNet` #2621 Co-authored-by: Patrick von Platen * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: Patrick von Platen * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: Patrick von Platen * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: Patrick von Platen * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: Patrick von Platen * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: Patrick von Platen * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: Patrick von Platen * cleanups - don't expose ControlNetCondition - move scaling to ControlNetModel * make style error correct * remove ControlNetCondition to reduce code diff * refactoring image/cond_scale * add explain for `images` * Add docstrings * all fast-test passed * Add a slow test * nit * Apply suggestions from code review * small precision fix * nits MultiControlNet -> MultiControlNetModel - Matches existing naming a bit closer MultiControlNetModel inherit from model utils class - Don't have to re-write fp16 test Skip tests that save multi controlnet pipeline - Clearer than changing test body Don't auto-batch the number of input images to the number of controlnets. We generally like to require the user to pass the expected number of inputs. This simplifies the processing code a bit more Use existing image pre-processing code a bit more. We can rely on the existing image pre-processing code and keep the inference loop a bit simpler. --------- Co-authored-by: Patrick von Platen Co-authored-by: William Berman --- src/diffusers/models/controlnet.py | 5 + .../pipeline_stable_diffusion_controlnet.py | 214 +++++++++++++++--- .../test_stable_diffusion_controlnet.py | 196 ++++++++++++++++ 3 files changed, 384 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 5b55da5af3727..a9686dfc31bf4 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -389,6 +389,7 @@ def forward( timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: torch.FloatTensor, + conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -492,6 +493,10 @@ def forward( mid_block_res_sample = self.controlnet_mid_block(sample) + # 6. scaling + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + if not return_dict: return (down_block_res_samples, mid_block_res_sample) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 00f35ee4d0218..d6440214df1ee 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -14,14 +14,18 @@ import inspect -from typing import Any, Callable, Dict, List, Optional, Union +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch +from torch import nn from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models.controlnet import ControlNetOutput +from ...models.modeling_utils import ModelMixin from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, @@ -85,6 +89,63 @@ """ +class MultiControlNetModel(ModelMixin): + r""" + Multiple `ControlNetModel` wrapper class for Multi-ControlNet + + This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be + compatible with `ControlNetModel`. + + Args: + controlnets (`List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `ControlNetModel` as a list. + """ + + def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: List[torch.tensor], + conditioning_scale: List[float], + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + down_samples, mid_sample = controlnet( + sample, + timestep, + encoder_hidden_states, + image, + scale, + class_labels, + timestep_cond, + attention_mask, + cross_attention_kwargs, + return_dict, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + return down_block_res_samples, mid_block_res_sample + + class StableDiffusionControlNetPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. @@ -103,8 +164,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - controlnet ([`ControlNetModel`]): - Provides additional conditioning to the unet during the denoising process + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. @@ -122,7 +185,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: ControlNetModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, @@ -146,6 +209,9 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -432,6 +498,7 @@ def check_inputs( negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -470,6 +537,41 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + # Check `image` + + if isinstance(self.controlnet, ControlNetModel): + self.check_image(image, prompt, prompt_embeds) + elif isinstance(self.controlnet, MultiControlNetModel): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + if len(image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `image` must have the same length as the number of controlnets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + + if isinstance(self.controlnet, ControlNetModel): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif isinstance(self.controlnet, MultiControlNetModel): + if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) @@ -501,7 +603,9 @@ def check_inputs( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) - def prepare_image(self, image, width, height, batch_size, num_images_per_prompt, device, dtype): + def prepare_image( + self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance + ): if not isinstance(image, torch.Tensor): if isinstance(image, PIL.Image.Image): image = [image] @@ -529,6 +633,9 @@ def prepare_image(self, image, width, height, batch_size, num_images_per_prompt, image = image.to(device=device, dtype=dtype) + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents @@ -550,7 +657,10 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype return latents def _default_height_width(self, height, width, image): - if isinstance(image, list): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): image = image[0] if height is None: @@ -571,6 +681,18 @@ def _default_height_width(self, height, width, image): return height, width + # override DiffusionPipeline + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = False, + variant: Optional[str] = None, + ): + if isinstance(self.controlnet, ControlNetModel): + super().save_pretrained(save_directory, safe_serialization, variant) + else: + raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.") + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -593,7 +715,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: float = 1.0, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, ): r""" Function invoked when calling the pipeline for generation. @@ -602,10 +724,14 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, + `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If - the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can - also be accepted as an image. The control image is automatically resized to fit the output image. + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -658,10 +784,10 @@ def __call__( A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original unet. - + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. Examples: Returns: @@ -676,7 +802,15 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + prompt, + image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, ) # 2. Define call parameters @@ -693,6 +827,9 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) + # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, @@ -705,18 +842,37 @@ def __call__( ) # 4. Prepare image - image = self.prepare_image( - image, - width, - height, - batch_size * num_images_per_prompt, - num_images_per_prompt, - device, - self.controlnet.dtype, - ) + if isinstance(self.controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + elif isinstance(self.controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) - if do_classifier_free_guidance: - image = torch.cat([image] * 2) + images.append(image_) + + image = images + else: + assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -746,20 +902,16 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # controlnet(s) inference down_block_res_samples, mid_block_res_sample = self.controlnet( latent_model_input, t, encoder_hidden_states=prompt_embeds, controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, return_dict=False, ) - down_block_res_samples = [ - down_block_res_sample * controlnet_conditioning_scale - for down_block_res_sample in down_block_res_samples - ] - mid_block_res_sample *= controlnet_conditioning_scale - # predict the noise residual noise_pred = self.unet( latent_model_input, diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py index 406cbc9ad0895..c235b493340d0 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py @@ -14,6 +14,7 @@ # limitations under the License. import gc +import tempfile import unittest import numpy as np @@ -27,6 +28,7 @@ StableDiffusionControlNetPipeline, UNet2DConditionModel, ) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from diffusers.utils import load_image, load_numpy, randn_tensor, slow, torch_device from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import require_torch_gpu @@ -143,6 +145,160 @@ def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) +class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionControlNetPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet1 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + controlnet2 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + controlnet = MultiControlNetModel([controlnet1, controlnet2]) + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + + images = [ + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + ] + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": images, + } + + return inputs + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + def test_save_pretrained_raise_not_implemented_exception(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + with tempfile.TemporaryDirectory() as tmpdir: + try: + # save_pretrained is not implemented for Multi-ControlNet + pipe.save_pretrained(tmpdir) + except NotImplementedError: + pass + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_float16(self): + ... + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_local(self): + ... + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_optional_components(self): + ... + + @slow @require_torch_gpu class StableDiffusionControlNetPipelineSlowTests(unittest.TestCase): @@ -396,3 +552,43 @@ def test_sequential_cpu_offloading(self): mem_bytes = torch.cuda.max_memory_allocated() # make sure that less than 7 GB is allocated assert mem_bytes < 4 * 10**9 + + +@slow +@require_torch_gpu +class StableDiffusionMultiControlNetPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_pose_and_canny(self): + controlnet_canny = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") + controlnet_pose = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=[controlnet_pose, controlnet_canny] + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "bird and Chef" + image_canny = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + image_pose = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png" + ) + + output = pipe(prompt, [image_pose, image_canny], generator=generator, output_type="np") + + image = output.images[0] + + assert image.shape == (768, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose_canny_out.npy" + ) + + assert np.abs(expected_image - image).max() < 5e-2