-
Notifications
You must be signed in to change notification settings - Fork 79
/
ddim_inv.py
146 lines (122 loc) · 5.97 KB
/
ddim_inv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import sys
import numpy as np
import torch
import torch.nn.functional as F
from random import randrange
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from diffusers import DDIMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
sys.path.insert(0, "src/utils")
from base_pipeline import BasePipeline
from cross_attention import prep_unet
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
class DDIMInversion(BasePipeline):
def auto_corr_loss(self, x, random_shift=True):
B,C,H,W = x.shape
assert B==1
x = x.squeeze(0)
# x must be shape [C,H,W] now
reg_loss = 0.0
for ch_idx in range(x.shape[0]):
noise = x[ch_idx][None, None,:,:]
while True:
if random_shift: roll_amount = randrange(1, noise.shape[2]//2)
else: roll_amount = 1
reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=2)).mean()**2
reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=3)).mean()**2
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
return reg_loss
def kl_divergence(self, x):
_mu = x.mean()
_var = x.var()
return _var + _mu**2 - 1 - torch.log(_var+1e-7)
def __call__(
self,
prompt: Union[str, List[str]] = None,
num_inversion_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
img=None, # the input image as a PIL image
torch_dtype=torch.float32,
# inversion regularization parameters
lambda_ac: float = 20.0,
lambda_kl: float = 20.0,
num_reg_steps: int = 5,
num_ac_rolls: int = 5,
):
# 0. modify the unet to be useful :D
self.unet = prep_unet(self.unet)
# set the scheduler to be the Inverse DDIM scheduler
# self.scheduler = MyDDIMScheduler.from_config(self.scheduler.config)
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
self.scheduler.set_timesteps(num_inversion_steps, device=device)
timesteps = self.scheduler.timesteps
# Encode the input image with the first stage model
x0 = np.array(img)/255
x0 = torch.from_numpy(x0).type(torch_dtype).permute(2, 0, 1).unsqueeze(dim=0).repeat(1, 1, 1, 1).to(device)
x0 = (x0 - 0.5) * 2.
with torch.no_grad():
x0_enc = self.vae.encode(x0).latent_dist.sample().to(device, torch_dtype)
latents = x0_enc = 0.18215 * x0_enc
# Decode and return the image
with torch.no_grad():
x0_dec = self.decode_latents(x0_enc.detach())
image_x0_dec = self.numpy_to_pil(x0_dec)
with torch.no_grad():
prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt).to(device)
extra_step_kwargs = self.prepare_extra_step_kwargs(None, eta)
# Do the inversion
num_warmup_steps = len(timesteps) - num_inversion_steps * self.scheduler.order # should be 0?
with self.progress_bar(total=num_inversion_steps) as progress_bar:
for i, t in enumerate(timesteps.flip(0)[1:-1]):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
with torch.no_grad():
noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# regularization of the noise prediction
e_t = noise_pred
for _outer in range(num_reg_steps):
if lambda_ac>0:
for _inner in range(num_ac_rolls):
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
l_ac = self.auto_corr_loss(_var)
l_ac.backward()
_grad = _var.grad.detach()/num_ac_rolls
e_t = e_t - lambda_ac*_grad
if lambda_kl>0:
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
l_kld = self.kl_divergence(_var)
l_kld.backward()
_grad = _var.grad.detach()
e_t = e_t - lambda_kl*_grad
e_t = e_t.detach()
noise_pred = e_t
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, reverse=True, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
x_inv = latents.detach().clone()
# reconstruct the image
# 8. Post-processing
image = self.decode_latents(latents.detach())
image = self.numpy_to_pil(image)
return x_inv, image, image_x0_dec