Skip to content

Commit

Permalink
add DDIM to CLIP Guided Stable Diffusion and add code example (#4920)
Browse files Browse the repository at this point in the history
* add DDIM to CLIP Guided Stable Diffusion and add example code

* modify code sample

* use generator and modify code sample
  • Loading branch information
Yam0214 authored Feb 22, 2023
1 parent 9a00c15 commit 82bd9b1
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 30 deletions.
72 changes: 72 additions & 0 deletions ppdiffusers/examples/community/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Community Examples

社区示例包含由社区添加的推理和训练示例。可以从下表中了解所有社区实例的概况。点击**Code Example**,跳转到对应实例的可运行代码,可以复制并运行。如果一个示例不能像预期的那样工作,请创建一个issue提问。

|Example|Description|Code Example|Author|
|-|-|-|-|
|CLIP Guided Stable Diffusion|使用CLIP引导Stable Diffusion实现文生图|[CLIP Guided Stable Diffusion](#CLIP%20Guided%20Stable%20Diffusion)||

## Example usages

### CLIP Guided Stable Diffusion

使用 CLIP 模型引导 Stable Diffusion 去噪,可以生成更真实的图像。

以下代码运行需要16GB的显存。

```python
import os

import paddle
from clip_guided_stable_diffusion import CLIPGuidedStableDiffusion

from paddlenlp.transformers import CLIPFeatureExtractor, CLIPModel

feature_extractor = CLIPFeatureExtractor.from_pretrained(
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
dtype=paddle.float32)

guided_pipeline = CLIPGuidedStableDiffusion.from_pretrained(
"runwayml/stable-diffusion-v1-5",
clip_model=clip_model,
feature_extractor=feature_extractor,
paddle_dtype=paddle.float16,
)
guided_pipeline.enable_attention_slicing()

prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"

generator = paddle.Generator().manual_seed(2022)
with paddle.amp.auto_cast(True, level="O2"):
images = []
for i in range(4):
image = guided_pipeline(
prompt,
num_inference_steps=50,
guidance_scale=7.5,
clip_guidance_scale=100,
num_cutouts=4,
use_cutouts=False,
generator=generator,
unfreeze_unet=False,
unfreeze_vae=False,
).images[0]
images.append(image)

# save images locally
if not os.path.exists("clip_guided_sd"):
os.mkdir("clip_guided_sd")
for i, img in enumerate(images):
img.save(f"./clip_guided_sd/image_{i}.png")
```
生成的图片保存在`images`列表中,样例如下:

| image_0 | image_1 | image_2 | image_3 |
|:-------------------:|:-------------------:|:-------------------:|:-------------------:|
|![][clip_guided_sd_0]|![][clip_guided_sd_1]|![][clip_guided_sd_2]|![][clip_guided_sd_3]|

[clip_guided_sd_0]: https://user-images.githubusercontent.com/40912707/220514674-e5cb29a3-b07e-4e8f-a4c8-323b35637294.png
[clip_guided_sd_1]: https://user-images.githubusercontent.com/40912707/220514703-1eaf444e-1506-4c44-b686-5950fd79a3da.png
[clip_guided_sd_2]: https://user-images.githubusercontent.com/40912707/220514765-89e48c13-156f-4e61-b433-06f1283d2265.png
[clip_guided_sd_3]: https://user-images.githubusercontent.com/40912707/220514751-82d63fd4-e35e-482b-a8e1-c5c956119b2e.png
43 changes: 32 additions & 11 deletions ppdiffusers/examples/community/clip_guided_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Callable, List, Optional, Union

import paddle
Expand All @@ -28,6 +29,7 @@
)
from ppdiffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
LMSDiscreteScheduler,
PNDMScheduler,
Expand Down Expand Up @@ -84,7 +86,7 @@ def __init__(
clip_model: CLIPModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler],
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
Expand All @@ -99,7 +101,12 @@ def __init__(
)

self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
self.make_cutouts = MakeCutouts(feature_extractor.size)
self.cut_out_size = (
feature_extractor.size
if isinstance(feature_extractor.size, int)
else feature_extractor.size["shortest_edge"]
)
self.make_cutouts = MakeCutouts(self.cut_out_size)

set_stop_gradient(self.text_encoder, True)
set_stop_gradient(self.clip_model, True)
Expand Down Expand Up @@ -152,7 +159,7 @@ def cond_fn(
# predict the noise residual
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample

if isinstance(self.scheduler, PNDMScheduler):
if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)):
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
beta_prod_t = 1 - alpha_prod_t
# compute predicted original sample from predicted noise also called
Expand All @@ -174,7 +181,7 @@ def cond_fn(
if use_cutouts:
image = self.make_cutouts(image, num_cutouts)
else:
resize_transform = transforms.Resize(self.feature_extractor.size)
resize_transform = transforms.Resize(self.cut_out_size)
image = paddle.stack([resize_transform(img) for img in image], axis=0)
image = self.normalize(image).astype(latents.dtype)

Expand Down Expand Up @@ -207,11 +214,12 @@ def __call__(
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
clip_guidance_scale: Optional[float] = 100,
clip_prompt: Optional[Union[str, List[str]]] = None,
num_cutouts: Optional[int] = 4,
use_cutouts: Optional[bool] = True,
seed: Optional[int] = None,
generator: Optional[paddle.Generator] = None,
latents: Optional[paddle.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
Expand Down Expand Up @@ -277,9 +285,9 @@ def __call__(
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input_ids)
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, axis=-1, keepdim=True)
# duplicate text embeddings clip for each generation per prompt
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings_clip = text_embeddings_clip.tile([1, num_images_per_prompt, 1])
text_embeddings_clip = text_embeddings_clip.reshape([bs_embed * num_images_per_prompt, seq_len, -1])
bs_embed, _ = text_embeddings_clip.shape
text_embeddings_clip = text_embeddings_clip.tile([1, num_images_per_prompt])
text_embeddings_clip = text_embeddings_clip.reshape([bs_embed * num_images_per_prompt, -1])

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand Down Expand Up @@ -334,8 +342,7 @@ def __call__(
# However this currently doesn't work in `mps`.
latents_shape = [batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8]
if latents is None:
paddle.seed(seed)
latents = paddle.randn(latents_shape, dtype=text_embeddings.dtype)
latents = paddle.randn(latents_shape, generator=generator, dtype=text_embeddings.dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
Expand All @@ -350,6 +357,20 @@ def __call__(
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta

# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator

for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
Expand Down Expand Up @@ -381,7 +402,7 @@ def __call__(
)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from clip_guided_stable_diffusion import CLIPGuidedStableDiffusion
from IPython.display import display
from PIL import Image
Expand All @@ -34,7 +35,7 @@ def image_grid(imgs, rows, cols):
def create_clip_guided_pipeline(
model_id="CompVis/stable-diffusion-v1-4", clip_model_id="openai/clip-vit-large-patch14", scheduler="plms"
):
pipeline = StableDiffusionPipeline.from_pretrained(model_id)
pipeline = StableDiffusionPipeline.from_pretrained(model_id, paddle_dtype=paddle.float16)

if scheduler == "lms":
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
Expand Down Expand Up @@ -116,29 +117,31 @@ def infer(
clip_guidance_scale = 100 # @param {type: "number"}
num_cutouts = 4 # @param {type: "number"}
use_cutouts = False # @param ["False", "True"]
unfreeze_unet = True # @param ["False", "True"]
unfreeze_vae = True # @param ["False", "True"]
unfreeze_unet = False # @param ["False", "True"]
unfreeze_vae = False # @param ["False", "True"]
seed = 3788086447 # @param {type: "number"}

model_id = "CompVis/stable-diffusion-v1-4"
clip_model_id = "openai/clip-vit-large-patch14" # @param ["openai/clip-vit-base-patch32", "openai/clip-vit-base-patch14", "openai/clip-rn101", "openai/clip-rn50"] {allow-input: true}
scheduler = "plms" # @param ['plms', 'lms']
guided_pipeline = create_clip_guided_pipeline(model_id, clip_model_id)
grid_image = infer(
prompt=prompt,
negative_prompt=negative_prompt,
clip_prompt=clip_prompt,
num_return_images=num_return_images,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=num_inference_steps,
clip_guidance_scale=clip_guidance_scale,
guidance_scale=guidance_scale,
guided_pipeline=guided_pipeline,
use_cutouts=use_cutouts,
num_cutouts=num_cutouts,
seed=seed,
unfreeze_unet=unfreeze_unet,
unfreeze_vae=unfreeze_vae,
)

with paddle.amp.auto_cast(True, level="O2"):
grid_image = infer(
prompt=prompt,
negative_prompt=negative_prompt,
clip_prompt=clip_prompt,
num_return_images=num_return_images,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=num_inference_steps,
clip_guidance_scale=clip_guidance_scale,
guidance_scale=guidance_scale,
guided_pipeline=guided_pipeline,
use_cutouts=use_cutouts,
num_cutouts=num_cutouts,
seed=seed,
unfreeze_unet=unfreeze_unet,
unfreeze_vae=unfreeze_vae,
)

display(grid_image)

0 comments on commit 82bd9b1

Please sign in to comment.