Skip to content

Commit

Permalink
Merge pull request #14208 from CodeHatchling/soft-inpainting
Browse files Browse the repository at this point in the history
Soft Inpainting
  • Loading branch information
AUTOMATIC1111 committed Dec 14, 2023
2 parents f3cc5f8 + f1ff932 commit 8c32594
Show file tree
Hide file tree
Showing 5 changed files with 904 additions and 27 deletions.
1 change: 1 addition & 0 deletions modules/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,3 +791,4 @@ def flatten(img, bgcolor):
img = background

return img.convert('RGB')

92 changes: 67 additions & 25 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,35 @@ def apply_color_correction(correction, original_image):
return image.convert('RGB')


def apply_overlay(image, paste_loc, index, overlays):
if overlays is None or index >= len(overlays):
return image
def uncrop(image, dest_size, paste_loc):
x, y, w, h = paste_loc
base_image = Image.new('RGBA', dest_size)
image = images.resize_image(1, image, w, h)
base_image.paste(image, (x, y))
image = base_image

return image

overlay = overlays[index]

def apply_overlay(image, paste_loc, overlay):
if overlay is None:
return image

if paste_loc is not None:
x, y, w, h = paste_loc
base_image = Image.new('RGBA', (overlay.width, overlay.height))
image = images.resize_image(1, image, w, h)
base_image.paste(image, (x, y))
image = base_image
image = uncrop(image, (overlay.width, overlay.height), paste_loc)

image = image.convert('RGBA')
image.alpha_composite(overlay)
image = image.convert('RGB')

return image

def create_binary_mask(image):
def create_binary_mask(image, round=True):
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
if round:
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
else:
image = image.split()[-1].convert("L")
else:
image = image.convert('L')
return image
Expand Down Expand Up @@ -308,7 +315,7 @@ def unclip_image_conditioning(self, source_image):
c_adm = torch.cat((c_adm, noise_level_emb), 1)
return c_adm

def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
self.is_using_inpainting_conditioning = True

# Handle the different mask inputs
Expand All @@ -320,8 +327,10 @@ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=N
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])

# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
if round_image_mask:
# Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)

else:
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])

Expand All @@ -345,7 +354,7 @@ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=N

return image_conditioning

def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
source_image = devices.cond_cast_float(source_image)

# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
Expand All @@ -357,7 +366,7 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None
return self.edit_image_conditioning(source_image)

if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)

if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
Expand Down Expand Up @@ -867,6 +876,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)

if p.scripts is not None:
ps = scripts.PostSampleArgs(samples_ddim)
p.scripts.post_sample(p, ps)
samples_ddim = ps.samples

if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
else:
Expand Down Expand Up @@ -922,13 +936,31 @@ def infotext(index=0, use_main_prompt=False):
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp)
image = pp.image

mask_for_overlay = getattr(p, "mask_for_overlay", None)
overlay_image = p.overlay_images[i] if getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images) else None

if p.scripts is not None:
ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
p.scripts.postprocess_maskoverlay(p, ppmo)
mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image

if p.color_corrections is not None and i < len(p.color_corrections):
if save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)

image = apply_overlay(image, p.paste_to, i, p.overlay_images)
# If the intention is to show the output from the model
# that is being composited over the original image,
# we need to keep the original image around
# and use it in the composite step.
original_denoised_image = image.copy()

if p.paste_to is not None:
original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)

image = apply_overlay(image, p.paste_to, overlay_image)

if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
Expand All @@ -938,16 +970,17 @@ def infotext(index=0, use_main_prompt=False):
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:

if mask_for_overlay is not None:
if opts.return_mask or opts.save_mask:
image_mask = p.mask_for_overlay.convert('RGB')
image_mask = mask_for_overlay.convert('RGB')
if save_samples and opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.return_mask:
output_images.append(image_mask)

if opts.return_mask_composite or opts.save_mask_composite:
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if save_samples and opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask_composite:
Expand Down Expand Up @@ -1351,6 +1384,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask_blur_x: int = 4
mask_blur_y: int = 4
mask_blur: int = None
mask_round: bool = True
inpainting_fill: int = 0
inpaint_full_res: bool = True
inpaint_full_res_padding: int = 0
Expand Down Expand Up @@ -1396,7 +1430,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
if image_mask is not None:
# image_mask is passed in as RGBA by Gradio to support alpha masks,
# but we still want to support binary masks.
image_mask = create_binary_mask(image_mask)
image_mask = create_binary_mask(image_mask, round=self.mask_round)

if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
Expand Down Expand Up @@ -1503,7 +1537,8 @@ def init(self, all_prompts, all_seeds, all_subseeds):
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0]
latmask = np.around(latmask)
if self.mask_round:
latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1))

self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
Expand All @@ -1515,7 +1550,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask

self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)

def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = self.rng.next()
Expand All @@ -1527,7 +1562,14 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)

if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
blended_samples = samples * self.nmask + self.init_latent * self.mask

if self.scripts is not None:
mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
self.scripts.on_mask_blend(self, mba)
blended_samples = mba.blended_latent

samples = blended_samples

del x
devices.torch_gc()
Expand Down
70 changes: 70 additions & 0 deletions modules/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,31 @@

AlwaysVisible = object()

class MaskBlendArgs:
def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
self.current_latent = current_latent
self.nmask = nmask
self.init_latent = init_latent
self.mask = mask
self.blended_latent = blended_latent

self.denoiser = denoiser
self.is_final_blend = denoiser is None
self.sigma = sigma

class PostSampleArgs:
def __init__(self, samples):
self.samples = samples

class PostprocessImageArgs:
def __init__(self, image):
self.image = image

class PostProcessMaskOverlayArgs:
def __init__(self, index, mask_for_overlay, overlay_image):
self.index = index
self.mask_for_overlay = mask_for_overlay
self.overlay_image = overlay_image

class PostprocessBatchListArgs:
def __init__(self, images):
Expand Down Expand Up @@ -206,13 +226,39 @@ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwarg

pass

def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
"""
Called in inpainting mode when the original content is blended with the inpainted content.
This is called at every step in the denoising process and once at the end.
If is_final_blend is true, this is called for the final blending stage.
Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
"""

pass

def post_sample(self, p, ps: PostSampleArgs, *args):
"""
Called after the samples have been generated,
but before they have been decoded by the VAE, if applicable.
Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
"""

pass

def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
"""

pass

def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
"""
Called for every image after it has been generated.
"""

pass

def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
Expand Down Expand Up @@ -767,6 +813,22 @@ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
except Exception:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)

def post_sample(self, p, ps: PostSampleArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.post_sample(p, ps, *script_args)
except Exception:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)

def on_mask_blend(self, p, mba: MaskBlendArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.on_mask_blend(p, mba, *script_args)
except Exception:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)

def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
Expand All @@ -775,6 +837,14 @@ def postprocess_image(self, p, pp: PostprocessImageArgs):
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)

def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_maskoverlay(p, ppmo, *script_args)
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)

def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try:
Expand Down
21 changes: 19 additions & 2 deletions modules/sd_samplers_cfg_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __init__(self, sampler):
self.sampler = sampler
self.model_wrap = None
self.p = None

# NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise
self.mask_before_denoising = False

@property
Expand Down Expand Up @@ -105,8 +108,21 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):

assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"

# If we use masks, blending between the denoised and original latent images occurs here.
def apply_blend(current_latent):
blended_latent = current_latent * self.nmask + self.init_latent * self.mask

if self.p.scripts is not None:
from modules import scripts
mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
self.p.scripts.on_mask_blend(self.p, mba)
blended_latent = mba.blended_latent

return blended_latent

# Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
x = self.init_latent * self.mask + self.nmask * x
x = apply_blend(x)

batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
Expand Down Expand Up @@ -207,8 +223,9 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)

# Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
denoised = apply_blend(denoised)

self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)

Expand Down
Loading

0 comments on commit 8c32594

Please sign in to comment.