Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PPDiffusers] Fix ppdiffusers bug and support ZH stablediffusion #3663

Merged
merged 13 commits into from
Nov 6, 2022
2 changes: 1 addition & 1 deletion paddlenlp/transformers/clip/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def quick_gelu(x):

F.quick_gelu = quick_gelu

NEG_INF = float("-inf") # -1e4 -1e9
NEG_INF = -1e9 # float("-inf") -1e4 -1e9


class VisionTransformer(nn.Layer):
Expand Down
16 changes: 8 additions & 8 deletions ppdiffusers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

## 1. News 📢

* 🔥 **2022.11.04 支持 IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1 和 IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-EN-v0.1 中文权重**
* 🔥 **2022.10.27 发布 PPDiffusers仓库**


Expand Down Expand Up @@ -39,7 +40,7 @@ python setup.py install

## 4. 使用PPDiffusers快速体验Stable Diffusion模型!

Stable Diffusion 是一个**文本到图像(text-to-image)**的**潜在扩散模型(latent diffusion model, ldm)**, 该模型是由来自[CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), [LAION](https://laion.ai/) 的工程师以及 [RunwayML](https://runwayml.com/)一起开发而完成的。该模型使用了大小为**512x512**的[LAION-5B](https://laion.ai/blog/laion-5b/)数据集子集进行训练。该模型使用了Openai开源的**CLIP ViT-L/14** 文本编码器(text_encoder)来编码提示(prompt)文本,从而作为引导条件(注意该部分权重不进行训练)。该模型使用了Unet模型(860M参数)和text encoder(123M参数),并且可以在具有4GB显存(注:当前paddle版本需要进行优化,无法在4GB的显卡上运行)的GPU进行推理预测
Stable Diffusion 是一个**文本到图像(text-to-image)**的**潜在扩散模型(latent diffusion model, ldm)**, 该模型是由来自[CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), [LAION](https://laion.ai/) 的工程师以及 [RunwayML](https://runwayml.com/)一起开发而完成的。该模型使用了大小为**512x512**的[LAION-5B](https://laion.ai/blog/laion-5b/)数据集子集进行训练。该模型使用了Openai开源的**CLIP ViT-L/14** 文本编码器(text_encoder)来编码提示(prompt)文本,从而作为引导条件(注意该部分权重不进行训练)。该模型使用了Unet模型(860M参数)和text encoder(123M参数),并且可以在具有4GB显存的GPU上进行推理预测

___注意___:
___为了方便国内用户下载使用及快速体验Stable Diffusion模型,我们在百度云(BOS)上提供了paddle版本的镜像权重。注意:为了使用该模型与权重,你必须接受该模型所要求的**License**,请访问huggingface的[model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), 仔细阅读里面的**License**,然后签署该协议。___
Expand All @@ -61,8 +62,7 @@ image = pipe(prompt).images[0]

image.save("astronaut_rides_horse.png")
```
<center><image src="https://user-images.githubusercontent.com/50394665/197779466-04543823-8b83-41d6-94e8-146a7dac00d7.png" width="600"></center>

<img width="600" alt="image" src="https://user-images.githubusercontent.com/50394665/197779466-04543823-8b83-41d6-94e8-146a7dac00d7.png">

### 4.2 使用Stable Diffusion进行由文本引导的图片-图片的生成

Expand All @@ -74,10 +74,10 @@ from io import BytesIO

from ppdiffusers import StableDiffusionImg2ImgPipeline

# load the pipeline
# 加载pipeline
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")

# let's download an initial image
# 下载初始图片
url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/sketch-mountains-input.png"

response = requests.get(url)
Expand All @@ -92,7 +92,7 @@ with paddle.amp.auto_cast(True):
image.save("fantasy_landscape.png")
```

<center><image src="https://user-images.githubusercontent.com/50394665/197780044-34e6f8ca-6864-4c3d-bb99-28e0aadf867b.png" width="600"></center>
<img width="600" alt="image" src="https://user-images.githubusercontent.com/50394665/197780044-34e6f8ca-6864-4c3d-bb99-28e0aadf867b.png">


### 4.3 使用Stable Diffusion根据文本补全图片
Expand Down Expand Up @@ -125,7 +125,7 @@ with paddle.amp.auto_cast(True):

image.save("cat_on_bench.png")
```
<center><image src="https://user-images.githubusercontent.com/50394665/197783711-ab3caf2e-5a4d-4099-8d01-d6ca80ca8e78.png" width="600"></center>
<img width="600" alt="image" src="https://user-images.githubusercontent.com/50394665/197783711-ab3caf2e-5a4d-4099-8d01-d6ca80ca8e78.png">

Tips: 下面的使用方法是新版本的代码,也是官方推荐的代码,注意必须配合**runwayml/stable-diffusion-inpainting**才可正常使用。
```python
Expand Down Expand Up @@ -153,7 +153,7 @@ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]

image.save("cat_on_bench_new.png")
```
<center><image src="https://user-images.githubusercontent.com/50394665/198016801-87cec13b-0d89-41c3-aedb-c89a43d76153.png" width="600"></center>
<img width="600" alt="image" src="https://user-images.githubusercontent.com/50394665/198016801-87cec13b-0d89-41c3-aedb-c89a43d76153.png">

## 5. Credits

Expand Down
2 changes: 1 addition & 1 deletion ppdiffusers/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.6.0.dev1
0.6.1
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,10 @@ def __call__(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}")
text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids)[0]

attention_mask = paddle.ones_like(text_input_ids)
text_embeddings = self.text_encoder(text_input_ids,
attention_mask=attention_mask)[0]

# duplicate text embeddings for each generation per prompt
bs_embed, seq_len, _ = text_embeddings.shape
Expand Down Expand Up @@ -323,7 +326,9 @@ def __call__(
truncation=True,
return_tensors="pd",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids)[0]
attention_mask = paddle.ones_like(uncond_input.input_ids)
uncond_embeddings = self.text_encoder(
uncond_input.input_ids, attention_mask=attention_mask)[0]

# duplicate unconditional embeddings for each generation per prompt
seq_len = uncond_embeddings.shape[1]
Expand Down
10 changes: 7 additions & 3 deletions ppdiffusers/examples/community/composable_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
Expand Down Expand Up @@ -253,7 +253,9 @@ def __call__(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}")
text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids)[0]
attention_mask = paddle.ones_like(text_input_ids)
text_embeddings = self.text_encoder(text_input_ids,
attention_mask=attention_mask)[0]

# duplicate text embeddings for each generation per prompt, using mps friendly method
# bs_embed, seq_len, _ = text_embeddings.shape
Expand Down Expand Up @@ -318,7 +320,9 @@ def __call__(
truncation=True,
return_tensors="pd",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids)[0]
attention_mask = paddle.ones_like(uncond_input.input_ids)
uncond_embeddings = self.text_encoder(
uncond_input.input_ids, attention_mask=attention_mask)[0]

# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
# seq_len = uncond_embeddings.shape[1]
Expand Down
10 changes: 7 additions & 3 deletions ppdiffusers/examples/community/interpolate_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
Expand Down Expand Up @@ -277,7 +277,9 @@ def __call__(
)
text_input_ids = text_input_ids[:, :self.tokenizer.
model_max_length]
text_embeddings = self.text_encoder(text_input_ids)[0]
attention_mask = paddle.ones_like(text_input_ids)
text_embeddings = self.text_encoder(
text_input_ids, attention_mask=attention_mask)[0]
else:
batch_size = text_embeddings.shape[0]

Expand Down Expand Up @@ -318,7 +320,9 @@ def __call__(
truncation=True,
return_tensors="pd",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids)[0]
attention_mask = paddle.ones_like(uncond_input.input_ids)
uncond_embeddings = self.text_encoder(
uncond_input.input_ids, attention_mask=attention_mask)[0]

# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
Expand Down
33 changes: 22 additions & 11 deletions ppdiffusers/examples/community/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@


def parse_prompt_attention(text):
"""
r"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
Expand Down Expand Up @@ -186,6 +186,7 @@ def pad_tokens_and_weights(tokens,
max_length,
bos,
eos,
pad,
no_boseos_middle=True,
chunk_length=77):
r"""
Expand All @@ -194,8 +195,9 @@ def pad_tokens_and_weights(tokens,
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
for i in range(len(tokens)):
tokens[i] = [bos
] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
tokens[i] = [bos] + tokens[i] + [
eos
] + [pad] * (max_length - 2 - len(tokens[i]))
if no_boseos_middle:
weights[i] = [
1.0
Expand Down Expand Up @@ -238,7 +240,9 @@ def get_unweighted_text_embeddings(
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk)[0]
attention_mask = paddle.ones_like(text_input_chunk)
text_embedding = pipe.text_encoder(text_input_chunk,
attention_mask=attention_mask)[0]

if no_boseos_middle:
if i == 0:
Expand All @@ -254,7 +258,9 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding)
text_embeddings = paddle.concat(text_embeddings, axis=1)
else:
text_embeddings = pipe.text_encoder(text_input)[0]
attention_mask = paddle.ones_like(text_input)
text_embeddings = pipe.text_encoder(text_input,
attention_mask=attention_mask)[0]
return text_embeddings


Expand Down Expand Up @@ -336,14 +342,17 @@ def get_weighted_text_embeddings(
2) * max_embeddings_multiples + 2

# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
# support bert tokenizer
bos = pipe.tokenizer.bos_token_id if pipe.tokenizer.bos_token_id is not None else pipe.tokenizer.cls_token_id
eos = pipe.tokenizer.eos_token_id if pipe.tokenizer.eos_token_id is not None else pipe.tokenizer.sep_token_id
pad = pipe.tokenizer.pad_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
pad,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length,
)
Expand All @@ -355,6 +364,7 @@ def get_weighted_text_embeddings(
max_length,
bos,
eos,
pad,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length,
)
Expand Down Expand Up @@ -481,7 +491,7 @@ def __init__(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
Expand Down Expand Up @@ -753,7 +763,8 @@ def __call__(
timesteps = timesteps.tile([
batch_size * num_images_per_prompt,
])

if seed is not None:
paddle.seed(seed)
noise = paddle.randn(
init_latents.shape,
dtype=latents_dtype,
Expand Down Expand Up @@ -926,8 +937,8 @@ def text2img(

def img2img(
self,
init_image: Union[paddle.Tensor, PIL.Image.Image],
prompt: Union[str, List[str]],
init_image: Union[paddle.Tensor, PIL.Image.Image],
negative_prompt: Optional[Union[str, List[str]]] = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
Expand Down Expand Up @@ -1016,9 +1027,9 @@ def img2img(

def inpaint(
self,
prompt: Union[str, List[str]],
init_image: Union[paddle.Tensor, PIL.Image.Image],
mask_image: Union[paddle.Tensor, PIL.Image.Image],
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
Expand Down
10 changes: 7 additions & 3 deletions ppdiffusers/examples/community/wildcard_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
Expand Down Expand Up @@ -298,7 +298,9 @@ def __call__(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}")
text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids)[0]
attention_mask = paddle.ones_like(text_input_ids)
text_embeddings = self.text_encoder(text_input_ids,
attention_mask=attention_mask)[0]

# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
Expand Down Expand Up @@ -337,7 +339,9 @@ def __call__(
truncation=True,
return_tensors="pd",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids)[0]
attention_mask = paddle.ones_like(uncond_input.input_ids)
uncond_embeddings = self.text_encoder(
uncond_input.input_ids, attention_mask=attention_mask)[0]

# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
Expand Down
Loading