Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try dynamic thresholding #3962

Closed
wants to merge 13 commits into from
4 changes: 3 additions & 1 deletion modules/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))


def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, mimic_scale: float, threshold_enable: bool, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_inpaint = mode == 1
is_batch = mode == 2

Expand Down Expand Up @@ -117,6 +117,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
n_iter=n_iter,
steps=steps,
cfg_scale=cfg_scale,
mimic_scale=mimic_scale,
threshold_enable=threshold_enable,
width=width,
height=height,
restore_faces=restore_faces,
Expand Down
17 changes: 13 additions & 4 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, mimic_scale: float = 7.5, threshold_enable: bool = False, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)

Expand All @@ -101,6 +101,8 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom
self.n_iter: int = n_iter
self.steps: int = steps
self.cfg_scale: float = cfg_scale
self.mimic_scale: float = mimic_scale
self.threshold_enable: float = threshold_enable
self.width: int = width
self.height: int = height
self.restore_faces: bool = restore_faces
Expand Down Expand Up @@ -130,6 +132,9 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom
self.seed_resize_from_h = 0
self.seed_resize_from_w = 0

if not threshold_enable:
self.mimic_scale = 0

self.scripts = None
self.script_args = None
self.all_prompts = None
Expand Down Expand Up @@ -251,6 +256,7 @@ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="",
self.height = p.height
self.sampler_name = p.sampler_name
self.cfg_scale = p.cfg_scale
self.mimic_scale = p.mimic_scale
self.steps = p.steps
self.batch_size = p.batch_size
self.restore_faces = p.restore_faces
Expand Down Expand Up @@ -299,6 +305,7 @@ def js(self):
"height": self.height,
"sampler_name": self.sampler_name,
"cfg_scale": self.cfg_scale,
"mimic_scale": self.mimic_scale,
"steps": self.steps,
"batch_size": self.batch_size,
"restore_faces": self.restore_faces,
Expand Down Expand Up @@ -445,6 +452,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
"Mimic CFG scale": None if not p.threshold_enable else p.mimic_scale,
"Threshold percentile": None if not p.threshold_enable else opts.dynamic_threshold_percentile,
}

generation_params.update(p.extra_generation_params)
Expand Down Expand Up @@ -703,11 +712,11 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs

if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x), mimic_scale=self.mimic_scale, threshold_enable=self.threshold_enable)
return samples

x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height), mimic_scale=self.mimic_scale, threshold_enable=self.threshold_enable)

samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]

Expand Down Expand Up @@ -913,7 +922,7 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
x *= self.initial_noise_multiplier

samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning, mimic_scale=self.mimic_scale, threshold_enable=self.threshold_enable)

if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
Expand Down
67 changes: 53 additions & 14 deletions modules/sd_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def adjust_steps_if_invalid(self, p, num_steps):

return num_steps

def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None, mimic_scale=None, threshold_enable=False):
steps, t_enc = setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps)
self.initialize(p)
Expand All @@ -271,7 +271,7 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,

return samples

def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None, mimic_scale=None, hold_enable=False):
self.initialize(p)

self.init_latent = None
Expand Down Expand Up @@ -300,17 +300,52 @@ def __init__(self, model):
self.init_latent = None
self.step = 0

def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
def _dynthresh(self, cond, uncond, cond_scale, conds_list, mimic_scale):
# uncond shape is (batch, 4, height, width)
conds_per_batch = cond.shape[0] / uncond.shape[0]
assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches"
cond_stacked = cond.reshape((-1, int(conds_per_batch)) + uncond.shape[1:])
diff = cond_stacked - uncond.unsqueeze(1)
# conds_list shape is (batch, cond, 2)
weights = torch.tensor(conds_list).select(2, 1)
weights = weights.reshape(*weights.shape, 1, 1, 1).to(diff.device)
diff_weighted = (diff * weights).sum(1)
dynthresh_target = uncond + diff_weighted * mimic_scale

dt_flattened = dynthresh_target.flatten(2)
dt_means = dt_flattened.mean(dim=2).unsqueeze(2)
dt_recentered = dt_flattened - dt_means
dt_max = dt_recentered.abs().max(dim=2).values.unsqueeze(2)

ut = uncond + diff_weighted * cond_scale
ut_flattened = ut.flatten(2)
ut_means = ut_flattened.mean(dim=2).unsqueeze(2)
ut_centered = ut_flattened - ut_means

ut_q = torch.quantile(ut_centered.abs(), opts.dynamic_threshold_percentile, dim=2).unsqueeze(2)
s = torch.maximum(ut_q, dt_max)
t_clamped = ut_centered.clamp(-s, s)
t_normalized = t_clamped / s
t_renormalized = t_normalized * dt_max

uncentered = t_renormalized + ut_means
unflattened = uncentered.unflatten(2, dynthresh_target.shape[2:])
return unflattened

def combine_denoised(self, x_out, conds_list, uncond, cond_scale, mimic_scale, threshold_enable):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)

for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
if threshold_enable:
denoised = self._dynthresh(x_out[:-uncond.shape[0]], denoised_uncond, cond_scale, conds_list, mimic_scale)
else:
denoised = torch.clone(denoised_uncond)

for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)

return denoised

def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond, mimic_scale, threshold_enable):
if state.interrupted or state.skipped:
raise InterruptedException

Expand Down Expand Up @@ -351,7 +386,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):

x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})

denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, mimic_scale, threshold_enable)

if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
Expand Down Expand Up @@ -474,7 +509,7 @@ def get_sigmas(self, p, steps):

return sigmas

def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None, mimic_scale=None, threshold_enable=False):
steps, t_enc = setup_img2img_steps(p, steps)

sigmas = self.get_sigmas(p, steps)
Expand Down Expand Up @@ -502,12 +537,14 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale
'cond_scale': p.cfg_scale,
'mimic_scale': mimic_scale,
'threshold_enable': threshold_enable,
}, disable=False, callback=self.callback_state, **extra_params_kwargs))

return samples

def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None, mimic_scale=None, threshold_enable=False):
steps = steps or p.steps

sigmas = self.get_sigmas(p, steps)
Expand All @@ -528,7 +565,9 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale
'cond_scale': p.cfg_scale,
'mimic_scale': mimic_scale,
'threshold_enable': threshold_enable,
}, disable=False, callback=self.callback_state, **extra_params_kwargs))

return samples
Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def list_samplers():
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
"dynamic_threshold_percentile": OptionInfo(0.999, "For latent fix, the top percentile of latents to clamp (ex: 0.999 means the top 0.1% is clamped)", gr.Slider, {"minimum": 0.9, "maximum": 1.0, "step": 0.0005})
}))

options_templates.update(options_section(('interrogate', "Interrogate Options"), {
Expand Down
4 changes: 3 additions & 1 deletion modules/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from modules.ui import plaintext_to_html


def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, mimic_scale: float, threshold_enable: bool, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
Expand All @@ -27,6 +27,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
n_iter=n_iter,
steps=steps,
cfg_scale=cfg_scale,
mimic_scale=mimic_scale,
threshold_enable=threshold_enable,
width=width,
height=height,
restore_faces=restore_faces,
Expand Down
Loading