Skip to content

Commit

Permalink
Added prompt2prompt guidance with prompt parsing
Browse files Browse the repository at this point in the history
works like [word1:word2:step] where word1 is an old concept, word2 is a new concept and step is step where to replace the words in the sampling loop.

can also [word:step] to just introduce new concept or [word::step] to remove a concept.
  • Loading branch information
Doggettx committed Sep 10, 2022
1 parent cd3d653 commit 3b5c504
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 50 deletions.
17 changes: 13 additions & 4 deletions ldm/models/diffusion/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def sample(self,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
**kwargs)
return samples, intermediates

@torch.no_grad()
Expand All @@ -116,7 +116,7 @@ def ddim_sampling(self, cond, shape,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,):
unconditional_guidance_scale=1., unconditional_conditioning=None, **kwargs):
device = self.model.betas.device
b = shape[0]
if x_T is None:
Expand All @@ -134,7 +134,8 @@ def ddim_sampling(self, cond, shape,
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")


prompt_guidance = kwargs.get('prompt_guidance')
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)

for i, step in enumerate(iterator):
Expand All @@ -146,6 +147,9 @@ def ddim_sampling(self, cond, shape,
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img

if prompt_guidance is not None and i < len(prompt_guidance):
cond = prompt_guidance[i]

outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
Expand Down Expand Up @@ -221,20 +225,25 @@ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):

@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False):
use_original_steps=False, **kwargs):

timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]

time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
prompt_guidance = kwargs.get('prompt_guidance')

iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)

if prompt_guidance is not None and i < len(prompt_guidance):
cond = prompt_guidance[i]

x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
Expand Down
94 changes: 94 additions & 0 deletions scripts/prompt_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np
import re


class PromptParser:
def __init__(self, model):
self.model = model
self.regex = re.compile(r'\[.*?\]|.+?(?=\[)|.*')

def get_prompt_guidance(self, prompt, steps, batch_size):
prompts = self.parse_prompt(prompt, steps)
prompt_guidance = np.empty(steps, dtype=object)
cg = None

index = 0
next_step = 0

print()
for i in range(0, steps):
if i == next_step:
_, text = prompts[index]
print(f"Swapping at step {i} to: {text}")
cg = self.model.get_learned_conditioning(batch_size * text)

index += 1

if index < len(prompts):
next_step, _ = prompts[index]


prompt_guidance[i] = cg

return prompt_guidance

def __parse_float(self, text):
try:
return float(text)
except ValueError:
return 0.

def __parse_swap_statement(self, statement):
fields = str.split(statement[1:-1], ':')
if len(fields) < 2:
return "", "", 0.

if len(fields) == 2:
return "", fields[0], self.__parse_float(fields[1])
else:
return fields[0], fields[1], self.__parse_float(fields[2])


def __get_step(self, token, steps):
_, _, weight = token
if weight >= 1.:
return int(weight)
else:
return int(weight * steps)

def parse_prompt(self, prompt, steps):
tokens = self.__get_tokens(prompt)
values = np.array([self.__get_step(token, steps) for token in tokens if type(token) is tuple])
values = np.concatenate(([0], values))
values = np.sort(np.unique(values))

builders = [(value, list()) for value in values]

for token in tokens:
if type(token) is tuple:
for value, text in builders:
word1, word2, _ = token
step = self.__get_step(token, steps)
text.append(word1 if value < step else word2)
else:
for _, text in builders:
text.append(token)

return [(value, ''.join(text)) for value, text in builders]

def __get_tokens(self, prompt):
parts = self.regex.findall(prompt)
result = list()

for part in parts:
if len(part) == 0:
continue

if part[0] == '[':
result.append(self.__parse_swap_statement(part))
else:
result.append(part)

return result


59 changes: 13 additions & 46 deletions scripts/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from imwatermark import WatermarkEncoder
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
Expand All @@ -18,15 +17,7 @@
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor


# load safety model
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)

from prompt_parser import PromptParser

def chunk(it, size):
it = iter(it)
Expand Down Expand Up @@ -65,14 +56,6 @@ def load_model_from_config(config, ckpt, verbose=False):
return model


def put_watermark(img, wm_encoder=None):
if wm_encoder is not None:
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
img = wm_encoder.encode(img, 'dwtDct')
img = Image.fromarray(img[:, :, ::-1])
return img


def load_replacement(x):
try:
hwc = x.shape
Expand All @@ -84,16 +67,6 @@ def load_replacement(x):
return x


def check_safety(x_image):
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
assert x_checked_image.shape[0] == len(has_nsfw_concept)
for i in range(len(has_nsfw_concept)):
if has_nsfw_concept[i]:
x_checked_image[i] = load_replacement(x_checked_image[i])
return x_checked_image, has_nsfw_concept


def main():
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -151,7 +124,7 @@ def main():
parser.add_argument(
"--n_iter",
type=int,
default=2,
default=1,
help="sample this often",
)
parser.add_argument(
Expand Down Expand Up @@ -181,7 +154,7 @@ def main():
parser.add_argument(
"--n_samples",
type=int,
default=3,
default=1,
help="how many samples to produce for each given prompt. A.k.a. batch size",
)
parser.add_argument(
Expand Down Expand Up @@ -251,11 +224,6 @@ def main():
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir

print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
wm = "StableDiffusionV1"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))

batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
Expand All @@ -278,6 +246,7 @@ def main():
if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)

prompt_parser = PromptParser(model)
precision_scope = autocast if opt.precision=="autocast" else nullcontext
with torch.no_grad():
with precision_scope("cuda"):
Expand All @@ -291,7 +260,11 @@ def main():
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)

prompt_guidance = prompt_parser.get_prompt_guidance(prompts[0], opt.ddim_steps, batch_size)

c = prompt_guidance[0]

shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
Expand All @@ -301,26 +274,21 @@ def main():
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
x_T=start_code,
prompt_guidance=prompt_guidance)

x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()

x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)

x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)

if not opt.skip_save:
for x_sample in x_checked_image_torch:
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
img = Image.fromarray(x_sample.astype(np.uint8))
img = put_watermark(img, wm_encoder)
img.save(os.path.join(sample_path, f"{base_count:05}.png"))
base_count += 1

if not opt.skip_grid:
all_samples.append(x_checked_image_torch)
all_samples.append(x_samples_ddim)

if not opt.skip_grid:
# additionally, save as grid
Expand All @@ -331,7 +299,6 @@ def main():
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img = put_watermark(img, wm_encoder)
img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1

Expand Down

0 comments on commit 3b5c504

Please sign in to comment.