From c4ee6d9b73300d906b8df4602157d646df2415ef Mon Sep 17 00:00:00 2001 From: Robert Barron Date: Sun, 30 Jul 2023 00:41:10 -0700 Subject: [PATCH 001/449] xyz_grid: allow varying the seed along an axis along with the axis's other changes --- scripts/xyz_grid.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 1010845e56b..4eb1b197d50 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -416,6 +416,10 @@ def ui(self, is_img2img): with gr.Column(): include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) + with gr.Column(): + vary_seeds_x = gr.Checkbox(label='Vary seed on X axis', value=False, elem_id=self.elem_id("vary_seeds_x")) + vary_seeds_y = gr.Checkbox(label='Vary seed on Y axis', value=False, elem_id=self.elem_id("vary_seeds_y")) + vary_seeds_z = gr.Checkbox(label='Vary seed on Z axis', value=False, elem_id=self.elem_id("vary_seeds_z")) with gr.Column(): margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) @@ -475,9 +479,9 @@ def get_dropdown_update_from_params(axis,params): (z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)), ) - return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size] + return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size] - def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size): + def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size): if not no_fixed_seeds: modules.processing.fix_seed(p) @@ -648,6 +652,16 @@ def cell(x, y, z, ix, iy, iz): y_opt.apply(pc, y, ys) z_opt.apply(pc, z, zs) + xdim = len(xs) if vary_seeds_x else 1 + ydim = len(ys) if vary_seeds_y else 1 + + if vary_seeds_x: + pc.seed += ix + if vary_seeds_y: + pc.seed += iy * xdim + if vary_seeds_z: + pc.seed += iz * xdim * ydim + res = process_images(pc) # Sets subgrid infotexts From 8aa13d5dce2789a7d0bd802e6d62453b3c380496 Mon Sep 17 00:00:00 2001 From: Anthony Fu Date: Mon, 16 Oct 2023 14:12:18 +0800 Subject: [PATCH 002/449] Interrupt after current generation --- modules/call_queue.py | 1 + modules/img2img.py | 2 +- modules/processing.py | 2 +- modules/shared_options.py | 1 + modules/shared_state.py | 11 +++++++++-- scripts/loopback.py | 6 +++--- scripts/xyz_grid.py | 2 +- 7 files changed, 17 insertions(+), 8 deletions(-) diff --git a/modules/call_queue.py b/modules/call_queue.py index ddf0d57383c..01c6d17f685 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -78,6 +78,7 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): shared.state.skipped = False shared.state.interrupted = False + shared.state.interrupted_next = False shared.state.job_count = 0 if not add_stats: diff --git a/modules/img2img.py b/modules/img2img.py index 52cb577a646..31f8c2aaf4c 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -49,7 +49,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal if state.skipped: state.skipped = False - if state.interrupted: + if state.interrupted or state.interrupted_next: break try: diff --git a/modules/processing.py b/modules/processing.py index 40598f5cff4..e7eecd66d10 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -819,7 +819,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.skipped: state.skipped = False - if state.interrupted: + if state.interrupted or state.interrupted_next: break sd_models.reload_model_weights() # model can be changed for example by refiner diff --git a/modules/shared_options.py b/modules/shared_options.py index 32bf735323f..4638ef068f1 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -113,6 +113,7 @@ "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), + "interrupt_after_current": OptionInfo(False, "Interrupt generation after current image is finished on batch processing"), })) options_templates.update(options_section(('API', "API"), { diff --git a/modules/shared_state.py b/modules/shared_state.py index a68789cc815..c72c3f63cb2 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -12,6 +12,7 @@ class State: skipped = False interrupted = False + interrupted_next = False job = "" job_no = 0 job_count = 0 @@ -76,8 +77,12 @@ def skip(self): log.info("Received skip request") def interrupt(self): - self.interrupted = True - log.info("Received interrupt request") + if shared.opts.interrupt_after_current and self.job_count > 1: + self.interrupted_next = True + log.info("Received interrupt request, interrupt after current job") + else: + self.interrupted = True + log.info("Received interrupt request") def nextjob(self): if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: @@ -91,6 +96,7 @@ def dict(self): obj = { "skipped": self.skipped, "interrupted": self.interrupted, + "interrupted_next": self.interrupted_next, "job": self.job, "job_count": self.job_count, "job_timestamp": self.job_timestamp, @@ -114,6 +120,7 @@ def begin(self, job: str = "(unknown)"): self.id_live_preview = 0 self.skipped = False self.interrupted = False + self.interrupted_next = False self.textinfo = None self.job = job devices.torch_gc() diff --git a/scripts/loopback.py b/scripts/loopback.py index 2d5feaf9b26..ad921269afb 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -95,7 +95,7 @@ def calculate_denoising_strength(loop): processed = processing.process_images(p) # Generation cancelled. - if state.interrupted: + if state.interrupted or state.interrupted_next: break if initial_seed is None: @@ -122,8 +122,8 @@ def calculate_denoising_strength(loop): p.inpainting_fill = original_inpainting_fill - if state.interrupted: - break + if state.interrupted or state.interrupted_next: + break if len(history) > 1: grid = images.image_grid(history, rows=1) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 0dc255bc43d..495008ada08 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -688,7 +688,7 @@ def fix_axis_seeds(axis_opt, axis_list): grid_infotext = [None] * (1 + len(zs)) def cell(x, y, z, ix, iy, iz): - if shared.state.interrupted: + if shared.state.interrupted or state.interrupted_next: return Processed(p, [], p.seed, "") pc = copy(p) From 3d15e58b0a30f2ef1e731f9e429f4d3cf1c259c5 Mon Sep 17 00:00:00 2001 From: Anthony Fu Date: Mon, 16 Oct 2023 15:00:17 +0800 Subject: [PATCH 003/449] feat: refactor --- modules/shared_state.py | 12 ++++++------ modules/ui.py | 8 +++++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/modules/shared_state.py b/modules/shared_state.py index c72c3f63cb2..532fdcd8d12 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -77,12 +77,12 @@ def skip(self): log.info("Received skip request") def interrupt(self): - if shared.opts.interrupt_after_current and self.job_count > 1: - self.interrupted_next = True - log.info("Received interrupt request, interrupt after current job") - else: - self.interrupted = True - log.info("Received interrupt request") + self.interrupted = True + log.info("Received interrupt request") + + def interrupt_next(self): + self.interrupted_next = True + log.info("Received interrupt request, interrupt after current job") def nextjob(self): if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: diff --git a/modules/ui.py b/modules/ui.py index bcf39199715..c30093d755d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -216,8 +216,14 @@ def __init__(self, is_img2img): outputs=[], ) + def interrupt_fn(): + if shared.state.job_count > 1 and shared.opts.interrupt_after_current: + shared.state.interrupt_next() + else: + shared.state.interrupt() + self.interrupt.click( - fn=lambda: shared.state.interrupt(), + fn=interrupt_fn, inputs=[], outputs=[], ) From 7c128bbdac0da1767c239174e91af6f327845372 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:56:17 +0800 Subject: [PATCH 004/449] Add fp8 for sd unet --- extensions-builtin/Lora/network.py | 2 +- extensions-builtin/Lora/network_full.py | 4 ++-- extensions-builtin/Lora/network_glora.py | 10 +++++----- extensions-builtin/Lora/network_hada.py | 12 ++++++------ extensions-builtin/Lora/network_ia3.py | 2 +- extensions-builtin/Lora/network_lokr.py | 18 +++++++++--------- extensions-builtin/Lora/network_lora.py | 6 +++--- extensions-builtin/Lora/network_norm.py | 4 ++-- extensions-builtin/Lora/networks.py | 6 +++--- modules/cmd_args.py | 1 + modules/sd_models.py | 3 +++ 11 files changed, 36 insertions(+), 32 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 6021fd8de0f..a62e5eff9e1 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -137,7 +137,7 @@ def calc_scale(self): def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown += self.bias.to(orig_weight.device, dtype=updown.dtype) updown = updown.reshape(output_shape) if len(output_shape) == 4: diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index bf6930e96c0..f221c95f3b5 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -18,9 +18,9 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): def calc_updown(self, orig_weight): output_shape = self.weight.shape - updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.weight.to(orig_weight.device) if self.ex_bias is not None: - ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.ex_bias.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py index 492d487078d..efe5c6814fa 100644 --- a/extensions-builtin/Lora/network_glora.py +++ b/extensions-builtin/Lora/network_glora.py @@ -22,12 +22,12 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.w2b = weights.w["b2.weight"] def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] - updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) + updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a)) return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 5fcb0695fbb..d95a0fd18e3 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -27,16 +27,16 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.t2 = weights.w.get("hada_t2") def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] if self.t1 is not None: output_shape = [w1a.size(1), w1b.size(1)] - t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) + t1 = self.t1.to(orig_weight.device) updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) output_shape += t1.shape[2:] else: @@ -45,7 +45,7 @@ def calc_updown(self, orig_weight): updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) if self.t2 is not None: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) else: updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index 7edc4249791..96faeaf3ede 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -17,7 +17,7 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.on_input = weights.w["on_input"].item() def calc_updown(self, orig_weight): - w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) + w = self.w.to(orig_weight.device) output_shape = [w.size(0), orig_weight.size(1)] if self.on_input: diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index 340acdab3d0..fcdaeafd896 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -37,22 +37,22 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): def calc_updown(self, orig_weight): if self.w1 is not None: - w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) + w1 = self.w1.to(orig_weight.device) else: - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) w1 = w1a @ w1b if self.w2 is not None: - w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = self.w2.to(orig_weight.device) elif self.t2 is None: - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = w2a @ w2b else: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 26c0a72c237..4cc4029510f 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -61,13 +61,13 @@ def create_module(self, weights, key, none_ok=False): return module def calc_updown(self, orig_weight): - up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) - down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + up = self.up_model.weight.to(orig_weight.device) + down = self.down_model.weight.to(orig_weight.device) output_shape = [up.size(0), down.size(1)] if self.mid_model is not None: # cp-decomposition - mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + mid = self.mid_model.weight.to(orig_weight.device) updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) output_shape += mid.shape[2:] else: diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py index ce450158068..d25afcbb928 100644 --- a/extensions-builtin/Lora/network_norm.py +++ b/extensions-builtin/Lora/network_norm.py @@ -18,10 +18,10 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): def calc_updown(self, orig_weight): output_shape = self.w_norm.shape - updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.w_norm.to(orig_weight.device) if self.b_norm is not None: - ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.b_norm.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 60d8dec4cf5..8ea4ea6092d 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -381,12 +381,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight += updown + self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: - self.bias = torch.nn.Parameter(ex_bias) + self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: - self.bias += ex_bias + self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 4e602a84270..0f14c71e4a8 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -118,3 +118,4 @@ parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) +parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b6cdea18d4..3b8ff8209b3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -391,6 +391,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") + if shared.cmd_opts.opt_unet_fp8_storage: + model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From 5f9ddfa46f28ca2aa9e0bd832f6bbd67069be63e Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:57:22 +0800 Subject: [PATCH 005/449] Add sdxl only arg --- modules/cmd_args.py | 1 + modules/sd_models.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 0f14c71e4a8..20bfb2c442b 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -119,3 +119,4 @@ parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) +parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b8ff8209b3..08af128fc4e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -394,6 +394,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.opt_unet_fp8_storage: model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") + elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: + model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet for sdxl") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From eaa9f5162fbca2ebcb2682eb861bc7e5510a2b66 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 24 Oct 2023 01:49:05 +0800 Subject: [PATCH 006/449] Add CPU fp8 support Since norm layer need fp32, I only convert the linear operation layer(conv2d/linear) And TE have some pytorch function not support bf16 amp in CPU. I add a condition to indicate if the autocast is for unet. --- modules/devices.py | 6 +++++- modules/processing.py | 2 +- modules/sd_models.py | 20 ++++++++++++++++---- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 1d4eb563517..0cd2b55d279 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -71,6 +71,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") +fp8: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -93,10 +94,13 @@ def cond_cast_float(input): nv_rng = None -def autocast(disable=False): +def autocast(disable=False, unet=False): if disable: return contextlib.nullcontext() + if unet and fp8 and device==cpu: + return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 40598f5cff4..2df8a7ea7a8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(unet=True): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): diff --git a/modules/sd_models.py b/modules/sd_models.py index 08af128fc4e..c5fe57bfec0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -391,12 +391,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if shared.cmd_opts.opt_unet_fp8_storage: + + if shared.cmd_opts.opt_unet_fp8_storage: + enable_fp8 = True + elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: + enable_fp8 = True + + if enable_fp8: + devices.fp8 = True + if devices.device == devices.cpu: + for module in model.model.diffusion_model.modules(): + if isinstance(module, torch.nn.Conv2d): + module.to(torch.float8_e4m3fn) + elif isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet for cpu") + else: model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") - elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: - model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet for sdxl") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From 9c1eba2af3a6f9cd6282b3a367656793cbe70c01 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 24 Oct 2023 02:11:27 +0800 Subject: [PATCH 007/449] Fix lint --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index c5fe57bfec0..44d4038b25b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -396,7 +396,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True - + if enable_fp8: devices.fp8 = True if devices.device == devices.cpu: From 1df6c8bfec4715610d64684b6ad2fa38c76c1df6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:36:43 +0800 Subject: [PATCH 008/449] fp8 for TE --- modules/sd_models.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 44d4038b25b..69395294300 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -407,6 +407,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) timer.record("apply fp8 unet for cpu") else: + if model.is_sdxl: + cond_stage = model.conditioner + else: + cond_stage = model.cond_stage_model + for module in cond_stage.modules(): + if isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") From 4830b251366436ee8499c003fe87e46ddb4a4581 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:53:37 +0800 Subject: [PATCH 009/449] Fix alphas_cumprod dtype --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 69395294300..23660454505 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -416,6 +416,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") + model.alphas_cumprod = model.alphas_cumprod.to(torch.float32) devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From bf5067f50ca32cd4764638702e3cc38bca8bfd8b Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 12:54:28 +0800 Subject: [PATCH 010/449] Fix alphas cumprod --- modules/sd_models.py | 3 ++- modules/sd_models_xl.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 23660454505..7ed89a9c93f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -396,6 +396,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True + else: + enable_fp8 = False if enable_fp8: devices.fp8 = True @@ -416,7 +418,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") - model.alphas_cumprod = model.alphas_cumprod.to(torch.float32) devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 0112332161f..11259a369de 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -93,7 +93,7 @@ def extend_sdxl(model): model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() - model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) model.conditioner.wrapped = torch.nn.Module() From dda067f64d3289cee3ffd65767126cb30ae73b13 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 19:53:22 +0800 Subject: [PATCH 011/449] ignore mps for fp8 --- modules/sd_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 7ed89a9c93f..ccb6afd2ade 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -392,7 +392,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if shared.cmd_opts.opt_unet_fp8_storage: + if devices.get_optimal_device_name() == "mps": + enable_fp8 = False + elif shared.cmd_opts.opt_unet_fp8_storage: enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True From 0beb131c7ffae6f756a6339206da311232a36970 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 20:07:37 +0800 Subject: [PATCH 012/449] change torch version --- modules/launch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 8cdbafa50c8..636da679ecc 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -308,8 +308,8 @@ def requirements_met(requirements_file): def prepare_environment(): - torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') From d4d3134f6d2d232c7bcfa80900a362921e644976 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Oct 2023 15:24:26 +0800 Subject: [PATCH 013/449] ManualCast for 10/16 series gpu --- modules/devices.py | 57 ++++++++++++++++++++++++++++++++++++++----- modules/processing.py | 2 +- modules/sd_models.py | 21 +++++++++------- 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 0cd2b55d279..c05f2b35e5e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -16,6 +16,23 @@ def has_mps() -> bool: return mac_specific.has_mps +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -60,8 +77,7 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() - if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -92,15 +108,44 @@ def cond_cast_float(input): nv_rng = None - - -def autocast(disable=False, unet=False): +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] + +@contextlib.contextmanager +def manual_autocast(): + def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = manual_cast_forward + module_type.org_forward = org_forward + try: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward + + +def autocast(disable=False): + print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() - if unet and fp8 and device==cpu: + if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 2df8a7ea7a8..40598f5cff4 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(unet=True): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): diff --git a/modules/sd_models.py b/modules/sd_models.py index ccb6afd2ade..31bcb913a51 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -403,23 +403,26 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if enable_fp8: devices.fp8 = True + if model.is_sdxl: + cond_stage = model.conditioner + else: + cond_stage = model.cond_stage_model + + for module in cond_stage.modules(): + if isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + if devices.device == devices.cpu: for module in model.model.diffusion_model.modules(): if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) elif isinstance(module, torch.nn.Linear): module.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet for cpu") else: - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet") + timer.record("apply fp8") + else: + devices.fp8 = False devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From ddc2a3499b8cd120b4a42358bcd33137ce1d1e75 Mon Sep 17 00:00:00 2001 From: KohakuBlueleaf Date: Sat, 28 Oct 2023 16:52:35 +0800 Subject: [PATCH 014/449] Add MPS manual cast --- modules/devices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/devices.py b/modules/devices.py index c05f2b35e5e..d7c905c283a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -121,6 +121,8 @@ def manual_autocast(): def manual_cast_forward(self, *args, **kwargs): org_dtype = next(self.parameters()).dtype self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} result = self.org_forward(*args, **kwargs) self.to(org_dtype) return result @@ -136,7 +138,6 @@ def manual_cast_forward(self, *args, **kwargs): def autocast(disable=False): - print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() @@ -146,6 +147,9 @@ def autocast(disable=False): if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): return manual_autocast() + if has_mps() and shared.cmd_opts.precision != "full": + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() From f2b83517aa49268219706f9718d734c6f242fd93 Mon Sep 17 00:00:00 2001 From: Nick Harrison <42382362+nickpharrison@users.noreply.github.com> Date: Sun, 29 Oct 2023 15:40:13 +0000 Subject: [PATCH 015/449] Add new arguments to known command prompts --- modules/cmd_args.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 4e602a84270..c6d4b61239e 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -76,7 +76,9 @@ parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json')) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) -parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) +parser.add_argument("--freeze-settings", action='store_true', help="disable editing of all settings globally", default=False) +parser.add_argument("--freeze-settings-in-sections", type=str, help='disable editing settings in specific sections of the settings page by specifying a comma-delimited list such like "saving-images,upscaling". The list of setting names can be found in the modules/shared_options.py file', default=None) +parser.add_argument("--freeze-specific-settings", type=str, help='disable editing of individual settings by specifying a comma-delimited list like "samples_save,samples_format". The list of setting names can be found in the config.json file', default=None) parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) @@ -90,7 +92,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) -parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts +parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") @@ -112,9 +114,8 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') -parser.add_argument('--add-stop-route', action='store_true', help='does not do anything') +parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) -parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) -parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) +parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) From 844c23975f369f85c7179379ec419e1bd067de18 Mon Sep 17 00:00:00 2001 From: Nick Harrison <42382362+nickpharrison@users.noreply.github.com> Date: Sun, 29 Oct 2023 15:40:58 +0000 Subject: [PATCH 016/449] Add assertions for checking additional settings freezing parameters --- modules/options.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/modules/options.py b/modules/options.py index ab40aff73f0..1ac32d9503a 100644 --- a/modules/options.py +++ b/modules/options.py @@ -85,18 +85,35 @@ def __setattr__(self, key, value): if self.data is not None: if key in self.data or key in self.data_labels: + + # Check that settings aren't globally frozen assert not cmd_opts.freeze_settings, "changing settings is disabled" + # Get the info related to the setting being changed info = self.data_labels.get(key, None) if info.do_not_save: return + # Restrict component arguments comp_args = info.component_args if info else None if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: - raise RuntimeError(f"not possible to set {key} because it is restricted") + raise RuntimeError(f"not possible to set '{key}' because it is restricted") + # Check that this section isn't frozen + if cmd_opts.freeze_settings_in_sections is not None: + frozen_sections = list(map(str.strip, cmd_opts.freeze_settings_in_sections.split(','))) # Trim whitespace from section names + section_key = info.section[0] + section_name = info.section[1] + assert section_key not in frozen_sections, f"not possible to set '{key}' because settings in section '{section_name}' ({section_key}) are frozen with --freeze-settings-in-sections" + + # Check that this section of the settings isn't frozen + if cmd_opts.freeze_specific_settings is not None: + frozen_keys = list(map(str.strip, cmd_opts.freeze_specific_settings.split(','))) # Trim whitespace from setting keys + assert key not in frozen_keys, f"not possible to set '{key}' because this setting is frozen with --freeze-specific-settings" + + # Check shorthand option which disables editing options in "saving-paths" if cmd_opts.hide_ui_dir_config and key in self.restricted_opts: - raise RuntimeError(f"not possible to set {key} because it is restricted") + raise RuntimeError(f"not possible to set '{key}' because it is restricted with --hide_ui_dir_config") self.data[key] = value return @@ -210,8 +227,6 @@ def dumpjson(self): def add_option(self, key, info): self.data_labels[key] = info - if key not in self.data: - self.data[key] = info.default def reorder(self): """reorder settings so that all items related to section always go together""" From be31e7e71a08dc27543d31aa6e6532463ccbf20f Mon Sep 17 00:00:00 2001 From: Nick Harrison <42382362+nickpharrison@users.noreply.github.com> Date: Sun, 29 Oct 2023 16:05:01 +0000 Subject: [PATCH 017/449] Remove blank line whitespace --- modules/options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/options.py b/modules/options.py index 1ac32d9503a..e270a42df85 100644 --- a/modules/options.py +++ b/modules/options.py @@ -85,7 +85,7 @@ def __setattr__(self, key, value): if self.data is not None: if key in self.data or key in self.data_labels: - + # Check that settings aren't globally frozen assert not cmd_opts.freeze_settings, "changing settings is disabled" From 6b9795849d497b41514aa9462690cf7c2802e4f6 Mon Sep 17 00:00:00 2001 From: hako-mikan <122196982+hako-mikan@users.noreply.github.com> Date: Thu, 9 Nov 2023 20:23:37 +0900 Subject: [PATCH 018/449] Fix model switch bug --- extensions-builtin/Lora/networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 96f935b236f..a21ea0fa356 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -418,7 +418,7 @@ def network_forward(module, input, original_forward): def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): self.network_current_names = () self.network_weights_backup = None - + self.network_bias_backup = None def network_Linear_forward(self, input): if shared.opts.lora_functional: From 598da5cd4928618b166886d3485ce30ce3a43490 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:50:06 +0800 Subject: [PATCH 019/449] Use options instead of cmd_args --- modules/cmd_args.py | 2 -- modules/devices.py | 25 +++++++++------- modules/initialize_util.py | 1 + modules/sd_models.py | 61 ++++++++++++++++++++------------------ modules/shared_options.py | 1 + scripts/xyz_grid.py | 1 + 6 files changed, 49 insertions(+), 42 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 088d5deaac1..a9fb9bfa3ef 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -118,5 +118,3 @@ parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) -parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) -parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/devices.py b/modules/devices.py index d7c905c283a..03e7bdb7c2e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool: if device_id is None: device_id = get_cuda_device_id() return ( - torch.cuda.get_device_capability(device_id) == (7, 5) + torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") ) def get_cuda_device_id(): return ( - int(shared.cmd_opts.device_id) - if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0 ) or torch.cuda.current_device() @@ -116,16 +116,19 @@ def cond_cast_float(input): torch.nn.LayerNorm, ] + +def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + + @contextlib.contextmanager def manual_autocast(): - def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 2e9b6d895f4..1b11ead61fd 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -177,6 +177,7 @@ def configure_opts_onchange(): shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) + shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) startup_timer.record("opts onchange") diff --git a/modules/sd_models.py b/modules/sd_models.py index a6c8b2fa7db..eb4914347b9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -339,10 +339,28 @@ def __exit__(self, exc_type, exc_value, exc_traceback): SkipWritingToConfig.skip = self.previous +def check_fp8(model): + if model is None: + return None + if devices.get_optimal_device_name() == "mps": + enable_fp8 = False + elif shared.opts.fp8_storage == "Enable": + enable_fp8 = True + elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL": + enable_fp8 = True + else: + enable_fp8 = False + return enable_fp8 + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") + if not check_fp8(model) and devices.fp8: + # prevent model to load state dict in fp8 + model.half() + if not SkipWritingToConfig.skip: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title @@ -395,34 +413,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if devices.get_optimal_device_name() == "mps": - enable_fp8 = False - elif shared.cmd_opts.opt_unet_fp8_storage: - enable_fp8 = True - elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: - enable_fp8 = True - else: - enable_fp8 = False - - if enable_fp8: + if check_fp8(model): devices.fp8 = True - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): + first_stage = model.first_stage_model + model.first_stage_model = None + for module in model.modules(): + if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) - - if devices.device == devices.cpu: - for module in model.model.diffusion_model.modules(): - if isinstance(module, torch.nn.Conv2d): - module.to(torch.float8_e4m3fn) - elif isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) - else: - model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + elif isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + model.first_stage_model = first_stage timer.record("apply fp8") else: devices.fp8 = False @@ -769,7 +769,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): return None -def reload_model_weights(sd_model=None, info=None): +def reload_model_weights(sd_model=None, info=None, forced_reload=False): checkpoint_info = info or select_checkpoint() timer = Timer() @@ -781,11 +781,14 @@ def reload_model_weights(sd_model=None, info=None): current_checkpoint_info = None else: current_checkpoint_info = sd_model.sd_checkpoint_info - if sd_model.sd_model_checkpoint == checkpoint_info.filename: + if check_fp8(sd_model) != devices.fp8: + # load from state dict again to prevent extra numerical errors + forced_reload = True + elif sd_model.sd_model_checkpoint == checkpoint_info.filename: return sd_model sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) - if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: return sd_model if sd_model is not None: diff --git a/modules/shared_options.py b/modules/shared_options.py index f1003f2187b..d27f35e96fc 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -200,6 +200,7 @@ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), + "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 0dc255bc43d..b2250c04dc0 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -270,6 +270,7 @@ def __init__(self, *args, **kwargs): AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)), AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')), AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]), + AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]), ] From 890181e1d456b613bf60f6e8378dc68b39011af9 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:54:39 +0800 Subject: [PATCH 020/449] Update the xformers/torch versions --- modules/errors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index 8c339464d46..a3498c11f4b 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -93,8 +93,8 @@ def check_versions(): import torch import gradio - expected_torch_version = "2.0.0" - expected_xformers_version = "0.0.20" + expected_torch_version = "2.1.0" + expected_xformers_version = "0.0.22.post7" expected_gradio_version = "3.41.2" if version.parse(torch.__version__) < version.parse(expected_torch_version): From f383af2729ec2d1969200218577ab19dd78f7d48 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:56:23 +0800 Subject: [PATCH 021/449] update xformers/torch versions --- modules/launch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 636da679ecc..c225bbc180c 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -312,7 +312,7 @@ def prepare_environment(): torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.22.post7') clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") From 043d2edcf6a543f236f1f3cb70ac72e7b3b357b6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:56:31 +0800 Subject: [PATCH 022/449] Better naming --- modules/devices.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 03e7bdb7c2e..c19a7f402f1 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -128,7 +128,7 @@ def manual_cast_forward(self, *args, **kwargs): @contextlib.contextmanager -def manual_autocast(): +def manual_cast(): for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward @@ -148,10 +148,10 @@ def autocast(disable=False): return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): - return manual_autocast() + return manual_cast() if has_mps() and shared.cmd_opts.precision != "full": - return manual_autocast() + return manual_cast() if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() From b2e039d07bed76350120ff448964c907a3b5e4a3 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 20 Nov 2023 14:05:32 +0800 Subject: [PATCH 023/449] Update webui-macos-env.sh --- webui-macos-env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 24bc5c42615..db7e8b1a05b 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -11,7 +11,7 @@ fi export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" -export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" +export TORCH_COMMAND="pip install torch==2.1.0 torchvision==0.16.0" export PYTORCH_ENABLE_MPS_FALLBACK=1 #################################################################### From 370a77f8e78e65a8a1339289d684cb43df142f70 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 21 Nov 2023 19:59:34 +0800 Subject: [PATCH 024/449] Option for using fp16 weight when apply lora --- extensions-builtin/Lora/networks.py | 16 ++++++++++++---- modules/initialize_util.py | 1 + modules/sd_models.py | 14 +++++++++++--- modules/shared_options.py | 1 + 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 0170dbfbd23..d22ed843868 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -388,18 +388,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn if module is not None and hasattr(self, 'weight'): try: with torch.no_grad(): - updown, ex_bias = module.calc_updown(self.weight) + if getattr(self, 'fp16_weight', None) is None: + weight = self.weight + bias = self.bias + else: + weight = self.fp16_weight.clone().to(self.weight.device) + bias = getattr(self, 'fp16_bias', None) + if bias is not None: + bias = bias.clone().to(self.bias.device) + updown, ex_bias = module.calc_updown(weight) - if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + if len(weight.shape) == 4 and weight.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) + self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: - self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype)) + self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 1b11ead61fd..7fb1d8d5093 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -178,6 +178,7 @@ def configure_opts_onchange(): shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) startup_timer.record("opts onchange") diff --git a/modules/sd_models.py b/modules/sd_models.py index eb4914347b9..0a7777f1c82 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -413,14 +413,22 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") + for module in model.modules(): + if hasattr(module, 'fp16_weight'): + del module.fp16_weight + if hasattr(module, 'fp16_bias'): + del module.fp16_bias + if check_fp8(model): devices.fp8 = True first_stage = model.first_stage_model model.first_stage_model = None for module in model.modules(): - if isinstance(module, torch.nn.Conv2d): - module.to(torch.float8_e4m3fn) - elif isinstance(module, torch.nn.Linear): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): + if shared.opts.cache_fp16_weight: + module.fp16_weight = module.weight.clone().half() + if module.bias is not None: + module.fp16_bias = module.bias.clone().half() module.to(torch.float8_e4m3fn) model.first_stage_model = first_stage timer.record("apply fp8") diff --git a/modules/shared_options.py b/modules/shared_options.py index d27f35e96fc..eaa9f1352c7 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -201,6 +201,7 @@ "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), + "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), })) options_templates.update(options_section(('compatibility', "Compatibility"), { From f5d719d1f1baa775d838aa75d9af1971bcc78e8f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 22 Nov 2023 01:45:56 +0800 Subject: [PATCH 025/449] Add forced reload for fp16 cache --- modules/initialize_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 7fb1d8d5093..b6767138dd5 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -178,7 +178,7 @@ def configure_opts_onchange(): shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) - shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False) startup_timer.record("opts onchange") From 40ac134c553ac824d4a96666bba14d550300daa5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 25 Nov 2023 12:35:09 +0800 Subject: [PATCH 026/449] Fix pre-fp8 --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a7777f1c82..90437c87ac1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -357,7 +357,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - if not check_fp8(model) and devices.fp8: + if devices.fp8: # prevent model to load state dict in fp8 model.half() From 29f04149b60bcf6e8e2b41a161d6cc7e8981710f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 26 Nov 2023 12:07:33 +0300 Subject: [PATCH 027/449] update torch to 2.1.0 --- modules/errors.py | 4 ++-- modules/launch_utils.py | 6 +++--- webui-macos-env.sh | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index eb234a83811..c534a5d66fe 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -107,8 +107,8 @@ def check_versions(): import torch import gradio - expected_torch_version = "2.0.0" - expected_xformers_version = "0.0.20" + expected_torch_version = "2.1.0" + expected_xformers_version = "0.0.22.post7" expected_gradio_version = "3.41.2" if version.parse(torch.__version__) < version.parse(expected_torch_version): diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 264ec9ca6a4..1f2b6c5e8f6 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -308,11 +308,11 @@ def requirements_met(requirements_file): def prepare_environment(): - torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.22.post7') clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 24bc5c42615..db7e8b1a05b 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -11,7 +11,7 @@ fi export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" -export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" +export TORCH_COMMAND="pip install torch==2.1.0 torchvision==0.16.0" export PYTORCH_ENABLE_MPS_FALLBACK=1 #################################################################### From dec791d35ddcd02ca33563d3d0355e05e45de8ad Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 15:05:01 -0700 Subject: [PATCH 028/449] Removed code which forces the inpainting mask to be 0 or 1. Now fractional values (e.g. 0.5) are accepted. --- modules/processing.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index e124e7f0dd2..317458f582f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -83,7 +83,7 @@ def apply_overlay(image, paste_loc, index, overlays): def create_binary_mask(image): if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255): - image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) + image = image.split()[-1].convert("L") else: image = image.convert('L') return image @@ -319,9 +319,6 @@ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=N conditioning_mask = np.array(image_mask.convert("L")) conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) - - # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 - conditioning_mask = torch.round(conditioning_mask) else: conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:]) @@ -1504,7 +1501,6 @@ def init(self, all_prompts, all_seeds, all_subseeds): latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = latmask[0] - latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) From bbba133f054706c3668b7d03b0e6d0afc15705db Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 15:09:43 -0700 Subject: [PATCH 029/449] Removed conflicting step that replaces the softly inpainted latents with a naive blend with the original latents. --- modules/processing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 317458f582f..ae894f1a756 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1523,9 +1523,6 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) - if self.mask is not None: - samples = samples * self.nmask + self.init_latent * self.mask - del x devices.torch_gc() From e715e46b6aa7f2e5e147cfa1fa2f49b1d926a074 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 16:10:22 -0700 Subject: [PATCH 030/449] Implements "scheduling" for blending of the original latents and a latent blending formula that preserves details in blend transition areas. --- modules/sd_samplers_cfg_denoiser.py | 61 ++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index b8101d38dc3..c4d6fda650f 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -43,6 +43,9 @@ def __init__(self, sampler): self.model_wrap = None self.mask = None self.nmask = None + self.mask_blend_power = 1 + self.mask_blend_scale = 1 + self.mask_blend_offset = 0 self.init_latent = None self.steps = None """number of steps as specified by user in UI""" @@ -56,6 +59,9 @@ def __init__(self, sampler): self.sampler = sampler self.model_wrap = None self.p = None + + # NOTE: masking before denoising can cause the original latents to be oversmoothed + # as the original latents do not have noise self.mask_before_denoising = False @property @@ -89,6 +95,55 @@ def update_inner_model(self): self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): + def latent_blend(a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + # Record the original latent vector magnitudes. + # We bring them to a power so that larger magnitudes are favored over smaller ones. + # 64-bit operations are used here to allow large exponents. + detail_preservation = 32 + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** detail_preservation + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** detail_preservation + + one_minus_t = 1 - t + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / detail_preservation) + + # Linearly interpolate the image vectors. + image_interp = a * one_minus_t + b * t + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + image_interp_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64) + 0.0001 + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp *= (interp_magnitude / image_interp_magnitude).to(image_interp.dtype) + + return image_interp + + def get_modified_nmask(nmask, _sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale + self.mask_blend_offset) + if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -105,8 +160,9 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: - x = self.init_latent * self.mask + self.nmask * x + x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma)) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -207,8 +263,9 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): else: denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised + denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma)) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) From a6e584645305c0a91a3d46f73546e191b249210f Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 16:13:42 -0700 Subject: [PATCH 031/449] Nerfs the aggressive post-processing step of overlaying the original image. --- modules/processing.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index ae894f1a756..12e08e87630 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1412,7 +1412,12 @@ def init(self, all_prompts, all_seeds, all_subseeds): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: - self.mask_for_overlay = image_mask + np_mask = np.array(image_mask).astype(np.float32) + np_mask /= 255 + np_mask = 1-pow(1-np_mask, 100) + np_mask *= 255 + np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) + self.mask_for_overlay = Image.fromarray(np_mask) mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1423,8 +1428,11 @@ def init(self, all_prompts, all_seeds, all_subseeds): self.paste_to = (x1, y1, x2-x1, y2-y1) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) - np_mask = np.array(image_mask) - np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) + np_mask = np.array(image_mask).astype(np.float32) + np_mask /= 255 + np_mask = 1-pow(1-np_mask, 100) + np_mask *= 255 + np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) self.mask_for_overlay = Image.fromarray(np_mask) self.overlay_images = [] From debf836fcc8d9becc3da8b1a29e33f40b0d9ef3e Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 16:15:36 -0700 Subject: [PATCH 032/449] Added UI elements to control blending parameters. --- modules/img2img.py | 48 +++++++++++++++++++++++++++++++- modules/processing.py | 3 ++ modules/sd_samplers_common.py | 3 ++ modules/ui.py | 9 ++++++ scripts/outpainting_mk_2.py | 10 +++++-- scripts/poor_mans_outpainting.py | 11 ++++++-- test/test_img2img.py | 3 ++ 7 files changed, 82 insertions(+), 5 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 1519e132b2b..240d0588422 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -116,7 +116,47 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal process_images(p) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): +def img2img(id_task: str, + mode: int, + prompt: str, + negative_prompt: str, + prompt_styles, + init_img, + sketch, + init_img_with_mask, + inpaint_color_sketch, + inpaint_color_sketch_orig, + init_img_inpaint, + init_mask_inpaint, + steps: int, + sampler_name: str, + mask_blur: int, + mask_alpha: float, + mask_blend_power: float, + mask_blend_scale: float, + mask_blend_offset: float, + inpainting_fill: int, + n_iter: int, + batch_size: int, + cfg_scale: float, + image_cfg_scale: float, + denoising_strength: float, + selected_scale_tab: int, + height: int, + width: int, + scale_by: float, + resize_mode: int, + inpaint_full_res: bool, + inpaint_full_res_padding: int, + inpainting_mask_invert: int, + img2img_batch_input_dir: str, + img2img_batch_output_dir: str, + img2img_batch_inpaint_mask_dir: str, + override_settings_texts, + img2img_batch_use_png_info: bool, + img2img_batch_png_info_props: list, + img2img_batch_png_info_dir: str, + request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -174,6 +214,9 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s init_images=[image], mask=mask, mask_blur=mask_blur, + mask_blend_power=mask_blend_power, + mask_blend_scale=mask_blend_scale, + mask_blend_offset=mask_blend_offset, inpainting_fill=inpainting_fill, resize_mode=resize_mode, denoising_strength=denoising_strength, @@ -194,6 +237,9 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if mask: p.extra_generation_params["Mask blur"] = mask_blur + p.extra_generation_params["Mask blend power"] = mask_blend_power + p.extra_generation_params["Mask blend scale"] = mask_blend_scale + p.extra_generation_params["Mask blend offset"] = mask_blend_offset with closing(p): if is_batch: diff --git a/modules/processing.py b/modules/processing.py index 12e08e87630..da4d6fda9ff 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1349,6 +1349,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_x: int = 4 mask_blur_y: int = 4 mask_blur: int = None + mask_blend_power: float = 1 + mask_blend_scale: float = 1 + mask_blend_offset: float = 0 inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 58efcad2374..8904da2fb6b 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -277,6 +277,9 @@ def initialize(self, p) -> dict: self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None + self.model_wrap_cfg.mask_blend_power = p.mask_blend_power if hasattr(p, 'mask_blend_power') else None + self.model_wrap_cfg.mask_blend_scale = p.mask_blend_scale if hasattr(p, 'mask_blend_scale') else None + self.model_wrap_cfg.mask_blend_offset = p.mask_blend_offset if hasattr(p, 'mask_blend_offset') else None self.model_wrap_cfg.step = 0 self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) diff --git a/modules/ui.py b/modules/ui.py index 579bab9800c..86c130869e5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -732,6 +732,9 @@ def copy_image(img): with FormRow(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") + mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_scale") + mask_blend_offset = gr.Slider(label='Mask blend offset', minimum=-4, maximum=4, step=0.1, value=0, elem_id="img2img_mask_blend_offset") with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -781,6 +784,9 @@ def select_img2img_tab(tab): sampler_name, mask_blur, mask_alpha, + mask_blend_power, + mask_blend_scale, + mask_blend_offset, inpainting_fill, batch_count, batch_size, @@ -879,6 +885,9 @@ def select_img2img_tab(tab): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), + (mask_blend_power, "Mask blend power"), + (mask_blend_scale, "Mask blend scale"), + (mask_blend_offset, "Mask blend offset"), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index c98ab48098e..6aa97edfabe 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -133,13 +133,16 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) + mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) + mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) + mask_blend_offset = gr.Slider(label='Mask blend scale', minimum=-4, maximum=4, step=0.1, value=1, elem_id=self.elem_id("mask_blend_offset")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - return [info, pixels, mask_blur, direction, noise_q, color_variation] + return [info, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, direction, noise_q, color_variation] - def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): + def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, direction, noise_q, color_variation): initial_seed_and_info = [None, None] process_width = p.width @@ -167,6 +170,9 @@ def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): p.mask_blur_x = mask_blur_x*4 p.mask_blur_y = mask_blur_y*4 + p.mask_blend_power = mask_blend_power + p.mask_blend_scale = mask_blend_scale + p.mask_blend_offset = mask_blend_offset init_img = p.init_images[0] target_w = math.ceil((init_img.width + left + right) / 64) * 64 diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index ea0632b68c1..b10140f1420 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -22,16 +22,23 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) + mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) + mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) + mask_blend_offset = gr.Slider(label='Mask blend offset', minimum=-4, maximum=4, step=0.1, value=0, elem_id=self.elem_id("mask_blend_offset")) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - return [pixels, mask_blur, inpainting_fill, direction] + return [pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, inpainting_fill, direction] - def run(self, p, pixels, mask_blur, inpainting_fill, direction): + def run(self, p, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, inpainting_fill, direction): initial_seed = None initial_info = None p.mask_blur = mask_blur * 2 + p.mask_blend_power = mask_blend_power + p.mask_blend_scale = mask_blend_scale + p.mask_blend_offset = mask_blend_offset + p.inpainting_fill = inpainting_fill p.inpaint_full_res = False diff --git a/test/test_img2img.py b/test/test_img2img.py index 117d2d1eb45..6289e59e14b 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -24,6 +24,9 @@ def simple_img2img_request(img2img_basic_image_base64): "inpainting_mask_invert": False, "mask": None, "mask_blur": 4, + "mask_blend_power": 1, + "mask_blend_scale": 1, + "mask_blend_offset": 0, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From c5c7fa06aae1ae9f8b6d29ae2da3874921d4729b Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 22:35:07 -0700 Subject: [PATCH 033/449] Added slider for detail preservation strength, removed largely needless offset parameter, changed labels in UI and for saving to/pasting data from PNG files. --- modules/img2img.py | 10 +++++----- modules/processing.py | 2 +- modules/sd_samplers_cfg_denoiser.py | 11 +++++------ modules/sd_samplers_common.py | 2 +- modules/ui.py | 14 +++++++------- scripts/outpainting_mk_2.py | 12 ++++++------ scripts/poor_mans_outpainting.py | 12 ++++++------ test/test_img2img.py | 2 +- 8 files changed, 32 insertions(+), 33 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 240d0588422..023808d6c00 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -134,7 +134,7 @@ def img2img(id_task: str, mask_alpha: float, mask_blend_power: float, mask_blend_scale: float, - mask_blend_offset: float, + inpaint_detail_preservation: float, inpainting_fill: int, n_iter: int, batch_size: int, @@ -216,7 +216,7 @@ def img2img(id_task: str, mask_blur=mask_blur, mask_blend_power=mask_blend_power, mask_blend_scale=mask_blend_scale, - mask_blend_offset=mask_blend_offset, + inpaint_detail_preservation=inpaint_detail_preservation, inpainting_fill=inpainting_fill, resize_mode=resize_mode, denoising_strength=denoising_strength, @@ -237,9 +237,9 @@ def img2img(id_task: str, if mask: p.extra_generation_params["Mask blur"] = mask_blur - p.extra_generation_params["Mask blend power"] = mask_blend_power - p.extra_generation_params["Mask blend scale"] = mask_blend_scale - p.extra_generation_params["Mask blend offset"] = mask_blend_offset + p.extra_generation_params["Mask blending bias"] = mask_blend_power + p.extra_generation_params["Mask blending preservation"] = mask_blend_scale + p.extra_generation_params["Mask blending detail boost"] = inpaint_detail_preservation with closing(p): if is_batch: diff --git a/modules/processing.py b/modules/processing.py index da4d6fda9ff..361e8b05d72 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1351,7 +1351,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur: int = None mask_blend_power: float = 1 mask_blend_scale: float = 1 - mask_blend_offset: float = 0 + inpaint_detail_preservation: float = 16 inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index c4d6fda650f..598cd487640 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -45,7 +45,7 @@ def __init__(self, sampler): self.nmask = None self.mask_blend_power = 1 self.mask_blend_scale = 1 - self.mask_blend_offset = 0 + self.inpaint_detail_preservation = 16 self.init_latent = None self.steps = None """number of steps as specified by user in UI""" @@ -105,14 +105,13 @@ def latent_blend(a, b, t): # Record the original latent vector magnitudes. # We bring them to a power so that larger magnitudes are favored over smaller ones. # 64-bit operations are used here to allow large exponents. - detail_preservation = 32 - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** detail_preservation - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** detail_preservation + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation one_minus_t = 1 - t # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / detail_preservation) + interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / self.inpaint_detail_preservation) # Linearly interpolate the image vectors. image_interp = a * one_minus_t + b * t @@ -142,7 +141,7 @@ def get_modified_nmask(nmask, _sigma): NOTE: "mask" is not used """ - return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale + self.mask_blend_offset) + return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale) if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 8904da2fb6b..ecd8ab0a093 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -279,7 +279,7 @@ def initialize(self, p) -> dict: self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.mask_blend_power = p.mask_blend_power if hasattr(p, 'mask_blend_power') else None self.model_wrap_cfg.mask_blend_scale = p.mask_blend_scale if hasattr(p, 'mask_blend_scale') else None - self.model_wrap_cfg.mask_blend_offset = p.mask_blend_offset if hasattr(p, 'mask_blend_offset') else None + self.model_wrap_cfg.inpaint_detail_preservation = p.inpaint_detail_preservation if hasattr(p, 'inpaint_detail_preservation') else None self.model_wrap_cfg.step = 0 self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) diff --git a/modules/ui.py b/modules/ui.py index 86c130869e5..f5e201477a6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -732,9 +732,9 @@ def copy_image(img): with FormRow(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") - mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") - mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_scale") - mask_blend_offset = gr.Slider(label='Mask blend offset', minimum=-4, maximum=4, step=0.1, value=0, elem_id="img2img_mask_blend_offset") + mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=1, elem_id="img2img_mask_blend_scale") + inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id="img2img_mask_blend_offset") with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -786,7 +786,7 @@ def select_img2img_tab(tab): mask_alpha, mask_blend_power, mask_blend_scale, - mask_blend_offset, + inpaint_detail_preservation, inpainting_fill, batch_count, batch_size, @@ -885,9 +885,9 @@ def select_img2img_tab(tab): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), - (mask_blend_power, "Mask blend power"), - (mask_blend_scale, "Mask blend scale"), - (mask_blend_offset, "Mask blend offset"), + (mask_blend_power, "Mask blending bias"), + (mask_blend_scale, "Mask blending preservation"), + (inpaint_detail_preservation, "Mask blending detail boost"), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index 6aa97edfabe..54d95825aba 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -133,16 +133,16 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) - mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) - mask_blend_offset = gr.Slider(label='Mask blend scale', minimum=-4, maximum=4, step=0.1, value=1, elem_id=self.elem_id("mask_blend_offset")) + mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) + inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id=self.elem_id("inpaint_detail_preservation")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - return [info, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, direction, noise_q, color_variation] + return [info, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation] - def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, direction, noise_q, color_variation): + def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation): initial_seed_and_info = [None, None] process_width = p.width @@ -172,7 +172,7 @@ def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_ p.mask_blur_y = mask_blur_y*4 p.mask_blend_power = mask_blend_power p.mask_blend_scale = mask_blend_scale - p.mask_blend_offset = mask_blend_offset + p.inpaint_detail_preservation = inpaint_detail_preservation init_img = p.init_images[0] target_w = math.ceil((init_img.width + left + right) / 64) * 64 diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index b10140f1420..e3acb3d4701 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -22,22 +22,22 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) - mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) - mask_blend_offset = gr.Slider(label='Mask blend offset', minimum=-4, maximum=4, step=0.1, value=0, elem_id=self.elem_id("mask_blend_offset")) + mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) + inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id=self.elem_id("inpaint_detail_preservation")) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - return [pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, inpainting_fill, direction] + return [pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction] - def run(self, p, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, inpainting_fill, direction): + def run(self, p, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction): initial_seed = None initial_info = None p.mask_blur = mask_blur * 2 p.mask_blend_power = mask_blend_power p.mask_blend_scale = mask_blend_scale - p.mask_blend_offset = mask_blend_offset + p.inpaint_detail_preservation = inpaint_detail_preservation p.inpainting_fill = inpainting_fill p.inpaint_full_res = False diff --git a/test/test_img2img.py b/test/test_img2img.py index 6289e59e14b..88b06eb8d7f 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -26,7 +26,7 @@ def simple_img2img_request(img2img_basic_image_base64): "mask_blur": 4, "mask_blend_power": 1, "mask_blend_scale": 1, - "mask_blend_offset": 0, + "inpaint_detail_preservation": 16, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From 284fd8f415ec70e14ae5de0b7f5ce738007a6b7f Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 23:03:50 -0700 Subject: [PATCH 034/449] Tweaked UI sliders and labels. --- modules/img2img.py | 2 +- modules/ui.py | 6 +++--- scripts/outpainting_mk_2.py | 4 ++-- scripts/poor_mans_outpainting.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 023808d6c00..0ae1636541a 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -239,7 +239,7 @@ def img2img(id_task: str, p.extra_generation_params["Mask blur"] = mask_blur p.extra_generation_params["Mask blending bias"] = mask_blend_power p.extra_generation_params["Mask blending preservation"] = mask_blend_scale - p.extra_generation_params["Mask blending detail boost"] = inpaint_detail_preservation + p.extra_generation_params["Mask blending contrast boost"] = inpaint_detail_preservation with closing(p): if is_batch: diff --git a/modules/ui.py b/modules/ui.py index f5e201477a6..3a9038b227d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -733,8 +733,8 @@ def copy_image(img): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=1, elem_id="img2img_mask_blend_scale") - inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id="img2img_mask_blend_offset") + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id="img2img_mask_blend_scale") + inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id="img2img_mask_blend_offset") with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -887,7 +887,7 @@ def select_img2img_tab(tab): (mask_blur, "Mask blur"), (mask_blend_power, "Mask blending bias"), (mask_blend_scale, "Mask blending preservation"), - (inpaint_detail_preservation, "Mask blending detail boost"), + (inpaint_detail_preservation, "Mask blending contrast boost"), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index 54d95825aba..bd9cb61bf82 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -134,8 +134,8 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) - inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id=self.elem_id("inpaint_detail_preservation")) + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale")) + inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index e3acb3d4701..5388f5db462 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -23,8 +23,8 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) - inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id=self.elem_id("inpaint_detail_preservation")) + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale")) + inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation")) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) From c7a1ff87207544dd4bcf3aefffa67a4a38678c16 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 23:31:10 -0700 Subject: [PATCH 035/449] Tweaked default values. --- modules/processing.py | 4 ++-- modules/sd_samplers_cfg_denoiser.py | 4 ++-- test/test_img2img.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 361e8b05d72..92fdebadd98 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1350,8 +1350,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_y: int = 4 mask_blur: int = None mask_blend_power: float = 1 - mask_blend_scale: float = 1 - inpaint_detail_preservation: float = 16 + mask_blend_scale: float = 0.5 + inpaint_detail_preservation: float = 4 inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 598cd487640..ceb612d7936 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -44,8 +44,8 @@ def __init__(self, sampler): self.mask = None self.nmask = None self.mask_blend_power = 1 - self.mask_blend_scale = 1 - self.inpaint_detail_preservation = 16 + self.mask_blend_scale = 0.5 + self.inpaint_detail_preservation = 4 self.init_latent = None self.steps = None """number of steps as specified by user in UI""" diff --git a/test/test_img2img.py b/test/test_img2img.py index 88b06eb8d7f..5cda2dbaefe 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -25,8 +25,8 @@ def simple_img2img_request(img2img_basic_image_base64): "mask": None, "mask_blur": 4, "mask_blend_power": 1, - "mask_blend_scale": 1, - "inpaint_detail_preservation": 16, + "mask_blend_scale": 0.5, + "inpaint_detail_preservation": 4, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From b25c126ccdbc4da22ade46597a9addf808998989 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:38:53 -0500 Subject: [PATCH 036/449] Protect alphas_cumprod from downcasting --- modules/sd_models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 841402e8629..de80a493a6f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -387,7 +387,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.upcast_sampling and depth_model: model.depth_model = None + alphas_cumprod = model.alphas_cumprod + model.alphas_cumprod = None model.half() + model.alphas_cumprod = alphas_cumprod + model.alphas_cumprod_original = alphas_cumprod model.first_stage_model = vae if depth_model: model.depth_model = depth_model @@ -642,6 +646,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): else: weight_dtype_conversion = { 'first_stage_model': None, + 'alphas_cumprod': None, '': torch.float16, } From 588a52891dca4d030ca7028dd9c0b56022a68b57 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:40:23 -0500 Subject: [PATCH 037/449] Add options for zero terminal SNR --- modules/shared_options.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/shared_options.py b/modules/shared_options.py index 04e68a712ef..51596777044 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -218,6 +218,7 @@ "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"), + "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.") })) options_templates.update(options_section(('interrogate', "Interrogate"), { @@ -335,6 +336,7 @@ 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'), 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), + 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise schedule for sampling").info("for use with zero terminal SNR trained models") })) options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), { From 6d0a8dcd892f7ad9b399fed6edbad6ede13c5f69 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:42:07 -0500 Subject: [PATCH 038/449] Implement zero terminal SNR schedule option --- modules/processing.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index ac58ef86940..c88eec70589 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -863,6 +863,34 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" + + def rescale_zero_terminal_snr_abar(alphas_cumprod): + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return alphas_bar + + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + print("rescaling noise schedule for zero snr") + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) From ec6ee5c13bf3453f8703e225a191333a9bbcf10a Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 29 Nov 2023 18:10:27 -0500 Subject: [PATCH 039/449] Fix infotext for ztSNR --- modules/shared_options.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 51596777044..bc3d56dec60 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -218,7 +218,7 @@ "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"), - "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.") + "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod") })) options_templates.update(options_section(('interrogate', "Interrogate"), { @@ -336,7 +336,7 @@ 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'), 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), - 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise schedule for sampling").info("for use with zero terminal SNR trained models") + 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models") })) options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), { From ffa7f8201d849636bb327b3b40298e7c169ff204 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 29 Nov 2023 18:10:43 -0500 Subject: [PATCH 040/449] Lint --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index c88eec70589..f3883d5b3c2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -863,7 +863,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - + def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar_sqrt = alphas_cumprod.sqrt() @@ -881,7 +881,7 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) if opts.use_downcasted_alpha_bar: From de79597ab9894965e3702939b8536ec3dcc3c859 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 29 Nov 2023 18:33:32 -0500 Subject: [PATCH 041/449] Only apply ztSNR related code if alphas_cumprod exists --- modules/processing.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f3883d5b3c2..7e73d7e2988 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -882,15 +882,16 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - print("rescaling noise schedule for zero snr") - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + print("rescaling noise schedule for zero snr") + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) From 668ae34e21df848ef4909b8b49c4142a3674701b Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 22:48:31 -0500 Subject: [PATCH 042/449] remove debug print --- modules/processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 7e73d7e2988..d73c8bfc1c9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -890,7 +890,6 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) if opts.sd_noise_schedule == "Zero Terminal SNR": p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - print("rescaling noise schedule for zero snr") p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): From 50a21cb09fe3e9ea2d4fe058e0484e192c8a86e3 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 2 Dec 2023 22:06:47 +0800 Subject: [PATCH 043/449] Ensure the cached weight will not be affected --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 4b8a9ae653a..dcf816b3a85 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -435,9 +435,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer for module in model.modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): if shared.opts.cache_fp16_weight: - module.fp16_weight = module.weight.clone().half() + module.fp16_weight = module.weight.data.clone().cpu().half() if module.bias is not None: - module.fp16_bias = module.bias.clone().half() + module.fp16_bias = module.bias.data.clone().cpu().half() module.to(torch.float8_e4m3fn) model.first_stage_model = first_stage timer.record("apply fp8") From 309a606c2fa645b6b8623f96ea56117e685a47fb Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 2 Dec 2023 13:07:45 -0500 Subject: [PATCH 044/449] ensure that original alpha bar always exists --- modules/processing.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index d73c8bfc1c9..bfa59038b10 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -882,15 +882,17 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + if hasattr(p.sd_model, 'alphas_cumprod') and not hasattr(p.sd_model, 'alphas_cumprod_original'): + p.sd_model.alphas_cumprod_original = p.sd_model.alphas_cumprod + + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) From 81c4ddf6ebebe6f18338de3b0391da1d8521a525 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 2 Dec 2023 13:11:00 -0500 Subject: [PATCH 045/449] fix linting --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index bfa59038b10..eeccea743a7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -884,7 +884,7 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): if hasattr(p.sd_model, 'alphas_cumprod') and not hasattr(p.sd_model, 'alphas_cumprod_original'): p.sd_model.alphas_cumprod_original = p.sd_model.alphas_cumprod - + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) if opts.use_downcasted_alpha_bar: From 4a43334376d9e116f7a1446f042f9af9c0484fc6 Mon Sep 17 00:00:00 2001 From: drhead Date: Sat, 2 Dec 2023 14:05:42 -0500 Subject: [PATCH 046/449] Revert 309a606c --- modules/processing.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index eeccea743a7..d73c8bfc1c9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -882,17 +882,15 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - if hasattr(p.sd_model, 'alphas_cumprod') and not hasattr(p.sd_model, 'alphas_cumprod_original'): - p.sd_model.alphas_cumprod_original = p.sd_model.alphas_cumprod - - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) From dc1adeecdd02f3fb910481e808a6d60a77100fea Mon Sep 17 00:00:00 2001 From: drhead Date: Sat, 2 Dec 2023 14:06:56 -0500 Subject: [PATCH 047/449] Create alphas_cumprod_original on full precision path --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index de80a493a6f..976c7d5be7b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -374,6 +374,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.no_half: model.float() + model.alphas_cumprod_original = alphas_cumprod devices.dtype_unet = torch.float32 timer.record("apply float()") else: From 78acdcf677a96894651ff0d7d8287f2a994f3781 Mon Sep 17 00:00:00 2001 From: drhead Date: Sat, 2 Dec 2023 14:09:18 -0500 Subject: [PATCH 048/449] fix variable --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 976c7d5be7b..5a19a00a50a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -374,7 +374,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.no_half: model.float() - model.alphas_cumprod_original = alphas_cumprod + model.alphas_cumprod_original = model.alphas_cumprod devices.dtype_unet = torch.float32 timer.record("apply float()") else: From 609dea36ea919aa7db42fd4233c416a45c74578b Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sat, 2 Dec 2023 18:56:49 -0700 Subject: [PATCH 049/449] Added utility functions related to processing masks. --- modules/images.py | 191 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/modules/images.py b/modules/images.py index eb644733898..b5a0cead67a 100644 --- a/modules/images.py +++ b/modules/images.py @@ -776,3 +776,194 @@ def flatten(img, bgcolor): img = background return img.convert('RGB') + + +def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): + """ + Generalization convolution filter capable of applying + weighted mean, median, maximum, and minimum filters + parametrically using an arbitrary kernel. + + Args: + img (nparray): + The image, a 2-D array of floats, to which the filter is being applied. + kernel (nparray): + The kernel, a 2-D array of floats. + kernel_center (nparray): + The kernel center coordinate, a 1-D array with two elements. + percentile_min (float): + The lower bound of the histogram window used by the filter, + from 0 to 1. + percentile_max (float): + The upper bound of the histogram window used by the filter, + from 0 to 1. + min_width (float): + The minimum size of the histogram window bounds, in weight units. + Must be greater than 0. + + Returns: + (nparray): A filtered copy of the input image "img", a 2-D array of floats. + """ + + # Converts an index tuple into a vector. + def vec(x): + return np.array(x) + + kernel_min = -kernel_center + kernel_max = vec(kernel.shape) - kernel_center + + def weighted_histogram_filter_single(idx): + idx = vec(idx) + min_index = np.maximum(0, idx + kernel_min) + max_index = np.minimum(vec(img.shape), idx + kernel_max) + window_shape = max_index - min_index + + class WeightedElement: + """ + An element of the histogram, its weight + and bounds. + """ + def __init__(self, value, weight): + self.value: float = value + self.weight: float = weight + self.window_min: float = 0.0 + self.window_max: float = 1.0 + + # Collect the values in the image as WeightedElements, + # weighted by their corresponding kernel values. + values = [] + for window_tup in np.ndindex(tuple(window_shape)): + window_index = vec(window_tup) + image_index = window_index + min_index + centered_kernel_index = image_index - idx + kernel_index = centered_kernel_index + kernel_center + element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) + values.append(element) + + def sort_key(x: WeightedElement): + return x.value + + values.sort(key=sort_key) + + # Calculate the height of the stack (sum) + # and each sample's range they occupy in the stack + sum = 0 + for i in range(len(values)): + values[i].window_min = sum + sum += values[i].weight + values[i].window_max = sum + + # Calculate what range of this stack ("window") + # we want to get the weighted average across. + window_min = sum * percentile_min + window_max = sum * percentile_max + window_width = window_max - window_min + + # Ensure the window is within the stack and at least a certain size. + if window_width < min_width: + window_center = (window_min + window_max) / 2 + window_min = window_center - min_width / 2 + window_max = window_center + min_width / 2 + + if window_max > sum: + window_max = sum + window_min = sum - min_width + + if window_min < 0: + window_min = 0 + window_max = min_width + + value = 0 + value_weight = 0 + + # Get the weighted average of all the samples + # that overlap with the window, weighted + # by the size of their overlap. + for i in range(len(values)): + if window_min >= values[i].window_max: + continue + if window_max <= values[i].window_min: + break + + s = max(window_min, values[i].window_min) + e = min(window_max, values[i].window_max) + w = e - s + + value += values[i].value * w + value_weight += w + + return value / value_weight if value_weight != 0 else 0 + + img_out = img.copy() + + # Apply the kernel operation over each pixel. + for index in np.ndindex(img.shape): + img_out[index] = weighted_histogram_filter_single(index) + + return img_out + +def smoothstep(x): + """ + The smoothstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * (3 - 2 * x) + +def smootherstep(x): + """ + The smootherstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * x * (x * (6 * x - 15) + 10) + + +def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): + """ + Creates a Gaussian kernel with thresholded edges. + + Args: + stddev_radius (float): + Standard deviation of the gaussian kernel, in pixels. + max_radius (int): + The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. + The kernel is thresholded so that any values one pixel beyond this radius + is weighted at 0. + + Returns: + (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) + """ + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. + def gaussian(sqr_mag): + return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) + + # Helper function for converting a tuple to an array. + def vec(x): + return np.array(x) + + """ + Since a gaussian is unbounded, we need to limit ourselves + to a finite range. + We taper the ends off at the end of that range so they equal zero + while preserving the maximum value of 1 at the mean. + """ + zero_radius = max_radius + 1.0 + gauss_zero = gaussian(zero_radius * zero_radius) + gauss_kernel_scale = 1 / (1 - gauss_zero) + + def gaussian_kernel_func(coordinate): + x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 + x = gaussian(x) + x -= gauss_zero + x /= gauss_kernel_scale + x = max(0.0, x) + return x + + size = max_radius * 2 + 1 + kernel_center = max_radius + kernel = np.zeros((size, size)) + + for index in np.ndindex(kernel.shape): + kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) + + return kernel, kernel_center + From 73ab982d1b7394574d1cf2e0a151bc457eeed769 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sat, 2 Dec 2023 21:07:02 -0700 Subject: [PATCH 050/449] Blend masks are now produced afterward, based on an estimate of the visual difference between the original and modified latent images. This should remove ghosting and clipping artifacts from masks, while preserving the details of largely unchanged content. --- modules/processing.py | 119 ++++++++++++++++++++++++++++++++---------- 1 file changed, 90 insertions(+), 29 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 92fdebadd98..ad716e11f7c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -9,7 +9,7 @@ import torch import numpy as np -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageFilter import random import cv2 from skimage import exposure @@ -62,6 +62,16 @@ def apply_color_correction(correction, original_image): return image.convert('RGB') +def uncrop(image, dest_size, paste_loc): + x, y, w, h = paste_loc + base_image = Image.new('RGBA', dest_size) + image = images.resize_image(1, image, w, h) + base_image.paste(image, (x, y)) + image = base_image + + return image + + def apply_overlay(image, paste_loc, index, overlays): if overlays is None or index >= len(overlays): return image @@ -69,11 +79,7 @@ def apply_overlay(image, paste_loc, index, overlays): overlay = overlays[index] if paste_loc is not None: - x, y, w, h = paste_loc - base_image = Image.new('RGBA', (overlay.width, overlay.height)) - image = images.resize_image(1, image, w, h) - base_image.paste(image, (x, y)) - image = base_image + image = uncrop(image, (overlay.width, overlay.height), paste_loc) image = image.convert('RGBA') image.alpha_composite(overlay) @@ -140,6 +146,7 @@ class StableDiffusionProcessing: do_not_save_grid: bool = False extra_generation_params: dict[str, Any] = None overlay_images: list = None + masks_for_overlay: list = None eta: float = None do_not_reload_embeddings: bool = False denoising_strength: float = 0 @@ -865,11 +872,66 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim + # todo: generate masks the old fashioned way else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + # Generate the mask(s) based on similarity between the original and denoised latent vectors + if getattr(p, "image_mask", None) is not None: + # latent_mask = p.nmask[0].float().cpu() + + # convert the original mask into a form we use to scale distances for thresholding + # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) + # mask_scalar = mask_scalar / (1.00001-mask_scalar) + # mask_scalar = mask_scalar.numpy() + + latent_orig = p.init_latent + latent_proc = samples_ddim + latent_distance = torch.norm(latent_proc - latent_orig, p=2, dim=1) + + kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, p.overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + # half_weighted_distance = 1 # * mask_scalar + # converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = images.smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, p.width, p.height) + converted_mask = create_binary_mask(converted_mask) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if p.paste_to is not None: + converted_mask = uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + p.paste_to) + + p.masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + p.overlay_images[i] = image_masked.convert('RGBA') + + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, + target_device=devices.cpu, + check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -892,7 +954,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = batch_params.images def infotext(index=0, use_main_prompt=False): - return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts) + return create_infotext(p, p.prompts, p.seeds, p.subseeds, + use_main_prompt=use_main_prompt, index=index, + all_negative_prompts=p.negative_prompts) save_samples = p.save_samples() @@ -923,19 +987,27 @@ def infotext(index=0, use_main_prompt=False): images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) + # If the intention is to show the output from the model + # that is being composited over the original image, + # we need to keep the original image around + # and use it in the composite step. + original_denoised_image = image.copy() image = apply_overlay(image, p.paste_to, i, p.overlay_images) if save_samples: - images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) + images.save_image(image, p.outpath_samples, "", p.seeds[i], + p.prompts[i], opts.samples_format, info=infotext(i), p=p) text = infotext(i) infotexts.append(text) if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): - image_mask = p.mask_for_overlay.convert('RGB') - image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + if save_samples and hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): + image_mask = p.masks_for_overlay[i].convert('RGB') + image_mask_composite = Image.composite( + original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), + images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') if opts.save_mask: images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") @@ -1364,7 +1436,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): nmask: torch.Tensor = field(default=None, init=False) image_conditioning: torch.Tensor = field(default=None, init=False) init_img_hash: str = field(default=None, init=False) - mask_for_overlay: Image = field(default=None, init=False) init_latent: torch.Tensor = field(default=None, init=False) def __post_init__(self): @@ -1415,12 +1486,6 @@ def init(self, all_prompts, all_seeds, all_subseeds): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: - np_mask = np.array(image_mask).astype(np.float32) - np_mask /= 255 - np_mask = 1-pow(1-np_mask, 100) - np_mask *= 255 - np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) - self.mask_for_overlay = Image.fromarray(np_mask) mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1431,13 +1496,8 @@ def init(self, all_prompts, all_seeds, all_subseeds): self.paste_to = (x1, y1, x2-x1, y2-y1) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) - np_mask = np.array(image_mask).astype(np.float32) - np_mask /= 255 - np_mask = 1-pow(1-np_mask, 100) - np_mask *= 255 - np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) - self.mask_for_overlay = Image.fromarray(np_mask) + self.masks_for_overlay = [] self.overlay_images = [] latent_mask = self.latent_mask if self.latent_mask is not None else image_mask @@ -1459,10 +1519,8 @@ def init(self, all_prompts, all_seeds, all_subseeds): image = images.resize_image(self.resize_mode, image, self.width, self.height) if image_mask is not None: - image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) - - self.overlay_images.append(image_masked.convert('RGBA')) + self.overlay_images.append(image) + self.masks_for_overlay.append(image_mask) # crop_region is not None if we are doing inpaint full res if crop_region is not None: @@ -1486,6 +1544,9 @@ def init(self, all_prompts, all_seeds, all_subseeds): if self.overlay_images is not None: self.overlay_images = self.overlay_images * self.batch_size + if self.masks_for_overlay is not None: + self.masks_for_overlay = self.masks_for_overlay * self.batch_size + if self.color_corrections is not None and len(self.color_corrections) == 1: self.color_corrections = self.color_corrections * self.batch_size From bb04d400c95df01d191ef6c1a43e66b95425fa33 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sat, 2 Dec 2023 21:08:26 -0700 Subject: [PATCH 051/449] Rewrote latent_blend() to use in-place operations and to aggressively "del" references with the intention of minimizing allocations and easing garbage collection. --- modules/sd_samplers_cfg_denoiser.py | 41 ++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index ceb612d7936..efbe7a403de 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -102,29 +102,44 @@ def latent_blend(a, b, t): The "detail_preservation" factor biases the magnitude interpolation towards the larger of the two magnitudes. """ - # Record the original latent vector magnitudes. - # We bring them to a power so that larger magnitudes are favored over smaller ones. - # 64-bit operations are used here to allow large exponents. - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation + # NOTE: We use inplace operations wherever possible. one_minus_t = 1 - t - # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / self.inpaint_detail_preservation) - # Linearly interpolate the image vectors. - image_interp = a * one_minus_t + b * t + a_scaled = a * one_minus_t + b_scaled = b * t + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) # 64-bit operations are used here to allow large exponents. - image_interp_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64) + 0.0001 + current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * one_minus_t + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * t + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / self.inpaint_detail_preservation) + del a_magnitude, b_magnitude, one_minus_t # Change the linearly interpolated image vectors' magnitudes to the value we want. # This is the last 64-bit operation. - image_interp *= (interp_magnitude / image_interp_magnitude).to(image_interp.dtype) - - return image_interp + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor + + image_interp_scaled = image_interp_scaled.to(result_type) + del result_type + + return image_interp_scaled def get_modified_nmask(nmask, _sigma): """ From 28a2b5b4aab43424733039c31d910e8b8dd507cd Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sun, 3 Dec 2023 14:20:20 -0700 Subject: [PATCH 052/449] Fixed a math mistake. --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index 6648097e851..9495349866b 100644 --- a/modules/images.py +++ b/modules/images.py @@ -969,7 +969,7 @@ def gaussian_kernel_func(coordinate): x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 x = gaussian(x) x -= gauss_zero - x /= gauss_kernel_scale + x *= gauss_kernel_scale x = max(0.0, x) return x From 552f8bc832cd21ee0338e08b6a701687d0d79fad Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sun, 3 Dec 2023 14:49:41 -0700 Subject: [PATCH 053/449] "Uncrop" the original denoised image for the composite step, fixing a "ValueError: Images do not match" *shudder* --- modules/processing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index 66aaab83150..cd7216f8313 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -994,6 +994,10 @@ def infotext(index=0, use_main_prompt=False): # we need to keep the original image around # and use it in the composite step. original_denoised_image = image.copy() + + if p.paste_to is not None: + original_denoised_image = uncrop(original_denoised_image, (p.overlay_images[i].width, p.overlay_images[i].height), p.paste_to) + image = apply_overlay(image, p.paste_to, i, p.overlay_images) if save_samples: From aaacf4823241450d88315af9d465d6815119fe0d Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 01:27:22 -0700 Subject: [PATCH 054/449] Organized the settings and UI of soft inpainting to allow for toggling the feature, and centralizes default values to reduce the amount of copy-pasta. --- modules/img2img.py | 14 +-- modules/processing.py | 5 +- modules/sd_samplers_cfg_denoiser.py | 35 +++++--- modules/sd_samplers_common.py | 4 +- modules/soft_inpainting.py | 133 ++++++++++++++++++++++++++++ modules/ui.py | 17 ++-- scripts/outpainting_mk_2.py | 15 ++-- scripts/poor_mans_outpainting.py | 15 ++-- test/test_img2img.py | 8 +- 9 files changed, 197 insertions(+), 49 deletions(-) create mode 100644 modules/soft_inpainting.py diff --git a/modules/img2img.py b/modules/img2img.py index 596f741c1ef..3aa8a9cef6c 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -15,6 +15,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html import modules.scripts +import modules.soft_inpainting as si def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): @@ -162,6 +163,7 @@ def img2img(id_task: str, sampler_name: str, mask_blur: int, mask_alpha: float, + mask_blend_enabled: bool, mask_blend_power: float, mask_blend_scale: float, inpaint_detail_preservation: float, @@ -227,6 +229,9 @@ def img2img(id_task: str, assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' + soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ + if mask_blend_enabled else None + p = StableDiffusionProcessingImg2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples, @@ -244,9 +249,7 @@ def img2img(id_task: str, init_images=[image], mask=mask, mask_blur=mask_blur, - mask_blend_power=mask_blend_power, - mask_blend_scale=mask_blend_scale, - inpaint_detail_preservation=inpaint_detail_preservation, + soft_inpainting=soft_inpainting, inpainting_fill=inpainting_fill, resize_mode=resize_mode, denoising_strength=denoising_strength, @@ -267,9 +270,8 @@ def img2img(id_task: str, if mask: p.extra_generation_params["Mask blur"] = mask_blur - p.extra_generation_params["Mask blending bias"] = mask_blend_power - p.extra_generation_params["Mask blending preservation"] = mask_blend_scale - p.extra_generation_params["Mask blending contrast boost"] = inpaint_detail_preservation + if soft_inpainting is not None: + soft_inpainting.add_generation_params(p.extra_generation_params) with closing(p): if is_batch: diff --git a/modules/processing.py b/modules/processing.py index cd7216f8313..b209c84a31d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -30,6 +30,7 @@ import modules.sd_vae as sd_vae from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion +import modules.soft_inpainting as si from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType @@ -1425,9 +1426,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_x: int = 4 mask_blur_y: int = 4 mask_blur: int = None - mask_blend_power: float = 1 - mask_blend_scale: float = 0.5 - inpaint_detail_preservation: float = 4 + soft_inpainting: si.SoftInpaintingParameters = si.default inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index efbe7a403de..0ee0b7dde03 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -6,6 +6,7 @@ from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback +import modules.soft_inpainting as si def catenate_conds(conds): @@ -43,9 +44,7 @@ def __init__(self, sampler): self.model_wrap = None self.mask = None self.nmask = None - self.mask_blend_power = 1 - self.mask_blend_scale = 0.5 - self.inpaint_detail_preservation = 4 + self.soft_inpainting: si.SoftInpaintingParameters = None self.init_latent = None self.steps = None """number of steps as specified by user in UI""" @@ -95,7 +94,8 @@ def update_inner_model(self): self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - def latent_blend(a, b, t): + def latent_blend(a, b, t, one_minus_t=None): + """ Interpolates two latent image representations according to the parameter t, where the interpolated vectors' magnitudes are also interpolated separately. @@ -104,7 +104,11 @@ def latent_blend(a, b, t): """ # NOTE: We use inplace operations wherever possible. - one_minus_t = 1 - t + if one_minus_t is None: + one_minus_t = 1 - t + + if self.soft_inpainting is None: + return a * one_minus_t + b * t # Linearly interpolate the image vectors. a_scaled = a * one_minus_t @@ -119,10 +123,10 @@ def latent_blend(a, b, t): current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * one_minus_t - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * t + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * one_minus_t + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * t desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / self.inpaint_detail_preservation) + desired_magnitude.add_(b_magnitude).pow_(1 / self.soft_inpainting.inpaint_detail_preservation) del a_magnitude, b_magnitude, one_minus_t # Change the linearly interpolated image vectors' magnitudes to the value we want. @@ -156,7 +160,10 @@ def get_modified_nmask(nmask, _sigma): NOTE: "mask" is not used """ - return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale) + if self.soft_inpainting is None: + return nmask + + return torch.pow(nmask, (_sigma ** self.soft_inpainting.mask_blend_power) * self.soft_inpainting.mask_blend_scale) if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -176,7 +183,10 @@ def get_modified_nmask(nmask, _sigma): # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: - x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma)) + if self.soft_inpainting is None: + x = latent_blend(self.init_latent, x, self.nmask, self.mask) + else: + x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma)) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -279,7 +289,10 @@ def get_modified_nmask(nmask, _sigma): # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: - denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma)) + if self.soft_inpainting is None: + denoised = latent_blend(self.init_latent, denoised, self.nmask, self.mask) + else: + denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma)) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index ecd8ab0a093..9682bee3d8e 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -277,9 +277,7 @@ def initialize(self, p) -> dict: self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.mask_blend_power = p.mask_blend_power if hasattr(p, 'mask_blend_power') else None - self.model_wrap_cfg.mask_blend_scale = p.mask_blend_scale if hasattr(p, 'mask_blend_scale') else None - self.model_wrap_cfg.inpaint_detail_preservation = p.inpaint_detail_preservation if hasattr(p, 'inpaint_detail_preservation') else None + self.model_wrap_cfg.soft_inpainting = p.soft_inpainting if hasattr(p, 'soft_inpainting') else None self.model_wrap_cfg.step = 0 self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py new file mode 100644 index 00000000000..259c36ec892 --- /dev/null +++ b/modules/soft_inpainting.py @@ -0,0 +1,133 @@ +class SoftInpaintingSettings: + def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): + self.mask_blend_power = mask_blend_power + self.mask_blend_scale = mask_blend_scale + self.inpaint_detail_preservation = inpaint_detail_preservation + + def get_paste_fields(self): + return [ + (self.mask_blend_power, gen_param_labels.mask_blend_power), + (self.mask_blend_scale, gen_param_labels.mask_blend_scale), + (self.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation), + ] + + def add_generation_params(self, dest): + dest[enabled_gen_param_label] = True + dest[gen_param_labels.mask_blend_power] = self.mask_blend_power + dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale + dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + + +enabled_ui_label = "Soft inpainting" +enabled_gen_param_label = "Soft inpainting enabled" +enabled_el_id = "soft_inpainting_enabled" + +default = SoftInpaintingSettings(1, 0.5, 4) +ui_labels = SoftInpaintingSettings("Schedule bias", "Preservation strength", "Transition contrast boost") + +ui_info = SoftInpaintingSettings( + mask_blend_power="Shifts when preservation of original content occurs during denoising.", + # "Below 1: Stronger preservation near the end (with low sigma)\n" + # "1: Balanced (proportional to sigma)\n" + # "Above 1: Stronger preservation in the beginning (with high sigma)", + mask_blend_scale="How strongly partially masked content should be preserved.", + # "Low values: Favors generated content.\n" + # "High values: Favors original content.", + inpaint_detail_preservation="Amplifies the contrast that may be lost in partially masked regions.") + +gen_param_labels = SoftInpaintingSettings("Soft inpainting schedule bias", "Soft inpainting preservation strength", "Soft inpainting transition contrast boost") +el_ids = SoftInpaintingSettings("mask_blend_power", "mask_blend_scale", "inpaint_detail_preservation") + + +def gradio_ui(): + import gradio as gr + from modules.ui_components import InputAccordion + """ + with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner: + with gr.Row(): + refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation") + create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) + + refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation") + + """ + with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: + with gr.Group(): + gr.Markdown( + """ + Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. + **High _Mask blur_** values are recommended! + """) + + result = SoftInpaintingSettings( + gr.Slider(label=ui_labels.mask_blend_power, + info=ui_info.mask_blend_power, + minimum=0, + maximum=8, + step=0.1, + value=default.mask_blend_power, + elem_id=el_ids.mask_blend_power), + gr.Slider(label=ui_labels.mask_blend_scale, + info=ui_info.mask_blend_scale, + minimum=0, + maximum=8, + step=0.05, + value=default.mask_blend_scale, + elem_id=el_ids.mask_blend_scale), + gr.Slider(label=ui_labels.inpaint_detail_preservation, + info=ui_info.inpaint_detail_preservation, + minimum=1, + maximum=32, + step=0.5, + value=default.inpaint_detail_preservation, + elem_id=el_ids.inpaint_detail_preservation)) + + with gr.Accordion("Help", open=False): + gr.Markdown( + f""" + ### {ui_labels.mask_blend_power} + + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). + This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. + This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. + + - **Below 1**: Stronger preservation near the end (with low sigma) + - **1**: Balanced (proportional to sigma) + - **Above 1**: Stronger preservation in the beginning (with high sigma) + """) + gr.Markdown( + f""" + ### {ui_labels.mask_blend_scale} + + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. + This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. + + - **Low values**: Favors generated content. + - **High values**: Favors original content. + """) + gr.Markdown( + f""" + ### {ui_labels.inpaint_detail_preservation} + + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. + With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. + This can prevent the loss of contrast that occurs with linear interpolation. + + - **Low values**: Softer blending, details may fade. + - **High values**: Stronger contrast, may over-saturate colors. + """) + + return ( + [ + soft_inpainting_enabled, + result.mask_blend_power, + result.mask_blend_scale, + result.inpaint_detail_preservation + ], + [ + (soft_inpainting_enabled, enabled_gen_param_label), + (result.mask_blend_power, gen_param_labels.mask_blend_power), + (result.mask_blend_scale, gen_param_labels.mask_blend_scale), + (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation) + ] + ) diff --git a/modules/ui.py b/modules/ui.py index b13ed66cbaa..0e4fb17aa75 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -29,6 +29,7 @@ from modules import prompt_parser from modules.sd_hijack import model_hijack from modules.generation_parameters_copypaste import image_from_url_text +import modules.soft_inpainting as si create_setting_component = ui_settings.create_setting_component @@ -678,9 +679,16 @@ def copy_image(img): with FormRow(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + + with FormRow(): + soft_inpainting = si.gradio_ui() + + + """ mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id="img2img_mask_blend_scale") inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id="img2img_mask_blend_offset") + """ with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -736,9 +744,7 @@ def select_img2img_tab(tab): sampler_name, mask_blur, mask_alpha, - mask_blend_power, - mask_blend_scale, - inpaint_detail_preservation, + *(soft_inpainting[0]), inpainting_fill, batch_count, batch_size, @@ -837,11 +843,10 @@ def select_img2img_tab(tab): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), - (mask_blend_power, "Mask blending bias"), - (mask_blend_scale, "Mask blending preservation"), - (inpaint_detail_preservation, "Mask blending contrast boost"), + *(soft_inpainting[1]), *scripts.scripts_img2img.infotext_fields ] + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index bd9cb61bf82..f78886883bd 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -10,6 +10,7 @@ from modules import images from modules.processing import Processed, process_images from modules.shared import opts, state +import modules.soft_inpainting as si # this function is taken from https://github.com/parlance-zz/g-diffuser-bot @@ -133,16 +134,14 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) - mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale")) - inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation")) + soft_inpainting = si.gradio_ui()[0] direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - return [info, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation] + return [info, pixels, mask_blur, *soft_inpainting, direction, noise_q, color_variation] - def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation): + def run(self, p, _, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation): initial_seed_and_info = [None, None] process_width = p.width @@ -170,9 +169,9 @@ def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpai p.mask_blur_x = mask_blur_x*4 p.mask_blur_y = mask_blur_y*4 - p.mask_blend_power = mask_blend_power - p.mask_blend_scale = mask_blend_scale - p.inpaint_detail_preservation = inpaint_detail_preservation + + p.soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ + if mask_blend_enabled else None init_img = p.init_images[0] target_w = math.ceil((init_img.width + left + right) / 64) * 64 diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 5388f5db462..11f7f74a820 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -7,6 +7,7 @@ from modules import images, devices from modules.processing import Processed, process_images from modules.shared import opts, state +import modules.soft_inpainting as si class Script(scripts.Script): @@ -22,23 +23,19 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) - mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale")) - inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation")) + soft_inpainting = si.gradio_ui()[0] inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - return [pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction] + return [pixels, mask_blur, *soft_inpainting, inpainting_fill, direction] - def run(self, p, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction): + def run(self, p, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction): initial_seed = None initial_info = None p.mask_blur = mask_blur * 2 - p.mask_blend_power = mask_blend_power - p.mask_blend_scale = mask_blend_scale - p.inpaint_detail_preservation = inpaint_detail_preservation - + p.soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ + if mask_blend_enabled else None p.inpainting_fill = inpainting_fill p.inpaint_full_res = False diff --git a/test/test_img2img.py b/test/test_img2img.py index 5cda2dbaefe..87bd850912e 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -1,6 +1,7 @@ import pytest import requests +import modules.soft_inpainting as si @pytest.fixture() @@ -24,9 +25,10 @@ def simple_img2img_request(img2img_basic_image_base64): "inpainting_mask_invert": False, "mask": None, "mask_blur": 4, - "mask_blend_power": 1, - "mask_blend_scale": 0.5, - "inpaint_detail_preservation": 4, + "mask_blend_enabled": True, + "mask_blend_power": si.default.mask_blend_power, + "mask_blend_scale": si.default.mask_blend_scale, + "inpaint_detail_preservation": si.default.inpaint_detail_preservation, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From 259d33c3c8e27557cb9bab9b3a1dd7fc7450d16c Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 01:57:21 -0700 Subject: [PATCH 055/449] Enables the original functionality to be toggled on and off. --- modules/processing.py | 99 ++++++++++++++++++++++++++++++------------- 1 file changed, 70 insertions(+), 29 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index b209c84a31d..b40b1a40d49 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -88,9 +88,12 @@ def apply_overlay(image, paste_loc, index, overlays): return image -def create_binary_mask(image): +def create_binary_mask(image, round=True): if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255): - image = image.split()[-1].convert("L") + if round: + image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) + else: + image = image.split()[-1].convert("L") else: image = image.convert('L') return image @@ -316,7 +319,7 @@ def unclip_image_conditioning(self, source_image): c_adm = torch.cat((c_adm, noise_level_emb), 1) return c_adm - def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): + def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True): self.is_using_inpainting_conditioning = True # Handle the different mask inputs @@ -327,6 +330,11 @@ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=N conditioning_mask = np.array(image_mask.convert("L")) conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + if round_image_mask: + # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:]) @@ -350,7 +358,7 @@ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=N return image_conditioning - def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): + def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True): source_image = devices.cond_cast_float(source_image) # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely @@ -362,7 +370,10 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None return self.edit_image_conditioning(source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + return self.inpainting_image_conditioning(source_image, + latent_image, + image_mask=image_mask, + round_image_mask=round_image_mask) if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) @@ -878,8 +889,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method + # Generate the mask(s) based on similarity between the original and denoised latent vectors - if getattr(p, "image_mask", None) is not None: + if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: # latent_mask = p.nmask[0].float().cpu() # convert the original mask into a form we use to scale distances for thresholding @@ -911,7 +923,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: converted_mask = converted_mask.astype(np.uint8) converted_mask = Image.fromarray(converted_mask) converted_mask = images.resize_image(2, converted_mask, p.width, p.height) - converted_mask = create_binary_mask(converted_mask) + converted_mask = create_binary_mask(converted_mask, round=False) # Remove aliasing artifacts using a gaussian blur. converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) @@ -1010,23 +1022,33 @@ def infotext(index=0, use_main_prompt=False): if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - if save_samples and hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): - image_mask = p.masks_for_overlay[i].convert('RGB') - image_mask_composite = Image.composite( - original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), - images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') - - if opts.save_mask: - images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") - - if opts.save_mask_composite: - images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") - - if opts.return_mask: - output_images.append(image_mask) - - if opts.return_mask_composite: - output_images.append(image_mask_composite) + if save_samples and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): + if hasattr(p, 'masks_for_overlay') and p.masks_for_overlay: + image_mask = p.masks_for_overlay[i].convert('RGB') + image_mask_composite = Image.composite( + original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), + images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') + elif hasattr(p, 'mask_for_overlay') and p.mask_for_overlay: + image_mask = p.mask_for_overlay.convert('RGB') + image_mask_composite = Image.composite( + original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), + images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + else: + image_mask = None + image_mask_composite = None + + if image_mask is not None and image_mask_composite is not None: + if opts.save_mask: + images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") + + if opts.save_mask_composite: + images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") + + if opts.return_mask: + output_images.append(image_mask) + + if opts.return_mask_composite: + output_images.append(image_mask_composite) del x_samples_ddim @@ -1439,6 +1461,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): nmask: torch.Tensor = field(default=None, init=False) image_conditioning: torch.Tensor = field(default=None, init=False) init_img_hash: str = field(default=None, init=False) + mask_for_overlay: Image = field(default=None, init=False) init_latent: torch.Tensor = field(default=None, init=False) def __post_init__(self): @@ -1471,7 +1494,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): if image_mask is not None: # image_mask is passed in as RGBA by Gradio to support alpha masks, # but we still want to support binary masks. - image_mask = create_binary_mask(image_mask) + image_mask = create_binary_mask(image_mask, round=(self.soft_inpainting is None)) if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) @@ -1489,6 +1512,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: + self.mask_for_overlay = image_mask if self.soft_inpainting is None else None mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1500,7 +1524,12 @@ def init(self, all_prompts, all_seeds, all_subseeds): else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) - self.masks_for_overlay = [] + if self.soft_inpainting is None: + np_mask = np.array(image_mask) + np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) + self.mask_for_overlay = Image.fromarray(np_mask) + + self.masks_for_overlay = [] if self.soft_inpainting is not None else None self.overlay_images = [] latent_mask = self.latent_mask if self.latent_mask is not None else image_mask @@ -1522,8 +1551,15 @@ def init(self, all_prompts, all_seeds, all_subseeds): image = images.resize_image(self.resize_mode, image, self.width, self.height) if image_mask is not None: - self.overlay_images.append(image) - self.masks_for_overlay.append(image_mask) + if self.soft_inpainting is not None: + # We apply the masks AFTER to adjust mask based on changed content. + self.overlay_images.append(image) + self.masks_for_overlay.append(image_mask) + else: + image_masked = Image.new('RGBa', (image.width, image.height)) + image_masked.paste(image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + self.overlay_images.append(image_masked.convert('RGBA')) # crop_region is not None if we are doing inpaint full res if crop_region is not None: @@ -1576,6 +1612,8 @@ def init(self, all_prompts, all_seeds, all_subseeds): latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = latmask[0] + if self.soft_inpainting is None: + latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) @@ -1587,7 +1625,10 @@ def init(self, all_prompts, all_seeds, all_subseeds): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask) + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, + self.init_latent, + image_mask, + self.soft_inpainting is None) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = self.rng.next() From 15322e1b1a9e31edcc2f7d72a32d02365058737d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 12:36:41 +0300 Subject: [PATCH 056/449] repair old handler for postprocessing API --- modules/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 3c85a74c1b3..d166f859b68 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -153,4 +153,4 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ }, }) - return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) + return run_postprocessing("", extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) From 883d6a2b34a2817304d23c2481a6f9fc56687a53 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 13:11:00 +0300 Subject: [PATCH 057/449] repair old handler for postprocessing API in a way that doesn't break interface --- modules/postprocessing.py | 8 ++++++-- modules/ui_postprocessing.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index d166f859b68..0c59fad480b 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -6,7 +6,7 @@ from modules.shared import opts -def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): +def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): devices.torch_gc() shared.state.begin(job="extras") @@ -128,6 +128,10 @@ def get_images(extras_mode, image, image_folder, input_dir): return outputs, ui_common.plaintext_to_html(infotext), '' +def run_postprocessing_webui(id_task, *args, **kwargs): + return run_postprocessing(*args, **kwargs) + + def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): """old handler for API""" @@ -153,4 +157,4 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ }, }) - return run_postprocessing("", extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) + return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index fbad0800aca..13d888e48d1 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -35,7 +35,7 @@ def create_ui(): tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) submit.click( - fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']), _js="submit_extras", inputs=[ dummy_component, From 22e23dbf29b0bbc807daa57318c31145f8dd0774 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 15:56:03 +0300 Subject: [PATCH 058/449] add hypertile infotext --- .../hypertile/scripts/hypertile_script.py | 53 +++++++++++++++---- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py index d3ab6091500..395d584b605 100644 --- a/extensions-builtin/hypertile/scripts/hypertile_script.py +++ b/extensions-builtin/hypertile/scripts/hypertile_script.py @@ -17,11 +17,42 @@ def process(self, p, *args): configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet) + self.add_infotext(p) + def before_hr(self, p, *args): + + enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet + # exclusive hypertile seed for the second pass - if not shared.opts.hypertile_enable_unet: + if enable: hypertile.set_hypertile_seed(p.all_seeds[0]) - configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass) + + configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable) + + if enable and not shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net second pass"] = True + + self.add_infotext(p, add_unet_params=True) + + def add_infotext(self, p, add_unet_params=False): + def option(name): + value = getattr(shared.opts, name) + default_value = shared.opts.get_default(name) + return None if value == default_value else value + + if shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net"] = True + + if shared.opts.hypertile_enable_unet or add_unet_params: + p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet') + p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet') + p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet') + + if shared.opts.hypertile_enable_vae: + p.extra_generation_params["Hypertile VAE"] = True + p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae') + p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae') + p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae') def configure_hypertile(width, height, enable_unet=True): @@ -57,16 +88,16 @@ def on_ui_settings(): benefit. """), - "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"), - "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"), - "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), - "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), - "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), + "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"), + "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"), + "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"), + "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"), + "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"), - "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"), - "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), - "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), - "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), + "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"), + "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"), + "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"), + "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"), } for name, opt in options.items(): From 854f8c318c2610c76259056ab02739176aa849e8 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 5 Dec 2023 04:40:12 +0900 Subject: [PATCH 059/449] remove clean_text() --- modules/styles.py | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 4d218cd7ed0..7fb6c2e1177 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -2,7 +2,6 @@ import fnmatch import os import os.path -import re import typing import shutil @@ -14,22 +13,6 @@ class PromptStyle(typing.NamedTuple): path: str = None -def clean_text(text: str) -> str: - """ - Iterating through a list of regular expressions and replacement strings, we - clean up the prompt and style text to make it easier to match against each - other. - """ - re_list = [ - ("multiple commas", re.compile("(,+\s+)+,?"), ", "), - ("multiple spaces", re.compile("\s{2,}"), " "), - ] - for _, regex, replace in re_list: - text = regex.sub(replace, text) - - return text.strip(", ") - - def merge_prompts(style_prompt: str, prompt: str) -> str: if "{prompt}" in style_prompt: res = style_prompt.replace("{prompt}", prompt) @@ -44,7 +27,7 @@ def apply_styles_to_prompt(prompt, styles): for style in styles: prompt = merge_prompts(style, prompt) - return clean_text(prompt) + return prompt def unwrap_style_text_from_prompt(style_text, prompt): @@ -56,8 +39,8 @@ def unwrap_style_text_from_prompt(style_text, prompt): Note that the "cleaned" version of the style text is only used for matching purposes here. It isn't returned; the original style text is not modified. """ - stripped_prompt = clean_text(prompt) - stripped_style_text = clean_text(style_text) + stripped_prompt = prompt + stripped_style_text = style_text if "{prompt}" in stripped_style_text: # Work out whether the prompt is wrapped in the style text. If so, we # return True and the "inner" prompt text that isn't part of the style. From 976c1053efeb5054692ed3cfa294cf79196f3946 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 16:06:58 -0700 Subject: [PATCH 060/449] Cleaned up code, moved main code contributions into soft_inpainting.py --- modules/processing.py | 56 ++------- modules/sd_samplers_cfg_denoiser.py | 84 ++----------- modules/soft_inpainting.py | 177 ++++++++++++++++++++++++---- modules/ui.py | 7 -- 4 files changed, 174 insertions(+), 150 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index b40b1a40d49..0b3603875a7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -892,55 +892,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: # Generate the mask(s) based on similarity between the original and denoised latent vectors if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - # latent_mask = p.nmask[0].float().cpu() - - # convert the original mask into a form we use to scale distances for thresholding - # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) - # mask_scalar = mask_scalar / (1.00001-mask_scalar) - # mask_scalar = mask_scalar.numpy() - - latent_orig = p.init_latent - latent_proc = samples_ddim - latent_distance = torch.norm(latent_proc - latent_orig, p=2, dim=1) - - kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) - - for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, p.overlay_images)): - converted_mask = distance_map.float().cpu().numpy() - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.9, percentile_max=1, min_width=1) - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.25, percentile_max=0.75, min_width=1) - - # The distance at which opacity of original decreases to 50% - # half_weighted_distance = 1 # * mask_scalar - # converted_mask = converted_mask / half_weighted_distance - - converted_mask = 1 / (1 + converted_mask ** 2) - converted_mask = images.smootherstep(converted_mask) - converted_mask = 1 - converted_mask - converted_mask = 255. * converted_mask - converted_mask = converted_mask.astype(np.uint8) - converted_mask = Image.fromarray(converted_mask) - converted_mask = images.resize_image(2, converted_mask, p.width, p.height) - converted_mask = create_binary_mask(converted_mask, round=False) - - # Remove aliasing artifacts using a gaussian blur. - converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) - - # Expand the mask to fit the whole image if needed. - if p.paste_to is not None: - converted_mask = uncrop(converted_mask, - (overlay_image.width, overlay_image.height), - p.paste_to) - - p.masks_for_overlay[i] = converted_mask - - image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) - image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(converted_mask.convert('L'))) - - p.overlay_images[i] = image_masked.convert('RGBA') + si.generate_adaptive_masks(latent_orig=p.init_latent, + latent_processed=samples_ddim, + overlay_images=p.overlay_images, + masks_for_overlay=p.masks_for_overlay, + width=p.width, + height=p.height, + paste_to=p.paste_to) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 0ee0b7dde03..a700e6922fc 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -94,76 +94,6 @@ def update_inner_model(self): self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - def latent_blend(a, b, t, one_minus_t=None): - - """ - Interpolates two latent image representations according to the parameter t, - where the interpolated vectors' magnitudes are also interpolated separately. - The "detail_preservation" factor biases the magnitude interpolation towards - the larger of the two magnitudes. - """ - # NOTE: We use inplace operations wherever possible. - - if one_minus_t is None: - one_minus_t = 1 - t - - if self.soft_inpainting is None: - return a * one_minus_t + b * t - - # Linearly interpolate the image vectors. - a_scaled = a * one_minus_t - b_scaled = b * t - image_interp = a_scaled - image_interp.add_(b_scaled) - result_type = image_interp.dtype - del a_scaled, b_scaled - - # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) - # 64-bit operations are used here to allow large exponents. - current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) - - # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * one_minus_t - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * t - desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / self.soft_inpainting.inpaint_detail_preservation) - del a_magnitude, b_magnitude, one_minus_t - - # Change the linearly interpolated image vectors' magnitudes to the value we want. - # This is the last 64-bit operation. - image_interp_scaling_factor = desired_magnitude - image_interp_scaling_factor.div_(current_magnitude) - image_interp_scaled = image_interp - image_interp_scaled.mul_(image_interp_scaling_factor) - del current_magnitude - del desired_magnitude - del image_interp - del image_interp_scaling_factor - - image_interp_scaled = image_interp_scaled.to(result_type) - del result_type - - return image_interp_scaled - - def get_modified_nmask(nmask, _sigma): - """ - Converts a negative mask representing the transparency of the original latent vectors being overlayed - to a mask that is scaled according to the denoising strength for this step. - - Where: - 0 = fully opaque, infinite density, fully masked - 1 = fully transparent, zero density, fully unmasked - - We bring this transparency to a power, as this allows one to simulate N number of blending operations - where N can be any positive real value. Using this one can control the balance of influence between - the denoiser and the original latents according to the sigma value. - - NOTE: "mask" is not used - """ - if self.soft_inpainting is None: - return nmask - - return torch.pow(nmask, (_sigma ** self.soft_inpainting.mask_blend_power) * self.soft_inpainting.mask_blend_scale) if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -184,9 +114,12 @@ def get_modified_nmask(nmask, _sigma): # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: if self.soft_inpainting is None: - x = latent_blend(self.init_latent, x, self.nmask, self.mask) + x = self.init_latent * self.mask + self.nmask * x else: - x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma)) + x = si.latent_blend(self.soft_inpainting, + self.init_latent, + x, + si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -290,9 +223,12 @@ def get_modified_nmask(nmask, _sigma): # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: if self.soft_inpainting is None: - denoised = latent_blend(self.init_latent, denoised, self.nmask, self.mask) + denoised = self.init_latent * self.mask + self.nmask * denoised else: - denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma)) + denoised = si.latent_blend(self.soft_inpainting, + self.init_latent, + denoised, + si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py index 259c36ec892..b81c8dd95d1 100644 --- a/modules/soft_inpainting.py +++ b/modules/soft_inpainting.py @@ -4,13 +4,6 @@ def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservati self.mask_blend_scale = mask_blend_scale self.inpaint_detail_preservation = inpaint_detail_preservation - def get_paste_fields(self): - return [ - (self.mask_blend_power, gen_param_labels.mask_blend_power), - (self.mask_blend_scale, gen_param_labels.mask_blend_scale), - (self.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation), - ] - def add_generation_params(self, dest): dest[enabled_gen_param_label] = True dest[gen_param_labels.mask_blend_power] = self.mask_blend_power @@ -18,25 +11,169 @@ def add_generation_params(self, dest): dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation +# ------------------- Methods ------------------- + + +def latent_blend(soft_inpainting, a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + import torch + + # NOTE: We use inplace operations wherever possible. + + one_minus_t = 1 - t + + # Linearly interpolate the image vectors. + a_scaled = a * one_minus_t + b_scaled = b * t + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) + del a_magnitude, b_magnitude, one_minus_t + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor + del result_type + + return image_interp_scaled + + +def get_modified_nmask(soft_inpainting, nmask, sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + import torch + return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + + +def generate_adaptive_masks( + latent_orig, + latent_processed, + overlay_images, + masks_for_overlay, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. + # latent_mask = p.nmask[0].float().cpu() + # convert the original mask into a form we use to scale distances for thresholding + # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) + # mask_scalar = mask_scalar / (1.00001-mask_scalar) + # mask_scalar = mask_scalar.numpy() + + latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) + + kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + # half_weighted_distance = 1 # * mask_scalar + # converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = images.smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc. uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + paste_to) + + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + +# ------------------- Constants ------------------- + + +default = SoftInpaintingSettings(1, 0.5, 4) + enabled_ui_label = "Soft inpainting" enabled_gen_param_label = "Soft inpainting enabled" enabled_el_id = "soft_inpainting_enabled" -default = SoftInpaintingSettings(1, 0.5, 4) -ui_labels = SoftInpaintingSettings("Schedule bias", "Preservation strength", "Transition contrast boost") +ui_labels = SoftInpaintingSettings( + "Schedule bias", + "Preservation strength", + "Transition contrast boost") ui_info = SoftInpaintingSettings( - mask_blend_power="Shifts when preservation of original content occurs during denoising.", - # "Below 1: Stronger preservation near the end (with low sigma)\n" - # "1: Balanced (proportional to sigma)\n" - # "Above 1: Stronger preservation in the beginning (with high sigma)", - mask_blend_scale="How strongly partially masked content should be preserved.", - # "Low values: Favors generated content.\n" - # "High values: Favors original content.", - inpaint_detail_preservation="Amplifies the contrast that may be lost in partially masked regions.") - -gen_param_labels = SoftInpaintingSettings("Soft inpainting schedule bias", "Soft inpainting preservation strength", "Soft inpainting transition contrast boost") -el_ids = SoftInpaintingSettings("mask_blend_power", "mask_blend_scale", "inpaint_detail_preservation") + "Shifts when preservation of original content occurs during denoising.", + "How strongly partially masked content should be preserved.", + "Amplifies the contrast that may be lost in partially masked regions.") + +gen_param_labels = SoftInpaintingSettings( + "Soft inpainting schedule bias", + "Soft inpainting preservation strength", + "Soft inpainting transition contrast boost") + +el_ids = SoftInpaintingSettings( + "mask_blend_power", + "mask_blend_scale", + "inpaint_detail_preservation") + + +# ------------------- UI ------------------- def gradio_ui(): diff --git a/modules/ui.py b/modules/ui.py index 0e4fb17aa75..4f1265a3e8b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -683,13 +683,6 @@ def copy_image(img): with FormRow(): soft_inpainting = si.gradio_ui() - - """ - mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id="img2img_mask_blend_scale") - inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id="img2img_mask_blend_offset") - """ - with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") From 1455159cf44cd8c21656818463f6095eae887540 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 16:43:57 -0700 Subject: [PATCH 061/449] Fixed issue with whitespace, removed commented out code that was meant to be used as a reference. --- modules/soft_inpainting.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py index b81c8dd95d1..56a87774693 100644 --- a/modules/soft_inpainting.py +++ b/modules/soft_inpainting.py @@ -179,15 +179,7 @@ def generate_adaptive_masks( def gradio_ui(): import gradio as gr from modules.ui_components import InputAccordion - """ - with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner: - with gr.Row(): - refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation") - create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) - - refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation") - """ with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: with gr.Group(): gr.Markdown( @@ -223,11 +215,11 @@ def gradio_ui(): gr.Markdown( f""" ### {ui_labels.mask_blend_power} - + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. - + - **Below 1**: Stronger preservation near the end (with low sigma) - **1**: Balanced (proportional to sigma) - **Above 1**: Stronger preservation in the beginning (with high sigma) @@ -235,21 +227,21 @@ def gradio_ui(): gr.Markdown( f""" ### {ui_labels.mask_blend_scale} - + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. - + - **Low values**: Favors generated content. - **High values**: Favors original content. """) gr.Markdown( f""" ### {ui_labels.inpaint_detail_preservation} - + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. This can prevent the loss of contrast that occurs with linear interpolation. - + - **Low values**: Softer blending, details may fade. - **High values**: Stronger contrast, may over-saturate colors. """) From 57f29bd61dc30f1a8c94ead9b780f4655f7d7d6d Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 17:41:18 -0700 Subject: [PATCH 062/449] Re-introduce latent blending step from the vanilla inpainting procedure. --- modules/processing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index 0b3603875a7..c8dc4d9340e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1597,6 +1597,9 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) + if self.mask is not None and self.soft_inpainting is None: + samples = samples * self.nmask + self.init_latent * self.mask + del x devices.torch_gc() From 60c602232fd760fb548fb0b3d18b5297f8823c2a Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 17:41:51 -0700 Subject: [PATCH 063/449] Restored original formatting. --- modules/processing.py | 36 +++++++++++------------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index c8dc4d9340e..90ae249a4db 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -370,10 +370,7 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None return self.edit_image_conditioning(source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image, - latent_image, - image_mask=image_mask, - round_image_mask=round_image_mask) + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask) if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) @@ -885,7 +882,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim - # todo: generate masks the old fashioned way + # todo: generate adaptive masks based on pixel differences. + # if p.masks_for_overlay is used, it will already be populated with masks else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method @@ -900,9 +898,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: height=p.height, paste_to=p.paste_to) - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, - target_device=devices.cpu, - check_for_nans=True) + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -927,9 +923,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = batch_params.images def infotext(index=0, use_main_prompt=False): - return create_infotext(p, p.prompts, p.seeds, p.subseeds, - use_main_prompt=use_main_prompt, index=index, - all_negative_prompts=p.negative_prompts) + return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts) save_samples = p.save_samples() @@ -972,8 +966,7 @@ def infotext(index=0, use_main_prompt=False): image = apply_overlay(image, p.paste_to, i, p.overlay_images) if save_samples: - images.save_image(image, p.outpath_samples, "", p.seeds[i], - p.prompts[i], opts.samples_format, info=infotext(i), p=p) + images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) text = infotext(i) infotexts.append(text) @@ -983,14 +976,10 @@ def infotext(index=0, use_main_prompt=False): if save_samples and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): if hasattr(p, 'masks_for_overlay') and p.masks_for_overlay: image_mask = p.masks_for_overlay[i].convert('RGB') - image_mask_composite = Image.composite( - original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), - images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') + image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') elif hasattr(p, 'mask_for_overlay') and p.mask_for_overlay: image_mask = p.mask_for_overlay.convert('RGB') - image_mask_composite = Image.composite( - original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), - images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') else: image_mask = None image_mask_composite = None @@ -1515,8 +1504,8 @@ def init(self, all_prompts, all_seeds, all_subseeds): self.masks_for_overlay.append(image_mask) else: image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + self.overlay_images.append(image_masked.convert('RGBA')) # crop_region is not None if we are doing inpaint full res @@ -1583,10 +1572,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, - self.init_latent, - image_mask, - self.soft_inpainting is None) + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.soft_inpainting is None) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = self.rng.next() From b32a334e3da7b06d82441beaa08a673b4f55bca1 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 17:57:10 -0700 Subject: [PATCH 064/449] Applies a convert('RGBA') operation early to mimic previous behaviour. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 90ae249a4db..7fc282cfd21 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1500,7 +1500,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): if image_mask is not None: if self.soft_inpainting is not None: # We apply the masks AFTER to adjust mask based on changed content. - self.overlay_images.append(image) + self.overlay_images.append(image.convert('RGBA')) self.masks_for_overlay.append(image_mask) else: image_masked = Image.new('RGBa', (image.width, image.height)) From 6fc12428e3c5f903584ca7986e0c441f80fa2807 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 19:42:59 -0700 Subject: [PATCH 065/449] Fixed issue where batched inpainting (batch size > 1) wouldn't work because of mismatched tensor sizes. The 'already_decoded' decoded case should also be handled correctly (tested indirectly). --- modules/processing.py | 23 ++++++++----- modules/soft_inpainting.py | 66 ++++++++++++++++++++++++++++++++------ 2 files changed, 71 insertions(+), 18 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 7fc282cfd21..71bb056a26b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -883,20 +883,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim # todo: generate adaptive masks based on pixel differences. - # if p.masks_for_overlay is used, it will already be populated with masks + if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: + si.apply_masks(soft_inpainting=p.soft_inpainting, + nmask=p.nmask, + overlay_images=p.overlay_images, + masks_for_overlay=p.masks_for_overlay, + width=p.width, + height=p.height, + paste_to=p.paste_to) else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method # Generate the mask(s) based on similarity between the original and denoised latent vectors if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - si.generate_adaptive_masks(latent_orig=p.init_latent, - latent_processed=samples_ddim, - overlay_images=p.overlay_images, - masks_for_overlay=p.masks_for_overlay, - width=p.width, - height=p.height, - paste_to=p.paste_to) + si.apply_adaptive_masks(latent_orig=p.init_latent, + latent_processed=samples_ddim, + overlay_images=p.overlay_images, + masks_for_overlay=p.masks_for_overlay, + width=p.width, + height=p.height, + paste_to=p.paste_to) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py index 56a87774693..b36ac8fa1b7 100644 --- a/modules/soft_inpainting.py +++ b/modules/soft_inpainting.py @@ -25,26 +25,32 @@ def latent_blend(soft_inpainting, a, b, t): # NOTE: We use inplace operations wherever possible. - one_minus_t = 1 - t + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + + one_minus_t2 = 1 - t2 + one_minus_t3 = 1 - t3 # Linearly interpolate the image vectors. - a_scaled = a * one_minus_t - b_scaled = b * t + a_scaled = a * one_minus_t2 + b_scaled = b * t2 image_interp = a_scaled image_interp.add_(b_scaled) result_type = image_interp.dtype - del a_scaled, b_scaled + del a_scaled, b_scaled, t2, one_minus_t2 # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) # 64-bit operations are used here to allow large exponents. - current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) + current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t + a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3 + b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3 desired_magnitude = a_magnitude desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) - del a_magnitude, b_magnitude, one_minus_t + del a_magnitude, b_magnitude, t3, one_minus_t3 # Change the linearly interpolated image vectors' magnitudes to the value we want. # This is the last 64-bit operation. @@ -78,10 +84,11 @@ def get_modified_nmask(soft_inpainting, nmask, sigma): NOTE: "mask" is not used """ import torch - return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + # todo: Why is sigma 2D? Both values are the same. + return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) -def generate_adaptive_masks( +def apply_adaptive_masks( latent_orig, latent_processed, overlay_images, @@ -142,6 +149,45 @@ def generate_adaptive_masks( overlay_images[i] = image_masked.convert('RGBA') +def apply_masks( + soft_inpainting, + nmask, + overlay_images, + masks_for_overlay, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + converted_mask = nmask[0].float() + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) + converted_mask = 255. * converted_mask + converted_mask = converted_mask.cpu().numpy().astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (width, height), + paste_to) + + for i, overlay_image in enumerate(overlay_images): + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + # ------------------- Constants ------------------- From 49bbf1140731036875573bb7c44aa7e74623c856 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 19:47:40 -0700 Subject: [PATCH 066/449] Fixed unused import. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 71bb056a26b..e1823ac3321 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -9,7 +9,7 @@ import torch import numpy as np -from PIL import Image, ImageOps, ImageFilter +from PIL import Image, ImageOps import random import cv2 from skimage import exposure From 895456c4a2e87f5fe3ee23b4482e68fce317a1ca Mon Sep 17 00:00:00 2001 From: Jabasukuriputo Wang Date: Tue, 5 Dec 2023 18:00:48 -0600 Subject: [PATCH 067/449] change state dict comparison to ref compare --- modules/sd_disable_initialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 8863107ae6f..273a7edd8b4 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -215,7 +215,7 @@ def load_state_dict(original, module, state_dict, strict=True): would be on the meta device. """ - if state_dict == sd: + if state_dict is sd: state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} original(module, state_dict, strict=strict) From 672dc4efa8e0da38426b121e7c7216d0a8e465fd Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 6 Dec 2023 15:16:10 +0800 Subject: [PATCH 068/449] Fix forced reload --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index dcf816b3a85..d0046f88c6d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -801,7 +801,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): if check_fp8(sd_model) != devices.fp8: # load from state dict again to prevent extra numerical errors forced_reload = True - elif sd_model.sd_model_checkpoint == checkpoint_info.filename: + elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload: return sd_model sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) From 4d56383025f2cbd00dc6296161e31a896624ab75 Mon Sep 17 00:00:00 2001 From: "fuchen.ljl" Date: Wed, 6 Dec 2023 20:23:56 +0800 Subject: [PATCH 069/449] Long distance memory overflow issue Problem: The memory will slowly increase with the drawing until restarting. Observation: GC analysis shows that no occupation has occurred, so it is suspected to be a problem with the underlying allocator. Reason: Under Linux, glibc is used to allocate memory. glibc uses brk and mmap to allocate memory, and the memory allocated by brk cannot be released until the high-address memory is released. That is to say, if you apply for two pieces of memory A and B through brk, it is impossible to release A before B is released, and it is still occupied by the process. Check the suspected "memory leak" through TOP. So I replaced TCMalloc, but found that libtcmalloc_minimal could not find ptthread_Key_Create. After analysis, it was found that pthread was not entered during compilation. --- webui.sh | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/webui.sh b/webui.sh index 3d0f87eed74..081624c4eb8 100755 --- a/webui.sh +++ b/webui.sh @@ -222,13 +222,29 @@ fi # Try using TCMalloc on Linux prepare_tcmalloc() { if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then - TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)" - if [[ ! -z "${TCMALLOC}" ]]; then - echo "Using TCMalloc: ${TCMALLOC}" - export LD_PRELOAD="${TCMALLOC}" - else - printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n" - fi + # Define Tcmalloc Libs arrays + TCMALLOC_LIBS=("libtcmalloc(_minimal|)\.so\.\d" "libtcmalloc\.so\.\d") + + # Traversal array + for lib in "${TCMALLOC_LIBS[@]}" + do + #Determine which type of tcmalloc library the library supports + TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -P $lib | head -n 1)" + TC_INFO=(${TCMALLOC//=>/}) + if [[ ! -z "${TC_INFO}" ]]; then + echo "Using TCMalloc: ${TC_INFO}" + #Determine if the library is linked to libptthread and resolve undefined symbol: ptthread_Key_Create + if ldd ${TC_INFO[2]} | grep -q 'libpthread'; then + echo "$TC_INFO is linked with libpthread,execute LD_PRELOAD=${TC_INFO}" + export LD_PRELOAD="${TC_INFO}" + break + else + echo "$TC_INFO is not linked with libpthreadand will trigger undefined symbol: ptthread_Key_Create error" + fi + else + printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n" + fi + done fi } From 746783f7a47f38f728f221cc26fe04035d3ca66b Mon Sep 17 00:00:00 2001 From: Nuullll Date: Wed, 6 Dec 2023 20:55:42 +0800 Subject: [PATCH 070/449] [IPEX] Fix embedding Cast `torch.bmm` args into same `dtype`. Fixes the following error when using Text Inversion embedding (#14224): ``` RuntimeError: could not create a primitive descriptor for a matmul primitive ``` --- modules/xpu_specific.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d933c790378..ec1ad100aae 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -48,3 +48,6 @@ def torch_xpu_gc(): CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.bmm', + lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), + lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) From 9d2cbf8e97832662e446145d3961c39e78919d3d Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 6 Dec 2023 23:06:32 +0900 Subject: [PATCH 071/449] add option: Live preview in full page image viewer make #13459 "show the preview image in the modal view if available" optional --- javascript/imageviewer.js | 2 +- modules/shared_options.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index e4dae91bc9d..625c5d148df 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -34,7 +34,7 @@ function updateOnBackgroundChange() { if (modalImage && modalImage.offsetParent) { let currentButton = selected_gallery_button(); let preview = gradioApp().querySelectorAll('.livePreview > img'); - if (preview.length > 0) { + if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) { // show preview image if available modalImage.src = preview[preview.length - 1].src; } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { diff --git a/modules/shared_options.py b/modules/shared_options.py index e5de0d018fb..88cfddedeb4 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -330,6 +330,7 @@ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), + "js_live_preview_in_modal_lightbox": OptionInfo(True, "Show Live preview in full page image viewer"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { From e90d4334ad37024a802f4ef27069b625a6508f72 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 16:54:42 -0700 Subject: [PATCH 072/449] A custom blending function can be provided by p, replacing the use of soft_inpainting. --- modules/sd_samplers_cfg_denoiser.py | 34 ++++++++++++++--------------- modules/sd_samplers_common.py | 1 - 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index a700e6922fc..f13e8dcc5b0 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -6,7 +6,6 @@ from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback -import modules.soft_inpainting as si def catenate_conds(conds): @@ -44,7 +43,6 @@ def __init__(self, sampler): self.model_wrap = None self.mask = None self.nmask = None - self.soft_inpainting: si.SoftInpaintingParameters = None self.init_latent = None self.steps = None """number of steps as specified by user in UI""" @@ -94,7 +92,6 @@ def update_inner_model(self): self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -111,15 +108,24 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + # If we use masks, blending between the denoised and original latent images occurs here. + def apply_blend(latent): + if hasattr(self.p, "denoiser_masked_blend_function") and callable(self.p.denoiser_masked_blend_function): + return self.p.denoiser_masked_blend_function( + self, + # Using an argument dictionary so that arguments can be added without breaking extensions. + args= + { + "denoiser": self, + "current_latent": latent, + "sigma": sigma + }) + else: + return self.init_latent * self.mask + self.nmask * latent + # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: - if self.soft_inpainting is None: - x = self.init_latent * self.mask + self.nmask * x - else: - x = si.latent_blend(self.soft_inpainting, - self.init_latent, - x, - si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) + x = apply_blend(x) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -222,13 +228,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: - if self.soft_inpainting is None: - denoised = self.init_latent * self.mask + self.nmask * denoised - else: - denoised = si.latent_blend(self.soft_inpainting, - self.init_latent, - denoised, - si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) + denoised = apply_blend(denoised) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 9682bee3d8e..58efcad2374 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -277,7 +277,6 @@ def initialize(self, p) -> dict: self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.soft_inpainting = p.soft_inpainting if hasattr(p, 'soft_inpainting') else None self.model_wrap_cfg.step = 0 self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) From 4608f6236fc24d937f89500b2c9bf48484537cf9 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 18:11:17 -0700 Subject: [PATCH 073/449] Removed changes in some scripts since the arguments for soft painting are no longer passed through the same path as "mask_blur". --- modules/img2img.py | 50 +------------------------------- modules/ui.py | 7 ----- scripts/outpainting_mk_2.py | 9 ++---- scripts/poor_mans_outpainting.py | 8 ++--- test/test_img2img.py | 5 ---- 5 files changed, 5 insertions(+), 74 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 3aa8a9cef6c..c583290a0ea 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -15,7 +15,6 @@ import modules.processing as processing from modules.ui import plaintext_to_html import modules.scripts -import modules.soft_inpainting as si def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): @@ -147,48 +146,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal return batch_results -def img2img(id_task: str, - mode: int, - prompt: str, - negative_prompt: str, - prompt_styles, - init_img, - sketch, - init_img_with_mask, - inpaint_color_sketch, - inpaint_color_sketch_orig, - init_img_inpaint, - init_mask_inpaint, - steps: int, - sampler_name: str, - mask_blur: int, - mask_alpha: float, - mask_blend_enabled: bool, - mask_blend_power: float, - mask_blend_scale: float, - inpaint_detail_preservation: float, - inpainting_fill: int, - n_iter: int, - batch_size: int, - cfg_scale: float, - image_cfg_scale: float, - denoising_strength: float, - selected_scale_tab: int, - height: int, - width: int, - scale_by: float, - resize_mode: int, - inpaint_full_res: bool, - inpaint_full_res_padding: int, - inpainting_mask_invert: int, - img2img_batch_input_dir: str, - img2img_batch_output_dir: str, - img2img_batch_inpaint_mask_dir: str, - override_settings_texts, - img2img_batch_use_png_info: bool, - img2img_batch_png_info_props: list, - img2img_batch_png_info_dir: str, - request: gr.Request, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -229,9 +187,6 @@ def img2img(id_task: str, assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' - soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ - if mask_blend_enabled else None - p = StableDiffusionProcessingImg2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples, @@ -249,7 +204,6 @@ def img2img(id_task: str, init_images=[image], mask=mask, mask_blur=mask_blur, - soft_inpainting=soft_inpainting, inpainting_fill=inpainting_fill, resize_mode=resize_mode, denoising_strength=denoising_strength, @@ -270,8 +224,6 @@ def img2img(id_task: str, if mask: p.extra_generation_params["Mask blur"] = mask_blur - if soft_inpainting is not None: - soft_inpainting.add_generation_params(p.extra_generation_params) with closing(p): if is_batch: diff --git a/modules/ui.py b/modules/ui.py index bd2091e1ff9..d80486dd4ac 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -29,7 +29,6 @@ from modules import prompt_parser from modules.sd_hijack import model_hijack from modules.generation_parameters_copypaste import image_from_url_text -import modules.soft_inpainting as si create_setting_component = ui_settings.create_setting_component @@ -680,9 +679,6 @@ def copy_image(img): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") - with FormRow(): - soft_inpainting = si.gradio_ui() - with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -737,7 +733,6 @@ def select_img2img_tab(tab): sampler_name, mask_blur, mask_alpha, - *(soft_inpainting[0]), inpainting_fill, batch_count, batch_size, @@ -836,10 +831,8 @@ def select_img2img_tab(tab): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), - *(soft_inpainting[1]), *scripts.scripts_img2img.infotext_fields ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index f78886883bd..c98ab48098e 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -10,7 +10,6 @@ from modules import images from modules.processing import Processed, process_images from modules.shared import opts, state -import modules.soft_inpainting as si # this function is taken from https://github.com/parlance-zz/g-diffuser-bot @@ -134,14 +133,13 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) - soft_inpainting = si.gradio_ui()[0] direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - return [info, pixels, mask_blur, *soft_inpainting, direction, noise_q, color_variation] + return [info, pixels, mask_blur, direction, noise_q, color_variation] - def run(self, p, _, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation): + def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): initial_seed_and_info = [None, None] process_width = p.width @@ -170,9 +168,6 @@ def run(self, p, _, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mas p.mask_blur_x = mask_blur_x*4 p.mask_blur_y = mask_blur_y*4 - p.soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ - if mask_blend_enabled else None - init_img = p.init_images[0] target_w = math.ceil((init_img.width + left + right) / 64) * 64 target_h = math.ceil((init_img.height + up + down) / 64) * 64 diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 11f7f74a820..ea0632b68c1 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -7,7 +7,6 @@ from modules import images, devices from modules.processing import Processed, process_images from modules.shared import opts, state -import modules.soft_inpainting as si class Script(scripts.Script): @@ -23,19 +22,16 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) - soft_inpainting = si.gradio_ui()[0] inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - return [pixels, mask_blur, *soft_inpainting, inpainting_fill, direction] + return [pixels, mask_blur, inpainting_fill, direction] - def run(self, p, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction): + def run(self, p, pixels, mask_blur, inpainting_fill, direction): initial_seed = None initial_info = None p.mask_blur = mask_blur * 2 - p.soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ - if mask_blend_enabled else None p.inpainting_fill = inpainting_fill p.inpaint_full_res = False diff --git a/test/test_img2img.py b/test/test_img2img.py index 87bd850912e..117d2d1eb45 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -1,7 +1,6 @@ import pytest import requests -import modules.soft_inpainting as si @pytest.fixture() @@ -25,10 +24,6 @@ def simple_img2img_request(img2img_basic_image_base64): "inpainting_mask_invert": False, "mask": None, "mask_blur": 4, - "mask_blend_enabled": True, - "mask_blend_power": si.default.mask_blend_power, - "mask_blend_scale": si.default.mask_blend_scale, - "inpaint_detail_preservation": si.default.inpaint_detail_preservation, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From ac4578912395627731f2cd8529f87a95df1f7644 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 21:16:27 -0700 Subject: [PATCH 074/449] Removed soft inpainting, added hooks for softpainting to work instead. --- modules/processing.py | 94 ++++++++++++----------------- modules/scripts.py | 70 +++++++++++++++++++++ modules/sd_samplers_cfg_denoiser.py | 23 +++---- 3 files changed, 118 insertions(+), 69 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 7d46949fa3e..5a1a90afedd 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -30,7 +30,6 @@ import modules.sd_vae as sd_vae from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion -import modules.soft_inpainting as si from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType @@ -73,12 +72,10 @@ def uncrop(image, dest_size, paste_loc): return image -def apply_overlay(image, paste_loc, index, overlays): - if overlays is None or index >= len(overlays): +def apply_overlay(image, paste_loc, overlay): + if overlay is None: return image - overlay = overlays[index] - if paste_loc is not None: image = uncrop(image, (overlay.width, overlay.height), paste_loc) @@ -150,7 +147,6 @@ class StableDiffusionProcessing: do_not_save_grid: bool = False extra_generation_params: dict[str, Any] = None overlay_images: list = None - masks_for_overlay: list = None eta: float = None do_not_reload_embeddings: bool = False denoising_strength: float = None @@ -880,31 +876,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + if p.scripts is not None: + ps = scripts.PostSampleArgs(samples_ddim) + p.scripts.post_sample(p, ps) + samples_ddim = pp.samples + if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim - # todo: generate adaptive masks based on pixel differences. - if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - si.apply_masks(soft_inpainting=p.soft_inpainting, - nmask=p.nmask, - overlay_images=p.overlay_images, - masks_for_overlay=p.masks_for_overlay, - width=p.width, - height=p.height, - paste_to=p.paste_to) else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - # Generate the mask(s) based on similarity between the original and denoised latent vectors - if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - si.apply_adaptive_masks(latent_orig=p.init_latent, - latent_processed=samples_ddim, - overlay_images=p.overlay_images, - masks_for_overlay=p.masks_for_overlay, - width=p.width, - height=p.height, - paste_to=p.paste_to) - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -955,9 +937,18 @@ def infotext(index=0, use_main_prompt=False): pp = scripts.PostprocessImageArgs(image) p.scripts.postprocess_image(p, pp) image = pp.image + + mask_for_overlay = p.mask_for_overlay + overlay_image = p.overlay_images[i] if p.overlay_images is not None and i < len(p.overlay_images) else None + + if p.scripts is not None: + ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) + p.scripts.postprocess_maskoverlay(p, ppmo) + mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image + if p.color_corrections is not None and i < len(p.color_corrections): if save_samples and opts.save_images_before_color_correction: - image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) + image_without_cc = apply_overlay(image, p.paste_to, overlay_image) images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) @@ -968,9 +959,9 @@ def infotext(index=0, use_main_prompt=False): original_denoised_image = image.copy() if p.paste_to is not None: - original_denoised_image = uncrop(original_denoised_image, (p.overlay_images[i].width, p.overlay_images[i].height), p.paste_to) + original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to) - image = apply_overlay(image, p.paste_to, i, p.overlay_images) + image = apply_overlay(image, p.paste_to, overlay_image) if save_samples: images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) @@ -981,13 +972,6 @@ def infotext(index=0, use_main_prompt=False): image.info["parameters"] = text output_images.append(image) - if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay: - mask_for_overlay = p.mask_for_overlay - elif hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and p.masks_for_overlay[i]: - mask_for_overlay = p.masks_for_overlay[i] - else: - mask_for_overlay = None - if mask_for_overlay is not None: if opts.return_mask or opts.save_mask: image_mask = mask_for_overlay.convert('RGB') @@ -1401,7 +1385,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_x: int = 4 mask_blur_y: int = 4 mask_blur: int = None - soft_inpainting: si.SoftInpaintingParameters = si.default + mask_round: bool = True inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 @@ -1447,7 +1431,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): if image_mask is not None: # image_mask is passed in as RGBA by Gradio to support alpha masks, # but we still want to support binary masks. - image_mask = create_binary_mask(image_mask, round=(self.soft_inpainting is None)) + image_mask = create_binary_mask(image_mask, round=self.mask_round) if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) @@ -1465,7 +1449,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: - self.mask_for_overlay = image_mask if self.soft_inpainting is None else None + self.mask_for_overlay = image_mask mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1476,13 +1460,10 @@ def init(self, all_prompts, all_seeds, all_subseeds): self.paste_to = (x1, y1, x2-x1, y2-y1) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) + np_mask = np.array(image_mask) + np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) + self.mask_for_overlay = Image.fromarray(np_mask) - if self.soft_inpainting is None: - np_mask = np.array(image_mask) - np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) - self.mask_for_overlay = Image.fromarray(np_mask) - - self.masks_for_overlay = [] if self.soft_inpainting is not None else None self.overlay_images = [] latent_mask = self.latent_mask if self.latent_mask is not None else image_mask @@ -1504,15 +1485,10 @@ def init(self, all_prompts, all_seeds, all_subseeds): image = images.resize_image(self.resize_mode, image, self.width, self.height) if image_mask is not None: - if self.soft_inpainting is not None: - # We apply the masks AFTER to adjust mask based on changed content. - self.overlay_images.append(image.convert('RGBA')) - self.masks_for_overlay.append(image_mask) - else: - image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + image_masked = Image.new('RGBa', (image.width, image.height)) + image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) - self.overlay_images.append(image_masked.convert('RGBA')) + self.overlay_images.append(image_masked.convert('RGBA')) # crop_region is not None if we are doing inpaint full res if crop_region is not None: @@ -1565,7 +1541,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = latmask[0] - if self.soft_inpainting is None: + if self.mask_round: latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) @@ -1578,7 +1554,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.soft_inpainting is None) + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = self.rng.next() @@ -1589,8 +1565,14 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) - if self.mask is not None and self.soft_inpainting is None: - samples = samples * self.nmask + self.init_latent * self.mask + blended_samples = samples * self.nmask + self.init_latent * self.mask + + if self.scripts is not None: + mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True) + self.scripts.on_mask_blend(self, mba) + blended_samples = mba.blended_latent + + samples = blended_samples del x devices.torch_gc() diff --git a/modules/scripts.py b/modules/scripts.py index 7f9454eb578..92a07c564e9 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -11,11 +11,31 @@ AlwaysVisible = object() +class MaskBlendArgs: + def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None): + self.current_latent = current_latent + self.nmask = nmask + self.init_latent = init_latent + self.mask = mask + self.blended_samples = blended_samples + + self.denoiser = denoiser + self.is_final_blend = denoiser is None + self.sigma = sigma + +class PostSampleArgs: + def __init__(self, samples): + self.samples = samples class PostprocessImageArgs: def __init__(self, image): self.image = image +class PostProcessMaskOverlayArgs: + def __init__(self, index, mask_for_overlay, overlay_image): + self.index = index + self.mask_for_overlay = mask_for_overlay + self.overlay_image = overlay_image class PostprocessBatchListArgs: def __init__(self, images): @@ -206,6 +226,25 @@ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwarg pass + def on_mask_blend(self, p, mba: MaskBlendArgs, *args): + """ + Called in inpainting mode when the original content is blended with the inpainted content. + This is called at every step in the denoising process and once at the end. + If is_final_blend is true, this is called for the final blending stage. + Otherwise, denoiser and sigma are defined and may be used to inform the procedure. + """ + + pass + + def post_sample(self, p, ps: PostSampleArgs, *args): + """ + Called after the samples have been generated, + but before they have been decoded by the VAE, if applicable. + Check getattr(samples, 'already_decoded', False) to test if the images are decoded. + """ + + pass + def postprocess_image(self, p, pp: PostprocessImageArgs, *args): """ Called for every image after it has been generated. @@ -213,6 +252,13 @@ def postprocess_image(self, p, pp: PostprocessImageArgs, *args): pass + def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args): + """ + Called for every image after it has been generated. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -767,6 +813,22 @@ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs): except Exception: errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) + def post_sample(self, p, ps: PostSampleArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.post_sample(p, ps, *script_args) + except Exception: + errors.report(f"Error running post_sample: {script.filename}", exc_info=True) + + def on_mask_blend(self, p, mba: MaskBlendArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.on_mask_blend(p, mba, *script_args) + except Exception: + errors.report(f"Error running post_sample: {script.filename}", exc_info=True) + def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: try: @@ -775,6 +837,14 @@ def postprocess_image(self, p, pp: PostprocessImageArgs): except Exception: errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_maskoverlay(p, ppmo, *script_args) + except Exception: + errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def before_component(self, component, **kwargs): for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): try: diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index f13e8dcc5b0..eb9d5dafacc 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -109,19 +109,16 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" # If we use masks, blending between the denoised and original latent images occurs here. - def apply_blend(latent): - if hasattr(self.p, "denoiser_masked_blend_function") and callable(self.p.denoiser_masked_blend_function): - return self.p.denoiser_masked_blend_function( - self, - # Using an argument dictionary so that arguments can be added without breaking extensions. - args= - { - "denoiser": self, - "current_latent": latent, - "sigma": sigma - }) - else: - return self.init_latent * self.mask + self.nmask * latent + def apply_blend(current_latent): + blended_latent = current_latent * self.nmask + self.init_latent * self.mask + + if self.p.scripts is not None: + from modules import scripts + mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma) + self.p.scripts.on_mask_blend(self.p, mba) + blended_latent = mba.blended_latent + + return blended_latent # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: From 2abc417834d752e43a283f8603bfddfb1c80b30f Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 22:25:53 -0700 Subject: [PATCH 075/449] Re-implemented soft inpainting via a script. Also fixed some mistakes with the previous hooks, removed unnecessary formatting changes, removed code that I had forgotten to. --- modules/processing.py | 23 +-- modules/scripts.py | 4 +- modules/soft_inpainting.py | 308 ---------------------------- scripts/soft_inpainting.py | 401 +++++++++++++++++++++++++++++++++++++ 4 files changed, 413 insertions(+), 323 deletions(-) delete mode 100644 modules/soft_inpainting.py create mode 100644 scripts/soft_inpainting.py diff --git a/modules/processing.py b/modules/processing.py index 5a1a90afedd..f8d85bdf527 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -879,14 +879,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: ps = scripts.PostSampleArgs(samples_ddim) p.scripts.post_sample(p, ps) - samples_ddim = pp.samples + samples_ddim = ps.samples if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -944,7 +943,7 @@ def infotext(index=0, use_main_prompt=False): if p.scripts is not None: ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) p.scripts.postprocess_maskoverlay(p, ppmo) - mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image + mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image if p.color_corrections is not None and i < len(p.color_corrections): if save_samples and opts.save_images_before_color_correction: @@ -959,7 +958,7 @@ def infotext(index=0, use_main_prompt=False): original_denoised_image = image.copy() if p.paste_to is not None: - original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to) + original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to) image = apply_overlay(image, p.paste_to, overlay_image) @@ -1512,9 +1511,6 @@ def init(self, all_prompts, all_seeds, all_subseeds): if self.overlay_images is not None: self.overlay_images = self.overlay_images * self.batch_size - if self.masks_for_overlay is not None: - self.masks_for_overlay = self.masks_for_overlay * self.batch_size - if self.color_corrections is not None and len(self.color_corrections) == 1: self.color_corrections = self.color_corrections * self.batch_size @@ -1565,14 +1561,15 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) - blended_samples = samples * self.nmask + self.init_latent * self.mask + if self.mask is not None: + blended_samples = samples * self.nmask + self.init_latent * self.mask - if self.scripts is not None: - mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True) - self.scripts.on_mask_blend(self, mba) - blended_samples = mba.blended_latent + if self.scripts is not None: + mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples) + self.scripts.on_mask_blend(self, mba) + blended_samples = mba.blended_latent - samples = blended_samples + samples = blended_samples del x devices.torch_gc() diff --git a/modules/scripts.py b/modules/scripts.py index 92a07c564e9..b6fcf96e047 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -12,12 +12,12 @@ AlwaysVisible = object() class MaskBlendArgs: - def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None): + def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None): self.current_latent = current_latent self.nmask = nmask self.init_latent = init_latent self.mask = mask - self.blended_samples = blended_samples + self.blended_latent = blended_latent self.denoiser = denoiser self.is_final_blend = denoiser is None diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py deleted file mode 100644 index b36ac8fa1b7..00000000000 --- a/modules/soft_inpainting.py +++ /dev/null @@ -1,308 +0,0 @@ -class SoftInpaintingSettings: - def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): - self.mask_blend_power = mask_blend_power - self.mask_blend_scale = mask_blend_scale - self.inpaint_detail_preservation = inpaint_detail_preservation - - def add_generation_params(self, dest): - dest[enabled_gen_param_label] = True - dest[gen_param_labels.mask_blend_power] = self.mask_blend_power - dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale - dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation - - -# ------------------- Methods ------------------- - - -def latent_blend(soft_inpainting, a, b, t): - """ - Interpolates two latent image representations according to the parameter t, - where the interpolated vectors' magnitudes are also interpolated separately. - The "detail_preservation" factor biases the magnitude interpolation towards - the larger of the two magnitudes. - """ - import torch - - # NOTE: We use inplace operations wherever possible. - - # [4][w][h] to [1][4][w][h] - t2 = t.unsqueeze(0) - # [4][w][h] to [1][1][w][h] - the [4] seem redundant. - t3 = t[0].unsqueeze(0).unsqueeze(0) - - one_minus_t2 = 1 - t2 - one_minus_t3 = 1 - t3 - - # Linearly interpolate the image vectors. - a_scaled = a * one_minus_t2 - b_scaled = b * t2 - image_interp = a_scaled - image_interp.add_(b_scaled) - result_type = image_interp.dtype - del a_scaled, b_scaled, t2, one_minus_t2 - - # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) - # 64-bit operations are used here to allow large exponents. - current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) - - # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3 - b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3 - desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) - del a_magnitude, b_magnitude, t3, one_minus_t3 - - # Change the linearly interpolated image vectors' magnitudes to the value we want. - # This is the last 64-bit operation. - image_interp_scaling_factor = desired_magnitude - image_interp_scaling_factor.div_(current_magnitude) - image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) - image_interp_scaled = image_interp - image_interp_scaled.mul_(image_interp_scaling_factor) - del current_magnitude - del desired_magnitude - del image_interp - del image_interp_scaling_factor - del result_type - - return image_interp_scaled - - -def get_modified_nmask(soft_inpainting, nmask, sigma): - """ - Converts a negative mask representing the transparency of the original latent vectors being overlayed - to a mask that is scaled according to the denoising strength for this step. - - Where: - 0 = fully opaque, infinite density, fully masked - 1 = fully transparent, zero density, fully unmasked - - We bring this transparency to a power, as this allows one to simulate N number of blending operations - where N can be any positive real value. Using this one can control the balance of influence between - the denoiser and the original latents according to the sigma value. - - NOTE: "mask" is not used - """ - import torch - # todo: Why is sigma 2D? Both values are the same. - return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) - - -def apply_adaptive_masks( - latent_orig, - latent_processed, - overlay_images, - masks_for_overlay, - width, height, - paste_to): - import torch - import numpy as np - import modules.processing as proc - import modules.images as images - from PIL import Image, ImageOps, ImageFilter - - # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. - # latent_mask = p.nmask[0].float().cpu() - # convert the original mask into a form we use to scale distances for thresholding - # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) - # mask_scalar = mask_scalar / (1.00001-mask_scalar) - # mask_scalar = mask_scalar.numpy() - - latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) - - kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) - - for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): - converted_mask = distance_map.float().cpu().numpy() - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.9, percentile_max=1, min_width=1) - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.25, percentile_max=0.75, min_width=1) - - # The distance at which opacity of original decreases to 50% - # half_weighted_distance = 1 # * mask_scalar - # converted_mask = converted_mask / half_weighted_distance - - converted_mask = 1 / (1 + converted_mask ** 2) - converted_mask = images.smootherstep(converted_mask) - converted_mask = 1 - converted_mask - converted_mask = 255. * converted_mask - converted_mask = converted_mask.astype(np.uint8) - converted_mask = Image.fromarray(converted_mask) - converted_mask = images.resize_image(2, converted_mask, width, height) - converted_mask = proc.create_binary_mask(converted_mask, round=False) - - # Remove aliasing artifacts using a gaussian blur. - converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) - - # Expand the mask to fit the whole image if needed. - if paste_to is not None: - converted_mask = proc. uncrop(converted_mask, - (overlay_image.width, overlay_image.height), - paste_to) - - masks_for_overlay[i] = converted_mask - - image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) - image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(converted_mask.convert('L'))) - - overlay_images[i] = image_masked.convert('RGBA') - -def apply_masks( - soft_inpainting, - nmask, - overlay_images, - masks_for_overlay, - width, height, - paste_to): - import torch - import numpy as np - import modules.processing as proc - import modules.images as images - from PIL import Image, ImageOps, ImageFilter - - converted_mask = nmask[0].float() - converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) - converted_mask = 255. * converted_mask - converted_mask = converted_mask.cpu().numpy().astype(np.uint8) - converted_mask = Image.fromarray(converted_mask) - converted_mask = images.resize_image(2, converted_mask, width, height) - converted_mask = proc.create_binary_mask(converted_mask, round=False) - - # Remove aliasing artifacts using a gaussian blur. - converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) - - # Expand the mask to fit the whole image if needed. - if paste_to is not None: - converted_mask = proc.uncrop(converted_mask, - (width, height), - paste_to) - - for i, overlay_image in enumerate(overlay_images): - masks_for_overlay[i] = converted_mask - - image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) - image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(converted_mask.convert('L'))) - - overlay_images[i] = image_masked.convert('RGBA') - - -# ------------------- Constants ------------------- - - -default = SoftInpaintingSettings(1, 0.5, 4) - -enabled_ui_label = "Soft inpainting" -enabled_gen_param_label = "Soft inpainting enabled" -enabled_el_id = "soft_inpainting_enabled" - -ui_labels = SoftInpaintingSettings( - "Schedule bias", - "Preservation strength", - "Transition contrast boost") - -ui_info = SoftInpaintingSettings( - "Shifts when preservation of original content occurs during denoising.", - "How strongly partially masked content should be preserved.", - "Amplifies the contrast that may be lost in partially masked regions.") - -gen_param_labels = SoftInpaintingSettings( - "Soft inpainting schedule bias", - "Soft inpainting preservation strength", - "Soft inpainting transition contrast boost") - -el_ids = SoftInpaintingSettings( - "mask_blend_power", - "mask_blend_scale", - "inpaint_detail_preservation") - - -# ------------------- UI ------------------- - - -def gradio_ui(): - import gradio as gr - from modules.ui_components import InputAccordion - - with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: - with gr.Group(): - gr.Markdown( - """ - Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. - **High _Mask blur_** values are recommended! - """) - - result = SoftInpaintingSettings( - gr.Slider(label=ui_labels.mask_blend_power, - info=ui_info.mask_blend_power, - minimum=0, - maximum=8, - step=0.1, - value=default.mask_blend_power, - elem_id=el_ids.mask_blend_power), - gr.Slider(label=ui_labels.mask_blend_scale, - info=ui_info.mask_blend_scale, - minimum=0, - maximum=8, - step=0.05, - value=default.mask_blend_scale, - elem_id=el_ids.mask_blend_scale), - gr.Slider(label=ui_labels.inpaint_detail_preservation, - info=ui_info.inpaint_detail_preservation, - minimum=1, - maximum=32, - step=0.5, - value=default.inpaint_detail_preservation, - elem_id=el_ids.inpaint_detail_preservation)) - - with gr.Accordion("Help", open=False): - gr.Markdown( - f""" - ### {ui_labels.mask_blend_power} - - The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). - This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. - This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. - - - **Below 1**: Stronger preservation near the end (with low sigma) - - **1**: Balanced (proportional to sigma) - - **Above 1**: Stronger preservation in the beginning (with high sigma) - """) - gr.Markdown( - f""" - ### {ui_labels.mask_blend_scale} - - Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. - This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. - - - **Low values**: Favors generated content. - - **High values**: Favors original content. - """) - gr.Markdown( - f""" - ### {ui_labels.inpaint_detail_preservation} - - This parameter controls how the original latent vectors and denoised latent vectors are interpolated. - With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. - This can prevent the loss of contrast that occurs with linear interpolation. - - - **Low values**: Softer blending, details may fade. - - **High values**: Stronger contrast, may over-saturate colors. - """) - - return ( - [ - soft_inpainting_enabled, - result.mask_blend_power, - result.mask_blend_scale, - result.inpaint_detail_preservation - ], - [ - (soft_inpainting_enabled, enabled_gen_param_label), - (result.mask_blend_power, gen_param_labels.mask_blend_power), - (result.mask_blend_scale, gen_param_labels.mask_blend_scale), - (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation) - ] - ) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py new file mode 100644 index 00000000000..47e0269bf1b --- /dev/null +++ b/scripts/soft_inpainting.py @@ -0,0 +1,401 @@ +import gradio as gr +from modules.ui_components import InputAccordion +import modules.scripts as scripts + + +class SoftInpaintingSettings: + def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): + self.mask_blend_power = mask_blend_power + self.mask_blend_scale = mask_blend_scale + self.inpaint_detail_preservation = inpaint_detail_preservation + + def add_generation_params(self, dest): + dest[enabled_gen_param_label] = True + dest[gen_param_labels.mask_blend_power] = self.mask_blend_power + dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale + dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + + +# ------------------- Methods ------------------- + + +def latent_blend(soft_inpainting, a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + import torch + + # NOTE: We use inplace operations wherever possible. + + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + + one_minus_t2 = 1 - t2 + one_minus_t3 = 1 - t3 + + # Linearly interpolate the image vectors. + a_scaled = a * one_minus_t2 + b_scaled = b * t2 + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled, t2, one_minus_t2 + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + soft_inpainting.inpaint_detail_preservation) * one_minus_t3 + b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + soft_inpainting.inpaint_detail_preservation) * t3 + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) + del a_magnitude, b_magnitude, t3, one_minus_t3 + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor + del result_type + + return image_interp_scaled + + +def get_modified_nmask(soft_inpainting, nmask, sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + import torch + return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + + +def apply_adaptive_masks( + latent_orig, + latent_processed, + overlay_images, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. + # latent_mask = p.nmask[0].float().cpu() + # convert the original mask into a form we use to scale distances for thresholding + # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) + # mask_scalar = mask_scalar / (1.00001-mask_scalar) + # mask_scalar = mask_scalar.numpy() + + latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) + + kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + masks_for_overlay = [] + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + # half_weighted_distance = 1 # * mask_scalar + # converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = images.smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + paste_to) + + masks_for_overlay.append(converted_mask) + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +def apply_masks( + soft_inpainting, + nmask, + overlay_images, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + converted_mask = nmask[0].float() + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) + converted_mask = 255. * converted_mask + converted_mask = converted_mask.cpu().numpy().astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (width, height), + paste_to) + + masks_for_overlay = [] + + for i, overlay_image in enumerate(overlay_images): + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +# ------------------- Constants ------------------- + + +default = SoftInpaintingSettings(1, 0.5, 4) + +enabled_ui_label = "Soft inpainting" +enabled_gen_param_label = "Soft inpainting enabled" +enabled_el_id = "soft_inpainting_enabled" + +ui_labels = SoftInpaintingSettings( + "Schedule bias", + "Preservation strength", + "Transition contrast boost") + +ui_info = SoftInpaintingSettings( + "Shifts when preservation of original content occurs during denoising.", + "How strongly partially masked content should be preserved.", + "Amplifies the contrast that may be lost in partially masked regions.") + +gen_param_labels = SoftInpaintingSettings( + "Soft inpainting schedule bias", + "Soft inpainting preservation strength", + "Soft inpainting transition contrast boost") + +el_ids = SoftInpaintingSettings( + "mask_blend_power", + "mask_blend_scale", + "inpaint_detail_preservation") + + +class Script(scripts.Script): + + def __init__(self): + self.masks_for_overlay = None + self.overlay_images = None + + def title(self): + return "Soft Inpainting" + + def show(self, is_img2img): + return scripts.AlwaysVisible if is_img2img else False + + def ui(self, is_img2img): + if not is_img2img: + return + + with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: + with gr.Group(): + gr.Markdown( + """ + Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. + **High _Mask blur_** values are recommended! + """) + + result = SoftInpaintingSettings( + gr.Slider(label=ui_labels.mask_blend_power, + info=ui_info.mask_blend_power, + minimum=0, + maximum=8, + step=0.1, + value=default.mask_blend_power, + elem_id=el_ids.mask_blend_power), + gr.Slider(label=ui_labels.mask_blend_scale, + info=ui_info.mask_blend_scale, + minimum=0, + maximum=8, + step=0.05, + value=default.mask_blend_scale, + elem_id=el_ids.mask_blend_scale), + gr.Slider(label=ui_labels.inpaint_detail_preservation, + info=ui_info.inpaint_detail_preservation, + minimum=1, + maximum=32, + step=0.5, + value=default.inpaint_detail_preservation, + elem_id=el_ids.inpaint_detail_preservation)) + + with gr.Accordion("Help", open=False): + gr.Markdown( + f""" + ### {ui_labels.mask_blend_power} + + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). + This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. + This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. + + - **Below 1**: Stronger preservation near the end (with low sigma) + - **1**: Balanced (proportional to sigma) + - **Above 1**: Stronger preservation in the beginning (with high sigma) + """) + gr.Markdown( + f""" + ### {ui_labels.mask_blend_scale} + + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. + This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. + + - **Low values**: Favors generated content. + - **High values**: Favors original content. + """) + gr.Markdown( + f""" + ### {ui_labels.inpaint_detail_preservation} + + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. + With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. + This can prevent the loss of contrast that occurs with linear interpolation. + + - **Low values**: Softer blending, details may fade. + - **High values**: Stronger contrast, may over-saturate colors. + """) + + self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), + (result.mask_blend_power, gen_param_labels.mask_blend_power), + (result.mask_blend_scale, gen_param_labels.mask_blend_scale), + (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)] + + self.paste_field_names = [] + for _, field_name in self.infotext_fields: + self.paste_field_names.append(field_name) + + return [soft_inpainting_enabled, + result.mask_blend_power, + result.mask_blend_scale, + result.inpaint_detail_preservation] + + def process(self, p, enabled, power, scale, detail_preservation): + if not enabled: + return + + # Shut off the rounding it normally does. + p.mask_round = False + + settings = SoftInpaintingSettings(power, scale, detail_preservation) + + # p.extra_generation_params["Mask rounding"] = False + settings.add_generation_params(p.extra_generation_params) + + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation): + if not enabled: + return + + if mba.sigma is None: + mba.blended_latent = mba.current_latent + return + + settings = SoftInpaintingSettings(power, scale, detail_preservation) + + # todo: Why is sigma 2D? Both values are the same. + mba.blended_latent = latent_blend(settings, + mba.init_latent, + mba.current_latent, + get_modified_nmask(settings, mba.nmask, mba.sigma[0])) + + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation): + if not enabled: + return + + settings = SoftInpaintingSettings(power, scale, detail_preservation) + + from modules import images + from modules.shared import opts + + # since the original code puts holes in the existing overlay images, + # we have to rebuild them. + self.overlay_images = [] + for img in p.init_images: + + image = images.flatten(img, opts.img2img_background_color) + + if p.paste_to is None and p.resize_mode != 3: + image = images.resize_image(p.resize_mode, image, p.width, p.height) + + self.overlay_images.append(image.convert('RGBA')) + + if getattr(ps.samples, 'already_decoded', False): + self.masks_for_overlay = apply_masks(soft_inpainting=settings, + nmask=p.nmask, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + else: + self.masks_for_overlay = apply_adaptive_masks(latent_orig=p.init_latent, + latent_processed=ps.samples, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + + + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation): + if not enabled: + return + + ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] + ppmo.overlay_image = self.overlay_images[ppmo.index] \ No newline at end of file From 8dbacc7d018774a3bc801cc57617795274a15087 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 14:30:30 -0700 Subject: [PATCH 076/449] Fixed "No newline at end of file". --- scripts/soft_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 47e0269bf1b..6d0cf847952 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -398,4 +398,4 @@ def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, e return ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] - ppmo.overlay_image = self.overlay_images[ppmo.index] \ No newline at end of file + ppmo.overlay_image = self.overlay_images[ppmo.index] From 56604f08a18588e8e6b57d7c3f9c61d6624846f8 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 14:53:44 -0700 Subject: [PATCH 077/449] Moved image filters used by soft inpainting into soft_inpainting.py from images.py --- modules/images.py | 190 ---------------------------------- scripts/soft_inpainting.py | 205 +++++++++++++++++++++++++++++++++++-- 2 files changed, 199 insertions(+), 196 deletions(-) diff --git a/modules/images.py b/modules/images.py index 9495349866b..16f9ae7cce1 100644 --- a/modules/images.py +++ b/modules/images.py @@ -792,193 +792,3 @@ def flatten(img, bgcolor): return img.convert('RGB') - -def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): - """ - Generalization convolution filter capable of applying - weighted mean, median, maximum, and minimum filters - parametrically using an arbitrary kernel. - - Args: - img (nparray): - The image, a 2-D array of floats, to which the filter is being applied. - kernel (nparray): - The kernel, a 2-D array of floats. - kernel_center (nparray): - The kernel center coordinate, a 1-D array with two elements. - percentile_min (float): - The lower bound of the histogram window used by the filter, - from 0 to 1. - percentile_max (float): - The upper bound of the histogram window used by the filter, - from 0 to 1. - min_width (float): - The minimum size of the histogram window bounds, in weight units. - Must be greater than 0. - - Returns: - (nparray): A filtered copy of the input image "img", a 2-D array of floats. - """ - - # Converts an index tuple into a vector. - def vec(x): - return np.array(x) - - kernel_min = -kernel_center - kernel_max = vec(kernel.shape) - kernel_center - - def weighted_histogram_filter_single(idx): - idx = vec(idx) - min_index = np.maximum(0, idx + kernel_min) - max_index = np.minimum(vec(img.shape), idx + kernel_max) - window_shape = max_index - min_index - - class WeightedElement: - """ - An element of the histogram, its weight - and bounds. - """ - def __init__(self, value, weight): - self.value: float = value - self.weight: float = weight - self.window_min: float = 0.0 - self.window_max: float = 1.0 - - # Collect the values in the image as WeightedElements, - # weighted by their corresponding kernel values. - values = [] - for window_tup in np.ndindex(tuple(window_shape)): - window_index = vec(window_tup) - image_index = window_index + min_index - centered_kernel_index = image_index - idx - kernel_index = centered_kernel_index + kernel_center - element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) - values.append(element) - - def sort_key(x: WeightedElement): - return x.value - - values.sort(key=sort_key) - - # Calculate the height of the stack (sum) - # and each sample's range they occupy in the stack - sum = 0 - for i in range(len(values)): - values[i].window_min = sum - sum += values[i].weight - values[i].window_max = sum - - # Calculate what range of this stack ("window") - # we want to get the weighted average across. - window_min = sum * percentile_min - window_max = sum * percentile_max - window_width = window_max - window_min - - # Ensure the window is within the stack and at least a certain size. - if window_width < min_width: - window_center = (window_min + window_max) / 2 - window_min = window_center - min_width / 2 - window_max = window_center + min_width / 2 - - if window_max > sum: - window_max = sum - window_min = sum - min_width - - if window_min < 0: - window_min = 0 - window_max = min_width - - value = 0 - value_weight = 0 - - # Get the weighted average of all the samples - # that overlap with the window, weighted - # by the size of their overlap. - for i in range(len(values)): - if window_min >= values[i].window_max: - continue - if window_max <= values[i].window_min: - break - - s = max(window_min, values[i].window_min) - e = min(window_max, values[i].window_max) - w = e - s - - value += values[i].value * w - value_weight += w - - return value / value_weight if value_weight != 0 else 0 - - img_out = img.copy() - - # Apply the kernel operation over each pixel. - for index in np.ndindex(img.shape): - img_out[index] = weighted_histogram_filter_single(index) - - return img_out - -def smoothstep(x): - """ - The smoothstep function, input should be clamped to 0-1 range. - Turns a diagonal line (f(x) = x) into a sigmoid-like curve. - """ - return x * x * (3 - 2 * x) - -def smootherstep(x): - """ - The smootherstep function, input should be clamped to 0-1 range. - Turns a diagonal line (f(x) = x) into a sigmoid-like curve. - """ - return x * x * x * (x * (6 * x - 15) + 10) - - -def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): - """ - Creates a Gaussian kernel with thresholded edges. - - Args: - stddev_radius (float): - Standard deviation of the gaussian kernel, in pixels. - max_radius (int): - The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. - The kernel is thresholded so that any values one pixel beyond this radius - is weighted at 0. - - Returns: - (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) - """ - # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. - def gaussian(sqr_mag): - return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) - - # Helper function for converting a tuple to an array. - def vec(x): - return np.array(x) - - """ - Since a gaussian is unbounded, we need to limit ourselves - to a finite range. - We taper the ends off at the end of that range so they equal zero - while preserving the maximum value of 1 at the mean. - """ - zero_radius = max_radius + 1.0 - gauss_zero = gaussian(zero_radius * zero_radius) - gauss_kernel_scale = 1 / (1 - gauss_zero) - - def gaussian_kernel_func(coordinate): - x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 - x = gaussian(x) - x -= gauss_zero - x *= gauss_kernel_scale - x = max(0.0, x) - return x - - size = max_radius * 2 + 1 - kernel_center = max_radius - kernel = np.zeros((size, size)) - - for index in np.ndindex(kernel.shape): - kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) - - return kernel, kernel_center - diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 6d0cf847952..1f451b55370 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -1,4 +1,6 @@ +import numpy as np import gradio as gr +import math from modules.ui_components import InputAccordion import modules.scripts as scripts @@ -101,7 +103,6 @@ def apply_adaptive_masks( width, height, paste_to): import torch - import numpy as np import modules.processing as proc import modules.images as images from PIL import Image, ImageOps, ImageFilter @@ -115,15 +116,15 @@ def apply_adaptive_masks( latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) - kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2) masks_for_overlay = [] for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): converted_mask = distance_map.float().cpu().numpy() - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, percentile_min=0.9, percentile_max=1, min_width=1) - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, percentile_min=0.25, percentile_max=0.75, min_width=1) # The distance at which opacity of original decreases to 50% @@ -131,7 +132,7 @@ def apply_adaptive_masks( # converted_mask = converted_mask / half_weighted_distance converted_mask = 1 / (1 + converted_mask ** 2) - converted_mask = images.smootherstep(converted_mask) + converted_mask = smootherstep(converted_mask) converted_mask = 1 - converted_mask converted_mask = 255. * converted_mask converted_mask = converted_mask.astype(np.uint8) @@ -166,7 +167,6 @@ def apply_masks( width, height, paste_to): import torch - import numpy as np import modules.processing as proc import modules.images as images from PIL import Image, ImageOps, ImageFilter @@ -202,6 +202,196 @@ def apply_masks( return masks_for_overlay +def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): + """ + Generalization convolution filter capable of applying + weighted mean, median, maximum, and minimum filters + parametrically using an arbitrary kernel. + + Args: + img (nparray): + The image, a 2-D array of floats, to which the filter is being applied. + kernel (nparray): + The kernel, a 2-D array of floats. + kernel_center (nparray): + The kernel center coordinate, a 1-D array with two elements. + percentile_min (float): + The lower bound of the histogram window used by the filter, + from 0 to 1. + percentile_max (float): + The upper bound of the histogram window used by the filter, + from 0 to 1. + min_width (float): + The minimum size of the histogram window bounds, in weight units. + Must be greater than 0. + + Returns: + (nparray): A filtered copy of the input image "img", a 2-D array of floats. + """ + + # Converts an index tuple into a vector. + def vec(x): + return np.array(x) + + kernel_min = -kernel_center + kernel_max = vec(kernel.shape) - kernel_center + + def weighted_histogram_filter_single(idx): + idx = vec(idx) + min_index = np.maximum(0, idx + kernel_min) + max_index = np.minimum(vec(img.shape), idx + kernel_max) + window_shape = max_index - min_index + + class WeightedElement: + """ + An element of the histogram, its weight + and bounds. + """ + def __init__(self, value, weight): + self.value: float = value + self.weight: float = weight + self.window_min: float = 0.0 + self.window_max: float = 1.0 + + # Collect the values in the image as WeightedElements, + # weighted by their corresponding kernel values. + values = [] + for window_tup in np.ndindex(tuple(window_shape)): + window_index = vec(window_tup) + image_index = window_index + min_index + centered_kernel_index = image_index - idx + kernel_index = centered_kernel_index + kernel_center + element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) + values.append(element) + + def sort_key(x: WeightedElement): + return x.value + + values.sort(key=sort_key) + + # Calculate the height of the stack (sum) + # and each sample's range they occupy in the stack + sum = 0 + for i in range(len(values)): + values[i].window_min = sum + sum += values[i].weight + values[i].window_max = sum + + # Calculate what range of this stack ("window") + # we want to get the weighted average across. + window_min = sum * percentile_min + window_max = sum * percentile_max + window_width = window_max - window_min + + # Ensure the window is within the stack and at least a certain size. + if window_width < min_width: + window_center = (window_min + window_max) / 2 + window_min = window_center - min_width / 2 + window_max = window_center + min_width / 2 + + if window_max > sum: + window_max = sum + window_min = sum - min_width + + if window_min < 0: + window_min = 0 + window_max = min_width + + value = 0 + value_weight = 0 + + # Get the weighted average of all the samples + # that overlap with the window, weighted + # by the size of their overlap. + for i in range(len(values)): + if window_min >= values[i].window_max: + continue + if window_max <= values[i].window_min: + break + + s = max(window_min, values[i].window_min) + e = min(window_max, values[i].window_max) + w = e - s + + value += values[i].value * w + value_weight += w + + return value / value_weight if value_weight != 0 else 0 + + img_out = img.copy() + + # Apply the kernel operation over each pixel. + for index in np.ndindex(img.shape): + img_out[index] = weighted_histogram_filter_single(index) + + return img_out + +def smoothstep(x): + """ + The smoothstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * (3 - 2 * x) + +def smootherstep(x): + """ + The smootherstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * x * (x * (6 * x - 15) + 10) + + +def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): + """ + Creates a Gaussian kernel with thresholded edges. + + Args: + stddev_radius (float): + Standard deviation of the gaussian kernel, in pixels. + max_radius (int): + The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. + The kernel is thresholded so that any values one pixel beyond this radius + is weighted at 0. + + Returns: + (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) + """ + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. + def gaussian(sqr_mag): + return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) + + # Helper function for converting a tuple to an array. + def vec(x): + return np.array(x) + + """ + Since a gaussian is unbounded, we need to limit ourselves + to a finite range. + We taper the ends off at the end of that range so they equal zero + while preserving the maximum value of 1 at the mean. + """ + zero_radius = max_radius + 1.0 + gauss_zero = gaussian(zero_radius * zero_radius) + gauss_kernel_scale = 1 / (1 - gauss_zero) + + def gaussian_kernel_func(coordinate): + x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 + x = gaussian(x) + x -= gauss_zero + x *= gauss_kernel_scale + x = max(0.0, x) + return x + + size = max_radius * 2 + 1 + kernel_center = max_radius + kernel = np.zeros((size, size)) + + for index in np.ndindex(kernel.shape): + kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) + + return kernel, kernel_center + + # ------------------- Constants ------------------- @@ -232,6 +422,9 @@ def apply_masks( "inpaint_detail_preservation") +# ----- + + class Script(scripts.Script): def __init__(self): From 0ef4a4cb2365051b1e308f0136a0d8c01d071569 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 14:54:26 -0700 Subject: [PATCH 078/449] Fixed error that occurs when using vanilla samplers (somehow). --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f8d85bdf527..bea01ec6884 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -937,8 +937,8 @@ def infotext(index=0, use_main_prompt=False): p.scripts.postprocess_image(p, pp) image = pp.image - mask_for_overlay = p.mask_for_overlay - overlay_image = p.overlay_images[i] if p.overlay_images is not None and i < len(p.overlay_images) else None + mask_for_overlay = getattr(p, "mask_for_overlay", None) + overlay_image = p.overlay_images[i] if getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images) else None if p.scripts is not None: ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) From f284ae23bcdfa212cf4763659c06e124ec5b1456 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 20:19:35 -0700 Subject: [PATCH 079/449] Added parameters for the composite stage, fixed batched generation. --- scripts/soft_inpainting.py | 198 +++++++++++++++++++++++++++++-------- 1 file changed, 155 insertions(+), 43 deletions(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 1f451b55370..1b21aee9db9 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -6,22 +6,34 @@ class SoftInpaintingSettings: - def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): + def __init__(self, + mask_blend_power, + mask_blend_scale, + inpaint_detail_preservation, + composite_mask_influence, + composite_difference_threshold, + composite_difference_contrast): self.mask_blend_power = mask_blend_power self.mask_blend_scale = mask_blend_scale self.inpaint_detail_preservation = inpaint_detail_preservation + self.composite_mask_influence = composite_mask_influence + self.composite_difference_threshold = composite_difference_threshold + self.composite_difference_contrast = composite_difference_contrast def add_generation_params(self, dest): dest[enabled_gen_param_label] = True dest[gen_param_labels.mask_blend_power] = self.mask_blend_power dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence + dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold + dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast # ------------------- Methods ------------------- -def latent_blend(soft_inpainting, a, b, t): +def latent_blend(settings, a, b, t): """ Interpolates two latent image representations according to the parameter t, where the interpolated vectors' magnitudes are also interpolated separately. @@ -54,11 +66,11 @@ def latent_blend(soft_inpainting, a, b, t): # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( - soft_inpainting.inpaint_detail_preservation) * one_minus_t3 + settings.inpaint_detail_preservation) * one_minus_t3 b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( - soft_inpainting.inpaint_detail_preservation) * t3 + settings.inpaint_detail_preservation) * t3 desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) + desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation) del a_magnitude, b_magnitude, t3, one_minus_t3 # Change the linearly interpolated image vectors' magnitudes to the value we want. @@ -77,7 +89,7 @@ def latent_blend(soft_inpainting, a, b, t): return image_interp_scaled -def get_modified_nmask(soft_inpainting, nmask, sigma): +def get_modified_nmask(settings, nmask, sigma): """ Converts a negative mask representing the transparency of the original latent vectors being overlayed to a mask that is scaled according to the denoising strength for this step. @@ -93,10 +105,12 @@ def get_modified_nmask(soft_inpainting, nmask, sigma): NOTE: "mask" is not used """ import torch - return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale) def apply_adaptive_masks( + settings:SoftInpaintingSettings, + nmask, latent_orig, latent_processed, overlay_images, @@ -108,11 +122,13 @@ def apply_adaptive_masks( from PIL import Image, ImageOps, ImageFilter # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. - # latent_mask = p.nmask[0].float().cpu() + latent_mask = nmask[0].float() # convert the original mask into a form we use to scale distances for thresholding - # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) - # mask_scalar = mask_scalar / (1.00001-mask_scalar) - # mask_scalar = mask_scalar.numpy() + mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) + mask_scalar = (0.5 * (1-settings.composite_mask_influence) + + mask_scalar * settings.composite_mask_influence) + mask_scalar = mask_scalar / (1.00001-mask_scalar) + mask_scalar = mask_scalar.cpu().numpy() latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) @@ -128,10 +144,10 @@ def apply_adaptive_masks( percentile_min=0.25, percentile_max=0.75, min_width=1) # The distance at which opacity of original decreases to 50% - # half_weighted_distance = 1 # * mask_scalar - # converted_mask = converted_mask / half_weighted_distance + half_weighted_distance = settings.composite_difference_threshold * mask_scalar + converted_mask = converted_mask / half_weighted_distance - converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast) converted_mask = smootherstep(converted_mask) converted_mask = 1 - converted_mask converted_mask = 255. * converted_mask @@ -161,7 +177,7 @@ def apply_adaptive_masks( def apply_masks( - soft_inpainting, + settings, nmask, overlay_images, width, height, @@ -172,7 +188,7 @@ def apply_masks( from PIL import Image, ImageOps, ImageFilter converted_mask = nmask[0].float() - converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2) converted_mask = 255. * converted_mask converted_mask = converted_mask.cpu().numpy().astype(np.uint8) converted_mask = Image.fromarray(converted_mask) @@ -395,7 +411,7 @@ def gaussian_kernel_func(coordinate): # ------------------- Constants ------------------- -default = SoftInpaintingSettings(1, 0.5, 4) +default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2) enabled_ui_label = "Soft inpainting" enabled_gen_param_label = "Soft inpainting enabled" @@ -404,25 +420,37 @@ def gaussian_kernel_func(coordinate): ui_labels = SoftInpaintingSettings( "Schedule bias", "Preservation strength", - "Transition contrast boost") + "Transition contrast boost", + "Mask influence", + "Difference threshold", + "Difference contrast") ui_info = SoftInpaintingSettings( "Shifts when preservation of original content occurs during denoising.", "How strongly partially masked content should be preserved.", - "Amplifies the contrast that may be lost in partially masked regions.") + "Amplifies the contrast that may be lost in partially masked regions.", + "How strongly the original mask should bias the difference threshold.", + "How much an image region can change before the original pixels are not blended in anymore.", + "How sharp the transition should be between blended and not blended.") gen_param_labels = SoftInpaintingSettings( "Soft inpainting schedule bias", "Soft inpainting preservation strength", - "Soft inpainting transition contrast boost") + "Soft inpainting transition contrast boost", + "Soft inpainting mask influence", + "Soft inpainting difference threshold", + "Soft inpainting difference contrast") el_ids = SoftInpaintingSettings( "mask_blend_power", "mask_blend_scale", - "inpaint_detail_preservation") + "inpaint_detail_preservation", + "composite_mask_influence", + "composite_difference_threshold", + "composite_difference_contrast") -# ----- +# ------------------- Script ------------------- class Script(scripts.Script): @@ -449,28 +477,62 @@ def ui(self, is_img2img): **High _Mask blur_** values are recommended! """) - result = SoftInpaintingSettings( + power = \ gr.Slider(label=ui_labels.mask_blend_power, info=ui_info.mask_blend_power, minimum=0, maximum=8, step=0.1, value=default.mask_blend_power, - elem_id=el_ids.mask_blend_power), + elem_id=el_ids.mask_blend_power) + scale = \ gr.Slider(label=ui_labels.mask_blend_scale, info=ui_info.mask_blend_scale, minimum=0, maximum=8, step=0.05, value=default.mask_blend_scale, - elem_id=el_ids.mask_blend_scale), + elem_id=el_ids.mask_blend_scale) + detail = \ gr.Slider(label=ui_labels.inpaint_detail_preservation, info=ui_info.inpaint_detail_preservation, minimum=1, maximum=32, step=0.5, value=default.inpaint_detail_preservation, - elem_id=el_ids.inpaint_detail_preservation)) + elem_id=el_ids.inpaint_detail_preservation) + + gr.Markdown( + """ + ### Pixel Composite Settings + """) + + mask_inf = \ + gr.Slider(label=ui_labels.composite_mask_influence, + info=ui_info.composite_mask_influence, + minimum=0, + maximum=1, + step=0.05, + value=default.composite_mask_influence, + elem_id=el_ids.composite_mask_influence) + + dif_thresh = \ + gr.Slider(label=ui_labels.composite_difference_threshold, + info=ui_info.composite_difference_threshold, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_threshold, + elem_id=el_ids.composite_difference_threshold) + + dif_contr = \ + gr.Slider(label=ui_labels.composite_difference_contrast, + info=ui_info.composite_difference_contrast, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_contrast, + elem_id=el_ids.composite_difference_contrast) with gr.Accordion("Help", open=False): gr.Markdown( @@ -507,41 +569,86 @@ def ui(self, is_img2img): - **High values**: Stronger contrast, may over-saturate colors. """) + gr.Markdown( + """ + ## Pixel Composite Settings + + Masks are generated based on how much a part of the image changed after denoising. + These masks are used to blend the original and final images together. + If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_mask_influence} + + This parameter controls how much the mask should bias this sensitivity to difference. + + - **0**: Ignore the mask, only consider differences in image content. + - **1**: Follow the mask closely despite image content changes. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_threshold} + + This value represents the difference at which the opacity of the original pixels will have less than 50% opacity. + + - **Low values**: Two images patches must be almost the same in order to retain original pixels. + - **High values**: Two images patches can be very different and still retain original pixels. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_contrast} + + This value represents the difference at which the opacity of the original pixels will have less than 50% opacity. + + - **Low values**: Two images patches must be almost the same in order to retain original pixels. + - **High values**: Two images patches can be very different and still retain original pixels. + """) + self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), - (result.mask_blend_power, gen_param_labels.mask_blend_power), - (result.mask_blend_scale, gen_param_labels.mask_blend_scale), - (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)] + (power, gen_param_labels.mask_blend_power), + (scale, gen_param_labels.mask_blend_scale), + (detail, gen_param_labels.inpaint_detail_preservation), + (mask_inf, gen_param_labels.composite_mask_influence), + (dif_thresh, gen_param_labels.composite_difference_threshold), + (dif_contr, gen_param_labels.composite_difference_contrast)] self.paste_field_names = [] for _, field_name in self.infotext_fields: self.paste_field_names.append(field_name) return [soft_inpainting_enabled, - result.mask_blend_power, - result.mask_blend_scale, - result.inpaint_detail_preservation] - - def process(self, p, enabled, power, scale, detail_preservation): + power, + scale, + detail, + mask_inf, + dif_thresh, + dif_contr] + + def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return # Shut off the rounding it normally does. p.mask_round = False - settings = SoftInpaintingSettings(power, scale, detail_preservation) + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) # p.extra_generation_params["Mask rounding"] = False settings.add_generation_params(p.extra_generation_params) - def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation): + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return - if mba.sigma is None: + if mba.is_final_blend: mba.blended_latent = mba.current_latent return - settings = SoftInpaintingSettings(power, scale, detail_preservation) + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) # todo: Why is sigma 2D? Both values are the same. mba.blended_latent = latent_blend(settings, @@ -549,11 +656,11 @@ def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, de mba.current_latent, get_modified_nmask(settings, mba.nmask, mba.sigma[0])) - def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation): + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return - settings = SoftInpaintingSettings(power, scale, detail_preservation) + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) from modules import images from modules.shared import opts @@ -570,15 +677,20 @@ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, deta self.overlay_images.append(image.convert('RGBA')) + if len(p.init_images) == 1: + self.overlay_images = self.overlay_images * p.batch_size + if getattr(ps.samples, 'already_decoded', False): - self.masks_for_overlay = apply_masks(soft_inpainting=settings, + self.masks_for_overlay = apply_masks(settings=settings, nmask=p.nmask, overlay_images=self.overlay_images, width=p.width, height=p.height, paste_to=p.paste_to) else: - self.masks_for_overlay = apply_adaptive_masks(latent_orig=p.init_latent, + self.masks_for_overlay = apply_adaptive_masks(settings=settings, + nmask=p.nmask, + latent_orig=p.init_latent, latent_processed=ps.samples, overlay_images=self.overlay_images, width=p.width, @@ -586,7 +698,7 @@ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, deta paste_to=p.paste_to) - def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation): + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return From fc3e246c0f4f292c33b181a902cd934629ff0d7a Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 20:28:38 -0700 Subject: [PATCH 080/449] Fixed complaint about whitespace, updated help section for a parameter. --- scripts/soft_inpainting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 1b21aee9db9..6fb5cfbd074 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -572,7 +572,7 @@ def ui(self, is_img2img): gr.Markdown( """ ## Pixel Composite Settings - + Masks are generated based on how much a part of the image changed after denoising. These masks are used to blend the original and final images together. If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process. @@ -602,10 +602,10 @@ def ui(self, is_img2img): f""" ### {ui_labels.composite_difference_contrast} - This value represents the difference at which the opacity of the original pixels will have less than 50% opacity. + This value represents the contrast between the opacity of the original and inpainted content. - - **Low values**: Two images patches must be almost the same in order to retain original pixels. - - **High values**: Two images patches can be very different and still retain original pixels. + - **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting. + - **High values**: Ghosting will be less common, but transitions may be very sudden. """) self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), From 659f62e120b210e3043712ff928e8b7b6cd6cf61 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 21:39:54 -0700 Subject: [PATCH 081/449] Fixed grammar error. --- scripts/soft_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 6fb5cfbd074..51f9ca2fed2 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -592,7 +592,7 @@ def ui(self, is_img2img): f""" ### {ui_labels.composite_difference_threshold} - This value represents the difference at which the opacity of the original pixels will have less than 50% opacity. + This value represents the difference at which the original pixels will have less than 50% opacity. - **Low values**: Two images patches must be almost the same in order to retain original pixels. - **High values**: Two images patches can be very different and still retain original pixels. From 16bdcce92d5b482d50cdc32a8f308040d320b6c9 Mon Sep 17 00:00:00 2001 From: Rene Kroon Date: Fri, 8 Dec 2023 21:19:29 +0100 Subject: [PATCH 082/449] #13354: solve lora loading issue --- extensions-builtin/Lora/networks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 7f814706adb..629bf85376d 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -159,7 +159,8 @@ def load_network(name, network_on_disk): bundle_embeddings = {} for key_network, weight in sd.items(): - key_network_without_network_parts, network_part = key_network.split(".", 1) + key_network_without_network_parts, _, network_part = key_network.partition(".") + if key_network_without_network_parts == "bundle_emb": emb_name, vec_name = network_part.split(".", 1) emb_dict = bundle_embeddings.get(emb_name, {}) From b2414476ef164ba55cff2508c58b73d23bbc3000 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Fri, 8 Dec 2023 17:32:41 -0700 Subject: [PATCH 083/449] soft_inpainting now appears in the "inpaint" section, and will not activate unless inpainting is activated. --- scripts/soft_inpainting.py | 43 ++++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 51f9ca2fed2..f10a1e5627c 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -32,6 +32,19 @@ def add_generation_params(self, dest): # ------------------- Methods ------------------- +def processing_uses_inpainting(p): + # TODO: Figure out a better way to determine if inpainting is being used by p + if getattr(p, "image_mask", None) is not None: + return True + + if getattr(p, "mask", None) is not None: + return True + + if getattr(p, "nmask", None) is not None: + return True + + return False + def latent_blend(settings, a, b, t): """ @@ -454,8 +467,8 @@ def gaussian_kernel_func(coordinate): class Script(scripts.Script): - def __init__(self): + self.section = "inpaint" self.masks_for_overlay = None self.overlay_images = None @@ -632,6 +645,9 @@ def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_t if not enabled: return + if not processing_uses_inpainting(p): + return + # Shut off the rounding it normally does. p.mask_round = False @@ -644,6 +660,9 @@ def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, de if not enabled: return + if not processing_uses_inpainting(p): + return + if mba.is_final_blend: mba.blended_latent = mba.current_latent return @@ -660,11 +679,18 @@ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, deta if not enabled: return - settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + if not processing_uses_inpainting(p): + return + + nmask = getattr(p, "nmask", None) + if nmask is None: + return from modules import images from modules.shared import opts + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + # since the original code puts holes in the existing overlay images, # we have to rebuild them. self.overlay_images = [] @@ -682,14 +708,14 @@ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, deta if getattr(ps.samples, 'already_decoded', False): self.masks_for_overlay = apply_masks(settings=settings, - nmask=p.nmask, + nmask=nmask, overlay_images=self.overlay_images, width=p.width, height=p.height, paste_to=p.paste_to) else: self.masks_for_overlay = apply_adaptive_masks(settings=settings, - nmask=p.nmask, + nmask=nmask, latent_orig=p.init_latent, latent_processed=ps.samples, overlay_images=self.overlay_images, @@ -702,5 +728,14 @@ def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, e if not enabled: return + if not processing_uses_inpainting(p): + return + + if self.masks_for_overlay is None: + return + + if self.overlay_images is None: + return + ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] ppmo.overlay_image = self.overlay_images[ppmo.index] From f1ff932cafa2bf34fa35f41072f21a8ea5474d84 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Fri, 8 Dec 2023 17:33:11 -0700 Subject: [PATCH 084/449] Formatted soft_inpainting. --- scripts/soft_inpainting.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index f10a1e5627c..d9024344251 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -122,7 +122,7 @@ def get_modified_nmask(settings, nmask, sigma): def apply_adaptive_masks( - settings:SoftInpaintingSettings, + settings: SoftInpaintingSettings, nmask, latent_orig, latent_processed, @@ -137,10 +137,10 @@ def apply_adaptive_masks( # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. latent_mask = nmask[0].float() # convert the original mask into a form we use to scale distances for thresholding - mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) - mask_scalar = (0.5 * (1-settings.composite_mask_influence) + mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) + mask_scalar = (0.5 * (1 - settings.composite_mask_influence) + mask_scalar * settings.composite_mask_influence) - mask_scalar = mask_scalar / (1.00001-mask_scalar) + mask_scalar = mask_scalar / (1.00001 - mask_scalar) mask_scalar = mask_scalar.cpu().numpy() latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) @@ -152,9 +152,9 @@ def apply_adaptive_masks( for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): converted_mask = distance_map.float().cpu().numpy() converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.9, percentile_max=1, min_width=1) + percentile_min=0.9, percentile_max=1, min_width=1) converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.25, percentile_max=0.75, min_width=1) + percentile_min=0.25, percentile_max=0.75, min_width=1) # The distance at which opacity of original decreases to 50% half_weighted_distance = settings.composite_difference_threshold * mask_scalar @@ -276,6 +276,7 @@ class WeightedElement: An element of the histogram, its weight and bounds. """ + def __init__(self, value, weight): self.value: float = value self.weight: float = weight @@ -355,6 +356,7 @@ def sort_key(x: WeightedElement): return img_out + def smoothstep(x): """ The smoothstep function, input should be clamped to 0-1 range. @@ -362,6 +364,7 @@ def smoothstep(x): """ return x * x * (3 - 2 * x) + def smootherstep(x): """ The smootherstep function, input should be clamped to 0-1 range. @@ -385,6 +388,7 @@ def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): Returns: (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) """ + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. def gaussian(sqr_mag): return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) @@ -656,7 +660,8 @@ def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_t # p.extra_generation_params["Mask rounding"] = False settings.add_generation_params(p.extra_generation_params) - def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): if not enabled: return @@ -675,7 +680,8 @@ def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, de mba.current_latent, get_modified_nmask(settings, mba.nmask, mba.sigma[0])) - def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): if not enabled: return @@ -723,8 +729,8 @@ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, deta height=p.height, paste_to=p.paste_to) - - def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, + detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return From 59429793440fb3cb1624ddcc702c6f9807373203 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 9 Dec 2023 18:09:45 +0800 Subject: [PATCH 085/449] Fix ControlNet --- modules/xpu_specific.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index ec1ad100aae..9bb0a56155a 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -51,3 +51,9 @@ def torch_xpu_gc(): CondFunc('torch.bmm', lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) + CondFunc('torch.cat', + lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), + lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) + CondFunc('torch.nn.functional.scaled_dot_product_attention', + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) \ No newline at end of file From 049d5642e58d572ee8657ac754e72d019eea0e6c Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 9 Dec 2023 18:11:26 +0800 Subject: [PATCH 086/449] Fix format --- modules/xpu_specific.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 9bb0a56155a..d8da94a0efd 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -56,4 +56,4 @@ def torch_xpu_gc(): lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) CondFunc('torch.nn.functional.scaled_dot_product_attention', lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) \ No newline at end of file + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) From 39ec4cfea9040bc94e639eb4aa8ab8ed37a68f01 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sat, 9 Dec 2023 19:12:59 +0600 Subject: [PATCH 087/449] Re-add setting lost as part of e294e46 --- modules/shared_options.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared_options.py b/modules/shared_options.py index e5de0d018fb..acb6e2d4819 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -256,6 +256,7 @@ "keyedit_precision_extra": OptionInfo(0.05, "Precision for when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"), "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), + "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), })) From 9c201550ddae0b33367adfb99bcbb57ba9b207a9 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sat, 9 Dec 2023 21:04:45 +0600 Subject: [PATCH 088/449] Add keyboard shortcuts for generation (Removed Alt+Enter) Ctrl+Enter to start/restart generation (New) Alt/Option+Enter to skip generation (New) Ctrl+Alt/Option+Enter to interrupt generation --- modules/ui_toprow.py | 4 ++-- script.js | 23 +++++++++++++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index 88838f97749..c3865e3d91d 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -79,11 +79,11 @@ def create_inline_toprow_image(self): def create_prompts(self): with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6): with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]): - self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Ctrl+Alt+Enter to interrupt)", elem_classes=["prompt"]) self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False) with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]): - self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Ctrl+Alt+Enter to interrupt)", elem_classes=["prompt"]) self.prompt_img.change( fn=modules.images.image_data, diff --git a/script.js b/script.js index c0e678ea702..69598f45454 100644 --- a/script.js +++ b/script.js @@ -121,16 +121,21 @@ document.addEventListener("DOMContentLoaded", function() { }); /** - * Add a ctrl+enter as a shortcut to start a generation + * Add keyboard shortcuts: + * Ctrl+Enter to start/restart a generation + * Alt/Option+Enter to skip a generation + * Alt/Option+Ctrl+Enter to interrupt a generation */ document.addEventListener('keydown', function(e) { const isEnter = e.key === 'Enter' || e.keyCode === 13; - const isModifierKey = e.metaKey || e.ctrlKey || e.altKey; + const isCtrlKey = e.metaKey || e.ctrlKey; + const isAltKey = e.altKey; - const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); + const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); + const skipButton = get_uiCurrentTabContent().querySelector('button[id$=_skip]'); - if (isEnter && isModifierKey) { + if (isCtrlKey && isEnter && !isAltKey) { if (interruptButton.style.display === 'block') { interruptButton.click(); const callback = (mutationList) => { @@ -150,6 +155,16 @@ document.addEventListener('keydown', function(e) { } e.preventDefault(); } + + if (isAltKey && isEnter && !isCtrlKey) { + skipButton.click(); + e.preventDefault(); + } + + if (isAltKey && isCtrlKey && isEnter) { + interruptButton.click(); + e.preventDefault(); + } }); /** From 1a79a5049bdfef285235e83f37b201e39dd54f81 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sat, 9 Dec 2023 22:35:31 +0600 Subject: [PATCH 089/449] Assign id for "extra_options". Replace numeric field with slider in Settings. --- .../extra-options-section/scripts/extra_options_section.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index a903df62548..b9867fe63b8 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -23,11 +23,12 @@ def ui(self, is_img2img): self.setting_names = [] self.infotext_fields = [] extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img + elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group(): + with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) @@ -70,7 +71,7 @@ def before_process(self, p, *args): """), "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), - "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 6}).needs_reload_ui(), "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() })) From 5381405eaa1e809e5cfb97522bd4c19d3c946079 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 9 Dec 2023 14:09:28 -0500 Subject: [PATCH 090/449] re-derive sqrt alpha bar and sqrt one minus alphabar This is the only place these values are ever referenced outside of training code so this change is very justifiable and more consistent. --- modules/sd_samplers_timesteps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index b17a8f93c2b..c4bd5c12775 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -36,7 +36,7 @@ def __init__(self, model, *args, **kwargs): self.inner_model = model def predict_eps_from_z_and_v(self, x_t, t, v): - return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t + return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t def forward(self, input, timesteps, **kwargs): model_output = self.inner_model.apply_model(input, timesteps, **kwargs) From 23a0e60b9bf90a80f8af9732cc6495fbfce2ea21 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 10 Dec 2023 14:03:41 +0900 Subject: [PATCH 091/449] fix save styles --- modules/styles.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 7fb6c2e1177..07588945eba 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -155,10 +155,8 @@ def load_from_csv(self, path: str): row["name"], prompt, negative_prompt, path ) - def get_style_paths(self) -> list(): - """ - Returns a list of all distinct paths, including the default path, of - files that styles are loaded from.""" + def get_style_paths(self) -> set: + """Returns a set of all distinct paths of files that styles are loaded from.""" # Update any styles without a path to the default path for style in list(self.styles.values()): if not style.path: @@ -172,9 +170,9 @@ def get_style_paths(self) -> list(): style_paths.add(style.path) # Remove any paths for styles that are just list dividers - style_paths.remove("do_not_save") + style_paths.discard("do_not_save") - return list(style_paths) + return style_paths def get_style_prompts(self, styles): return [self.styles.get(x, self.no_style).prompt for x in styles] @@ -196,20 +194,7 @@ def save_styles(self, path: str = None) -> None: # The path argument is deprecated, but kept for backwards compatibility _ = path - # Update any styles without a path to the default path - for style in list(self.styles.values()): - if not style.path: - self.styles[style.name] = style._replace(path=self.default_path) - - # Create a list of all distinct paths, including the default path - style_paths = set() - style_paths.add(self.default_path) - for _, style in self.styles.items(): - if style.path: - style_paths.add(style.path) - - # Remove any paths for styles that are just list dividers - style_paths.remove("do_not_save") + style_paths = self.get_style_paths() csv_names = [os.path.split(path)[1].lower() for path in style_paths] From 8b74389e76a7678e972583ef16100e90e1519e55 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 10 Dec 2023 15:48:16 +0900 Subject: [PATCH 092/449] fix styles.csv filename --- modules/styles.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 07588945eba..81d9800d184 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -98,10 +98,8 @@ def __init__(self, path: str): self.path = path folder, file = os.path.split(self.path) - self.default_file = file.split("*")[0] + ".csv" - if self.default_file == ".csv": - self.default_file = "styles.csv" - self.default_path = os.path.join(folder, self.default_file) + filename, _, ext = file.partition('*') + self.default_path = os.path.join(folder, filename + ext) self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] From 6b8143a84e112f029ee1868b6ab98b1d2c773ead Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sun, 10 Dec 2023 15:35:06 +0600 Subject: [PATCH 093/449] Number of columns slider: max count set to 20, add description info --- .../extra-options-section/scripts/extra_options_section.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index b9867fe63b8..ac2c3de4643 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -71,7 +71,7 @@ def before_process(self, p, *args): """), "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), - "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 6}).needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(), "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() })) From 1d42babd324b933bae317cb427fe0513138954f4 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sun, 10 Dec 2023 16:28:56 +0600 Subject: [PATCH 094/449] Replace Ctrl+Alt+Enter with Esc --- modules/ui_toprow.py | 4 ++-- script.js | 17 +++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index c3865e3d91d..9caf8faa2de 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -79,11 +79,11 @@ def create_inline_toprow_image(self): def create_prompts(self): with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6): with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]): - self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Ctrl+Alt+Enter to interrupt)", elem_classes=["prompt"]) + self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"]) self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False) with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]): - self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Ctrl+Alt+Enter to interrupt)", elem_classes=["prompt"]) + self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"]) self.prompt_img.change( fn=modules.images.image_data, diff --git a/script.js b/script.js index 69598f45454..44950090aba 100644 --- a/script.js +++ b/script.js @@ -124,18 +124,19 @@ document.addEventListener("DOMContentLoaded", function() { * Add keyboard shortcuts: * Ctrl+Enter to start/restart a generation * Alt/Option+Enter to skip a generation - * Alt/Option+Ctrl+Enter to interrupt a generation + * Esc to interrupt a generation */ document.addEventListener('keydown', function(e) { const isEnter = e.key === 'Enter' || e.keyCode === 13; const isCtrlKey = e.metaKey || e.ctrlKey; const isAltKey = e.altKey; + const isEsc = e.key === 'Escape'; const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); const skipButton = get_uiCurrentTabContent().querySelector('button[id$=_skip]'); - if (isCtrlKey && isEnter && !isAltKey) { + if (isCtrlKey && isEnter) { if (interruptButton.style.display === 'block') { interruptButton.click(); const callback = (mutationList) => { @@ -156,14 +157,18 @@ document.addEventListener('keydown', function(e) { e.preventDefault(); } - if (isAltKey && isEnter && !isCtrlKey) { + if (isAltKey && isEnter) { skipButton.click(); e.preventDefault(); } - if (isAltKey && isCtrlKey && isEnter) { - interruptButton.click(); - e.preventDefault(); + if (isEsc) { + if (!globalPopup || globalPopup.style.display === "none") { + interruptButton.click(); + e.preventDefault(); + } else { + closePopup(); + } } }); From cee1a4065162982e18f32761259d9107538c2d93 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sun, 10 Dec 2023 17:06:12 +0600 Subject: [PATCH 095/449] Fix linter issues --- script.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/script.js b/script.js index 44950090aba..354154b01d6 100644 --- a/script.js +++ b/script.js @@ -163,11 +163,13 @@ document.addEventListener('keydown', function(e) { } if (isEsc) { + const globalPopup = document.querySelector('.global-popup'); if (!globalPopup || globalPopup.style.display === "none") { interruptButton.click(); e.preventDefault(); } else { - closePopup(); + if (!globalPopup) return; + globalPopup.style.display = "none"; } } }); From 6513470f0db1aed1b0a5200634e8e02f7c05e932 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Mon, 11 Dec 2023 18:06:08 +0600 Subject: [PATCH 096/449] Remove unnecessary 'else', add 'lightboxModal' check --- script.js | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/script.js b/script.js index 354154b01d6..be1bc317e2b 100644 --- a/script.js +++ b/script.js @@ -164,12 +164,11 @@ document.addEventListener('keydown', function(e) { if (isEsc) { const globalPopup = document.querySelector('.global-popup'); - if (!globalPopup || globalPopup.style.display === "none") { + const lightboxModal = document.querySelector('#lightboxModal'); + if (!globalPopup || globalPopup.style.display === 'none') { + if (document.activeElement === lightboxModal) return; interruptButton.click(); e.preventDefault(); - } else { - if (!globalPopup) return; - globalPopup.style.display = "none"; } } }); From cc41cc4349514bbfeb9f37445c931a050b076bd6 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 13 Dec 2023 02:06:56 +0900 Subject: [PATCH 097/449] on mouse hover show / hide modal image viewer icons --- style.css | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/style.css b/style.css index ee39a57b73e..ec449bdea44 100644 --- a/style.css +++ b/style.css @@ -749,6 +749,22 @@ table.popup-table .link{ display: none; } +@media (pointer: fine) { + .modalPrev:hover, + .modalNext:hover, + .modalControls:hover ~ .modalPrev, + .modalControls:hover ~ .modalNext, + .modalControls:hover .cursor { + opacity: 1; + } + + .modalPrev, + .modalNext, + .modalControls .cursor { + opacity: 0; + } +} + /* context menu (ie for the generate button) */ #context-menu{ From bda86f0fd9653657c146f7c1128f92771d16ad4e Mon Sep 17 00:00:00 2001 From: Hina <102651522+HinaHyugaHime@users.noreply.github.com> Date: Tue, 12 Dec 2023 19:39:14 -0600 Subject: [PATCH 098/449] Update webui.sh --- webui.sh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/webui.sh b/webui.sh index 3d0f87eed74..046ecf9d001 100755 --- a/webui.sh +++ b/webui.sh @@ -131,7 +131,7 @@ case "$gpu_info" in if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]] then # Navi users will still use torch 1.13 because 2.0 does not seem to work. - export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2" + export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6" else printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m" exit 1 @@ -141,9 +141,8 @@ case "$gpu_info" in *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 ;; *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \ - export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6" - # Navi 3 needs at least 5.5 which is only on the nightly chain, previous versions are no longer online (torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 torchaudio==2.1.0.dev-20230614+rocm5.5) - # so switch to nightly rocm5.6 without explicit versions this time + export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7" + ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 printf "\n%s\n" "${delimiter}" From 89cfbc3bbe401fe1655afb07edbae34ec6af7aca Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 13 Dec 2023 12:22:13 +0200 Subject: [PATCH 099/449] Allow pasting in WIDTHxHEIGHT strings into the width/height fields --- javascript/ui.js | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/javascript/ui.js b/javascript/ui.js index 410fc44e3d9..18c9f891afc 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -215,9 +215,33 @@ function restoreProgressImg2img() { } +/** + * Configure the width and height elements on `tabname` to accept + * pasting of resolutions in the form of "width x height". + */ +function setupResolutionPasting(tabname) { + var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`); + var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`); + for (const el of [width, height]) { + el.addEventListener('paste', function(event) { + var pasteData = event.clipboardData.getData('text/plain'); + var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/); + if (parsed) { + width.value = parsed[1]; + height.value = parsed[2]; + updateInput(width); + updateInput(height); + event.preventDefault(); + } + }); + } +} + onUiLoaded(function() { showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); showRestoreProgressButton('img2img', localGet("img2img_task_id")); + setupResolutionPasting('txt2img'); + setupResolutionPasting('img2img'); }); From 735c9e8059384d4f640e5582413c30871f83eac5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:38:32 +0800 Subject: [PATCH 100/449] Fix network_oft --- extensions-builtin/Lora/network_oft.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 05c3781187a..44465f7aa26 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -53,12 +53,17 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - def calc_updown_kb(self, orig_weight, multiplier): + def calc_updown(self, orig_weight): + I = torch.eye(self.block_size, device=self.oft_blocks.device) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + if self.is_kohya: + block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + oft_blocks = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -70,15 +75,10 @@ def calc_updown_kb(self, orig_weight, multiplier): merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - def calc_updown(self, orig_weight): - # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it - multiplier = self.multiplier() - return self.calc_updown_kb(orig_weight, multiplier) - - # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) @@ -94,4 +94,5 @@ def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if ex_bias is not None: ex_bias = ex_bias * self.multiplier() - return updown, ex_bias + # Ignore calc_scale, which is not used in OFT. + return updown * self.multiplier(), ex_bias From 265bc26c21264d63956e8f30f1ce31dec917fc76 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:43:24 +0800 Subject: [PATCH 101/449] Use self.scale instead of custom finalize --- extensions-builtin/Lora/network_oft.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 44465f7aa26..e3ae61a22b4 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -21,6 +21,8 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + self.scale = 1.0 + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -78,21 +80,3 @@ def calc_updown(self, orig_weight): print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - - def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - if ex_bias is not None: - ex_bias = ex_bias * self.multiplier() - - # Ignore calc_scale, which is not used in OFT. - return updown * self.multiplier(), ex_bias From 8fc67f3851babd4575d3312b931d5e7c2b0c78c6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:44:49 +0800 Subject: [PATCH 102/449] remove debug print --- extensions-builtin/Lora/network_oft.py | 1 - 1 file changed, 1 deletion(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e3ae61a22b4..ff4eb59b1c2 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -77,6 +77,5 @@ def calc_updown(self, orig_weight): merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) From 3772a82a70769fe1aac884a75bf5a3313fb83328 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:47:13 +0800 Subject: [PATCH 103/449] better naming and correct order for device. --- extensions-builtin/Lora/network_oft.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ff4eb59b1c2..fa647020f0a 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -56,14 +56,15 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) def calc_updown(self, orig_weight): - I = torch.eye(self.block_size, device=self.oft_blocks.device) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + eye = torch.eye(self.block_size, device=self.oft_blocks.device) + if self.is_kohya: block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - oft_blocks = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) From 3c0c27757944ae17a7fa4c2323ee9ae2d434dbce Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 14 Dec 2023 19:36:17 +0900 Subject: [PATCH 104/449] default False js_live_preview_in_modal_lightbox --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 41097d8e6c9..d2e86ff10b3 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -331,7 +331,7 @@ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), - "js_live_preview_in_modal_lightbox": OptionInfo(True, "Show Live preview in full page image viewer"), + "js_live_preview_in_modal_lightbox": OptionInfo(False, "Show Live preview in full page image viewer"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { From 0c5427960b3a4ffe6d673c28e8e135b26f015717 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 15 Dec 2023 17:11:59 +0900 Subject: [PATCH 105/449] make modal toolbar and icon opacity adjustable --- modules/shared_gradio_themes.py | 4 ++++ modules/shared_options.py | 2 ++ style.css | 4 ++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/shared_gradio_themes.py b/modules/shared_gradio_themes.py index 822db0a951d..b6dc31450bc 100644 --- a/modules/shared_gradio_themes.py +++ b/modules/shared_gradio_themes.py @@ -65,3 +65,7 @@ def reload_gradio_theme(theme_name=None): except Exception as e: errors.display(e, "changing gradio theme") shared.gradio_theme = gr.themes.Default(**default_theme_args) + + # append additional values gradio_theme + shared.gradio_theme.sd_webui_modal_lightbox_toolbar_opacity = shared.opts.sd_webui_modal_lightbox_toolbar_opacity + shared.gradio_theme.sd_webui_modal_lightbox_icon_opacity = shared.opts.sd_webui_modal_lightbox_icon_opacity diff --git a/modules/shared_options.py b/modules/shared_options.py index e5de0d018fb..86e7636cc58 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -266,6 +266,8 @@ "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Full page image viewer: show images zoomed in by default"), "js_modal_lightbox_gamepad": OptionInfo(False, "Full page image viewer: navigate with gamepad"), "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Full page image viewer: gamepad repeat period").info("in milliseconds"), + "sd_webui_modal_lightbox_icon_opacity": OptionInfo(1, "Full page image viewer: control icon unfocused opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(), + "sd_webui_modal_lightbox_toolbar_opacity": OptionInfo(0.9, "Full page image viewer: tool bar opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(), "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(), })) diff --git a/style.css b/style.css index ec449bdea44..6d4c8a0d505 100644 --- a/style.css +++ b/style.css @@ -679,7 +679,7 @@ table.popup-table .link{ transition: 0.2s ease background-color; } .modalControls:hover { - background-color:rgba(0,0,0,0.9); + background-color:rgba(0,0,0, var(--sd-webui-modal-lightbox-toolbar-opacity)); } .modalClose { margin-left: auto; @@ -761,7 +761,7 @@ table.popup-table .link{ .modalPrev, .modalNext, .modalControls .cursor { - opacity: 0; + opacity: var(--sd-webui-modal-lightbox-icon-opacity); } } From 1242ba08e19f3d317bdc5924db2b73d0c9569a7f Mon Sep 17 00:00:00 2001 From: gayshub Date: Fri, 15 Dec 2023 16:57:17 +0800 Subject: [PATCH 106/449] add allow specify the task id and get the location of task in the queue of pending task --- modules/api/api.py | 20 ++++++++++++++++++-- modules/api/models.py | 2 ++ modules/processing.py | 2 ++ modules/progress.py | 21 +++++++++++++++++++-- 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index e6edffe7144..5d000ae8bf6 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -33,7 +33,7 @@ import piexif import piexif.helper from contextlib import closing - +from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task def script_name_to_index(name, scripts): try: @@ -337,6 +337,10 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel return script_args def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): + task_id = create_task_id("text2img") + if txt2imgreq.force_task_id != None: + task_id = txt2imgreq.force_task_id + script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) @@ -363,6 +367,8 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): send_images = args.pop('send_images', True) args.pop('save_images', None) + add_task_to_queue(task_id) + with self.queue_lock: with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: p.is_api = True @@ -372,12 +378,14 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): try: shared.state.begin(job="scripts_txt2img") + start_task(task_id) if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here else: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) + finish_task(task_id) finally: shared.state.end() shared.total_tqdm.clear() @@ -387,6 +395,10 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): + task_id = create_task_id("img2img") + if img2imgreq.force_task_id != None: + task_id = img2imgreq.force_task_id + init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") @@ -423,6 +435,8 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): send_images = args.pop('send_images', True) args.pop('save_images', None) + add_task_to_queue(task_id) + with self.queue_lock: with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: p.init_images = [decode_base64_to_image(x) for x in init_images] @@ -433,12 +447,14 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): try: shared.state.begin(job="scripts_img2img") + start_task(task_id) if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here else: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) + finish_task(task_id) finally: shared.state.end() shared.total_tqdm.clear() @@ -514,7 +530,7 @@ def progressapi(self, req: models.ProgressRequest = Depends()): if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) - return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) + return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task) def interrogateapi(self, interrogatereq: models.InterrogateRequest): image_b64 = interrogatereq.image diff --git a/modules/api/models.py b/modules/api/models.py index 6a574771c33..7b7f1773866 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -109,6 +109,7 @@ def generate_model(self): {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, + {"key": "force_task_id", "type": str, "default": None}, ] ).generate_model() @@ -126,6 +127,7 @@ def generate_model(self): {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, + {"key": "force_task_id", "type": str, "default": None}, ] ).generate_model() diff --git a/modules/processing.py b/modules/processing.py index e124e7f0dd2..657cacfcace 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1023,6 +1023,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): hr_sampler_name: str = None hr_prompt: str = '' hr_negative_prompt: str = '' + force_task_id: str = None cached_hr_uc = [None, None] cached_hr_c = [None, None] @@ -1358,6 +1359,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): inpainting_mask_invert: int = 0 initial_noise_multiplier: float = None latent_mask: Image = None + force_task_id: string = None image_mask: Any = field(default=None, init=False) diff --git a/modules/progress.py b/modules/progress.py index 69921de7281..553866db113 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -8,10 +8,13 @@ from modules.shared import opts import modules.shared as shared - +from collections import OrderedDict +import string +import random +from typing import List current_task = None -pending_tasks = {} +pending_tasks = OrderedDict() finished_tasks = [] recorded_results = [] recorded_results_limit = 2 @@ -34,6 +37,11 @@ def finish_task(id_task): if len(finished_tasks) > 16: finished_tasks.pop(0) +def create_task_id(task_type): + N = 7 + res = ''.join(random.choices(string.ascii_uppercase + + string.digits, k=N)) + return f"task({task_type}-{res})" def record_results(id_task, res): recorded_results.append((id_task, res)) @@ -44,6 +52,9 @@ def record_results(id_task, res): def add_task_to_queue(id_job): pending_tasks[id_job] = time.time() +class PendingTasksResponse(BaseModel): + size: int = Field(title="Pending task size") + tasks: List[str] = Field(title="Pending task ids") class ProgressRequest(BaseModel): id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") @@ -63,8 +74,14 @@ class ProgressResponse(BaseModel): def setup_progress_api(app): + app.add_api_route("/internal/pendingTasks", get_pending_tasks, methods=["GET"]) return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) +def get_pending_tasks(): + pending_tasks_ids = [x for x in pending_tasks] + pending_len = len(pending_tasks_ids) + return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids) + def progressapi(req: ProgressRequest): active = req.id_task == current_task From d859de37d9ec10cb6c804226328a11c87c444852 Mon Sep 17 00:00:00 2001 From: gayshub Date: Fri, 15 Dec 2023 17:48:20 +0800 Subject: [PATCH 107/449] fix the problem of ruff of github --- modules/api/api.py | 4 ++-- modules/processing.py | 2 +- modules/progress.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 5d000ae8bf6..1f464806727 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -338,7 +338,7 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = create_task_id("text2img") - if txt2imgreq.force_task_id != None: + if txt2imgreq.force_task_id is None: task_id = txt2imgreq.force_task_id script_runner = scripts.scripts_txt2img @@ -396,7 +396,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): task_id = create_task_id("img2img") - if img2imgreq.force_task_id != None: + if img2imgreq.force_task_id is None: task_id = img2imgreq.force_task_id init_images = img2imgreq.init_images diff --git a/modules/processing.py b/modules/processing.py index 657cacfcace..5added65ed0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1359,7 +1359,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): inpainting_mask_invert: int = 0 initial_noise_multiplier: float = None latent_mask: Image = None - force_task_id: string = None + force_task_id: str = None image_mask: Any = field(default=None, init=False) diff --git a/modules/progress.py b/modules/progress.py index 553866db113..6946fb1bc2a 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -78,7 +78,7 @@ def setup_progress_api(app): return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) def get_pending_tasks(): - pending_tasks_ids = [x for x in pending_tasks] + pending_tasks_ids = list(pending_tasks) pending_len = len(pending_tasks_ids) return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids) From da45e73b4ffde2e2a85b64a3e3258a0625bd307e Mon Sep 17 00:00:00 2001 From: gayshub Date: Fri, 15 Dec 2023 17:57:58 +0800 Subject: [PATCH 108/449] fix the problem of ruff of github --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 1f464806727..9fac7e604f3 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -340,7 +340,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = create_task_id("text2img") if txt2imgreq.force_task_id is None: task_id = txt2imgreq.force_task_id - + script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) From 6d7e57ba6a4d686d515518b5f90e91b32fa01caf Mon Sep 17 00:00:00 2001 From: gayshub Date: Fri, 15 Dec 2023 18:03:14 +0800 Subject: [PATCH 109/449] fix the problem of ruff of github --- modules/api/api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 9fac7e604f3..8d8e70a46a2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -340,7 +340,6 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = create_task_id("text2img") if txt2imgreq.force_task_id is None: task_id = txt2imgreq.force_task_id - script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) From ea272152e0b50dbb2bd675ec020607f3d50c37d0 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Dec 2023 15:08:08 +0800 Subject: [PATCH 110/449] Add FP8 settings into PNG info --- modules/generation_parameters_copypaste.py | 6 ++++++ modules/processing.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 4efe53e0c48..dbffe494998 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -314,6 +314,12 @@ def parse_generation_parameters(x: str): if "VAE Decoder" not in res: res["VAE Decoder"] = "Full" + if "FP8 weight" not in res: + res["FP8 weight"] = "Disable" + + if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable": + res["Cache FP16 weight for LoRA"] = False + skip = set(shared.opts.infotext_skip_pasting) res = {k: v for k, v in res.items() if k not in skip} diff --git a/modules/processing.py b/modules/processing.py index bea01ec6884..179f2c0fccf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -688,6 +688,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, "Model": p.sd_model_name if opts.add_model_name_to_info else None, + "FP8 weight": opts.fp8_storage if devices.fp8 else None, + "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None, "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None, "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None, "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), From 7745db6fc02faf19117838c1e7bcc8a60b5f5e90 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 10:15:08 +0300 Subject: [PATCH 111/449] torch 2.1.2 --- modules/errors.py | 4 ++-- modules/launch_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index c534a5d66fe..48aa13a1728 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -107,8 +107,8 @@ def check_versions(): import torch import gradio - expected_torch_version = "2.1.0" - expected_xformers_version = "0.0.22.post7" + expected_torch_version = "2.1.2" + expected_xformers_version = "0.0.23.post1" expected_gradio_version = "3.41.2" if version.parse(torch.__version__) < version.parse(expected_torch_version): diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 2c54e2a02f4..dabef0f5335 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -315,7 +315,7 @@ def requirements_met(requirements_file): def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}") if args.use_ipex: if platform.system() == "Windows": # The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main @@ -338,7 +338,7 @@ def prepare_environment(): torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.22.post7') + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1') clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") From cd9ce2e31c4a264d7cde17c54d24f8ad94c9cf2c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 10:40:20 +0300 Subject: [PATCH 112/449] Use radio for FP8 mode selection --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index d470eb8ffd7..fa542ba8cd6 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -206,7 +206,7 @@ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), - "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), + "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), })) From 93eae69895c34361a71dbed17348bcfd132fbc6a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 11:00:42 +0300 Subject: [PATCH 113/449] move soft inpainting to a built-in extension --- .../soft-inpainting/scripts}/soft_inpainting.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {scripts => extensions-builtin/soft-inpainting/scripts}/soft_inpainting.py (100%) diff --git a/scripts/soft_inpainting.py b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py similarity index 100% rename from scripts/soft_inpainting.py rename to extensions-builtin/soft-inpainting/scripts/soft_inpainting.py From 86b3aa94e2d36a4f9d5ef1bb7c6ec995ff8eb517 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 11:04:59 +0300 Subject: [PATCH 114/449] rename pending tasks api endpoint to be more in line with others --- modules/progress.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/progress.py b/modules/progress.py index 6946fb1bc2a..85255e821f5 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -74,9 +74,10 @@ class ProgressResponse(BaseModel): def setup_progress_api(app): - app.add_api_route("/internal/pendingTasks", get_pending_tasks, methods=["GET"]) + app.add_api_route("/internal/pending-tasks", get_pending_tasks, methods=["GET"]) return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) + def get_pending_tasks(): pending_tasks_ids = list(pending_tasks) pending_len = len(pending_tasks_ids) From a97832033427096072d5ea914adac3662cda4fd1 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Dec 2023 19:39:43 +0800 Subject: [PATCH 115/449] Let fp8-related settings to invalidate cond_cache --- modules/processing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index dd97b4eee80..9351e3fb019 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -431,6 +431,8 @@ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps opts.sdxl_crop_top, self.width, self.height, + opts.fp8_storage, + opts.cache_fp16_weight, ) def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None): From 98c5fa92015837706adfd9975d5f345ab74f1c99 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 16 Dec 2023 22:14:39 +0900 Subject: [PATCH 116/449] fix extras caption BLIP #14328 --- scripts/postprocessing_caption.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/postprocessing_caption.py b/scripts/postprocessing_caption.py index 243e3ad9c62..9482a03ca8b 100644 --- a/scripts/postprocessing_caption.py +++ b/scripts/postprocessing_caption.py @@ -25,6 +25,6 @@ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option) captions.append(deepbooru.model.tag(pp.image)) if "BLIP" in option: - captions.append(shared.interrogator.generate_caption(pp.image)) + captions.append(shared.interrogator.interrogate(pp.image.convert("RGB"))) pp.caption = ", ".join([x for x in captions if x]) From de03882d6ca56bc81058f5120f028678a6a54aaa Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 17 Dec 2023 08:55:35 +0300 Subject: [PATCH 117/449] make task ids for API work without force_task_id --- modules/api/api.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 9637cb81b77..7154c9d5b7b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -336,9 +336,8 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel return script_args def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): - task_id = create_task_id("text2img") - if txt2imgreq.force_task_id is None: - task_id = txt2imgreq.force_task_id + task_id = txt2imgreq.force_task_id or create_task_id("txt2img") + script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) @@ -393,9 +392,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): - task_id = create_task_id("img2img") - if img2imgreq.force_task_id is None: - task_id = img2imgreq.force_task_id + task_id = img2imgreq.force_task_id or create_task_id("img2img") init_images = img2imgreq.init_images if init_images is None: From 10945aa41a158ee03727c5ea77d4ffff6b5370f0 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 18 Dec 2023 15:27:41 +0900 Subject: [PATCH 118/449] only rewrite ui-config when there is change and a typo --- modules/ui.py | 4 +++- modules/ui_loadsave.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index d80486dd4ac..f02b55116f5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1086,6 +1086,7 @@ def get_textual_inversion_template_names(): ) loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) + ui_settings_from_file = loadsave.ui_settings.copy() settings = ui_settings.UiSettings() settings.create_ui(loadsave, dummy_component) @@ -1146,7 +1147,8 @@ def get_textual_inversion_template_names(): modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint']) - loadsave.dump_defaults() + if ui_settings_from_file != loadsave.ui_settings: + loadsave.dump_defaults() demo.ui_loadsave = loadsave return demo diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index 7826786ccde..693ff75c5d8 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -144,7 +144,7 @@ def write_to_file(self, current_ui_settings): json.dump(current_ui_settings, file, indent=4, ensure_ascii=False) def dump_defaults(self): - """saves default values to a file unless tjhe file is present and there was an error loading default values at start""" + """saves default values to a file unless the file is present and there was an error loading default values at start""" if self.error_loading and os.path.exists(self.filename): return From e4b4a9c4acf0ca375a8603f7f52fde8467b2d266 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Mon, 18 Dec 2023 18:00:01 +0800 Subject: [PATCH 119/449] [IPEX] Slice SDPA into smaller chunks --- modules/xpu_specific.py | 66 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d8da94a0efd..0ebdd5964f1 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -27,6 +27,68 @@ def torch_xpu_gc(): has_xpu = check_for_xpu() + +# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627 +# Here we implement a slicing algorithm to split large batch size into smaller chunks, +# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. +# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, +# which is the best trade-off between VRAM usage and performance. +ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) +orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention +def torch_xpu_scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs +): + # cast to same dtype first + key = key.to(query.dtype) + value = value.to(query.dtype) + + N = query.shape[:-2] # Batch size + L = query.size(-2) # Target sequence length + E = query.size(-1) # Embedding dimension of the query and key + S = key.size(-2) # Source sequence length + Ev = value.size(-1) # Embedding dimension of the value + + total_batch_size = torch.numel(torch.empty(N)) + batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size())) + + if total_batch_size <= batch_size_limit: + return orig_sdp_attn_func( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + *args, **kwargs + ) + + query = torch.reshape(query, (-1, L, E)) + key = torch.reshape(key, (-1, S, E)) + value = torch.reshape(value, (-1, S, Ev)) + if attn_mask is not None: + attn_mask = attn_mask.view(-1, L, S) + chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit + outputs = [] + for i in range(chunk_count): + attn_mask_chunk = ( + None + if attn_mask is None + else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :] + ) + chunk_output = orig_sdp_attn_func( + query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + attn_mask_chunk, + dropout_p, + is_causal, + *args, **kwargs + ) + outputs.append(chunk_output) + result = torch.cat(outputs, dim=0) + return torch.reshape(result, (*N, L, Ev)) + + if has_xpu: # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device CondFunc('torch.Generator', @@ -55,5 +117,5 @@ def torch_xpu_gc(): lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) CondFunc('torch.nn.functional.scaled_dot_product_attention', - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) + lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs), + lambda orig_func, query, *args, **kwargs: query.is_xpu) From f586f4973a0f715e30b42242bb0e6b3f88c37d90 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Mon, 18 Dec 2023 19:44:52 +0800 Subject: [PATCH 120/449] Fix device id --- modules/xpu_specific.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 0ebdd5964f1..f7687a66cc2 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -33,7 +33,7 @@ def torch_xpu_gc(): # so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. # The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, # which is the best trade-off between VRAM usage and performance. -ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) +ARC_SINGLE_ALLOCATION_LIMIT = {} orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention def torch_xpu_scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs @@ -49,7 +49,10 @@ def torch_xpu_scaled_dot_product_attention( Ev = value.size(-1) # Embedding dimension of the value total_batch_size = torch.numel(torch.empty(N)) - batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size())) + device_id = query.device.index + if device_id not in ARC_SINGLE_ALLOCATION_LIMIT: + ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) + batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size())) if total_batch_size <= batch_size_limit: return orig_sdp_attn_func( From fe4d084390d35de49baf83c3319a72d71f540aee Mon Sep 17 00:00:00 2001 From: Muhammad Rehan Aslam <19831661+ranareehanaslam@users.noreply.github.com> Date: Mon, 18 Dec 2023 17:50:00 +0500 Subject: [PATCH 121/449] Update webui.py Added (Fixed) IPV6 Functionality When there is No Webui Argument Passed --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 9ed20b30672..3b587dc4150 100644 --- a/webui.py +++ b/webui.py @@ -39,7 +39,7 @@ def api_only(): print(f"Startup time: {startup_timer.summary()}.") api.launch( - server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", + server_name = cmd_opts.server_name if cmd_opts.server_name else ("0.0.0.0" if cmd_opts.listen else "127.0.0.1"), port=cmd_opts.port if cmd_opts.port else 7861, root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "" ) From 0d5941edbc4c602b760d4200bd76e044c65a0e40 Mon Sep 17 00:00:00 2001 From: Muhammad Rehan Aslam <19831661+ranareehanaslam@users.noreply.github.com> Date: Tue, 19 Dec 2023 09:50:38 +0500 Subject: [PATCH 122/449] Update webui.py Co-authored-by: Aarni Koskela --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 3b587dc4150..3f5b67a3060 100644 --- a/webui.py +++ b/webui.py @@ -39,7 +39,7 @@ def api_only(): print(f"Startup time: {startup_timer.summary()}.") api.launch( - server_name = cmd_opts.server_name if cmd_opts.server_name else ("0.0.0.0" if cmd_opts.listen else "127.0.0.1"), + server_name=cmd_opts.server_name or ("0.0.0.0" if cmd_opts.listen else "127.0.0.1"), port=cmd_opts.port if cmd_opts.port else 7861, root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "" ) From 3e068de0dcec811d515402fc184f70709a785e4f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 19 Dec 2023 18:48:49 +0900 Subject: [PATCH 123/449] reorder training preprocessing modules in extras tab using the order from before the rework 11d23e8ca55c097ecfa255a05b63f194e25f08be --- scripts/postprocessing_caption.py | 2 +- scripts/postprocessing_create_flipped_copies.py | 2 +- scripts/postprocessing_focal_crop.py | 2 +- scripts/processing_autosized_crop.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/postprocessing_caption.py b/scripts/postprocessing_caption.py index 9482a03ca8b..5592a89870e 100644 --- a/scripts/postprocessing_caption.py +++ b/scripts/postprocessing_caption.py @@ -4,7 +4,7 @@ class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing): name = "Caption" - order = 4000 + order = 4040 def ui(self): with ui_components.InputAccordion(False, label="Caption") as enable: diff --git a/scripts/postprocessing_create_flipped_copies.py b/scripts/postprocessing_create_flipped_copies.py index 3425571dc3b..b673003b6ea 100644 --- a/scripts/postprocessing_create_flipped_copies.py +++ b/scripts/postprocessing_create_flipped_copies.py @@ -6,7 +6,7 @@ class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing): name = "Create flipped copies" - order = 4000 + order = 4030 def ui(self): with ui_components.InputAccordion(False, label="Create flipped copies") as enable: diff --git a/scripts/postprocessing_focal_crop.py b/scripts/postprocessing_focal_crop.py index d3baf29878a..cff1dbc5470 100644 --- a/scripts/postprocessing_focal_crop.py +++ b/scripts/postprocessing_focal_crop.py @@ -7,7 +7,7 @@ class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing): name = "Auto focal point crop" - order = 4000 + order = 4010 def ui(self): with ui_components.InputAccordion(False, label="Auto focal point crop") as enable: diff --git a/scripts/processing_autosized_crop.py b/scripts/processing_autosized_crop.py index c098022645d..7e674989814 100644 --- a/scripts/processing_autosized_crop.py +++ b/scripts/processing_autosized_crop.py @@ -28,7 +28,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing): name = "Auto-sized crop" - order = 4000 + order = 4020 def ui(self): with ui_components.InputAccordion(False, label="Auto-sized crop") as enable: From 9feb034e343d6d7ef63395821658fb3774b30a24 Mon Sep 17 00:00:00 2001 From: wangqyqq Date: Thu, 21 Dec 2023 20:15:51 +0800 Subject: [PATCH 124/449] support for sdxl-inpaint model --- configs/sd_xl_inpaint.yaml | 98 +++++++++++++++++++++++++++++++++++++ modules/processing.py | 19 +++++++ modules/sd_models_config.py | 6 ++- modules/sd_models_xl.py | 5 ++ 4 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 configs/sd_xl_inpaint.yaml diff --git a/configs/sd_xl_inpaint.yaml b/configs/sd_xl_inpaint.yaml new file mode 100644 index 00000000000..3bad372186f --- /dev/null +++ b/configs/sd_xl_inpaint.yaml @@ -0,0 +1,98 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: True + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + spatial_transformer_attn_type: softmax-xformers + legacy: False + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + # crossattn cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenCLIPEmbedder + params: + layer: hidden + layer_idx: 11 + # crossattn and vector cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + freeze: True + layer: penultimate + always_return_pooled: True + legacy: False + # vector cond + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: target_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/modules/processing.py b/modules/processing.py index 6f01c95f5b6..159548dba7d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -106,6 +106,20 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: + sd = sd_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + return image_conditioning + # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. @@ -362,6 +376,11 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) + sd = self.sampler.model_wrap.inner_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index deab2f6e237..b38137eb5a9 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -15,6 +15,7 @@ config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") +config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename): sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: - return config_sdxl + if diffusion_model_input.shape[1] == 9: + return config_sdxl_inpainting + else: + return config_sdxl if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 0112332161f..d8a9a73bcb3 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -34,6 +34,11 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): + sd = self.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + x = torch.cat([x] + cond['c_concat'], dim=1) + return self.model(x, t, cond) From de1809bd14450cfc41623b6021c7087fb385ab6f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 22 Dec 2023 00:08:35 +0900 Subject: [PATCH 125/449] handle axis_type is None --- scripts/xyz_grid.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index b2250c04dc0..e5083874af9 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -476,6 +476,8 @@ def fill(axis_type, csv_mode): fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown]) def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode): + axis_type = axis_type or 0 # if axle type is None set to 0 + choices = self.current_axis_options[axis_type].choices has_choices = choices is not None @@ -526,6 +528,8 @@ def get_dropdown_update_from_params(axis, params): return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode] def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode): + x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0 + if not no_fixed_seeds: modules.processing.fix_seed(p) From edfae95d90a49ea95394b772817a59dde4175222 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 23 Dec 2023 01:21:00 +0900 Subject: [PATCH 126/449] prevent crash due to Script __init__ exception --- modules/scripts.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/modules/scripts.py b/modules/scripts.py index b6fcf96e047..3a76691186b 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -566,7 +566,12 @@ def initialize_scripts(self, is_img2img): auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() for script_data in auto_processing_scripts + scripts_data: - script = script_data.script_class() + try: + script = script_data.script_class() + except Exception: + errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True) + continue + script.filename = script_data.path script.is_txt2img = not is_img2img script.is_img2img = is_img2img From 00d4a4d4ac75903d8224e9beb1136584dd66fcd8 Mon Sep 17 00:00:00 2001 From: lanyeeee <1210347077@qq.com> Date: Tue, 26 Dec 2023 14:46:29 +0800 Subject: [PATCH 127/449] move thread-unsafe code to __init__ --- modules/api/api.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 7154c9d5b7b..f0a68c672d6 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -251,6 +251,15 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] + script_runner = scripts.scripts_img2img + if not script_runner.scripts: + script_runner.initialize_scripts(True) + ui.create_ui() + if not self.default_script_arg_txt2img: + self.default_script_arg_txt2img = self.init_default_script_args(script_runner) + if not self.default_script_arg_img2img: + self.default_script_arg_img2img = self.init_default_script_args(script_runner) + def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) @@ -339,11 +348,6 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = txt2imgreq.force_task_id or create_task_id("txt2img") script_runner = scripts.scripts_txt2img - if not script_runner.scripts: - script_runner.initialize_scripts(False) - ui.create_ui() - if not self.default_script_arg_txt2img: - self.default_script_arg_txt2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params @@ -403,11 +407,6 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): mask = decode_base64_to_image(mask) script_runner = scripts.scripts_img2img - if not script_runner.scripts: - script_runner.initialize_scripts(True) - ui.create_ui() - if not self.default_script_arg_img2img: - self.default_script_arg_img2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) populate = img2imgreq.copy(update={ # Override __init__ params From bfe418a58d39c69ca2672e7d8a1fd7ad2b34869b Mon Sep 17 00:00:00 2001 From: wangqyqq Date: Wed, 27 Dec 2023 10:20:56 +0800 Subject: [PATCH 128/449] add some codes for robust --- modules/processing.py | 24 +++++++++++++----------- modules/sd_models_xl.py | 5 +++-- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 159548dba7d..c05e608ab99 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -108,17 +108,18 @@ def txt2img_image_conditioning(sd_model, x, width, height): else: sd = sd_model.model.state_dict() diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input.shape[1] == 9: - # The "masked-image" in this case will just be all 0.5 since the entire image is masked. - image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 - image_conditioning = images_tensor_to_samples(image_conditioning, - approximation_indexes.get(opts.sd_vae_encode_method)) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) - return image_conditioning + return image_conditioning # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. @@ -378,8 +379,9 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None sd = self.sampler.model_wrap.inner_model.model.state_dict() diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input.shape[1] == 9: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index d8a9a73bcb3..162d0fee892 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -36,8 +36,9 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): sd = self.model.state_dict() diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input.shape[1] == 9: - x = torch.cat([x] + cond['c_concat'], dim=1) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + x = torch.cat([x] + cond['c_concat'], dim=1) return self.model(x, t, cond) From de04573438bc111f137359b8f4998780bf315275 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 28 Dec 2023 06:22:51 +0900 Subject: [PATCH 129/449] create utility truncate_path utli.truncate_path(target_path, base_path) return the target_path relative to base_path if target_path is a sub path of base_path else return the absolute path --- modules/util.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/modules/util.py b/modules/util.py index 60afc0670c7..4861bcb0861 100644 --- a/modules/util.py +++ b/modules/util.py @@ -2,7 +2,7 @@ import re from modules import shared -from modules.paths_internal import script_path +from modules.paths_internal import script_path, cwd def natural_sort_key(s, regex=re.compile('([0-9]+)')): @@ -56,3 +56,13 @@ def ldm_print(*args, **kwargs): return print(*args, **kwargs) + + +def truncate_path(target_path, base_path=cwd): + abs_target, abs_base = os.path.abspath(target_path), os.path.abspath(base_path) + try: + if os.path.commonpath([abs_target, abs_base]) == abs_base: + return os.path.relpath(abs_target, abs_base) + except ValueError: + pass + return abs_target From af2951ed53da6d357aea9232538f9ea7e1cdc648 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 28 Dec 2023 06:52:33 +0900 Subject: [PATCH 130/449] base default image output on data_path Co-Authored-By: Alberto Cano <34340962+canoalberto@users.noreply.github.com> --- modules/paths_internal.py | 1 + modules/shared_options.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/modules/paths_internal.py b/modules/paths_internal.py index 89131a54fa1..b86ecd7f192 100644 --- a/modules/paths_internal.py +++ b/modules/paths_internal.py @@ -28,5 +28,6 @@ extensions_dir = os.path.join(data_path, "extensions") extensions_builtin_dir = os.path.join(script_path, "extensions-builtin") config_states_dir = os.path.join(script_path, "config_states") +default_output_dir = os.path.join(data_path, "output") roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf') diff --git a/modules/shared_options.py b/modules/shared_options.py index fa542ba8cd6..752a4f1259a 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -1,7 +1,8 @@ +import os import gradio as gr -from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes -from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 +from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util +from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401 from modules.shared_cmd_options import cmd_opts from modules.options import options_section, OptionInfo, OptionHTML, categories @@ -74,14 +75,14 @@ options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), { "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), - "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), - "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), - "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs), + "outdir_txt2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-images')), 'Output directory for txt2img images', component_args=hide_dirs), + "outdir_img2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-images')), 'Output directory for img2img images', component_args=hide_dirs), + "outdir_extras_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'extras-images')), 'Output directory for images from extras tab', component_args=hide_dirs), "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs), - "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs), - "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs), - "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs), - "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs), + "outdir_txt2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-grids')), 'Output directory for txt2img grids', component_args=hide_dirs), + "outdir_img2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-grids')), 'Output directory for img2img grids', component_args=hide_dirs), + "outdir_save": OptionInfo(util.truncate_path(os.path.join(data_path, 'log', 'images')), "Directory for saving images using the Save button", component_args=hide_dirs), + "outdir_init_images": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'init-images')), "Directory for saving init images when using img2img", component_args=hide_dirs), })) options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), { From 892e703b59b2f867d8a202a52fab1db89882ef86 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 28 Dec 2023 06:52:41 +0900 Subject: [PATCH 131/449] webpath use truncate_path --- modules/ui_gradio_extensions.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py index 0d368f8b2c4..a86c368ef70 100644 --- a/modules/ui_gradio_extensions.py +++ b/modules/ui_gradio_extensions.py @@ -1,17 +1,12 @@ import os import gradio as gr -from modules import localization, shared, scripts -from modules.paths import script_path, data_path, cwd +from modules import localization, shared, scripts, util +from modules.paths import script_path, data_path def webpath(fn): - if fn.startswith(cwd): - web_path = os.path.relpath(fn, cwd) - else: - web_path = os.path.abspath(fn) - - return f'file={web_path}?{os.path.getmtime(fn)}' + return f'file={util.truncate_path(fn)}?{os.path.getmtime(fn)}' def javascript_html(): From dc57ec0296e768ee91290e16ab262404837c566d Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 29 Dec 2023 01:56:48 +0900 Subject: [PATCH 132/449] save info of init image --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 9351e3fb019..141f2f11167 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1482,7 +1482,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): # Save init image if opts.save_init_img: self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest() - images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False) + images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info) image = images.flatten(img, opts.img2img_background_color) From bb07cb6a0df60a96827125ffc09ea182a1ed272c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 17 Dec 2023 10:22:03 +0300 Subject: [PATCH 133/449] a --- modules/api/api.py | 27 +++++++++++++ modules/api/models.py | 2 + modules/generation_parameters_copypaste.py | 19 +++++++++ modules/processing.py | 2 +- modules/processing_scripts/refiner.py | 7 ++-- modules/processing_scripts/seed.py | 13 +++--- modules/ui.py | 46 +++++++++++----------- 7 files changed, 83 insertions(+), 33 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 7154c9d5b7b..b3d70940401 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -335,6 +335,29 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] return script_args + def apply_infotext(self, request, tabname): + if not request.infotext: + return {} + + params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) + + for field in generation_parameters_copypaste.paste_fields[tabname]["fields"]: + if not field.api: + continue + + value = field.function(params) if field.function else params.get(field.label) + target_type = request.__fields__[field.api].type_ + + if value is None: + continue + + if not isinstance(value, target_type): + value = target_type(value) + + setattr(request, field.api, value) + + return params + def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = txt2imgreq.force_task_id or create_task_id("txt2img") @@ -342,6 +365,9 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): if not script_runner.scripts: script_runner.initialize_scripts(False) ui.create_ui() + + infotext_params = self.apply_infotext(txt2imgreq, "txt2img") + if not self.default_script_arg_txt2img: self.default_script_arg_txt2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) @@ -358,6 +384,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_scripts', None) + args.pop('infotext', None) script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) diff --git a/modules/api/models.py b/modules/api/models.py index 58083a34fb0..16edf11cf83 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -108,6 +108,7 @@ def generate_model(self): {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "force_task_id", "type": str, "default": None}, + {"key": "infotext", "type": str, "default": None}, ] ).generate_model() @@ -126,6 +127,7 @@ def generate_model(self): {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "force_task_id", "type": str, "default": None}, + {"key": "infotext", "type": str, "default": None}, ] ).generate_model() diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index dbffe494998..4b4727c4458 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -28,6 +28,19 @@ def __init__(self, paste_button, tabname, source_text_component=None, source_ima self.paste_field_names = paste_field_names or [] +class PasteField(tuple): + def __new__(cls, component, target, *, api=None): + return super().__new__(cls, (component, target)) + + def __init__(self, component, target, *, api=None): + super().__init__() + + self.api = api + self.component = component + self.label = target if isinstance(target, str) else None + self.function = target if callable(target) else None + + paste_fields: dict[str, dict] = {} registered_param_bindings: list[ParamBinding] = [] @@ -84,6 +97,12 @@ def image_from_url_text(filedata): def add_paste_fields(tabname, init_img, fields, override_settings_component=None): + + if fields: + for i in range(len(fields)): + if not isinstance(fields[i], PasteField): + fields[i] = PasteField(*fields[i]) + paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component} # backwards compatibility for existing extensions diff --git a/modules/processing.py b/modules/processing.py index 9351e3fb019..ee2ccf460d2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1135,7 +1135,7 @@ def calculate_target_resolution(self): def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: - if self.hr_checkpoint_name: + if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint': self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) if self.hr_checkpoint_info is None: diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index 29ccb78f903..cefad32b761 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -1,6 +1,7 @@ import gradio as gr from modules import scripts, sd_models +from modules.generation_parameters_copypaste import PasteField from modules.ui_common import create_refresh_button from modules.ui_components import InputAccordion @@ -31,9 +32,9 @@ def lookup_checkpoint(title): return None if info is None else info.title self.infotext_fields = [ - (enable_refiner, lambda d: 'Refiner' in d), - (refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))), - (refiner_switch_at, 'Refiner switch at'), + PasteField(enable_refiner, lambda d: 'Refiner' in d), + PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"), + PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"), ] return enable_refiner, refiner_checkpoint, refiner_switch_at diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index dc9c2da5000..a3e16a12e9b 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -3,6 +3,7 @@ import gradio as gr from modules import scripts, ui, errors +from modules.generation_parameters_copypaste import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton @@ -51,12 +52,12 @@ def ui(self, is_img2img): seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras]) self.infotext_fields = [ - (self.seed, "Seed"), - (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), + PasteField(self.seed, "Seed", api="seed"), + PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), + PasteField(subseed, "Variation seed", api="subseed"), + PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"), + PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"), + PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"), ] self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}') diff --git a/modules/ui.py b/modules/ui.py index d80486dd4ac..9db2407ed30 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -28,7 +28,7 @@ import modules.shared as shared from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.generation_parameters_copypaste import image_from_url_text +from modules.generation_parameters_copypaste import image_from_url_text, PasteField create_setting_component = ui_settings.create_setting_component @@ -436,28 +436,28 @@ def create_ui(): ) txt2img_paste_fields = [ - (toprow.prompt, "Prompt"), - (toprow.negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_name, "Sampler"), - (cfg_scale, "CFG scale"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)), - (hr_scale, "Hires upscale"), - (hr_upscaler, "Hires upscaler"), - (hr_second_pass_steps, "Hires steps"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), - (hr_checkpoint_name, "Hires checkpoint"), - (hr_sampler_name, "Hires sampler"), - (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), - (hr_prompt, "Hires prompt"), - (hr_negative_prompt, "Hires negative prompt"), - (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), + PasteField(toprow.prompt, "Prompt", api="prompt"), + PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"), + PasteField(steps, "Steps", api="steps"), + PasteField(sampler_name, "Sampler", api="sampler_name"), + PasteField(cfg_scale, "CFG scale", api="cfg_scale"), + PasteField(width, "Size-1", api="width"), + PasteField(height, "Size-2", api="height"), + PasteField(batch_size, "Batch size", api="batch_size"), + PasteField(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update(), api="styles"), + PasteField(denoising_strength, "Denoising strength", api="denoising_strength"), + PasteField(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d), api="enable_hr"), + PasteField(hr_scale, "Hires upscale", api="hr_scale"), + PasteField(hr_upscaler, "Hires upscaler", api="hr_upscaler"), + PasteField(hr_second_pass_steps, "Hires steps", api="hr_second_pass_steps"), + PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"), + PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"), + PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"), + PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"), + PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), + PasteField(hr_prompt, "Hires prompt", api="hr_prompt"), + PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"), + PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), *scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) From 59d060fd5ea93fcc3fdbfbd13b6e20fda06ecf94 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 30 Dec 2023 17:11:03 +0900 Subject: [PATCH 134/449] More lora not found warning --- extensions-builtin/Lora/networks.py | 8 +++++++- extensions-builtin/Lora/scripts/lora_script.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 985b2753bce..72ebd624169 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,3 +1,4 @@ +import gradio as gr import logging import os import re @@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No emb_db.skipped_embeddings[name] = embedding if failed_to_load_networks: - sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) + lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}' + sd_hijack.model_hijack.comments.append(lora_not_found_message) + if shared.opts.lora_not_found_warning_console: + print(f'\n{lora_not_found_message}\n') + if shared.opts.lora_not_found_gradio_warning: + gr.Warning(lora_not_found_message) purge_networks_from_memory() diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index ef23968c563..1518f7e5c89 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -39,6 +39,8 @@ def before_ui(): "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), + "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"), + "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"), })) From ba92135a2ba9e210ce5370715e2defcb43df70d1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 12:11:09 +0300 Subject: [PATCH 135/449] add override_settings support for infotext API --- modules/api/api.py | 10 ++++ modules/generation_parameters_copypaste.py | 66 ++++++++++++++-------- 2 files changed, 54 insertions(+), 22 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index b3d70940401..fb108486fab 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -341,6 +341,7 @@ def apply_infotext(self, request, tabname): params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) + handled_fields = {} for field in generation_parameters_copypaste.paste_fields[tabname]["fields"]: if not field.api: continue @@ -355,6 +356,15 @@ def apply_infotext(self, request, tabname): value = target_type(value) setattr(request, field.api, value) + handled_fields[field.label] = 1 + + if request.override_settings is None: + request.override_settings = {} + + overriden_settings = generation_parameters_copypaste.get_override_settings(params, skip_fields=handled_fields) + for infotext_text, setting_name, value in overriden_settings: + if setting_name not in request.override_settings: + request.override_settings[setting_name] = value return params diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 4b4727c4458..86a36c3273c 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -390,6 +390,48 @@ def create_override_settings_dict(text_pairs): return res +def get_override_settings(params, *, skip_fields=None): + """Returns a list of settings overrides from the infotext parameters dictionary. + + This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns + a list of tuples containing the parameter name, setting name, and new value cast to correct type. + + It checks for conditions before adding an override: + - ignores settings that match the current value + - ignores parameter keys present in skip_fields argument. + + Example input: + {"Clip skip": "2"} + + Example output: + [("Clip skip", "CLIP_stop_at_last_layers", 2)] + """ + + res = [] + + mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] + for param_name, setting_name in mapping + infotext_to_setting_name_mapping: + if param_name in (skip_fields or {}): + continue + + v = params.get(param_name, None) + if v is None: + continue + + if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap: + continue + + v = shared.opts.cast_value(setting_name, v) + current_value = getattr(shared.opts, setting_name, None) + + if v == current_value: + continue + + res.append((param_name, setting_name, v)) + + return res + + def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: @@ -431,29 +473,9 @@ def paste_func(prompt): already_handled_fields = {key: 1 for _, key in paste_fields} def paste_settings(params): - vals = {} - - mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] - for param_name, setting_name in mapping + infotext_to_setting_name_mapping: - if param_name in already_handled_fields: - continue - - v = params.get(param_name, None) - if v is None: - continue - - if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap: - continue - - v = shared.opts.cast_value(setting_name, v) - current_value = getattr(shared.opts, setting_name, None) - - if v == current_value: - continue - - vals[param_name] = v + vals = get_override_settings(params, skip_fields=already_handled_fields) - vals_pairs = [f"{k}: {v}" for k, v in vals.items()] + vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals] return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs)) From 8b08b78c03f09898455d54cf099225ed5f8de1ee Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 12:27:23 +0300 Subject: [PATCH 136/449] make it so that if an option from infotext conflicts with an argument from API, the latter overrides the former --- modules/api/api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index fb108486fab..cabccb4c025 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -339,6 +339,7 @@ def apply_infotext(self, request, tabname): if not request.infotext: return {} + set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) handled_fields = {} @@ -346,6 +347,9 @@ def apply_infotext(self, request, tabname): if not field.api: continue + if field.api in set_fields: + continue + value = field.function(params) if field.function else params.get(field.label) target_type = request.__fields__[field.api].type_ @@ -376,7 +380,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): script_runner.initialize_scripts(False) ui.create_ui() - infotext_params = self.apply_infotext(txt2imgreq, "txt2img") + self.apply_infotext(txt2imgreq, "txt2img") if not self.default_script_arg_txt2img: self.default_script_arg_txt2img = self.init_default_script_args(script_runner) From 0aacd4c72b4008d7153e747301fe8c5ffca57f85 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 13:33:18 +0300 Subject: [PATCH 137/449] add support for alwayson scripts for infotext API --- modules/api/api.py | 61 +++++++++++++++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index cabccb4c025..946cfe4a92f 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -312,8 +312,13 @@ def init_default_script_args(self, script_runner): script_args[script.args_from:script.args_to] = ui_default_values return script_args - def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner): + def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None): script_args = default_script_args.copy() + + if input_script_args is not None: + for index, value in input_script_args.items(): + script_args[index] = value + # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run() if selectable_scripts: script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args @@ -335,41 +340,58 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] return script_args - def apply_infotext(self, request, tabname): + def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None): if not request.infotext: return {} + possible_fields = generation_parameters_copypaste.paste_fields[tabname]["fields"] set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) - handled_fields = {} - for field in generation_parameters_copypaste.paste_fields[tabname]["fields"]: - if not field.api: - continue - - if field.api in set_fields: - continue - + def get_field_value(field, params): value = field.function(params) if field.function else params.get(field.label) - target_type = request.__fields__[field.api].type_ - if value is None: - continue + return None + + if field.api in request.__fields__: + target_type = request.__fields__[field.api].type_ + else: + target_type = type(field.component.value) + + if target_type == type(None): + return None if not isinstance(value, target_type): value = target_type(value) - setattr(request, field.api, value) - handled_fields[field.label] = 1 + return value + + for field in possible_fields: + if not field.api: + continue + + if field.api in set_fields: + continue + + value = get_field_value(field, params) + if value is not None: + setattr(request, field.api, value) if request.override_settings is None: request.override_settings = {} - overriden_settings = generation_parameters_copypaste.get_override_settings(params, skip_fields=handled_fields) - for infotext_text, setting_name, value in overriden_settings: + overriden_settings = generation_parameters_copypaste.get_override_settings(params) + for _, setting_name, value in overriden_settings: if setting_name not in request.override_settings: request.override_settings[setting_name] = value + if script_runner is not None and mentioned_script_args is not None: + indexes = {v: i for i, v in enumerate(script_runner.inputs)} + script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes) + + for field, index in script_fields: + mentioned_script_args[index] = get_field_value(field, params) + return params def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): @@ -380,7 +402,8 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): script_runner.initialize_scripts(False) ui.create_ui() - self.apply_infotext(txt2imgreq, "txt2img") + infotext_script_args = {} + self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) if not self.default_script_arg_txt2img: self.default_script_arg_txt2img = self.init_default_script_args(script_runner) @@ -400,7 +423,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): args.pop('alwayson_scripts', None) args.pop('infotext', None) - script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) + script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args) send_images = args.pop('send_images', True) args.pop('save_images', None) From 11a435b4697c2d735a117f31944c4ebe59c2504c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 13:34:46 +0300 Subject: [PATCH 138/449] img2img support for infotext API --- modules/api/api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 946cfe4a92f..2c8dc2a0f6e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -470,6 +470,10 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): if not script_runner.scripts: script_runner.initialize_scripts(True) ui.create_ui() + + infotext_script_args = {} + self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + if not self.default_script_arg_img2img: self.default_script_arg_img2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) @@ -489,7 +493,7 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_scripts', None) - script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner) + script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args) send_images = args.pop('send_images', True) args.pop('save_images', None) From 8f1826375943718463cec3af97a37886249bdb44 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 13:48:25 +0300 Subject: [PATCH 139/449] fix bad values read from infotext for API, add comment --- modules/api/api.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 2c8dc2a0f6e..2918f785750 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -341,6 +341,13 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel return script_args def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None): + """Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext. + + If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored. + + Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext. + """ + if not request.infotext: return {} @@ -361,7 +368,10 @@ def get_field_value(field, params): if target_type == type(None): return None - if not isinstance(value, target_type): + if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value + value = value.get('value') + + if value is not None and not isinstance(value, target_type): value = target_type(value) return value @@ -390,7 +400,12 @@ def get_field_value(field, params): script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes) for field, index in script_fields: - mentioned_script_args[index] = get_field_value(field, params) + value = get_field_value(field, params) + + if value is None: + continue + + mentioned_script_args[index] = value return params From f0e2e8b930115012976f7c5bae00e243a7ebbf79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 15:12:48 +0300 Subject: [PATCH 140/449] update #14354 --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 3f5b67a3060..2c417168aa6 100644 --- a/webui.py +++ b/webui.py @@ -39,7 +39,7 @@ def api_only(): print(f"Startup time: {startup_timer.summary()}.") api.launch( - server_name=cmd_opts.server_name or ("0.0.0.0" if cmd_opts.listen else "127.0.0.1"), + server_name=initialize_util.gradio_server_name(), port=cmd_opts.port if cmd_opts.port else 7861, root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "" ) From c069c2c5628728c9506dd034ef98e6335fd5bb34 Mon Sep 17 00:00:00 2001 From: lanyeeee <1210347077@qq.com> Date: Sat, 30 Dec 2023 21:32:22 +0800 Subject: [PATCH 141/449] add locks to ensure init args are thread-safe --- modules/api/api.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index f0a68c672d6..45c5c50773a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -251,14 +251,10 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] - script_runner = scripts.scripts_img2img - if not script_runner.scripts: - script_runner.initialize_scripts(True) - ui.create_ui() - if not self.default_script_arg_txt2img: - self.default_script_arg_txt2img = self.init_default_script_args(script_runner) - if not self.default_script_arg_img2img: - self.default_script_arg_img2img = self.init_default_script_args(script_runner) + self.txt2img_script_arg_init_lock = Lock() + self.img2img_script_arg_init_lock = Lock() + + def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: @@ -348,6 +344,12 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = txt2imgreq.force_task_id or create_task_id("txt2img") script_runner = scripts.scripts_txt2img + with self.txt2img_script_arg_init_lock: + if not script_runner.scripts: + script_runner.initialize_scripts(False) + ui.create_ui() + if not self.default_script_arg_txt2img: + self.default_script_arg_txt2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params @@ -407,6 +409,12 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): mask = decode_base64_to_image(mask) script_runner = scripts.scripts_img2img + with self.img2img_script_arg_init_lock: + if not script_runner.scripts: + script_runner.initialize_scripts(True) + ui.create_ui() + if not self.default_script_arg_img2img: + self.default_script_arg_img2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) populate = img2imgreq.copy(update={ # Override __init__ params From 31992eff9b9714c158b12cec16dfe66c76270dfa Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 16:51:02 +0300 Subject: [PATCH 142/449] make it possible again to extract styles that have whitespace at the end. --- modules/styles.py | 47 +++++++++++++++++------------------------------ 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 81d9800d184..026c43001a4 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -30,38 +30,29 @@ def apply_styles_to_prompt(prompt, styles): return prompt -def unwrap_style_text_from_prompt(style_text, prompt): - """ - Checks the prompt to see if the style text is wrapped around it. If so, - returns True plus the prompt text without the style text. Otherwise, returns - False with the original prompt. +def extract_style_text_from_prompt(style_text, prompt): + """This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt. - Note that the "cleaned" version of the style text is only used for matching - purposes here. It isn't returned; the original style text is not modified. + extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg") + extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg") + extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg") """ - stripped_prompt = prompt - stripped_style_text = style_text + + stripped_prompt = prompt.strip() + stripped_style_text = style_text.strip() + if "{prompt}" in stripped_style_text: - # Work out whether the prompt is wrapped in the style text. If so, we - # return True and the "inner" prompt text that isn't part of the style. - try: - left, right = stripped_style_text.split("{prompt}", 2) - except ValueError as e: - # If the style text has multple "{prompt}"s, we can't split it into - # two parts. This is an error, but we can't do anything about it. - print(f"Unable to compare style text to prompt:\n{style_text}") - print(f"Error: {e}") - return False, prompt + left, right = stripped_style_text.split("{prompt}", 2) if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): - prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)] + prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)] return True, prompt else: - # Work out whether the given prompt ends with the style text. If so, we - # return True and the prompt text up to where the style text starts. if stripped_prompt.endswith(stripped_style_text): - prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)] - if prompt.endswith(", "): + prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)] + + if prompt.endswith(', '): prompt = prompt[:-2] + return True, prompt return False, prompt @@ -76,15 +67,11 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt): if not style.prompt and not style.negative_prompt: return False, prompt, negative_prompt - match_positive, extracted_positive = unwrap_style_text_from_prompt( - style.prompt, prompt - ) + match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt) if not match_positive: return False, prompt, negative_prompt - match_negative, extracted_negative = unwrap_style_text_from_prompt( - style.negative_prompt, negative_prompt - ) + match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt) if not match_negative: return False, prompt, negative_prompt From 7aa27b000a3087dcb5cc7254600064bf70cacd3e Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 25 Dec 2023 14:44:15 +0200 Subject: [PATCH 143/449] Add types to split_grid --- modules/images.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/images.py b/modules/images.py index 16f9ae7cce1..d30e8865d7c 100644 --- a/modules/images.py +++ b/modules/images.py @@ -64,9 +64,8 @@ def image_grid(imgs, batch_size=1, rows=None): Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) -def split_grid(image, tile_w=512, tile_h=512, overlap=64): - w = image.width - h = image.height +def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: + w, h = image.size non_overlap_width = tile_w - overlap non_overlap_height = tile_h - overlap From 12c6f37f8e4b1d1d643c9d8d5dfc763c3203c728 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 27 Dec 2023 11:01:45 +0200 Subject: [PATCH 144/449] Add tile_count property to Grid --- modules/images.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index d30e8865d7c..87a7bf22139 100644 --- a/modules/images.py +++ b/modules/images.py @@ -61,7 +61,13 @@ def image_grid(imgs, batch_size=1, rows=None): return grid -Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) +class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])): + @property + def tile_count(self) -> int: + """ + The total number of tiles in the grid. + """ + return sum(len(row[2]) for row in self.tiles) def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: From e472383acbb9e07dca311abe5fb16ee2675e410a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 27 Dec 2023 11:04:33 +0200 Subject: [PATCH 145/449] Refactor esrgan_upscale to more generic upscale_with_model --- modules/esrgan_model.py | 47 +++++----------------------- modules/upscaler_utils.py | 66 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 39 deletions(-) create mode 100644 modules/upscaler_utils.py diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 02a1727d280..c0d22a99273 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,13 +1,12 @@ import sys -import numpy as np import torch -from PIL import Image import modules.esrgan_model_arch as arch -from modules import modelloader, images, devices +from modules import modelloader, devices from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model def mod2normal(state_dict): @@ -190,40 +189,10 @@ def load_model(self, path: str): return model -def upscale_without_tiling(model, img): - img = np.array(img) - img = img[:, :, ::-1] - img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_esrgan) - with torch.no_grad(): - output = model(img) - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - return Image.fromarray(output, 'RGB') - - def esrgan_upscale(model, img): - if opts.ESRGAN_tile == 0: - return upscale_without_tiling(model, img) - - grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) - newtiles = [] - scale_factor = 1 - - for y, h, row in grid.tiles: - newrow = [] - for tiledata in row: - x, w, tile = tiledata - - output = upscale_without_tiling(model, tile) - scale_factor = output.width // tile.width - - newrow.append([x * scale_factor, w * scale_factor, output]) - newtiles.append([y * scale_factor, h * scale_factor, newrow]) - - newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) - output = images.combine_grid(newgrid) - return output + return upscale_with_model( + model, + img, + tile_size=opts.ESRGAN_tile, + tile_overlap=opts.ESRGAN_tile_overlap, + ) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py new file mode 100644 index 00000000000..8bdda51c4c2 --- /dev/null +++ b/modules/upscaler_utils.py @@ -0,0 +1,66 @@ +import logging +from typing import Callable + +import numpy as np +import torch +import tqdm +from PIL import Image + +from modules import devices, images + +logger = logging.getLogger(__name__) + + +def upscale_without_tiling(model, img: Image.Image): + img = np.array(img) + img = img[:, :, ::-1] + img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(devices.device_esrgan) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + return Image.fromarray(output, 'RGB') + + +def upscale_with_model( + model: Callable[[torch.Tensor], torch.Tensor], + img: Image.Image, + *, + tile_size: int, + tile_overlap: int = 0, + desc="tiled upscale", +) -> Image.Image: + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img) + output = upscale_without_tiling(model, img) + logger.debug("=> %s", output) + return output + + grid = images.split_grid(img, tile_size, tile_size, tile_overlap) + newtiles = [] + + with tqdm.tqdm(total=grid.tile_count, desc=desc) as p: + for y, h, row in grid.tiles: + newrow = [] + for x, w, tile in row: + logger.debug("Tile (%d, %d) %s...", x, y, tile) + output = upscale_without_tiling(model, tile) + scale_factor = output.width // tile.width + logger.debug("=> %s (scale factor %s)", output, scale_factor) + newrow.append([x * scale_factor, w * scale_factor, output]) + p.update(1) + newtiles.append([y * scale_factor, h * scale_factor, newrow]) + + newgrid = images.Grid( + newtiles, + tile_w=grid.tile_w * scale_factor, + tile_h=grid.tile_h * scale_factor, + image_w=grid.image_w * scale_factor, + image_h=grid.image_h * scale_factor, + overlap=grid.overlap * scale_factor, + ) + return images.combine_grid(newgrid) From b0f59342346b1c8b405f97c0e0bb01c6ae05c601 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 25 Dec 2023 14:43:51 +0200 Subject: [PATCH 146/449] Use Spandrel for upscaling and face restoration architectures (aside from GFPGAN and LDSR) --- .../ScuNET/scripts/scunet_model.py | 13 +- .../ScuNET/scunet_model_arch.py | 268 ----- .../SwinIR/scripts/swinir_model.py | 126 +- .../SwinIR/swinir_model_arch.py | 867 -------------- .../SwinIR/swinir_model_arch_v2.py | 1017 ----------------- modules/codeformer/codeformer_arch.py | 276 ----- modules/codeformer/vqgan_arch.py | 435 ------- modules/codeformer_model.py | 195 ++-- modules/esrgan_model.py | 153 +-- modules/esrgan_model_arch.py | 465 -------- modules/gfpgan_model.py | 13 +- modules/launch_utils.py | 7 - modules/modelloader.py | 16 + modules/paths.py | 1 - modules/realesrgan_model.py | 153 +-- modules/sysinfo.py | 2 - modules/upscaler.py | 3 + requirements.txt | 3 +- requirements_versions.txt | 4 +- 19 files changed, 263 insertions(+), 3754 deletions(-) delete mode 100644 extensions-builtin/ScuNET/scunet_model_arch.py delete mode 100644 extensions-builtin/SwinIR/swinir_model_arch.py delete mode 100644 extensions-builtin/SwinIR/swinir_model_arch_v2.py delete mode 100644 modules/codeformer/codeformer_arch.py delete mode 100644 modules/codeformer/vqgan_arch.py delete mode 100644 modules/esrgan_model_arch.py diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 167d2f64b8e..18cf8e1a0b6 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -7,9 +7,7 @@ import modules.upscaler from modules import devices, modelloader, script_callbacks, errors -from scunet_model_arch import SCUNet -from modules.modelloader import load_file_from_url from modules.shared import opts @@ -120,17 +118,10 @@ def load_model(self, path: str): device = devices.get_device_for('scunet') if path.startswith("http"): # TODO: this doesn't use `path` at all? - filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") + filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) - model.load_state_dict(torch.load(filename), strict=True) - model.eval() - for _, v in model.named_parameters(): - v.requires_grad = False - model = model.to(device) - - return model + return modelloader.load_spandrel_model(filename, device=device) def on_ui_settings(): diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py deleted file mode 100644 index b51a880629b..00000000000 --- a/extensions-builtin/ScuNET/scunet_model_arch.py +++ /dev/null @@ -1,268 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange -from einops.layers.torch import Rearrange -from timm.models.layers import trunc_normal_, DropPath - - -class WMSA(nn.Module): - """ Self-attention module in Swin Transformer - """ - - def __init__(self, input_dim, output_dim, head_dim, window_size, type): - super(WMSA, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.head_dim = head_dim - self.scale = self.head_dim ** -0.5 - self.n_heads = input_dim // head_dim - self.window_size = window_size - self.type = type - self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) - - self.relative_position_params = nn.Parameter( - torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)) - - self.linear = nn.Linear(self.input_dim, self.output_dim) - - trunc_normal_(self.relative_position_params, std=.02) - self.relative_position_params = torch.nn.Parameter( - self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1, - 2).transpose( - 0, 1)) - - def generate_mask(self, h, w, p, shift): - """ generating the mask of SW-MSA - Args: - shift: shift parameters in CyclicShift. - Returns: - attn_mask: should be (1 1 w p p), - """ - # supporting square. - attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) - if self.type == 'W': - return attn_mask - - s = p - shift - attn_mask[-1, :, :s, :, s:, :] = True - attn_mask[-1, :, s:, :, :s, :] = True - attn_mask[:, -1, :, :s, :, s:] = True - attn_mask[:, -1, :, s:, :, :s] = True - attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') - return attn_mask - - def forward(self, x): - """ Forward pass of Window Multi-head Self-attention module. - Args: - x: input tensor with shape of [b h w c]; - attn_mask: attention mask, fill -inf where the value is True; - Returns: - output: tensor shape [b h w c] - """ - if self.type != 'W': - x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) - - x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) - h_windows = x.size(1) - w_windows = x.size(2) - # square validation - # assert h_windows == w_windows - - x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) - qkv = self.embedding_layer(x) - q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0) - sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale - # Adding learnable relative embedding - sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q') - # Using Attn Mask to distinguish different subwindows. - if self.type != 'W': - attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2) - sim = sim.masked_fill_(attn_mask, float("-inf")) - - probs = nn.functional.softmax(sim, dim=-1) - output = torch.einsum('hbwij,hbwjc->hbwic', probs, v) - output = rearrange(output, 'h b w p c -> b w p (h c)') - output = self.linear(output) - output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) - - if self.type != 'W': - output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2)) - - return output - - def relative_embedding(self): - cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)])) - relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 - # negative is allowed - return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()] - - -class Block(nn.Module): - def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer Block - """ - super(Block, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - assert type in ['W', 'SW'] - self.type = type - if input_resolution <= window_size: - self.type = 'W' - - self.ln1 = nn.LayerNorm(input_dim) - self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.ln2 = nn.LayerNorm(input_dim) - self.mlp = nn.Sequential( - nn.Linear(input_dim, 4 * input_dim), - nn.GELU(), - nn.Linear(4 * input_dim, output_dim), - ) - - def forward(self, x): - x = x + self.drop_path(self.msa(self.ln1(x))) - x = x + self.drop_path(self.mlp(self.ln2(x))) - return x - - -class ConvTransBlock(nn.Module): - def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer and Conv Block - """ - super(ConvTransBlock, self).__init__() - self.conv_dim = conv_dim - self.trans_dim = trans_dim - self.head_dim = head_dim - self.window_size = window_size - self.drop_path = drop_path - self.type = type - self.input_resolution = input_resolution - - assert self.type in ['W', 'SW'] - if self.input_resolution <= self.window_size: - self.type = 'W' - - self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, - self.type, self.input_resolution) - self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - - self.conv_block = nn.Sequential( - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), - nn.ReLU(True), - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False) - ) - - def forward(self, x): - conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1) - conv_x = self.conv_block(conv_x) + conv_x - trans_x = Rearrange('b c h w -> b h w c')(trans_x) - trans_x = self.trans_block(trans_x) - trans_x = Rearrange('b h w c -> b c h w')(trans_x) - res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) - x = x + res - - return x - - -class SCUNet(nn.Module): - # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256): - def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256): - super(SCUNet, self).__init__() - if config is None: - config = [2, 2, 2, 2, 2, 2, 2] - self.config = config - self.dim = dim - self.head_dim = 32 - self.window_size = 8 - - # drop path rate for each layer - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] - - self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] - - begin = 0 - self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[0])] + \ - [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] - - begin += config[0] - self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[1])] + \ - [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] - - begin += config[1] - self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[2])] + \ - [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] - - begin += config[2] - self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 8) - for i in range(config[3])] - - begin += config[3] - self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[4])] - - begin += config[4] - self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[5])] - - begin += config[5] - self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[6])] - - self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] - - self.m_head = nn.Sequential(*self.m_head) - self.m_down1 = nn.Sequential(*self.m_down1) - self.m_down2 = nn.Sequential(*self.m_down2) - self.m_down3 = nn.Sequential(*self.m_down3) - self.m_body = nn.Sequential(*self.m_body) - self.m_up3 = nn.Sequential(*self.m_up3) - self.m_up2 = nn.Sequential(*self.m_up2) - self.m_up1 = nn.Sequential(*self.m_up1) - self.m_tail = nn.Sequential(*self.m_tail) - # self.apply(self._init_weights) - - def forward(self, x0): - - h, w = x0.size()[-2:] - paddingBottom = int(np.ceil(h / 64) * 64 - h) - paddingRight = int(np.ceil(w / 64) * 64 - w) - x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0) - - x1 = self.m_head(x0) - x2 = self.m_down1(x1) - x3 = self.m_down2(x2) - x4 = self.m_down3(x3) - x = self.m_body(x4) - x = self.m_up3(x + x4) - x = self.m_up2(x + x3) - x = self.m_up1(x + x2) - x = self.m_tail(x + x1) - - x = x[..., :h, :w] - - return x - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index ae0d0e6a8ea..85c18b9e939 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,5 +1,5 @@ +import logging import sys -import platform import numpy as np import torch @@ -8,13 +8,11 @@ from modules import modelloader, devices, script_callbacks, shared from modules.shared import opts, state -from swinir_model_arch import SwinIR -from swinir_model_arch_v2 import Swin2SR from modules.upscaler import Upscaler, UpscalerData SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" -device_swinir = devices.get_device_for('swinir') +logger = logging.getLogger(__name__) class UpscalerSwinIR(Upscaler): @@ -37,26 +35,29 @@ def __init__(self, dirname): scalers.append(model_data) self.scalers = scalers - def do_upscale(self, img, model_file): - use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \ - and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows" + def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: current_config = (model_file, opts.SWIN_tile) - if use_compile and self._cached_model_config == current_config: + device = self._get_device() + + if self._cached_model_config == current_config: model = self._cached_model else: - self._cached_model = None try: model = self.load_model(model_file) except Exception as e: print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) return img - model = model.to(device_swinir, dtype=devices.dtype) - if use_compile: - model = torch.compile(model) - self._cached_model = model - self._cached_model_config = current_config - img = upscale(img, model) + self._cached_model = model + self._cached_model_config = current_config + + img = upscale( + img, + model, + tile=opts.SWIN_tile, + tile_overlap=opts.SWIN_tile_overlap, + device=device, + ) devices.torch_gc() return img @@ -69,69 +70,54 @@ def load_model(self, path, scale=4): ) else: filename = path - if filename.endswith(".v2.pth"): - model = Swin2SR( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6], - embed_dim=180, - num_heads=[6, 6, 6, 6, 6, 6], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="1conv", - ) - params = None - else: - model = SwinIR( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], - embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="3conv", - ) - params = "params_ema" - pretrained_model = torch.load(filename) - if params is not None: - model.load_state_dict(pretrained_model[params], strict=True) - else: - model.load_state_dict(pretrained_model, strict=True) + model = modelloader.load_spandrel_model( + filename, + device=self._get_device(), + dtype=devices.dtype, + ) + if getattr(opts, 'SWIN_torch_compile', False): + try: + model = torch.compile(model) + except Exception: + logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True) return model + def _get_device(self): + return devices.get_device_for('swinir') + def upscale( - img, - model, - tile=None, - tile_overlap=None, - window_size=8, - scale=4, + img, + model, + *, + tile: int, + tile_overlap: int, + window_size=8, + scale=4, + device, ): - tile = tile or opts.SWIN_tile - tile_overlap = tile_overlap or opts.SWIN_tile_overlap - img = np.array(img) img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype) + img = img.unsqueeze(0).to(device, dtype=devices.dtype) with torch.no_grad(), devices.autocast(): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old w_pad = (w_old // window_size + 1) * window_size - w_old img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = inference(img, model, tile, tile_overlap, window_size, scale) + output = inference( + img, + model, + tile=tile, + tile_overlap=tile_overlap, + window_size=window_size, + scale=scale, + device=device, + ) output = output[..., : h_old * scale, : w_old * scale] output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: @@ -142,7 +128,16 @@ def upscale( return Image.fromarray(output, "RGB") -def inference(img, model, tile, tile_overlap, window_size, scale): +def inference( + img, + model, + *, + tile: int, + tile_overlap: int, + window_size: int, + scale: int, + device, +): # test the image tile by tile b, c, h, w = img.size() tile = min(tile, h, w) @@ -152,8 +147,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale): stride = tile - tile_overlap h_idx_list = list(range(0, h - tile, stride)) + [h - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img) - W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir) + E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img) + W = torch.zeros_like(E, dtype=devices.dtype, device=device) with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: @@ -185,8 +180,7 @@ def on_ui_settings(): shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) - if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows - shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) + shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py deleted file mode 100644 index 93b9327473a..00000000000 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ /dev/null @@ -1,867 +0,0 @@ -# ----------------------------------------------------------------------------------- -# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 -# Originally Written by Ze Liu, Modified by Jingyun Liang. -# ----------------------------------------------------------------------------------- - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - # assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - x = x.flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - flops = 0 - H, W = self.img_size - if self.norm is not None: - flops += H * W * self.embed_dim - return flops - - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - -class SwinIR(nn.Module): - r""" SwinIR - A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(SwinIR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.window_size = window_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - if self.upscale == 4: - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def check_image_size(self, x): - _, _, h, w = x.size() - mod_pad_h = (self.window_size - h % self.window_size) % self.window_size - mod_pad_w = (self.window_size - w % self.window_size) % self.window_size - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - return x - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - H, W = x.shape[2:] - x = self.check_image_size(x) - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - if self.upscale == 4: - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - - return x[:, :, :H*self.upscale, :W*self.upscale] - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for layer in self.layers: - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = SwinIR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py deleted file mode 100644 index dad22cca29e..00000000000 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ /dev/null @@ -1,1017 +0,0 @@ -# ----------------------------------------------------------------------------------- -# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/ -# Written by Conde and Choi et al. -# ----------------------------------------------------------------------------------- - -import math -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - pretrained_window_size (tuple[int]): The height and width of the window in pre-training. - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., - pretrained_window_size=(0, 0)): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.pretrained_window_size = pretrained_window_size - self.num_heads = num_heads - - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) - - # mlp to generate continuous relative position bias - self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), - nn.ReLU(inplace=True), - nn.Linear(512, num_heads, bias=False)) - - # get relative_coords_table - relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) - relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) - relative_coords_table = torch.stack( - torch.meshgrid([relative_coords_h, - relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 - if pretrained_window_size[0] > 0: - relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) - else: - relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) - relative_coords_table *= 8 # normalize to -8, 8 - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - torch.abs(relative_coords_table) + 1.0) / np.log2(8) - - self.register_buffer("relative_coords_table", relative_coords_table) - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=False) - if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(dim)) - self.v_bias = nn.Parameter(torch.zeros(dim)) - else: - self.q_bias = None - self.v_bias = None - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) - qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - # cosine attention - attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) - logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp() - attn = attn * logit_scale - - relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) - relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - relative_position_bias = 16 * torch.sigmoid(relative_position_bias) - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, ' \ - f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - pretrained_window_size (int): Window size in pre-training. - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, - pretrained_window_size=to_2tuple(pretrained_window_size)) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - #assert L == H * W, "input feature has wrong size" - - shortcut = x - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - x = shortcut + self.drop_path(self.norm1(x)) - - # FFN - x = x + self.drop_path(self.norm2(self.mlp(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(2 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.reduction(x) - x = self.norm(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - flops += H * W * self.dim // 2 - return flops - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - pretrained_window_size (int): Local window size in pre-training. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - pretrained_window_size=0): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - pretrained_window_size=pretrained_window_size) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - def _init_respostnorm(self): - for blk in self.blocks: - nn.init.constant_(blk.norm1.bias, 0) - nn.init.constant_(blk.norm1.weight, 0) - nn.init.constant_(blk.norm2.bias, 0) - nn.init.constant_(blk.norm2.weight, 0) - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - # assert H == self.img_size[0] and W == self.img_size[1], - # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - -class Upsample_hf(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample_hf, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - - -class Swin2SR(nn.Module): - r""" Swin2SR - A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), - window_size=7, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(Swin2SR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.window_size = window_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - - if self.upsampler == 'pixelshuffle_hf': - self.layers_hf = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers_hf.append(layer) - - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffle_aux': - self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) - self.conv_before_upsample = nn.Sequential( - nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.conv_after_aux = nn.Sequential( - nn.Conv2d(3, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - - elif self.upsampler == 'pixelshuffle_hf': - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.upsample_hf = Upsample_hf(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - self.conv_before_upsample_hf = nn.Sequential( - nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - assert self.upscale == 4, 'only support x4 now.' - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def check_image_size(self, x): - _, _, h, w = x.size() - mod_pad_h = (self.window_size - h % self.window_size) % self.window_size - mod_pad_w = (self.window_size - w % self.window_size) % self.window_size - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - return x - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward_features_hf(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers_hf: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - H, W = x.shape[2:] - x = self.check_image_size(x) - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffle_aux': - bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False) - bicubic = self.conv_bicubic(bicubic) - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - aux = self.conv_aux(x) # b, 3, LR_H, LR_W - x = self.conv_after_aux(aux) - x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale] - x = self.conv_last(x) - aux = aux / self.img_range + self.mean - elif self.upsampler == 'pixelshuffle_hf': - # for classical SR with HF - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x_before = self.conv_before_upsample(x) - x_out = self.conv_last(self.upsample(x_before)) - - x_hf = self.conv_first_hf(x_before) - x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf - x_hf = self.conv_before_upsample_hf(x_hf) - x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) - x = x_out + x_hf - x_hf = x_hf / self.img_range + self.mean - - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - if self.upsampler == "pixelshuffle_aux": - return x[:, :, :H*self.upscale, :W*self.upscale], aux - - elif self.upsampler == "pixelshuffle_hf": - x_out = x_out / self.img_range + self.mean - return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] - - else: - return x[:, :, :H*self.upscale, :W*self.upscale] - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for layer in self.layers: - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = Swin2SR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py deleted file mode 100644 index 12db6814268..00000000000 --- a/modules/codeformer/codeformer_arch.py +++ /dev/null @@ -1,276 +0,0 @@ -# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py - -import math -import torch -from torch import nn, Tensor -import torch.nn.functional as F -from typing import Optional - -from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock -from basicsr.utils.registry import ARCH_REGISTRY - -def calc_mean_std(feat, eps=1e-5): - """Calculate mean and std for adaptive_instance_normalization. - - Args: - feat (Tensor): 4D tensor. - eps (float): A small value added to the variance to avoid - divide-by-zero. Default: 1e-5. - """ - size = feat.size() - assert len(size) == 4, 'The input feature should be 4D tensor.' - b, c = size[:2] - feat_var = feat.view(b, c, -1).var(dim=2) + eps - feat_std = feat_var.sqrt().view(b, c, 1, 1) - feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) - return feat_mean, feat_std - - -def adaptive_instance_normalization(content_feat, style_feat): - """Adaptive instance normalization. - - Adjust the reference features to have the similar color and illuminations - as those in the degradate features. - - Args: - content_feat (Tensor): The reference feature. - style_feat (Tensor): The degradate features. - """ - size = content_feat.size() - style_mean, style_std = calc_mean_std(style_feat) - content_mean, content_std = calc_mean_std(content_feat) - normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) - return normalized_feat * style_std.expand(size) + style_mean.expand(size) - - -class PositionEmbeddingSine(nn.Module): - """ - This is a more standard version of the position embedding, very similar to the one - used by the Attention is all you need paper, generalized to work on images. - """ - - def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): - super().__init__() - self.num_pos_feats = num_pos_feats - self.temperature = temperature - self.normalize = normalize - if scale is not None and normalize is False: - raise ValueError("normalize should be True if scale is passed") - if scale is None: - scale = 2 * math.pi - self.scale = scale - - def forward(self, x, mask=None): - if mask is None: - mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) - not_mask = ~mask - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) - if self.normalize: - eps = 1e-6 - y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) - - pos_x = x_embed[:, :, :, None] / dim_t - pos_y = y_embed[:, :, :, None] / dim_t - pos_x = torch.stack( - (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos_y = torch.stack( - (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - return pos - -def _get_activation_fn(activation): - """Return an activation function given a string""" - if activation == "relu": - return F.relu - if activation == "gelu": - return F.gelu - if activation == "glu": - return F.glu - raise RuntimeError(F"activation should be relu/gelu, not {activation}.") - - -class TransformerSALayer(nn.Module): - def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): - super().__init__() - self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) - # Implementation of Feedforward model - MLP - self.linear1 = nn.Linear(embed_dim, dim_mlp) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_mlp, embed_dim) - - self.norm1 = nn.LayerNorm(embed_dim) - self.norm2 = nn.LayerNorm(embed_dim) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - def with_pos_embed(self, tensor, pos: Optional[Tensor]): - return tensor if pos is None else tensor + pos - - def forward(self, tgt, - tgt_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None): - - # self attention - tgt2 = self.norm1(tgt) - q = k = self.with_pos_embed(tgt2, query_pos) - tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask)[0] - tgt = tgt + self.dropout1(tgt2) - - # ffn - tgt2 = self.norm2(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) - tgt = tgt + self.dropout2(tgt2) - return tgt - -class Fuse_sft_block(nn.Module): - def __init__(self, in_ch, out_ch): - super().__init__() - self.encode_enc = ResBlock(2*in_ch, out_ch) - - self.scale = nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) - - self.shift = nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) - - def forward(self, enc_feat, dec_feat, w=1): - enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) - scale = self.scale(enc_feat) - shift = self.shift(enc_feat) - residual = w * (dec_feat * scale + shift) - out = dec_feat + residual - return out - - -@ARCH_REGISTRY.register() -class CodeFormer(VQAutoEncoder): - def __init__(self, dim_embd=512, n_head=8, n_layers=9, - codebook_size=1024, latent_size=256, - connect_list=('32', '64', '128', '256'), - fix_modules=('quantize', 'generator')): - super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) - - if fix_modules is not None: - for module in fix_modules: - for param in getattr(self, module).parameters(): - param.requires_grad = False - - self.connect_list = connect_list - self.n_layers = n_layers - self.dim_embd = dim_embd - self.dim_mlp = dim_embd*2 - - self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) - self.feat_emb = nn.Linear(256, self.dim_embd) - - # transformer - self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) - for _ in range(self.n_layers)]) - - # logits_predict head - self.idx_pred_layer = nn.Sequential( - nn.LayerNorm(dim_embd), - nn.Linear(dim_embd, codebook_size, bias=False)) - - self.channels = { - '16': 512, - '32': 256, - '64': 256, - '128': 128, - '256': 128, - '512': 64, - } - - # after second residual block for > 16, before attn layer for ==16 - self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} - # after first residual block for > 16, before attn layer for ==16 - self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} - - # fuse_convs_dict - self.fuse_convs_dict = nn.ModuleDict() - for f_size in self.connect_list: - in_ch = self.channels[f_size] - self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): - # ################### Encoder ##################### - enc_feat_dict = {} - out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] - for i, block in enumerate(self.encoder.blocks): - x = block(x) - if i in out_list: - enc_feat_dict[str(x.shape[-1])] = x.clone() - - lq_feat = x - # ################# Transformer ################### - # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) - pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) - # BCHW -> BC(HW) -> (HW)BC - feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) - query_emb = feat_emb - # Transformer encoder - for layer in self.ft_layers: - query_emb = layer(query_emb, query_pos=pos_emb) - - # output logits - logits = self.idx_pred_layer(query_emb) # (hw)bn - logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n - - if code_only: # for training stage II - # logits doesn't need softmax before cross_entropy loss - return logits, lq_feat - - # ################# Quantization ################### - # if self.training: - # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) - # # b(hw)c -> bc(hw) -> bchw - # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) - # ------------ - soft_one_hot = F.softmax(logits, dim=2) - _, top_idx = torch.topk(soft_one_hot, 1, dim=2) - quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) - # preserve gradients - # quant_feat = lq_feat + (quant_feat - lq_feat).detach() - - if detach_16: - quant_feat = quant_feat.detach() # for training stage III - if adain: - quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) - - # ################## Generator #################### - x = quant_feat - fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] - - for i, block in enumerate(self.generator.blocks): - x = block(x) - if i in fuse_list: # fuse after i-th block - f_size = str(x.shape[-1]) - if w>0: - x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) - out = x - # logits doesn't need softmax before cross_entropy loss - return out, logits, lq_feat diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py deleted file mode 100644 index 09ee6660dc5..00000000000 --- a/modules/codeformer/vqgan_arch.py +++ /dev/null @@ -1,435 +0,0 @@ -# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py - -''' -VQGAN code, adapted from the original created by the Unleashing Transformers authors: -https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py - -''' -import torch -import torch.nn as nn -import torch.nn.functional as F -from basicsr.utils import get_root_logger -from basicsr.utils.registry import ARCH_REGISTRY - -def normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -@torch.jit.script -def swish(x): - return x*torch.sigmoid(x) - - -# Define VQVAE classes -class VectorQuantizer(nn.Module): - def __init__(self, codebook_size, emb_dim, beta): - super(VectorQuantizer, self).__init__() - self.codebook_size = codebook_size # number of embeddings - self.emb_dim = emb_dim # dimension of embedding - self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 - self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) - self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) - - def forward(self, z): - # reshape z -> (batch, height, width, channel) and flatten - z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.emb_dim) - - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) - - mean_distance = torch.mean(d) - # find closest encodings - # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) - min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) - # [0-1], higher score, higher confidence - min_encoding_scores = torch.exp(-min_encoding_scores/10) - - min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) - min_encodings.scatter_(1, min_encoding_indices, 1) - - # get quantized latent vectors - z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) - # compute loss for embedding - loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - # preserve gradients - z_q = z + (z_q - z).detach() - - # perplexity - e_mean = torch.mean(min_encodings, dim=0) - perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q, loss, { - "perplexity": perplexity, - "min_encodings": min_encodings, - "min_encoding_indices": min_encoding_indices, - "min_encoding_scores": min_encoding_scores, - "mean_distance": mean_distance - } - - def get_codebook_feat(self, indices, shape): - # input indices: batch*token_num -> (batch*token_num)*1 - # shape: batch, height, width, channel - indices = indices.view(-1,1) - min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) - min_encodings.scatter_(1, indices, 1) - # get quantized latent vectors - z_q = torch.matmul(min_encodings.float(), self.embedding.weight) - - if shape is not None: # reshape back to match original input shape - z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() - - return z_q - - -class GumbelQuantizer(nn.Module): - def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): - super().__init__() - self.codebook_size = codebook_size # number of embeddings - self.emb_dim = emb_dim # dimension of embedding - self.straight_through = straight_through - self.temperature = temp_init - self.kl_weight = kl_weight - self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits - self.embed = nn.Embedding(codebook_size, emb_dim) - - def forward(self, z): - hard = self.straight_through if self.training else True - - logits = self.proj(z) - - soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) - - z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) - - # + kl divergence to the prior loss - qy = F.softmax(logits, dim=1) - diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() - min_encoding_indices = soft_one_hot.argmax(dim=1) - - return z_q, diff, { - "min_encoding_indices": min_encoding_indices - } - - -class Downsample(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x): - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - return x - - -class Upsample(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - x = self.conv(x) - - return x - - -class ResBlock(nn.Module): - def __init__(self, in_channels, out_channels=None): - super(ResBlock, self).__init__() - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - self.norm1 = normalize(in_channels) - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.norm2 = normalize(out_channels) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x_in): - x = x_in - x = self.norm1(x) - x = swish(x) - x = self.conv1(x) - x = self.norm2(x) - x = swish(x) - x = self.conv2(x) - if self.in_channels != self.out_channels: - x_in = self.conv_out(x_in) - - return x + x_in - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h*w) - q = q.permute(0, 2, 1) - k = k.reshape(b, c, h*w) - w_ = torch.bmm(q, k) - w_ = w_ * (int(c)**(-0.5)) - w_ = F.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h*w) - w_ = w_.permute(0, 2, 1) - h_ = torch.bmm(v, w_) - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x+h_ - - -class Encoder(nn.Module): - def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): - super().__init__() - self.nf = nf - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.attn_resolutions = attn_resolutions - - curr_res = self.resolution - in_ch_mult = (1,)+tuple(ch_mult) - - blocks = [] - # initial convultion - blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) - - # residual and downsampling blocks, with attention on smaller res (16x16) - for i in range(self.num_resolutions): - block_in_ch = nf * in_ch_mult[i] - block_out_ch = nf * ch_mult[i] - for _ in range(self.num_res_blocks): - blocks.append(ResBlock(block_in_ch, block_out_ch)) - block_in_ch = block_out_ch - if curr_res in attn_resolutions: - blocks.append(AttnBlock(block_in_ch)) - - if i != self.num_resolutions - 1: - blocks.append(Downsample(block_in_ch)) - curr_res = curr_res // 2 - - # non-local attention block - blocks.append(ResBlock(block_in_ch, block_in_ch)) - blocks.append(AttnBlock(block_in_ch)) - blocks.append(ResBlock(block_in_ch, block_in_ch)) - - # normalise and convert to latent size - blocks.append(normalize(block_in_ch)) - blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) - self.blocks = nn.ModuleList(blocks) - - def forward(self, x): - for block in self.blocks: - x = block(x) - - return x - - -class Generator(nn.Module): - def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): - super().__init__() - self.nf = nf - self.ch_mult = ch_mult - self.num_resolutions = len(self.ch_mult) - self.num_res_blocks = res_blocks - self.resolution = img_size - self.attn_resolutions = attn_resolutions - self.in_channels = emb_dim - self.out_channels = 3 - block_in_ch = self.nf * self.ch_mult[-1] - curr_res = self.resolution // 2 ** (self.num_resolutions-1) - - blocks = [] - # initial conv - blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) - - # non-local attention block - blocks.append(ResBlock(block_in_ch, block_in_ch)) - blocks.append(AttnBlock(block_in_ch)) - blocks.append(ResBlock(block_in_ch, block_in_ch)) - - for i in reversed(range(self.num_resolutions)): - block_out_ch = self.nf * self.ch_mult[i] - - for _ in range(self.num_res_blocks): - blocks.append(ResBlock(block_in_ch, block_out_ch)) - block_in_ch = block_out_ch - - if curr_res in self.attn_resolutions: - blocks.append(AttnBlock(block_in_ch)) - - if i != 0: - blocks.append(Upsample(block_in_ch)) - curr_res = curr_res * 2 - - blocks.append(normalize(block_in_ch)) - blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) - - self.blocks = nn.ModuleList(blocks) - - - def forward(self, x): - for block in self.blocks: - x = block(x) - - return x - - -@ARCH_REGISTRY.register() -class VQAutoEncoder(nn.Module): - def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, - beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): - super().__init__() - logger = get_root_logger() - self.in_channels = 3 - self.nf = nf - self.n_blocks = res_blocks - self.codebook_size = codebook_size - self.embed_dim = emb_dim - self.ch_mult = ch_mult - self.resolution = img_size - self.attn_resolutions = attn_resolutions or [16] - self.quantizer_type = quantizer - self.encoder = Encoder( - self.in_channels, - self.nf, - self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, - self.attn_resolutions - ) - if self.quantizer_type == "nearest": - self.beta = beta #0.25 - self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) - elif self.quantizer_type == "gumbel": - self.gumbel_num_hiddens = emb_dim - self.straight_through = gumbel_straight_through - self.kl_weight = gumbel_kl_weight - self.quantize = GumbelQuantizer( - self.codebook_size, - self.embed_dim, - self.gumbel_num_hiddens, - self.straight_through, - self.kl_weight - ) - self.generator = Generator( - self.nf, - self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, - self.attn_resolutions - ) - - if model_path is not None: - chkpt = torch.load(model_path, map_location='cpu') - if 'params_ema' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) - logger.info(f'vqgan is loaded from: {model_path} [params_ema]') - elif 'params' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) - logger.info(f'vqgan is loaded from: {model_path} [params]') - else: - raise ValueError('Wrong params!') - - - def forward(self, x): - x = self.encoder(x) - quant, codebook_loss, quant_stats = self.quantize(x) - x = self.generator(quant) - return x, codebook_loss, quant_stats - - - -# patch based discriminator -@ARCH_REGISTRY.register() -class VQGANDiscriminator(nn.Module): - def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): - super().__init__() - - layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] - ndf_mult = 1 - ndf_mult_prev = 1 - for n in range(1, n_layers): # gradually increase the number of filters - ndf_mult_prev = ndf_mult - ndf_mult = min(2 ** n, 8) - layers += [ - nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), - nn.BatchNorm2d(ndf * ndf_mult), - nn.LeakyReLU(0.2, True) - ] - - ndf_mult_prev = ndf_mult - ndf_mult = min(2 ** n_layers, 8) - - layers += [ - nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), - nn.BatchNorm2d(ndf * ndf_mult), - nn.LeakyReLU(0.2, True) - ] - - layers += [ - nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map - self.main = nn.Sequential(*layers) - - if model_path is not None: - chkpt = torch.load(model_path, map_location='cpu') - if 'params_d' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) - elif 'params' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) - else: - raise ValueError('Wrong params!') - - def forward(self, x): - return self.main(x) diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index da42b5e9932..517eadfd8de 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -8,9 +8,6 @@ from modules import shared, devices, modelloader, errors from modules.paths import models_path -# codeformer people made a choice to include modified basicsr library to their project which makes -# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. -# I am making a choice to include some files from codeformer to work around this issue. model_dir = "Codeformer" model_path = os.path.join(models_path, model_dir) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' @@ -18,115 +15,127 @@ codeformer = None -def setup_model(dirname): - os.makedirs(model_path, exist_ok=True) - - path = modules.paths.paths.get("CodeFormer", None) - if path is None: - return - - try: +class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): + def name(self): + return "CodeFormer" + + def __init__(self, dirname): + self.net = None + self.face_helper = None + self.cmd_dir = dirname + + def create_models(self): + from facexlib.detection import retinaface + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + + if self.net is not None and self.face_helper is not None: + self.net.to(devices.device_codeformer) + return self.net, self.face_helper + model_paths = modelloader.load_models( + model_path, + model_url, + self.cmd_dir, + download_name='codeformer-v0.1.0.pth', + ext_filter=['.pth'], + ) + + if len(model_paths) != 0: + ckpt_path = model_paths[0] + else: + print("Unable to load codeformer model.") + return None, None + net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer) + + if hasattr(retinaface, 'device'): + retinaface.device = devices.device_codeformer + + face_helper = FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + use_parse=True, + device=devices.device_codeformer, + ) + + self.net = net + self.face_helper = face_helper + + def send_model_to(self, device): + self.net.to(device) + self.face_helper.face_det.to(device) + self.face_helper.face_parse.to(device) + + def restore(self, np_image, w=None): from torchvision.transforms.functional import normalize - from modules.codeformer.codeformer_arch import CodeFormer from basicsr.utils import img2tensor, tensor2img - from facelib.utils.face_restoration_helper import FaceRestoreHelper - from facelib.detection.retinaface import retinaface - - net_class = CodeFormer - - class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): - def name(self): - return "CodeFormer" - - def __init__(self, dirname): - self.net = None - self.face_helper = None - self.cmd_dir = dirname - - def create_models(self): - - if self.net is not None and self.face_helper is not None: - self.net.to(devices.device_codeformer) - return self.net, self.face_helper - model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth']) - if len(model_paths) != 0: - ckpt_path = model_paths[0] - else: - print("Unable to load codeformer model.") - return None, None - net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) - checkpoint = torch.load(ckpt_path)['params_ema'] - net.load_state_dict(checkpoint) - net.eval() + np_image = np_image[:, :, ::-1] - if hasattr(retinaface, 'device'): - retinaface.device = devices.device_codeformer - face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer) + original_resolution = np_image.shape[0:2] - self.net = net - self.face_helper = face_helper + self.create_models() + if self.net is None or self.face_helper is None: + return np_image - return net, face_helper + self.send_model_to(devices.device_codeformer) - def send_model_to(self, device): - self.net.to(device) - self.face_helper.face_det.to(device) - self.face_helper.face_parse.to(device) + self.face_helper.clean_all() + self.face_helper.read_image(np_image) + self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + self.face_helper.align_warp_face() - def restore(self, np_image, w=None): - np_image = np_image[:, :, ::-1] + for cropped_face in self.face_helper.cropped_faces: + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) - original_resolution = np_image.shape[0:2] + try: + with torch.no_grad(): + res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True) + if isinstance(res, tuple): + output = res[0] + else: + output = res + if not isinstance(res, torch.Tensor): + raise TypeError(f"Expected torch.Tensor, got {type(res)}") + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + del output + devices.torch_gc() + except Exception: + errors.report('Failed inference for CodeFormer', exc_info=True) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - self.create_models() - if self.net is None or self.face_helper is None: - return np_image + restored_face = restored_face.astype('uint8') + self.face_helper.add_restored_face(restored_face) - self.send_model_to(devices.device_codeformer) + self.face_helper.get_inverse_affine(None) - self.face_helper.clean_all() - self.face_helper.read_image(np_image) - self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) - self.face_helper.align_warp_face() + restored_img = self.face_helper.paste_faces_to_input_image() + restored_img = restored_img[:, :, ::-1] - for cropped_face in self.face_helper.cropped_faces: - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) + if original_resolution != restored_img.shape[0:2]: + restored_img = cv2.resize( + restored_img, + (0, 0), + fx=original_resolution[1]/restored_img.shape[1], + fy=original_resolution[0]/restored_img.shape[0], + interpolation=cv2.INTER_LINEAR, + ) - try: - with torch.no_grad(): - output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - devices.torch_gc() - except Exception: - errors.report('Failed inference for CodeFormer', exc_info=True) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + self.face_helper.clean_all() - restored_face = restored_face.astype('uint8') - self.face_helper.add_restored_face(restored_face) + if shared.opts.face_restoration_unload: + self.send_model_to(devices.cpu) - self.face_helper.get_inverse_affine(None) + return restored_img - restored_img = self.face_helper.paste_faces_to_input_image() - restored_img = restored_img[:, :, ::-1] - - if original_resolution != restored_img.shape[0:2]: - restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) - - self.face_helper.clean_all() - - if shared.opts.face_restoration_unload: - self.send_model_to(devices.cpu) - - return restored_img +def setup_model(dirname): + os.makedirs(model_path, exist_ok=True) + try: global codeformer codeformer = FaceRestorerCodeFormer(dirname) shared.face_restorers.append(codeformer) - except Exception: errors.report("Error setting up CodeFormer", exc_info=True) - - # sys.path = stored_sys_path diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index c0d22a99273..a7c7c9e3013 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,122 +1,9 @@ -import sys - -import torch - -import modules.esrgan_model_arch as arch -from modules import modelloader, devices +from modules import modelloader, devices, errors from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData from modules.upscaler_utils import upscale_with_model -def mod2normal(state_dict): - # this code is copied from https://github.com/victorca25/iNNfer - if 'conv_first.weight' in state_dict: - crt_net = {} - items = list(state_dict) - - crt_net['model.0.weight'] = state_dict['conv_first.weight'] - crt_net['model.0.bias'] = state_dict['conv_first.bias'] - - for k in items.copy(): - if 'RDB' in k: - ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[ori_k] = state_dict[k] - items.remove(k) - - crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight'] - crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias'] - crt_net['model.3.weight'] = state_dict['upconv1.weight'] - crt_net['model.3.bias'] = state_dict['upconv1.bias'] - crt_net['model.6.weight'] = state_dict['upconv2.weight'] - crt_net['model.6.bias'] = state_dict['upconv2.bias'] - crt_net['model.8.weight'] = state_dict['HRconv.weight'] - crt_net['model.8.bias'] = state_dict['HRconv.bias'] - crt_net['model.10.weight'] = state_dict['conv_last.weight'] - crt_net['model.10.bias'] = state_dict['conv_last.bias'] - state_dict = crt_net - return state_dict - - -def resrgan2normal(state_dict, nb=23): - # this code is copied from https://github.com/victorca25/iNNfer - if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: - re8x = 0 - crt_net = {} - items = list(state_dict) - - crt_net['model.0.weight'] = state_dict['conv_first.weight'] - crt_net['model.0.bias'] = state_dict['conv_first.bias'] - - for k in items.copy(): - if "rdb" in k: - ori_k = k.replace('body.', 'model.1.sub.') - ori_k = ori_k.replace('.rdb', '.RDB') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[ori_k] = state_dict[k] - items.remove(k) - - crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight'] - crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias'] - crt_net['model.3.weight'] = state_dict['conv_up1.weight'] - crt_net['model.3.bias'] = state_dict['conv_up1.bias'] - crt_net['model.6.weight'] = state_dict['conv_up2.weight'] - crt_net['model.6.bias'] = state_dict['conv_up2.bias'] - - if 'conv_up3.weight' in state_dict: - # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py - re8x = 3 - crt_net['model.9.weight'] = state_dict['conv_up3.weight'] - crt_net['model.9.bias'] = state_dict['conv_up3.bias'] - - crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight'] - crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias'] - crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight'] - crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias'] - - state_dict = crt_net - return state_dict - - -def infer_params(state_dict): - # this code is copied from https://github.com/victorca25/iNNfer - scale2x = 0 - scalemin = 6 - n_uplayer = 0 - plus = False - - for block in list(state_dict): - parts = block.split(".") - n_parts = len(parts) - if n_parts == 5 and parts[2] == "sub": - nb = int(parts[3]) - elif n_parts == 3: - part_num = int(parts[1]) - if (part_num > scalemin - and parts[0] == "model" - and parts[2] == "weight"): - scale2x += 1 - if part_num > n_uplayer: - n_uplayer = part_num - out_nc = state_dict[block].shape[0] - if not plus and "conv1x1" in block: - plus = True - - nf = state_dict["model.0.weight"].shape[0] - in_nc = state_dict["model.0.weight"].shape[1] - out_nc = out_nc - scale = 2 ** scale2x - - return in_nc, out_nc, nf, nb, plus, scale - - class UpscalerESRGAN(Upscaler): def __init__(self, dirname): self.name = "ESRGAN" @@ -142,12 +29,11 @@ def __init__(self, dirname): def do_upscale(self, img, selected_model): try: model = self.load_model(selected_model) - except Exception as e: - print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) + except Exception: + errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True) return img model.to(devices.device_esrgan) - img = esrgan_upscale(model, img) - return img + return esrgan_upscale(model, img) def load_model(self, path: str): if path.startswith("http"): @@ -160,33 +46,10 @@ def load_model(self, path: str): else: filename = path - state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) - - if "params_ema" in state_dict: - state_dict = state_dict["params_ema"] - elif "params" in state_dict: - state_dict = state_dict["params"] - num_conv = 16 if "realesr-animevideov3" in filename else 32 - model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu') - model.load_state_dict(state_dict) - model.eval() - return model - - if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict: - nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23 - state_dict = resrgan2normal(state_dict, nb) - elif "conv_first.weight" in state_dict: - state_dict = mod2normal(state_dict) - elif "model.0.weight" not in state_dict: - raise Exception("The file is not a recognized ESRGAN model.") - - in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) - - model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus) - model.load_state_dict(state_dict) - model.eval() - - return model + return modelloader.load_spandrel_model( + filename, + device=('cpu' if devices.device_esrgan.type == 'mps' else None), + ) def esrgan_upscale(model, img): diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py deleted file mode 100644 index 2b9888bafbc..00000000000 --- a/modules/esrgan_model_arch.py +++ /dev/null @@ -1,465 +0,0 @@ -# this file is adapted from https://github.com/victorca25/iNNfer - -from collections import OrderedDict -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - - -#################### -# RRDBNet Generator -#################### - -class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None, - act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', - finalact=None, gaussian_noise=False, plus=False): - super(RRDBNet, self).__init__() - n_upscale = int(math.log(upscale, 2)) - if upscale == 3: - n_upscale = 1 - - self.resrgan_scale = 0 - if in_nc % 16 == 0: - self.resrgan_scale = 1 - elif in_nc != 4 and in_nc % 4 == 0: - self.resrgan_scale = 2 - - fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) - rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype, - gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)] - LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype) - - if upsample_mode == 'upconv': - upsample_block = upconv_block - elif upsample_mode == 'pixelshuffle': - upsample_block = pixelshuffle_block - else: - raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found') - if upscale == 3: - upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype) - else: - upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)] - HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype) - HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) - - outact = act(finalact) if finalact else None - - self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)), - *upsampler, HR_conv0, HR_conv1, outact) - - def forward(self, x, outm=None): - if self.resrgan_scale == 1: - feat = pixel_unshuffle(x, scale=4) - elif self.resrgan_scale == 2: - feat = pixel_unshuffle(x, scale=2) - else: - feat = x - - return self.model(feat) - - -class RRDB(nn.Module): - """ - Residual in Residual Dense Block - (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) - """ - - def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', - spectral_norm=False, gaussian_noise=False, plus=False): - super(RRDB, self).__init__() - # This is for backwards compatibility with existing models - if nr == 3: - self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - else: - RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)] - self.RDBs = nn.Sequential(*RDB_list) - - def forward(self, x): - if hasattr(self, 'RDB1'): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - else: - out = self.RDBs(x) - return out * 0.2 + x - - -class ResidualDenseBlock_5C(nn.Module): - """ - Residual Dense Block - The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) - Modified options that can be used: - - "Partial Convolution based Padding" arXiv:1811.11718 - - "Spectral normalization" arXiv:1802.05957 - - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. - {Rakotonirina} and A. {Rasoanaivo} - """ - - def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', - spectral_norm=False, gaussian_noise=False, plus=False): - super(ResidualDenseBlock_5C, self).__init__() - - self.noise = GaussianNoise() if gaussian_noise else None - self.conv1x1 = conv1x1(nf, gc) if plus else None - - self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - if mode == 'CNA': - last_act = None - else: - last_act = act_type - self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - - def forward(self, x): - x1 = self.conv1(x) - x2 = self.conv2(torch.cat((x, x1), 1)) - if self.conv1x1: - x2 = x2 + self.conv1x1(x) - x3 = self.conv3(torch.cat((x, x1, x2), 1)) - x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) - if self.conv1x1: - x4 = x4 + x2 - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - if self.noise: - return self.noise(x5.mul(0.2) + x) - else: - return x5 * 0.2 + x - - -#################### -# ESRGANplus -#################### - -class GaussianNoise(nn.Module): - def __init__(self, sigma=0.1, is_relative_detach=False): - super().__init__() - self.sigma = sigma - self.is_relative_detach = is_relative_detach - self.noise = torch.tensor(0, dtype=torch.float) - - def forward(self, x): - if self.training and self.sigma != 0: - self.noise = self.noise.to(x.device) - scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x - sampled_noise = self.noise.repeat(*x.size()).normal_() * scale - x = x + sampled_noise - return x - -def conv1x1(in_planes, out_planes, stride=1): - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - - -#################### -# SRVGGNetCompact -#################### - -class SRVGGNetCompact(nn.Module): - """A compact VGG-style network structure for super-resolution. - This class is copied from https://github.com/xinntao/Real-ESRGAN - """ - - def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): - super(SRVGGNetCompact, self).__init__() - self.num_in_ch = num_in_ch - self.num_out_ch = num_out_ch - self.num_feat = num_feat - self.num_conv = num_conv - self.upscale = upscale - self.act_type = act_type - - self.body = nn.ModuleList() - # the first conv - self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) - # the first activation - if act_type == 'relu': - activation = nn.ReLU(inplace=True) - elif act_type == 'prelu': - activation = nn.PReLU(num_parameters=num_feat) - elif act_type == 'leakyrelu': - activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.body.append(activation) - - # the body structure - for _ in range(num_conv): - self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) - # activation - if act_type == 'relu': - activation = nn.ReLU(inplace=True) - elif act_type == 'prelu': - activation = nn.PReLU(num_parameters=num_feat) - elif act_type == 'leakyrelu': - activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.body.append(activation) - - # the last conv - self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) - # upsample - self.upsampler = nn.PixelShuffle(upscale) - - def forward(self, x): - out = x - for i in range(0, len(self.body)): - out = self.body[i](out) - - out = self.upsampler(out) - # add the nearest upsampled image, so that the network learns the residual - base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') - out += base - return out - - -#################### -# Upsampler -#################### - -class Upsample(nn.Module): - r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. - The input data is assumed to be of the form - `minibatch x channels x [optional depth] x [optional height] x width`. - """ - - def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None): - super(Upsample, self).__init__() - if isinstance(scale_factor, tuple): - self.scale_factor = tuple(float(factor) for factor in scale_factor) - else: - self.scale_factor = float(scale_factor) if scale_factor else None - self.mode = mode - self.size = size - self.align_corners = align_corners - - def forward(self, x): - return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) - - def extra_repr(self): - if self.scale_factor is not None: - info = f'scale_factor={self.scale_factor}' - else: - info = f'size={self.size}' - info += f', mode={self.mode}' - return info - - -def pixel_unshuffle(x, scale): - """ Pixel unshuffle. - Args: - x (Tensor): Input feature with shape (b, c, hh, hw). - scale (int): Downsample ratio. - Returns: - Tensor: the pixel unshuffled feature. - """ - b, c, hh, hw = x.size() - out_channel = c * (scale**2) - assert hh % scale == 0 and hw % scale == 0 - h = hh // scale - w = hw // scale - x_view = x.view(b, c, h, scale, w, scale) - return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) - - -def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'): - """ - Pixel shuffle layer - (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional - Neural Network, CVPR17) - """ - conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, - pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype) - pixel_shuffle = nn.PixelShuffle(upscale_factor) - - n = norm(norm_type, out_nc) if norm_type else None - a = act(act_type) if act_type else None - return sequential(conv, pixel_shuffle, n, a) - - -def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'): - """ Upconv layer """ - upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor - upsample = Upsample(scale_factor=upscale_factor, mode=mode) - conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, - pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype) - return sequential(upsample, conv) - - - - - - - - -#################### -# Basic blocks -#################### - - -def make_layer(basic_block, num_basic_block, **kwarg): - """Make layers by stacking the same blocks. - Args: - basic_block (nn.module): nn.module class for basic block. (block) - num_basic_block (int): number of blocks. (n_layers) - Returns: - nn.Sequential: Stacked blocks in nn.Sequential. - """ - layers = [] - for _ in range(num_basic_block): - layers.append(basic_block(**kwarg)) - return nn.Sequential(*layers) - - -def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0): - """ activation helper """ - act_type = act_type.lower() - if act_type == 'relu': - layer = nn.ReLU(inplace) - elif act_type in ('leakyrelu', 'lrelu'): - layer = nn.LeakyReLU(neg_slope, inplace) - elif act_type == 'prelu': - layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) - elif act_type == 'tanh': # [-1, 1] range output - layer = nn.Tanh() - elif act_type == 'sigmoid': # [0, 1] range output - layer = nn.Sigmoid() - else: - raise NotImplementedError(f'activation layer [{act_type}] is not found') - return layer - - -class Identity(nn.Module): - def __init__(self, *kwargs): - super(Identity, self).__init__() - - def forward(self, x, *kwargs): - return x - - -def norm(norm_type, nc): - """ Return a normalization layer """ - norm_type = norm_type.lower() - if norm_type == 'batch': - layer = nn.BatchNorm2d(nc, affine=True) - elif norm_type == 'instance': - layer = nn.InstanceNorm2d(nc, affine=False) - elif norm_type == 'none': - def norm_layer(x): return Identity() - else: - raise NotImplementedError(f'normalization layer [{norm_type}] is not found') - return layer - - -def pad(pad_type, padding): - """ padding layer helper """ - pad_type = pad_type.lower() - if padding == 0: - return None - if pad_type == 'reflect': - layer = nn.ReflectionPad2d(padding) - elif pad_type == 'replicate': - layer = nn.ReplicationPad2d(padding) - elif pad_type == 'zero': - layer = nn.ZeroPad2d(padding) - else: - raise NotImplementedError(f'padding layer [{pad_type}] is not implemented') - return layer - - -def get_valid_padding(kernel_size, dilation): - kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) - padding = (kernel_size - 1) // 2 - return padding - - -class ShortcutBlock(nn.Module): - """ Elementwise sum the output of a submodule to its input """ - def __init__(self, submodule): - super(ShortcutBlock, self).__init__() - self.sub = submodule - - def forward(self, x): - output = x + self.sub(x) - return output - - def __repr__(self): - return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|') - - -def sequential(*args): - """ Flatten Sequential. It unwraps nn.Sequential. """ - if len(args) == 1: - if isinstance(args[0], OrderedDict): - raise NotImplementedError('sequential does not support OrderedDict input.') - return args[0] # No sequential is needed. - modules = [] - for module in args: - if isinstance(module, nn.Sequential): - for submodule in module.children(): - modules.append(submodule) - elif isinstance(module, nn.Module): - modules.append(module) - return nn.Sequential(*modules) - - -def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', - spectral_norm=False): - """ Conv layer with padding, normalization, activation """ - assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]' - padding = get_valid_padding(kernel_size, dilation) - p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None - padding = padding if pad_type == 'zero' else 0 - - if convtype=='PartialConv2D': - from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer - c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - elif convtype=='DeformConv2D': - from torchvision.ops import DeformConv2d # not tested - c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - elif convtype=='Conv3D': - c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - else: - c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - - if spectral_norm: - c = nn.utils.spectral_norm(c) - - a = act(act_type) if act_type else None - if 'CNA' in mode: - n = norm(norm_type, out_nc) if norm_type else None - return sequential(p, c, n, a) - elif mode == 'NAC': - if norm_type is None and act_type is not None: - a = act(act_type, inplace=False) - n = norm(norm_type, in_nc) if norm_type else None - return sequential(n, a, p, c) diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 01d668ecdaf..6b6f17c435a 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,8 +1,5 @@ import os -import facexlib -import gfpgan - import modules.face_restoration from modules import paths, shared, devices, modelloader, errors @@ -41,6 +38,8 @@ def gfpgann(): print("Unable to load gfpgan model!") return None + import facexlib.detection.retinaface + if hasattr(facexlib.detection.retinaface, 'device'): facexlib.detection.retinaface.device = devices.device_gfpgan model_file_path = model_file @@ -81,8 +80,10 @@ def gfpgan_fix_faces(np_image): def setup_model(dirname): try: os.makedirs(model_path, exist_ok=True) - from gfpgan import GFPGANer - from facexlib import detection, parsing # noqa: F401 + import gfpgan + import facexlib.detection + import facexlib.parsing + global user_path global have_gfpgan global gfpgan_constructor @@ -111,7 +112,7 @@ def facex_load_file_from_url2(**kwargs): facexlib.parsing.load_file_from_url = facex_load_file_from_url2 user_path = dirname have_gfpgan = True - gfpgan_constructor = GFPGANer + gfpgan_constructor = gfpgan.GFPGANer class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): def name(self): diff --git a/modules/launch_utils.py b/modules/launch_utils.py index dabef0f5335..c2cbd8ce73e 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -345,13 +345,11 @@ def prepare_environment(): stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') - codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") - codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") try: @@ -408,15 +406,10 @@ def prepare_environment(): git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) - git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) startup_timer.record("clone repositores") - if not is_installed("lpips"): - run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer") - startup_timer.record("install CodeFormer requirements") - if not os.path.isfile(requirements_file): requirements_file = os.path.join(script_path, requirements_file) diff --git a/modules/modelloader.py b/modules/modelloader.py index 098bcb79336..30116932a1d 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import os import shutil import importlib @@ -10,6 +11,9 @@ from modules.paths import script_path, models_path +logger = logging.getLogger(__name__) + + def load_file_from_url( url: str, *, @@ -177,3 +181,15 @@ def load_upscalers(): # Special case for UpscalerNone keeps it at the beginning of the list. key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" ) + + +def load_spandrel_model(path, *, device, half: bool = False, dtype=None): + import spandrel + model = spandrel.ModelLoader(device=device).load_from_file(path) + if half: + model = model.model.half() + if dtype: + model = model.model.to(dtype=dtype) + model.eval() + logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype) + return model diff --git a/modules/paths.py b/modules/paths.py index 187b949612a..030646519c3 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -38,7 +38,6 @@ class Dummy: path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), - (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), ] diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 02841c30289..332d8f4b108 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,9 +1,6 @@ import os -import numpy as np -from PIL import Image -from realesrgan import RealESRGANer - +from modules.upscaler_utils import upscale_with_model from modules.upscaler import Upscaler, UpscalerData from modules.shared import cmd_opts, opts from modules import modelloader, errors @@ -14,29 +11,20 @@ def __init__(self, path): self.name = "RealESRGAN" self.user_path = path super().__init__() - try: - from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401 - from realesrgan import RealESRGANer # noqa: F401 - from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401 - self.enable = True - self.scalers = [] - scalers = self.load_models(path) + self.enable = True + self.scalers = [] + scalers = get_realesrgan_models(self) - local_model_paths = self.find_models(ext_filter=[".pth"]) - for scaler in scalers: - if scaler.local_data_path.startswith("http"): - filename = modelloader.friendly_name(scaler.local_data_path) - local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")] - if local_model_candidates: - scaler.local_data_path = local_model_candidates[0] + local_model_paths = self.find_models(ext_filter=[".pth"]) + for scaler in scalers: + if scaler.local_data_path.startswith("http"): + filename = modelloader.friendly_name(scaler.local_data_path) + local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")] + if local_model_candidates: + scaler.local_data_path = local_model_candidates[0] - if scaler.name in opts.realesrgan_enabled_models: - self.scalers.append(scaler) - - except Exception: - errors.report("Error importing Real-ESRGAN", exc_info=True) - self.enable = False - self.scalers = [] + if scaler.name in opts.realesrgan_enabled_models: + self.scalers.append(scaler) def do_upscale(self, img, path): if not self.enable: @@ -48,20 +36,18 @@ def do_upscale(self, img, path): errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img - upsampler = RealESRGANer( - scale=info.scale, - model_path=info.local_data_path, - model=info.model(), - half=not cmd_opts.no_half and not cmd_opts.upcast_sampling, - tile=opts.ESRGAN_tile, - tile_pad=opts.ESRGAN_tile_overlap, + mod = modelloader.load_spandrel_model( + info.local_data_path, device=self.device, + half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + ) + return upscale_with_model( + mod, + img, + tile_size=opts.ESRGAN_tile, + tile_overlap=opts.ESRGAN_tile_overlap, + # TODO: `outscale`? ) - - upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0] - - image = Image.fromarray(upsampled) - return image def load_model(self, path): for scaler in self.scalers: @@ -76,58 +62,43 @@ def load_model(self, path): return scaler raise ValueError(f"Unable to find model info: {path}") - def load_models(self, _): - return get_realesrgan_models(self) - -def get_realesrgan_models(scaler): - try: - from basicsr.archs.rrdbnet_arch import RRDBNet - from realesrgan.archs.srvgg_arch import SRVGGNetCompact - models = [ - UpscalerData( - name="R-ESRGAN General 4xV3", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN General WDN 4xV3", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN AnimeVideo", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN 4x+", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", - scale=4, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) - ), - UpscalerData( - name="R-ESRGAN 4x+ Anime6B", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", - scale=4, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) - ), - UpscalerData( - name="R-ESRGAN 2x+", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", - scale=2, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) - ), - ] - return models - except Exception: - errors.report("Error making Real-ESRGAN models list", exc_info=True) +def get_realesrgan_models(scaler: UpscalerRealESRGAN): + return [ + UpscalerData( + name="R-ESRGAN General 4xV3", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN General WDN 4xV3", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN AnimeVideo", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 4x+", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 4x+ Anime6B", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 2x+", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + scale=2, + upscaler=scaler, + ), + ] diff --git a/modules/sysinfo.py b/modules/sysinfo.py index b669edd0cfd..5abf616b707 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -26,11 +26,9 @@ "OPENCLIP_PACKAGE", "STABLE_DIFFUSION_REPO", "K_DIFFUSION_REPO", - "CODEFORMER_REPO", "BLIP_REPO", "STABLE_DIFFUSION_COMMIT_HASH", "K_DIFFUSION_COMMIT_HASH", - "CODEFORMER_COMMIT_HASH", "BLIP_COMMIT_HASH", "COMMANDLINE_ARGS", "IGNORE_CMD_ARGS_ERRORS", diff --git a/modules/upscaler.py b/modules/upscaler.py index b256e085b6d..3aee69db8d2 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -98,6 +98,9 @@ def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = self.scale = scale self.model = model + def __repr__(self): + return f"" + class UpscalerNone(Upscaler): name = "None" diff --git a/requirements.txt b/requirements.txt index 80b438455ce..36f5674ad99 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ basicsr blendmodes clean-fid einops +facexlib fastapi>=0.90.1 gfpgan gradio==3.41.2 @@ -20,13 +21,11 @@ open-clip-torch piexif psutil pytorch_lightning -realesrgan requests resize-right safetensors scikit-image>=0.19 -timm tomesd torch torchdiffeq diff --git a/requirements_versions.txt b/requirements_versions.txt index cb7403a9d4b..042fa708c49 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -5,6 +5,7 @@ basicsr==1.4.2 blendmodes==2022 clean-fid==0.1.35 einops==0.4.1 +facexlib==0.3.0 fastapi==0.94.0 gfpgan==1.3.8 gradio==3.41.2 @@ -19,11 +20,10 @@ open-clip-torch==2.20.0 piexif==1.1.3 psutil==5.9.5 pytorch_lightning==1.9.4 -realesrgan==0.3.0 resize-right==0.0.2 safetensors==0.3.1 scikit-image==0.21.0 -timm==0.9.2 +spandrel==0.1.6 tomesd==0.1.3 torch torchdiffeq==0.2.3 From b621a63cf68c788487684250856707cb352b82d0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 25 Dec 2023 23:01:02 +0200 Subject: [PATCH 147/449] Unify CodeFormer and GFPGAN restoration backends, use Spandrel for GFPGAN --- .github/workflows/run_tests.yaml | 8 ++ .gitignore | 1 + modules/codeformer_model.py | 158 +++++++--------------------- modules/face_restoration_utils.py | 163 +++++++++++++++++++++++++++++ modules/gfpgan_model.py | 166 ++++++++++-------------------- requirements.txt | 1 - requirements_versions.txt | 1 - test/conftest.py | 15 ++- test/test_face_restorers.py | 29 ++++++ test/test_files/two-faces.jpg | Bin 0 -> 14768 bytes test/test_outputs/.gitkeep | 0 11 files changed, 308 insertions(+), 234 deletions(-) create mode 100644 modules/face_restoration_utils.py create mode 100644 test/test_face_restorers.py create mode 100644 test/test_files/two-faces.jpg create mode 100644 test/test_outputs/.gitkeep diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 3dafaf8dcfc..cd5c3f868ff 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -20,6 +20,12 @@ jobs: cache-dependency-path: | **/requirements*txt launch.py + - name: Cache models + id: cache-models + uses: actions/cache@v3 + with: + path: models + key: "2023-12-30" - name: Install test dependencies run: pip install wait-for-it -r requirements-test.txt env: @@ -33,6 +39,8 @@ jobs: TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu WEBUI_LAUNCH_LIVE_OUTPUT: "1" PYTHONUNBUFFERED: "1" + - name: Print installed packages + run: pip freeze - name: Start test server run: > python -m coverage run diff --git a/.gitignore b/.gitignore index 09734267ff5..6790e9ee728 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,4 @@ notification.mp3 /node_modules /package-lock.json /.coverage* +/test/test_outputs diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 517eadfd8de..ceda4bab919 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -1,140 +1,62 @@ -import os +from __future__ import annotations + +import logging -import cv2 import torch -import modules.face_restoration -import modules.shared -from modules import shared, devices, modelloader, errors -from modules.paths import models_path +from modules import ( + devices, + errors, + face_restoration, + face_restoration_utils, + modelloader, + shared, +) + +logger = logging.getLogger(__name__) -model_dir = "Codeformer" -model_path = os.path.join(models_path, model_dir) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' +model_download_name = 'codeformer-v0.1.0.pth' -codeformer = None +# used by e.g. postprocessing_codeformer.py +codeformer: face_restoration.FaceRestoration | None = None -class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): +class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration): def name(self): return "CodeFormer" - def __init__(self, dirname): - self.net = None - self.face_helper = None - self.cmd_dir = dirname - - def create_models(self): - from facexlib.detection import retinaface - from facexlib.utils.face_restoration_helper import FaceRestoreHelper - - if self.net is not None and self.face_helper is not None: - self.net.to(devices.device_codeformer) - return self.net, self.face_helper - model_paths = modelloader.load_models( - model_path, - model_url, - self.cmd_dir, - download_name='codeformer-v0.1.0.pth', + def load_net(self) -> torch.Module: + for model_path in modelloader.load_models( + model_path=self.model_path, + model_url=model_url, + command_path=self.model_path, + download_name=model_download_name, ext_filter=['.pth'], - ) - - if len(model_paths) != 0: - ckpt_path = model_paths[0] - else: - print("Unable to load codeformer model.") - return None, None - net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer) - - if hasattr(retinaface, 'device'): - retinaface.device = devices.device_codeformer - - face_helper = FaceRestoreHelper( - upscale_factor=1, - face_size=512, - crop_ratio=(1, 1), - det_model='retinaface_resnet50', - save_ext='png', - use_parse=True, - device=devices.device_codeformer, - ) - - self.net = net - self.face_helper = face_helper - - def send_model_to(self, device): - self.net.to(device) - self.face_helper.face_det.to(device) - self.face_helper.face_parse.to(device) - - def restore(self, np_image, w=None): - from torchvision.transforms.functional import normalize - from basicsr.utils import img2tensor, tensor2img - np_image = np_image[:, :, ::-1] - - original_resolution = np_image.shape[0:2] - - self.create_models() - if self.net is None or self.face_helper is None: - return np_image - - self.send_model_to(devices.device_codeformer) - - self.face_helper.clean_all() - self.face_helper.read_image(np_image) - self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) - self.face_helper.align_warp_face() - - for cropped_face in self.face_helper.cropped_faces: - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) - - try: - with torch.no_grad(): - res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True) - if isinstance(res, tuple): - output = res[0] - else: - output = res - if not isinstance(res, torch.Tensor): - raise TypeError(f"Expected torch.Tensor, got {type(res)}") - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - devices.torch_gc() - except Exception: - errors.report('Failed inference for CodeFormer', exc_info=True) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - - restored_face = restored_face.astype('uint8') - self.face_helper.add_restored_face(restored_face) - - self.face_helper.get_inverse_affine(None) - - restored_img = self.face_helper.paste_faces_to_input_image() - restored_img = restored_img[:, :, ::-1] + ): + return modelloader.load_spandrel_model( + model_path, + device=devices.device_codeformer, + ).model + raise ValueError("No codeformer model found") - if original_resolution != restored_img.shape[0:2]: - restored_img = cv2.resize( - restored_img, - (0, 0), - fx=original_resolution[1]/restored_img.shape[1], - fy=original_resolution[0]/restored_img.shape[0], - interpolation=cv2.INTER_LINEAR, - ) + def get_device(self): + return devices.device_codeformer - self.face_helper.clean_all() + def restore(self, np_image, w: float | None = None): + if w is None: + w = getattr(shared.opts, "code_former_weight", 0.5) - if shared.opts.face_restoration_unload: - self.send_model_to(devices.cpu) + def restore_face(cropped_face_t): + assert self.net is not None + return self.net(cropped_face_t, w=w, adain=True)[0] - return restored_img + return self.restore_with_helper(np_image, restore_face) -def setup_model(dirname): - os.makedirs(model_path, exist_ok=True) +def setup_model(dirname: str) -> None: + global codeformer try: - global codeformer codeformer = FaceRestorerCodeFormer(dirname) shared.face_restorers.append(codeformer) except Exception: diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py new file mode 100644 index 00000000000..c65c85ef89d --- /dev/null +++ b/modules/face_restoration_utils.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import logging +import os +from functools import cached_property +from typing import TYPE_CHECKING, Callable + +import cv2 +import numpy as np +import torch + +from modules import devices, errors, face_restoration, shared + +if TYPE_CHECKING: + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + +logger = logging.getLogger(__name__) + + +def create_face_helper(device) -> FaceRestoreHelper: + from facexlib.detection import retinaface + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + if hasattr(retinaface, 'device'): + retinaface.device = device + return FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + use_parse=True, + device=device, + ) + + +def restore_with_face_helper( + np_image: np.ndarray, + face_helper: FaceRestoreHelper, + restore_face: Callable[[np.ndarray], np.ndarray], +) -> np.ndarray: + """ + Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image. + + `restore_face` should take a cropped face image and return a restored face image. + """ + from basicsr.utils import img2tensor, tensor2img + from torchvision.transforms.functional import normalize + np_image = np_image[:, :, ::-1] + original_resolution = np_image.shape[0:2] + + try: + logger.debug("Detecting faces...") + face_helper.clean_all() + face_helper.read_image(np_image) + face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + face_helper.align_warp_face() + logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) + for cropped_face in face_helper.cropped_faces: + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) + + try: + with torch.no_grad(): + restored_face = tensor2img( + restore_face(cropped_face_t), + rgb2bgr=True, + min_max=(-1, 1), + ) + devices.torch_gc() + except Exception: + errors.report('Failed face-restoration inference', exc_info=True) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + + restored_face = restored_face.astype('uint8') + face_helper.add_restored_face(restored_face) + + logger.debug("Merging restored faces into image") + face_helper.get_inverse_affine(None) + img = face_helper.paste_faces_to_input_image() + img = img[:, :, ::-1] + if original_resolution != img.shape[0:2]: + img = cv2.resize( + img, + (0, 0), + fx=original_resolution[1] / img.shape[1], + fy=original_resolution[0] / img.shape[0], + interpolation=cv2.INTER_LINEAR, + ) + logger.debug("Face restoration complete") + finally: + face_helper.clean_all() + return img + + +class CommonFaceRestoration(face_restoration.FaceRestoration): + net: torch.Module | None + model_url: str + model_download_name: str + + def __init__(self, model_path: str): + super().__init__() + self.net = None + self.model_path = model_path + os.makedirs(model_path, exist_ok=True) + + @cached_property + def face_helper(self) -> FaceRestoreHelper: + return create_face_helper(self.get_device()) + + def send_model_to(self, device): + if self.net: + logger.debug("Sending %s to %s", self.net, device) + self.net.to(device) + if self.face_helper: + logger.debug("Sending face helper to %s", device) + self.face_helper.face_det.to(device) + self.face_helper.face_parse.to(device) + + def get_device(self): + raise NotImplementedError("get_device must be implemented by subclasses") + + def load_net(self) -> torch.Module: + raise NotImplementedError("load_net must be implemented by subclasses") + + def restore_with_helper( + self, + np_image: np.ndarray, + restore_face: Callable[[np.ndarray], np.ndarray], + ) -> np.ndarray: + try: + if self.net is None: + self.net = self.load_net() + except Exception: + logger.warning("Unable to load face-restoration model", exc_info=True) + return np_image + + try: + self.send_model_to(self.get_device()) + return restore_with_face_helper(np_image, self.face_helper, restore_face) + finally: + if shared.opts.face_restoration_unload: + self.send_model_to(devices.cpu) + + +def patch_facexlib(dirname: str) -> None: + import facexlib.detection + import facexlib.parsing + + det_facex_load_file_from_url = facexlib.detection.load_file_from_url + par_facex_load_file_from_url = facexlib.parsing.load_file_from_url + + def update_kwargs(kwargs): + return dict(kwargs, save_dir=dirname, model_dir=None) + + def facex_load_file_from_url(**kwargs): + return det_facex_load_file_from_url(**update_kwargs(kwargs)) + + def facex_load_file_from_url2(**kwargs): + return par_facex_load_file_from_url(**update_kwargs(kwargs)) + + facexlib.detection.load_file_from_url = facex_load_file_from_url + facexlib.parsing.load_file_from_url = facex_load_file_from_url2 diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 6b6f17c435a..a356b56fe1a 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,126 +1,68 @@ +from __future__ import annotations + +import logging import os -import modules.face_restoration -from modules import paths, shared, devices, modelloader, errors +from modules import ( + devices, + errors, + face_restoration, + face_restoration_utils, + modelloader, + shared, +) -model_dir = "GFPGAN" -user_path = None -model_path = os.path.join(paths.models_path, model_dir) -model_file_path = None +logger = logging.getLogger(__name__) model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" -have_gfpgan = False -loaded_gfpgan_model = None - - -def gfpgann(): - global loaded_gfpgan_model - global model_path - global model_file_path - if loaded_gfpgan_model is not None: - loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) - return loaded_gfpgan_model - - if gfpgan_constructor is None: - return None - - models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth']) - - if len(models) == 1 and models[0].startswith("http"): - model_file = models[0] - elif len(models) != 0: - gfp_models = [] - for item in models: - if 'GFPGAN' in os.path.basename(item): - gfp_models.append(item) - latest_file = max(gfp_models, key=os.path.getctime) - model_file = latest_file - else: - print("Unable to load gfpgan model!") - return None - - import facexlib.detection.retinaface - - if hasattr(facexlib.detection.retinaface, 'device'): - facexlib.detection.retinaface.device = devices.device_gfpgan - model_file_path = model_file - model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) - loaded_gfpgan_model = model - - return model - - -def send_model_to(model, device): - model.gfpgan.to(device) - model.face_helper.face_det.to(device) - model.face_helper.face_parse.to(device) +model_download_name = "GFPGANv1.4.pth" +gfpgan_face_restorer: face_restoration.FaceRestoration | None = None + + +class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): + def name(self): + return "GFPGAN" + + def get_device(self): + return devices.device_gfpgan + + def load_net(self) -> None: + for model_path in modelloader.load_models( + model_path=self.model_path, + model_url=model_url, + command_path=self.model_path, + download_name=model_download_name, + ext_filter=['.pth'], + ): + if 'GFPGAN' in os.path.basename(model_path): + net = modelloader.load_spandrel_model( + model_path, + device=self.get_device(), + ).model + net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 + return net + raise ValueError("No GFPGAN model found") + + def restore(self, np_image): + def restore_face(cropped_face_t): + assert self.net is not None + return self.net(cropped_face_t, return_rgb=False)[0] + + return self.restore_with_helper(np_image, restore_face) def gfpgan_fix_faces(np_image): - model = gfpgann() - if model is None: - return np_image - - send_model_to(model, devices.device_gfpgan) - - np_image_bgr = np_image[:, :, ::-1] - cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) - np_image = gfpgan_output_bgr[:, :, ::-1] - - model.face_helper.clean_all() - - if shared.opts.face_restoration_unload: - send_model_to(model, devices.cpu) - + if gfpgan_face_restorer: + return gfpgan_face_restorer.restore(np_image) + logger.warning("GFPGAN face restorer not set up") return np_image -gfpgan_constructor = None +def setup_model(dirname: str) -> None: + global gfpgan_face_restorer - -def setup_model(dirname): try: - os.makedirs(model_path, exist_ok=True) - import gfpgan - import facexlib.detection - import facexlib.parsing - - global user_path - global have_gfpgan - global gfpgan_constructor - global model_file_path - - facexlib_path = model_path - - if dirname is not None: - facexlib_path = dirname - - load_file_from_url_orig = gfpgan.utils.load_file_from_url - facex_load_file_from_url_orig = facexlib.detection.load_file_from_url - facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url - - def my_load_file_from_url(**kwargs): - return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path)) - - def facex_load_file_from_url(**kwargs): - return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None)) - - def facex_load_file_from_url2(**kwargs): - return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None)) - - gfpgan.utils.load_file_from_url = my_load_file_from_url - facexlib.detection.load_file_from_url = facex_load_file_from_url - facexlib.parsing.load_file_from_url = facex_load_file_from_url2 - user_path = dirname - have_gfpgan = True - gfpgan_constructor = gfpgan.GFPGANer - - class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): - def name(self): - return "GFPGAN" - - def restore(self, np_image): - return gfpgan_fix_faces(np_image) - - shared.face_restorers.append(FaceRestorerGFPGAN()) + face_restoration_utils.patch_facexlib(dirname) + gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname) + shared.face_restorers.append(gfpgan_face_restorer) except Exception: errors.report("Error setting up GFPGAN", exc_info=True) diff --git a/requirements.txt b/requirements.txt index 36f5674ad99..b1329c9e3ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,6 @@ clean-fid einops facexlib fastapi>=0.90.1 -gfpgan gradio==3.41.2 inflection jsonmerge diff --git a/requirements_versions.txt b/requirements_versions.txt index 042fa708c49..edbb6db9e0e 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -7,7 +7,6 @@ clean-fid==0.1.35 einops==0.4.1 facexlib==0.3.0 fastapi==0.94.0 -gfpgan==1.3.8 gradio==3.41.2 httpcore==0.15 inflection==0.5.1 diff --git a/test/conftest.py b/test/conftest.py index 31a5d9eafb8..e4fc567854d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,10 +1,16 @@ +import base64 import os import pytest -import base64 - test_files_path = os.path.dirname(__file__) + "/test_files" +test_outputs_path = os.path.dirname(__file__) + "/test_outputs" + + +def pytest_configure(config): + # We don't want to fail on Py.test command line arguments being + # parsed by webui: + os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1") def file_to_base64(filename): @@ -23,3 +29,8 @@ def img2img_basic_image_base64() -> str: @pytest.fixture(scope="session") # session so we don't read this over and over def mask_basic_image_base64() -> str: return file_to_base64(os.path.join(test_files_path, "mask_basic.png")) + + +@pytest.fixture(scope="session") +def initialize() -> None: + import webui # noqa: F401 diff --git a/test/test_face_restorers.py b/test/test_face_restorers.py new file mode 100644 index 00000000000..7760d51bfbe --- /dev/null +++ b/test/test_face_restorers.py @@ -0,0 +1,29 @@ +import os +from test.conftest import test_files_path, test_outputs_path + +import numpy as np +import pytest +from PIL import Image + + +@pytest.mark.usefixtures("initialize") +@pytest.mark.parametrize("restorer_name", ["gfpgan", "codeformer"]) +def test_face_restorers(restorer_name): + from modules import shared + + if restorer_name == "gfpgan": + from modules import gfpgan_model + gfpgan_model.setup_model(shared.cmd_opts.gfpgan_models_path) + restorer = gfpgan_model.gfpgan_fix_faces + elif restorer_name == "codeformer": + from modules import codeformer_model + codeformer_model.setup_model(shared.cmd_opts.codeformer_models_path) + restorer = codeformer_model.codeformer.restore + else: + raise NotImplementedError("...") + img = Image.open(os.path.join(test_files_path, "two-faces.jpg")) + np_img = np.array(img, dtype=np.uint8) + fixed_image = restorer(np_img) + assert fixed_image.shape == np_img.shape + assert not np.allclose(fixed_image, np_img) # should have visibly changed + Image.fromarray(fixed_image).save(os.path.join(test_outputs_path, f"{restorer_name}.png")) diff --git a/test/test_files/two-faces.jpg b/test/test_files/two-faces.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c9d1b01032a7298d76608c8b65cbb243463491c5 GIT binary patch literal 14768 zcmbVz^LHju)9n-6wr$(Cp4hfEu`#i2+t$QRCUz#~#7-u-dEfhe_Yb(${adZmeb(t! zyQ=oC{#yIm1t7~x%18o0KmY)c?+fs?2?zy%{yzf)0|f&E2m3Z~Nbvs{Bn$*3#PQS_fQ*KMf`*NSO+-Ws^8eOpBU(=V^Q~-H4M7#g@dO9zX4+!d2gNfw(CEqw0*5$;BaF${~6H~tc~o# z_=CaCf^y3T4B`tXfI!ujEhIeVg!1U-Bn4;cI5D`Ju6(vFJnmH}txN}~biq~yee(OMc zX_Jzrvr-8Bw`n3>f?;Cgig5eMy#^l4FnC!wf9R7c|hQBhx+$hVSM_`$oH z3~XPl^$2<}H+=zm8571hkCPGKCRb?{E$vDwNlQL-5iK&L{w#Z#3EA3TYvG*EMkU1A zk*f@MRjxZwZPcwl6=$wCx~_7voQc{@VAPv_1Mcv z`?M)V>uVmKNSu~-A1(SP)dkwC{A$JfZ7G=XUeT`6z|i2-qhh%}ISi2tHw>$lOqOhj z80$M#{939h#mRz#1mq9@yzVS2@>OrDJNzd{hNN*{6dH@~z%; z?_|t|JKV$;ai&gSEjH!3z?qngyNPx;f$?n*?QUUSDi&9I^k8C~cmHpV1`+%KDCB>J6^5GDC*R!s<)wn-EW>O=K2v@>)eR0r(CaP{?EQG%`~9QN0k3O&OD5n zv+O>Q0h3CtJ%_WCHSASGj^tjKYdaX8(@a6GE0!9gP@_-W9V_|D<>N*&PpqLjpp zWEAv+GrSV**8AFNbkaO0^FH6{68!3zx~{P%{TILa$ym*;B3zP?WGxg*{gtk*mtmmp-mix7hY4GYsC|j-3b9cuH>`3G?5;DxwaK1qAg2p zk8a@rb3dz0fy*sq(x?ljVJ>OT7eL@mCbhr9u`$wC$Yw~s@IsxUMdfN?u3VG<;aQ0G zS2l*kKwNjpZ9X#|st`#V)cI3d!V_$b$vs{X@I;`tK18k2>ylT;IBD23zamiA{*(U5l0p3NCg<+fZlkL+ghzHTA$I+6~3etKo*LhX2k28WLQ zPtwLRF`0J6?w5a?0(Q{OvN5?CDGR|&h@|VxmCFk)3;m-}F(#HV&AG?4gO2=E5sA#P zsPc4U82o9>P`Cyi_TgI&;PKHnhNmA-R2(9qQGfONaKDvbHV2~}ug|=&wH2$hD6`$3 zFQjzIxq8<{JXwmkNrf_#XA)JDz(fy`1Uk8YAN^9#E={NSy#8-l0)*qj0G3+aRf->6 zUUc^S7eL07CZ64MZ5^&YOG+f#6v90PSA8S5cSi=rHAlIo&JJp>ckZ?a!IIIyw&Y&U zLBVz-FS=6h@?RJ!Mw=(~>DE(cpsS3kXfXbb*HgO39vvZ7tZe+PJ3m9OU$TMPZ<1fj z&yMmWXa2r;988u&QX#9T2Ketp?@aH8FUrNST=ne6e2IsuuqJ9p(OT|XE9>a0jP|4uJ+r?97atpquh)gM_oaT9Ft zcPG=KTYk`BE0!xc{HJSk;2PBppVp+CcJix>YQ9Aud*-p-2$iGhs_dG)_E{WB@e9!Y z7wJS(YA?JLK}u@)9^W2A+_j+M!E0gdqlt2=b5gtToMp+i%i5VX9c03-hns&fj%Md>yZE&1nx4l*d{?ORjPx|pXltABl>?3+$hB!l$(G2hl<(` zZBlKse0pW#g-`bg#(pz->grd0))GLrOn5g$2(e&2c`-Pk#puFQ_go5^_ncMp5sC#J zs}F`RL(Wwkb^WP4X+cHMfepct2*70F~zf38+^r1FWXYI;duNz-}dbH%dF=X$Yx zqDsf1A-~)%c1lp2nkAP7@sqt-FkolJt^5-Isyao9V7HEMgR%HU+0POILSrL2zqFW} zaLdH0)c*2m8>)nNRD;Qw%_{Ma3)Rc%5|~zXTF-@C^L>`$&ZC1X{dzL0-^sRx#CdXU z>vmXjS}m#vY!s?sG(W#g)xHLn4H*Ujm^$a?&{UZSU8wi+b)O3vxFt=le2mPpf{BmR zdO=OMY9FP5&0YAu^m4TvL%+F1=os$#@E3rob#7&}YGw+(qMKqrqrj8Q1;*q}eZADq@js(w)d+ zhu#TcSAp=J-BFGdMA5)ow1tP>=X4223u+f#KHGsB6?0L_xO>^T-IH`O4%uc46ZZLQ z(MToV8@NdJ{F??#C$ossHx!B>13gZ~4SU_u~HBsK$?Iuz56+i*5&dP`br zS;yQ2>gs%Wcy+?^WPb!l)QI-uyUKpXwN&{^95wbFUFFEhGL_K`sQi3(A(HUJ>V@fHUg)r%wQAIEp%wy~kU;62 z)nY-IzuO7{p-yLmY$lre!Yq~%D{~j+VDG7LT*An-mUNd{-q2TU&=A{}&D^`lh%puI zS2QoHTyxNx^g@4h%JPVgi&|UB`_)8yC?RYUxEoY2^UUHF=rD#r1S9bkaEL&jx6WI} zXL799Gy`{*Ik}~LDAOI9vL8OM&a>Uv;W*}I;W)MNmvPElyO=gwaP_&!)I@aQ3hdoj zw>e@-;Br-HC$C^^PDgLt~at{&lZA4JAmCqC4++}HIh z+o9z!!p}t*hJXptjHd2w)dH!-^9Fl7n;2?FC+vqp6!fZQ2`J||^ma8ELE^nwc4&V- zCDVx{egZk)5d>x1Vl0`n!A-5L=iaKfcGNww=nEin*4bm#z-0X{Fh?eu@GBto&F5_} z0u@r*oAwL_xI0ecoU%5VJd2WwY&qXj4b@XGpH=erdmKN=bgQ7DF1dRKHajdF5C2w4 z<}q-nSsDo%6at4311$UvkHE|O_s0zWa!FhRNS{wJDCHl?^@l3v=wXvAnv$T8Gr%&p zL10U#5kkm~9F1MdY=2{1y7DG%2ADMPTkru#2WURYedE+b0>U7%sBz~DRsKgd z{2{?kDRVJ^m7M~nGC~#xB$pDgo0vdsY#PHJZlx2yq#`APvUIpM#EZj)ANoTotLSPP z6doT9C$SBe{3zWQKzoI4i)AIYTwh9E0h~TI5CDp~=+5?DWqQ1>w$#w1mG$=-_{1Gf zI13@?zRxD1E~jFXeq)>YZV@LFao_$4RLo!d%VTBB?iV{m@v?aeZ02YZpK6G19jZ_# ztn!EzQRB{qt?V07M+^hD7pDqoKqx{b&^iXv1_VjmKoIm)VPl^W8Lb5G`^#_9Oo;SFq9`L@%b%&+Rww%V@4?GZRS@cRevcwZ+FiD zuI?Gtj)yOmk|_wE4ejUrehRqw%Q22&a_NAwFNoL;wLb*NJBZpBAPpT;@R9=@SwTT% z_r>nqEA8Dij4WfOFDXq&x-fux_xU&Et9aw(`9%AH38<{MkVNTjc3 z%iKj7d&eaq5iEML1RO)QI=@t>B>OD~T*u{V8qS6K+`5!m_7UH2CmPOs*|^>g`BikY zf;cK6hrzD2nWxb@cIlF7$Q_4m?gZt;*8DM`X4(Nn5r!IK| z=f*>nX^Kr1pRFA~zKYuk!IrlIJz$nIe=3N+03rg8EaF`cM8FOng9py5qbG@buGw2p6ZJbxK56z!CUs z=<0Z8ZGbdp@paj=FlGU{yjmk#wqZv&q3&&KXCjDCx0mp=&=9@Z^-e8N2@mNwVW!siGT zDqC@l15j%eV3YOf+3c=QkUg5eK9yr&V!rJ#bJ6fZ8(?yCDIuZF`&;`6s;{alDUXHD zlwcDxkP#2IwrYe0w}y+j@;lm?pSG|N{LkB(2UEBoO~rqNVfccvxy^0|kZ(-02A2^1 z5gqdzg!T^%=A$A-kt%XW3cuyixm&%#pYt)4;7~%OwYo4{{>{JgNUh4t%Pq4nGwygy z{`{z7{Nk&~f=JfSL~Ap5AZwf$VSpe8mi>HMrRAa|B8AQ~oOk@(@v1E*cZuq>F()~N zoR3^r`SXoabWKYk+j#52fNg;4$A1Ro?PFjR$aaRGZ?y1?IAD00myXv&mH!G=qW4j`9Oxo8#i`%L z@D<&&`Q!w^76Q>=;cI!KitXVZE@+hA*bL63hntH5&Yz`$d80D+umrSGWUxI z8{(}iTqP|#j$#AfQq@(^8mSd^R+GtMd_dis>Iy2&F-ccBm0cp-Xf0->@2*`bqfv@) z`~)dN)dgYGG$e2mvya@D6^YNzonbSna&LGw1*I$~NB1mm9tPQ?pt<0+ zNw8E}XQL$#$OUk|3D($@J~U}@4qR`wce7*LMy;gO1BE;-iO}j-!53g2xT@CHmS*wu z;K;)vq-%t^9vS$1-;UK+HgkxbGCxDlpTOtAfZCm3sPdggD!e%)ahW@1NaO>>!M`%2 zIs5GyGas|QTRM(qpe%NJZ9TVFokD=lnr7FOthK2@3!H{*k{w``Z@6cZ*iDic{|}E5 z&%|i^dRnSL5(#eM_?~B0XdASfE>`&>WAXf_eaf4}ZWB zSC(PDaDyJdPZlN57XUx!>hl~rJ3SE0Ux%I!vp5O2>c{q^eGbTzvT^HNYo@*eo!)6A zQI_p34xNk3sFz!lUEFTR`OX9grdYvMgBZ=z7FFKiKLsYT*1ttVhp`EB2iV;J%<**s zpIq32vXm8B@}#nz-l_nUOR%LXd|oV+f01>hj@tNDEZ!5>I`j<)_a^8WEw>!1j&8}AJ%-q=E){WY1Do8U1;QdaGz0$zhdjW*y$TkTaUg3`&tyH@>xl~y| z49<`8ogW_004KKdt`|Aqx6^F!c~v>03g_}GM8kt}+FZjXvtVb*{e-$oYD4~Pe8(Z+tp;yde1 zp4y?KCMh|M)cz(qiR5xLtt8Ia!QD*8Ckq#F1RP{roIHMmP~)p*k5d^T z)9+66_E^_mp_yyr!K?m0I+V!8mBkT$Mod=5z{*SAjsswbe1+)trb;tV7=xamj#O2@uSwjm*JV%tjM<66>fsR%TMOG7`G!mC zxp{Qaoaqc;S|qM;ng6T$qjv{2%7t`ZNjv1eH-5>n{L4S;x#>}_oO7&RoYrZgz@t+# zpZq+Bw=7BQ+6XE@{D(4M1*79y*#vrGK3!Pf*#qn3sGCpr)4UAI^2F%zxy=~z=8+-? zo8r^m2wdcAw{Vk{H^e`c+@aK*i3Q%Pt=#l3pU9<}E7$OAVx$RBbN3de#lt1J583_c zf%6y@mO5Mn>CskFEa#gk9(Iy^lyCdb$e$Uyf-o(7BKe`L8eWJVhn*w5R8YJ*!x?Vc`K7cRUS zm+0EH_ace>k6p|Q@=vcnCwDUSC?kH2&%2!w>hV_N;RR1H<(`OPyJ0{IH>( zU_XCDysr=$HI3^?mgOCsD%(r2X)h)Dm*PcZ3*+hBuD!mGXKx>Xt!{&dz8Qq-?9nVIj6ylw@`i z*$@XjkWFAvW8 zv^*9f{-w3TZ%K(gn_IJ-M>>C-tNbJk`P}|5As;zU>Xq&9oy~KKDvI6jt|@J$vUpUa zLTJG*q}+(fbI$!F75+`Gp6E}}$r(&lgyc8$t7;%GeRut`7Jb9;KcKo#nm3eJ?vuFr zYmp1lkwyx;n~~M`r@SVMi~eRwDUiQe(f_FK|CQJOqppJhP)JxrRE)ow(ZWXLft&fg z{}qjdko75Ir;y=i6=sr|d#vs8ESIl%mC1RfK*DCnwL?A%GBi4jaQV&AfSN3+`VyrZ zg+0=eR<8~!Aqib^K}H=Yc+sjcPA!dM%!wFc~YSadN!a z?MLQuvPF}Nb;thu$jBH_>OK_Wzb=TQ{YN}lyCD%p(hwqCs;%{9JoN|I&uIW(fFNk; z&??v%RV!+f$V!7W^L@;@5k}Hsm-XNjJM7FA*~*`~6>=O-OPm|#^HympQ9q5-J4Gz$ z8AKR$Q>w058{}QnWpo;qD@Brs$qzKKyJiN)!)kXAbgWs{R8SHQNEbS9F>hQS`hNm# z2*xxclM8-|!Cx9z^n*TNp;*TRg$x(iBcl%47~z-A4_{i)g;SIxLn)dDnmZ|UrmyYE zLbj~{q;Ihy^@dfC0!uNA8D(V0M^VRK47Rp7HONwXVA|RmG!ueqDHc@PbgxQ@bNjo5 zsX7k_2q3RWgPELuZiN(=r5*U04iJz{Fi44Uol9a*sUE{)atuS$6hIs!N!zas+)-(g zbrTDRqk*UElA975vw?<7qN~bcb;wWl%BnV5vY4ss`J%}KMzAbd@*88fH=?qE8VtRo z5X@3|TqT7f1+<8Nr0ZKtup+GiLIXjUMCtX z$U>YAScf#70Ke4|QW<83Uqrk)IhYuXL$*4eH~-#ACOU8pqZaB&Ni?H~B(ZWr$5jnB zA#1BUz{LJ0b;iZEhKc8lm8uG4HGqhStf3Fhb?R0th<=KgCdhb!5nkLiM1nL8yUl@W z{TvnQP-qyBHpQ*DUC7(!?@E6;CwYTR@3g!}#KsKNDpvjibdxi!8|7VTw8vek!n6Gl zL=guVN%_+E_3vSEGDlPxQ8A7XBFPNX4Vm!7#|=wXd0>+lb3}Gl_H z5L1k$Bhx7&vAk|Bs7Kb7+8xo@cQjV2fK{lq2*DO{G*^FkSvVNDmt*}xTU%8a!b;P0 zsB0ex)EPY8=aHY(VP1wfyO2SoI5Ph|xXKt3dYZvi!lT*8EP4T=u0I^!R)b2w zYRlot*w#sT7xd8W&t#Wj)h4T#&zRGn)Hv6D>kc{4m<3-ZGA(kL`D)!skE)F{+>!^O z^nMNt?Eyqw;+KNlTk_L-)MOOqyS=!D2cjpf%kQnF+@Hf60uJm&@3^LCCUIZ^6)=dZ zB*%7A*naCI_Yq`NJI*t&29KxQ6pL>h)>BZj%F17%DuCEyeDj^miYTde-$fN9f^|$GDYntaZ@Ua+vPJIhFRtNREeZkY z$ZvRWN=hR#emy3!3b_?yb$<}@U-HIgsu^VlF48f)6-4`NB05R?gEkE3`ZTob{w700 z{-T|rov7=xAIP?{SbV|#N*z?EdSb)i0O2E;cAtVEtO(@gEu&K>&i-MLrkYqM3A6}~ zB>H_5Q*DR60fAFi%=V&P`*=u}WS|}4v|^saDD=JDj)3 zsB}2=0Tbge%FOqm>?>RlkNzY(BCPXH1TP-P4|ip-ze2SGL0Y<{`_1E-%2QN>pCQV_`!eVh53<9%kyJ);4C^z? zE)by9wlyM~8tbsWmf#;yLwWr#f+m z%Z8`&PjF>`iTr?EJGg+)q&dJG9)>Q0q}DgsZ`h;Io(HVI*;8~s$YHr8jieU&>sCas z)#PVO?rJa2Qgi&m-j_|+P?oD$1^KaLG6`TsnvX0+C7RCt9n4FM53@N=`NuloeJvNG z>t@DBOC+4lXp0ig6s%O>K8%jAP^Zul3wqHE%xqdsEoB-t(OREE2+9)Z-^S_}o}=kw zy9dE=y{Am2FSiXbAe1kSh|AnkzHM{yTX3?~GF*v!*g zy=6q?QmE845+xO4CWK$kP9ite%LHCj8Pe40Px6H`Q`GYw)nLL1IwV1(e2rKh2% z0bPn-ooV%Nhc&N{*m@t0YuKlP>IHu5U(UG`6b4tC^ zj*0-A%V;Qx80nkg_(Gjd7c=Ur4;XC}&a_^kqKm?NJXv+Aa-O9qUeVrW8?@%uMsH)x zYq~|jl;)s)uXHIQCn2XXijom+7HAe=coQE2Kf|f^^7|K_gH+;zL7}^8y+_M({LC}FInu9)Z_;a)k-?MPV%Jv-Cn!y?*!k%d%a*Z@Hfl)U z7GRM_+OWHVeRJU)#%)CA6&;X?PKc^rsu>SJ;YusniL|R1FXRa+2%f&{i-Wc@JoN}$ zTC38@){YHkh9s<_Kr(_hsV8&l5irsTVA%&P1U#gB;)r@X4T*PHNnSihhMONo%6e4C%yBoI3lQwvFBmYuv@~fJx)8Yx1aKnNIjl;=a zqL#BuxbT6aHPXyoJWwQ!Xi#TbegtU|SIJlDTamywlFfk}rD+J(Hs-=aFp@7vuVD=_ zWnZYQ856b@M7p3n;qfOk8ZfYUQjlW-IOKfG8w7r6=VIL`Y(aFPuNpH4)7&2G6#a zxQw6z=+iJhQe}L)B}Ido9x)&Q1FVu(5U&q$iP4ohhVkOGX4b~vaWFZ`tJY!BO@WH~ zCR5zXZ2fAmJOEFv_?$>m6)MIE+R7CM@BOvRgR%{&OxUqSYrLtabAsefMw3QRfK8j|s+0eWg%;-q)z|jn`RCj5Usd#Gd?=@lCgpymXMiDq?T& z^9We0P5v(a=zt!inyXEdv1@UF{Q@YDtcM?=I?FgA&OtT9Lp5OrBqNt5je)WRXX%?p zKGVg$^o}{D9tIc#=$edbryl{u9j_&$G=n1ld?J6t>VuR)G{_A`2PAtz<+0|B-1lvx6orHoD5qNWJO&#U0aqI~r;s%nHm}g`u?a2zQ8=Bh zW>GXC3{{KsDYtK^Zi!b-tF9Cu6|Wdsb`9-g)F2uvcNiTw1C8b5?2S!W=dI+A3&HhH zhs{i8NtYF`ZU$;Ff8Rf5nzey!_fIFKMmkaV(BFtyxF$fc+}iyu8sxOlXdN|mYDk|M z5L|h`!sY~MI)&n=Jpx2`YN)S2a~tJWuFhyF=H=}au=}W?oq`7k`xok=*yx-*qhT^D zJ}WSbG?Pt+`xaDSG?OAcN-|d30?IJT*VRoa53I~$*O-{;Jz273WkGdEz{mLfPd=49 zq*$TuQFI$D-lK6Cfj7=*L$EF7T|p}NnF6d$UZIZgSExmQO$aO&ZkcwAH?7|>aEFlr z$AUnxbL4eR$;l2xnOX>=mj%Z~HZo%obwr74ClY;?nIlHL2#|H%(u6re?7*j)+3~Zi zEHgQNVaQ4aFF(JEuq+NwE!)>%X5HDqVVy~kBhbB7$<;$p=ow#A6H$B-7H)IMqo9;g zx~Nye{22!kLula;jIIG>OXz}lr!7KEIF3lS=I@S;4dr_r8VnTTd$tAQKT|E=9UCYZ z02~FC1On~5ZzDB!hAbpw6%B$yCwEE2R849e_;3FP3IZYo$iz!?;Tn+wV-~cr(wFHk zSNwQe{)N}mMxx&t2?^l%EbMe26vpUztp4dx?qJ&Jn71?%e_Rs6khe5XtyeKnVP-}g z6cL&9d5cXmK0{^%@Q$*M9!5tPkzy?BimxioV3+;76xEa=eFsSZA2fN@~74Uy&{S14FfisiGbDbfiWQICe$gqp1P-)arz9~%HgIIzt znWBM=9okf!^!B(H63k(6*%Wv{<*VXr}Fe;)Wx|y>!-#7$1$h5!%$^e)~-Wj#J>uC5)zAu zRH_7*vP>-p8WW_XZd{fWFm;A*`F?^ha#9XuKGj>98`dhC>#}6xAW>^y^&CfMW~jT^ z8EJ7^rJ$@Z^tZ#9=G+Mm>ab*Z_7_3A@X2OAGGuK29+gOv$qBvzfpLl!cUW5Z}kz@JNFViIVnqxl8lqXZxM>xp?qj^e?3#nYxRwI4cn zl({*;VaV>)x0m3L{w=X2;0&Wob%JvaYeG%wZ4nxTL37Ru^HkrG>q)d&Ln@cB!SAOA z>xIwYlmd<+h`9$L;BRR4_occk;@fPd*FUI0aPGWqUfPRE6<-V1or*5)RM6$VQaFr^3bE3CGAT70rWE@BBSS4;g5*ipAvUMNYjIQ`KOY)#6Mkvfiaee+ysDG|pWuNWgLMvknHuOF(x`jrIHN zu~izt8qfjmB;mjqWS^*geCka!eySU`j;oencrZFNO;6dAbl6=@q%5obPOWrcT!%`I zVw! zn+@SOR_W&(n?IlX3leE;9UsdP;RFN45Aln>9xHS&Mz%7P!msVz%3~D^E_6S?0EjQ; zNSb+r61{&#I0Hsv}uSxXIbts;uiZ4i4Zc(m< z>QXK=N;bXu63saWc);QFCF-X;1>H?}^V1&YI+z!g&ic|)_q}ft8dnC|5R8LhHiy$6 z0!=BHmD;_IgRmQldqNa90n3QxGHP$d>mBe0e4PD+2IEHJ3?pJ!#xJ`$zp4v6Wj{`@yd}bo-njyD(j*& zBVBc_60V*OuKo+LC0xOpsG8601S3Y$>aB$Qi8RMJj-RNHuFIzyI9Jl9N*5c-IJ8#x zm|MYeum3lc^!Kpceu6pNVLY#AB`?>daKZx0l3V>Z(zE{;>Hjms2nzRqe~I~Dq-TK; zK_gXhhD2vIMh+q)XA@O*`6lZI<|+OQ_20waLI4era9R`oTIs!TItrNwjynr;eHifMZP&ZU{(%*Z%LoiQ{Mp1 zbZJpr28Fn?lpLNQRt!#lO((i6(?!35tsjF*ECtjk1MPGOgCBKss>qFYFz7DkR{t01 zmPrH^JCY5x`>g>EzpDl&3@Y5uQQcIv%|A&zRO7KYjq4I}c4w2fGwyKY?x>^e z0liwA~t@uA2bM6Iv{2}kJ)d$BV(E#u1;1b0yq+ zuYT!KW+c9w?jFxUi+iM{*qw?Q56SUAyA&%QGbADK4Thauk+8uT^QP^uVfoM=>n3{D zxM8RUKlNtGUdwV5l!cGAi+-C1G1nocw_KmFdUV# z3@YUSyQrVV;K1hn3jm?cLp&#zvtKB-U$(v@l&U>|%Zmg1646L(e=2w!J$)S(gZ-Hc zZU(%wG6XLoIe>;-f2xjL&O|#tYTvCU21mINssE`wWNlH)*lt0k!F~vGrjN>f^(G60 zGdVn66S?8KxZ4 zB%|;0nz@o{ZX$#H$Hc(RXpL&=;e1+j*%-POPtqx-@u}!jX{vIM&G2gz_9`$!VmBo= z6H;hIzt6(kl3Yb4<~YyJKKFyB+=UqrY|dUJn+M)iAD&`NDQd5ujjk!}sj_wO{8Wht z+y#M3r6r5GU7v$c-eHUw?OeiLu^yw~M{Oc_d8kO3th%LT;sm|#aNgmHHa2S#j-m+0 z0`FNy2^wlUvE4!V*b4{4$l<|LNyXI|2SO*j{w9OIjes0_ROg)MSg8WH*QN@|1Qk5= z$)8Aa8_?#q{nmW_Gef`AR$n@9PEY1MRlQGnS}Bd{zpL#(9~xIA2b;9c$9e^PP4aA88&T>*|=^ z^tYZ=JJJVT^8BSQTs5{W{z%TVW~{wk{aR_(AF(vD-NRfJxFG*fsg8H&{sOqut3o>U zR#!>w>h>i4I>SFxQ7$`s!i|-Lj~nSwwhr8UTUBhoWg-(5d_>(7ZAS-O_E#PmIkk{v zd4eNEkY7OfYE*dl40zHe2Ky<}Tt1?*TuX-8+3r>LFjk@_2N9cnzk)@&*n)c;hkRA@ zoQTf@wDBfoFW^!aCBpZ9tf7S0D6zXyop8+CYK$;jQ$gtifo?J&n8t%=;;EFr&WV$K zJaf6+JEf5qr2he9=>*G9_DZFTy@ci;<6u+_kyz4BaLlJ Date: Wed, 27 Dec 2023 10:55:01 +0200 Subject: [PATCH 148/449] Add experimental HAT model --- modules/hat_model.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 modules/hat_model.py diff --git a/modules/hat_model.py b/modules/hat_model.py new file mode 100644 index 00000000000..553e1941170 --- /dev/null +++ b/modules/hat_model.py @@ -0,0 +1,42 @@ +import os +import sys + +from modules import modelloader, devices +from modules.shared import opts +from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model + + +class UpscalerHAT(Upscaler): + def __init__(self, dirname): + self.name = "HAT" + self.scalers = [] + self.user_path = dirname + super().__init__() + for file in self.find_models(ext_filter=[".pt", ".pth"]): + name = modelloader.friendly_name(file) + scale = 4 # TODO: scale might not be 4, but we can't know without loading the model + scaler_data = UpscalerData(name, file, upscaler=self, scale=scale) + self.scalers.append(scaler_data) + + def do_upscale(self, img, selected_model): + try: + model = self.load_model(selected_model) + except Exception as e: + print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr) + return img + model.to(devices.device_esrgan) # TODO: should probably be device_hat + return upscale_with_model( + model, + img, + tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile + tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap + ) + + def load_model(self, path: str): + if not os.path.isfile(path): + raise FileNotFoundError(f"Model file {path} not found") + return modelloader.load_spandrel_model( + path, + device=devices.device_esrgan, # TODO: should probably be device_hat + ) From 4ad0c0c0a805da4bac03cff86ea17c25a1291546 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 16:37:03 +0200 Subject: [PATCH 149/449] Verify architecture for loaded Spandrel models --- extensions-builtin/ScuNET/scripts/scunet_model.py | 2 +- extensions-builtin/SwinIR/scripts/swinir_model.py | 1 + modules/codeformer_model.py | 1 + modules/esrgan_model.py | 1 + modules/gfpgan_model.py | 1 + modules/hat_model.py | 1 + modules/modelloader.py | 13 ++++++++++++- modules/realesrgan_model.py | 7 ++++--- 8 files changed, 22 insertions(+), 5 deletions(-) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 18cf8e1a0b6..5f3dd08b337 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -121,7 +121,7 @@ def load_model(self, path: str): filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - return modelloader.load_spandrel_model(filename, device=device) + return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet') def on_ui_settings(): diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 85c18b9e939..aae159af56e 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -75,6 +75,7 @@ def load_model(self, path, scale=4): filename, device=self._get_device(), dtype=devices.dtype, + expected_architecture="SwinIR", ) if getattr(opts, 'SWIN_torch_compile', False): try: diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ceda4bab919..44b84618ec0 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -37,6 +37,7 @@ def load_net(self) -> torch.Module: return modelloader.load_spandrel_model( model_path, device=devices.device_codeformer, + expected_architecture='CodeFormer', ).model raise ValueError("No codeformer model found") diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a7c7c9e3013..70041ab0234 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -49,6 +49,7 @@ def load_model(self, path: str): return modelloader.load_spandrel_model( filename, device=('cpu' if devices.device_esrgan.type == 'mps' else None), + expected_architecture='ESRGAN', ) diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index a356b56fe1a..48f8ad5e294 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -37,6 +37,7 @@ def load_net(self) -> None: net = modelloader.load_spandrel_model( model_path, device=self.get_device(), + expected_architecture='GFPGAN', ).model net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 return net diff --git a/modules/hat_model.py b/modules/hat_model.py index 553e1941170..7f2abb41660 100644 --- a/modules/hat_model.py +++ b/modules/hat_model.py @@ -39,4 +39,5 @@ def load_model(self, path: str): return modelloader.load_spandrel_model( path, device=devices.device_esrgan, # TODO: should probably be device_hat + expected_architecture='HAT', ) diff --git a/modules/modelloader.py b/modules/modelloader.py index 30116932a1d..f4182559e59 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -6,6 +6,8 @@ import importlib from urllib.parse import urlparse +import torch + from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.paths import script_path, models_path @@ -183,9 +185,18 @@ def load_upscalers(): ) -def load_spandrel_model(path, *, device, half: bool = False, dtype=None): +def load_spandrel_model( + path: str, + *, + device: str | torch.device | None, + half: bool = False, + dtype: str | None = None, + expected_architecture: str | None = None, +): import spandrel model = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model.architecture != expected_architecture: + raise TypeError(f"Model {path} is not a {expected_architecture} model") if half: model = model.model.half() if dtype: diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 332d8f4b108..2a2be5ad7fe 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,9 +1,9 @@ import os -from modules.upscaler_utils import upscale_with_model -from modules.upscaler import Upscaler, UpscalerData -from modules.shared import cmd_opts, opts from modules import modelloader, errors +from modules.shared import cmd_opts, opts +from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model class UpscalerRealESRGAN(Upscaler): @@ -40,6 +40,7 @@ def do_upscale(self, img, path): info.local_data_path, device=self.device, half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + expected_architecture="RealESRGAN", ) return upscale_with_model( mod, From 05230c02606080527b65ace9eacb6fb835239877 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 18:02:51 +0300 Subject: [PATCH 150/449] fix img2img api that i broke when implementing infotext support --- modules/api/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/api/api.py b/modules/api/api.py index 2918f785750..2e18c6b9185 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -507,6 +507,7 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_scripts', None) + args.pop('infotext', None) script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args) From f476649c02cf3547d891fa08c50a92f92c4d73bd Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 17:41:19 +0200 Subject: [PATCH 151/449] Correct arg type for restore_face --- modules/face_restoration_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py index c65c85ef89d..85cb3057050 100644 --- a/modules/face_restoration_utils.py +++ b/modules/face_restoration_utils.py @@ -36,7 +36,7 @@ def create_face_helper(device) -> FaceRestoreHelper: def restore_with_face_helper( np_image: np.ndarray, face_helper: FaceRestoreHelper, - restore_face: Callable[[np.ndarray], np.ndarray], + restore_face: Callable[[torch.Tensor], torch.Tensor], ) -> np.ndarray: """ Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image. @@ -126,7 +126,7 @@ def load_net(self) -> torch.Module: def restore_with_helper( self, np_image: np.ndarray, - restore_face: Callable[[np.ndarray], np.ndarray], + restore_face: Callable[[torch.Tensor], torch.Tensor], ) -> np.ndarray: try: if self.net is None: From 91560e98c47f8271d444556ef4ae6505dece9aba Mon Sep 17 00:00:00 2001 From: lanyeeee <1210347077@qq.com> Date: Sat, 30 Dec 2023 23:42:10 +0800 Subject: [PATCH 152/449] fix format issue --- modules/api/api.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 2f718ec262f..d202cb8d27d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -418,7 +418,6 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = txt2imgreq.force_task_id or create_task_id("txt2img") script_runner = scripts.scripts_txt2img - with self.txt2img_script_arg_init_lock: if not script_runner.scripts: script_runner.initialize_scripts(False) @@ -489,14 +488,13 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): mask = decode_base64_to_image(mask) script_runner = scripts.scripts_img2img - with self.img2img_script_arg_init_lock: if not script_runner.scripts: script_runner.initialize_scripts(True) ui.create_ui() - infotext_script_args = {} - self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + infotext_script_args = {} + self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) if not self.default_script_arg_img2img: self.default_script_arg_img2img = self.init_default_script_args(script_runner) From c9174253fb603e6b2552e4c2721fd767b6ede87d Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 17:45:26 +0200 Subject: [PATCH 153/449] Drop dependency on basicsr --- modules/face_restoration_utils.py | 35 +++++++++++++++++++++++-------- requirements.txt | 1 - requirements_versions.txt | 1 - 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py index 85cb3057050..1cbac236480 100644 --- a/modules/face_restoration_utils.py +++ b/modules/face_restoration_utils.py @@ -17,6 +17,28 @@ logger = logging.getLogger(__name__) +def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor: + """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor.""" + assert img.shape[2] == 3, "image must be RGB" + if img.dtype == "float64": + img = img.astype("float32") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return torch.from_numpy(img.transpose(2, 0, 1)).float() + + +def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray: + """ + Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range. + """ + tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) + assert tensor.dim() == 3, "tensor must be RGB" + img_np = tensor.numpy().transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image, no RGB/BGR required + return np.squeeze(img_np, axis=2) + return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) + + def create_face_helper(device) -> FaceRestoreHelper: from facexlib.detection import retinaface from facexlib.utils.face_restoration_helper import FaceRestoreHelper @@ -43,7 +65,6 @@ def restore_with_face_helper( `restore_face` should take a cropped face image and return a restored face image. """ - from basicsr.utils import img2tensor, tensor2img from torchvision.transforms.functional import normalize np_image = np_image[:, :, ::-1] original_resolution = np_image.shape[0:2] @@ -56,23 +77,19 @@ def restore_with_face_helper( face_helper.align_warp_face() logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) for cropped_face in face_helper.cropped_faces: - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) try: with torch.no_grad(): - restored_face = tensor2img( - restore_face(cropped_face_t), - rgb2bgr=True, - min_max=(-1, 1), - ) + cropped_face_t = restore_face(cropped_face_t) devices.torch_gc() except Exception: errors.report('Failed face-restoration inference', exc_info=True) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - restored_face = restored_face.astype('uint8') + restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1)) + restored_face = (restored_face * 255.0).astype('uint8') face_helper.add_restored_face(restored_face) logger.debug("Merging restored faces into image") diff --git a/requirements.txt b/requirements.txt index b1329c9e3ab..731a1be7d9b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ GitPython Pillow accelerate -basicsr blendmodes clean-fid einops diff --git a/requirements_versions.txt b/requirements_versions.txt index edbb6db9e0e..1e0ccafa7d2 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,7 +1,6 @@ GitPython==3.1.32 Pillow==9.5.0 accelerate==0.21.0 -basicsr==1.4.2 blendmodes==2022 clean-fid==0.1.35 einops==0.4.1 From b58ed1b2432c3c7643b39e53f7bb567ea8655aae Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 18:02:01 +0200 Subject: [PATCH 154/449] Bump numpy to 1.26.2 This avoids it being downgraded during `launch.py` --- requirements_versions.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_versions.txt b/requirements_versions.txt index edbb6db9e0e..7ec7abe2ee2 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -13,7 +13,7 @@ inflection==0.5.1 jsonmerge==1.8.0 kornia==0.6.7 lark==1.1.2 -numpy==1.23.5 +numpy==1.26.2 omegaconf==2.2.3 open-clip-torch==2.20.0 piexif==1.1.3 From f651405427dfc6d4ef96ecba7f9c2ceb580263fd Mon Sep 17 00:00:00 2001 From: lanyeeee <1210347077@qq.com> Date: Sun, 31 Dec 2023 01:09:13 +0800 Subject: [PATCH 155/449] remove locks, move init code to __init__ --- modules/api/api.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index d202cb8d27d..fc3921c23e1 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -251,8 +251,21 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] - self.txt2img_script_arg_init_lock = Lock() - self.img2img_script_arg_init_lock = Lock() + txt2img_script_runner = scripts.scripts_txt2img + img2img_script_runner = scripts.scripts_img2img + + if not txt2img_script_runner.scripts or not img2img_script_runner.scripts: + ui.create_ui() + + if not txt2img_script_runner.scripts: + txt2img_script_runner.initialize_scripts(False) + if not self.default_script_arg_txt2img: + self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner) + + if not img2img_script_runner.scripts: + img2img_script_runner.initialize_scripts(True) + if not self.default_script_arg_img2img: + self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner) @@ -418,16 +431,10 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = txt2imgreq.force_task_id or create_task_id("txt2img") script_runner = scripts.scripts_txt2img - with self.txt2img_script_arg_init_lock: - if not script_runner.scripts: - script_runner.initialize_scripts(False) - ui.create_ui() - infotext_script_args = {} - self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + infotext_script_args = {} + self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) - if not self.default_script_arg_txt2img: - self.default_script_arg_txt2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params @@ -488,16 +495,10 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): mask = decode_base64_to_image(mask) script_runner = scripts.scripts_img2img - with self.img2img_script_arg_init_lock: - if not script_runner.scripts: - script_runner.initialize_scripts(True) - ui.create_ui() - infotext_script_args = {} - self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + infotext_script_args = {} + self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) - if not self.default_script_arg_img2img: - self.default_script_arg_img2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) populate = img2imgreq.copy(update={ # Override __init__ params From 1465dab71564bb30091479ceabae6c69e3426bc6 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 19:44:05 +0200 Subject: [PATCH 156/449] Make Tensorboard a late import (it was implicitly installed by basicsr) --- modules/textual_inversion/textual_inversion.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 04dda585cc9..c6bcab1538d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -11,7 +11,6 @@ import numpy as np from PIL import Image, PngImagePlugin -from torch.utils.tensorboard import SummaryWriter from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes import modules.textual_inversion.dataset @@ -344,6 +343,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) def tensorboard_setup(log_directory): + from torch.utils.tensorboard import SummaryWriter os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) return SummaryWriter( log_dir=os.path.join(log_directory, "tensorboard"), @@ -448,8 +448,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed + tensorboard_writer = None if shared.opts.training_enable_tensorboard: - tensorboard_writer = tensorboard_setup(log_directory) + try: + tensorboard_writer = tensorboard_setup(log_directory) + except ImportError: + errors.report("Error initializing tensorboard", exc_info=True) pin_memory = shared.opts.pin_memory @@ -622,7 +626,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" - if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: + if tensorboard_writer and shared.opts.training_tensorboard_save_images: tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: From 48a2a1a437a48cc232725cc813242f98483b7697 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 19:44:38 +0200 Subject: [PATCH 157/449] Don't wait for 10 minutes for test server to come up --- .github/workflows/run_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index cd5c3f868ff..f42e4758e63 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -57,7 +57,7 @@ jobs: 2>&1 | tee output.txt & - name: Run tests run: | - wait-for-it --service 127.0.0.1:7860 -t 600 + wait-for-it --service 127.0.0.1:7860 -t 20 python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test - name: Kill test server if: always() From 5fbb13e0da8eb2e26bd2c45ec8ffbb2de669ef47 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 20:46:44 +0200 Subject: [PATCH 158/449] Remove `cleanup_models` code --- modules/initialize.py | 3 --- modules/modelloader.py | 50 ------------------------------------------ 2 files changed, 53 deletions(-) diff --git a/modules/initialize.py b/modules/initialize.py index ac95fc6f00c..4a3cd98cf86 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -54,9 +54,6 @@ def initialize(): initialize_util.configure_sigint_handler() initialize_util.configure_opts_onchange() - from modules import modelloader - modelloader.cleanup_models() - from modules import sd_models sd_models.setup_model() startup_timer.record("setup SD model") diff --git a/modules/modelloader.py b/modules/modelloader.py index f4182559e59..5f7aec3e42a 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -2,7 +2,6 @@ import logging import os -import shutil import importlib from urllib.parse import urlparse @@ -10,7 +9,6 @@ from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone -from modules.paths import script_path, models_path logger = logging.getLogger(__name__) @@ -96,54 +94,6 @@ def friendly_name(file: str): return model_name -def cleanup_models(): - # This code could probably be more efficient if we used a tuple list or something to store the src/destinations - # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler - # somehow auto-register and just do these things... - root_path = script_path - src_path = models_path - dest_path = os.path.join(models_path, "Stable-diffusion") - move_files(src_path, dest_path, ".ckpt") - move_files(src_path, dest_path, ".safetensors") - src_path = os.path.join(root_path, "ESRGAN") - dest_path = os.path.join(models_path, "ESRGAN") - move_files(src_path, dest_path) - src_path = os.path.join(models_path, "BSRGAN") - dest_path = os.path.join(models_path, "ESRGAN") - move_files(src_path, dest_path, ".pth") - src_path = os.path.join(root_path, "gfpgan") - dest_path = os.path.join(models_path, "GFPGAN") - move_files(src_path, dest_path) - src_path = os.path.join(root_path, "SwinIR") - dest_path = os.path.join(models_path, "SwinIR") - move_files(src_path, dest_path) - src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/") - dest_path = os.path.join(models_path, "LDSR") - move_files(src_path, dest_path) - - -def move_files(src_path: str, dest_path: str, ext_filter: str = None): - try: - os.makedirs(dest_path, exist_ok=True) - if os.path.exists(src_path): - for file in os.listdir(src_path): - fullpath = os.path.join(src_path, file) - if os.path.isfile(fullpath): - if ext_filter is not None: - if ext_filter not in file: - continue - print(f"Moving {file} from {src_path} to {dest_path}.") - try: - shutil.move(fullpath, dest_path) - except Exception: - pass - if len(os.listdir(src_path)) == 0: - print(f"Removing empty folder: {src_path}") - shutil.rmtree(src_path, True) - except Exception: - pass - - def load_upscalers(): # We can only do this 'magic' method to dynamically load upscalers if they are referenced, # so we'll try to import any _model.py files before looking in __subclasses__ From af050dcaa75ef40b6b1c3da3361f32fe52786aeb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 21:05:59 +0200 Subject: [PATCH 159/449] Soften Spandrel model-architecture check to just a warning --- modules/modelloader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index f4182559e59..6b7d697f103 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -196,7 +196,9 @@ def load_spandrel_model( import spandrel model = spandrel.ModelLoader(device=device).load_from_file(path) if expected_architecture and model.architecture != expected_architecture: - raise TypeError(f"Model {path} is not a {expected_architecture} model") + logger.warning( + f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})", + ) if half: model = model.model.half() if dtype: From 393a5b82ba6df06d85f2bf7bbbe0456d3d06115f Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 21:12:32 +0200 Subject: [PATCH 160/449] Correct RealESRGAN expected architecture type to ESRGAN --- modules/realesrgan_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 2a2be5ad7fe..65f2e880668 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -40,7 +40,7 @@ def do_upscale(self, img, path): info.local_data_path, device=self.device, half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), - expected_architecture="RealESRGAN", + expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel ) return upscale_with_model( mod, From 8100e901ab0c5b04d289eebb722c8a653b8beef1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 22:41:53 +0300 Subject: [PATCH 161/449] fix error with RealESRGAN model failing to upscale fp32 image --- modules/upscaler_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 8bdda51c4c2..39f78a0b7bc 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -16,9 +16,13 @@ def upscale_without_tiling(model, img: Image.Image): img = img[:, :, ::-1] img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_esrgan) + + model_weight = next(iter(model.parameters())) + img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) + with torch.no_grad(): output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = 255. * np.moveaxis(output, 0, 2) output = output.astype(np.uint8) From bc5ae74c7d8949bab37e260b16e76889b9968099 Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sat, 30 Dec 2023 21:52:27 +0100 Subject: [PATCH 162/449] Added negative prompts to extra networks lora --- .../Lora/ui_edit_user_metadata.py | 14 +++++++-- .../Lora/ui_extra_networks_lora.py | 9 ++++++ javascript/extraNetworks.js | 31 +++++++++++++------ modules/ui_extra_networks.py | 5 ++- 4 files changed, 46 insertions(+), 13 deletions(-) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index c7011909055..f7859b21f07 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -54,12 +54,14 @@ def __init__(self, ui, tabname, page): self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, negative_weight, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight + user_metadata["negative text"] = negative_text + user_metadata["negative weight"] = negative_weight user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) @@ -127,6 +129,8 @@ def put_values_into_components(self, name): gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), + user_metadata.get('negative text', ''), + float(user_metadata.get('negative weight', 0.0)), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -162,7 +166,8 @@ def create_editor(self): self.taginfo = gr.HighlightedText(label="Training dataset tags") self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) - + self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts") + self.slider_negative_weight = gr.Slider(label='Preferred negative weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) with gr.Row() as row_random_prompt: with gr.Column(scale=8): random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) @@ -198,6 +203,8 @@ def select_tag(activation_text, evt: gr.SelectData): self.taginfo, self.edit_activation_text, self.slider_preferred_weight, + self.edit_negative_text, + self.slider_negative_weight, row_random_prompt, random_prompt, ] @@ -211,7 +218,10 @@ def select_tag(activation_text, evt: gr.SelectData): self.select_sd_version, self.edit_activation_text, self.slider_preferred_weight, + self.edit_negative_text, + self.slider_negative_weight, self.edit_notes, ] + self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index df02c663b12..09ce2a0578e 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -45,6 +45,15 @@ def create_item(self, name, index=None, enable_filter=True): if activation_text: item["prompt"] += " + " + quote_js(" " + activation_text) + negative_prompt = item["user_metadata"].get("negative text") + preferred_negative_weight = item["user_metadata"].get("negative weight") + item["negative_prompt"] = quote_js("") + if negative_prompt: + neg_prompt = negative_prompt + if (preferred_negative_weight > 0): + neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' + item["negative_prompt"] = quote_js(neg_prompt) + sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: item["sd_version"] = sd_version diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 98a7abb745c..2bb9795dca4 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -185,8 +185,10 @@ onUiLoaded(setupExtraNetworks); var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/; var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g; -function tryToRemoveExtraNetworkFromPrompt(textarea, text) { - var m = text.match(re_extranet); +var re_extranet_neg = /\(([^:^>]+:[\d.]+)\)/; +var re_extranet_g_neg = /\(([^:^>]+:[\d.]+)\)/g; +function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { + var m = text.match(isNeg ? re_extranet_neg : re_extranet); var replaced = false; var newTextareaText; if (m) { @@ -194,8 +196,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { var extraTextAfterNet = m[2]; var partToSearch = m[1]; var foundAtPosition = -1; - newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) { - m = found.match(re_extranet); + newTextareaText = textarea.value.replaceAll(isNeg ? re_extranet_g_neg : re_extranet_g, function(found, net, pos) { + m = found.match(isNeg ? re_extranet_neg : re_extranet); if (m[1] == partToSearch) { replaced = true; foundAtPosition = pos; @@ -205,7 +207,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { }); if (foundAtPosition >= 0) { - if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { + if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); } if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) { @@ -230,14 +232,23 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { return false; } -function cardClicked(tabname, textToAdd, allowNegativePrompt) { - var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); +function updatePromptArea(text, textArea, isNeg) { - if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) { - textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd; + if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) { + textArea.value = textArea.value + opts.extra_networks_add_text_separator + text; } - updateInput(textarea); + updateInput(textArea); +} + +function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) { + if (textToAddNegative.length > 0) { + updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")) + updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true) + } else { + var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); + updatePromptArea(textToAdd, textarea) + } } function saveCardPreview(event, tabname, filename) { diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index fe5d3ba3338..b8c02241365 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -223,7 +223,10 @@ def create_html_for_item(self, item, tabname): onclick = item.get("onclick", None) if onclick is None: - onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + if "negative_prompt" in item: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {item["negative_prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + else: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {'""'}, {"true" if self.allow_negative_prompt else "false"})""") + '"' height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' From a2f23f9d22dde87bf2529dcb2854a6a5d3d44278 Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sat, 30 Dec 2023 22:16:51 +0100 Subject: [PATCH 163/449] Code Style fixes --- extensions-builtin/Lora/ui_extra_networks_lora.py | 4 ++-- javascript/extraNetworks.js | 6 +++--- modules/upscaler_utils.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 09ce2a0578e..9a6624e3f1e 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -52,8 +52,8 @@ def create_item(self, name, index=None, enable_filter=True): neg_prompt = negative_prompt if (preferred_negative_weight > 0): neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' - item["negative_prompt"] = quote_js(neg_prompt) - + item["negative_prompt"] = quote_js(neg_prompt) + sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: item["sd_version"] = sd_version diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 2bb9795dca4..f1ad19a66b7 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -243,11 +243,11 @@ function updatePromptArea(text, textArea, isNeg) { function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) { if (textToAddNegative.length > 0) { - updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")) - updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true) + updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")); + updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true); } else { var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); - updatePromptArea(textToAdd, textarea) + updatePromptArea(textToAdd, textarea); } } diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 39f78a0b7bc..1d610dbffd9 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import tqdm from PIL import Image -from modules import devices, images +from modules import images logger = logging.getLogger(__name__) From 3be90740316f8fbb950b31d440458a5e8ed4beb3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 31 Dec 2023 00:43:41 +0300 Subject: [PATCH 164/449] fix for the previous fix. --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 39f78a0b7bc..dde5d7ad43a 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -17,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - model_weight = next(iter(model.parameters())) + model_weight = next(iter(model.model.parameters())) img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) with torch.no_grad(): From c0ca6348e8489651df861a101142805c213c66a0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:04:47 +0200 Subject: [PATCH 165/449] load_spandrel_model: always return a model descriptor --- modules/modelloader.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index 0b89d682c55..8bcee08c155 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,8 +1,9 @@ from __future__ import annotations +import importlib import logging import os -import importlib +from typing import TYPE_CHECKING from urllib.parse import urlparse import torch @@ -10,6 +11,8 @@ from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone +if TYPE_CHECKING: + import spandrel logger = logging.getLogger(__name__) @@ -142,17 +145,17 @@ def load_spandrel_model( half: bool = False, dtype: str | None = None, expected_architecture: str | None = None, -): +) -> spandrel.ModelDescriptor: import spandrel - model = spandrel.ModelLoader(device=device).load_from_file(path) - if expected_architecture and model.architecture != expected_architecture: + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model_descriptor.architecture != expected_architecture: logger.warning( - f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})", + f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", ) if half: - model = model.model.half() + model_descriptor.model.half() if dtype: - model = model.model.to(dtype=dtype) - model.eval() - logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype) - return model + model_descriptor.model.to(dtype=dtype) + model_descriptor.model.eval() + logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype) + return model_descriptor From 777af661a21821994993df3ef566b01df2bb61a0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:09:51 +0200 Subject: [PATCH 166/449] Be more clear about Spandrel model nomenclature --- extensions-builtin/SwinIR/scripts/swinir_model.py | 6 +++--- modules/gfpgan_model.py | 10 ++++++---- modules/modelloader.py | 2 +- modules/realesrgan_model.py | 4 ++-- modules/upscaler_utils.py | 2 +- 5 files changed, 13 insertions(+), 11 deletions(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index aae159af56e..95c7ec648ef 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -71,7 +71,7 @@ def load_model(self, path, scale=4): else: filename = path - model = modelloader.load_spandrel_model( + model_descriptor = modelloader.load_spandrel_model( filename, device=self._get_device(), dtype=devices.dtype, @@ -79,10 +79,10 @@ def load_model(self, path, scale=4): ) if getattr(opts, 'SWIN_torch_compile', False): try: - model = torch.compile(model) + model_descriptor.model.compile() except Exception: logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True) - return model + return model_descriptor def _get_device(self): return devices.get_device_for('swinir') diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 48f8ad5e294..445b040925e 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -3,6 +3,8 @@ import logging import os +import torch + from modules import ( devices, errors, @@ -25,7 +27,7 @@ def name(self): def get_device(self): return devices.device_gfpgan - def load_net(self) -> None: + def load_net(self) -> torch.Module: for model_path in modelloader.load_models( model_path=self.model_path, model_url=model_url, @@ -34,13 +36,13 @@ def load_net(self) -> None: ext_filter=['.pth'], ): if 'GFPGAN' in os.path.basename(model_path): - net = modelloader.load_spandrel_model( + model = modelloader.load_spandrel_model( model_path, device=self.get_device(), expected_architecture='GFPGAN', ).model - net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 - return net + model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 + return model raise ValueError("No GFPGAN model found") def restore(self, np_image): diff --git a/modules/modelloader.py b/modules/modelloader.py index 8bcee08c155..a7194137571 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -143,7 +143,7 @@ def load_spandrel_model( *, device: str | torch.device | None, half: bool = False, - dtype: str | None = None, + dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, ) -> spandrel.ModelDescriptor: import spandrel diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 65f2e880668..4d35b695c3b 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -36,14 +36,14 @@ def do_upscale(self, img, path): errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img - mod = modelloader.load_spandrel_model( + model_descriptor = modelloader.load_spandrel_model( info.local_data_path, device=self.device, half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel ) return upscale_with_model( - mod, + model_descriptor, img, tile_size=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index dde5d7ad43a..174c9bc3713 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import tqdm from PIL import Image -from modules import devices, images +from modules import images logger = logging.getLogger(__name__) From 6f86b62a1be7993073ba3a789d522e0b8870605a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 22:53:49 +0200 Subject: [PATCH 167/449] Deduplicate tiled inference code from SwinIR/ScuNET --- .../ScuNET/scripts/scunet_model.py | 55 +++----------- .../SwinIR/scripts/swinir_model.py | 57 ++------------- modules/upscaler_utils.py | 72 ++++++++++++++++++- 3 files changed, 87 insertions(+), 97 deletions(-) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 5f3dd08b337..f799cb76dec 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -3,12 +3,11 @@ import PIL.Image import numpy as np import torch -from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors - from modules.shared import opts +from modules.upscaler_utils import tiled_upscale_2 class UpscalerScuNET(modules.upscaler.Upscaler): @@ -40,47 +39,6 @@ def __init__(self, dirname): scalers.append(scaler_data2) self.scalers = scalers - @staticmethod - @torch.no_grad() - def tiled_inference(img, model): - # test the image tile by tile - h, w = img.shape[2:] - tile = opts.SCUNET_tile - tile_overlap = opts.SCUNET_tile_overlap - if tile == 0: - return model(img) - - device = devices.get_device_for('scunet') - assert tile % 8 == 0, "tile size should be a multiple of window_size" - sf = 1 - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device) - W = torch.zeros_like(E, dtype=devices.dtype, device=device) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar: - for h_idx in h_idx_list: - - for w_idx in w_idx_list: - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - def do_upscale(self, img: PIL.Image.Image, selected_file): devices.torch_gc() @@ -104,7 +62,16 @@ def do_upscale(self, img: PIL.Image.Image, selected_file): _img[:, :, :h, :w] = torch_img # pad image torch_img = _img - torch_output = self.tiled_inference(torch_img, model).squeeze(0) + with torch.no_grad(): + torch_output = tiled_upscale_2( + torch_img, + model, + tile_size=opts.SCUNET_tile, + tile_overlap=opts.SCUNET_tile_overlap, + scale=1, + device=devices.get_device_for('scunet'), + desc="ScuNET tiles", + ).squeeze(0) torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() del torch_img, torch_output diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 95c7ec648ef..8a555c794f7 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -4,11 +4,11 @@ import numpy as np import torch from PIL import Image -from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import opts, state +from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import tiled_upscale_2 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" @@ -110,14 +110,14 @@ def upscale( w_pad = (w_old // window_size + 1) * window_size - w_old img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = inference( + output = tiled_upscale_2( img, model, - tile=tile, + tile_size=tile, tile_overlap=tile_overlap, - window_size=window_size, scale=scale, device=device, + desc="SwinIR tiles", ) output = output[..., : h_old * scale, : w_old * scale] output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() @@ -129,53 +129,6 @@ def upscale( return Image.fromarray(output, "RGB") -def inference( - img, - model, - *, - tile: int, - tile_overlap: int, - window_size: int, - scale: int, - device, -): - # test the image tile by tile - b, c, h, w = img.size() - tile = min(tile, h, w) - assert tile % window_size == 0, "tile size should be a multiple of window_size" - sf = scale - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img) - W = torch.zeros_like(E, dtype=devices.dtype, device=device) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: - for h_idx in h_idx_list: - if state.interrupted or state.skipped: - break - - for w_idx in w_idx_list: - if state.interrupted or state.skipped: - break - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - - def on_ui_settings(): import gradio as gr diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 174c9bc3713..8e413854303 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import tqdm from PIL import Image -from modules import images +from modules import images, shared logger = logging.getLogger(__name__) @@ -68,3 +68,73 @@ def upscale_with_model( overlap=grid.overlap * scale_factor, ) return images.combine_grid(newgrid) + + +def tiled_upscale_2( + img, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + device, + desc="Tiled upscale", +): + # Alternative implementation of `upscale_with_model` originally used by + # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and + # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in + # Pillow space without weighting. + b, c, h, w = img.size() + tile_size = min(tile_size, h, w) + + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img.shape) + return model(img) + + stride = tile_size - tile_overlap + h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size] + w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size] + result = torch.zeros( + b, + c, + h * scale, + w * scale, + device=device, + ).type_as(img) + weights = torch.zeros_like(result) + logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) + with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar: + for h_idx in h_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + for w_idx in w_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + in_patch = img[ + ..., + h_idx : h_idx + tile_size, + w_idx : w_idx + tile_size, + ] + out_patch = model(in_patch) + + result[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch) + + out_patch_mask = torch.ones_like(out_patch) + + weights[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch_mask) + + pbar.update(1) + + output = result.div_(weights) + + return output From 5768afc776a66bb94e77a9c1daebeea58fa731d5 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:20:30 +0200 Subject: [PATCH 168/449] Add utility to inspect a model's parameters (to get dtype/device) --- modules/devices.py | 3 ++- modules/interrogate.py | 3 ++- modules/sd_models_xl.py | 3 ++- modules/torch_utils.py | 17 +++++++++++++++++ modules/upscaler_utils.py | 5 +++-- modules/xlmr.py | 5 ++++- modules/xlmr_m18.py | 5 ++++- test/test_torch_utils.py | 19 +++++++++++++++++++ 8 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 modules/torch_utils.py create mode 100644 test/test_torch_utils.py diff --git a/modules/devices.py b/modules/devices.py index c956207f334..bd6bd579b1a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -4,6 +4,7 @@ import torch from modules import errors, shared +from modules.torch_utils import get_param if sys.platform == "darwin": from modules import mac_specific @@ -131,7 +132,7 @@ def cond_cast_float(input): def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype + org_dtype = get_param(self).dtype self.to(dtype) args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} diff --git a/modules/interrogate.py b/modules/interrogate.py index 3045560d0ae..5be5a10f39c 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -11,6 +11,7 @@ from torchvision.transforms.functional import InterpolationMode from modules import devices, paths, shared, lowvram, modelloader, errors +from modules.torch_utils import get_param blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -131,7 +132,7 @@ def load(self): self.clip_model = self.clip_model.to(devices.device_interrogate) - self.dtype = next(self.clip_model.parameters()).dtype + self.dtype = get_param(self.clip_model).dtype def send_clip_to_ram(self): if not shared.opts.interrogate_keep_models_in_memory: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 1de31b0da19..c3602a7e1a6 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -6,6 +6,7 @@ import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer from modules import devices, shared, prompt_parser +from modules.torch_utils import get_param def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): @@ -90,7 +91,7 @@ def get_target_prompt_token_count(self, token_count): def extend_sdxl(model): """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" - dtype = next(model.model.diffusion_model.parameters()).dtype + dtype = get_param(model.model.diffusion_model).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' model.cond_stage_key = 'txt' diff --git a/modules/torch_utils.py b/modules/torch_utils.py new file mode 100644 index 00000000000..e5b52393ec8 --- /dev/null +++ b/modules/torch_utils.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import torch.nn + + +def get_param(model) -> torch.nn.Parameter: + """ + Find the first parameter in a model or module. + """ + if hasattr(model, "model") and hasattr(model.model, "parameters"): + # Unpeel a model descriptor to get at the actual Torch module. + model = model.model + + for param in model.parameters(): + return param + + raise ValueError(f"No parameters found in model {model!r}") diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 8e413854303..c60e3beb89d 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -7,6 +7,7 @@ from PIL import Image from modules import images, shared +from modules.torch_utils import get_param logger = logging.getLogger(__name__) @@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - model_weight = next(iter(model.model.parameters())) - img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) + param = get_param(model) + img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): output = model(img) diff --git a/modules/xlmr.py b/modules/xlmr.py index a407a3cade8..6e000a56e75 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -5,6 +5,9 @@ from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional +from modules.torch_utils import get_param + + class BertSeriesConfig(BertConfig): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): @@ -62,7 +65,7 @@ def __init__(self, config=None, **kargs): self.post_init() def encode(self,c): - device = next(self.parameters()).device + device = get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index a727e865529..e3e8196108c 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -5,6 +5,9 @@ from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional +from modules.torch_utils import get_param + + class BertSeriesConfig(BertConfig): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): @@ -68,7 +71,7 @@ def __init__(self, config=None, **kargs): self.post_init() def encode(self,c): - device = next(self.parameters()).device + device = get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py new file mode 100644 index 00000000000..f1aec832945 --- /dev/null +++ b/test/test_torch_utils.py @@ -0,0 +1,19 @@ +import types + +import pytest +import torch + +from modules.torch_utils import get_param + + +@pytest.mark.parametrize("wrapped", [True, False]) +def test_get_param(wrapped): + mod = torch.nn.Linear(1, 1) + cpu = torch.device("cpu") + mod.to(dtype=torch.float16, device=cpu) + if wrapped: + # more or less how spandrel wraps a thing + mod = types.SimpleNamespace(model=mod) + p = get_param(mod) + assert p.dtype == torch.float16 + assert p.device == cpu From d4945f4422e5a0bf31a6dbe4c1aeedd78c09eacb Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sun, 31 Dec 2023 13:22:30 +0100 Subject: [PATCH 169/449] Removed weight slider for negative prompts --- extensions-builtin/Lora/ui_edit_user_metadata.py | 7 +------ extensions-builtin/Lora/ui_extra_networks_lora.py | 6 +----- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index f7859b21f07..3160aecfa38 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -54,14 +54,13 @@ def __init__(self, ui, tabname, page): self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, negative_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight user_metadata["negative text"] = negative_text - user_metadata["negative weight"] = negative_weight user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) @@ -130,7 +129,6 @@ def put_values_into_components(self, name): user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), user_metadata.get('negative text', ''), - float(user_metadata.get('negative weight', 0.0)), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -167,7 +165,6 @@ def create_editor(self): self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts") - self.slider_negative_weight = gr.Slider(label='Preferred negative weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) with gr.Row() as row_random_prompt: with gr.Column(scale=8): random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) @@ -204,7 +201,6 @@ def select_tag(activation_text, evt: gr.SelectData): self.edit_activation_text, self.slider_preferred_weight, self.edit_negative_text, - self.slider_negative_weight, row_random_prompt, random_prompt, ] @@ -219,7 +215,6 @@ def select_tag(activation_text, evt: gr.SelectData): self.edit_activation_text, self.slider_preferred_weight, self.edit_negative_text, - self.slider_negative_weight, self.edit_notes, ] diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 9a6624e3f1e..e714fac4692 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -46,13 +46,9 @@ def create_item(self, name, index=None, enable_filter=True): item["prompt"] += " + " + quote_js(" " + activation_text) negative_prompt = item["user_metadata"].get("negative text") - preferred_negative_weight = item["user_metadata"].get("negative weight") item["negative_prompt"] = quote_js("") if negative_prompt: - neg_prompt = negative_prompt - if (preferred_negative_weight > 0): - neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' - item["negative_prompt"] = quote_js(neg_prompt) + item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)') sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: From b6f74e936e4de3b8d190bffaf3bed67d6d4bd211 Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sun, 31 Dec 2023 13:36:36 +0100 Subject: [PATCH 170/449] Revert change from linting for unrelated file --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 1d610dbffd9..39f78a0b7bc 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import tqdm from PIL import Image -from modules import images +from modules import devices, images logger = logging.getLogger(__name__) From a70dfb64a86b9b6d869deffdb0ffebe980365473 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 31 Dec 2023 22:38:30 +0300 Subject: [PATCH 171/449] change import statements for #14478 --- modules/devices.py | 4 ++-- modules/interrogate.py | 5 ++--- modules/sd_models_xl.py | 4 ++-- modules/upscaler_utils.py | 5 ++--- modules/xlmr.py | 4 ++-- modules/xlmr_m18.py | 5 ++--- test/test_torch_utils.py | 4 ++-- 7 files changed, 14 insertions(+), 17 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index bd6bd579b1a..ff279ac5016 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -4,7 +4,7 @@ import torch from modules import errors, shared -from modules.torch_utils import get_param +from modules import torch_utils if sys.platform == "darwin": from modules import mac_specific @@ -132,7 +132,7 @@ def cond_cast_float(input): def manual_cast_forward(self, *args, **kwargs): - org_dtype = get_param(self).dtype + org_dtype = torch_utils.get_param(self).dtype self.to(dtype) args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} diff --git a/modules/interrogate.py b/modules/interrogate.py index 5be5a10f39c..35a627ca196 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -10,8 +10,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -from modules import devices, paths, shared, lowvram, modelloader, errors -from modules.torch_utils import get_param +from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -132,7 +131,7 @@ def load(self): self.clip_model = self.clip_model.to(devices.device_interrogate) - self.dtype = get_param(self.clip_model).dtype + self.dtype = torch_utils.get_param(self.clip_model).dtype def send_clip_to_ram(self): if not shared.opts.interrogate_keep_models_in_memory: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index c3602a7e1a6..0de17af3db2 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -6,7 +6,7 @@ import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer from modules import devices, shared, prompt_parser -from modules.torch_utils import get_param +from modules import torch_utils def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): @@ -91,7 +91,7 @@ def get_target_prompt_token_count(self, token_count): def extend_sdxl(model): """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" - dtype = get_param(model.model.diffusion_model).dtype + dtype = torch_utils.get_param(model.model.diffusion_model).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' model.cond_stage_key = 'txt' diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index c60e3beb89d..f5cb92d51b8 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,8 +6,7 @@ import tqdm from PIL import Image -from modules import images, shared -from modules.torch_utils import get_param +from modules import images, shared, torch_utils logger = logging.getLogger(__name__) @@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - param = get_param(model) + param = torch_utils.get_param(model) img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): diff --git a/modules/xlmr.py b/modules/xlmr.py index 6e000a56e75..319771b7bf0 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -5,7 +5,7 @@ from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional -from modules.torch_utils import get_param +from modules import torch_utils class BertSeriesConfig(BertConfig): @@ -65,7 +65,7 @@ def __init__(self, config=None, **kargs): self.post_init() def encode(self,c): - device = get_param(self).device + device = torch_utils.get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index e3e8196108c..f60555049f5 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -4,8 +4,7 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional - -from modules.torch_utils import get_param +from modules import torch_utils class BertSeriesConfig(BertConfig): @@ -71,7 +70,7 @@ def __init__(self, config=None, **kargs): self.post_init() def encode(self,c): - device = get_param(self).device + device = torch_utils.get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py index f1aec832945..23ccb93a464 100644 --- a/test/test_torch_utils.py +++ b/test/test_torch_utils.py @@ -3,7 +3,7 @@ import pytest import torch -from modules.torch_utils import get_param +from modules import torch_utils @pytest.mark.parametrize("wrapped", [True, False]) @@ -14,6 +14,6 @@ def test_get_param(wrapped): if wrapped: # more or less how spandrel wraps a thing mod = types.SimpleNamespace(model=mod) - p = get_param(mod) + p = torch_utils.get_param(mod) assert p.dtype == torch.float16 assert p.device == cpu From 00901bfbe0095303554f4440b4c12fac262e2e89 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 1 Jan 2024 15:47:57 +0900 Subject: [PATCH 172/449] handle selectable script_index is None --- modules/scripts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/scripts.py b/modules/scripts.py index 3a76691186b..017aed5a01e 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -696,6 +696,8 @@ def setup_ui(self): self.setup_ui_for_section(None, self.selectable_scripts) def select_script(script_index): + if script_index is None: + script_index = 0 selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None return [gr.update(visible=selected_script == s) for s in self.selectable_scripts] @@ -739,7 +741,7 @@ def onload_script_visibility(params): def run(self, p, *args): script_index = args[0] - if script_index == 0: + if script_index == 0 or script_index is None: return None script = self.selectable_scripts[script_index-1] From 5692bf1517c3409ad46262c56e65f256389825b1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 11:11:14 +0300 Subject: [PATCH 173/449] add missing field for DDIM sampler that was breaking img2img --- modules/sd_samplers_timesteps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index b17a8f93c2b..f8afa8bd7e5 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -80,6 +80,7 @@ def __init__(self, funcname, sd_model): self.eta_default = 0.0 self.model_wrap_cfg = CFGDenoiserTimesteps(self) + self.model_wrap = self.model_wrap_cfg.inner_model def get_timesteps(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) From 003b91f08361c99ecdd97257624d81a2046d3823 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 13:45:01 +0300 Subject: [PATCH 174/449] rename generation_parameters_copypaste module to infotext --- modules/{generation_parameters_copypaste.py => infotext.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{generation_parameters_copypaste.py => infotext.py} (100%) diff --git a/modules/generation_parameters_copypaste.py b/modules/infotext.py similarity index 100% rename from modules/generation_parameters_copypaste.py rename to modules/infotext.py From c5496c76461c90bd186ae8804aa65a33cd136d48 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 13:52:37 +0300 Subject: [PATCH 175/449] infotext.py: add support for old modules.generation_parameters_copypaste name --- modules/infotext.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/infotext.py b/modules/infotext.py index 86a36c3273c..bcbeb0fdcd1 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -4,12 +4,15 @@ import json import os import re +import sys import gradio as gr from modules.paths import data_path from modules import shared, ui_tempdir, script_callbacks, processing from PIL import Image +sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name + re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") From d859cec696a953dbfd6f69f7735e68661748d579 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 13:53:12 +0300 Subject: [PATCH 176/449] infotext.py: rename usages in the codebase --- .../scripts/extra_options_section.py | 4 ++-- modules/api/api.py | 10 +++++----- modules/img2img.py | 2 +- modules/postprocessing.py | 4 ++-- modules/processing.py | 4 ++-- modules/processing_scripts/refiner.py | 2 +- modules/processing_scripts/seed.py | 2 +- modules/shared_items.py | 4 ++-- modules/txt2img.py | 2 +- modules/ui.py | 4 ++-- modules/ui_common.py | 4 ++-- modules/ui_extra_networks.py | 2 +- modules/ui_extra_networks_user_metadata.py | 4 ++-- modules/ui_postprocessing.py | 2 +- 14 files changed, 25 insertions(+), 25 deletions(-) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index ac2c3de4643..8aa901fd481 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,7 +1,7 @@ import math import gradio as gr -from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste +from modules import scripts, shared, ui_components, ui_settings, infotext from modules.ui_components import FormColumn @@ -25,7 +25,7 @@ def ui(self, is_img2img): extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") - mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} + mapping = {k: v for v, k in infotext.infotext_to_setting_name_mapping} with gr.Blocks() as interface: with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): diff --git a/modules/api/api.py b/modules/api/api.py index 843c59b0400..0e2807de2b6 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,7 +17,7 @@ from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -369,9 +369,9 @@ def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_scri if not request.infotext: return {} - possible_fields = generation_parameters_copypaste.paste_fields[tabname]["fields"] + possible_fields = infotext.paste_fields[tabname]["fields"] set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this - params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) + params = infotext.parse_generation_parameters(request.infotext) def get_field_value(field, params): value = field.function(params) if field.function else params.get(field.label) @@ -408,7 +408,7 @@ def get_field_value(field, params): if request.override_settings is None: request.override_settings = {} - overriden_settings = generation_parameters_copypaste.get_override_settings(params) + overriden_settings = infotext.get_override_settings(params) for _, setting_name, value in overriden_settings: if setting_name not in request.override_settings: request.override_settings[setting_name] = value @@ -584,7 +584,7 @@ def pnginfoapi(self, req: models.PNGInfoRequest): if geninfo is None: geninfo = "" - params = generation_parameters_copypaste.parse_generation_parameters(geninfo) + params = infotext.parse_generation_parameters(geninfo) script_callbacks.infotext_pasted_callback(geninfo, params) return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) diff --git a/modules/img2img.py b/modules/img2img.py index c583290a0ea..75b3d346f9a 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -7,7 +7,7 @@ import gradio as gr from modules import images as imgutil -from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters +from modules.infotext import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state from modules.sd_models import get_closet_checkpoint_match diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 0c59fad480b..f776f7b69b9 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -2,7 +2,7 @@ from PIL import Image -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext from modules.shared import opts @@ -86,7 +86,7 @@ def get_images(extras_mode, image, image_folder, input_dir): basename = '' forced_filename = None - infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) + infotext = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in pp.info.items() if v is not None]) if opts.enable_pnginfo: pp.image.info = existing_pnginfo diff --git a/modules/processing.py b/modules/processing.py index 7789f9a4239..b30df60db3a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -733,7 +733,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "User": p.user if opts.add_user_name_to_info else None, } - generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in generation_params.items() if v is not None]) prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else "" diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index cefad32b761..e9941413f24 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -1,7 +1,7 @@ import gradio as gr from modules import scripts, sd_models -from modules.generation_parameters_copypaste import PasteField +from modules.infotext import PasteField from modules.ui_common import create_refresh_button from modules.ui_components import InputAccordion diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index a3e16a12e9b..602932785fe 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -3,7 +3,7 @@ import gradio as gr from modules import scripts, ui, errors -from modules.generation_parameters_copypaste import PasteField +from modules.infotext import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton diff --git a/modules/shared_items.py b/modules/shared_items.py index 991971ad0fb..e1392472099 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -67,14 +67,14 @@ def reload_hypernetworks(): def get_infotext_names(): - from modules import generation_parameters_copypaste, shared + from modules import infotext, shared res = {} for info in shared.opts.data_labels.values(): if info.infotext: res[info.infotext] = 1 - for tab_data in generation_parameters_copypaste.paste_fields.values(): + for tab_data in infotext.paste_fields.values(): for _, name in tab_data.get("fields") or []: if isinstance(name, str): res[name] = 1 diff --git a/modules/txt2img.py b/modules/txt2img.py index e4e18ceb6dd..3a481915f3a 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -2,7 +2,7 @@ import modules.scripts from modules import processing -from modules.generation_parameters_copypaste import create_override_settings_dict +from modules.infotext import create_override_settings_dict from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html diff --git a/modules/ui.py b/modules/ui.py index 9db2407ed30..6451e14c14a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,14 +21,14 @@ from modules.shared import opts, cmd_opts -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste import modules.hypernetworks.ui as hypernetworks_ui import modules.textual_inversion.ui as textual_inversion_ui import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.generation_parameters_copypaste import image_from_url_text, PasteField +from modules.infotext import image_from_url_text, PasteField create_setting_component = ui_settings.create_setting_component diff --git a/modules/ui_common.py b/modules/ui_common.py index 032ec4af762..fd32676f956 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -8,10 +8,10 @@ import subprocess as sp from modules import call_queue, shared -from modules.generation_parameters_copypaste import image_from_url_text +from modules.infotext import image_from_url_text import modules.images from modules.ui_components import ToolButton -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index b8c02241365..790af135665 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -10,7 +10,7 @@ import html from fastapi.exceptions import HTTPException -from modules.generation_parameters_copypaste import image_from_url_text +from modules.infotext import image_from_url_text from modules.ui_components import ToolButton extra_pages = [] diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 36a807fcdf9..87aeb6f3359 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -5,7 +5,7 @@ import gradio as gr -from modules import generation_parameters_copypaste, images, sysinfo, errors, ui_extra_networks +from modules import infotext, images, sysinfo, errors, ui_extra_networks class UserMetadataEditor: @@ -181,7 +181,7 @@ def save_preview(self, index, gallery, name): index = len(gallery) - 1 if index >= len(gallery) else index img_info = gallery[index if index >= 0 else 0] - image = generation_parameters_copypaste.image_from_url_text(img_info) + image = infotext.image_from_url_text(img_info) geninfo, items = images.read_info_from_image(image) images.save_image_with_geninfo(image, geninfo, item["local_preview"]) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 13d888e48d1..b74a153231c 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -1,6 +1,6 @@ import gradio as gr from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste def create_ui(): From d613cd17c72c753bd1e314dff74dc22d9a949374 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 14:38:29 +0300 Subject: [PATCH 177/449] add automatic backwards version compatibility --- modules/infotext.py | 4 +++- modules/infotext_versions.py | 35 +++++++++++++++++++++++++++++++++++ modules/shared_options.py | 1 + 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 modules/infotext_versions.py diff --git a/modules/infotext.py b/modules/infotext.py index bcbeb0fdcd1..7f30446b9b4 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -8,7 +8,7 @@ import gradio as gr from modules.paths import data_path -from modules import shared, ui_tempdir, script_callbacks, processing +from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, errors from PIL import Image sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name @@ -342,6 +342,8 @@ def parse_generation_parameters(x: str): if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable": res["Cache FP16 weight for LoRA"] = False + infotext_versions.backcompat(res) + skip = set(shared.opts.infotext_skip_pasting) res = {k: v for k, v in res.items() if k not in skip} diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py new file mode 100644 index 00000000000..01e885a26ef --- /dev/null +++ b/modules/infotext_versions.py @@ -0,0 +1,35 @@ +from modules import shared +from packaging import version +import re + + +v160 = version.parse("1.6.0") + + +def parse_version(text): + if text is None: + return None + + m = re.match(r'([^-]+-[^-]+)-.*', text) + if m: + text = m.group(1) + + try: + return version.parse(text) + except Exception as e: + return None + + +def backcompat(d): + """Checks infotext Version field, and enables backwards compatibility options according to it.""" + + if not shared.opts.auto_backcompat: + return + + ver = parse_version(d.get("Version")) + if ver is None: + return + + if ver < v160: + d["Old prompt editing timelines"] = True + diff --git a/modules/shared_options.py b/modules/shared_options.py index 752a4f1259a..281591da810 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -212,6 +212,7 @@ })) options_templates.update(options_section(('compatibility', "Compatibility", "sd"), { + "auto_backcompat": OptionInfo(True, "Automatic backward compatibility").info("automatically enable options for backwards compatibility when importing generation parameters from infotext that has program version."), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."), From 45b7bba3d06f2d4bc2fffc210cbfcb357b86add6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 14:51:56 +0300 Subject: [PATCH 178/449] add automatic version support for zero terminal SNR noise schedule option from #14145 --- modules/infotext_versions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py index 01e885a26ef..9a204d84292 100644 --- a/modules/infotext_versions.py +++ b/modules/infotext_versions.py @@ -4,6 +4,7 @@ v160 = version.parse("1.6.0") +v170_tsnr = version.parse("v1.7.0-225") def parse_version(text): @@ -33,3 +34,6 @@ def backcompat(d): if ver < v160: d["Old prompt editing timelines"] = True + if ver < v170_tsnr: + d["Downcast alphas_cumprod"] = True + From d8126be578c7d4579c0f2ee4adbe35500bc71ce6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 15:00:39 +0300 Subject: [PATCH 179/449] linter --- modules/infotext.py | 2 +- modules/infotext_versions.py | 2 +- modules/postprocessing.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/infotext.py b/modules/infotext.py index 7f30446b9b4..26e9b94934f 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -8,7 +8,7 @@ import gradio as gr from modules.paths import data_path -from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, errors +from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions from PIL import Image sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py index 9a204d84292..a5afeebf136 100644 --- a/modules/infotext_versions.py +++ b/modules/infotext_versions.py @@ -17,7 +17,7 @@ def parse_version(text): try: return version.parse(text) - except Exception as e: + except Exception: return None diff --git a/modules/postprocessing.py b/modules/postprocessing.py index f776f7b69b9..facea899f37 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -2,7 +2,7 @@ from PIL import Image -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common from modules.shared import opts From 0743ee9b3eda8dd4ceea625d710031577201f4ad Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 15:50:47 +0300 Subject: [PATCH 180/449] re-layout checkboxes for XYZ grid a bit --- scripts/xyz_grid.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 34267c2c3bc..2d550994709 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -438,17 +438,16 @@ def ui(self, is_img2img): with gr.Column(): draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + with gr.Row(): + vary_seeds_x = gr.Checkbox(label='Vary seeds for X', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_x"), tooltip="Use different seeds for images along X axis.") + vary_seeds_y = gr.Checkbox(label='Vary seeds for Y', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_y"), tooltip="Use different seeds for images along Y axis.") + vary_seeds_z = gr.Checkbox(label='Vary seeds for Z', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_z"), tooltip="Use different seeds for images along Z axis.") with gr.Column(): include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) - with gr.Column(): - vary_seeds_x = gr.Checkbox(label='Vary seed on X axis', value=False, elem_id=self.elem_id("vary_seeds_x")) - vary_seeds_y = gr.Checkbox(label='Vary seed on Y axis', value=False, elem_id=self.elem_id("vary_seeds_y")) - vary_seeds_z = gr.Checkbox(label='Vary seed on Z axis', value=False, elem_id=self.elem_id("vary_seeds_z")) + csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode")) with gr.Column(): margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) - with gr.Column(): - csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode")) with gr.Row(variant="compact", elem_id="swap_axes"): swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") @@ -531,7 +530,7 @@ def get_dropdown_update_from_params(axis, params): return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode] - def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode): + def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode): x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0 if not no_fixed_seeds: From ac0ecf3b4b9d147743c04f0ff4ddc4cf4595e11d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 16:28:58 +0300 Subject: [PATCH 181/449] option to convert VAE to bfloat16 (implementation of #9295) --- modules/processing.py | 23 ++++++++++++++++++----- modules/shared_options.py | 1 + 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 846e4796af2..f065688218a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -628,20 +628,33 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): sample = decode_first_stage(model, batch[i:i + 1])[0] if check_for_nans: + try: devices.test_for_nans(sample, "vae") except devices.NansException as e: - if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision: + if shared.opts.auto_vae_precision_bfloat16: + autofix_dtype = torch.bfloat16 + autofix_dtype_text = "bfloat16" + autofix_dtype_setting = "Automatically convert VAE to bfloat16" + autofix_dtype_comment = "" + elif shared.opts.auto_vae_precision: + autofix_dtype = torch.float32 + autofix_dtype_text = "32-bit float" + autofix_dtype_setting = "Automatically revert VAE to 32-bit floats" + autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag." + else: + raise e + + if devices.dtype_vae == autofix_dtype: raise e errors.print_error_explanation( "A tensor with all NaNs was produced in VAE.\n" - "Web UI will now convert VAE into 32-bit float and retry.\n" - "To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n" - "To always start with 32-bit VAE, use --no-half-vae commandline flag." + f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n" + f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}" ) - devices.dtype_vae = torch.float32 + devices.dtype_vae = autofix_dtype model.first_stage_model.to(devices.dtype_vae) batch = batch.to(devices.dtype_vae) diff --git a/modules/shared_options.py b/modules/shared_options.py index ce06f022eaf..e813546f6a5 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -177,6 +177,7 @@ "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"), + "auto_vae_precision_bfloat16": OptionInfo(False, "Automatically convert VAE to bfloat16").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image; if enabled, overrides the option below"), "auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"), "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"), From 0aa7c53c0b9469849377aff83f43c9f75c19b3fa Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 16:50:59 +0300 Subject: [PATCH 182/449] fix borked merge, rename fields to better match what they do, change setting default to true for #13653 --- modules/call_queue.py | 2 +- modules/img2img.py | 2 +- modules/processing.py | 2 +- modules/shared_options.py | 2 +- modules/shared_state.py | 12 ++++++------ modules/ui_toprow.py | 8 +++++++- scripts/loopback.py | 4 ++-- scripts/xyz_grid.py | 2 +- 8 files changed, 20 insertions(+), 14 deletions(-) diff --git a/modules/call_queue.py b/modules/call_queue.py index 01c6d17f685..bcd7c546200 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -78,7 +78,7 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): shared.state.skipped = False shared.state.interrupted = False - shared.state.interrupted_next = False + shared.state.stopping_generation = False shared.state.job_count = 0 if not add_stats: diff --git a/modules/img2img.py b/modules/img2img.py index 829faa818ff..e7e8e2510ca 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -51,7 +51,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal if state.skipped: state.skipped = False - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break try: diff --git a/modules/processing.py b/modules/processing.py index 00de2ed2e9a..f55b85ed350 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.skipped: state.skipped = False - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break sd_models.reload_model_weights() # model can be changed for example by refiner diff --git a/modules/shared_options.py b/modules/shared_options.py index 7852e0ea362..7581e276e25 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -120,7 +120,6 @@ "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), - "interrupt_after_current": OptionInfo(False, "Interrupt generation after current image is finished on batch processing"), })) options_templates.update(options_section(('API', "API", "system"), { @@ -286,6 +285,7 @@ "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), "txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(), "img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(), + "interrupt_after_current": OptionInfo(True, "Don't Interrupt in the middle").info("when using Interrupt button, if generating more than one image, stop after the generation of an image has finished, instead of immediately"), })) options_templates.update(options_section(('ui', "User interface", "ui"), { diff --git a/modules/shared_state.py b/modules/shared_state.py index 532fdcd8d12..33996691c40 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -12,7 +12,7 @@ class State: skipped = False interrupted = False - interrupted_next = False + stopping_generation = False job = "" job_no = 0 job_count = 0 @@ -80,9 +80,9 @@ def interrupt(self): self.interrupted = True log.info("Received interrupt request") - def interrupt_next(self): - self.interrupted_next = True - log.info("Received interrupt request, interrupt after current job") + def stop_generating(self): + self.stopping_generation = True + log.info("Received stop generating request") def nextjob(self): if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: @@ -96,7 +96,7 @@ def dict(self): obj = { "skipped": self.skipped, "interrupted": self.interrupted, - "interrupted_next": self.interrupted_next, + "stopping_generation": self.stopping_generation, "job": self.job, "job_count": self.job_count, "job_timestamp": self.job_timestamp, @@ -120,7 +120,7 @@ def begin(self, job: str = "(unknown)"): self.id_live_preview = 0 self.skipped = False self.interrupted = False - self.interrupted_next = False + self.stopping_generation = False self.textinfo = None self.job = job devices.torch_gc() diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index 9caf8faa2de..1abc9117ba3 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -106,8 +106,14 @@ def create_submit_box(self): outputs=[], ) + def interrupt_function(): + if shared.state.job_count > 1 and shared.opts.interrupt_after_current: + shared.state.stop_generating() + else: + shared.state.interrupt() + self.interrupt.click( - fn=lambda: shared.state.interrupt(), + fn=interrupt_function, inputs=[], outputs=[], ) diff --git a/scripts/loopback.py b/scripts/loopback.py index ad921269afb..800ee882a16 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -95,7 +95,7 @@ def calculate_denoising_strength(loop): processed = processing.process_images(p) # Generation cancelled. - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break if initial_seed is None: @@ -122,7 +122,7 @@ def calculate_denoising_strength(loop): p.inpainting_fill = original_inpainting_fill - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break if len(history) > 1: diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 2deff365fa1..2f385ebf295 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -696,7 +696,7 @@ def fix_axis_seeds(axis_opt, axis_list): grid_infotext = [None] * (1 + len(zs)) def cell(x, y, z, ix, iy, iz): - if shared.state.interrupted or state.interrupted_next: + if shared.state.interrupted or state.stopping_generation: return Processed(p, [], p.seed, "") pc = copy(p) From 1ffdedc11d49862cf0d030fb0bcc25eb0449939b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 17:03:08 +0300 Subject: [PATCH 183/449] restore lines lost from #13789 merge --- modules/cmd_args.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 3775e670261..e58059a1fea 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -93,7 +93,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) -parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) +parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") @@ -115,8 +115,9 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') -parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') +parser.add_argument('--add-stop-route', action='store_true', help='does not do anything') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) -parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) +parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) +parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) From 5d7d1823afab0a051a3fbbdb3213bae8051350b7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 17:25:30 +0300 Subject: [PATCH 184/449] rename infotext.py again, this time to infotext_utils.py; I didn't realize infotext would be used for variable names in multiple places, which makes it awkward to import the module; also fix the bug I caused by this rename that breaks tests --- .../scripts/extra_options_section.py | 4 ++-- modules/api/api.py | 10 +++++----- modules/img2img.py | 2 +- modules/{infotext.py => infotext_utils.py} | 0 modules/postprocessing.py | 4 ++-- modules/processing.py | 4 ++-- modules/processing_scripts/refiner.py | 2 +- modules/processing_scripts/seed.py | 2 +- modules/shared_items.py | 4 ++-- modules/txt2img.py | 2 +- modules/ui.py | 4 ++-- modules/ui_common.py | 4 ++-- modules/ui_extra_networks.py | 2 +- modules/ui_extra_networks_user_metadata.py | 4 ++-- modules/ui_postprocessing.py | 2 +- 15 files changed, 25 insertions(+), 25 deletions(-) rename modules/{infotext.py => infotext_utils.py} (100%) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 8aa901fd481..4c10d9c7d6a 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,7 +1,7 @@ import math import gradio as gr -from modules import scripts, shared, ui_components, ui_settings, infotext +from modules import scripts, shared, ui_components, ui_settings, infotext_utils from modules.ui_components import FormColumn @@ -25,7 +25,7 @@ def ui(self, is_img2img): extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") - mapping = {k: v for v, k in infotext.infotext_to_setting_name_mapping} + mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping} with gr.Blocks() as interface: with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): diff --git a/modules/api/api.py b/modules/api/api.py index 0e2807de2b6..9d1292e9571 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,7 +17,7 @@ from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext, sd_models +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -369,9 +369,9 @@ def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_scri if not request.infotext: return {} - possible_fields = infotext.paste_fields[tabname]["fields"] + possible_fields = infotext_utils.paste_fields[tabname]["fields"] set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this - params = infotext.parse_generation_parameters(request.infotext) + params = infotext_utils.parse_generation_parameters(request.infotext) def get_field_value(field, params): value = field.function(params) if field.function else params.get(field.label) @@ -408,7 +408,7 @@ def get_field_value(field, params): if request.override_settings is None: request.override_settings = {} - overriden_settings = infotext.get_override_settings(params) + overriden_settings = infotext_utils.get_override_settings(params) for _, setting_name, value in overriden_settings: if setting_name not in request.override_settings: request.override_settings[setting_name] = value @@ -584,7 +584,7 @@ def pnginfoapi(self, req: models.PNGInfoRequest): if geninfo is None: geninfo = "" - params = infotext.parse_generation_parameters(geninfo) + params = infotext_utils.parse_generation_parameters(geninfo) script_callbacks.infotext_pasted_callback(geninfo, params) return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) diff --git a/modules/img2img.py b/modules/img2img.py index e7e8e2510ca..04de8e62cf2 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -7,7 +7,7 @@ import gradio as gr from modules import images as imgutil -from modules.infotext import create_override_settings_dict, parse_generation_parameters +from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state from modules.sd_models import get_closet_checkpoint_match diff --git a/modules/infotext.py b/modules/infotext_utils.py similarity index 100% rename from modules/infotext.py rename to modules/infotext_utils.py diff --git a/modules/postprocessing.py b/modules/postprocessing.py index facea899f37..7850328f63f 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -2,7 +2,7 @@ from PIL import Image -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext_utils from modules.shared import opts @@ -86,7 +86,7 @@ def get_images(extras_mode, image, image_folder, input_dir): basename = '' forced_filename = None - infotext = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in pp.info.items() if v is not None]) + infotext = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in pp.info.items() if v is not None]) if opts.enable_pnginfo: pp.image.info = existing_pnginfo diff --git a/modules/processing.py b/modules/processing.py index f55b85ed350..213a2879cc6 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -746,7 +746,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "User": p.user if opts.add_user_name_to_info else None, } - generation_params_text = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None]) prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else "" diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index e9941413f24..ba33d8a4b80 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -1,7 +1,7 @@ import gradio as gr from modules import scripts, sd_models -from modules.infotext import PasteField +from modules.infotext_utils import PasteField from modules.ui_common import create_refresh_button from modules.ui_components import InputAccordion diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index 602932785fe..2d3cbb97f06 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -3,7 +3,7 @@ import gradio as gr from modules import scripts, ui, errors -from modules.infotext import PasteField +from modules.infotext_utils import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton diff --git a/modules/shared_items.py b/modules/shared_items.py index e1392472099..13fb2814f2b 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -67,14 +67,14 @@ def reload_hypernetworks(): def get_infotext_names(): - from modules import infotext, shared + from modules import infotext_utils, shared res = {} for info in shared.opts.data_labels.values(): if info.infotext: res[info.infotext] = 1 - for tab_data in infotext.paste_fields.values(): + for tab_data in infotext_utils.paste_fields.values(): for _, name in tab_data.get("fields") or []: if isinstance(name, str): res[name] = 1 diff --git a/modules/txt2img.py b/modules/txt2img.py index 3a481915f3a..49660e89167 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -2,7 +2,7 @@ import modules.scripts from modules import processing -from modules.infotext import create_override_settings_dict +from modules.infotext_utils import create_override_settings_dict from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html diff --git a/modules/ui.py b/modules/ui.py index 378529c79a4..52b15646ab5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,14 +21,14 @@ from modules.shared import opts, cmd_opts -import modules.infotext as parameters_copypaste +import modules.infotext_utils as parameters_copypaste import modules.hypernetworks.ui as hypernetworks_ui import modules.textual_inversion.ui as textual_inversion_ui import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.infotext import image_from_url_text, PasteField +from modules.infotext_utils import image_from_url_text, PasteField create_setting_component = ui_settings.create_setting_component diff --git a/modules/ui_common.py b/modules/ui_common.py index fd32676f956..f48ad426072 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -8,10 +8,10 @@ import subprocess as sp from modules import call_queue, shared -from modules.infotext import image_from_url_text +from modules.infotext_utils import image_from_url_text import modules.images from modules.ui_components import ToolButton -import modules.infotext as parameters_copypaste +import modules.infotext_utils as parameters_copypaste folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 790af135665..beea1316083 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -10,7 +10,7 @@ import html from fastapi.exceptions import HTTPException -from modules.infotext import image_from_url_text +from modules.infotext_utils import image_from_url_text from modules.ui_components import ToolButton extra_pages = [] diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 87aeb6f3359..989a649b799 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -5,7 +5,7 @@ import gradio as gr -from modules import infotext, images, sysinfo, errors, ui_extra_networks +from modules import infotext_utils, images, sysinfo, errors, ui_extra_networks class UserMetadataEditor: @@ -181,7 +181,7 @@ def save_preview(self, index, gallery, name): index = len(gallery) - 1 if index >= len(gallery) else index img_info = gallery[index if index >= 0 else 0] - image = infotext.image_from_url_text(img_info) + image = infotext_utils.image_from_url_text(img_info) geninfo, items = images.read_info_from_image(image) images.save_image_with_geninfo(image, geninfo, item["local_preview"]) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index b74a153231c..1edb68c5ced 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -1,6 +1,6 @@ import gradio as gr from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow -import modules.infotext as parameters_copypaste +import modules.infotext_utils as parameters_copypaste def create_ui(): From 501993ebf210bf3b55173ec1910f0c84c7e75424 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 19:31:06 +0300 Subject: [PATCH 185/449] added a button to run hires fix on selected image in the gallery --- javascript/ui.js | 8 +++ modules/processing.py | 46 ++++++++++++--- modules/txt2img.py | 19 +++++- modules/ui.py | 108 +++++++++++++++++++---------------- modules/ui_common.py | 57 +++++++++++------- modules/ui_postprocessing.py | 8 +-- 6 files changed, 160 insertions(+), 86 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index 18c9f891afc..78d1caee5bf 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -150,6 +150,14 @@ function submit() { return res; } +function submit_txt2img_upscale() { + res = submit.apply(null, arguments); + + res[2] = selected_gallery_index(); + + return res; +} + function submit_img2img() { showSubmitButtons('img2img', false); diff --git a/modules/processing.py b/modules/processing.py index 213a2879cc6..045c7d7955e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -179,6 +179,7 @@ class StableDiffusionProcessing: token_merging_ratio = 0 token_merging_ratio_hr = 0 disable_extra_networks: bool = False + firstpass_image: Image = None scripts_value: scripts.ScriptRunner = field(default=None, init=False) script_args_value: list = field(default=None, init=False) @@ -1238,18 +1239,45 @@ def init(self, all_prompts, all_seeds, all_subseeds): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - x = self.rng.next() - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) - del x + if self.firstpass_image is not None and self.enable_hr: + # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix - if not self.enable_hr: - return samples - devices.torch_gc() + if self.latent_scale_mode is None: + image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0 + image = np.moveaxis(image, 2, 0) + + samples = None + decoded_samples = torch.asarray(np.expand_dims(image, 0)) + + else: + image = np.array(self.firstpass_image).astype(np.float32) / 255.0 + image = np.moveaxis(image, 2, 0) + image = torch.from_numpy(np.expand_dims(image, axis=0)) + image = image.to(shared.device, dtype=devices.dtype_vae) + + if opts.sd_vae_encode_method != 'Full': + self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method + + samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model) + decoded_samples = None + devices.torch_gc() - if self.latent_scale_mode is None: - decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) else: - decoded_samples = None + # here we generate an image normally + + x = self.rng.next() + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + del x + + if not self.enable_hr: + return samples + + devices.torch_gc() + + if self.latent_scale_mode is None: + decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) + else: + decoded_samples = None with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=self.hr_checkpoint_info) diff --git a/modules/txt2img.py b/modules/txt2img.py index 49660e89167..4a6fe72a65d 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,7 +1,7 @@ from contextlib import closing import modules.scripts -from modules import processing +from modules import processing, infotext_utils from modules.infotext_utils import create_override_settings_dict from modules.shared import opts import modules.shared as shared @@ -9,9 +9,23 @@ import gradio as gr -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): +def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, *args): + assert len(gallery) > 0, 'No image to upscale' + + image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] + image = infotext_utils.image_from_url_text(image_info) + + return txt2img(id_task, request, *args, firstpass_image=image) + + +def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, firstpass_image=None): override_settings = create_override_settings_dict(override_settings_texts) + if firstpass_image is not None: + enable_hr = True + batch_size = 1 + n_iter = 1 + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -38,6 +52,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, override_settings=override_settings, + firstpass_image=firstpass_image, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index 52b15646ab5..3d548430d36 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -375,50 +375,60 @@ def create_ui(): show_progress=False, ) - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow) + output_panel = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow) + + txt2img_inputs = [ + dummy_component, + toprow.prompt, + toprow.negative_prompt, + toprow.ui_styles.dropdown, + steps, + sampler_name, + batch_count, + batch_size, + cfg_scale, + height, + width, + enable_hr, + denoising_strength, + hr_scale, + hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, + hr_checkpoint_name, + hr_sampler_name, + hr_prompt, + hr_negative_prompt, + override_settings, + ] + custom_inputs + + txt2img_outputs = [ + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, + ] txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), _js="submit", - inputs=[ - dummy_component, - toprow.prompt, - toprow.negative_prompt, - toprow.ui_styles.dropdown, - steps, - sampler_name, - batch_count, - batch_size, - cfg_scale, - height, - width, - enable_hr, - denoising_strength, - hr_scale, - hr_upscaler, - hr_second_pass_steps, - hr_resize_x, - hr_resize_y, - hr_checkpoint_name, - hr_sampler_name, - hr_prompt, - hr_negative_prompt, - override_settings, - - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, - ], + inputs=txt2img_inputs, + outputs=txt2img_outputs, show_progress=False, ) toprow.prompt.submit(**txt2img_args) toprow.submit.click(**txt2img_args) + output_panel.button_upscale.click( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']), + _js="submit_txt2img_upscale", + inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component] + txt2img_inputs[1:], + outputs=txt2img_outputs, + show_progress=False, + ) + res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False) toprow.restore_progress_button.click( @@ -426,10 +436,10 @@ def create_ui(): _js="restoreProgressTxt2img", inputs=[dummy_component], outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, ], show_progress=False, ) @@ -479,7 +489,7 @@ def create_ui(): toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') - ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery) extra_tabs.__exit__() @@ -710,7 +720,7 @@ def select_img2img_tab(tab): outputs=[inpaint_controls, mask_alpha], ) - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow) + output_panel = create_output_panel("img2img", opts.outdir_img2img_samples, toprow) img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), @@ -755,10 +765,10 @@ def select_img2img_tab(tab): img2img_batch_png_info_dir, ] + custom_inputs, outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, ], show_progress=False, ) @@ -796,10 +806,10 @@ def select_img2img_tab(tab): _js="restoreProgressImg2img", inputs=[dummy_component], outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, ], show_progress=False, ) @@ -839,7 +849,7 @@ def select_img2img_tab(tab): )) extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img') - ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + ui_extra_networks.setup_ui(extra_networks_ui_img2img, output_panel.gallery) extra_tabs.__exit__() diff --git a/modules/ui_common.py b/modules/ui_common.py index f48ad426072..ff84197c1de 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -1,3 +1,4 @@ +import dataclasses import json import html import os @@ -104,7 +105,17 @@ def __init__(self, d=None): return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") +@dataclasses.dataclass +class OutputPanel: + gallery = None + infotext = None + html_info = None + html_log = None + button_upscale = None + + def create_output_panel(tabname, outdir, toprow=None): + res = OutputPanel() def open_folder(f): if not os.path.exists(f): @@ -136,9 +147,8 @@ def open_folder(f): with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"): with gr.Group(elem_id=f"{tabname}_gallery_container"): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) + res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) - generation_info = None with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"): open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.") @@ -152,6 +162,9 @@ def open_folder(f): 'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.") } + if tabname == 'txt2img': + res.button_upscale = ToolButton('✨', elem_id=f'{tabname}_upscale', tooltip="Create an upscaled version of the current image using hires fix settings.") + open_folder_button.click( fn=lambda: open_folder(shared.opts.outdir_samples or outdir), inputs=[], @@ -162,17 +175,17 @@ def open_folder(f): download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") - html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") + res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + res.infotext = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') if tabname == 'txt2img' or tabname == 'img2img': generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button.click( fn=update_generation_info, _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", - inputs=[generation_info, html_info, html_info], - outputs=[html_info, html_info], + inputs=[res.infotext, res.html_info, res.html_info], + outputs=[res.html_info, res.html_info], show_progress=False, ) @@ -180,14 +193,14 @@ def open_folder(f): fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", inputs=[ - generation_info, - result_gallery, - html_info, - html_info, + res.infotext, + res.gallery, + res.html_info, + res.html_info, ], outputs=[ download_files, - html_log, + res.html_log, ], show_progress=False, ) @@ -196,21 +209,21 @@ def open_folder(f): fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", inputs=[ - generation_info, - result_gallery, - html_info, - html_info, + res.infotext, + res.gallery, + res.html_info, + res.html_info, ], outputs=[ download_files, - html_log, + res.html_log, ] ) else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") - html_log = gr.HTML(elem_id=f'html_log_{tabname}') + res.infotext = gr.HTML(elem_id=f'html_info_x_{tabname}') + res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.html_log = gr.HTML(elem_id=f'html_log_{tabname}') paste_field_names = [] if tabname == "txt2img": @@ -220,11 +233,11 @@ def open_folder(f): for paste_tabname, paste_button in buttons.items(): parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery, + paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=res.gallery, paste_field_names=paste_field_names )) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + return res def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 1edb68c5ced..8f09e6588a3 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -28,7 +28,7 @@ def create_ui(): toprow.create_inline_toprow_image() submit = toprow.submit - result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) + output_panel = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) @@ -48,9 +48,9 @@ def create_ui(): *script_inputs ], outputs=[ - result_images, - html_info_x, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_log, ], show_progress=False, ) From c32c51a0fc49ca4fbfe02bfbae45ec49e7bf9876 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 1 Jan 2024 19:20:54 +0200 Subject: [PATCH 186/449] Fix lint issue from 501993eb --- javascript/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/ui.js b/javascript/ui.js index 78d1caee5bf..3430b3fefd8 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -151,7 +151,7 @@ function submit() { } function submit_txt2img_upscale() { - res = submit.apply(null, arguments); + var res = submit(...arguments); res[2] = selected_gallery_index(); From c2ea571005dab29b285e31a0ad4a97258360bf2d Mon Sep 17 00:00:00 2001 From: Jibaku789 <151478027+Jibaku789@users.noreply.github.com> Date: Mon, 1 Jan 2024 14:57:41 -0600 Subject: [PATCH 187/449] Add inpaint options to paste fields --- modules/ui.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/ui.py b/modules/ui.py index 3d548430d36..4cdb0e9c0f3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -840,6 +840,10 @@ def select_img2img_tab(tab): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), + (inpainting_mask_invert, 'Mask mode'), + (inpainting_fill, 'Masked content'), + (inpaint_full_res, 'Inpaint area'), + (inpaint_full_res_padding, 'Only masked padding, pixels'), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) From a5b6a5a3adcc845237d872750ded34240cc6a810 Mon Sep 17 00:00:00 2001 From: Jibaku789 <151478027+Jibaku789@users.noreply.github.com> Date: Mon, 1 Jan 2024 14:58:55 -0600 Subject: [PATCH 188/449] Add inpaint options to img2img.py --- modules/img2img.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/modules/img2img.py b/modules/img2img.py index 04de8e62cf2..9e09c0a00cc 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -225,6 +225,18 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if mask: p.extra_generation_params["Mask blur"] = mask_blur + if inpainting_mask_invert is not None: + p.extra_generation_params["Mask mode"] = inpainting_mask_invert + + if inpainting_fill is not None: + p.extra_generation_params["Masked content"] = inpainting_fill + + if inpaint_full_res is not None: + p.extra_generation_params["Inpaint area"] = inpaint_full_res + + if inpaint_full_res_padding is not None: + p.extra_generation_params["Only masked padding, pixels"] = inpaint_full_res_padding + with closing(p): if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" From 1341b2208185cd89b0019bda2df63b406ec0cb5e Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 2 Jan 2024 06:47:26 +0300 Subject: [PATCH 189/449] add an option to hide upscaling progressbar --- modules/shared_options.py | 1 + modules/upscaler_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index cca3f7be3c7..63488f4e700 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -115,6 +115,7 @@ "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), + "enable_upscale_progressbar": OptionInfo(True, "Show a progress bar in the console for tiled upscaling."), "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index f5cb92d51b8..9379f512b12 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -47,7 +47,7 @@ def upscale_with_model( grid = images.split_grid(img, tile_size, tile_size, tile_overlap) newtiles = [] - with tqdm.tqdm(total=grid.tile_count, desc=desc) as p: + with tqdm.tqdm(total=grid.tile_count, desc=desc, disable=not shared.opts.enable_upscale_progressbar) as p: for y, h, row in grid.tiles: newrow = [] for x, w, tile in row: @@ -103,7 +103,7 @@ def tiled_upscale_2( ).type_as(img) weights = torch.zeros_like(result) logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) - with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar: + with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar: for h_idx in h_idx_list: if shared.state.interrupted or shared.state.skipped: break From 80873b1538e6ca0c7ebe558f8ce4213b06fd8307 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 2 Jan 2024 07:05:05 +0300 Subject: [PATCH 190/449] fix #14497 --- modules/img2img.py | 15 --------------- modules/infotext_utils.py | 12 ++++++++++++ modules/processing.py | 13 +++++++++++++ 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 9e09c0a00cc..f81405df5db 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -222,21 +222,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if shared.opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) - if mask: - p.extra_generation_params["Mask blur"] = mask_blur - - if inpainting_mask_invert is not None: - p.extra_generation_params["Mask mode"] = inpainting_mask_invert - - if inpainting_fill is not None: - p.extra_generation_params["Masked content"] = inpainting_fill - - if inpaint_full_res is not None: - p.extra_generation_params["Inpaint area"] = inpaint_full_res - - if inpaint_full_res_padding is not None: - p.extra_generation_params["Only masked padding, pixels"] = inpaint_full_res_padding - with closing(p): if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index 26e9b94934f..e582ee479c6 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -312,6 +312,18 @@ def parse_generation_parameters(x: str): if "Hires negative prompt" not in res: res["Hires negative prompt"] = "" + if "Mask mode" not in res: + res["Mask mode"] = "Inpaint masked" + + if "Masked content" not in res: + res["Masked content"] = 'original' + + if "Inpaint area" not in res: + res["Inpaint area"] = "Whole picture" + + if "Masked area padding" not in res: + res["Masked area padding"] = 32 + restore_old_hires_fix_params(res) # Missing RNG means the default was set, which is GPU RNG diff --git a/modules/processing.py b/modules/processing.py index 045c7d7955e..84e7b1b4577 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1530,6 +1530,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) + self.extra_generation_params["Mask mode"] = "Inpaint not masked" if self.mask_blur_x > 0: np_mask = np.array(image_mask) @@ -1543,6 +1544,9 @@ def init(self, all_prompts, all_seeds, all_subseeds): np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) image_mask = Image.fromarray(np_mask) + if self.mask_blur_x > 0 or self.mask_blur_y > 0: + self.extra_generation_params["Mask blur"] = self.mask_blur + if self.inpaint_full_res: self.mask_for_overlay = image_mask mask = image_mask.convert('L') @@ -1553,6 +1557,9 @@ def init(self, all_prompts, all_seeds, all_subseeds): mask = mask.crop(crop_region) image_mask = images.resize_image(2, mask, self.width, self.height) self.paste_to = (x1, y1, x2-x1, y2-y1) + + self.extra_generation_params["Inpaint area"] = "Only masked" + self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) np_mask = np.array(image_mask) @@ -1594,6 +1601,9 @@ def init(self, all_prompts, all_seeds, all_subseeds): if self.inpainting_fill != 1: image = masking.fill(image, latent_mask) + if self.inpainting_fill == 0: + self.extra_generation_params["Masked content"] = 'fill' + if add_color_corrections: self.color_corrections.append(setup_color_correction(image)) @@ -1643,8 +1653,11 @@ def init(self, all_prompts, all_seeds, all_subseeds): # this needs to be fixed to be done in sample() using actual seeds for batches if self.inpainting_fill == 2: self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask + self.extra_generation_params["Masked content"] = 'latent noise' + elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask + self.extra_generation_params["Masked content"] = 'latent nothing' self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round) From 980970d39091e572500434c69660bc6eed22498d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 2 Jan 2024 07:08:32 +0300 Subject: [PATCH 191/449] final touches --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 4cdb0e9c0f3..7116d71c582 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -843,7 +843,7 @@ def select_img2img_tab(tab): (inpainting_mask_invert, 'Mask mode'), (inpainting_fill, 'Masked content'), (inpaint_full_res, 'Inpaint area'), - (inpaint_full_res_padding, 'Only masked padding, pixels'), + (inpaint_full_res_padding, 'Masked area padding'), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) From cf14a6a7aaf8ccb40552990785d5c9e400d93610 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 16:11:18 +0200 Subject: [PATCH 192/449] Refactor upscale_2 helper out of ScuNET/SwinIR; make sure devices are right --- .../ScuNET/scripts/scunet_model.py | 48 +++------- .../SwinIR/scripts/swinir_model.py | 62 ++----------- modules/upscaler_utils.py | 89 ++++++++++++++----- 3 files changed, 87 insertions(+), 112 deletions(-) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index f799cb76dec..fe5e5a19265 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,13 +1,9 @@ import sys import PIL.Image -import numpy as np -import torch import modules.upscaler -from modules import devices, modelloader, script_callbacks, errors -from modules.shared import opts -from modules.upscaler_utils import tiled_upscale_2 +from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils class UpscalerScuNET(modules.upscaler.Upscaler): @@ -40,46 +36,23 @@ def __init__(self, dirname): self.scalers = scalers def do_upscale(self, img: PIL.Image.Image, selected_file): - devices.torch_gc() - try: model = self.load_model(selected_file) except Exception as e: print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) return img - device = devices.get_device_for('scunet') - tile = opts.SCUNET_tile - h, w = img.height, img.width - np_img = np.array(img) - np_img = np_img[:, :, ::-1] # RGB to BGR - np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW - torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore - - if tile > h or tile > w: - _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device) - _img[:, :, :h, :w] = torch_img # pad image - torch_img = _img - - with torch.no_grad(): - torch_output = tiled_upscale_2( - torch_img, - model, - tile_size=opts.SCUNET_tile, - tile_overlap=opts.SCUNET_tile_overlap, - scale=1, - device=devices.get_device_for('scunet'), - desc="ScuNET tiles", - ).squeeze(0) - torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any - np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() - del torch_img, torch_output + img = upscaler_utils.upscale_2( + img, + model, + tile_size=shared.opts.SCUNET_tile, + tile_overlap=shared.opts.SCUNET_tile_overlap, + scale=1, # ScuNET is a denoising model, not an upscaler + desc='ScuNET', + ) devices.torch_gc() - - output = np_output.transpose((1, 2, 0)) # CHW to HWC - output = output[:, :, ::-1] # BGR to RGB - return PIL.Image.fromarray((output * 255).astype(np.uint8)) + return img def load_model(self, path: str): device = devices.get_device_for('scunet') @@ -93,7 +66,6 @@ def load_model(self, path: str): def on_ui_settings(): import gradio as gr - from modules import shared shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 8a555c794f7..bc427feac24 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,14 +1,10 @@ import logging import sys -import numpy as np -import torch from PIL import Image -from modules import modelloader, devices, script_callbacks, shared -from modules.shared import opts +from modules import devices, modelloader, script_callbacks, shared, upscaler_utils from modules.upscaler import Upscaler, UpscalerData -from modules.upscaler_utils import tiled_upscale_2 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" @@ -36,9 +32,7 @@ def __init__(self, dirname): self.scalers = scalers def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: - current_config = (model_file, opts.SWIN_tile) - - device = self._get_device() + current_config = (model_file, shared.opts.SWIN_tile) if self._cached_model_config == current_config: model = self._cached_model @@ -51,12 +45,13 @@ def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: self._cached_model = model self._cached_model_config = current_config - img = upscale( + img = upscaler_utils.upscale_2( img, model, - tile=opts.SWIN_tile, - tile_overlap=opts.SWIN_tile_overlap, - device=device, + tile_size=shared.opts.SWIN_tile, + tile_overlap=shared.opts.SWIN_tile_overlap, + scale=4, # TODO: This was hard-coded before too... + desc="SwinIR", ) devices.torch_gc() return img @@ -77,7 +72,7 @@ def load_model(self, path, scale=4): dtype=devices.dtype, expected_architecture="SwinIR", ) - if getattr(opts, 'SWIN_torch_compile', False): + if getattr(shared.opts, 'SWIN_torch_compile', False): try: model_descriptor.model.compile() except Exception: @@ -88,47 +83,6 @@ def _get_device(self): return devices.get_device_for('swinir') -def upscale( - img, - model, - *, - tile: int, - tile_overlap: int, - window_size=8, - scale=4, - device, -): - - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device, dtype=devices.dtype) - with torch.no_grad(), devices.autocast(): - _, _, h_old, w_old = img.size() - h_pad = (h_old // window_size + 1) * window_size - h_old - w_pad = (w_old // window_size + 1) * window_size - w_old - img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] - img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = tiled_upscale_2( - img, - model, - tile_size=tile, - tile_overlap=tile_overlap, - scale=scale, - device=device, - desc="SwinIR tiles", - ) - output = output[..., : h_old * scale, : w_old * scale] - output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() - if output.ndim == 3: - output = np.transpose( - output[[2, 1, 0], :, :], (1, 2, 0) - ) # CHW-RGB to HCW-BGR - output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 - return Image.fromarray(output, "RGB") - - def on_ui_settings(): import gradio as gr diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 9379f512b12..e4c63f09758 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -11,23 +11,40 @@ logger = logging.getLogger(__name__) -def upscale_without_tiling(model, img: Image.Image): - img = np.array(img) - img = img[:, :, ::-1] - img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 - img = torch.from_numpy(img).float() - +def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor: + img = np.array(img.convert("RGB")) + img = img[:, :, ::-1] # flip RGB to BGR + img = np.transpose(img, (2, 0, 1)) # HWC to CHW + img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1] + return torch.from_numpy(img) + + +def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image: + if tensor.ndim == 4: + # If we're given a tensor with a batch dimension, squeeze it out + # (but only if it's a batch of size 1). + if tensor.shape[0] != 1: + raise ValueError(f"{tensor.shape} does not describe a BCHW tensor") + tensor = tensor.squeeze(0) + assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor" + # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom? + arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp + arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale + arr = arr.astype(np.uint8) + arr = arr[:, :, ::-1] # flip BGR to RGB + return Image.fromarray(arr, "RGB") + + +def upscale_pil_patch(model, img: Image.Image) -> Image.Image: + """ + Upscale a given PIL image using the given model. + """ param = torch_utils.get_param(model) - img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): - output = model(img) - - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - return Image.fromarray(output, 'RGB') + tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension + tensor = tensor.to(device=param.device, dtype=param.dtype) + return torch_bgr_to_pil_image(model(tensor)) def upscale_with_model( @@ -40,7 +57,7 @@ def upscale_with_model( ) -> Image.Image: if tile_size <= 0: logger.debug("Upscaling %s without tiling", img) - output = upscale_without_tiling(model, img) + output = upscale_pil_patch(model, img) logger.debug("=> %s", output) return output @@ -52,7 +69,7 @@ def upscale_with_model( newrow = [] for x, w, tile in row: logger.debug("Tile (%d, %d) %s...", x, y, tile) - output = upscale_without_tiling(model, tile) + output = upscale_pil_patch(model, tile) scale_factor = output.width // tile.width logger.debug("=> %s (scale factor %s)", output, scale_factor) newrow.append([x * scale_factor, w * scale_factor, output]) @@ -71,19 +88,22 @@ def upscale_with_model( def tiled_upscale_2( - img, + img: torch.Tensor, model, *, tile_size: int, tile_overlap: int, scale: int, - device, desc="Tiled upscale", ): # Alternative implementation of `upscale_with_model` originally used by # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # Pillow space without weighting. + + # Grab the device the model is on, and use it. + device = torch_utils.get_param(model).device + b, c, h, w = img.size() tile_size = min(tile_size, h, w) @@ -100,7 +120,8 @@ def tiled_upscale_2( h * scale, w * scale, device=device, - ).type_as(img) + dtype=img.dtype, + ) weights = torch.zeros_like(result) logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar: @@ -112,11 +133,13 @@ def tiled_upscale_2( if shared.state.interrupted or shared.state.skipped: break + # Only move this patch to the device if it's not already there. in_patch = img[ ..., h_idx : h_idx + tile_size, w_idx : w_idx + tile_size, - ] + ].to(device=device) + out_patch = model(in_patch) result[ @@ -138,3 +161,29 @@ def tiled_upscale_2( output = result.div_(weights) return output + + +def upscale_2( + img: Image.Image, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + desc: str, +): + """ + Convenience wrapper around `tiled_upscale_2` that handles PIL images. + """ + tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension + + with torch.no_grad(): + output = tiled_upscale_2( + tensor, + model, + tile_size=tile_size, + tile_overlap=tile_overlap, + scale=scale, + desc=desc, + ) + return torch_bgr_to_pil_image(output) From 2cacbc124c49f45da5b66b79d9b0a3ab943472eb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 19:52:32 +0200 Subject: [PATCH 193/449] load_spandrel_model: make `half` `prefer_half` As discussed with the Spandrel folks, it's good to heed Spandrel's "supports half precision" flag to avoid e.g. black blotches and what-not. --- modules/modelloader.py | 20 ++++++++++++++------ modules/realesrgan_model.py | 2 +- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index a7194137571..e100bb2469c 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -139,23 +139,31 @@ def load_upscalers(): def load_spandrel_model( - path: str, + path: str | os.PathLike, *, device: str | torch.device | None, - half: bool = False, + prefer_half: bool = False, dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, ) -> spandrel.ModelDescriptor: import spandrel - model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path) + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path)) if expected_architecture and model_descriptor.architecture != expected_architecture: logger.warning( f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", ) - if half: - model_descriptor.model.half() + half = False + if prefer_half: + if model_descriptor.supports_half: + model_descriptor.model.half() + half = True + else: + logger.info("Model %s does not support half precision, ignoring --half", path) if dtype: model_descriptor.model.to(dtype=dtype) model_descriptor.model.eval() - logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype) + logger.debug( + "Loaded %s from %s (device=%s, half=%s, dtype=%s)", + model_descriptor, path, device, half, dtype, + ) return model_descriptor diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 4d35b695c3b..ff9d8ac0d69 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -39,7 +39,7 @@ def do_upscale(self, img, path): model_descriptor = modelloader.load_spandrel_model( info.local_data_path, device=self.device, - half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel ) return upscale_with_model( From 62bd7624d2adb123e84d4625804e9ad94d1db018 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 2 Jan 2024 10:40:26 +0200 Subject: [PATCH 194/449] Remove licenses for code that's no longer copy-pasted; adjust README --- README.md | 11 +- html/licenses.html | 310 +-------------------------------------------- 2 files changed, 7 insertions(+), 314 deletions(-) diff --git a/README.md b/README.md index 9f9f33b1295..72908fa7764 100644 --- a/README.md +++ b/README.md @@ -151,11 +151,12 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers - k-diffusion - https://github.com/crowsonkb/k-diffusion.git -- GFPGAN - https://github.com/TencentARC/GFPGAN.git -- CodeFormer - https://github.com/sczhou/CodeFormer -- ESRGAN - https://github.com/xinntao/ESRGAN -- SwinIR - https://github.com/JingyunLiang/SwinIR -- Swin2SR - https://github.com/mv-lab/swin2sr +- Spandrel - https://github.com/chaiNNer-org/spandrel implementing + - GFPGAN - https://github.com/TencentARC/GFPGAN.git + - CodeFormer - https://github.com/sczhou/CodeFormer + - ESRGAN - https://github.com/xinntao/ESRGAN + - SwinIR - https://github.com/JingyunLiang/SwinIR + - Swin2SR - https://github.com/mv-lab/swin2sr - LDSR - https://github.com/Hafiidz/latent-diffusion - MiDaS - https://github.com/isl-org/MiDaS - Ideas for optimizations - https://github.com/basujindal/stable-diffusion diff --git a/html/licenses.html b/html/licenses.html index ef6f2c0a42b..9f5d1e9dc5c 100644 --- a/html/licenses.html +++ b/html/licenses.html @@ -4,107 +4,6 @@ #licenses pre { margin: 1em 0 2em 0;} -

CodeFormer

-Parts of CodeFormer code had to be copied to be compatible with GFPGAN. -
-S-Lab License 1.0
-
-Copyright 2022 S-Lab
-
-Redistribution and use for non-commercial purpose in source and
-binary forms, with or without modification, are permitted provided
-that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright
-   notice, this list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright
-   notice, this list of conditions and the following disclaimer in
-   the documentation and/or other materials provided with the
-   distribution.
-
-3. Neither the name of the copyright holder nor the names of its
-   contributors may be used to endorse or promote products derived
-   from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-In the event that redistribution and/or use for commercial purpose in
-source or binary forms, with or without modification is required,
-please contact the contributor(s) of the work.
-
- - -

ESRGAN

-Code for architecture and reading models copied. -
-MIT License
-
-Copyright (c) 2021 victorca25
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
- -

Real-ESRGAN

-Some code is copied to support ESRGAN models. -
-BSD 3-Clause License
-
-Copyright (c) 2021, Xintao Wang
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright notice, this
-   list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright notice,
-   this list of conditions and the following disclaimer in the documentation
-   and/or other materials provided with the distribution.
-
-3. Neither the name of the copyright holder nor the names of its
-   contributors may be used to endorse or promote products derived from
-   this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-

InvokeAI

Some code for compatibility with OSX is taken from lstein's repository.
@@ -183,213 +82,6 @@ 

SwinIR

-Code added by contributors, most likely copied from this repository. - -
-                                 Apache License
-                           Version 2.0, January 2004
-                        http://www.apache.org/licenses/
-
-   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
-   1. Definitions.
-
-      "License" shall mean the terms and conditions for use, reproduction,
-      and distribution as defined by Sections 1 through 9 of this document.
-
-      "Licensor" shall mean the copyright owner or entity authorized by
-      the copyright owner that is granting the License.
-
-      "Legal Entity" shall mean the union of the acting entity and all
-      other entities that control, are controlled by, or are under common
-      control with that entity. For the purposes of this definition,
-      "control" means (i) the power, direct or indirect, to cause the
-      direction or management of such entity, whether by contract or
-      otherwise, or (ii) ownership of fifty percent (50%) or more of the
-      outstanding shares, or (iii) beneficial ownership of such entity.
-
-      "You" (or "Your") shall mean an individual or Legal Entity
-      exercising permissions granted by this License.
-
-      "Source" form shall mean the preferred form for making modifications,
-      including but not limited to software source code, documentation
-      source, and configuration files.
-
-      "Object" form shall mean any form resulting from mechanical
-      transformation or translation of a Source form, including but
-      not limited to compiled object code, generated documentation,
-      and conversions to other media types.
-
-      "Work" shall mean the work of authorship, whether in Source or
-      Object form, made available under the License, as indicated by a
-      copyright notice that is included in or attached to the work
-      (an example is provided in the Appendix below).
-
-      "Derivative Works" shall mean any work, whether in Source or Object
-      form, that is based on (or derived from) the Work and for which the
-      editorial revisions, annotations, elaborations, or other modifications
-      represent, as a whole, an original work of authorship. For the purposes
-      of this License, Derivative Works shall not include works that remain
-      separable from, or merely link (or bind by name) to the interfaces of,
-      the Work and Derivative Works thereof.
-
-      "Contribution" shall mean any work of authorship, including
-      the original version of the Work and any modifications or additions
-      to that Work or Derivative Works thereof, that is intentionally
-      submitted to Licensor for inclusion in the Work by the copyright owner
-      or by an individual or Legal Entity authorized to submit on behalf of
-      the copyright owner. For the purposes of this definition, "submitted"
-      means any form of electronic, verbal, or written communication sent
-      to the Licensor or its representatives, including but not limited to
-      communication on electronic mailing lists, source code control systems,
-      and issue tracking systems that are managed by, or on behalf of, the
-      Licensor for the purpose of discussing and improving the Work, but
-      excluding communication that is conspicuously marked or otherwise
-      designated in writing by the copyright owner as "Not a Contribution."
-
-      "Contributor" shall mean Licensor and any individual or Legal Entity
-      on behalf of whom a Contribution has been received by Licensor and
-      subsequently incorporated within the Work.
-
-   2. Grant of Copyright License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      copyright license to reproduce, prepare Derivative Works of,
-      publicly display, publicly perform, sublicense, and distribute the
-      Work and such Derivative Works in Source or Object form.
-
-   3. Grant of Patent License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      (except as stated in this section) patent license to make, have made,
-      use, offer to sell, sell, import, and otherwise transfer the Work,
-      where such license applies only to those patent claims licensable
-      by such Contributor that are necessarily infringed by their
-      Contribution(s) alone or by combination of their Contribution(s)
-      with the Work to which such Contribution(s) was submitted. If You
-      institute patent litigation against any entity (including a
-      cross-claim or counterclaim in a lawsuit) alleging that the Work
-      or a Contribution incorporated within the Work constitutes direct
-      or contributory patent infringement, then any patent licenses
-      granted to You under this License for that Work shall terminate
-      as of the date such litigation is filed.
-
-   4. Redistribution. You may reproduce and distribute copies of the
-      Work or Derivative Works thereof in any medium, with or without
-      modifications, and in Source or Object form, provided that You
-      meet the following conditions:
-
-      (a) You must give any other recipients of the Work or
-          Derivative Works a copy of this License; and
-
-      (b) You must cause any modified files to carry prominent notices
-          stating that You changed the files; and
-
-      (c) You must retain, in the Source form of any Derivative Works
-          that You distribute, all copyright, patent, trademark, and
-          attribution notices from the Source form of the Work,
-          excluding those notices that do not pertain to any part of
-          the Derivative Works; and
-
-      (d) If the Work includes a "NOTICE" text file as part of its
-          distribution, then any Derivative Works that You distribute must
-          include a readable copy of the attribution notices contained
-          within such NOTICE file, excluding those notices that do not
-          pertain to any part of the Derivative Works, in at least one
-          of the following places: within a NOTICE text file distributed
-          as part of the Derivative Works; within the Source form or
-          documentation, if provided along with the Derivative Works; or,
-          within a display generated by the Derivative Works, if and
-          wherever such third-party notices normally appear. The contents
-          of the NOTICE file are for informational purposes only and
-          do not modify the License. You may add Your own attribution
-          notices within Derivative Works that You distribute, alongside
-          or as an addendum to the NOTICE text from the Work, provided
-          that such additional attribution notices cannot be construed
-          as modifying the License.
-
-      You may add Your own copyright statement to Your modifications and
-      may provide additional or different license terms and conditions
-      for use, reproduction, or distribution of Your modifications, or
-      for any such Derivative Works as a whole, provided Your use,
-      reproduction, and distribution of the Work otherwise complies with
-      the conditions stated in this License.
-
-   5. Submission of Contributions. Unless You explicitly state otherwise,
-      any Contribution intentionally submitted for inclusion in the Work
-      by You to the Licensor shall be under the terms and conditions of
-      this License, without any additional terms or conditions.
-      Notwithstanding the above, nothing herein shall supersede or modify
-      the terms of any separate license agreement you may have executed
-      with Licensor regarding such Contributions.
-
-   6. Trademarks. This License does not grant permission to use the trade
-      names, trademarks, service marks, or product names of the Licensor,
-      except as required for reasonable and customary use in describing the
-      origin of the Work and reproducing the content of the NOTICE file.
-
-   7. Disclaimer of Warranty. Unless required by applicable law or
-      agreed to in writing, Licensor provides the Work (and each
-      Contributor provides its Contributions) on an "AS IS" BASIS,
-      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-      implied, including, without limitation, any warranties or conditions
-      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
-      PARTICULAR PURPOSE. You are solely responsible for determining the
-      appropriateness of using or redistributing the Work and assume any
-      risks associated with Your exercise of permissions under this License.
-
-   8. Limitation of Liability. In no event and under no legal theory,
-      whether in tort (including negligence), contract, or otherwise,
-      unless required by applicable law (such as deliberate and grossly
-      negligent acts) or agreed to in writing, shall any Contributor be
-      liable to You for damages, including any direct, indirect, special,
-      incidental, or consequential damages of any character arising as a
-      result of this License or out of the use or inability to use the
-      Work (including but not limited to damages for loss of goodwill,
-      work stoppage, computer failure or malfunction, or any and all
-      other commercial damages or losses), even if such Contributor
-      has been advised of the possibility of such damages.
-
-   9. Accepting Warranty or Additional Liability. While redistributing
-      the Work or Derivative Works thereof, You may choose to offer,
-      and charge a fee for, acceptance of support, warranty, indemnity,
-      or other liability obligations and/or rights consistent with this
-      License. However, in accepting such obligations, You may act only
-      on Your own behalf and on Your sole responsibility, not on behalf
-      of any other Contributor, and only if You agree to indemnify,
-      defend, and hold each Contributor harmless for any liability
-      incurred by, or claims asserted against, such Contributor by reason
-      of your accepting any such warranty or additional liability.
-
-   END OF TERMS AND CONDITIONS
-
-   APPENDIX: How to apply the Apache License to your work.
-
-      To apply the Apache License to your work, attach the following
-      boilerplate notice, with the fields enclosed by brackets "[]"
-      replaced with your own identifying information. (Don't include
-      the brackets!)  The text should be enclosed in the appropriate
-      comment syntax for the file format. We also recommend that a
-      file or class name and description of purpose be included on the
-      same "printed page" as the copyright notice for easier
-      identification within third-party archives.
-
-   Copyright [2021] [SwinIR Authors]
-
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
-
-       http://www.apache.org/licenses/LICENSE-2.0
-
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License.
-
-

Memory Efficient Attention

The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.
@@ -687,4 +379,4 @@ 

TAESD \ No newline at end of file +

From 7ad6899bf987a8ee615efbcfc99562457f89cd8b Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 2 Jan 2024 17:14:05 +0200 Subject: [PATCH 195/449] torch_bgr_to_pil_image: round, don't truncate This matches what `realesrgan` does. --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index e4c63f09758..4f1417cf06a 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -30,7 +30,7 @@ def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image: # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom? arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale - arr = arr.astype(np.uint8) + arr = arr.round().astype(np.uint8) arr = arr[:, :, ::-1] # flip BGR to RGB return Image.fromarray(arr, "RGB") From fccd0b00c2ca17360b7b956cd2e9bd1fb42c017d Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 3 Jan 2024 18:55:43 +0900 Subject: [PATCH 196/449] reduce unnecessary re-indexing extra networks dir --- modules/ui_extra_networks.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index beea1316083..e1c679eca35 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -417,21 +417,21 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }") + def create_html(): + ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] + def pages_html(): if not ui.pages_contents: - return refresh() - + create_html() return ui.pages_contents def refresh(): for pg in ui.stored_extra_pages: pg.refresh() - - ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] - + create_html() return ui.pages_contents - interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages]) + interface.load(fn=pages_html, inputs=[], outputs=ui.pages) button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui From bfc48fbc244130770991fab284f6fedcef2054e7 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 4 Jan 2024 03:46:05 +0900 Subject: [PATCH 197/449] paste infotext cast int as float --- modules/infotext_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index e582ee479c6..a21329e6238 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -477,6 +477,8 @@ def paste_func(prompt): if valtype == bool and v == "False": val = False + elif valtype == int: + val = float(v) else: val = valtype(v) From dfdc51246c678b585e1bdfdb7d2f202b0ca0e362 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:38:13 +0200 Subject: [PATCH 198/449] SwinIR: use prefer_half --- extensions-builtin/SwinIR/scripts/swinir_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index bc427feac24..6a8e21b0250 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,6 +1,7 @@ import logging import sys +import torch from PIL import Image from modules import devices, modelloader, script_callbacks, shared, upscaler_utils @@ -69,7 +70,7 @@ def load_model(self, path, scale=4): model_descriptor = modelloader.load_spandrel_model( filename, device=self._get_device(), - dtype=devices.dtype, + prefer_half=(devices.dtype == torch.float16), expected_architecture="SwinIR", ) if getattr(shared.opts, 'SWIN_torch_compile', False): From 3d31d5c27beb433fa37b30f135ec06a278a87630 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:38:49 +0200 Subject: [PATCH 199/449] SwinIR: pass model.scale --- extensions-builtin/SwinIR/scripts/swinir_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 6a8e21b0250..16bf9b792fc 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -51,7 +51,7 @@ def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: model, tile_size=shared.opts.SWIN_tile, tile_overlap=shared.opts.SWIN_tile_overlap, - scale=4, # TODO: This was hard-coded before too... + scale=model.scale, desc="SwinIR", ) devices.torch_gc() From 62470ee23443cb2ad3943a152ccae26a689c86e1 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:39:12 +0200 Subject: [PATCH 200/449] upscale_2: cast image to model's dtype --- modules/upscaler_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index e4c63f09758..5db74877617 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -94,6 +94,7 @@ def tiled_upscale_2( tile_size: int, tile_overlap: int, scale: int, + device: torch.device, desc="Tiled upscale", ): # Alternative implementation of `upscale_with_model` originally used by @@ -101,9 +102,6 @@ def tiled_upscale_2( # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # Pillow space without weighting. - # Grab the device the model is on, and use it. - device = torch_utils.get_param(model).device - b, c, h, w = img.size() tile_size = min(tile_size, h, w) @@ -175,7 +173,8 @@ def upscale_2( """ Convenience wrapper around `tiled_upscale_2` that handles PIL images. """ - tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension + param = torch_utils.get_param(model) + tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension with torch.no_grad(): output = tiled_upscale_2( @@ -185,5 +184,6 @@ def upscale_2( tile_overlap=tile_overlap, scale=scale, desc=desc, + device=param.device, ) return torch_bgr_to_pil_image(output) From 50158a1fc9b4dd47a7bef70d34fbb0b30d5e8b47 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 4 Jan 2024 06:21:53 +0900 Subject: [PATCH 201/449] handle config.json failed to load --- modules/launch_utils.py | 4 +++- modules/options.py | 12 +++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index c2cbd8ce73e..e2ad412a6bb 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -249,7 +249,9 @@ def list_extensions(settings_file): with open(settings_file, "r", encoding="utf8") as file: settings = json.load(file) except Exception: - errors.report("Could not load settings", exc_info=True) + errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True) + os.replace(settings_file, os.path.join(script_path, "tmp", "config.json")) + settings = {} disabled_extensions = set(settings.get('disabled_extensions', [])) disable_all_extensions = settings.get('disable_all_extensions', 'none') diff --git a/modules/options.py b/modules/options.py index 09ff9403ddc..503b40e98db 100644 --- a/modules/options.py +++ b/modules/options.py @@ -1,3 +1,4 @@ +import os import json import sys from dataclasses import dataclass @@ -6,6 +7,7 @@ from modules import errors from modules.shared_cmd_options import cmd_opts +from modules.paths_internal import script_path class OptionInfo: @@ -193,9 +195,13 @@ def same_type(self, x, y): return type_x == type_y def load(self, filename): - with open(filename, "r", encoding="utf8") as file: - self.data = json.load(file) - + try: + with open(filename, "r", encoding="utf8") as file: + self.data = json.load(file) + except Exception: + errors.report(f'\nCould not load settings\nThe config file "{filename}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True) + os.replace(filename, os.path.join(script_path, "tmp", "config.json")) + self.data = {} # 1.6.0 VAE defaults if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None: self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default') From d9034b48a526f0a0c3e8f0dbf7c171bf4f0597fd Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 4 Jan 2024 00:16:58 +0200 Subject: [PATCH 202/449] Avoid unnecessary `isfile`/`exists` calls --- modules/cache.py | 17 ++++++++--------- modules/extensions.py | 11 ++++++----- modules/extra_networks.py | 7 ++++--- modules/infotext_utils.py | 4 +++- modules/launch_utils.py | 7 ++++--- modules/postprocessing.py | 7 ++++--- modules/shared_init.py | 4 +++- modules/ui_gradio_extensions.py | 8 +++----- modules/ui_loadsave.py | 5 +++-- modules/util.py | 6 +++--- 10 files changed, 41 insertions(+), 35 deletions(-) diff --git a/modules/cache.py b/modules/cache.py index 2d37e7b99d5..a9822a0eb21 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -62,16 +62,15 @@ def cache(subsection): if cache_data is None: with cache_lock: if cache_data is None: - if not os.path.isfile(cache_filename): + try: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + except FileNotFoundError: + cache_data = {} + except Exception: + os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) + print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') cache_data = {} - else: - try: - with open(cache_filename, "r", encoding="utf8") as file: - cache_data = json.load(file) - except Exception: - os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) - print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') - cache_data = {} s = cache_data.get(subsection, {}) cache_data[subsection] = s diff --git a/modules/extensions.py b/modules/extensions.py index 1899cd52975..99e7ee60f64 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -32,11 +32,12 @@ def __init__(self, path, canonical_name): self.config = configparser.ConfigParser() filepath = os.path.join(path, self.filename) - if os.path.isfile(filepath): - try: - self.config.read(filepath) - except Exception: - errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True) + # `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is), + # so no need to check whether the file exists beforehand. + try: + self.config.read(filepath) + except Exception: + errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True) self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name) self.canonical_name = canonical_name.lower().strip() diff --git a/modules/extra_networks.py b/modules/extra_networks.py index b9533677887..cd030fa3184 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -215,9 +215,10 @@ def get_user_metadata(filename): metadata = {} try: - if os.path.isfile(metadata_filename): - with open(metadata_filename, "r", encoding="utf8") as file: - metadata = json.load(file) + with open(metadata_filename, "r", encoding="utf8") as file: + metadata = json.load(file) + except FileNotFoundError: + pass except Exception as e: errors.display(e, f"reading extra network user metadata from {metadata_filename}") diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index e582ee479c6..6978a0bf0f0 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -453,9 +453,11 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: filename = os.path.join(data_path, "params.txt") - if os.path.exists(filename): + try: with open(filename, "r", encoding="utf8") as file: prompt = file.read() + except OSError: + pass params = parse_generation_parameters(prompt) script_callbacks.infotext_pasted_callback(prompt, params) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index c2cbd8ce73e..febd8c24a56 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -245,9 +245,10 @@ def list_extensions(settings_file): settings = {} try: - if os.path.isfile(settings_file): - with open(settings_file, "r", encoding="utf8") as file: - settings = json.load(file) + with open(settings_file, "r", encoding="utf8") as file: + settings = json.load(file) + except FileNotFoundError: + pass except Exception: errors.report("Could not load settings", exc_info=True) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 7850328f63f..7449b0dc51c 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -97,11 +97,12 @@ def get_images(extras_mode, image, image_folder, input_dir): if pp.caption: caption_filename = os.path.splitext(fullfn)[0] + ".txt" - if os.path.isfile(caption_filename): + existing_caption = "" + try: with open(caption_filename, encoding="utf8") as file: existing_caption = file.read().strip() - else: - existing_caption = "" + except FileNotFoundError: + pass action = shared.opts.postprocessing_existing_caption_action if action == 'Prepend' and existing_caption: diff --git a/modules/shared_init.py b/modules/shared_init.py index d3fb687e0cd..586be3423d0 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -18,8 +18,10 @@ def initialize(): shared.options_templates = shared_options.options_templates shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts) shared.restricted_opts = shared_options.restricted_opts - if os.path.exists(shared.config_filename): + try: shared.opts.load(shared.config_filename) + except FileNotFoundError: + pass from modules import devices devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py index a86c368ef70..f5278d22f02 100644 --- a/modules/ui_gradio_extensions.py +++ b/modules/ui_gradio_extensions.py @@ -35,13 +35,11 @@ def stylesheet(fn): return f'' for cssfile in scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - head += stylesheet(cssfile) - if os.path.exists(os.path.join(data_path, "user.css")): - head += stylesheet(os.path.join(data_path, "user.css")) + user_css = os.path.join(data_path, "user.css") + if os.path.exists(user_css): + head += stylesheet(user_css) return head diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index 693ff75c5d8..2555cdb6c60 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -26,8 +26,9 @@ def __init__(self, filename): self.ui_defaults_review = None try: - if os.path.exists(self.filename): - self.ui_settings = self.read_from_file() + self.ui_settings = self.read_from_file() + except FileNotFoundError: + pass except Exception as e: self.error_loading = True errors.display(e, "loading settings") diff --git a/modules/util.py b/modules/util.py index 4861bcb0861..d503f26729e 100644 --- a/modules/util.py +++ b/modules/util.py @@ -21,11 +21,11 @@ def html_path(filename): def html(filename): path = html_path(filename) - if os.path.exists(path): + try: with open(path, encoding="utf8") as file: return file.read() - - return "" + except OSError: + return "" def walk_files(path, allowed_extensions=None): From 420f56c2e85ebbd3f530cf2c7b22022fda13ae13 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 4 Jan 2024 02:28:05 +0300 Subject: [PATCH 203/449] mass file lister as an attempt to tackle #14507 --- modules/extra_networks.py | 5 +-- modules/ui_extra_networks.py | 18 ++++++---- modules/util.py | 70 ++++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 8 deletions(-) diff --git a/modules/extra_networks.py b/modules/extra_networks.py index b9533677887..04249dffd27 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -206,7 +206,7 @@ def parse_prompts(prompts): return res, extra_data -def get_user_metadata(filename): +def get_user_metadata(filename, lister=None): if filename is None: return {} @@ -215,7 +215,8 @@ def get_user_metadata(filename): metadata = {} try: - if os.path.isfile(metadata_filename): + exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename) + if exists: with open(metadata_filename, "r", encoding="utf8") as file: metadata = json.load(file) except Exception as e: diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index e1c679eca35..c06c8664950 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -3,7 +3,7 @@ import urllib.parse from pathlib import Path -from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks +from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks, util from modules.images import read_info_from_image, save_image_with_geninfo import gradio as gr import json @@ -107,6 +107,7 @@ def __init__(self, title): self.allow_negative_prompt = False self.metadata = {} self.items = {} + self.lister = util.MassFileLister() def refresh(self): pass @@ -123,7 +124,7 @@ def read_user_metadata(self, item): def link_preview(self, filename): quoted_filename = urllib.parse.quote(filename.replace('\\', '/')) - mtime = os.path.getmtime(filename) + mtime, _ = self.lister.mctime(filename) return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}" def search_terms_from_path(self, filename, possible_directories=None): @@ -137,6 +138,8 @@ def search_terms_from_path(self, filename, possible_directories=None): return "" def create_html(self, tabname): + self.lister.reset() + items_html = '' self.metadata = {} @@ -282,10 +285,10 @@ def get_sort_keys(self, path): List of default keys used for sorting in the UI. """ pth = Path(path) - stat = pth.stat() + mtime, ctime = self.lister.mctime(path) return { - "date_created": int(stat.st_ctime or 0), - "date_modified": int(stat.st_mtime or 0), + "date_created": int(mtime), + "date_modified": int(ctime), "name": pth.name.lower(), "path": str(pth.parent).lower(), } @@ -298,7 +301,7 @@ def find_preview(self, path): potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], []) for file in potential_files: - if os.path.isfile(file): + if self.lister.exists(file): return self.link_preview(file) return None @@ -308,6 +311,9 @@ def find_description(self, path): Find and read a description file for a given path (without extension). """ for file in [f"{path}.txt", f"{path}.description.txt"]: + if not self.lister.exists(file): + continue + try: with open(file, "r", encoding="utf-8", errors="replace") as f: return f.read() diff --git a/modules/util.py b/modules/util.py index 4861bcb0861..c2a27590dca 100644 --- a/modules/util.py +++ b/modules/util.py @@ -66,3 +66,73 @@ def truncate_path(target_path, base_path=cwd): except ValueError: pass return abs_target + + +class MassFileListerCachedDir: + """A class that caches file metadata for a specific directory.""" + + def __init__(self, dirname): + self.files = None + self.files_cased = None + self.dirname = dirname + + stats = ((x.name, x.stat(follow_symlinks=False)) for x in os.scandir(self.dirname)) + files = [(n, s.st_mtime, s.st_ctime) for n, s in stats] + self.files = {x[0].lower(): x for x in files} + self.files_cased = {x[0]: x for x in files} + + +class MassFileLister: + """A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file.""" + + def __init__(self): + self.cached_dirs = {} + + def find(self, path): + """ + Find the metadata for a file at the given path. + + Returns: + tuple or None: A tuple of (name, mtime, ctime) if the file exists, or None if it does not. + """ + + dirname, filename = os.path.split(path) + + cached_dir = self.cached_dirs.get(dirname) + if cached_dir is None: + cached_dir = MassFileListerCachedDir(dirname) + self.cached_dirs[dirname] = cached_dir + + stats = cached_dir.files_cased.get(filename) + if stats is not None: + return stats + + stats = cached_dir.files.get(filename.lower()) + if stats is None: + return None + + try: + os_stats = os.stat(path, follow_symlinks=False) + return filename, os_stats.st_mtime, os_stats.st_ctime + except Exception: + return None + + def exists(self, path): + """Check if a file exists at the given path.""" + + return self.find(path) is not None + + def mctime(self, path): + """ + Get the modification and creation times for a file at the given path. + + Returns: + tuple: A tuple of (mtime, ctime) if the file exists, or (0, 0) if it does not. + """ + + stats = self.find(path) + return (0, 0) if stats is None else stats[1:3] + + def reset(self): + """Clear the cache of all directories.""" + self.cached_dirs.clear() From 320a217b78047f30e1aa5e735742669a7f4c6bd8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 4 Jan 2024 02:39:02 +0300 Subject: [PATCH 204/449] forgot something --- modules/ui_extra_networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index c06c8664950..62db36f53cc 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -114,7 +114,7 @@ def refresh(self): def read_user_metadata(self, item): filename = item.get("filename", None) - metadata = extra_networks.get_user_metadata(filename) + metadata = extra_networks.get_user_metadata(filename, lister=self.lister) desc = metadata.get("description", None) if desc is not None: From 15ec54dd969d6dc3fea7790ca5cce5badcfda426 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 4 Jan 2024 19:47:00 +0300 Subject: [PATCH 205/449] Have upscale button use the same seed as hires fix. --- modules/scripts.py | 21 +++++++++++++++++++++ modules/txt2img.py | 14 +++++++++++++- modules/ui.py | 10 +++++----- modules/ui_common.py | 26 +++++++++++++------------- modules/ui_postprocessing.py | 2 +- 5 files changed, 53 insertions(+), 20 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 017aed5a01e..cf938ebb972 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -91,6 +91,9 @@ class Script: setup_for_ui_only = False """If true, the script setup will only be run in Gradio UI, not in API""" + controls = None + """A list of controls retured by the ui().""" + def title(self): """this function should return the title of the script. This is what will be displayed in the dropdown menu.""" @@ -624,6 +627,7 @@ def create_script_ui_inner(self, script): import modules.api.models as api_models controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) + script.controls = controls if controls is None: return @@ -918,6 +922,23 @@ def setup_scrips(self, p, *, is_ui=True): except Exception: errors.report(f"Error running setup: {script.filename}", exc_info=True) + def set_named_arg(self, args, script_type, arg_elem_id, value): + script = next((x for x in self.scripts if type(x).__name__ == script_type), None) + if script is None: + return + + for i, control in enumerate(script.controls): + if arg_elem_id in control.elem_id: + index = script.args_from + i + + if isinstance(args, list): + args[index] = value + return args + elif isinstance(args, tuple): + return args[:index] + (value,) + args[index+1:] + else: + return None + scripts_txt2img: ScriptRunner = None scripts_img2img: ScriptRunner = None diff --git a/modules/txt2img.py b/modules/txt2img.py index 4a6fe72a65d..41bb9da33dd 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,3 +1,4 @@ +import json from contextlib import closing import modules.scripts @@ -9,12 +10,19 @@ import gradio as gr -def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, *args): +def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args): assert len(gallery) > 0, 'No image to upscale' + assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}' + + geninfo = json.loads(generation_info) + all_seeds = geninfo["all_seeds"] image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] image = infotext_utils.image_from_url_text(image_info) + gallery_index_from_end = len(gallery) - gallery_index + image.seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] + return txt2img(id_task, request, *args, firstpass_image=image) @@ -22,6 +30,10 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str override_settings = create_override_settings_dict(override_settings_texts) if firstpass_image is not None: + seed = getattr(firstpass_image, 'seed', None) + if seed: + args = modules.scripts.scripts_txt2img.set_named_arg(args, 'ScriptSeed', 'seed', seed) + enable_hr = True batch_size = 1 n_iter = 1 diff --git a/modules/ui.py b/modules/ui.py index 7116d71c582..2d2e333b293 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -405,8 +405,8 @@ def create_ui(): txt2img_outputs = [ output_panel.gallery, + output_panel.generation_info, output_panel.infotext, - output_panel.html_info, output_panel.html_log, ] @@ -424,7 +424,7 @@ def create_ui(): output_panel.button_upscale.click( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']), _js="submit_txt2img_upscale", - inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component] + txt2img_inputs[1:], + inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component, output_panel.generation_info] + txt2img_inputs[1:], outputs=txt2img_outputs, show_progress=False, ) @@ -437,8 +437,8 @@ def create_ui(): inputs=[dummy_component], outputs=[ output_panel.gallery, + output_panel.generation_info, output_panel.infotext, - output_panel.html_info, output_panel.html_log, ], show_progress=False, @@ -766,8 +766,8 @@ def select_img2img_tab(tab): ] + custom_inputs, outputs=[ output_panel.gallery, + output_panel.generation_info, output_panel.infotext, - output_panel.html_info, output_panel.html_log, ], show_progress=False, @@ -807,8 +807,8 @@ def select_img2img_tab(tab): inputs=[dummy_component], outputs=[ output_panel.gallery, + output_panel.generation_info, output_panel.infotext, - output_panel.html_info, output_panel.html_log, ], show_progress=False, diff --git a/modules/ui_common.py b/modules/ui_common.py index ff84197c1de..f17259c2956 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -108,8 +108,8 @@ def __init__(self, d=None): @dataclasses.dataclass class OutputPanel: gallery = None + generation_info = None infotext = None - html_info = None html_log = None button_upscale = None @@ -175,17 +175,17 @@ def open_folder(f): download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') with gr.Group(): - res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") - res.infotext = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + res.generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') if tabname == 'txt2img' or tabname == 'img2img': generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button.click( fn=update_generation_info, _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", - inputs=[res.infotext, res.html_info, res.html_info], - outputs=[res.html_info, res.html_info], + inputs=[res.generation_info, res.infotext, res.infotext], + outputs=[res.infotext, res.infotext], show_progress=False, ) @@ -193,10 +193,10 @@ def open_folder(f): fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", inputs=[ - res.infotext, + res.generation_info, res.gallery, - res.html_info, - res.html_info, + res.infotext, + res.infotext, ], outputs=[ download_files, @@ -209,10 +209,10 @@ def open_folder(f): fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", inputs=[ - res.infotext, + res.generation_info, res.gallery, - res.html_info, - res.html_info, + res.infotext, + res.infotext, ], outputs=[ download_files, @@ -221,8 +221,8 @@ def open_folder(f): ) else: - res.infotext = gr.HTML(elem_id=f'html_info_x_{tabname}') - res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.generation_info = gr.HTML(elem_id=f'html_info_x_{tabname}') + res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") res.html_log = gr.HTML(elem_id=f'html_log_{tabname}') paste_field_names = [] diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 8f09e6588a3..7a132ac22ed 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -49,7 +49,7 @@ def create_ui(): ], outputs=[ output_panel.gallery, - output_panel.infotext, + output_panel.generation_info, output_panel.html_log, ], show_progress=False, From 9805f35c6f3ef0b0fc4e3648aa3d6eddf0a907af Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 4 Jan 2024 19:13:36 +0200 Subject: [PATCH 206/449] Ensure GRADIO_ANALYTICS_ENABLED is set early enough --- modules/initialize.py | 2 ++ modules/launch_utils.py | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/initialize.py b/modules/initialize.py index 4a3cd98cf86..7c1ac99ef07 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -1,5 +1,6 @@ import importlib import logging +import os import sys import warnings from threading import Thread @@ -18,6 +19,7 @@ def imports(): warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") + os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') import gradio # noqa: F401 startup_timer.record("import gradio") diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 7ebdf0b40a3..c2a7ae932ce 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -27,8 +27,7 @@ # Whether to default to printing command output default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1") -if 'GRADIO_ANALYTICS_ENABLED' not in os.environ: - os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' +os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') def check_python_version(): From 6fa42e919f27d7316a2c7b61674fb3eb17f3a1bb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 4 Jan 2024 16:13:08 +0200 Subject: [PATCH 207/449] Fix logging configuration again * Only use `tqdm.write()` if `tqdm` is active, defer to stderr * Correct log formatter for TqdmLoggingHandler * If `rich` is installed and `SD_WEBUI_RICH_LOG` is set, use `rich`'s formatter --- modules/logging_config.py | 62 ++++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/modules/logging_config.py b/modules/logging_config.py index 79269875608..11eee9a63db 100644 --- a/modules/logging_config.py +++ b/modules/logging_config.py @@ -1,41 +1,57 @@ -import os import logging +import os try: - from tqdm.auto import tqdm + from tqdm import tqdm + class TqdmLoggingHandler(logging.Handler): - def __init__(self, level=logging.INFO): - super().__init__(level) + def __init__(self, fallback_handler: logging.Handler): + super().__init__() + self.fallback_handler = fallback_handler def emit(self, record): try: - msg = self.format(record) - tqdm.write(msg) - self.flush() + # If there are active tqdm progress bars, + # attempt to not interfere with them. + if tqdm._instances: + tqdm.write(self.format(record)) + else: + self.fallback_handler.emit(record) except Exception: - self.handleError(record) + self.fallback_handler.emit(record) - TQDM_IMPORTED = True except ImportError: - # tqdm does not exist before first launch - # I will import once the UI finishes seting up the enviroment and reloads. - TQDM_IMPORTED = False + TqdmLoggingHandler = None + def setup_logging(loglevel): if loglevel is None: loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL") - loghandlers = [] + if not loglevel: + return + + if logging.root.handlers: + # Already configured, do not interfere + return + + if os.environ.get("SD_WEBUI_RICH_LOG"): + from rich.logging import RichHandler + handler = RichHandler() + else: + handler = logging.StreamHandler() + + if TqdmLoggingHandler: + handler = TqdmLoggingHandler(handler) + + formatter = logging.Formatter( + '%(asctime)s %(levelname)s [%(name)s] %(message)s', + '%Y-%m-%d %H:%M:%S', + ) - if TQDM_IMPORTED: - loghandlers.append(TqdmLoggingHandler()) + handler.setFormatter(formatter) - if loglevel: - log_level = getattr(logging, loglevel.upper(), None) or logging.INFO - logging.basicConfig( - level=log_level, - format='%(asctime)s %(levelname)s [%(name)s] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - handlers=loghandlers - ) + log_level = getattr(logging, loglevel.upper(), None) or logging.INFO + logging.root.setLevel(log_level) + logging.root.addHandler(handler) From f8f38c7c28e48f9f79225c969e3e82b1adcfb910 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:31:48 +0800 Subject: [PATCH 208/449] Fix dtype casting for OFT module --- extensions-builtin/Lora/network_oft.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index fa647020f0a..342fcd0dcaf 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -56,7 +56,7 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + oft_blocks = self.oft_blocks.to(orig_weight.device) eye = torch.eye(self.block_size, device=self.oft_blocks.device) if self.is_kohya: @@ -66,7 +66,7 @@ def calc_updown(self, orig_weight): block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) - R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + R = oft_blocks.to(orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -77,6 +77,6 @@ def calc_updown(self, orig_weight): ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) From 18ca987c92f52690daec43a6c67363c341bb6008 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:32:19 +0800 Subject: [PATCH 209/449] Add general forward method for all modules. --- extensions-builtin/Lora/network.py | 34 ++++++++++++++++++++++++++++- extensions-builtin/Lora/networks.py | 12 +++++----- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index a62e5eff9e1..f9b571b5fa1 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -3,6 +3,10 @@ from collections import namedtuple import enum +import torch +import torch.nn as nn +import torch.nn.functional as F + from modules import sd_models, cache, errors, hashes, shared NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) @@ -115,6 +119,29 @@ def __init__(self, net: Network, weights: NetworkWeights): if hasattr(self.sd_module, 'weight'): self.shape = self.sd_module.weight.shape + self.ops = None + self.extra_kwargs = {} + if isinstance(self.sd_module, nn.Conv2d): + self.ops = F.conv2d + self.extra_kwargs = { + 'stride': self.sd_module.stride, + 'padding': self.sd_module.padding + } + elif isinstance(self.sd_module, nn.Linear): + self.ops = F.linear + elif isinstance(self.sd_module, nn.LayerNorm): + self.ops = F.layer_norm + self.extra_kwargs = { + 'normalized_shape': self.sd_module.normalized_shape, + 'eps': self.sd_module.eps + } + elif isinstance(self.sd_module, nn.GroupNorm): + self.ops = F.group_norm + self.extra_kwargs = { + 'num_groups': self.sd_module.num_groups, + 'eps': self.sd_module.eps + } + self.dim = None self.bias = weights.w.get("bias") self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None @@ -155,5 +182,10 @@ def calc_updown(self, target): raise NotImplementedError() def forward(self, x, y): - raise NotImplementedError() + """A general forward implementation for all modules""" + if self.ops is None: + raise NotImplementedError() + else: + updown, ex_bias = self.calc_updown(self.sd_module.weight) + return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 72ebd624169..32e10b62544 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -458,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.network_current_names = wanted_names -def network_forward(module, input, original_forward): +def network_forward(org_module, input, original_forward): """ Old way of applying Lora by executing operations during layer's forward. Stacking many loras this way results in big performance degradation. """ if len(loaded_networks) == 0: - return original_forward(module, input) + return original_forward(org_module, input) input = devices.cond_cast_unet(input) - network_restore_weights_from_backup(module) - network_reset_cached_weight(module) + network_restore_weights_from_backup(org_module) + network_reset_cached_weight(org_module) - y = original_forward(module, input) + y = original_forward(org_module, input) - network_layer_name = getattr(module, 'network_layer_name', None) + network_layer_name = getattr(org_module, 'network_layer_name', None) for lora in loaded_networks: module = lora.modules.get(network_layer_name, None) if module is None: From 44744d6005da5e424267698ee3279caa597dfebc Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:38:05 +0800 Subject: [PATCH 210/449] linting --- extensions-builtin/Lora/network.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index f9b571b5fa1..b8fd91941c8 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -3,7 +3,6 @@ from collections import namedtuple import enum -import torch import torch.nn as nn import torch.nn.functional as F @@ -124,7 +123,7 @@ def __init__(self, net: Network, weights: NetworkWeights): if isinstance(self.sd_module, nn.Conv2d): self.ops = F.conv2d self.extra_kwargs = { - 'stride': self.sd_module.stride, + 'stride': self.sd_module.stride, 'padding': self.sd_module.padding } elif isinstance(self.sd_module, nn.Linear): From 88ba095fd022fbc1acbbbc845e23493a5b2c5d8b Mon Sep 17 00:00:00 2001 From: Keshav Nischal <91429557+keshav-nischal@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:15:58 +0530 Subject: [PATCH 211/449] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9f9f33b1295..7f1ea457522 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # Stable Diffusion web UI -A browser interface based on Gradio library for Stable Diffusion. +A web interface for Stable Diffusion, implemented using Gradio library. ![](screenshot.png) From 233c66b36eba07b905f9743c2ad807aec33d9ccb Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 5 Jan 2024 12:28:32 +0300 Subject: [PATCH 212/449] Make the upscale button update the gallery with the new image rather than replace it. --- modules/processing.py | 6 ++- modules/txt2img.py | 85 +++++++++++++++++++++++++++++-------------- 2 files changed, 63 insertions(+), 28 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 84e7b1b4577..dcc807fe38d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -732,7 +732,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), - "Denoising strength": getattr(p, 'denoising_strength', None), + "Denoising strength": p.extra_generation_params.get("Denoising strength"), "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": opts.eta_noise_seed_delta if uses_ensd else None, @@ -1198,6 +1198,8 @@ def calculate_target_resolution(self): def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: + self.extra_generation_params["Denoising strength"] = self.denoising_strength + if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint': self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) @@ -1516,6 +1518,8 @@ def mask_blur(self, value): self.mask_blur_y = value def init(self, all_prompts, all_seeds, all_subseeds): + self.extra_generation_params["Denoising strength"] = self.denoising_strength + self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) diff --git a/modules/txt2img.py b/modules/txt2img.py index 41bb9da33dd..c4cc12d2f6d 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -7,36 +7,15 @@ from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html +from PIL import Image import gradio as gr -def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args): - assert len(gallery) > 0, 'No image to upscale' - assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}' - - geninfo = json.loads(generation_info) - all_seeds = geninfo["all_seeds"] - - image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] - image = infotext_utils.image_from_url_text(image_info) - - gallery_index_from_end = len(gallery) - gallery_index - image.seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] - - return txt2img(id_task, request, *args, firstpass_image=image) - - -def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, firstpass_image=None): +def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False): override_settings = create_override_settings_dict(override_settings_texts) - if firstpass_image is not None: - seed = getattr(firstpass_image, 'seed', None) - if seed: - args = modules.scripts.scripts_txt2img.set_named_arg(args, 'ScriptSeed', 'seed', seed) - + if force_enable_hr: enable_hr = True - batch_size = 1 - n_iter = 1 p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, @@ -53,7 +32,7 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str width=width, height=height, enable_hr=enable_hr, - denoising_strength=denoising_strength if enable_hr else None, + denoising_strength=denoising_strength, hr_scale=hr_scale, hr_upscaler=hr_upscaler, hr_second_pass_steps=hr_second_pass_steps, @@ -64,7 +43,6 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, override_settings=override_settings, - firstpass_image=firstpass_image, ) p.scripts = modules.scripts.scripts_txt2img @@ -75,8 +53,61 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str if shared.opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) + return p + + +def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args): + assert len(gallery) > 0, 'No image to upscale' + assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}' + + p = txt2img_create_processing(id_task, request, *args) + p.enable_hr = True + p.batch_size = 1 + p.n_iter = 1 + + geninfo = json.loads(generation_info) + all_seeds = geninfo["all_seeds"] + + image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] + p.firstpass_image = infotext_utils.image_from_url_text(image_info) + + gallery_index_from_end = len(gallery) - gallery_index + seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] + p.script_args = modules.scripts.scripts_txt2img.set_named_arg(p.script_args, 'ScriptSeed', 'seed', seed) + + with closing(p): + processed = modules.scripts.scripts_txt2img.run(p, *p.script_args) + + if processed is None: + processed = processing.process_images(p) + + shared.total_tqdm.clear() + + new_gallery = [] + for i, image in enumerate(gallery): + fake_image = Image.new(mode="RGB", size=(1, 1)) + + if i == gallery_index: + already_saved_as = getattr(processed.images[0], 'already_saved_as', None) + if already_saved_as is not None: + fake_image.already_saved_as = already_saved_as + else: + fake_image = processed.images[0] + else: + fake_image.already_saved_as = image["name"] + + new_gallery.append(fake_image) + + geninfo["infotexts"][gallery_index] = processed.info + + return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments") + + +def txt2img(id_task: str, request: gr.Request, *args): + p = txt2img_create_processing(id_task, request, *args) + with closing(p): - processed = modules.scripts.scripts_txt2img.run(p, *args) + processed = modules.scripts.scripts_txt2img.run(p, *p.script_args) if processed is None: processed = processing.process_images(p) From 16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 16:32:18 +0800 Subject: [PATCH 213/449] [IPEX] Fix SDPA attn_mask dtype --- modules/xpu_specific.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index f7687a66cc2..4e11125b20b 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention( # cast to same dtype first key = key.to(query.dtype) value = value.to(query.dtype) + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(query.dtype) N = query.shape[:-2] # Batch size L = query.size(-2) # Target sequence length From ec9acb31450536a9258192351d6a26421efd7eb4 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 17:17:53 +0800 Subject: [PATCH 214/449] Handle CondFunc exception when resolving attributes --- modules/sd_hijack_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py index f8684475ec5..79bf6e46862 100644 --- a/modules/sd_hijack_utils.py +++ b/modules/sd_hijack_utils.py @@ -11,10 +11,14 @@ def __new__(cls, orig_func, sub_func, cond_func): break except ImportError: pass - for attr_name in func_path[i:-1]: - resolved_obj = getattr(resolved_obj, attr_name) - orig_func = getattr(resolved_obj, func_path[-1]) - setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + try: + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + except AttributeError: + print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack") + pass self.__init__(orig_func, sub_func, cond_func) return lambda *args, **kwargs: self(*args, **kwargs) def __init__(self, orig_func, sub_func, cond_func): From 73786c047f14d6ae658b2c12f493f05486ba1789 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 19:09:56 +0800 Subject: [PATCH 215/449] [IPEX] Fix torch.Generator hijack --- modules/xpu_specific.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 4e11125b20b..1137891a61c 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -94,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention( return torch.reshape(result, (*N, L, Ev)) +def is_xpu_device(device: str | torch.device = None): + if device is None: + return False + if isinstance(device, str): + return device.startswith("xpu") + return device.type == "xpu" + + if has_xpu: - # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: device is not None and device.type == "xpu") + try: + # torch.Generator supports "xpu" device since 2.1 + torch.Generator("xpu") + except: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1) + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: is_xpu_device(device)) # W/A for some OPs that could not handle different input dtypes CondFunc('torch.nn.functional.layer_norm', From 818d6a11e709bf07d48606bdccab944c46a5f4b0 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 19:14:06 +0800 Subject: [PATCH 216/449] Fix format --- modules/xpu_specific.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 1137891a61c..2971dbc3cf5 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -106,8 +106,8 @@ def is_xpu_device(device: str | torch.device = None): try: # torch.Generator supports "xpu" device since 2.1 torch.Generator("xpu") - except: - # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1) + except RuntimeError: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1) CondFunc('torch.Generator', lambda orig_func, device=None: torch.xpu.Generator(device), lambda orig_func, device=None: is_xpu_device(device)) From a183de04e3f965083e7f3462201327d30c36b958 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 20:03:33 +0800 Subject: [PATCH 217/449] Execute model_loaded_callback after moving to target device --- modules/sd_models.py | 6 +++--- modules/sd_vae.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 50bc209e465..2c04577152c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -842,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): sd_hijack.model_hijack.hijack(sd_model) timer.record("hijack") - script_callbacks.model_loaded_callback(sd_model) - timer.record("script callbacks") - if not sd_model.lowvram: sd_model.to(devices.device) timer.record("move model to device") + script_callbacks.model_loaded_callback(sd_model) + timer.record("script callbacks") + print(f"Weights loaded in {timer.summary()}.") model_data.set_sd_model(sd_model) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 31306d8ba4b..43687e48dcf 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -273,10 +273,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) if not sd_model.lowvram: sd_model.to(devices.device) + script_callbacks.model_loaded_callback(sd_model) + print("VAE weights loaded.") return sd_model From 2f98a35fc4508494355c01ec45f5bec725f570a6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 7 Jan 2024 09:21:21 +0300 Subject: [PATCH 218/449] add assets repo; serve fonts locally rather than from google's servers --- modules/launch_utils.py | 3 +++ modules/sysinfo.py | 2 ++ modules/ui.py | 4 +++- style.css | 2 +- 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index c2a7ae932ce..8e58d714550 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -344,11 +344,13 @@ def prepare_environment(): clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") + assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') + assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") @@ -405,6 +407,7 @@ def prepare_environment(): os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) + git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) diff --git a/modules/sysinfo.py b/modules/sysinfo.py index 5abf616b707..f336251e445 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -24,9 +24,11 @@ "XFORMERS_PACKAGE", "CLIP_PACKAGE", "OPENCLIP_PACKAGE", + "ASSETS_REPO", "STABLE_DIFFUSION_REPO", "K_DIFFUSION_REPO", "BLIP_REPO", + "ASSETS_COMMIT_HASH", "STABLE_DIFFUSION_COMMIT_HASH", "K_DIFFUSION_COMMIT_HASH", "BLIP_COMMIT_HASH", diff --git a/modules/ui.py b/modules/ui.py index 2d2e333b293..a716a04055d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -13,7 +13,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import gradio_extensons # noqa: F401 -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow from modules.paths import script_path from modules.ui_common import create_refresh_button @@ -1223,3 +1223,5 @@ def download_sysinfo(attachment=False): app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"]) app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"]) + import fastapi.staticfiles + app.mount("/webui-assets", fastapi.staticfiles.StaticFiles(directory=launch_utils.repo_dir('stable-diffusion-webui-assets')), name="webui-assets") diff --git a/style.css b/style.css index 6d4c8a0d505..4957c523df3 100644 --- a/style.css +++ b/style.css @@ -1,6 +1,6 @@ /* temporary fix to load default gradio font in frontend instead of backend */ -@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap'); +@import url('webui-assets/css/sourcesanspro.css'); /* temporary fix to hide gradio crop tool until it's fixed https://github.com/gradio-app/gradio/issues/3810 */ From 425507bd10c55f1f804eb5015db74520668f46f9 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Sun, 7 Jan 2024 10:25:01 -0600 Subject: [PATCH 219/449] add p to cfgdenoiserparams --- modules/script_callbacks.py | 5 ++++- modules/sd_samplers_cfg_denoiser.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 9ed7ad21d1b..bb47c18d3f1 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -41,7 +41,7 @@ def __init__(self, noise, x, xi): class CFGDenoiserParams: - def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, p): self.x = x """Latent image representation in the process of being denoised""" @@ -63,6 +63,9 @@ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, te self.text_uncond = text_uncond """ Encoder hidden states of text conditioning from negative prompt""" + self.p = p + """StableDiffusionProcessing object with processing parameters""" + class CFGDenoisedParams: def __init__(self, x, sampling_step, total_sampling_steps, inner_model): diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index eb9d5dafacc..f4ded6bdbe9 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -146,7 +146,7 @@ def apply_blend(current_latent): sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self.p) cfg_denoiser_callback(denoiser_params) x_in = denoiser_params.x image_cond_in = denoiser_params.image_cond From f56cebf5ba24313447b2204c3f804379767201c9 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Sun, 7 Jan 2024 12:35:35 -0600 Subject: [PATCH 220/449] add self instead --- modules/script_callbacks.py | 6 +++--- modules/sd_samplers_cfg_denoiser.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index bb47c18d3f1..053dfc96309 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -41,7 +41,7 @@ def __init__(self, noise, x, xi): class CFGDenoiserParams: - def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, p): + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser): self.x = x """Latent image representation in the process of being denoised""" @@ -63,8 +63,8 @@ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, te self.text_uncond = text_uncond """ Encoder hidden states of text conditioning from negative prompt""" - self.p = p - """StableDiffusionProcessing object with processing parameters""" + self.denoiser = denoiser + """Current CFGDenoiser object with processing parameters""" class CFGDenoisedParams: diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index f4ded6bdbe9..6d76aa965a5 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -146,7 +146,7 @@ def apply_blend(current_latent): sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self.p) + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self) cfg_denoiser_callback(denoiser_params) x_in = denoiser_params.x image_cond_in = denoiser_params.image_cond From 37906e429ae5a92f9ea96ab6dc2157b1d7c4d8b6 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 7 Jan 2024 20:17:42 -0600 Subject: [PATCH 221/449] make denoiser None by default --- modules/script_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 053dfc96309..a54cb3ebb56 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -41,7 +41,7 @@ def __init__(self, noise, x, xi): class CFGDenoiserParams: - def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser): + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser=None): self.x = x """Latent image representation in the process of being denoised""" From 8e292373ec5c54493ce48af7c76f5eaa79dc8abd Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Mon, 8 Jan 2024 06:43:39 -0600 Subject: [PATCH 222/449] lcm sampler --- modules/sd_samplers.py | 3 +- modules/sd_samplers_lcm.py | 104 +++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 modules/sd_samplers_lcm.py diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 45faae62821..a58528a0b52 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,4 +1,4 @@ -from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared +from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 @@ -6,6 +6,7 @@ all_samplers = [ *sd_samplers_kdiffusion.samplers_data_k_diffusion, *sd_samplers_timesteps.samplers_data_timesteps, + *sd_samplers_lcm.samplers_data_lcm, ] all_samplers_map = {x.name: x for x in all_samplers} diff --git a/modules/sd_samplers_lcm.py b/modules/sd_samplers_lcm.py new file mode 100644 index 00000000000..59839b720dd --- /dev/null +++ b/modules/sd_samplers_lcm.py @@ -0,0 +1,104 @@ +import torch + +from k_diffusion import utils, sampling +from k_diffusion.external import DiscreteEpsDDPMDenoiser +from k_diffusion.sampling import default_noise_sampler, trange + +from modules import shared, sd_samplers_cfg_denoiser, sd_samplers_kdiffusion, sd_samplers_common + + +class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser): + def __init__(self, model): + timesteps = 1000 + original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM) + self.skip_steps = timesteps // original_timesteps + + alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device) + for x in range(original_timesteps): + alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps] + + super().__init__(model, alphas_cumprod_valid, quantize=None) + + + def get_sigmas(self, n=None,): + if n is None: + return sampling.append_zero(self.sigmas.flip(0)) + + start = self.sigma_to_t(self.sigma_max) + end = self.sigma_to_t(self.sigma_min) + + t = torch.linspace(start, end, n, device=shared.sd_model.device) + + return sampling.append_zero(self.t_to_sigma(t)) + + + def sigma_to_t(self, sigma, quantize=None): + log_sigma = sigma.log() + dists = log_sigma - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) + + + def t_to_sigma(self, timestep): + t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) + return super().t_to_sigma(t) + + + def get_eps(self, *args, **kwargs): + return self.inner_model.apply_model(*args, **kwargs) + + + def get_scaled_out(self, sigma, output, input): + sigma_data = 0.5 + scaled_timestep = utils.append_dims(self.sigma_to_t(sigma), output.ndim) * 10.0 + + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 + + return c_out * output + c_skip * input + + + def forward(self, input, sigma, **kwargs): + c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) + return self.get_scaled_out(sigma, input + eps * c_out, input) + + +def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + x = denoised + if sigmas[i + 1] > 0: + x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) + return x + + +class CFGDenoiserLCM(sd_samplers_cfg_denoiser.CFGDenoiser): + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = LCMCompVisDenoiser + self.model_wrap = denoiser(shared.sd_model) + + return self.model_wrap + + +class LCMSampler(sd_samplers_kdiffusion.KDiffusionSampler): + def __init__(self, funcname, sd_model, options=None): + super().__init__(funcname, sd_model, options) + self.model_wrap_cfg = CFGDenoiserLCM(self) + self.model_wrap = self.model_wrap_cfg.inner_model + + +samplers_lcm = [('LCM', sample_lcm, ['k_lcm'], {})] +samplers_data_lcm = [ + sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: LCMSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_lcm +] From df8aa69a99e38ae59a4e599b9dff11eccf3490f4 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 8 Jan 2024 14:10:03 -0500 Subject: [PATCH 223/449] Add tree-view display for extra networks. --- html/extra-networks-card-minimal.html | 3 + html/extra-networks-card.html | 1 + javascript/extraNetworks.js | 5 + modules/shared_options.py | 1 + modules/ui_extra_networks.py | 352 ++++++++++++++++++-------- style.css | 92 +++++++ 6 files changed, 355 insertions(+), 99 deletions(-) create mode 100644 html/extra-networks-card-minimal.html diff --git a/html/extra-networks-card-minimal.html b/html/extra-networks-card-minimal.html new file mode 100644 index 00000000000..a6a54d9f4ac --- /dev/null +++ b/html/extra-networks-card-minimal.html @@ -0,0 +1,3 @@ +
+ {name}{copy_path_button}{metadata_button}{edit_button} +
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 39674666f1e..d76011d737c 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,6 +1,7 @@
{background_image}
+ {copy_path_button} {metadata_button} {edit_button}
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 98a7abb745c..40309d5575c 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -337,6 +337,11 @@ function requestGet(url, data, handler, errorHandler) { xhr.send(js); } +function extraNetworksCopyCardPath(event, path) { + navigator.clipboard.writeText(path); + event.stopPropagation(); +} + function extraNetworksRequestMetadata(event, extraPage, cardName) { var showError = function() { extraNetworksShowMetadata("there was an error getting metadata"); diff --git a/modules/shared_options.py b/modules/shared_options.py index d2e86ff10b3..e698c26498c 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -238,6 +238,7 @@ "extra_networks_dir_button_function": OptionInfo(False, "Add a '/' to the beginning of directory buttons").info("Buttons will display the contents of the selected directory without acting as a search filter."), "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), + "extra_networks_tree_view": OptionInfo(False, "Show extra networks using a directory tree view.").needs_reload_ui(), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index fe5d3ba3338..8667617b9f5 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,6 +2,8 @@ import os.path import urllib.parse from pathlib import Path +from typing import Optional, Union +from dataclasses import dataclass from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules.images import read_info_from_image, save_image_with_geninfo @@ -15,10 +17,8 @@ extra_pages = [] allowed_dirs = set() - default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] - @functools.cache def allowed_preview_extensions_with_extra(extra_extensions=None): return set(default_allowed_preview_extensions) | set(extra_extensions or []) @@ -28,6 +28,58 @@ def allowed_preview_extensions(): return allowed_preview_extensions_with_extra((shared.opts.samples_format, )) +@dataclass +class ExtraNetworksItem: + """Wrapper for dictionaries representing ExtraNetworks items.""" + item: dict + + +def get_tree(paths: Union[str, list[str]], items: dict[str, ExtraNetworksItem]) -> dict: + """Recursively builds a directory tree. + + Args: + paths: Path or list of paths to directories. These paths are treated as roots from which + the tree will be built. + items: A dictionary associating filepaths to an ExtraNetworksItem instance. + + Returns: + The result directory tree. + """ + if isinstance(paths, (str,)): + paths = [paths] + + def _get_tree(_paths: list[str]): + _res = {} + for path in _paths: + if os.path.isdir(path): + dir_items = os.listdir(path) + # Ignore empty directories. + if not dir_items: + continue + dir_tree = _get_tree([os.path.join(path, x) for x in dir_items]) + # We only want to store non-empty folders in the tree. + if dir_tree: + _res[os.path.basename(path)] = dir_tree + else: + if path not in items: + continue + # Add the ExtraNetworksItem to the result. + _res[os.path.basename(path)] = items[path] + return _res + + res = {} + # Handle each root directory separately. + # Each root WILL have a key/value at the root of the result dict though + # the value can be an empty dict if the directory is empty. We want these + # placeholders for empty dirs so we can inform the user later. + for path in paths: + # Wrap the path in a list since that is what the `_get_tree` expects. + res[path] = _get_tree([path]) + if res[path]: + res[path] = res[path][os.path.basename(path)] + + return res + def register_page(page): """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" @@ -80,7 +132,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""): item = page.items.get(name) page.read_user_metadata(item) - item_html = page.create_html_for_item(item, tabname) + item_html = page.create_item_html(tabname, item) return JSONResponse({"html": item_html}) @@ -96,13 +148,15 @@ def quote_js(s): s = s.replace('"', '\\"') return f'"{s}"' - class ExtraNetworksPage: def __init__(self, title): self.title = title self.name = title.lower() self.id_page = self.name.replace(" ", "_") - self.card_page = shared.html("extra-networks-card.html") + if shared.opts.extra_networks_tree_view: + self.card_page = shared.html("extra-networks-card-minimal.html") + else: + self.card_page = shared.html("extra-networks-card.html") self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} @@ -136,12 +190,141 @@ def search_terms_from_path(self, filename, possible_directories=None): return "" - def create_html(self, tabname): - items_html = '' + def create_item_html(self, tabname: str, item: dict) -> str: + """Generates HTML for a single ExtraNetworks Item + + Args: + tabname: The name of the active tab. + item: Dictionary containing item information. + + Returns: + HTML string generated for this item. + Can be empty if the item is not meant to be shown. + """ + metadata = item.get("metadata") + if metadata: + self.metadata[item["name"]] = metadata + + if "user_metadata" not in item: + self.read_user_metadata(item) + + preview = item.get("preview", None) + height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' + width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' + background_image = f'' if preview else '' + + onclick = item.get("onclick", None) + if onclick is None: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + + copy_path_button = f"
" + + metadata_button = "" + metadata = item.get("metadata") + if metadata: + metadata_button = f"" + + edit_button = f"
" + + local_path = "" + filename = item.get("filename", "") + for reldir in self.allowed_directories_for_previews(): + absdir = os.path.abspath(reldir) + + if filename.startswith(absdir): + local_path = filename[len(absdir):] + + # if this is true, the item must not be shown in the default view, and must instead only be + # shown when searching for it + if shared.opts.extra_networks_hidden_models == "Always": + search_only = False + else: + search_only = "/." in local_path or "\\." in local_path + + if search_only and shared.opts.extra_networks_hidden_models == "Never": + return "" + + sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip() + + # Some items here might not be used depending on HTML template used. + args = { + "background_image": background_image, + "card_clicked": onclick, + "copy_path_button": copy_path_button, + "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), + "edit_button": edit_button, + "local_preview": quote_js(item["local_preview"]), + "metadata_button": metadata_button, + "name": html.escape(item["name"]), + "prompt": item.get("prompt", None), + "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', + "search_only": " search_only" if search_only else "", + "search_term": item.get("search_term", ""), + "sort_keys": sort_keys, + "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", + "tabname": quote_js(tabname), + } + + return self.card_page.format(**args) + + def create_tree_view_html(self, tabname: str) -> str: + """Generates HTML for displaying folders in a tree view. + + Args: + tabname: The name of the active tab. + + Returns: + HTML string generated for this tree view. + """ + self_name_id = self.name.replace(" ", "_") + res = f"
" self.metadata = {} + self.items = {x["name"]: x for x in self.list_items()} + roots = self.allowed_directories_for_previews() + tree_items = {v["filename"]: ExtraNetworksItem(v) for v in self.items.values()} + tree = get_tree([os.path.abspath(x) for x in roots], items=tree_items) + + if not tree: + return res + "
" + + file_template = "
  • {}
  • " + dir_template = ( + "
    " + "{}" + "{}" + "
    " + ) + + def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> str: + """Recursively builds HTML for a tree.""" + _res = "
      " + if not data: + return "
      • DIRECTORY IS EMPTY
      " + + for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])): + if isinstance(v, (ExtraNetworksItem,)): + _res += file_template.format(self.create_item_html(tabname, v.item)) + else: + _res += dir_template.format("", k, _build_tree(v)) + return _res + + res += "
        " + # Add each root directory to the tree. + for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): + # If root is empty, append the "disabled" attribute to the template details tag. + res += dir_template.format("open" if v else "open disabled", k, _build_tree(v)) + res += "
      " + res += "
    " + + return res + + def create_card_view_html(self, tabname): + items_html = "" + self.metadata = {} subdirs = {} + for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])): for dirname in sorted(dirs, key=shared.natural_sort_key): @@ -171,40 +354,45 @@ def create_html(self, tabname): if subdirs: subdirs = {"": 1, **subdirs} - subdirs_html = "".join([f""" - -""" for subdir in subdirs]) + subdirs_html_template = ( + "" + ) + subdirs_html = "".join( + [ + subdirs_html_template.format( + " search-all" if subdir == "" else "", + tabname, + html.escape(subdir if subdir != "" else "all"), + ) for subdir in subdirs + ] + ) self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): - metadata = item.get("metadata") - if metadata: - self.metadata[item["name"]] = metadata - - if "user_metadata" not in item: - self.read_user_metadata(item) - - items_html += self.create_html_for_item(item, tabname) + items_html += self.create_item_html(tabname, item) - if items_html == '': + if items_html == "": dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) self_name_id = self.name.replace(" ", "_") - res = f""" -
    -{subdirs_html} -
    -
    -{items_html} -
    -""" + res = ( + f"
    {subdirs_html}
    " + f"
    {items_html}
    " + ) return res + def create_html(self, tabname): + if shared.opts.extra_networks_tree_view: + return self.create_tree_view_html(tabname) + else: + return self.create_card_view_html(tabname) + def create_item(self, name, index=None): raise NotImplementedError() @@ -214,66 +402,6 @@ def list_items(self): def allowed_directories_for_previews(self): return [] - def create_html_for_item(self, item, tabname): - """ - Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown. - """ - - preview = item.get("preview", None) - - onclick = item.get("onclick", None) - if onclick is None: - onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' - - height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' - width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' - background_image = f'' if preview else '' - metadata_button = "" - metadata = item.get("metadata") - if metadata: - metadata_button = f"" - - edit_button = f"
    " - - local_path = "" - filename = item.get("filename", "") - for reldir in self.allowed_directories_for_previews(): - absdir = os.path.abspath(reldir) - - if filename.startswith(absdir): - local_path = filename[len(absdir):] - - # if this is true, the item must not be shown in the default view, and must instead only be - # shown when searching for it - if shared.opts.extra_networks_hidden_models == "Always": - search_only = False - else: - search_only = "/." in local_path or "\\." in local_path - - if search_only and shared.opts.extra_networks_hidden_models == "Never": - return "" - - sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip() - - args = { - "background_image": background_image, - "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", - "prompt": item.get("prompt", None), - "tabname": quote_js(tabname), - "local_preview": quote_js(item["local_preview"]), - "name": html.escape(item["name"]), - "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), - "card_clicked": onclick, - "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', - "search_term": item.get("search_term", ""), - "metadata_button": metadata_button, - "edit_button": edit_button, - "search_only": " search_only" if search_only else "", - "sort_keys": sort_keys, - } - - return self.card_page.format(**args) - def get_sort_keys(self, path): """ List of default keys used for sorting in the UI. @@ -360,7 +488,6 @@ def tab_name_score(name): return sorted(pages, key=lambda x: tab_scores[x.name]) - def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): from modules.ui import switch_values_symbol @@ -381,7 +508,6 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): elem_id = f"{tabname}_{page.id_page}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) - page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) editor = page.create_user_metadata_editor(ui, tabname) @@ -390,30 +516,60 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) + tab_controls = {} + edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) + + tab_controls["edit_search"] = edit_search + tab_controls["dropdown_sort"] = dropdown_sort + tab_controls["button_sortorder"] = button_sortorder + tab_controls["button_refresh"] = button_refresh + tab_controls["checkbox_show_dirs"] = checkbox_show_dirs ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs] - for tab in unrelated_tabs: - tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False) + tab.select( + fn=lambda: [gr.update(visible=False) for _ in tab_controls], + _js="function(){ extraNetworksUrelatedTabSelected('" + tabname + "'); }", + inputs=[], + outputs=list(tab_controls.values()), + show_progress=False, + ) + + visible_controls = list(tab_controls.keys()) + if shared.opts.extra_networks_tree_view: + visible_controls = ["button_refresh"] for page, tab in zip(ui.stored_extra_pages, related_tabs): allow_prompt = "true" if page.allow_prompt else "false" allow_negative_prompt = "true" if page.allow_negative_prompt else "false" - jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');' - - tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False) + jscode = ( + "extraNetworksTabSelected(" + f"'{tabname}', " + f"'{tabname}_{page.id_page}_prompts', " + f"'{allow_prompt}', " + f"'{allow_negative_prompt}'" + ");" + ) + + tab.select( + fn=lambda: [gr.update(visible=k in visible_controls) for k in tab_controls], + _js="function(){ " + jscode + " }", + inputs=[], + outputs=list(tab_controls.values()), + show_progress=False, + ) dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }") + def pages_html(): if not ui.pages_contents: return refresh() @@ -478,5 +634,3 @@ def save_preview(index, images, filename): for editor in ui.user_metadata_editors: editor.setup_ui(gallery) - - diff --git a/style.css b/style.css index ee39a57b73e..680a5f837b7 100644 --- a/style.css +++ b/style.css @@ -1157,3 +1157,95 @@ body.resizing .resize-handle { left: 7.5px; border-left: 1px dashed var(--border-color-primary); } + +.extra-network-cards .card .copy-path-button:before { + content: "⎘"; +} + +.extra-network-cards .card-minimal .button-column { + display: inline-flex; + visibility: hidden; + color: white; + padding-left: 0.5rem; + padding-right: 0.5rem; + align-items: center; +} + +.extra-network-cards .card-minimal:hover .button-column { + visibility: visible; +} + +.extra-network-cards .card-minimal .copy-path-button:before { + content: "⎘"; +} + +.extra-network-cards .card-minimal .metadata-button:before{ + content: "🛈"; +} + +.extra-network-cards .card-minimal .edit-button:before{ + content: "🛠"; +} + +.extra-network-cards .card-minimal .card-button { + color: white; + text-shadow: 2px 2px 3px black; + font-size: 1rem; + width: 1.5rem; +} + +.extra-network-cards .card-minimal .card-button:hover { + color: red; +} + +.extra-network-cards .card-minimal { + display: inline-flex; + position: relative; + overflow: hidden; + cursor: pointer; + font-size: 1rem; + font-weight: bold; + line-break: anywhere; +} + +.file-item { + list-style-type: '📄'; +} + +/* prevents clicking/collapsing of details tags when disabled attribute is used*/ +details[disabled] summary { + pointer-events: none; + user-select: none; +} + +details.folder-item > summary { + list-style-type: '📁'; +} + +details.folder-item[open] > summary { + list-style-type: '📂'; +} + +.file-item, +.folder-item, +.folder-item-summary { + display: block; + font-size: 1rem; + padding: 0.05rem; + cursor: pointer; + user-select: none; +} + +.folder-item-summary:hover, +.file-item:hover { + -webkit-transition: all 0.1s ease-in-out; + transition: all 0.1s ease-in-out; + background-color: var(--neutral-200); +} + +.dark .folder-item-summary:hover, +.dark .file-item:hover { + -webkit-transition: all 0.05s ease-in-out; + transition: all 0.05s ease-in-out; + background-color: var(--neutral-800); +} From 67a70ad112e1b0fb262852ab830896302dbd306a Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 8 Jan 2024 14:11:24 -0500 Subject: [PATCH 224/449] fix indentation --- html/extra-networks-card.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index d76011d737c..7770094da14 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,7 +1,7 @@
    {background_image}
    - {copy_path_button} + {copy_path_button} {metadata_button} {edit_button}
    From 34fc215249e2bc0acc66cda47319e40b6e46a05f Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 8 Jan 2024 14:23:01 -0500 Subject: [PATCH 225/449] fix linting --- modules/ui_extra_networks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 8667617b9f5..ab484a5dc0b 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -216,7 +216,7 @@ def create_item_html(self, tabname: str, item: dict) -> str: onclick = item.get("onclick", None) if onclick is None: onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' - + copy_path_button = f"
    " metadata_button = "" @@ -523,7 +523,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) - + tab_controls["edit_search"] = edit_search tab_controls["dropdown_sort"] = dropdown_sort tab_controls["button_sortorder"] = button_sortorder @@ -560,7 +560,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ) tab.select( - fn=lambda: [gr.update(visible=k in visible_controls) for k in tab_controls], + fn=lambda: [gr.update(visible=k in visible_controls) for k in tab_controls], _js="function(){ " + jscode + " }", inputs=[], outputs=list(tab_controls.values()), From 46413b20a4d88355b5fac5d27cc9fcb634cdb49e Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 8 Jan 2024 16:23:35 -0500 Subject: [PATCH 226/449] Increase limits for upscalers. --- modules/api/models.py | 2 +- scripts/postprocessing_upscale.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/api/models.py b/modules/api/models.py index 33894b3e694..11ba4f0b952 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -143,7 +143,7 @@ class ExtrasBaseRequest(BaseModel): gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") - upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.") + upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?") diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index ed709688de4..5946678e5be 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -21,13 +21,13 @@ def ui(self): with FormRow(): with gr.Tabs(elem_id="extras_resize_mode"): with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + upscaling_resize = gr.Slider(minimum=1.0, maximum=100.0, step=0.05, label="Resize", value=2, elem_id="extras_upscaling_resize") with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: with FormRow(): with gr.Column(elem_id="upscaling_column_size", scale=4): - upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h") + upscaling_resize_w = gr.Slider(minimum=64, maximum=8192, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Slider(minimum=64, maximum=8192, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h") with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"): upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height") upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") From 8d986727b39ee6616190559efa1c41c1942b99b0 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 9 Jan 2024 03:01:20 -0600 Subject: [PATCH 227/449] include tls arguments in api uvicorn init --- modules/api/api.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 9d1292e9571..59e46335231 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -879,7 +879,15 @@ def get_extensions_list(self): def launch(self, server_name, port, root_path): self.app.include_router(self.router) - uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path) + uvicorn.run( + self.app, + host=server_name, + port=port, + timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, + root_path=root_path, + ssl_keyfile=shared.cmd_opts.tls_keyfile, + ssl_certfile=shared.cmd_opts.tls_certfile + ) def kill_webui(self): restart.stop_program() From 209c26a1cb9e4be357ab3c5e7613caf3cbc26183 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 22:11:44 +0800 Subject: [PATCH 228/449] improve efficiency and support more device --- modules/devices.py | 60 ++++++++++++++++++++++++++++++------------ modules/shared_init.py | 1 + 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ff279ac5016..6edfb127896 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -110,6 +110,7 @@ def enable_tf32(): dtype: torch.dtype = torch.float16 dtype_vae: torch.dtype = torch.float16 dtype_unet: torch.dtype = torch.float16 +dtype_inference: torch.dtype = torch.float16 unet_needs_upcast = False @@ -131,21 +132,49 @@ def cond_cast_float(input): ] -def manual_cast_forward(self, *args, **kwargs): - org_dtype = torch_utils.get_param(self).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result +def manual_cast_forward(target_dtype): + def forward_wrapper(self, *args, **kwargs): + org_dtype = torch_utils.get_param(self).dtype + if not target_dtype == org_dtype == dtype_inference: + self.to(target_dtype) + args = [ + arg.to(target_dtype) + if isinstance(arg, torch.Tensor) + else arg + for arg in args + ] + kwargs = { + k: v.to(target_dtype) + if isinstance(v, torch.Tensor) + else v + for k, v in kwargs.items() + } + + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + + if target_dtype != dtype_inference: + if isinstance(result, tuple): + result = tuple( + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i + for i in result + ) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) + return result + return forward_wrapper @contextlib.contextmanager -def manual_cast(): +def manual_cast(target_dtype): for module_type in patch_module_list: org_forward = module_type.forward - module_type.forward = manual_cast_forward + if module_type == torch.nn.MultiheadAttention and has_xpu(): + module_type.forward = manual_cast_forward(torch.float32) + else: + module_type.forward = manual_cast_forward(target_dtype) module_type.org_forward = org_forward try: yield None @@ -161,15 +190,12 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): - return manual_cast() - - if has_mps() and shared.cmd_opts.precision != "full": - return manual_cast() - - if dtype == torch.float32 or shared.cmd_opts.precision == "full": + if dtype == torch.float32 and shared.cmd_opts.precision == "full": return contextlib.nullcontext() + if has_xpu() or has_mps() or cuda_no_autocast(): + return manual_cast(dtype_inference) + return torch.autocast("cuda") diff --git a/modules/shared_init.py b/modules/shared_init.py index 586be3423d0..935e3a21cf2 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -29,6 +29,7 @@ def initialize(): devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 + devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype shared.device = devices.device shared.weight_load_location = None if cmd_opts.lowram else "cpu" From 42e6df723c68af775b73c9fa4f43f99345348689 Mon Sep 17 00:00:00 2001 From: KohakuBlueleaf Date: Tue, 9 Jan 2024 22:39:39 +0800 Subject: [PATCH 229/449] Fix bugs when arg dtype doesn't match --- modules/devices.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 6edfb127896..e05740524f3 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -134,24 +134,19 @@ def cond_cast_float(input): def manual_cast_forward(target_dtype): def forward_wrapper(self, *args, **kwargs): + if any( + isinstance(arg, torch.Tensor) and arg.dtype != target_dtype + for arg in args + ): + args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + org_dtype = torch_utils.get_param(self).dtype - if not target_dtype == org_dtype == dtype_inference: + if org_dtype != target_dtype: self.to(target_dtype) - args = [ - arg.to(target_dtype) - if isinstance(arg, torch.Tensor) - else arg - for arg in args - ] - kwargs = { - k: v.to(target_dtype) - if isinstance(v, torch.Tensor) - else v - for k, v in kwargs.items() - } - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) + if org_dtype != target_dtype: + self.to(org_dtype) if target_dtype != dtype_inference: if isinstance(result, tuple): From c2c05fcca8f3547783c5440c04ec10cc63c65db5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 22:53:58 +0800 Subject: [PATCH 230/449] linting and debugs --- modules/devices.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index e05740524f3..ad36f6562bc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -140,20 +140,20 @@ def forward_wrapper(self, *args, **kwargs): ): args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - + org_dtype = torch_utils.get_param(self).dtype if org_dtype != target_dtype: self.to(target_dtype) result = self.org_forward(*args, **kwargs) if org_dtype != target_dtype: self.to(org_dtype) - + if target_dtype != dtype_inference: if isinstance(result, tuple): result = tuple( - i.to(dtype_inference) - if isinstance(i, torch.Tensor) - else i + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i for i in result ) elif isinstance(result, torch.Tensor): @@ -185,7 +185,7 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if dtype == torch.float32 and shared.cmd_opts.precision == "full": + if dtype == torch.float32: return contextlib.nullcontext() if has_xpu() or has_mps() or cuda_no_autocast(): From e00365962b17550a42235d1fbe2ad2c7cc4b8961 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:13:34 +0800 Subject: [PATCH 231/449] Apply correct inference precision implementation --- modules/devices.py | 42 +++++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ad36f6562bc..9e1f207c336 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -132,6 +132,21 @@ def cond_cast_float(input): ] +def cast_output(result): + if isinstance(result, tuple): + result = tuple(i.to(dtype_inference) if isinstance(i, torch.Tensor) else i for i in result) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) + return result + + +def autocast_with_cast_output(self, *args, **kwargs): + result = self.org_forward(*args, **kwargs) + if dtype_inference != dtype: + result = cast_output(result) + return result + + def manual_cast_forward(target_dtype): def forward_wrapper(self, *args, **kwargs): if any( @@ -149,15 +164,7 @@ def forward_wrapper(self, *args, **kwargs): self.to(org_dtype) if target_dtype != dtype_inference: - if isinstance(result, tuple): - result = tuple( - i.to(dtype_inference) - if isinstance(i, torch.Tensor) - else i - for i in result - ) - elif isinstance(result, torch.Tensor): - result = result.to(dtype_inference) + result = cast_output(result) return result return forward_wrapper @@ -178,6 +185,20 @@ def manual_cast(target_dtype): module_type.forward = module_type.org_forward +@contextlib.contextmanager +def precision_full_with_autocast(autocast_ctx): + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = autocast_with_cast_output + module_type.org_forward = org_forward + try: + with autocast_ctx: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward + + def autocast(disable=False): if disable: return contextlib.nullcontext() @@ -191,6 +212,9 @@ def autocast(disable=False): if has_xpu() or has_mps() or cuda_no_autocast(): return manual_cast(dtype_inference) + if dtype_inference == torch.float32 and dtype != torch.float32: + return precision_full_with_autocast(torch.autocast("cuda")) + return torch.autocast("cuda") From 1fd69655fe340325863cbd7bf5297e034a6a3a0a Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:15:05 +0800 Subject: [PATCH 232/449] Revert "Apply correct inference precision implementation" This reverts commit e00365962b17550a42235d1fbe2ad2c7cc4b8961. --- modules/devices.py | 42 +++++++++--------------------------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 9e1f207c336..ad36f6562bc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -132,21 +132,6 @@ def cond_cast_float(input): ] -def cast_output(result): - if isinstance(result, tuple): - result = tuple(i.to(dtype_inference) if isinstance(i, torch.Tensor) else i for i in result) - elif isinstance(result, torch.Tensor): - result = result.to(dtype_inference) - return result - - -def autocast_with_cast_output(self, *args, **kwargs): - result = self.org_forward(*args, **kwargs) - if dtype_inference != dtype: - result = cast_output(result) - return result - - def manual_cast_forward(target_dtype): def forward_wrapper(self, *args, **kwargs): if any( @@ -164,7 +149,15 @@ def forward_wrapper(self, *args, **kwargs): self.to(org_dtype) if target_dtype != dtype_inference: - result = cast_output(result) + if isinstance(result, tuple): + result = tuple( + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i + for i in result + ) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) return result return forward_wrapper @@ -185,20 +178,6 @@ def manual_cast(target_dtype): module_type.forward = module_type.org_forward -@contextlib.contextmanager -def precision_full_with_autocast(autocast_ctx): - for module_type in patch_module_list: - org_forward = module_type.forward - module_type.forward = autocast_with_cast_output - module_type.org_forward = org_forward - try: - with autocast_ctx: - yield None - finally: - for module_type in patch_module_list: - module_type.forward = module_type.org_forward - - def autocast(disable=False): if disable: return contextlib.nullcontext() @@ -212,9 +191,6 @@ def autocast(disable=False): if has_xpu() or has_mps() or cuda_no_autocast(): return manual_cast(dtype_inference) - if dtype_inference == torch.float32 and dtype != torch.float32: - return precision_full_with_autocast(torch.autocast("cuda")) - return torch.autocast("cuda") From 58d5b042cd02f287faabef399134b97d323691f2 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:23:40 +0800 Subject: [PATCH 233/449] Apply the correct behavior of precision='full' --- modules/devices.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ad36f6562bc..29a270d11a6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -185,11 +185,14 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if dtype == torch.float32: - return contextlib.nullcontext() - if has_xpu() or has_mps() or cuda_no_autocast(): - return manual_cast(dtype_inference) + return manual_cast(dtype) + + if fp8 and dtype_inference == torch.float32: + return manual_cast(dtype) + + if dtype == torch.float32 or dtype_inference == torch.float32: + return contextlib.nullcontext() return torch.autocast("cuda") From ca671e5d7b9d03227f01e6bcb350032b6d14e722 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:30:55 +0800 Subject: [PATCH 234/449] rearrange if-statements for cpu --- modules/devices.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 29a270d11a6..0321d12c6ab 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -185,15 +185,15 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if has_xpu() or has_mps() or cuda_no_autocast(): - return manual_cast(dtype) - if fp8 and dtype_inference == torch.float32: return manual_cast(dtype) if dtype == torch.float32 or dtype_inference == torch.float32: return contextlib.nullcontext() + if has_xpu() or has_mps() or cuda_no_autocast(): + return manual_cast(dtype) + return torch.autocast("cuda") From 3cc7572f5dec429d265ce937249eb4b5cc18d0ba Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Tue, 9 Jan 2024 11:46:10 -0500 Subject: [PATCH 235/449] Restore scale factor limit changes to master branch. --- modules/api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/models.py b/modules/api/models.py index 11ba4f0b952..33894b3e694 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -143,7 +143,7 @@ class ExtrasBaseRequest(BaseModel): gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") - upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, description="By how much to upscale the image, only used when resize_mode=0.") + upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?") From 9dd25348248c06770897f4cde64e2663dfa9f2de Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Tue, 9 Jan 2024 11:47:39 -0500 Subject: [PATCH 236/449] Restore scale factor limit changes to master branch. --- scripts/postprocessing_upscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index 5946678e5be..a57f9d4a4b6 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -21,7 +21,7 @@ def ui(self): with FormRow(): with gr.Tabs(elem_id="extras_resize_mode"): with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: - upscaling_resize = gr.Slider(minimum=1.0, maximum=100.0, step=0.05, label="Resize", value=2, elem_id="extras_upscaling_resize") + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: with FormRow(): From 4d9f2c3ec8e791b9c354292c4333fc85b8f8d740 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 6 Jan 2024 21:41:29 +0900 Subject: [PATCH 237/449] update p.seed and p.subseed --- modules/txt2img.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/txt2img.py b/modules/txt2img.py index c4cc12d2f6d..d22a1f319e3 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -67,13 +67,16 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g geninfo = json.loads(generation_info) all_seeds = geninfo["all_seeds"] + all_subseeds = geninfo["all_subseeds"] image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] p.firstpass_image = infotext_utils.image_from_url_text(image_info) gallery_index_from_end = len(gallery) - gallery_index seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] - p.script_args = modules.scripts.scripts_txt2img.set_named_arg(p.script_args, 'ScriptSeed', 'seed', seed) + subseed = all_subseeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] + p.seed = seed + p.subseed = subseed with closing(p): processed = modules.scripts.scripts_txt2img.run(p, *p.script_args) From 3db6938caa719aaa38b52edecf42740ef62b0c3c Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Wed, 10 Jan 2024 18:11:48 -0500 Subject: [PATCH 238/449] begin redesign of tree module. --- html/extra-networks-card-minimal.html | 3 +- html/extra-networks-pane.html | 11 ++ html/extra-networks-tree-directory.html | 4 + html/extra-networks-tree-file.html | 1 + javascript/extraNetworks.js | 22 +++ modules/shared_options.py | 1 - modules/ui_extra_networks.py | 164 ++++++++++--------- style.css | 204 +++++++++++++----------- 8 files changed, 248 insertions(+), 162 deletions(-) create mode 100644 html/extra-networks-pane.html create mode 100644 html/extra-networks-tree-directory.html create mode 100644 html/extra-networks-tree-file.html diff --git a/html/extra-networks-card-minimal.html b/html/extra-networks-card-minimal.html index a6a54d9f4ac..d66df7dfb3f 100644 --- a/html/extra-networks-card-minimal.html +++ b/html/extra-networks-card-minimal.html @@ -1,3 +1,4 @@
    - {name}{copy_path_button}{metadata_button}{edit_button} + {name} + {copy_path_button}{metadata_button}{edit_button}
    diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html new file mode 100644 index 00000000000..93bad698ddb --- /dev/null +++ b/html/extra-networks-pane.html @@ -0,0 +1,11 @@ +
    + {subdirs_html} +
    +
    +
    + {tree_html} +
    +
    + {items_html} +
    +
    \ No newline at end of file diff --git a/html/extra-networks-tree-directory.html b/html/extra-networks-tree-directory.html new file mode 100644 index 00000000000..cec15588635 --- /dev/null +++ b/html/extra-networks-tree-directory.html @@ -0,0 +1,4 @@ +
    +{folder_name} +{content} +
    \ No newline at end of file diff --git a/html/extra-networks-tree-file.html b/html/extra-networks-tree-file.html new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/html/extra-networks-tree-file.html @@ -0,0 +1 @@ + diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 40309d5575c..33f45c8e4d0 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -253,6 +253,28 @@ function saveCardPreview(event, tabname, filename) { event.preventDefault(); } +function extraNetworksFolderClick(event, tabs_id) { + var els = document.querySelectorAll(".folder-item-summary.selected"); + [...els].forEach(el => { + el.classList.remove("selected"); + }); + event.target.classList.add("selected"); + + var searchTextArea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); + var text = event.target.classList.contains("search-all") ? "" : event.target.firstChild.textContent.trim(); + searchTextArea.value = text; + updateInput(searchTextArea); + + if (event.target.parentElement.open) { + // before close + console.log("closed"); + } else { + // before open + console.log("opened"); + //console.log("Opened:", event.target.parentElement); + } +} + function extraNetworksSearchButton(tabs_id, event) { var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); var button = event.target; diff --git a/modules/shared_options.py b/modules/shared_options.py index e698c26498c..d2e86ff10b3 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -238,7 +238,6 @@ "extra_networks_dir_button_function": OptionInfo(False, "Add a '/' to the beginning of directory buttons").info("Buttons will display the contents of the selected directory without acting as a search filter."), "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), - "extra_networks_tree_view": OptionInfo(False, "Show extra networks using a directory tree view.").needs_reload_ui(), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index ab484a5dc0b..6318594fa5b 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -73,10 +73,11 @@ def _get_tree(_paths: list[str]): # the value can be an empty dict if the directory is empty. We want these # placeholders for empty dirs so we can inform the user later. for path in paths: + short_path = os.path.basename(path) # Wrap the path in a list since that is what the `_get_tree` expects. - res[path] = _get_tree([path]) - if res[path]: - res[path] = res[path][os.path.basename(path)] + res[short_path] = _get_tree([path]) + if res[short_path]: + res[short_path] = res[short_path][os.path.basename(path)] return res @@ -153,10 +154,9 @@ def __init__(self, title): self.title = title self.name = title.lower() self.id_page = self.name.replace(" ", "_") - if shared.opts.extra_networks_tree_view: - self.card_page = shared.html("extra-networks-card-minimal.html") - else: - self.card_page = shared.html("extra-networks-card.html") + self.extra_networks_pane_template = shared.html("extra-networks-pane.html") + self.card_page_template = shared.html("extra-networks-card.html") + self.card_page_minimal_template = shared.html("extra-networks-card-minimal.html") self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} @@ -182,15 +182,14 @@ def link_preview(self, filename): def search_terms_from_path(self, filename, possible_directories=None): abspath = os.path.abspath(filename) - for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()): - parentdir = os.path.abspath(parentdir) + parentdir = os.path.dirname(os.path.abspath(parentdir)) if abspath.startswith(parentdir): - return abspath[len(parentdir):].replace('\\', '/') + return os.path.relpath(abspath, parentdir) return "" - def create_item_html(self, tabname: str, item: dict) -> str: + def create_item_html(self, tabname: str, item: dict, template: Optional[str] = None) -> str: """Generates HTML for a single ExtraNetworks Item Args: @@ -265,7 +264,10 @@ def create_item_html(self, tabname: str, item: dict) -> str: "tabname": quote_js(tabname), } - return self.card_page.format(**args) + if template: + return template.format(**args) + else: + return self.card_page.format(**args) def create_tree_view_html(self, tabname: str) -> str: """Generates HTML for displaying folders in a tree view. @@ -276,53 +278,67 @@ def create_tree_view_html(self, tabname: str) -> str: Returns: HTML string generated for this tree view. """ - self_name_id = self.name.replace(" ", "_") - res = f"
    " - - self.metadata = {} - self.items = {x["name"]: x for x in self.list_items()} + res = f"" + # Generate HTML for the tree. roots = self.allowed_directories_for_previews() tree_items = {v["filename"]: ExtraNetworksItem(v) for v in self.items.values()} tree = get_tree([os.path.abspath(x) for x in roots], items=tree_items) if not tree: - return res + "
    " + return res - file_template = "
  • {}
  • " + file_template = "
  • {card}
  • " dir_template = ( - "
    " - "{}" - "{}" + "
    " + "" + "{folder_name}" + "" + "
      {content}
    " "
    " ) def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> str: """Recursively builds HTML for a tree.""" - _res = "
      " + _res = "" if not data: - return "
      • DIRECTORY IS EMPTY
      " + return "
    • DIRECTORY IS EMPTY
    • " for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])): if isinstance(v, (ExtraNetworksItem,)): - _res += file_template.format(self.create_item_html(tabname, v.item)) + item_html = self.create_item_html(tabname, v.item, self.card_page_minimal_template) + _res += file_template.format(**{"card": item_html}) else: - _res += dir_template.format("", k, _build_tree(v)) + tmp = dir_template.format( + **{ + "attributes": "", + "tabname": tabname, + "folder_name": k, + "content": _build_tree(v), + } + ) + _res += tmp + return _res - res += "
        " # Add each root directory to the tree. for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): # If root is empty, append the "disabled" attribute to the template details tag. - res += dir_template.format("open" if v else "open disabled", k, _build_tree(v)) + res += "
          " + res += dir_template.format( + **{ + "attributes": "open" if v else "open", + "tabname": tabname, + "folder_name": k, + "content": _build_tree(v), + } + ) + res += "
        " res += "
      " - res += "
    " return res - def create_card_view_html(self, tabname): - items_html = "" - self.metadata = {} + def create_subdirs_html(self, tabname): subdirs = {} for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: @@ -355,43 +371,53 @@ def create_card_view_html(self, tabname): subdirs = {"": 1, **subdirs} subdirs_html_template = ( - "" ) - subdirs_html = "".join( + return "".join( [ subdirs_html_template.format( - " search-all" if subdir == "" else "", - tabname, - html.escape(subdir if subdir != "" else "all"), + **{ + "classes": "search-all" if not subdir else "", + "tabname": tabname, + "content": html.escape(subdir if subdir else "all"), + } ) for subdir in subdirs ] ) + def create_card_view_html(self, tabname): + res = "" self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): - items_html += self.create_item_html(tabname, item) + res += self.create_item_html(tabname, item, self.card_page_template) - if items_html == "": + if res == "": dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) - items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) - - self_name_id = self.name.replace(" ", "_") - - res = ( - f"
    {subdirs_html}
    " - f"
    {items_html}
    " - ) + res = shared.html("extra-networks-no-cards.html").format(dirs=dirs) return res def create_html(self, tabname): - if shared.opts.extra_networks_tree_view: - return self.create_tree_view_html(tabname) - else: - return self.create_card_view_html(tabname) + self.metadata = {} + self.items = {x["name"]: x for x in self.list_items()} + + tree_view_html = self.create_tree_view_html(tabname) + subdirs_html = self.create_subdirs_html(tabname) + card_view_html = self.create_card_view_html(tabname) + network_type_id = self.name.replace(" ", "_") + + return self.extra_networks_pane_template.format( + **{ + "tabname": tabname, + "network_type_id": network_type_id, + "tree_html": tree_view_html, + "subdirs_html": subdirs_html, + "items_html": card_view_html, + } + ) def create_item(self, name, index=None): raise NotImplementedError() @@ -516,19 +542,19 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) - tab_controls = {} - edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) - tab_controls["edit_search"] = edit_search - tab_controls["dropdown_sort"] = dropdown_sort - tab_controls["button_sortorder"] = button_sortorder - tab_controls["button_refresh"] = button_refresh - tab_controls["checkbox_show_dirs"] = checkbox_show_dirs + tab_controls = [ + edit_search, + dropdown_sort, + button_sortorder, + button_refresh, + checkbox_show_dirs, + ] ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -538,32 +564,28 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js="function(){ extraNetworksUrelatedTabSelected('" + tabname + "'); }", inputs=[], - outputs=list(tab_controls.values()), + outputs=tab_controls, show_progress=False, ) - visible_controls = list(tab_controls.keys()) - if shared.opts.extra_networks_tree_view: - visible_controls = ["button_refresh"] - for page, tab in zip(ui.stored_extra_pages, related_tabs): allow_prompt = "true" if page.allow_prompt else "false" allow_negative_prompt = "true" if page.allow_negative_prompt else "false" jscode = ( "extraNetworksTabSelected(" - f"'{tabname}', " - f"'{tabname}_{page.id_page}_prompts', " - f"'{allow_prompt}', " - f"'{allow_negative_prompt}'" + f"'{tabname}', " + f"'{tabname}_{page.id_page}_prompts', " + f"'{allow_prompt}', " + f"'{allow_negative_prompt}'" ");" ) tab.select( - fn=lambda: [gr.update(visible=k in visible_controls) for k in tab_controls], + fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js="function(){ " + jscode + " }", inputs=[], - outputs=list(tab_controls.values()), + outputs=tab_controls, show_progress=False, ) diff --git a/style.css b/style.css index 680a5f837b7..4f285c68b06 100644 --- a/style.css +++ b/style.css @@ -863,7 +863,7 @@ footer { margin-bottom: 1em; } -.extra-network-cards{ +.extra-network-pane{ height: calc(100vh - 24rem); overflow: clip scroll; resize: vertical; @@ -908,53 +908,75 @@ footer { width: auto; } -.extra-network-cards .nocards{ +.extra-network-pane .nocards{ margin: 1.25em 0.5em 0.5em 0.5em; } -.extra-network-cards .nocards h1{ +.extra-network-pane .nocards h1{ font-size: 1.5em; margin-bottom: 1em; } -.extra-network-cards .nocards li{ +.extra-network-pane .nocards li{ margin-left: 0.5em; } +.extra-network-pane :is(.card, .card-minimal) .button-row{ + display: inline-flex; + visibility: hidden; + color: white; +} -.extra-network-cards .card .button-row{ - display: none; +.extra-network-pane .card .button-row { position: absolute; - color: white; right: 0; - z-index: 1 + z-index: 1; } -.extra-network-cards .card:hover .button-row{ - display: flex; + +.extra-network-pane .card-minimal .button-row { + padding-left: 0.5rem; + padding-right: 0.5rem; + align-items: center; } -.extra-network-cards .card .card-button{ +.extra-network-pane :is(.card:hover, .card-minimal:hover) .button-row{ + visibility: visible; +} + +.extra-network-pane .card-button{ color: white; } -.extra-network-cards .card .metadata-button:before{ +.extra-network-pane .copy-path-button:before { + content: "⎘"; +} + +.extra-network-pane .metadata-button:before{ content: "🛈"; } -.extra-network-cards .card .edit-button:before{ +.extra-network-pane .edit-button:before{ content: "🛠"; } -.extra-network-cards .card .card-button { +.extra-network-pane .card-button { + width: 1.5em; text-shadow: 2px 2px 3px black; + color: white; padding: 0.25em 0.1em; - font-size: 200%; - width: 1.5em; } -.extra-network-cards .card .card-button:hover{ + +.extra-network-pane .card-button:hover{ color: red; } +.extra-network-pane .card .card-button { + font-size: 2rem; +} + +.extra-network-pane .card-minimal .card-button { + font-size: 1rem; +} .standalone-card-preview.card .preview{ position: absolute; @@ -963,7 +985,7 @@ footer { height:100%; } -.extra-network-cards .card, .standalone-card-preview.card{ +.extra-network-pane .card, .standalone-card-preview.card{ display: inline-block; margin: 0.5rem; width: 16rem; @@ -980,15 +1002,15 @@ footer { background-image: url('./file=html/card-no-preview.png') } -.extra-network-cards .card:hover{ +.extra-network-pane .card:hover{ box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35); } -.extra-network-cards .card .actions .additional{ +.extra-network-pane .card .actions .additional{ display: none; } -.extra-network-cards .card .actions{ +.extra-network-pane .card .actions{ position: absolute; bottom: 0; left: 0; @@ -999,45 +1021,45 @@ footer { text-shadow: 0 0 0.2em black; } -.extra-network-cards .card .actions *{ +.extra-network-pane .card .actions *{ color: white; } -.extra-network-cards .card .actions .name{ +.extra-network-pane .card .actions .name{ font-size: 1.7em; font-weight: bold; line-break: anywhere; } -.extra-network-cards .card .actions .description { +.extra-network-pane .card .actions .description { display: block; max-height: 3em; white-space: pre-wrap; line-height: 1.1; } -.extra-network-cards .card .actions .description:hover { +.extra-network-pane .card .actions .description:hover { max-height: none; } -.extra-network-cards .card .actions:hover .additional{ +.extra-network-pane .card .actions:hover .additional{ display: block; } -.extra-network-cards .card ul{ +.extra-network-pane .card ul{ margin: 0.25em 0 0.75em 0.25em; cursor: unset; } -.extra-network-cards .card ul a{ +.extra-network-pane .card ul a{ cursor: pointer; } -.extra-network-cards .card ul a:hover{ +.extra-network-pane .card ul a:hover{ color: red; } -.extra-network-cards .card .preview{ +.extra-network-pane .card .preview{ position: absolute; object-fit: cover; width: 100%; @@ -1158,94 +1180,98 @@ body.resizing .resize-handle { border-left: 1px dashed var(--border-color-primary); } -.extra-network-cards .card .copy-path-button:before { - content: "⎘"; -} - -.extra-network-cards .card-minimal .button-column { +.extra-network-pane .card-minimal { display: inline-flex; - visibility: hidden; - color: white; - padding-left: 0.5rem; - padding-right: 0.5rem; - align-items: center; + flex-grow: 1; + position: relative; + overflow: hidden; + cursor: pointer; + font-size: 1rem; + font-weight: bold; + line-break: anywhere; } -.extra-network-cards .card-minimal:hover .button-column { - visibility: visible; +/* Pushes buttons to right */ +.extra-network-pane .card-minimal .name { + flex-grow: 1; } -.extra-network-cards .card-minimal .copy-path-button:before { - content: "⎘"; +.file-item, +.folder-item, +.folder-item-summary { + padding-left: 0.05rem; + cursor: pointer; + user-select: none; + font-size: 1rem; } -.extra-network-cards .card-minimal .metadata-button:before{ - content: "🛈"; +.extra-network-pane .extra-network-tree .folder-item-summary:hover, +.extra-network-pane .extra-network-tree .file-item:hover { + -webkit-transition: all 0.1s ease-in-out; + transition: all 0.1s ease-in-out; + background-color: var(--neutral-200); } -.extra-network-cards .card-minimal .edit-button:before{ - content: "🛠"; +.dark .extra-network-pane .extra-network-tree .folder-item-summary:hover, +.dark .extra-network-pane .extra-network-tree .file-item:hover { + -webkit-transition: all 0.05s ease-in-out; + transition: all 0.05s ease-in-out; + background-color: var(--neutral-800); } -.extra-network-cards .card-minimal .card-button { - color: white; - text-shadow: 2px 2px 3px black; - font-size: 1rem; - width: 1.5rem; +/* prevents clicking/collapsing of details tags when disabled attribute is used*/ +.extra-network-pane .extra-network-tree details[disabled] summary { + pointer-events: none; + user-select: none; } -.extra-network-cards .card-minimal .card-button:hover { - color: red; +.extra-network-pane .extra-network-tree details.folder-item > summary { + list-style-type: '📁'; + text-overflow: ellipsis; } -.extra-network-cards .card-minimal { - display: inline-flex; - position: relative; - overflow: hidden; - cursor: pointer; - font-size: 1rem; - font-weight: bold; - line-break: anywhere; +.extra-network-pane .extra-network-tree details.folder-item[open] > summary { + list-style-type: '📂'; + text-overflow: ellipsis; } -.file-item { - list-style-type: '📄'; +.extra-network-pane .extra-network tree ul.folder-container { + list-style: none; + font-size: 1rem; + text-overflow: ellipsis; } -/* prevents clicking/collapsing of details tags when disabled attribute is used*/ -details[disabled] summary { - pointer-events: none; - user-select: none; +.extra-network-pane .extra-network-tree li.file-item { + display: flex; + position: relative; + align-items: center; } -details.folder-item > summary { - list-style-type: '📁'; +.extra-network-pane .extra-network-tree li.file-item::before { + content: '📄'; + font-size: 0.85rem; + vertical-align: middle; } -details.folder-item[open] > summary { - list-style-type: '📂'; +.extra-network-pane { + display: flex; } -.file-item, -.folder-item, -.folder-item-summary { +.extra-network-pane .extra-network-subdirs { display: block; +} +.extra-network-pane .extra-network-tree { font-size: 1rem; - padding: 0.05rem; - cursor: pointer; - user-select: none; + width: 25%; } - -.folder-item-summary:hover, -.file-item:hover { - -webkit-transition: all 0.1s ease-in-out; - transition: all 0.1s ease-in-out; - background-color: var(--neutral-200); +.extra-network-pane .extra-network-cards { + flex-grow: 1; } -.dark .folder-item-summary:hover, -.dark .file-item:hover { - -webkit-transition: all 0.05s ease-in-out; - transition: all 0.05s ease-in-out; +.dark .extra-network-tree .folder-item-summary.selected{ background-color: var(--neutral-800); } + +.extra-network-tree .folder-item-summary.selected { + background-color: var(--neutral-200); +} From 0011640ab187e5a59098bb80d165a23f9d08568c Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 Jan 2024 08:29:42 +0200 Subject: [PATCH 239/449] Logging: set formatter correctly for fallback logger too --- modules/logging_config.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/logging_config.py b/modules/logging_config.py index 11eee9a63db..8e31d8c9fd1 100644 --- a/modules/logging_config.py +++ b/modules/logging_config.py @@ -36,20 +36,21 @@ def setup_logging(loglevel): # Already configured, do not interfere return + formatter = logging.Formatter( + '%(asctime)s %(levelname)s [%(name)s] %(message)s', + '%Y-%m-%d %H:%M:%S', + ) + if os.environ.get("SD_WEBUI_RICH_LOG"): from rich.logging import RichHandler handler = RichHandler() else: handler = logging.StreamHandler() + handler.setFormatter(formatter) if TqdmLoggingHandler: handler = TqdmLoggingHandler(handler) - formatter = logging.Formatter( - '%(asctime)s %(levelname)s [%(name)s] %(message)s', - '%Y-%m-%d %H:%M:%S', - ) - handler.setFormatter(formatter) log_level = getattr(logging, loglevel.upper(), None) or logging.INFO From 0726a6e12e85a37d1e514f5603acf9f058c11783 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Thu, 11 Jan 2024 15:06:57 -0500 Subject: [PATCH 240/449] Finish base layout. Fix bugs. Need to test for stability and clean up. --- .../Lora/ui_extra_networks_lora.py | 5 +- html/extra-networks-card.html | 4 +- html/extra-networks-pane.html | 3 - javascript/extraNetworks.js | 48 ++++---- modules/ui_extra_networks.py | 113 ++++++------------ modules/ui_extra_networks_checkpoints.py | 5 +- modules/ui_extra_networks_hypernets.py | 6 +- .../ui_extra_networks_textual_inversion.py | 5 +- style.css | 24 ++-- 9 files changed, 89 insertions(+), 124 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index df02c663b12..db612fa2ba3 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -24,13 +24,16 @@ def create_item(self, name, index=None, enable_filter=True): alias = lora_on_disk.get_alias() + search_terms = [self.search_terms_from_path(lora_on_disk.filename)] + if lora_on_disk.hash: + search_terms.append(lora_on_disk.hash) item = { "name": name, "filename": lora_on_disk.filename, "shorthash": lora_on_disk.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""), + "search_terms": search_terms, "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": lora_on_disk.metadata, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 7770094da14..d163fe370a5 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -6,9 +6,7 @@ {edit_button}
    -
    - -
    +
    {search_terms}
    {name} {description}
    diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index 93bad698ddb..20cf6686755 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -1,6 +1,3 @@ -
    - {subdirs_html} -
    {tree_html} diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 33f45c8e4d0..4cc67fd18a3 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -24,8 +24,6 @@ function setupExtraNetworksForTab(tabname) { var sort = gradioApp().getElementById(tabname + '_extra_sort'); var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); - var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs'); - var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); @@ -33,14 +31,14 @@ function setupExtraNetworksForTab(tabname) { tabs.appendChild(sort); tabs.appendChild(sortOrder); tabs.appendChild(refresh); - tabs.appendChild(showDirsDiv); var applyFilter = function() { var searchTerm = search.value.toLowerCase(); gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { var searchOnly = elem.querySelector('.search_only'); - var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase(); + + var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { return t.textContent.toLowerCase() }).join(" "); var visible = text.indexOf(searchTerm) != -1; @@ -100,15 +98,6 @@ function setupExtraNetworksForTab(tabname) { extraNetworksApplySort[tabname] = applySort; extraNetworksApplyFilter[tabname] = applyFilter; - - var showDirsUpdate = function() { - var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }'; - toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked); - localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0); - }; - showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1; - showDirs.addEventListener("change", showDirsUpdate); - showDirsUpdate(); } function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) { @@ -136,14 +125,23 @@ function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePromp } } +function clearSearch(tabname) { + // Clear search box. + var tab_id = tabname + "_extra_search"; + var searchTextarea = gradioApp().querySelector("#" + tab_id + ' > label > textarea'); + searchTextarea.value = ""; + updateInput(searchTextarea); +} + -function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) +function extraNetworksUnrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) extraNetworksMovePromptToTab(tabname, '', false, false); + clearSearch(tabname); } function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt); - + clearSearch(tabname); } function applyExtraNetworkFilter(tabname) { @@ -254,6 +252,15 @@ function saveCardPreview(event, tabname, filename) { } function extraNetworksFolderClick(event, tabs_id) { + // If folder is open but not selected, we don't want to collapse it. Instead + // we override the removal of the "open" attribute so that the folder is + // only selected but remains open. Since this is a toggle event, removing + // the "open" attribute instead forces the event to add it back which keeps it open. + if (event.target.parentElement.open && !event.target.classList.contains("selected")) { + // before event handler removes "open" + event.target.parentElement.removeAttribute("open"); + } + var els = document.querySelectorAll(".folder-item-summary.selected"); [...els].forEach(el => { el.classList.remove("selected"); @@ -261,18 +268,9 @@ function extraNetworksFolderClick(event, tabs_id) { event.target.classList.add("selected"); var searchTextArea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); - var text = event.target.classList.contains("search-all") ? "" : event.target.firstChild.textContent.trim(); + var text = event.target.classList.contains("search-all") ? "" : event.target.getAttribute("data-path"); searchTextArea.value = text; updateInput(searchTextArea); - - if (event.target.parentElement.open) { - // before close - console.log("closed"); - } else { - // before open - console.log("opened"); - //console.log("Opened:", event.target.parentElement); - } } function extraNetworksSearchButton(tabs_id, event) { diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 6318594fa5b..2e226ba0029 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -48,23 +48,24 @@ def get_tree(paths: Union[str, list[str]], items: dict[str, ExtraNetworksItem]) if isinstance(paths, (str,)): paths = [paths] - def _get_tree(_paths: list[str]): + def _get_tree(_paths: list[str], _root: str): _res = {} for path in _paths: + relpath = os.path.relpath(path, _root) if os.path.isdir(path): dir_items = os.listdir(path) # Ignore empty directories. if not dir_items: continue - dir_tree = _get_tree([os.path.join(path, x) for x in dir_items]) + dir_tree = _get_tree([os.path.join(path, x) for x in dir_items], _root) # We only want to store non-empty folders in the tree. if dir_tree: - _res[os.path.basename(path)] = dir_tree + _res[relpath] = dir_tree else: if path not in items: continue # Add the ExtraNetworksItem to the result. - _res[os.path.basename(path)] = items[path] + _res[relpath] = items[path] return _res res = {} @@ -73,11 +74,13 @@ def _get_tree(_paths: list[str]): # the value can be an empty dict if the directory is empty. We want these # placeholders for empty dirs so we can inform the user later. for path in paths: - short_path = os.path.basename(path) + root = os.path.dirname(path) + relpath = os.path.relpath(path, root) # Wrap the path in a list since that is what the `_get_tree` expects. - res[short_path] = _get_tree([path]) - if res[short_path]: - res[short_path] = res[short_path][os.path.basename(path)] + res[relpath] = _get_tree([path], root) + if res[relpath]: + # We need to pull the inner path out one for these root dirs. + res[relpath] = res[relpath][relpath] return res @@ -245,6 +248,17 @@ def create_item_html(self, tabname: str, item: dict, template: Optional[str] = N sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip() + search_terms_html = "" + search_term_template = "{search_term}" + for search_term in item.get("search_terms", []): + search_terms_html += search_term_template.format( + **{ + "style": "display: none;", + "class": "search_terms" + (" search_only" if search_only else ""), + "search_term": search_term, + } + ) + # Some items here might not be used depending on HTML template used. args = { "background_image": background_image, @@ -258,7 +272,7 @@ def create_item_html(self, tabname: str, item: dict, template: Optional[str] = N "prompt": item.get("prompt", None), "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', "search_only": " search_only" if search_only else "", - "search_term": item.get("search_term", ""), + "search_terms": search_terms_html, "sort_keys": sort_keys, "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", "tabname": quote_js(tabname), @@ -278,7 +292,7 @@ def create_tree_view_html(self, tabname: str) -> str: Returns: HTML string generated for this tree view. """ - res = f"" + res = "" # Generate HTML for the tree. roots = self.allowed_directories_for_previews() @@ -291,7 +305,8 @@ def create_tree_view_html(self, tabname: str) -> str: file_template = "
  • {card}
  • " dir_template = ( "
    " - "" + "" "{folder_name}" "" "
      {content}
    " @@ -309,16 +324,15 @@ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> str: item_html = self.create_item_html(tabname, v.item, self.card_page_minimal_template) _res += file_template.format(**{"card": item_html}) else: - tmp = dir_template.format( + _res += dir_template.format( **{ "attributes": "", "tabname": tabname, - "folder_name": k, + "folder_name": os.path.basename(k), + "data_path": k, "content": _build_tree(v), } ) - _res += tmp - return _res # Add each root directory to the tree. @@ -329,65 +343,15 @@ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> str: **{ "attributes": "open" if v else "open", "tabname": tabname, - "folder_name": k, + "folder_name": os.path.basename(k), + "data_path": k, "content": _build_tree(v), } ) res += "" res += "" - return res - def create_subdirs_html(self, tabname): - subdirs = {} - - for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: - for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])): - for dirname in sorted(dirs, key=shared.natural_sort_key): - x = os.path.join(root, dirname) - - if not os.path.isdir(x): - continue - - subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") - - if shared.opts.extra_networks_dir_button_function: - if not subdir.startswith("/"): - subdir = "/" + subdir - else: - while subdir.startswith("/"): - subdir = subdir[1:] - - is_empty = len(os.listdir(x)) == 0 - if not is_empty and not subdir.endswith("/"): - subdir = subdir + "/" - - if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories: - continue - - subdirs[subdir] = 1 - - if subdirs: - subdirs = {"": 1, **subdirs} - - subdirs_html_template = ( - "" - ) - return "".join( - [ - subdirs_html_template.format( - **{ - "classes": "search-all" if not subdir else "", - "tabname": tabname, - "content": html.escape(subdir if subdir else "all"), - } - ) for subdir in subdirs - ] - ) - def create_card_view_html(self, tabname): res = "" self.items = {x["name"]: x for x in self.list_items()} @@ -405,7 +369,6 @@ def create_html(self, tabname): self.items = {x["name"]: x for x in self.list_items()} tree_view_html = self.create_tree_view_html(tabname) - subdirs_html = self.create_subdirs_html(tabname) card_view_html = self.create_card_view_html(tabname) network_type_id = self.name.replace(" ", "_") @@ -414,7 +377,6 @@ def create_html(self, tabname): "tabname": tabname, "network_type_id": network_type_id, "tree_html": tree_view_html, - "subdirs_html": subdirs_html, "items_html": card_view_html, } ) @@ -534,7 +496,12 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): elem_id = f"{tabname}_{page.id_page}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) - page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) + page_elem.change( + fn=lambda: None, + _js=f"function(){{applyExtraNetworkFilter({tabname}_extra_search); return []}}", + inputs=[], + outputs=[], + ) editor = page.create_user_metadata_editor(ui, tabname) editor.create_ui() @@ -542,18 +509,16 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) - edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) + edit_search = gr.Textbox('', show_label=False, elem_id=f"{tabname}_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) - checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) tab_controls = [ edit_search, dropdown_sort, button_sortorder, button_refresh, - checkbox_show_dirs, ] ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) @@ -562,7 +527,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): for tab in unrelated_tabs: tab.select( fn=lambda: [gr.update(visible=False) for _ in tab_controls], - _js="function(){ extraNetworksUrelatedTabSelected('" + tabname + "'); }", + _js=f"function(){{ extraNetworksUnrelatedTabSelected('{tabname}'); }}", inputs=[], outputs=tab_controls, show_progress=False, diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 1693e71f16f..e7976ba1260 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -21,13 +21,16 @@ def create_item(self, name, index=None, enable_filter=True): return path, ext = os.path.splitext(checkpoint.filename) + search_terms = [self.search_terms_from_path(checkpoint.filename)] + if checkpoint.sha256: + search_terms.append(checkpoint.sha256) return { "name": checkpoint.name_for_extra, "filename": checkpoint.filename, "shorthash": checkpoint.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), + "search_terms": search_terms, "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": checkpoint.metadata, diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index c96c4fa3b12..2fb4bd190a1 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -20,14 +20,16 @@ def create_item(self, name, index=None, enable_filter=True): path, ext = os.path.splitext(full_path) sha256 = sha256_from_cache(full_path, f'hypernet/{name}') shorthash = sha256[0:10] if sha256 else None - + search_terms = [self.search_terms_from_path(path)] + if sha256: + search_terms.append(sha256) return { "name": name, "filename": full_path, "shorthash": shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(path) + " " + (sha256 or ""), + "search_terms": search_terms, "prompt": quote_js(f""), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 1b334fda174..deb7cb8733b 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -18,13 +18,16 @@ def create_item(self, name, index=None, enable_filter=True): return path, ext = os.path.splitext(embedding.filename) + search_terms = [self.search_terms_from_path(embedding.filename)] + if embedding.hash: + search_terms.append(embedding.hash) return { "name": name, "filename": embedding.filename, "shorthash": embedding.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""), + "search_terms": search_terms, "prompt": quote_js(embedding.name), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, diff --git a/style.css b/style.css index 4f285c68b06..70d80d6a170 100644 --- a/style.css +++ b/style.css @@ -878,16 +878,8 @@ footer { margin: 0.3em; } -.extra-network-subdirs{ - padding: 0.2em 0.35em; -} - -.extra-network-subdirs button{ - margin: 0 0.15em; -} .extra-networks .tab-nav .search, -.extra-networks .tab-nav .sort, -.extra-networks .tab-nav .show-dirs +.extra-networks .tab-nav .sort { margin: 0.3em; align-self: center; @@ -1196,6 +1188,10 @@ body.resizing .resize-handle { flex-grow: 1; } +.folder-container { + margin-left: 1.5em !important; +} + .file-item, .folder-item, .folder-item-summary { @@ -1235,7 +1231,7 @@ body.resizing .resize-handle { text-overflow: ellipsis; } -.extra-network-pane .extra-network tree ul.folder-container { +.extra-network-pane .extra-network-tree ul.folder-container { list-style: none; font-size: 1rem; text-overflow: ellipsis; @@ -1257,15 +1253,15 @@ body.resizing .resize-handle { display: flex; } -.extra-network-pane .extra-network-subdirs { - display: block; -} .extra-network-pane .extra-network-tree { font-size: 1rem; - width: 25%; + min-width: 25%; + max-width: 25%; + border: 1px solid var(--block-border-color); } .extra-network-pane .extra-network-cards { flex-grow: 1; + border: 1px solid var(--block-border-color); } .dark .extra-network-tree .folder-item-summary.selected{ From 541881e3188b4f59d82043c5fdfac02607ec3119 Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Sat, 13 Jan 2024 02:11:06 -0800 Subject: [PATCH 241/449] Adjust brush size with hotkeys. --- extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js index 45c7600ac5f..c695debc65b 100644 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -218,6 +218,8 @@ onUiLoaded(async() => { canvas_hotkey_fullscreen: "KeyS", canvas_hotkey_move: "KeyF", canvas_hotkey_overlap: "KeyO", + canvas_hotkey_shrink_brush: "BracketLeft", + canvas_hotkey_grow_brush: "BracketRight", canvas_disabled_functions: [], canvas_show_tooltip: true, canvas_auto_expand: true, @@ -227,6 +229,8 @@ onUiLoaded(async() => { const functionMap = { "Zoom": "canvas_hotkey_zoom", "Adjust brush size": "canvas_hotkey_adjust", + "Shrink brush size": "canvas_hotkey_shrink_brush", + "Grow brush size": "canvas_hotkey_grow_brush", "Moving canvas": "canvas_hotkey_move", "Fullscreen": "canvas_hotkey_fullscreen", "Reset Zoom": "canvas_hotkey_reset", @@ -686,7 +690,9 @@ onUiLoaded(async() => { const hotkeyActions = { [hotkeysConfig.canvas_hotkey_reset]: resetZoom, [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap, - [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen + [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen, + [defaultHotkeysConfig.canvas_hotkey_shrink_brush]: () => adjustBrushSize(elemId, 10), + [defaultHotkeysConfig.canvas_hotkey_grow_brush]: () => adjustBrushSize(elemId, -10) }; const action = hotkeyActions[event.code]; From 47b52d9b28aaabb603358fae5d9b824c49aa627b Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Sat, 13 Jan 2024 02:31:26 -0800 Subject: [PATCH 242/449] Add # to the invalid_filename_chars list --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index 87a7bf22139..b6f2358c3ee 100644 --- a/modules/images.py +++ b/modules/images.py @@ -321,7 +321,7 @@ def resize(im, w, h): return res -invalid_filename_chars = '<>:"/\\|?*\n\r\t' +invalid_filename_chars = '#<>:"/\\|?*\n\r\t' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') From b6dc307c99a25bc3f3f3e96ce640a4710c14e133 Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 13 Jan 2024 14:45:15 +0400 Subject: [PATCH 243/449] fix_extension_check_for_requirements --- modules/extensions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/extensions.py b/modules/extensions.py index 99e7ee60f64..04bda297e79 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -224,13 +224,16 @@ def list_extensions(): # check for requirements for extension in extensions: + if not extension.enabled: + continue + for req in extension.metadata.requires: required_extension = loaded_extensions.get(req) if required_extension is None: errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False) continue - if not extension.enabled: + if not required_extension.enabled: errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False) continue From 02e6963325e5221e0efb96a63f3dc849550489b7 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sat, 13 Jan 2024 13:16:39 -0500 Subject: [PATCH 244/449] continue cleanup and redesign. --- html/extra-networks-tree-button.html | 11 + javascript/extraNetworks.js | 294 ++++++++++++++++++--------- modules/ui_extra_networks.py | 177 +++++++++++++--- style.css | 277 +++++++++++++++++++------ 4 files changed, 572 insertions(+), 187 deletions(-) create mode 100644 html/extra-networks-tree-button.html diff --git a/html/extra-networks-tree-button.html b/html/extra-networks-tree-button.html new file mode 100644 index 00000000000..920330f7a7d --- /dev/null +++ b/html/extra-networks-tree-button.html @@ -0,0 +1,11 @@ + + \ No newline at end of file diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 3e3b03f320e..cce2246899e 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -16,88 +16,110 @@ function toggleCss(key, css, enable) { } function setupExtraNetworksForTab(tabname) { - gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); + var this_tab = gradioApp().querySelector('#' + tabname + '_extra_tabs'); + this_tab.classList.add('extra-networks'); - var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); - var searchDiv = gradioApp().getElementById(tabname + '_extra_search'); - var search = searchDiv.querySelector('textarea'); - var sort = gradioApp().getElementById(tabname + '_extra_sort'); - var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); - var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); - var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); - var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); - - tabs.appendChild(searchDiv); - tabs.appendChild(sort); - tabs.appendChild(sortOrder); - tabs.appendChild(refresh); - - var applyFilter = function() { - var searchTerm = search.value.toLowerCase(); - - gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { - var searchOnly = elem.querySelector('.search_only'); - - var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { return t.textContent.toLowerCase() }).join(" "); - - var visible = text.indexOf(searchTerm) != -1; - - if (searchOnly && searchTerm.length < 4) { - visible = false; - } - - elem.style.display = visible ? "" : "none"; - }); - - applySort(); - }; - - var applySort = function() { - var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); - - var reverse = sortOrder.classList.contains("sortReverse"); - var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; - sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); - var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; + function registerPrompt(tabname, id) { + var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); - if (sortKeyStore == sort.dataset.sortkey) { - return; + if (!activePromptTextarea[tabname]) { + activePromptTextarea[tabname] = textarea; } - sort.dataset.sortkey = sortKeyStore; - cards.forEach(function(card) { - card.originalParentElement = card.parentElement; + textarea.addEventListener("focus", function() { + activePromptTextarea[tabname] = textarea; }); - var sortedCards = Array.from(cards); - sortedCards.sort(function(cardA, cardB) { - var a = cardA.dataset[sortKey]; - var b = cardB.dataset[sortKey]; - if (!isNaN(a) && !isNaN(b)) { - return parseInt(a) - parseInt(b); + } + + this_tab.querySelectorAll(":scope > [id^='" + tabname + "_']").forEach(function(elem) { + var tab_id = elem.getAttribute("id"); + + var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); + var searchDiv = gradioApp().QuerySelector("#" + tab_id + "_extra_search"); + console.log("HERE:", tab_id + "_extra_search", searchDiv); + var search = searchDiv.value; + var sort = gradioApp().getElementById(tabname + '_extra_sort'); + var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); + var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); + var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); + var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); + tabs.appendChild(searchDiv); + tabs.appendChild(sort); + tabs.appendChild(sortOrder); + tabs.appendChild(refresh); + var applyFilter = function() { + var searchTerm = search.value.toLowerCase(); + + gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { + var searchOnly = elem.querySelector('.search_only'); + + var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { return t.textContent.toLowerCase() }).join(" "); + + var visible = text.indexOf(searchTerm) != -1; + + if (searchOnly && searchTerm.length < 4) { + visible = false; + } + + elem.style.display = visible ? "" : "none"; + }); + + applySort(); + }; + + var applySort = function() { + var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); + + var reverse = sortOrder.classList.contains("sortReverse"); + var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; + sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); + var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; + + if (sortKeyStore == sort.dataset.sortkey) { + return; } - - return (a < b ? -1 : (a > b ? 1 : 0)); - }); - if (reverse) { - sortedCards.reverse(); - } - cards.forEach(function(card) { - card.remove(); - }); - sortedCards.forEach(function(card) { - card.originalParentElement.appendChild(card); + sort.dataset.sortkey = sortKeyStore; + + cards.forEach(function(card) { + card.originalParentElement = card.parentElement; + }); + var sortedCards = Array.from(cards); + sortedCards.sort(function(cardA, cardB) { + var a = cardA.dataset[sortKey]; + var b = cardB.dataset[sortKey]; + if (!isNaN(a) && !isNaN(b)) { + return parseInt(a) - parseInt(b); + } + + return (a < b ? -1 : (a > b ? 1 : 0)); + }); + if (reverse) { + sortedCards.reverse(); + } + cards.forEach(function(card) { + card.remove(); + }); + sortedCards.forEach(function(card) { + card.originalParentElement.appendChild(card); + }); + }; + + search.addEventListener("input", applyFilter); + sortOrder.addEventListener("click", function() { + sortOrder.classList.toggle("sortReverse"); + applySort(); }); - }; + applyFilter(); + + extraNetworksApplySort[tab_id] = applySort; + extraNetworksApplyFilter[tab_id] = applyFilter; - search.addEventListener("input", applyFilter); - sortOrder.addEventListener("click", function() { - sortOrder.classList.toggle("sortReverse"); - applySort(); + registerPrompt(tab_id, tab_id + "_prompt"); + registerPrompt(tab_id, tab_id + "_neg_prompt"); }); - applyFilter(); - extraNetworksApplySort[tabname] = applySort; - extraNetworksApplyFilter[tabname] = applyFilter; + + } function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) { @@ -136,12 +158,12 @@ function clearSearch(tabname) { function extraNetworksUnrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) extraNetworksMovePromptToTab(tabname, '', false, false); - clearSearch(tabname); + //clearSearch(tabname); } function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt); - clearSearch(tabname); + //clearSearch(tabname); } function applyExtraNetworkFilter(tabname) { @@ -159,23 +181,6 @@ var activePromptTextarea = {}; function setupExtraNetworks() { setupExtraNetworksForTab('txt2img'); setupExtraNetworksForTab('img2img'); - - function registerPrompt(tabname, id) { - var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); - - if (!activePromptTextarea[tabname]) { - activePromptTextarea[tabname] = textarea; - } - - textarea.addEventListener("focus", function() { - activePromptTextarea[tabname] = textarea; - }); - } - - registerPrompt('txt2img', 'txt2img_prompt'); - registerPrompt('txt2img', 'txt2img_neg_prompt'); - registerPrompt('img2img', 'img2img_prompt'); - registerPrompt('img2img', 'img2img_neg_prompt'); } onUiLoaded(setupExtraNetworks); @@ -262,6 +267,106 @@ function saveCardPreview(event, tabname, filename) { event.preventDefault(); } +function extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id) { + /** + * Processes `onclick` events when user clicks on files in tree. + * + * @param event The generated event. + * @param btn The clicked `action-list-item` button. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ + var par = btn.parentElement; + var search_id = tabname + "_" + tab_id + "_extra_search"; + var type = par.getAttribute("data-tree-entry-type"); + var path = par.getAttribute("data-path"); +} + +function extraNetworksTreeProcessDirectoryClick(event, btn) { + /** + * Processes `onclick` events when user clicks on directories in tree. + * + * Here is how the tree reacts to clicks for various states: + * unselected unopened directory: Diretory is selected and expanded. + * unselected opened directory: Directory is selected. + * selected opened directory: Directory is collapsed and deselected. + * chevron is clicked: Directory is expanded or collapsed. Selected state unchanged. + * + * @param event The generated event. + * @param btn The clicked `action-list-item` button. + */ + var ul = btn.nextElementSibling; + // This is the actual target that the user clicked on within the target button. + // We use this to detect if the chevron was clicked. + var true_targ = event.target; + + function _expand_or_collapse(_ul, _btn) { + // Expands
      if it is collapsed, collapses otherwise. Updates button attributes. + if (_ul.hasAttribute("data-hidden")) { + _ul.removeAttribute("data-hidden"); + _btn.setAttribute("expanded", "true"); + } else { + _ul.setAttribute("data-hidden", ""); + _btn.setAttribute("expanded", "false"); + } + } + + function _remove_selected_from_all() { + // Removes the `selected` attribute from all buttons. + var sels = document.querySelectorAll("button.action-list-content"); + [...sels].forEach(el => { + el.removeAttribute("selected"); + }) + } + + function _select_button(_btn) { + // Removes `selected` attribute from all buttons then adds to passed button. + _remove_selected_from_all(); + _btn.setAttribute("selected", ""); + } + + // If user clicks on the chevron, then we do not select the folder. + if (true_targ.matches(".action-list-item-action--leading, .action-list-item-action-chevron")) { + _expand_or_collapse(ul, btn); + } else { + // User clicked anywhere else on the button. + if (btn.hasAttribute("selected") && !ul.hasAttribute("data-hidden")) { + // If folder is select and open, collapse and deselect button. + _expand_or_collapse(ul, btn); + btn.removeAttribute("selected"); + } else if (!(!btn.hasAttribute("selected") && !ul.hasAttribute("data-hidden"))) { + // If folder is open and not selected, then we don't collapse; just select. + // NOTE: Double inversion sucks but it is the clearest way to show the branching here. + _expand_or_collapse(ul, btn); + _select_button(btn); + } else { + // All other cases, just select the button. + _select_button(btn); + } + + } +} + +function extraNetworksTreeOnClick(event, tabname, tab_id) { + /** + * Handles `onclick` events for buttons within an `extra-network-tree .action-list--tree`. + * + * Determines whether the clicked button in the tree is for a file entry or a directory + * then calls the appropriate function. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ + var btn = event.currentTarget; + var par = btn.parentElement; + if (par.getAttribute("data-tree-entry-type") === "file") { + extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id); + } else { + extraNetworksTreeProcessDirectoryClick(event, btn); + } +} + function extraNetworksFolderClick(event, tabs_id) { // If folder is open but not selected, we don't want to collapse it. Instead // we override the removal of the "open" attribute so that the folder is @@ -434,3 +539,10 @@ window.addEventListener("keydown", function(event) { closePopup(); } }); + +function testprint(e) { + console.log(e); +} + +const testinput = gradioApp().querySelector("#txt2img_lora_extra_search"); +testinput.addEventListener("input", testprint); \ No newline at end of file diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 093ac7b4f0c..9cf5b57fbcb 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -19,6 +19,90 @@ allowed_dirs = set() default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] +tree_tpl = ( + "" + "
        " + "{content}" + "
      " +) + +tree_ul_tpl = ( + "
        " + "{content}" + "
      " +) + +tree_li_dir_tpl = ( + "
    • " + "{content}" + "
    • " +) +tree_li_file_tpl = ( + "
    • " + "{content}" + "
    • " +) + +tree_btn_dir_tpl = ( + "" +) + +tree_btn_file_action_buttons_tpl = ( + "
      " + "
      " + "
      " + "
      " + "
      " + "
      " +) + +tree_btn_file_tpl = ( + "" + "" +) + + @functools.cache def allowed_preview_extensions_with_extra(extra_extensions=None): return set(default_allowed_preview_extensions) | set(extra_extensions or []) @@ -160,6 +244,7 @@ def __init__(self, title): self.extra_networks_pane_template = shared.html("extra-networks-pane.html") self.card_page_template = shared.html("extra-networks-card.html") self.card_page_minimal_template = shared.html("extra-networks-card-minimal.html") + self.tree_button_template = shared.html("extra-networks-tree-button.html") self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} @@ -279,7 +364,9 @@ def create_item_html(self, tabname: str, item: dict, template: Optional[str] = N "search_terms": search_terms_html, "sort_keys": sort_keys, "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", - "tabname": quote_js(tabname), + "tabname": tabname, + "tab_id": self.id_page, + } if template: @@ -306,55 +393,81 @@ def create_tree_view_html(self, tabname: str) -> str: if not tree: return res - file_template = "
    • {card}
    • " - dir_template = ( - "
      " - "" - "{folder_name}" - "" - "
        {content}
      " - "
      " - ) - def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> str: """Recursively builds HTML for a tree.""" _res = "" if not data: - return "
    • DIRECTORY IS EMPTY
    • " + return ( + "
      " + "Directory is empty" + "
      " + ) for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])): if isinstance(v, (ExtraNetworksItem,)): - item_html = self.create_item_html(tabname, v.item, self.card_page_minimal_template) - _res += file_template.format(**{"card": item_html}) + _action_buttons = tree_btn_file_action_buttons_tpl.format( + **{ + "path": quote_js(k), + "filename": quote_js(v.item["name"]), + "tabname": quote_js(tabname), + "tab_id": quote_js(self.id_page), + } + ) + _btn = tree_btn_file_tpl.format( + **{ + "label": v.item["name"], + "filter": v.item["search_terms"], + "tabname": tabname, + "tab_id": self.id_page, + "buttons": _action_buttons, + } + ) + _li = tree_li_file_tpl.format( + **{ + "hash": v.item["shorthash"], + "path": k, + "type": "file", + #"content": _btn, + "content": self.create_item_html(tabname, v.item, self.tree_button_template), + } + ) + _res += _li + #item_html = self.create_item_html(tabname, v.item, self.card_page_minimal_template) + #_res += file_template.format(**{"card": item_html}) else: - _res += dir_template.format( + _btn = tree_btn_dir_tpl.format( **{ - "attributes": "", + "label": os.path.basename(k), "tabname": tabname, - "folder_name": os.path.basename(k), - "data_path": k, - "content": _build_tree(v), + "tab_id": self.id_page, } ) + _ul = tree_ul_tpl.format(**{"content": _build_tree(v)}) + _li = tree_li_dir_tpl.format(**{"content": _btn + _ul, "path": k}) + _res += _li return _res # Add each root directory to the tree. for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): # If root is empty, append the "disabled" attribute to the template details tag. - res += "
        " - res += dir_template.format( + btn = tree_btn_dir_tpl.format( **{ - "attributes": "open" if v else "open", + "label": os.path.basename(k), "tabname": tabname, - "folder_name": os.path.basename(k), - "data_path": k, - "content": _build_tree(v), + "tab_id": self.id_page, } ) - res += "
      " - res += "
    " - return res + ul = tree_ul_tpl.format(**{"content": _build_tree(v)}) + li = tree_li_dir_tpl.format(**{"content": btn + ul, "path": k}) + res += li + + return tree_tpl.format( + **{ + "content": res, + "tabname": tabname, + "tab_id": self.id_page, + } + ) def create_card_view_html(self, tabname): res = "" @@ -375,7 +488,7 @@ def create_html(self, tabname): tree_view_html = self.create_tree_view_html(tabname) card_view_html = self.create_card_view_html(tabname) - network_type_id = self.name.replace(" ", "_") + network_type_id = self.id_page return self.extra_networks_pane_template.format( **{ @@ -506,7 +619,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ui.pages.append(page_elem) page_elem.change( fn=lambda: None, - _js=f"function(){{applyExtraNetworkFilter({tabname}_extra_search); return []}}", + _js=f"function(){{applyExtraNetworkFilter({tabname}_{page.id_page}_extra_search); return []}}", inputs=[], outputs=[], ) @@ -517,13 +630,11 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) - edit_search = gr.Textbox('', show_label=False, elem_id=f"{tabname}_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) tab_controls = [ - edit_search, dropdown_sort, button_sortorder, button_refresh, diff --git a/style.css b/style.css index aaafaa9d7e1..8aa41088dff 100644 --- a/style.css +++ b/style.css @@ -955,15 +955,15 @@ footer { color: white; } -.extra-network-pane .copy-path-button:before { +.extra-network-pane .copy-path-button::before { content: "⎘"; } -.extra-network-pane .metadata-button:before{ +.extra-network-pane .metadata-button::before{ content: "🛈"; } -.extra-network-pane .edit-button:before{ +.extra-network-pane .edit-button::before{ content: "🛠"; } @@ -1188,102 +1188,253 @@ body.resizing .resize-handle { border-left: 1px dashed var(--border-color-primary); } -.extra-network-pane .card-minimal { - display: inline-flex; - flex-grow: 1; - position: relative; - overflow: hidden; - cursor: pointer; - font-size: 1rem; - font-weight: bold; - line-break: anywhere; +/* ========================= */ +.extra-network-pane { + display: flex; } -/* Pushes buttons to right */ -.extra-network-pane .card-minimal .name { - flex-grow: 1; +.extra-network-pane .extra-network-cards { + display: block; } -.folder-container { - margin-left: 1.5em !important; +.extra-network-pane .extra-network-tree { + display: block; + font-size: 1rem; + min-width: 25%; + border: 1px solid var(--block-border-color); + overflow: hidden; } -.file-item, -.folder-item, -.folder-item-summary { - padding-left: 0.05rem; +.extra-network-tree .action-list--tree { cursor: pointer; + -webkit-user-select: none; + -moz-user-select: none; + -ms-user-select: none; user-select: none; - font-size: 1rem; + margin: 0; + padding: 0; } -.extra-network-pane .extra-network-tree .folder-item-summary:hover, -.extra-network-pane .extra-network-tree .file-item:hover { - -webkit-transition: all 0.1s ease-in-out; - transition: all 0.1s ease-in-out; - background-color: var(--neutral-200); +/* Remove auto indentation from tree. Will be overridden later. */ +.extra-network-tree .action-list--subgroup { + margin: 0 !important; + padding: 0 !important; + box-shadow: 0.6rem 0 0 var(--body-background-fill) inset, + 0.8rem 0 0 var(--neutral-800) inset; +} + +/* Set indentation for each depth of tree. */ +.extra-network-tree .action-list--subgroup > .action-list-item { + margin-left: 0.4rem !important; + padding-left: 0.4rem !important; } -.dark .extra-network-pane .extra-network-tree .folder-item-summary:hover, -.dark .extra-network-pane .extra-network-tree .file-item:hover { +/* Styles for tree
      elements. */ +.extra-network-tree .action-list { + +} + +/* Styles for tree
    • elements. */ +.extra-network-tree .action-list-item { + list-style: none; + position: relative; + background-color: transparent; +} + +/* Directory
        */ +.extra-network-tree .action-list-content[expanded=false]+.action-list--subgroup { + height: 0; + overflow: hidden; + visibility: hidden; + opacity: 0; +} + +.extra-network-tree .action-list-content[expanded=true]+.action-list--subgroup { + height: auto; + overflow: visible; + visibility: visible; + opacity: 1; +} + +/* File
      • */ +.extra-network-tree .action-list-item--subitem { +} + +/*
      • containing
          */ +.extra-network-tree .action-list-item--has-subitem { +} + +/* BUTTON ELEMENTS */ +/* \ No newline at end of file diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 97a97d61f2c..56c9dc4c35d 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -267,7 +267,7 @@ function extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id) { * Processes `onclick` events when user clicks on files in tree. * * @param event The generated event. - * @param btn The clicked `action-list-item` button. + * @param btn The clicked `tree-list-item` button. * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ @@ -288,7 +288,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { * chevron is clicked: Directory is expanded or collapsed. Selected state unchanged. * * @param event The generated event. - * @param btn The clicked `action-list-item` button. + * @param btn The clicked `tree-list-item` button. * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ @@ -310,7 +310,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { function _remove_selected_from_all() { // Removes the `selected` attribute from all buttons. - var sels = document.querySelectorAll("button.action-list-content"); + var sels = document.querySelectorAll("button.tree-list-content"); [...sels].forEach(el => { el.removeAttribute("selected"); }); @@ -331,7 +331,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { // If user clicks on the chevron, then we do not select the folder. - if (true_targ.matches(".action-list-item-action--leading, .action-list-item-action-chevron")) { + if (true_targ.matches(".tree-list-item-action--leading, .tree-list-item-action-chevron")) { _expand_or_collapse(ul, btn); } else { // User clicked anywhere else on the button. @@ -356,7 +356,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { function extraNetworksTreeOnClick(event, tabname, tab_id) { /** - * Handles `onclick` events for buttons within an `extra-network-tree .action-list--tree`. + * Handles `onclick` events for buttons within an `extra-network-tree .tree-list--tree`. * * Determines whether the clicked button in the tree is for a file entry or a directory * then calls the appropriate function. diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 9cf5b57fbcb..a49c6c1c54b 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -20,28 +20,28 @@ default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] tree_tpl = ( - " \ No newline at end of file diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html new file mode 100644 index 00000000000..4d29b1be273 --- /dev/null +++ b/html/extra-networks-tree.html @@ -0,0 +1,42 @@ +
          +
          + +
          + +
          +
          + +
          +
          + +
          +
          +
          + {tree} +
          +
          \ No newline at end of file diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 56c9dc4c35d..cf98452a296 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -37,15 +37,20 @@ function setupExtraNetworksForTab(tabname) { return; // `continue` doesn't work in `forEach` loops. This is equivalent. } - var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); - var sort = gradioApp().getElementById(tabname + '_extra_sort'); - var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); - var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); - var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); - var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); - tabs.appendChild(sort); - tabs.appendChild(sortOrder); - tabs.appendChild(refresh); + var sort = gradioApp().querySelector("#" + tab_id + "_extra_sort"); + if (!sort) { + return; // `continue` doesn't work in `forEach` loops. This is equivalent. + } + + var sort_dir = gradioApp().querySelector("#" + tab_id + "_extra_sort_dir"); + if (!sort_dir) { + return; // `continue` doesn't work in `forEach` loops. This is equivalent. + } + + var refresh = gradioApp().querySelector("#" + tab_id + "_extra_refresh"); + if (!refresh) { + return; // `continue` doesn't work in `forEach` loops. This is equivalent. + } var applyFilter = function() { var searchTerm = search.value.toLowerCase(); @@ -72,8 +77,8 @@ function setupExtraNetworksForTab(tabname) { var applySort = function() { var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); - var reverse = sortOrder.classList.contains("sortReverse"); - var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; + var reverse = sort_dir.dataset.sortdir == "Descending"; + var sortKey = sort.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; @@ -107,10 +112,7 @@ function setupExtraNetworksForTab(tabname) { }; search.addEventListener("input", applyFilter); - sortOrder.addEventListener("click", function() { - sortOrder.classList.toggle("sortReverse"); - applySort(); - }); + applySort(); applyFilter(); extraNetworksApplySort[tab_id] = applySort; @@ -274,7 +276,7 @@ function extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id) { var par = btn.parentElement; var search_id = tabname + "_" + tab_id + "_extra_search"; var type = par.getAttribute("data-tree-entry-type"); - var path = par.getAttribute("data-path"); + var path = btn.getAttribute("data-path"); } function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { @@ -310,7 +312,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { function _remove_selected_from_all() { // Removes the `selected` attribute from all buttons. - var sels = document.querySelectorAll("button.tree-list-content"); + var sels = document.querySelectorAll("div.tree-list-content"); [...sels].forEach(el => { el.removeAttribute("selected"); }); @@ -345,11 +347,11 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { // NOTE: Double inversion sucks but it is the clearest way to show the branching here. _expand_or_collapse(ul, btn); _select_button(btn, tabname, tab_id); - _update_search(tabname, tab_id, btn.parentElement.getAttribute("data-path")); + _update_search(tabname, tab_id, btn.getAttribute("data-path")); } else { // All other cases, just select the button. _select_button(btn, tabname, tab_id); - _update_search(tabname, tab_id, btn.parentElement.getAttribute("data-path")); + _update_search(tabname, tab_id, btn.getAttribute("data-path")); } } } @@ -374,6 +376,48 @@ function extraNetworksTreeOnClick(event, tabname, tab_id) { } } +function extraNetworksTreeSortOnClick(event, tabname, tab_id) { + var curr_mode = event.currentTarget.dataset.sortmode; + var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + tab_id + "_extra_sort_dir"); + var sort_dir = el_sort_dir.dataset.sortdir; + if (curr_mode == "path") { + event.currentTarget.dataset.sortmode = "name"; + event.currentTarget.dataset.sortkey = "sortName-" + sort_dir + "-640"; + event.currentTarget.setAttribute("title", "Sort by filename"); + } else if (curr_mode == "name") { + event.currentTarget.dataset.sortmode = "date_created"; + event.currentTarget.dataset.sortkey = "sortDate_created-" + sort_dir + "-640"; + event.currentTarget.setAttribute("title", "Sort by date created"); + } else if (curr_mode == "date_created") { + event.currentTarget.dataset.sortmode = "date_modified"; + event.currentTarget.dataset.sortkey = "sortDate_modified-" + sort_dir + "-640"; + event.currentTarget.setAttribute("title", "Sort by date modified"); + } else { + event.currentTarget.dataset.sortmode = "path"; + event.currentTarget.dataset.sortkey = "sortPath-" + sort_dir + "-640"; + event.currentTarget.setAttribute("title", "Sort by path"); + } + applyExtraNetworkSort(tabname + "_" + tab_id); +} + +function extraNetworksTreeSortDirOnClick(event, tabname, tab_id) { + var curr_dir = event.currentTarget.getAttribute("data-sortdir"); + if (curr_dir == "Ascending") { + event.currentTarget.dataset.sortdir = "Descending"; + event.currentTarget.setAttribute("title", "Sort descending"); + } else { + event.currentTarget.dataset.sortdir = "Ascending"; + event.currentTarget.setAttribute("title", "Sort ascending"); + } + applyExtraNetworkSort(tabname + "_" + tab_id); +} + +function extraNetworksTreeRefreshOnClick(event, tabname, tab_id) { + console.log("refresh clicked"); + var btn_refresh_internal = gradioApp().getElementById(tabname + "_extra_refresh_internal"); + btn_refresh_internal.dispatchEvent(new Event("click")); +} + var globalPopup = null; var globalPopupInner = null; diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index a49c6c1c54b..4ba2bea1a8c 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -19,50 +19,6 @@ allowed_dirs = set() default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] -tree_tpl = ( - "" - "
            " - "{content}" - "
          " -) - -tree_ul_tpl = ( - "
            " - "{content}" - "
          " -) - -tree_li_dir_tpl = ( - "
        • " - "{content}" - "
        • " -) -tree_li_file_tpl = ( - "
        • " - "{content}" - "
        • " -) - -action_list_item_action_leading = ( - "" - "" - "" -) - @functools.cache def allowed_preview_extensions_with_extra(extra_extensions=None): return set(default_allowed_preview_extensions) | set(extra_extensions or []) @@ -201,9 +157,13 @@ def __init__(self, title): self.title = title self.name = title.lower() self.id_page = self.name.replace(" ", "_") - self.extra_networks_pane_template = shared.html("extra-networks-pane.html") - self.card_page_template = shared.html("extra-networks-card.html") - self.tree_button_template = shared.html("extra-networks-tree-button.html") + self.pane_tpl = shared.html("extra-networks-pane.html") + self.tree_tpl = shared.html("extra-networks-tree.html") + self.card_tpl = shared.html("extra-networks-card.html") + self.btn_tree_tpl = shared.html("extra-networks-tree-button.html") + self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html") + self.btn_metadata_tpl = shared.html("extra-networks-metadata-button.html") + self.btn_edit_item_tpl = shared.html("extra-networks-edit-item-button.html") self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} @@ -268,12 +228,8 @@ def create_item_html( onclick = item.get("onclick", None) if onclick is None: - print("HERE") - print("TABNAME:", tabname) - print("PROMPT:", item["prompt"]) - print("NEG_PROMPT:", item.get("negative_prompt", "")) - print("ALLOW_NEG:", self.allow_negative_prompt) - onclick_js_tpl = "cardClicked('{tabname}', '{prompt}', '{neg_prompt}', '{allow_neg}');" + # Don't quote prompt/neg_prompt since they are stored as js strings already. + onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, '{allow_neg}');" onclick = onclick_js_tpl.format( **{ "tabname": tabname, @@ -284,15 +240,23 @@ def create_item_html( ) onclick = html.escape(onclick) - - copy_path_button = f"
          " - - metadata_button = "" + btn_copy_path = self.btn_copy_path_tpl.format(**{"filename": item["filename"]}) + btn_metadata = "" metadata = item.get("metadata") if metadata: - metadata_button = f"" - - edit_button = f"
          " + btn_metadata = self.btn_metadata_tpl.format( + **{ + "page_id": self.id_page, + "name": html.escape(item["name"]), + } + ) + btn_edit_item = self.btn_edit_item_tpl.format( + **{ + "tabname": tabname, + "page_id": self.id_page, + "name": html.escape(item["name"]), + } + ) local_path = "" filename = item.get("filename", "") @@ -334,11 +298,11 @@ def create_item_html( args = { "background_image": background_image, "card_clicked": onclick, - "copy_path_button": copy_path_button, + "copy_path_button": btn_copy_path, "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), - "edit_button": edit_button, + "edit_button": btn_edit_item, "local_preview": quote_js(item["local_preview"]), - "metadata_button": metadata_button, + "metadata_button": btn_metadata, "name": html.escape(item["name"]), "prompt": item.get("prompt", None), "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', @@ -355,6 +319,57 @@ def create_item_html( else: return args + def create_tree_dir_item_html(self, tabname: str, dir_path: str, content: Optional[str] = None) -> Optional[str]: + if not content: + return None + + btn = self.btn_tree_tpl.format( + **{ + "search_terms": "", + "subclass": "tree-list-content-dir", + "tabname": tabname, + "tab_id": self.id_page, + "onclick_extra": "", + "data_path": dir_path, + "data_hash": "", + "action_list_item_action_leading": "", + "action_list_item_visual_leading": "🗀", + "action_list_item_label": os.path.basename(dir_path), + "action_list_item_visual_trailing": "", + "action_list_item_action_trailing": "", + } + ) + ul = f"
            {content}
          " + return f"
        • {btn + ul}
        • " + + def create_tree_file_item_html(self, tabname: str, item_name: str, item: dict) -> str: + item_html_args = self.create_item_html(tabname, item) + action_buttons = "".join( + [ + item_html_args["copy_path_button"], + item_html_args["metadata_button"], + item_html_args["edit_button"], + ] + ) + action_buttons = f"
          {action_buttons}
          " + btn = self.btn_tree_tpl.format( + **{ + "search_terms": "", + "subclass": "tree-list-content-file", + "tabname": tabname, + "tab_id": self.id_page, + "onclick_extra": item_html_args["card_clicked"], + "data_path": item_name, + "data_hash": item["shorthash"], + "action_list_item_action_leading": "", + "action_list_item_visual_leading": "🗎", + "action_list_item_label": item["name"], + "action_list_item_visual_trailing": "", + "action_list_item_action_trailing": action_buttons, + } + ) + return f"
        • {btn}
        • " + def create_tree_view_html(self, tabname: str) -> str: """Generates HTML for displaying folders in a tree view. @@ -385,57 +400,9 @@ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])): if isinstance(v, (ExtraNetworksItem,)): - _item_html_args = self.create_item_html(tabname, v.item) - _action_buttons = "".join( - [ - _item_html_args["copy_path_button"], - _item_html_args["metadata_button"], - _item_html_args["edit_button"], - ] - ) - _action_buttons = f"
          {_action_buttons}
          " - _btn = self.tree_button_template.format( - **{ - "search_terms": "", - "subclass": "tree-list-content-file", - "tabname": tabname, - "tab_id": self.id_page, - "onclick_extra": _item_html_args["card_clicked"], - "action_list_item_action_leading": action_list_item_action_leading, - "action_list_item_visual_leading": "🗎", - "action_list_item_label": v.item["name"], - "action_list_item_visual_trailing": "", - "action_list_item_action_trailing": _action_buttons, - } - ) - - _li = tree_li_file_tpl.format( - **{ - "hash": v.item["shorthash"], - "path": k, - "type": "file", - "content": _btn, - } - ) - _file_li.append(_li) + _file_li.append(self.create_tree_file_item_html(tabname, k, v.item)) else: - _btn = self.tree_button_template.format( - **{ - "search_terms": "", - "subclass": "tree-list-content-dir", - "tabname": tabname, - "tab_id": self.id_page, - "onclick_extra": "", - "action_list_item_action_leading": action_list_item_action_leading, - "action_list_item_visual_leading": "🗀", - "action_list_item_label": os.path.basename(k), - "action_list_item_visual_trailing": "", - "action_list_item_action_trailing": "", - } - ) - _ul = tree_ul_tpl.format(**{"content": _build_tree(v)}) - _li = tree_li_dir_tpl.format(**{"content": _btn + _ul, "path": k}) - _dir_li.append(_li) + _dir_li.append(self.create_tree_dir_item_html(tabname, k, _build_tree(v))) # Directories should always be displayed before files. return "".join(_dir_li) + "".join(_file_li) @@ -443,31 +410,15 @@ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional # Add each root directory to the tree. for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): # If root is empty, append the "disabled" attribute to the template details tag. - btn = self.tree_button_template.format( - **{ - "search_terms": "", - "subclass": "tree-list-content-dir", - "tabname": tabname, - "tab_id": self.id_page, - "onclick_extra": "", - "action_list_item_action_leading": action_list_item_action_leading, - "action_list_item_visual_leading": "🗀", - "action_list_item_label": os.path.basename(k), - "action_list_item_visual_trailing": "", - "action_list_item_action_trailing": "", - } - ) - subtree = _build_tree(v) - if subtree: - ul = tree_ul_tpl.format(**{"content": _build_tree(v)}) - li = tree_li_dir_tpl.format(**{"content": btn + ul, "path": k}) - res += li + item_html = self.create_tree_dir_item_html(tabname, k, _build_tree(v)) + if item_html: + res += item_html - return tree_tpl.format( + return self.tree_tpl.format( **{ - "content": res, "tabname": tabname, "tab_id": self.id_page, + "tree": f"
            {res}
          " } ) @@ -475,8 +426,7 @@ def create_card_view_html(self, tabname): res = "" self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): - print("HEEEERRE:", item) - res += self.create_item_html(tabname, item, self.card_page_template) + res += self.create_item_html(tabname, item, self.card_tpl) if res == "": dirs = "".join([f"
        • {x}
        • " for x in self.allowed_directories_for_previews()]) @@ -493,7 +443,7 @@ def create_html(self, tabname): card_view_html = self.create_card_view_html(tabname) network_type_id = self.id_page - return self.extra_networks_pane_template.format( + return self.pane_tpl.format( **{ "tabname": tabname, "network_type_id": network_type_id, @@ -612,6 +562,8 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs = [] + button_refresh = gr.Button("Refresh", elem_id=tabname+"_extra_refresh_internal", visible=False) + for page in ui.stored_extra_pages: with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab: with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]): @@ -633,51 +585,9 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) - dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") - button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") - button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) - - tab_controls = [ - dropdown_sort, - button_sortorder, - button_refresh, - ] - ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - for tab in unrelated_tabs: - tab.select( - fn=lambda: [gr.update(visible=False) for _ in tab_controls], - _js=f"function(){{ extraNetworksUnrelatedTabSelected('{tabname}'); }}", - inputs=[], - outputs=tab_controls, - show_progress=False, - ) - - for page, tab in zip(ui.stored_extra_pages, related_tabs): - allow_prompt = "true" if page.allow_prompt else "false" - allow_negative_prompt = "true" if page.allow_negative_prompt else "false" - - jscode = ( - "extraNetworksTabSelected(" - f"'{tabname}', " - f"'{tabname}_{page.id_page}_prompts', " - f"'{allow_prompt}', " - f"'{allow_negative_prompt}'" - ");" - ) - - tab.select( - fn=lambda: [gr.update(visible=True) for _ in tab_controls], - _js="function(){ " + jscode + " }", - inputs=[], - outputs=tab_controls, - show_progress=False, - ) - - dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }") - def create_html(): ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] @@ -693,6 +603,8 @@ def refresh(): return ui.pages_contents interface.load(fn=pages_html, inputs=[], outputs=ui.pages) + # NOTE: Event is manually fired in extraNetworks.js:extraNetworksTreeRefreshOnClick() + # button is unused and hidden at all times. Only used in order to fire this event. button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui diff --git a/style.css b/style.css index 2dafe97fd87..08573248654 100644 --- a/style.css +++ b/style.css @@ -1196,17 +1196,33 @@ body.resizing .resize-handle { overflow: hidden; } -.extra-network-tree .tree-list--tree { - cursor: pointer; - -webkit-user-select: none; - -moz-user-select: none; - -ms-user-select: none; - user-select: none; - margin: 0; +.extra-network-tree .tree-list { + margin: 0 0.25rem; padding: 0; - margin-left: 0.25rem; } +.extra-network-tree .tree-list .tree-list-controls { + position: relative; + display: grid; + width: 100%; + padding: 0 !important; + margin-top: 0 !important; + margin-bottom: 0 !important; + font-size: 1rem; + text-align: left; + user-select: none; + background-color: transparent; + border: none; + transition: background 33.333ms linear; + grid-template-rows: min-content; + grid-template-areas: "tree-list-controls-col-0 tree-list-controls-col-1 tree-list-controls-col-2 tree-list-controls-col-3"; + grid-template-columns: minmax(0, auto) min-content min-content min-content; + grid-gap: 0.1rem; + align-items: start; +} + +.extra-network-tree .tree-list--tree {} + /* Remove auto indentation from tree. Will be overridden later. */ .extra-network-tree .tree-list--subgroup { margin: 0 !important; @@ -1221,9 +1237,6 @@ body.resizing .resize-handle { padding-left: 0.4rem !important; } -/* Styles for tree
            elements. */ -.extra-network-tree .tree-list {} - /* Styles for tree
          • elements. */ .extra-network-tree .tree-list-item { list-style: none; @@ -1288,26 +1301,182 @@ body.resizing .resize-handle { padding-top: 0.5rem !important; } -.dark .extra-network-tree button.tree-list-content:hover { +.dark .extra-network-tree div.tree-list-content:hover { -webkit-transition: all 0.05s ease-in-out; transition: all 0.05s ease-in-out; background-color: var(--neutral-800); } -.dark .extra-network-tree button.tree-list-content[selected] { +.dark .extra-network-tree div.tree-list-content[selected] { background-color: var(--neutral-700); } -.extra-network-tree button.tree-list-content:hover { +.extra-network-tree div.tree-list-content:hover { -webkit-transition: all 0.05s ease-in-out; transition: all 0.05s ease-in-out; background-color: var(--neutral-200); } -.extra-network-tree button.tree-list-content[selected] { +.extra-network-tree div.tree-list-content[selected] { background-color: var(--neutral-300); } +/* ==== CHEVRON ICON ACTIONS ==== */ +/* Define the animation for the arrow when it is clicked. */ +.extra-network-tree .tree-list-content-dir[expanded=false] .tree-list-item-action-chevron { + -ms-transform: rotate(135deg); + -webkit-transform: rotate(135deg); + transform: rotate(135deg); + transition: transform 0.2s; +} + +.extra-network-tree .tree-list-content-dir[expanded=true] .tree-list-item-action-chevron { + -ms-transform: rotate(225deg); + -webkit-transform: rotate(225deg); + transform: rotate(225deg); + transition: transform 0.2s; +} + +.tree-list-item-action-chevron { + display: inline-flex; + /* Uses box shadow to generate a pseudo chevron `>` icon. */ + padding: 0.3rem; + box-shadow: 0.1rem 0.1rem 0 0 var(--neutral-200) inset; + transform: rotate(135deg); +} + +/* ==== SEARCH INPUT ACTIONS ==== */ +/* Add icon to left side of */ +.extra-network-tree .tree-list-controls .tree-list-search::before { + content: "🔎︎"; + position: absolute; + margin: 0.5rem; + font-size: 1rem; + color: var(--input-placeholder-color); +} + +.extra-network-tree .tree-list-controls .tree-list-search { + display: inline-flex; + grid-area: tree-list-controls-col-0; + position: relative; + margin: 0.5rem; +} + +.extra-network-tree .tree-list-controls .tree-list-search .tree-list-search-text { + border: 1px solid var(--button-secondary-border-color); + border-radius: 0.5rem; + color: var(--button-secondary-text-color); + background-color: transparent; + width: 100%; + padding-left: 2rem; + line-height: 1rem; +} + +/* clear button (x on right side) styling */ +.extra-network-tree .tree-list-controls .tree-list-search .tree-list-search-text::-webkit-search-cancel-button { + -webkit-appearance: none; + appearance: none; + cursor: pointer; + height: 1rem; + width: 1rem; + mask-image: url('data:image/svg+xml,'); + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +/* ==== SORT ICON ACTIONS ==== */ +.extra-network-tree .tree-list-controls .tree-list-sort { + grid-area: tree-list-controls-col-1; + padding: 0.25rem; + display: inline-flex; + cursor: pointer; + justify-self: center; + align-self: center; +} + +.extra-network-tree .tree-list-controls .tree-list-sort .tree-list-sort-icon { + height: 1.5rem; + width: 1.5rem; + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +.extra-network-tree .tree-list-sort[data-sortmode="path"] .tree-list-sort-icon { + mask-image: url('data:image/svg+xml,'); +} + +.extra-network-tree .tree-list-sort[data-sortmode="name"] .tree-list-sort-icon { + mask-image: url('data:image/svg+xml,'); +} + +.extra-network-tree .tree-list-sort[data-sortmode="date_created"] .tree-list-sort-icon { + mask-image: url('data:image/svg+xml,'); +} + +.extra-network-tree .tree-list-sort[data-sortmode="date_modified"] .tree-list-sort-icon { + mask-image: url('data:image/svg+xml,'); +} + +/* ==== SORT DIRECTION ICON ACTIONS ==== */ +.extra-network-tree .tree-list-controls .tree-list-sort-dir { + grid-area: tree-list-controls-col-2; + padding: 0.25rem; + display: inline-flex; + cursor: pointer; + justify-self: center; + align-self: center; +} + +.extra-network-tree .tree-list-controls .tree-list-sort-dir .tree-list-sort-dir-icon { + height: 1.5rem; + width: 1.5rem; + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +.extra-network-tree .tree-list-sort-dir[data-sortdir="Ascending"] .tree-list-sort-dir-icon { + mask-image: url('data:image/svg+xml,'); +} + +.extra-network-tree .tree-list-sort-dir[data-sortdir="Descending"] .tree-list-sort-dir-icon { + mask-image: url('data:image/svg+xml,'); +} + +/* ==== REFRESH ICON ACTIONS ==== */ +.extra-network-tree .tree-list-controls .tree-list-refresh { + grid-area: tree-list-controls-col-3; + padding: 0.25rem; + display: inline-flex; + cursor: pointer; + justify-self: center; + align-self: center; +} + +.extra-network-tree .tree-list-controls .tree-list-refresh .tree-list-refresh-icon { + height: 1.5rem; + width: 1.5rem; + mask-image: url('data:image/svg+xml,'); + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +.extra-network-tree .tree-list-refresh-icon:active { + -ms-transform: rotate(180deg); + -webkit-transform: rotate(180deg); + transform: rotate(180deg); + transition: transform 0.2s; +} + +/* ==== TREE GRID CONFIG ==== */ + /* Text for button. */ .extra-network-tree .tree-list-item-label { position: relative; @@ -1332,6 +1501,7 @@ body.resizing .resize-handle { align-items: right; } + /* Icon for button when it is before label. */ .extra-network-tree .tree-list-item-visual--leading { grid-area: leading-visual; @@ -1348,7 +1518,7 @@ body.resizing .resize-handle { /* Dropdown arrow for button. */ .extra-network-tree .tree-list-item-action--leading { - margin-right: 0.2rem; + margin-right: 0.5rem; margin-left: 0.2rem; } @@ -1356,30 +1526,6 @@ body.resizing .resize-handle { visibility: hidden; } -/* Define the animation for the arrow when it is clicked. */ -.extra-network-tree .tree-list-content-dir[expanded=false] .tree-list-item-action-chevron { - -ms-transform: rotate(135deg); - -webkit-transform: rotate(135deg); - transform: rotate(135deg); - transition: transform 0.2s; -} - -.extra-network-tree .tree-list-content-dir[expanded=true] .tree-list-item-action-chevron { - -ms-transform: rotate(225deg); - -webkit-transform: rotate(225deg); - transform: rotate(225deg); - transition: transform 0.2s; -} - -.tree-list-item-action-chevron { - display: inline-flex; - /* Uses box shadow to generate a pseudo chevron `>` icon. */ - padding: 0.3rem; - box-shadow: 0.1rem 0.1rem 0 0 var(--neutral-200) inset; - transform: rotate(135deg); -} - - .extra-network-tree .tree-list-item-action--leading { grid-area: leading-action; } @@ -1399,41 +1545,3 @@ body.resizing .resize-handle { .extra-network-tree .tree-list-content:hover .button-row { visibility: visible; } - -/* Add icon to left side of */ -.extra-network-tree .tree-list-search::before { - content: "🔎︎"; - position: absolute; - margin: 0.5rem; - font-size: 1rem; - color: var(--input-placeholder-color); -} - -.extra-network-tree .tree-list-search { - position: relative; - margin: 0.5rem; -} - -.extra-network-tree .tree-list-search .tree-list-search-text { - border: 1px solid var(--button-secondary-border-color); - border-radius: 0.5rem; - color: var(--button-secondary-text-color); - background-color: transparent; - width: 100%; - padding-left: 2rem; - line-height: 1rem; -} - -/* clear button (x on right side) styling */ -.extra-network-tree .tree-list-search .tree-list-search-text::-webkit-search-cancel-button { - -webkit-appearance: none; - appearance: none; - cursor: pointer; - height: 1rem; - width: 1rem; - mask-image: url('data:image/svg+xml,'); - mask-repeat: no-repeat; - mask-position: center center; - mask-size: 100%; - background-color: var(--input-placeholder-color); -} \ No newline at end of file From 1fdc18e6a01eb40889d46fab40f21aa138d64b01 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 15 Jan 2024 18:01:13 -0500 Subject: [PATCH 256/449] Run linting --- modules/ui_extra_networks.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 4ba2bea1a8c..06fa22afada 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -13,7 +13,6 @@ from fastapi.exceptions import HTTPException from modules.infotext_utils import image_from_url_text -from modules.ui_components import ToolButton extra_pages = [] allowed_dirs = set() @@ -225,7 +224,6 @@ def create_item_html( width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' background_image = f'' if preview else '' - onclick = item.get("onclick", None) if onclick is None: # Don't quote prompt/neg_prompt since they are stored as js strings already. @@ -239,7 +237,7 @@ def create_item_html( } ) onclick = html.escape(onclick) - + btn_copy_path = self.btn_copy_path_tpl.format(**{"filename": item["filename"]}) btn_metadata = "" metadata = item.get("metadata") @@ -551,8 +549,6 @@ def tab_name_score(name): return sorted(pages, key=lambda x: tab_scores[x.name]) def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): - from modules.ui import switch_values_symbol - ui = ExtraNetworksUi() ui.pages = [] ui.pages_contents = [] @@ -588,6 +584,9 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + for tab in unrelated_tabs: + tab.select(fn=None, _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=[], show_progress=False) + def create_html(): ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] From c1e04c63b3d2a5102b4cf6deddea337c8a964c53 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:18:20 +0900 Subject: [PATCH 257/449] callback postprocess_image_after_composite --- modules/processing.py | 5 +++++ modules/scripts.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index dcc807fe38d..2c3365a026f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1029,6 +1029,11 @@ def infotext(index=0, use_main_prompt=False): image = apply_overlay(image, p.paste_to, overlay_image) + if p.scripts is not None: + pp = scripts.PostprocessImageArgs(image) + p.scripts.postprocess_image_after_composite(p, pp) + image = pp.image + if save_samples: images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) diff --git a/modules/scripts.py b/modules/scripts.py index cf938ebb972..060069cf36a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -262,6 +262,15 @@ def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args): pass + def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs, *args): + """ + Called for every image after it has been generated. + Same as postprocess_image but after inpaint_full_res composite + So that it operates on the full image instead of the inpaint_full_res crop region. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -856,6 +865,14 @@ def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs): except Exception: errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_image_after_composite(p, pp, *script_args) + except Exception: + errors.report(f"Error running postprocess_image_after_composite: {script.filename}", exc_info=True) + def before_component(self, component, **kwargs): for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): try: From 14b9762bcab41fe0f7411498d9d3ffdc69ea1dda Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 17:08:07 +0900 Subject: [PATCH 258/449] immediately stop on second interrupt Revert "immediately stop on second interrupt" This reverts commit ab409072a1f2a9c911a63aee98a6b42081803cdc. immediately stop on second interrupt --- modules/ui_toprow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index 1abc9117ba3..fbe705be130 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -107,8 +107,9 @@ def create_submit_box(self): ) def interrupt_function(): - if shared.state.job_count > 1 and shared.opts.interrupt_after_current: + if not shared.state.stopping_generation and shared.state.job_count > 1 and shared.opts.interrupt_after_current: shared.state.stop_generating() + gr.Info("Generation will stop after finishing this image, click again to stop immediately.") else: shared.state.interrupt() From a75dfe1c0de1564f4607a6f4ed7f4a7c00ef18a0 Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Tue, 16 Jan 2024 19:03:48 +0100 Subject: [PATCH 259/449] - expand fields to include model name and hash - write these in the CSV log file - ensure old log files are updated w.r.t delimiter count --- modules/ui_common.py | 47 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index f17259c2956..cbb2495d6c9 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -36,6 +36,29 @@ def plaintext_to_html(text, classname=None): return f"

            {content}

            " if classname else f"

            {content}

            " +def update_logfile(logfile_path, fields): + import csv + + with open(logfile_path, "r", encoding="utf8", newline="") as file: + reader = csv.reader(file) + rows = list(reader) + + # blank file: leave it as is + if not rows: + return + + rows[0] = fields + + # append new fields to each row as empty values + for row in rows[1:]: + while len(row) < len(fields): + row.append("") + + with open(logfile_path, "w", encoding="utf8", newline="") as file: + writer = csv.writer(file) + writer.writerows(rows) + + def save_files(js_data, images, do_make_zip, index): import csv filenames = [] @@ -64,11 +87,31 @@ def __init__(self, d=None): os.makedirs(shared.opts.outdir_save, exist_ok=True) + fields = [ + "prompt", + "seed", + "width", + "height", + "sampler", + "cfgs", + "steps", + "filename", + "negative_prompt", + "sd_model_name", + "sd_model_hash", + ] + logfile_path = os.path.join(shared.opts.outdir_save, "log.csv") + + # NOTE: ensure csv integrity when fields are added by + # updating headers and padding with delimeters where needed + if os.path.exists(logfile_path): + update_logfile(logfile_path, fields) + with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: at_start = file.tell() == 0 writer = csv.writer(file) if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + writer.writerow(fields) for image_index, filedata in enumerate(images, start_index): image = image_from_url_text(filedata) @@ -86,7 +129,7 @@ def __init__(self, d=None): filenames.append(os.path.basename(txt_fullfn)) fullfns.append(txt_fullfn) - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"], data["sd_model_name"], data["sd_model_hash"]]) # Make Zip if do_make_zip: From 315e40a49c32438551ed6b66138acdf664ecdbc8 Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Tue, 16 Jan 2024 19:11:28 +0100 Subject: [PATCH 260/449] reuse variable for log file path --- modules/ui_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index cbb2495d6c9..1db72092d59 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -107,7 +107,7 @@ def __init__(self, d=None): if os.path.exists(logfile_path): update_logfile(logfile_path, fields) - with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + with open(logfile_path, "a", encoding="utf8", newline='') as file: at_start = file.tell() == 0 writer = csv.writer(file) if at_start: From 4f9626703345ced77935e6bbb06de0b4522d53b7 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Tue, 16 Jan 2024 13:35:01 -0500 Subject: [PATCH 261/449] Finish cleanup. --- html/extra-networks-card-minimal.html | 4 - html/extra-networks-card.html | 10 +- html/extra-networks-edit-item-button.html | 2 +- html/extra-networks-metadata-button.html | 2 +- html/extra-networks-pane.html | 6 +- html/extra-networks-tree-button.html | 3 +- html/extra-networks-tree.html | 14 +- javascript/extraNetworks.js | 170 +++++++++++---------- modules/ui_extra_networks.py | 162 +++++++++++++++----- modules/ui_extra_networks_user_metadata.py | 2 +- style.css | 81 +++++++--- 11 files changed, 288 insertions(+), 168 deletions(-) delete mode 100644 html/extra-networks-card-minimal.html diff --git a/html/extra-networks-card-minimal.html b/html/extra-networks-card-minimal.html deleted file mode 100644 index d66df7dfb3f..00000000000 --- a/html/extra-networks-card-minimal.html +++ /dev/null @@ -1,4 +0,0 @@ -
            - {name} - {copy_path_button}{metadata_button}{edit_button} -
            diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index ca683dc47ec..f1d959a6733 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,9 +1,9 @@ -
            +
            {background_image}
            {copy_path_button}{metadata_button}{edit_button}
            -
            -
            {search_terms}
            - {name} - {description} +
            +
            {search_terms}
            + {name} + {description}
            diff --git a/html/extra-networks-edit-item-button.html b/html/extra-networks-edit-item-button.html index 7d2677d9f14..0fe43082ad1 100644 --- a/html/extra-networks-edit-item-button.html +++ b/html/extra-networks-edit-item-button.html @@ -1,4 +1,4 @@
            + onclick="extraNetworksEditUserMetadata(event, '{tabname}', '{extra_networks_tabname}', '{name}')">
            \ No newline at end of file diff --git a/html/extra-networks-metadata-button.html b/html/extra-networks-metadata-button.html index ad6d6f41662..285b5b3b658 100644 --- a/html/extra-networks-metadata-button.html +++ b/html/extra-networks-metadata-button.html @@ -1,4 +1,4 @@ \ No newline at end of file diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index 20cf6686755..bf46ca16305 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -1,8 +1,8 @@ -
            -
            +
            +
            {tree_html}
            -
            +
            {items_html}
            \ No newline at end of file diff --git a/html/extra-networks-tree-button.html b/html/extra-networks-tree-button.html index 20a9b0b863f..9dc2e2a40c8 100644 --- a/html/extra-networks-tree-button.html +++ b/html/extra-networks-tree-button.html @@ -1,8 +1,7 @@
            diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html index 4d29b1be273..23f6af1058c 100644 --- a/html/extra-networks-tree.html +++ b/html/extra-networks-tree.html @@ -2,36 +2,36 @@
            diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index cf98452a296..a3f003bf475 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -31,25 +31,15 @@ function setupExtraNetworksForTab(tabname) { var this_tab = gradioApp().querySelector('#' + tabname + '_extra_tabs'); this_tab.classList.add('extra-networks'); this_tab.querySelectorAll(":scope > [id^='" + tabname + "_']").forEach(function(elem) { - var tab_id = elem.getAttribute("id"); - var search = gradioApp().querySelector("#" + tab_id + "_extra_search"); - if (!search) { - return; // `continue` doesn't work in `forEach` loops. This is equivalent. - } - - var sort = gradioApp().querySelector("#" + tab_id + "_extra_sort"); - if (!sort) { - return; // `continue` doesn't work in `forEach` loops. This is equivalent. - } - - var sort_dir = gradioApp().querySelector("#" + tab_id + "_extra_sort_dir"); - if (!sort_dir) { - return; // `continue` doesn't work in `forEach` loops. This is equivalent. - } - - var refresh = gradioApp().querySelector("#" + tab_id + "_extra_refresh"); - if (!refresh) { - return; // `continue` doesn't work in `forEach` loops. This is equivalent. + var extra_networks_tabname = elem.id; + var search = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_search"); + var sort_mode = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_sort"); + var sort_dir = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_sort_dir"); + var refresh = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_refresh"); + + // If any of the buttons above don't exist, we want to skip this iteration of the loop. + if (!search || !sort_mode || !sort_dir || !refresh) { + return; // `return` is equivalent of `continue` but for forEach loops. } var applyFilter = function() { @@ -78,14 +68,14 @@ function setupExtraNetworksForTab(tabname) { var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); var reverse = sort_dir.dataset.sortdir == "Descending"; - var sortKey = sort.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; + var sortKey = sort_mode.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; - if (sortKeyStore == sort.dataset.sortkey) { + if (sortKeyStore == sort_mode.dataset.sortkey) { return; } - sort.dataset.sortkey = sortKeyStore; + sort_mode.dataset.sortkey = sortKeyStore; cards.forEach(function(card) { card.originalParentElement = card.parentElement; @@ -115,8 +105,8 @@ function setupExtraNetworksForTab(tabname) { applySort(); applyFilter(); - extraNetworksApplySort[tab_id] = applySort; - extraNetworksApplyFilter[tab_id] = applyFilter; + extraNetworksApplySort[extra_networks_tabname] = applySort; + extraNetworksApplyFilter[extra_networks_tabname] = applyFilter; }); registerPrompt(tabname, tabname + "_prompt"); @@ -148,14 +138,6 @@ function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePromp } } -function clearSearch(tabname) { - // Clear search box. - var tab_id = tabname + "_extra_search"; - var searchTextarea = gradioApp().querySelector("#" + tab_id + ' > label > textarea'); - searchTextarea.value = ""; - updateInput(searchTextarea); -} - function extraNetworksUnrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) extraNetworksMovePromptToTab(tabname, '', false, false); @@ -264,22 +246,20 @@ function saveCardPreview(event, tabname, filename) { event.preventDefault(); } -function extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id) { +function extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname) { /** * Processes `onclick` events when user clicks on files in tree. * - * @param event The generated event. - * @param btn The clicked `tree-list-item` button. - * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. - * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + * @param event The generated event. + * @param btn The clicked `tree-list-item` button. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ - var par = btn.parentElement; - var search_id = tabname + "_" + tab_id + "_extra_search"; - var type = par.getAttribute("data-tree-entry-type"); - var path = btn.getAttribute("data-path"); + // NOTE: Currently unused. + return; } -function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { +function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_networks_tabname) { /** * Processes `onclick` events when user clicks on directories in tree. * @@ -289,10 +269,10 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { * selected opened directory: Directory is collapsed and deselected. * chevron is clicked: Directory is expanded or collapsed. Selected state unchanged. * - * @param event The generated event. - * @param btn The clicked `tree-list-item` button. - * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. - * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + * @param event The generated event. + * @param btn The clicked `tree-list-item` button. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ var ul = btn.nextElementSibling; // This is the actual target that the user clicked on within the target button. @@ -301,12 +281,12 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { function _expand_or_collapse(_ul, _btn) { // Expands
              if it is collapsed, collapses otherwise. Updates button attributes. - if (_ul.hasAttribute("data-hidden")) { - _ul.removeAttribute("data-hidden"); - _btn.setAttribute("expanded", "true"); + if (_ul.hasAttribute("hidden")) { + _ul.removeAttribute("hidden"); + _btn.dataset.expanded = ""; } else { - _ul.setAttribute("data-hidden", ""); - _btn.setAttribute("expanded", "false"); + _ul.setAttribute("hidden", ""); + delete _btn.dataset.expanded; } } @@ -314,19 +294,19 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { // Removes the `selected` attribute from all buttons. var sels = document.querySelectorAll("div.tree-list-content"); [...sels].forEach(el => { - el.removeAttribute("selected"); + delete el.dataset.selected; }); } function _select_button(_btn) { - // Removes `selected` attribute from all buttons then adds to passed button. + // Removes `data-selected` attribute from all buttons then adds to passed button. _remove_selected_from_all(); - _btn.setAttribute("selected", ""); + _btn.dataset.selected = ""; } - function _update_search(_tabname, _tab_id, _search_text) { + function _update_search(_tabname, _extra_networks_tabname, _search_text) { // Update search input with select button's path. - var search_input_elem = gradioApp().querySelector("#" + tabname + "_" + tab_id + "_extra_search"); + var search_input_elem = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_search"); search_input_elem.value = _search_text; updateInput(search_input_elem); } @@ -337,48 +317,58 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { _expand_or_collapse(ul, btn); } else { // User clicked anywhere else on the button. - if (btn.hasAttribute("selected") && !ul.hasAttribute("data-hidden")) { + if ("selected" in btn.dataset && !(ul.hasAttribute("hidden"))) { // If folder is select and open, collapse and deselect button. _expand_or_collapse(ul, btn); - btn.removeAttribute("selected"); - _update_search(tabname, tab_id, ""); - } else if (!(!btn.hasAttribute("selected") && !ul.hasAttribute("data-hidden"))) { + delete btn.dataset.selected; + _update_search(tabname, extra_networks_tabname, ""); + } else if (!(!("selected" in btn.dataset) && !(ul.hasAttribute("hidden")))) { // If folder is open and not selected, then we don't collapse; just select. // NOTE: Double inversion sucks but it is the clearest way to show the branching here. _expand_or_collapse(ul, btn); - _select_button(btn, tabname, tab_id); - _update_search(tabname, tab_id, btn.getAttribute("data-path")); + _select_button(btn, tabname, extra_networks_tabname); + _update_search(tabname, extra_networks_tabname, btn.dataset.path); } else { // All other cases, just select the button. - _select_button(btn, tabname, tab_id); - _update_search(tabname, tab_id, btn.getAttribute("data-path")); + _select_button(btn, tabname, extra_networks_tabname); + _update_search(tabname, extra_networks_tabname, btn.dataset.path); } } } -function extraNetworksTreeOnClick(event, tabname, tab_id) { +function extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for buttons within an `extra-network-tree .tree-list--tree`. * * Determines whether the clicked button in the tree is for a file entry or a directory * then calls the appropriate function. * - * @param event The generated event. - * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. - * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ var btn = event.currentTarget; var par = btn.parentElement; - if (par.getAttribute("data-tree-entry-type") === "file") { - extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id); + if (par.dataset.treeEntryType === "file") { + extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname); } else { - extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id); + extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_networks_tabname); } } -function extraNetworksTreeSortOnClick(event, tabname, tab_id) { +function extraNetworksTreeSortOnClick(event, tabname, extra_networks_tabname) { + /** + * Handles `onclick` events for the Sort Mode button. + * + * Modifies the data attributes of the Sort Mode button to cycle between + * various sorting modes. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ var curr_mode = event.currentTarget.dataset.sortmode; - var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + tab_id + "_extra_sort_dir"); + var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_sort_dir"); var sort_dir = el_sort_dir.dataset.sortdir; if (curr_mode == "path") { event.currentTarget.dataset.sortmode = "name"; @@ -397,23 +387,43 @@ function extraNetworksTreeSortOnClick(event, tabname, tab_id) { event.currentTarget.dataset.sortkey = "sortPath-" + sort_dir + "-640"; event.currentTarget.setAttribute("title", "Sort by path"); } - applyExtraNetworkSort(tabname + "_" + tab_id); + applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } -function extraNetworksTreeSortDirOnClick(event, tabname, tab_id) { - var curr_dir = event.currentTarget.getAttribute("data-sortdir"); - if (curr_dir == "Ascending") { +function extraNetworksTreeSortDirOnClick(event, tabname, extra_networks_tabname) { + /** + * Handles `onclick` events for the Sort Direction button. + * + * Modifies the data attributes of the Sort Direction button to cycle between + * ascending and descending sort directions. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ + if (event.currentTarget.dataset.sortdir == "Ascending") { event.currentTarget.dataset.sortdir = "Descending"; event.currentTarget.setAttribute("title", "Sort descending"); } else { event.currentTarget.dataset.sortdir = "Ascending"; event.currentTarget.setAttribute("title", "Sort ascending"); } - applyExtraNetworkSort(tabname + "_" + tab_id); + applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } -function extraNetworksTreeRefreshOnClick(event, tabname, tab_id) { - console.log("refresh clicked"); +function extraNetworksTreeRefreshOnClick(event, tabname, extra_networks_tabname) { + /** + * Handles `onclick` events for the Refresh Page button. + * + * In order to actually call the python functions in `ui_extra_networks.py` + * to refresh the page, we created an empty gradio button in that file with an + * event handler that refreshes the page. So what this function here does + * is it manually raises a `click` event on that button. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ var btn_refresh_internal = gradioApp().getElementById(tabname + "_extra_refresh_internal"); btn_refresh_internal.dispatchEvent(new Event("click")); } diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 06fa22afada..a03207b2bc0 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -155,7 +155,14 @@ class ExtraNetworksPage: def __init__(self, title): self.title = title self.name = title.lower() - self.id_page = self.name.replace(" ", "_") + # This is the actual name of the extra networks tab (not txt2img/img2img). + self.extra_networks_tabname = self.name.replace(" ", "_") + self.allow_prompt = True + self.allow_negative_prompt = False + self.metadata = {} + self.items = {} + self.lister = util.MassFileLister() + # HTML Templates self.pane_tpl = shared.html("extra-networks-pane.html") self.tree_tpl = shared.html("extra-networks-tree.html") self.card_tpl = shared.html("extra-networks-card.html") @@ -163,11 +170,6 @@ def __init__(self, title): self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html") self.btn_metadata_tpl = shared.html("extra-networks-metadata-button.html") self.btn_edit_item_tpl = shared.html("extra-networks-edit-item-button.html") - self.allow_prompt = True - self.allow_negative_prompt = False - self.metadata = {} - self.items = {} - self.lister = util.MassFileLister() def refresh(self): pass @@ -202,15 +204,17 @@ def create_item_html( item: dict, template: Optional[str] = None, ) -> Union[str, dict]: - """Generates HTML for a single ExtraNetworks Item + """Generates HTML for a single ExtraNetworks Item. Args: tabname: The name of the active tab. item: Dictionary containing item information. + template: Optional template string to use. Returns: - HTML string generated for this item. - Can be empty if the item is not meant to be shown. + If a template is passed: HTML string generated for this item. + Can be empty if the item is not meant to be shown. + If no template is passed: A dictionary containing the generated item's attributes. """ metadata = item.get("metadata") if metadata: @@ -244,14 +248,14 @@ def create_item_html( if metadata: btn_metadata = self.btn_metadata_tpl.format( **{ - "page_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "name": html.escape(item["name"]), } ) btn_edit_item = self.btn_edit_item_tpl.format( **{ "tabname": tabname, - "page_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "name": html.escape(item["name"]), } ) @@ -307,9 +311,9 @@ def create_item_html( "search_only": " search_only" if search_only else "", "search_terms": search_terms_html, "sort_keys": sort_keys, - "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", + "style": f"display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%", "tabname": tabname, - "tab_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, } if template: @@ -317,7 +321,32 @@ def create_item_html( else: return args - def create_tree_dir_item_html(self, tabname: str, dir_path: str, content: Optional[str] = None) -> Optional[str]: + def create_tree_dir_item_html( + self, + tabname: str, + dir_path: str, + content: Optional[str] = None, + ) -> Optional[str]: + """Generates HTML for a directory item in the tree. + + The generated HTML is of the format: + ```html +
            • +
              +
                + {content} +
              +
            • + ``` + + Args: + tabname: The name of the active tab. + dir_path: Path to the directory for this item. + content: Optional HTML string that will be wrapped by this
                . + + Returns: + HTML formatted string. + """ if not content: return None @@ -326,7 +355,7 @@ def create_tree_dir_item_html(self, tabname: str, dir_path: str, content: Option "search_terms": "", "subclass": "tree-list-content-dir", "tabname": tabname, - "tab_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "onclick_extra": "", "data_path": dir_path, "data_hash": "", @@ -337,10 +366,32 @@ def create_tree_dir_item_html(self, tabname: str, dir_path: str, content: Option "action_list_item_action_trailing": "", } ) - ul = f"
                  {content}
                " - return f"
              • {btn + ul}
              • " + ul = f"" + return ( + "
              • " + f"{btn + ul}" + "
              • " + ) + + def create_tree_file_item_html(self, tabname: str, file_path: str, item: dict) -> str: + """Generates HTML for a file item in the tree. + + The generated HTML is of the format: + ```html +
              • + +
                +
              • + ``` - def create_tree_file_item_html(self, tabname: str, item_name: str, item: dict) -> str: + Args: + tabname: The name of the active tab. + file_path: The path to the file for this item. + item: Dictionary containing the item information. + + Returns: + HTML formatted string. + """ item_html_args = self.create_item_html(tabname, item) action_buttons = "".join( [ @@ -355,9 +406,9 @@ def create_tree_file_item_html(self, tabname: str, item_name: str, item: dict) - "search_terms": "", "subclass": "tree-list-content-file", "tabname": tabname, - "tab_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "onclick_extra": item_html_args["card_clicked"], - "data_path": item_name, + "data_path": file_path, "data_hash": item["shorthash"], "action_list_item_action_leading": "", "action_list_item_visual_leading": "🗎", @@ -366,11 +417,17 @@ def create_tree_file_item_html(self, tabname: str, item_name: str, item: dict) - "action_list_item_action_trailing": action_buttons, } ) - return f"
              • {btn}
              • " + return ( + "
              • " + f"{btn}" + "
              • " + ) def create_tree_view_html(self, tabname: str) -> str: """Generates HTML for displaying folders in a tree view. + The generated HTML uses `extra-networks-tree.html` as a template. + Args: tabname: The name of the active tab. @@ -379,7 +436,7 @@ def create_tree_view_html(self, tabname: str) -> str: """ res = "" - # Generate HTML for the tree. + # Setup the tree dictionary. roots = self.allowed_directories_for_previews() tree_items = {v["filename"]: ExtraNetworksItem(v) for v in self.items.values()} tree = get_tree([os.path.abspath(x) for x in roots], items=tree_items) @@ -388,7 +445,17 @@ def create_tree_view_html(self, tabname: str) -> str: return res def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional[str]: - """Recursively builds HTML for a tree.""" + """Recursively builds HTML for a tree. + + Args: + data: Dictionary representing a directory tree. Can be NoneType. + Data keys should be absolute paths from the root and values + should be subdirectory trees or an ExtraNetworksItem. + + Returns: + If data is not None: HTML string + Else: None + """ if not data: return None @@ -402,25 +469,36 @@ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional else: _dir_li.append(self.create_tree_dir_item_html(tabname, k, _build_tree(v))) - # Directories should always be displayed before files. + # Directories should always be displayed before files so we order them here. return "".join(_dir_li) + "".join(_file_li) # Add each root directory to the tree. for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): - # If root is empty, append the "disabled" attribute to the template details tag. item_html = self.create_tree_dir_item_html(tabname, k, _build_tree(v)) - if item_html: + # Only add non-empty entries to the tree. + if item_html is not None: res += item_html return self.tree_tpl.format( **{ "tabname": tabname, - "tab_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "tree": f"
                  {res}
                " } ) - def create_card_view_html(self, tabname): + def create_card_view_html(self, tabname: str) -> str: + """Generates HTML for the network Card View section for a tab. + + This HTML goes into the `extra-networks-pane.html`
                with + `id='{tabname}_{extra_networks_tabname}_cards`. + + Args: + tabname: The name of the active tab. + + Returns: + HTML formatted string. + """ res = "" self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): @@ -433,20 +511,26 @@ def create_card_view_html(self, tabname): return res def create_html(self, tabname): + """Generates an HTML string for the current pane. + + The generated HTML uses `extra-networks-pane.html` as a template. + + Args: + tabname: The name of the active tab. + + Returns: + HTML formatted string. + """ self.lister.reset() self.metadata = {} self.items = {x["name"]: x for x in self.list_items()} - tree_view_html = self.create_tree_view_html(tabname) - card_view_html = self.create_card_view_html(tabname) - network_type_id = self.id_page - return self.pane_tpl.format( **{ "tabname": tabname, - "network_type_id": network_type_id, - "tree_html": tree_view_html, - "items_html": card_view_html, + "extra_networks_tabname": self.extra_networks_tabname, + "tree_html": self.create_tree_view_html(tabname), + "items_html": self.create_card_view_html(tabname), } ) @@ -561,16 +645,16 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): button_refresh = gr.Button("Refresh", elem_id=tabname+"_extra_refresh_internal", visible=False) for page in ui.stored_extra_pages: - with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab: - with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]): + with gr.Tab(page.title, elem_id=f"{tabname}_{page.extra_networks_tabname}", elem_classes=["extra-page"]) as tab: + with gr.Column(elem_id=f"{tabname}_{page.extra_networks_tabname}_prompts", elem_classes=["extra-page-prompts"]): pass - elem_id = f"{tabname}_{page.id_page}_cards_html" + elem_id = f"{tabname}_{page.extra_networks_tabname}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) page_elem.change( fn=lambda: None, - _js=f"function(){{applyExtraNetworkFilter({tabname}_{page.id_page}_extra_search); return []}}", + _js=f"function(){{applyExtraNetworkFilter({tabname}_{page.extra_networks_tabname}_extra_search); return []}}", inputs=[], outputs=[], ) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 989a649b799..2ca937fd117 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -14,7 +14,7 @@ def __init__(self, ui, tabname, page): self.ui = ui self.tabname = tabname self.page = page - self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata" + self.id_part = f"{self.tabname}_{self.page.extra_networks_tabname}_edit_user_metadata" self.box = None diff --git a/style.css b/style.css index 08573248654..1090e4368ab 100644 --- a/style.css +++ b/style.css @@ -879,13 +879,6 @@ footer { margin-bottom: 1em; } -.extra-network-pane{ - height: calc(100vh - 24rem); - overflow: clip scroll; - resize: vertical; - min-height: 52rem; -} - .extra-networks > div.tab-nav{ min-height: 3.4rem; } @@ -1182,23 +1175,63 @@ body.resizing .resize-handle { /* ========================= */ .extra-network-pane { display: flex; -} - -.extra-network-pane .extra-network-cards { - display: block; + height: calc(100vh - 24rem); + resize: vertical; + min-height: 52rem; } .extra-network-pane .extra-network-tree { - display: block; + flex: 1; + flex-direction: column; + display: flex; font-size: 1rem; - min-width: 25%; border: 1px solid var(--block-border-color); - overflow: hidden; } -.extra-network-tree .tree-list { - margin: 0 0.25rem; +.extra-network-pane .extra-network-cards { + flex: 3; + overflow: clip auto !important; + border: 1px solid var(--block-border-color); +} + +.extra-network-pane .extra-network-tree .tree-list { + flex: 1; + display: flex; + flex-direction: column; padding: 0; + width: 100%; + overflow: hidden; +} + +.extra-network-pane .extra-network-tree .tree-list .tree-list-container { + flex: 1; + overflow: clip auto !important; + width: 100%; +} + + +.extra-network-pane .extra-network-cards::-webkit-scrollbar, +.extra-network-pane .tree-list-container::-webkit-scrollbar { + background-color: transparent; + width: 16px; +} + +.extra-network-pane .extra-network-cards::-webkit-scrollbar-track, +.extra-network-pane .tree-list-container::-webkit-scrollbar-track { + background-color: transparent; + background-clip: content-box; +} + +.extra-network-pane .extra-network-cards::-webkit-scrollbar-thumb, +.extra-network-pane .tree-list-container::-webkit-scrollbar-thumb { + background-color: var(--border-color-primary); + border-radius: 16px; + border: 4px solid var(--background-fill-primary); +} + +.extra-network-pane .extra-network-cards::-webkit-scrollbar-button, +.extra-network-pane .tree-list-container::-webkit-scrollbar-button { + display: none; } .extra-network-tree .tree-list .tree-list-controls { @@ -1244,17 +1277,15 @@ body.resizing .resize-handle { background-color: transparent; } -/* Directory
                  visibility based on expanded attribute. */ -.extra-network-tree .tree-list-content[expanded=false]+.tree-list--subgroup { +/* Directory
                    visibility based on data-expanded attribute. */ +.extra-network-tree .tree-list-content+.tree-list--subgroup { height: 0; - overflow: hidden; visibility: hidden; opacity: 0; } -.extra-network-tree .tree-list-content[expanded=true]+.tree-list--subgroup { +.extra-network-tree .tree-list-content[data-expanded]+.tree-list--subgroup { height: auto; - overflow: visible; visibility: visible; opacity: 1; } @@ -1307,7 +1338,7 @@ body.resizing .resize-handle { background-color: var(--neutral-800); } -.dark .extra-network-tree div.tree-list-content[selected] { +.dark .extra-network-tree div.tree-list-content[data-selected] { background-color: var(--neutral-700); } @@ -1317,20 +1348,20 @@ body.resizing .resize-handle { background-color: var(--neutral-200); } -.extra-network-tree div.tree-list-content[selected] { +.extra-network-tree div.tree-list-content[data-selected] { background-color: var(--neutral-300); } /* ==== CHEVRON ICON ACTIONS ==== */ /* Define the animation for the arrow when it is clicked. */ -.extra-network-tree .tree-list-content-dir[expanded=false] .tree-list-item-action-chevron { +.extra-network-tree .tree-list-content-dir .tree-list-item-action-chevron { -ms-transform: rotate(135deg); -webkit-transform: rotate(135deg); transform: rotate(135deg); transition: transform 0.2s; } -.extra-network-tree .tree-list-content-dir[expanded=true] .tree-list-item-action-chevron { +.extra-network-tree .tree-list-content-dir[data-expanded] .tree-list-item-action-chevron { -ms-transform: rotate(225deg); -webkit-transform: rotate(225deg); transform: rotate(225deg); From ccee26b0653b4f6778c107d68df52da27446abd2 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Tue, 16 Jan 2024 14:54:07 -0500 Subject: [PATCH 262/449] fix bugs --- javascript/extraNetworks.js | 28 ++++++++----------- modules/ui_extra_networks.py | 35 ++++++++++++------------ modules/ui_extra_networks_checkpoints.py | 3 +- 3 files changed, 31 insertions(+), 35 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index a3f003bf475..caaa3fae0d2 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -31,11 +31,12 @@ function setupExtraNetworksForTab(tabname) { var this_tab = gradioApp().querySelector('#' + tabname + '_extra_tabs'); this_tab.classList.add('extra-networks'); this_tab.querySelectorAll(":scope > [id^='" + tabname + "_']").forEach(function(elem) { - var extra_networks_tabname = elem.id; - var search = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_search"); - var sort_mode = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_sort"); - var sort_dir = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_sort_dir"); - var refresh = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_refresh"); + // tabname_full = {tabname}_{extra_networks_tabname} + var tabname_full = elem.id; + var search = gradioApp().querySelector("#" + tabname_full + "_extra_search"); + var sort_mode = gradioApp().querySelector("#" + tabname_full + "_extra_sort"); + var sort_dir = gradioApp().querySelector("#" + tabname_full + "_extra_sort_dir"); + var refresh = gradioApp().querySelector("#" + tabname_full + "_extra_refresh"); // If any of the buttons above don't exist, we want to skip this iteration of the loop. if (!search || !sort_mode || !sort_dir || !refresh) { @@ -44,16 +45,13 @@ function setupExtraNetworksForTab(tabname) { var applyFilter = function() { var searchTerm = search.value.toLowerCase(); - gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { var searchOnly = elem.querySelector('.search_only'); - var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { return t.textContent.toLowerCase(); }).join(" "); var visible = text.indexOf(searchTerm) != -1; - if (searchOnly && searchTerm.length < 4) { visible = false; } @@ -66,7 +64,6 @@ function setupExtraNetworksForTab(tabname) { var applySort = function() { var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); - var reverse = sort_dir.dataset.sortdir == "Descending"; var sortKey = sort_mode.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); @@ -104,9 +101,8 @@ function setupExtraNetworksForTab(tabname) { search.addEventListener("input", applyFilter); applySort(); applyFilter(); - - extraNetworksApplySort[extra_networks_tabname] = applySort; - extraNetworksApplyFilter[extra_networks_tabname] = applyFilter; + extraNetworksApplySort[tabname_full] = applySort; + extraNetworksApplyFilter[tabname_full] = applyFilter; }); registerPrompt(tabname, tabname + "_prompt"); @@ -147,12 +143,12 @@ function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt); } -function applyExtraNetworkFilter(tabname) { - setTimeout(extraNetworksApplyFilter[tabname], 1); +function applyExtraNetworkFilter(tabname_full) { + setTimeout(extraNetworksApplyFilter[tabname_full], 1); } -function applyExtraNetworkSort(tabname) { - setTimeout(extraNetworksApplySort[tabname], 1); +function applyExtraNetworkSort(tabname_full) { + setTimeout(extraNetworksApplySort[tabname_full], 1); } var extraNetworksApplyFilter = {}; diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index a03207b2bc0..55cd1da2772 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -237,7 +237,7 @@ def create_item_html( "tabname": tabname, "prompt": item["prompt"], "neg_prompt": item.get("negative_prompt", ""), - "allow_neg": "true" if self.allow_negative_prompt else "false" + "allow_neg": str(self.allow_negative_prompt).lower(), } ) onclick = html.escape(onclick) @@ -291,7 +291,7 @@ def create_item_html( search_terms_html += search_term_template.format( **{ "style": "display: none;", - "class": "search_terms" + (" search_only" if search_only else ""), + "class": f"search_terms{' search_only' if search_only else ''}", "search_term": search_term, } ) @@ -307,7 +307,7 @@ def create_item_html( "metadata_button": btn_metadata, "name": html.escape(item["name"]), "prompt": item.get("prompt", None), - "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', + "save_card_preview": html.escape(f"return saveCardPreview(event, '{tabname}', '{item['local_preview']}');"), "search_only": " search_only" if search_only else "", "search_terms": search_terms_html, "sort_keys": sort_keys, @@ -369,7 +369,7 @@ def create_tree_dir_item_html( ul = f"" return ( "
                  • " - f"{btn + ul}" + f"{btn}{ul}" "
                  • " ) @@ -561,7 +561,7 @@ def find_preview(self, path): Find a preview PNG for a given path (without extension) and call link_preview on it. """ - potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], []) + potential_files = sum([[f"{path}.{ext}", f"{path}.preview.{ext}"] for ext in allowed_preview_extensions()], []) for file in potential_files: if self.lister.exists(file): @@ -642,7 +642,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs = [] - button_refresh = gr.Button("Refresh", elem_id=tabname+"_extra_refresh_internal", visible=False) + button_refresh = gr.Button("Refresh", elem_id=f"{tabname}_extra_refresh_internal", visible=False) for page in ui.stored_extra_pages: with gr.Tab(page.title, elem_id=f"{tabname}_{page.extra_networks_tabname}", elem_classes=["extra-page"]) as tab: @@ -652,24 +652,25 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): elem_id = f"{tabname}_{page.extra_networks_tabname}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) - page_elem.change( - fn=lambda: None, - _js=f"function(){{applyExtraNetworkFilter({tabname}_{page.extra_networks_tabname}_extra_search); return []}}", - inputs=[], - outputs=[], - ) - editor = page.create_user_metadata_editor(ui, tabname) editor.create_ui() ui.user_metadata_editors.append(editor) - related_tabs.append(tab) - ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) - ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + ui.button_save_preview = gr.Button('Save preview', elem_id=f"{tabname}_save_preview", visible=False) + ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=f"{tabname}_preview_filename", visible=False) for tab in unrelated_tabs: - tab.select(fn=None, _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=[], show_progress=False) + tab.select(fn=None, _js=f"function(){{extraNetworksUnrelatedTabSelected('{tabname}');}}", inputs=[], outputs=[], show_progress=False) + + for page, tab in zip(ui.stored_extra_pages, related_tabs): + jscode = ( + "function(){{" + f"extraNetworksTabSelected('{tabname}', '{tabname}_{page.extra_networks_tabname}_prompts', {str(page.allow_prompt).lower()}, {str(page.allow_negative_prompt).lower()});" + f"applyExtraNetworkFilter('{tabname}_{page.extra_networks_tabname}');" + "}}" + ) + tab.select(fn=None, _js=jscode, inputs=[], outputs=[], show_progress=False) def create_html(): ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index e7976ba1260..a8c3367198f 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -2,7 +2,6 @@ import os from modules import shared, ui_extra_networks, sd_models -from modules.ui_extra_networks import quote_js from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor @@ -31,7 +30,7 @@ def create_item(self, name, index=None, enable_filter=True): "preview": self.find_preview(path), "description": self.find_description(path), "search_terms": search_terms, - "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', + "onclick": html.escape(f"return selectCheckpoint('{name}');"), "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": checkpoint.metadata, "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, From 0b83f4c26376c87505769a3687cb06e321b413a1 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 13 Jan 2024 18:38:05 +0900 Subject: [PATCH 263/449] reuse seed from infotexts --- modules/processing_scripts/seed.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index 2d3cbb97f06..7b727d17cb1 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -6,6 +6,7 @@ from modules.infotext_utils import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton +from modules import infotext_utils class ScriptSeed(scripts.ScriptBuiltinUI): @@ -77,7 +78,6 @@ def setup(self, p, seed, seed_checkbox, subseed, subseed_strength, seed_resize_f p.seed_resize_from_h = seed_resize_from_h - def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed): """ Connects a 'reuse (sub)seed' button's click event so that it copies last used (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength @@ -85,21 +85,14 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def copy_seed(gen_info_string: str, index): res = -1 - try: gen_info = json.loads(gen_info_string) - index -= gen_info.get('index_of_first_image', 0) - - if is_subseed and gen_info.get('subseed_strength', 0) > 0: - all_subseeds = gen_info.get('all_subseeds', [-1]) - res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] - else: - all_seeds = gen_info.get('all_seeds', [-1]) - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] - - except json.decoder.JSONDecodeError: + infotext = gen_info.get('infotexts')[index] + gen_parameters = infotext_utils.parse_generation_parameters(infotext) + res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1)) + except Exception: if gen_info_string: - errors.report(f"Error parsing JSON generation info: {gen_info_string}") + errors.report(f"Error retrieving seed from generation info: {gen_info_string}", exc_info=True) return [res, gr.update()] From 6acd8e28fceccbbea0c09c91aed0eff5a3504315 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 14 Jan 2024 01:58:01 +0900 Subject: [PATCH 264/449] save_files info base on infotexts --- modules/ui_common.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index f17259c2956..6a4465e9358 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -40,8 +40,9 @@ def save_files(js_data, images, do_make_zip, index): import csv filenames = [] fullfns = [] + parsed_infotexts = [] - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + # quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it class MyObject: def __init__(self, d=None): if d is not None: @@ -49,16 +50,14 @@ def __init__(self, d=None): setattr(self, key, value) data = json.loads(js_data) - p = MyObject(data) + path = shared.opts.outdir_save save_to_dirs = shared.opts.use_save_to_dirs_for_ui extension: str = shared.opts.samples_format start_index = 0 - only_one = False if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - only_one = True images = [images[index]] start_index = index @@ -74,10 +73,12 @@ def __init__(self, d=None): image = image_from_url_text(filedata) is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) p.batch_index = image_index-1 - fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index]) + parsed_infotexts.append(parameters) + fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=parameters['Seed'], prompt=parameters['Prompt'], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) filename = os.path.relpath(fullfn, path) filenames.append(filename) @@ -86,12 +87,12 @@ def __init__(self, d=None): filenames.append(os.path.basename(txt_fullfn)) fullfns.append(txt_fullfn) - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt']]) # Make Zip if do_make_zip: - zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0] - namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True) + p.all_seeds = [parameters['Seed'] for parameters in parsed_infotexts] + namegen = modules.images.FilenameGenerator(p, parsed_infotexts[0]['Seed'], parsed_infotexts[0]['Prompt'], image, True) zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]") zip_filepath = os.path.join(path, f"{zip_filename}.zip") From d224fed0ce16e6f1507b0da29e7495ab2b0035e4 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:16:07 +0900 Subject: [PATCH 265/449] parse_generation_parameters skip_fields --- modules/infotext_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index 9a02cdf2410..1049c6c3c59 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -230,7 +230,7 @@ def restore_old_hires_fix_params(res): res['Hires resize-2'] = height -def parse_generation_parameters(x: str): +def parse_generation_parameters(x: str, skip_fields: list[str] | None = None): """parses generation parameters string, the one you see in text field under the picture in UI: ``` girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate @@ -240,6 +240,8 @@ def parse_generation_parameters(x: str): returns a dict with field values """ + if skip_fields is None: + skip_fields = shared.opts.infotext_skip_pasting res = {} @@ -356,8 +358,8 @@ def parse_generation_parameters(x: str): infotext_versions.backcompat(res) - skip = set(shared.opts.infotext_skip_pasting) - res = {k: v for k, v in res.items() if k not in skip} + for key in skip_fields: + res.pop(key, None) return res From 45a51c07e2ce5a36dc3cddde9a666a4c7b61af82 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:21:58 +0900 Subject: [PATCH 266/449] parse_generation_parameters with no skip_fields --- modules/processing_scripts/seed.py | 2 +- modules/ui_common.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index 7b727d17cb1..7a4c0159831 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -88,7 +88,7 @@ def copy_seed(gen_info_string: str, index): try: gen_info = json.loads(gen_info_string) infotext = gen_info.get('infotexts')[index] - gen_parameters = infotext_utils.parse_generation_parameters(infotext) + gen_parameters = infotext_utils.parse_generation_parameters(infotext, []) res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1)) except Exception: if gen_info_string: diff --git a/modules/ui_common.py b/modules/ui_common.py index 6a4465e9358..6d7f3a67574 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -76,7 +76,7 @@ def __init__(self, d=None): p.batch_index = image_index-1 - parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index]) + parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index], []) parsed_infotexts.append(parameters) fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=parameters['Seed'], prompt=parameters['Prompt'], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) From 6916de5c0bd8df3835d450caa3327d1924db081c Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:16:07 +0900 Subject: [PATCH 267/449] parse_generation_parameters skip_fields --- modules/infotext_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index 9a02cdf2410..1049c6c3c59 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -230,7 +230,7 @@ def restore_old_hires_fix_params(res): res['Hires resize-2'] = height -def parse_generation_parameters(x: str): +def parse_generation_parameters(x: str, skip_fields: list[str] | None = None): """parses generation parameters string, the one you see in text field under the picture in UI: ``` girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate @@ -240,6 +240,8 @@ def parse_generation_parameters(x: str): returns a dict with field values """ + if skip_fields is None: + skip_fields = shared.opts.infotext_skip_pasting res = {} @@ -356,8 +358,8 @@ def parse_generation_parameters(x: str): infotext_versions.backcompat(res) - skip = set(shared.opts.infotext_skip_pasting) - res = {k: v for k, v in res.items() if k not in skip} + for key in skip_fields: + res.pop(key, None) return res From e1dfd452c0447e729a7341c434b4aab6063aa654 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:27:29 +0900 Subject: [PATCH 268/449] parse_generation_parameters with no skip_fields --- modules/txt2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/txt2img.py b/modules/txt2img.py index e617cb1c576..92a160d61da 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -70,7 +70,7 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] p.firstpass_image = infotext_utils.image_from_url_text(image_info) - parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index]) + parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index], []) p.seed = parameters.get('Seed', -1) p.subseed = parameters.get('Variation seed', -1) From 2cf23099ebb81832a27c8016f14062885f5a9c98 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 18 Jan 2024 04:44:21 +0900 Subject: [PATCH 269/449] fix console total progress bar when using txt2img_upscale add p.txt2img_upscale as indicator --- modules/processing.py | 7 +++++-- modules/txt2img.py | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index dcc807fe38d..547ab2f86a0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1227,8 +1227,11 @@ def init(self, all_prompts, all_seeds, all_subseeds): if not state.processing_has_refined_job_count: if state.job_count == -1: state.job_count = self.n_iter - - shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count) + if getattr(self, 'txt2img_upscale', False): + total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count + else: + total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count + shared.total_tqdm.updateTotal(total_steps) state.job_count = state.job_count * 2 state.processing_has_refined_job_count = True diff --git a/modules/txt2img.py b/modules/txt2img.py index 92a160d61da..4efcb4c3d71 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -64,6 +64,7 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g p.enable_hr = True p.batch_size = 1 p.n_iter = 1 + p.txt2img_upscale = True geninfo = json.loads(generation_info) From f25c81a74462554890ac7327a30629b332db1084 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Wed, 17 Jan 2024 22:38:51 -0500 Subject: [PATCH 270/449] Fix embeddings add/remove to/from prompt on click bugs. --- javascript/extraNetworks.js | 13 +++---------- modules/ui_extra_networks.py | 2 +- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index caaa3fae0d2..1e2786ab0d2 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -169,8 +169,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { var m = text.match(isNeg ? re_extranet_neg : re_extranet); var replaced = false; var newTextareaText; + var extraTextBeforeNet = opts.extra_networks_add_text_separator; if (m) { - var extraTextBeforeNet = opts.extra_networks_add_text_separator; var extraTextAfterNet = m[2]; var partToSearch = m[1]; var foundAtPosition = -1; @@ -183,7 +183,6 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { } return found; }); - if (foundAtPosition >= 0) { if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); @@ -193,13 +192,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { } } } else { - newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) { - if (found == text) { - replaced = true; - return ""; - } - return found; - }); + newTextareaText = textarea.value.replaceAll(new RegExp(`((?:${extraTextBeforeNet})?${text})`, "g"), ""); + replaced = (newTextareaText != textarea.value); } if (replaced) { @@ -211,7 +205,6 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { } function updatePromptArea(text, textArea, isNeg) { - if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) { textArea.value = textArea.value + opts.extra_networks_add_text_separator + text; } diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 55cd1da2772..5dd4e443f8e 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -236,7 +236,7 @@ def create_item_html( **{ "tabname": tabname, "prompt": item["prompt"], - "neg_prompt": item.get("negative_prompt", ""), + "neg_prompt": item.get("negative_prompt", "''"), "allow_neg": str(self.allow_negative_prompt).lower(), } ) From 0181c1f76b97162c42401f1e6286ae73d8aa6033 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 19 Jan 2024 00:14:03 +0800 Subject: [PATCH 271/449] Fix nested manual cast --- modules/devices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/devices.py b/modules/devices.py index 0321d12c6ab..3702862979c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -165,6 +165,8 @@ def forward_wrapper(self, *args, **kwargs): @contextlib.contextmanager def manual_cast(target_dtype): for module_type in patch_module_list: + if hasattr(module_type, "org_forward"): + continue org_forward = module_type.forward if module_type == torch.nn.MultiheadAttention and has_xpu(): module_type.forward = manual_cast_forward(torch.float32) @@ -175,7 +177,9 @@ def manual_cast(target_dtype): yield None finally: for module_type in patch_module_list: - module_type.forward = module_type.org_forward + if hasattr(module_type, "org_forward"): + module_type.forward = module_type.org_forward + delattr(module_type, "org_forward") def autocast(disable=False): From 50e444fa1d1c779b0c9c915e44ed2d78b587f277 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Thu, 18 Jan 2024 12:13:09 -0500 Subject: [PATCH 272/449] Fix missing important style. --- style.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/style.css b/style.css index 1090e4368ab..57c52354a0b 100644 --- a/style.css +++ b/style.css @@ -28,7 +28,7 @@ div.gradio-container{ } .hidden{ - display: none; + display: none !important; } .compact{ From 69f4f148dce0868b748b700f96942d4036e848c9 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Thu, 18 Jan 2024 12:13:33 -0500 Subject: [PATCH 273/449] Fix various bugs including refresh bug. --- javascript/extraNetworks.js | 7 +++++-- modules/ui_extra_networks.py | 31 ++++++++++++++++--------------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 1e2786ab0d2..3029afec847 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -55,8 +55,11 @@ function setupExtraNetworksForTab(tabname) { if (searchOnly && searchTerm.length < 4) { visible = false; } - - elem.style.display = visible ? "" : "none"; + if (visible) { + elem.classList.remove("hidden"); + } else { + elem.classList.add("hidden"); + } }); applySort(); diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 5dd4e443f8e..656e7f181e4 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -216,22 +216,17 @@ def create_item_html( Can be empty if the item is not meant to be shown. If no template is passed: A dictionary containing the generated item's attributes. """ - metadata = item.get("metadata") - if metadata: - self.metadata[item["name"]] = metadata - - if "user_metadata" not in item: - self.read_user_metadata(item) - preview = item.get("preview", None) - height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' - width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' + style_height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' + style_width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' + style_font_size = f"font-size: {shared.opts.extra_networks_card_text_scale*100}%;" + card_style = style_height + style_width + style_font_size background_image = f'' if preview else '' onclick = item.get("onclick", None) if onclick is None: # Don't quote prompt/neg_prompt since they are stored as js strings already. - onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, '{allow_neg}');" + onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, {allow_neg});" onclick = onclick_js_tpl.format( **{ "tabname": tabname, @@ -286,11 +281,10 @@ def create_item_html( ).strip() search_terms_html = "" - search_term_template = "{search_term}" + search_term_template = "" for search_term in item.get("search_terms", []): search_terms_html += search_term_template.format( **{ - "style": "display: none;", "class": f"search_terms{' search_only' if search_only else ''}", "search_term": search_term, } @@ -301,7 +295,7 @@ def create_item_html( "background_image": background_image, "card_clicked": onclick, "copy_path_button": btn_copy_path, - "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), + "description": (item.get("description", "") or "" if shared.opts.extra_networks_card_show_desc else ""), "edit_button": btn_edit_item, "local_preview": quote_js(item["local_preview"]), "metadata_button": btn_metadata, @@ -311,7 +305,7 @@ def create_item_html( "search_only": " search_only" if search_only else "", "search_terms": search_terms_html, "sort_keys": sort_keys, - "style": f"display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%", + "style": card_style, "tabname": tabname, "extra_networks_tabname": self.extra_networks_tabname, } @@ -500,7 +494,6 @@ def create_card_view_html(self, tabname: str) -> str: HTML formatted string. """ res = "" - self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): res += self.create_item_html(tabname, item, self.card_tpl) @@ -524,6 +517,14 @@ def create_html(self, tabname): self.lister.reset() self.metadata = {} self.items = {x["name"]: x for x in self.list_items()} + # Populate the instance metadata for each item. + for item in self.items.values(): + metadata = item.get("metadata") + if metadata: + self.metadata[item["name"]] = metadata + + if "user_metadata" not in item: + self.read_user_metadata(item) return self.pane_tpl.format( **{ From a97147bc8a43ade7c18bb755f0cfac111fc1a619 Mon Sep 17 00:00:00 2001 From: n0kovo Date: Fri, 19 Jan 2024 00:10:02 +0100 Subject: [PATCH 274/449] Add support for DAT upscaler models --- modules/dat_model.py | 81 +++++++++++++++++++++++++++++++++++++++ modules/shared_items.py | 5 +++ modules/shared_options.py | 3 ++ 3 files changed, 89 insertions(+) create mode 100644 modules/dat_model.py diff --git a/modules/dat_model.py b/modules/dat_model.py new file mode 100644 index 00000000000..8637351c500 --- /dev/null +++ b/modules/dat_model.py @@ -0,0 +1,81 @@ +import os +import sys + +from modules import modelloader, devices +from modules.shared import cmd_opts, opts +from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model + +from icecream import ic + +class UpscalerDAT(Upscaler): + def __init__(self, user_path): + self.name = "DAT" + self.user_path = user_path + self.scalers = [] + super().__init__() + + for file in self.find_models(ext_filter=[".pt", ".pth"]): + name = modelloader.friendly_name(file) + scaler_data = UpscalerData(name, file, upscaler=self, scale=None) + self.scalers.append(scaler_data) + + for model in get_dat_models(self): + if model.name in opts.dat_enabled_models: + self.scalers.append(model) + + def do_upscale(self, img, selected_model): + try: + info = self.load_model(selected_model) + except Exception as e: + errors.report(f"Unable to load DAT model {path}", exc_info=True) + return img + + model_descriptor = modelloader.load_spandrel_model( + info.local_data_path, + device=self.device, + prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + expected_architecture="DAT", + ) + return upscale_with_model( + model_descriptor, + img, + tile_size=opts.DAT_tile, + tile_overlap=opts.DAT_tile_overlap, + ) + + def load_model(self, path): + for scaler in self.scalers: + if scaler.data_path == path: + if scaler.local_data_path.startswith("http"): + scaler.local_data_path = modelloader.load_file_from_url( + scaler.data_path, + model_dir=self.model_download_path, + ) + if not os.path.exists(scaler.local_data_path): + raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}") + return scaler + raise ValueError(f"Unable to find model info: {path}") + + +def get_dat_models(scaler): + return [ + UpscalerData( + name="DAT x2", + path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth", + scale=2, + upscaler=scaler, + ), + UpscalerData( + name="DAT x3", + path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth", + scale=3, + upscaler=scaler, + ), + UpscalerData( + name="DAT x4", + path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth", + scale=4, + upscaler=scaler, + ), + ] diff --git a/modules/shared_items.py b/modules/shared_items.py index 13fb2814f2b..88f636452c7 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -8,6 +8,11 @@ def realesrgan_models_names(): return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] +def dat_models_names(): + import modules.dat_model + return [x.name for x in modules.dat_model.get_dat_models(None)] + + def postprocessing_scripts(): import modules.scripts diff --git a/modules/shared_options.py b/modules/shared_options.py index 63488f4e700..48a206ce941 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -97,6 +97,9 @@ "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), + "dat_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}), + "DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), + "DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), })) From 2e7efe47b6c78a59f9f3d64c1d30f54751a31a56 Mon Sep 17 00:00:00 2001 From: n0kovo Date: Fri, 19 Jan 2024 00:39:14 +0100 Subject: [PATCH 275/449] Minor cleanup --- modules/dat_model.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/modules/dat_model.py b/modules/dat_model.py index 8637351c500..495d5f4937d 100644 --- a/modules/dat_model.py +++ b/modules/dat_model.py @@ -1,12 +1,10 @@ import os -import sys -from modules import modelloader, devices +from modules import modelloader, errors from modules.shared import cmd_opts, opts from modules.upscaler import Upscaler, UpscalerData from modules.upscaler_utils import upscale_with_model -from icecream import ic class UpscalerDAT(Upscaler): def __init__(self, user_path): @@ -24,13 +22,13 @@ def __init__(self, user_path): if model.name in opts.dat_enabled_models: self.scalers.append(model) - def do_upscale(self, img, selected_model): + def do_upscale(self, img, path): try: - info = self.load_model(selected_model) - except Exception as e: + info = self.load_model(path) + except Exception: errors.report(f"Unable to load DAT model {path}", exc_info=True) return img - + model_descriptor = modelloader.load_spandrel_model( info.local_data_path, device=self.device, From 1ddb886a804dc69f542ebc71bdd7baec48f677b6 Mon Sep 17 00:00:00 2001 From: n0kovo Date: Fri, 19 Jan 2024 00:48:46 +0100 Subject: [PATCH 276/449] Fix wrong options value --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 48a206ce941..74a2a67f933 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -97,7 +97,7 @@ "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), - "dat_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}), + "dat_enabled_models": OptionInfo(["DAT x2", "DAT x3", "DAT x4"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}), "DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), From d31dc7a73996ee611a72ef641237606fc9eddca5 Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 20 Jan 2024 00:40:03 +0400 Subject: [PATCH 277/449] fix extras big batch crashes --- modules/postprocessing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 7449b0dc51c..f14882321ec 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -62,8 +62,6 @@ def get_images(extras_mode, image, image_folder, input_dir): else: image_data = image_placeholder - shared.state.assign_current_image(image_data) - parameters, existing_pnginfo = images.read_info_from_image(image_data) if parameters: existing_pnginfo["parameters"] = parameters @@ -92,6 +90,8 @@ def get_images(extras_mode, image, image_folder, input_dir): pp.image.info = existing_pnginfo pp.image.info["postprocessing"] = infotext + shared.state.assign_current_image(pp.image) + if save_output: fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix) From 4dde96109bb9621bb5608f7f792dc569d692ecf5 Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:38:34 -0800 Subject: [PATCH 278/449] Update zoom.js --- extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js index dd6028c1e25..df60c1a1775 100644 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -218,8 +218,8 @@ onUiLoaded(async() => { canvas_hotkey_fullscreen: "KeyS", canvas_hotkey_move: "KeyF", canvas_hotkey_overlap: "KeyO", - canvas_hotkey_shrink_brush: "BracketLeft", - canvas_hotkey_grow_brush: "BracketRight", + canvas_hotkey_shrink_brush: "KeyQ", + canvas_hotkey_grow_brush: "KeyW", canvas_disabled_functions: [], canvas_show_tooltip: true, canvas_auto_expand: true, From bcdcc8be7d5fb923d64b7eb8a9a831f26f179d97 Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:39:07 -0800 Subject: [PATCH 279/449] Update hotkey_config.py --- .../canvas-zoom-and-pan/scripts/hotkey_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py index 9d9accce8cf..6d9e34d8b0e 100644 --- a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py +++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py @@ -4,8 +4,8 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), { "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), "canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), - "canvas_hotkey_shrink_brush": shared.OptionInfo("[", "Shrink the brush size"), - "canvas_hotkey_grow_brush": shared.OptionInfo("]", "Enlarge the brush size"), + "canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"), + "canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"), "canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"), "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "), "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), From 30f31d23c6d212389aef035137c381dd79d44a26 Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:39:33 -0800 Subject: [PATCH 280/449] Update hotkey_config.py --- extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py index 6d9e34d8b0e..89b7c31f22d 100644 --- a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py +++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py @@ -5,7 +5,7 @@ "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), "canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), "canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"), - "canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"), + "canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"), "canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"), "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "), "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), From 56676ff923497901d56fcbca1ac549095e71d72d Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 20 Jan 2024 11:49:05 +0400 Subject: [PATCH 281/449] fix tab indexes reset after restart ui --- modules/ui.py | 4 ++-- modules/ui_postprocessing.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index a716a04055d..ebd33a85616 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -532,7 +532,7 @@ def add_copy_image_controls(tab_name, elem): if category == "image": with gr.Tabs(elem_id="mode_img2img"): - img2img_selected_tab = gr.State(0) + img2img_selected_tab = gr.Number(value=0, visible=False) with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) @@ -613,7 +613,7 @@ def copy_image(img): elif category == "dimensions": with FormRow(): with gr.Column(elem_id="img2img_column_size", scale=4): - selected_scale_tab = gr.State(value=0) + selected_scale_tab = gr.Number(value=0, visible=False) with gr.Tabs(): with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 7a132ac22ed..ff22a178075 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -5,7 +5,7 @@ def create_ui(): dummy_component = gr.Label(visible=False) - tab_index = gr.State(value=0) + tab_index = gr.Number(value=0, visible=False) with gr.Row(equal_height=False, variant='compact'): with gr.Column(variant='compact'): From 81126027f5226e7ee58e1a99194eb9ec7b8ec6e7 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 20 Jan 2024 16:31:12 +0800 Subject: [PATCH 282/449] Avoid early disable --- modules/devices.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/devices.py b/modules/devices.py index 3702862979c..3bde169904b 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -164,9 +164,11 @@ def forward_wrapper(self, *args, **kwargs): @contextlib.contextmanager def manual_cast(target_dtype): + applied = False for module_type in patch_module_list: if hasattr(module_type, "org_forward"): continue + applied = True org_forward = module_type.forward if module_type == torch.nn.MultiheadAttention and has_xpu(): module_type.forward = manual_cast_forward(torch.float32) @@ -176,6 +178,8 @@ def manual_cast(target_dtype): try: yield None finally: + if not applied: + return for module_type in patch_module_list: if hasattr(module_type, "org_forward"): module_type.forward = module_type.org_forward From 4a66d2fb228584bb38dc22db6a3e657561834c7a Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 20 Jan 2024 16:33:59 +0800 Subject: [PATCH 283/449] Avoid exceptions to be silenced --- modules/devices.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 3bde169904b..dfffaf24fd0 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -178,12 +178,11 @@ def manual_cast(target_dtype): try: yield None finally: - if not applied: - return - for module_type in patch_module_list: - if hasattr(module_type, "org_forward"): - module_type.forward = module_type.org_forward - delattr(module_type, "org_forward") + if applied: + for module_type in patch_module_list: + if hasattr(module_type, "org_forward"): + module_type.forward = module_type.org_forward + delattr(module_type, "org_forward") def autocast(disable=False): From ed383eb5a0db78d0da60420bb464943e4cc9b847 Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 20 Jan 2024 15:37:49 +0400 Subject: [PATCH 284/449] keep postprocessing upscale selected tab after restart --- scripts/postprocessing_upscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index a57f9d4a4b6..e269682d0f3 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -15,7 +15,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): order = 1000 def ui(self): - selected_tab = gr.State(value=0) + selected_tab = gr.Number(value=0, visible=False) with gr.Column(): with FormRow(): From 2310cd66e5381fbe6b966894381c6ee7b762898f Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sat, 20 Jan 2024 11:43:45 -0500 Subject: [PATCH 285/449] Add toggle button for tree view. Use default settings for sortmode and direction. --- html/extra-networks-pane.html | 55 ++++++++++++++++++++-- html/extra-networks-tree.html | 37 --------------- javascript/extraNetworks.js | 20 ++++++-- modules/ui_extra_networks.py | 7 +++ style.css | 86 +++++++++++++++++++++++------------ 5 files changed, 132 insertions(+), 73 deletions(-) diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index bf46ca16305..73dad2ab248 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -1,8 +1,55 @@
                    -
                    - {tree_html} +
                    + +
                    + +
                    +
                    + +
                    +
                    + +
                    +
                    + +
                    -
                    - {items_html} +
                    +
                    + {tree_html} +
                    +
                    + {items_html} +
                    \ No newline at end of file diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html index 23f6af1058c..39649e86093 100644 --- a/html/extra-networks-tree.html +++ b/html/extra-networks-tree.html @@ -1,41 +1,4 @@
                    -
                    - -
                    - -
                    -
                    - -
                    -
                    - -
                    -
                    {tree}
                    diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 3029afec847..ce788328ce5 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -348,7 +348,7 @@ function extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) { } } -function extraNetworksTreeSortOnClick(event, tabname, extra_networks_tabname) { +function extraNetworksControlSortOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for the Sort Mode button. * @@ -382,7 +382,7 @@ function extraNetworksTreeSortOnClick(event, tabname, extra_networks_tabname) { applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } -function extraNetworksTreeSortDirOnClick(event, tabname, extra_networks_tabname) { +function extraNetworksControlSortDirOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for the Sort Direction button. * @@ -403,7 +403,21 @@ function extraNetworksTreeSortDirOnClick(event, tabname, extra_networks_tabname) applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } -function extraNetworksTreeRefreshOnClick(event, tabname, extra_networks_tabname) { +function extraNetworksControlTreeViewOnClick(event, tabname, extra_networks_tabname) { + /** + * Handles `onclick` events for the Tree View button. + * + * Toggles the tree view in the extra networks pane. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ + gradioApp().getElementById(tabname + "_" + extra_networks_tabname + "_tree").classList.toggle("hidden"); + event.currentTarget.classList.toggle("extra-network-control--enabled"); +} + +function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for the Refresh Page button. * diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 656e7f181e4..4c8a407449e 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -526,10 +526,17 @@ def create_html(self, tabname): if "user_metadata" not in item: self.read_user_metadata(item) + data_sortdir = shared.opts.extra_networks_card_order + data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip() + data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}" + return self.pane_tpl.format( **{ "tabname": tabname, "extra_networks_tabname": self.extra_networks_tabname, + "data_sortmode": data_sortmode, + "data_sortkey": data_sortkey, + "data_sortdir": data_sortdir, "tree_html": self.create_tree_view_html(tabname), "items_html": self.create_card_view_html(tabname), } diff --git a/style.css b/style.css index 57c52354a0b..f3fd1571bb1 100644 --- a/style.css +++ b/style.css @@ -1178,6 +1178,13 @@ body.resizing .resize-handle { height: calc(100vh - 24rem); resize: vertical; min-height: 52rem; + flex-direction: column; +} + +.extra-network-pane .extra-network-pane-content { + display: flex; + flex: 1; + flex-direction: row; } .extra-network-pane .extra-network-tree { @@ -1234,7 +1241,7 @@ body.resizing .resize-handle { display: none; } -.extra-network-tree .tree-list .tree-list-controls { +.extra-network-pane .extra-network-control { position: relative; display: grid; width: 100%; @@ -1248,8 +1255,7 @@ body.resizing .resize-handle { border: none; transition: background 33.333ms linear; grid-template-rows: min-content; - grid-template-areas: "tree-list-controls-col-0 tree-list-controls-col-1 tree-list-controls-col-2 tree-list-controls-col-3"; - grid-template-columns: minmax(0, auto) min-content min-content min-content; + grid-template-columns: minmax(0, auto) repeat(4, min-content); grid-gap: 0.1rem; align-items: start; } @@ -1342,16 +1348,16 @@ body.resizing .resize-handle { background-color: var(--neutral-700); } +.extra-network-tree div.tree-list-content[data-selected] { + background-color: var(--neutral-300); +} + .extra-network-tree div.tree-list-content:hover { -webkit-transition: all 0.05s ease-in-out; transition: all 0.05s ease-in-out; background-color: var(--neutral-200); } -.extra-network-tree div.tree-list-content[data-selected] { - background-color: var(--neutral-300); -} - /* ==== CHEVRON ICON ACTIONS ==== */ /* Define the animation for the arrow when it is clicked. */ .extra-network-tree .tree-list-content-dir .tree-list-item-action-chevron { @@ -1378,7 +1384,7 @@ body.resizing .resize-handle { /* ==== SEARCH INPUT ACTIONS ==== */ /* Add icon to left side of */ -.extra-network-tree .tree-list-controls .tree-list-search::before { +.extra-network-pane .extra-network-control .extra-network-control--search::before { content: "🔎︎"; position: absolute; margin: 0.5rem; @@ -1386,14 +1392,12 @@ body.resizing .resize-handle { color: var(--input-placeholder-color); } -.extra-network-tree .tree-list-controls .tree-list-search { +.extra-network-pane .extra-network-control .extra-network-control--search { display: inline-flex; - grid-area: tree-list-controls-col-0; position: relative; - margin: 0.5rem; } -.extra-network-tree .tree-list-controls .tree-list-search .tree-list-search-text { +.extra-network-pane .extra-network-control .extra-network-control--search .extra-network-control--search-text { border: 1px solid var(--button-secondary-border-color); border-radius: 0.5rem; color: var(--button-secondary-text-color); @@ -1404,7 +1408,7 @@ body.resizing .resize-handle { } /* clear button (x on right side) styling */ -.extra-network-tree .tree-list-controls .tree-list-search .tree-list-search-text::-webkit-search-cancel-button { +.extra-network-pane .extra-network-control .extra-network-control--search .extra-network-control--search-text::-webkit-search-cancel-button { -webkit-appearance: none; appearance: none; cursor: pointer; @@ -1418,8 +1422,7 @@ body.resizing .resize-handle { } /* ==== SORT ICON ACTIONS ==== */ -.extra-network-tree .tree-list-controls .tree-list-sort { - grid-area: tree-list-controls-col-1; +.extra-network-pane .extra-network-control .extra-network-control--sort { padding: 0.25rem; display: inline-flex; cursor: pointer; @@ -1427,7 +1430,7 @@ body.resizing .resize-handle { align-self: center; } -.extra-network-tree .tree-list-controls .tree-list-sort .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort .extra-network-control--sort-icon { height: 1.5rem; width: 1.5rem; mask-repeat: no-repeat; @@ -1436,25 +1439,24 @@ body.resizing .resize-handle { background-color: var(--input-placeholder-color); } -.extra-network-tree .tree-list-sort[data-sortmode="path"] .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort[data-sortmode="path"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-tree .tree-list-sort[data-sortmode="name"] .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort[data-sortmode="name"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-tree .tree-list-sort[data-sortmode="date_created"] .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort[data-sortmode="date_created"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-tree .tree-list-sort[data-sortmode="date_modified"] .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort[data-sortmode="date_modified"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } /* ==== SORT DIRECTION ICON ACTIONS ==== */ -.extra-network-tree .tree-list-controls .tree-list-sort-dir { - grid-area: tree-list-controls-col-2; +.extra-network-pane .extra-network-control .extra-network-control--sort-dir { padding: 0.25rem; display: inline-flex; cursor: pointer; @@ -1462,7 +1464,7 @@ body.resizing .resize-handle { align-self: center; } -.extra-network-tree .tree-list-controls .tree-list-sort-dir .tree-list-sort-dir-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort-dir .extra-network-control--sort-dir-icon { height: 1.5rem; width: 1.5rem; mask-repeat: no-repeat; @@ -1471,17 +1473,43 @@ body.resizing .resize-handle { background-color: var(--input-placeholder-color); } -.extra-network-tree .tree-list-sort-dir[data-sortdir="Ascending"] .tree-list-sort-dir-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort-dir[data-sortdir="Ascending"] .extra-network-control--sort-dir-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-tree .tree-list-sort-dir[data-sortdir="Descending"] .tree-list-sort-dir-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort-dir[data-sortdir="Descending"] .extra-network-control--sort-dir-icon { mask-image: url('data:image/svg+xml,'); } +/* ==== TREE VIEW ICON ACTIONS ==== */ +.extra-network-pane .extra-network-control .extra-network-control--tree-view { + padding: 0.25rem; + display: inline-flex; + cursor: pointer; + justify-self: center; + align-self: center; +} + +.extra-network-pane .extra-network-control .extra-network-control--tree-view .extra-network-control--tree-view-icon { + height: 1.5rem; + width: 1.5rem; + mask-image: url('data:image/svg+xml,'); + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +.dark .extra-network-pane .extra-network-control .extra-network-control--enabled { + background-color: var(--neutral-700); +} + +.dark .extra-network-pane .extra-network-control .extra-network-control--enabled { + background-color: var(--neutral-300); +} + /* ==== REFRESH ICON ACTIONS ==== */ -.extra-network-tree .tree-list-controls .tree-list-refresh { - grid-area: tree-list-controls-col-3; +.extra-network-pane .extra-network-control .extra-network-control--refresh { padding: 0.25rem; display: inline-flex; cursor: pointer; @@ -1489,7 +1517,7 @@ body.resizing .resize-handle { align-self: center; } -.extra-network-tree .tree-list-controls .tree-list-refresh .tree-list-refresh-icon { +.extra-network-pane .extra-network-control .extra-network-control--refresh .extra-network-control--refresh-icon { height: 1.5rem; width: 1.5rem; mask-image: url('data:image/svg+xml,'); @@ -1499,7 +1527,7 @@ body.resizing .resize-handle { background-color: var(--input-placeholder-color); } -.extra-network-tree .tree-list-refresh-icon:active { +.extra-network-pane .extra-network-control .extra-network-control--refresh-icon:active { -ms-transform: rotate(180deg); -webkit-transform: rotate(180deg); transform: rotate(180deg); From b67a49441fc420f37c6bef1172a0b1ad5c42f30f Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sat, 20 Jan 2024 13:28:37 -0500 Subject: [PATCH 286/449] Add option in settings to enable/disable tree view by default. --- html/extra-networks-pane.html | 4 ++-- modules/shared_options.py | 1 + modules/ui_extra_networks.py | 7 +++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index 73dad2ab248..9f5b3ecee62 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -29,7 +29,7 @@
                    @@ -45,7 +45,7 @@
                    -
                    +
                    {tree_html}
                    diff --git a/modules/shared_options.py b/modules/shared_options.py index 63488f4e700..e0a6d977c30 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -251,6 +251,7 @@ "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"), "extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(), "extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(), + "extra_networks_tree_view_default_enabled": OptionInfo(False, "Enables the Extra Networks directory tree view by default").needs_reload_ui(), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(), "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 4c8a407449e..80160b847b6 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -529,6 +529,11 @@ def create_html(self, tabname): data_sortdir = shared.opts.extra_networks_card_order data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip() data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}" + tree_view_btn_extra_class = "" + tree_view_div_extra_class = "hidden" + if shared.opts.extra_networks_tree_view_default_enabled: + tree_view_btn_extra_class = "extra-network-control--enabled" + tree_view_div_extra_class = "" return self.pane_tpl.format( **{ @@ -537,6 +542,8 @@ def create_html(self, tabname): "data_sortmode": data_sortmode, "data_sortkey": data_sortkey, "data_sortdir": data_sortdir, + "tree_view_btn_extra_class": tree_view_btn_extra_class, + "tree_view_div_extra_class": tree_view_div_extra_class, "tree_html": self.create_tree_view_html(tabname), "items_html": self.create_card_view_html(tabname), } From 25e8273d2f6481deca221b29d35093f6d0c9da6a Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 21 Jan 2024 02:21:36 +0900 Subject: [PATCH 287/449] re-work multi --styles-file --styles-file change to append str --styles-file is [] then defaults to [styles.csv] --styles-file accepts paths or paths with wildcard "*" the first `--styles-file` entry is use as the default styles file path if filename a wildcard then the first matching file is used if no match is found, create a new "styles.csv" in the same dir as the first path when saving a new style it will be save in the default styles file when saving a existing style, it will be saved to file it belongs to order of the styles files in the styles dropdown can be controlled to a certain degree by the order of --styles-file --- modules/cmd_args.py | 2 +- modules/shared.py | 3 +- modules/styles.py | 85 +++++++++++++++++++------------------ modules/ui_prompt_styles.py | 9 ++-- 4 files changed, 52 insertions(+), 47 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index e58059a1fea..f1251b6c850 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -88,7 +88,7 @@ parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path]) parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") -parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv')) +parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[]) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) diff --git a/modules/shared.py b/modules/shared.py index 636619391fc..ccdca4e7040 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,3 +1,4 @@ +import os import sys import gradio as gr @@ -11,7 +12,7 @@ batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond parallel_processing_allowed = True -styles_filename = cmd_opts.styles_file +styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')] config_filename = cmd_opts.ui_settings_file hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} diff --git a/modules/styles.py b/modules/styles.py index 026c43001a4..9edcc7e4447 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -1,16 +1,15 @@ +from pathlib import Path import csv -import fnmatch import os -import os.path import typing import shutil class PromptStyle(typing.NamedTuple): name: str - prompt: str - negative_prompt: str - path: str = None + prompt: str | None + negative_prompt: str | None + path: str | None = None def merge_prompts(style_prompt: str, prompt: str) -> str: @@ -79,14 +78,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt): class StyleDatabase: - def __init__(self, path: str): + def __init__(self, paths: list[str | Path]): self.no_style = PromptStyle("None", "", "", None) self.styles = {} - self.path = path - - folder, file = os.path.split(self.path) - filename, _, ext = file.partition('*') - self.default_path = os.path.join(folder, filename + ext) + self.paths = paths + self.all_styles_files: list[Path] = [] + + folder, file = os.path.split(self.paths[0]) + if '*' in file or '?' in file: + # if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path + self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv'))) + self.paths.insert(0, self.default_path) + else: + self.default_path = Path(self.paths[0]) self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] @@ -99,33 +103,31 @@ def reload(self): """ self.styles.clear() - path, filename = os.path.split(self.path) - - if "*" in filename: - fileglob = filename.split("*")[0] + "*.csv" - filelist = [] - for file in os.listdir(path): - if fnmatch.fnmatch(file, fileglob): - filelist.append(file) - # Add a visible divider to the style list - half_len = round(len(file) / 2) - divider = f"{'-' * (20 - half_len)} {file.upper()}" - divider = f"{divider} {'-' * (40 - len(divider))}" - self.styles[divider] = PromptStyle( - f"{divider}", None, None, "do_not_save" - ) - # Add styles from this CSV file - self.load_from_csv(os.path.join(path, file)) - if len(filelist) == 0: - print(f"No styles found in {path} matching {fileglob}") - return - elif not os.path.exists(self.path): - print(f"Style database not found: {self.path}") - return - else: - self.load_from_csv(self.path) - - def load_from_csv(self, path: str): + # scans for all styles files + all_styles_files = [] + for pattern in self.paths: + folder, file = os.path.split(pattern) + if '*' in file or '?' in file: + found_files = Path(folder).glob(file) + [all_styles_files.append(file) for file in found_files] + else: + # if os.path.exists(pattern): + all_styles_files.append(Path(pattern)) + + # Remove any duplicate entries + seen = set() + self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))] + + for styles_file in self.all_styles_files: + if len(all_styles_files) > 1: + # add divider when more than styles file + # '---------------- STYLES ----------------' + divider = f' {styles_file.stem.upper()} '.center(40, '-') + self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save") + if styles_file.is_file(): + self.load_from_csv(styles_file) + + def load_from_csv(self, path: str | Path): with open(path, "r", encoding="utf-8-sig", newline="") as file: reader = csv.DictReader(file, skipinitialspace=True) for row in reader: @@ -137,7 +139,7 @@ def load_from_csv(self, path: str): negative_prompt = row.get("negative_prompt", "") # Add style to database self.styles[row["name"]] = PromptStyle( - row["name"], prompt, negative_prompt, path + row["name"], prompt, negative_prompt, str(path) ) def get_style_paths(self) -> set: @@ -145,11 +147,11 @@ def get_style_paths(self) -> set: # Update any styles without a path to the default path for style in list(self.styles.values()): if not style.path: - self.styles[style.name] = style._replace(path=self.default_path) + self.styles[style.name] = style._replace(path=str(self.default_path)) # Create a list of all distinct paths, including the default path style_paths = set() - style_paths.add(self.default_path) + style_paths.add(str(self.default_path)) for _, style in self.styles.items(): if style.path: style_paths.add(style.path) @@ -177,7 +179,6 @@ def apply_negative_styles_to_prompt(self, prompt, styles): def save_styles(self, path: str = None) -> None: # The path argument is deprecated, but kept for backwards compatibility - _ = path style_paths = self.get_style_paths() diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py index 0d74c23fa19..d67e3f17e10 100644 --- a/modules/ui_prompt_styles.py +++ b/modules/ui_prompt_styles.py @@ -22,9 +22,12 @@ def save_style(name, prompt, negative_prompt): if not name: return gr.update(visible=False) - style = styles.PromptStyle(name, prompt, negative_prompt) + existing_style = shared.prompt_styles.styles.get(name) + path = existing_style.path if existing_style is not None else None + + style = styles.PromptStyle(name, prompt, negative_prompt, path) shared.prompt_styles.styles[style.name] = style - shared.prompt_styles.save_styles(shared.styles_filename) + shared.prompt_styles.save_styles() return gr.update(visible=True) @@ -34,7 +37,7 @@ def delete_style(name): return shared.prompt_styles.styles.pop(name, None) - shared.prompt_styles.save_styles(shared.styles_filename) + shared.prompt_styles.save_styles() return '', '', '' From 8459015017fde8833a4159b8d56a32f92ebf4186 Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Sat, 20 Jan 2024 21:19:53 +0100 Subject: [PATCH 288/449] skip if headers haven't changed --- modules/ui_common.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index 92b00719238..2c67834c6c5 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -1,3 +1,4 @@ +import csv import dataclasses import json import html @@ -37,8 +38,6 @@ def plaintext_to_html(text, classname=None): def update_logfile(logfile_path, fields): - import csv - with open(logfile_path, "r", encoding="utf8", newline="") as file: reader = csv.reader(file) rows = list(reader) @@ -47,6 +46,10 @@ def update_logfile(logfile_path, fields): if not rows: return + # file is already synced, do nothing + if len(rows[0]) == len(fields): + return + rows[0] = fields # append new fields to each row as empty values @@ -60,7 +63,6 @@ def update_logfile(logfile_path, fields): def save_files(js_data, images, do_make_zip, index): - import csv filenames = [] fullfns = [] parsed_infotexts = [] From f190b85182a89f873fa99d897ac20310047131ea Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Sat, 20 Jan 2024 21:27:38 +0100 Subject: [PATCH 289/449] restore saving fields --- modules/ui_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index 2c67834c6c5..6eb740f16c5 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -132,7 +132,7 @@ def __init__(self, d=None): filenames.append(os.path.basename(txt_fullfn)) fullfns.append(txt_fullfn) - writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt']]) + writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt'], data["sd_model_name"], data["sd_model_hash"]]) # Make Zip if do_make_zip: From 4aa99f77abd439469f20e6e2941dd553031ed2d0 Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Sat, 20 Jan 2024 22:04:53 +0100 Subject: [PATCH 290/449] add docstring --- modules/ui_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/ui_common.py b/modules/ui_common.py index 6eb740f16c5..29fe7d0e920 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -38,6 +38,7 @@ def plaintext_to_html(text, classname=None): def update_logfile(logfile_path, fields): + """Update a logfile from old format to new format to maintain CSV integrity.""" with open(logfile_path, "r", encoding="utf8", newline="") as file: reader = csv.reader(file) rows = list(reader) From e36827af3254f7bac9f8c78d6d56c709959b40b6 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 21 Jan 2024 07:20:52 +0900 Subject: [PATCH 291/449] improve get_crop_region --- modules/masking.py | 43 +++++++++---------------------------------- modules/processing.py | 2 +- 2 files changed, 10 insertions(+), 35 deletions(-) diff --git a/modules/masking.py b/modules/masking.py index be9f84c76f6..29a3945278e 100644 --- a/modules/masking.py +++ b/modules/masking.py @@ -3,40 +3,15 @@ def get_crop_region(mask, pad=0): """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. - For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" - - h, w = mask.shape - - crop_left = 0 - for i in range(w): - if not (mask[:, i] == 0).all(): - break - crop_left += 1 - - crop_right = 0 - for i in reversed(range(w)): - if not (mask[:, i] == 0).all(): - break - crop_right += 1 - - crop_top = 0 - for i in range(h): - if not (mask[i] == 0).all(): - break - crop_top += 1 - - crop_bottom = 0 - for i in reversed(range(h)): - if not (mask[i] == 0).all(): - break - crop_bottom += 1 - - return ( - int(max(crop_left-pad, 0)), - int(max(crop_top-pad, 0)), - int(min(w - crop_right + pad, w)), - int(min(h - crop_bottom + pad, h)) - ) + For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)""" + mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask) + box = mask_img.getbbox() + if box: + x1, y1, x2, y2 = box + else: # when no box is found + x1, y1 = mask_img.size + x2 = y2 = 0 + return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1]) def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): diff --git a/modules/processing.py b/modules/processing.py index 6b63179512e..72d8093ba32 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1562,7 +1562,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): if self.inpaint_full_res: self.mask_for_overlay = image_mask mask = image_mask.convert('L') - crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) + crop_region = masking.get_crop_region(mask, self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) x1, y1, x2, y2 = crop_region From 2974b9cee94dc474ffbc9e9617d14c9aaf9e1e63 Mon Sep 17 00:00:00 2001 From: Stefan Benten Date: Sun, 21 Jan 2024 14:05:47 +0100 Subject: [PATCH 292/449] modules/api/api.py: add api endpoint to refresh embeddings list --- modules/api/api.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index b3d74e513a3..b6bb9d06ab1 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -230,6 +230,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) + self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) @@ -643,6 +644,10 @@ def convert_embeddings(embeddings): "skipped": convert_embeddings(db.skipped_embeddings), } + def refresh_embeddings(self): + with self.queue_lock: + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + def refresh_checkpoints(self): with self.queue_lock: shared.refresh_checkpoints() From d7d3166a2749fe04f0ba1d8d0f88c8e8819a379d Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sun, 21 Jan 2024 11:27:24 -0500 Subject: [PATCH 293/449] Fix broken scrollbars --- html/extra-networks-tree.html | 4 +--- style.css | 20 +++++++------------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html index 39649e86093..beec888cd9d 100644 --- a/html/extra-networks-tree.html +++ b/html/extra-networks-tree.html @@ -1,5 +1,3 @@
                    -
                    - {tree} -
                    + {tree}
                    \ No newline at end of file diff --git a/style.css b/style.css index f3fd1571bb1..ff1d907226a 100644 --- a/style.css +++ b/style.css @@ -1179,20 +1179,20 @@ body.resizing .resize-handle { resize: vertical; min-height: 52rem; flex-direction: column; + overflow: hidden; } .extra-network-pane .extra-network-pane-content { display: flex; flex: 1; - flex-direction: row; + overflow: hidden; } .extra-network-pane .extra-network-tree { flex: 1; - flex-direction: column; - display: flex; font-size: 1rem; border: 1px solid var(--block-border-color); + overflow: clip auto !important; } .extra-network-pane .extra-network-cards { @@ -1210,34 +1210,28 @@ body.resizing .resize-handle { overflow: hidden; } -.extra-network-pane .extra-network-tree .tree-list .tree-list-container { - flex: 1; - overflow: clip auto !important; - width: 100%; -} - .extra-network-pane .extra-network-cards::-webkit-scrollbar, -.extra-network-pane .tree-list-container::-webkit-scrollbar { +.extra-network-pane .extra-network-tree::-webkit-scrollbar { background-color: transparent; width: 16px; } .extra-network-pane .extra-network-cards::-webkit-scrollbar-track, -.extra-network-pane .tree-list-container::-webkit-scrollbar-track { +.extra-network-pane .extra-network-tree::-webkit-scrollbar-track { background-color: transparent; background-clip: content-box; } .extra-network-pane .extra-network-cards::-webkit-scrollbar-thumb, -.extra-network-pane .tree-list-container::-webkit-scrollbar-thumb { +.extra-network-pane .extra-network-tree::-webkit-scrollbar-thumb { background-color: var(--border-color-primary); border-radius: 16px; border: 4px solid var(--background-fill-primary); } .extra-network-pane .extra-network-cards::-webkit-scrollbar-button, -.extra-network-pane .tree-list-container::-webkit-scrollbar-button { +.extra-network-pane .extra-network-tree::-webkit-scrollbar-button { display: none; } From 26e1cd7ec47c8d234d2ea3f189b1147329c9059c Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sun, 21 Jan 2024 11:34:08 -0500 Subject: [PATCH 294/449] Remove unnecessary template and simplify tree list. --- html/extra-networks-tree.html | 3 --- modules/ui_extra_networks.py | 11 +---------- 2 files changed, 1 insertion(+), 13 deletions(-) delete mode 100644 html/extra-networks-tree.html diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html deleted file mode 100644 index beec888cd9d..00000000000 --- a/html/extra-networks-tree.html +++ /dev/null @@ -1,3 +0,0 @@ -
                    - {tree} -
                    \ No newline at end of file diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 80160b847b6..157b3a6d4a9 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -164,7 +164,6 @@ def __init__(self, title): self.lister = util.MassFileLister() # HTML Templates self.pane_tpl = shared.html("extra-networks-pane.html") - self.tree_tpl = shared.html("extra-networks-tree.html") self.card_tpl = shared.html("extra-networks-card.html") self.btn_tree_tpl = shared.html("extra-networks-tree-button.html") self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html") @@ -420,8 +419,6 @@ def create_tree_file_item_html(self, tabname: str, file_path: str, item: dict) - def create_tree_view_html(self, tabname: str) -> str: """Generates HTML for displaying folders in a tree view. - The generated HTML uses `extra-networks-tree.html` as a template. - Args: tabname: The name of the active tab. @@ -473,13 +470,7 @@ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional if item_html is not None: res += item_html - return self.tree_tpl.format( - **{ - "tabname": tabname, - "extra_networks_tabname": self.extra_networks_tabname, - "tree": f"
                      {res}
                    " - } - ) + return f"
                      {res}
                    " def create_card_view_html(self, tabname: str) -> str: """Generates HTML for the network Card View section for a tab. From fd383140cf405100f3c619f106472273a7545beb Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 22 Jan 2024 02:52:34 -0800 Subject: [PATCH 295/449] fix: wrong devices for eye and constraint --- extensions-builtin/Lora/network_oft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 342fcd0dcaf..d1c46a4b22e 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -57,12 +57,12 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): def calc_updown(self, orig_weight): oft_blocks = self.oft_blocks.to(orig_weight.device) - eye = torch.eye(self.block_size, device=self.oft_blocks.device) + eye = torch.eye(self.block_size, device=oft_blocks.device) if self.is_kohya: block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device)) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) From f4e931f18fa4f94aece1f4dabd4dd0d635ecec13 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 22 Jan 2024 23:20:30 +0300 Subject: [PATCH 296/449] put extra networks controls row into the tabs UI element for #14588 --- html/extra-networks-pane.html | 2 +- javascript/extraNetworks.js | 24 +++++++++++++++-- modules/ui.py | 4 +-- modules/ui_extra_networks.py | 2 +- style.css | 51 +++++++++++++++++++---------------- 5 files changed, 54 insertions(+), 29 deletions(-) diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index 9f5b3ecee62..0c763f71058 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -1,5 +1,5 @@
                    -
                    +