From 82bd9b1897dab2a45d17abac0ebbd3c9ed4fda9c Mon Sep 17 00:00:00 2001 From: Yam <40912707+Yam0214@users.noreply.github.com> Date: Wed, 22 Feb 2023 15:31:47 +0800 Subject: [PATCH] add DDIM to CLIP Guided Stable Diffusion and add code example (#4920) * add DDIM to CLIP Guided Stable Diffusion and add example code * modify code sample * use generator and modify code sample --- ppdiffusers/examples/community/README.md | 72 +++++++++++++++++++ .../community/clip_guided_stable_diffusion.py | 43 ++++++++--- .../inference_clip_guided_stable_diffusion.py | 41 ++++++----- 3 files changed, 126 insertions(+), 30 deletions(-) create mode 100644 ppdiffusers/examples/community/README.md diff --git a/ppdiffusers/examples/community/README.md b/ppdiffusers/examples/community/README.md new file mode 100644 index 000000000000..3918d86c96c7 --- /dev/null +++ b/ppdiffusers/examples/community/README.md @@ -0,0 +1,72 @@ +# Community Examples + +社区示例包含由社区添加的推理和训练示例。可以从下表中了解所有社区实例的概况。点击**Code Example**,跳转到对应实例的可运行代码,可以复制并运行。如果一个示例不能像预期的那样工作,请创建一个issue提问。 + +|Example|Description|Code Example|Author| +|-|-|-|-| +|CLIP Guided Stable Diffusion|使用CLIP引导Stable Diffusion实现文生图|[CLIP Guided Stable Diffusion](#CLIP%20Guided%20Stable%20Diffusion)|| + +## Example usages + +### CLIP Guided Stable Diffusion + +使用 CLIP 模型引导 Stable Diffusion 去噪,可以生成更真实的图像。 + +以下代码运行需要16GB的显存。 + +```python +import os + +import paddle +from clip_guided_stable_diffusion import CLIPGuidedStableDiffusion + +from paddlenlp.transformers import CLIPFeatureExtractor, CLIPModel + +feature_extractor = CLIPFeatureExtractor.from_pretrained( + "laion/CLIP-ViT-B-32-laion2B-s34B-b79K") +clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", + dtype=paddle.float32) + +guided_pipeline = CLIPGuidedStableDiffusion.from_pretrained( + "runwayml/stable-diffusion-v1-5", + clip_model=clip_model, + feature_extractor=feature_extractor, + paddle_dtype=paddle.float16, +) +guided_pipeline.enable_attention_slicing() + +prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece" + +generator = paddle.Generator().manual_seed(2022) +with paddle.amp.auto_cast(True, level="O2"): + images = [] + for i in range(4): + image = guided_pipeline( + prompt, + num_inference_steps=50, + guidance_scale=7.5, + clip_guidance_scale=100, + num_cutouts=4, + use_cutouts=False, + generator=generator, + unfreeze_unet=False, + unfreeze_vae=False, + ).images[0] + images.append(image) + +# save images locally +if not os.path.exists("clip_guided_sd"): + os.mkdir("clip_guided_sd") +for i, img in enumerate(images): + img.save(f"./clip_guided_sd/image_{i}.png") +``` +生成的图片保存在`images`列表中,样例如下: + +| image_0 | image_1 | image_2 | image_3 | +|:-------------------:|:-------------------:|:-------------------:|:-------------------:| +|![][clip_guided_sd_0]|![][clip_guided_sd_1]|![][clip_guided_sd_2]|![][clip_guided_sd_3]| + +[clip_guided_sd_0]: https://user-images.githubusercontent.com/40912707/220514674-e5cb29a3-b07e-4e8f-a4c8-323b35637294.png +[clip_guided_sd_1]: https://user-images.githubusercontent.com/40912707/220514703-1eaf444e-1506-4c44-b686-5950fd79a3da.png +[clip_guided_sd_2]: https://user-images.githubusercontent.com/40912707/220514765-89e48c13-156f-4e61-b433-06f1283d2265.png +[clip_guided_sd_3]: https://user-images.githubusercontent.com/40912707/220514751-82d63fd4-e35e-482b-a8e1-c5c956119b2e.png diff --git a/ppdiffusers/examples/community/clip_guided_stable_diffusion.py b/ppdiffusers/examples/community/clip_guided_stable_diffusion.py index b575d2f3e58c..13053b46cf8d 100644 --- a/ppdiffusers/examples/community/clip_guided_stable_diffusion.py +++ b/ppdiffusers/examples/community/clip_guided_stable_diffusion.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import Callable, List, Optional, Union import paddle @@ -28,6 +29,7 @@ ) from ppdiffusers import ( AutoencoderKL, + DDIMScheduler, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, @@ -84,7 +86,7 @@ def __init__( clip_model: CLIPModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler], feature_extractor: CLIPFeatureExtractor, ): super().__init__() @@ -99,7 +101,12 @@ def __init__( ) self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) - self.make_cutouts = MakeCutouts(feature_extractor.size) + self.cut_out_size = ( + feature_extractor.size + if isinstance(feature_extractor.size, int) + else feature_extractor.size["shortest_edge"] + ) + self.make_cutouts = MakeCutouts(self.cut_out_size) set_stop_gradient(self.text_encoder, True) set_stop_gradient(self.clip_model, True) @@ -152,7 +159,7 @@ def cond_fn( # predict the noise residual noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample - if isinstance(self.scheduler, PNDMScheduler): + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t # compute predicted original sample from predicted noise also called @@ -174,7 +181,7 @@ def cond_fn( if use_cutouts: image = self.make_cutouts(image, num_cutouts) else: - resize_transform = transforms.Resize(self.feature_extractor.size) + resize_transform = transforms.Resize(self.cut_out_size) image = paddle.stack([resize_transform(img) for img in image], axis=0) image = self.normalize(image).astype(latents.dtype) @@ -207,11 +214,12 @@ def __call__( guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, clip_guidance_scale: Optional[float] = 100, clip_prompt: Optional[Union[str, List[str]]] = None, num_cutouts: Optional[int] = 4, use_cutouts: Optional[bool] = True, - seed: Optional[int] = None, + generator: Optional[paddle.Generator] = None, latents: Optional[paddle.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -277,9 +285,9 @@ def __call__( text_embeddings_clip = self.clip_model.get_text_features(clip_text_input_ids) text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, axis=-1, keepdim=True) # duplicate text embeddings clip for each generation per prompt - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings_clip = text_embeddings_clip.tile([1, num_images_per_prompt, 1]) - text_embeddings_clip = text_embeddings_clip.reshape([bs_embed * num_images_per_prompt, seq_len, -1]) + bs_embed, _ = text_embeddings_clip.shape + text_embeddings_clip = text_embeddings_clip.tile([1, num_images_per_prompt]) + text_embeddings_clip = text_embeddings_clip.reshape([bs_embed * num_images_per_prompt, -1]) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -334,8 +342,7 @@ def __call__( # However this currently doesn't work in `mps`. latents_shape = [batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8] if latents is None: - paddle.seed(seed) - latents = paddle.randn(latents_shape, dtype=text_embeddings.dtype) + latents = paddle.randn(latents_shape, generator=generator, dtype=text_embeddings.dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") @@ -350,6 +357,20 @@ def __call__( # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents @@ -381,7 +402,7 @@ def __call__( ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: diff --git a/ppdiffusers/examples/community/inference_clip_guided_stable_diffusion.py b/ppdiffusers/examples/community/inference_clip_guided_stable_diffusion.py index a3fd8de9e68a..484914c08901 100644 --- a/ppdiffusers/examples/community/inference_clip_guided_stable_diffusion.py +++ b/ppdiffusers/examples/community/inference_clip_guided_stable_diffusion.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle from clip_guided_stable_diffusion import CLIPGuidedStableDiffusion from IPython.display import display from PIL import Image @@ -34,7 +35,7 @@ def image_grid(imgs, rows, cols): def create_clip_guided_pipeline( model_id="CompVis/stable-diffusion-v1-4", clip_model_id="openai/clip-vit-large-patch14", scheduler="plms" ): - pipeline = StableDiffusionPipeline.from_pretrained(model_id) + pipeline = StableDiffusionPipeline.from_pretrained(model_id, paddle_dtype=paddle.float16) if scheduler == "lms": scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") @@ -116,29 +117,31 @@ def infer( clip_guidance_scale = 100 # @param {type: "number"} num_cutouts = 4 # @param {type: "number"} use_cutouts = False # @param ["False", "True"] - unfreeze_unet = True # @param ["False", "True"] - unfreeze_vae = True # @param ["False", "True"] + unfreeze_unet = False # @param ["False", "True"] + unfreeze_vae = False # @param ["False", "True"] seed = 3788086447 # @param {type: "number"} model_id = "CompVis/stable-diffusion-v1-4" clip_model_id = "openai/clip-vit-large-patch14" # @param ["openai/clip-vit-base-patch32", "openai/clip-vit-base-patch14", "openai/clip-rn101", "openai/clip-rn50"] {allow-input: true} scheduler = "plms" # @param ['plms', 'lms'] guided_pipeline = create_clip_guided_pipeline(model_id, clip_model_id) - grid_image = infer( - prompt=prompt, - negative_prompt=negative_prompt, - clip_prompt=clip_prompt, - num_return_images=num_return_images, - num_images_per_prompt=num_images_per_prompt, - num_inference_steps=num_inference_steps, - clip_guidance_scale=clip_guidance_scale, - guidance_scale=guidance_scale, - guided_pipeline=guided_pipeline, - use_cutouts=use_cutouts, - num_cutouts=num_cutouts, - seed=seed, - unfreeze_unet=unfreeze_unet, - unfreeze_vae=unfreeze_vae, - ) + + with paddle.amp.auto_cast(True, level="O2"): + grid_image = infer( + prompt=prompt, + negative_prompt=negative_prompt, + clip_prompt=clip_prompt, + num_return_images=num_return_images, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=num_inference_steps, + clip_guidance_scale=clip_guidance_scale, + guidance_scale=guidance_scale, + guided_pipeline=guided_pipeline, + use_cutouts=use_cutouts, + num_cutouts=num_cutouts, + seed=seed, + unfreeze_unet=unfreeze_unet, + unfreeze_vae=unfreeze_vae, + ) display(grid_image)