From 9e3403df899684d18bbc534393d9fdc5033d59a5 Mon Sep 17 00:00:00 2001 From: Bryce Date: Sat, 30 Dec 2023 21:21:49 -0800 Subject: [PATCH] feature: clean up terminal output - recording timing and memory usage of various steps - re-use logging context for composition images - load sdxl weights in a more VRAM efficient way - switch to diffusers weights for default weights for sd15 --- imaginairy/api/generate.py | 15 +- imaginairy/api/generate_compvis.py | 2 +- imaginairy/api/generate_refiners.py | 315 ++++++++++-------- imaginairy/cli/shared.py | 5 +- imaginairy/config.py | 4 +- .../enhancers/face_restoration_codeformer.py | 2 +- imaginairy/modules/refiners_sd.py | 114 ++++++- imaginairy/schema.py | 58 +++- imaginairy/utils/log_utils.py | 98 +++++- imaginairy/utils/model_manager.py | 34 +- scripts/generate_phraselist.py | 15 +- tests/test_enhancers/test_prompt_expansion.py | 5 + tox.ini | 2 +- 13 files changed, 470 insertions(+), 199 deletions(-) diff --git a/imaginairy/api/generate.py b/imaginairy/api/generate.py index 127e87b6..13c9162f 100755 --- a/imaginairy/api/generate.py +++ b/imaginairy/api/generate.py @@ -84,7 +84,6 @@ def _record_step(img, description, image_count, step_count, prompt): f"{base_count:06}_{prompt.seed}_{prompt.solver_type.replace('_', '')}{prompt.steps}_" f"PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}" ) - for image_type in result.images: subpath = os.path.join(outdir, image_type) os.makedirs(subpath, exist_ok=True) @@ -92,7 +91,7 @@ def _record_step(img, description, image_count, step_count, prompt): subpath, f"{basefilename}_[{image_type}].{output_file_extension}" ) result.save(filepath, image_type=image_type) - logger.info(f" [{image_type}] saved to: {filepath}") + logger.info(f" {image_type:<22} {filepath}") if image_type == return_filename_type: result_filenames.append(filepath) if videogen: @@ -123,7 +122,8 @@ def _record_step(img, description, image_count, step_count, prompt): start_pause_duration_ms=1500, end_pause_duration_ms=1000, ) - logger.info(f" [gif] {len(frames)} frames saved to: {filepath}") + image_type = "gif" + logger.info(f" {image_type:<22} {filepath}") if make_compare_gif and prompt.init_image: subpath = os.path.join(outdir, "gif") os.makedirs(subpath, exist_ok=True) @@ -137,7 +137,8 @@ def _record_step(img, description, image_count, step_count, prompt): imgs=frames, outpath=filepath, ) - logger.info(f" [gif-comparison] saved to: {filepath}") + image_type = "gif" + logger.info(f" {image_type:<22} {filepath}") base_count += 1 del result @@ -192,9 +193,8 @@ def imagine( ), fix_torch_nn_layer_norm(), fix_torch_group_norm(): for i, prompt in enumerate(prompts): concrete_prompt = prompt.make_concrete_copy() - logger.info( - f"🖼 Generating {i + 1}/{num_prompts}: {concrete_prompt.prompt_description()}" - ) + prog_text = f"{i + 1}/{num_prompts}" + logger.info(f"🖼 {prog_text} {concrete_prompt.prompt_description()}") for attempt in range(unsafe_retry_count + 1): if attempt > 0 and isinstance(concrete_prompt.seed, int): concrete_prompt.seed += 100_000_000 + attempt @@ -204,7 +204,6 @@ def imagine( progress_img_callback=progress_img_callback, progress_img_interval_steps=progress_img_interval_steps, progress_img_interval_min_s=progress_img_interval_min_s, - half_mode=half_mode, add_caption=add_caption, dtype=torch.float16 if half_mode else torch.float32, ) diff --git a/imaginairy/api/generate_compvis.py b/imaginairy/api/generate_compvis.py index 06de995d..633b614c 100644 --- a/imaginairy/api/generate_compvis.py +++ b/imaginairy/api/generate_compvis.py @@ -475,7 +475,7 @@ def latent_logger(latents): is_nsfw=safety_score.is_nsfw, safety_score=safety_score, result_images=result_images, - timings=lc.get_timings(), + performance_stats=lc.get_performance_stats(), progress_latents=progress_latents.copy(), ) diff --git a/imaginairy/api/generate_refiners.py b/imaginairy/api/generate_refiners.py index 3baabfc5..0c92a151 100644 --- a/imaginairy/api/generate_refiners.py +++ b/imaginairy/api/generate_refiners.py @@ -1,10 +1,13 @@ """Functions for generating refined images""" import logging +from contextlib import nullcontext +from typing import Any from imaginairy.config import CONTROL_CONFIG_SHORTCUTS from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode from imaginairy.utils import clear_gpu_cache +from imaginairy.utils.log_utils import ImageLoggingContext logger = logging.getLogger(__name__) @@ -18,7 +21,9 @@ def generate_single_image( add_caption=False, return_latent=False, dtype=None, - half_mode=None, + logging_context: ImageLoggingContext | None = None, + output_perf=False, + image_name="", ): import torch.nn from PIL import Image, ImageOps @@ -59,41 +64,48 @@ def generate_single_image( from imaginairy.utils.safety import create_safety_score if dtype is None: - dtype = torch.float16 if half_mode else torch.float32 + dtype = torch.float16 get_device() clear_gpu_cache() prompt = prompt.make_concrete_copy() - sd = get_diffusion_model_refiners( - weights_config=prompt.model_weights, - for_inpainting=prompt.should_use_inpainting - and prompt.inpaint_method == "finetune", - dtype=dtype, - ) + if not logging_context: + + def latent_logger(latents): + progress_latents.append(latents) + + lc = ImageLoggingContext( + prompt=prompt, + debug_img_callback=debug_img_callback, + progress_img_callback=progress_img_callback, + progress_img_interval_steps=progress_img_interval_steps, + progress_img_interval_min_s=progress_img_interval_min_s, + progress_latent_callback=latent_logger + if prompt.collect_progress_latents + else None, + ) + _context: Any = lc + else: + lc = logging_context + _context = nullcontext() + with _context: + with lc.timing("model-load"): + sd = get_diffusion_model_refiners( + weights_config=prompt.model_weights, + for_inpainting=prompt.should_use_inpainting + and prompt.inpaint_method == "finetune", + dtype=dtype, + ) + lc.model = sd + seed_everything(prompt.seed) + downsampling_factor = 8 + latent_channels = 4 + batch_size = 1 + + mask_image = None + mask_image_orig = None - seed_everything(prompt.seed) - downsampling_factor = 8 - latent_channels = 4 - batch_size = 1 - - mask_image = None - mask_image_orig = None - - def latent_logger(latents): - progress_latents.append(latents) - - with ImageLoggingContext( - prompt=prompt, - model=sd, - debug_img_callback=debug_img_callback, - progress_img_callback=progress_img_callback, - progress_img_interval_steps=progress_img_interval_steps, - progress_img_interval_min_s=progress_img_interval_min_s, - progress_latent_callback=latent_logger - if prompt.collect_progress_latents - else None, - ) as lc: sd.set_tile_mode(prompt.tile_mode) result_images: dict[str, torch.Tensor | None | Image.Image] = {} @@ -178,63 +190,14 @@ def latent_logger(latents): assert prompt.seed is not None noise = randn_seeded(seed=prompt.seed, size=shape).to( - get_device(), dtype=sd.dtype + sd.unet.device, dtype=sd.unet.dtype ) noised_latent = noise controlnets = [] if control_modes: - for control_input in control_inputs: - controlnet, control_image_t, control_image_disp = prep_control_input( - control_input=control_input, - sd=sd, - init_image_t=init_image_t, - fit_width=prompt.width, - fit_height=prompt.height, - ) - result_images[f"control-{control_input.mode}"] = control_image_disp - controlnets.append((controlnet, control_image_t)) - - if prompt.allow_compose_phase: - cutoff_size = get_model_default_image_size(prompt.model_architecture) - cutoff_size = (int(cutoff_size[0] * 1.30), int(cutoff_size[1] * 1.30)) - compose_kwargs = { - "prompt": prompt, - "target_height": prompt.height, - "target_width": prompt.width, - "cutoff": cutoff_size, - "dtype": dtype, - } - - if prompt.init_image: - compose_kwargs.update( - { - "target_height": init_image.height, - "target_width": init_image.width, - } - ) - comp_image, comp_img_orig = _generate_composition_image(**compose_kwargs) - - if comp_image is not None: - prompt.fix_faces = False # done in composition - result_images["composition"] = comp_img_orig - result_images["composition-upscaled"] = comp_image - composition_strength = prompt.composition_strength - first_step = int((prompt.steps) * composition_strength) - noise_step = int((prompt.steps - 1) * composition_strength) - log_img(comp_img_orig, "comp_image") - log_img(comp_image, "comp_image_upscaled") - comp_image_t = pillow_img_to_torch_image(comp_image) - comp_image_t = comp_image_t.to(sd.lda.device, dtype=sd.lda.dtype) - init_latent = sd.lda.encode(comp_image_t) - compose_control_inputs: list[ControlInput] - if prompt.model_weights.architecture.primary_alias == "sdxl": - compose_control_inputs = [] - else: - compose_control_inputs = [ - ControlInput(mode="details", image=comp_image, strength=1), - ] - for control_input in compose_control_inputs: + with lc.timing("control-image-prep"): + for control_input in control_inputs: ( controlnet, control_image_t, @@ -242,16 +205,73 @@ def latent_logger(latents): ) = prep_control_input( control_input=control_input, sd=sd, - init_image_t=None, + init_image_t=init_image_t, fit_width=prompt.width, fit_height=prompt.height, ) result_images[f"control-{control_input.mode}"] = control_image_disp controlnets.append((controlnet, control_image_t)) + if prompt.allow_compose_phase: + with lc.timing("composition"): + cutoff_size = get_model_default_image_size(prompt.model_architecture) + cutoff_size = (int(cutoff_size[0] * 1.30), int(cutoff_size[1] * 1.30)) + compose_kwargs = { + "prompt": prompt, + "target_height": prompt.height, + "target_width": prompt.width, + "cutoff": cutoff_size, + "dtype": dtype, + } + + if prompt.init_image: + compose_kwargs.update( + { + "target_height": init_image.height, + "target_width": init_image.width, + } + ) + comp_image, comp_img_orig = _generate_composition_image( + **compose_kwargs, logging_context=lc + ) + + if comp_image is not None: + prompt.fix_faces = False # done in composition + result_images["composition"] = comp_img_orig + result_images["composition-upscaled"] = comp_image + composition_strength = prompt.composition_strength + first_step = int((prompt.steps) * composition_strength) + noise_step = int((prompt.steps - 1) * composition_strength) + log_img(comp_img_orig, "comp_image") + log_img(comp_image, "comp_image_upscaled") + comp_image_t = pillow_img_to_torch_image(comp_image) + comp_image_t = comp_image_t.to(sd.lda.device, dtype=sd.lda.dtype) + init_latent = sd.lda.encode(comp_image_t) + compose_control_inputs: list[ControlInput] + if prompt.model_weights.architecture.primary_alias == "sdxl": + compose_control_inputs = [] + else: + compose_control_inputs = [ + ControlInput(mode="details", image=comp_image, strength=1), + ] + for control_input in compose_control_inputs: + ( + controlnet, + control_image_t, + control_image_disp, + ) = prep_control_input( + control_input=control_input, + sd=sd, + init_image_t=None, + fit_width=prompt.width, + fit_height=prompt.height, + ) + result_images[ + f"control-{control_input.mode}" + ] = control_image_disp + controlnets.append((controlnet, control_image_t)) + for controlnet, control_image_t in controlnets: - msg = f"Injecting controlnet {controlnet.name}. setting to device: {sd.unet.device}, dtype: {sd.unet.dtype}" - print(msg) controlnet.set_controlnet_condition( control_image_t.to(device=sd.unet.device, dtype=sd.unet.dtype) ) @@ -263,7 +283,7 @@ def latent_logger(latents): else: msg = f"Unknown solver type: {prompt.solver_type}" raise ValueError(msg) - sd.scheduler.to(device=sd.device, dtype=sd.dtype) + sd.scheduler.to(device=sd.unet.device, dtype=sd.unet.dtype) sd.set_num_inference_steps(prompt.steps) if hasattr(sd, "mask_latents") and mask_image is not None: @@ -288,60 +308,66 @@ def latent_logger(latents): x=init_latent, noise=noise, step=sd.steps[noise_step] ) - text_conditioning_kwargs = sd.calculate_text_conditioning_kwargs( - positive_prompts=prompt.prompts, - negative_prompts=prompt.negative_prompt, - positive_conditioning_override=prompt.conditioning, - ) - - for k, v in text_conditioning_kwargs.items(): - text_conditioning_kwargs[k] = v.to( - device=sd.unet.device, dtype=sd.unet.dtype + with lc.timing("text-conditioning"): + text_conditioning_kwargs = sd.calculate_text_conditioning_kwargs( + positive_prompts=prompt.prompts, + negative_prompts=prompt.negative_prompt, + positive_conditioning_override=prompt.conditioning, ) + + for k, v in text_conditioning_kwargs.items(): + text_conditioning_kwargs[k] = v.to( + device=sd.unet.device, dtype=sd.unet.dtype + ) x = noised_latent x = x.to(device=sd.unet.device, dtype=sd.unet.dtype) clear_gpu_cache() - for step in tqdm(sd.steps[first_step:], bar_format=" {l_bar}{bar}{r_bar}"): - log_latent(x, "noisy_latent") - x = sd( - x, - step=step, - condition_scale=prompt.prompt_strength, - **text_conditioning_kwargs, - ) - # trying to clear memory. not sure if this helps - sd.unet.set_context(context="self_attention_map", value={}) - sd.unet._reset_context() - clear_gpu_cache() + with lc.timing("unet"): + for step in tqdm( + sd.steps[first_step:], bar_format=" {l_bar}{bar}{r_bar}", leave=False + ): + log_latent(x, "noisy_latent") + x = sd( + x, + step=step, + condition_scale=prompt.prompt_strength, + **text_conditioning_kwargs, + ) + # trying to clear memory. not sure if this helps + sd.unet.set_context(context="self_attention_map", value={}) + sd.unet._reset_context() + clear_gpu_cache() logger.debug("Decoding image") if x.device != sd.lda.device: sd.lda.to(x.device) clear_gpu_cache() - - gen_img = sd.lda.decode_latents(x.to(dtype=sd.lda.dtype)) + with lc.timing("decode-img"): + gen_img = sd.lda.decode_latents(x.to(dtype=sd.lda.dtype)) if mask_image_orig and init_image: - result_images["pre-reconstitution"] = gen_img - mask_final = mask_image_orig.copy() - # mask_final = ImageOps.invert(mask_final) - - log_img(mask_final, "reconstituting mask") - # gen_img = Image.composite(gen_img, init_image, mask_final) - gen_img = combine_image( - original_img=init_image, - generated_img=gen_img, - mask_img=mask_final, - ) - log_img(gen_img, "reconstituted image") + with lc.timing("combine-image"): + result_images["pre-reconstitution"] = gen_img + mask_final = mask_image_orig.copy() + # mask_final = ImageOps.invert(mask_final) + + log_img(mask_final, "reconstituting mask") + # gen_img = Image.composite(gen_img, init_image, mask_final) + gen_img = combine_image( + original_img=init_image, + generated_img=gen_img, + mask_img=mask_final, + ) + log_img(gen_img, "reconstituted image") upscaled_img = None rebuilt_orig_img = None if add_caption: - caption = generate_caption(gen_img) - logger.info(f"Generated caption: {caption}") + with lc.timing("caption-img"): + caption = generate_caption(gen_img) + logger.info(f"Generated caption: {caption}") with lc.timing("safety-filter"): safety_score = create_safety_score( @@ -352,13 +378,17 @@ def latent_logger(latents): progress_latents.clear() if not safety_score.is_filtered: if prompt.fix_faces: - logger.info("Fixing 😊 's in 🖼 using CodeFormer...") - with lc.timing("face enhancement"): - gen_img = enhance_faces(gen_img, fidelity=prompt.fix_faces_fidelity) + with lc.timing("face-enhancement"): + logger.info("Fixing 😊 's in 🖼 using CodeFormer...") + with lc.timing("face-enhancement"): + gen_img = enhance_faces( + gen_img, fidelity=prompt.fix_faces_fidelity + ) if prompt.upscale: - logger.info("Upscaling 🖼 using real-ESRGAN...") with lc.timing("upscaling"): - upscaled_img = upscale_image(gen_img) + logger.info("Upscaling 🖼 using real-ESRGAN...") + with lc.timing("upscaling"): + upscaled_img = upscale_image(gen_img) # put the newly generated patch back into the original, full-size image if prompt.mask_modify_original and mask_image_orig and starting_image: @@ -390,13 +420,19 @@ def latent_logger(latents): is_nsfw=safety_score.is_nsfw, safety_score=safety_score, result_images=result_images, - timings=lc.get_timings(), + performance_stats=lc.get_performance_stats(), progress_latents=[], # todo ) _most_recent_result = result - if result.timings: - logger.info(f"Image Generated. Timings: {result.timings_str()}") + _image_name = f"{image_name} " if image_name else "" + logger.info(f"Generated {_image_name}image in {result.total_time():.1f}s") + + if result.performance_stats: + log = logger.info if output_perf else logger.debug + log(f" Timings: {result.timings_str()}") + log(f" Peak VRAM: {result.gpu_str('memory_peak')}") + log(f" Ending VRAM: {result.gpu_str('memory_end')}") for controlnet, _ in controlnets: controlnet.eject() clear_gpu_cache() @@ -495,6 +531,7 @@ def _generate_composition_image( target_width, cutoff: tuple[int, int] = (512, 512), dtype=None, + logging_context=None, ): from PIL import Image @@ -530,7 +567,13 @@ def _generate_composition_image( }, ) - result = generate_single_image(composition_prompt, dtype=dtype) + result = generate_single_image( + composition_prompt, + dtype=dtype, + logging_context=logging_context, + output_perf=False, + image_name="composition", + ) img = result.images["generated"] while img.width < target_width: from imaginairy.enhancers.upscale_realesrgan import upscale_image @@ -538,9 +581,11 @@ def _generate_composition_image( if prompt.fix_faces: from imaginairy.enhancers.face_restoration_codeformer import enhance_faces - img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity) - - img = upscale_image(img, ultrasharp=True) + with logging_context.timing("face-enhancement"): + logger.info("Fixing 😊 's in 🖼 using CodeFormer...") + img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity) + with logging_context.timing("upscaling"): + img = upscale_image(img, ultrasharp=True) img = img.resize( (target_width, target_height), diff --git a/imaginairy/cli/shared.py b/imaginairy/cli/shared.py index 1bf22d91..a41b570f 100644 --- a/imaginairy/cli/shared.py +++ b/imaginairy/cli/shared.py @@ -109,8 +109,11 @@ def _imagine_cmd( raise ValueError(msg) total_image_count = len(prompt_texts) * max(len(init_images), 1) * repeats + img_msg = "" + if len(init_images) > 0: + img_msg = f" and {len(init_images)} image(s)" logger.info( - f"Received {len(prompt_texts)} prompt(s) and {len(init_images)} input image(s). Will repeat the generations {repeats} times to create {total_image_count} images." + f"Received {len(prompt_texts)} prompt(s){img_msg}. Will repeat these {repeats} times to create {total_image_count} images.\n" ) from imaginairy.api import imagine_image_files diff --git a/imaginairy/config.py b/imaginairy/config.py index b4a96751..d1e99c91 100644 --- a/imaginairy/config.py +++ b/imaginairy/config.py @@ -124,7 +124,7 @@ def __post_init__(self): aliases=MODEL_ARCHITECTURE_LOOKUP["sd15"].aliases, architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT}, - weights_location="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt", + weights_location="https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/889b629140e71758e1e0006e355c331a5744b4bf/", ), ModelWeightsConfig( name="Stable Diffusion 1.5 - Inpainting", @@ -151,7 +151,7 @@ def __post_init__(self): name="OpenJourney V4", aliases=["openjourney-v4", "oj4", "ojv4", "openjourney4", "openjourney", "oj"], architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], - weights_location="https://huggingface.co/prompthero/openjourney/resolve/e291118e93d5423dc88ac1ed93c02362b17d698f/mdjrny-v4.safetensors", + weights_location="https://huggingface.co/prompthero/openjourney/tree/f4572661b028c732b2b97c8fbdc32fa5db3afe03/", defaults={"negative_prompt": "poor quality"}, ), ModelWeightsConfig( diff --git a/imaginairy/enhancers/face_restoration_codeformer.py b/imaginairy/enhancers/face_restoration_codeformer.py index 01b3fc54..010f7f25 100644 --- a/imaginairy/enhancers/face_restoration_codeformer.py +++ b/imaginairy/enhancers/face_restoration_codeformer.py @@ -74,7 +74,7 @@ def enhance_faces(img, fidelity=0): num_det_faces = face_helper.get_face_landmarks_5( only_center_face=False, resize=640, eye_dist_threshold=5 ) - logger.info(f"Enhancing {num_det_faces} faces") + logger.debug(f"Enhancing {num_det_faces} faces") # align and warp each face face_helper.align_warp_face() diff --git a/imaginairy/modules/refiners_sd.py b/imaginairy/modules/refiners_sd.py index e034df13..88b4b3b7 100644 --- a/imaginairy/modules/refiners_sd.py +++ b/imaginairy/modules/refiners_sd.py @@ -3,11 +3,18 @@ import logging import math from functools import lru_cache -from typing import List, Literal +from typing import Any, List, Literal +import refiners.fluxion.layers as fl import torch from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.layers.chain import ChainError +from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL +from refiners.foundationals.latent_diffusion.model import ( + TLatentDiffusionModel, +) +from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM +from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from refiners.foundationals.latent_diffusion.self_attention_guidance import ( SelfAttentionMap, ) @@ -25,7 +32,11 @@ SDXLAutoencoder, StableDiffusion_XL as RefinerStableDiffusion_XL, ) -from torch import Tensor, nn +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import ( + DoubleTextEncoder, +) +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet +from torch import Tensor, device as Device, dtype as DType, nn from torch.nn import functional as F from imaginairy.schema import WeightedPrompt @@ -83,6 +94,45 @@ def set_tile_mode(self, tile_mode: TileModeType = ""): class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1): + def __init__( + self, + unet: SD1UNet | None = None, + lda: SD1Autoencoder | None = None, + clip_text_encoder: CLIPTextEncoderL | None = None, + scheduler: Scheduler | None = None, + device: Device | str = "cpu", + dtype: DType = torch.float32, + ) -> None: + unet = unet or SD1UNet(in_channels=4) + lda = lda or SD1Autoencoder() + clip_text_encoder = clip_text_encoder or CLIPTextEncoderL() + scheduler = scheduler or DDIM(num_inference_steps=50) + fl.Module.__init__(self) + + # all this is to allow us to make structural copies without unnecessary device or dtype shuffeling + # since default behavior was to put everything on the same device and dtype and we want the option to + # not alter them from whatever they're already set to + self.unet = unet + self.lda = lda + self.clip_text_encoder = clip_text_encoder + self.scheduler = scheduler + to_kwargs: dict[str, Any] = {} + + if device is not None: + device = device if isinstance(device, Device) else Device(device=device) + to_kwargs["device"] = device + if dtype is not None: + to_kwargs["dtype"] = dtype + + self.device = device + self.dtype = dtype + + if to_kwargs: + self.unet = unet.to(**to_kwargs) + self.lda = lda.to(**to_kwargs) + self.clip_text_encoder = clip_text_encoder.to(**to_kwargs) + self.scheduler = scheduler.to(**to_kwargs) + def calculate_text_conditioning_kwargs( self, positive_prompts: List[WeightedPrompt], @@ -122,6 +172,48 @@ def prompts_to_embeddings(self, prompts: List[WeightedPrompt]) -> Tensor: class StableDiffusion_XL(TileModeMixin, RefinerStableDiffusion_XL): + def __init__( + self, + unet: SDXLUNet | None = None, + lda: SDXLAutoencoder | None = None, + clip_text_encoder: DoubleTextEncoder | None = None, + scheduler: Scheduler | None = None, + device: Device | str | None = "cpu", + dtype: DType | None = None, + ) -> None: + unet = unet or SDXLUNet(in_channels=4) + lda = lda or SDXLAutoencoder() + clip_text_encoder = clip_text_encoder or DoubleTextEncoder() + scheduler = scheduler or DDIM(num_inference_steps=30) + fl.Module.__init__(self) + + # all this is to allow us to make structural copies without unnecessary device or dtype shuffeling + # since default behavior was to put everything on the same device and dtype and we want the option to + # not alter them from whatever they're already set to + self.unet = unet + self.lda = lda + self.clip_text_encoder = clip_text_encoder + self.scheduler = scheduler + to_kwargs: dict[str, Any] = {} + + if device is not None: + device = device if isinstance(device, Device) else Device(device=device) + to_kwargs["device"] = device + if dtype is not None: + to_kwargs["dtype"] = dtype + + self.device = device # type: ignore + self.dtype = dtype # type: ignore + self.unet = unet + self.lda = lda + self.clip_text_encoder = clip_text_encoder + self.scheduler = scheduler + if to_kwargs: + self.unet = self.unet.to(**to_kwargs) + self.lda = self.lda.to(**to_kwargs) + self.clip_text_encoder = self.clip_text_encoder.to(**to_kwargs) + self.scheduler = self.scheduler.to(**to_kwargs) + def forward( # type: ignore self, x: Tensor, @@ -144,6 +236,22 @@ def forward( # type: ignore **kwargs, ) + def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel: + logger.debug("Making structural copy of StableDiffusion_XL model") + + sd = self.__class__( + unet=self.unet.structural_copy(), + lda=self.lda.structural_copy(), + clip_text_encoder=self.clip_text_encoder, + scheduler=self.scheduler, + device=self.device, + dtype=None, # type: ignore + ) + logger.debug( + f"dtype: {sd.dtype} unet-dtype:{sd.unet.dtype} lda-dtype:{sd.lda.dtype} text-encoder-dtype:{sd.clip_text_encoder.dtype} scheduler-dtype:{sd.scheduler.dtype}" + ) + return sd + def calculate_text_conditioning_kwargs( self, positive_prompts: List[WeightedPrompt], @@ -452,7 +560,7 @@ def __init__( device=target.device, dtype=target.dtype, ) - print( + logger.debug( f"controlnet: {name} loaded to device {target.device} and type {target.dtype}" ) diff --git a/imaginairy/schema.py b/imaginairy/schema.py index 4830a492..b5ccbf39 100644 --- a/imaginairy/schema.py +++ b/imaginairy/schema.py @@ -362,6 +362,7 @@ def __init__( composition_strength=composition_strength, inpaint_method=inpaint_method, ) + self._default_negative_prompt = None @field_validator("prompt", "negative_prompt", mode="before") def make_into_weighted_prompts( @@ -401,16 +402,20 @@ def sort_prompts(cls, v): v.sort(key=lambda p: p.weight, reverse=True) return v + @property + def default_negative_prompt(self): + default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT + if self.model_weights: + default_negative_prompt = self.model_weights.defaults.get( + "negative_prompt", default_negative_prompt + ) + return default_negative_prompt + @model_validator(mode="after") def validate_negative_prompt(self): if self.negative_prompt == []: - default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT - if self.model_weights: - default_negative_prompt = self.model_weights.defaults.get( - "negative_prompt", default_negative_prompt - ) + self.negative_prompt = [WeightedPrompt(text=self.default_negative_prompt)] - self.negative_prompt = [WeightedPrompt(text=default_negative_prompt)] return self @field_validator("prompt_strength", mode="before") @@ -667,15 +672,27 @@ def model_architecture(self) -> config.ModelArchitecture: return self.model_weights.architecture def prompt_description(self): + if self.negative_prompt_text == self.default_negative_prompt: + neg_prompt = "DEFAULT-NEGATIVE-PROMPT" + else: + neg_prompt = f'"{self.negative_prompt_text}"' + + from termcolor import colored + + prompt_text = colored(self.prompt_text, "green") + return ( - f'"{self.prompt_text}" {self.width}x{self.height}px ' - f'negative-prompt:"{self.negative_prompt_text}" ' + f'"{prompt_text}"\n' + " " + f"negative-prompt:{neg_prompt}\n" + " " + f"size:{self.width}x{self.height}px " f"seed:{self.seed} " f"prompt-strength:{self.prompt_strength} " f"steps:{self.steps} solver-type:{self.solver_type} " f"init-image-strength:{self.init_image_strength} " f"arch:{self.model_architecture.aliases[0]} " - f"weights: {self.model_weights.aliases[0]}" + f"weights:{self.model_weights.aliases[0]}" ) def logging_dict(self): @@ -724,7 +741,7 @@ def __init__( is_nsfw, safety_score, result_images=None, - timings=None, + performance_stats=None, progress_latents=None, ): import torch @@ -750,7 +767,7 @@ def __init__( r_img = torch_img_to_pillow_img(r_img) self.images[img_type] = r_img - self.timings = timings + self.performance_stats = performance_stats self.progress_latents = progress_latents # for backward compat @@ -771,9 +788,24 @@ def metadata_dict(self): } def timings_str(self) -> str: - if not self.timings: + if not self.performance_stats: return "" - return " ".join(f"{k}:{v:.2f}s" for k, v in self.timings.items()) + return " ".join( + f"{k}:{v['duration']:.2f}s" for k, v in self.performance_stats.items() + ) + + def total_time(self) -> float: + if not self.performance_stats: + return 0 + return self.performance_stats["total"]["duration"] + + def gpu_str(self, stat_name="memory_peak") -> str: + if not self.performance_stats: + return "" + return " ".join( + f"{k}:{v[stat_name]/(10**6):.1f}MB" + for k, v in self.performance_stats.items() + ) def _exif(self) -> "Image.Exif": from PIL import Image diff --git a/imaginairy/utils/log_utils.py b/imaginairy/utils/log_utils.py index 44256282..ae5a9a33 100644 --- a/imaginairy/utils/log_utils.py +++ b/imaginairy/utils/log_utils.py @@ -5,6 +5,7 @@ import re import time import warnings +from typing import Callable _CURRENT_LOGGING_CONTEXT = None @@ -53,23 +54,65 @@ def increment_step(): class TimingContext: - def __init__(self, logging_context, description): - self.logging_context = logging_context + """Tracks time and memory usage of a block of code""" + + def __init__( + self, + description: str, + device: str | None = None, + callback_fn: Callable | None = None, + ): + from imaginairy.utils import get_device + self.description = description + self._device = device or get_device() + self.callback_fn = callback_fn + self.start_time = None + self.end_time = None + self.duration = 0 - def __enter__(self): + self.memory_start = 0 + self.memory_end = 0 + self.memory_peak = 0 + + def start(self): + if self._device == "cuda": + import torch + + torch.cuda.reset_peak_memory_stats() + self.memory_start = torch.cuda.memory_allocated() + self.end_time = None self.start_time = time.time() + def stop(self): + self.end_time = time.time() + self.duration += self.end_time - self.start_time + + if self._device == "cuda": + import torch + + self.memory_end = torch.cuda.memory_allocated() + self.memory_peak = max( + torch.cuda.max_memory_allocated() - self.memory_start, self.memory_peak + ) + + if self.callback_fn is not None: + self.callback_fn(self) + + def __enter__(self): + self.start() + return self + def __exit__(self, exc_type, exc_value, traceback): - self.logging_context.timings[self.description] = time.time() - self.start_time + self.stop() class ImageLoggingContext: def __init__( self, prompt, - model, + model=None, debug_img_callback=None, img_outdir=None, progress_img_callback=None, @@ -88,29 +131,60 @@ def __init__( self.progress_img_interval_min_s = progress_img_interval_min_s self.progress_latent_callback = progress_latent_callback - self.start_ts = time.perf_counter() - self.timings = {} + self.summary_context = TimingContext("total") + self.summary_context.start() + self.timing_contexts = {} self.last_progress_img_ts = 0 self.last_progress_img_step = -1000 self._prev_log_context = None def __enter__(self): + self.start() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + def start(self): global _CURRENT_LOGGING_CONTEXT self._prev_log_context = _CURRENT_LOGGING_CONTEXT _CURRENT_LOGGING_CONTEXT = self return self - def __exit__(self, exc_type, exc_val, exc_tb): + def stop(self): global _CURRENT_LOGGING_CONTEXT _CURRENT_LOGGING_CONTEXT = self._prev_log_context def timing(self, description): - return TimingContext(self, description) + if description not in self.timing_contexts: + + def cb(context): + self.timing_contexts[description] = context + + tc = TimingContext(description, callback_fn=cb) + self.timing_contexts[description] = tc + return self.timing_contexts[description] + + def get_performance_stats(self) -> dict[str, dict[str, float]]: + # calculate max peak seen in any timing context + self.summary_context.stop() + self.timing_contexts["total"] = self.summary_context + + self.summary_context.memory_peak = max( + max(context.memory_peak, context.memory_start, context.memory_end) + for context in self.timing_contexts.values() + ) - def get_timings(self): - self.timings["total"] = time.perf_counter() - self.start_ts - return self.timings + performance_stats = {} + for context in self.timing_contexts.values(): + performance_stats[context.description] = { + "duration": context.duration, + "memory_start": context.memory_start, + "memory_end": context.memory_end, + "memory_peak": context.memory_peak, + "memory_delta": context.memory_end - context.memory_start, + } + return performance_stats def log_conditioning(self, conditioning, description): if not self.debug_img_callback: diff --git a/imaginairy/utils/model_manager.py b/imaginairy/utils/model_manager.py index e9cff14e..d1810a41 100644 --- a/imaginairy/utils/model_manager.py +++ b/imaginairy/utils/model_manager.py @@ -123,12 +123,9 @@ def load_model_from_config_old( if half_mode: model = model.half() - print("halved") model.to(get_device()) - print("moved to device") model.eval() - print("set to eval mode") return model @@ -222,12 +219,18 @@ def get_diffusion_model_refiners( ) -> LatentDiffusionModel: """Load a diffusion model.""" - return _get_diffusion_model_refiners( + sd = _get_diffusion_model_refiners( weights_location=weights_config.weights_location, architecture_alias=weights_config.architecture.primary_alias, for_inpainting=for_inpainting, dtype=dtype, ) + # ensures a "fresh" copy that doesn't have additional injected parts + # sd = sd.structural_copy() + + sd.set_self_attention_guidance(enable=True) + + return sd hf_repo_url_pattern = re.compile( @@ -241,7 +244,9 @@ def parse_diffusers_repo_url(url: str) -> dict[str, str]: def is_diffusers_repo_url(url: str) -> bool: - return bool(parse_diffusers_repo_url(url)) + result = bool(parse_diffusers_repo_url(url)) + logger.debug(f"{url} is diffusers repo url: {result}") + return result def normalize_diffusers_repo_url(url: str) -> str: @@ -332,17 +337,16 @@ def _get_sd15_diffusion_model_refiners( device=device, dtype=dtype, lda=SD1AutoencoderSliced(), unet=unet ) logger.debug("Loading VAE") - sd.lda.load_state_dict(vae_weights) + sd.lda.load_state_dict(vae_weights, assign=True) logger.debug("Loading text encoder") - sd.clip_text_encoder.load_state_dict(text_encoder_weights) + sd.clip_text_encoder.load_state_dict(text_encoder_weights, assign=True) logger.debug("Loading UNet") - sd.unet.load_state_dict(unet_weights, strict=False) + sd.unet.load_state_dict(unet_weights, strict=False, assign=True) logger.debug(f"'{weights_location}' Loaded") - - sd.set_self_attention_guidance(enable=True) + sd.to(device=device, dtype=dtype) return sd @@ -711,16 +715,15 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16) device="cpu", ) ) - text_encoder = DoubleTextEncoder(device="cpu", dtype=dtype) + text_encoder = DoubleTextEncoder(device="cpu", dtype=torch.float32) text_encoder.load_state_dict(text_encoder_weights) del text_encoder_weights - lda = lda.to(device=device) + lda = lda.to(device=device, dtype=torch.float32) unet = unet.to(device=device) text_encoder = text_encoder.to(device=device) sd = StableDiffusion_XL( - device=device, dtype=dtype, lda=lda, unet=unet, clip_text_encoder=text_encoder + device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder ) - sd.lda.to(device=device, dtype=torch.float32) return sd @@ -729,9 +732,6 @@ def load_sdxl_pipeline(base_url, device=None): logger.info(f"Loading SDXL weights from {base_url}") device = device or get_device() sd = load_sdxl_diffusers_weights(base_url, device=device) - - sd.set_self_attention_guidance(enable=True) - return sd diff --git a/scripts/generate_phraselist.py b/scripts/generate_phraselist.py index 3ece47dc..e832fc9d 100644 --- a/scripts/generate_phraselist.py +++ b/scripts/generate_phraselist.py @@ -1,7 +1,7 @@ import re -def generate_phrase_list(subject, num_phrases=100): +def generate_phrase_list(subject, num_phrases=100, max_words=6): """Generate a list of phrases for a given subject.""" from openai import OpenAI @@ -9,7 +9,7 @@ def generate_phrase_list(subject, num_phrases=100): prompt = ( f'Make list of archetypal imagery about "{subject}". These will provide composition ideas to an artist. ' - f"No more than 6 words per idea. Make {num_phrases} ideas. Provide the output as plaintext with each idea on a new line. " + f"No more than {max_words} words per idea. Make {num_phrases} ideas. Provide the output as plaintext with each idea on a new line. " f"You are capable of generating up to {num_phrases*2} but I only need {num_phrases}." ) messages = [ @@ -19,8 +19,8 @@ def generate_phrase_list(subject, num_phrases=100): response = client.chat.completions.create( model="gpt-4-1106-preview", messages=messages, - temperature=0.1, - max_tokens=2623, + temperature=1, + max_tokens=4000, top_p=1, frequency_penalty=0.17, presence_penalty=0, @@ -39,4 +39,9 @@ def generate_phrase_list(subject, num_phrases=100): if __name__ == "__main__": - print(generate_phrase_list("traditional christmas", num_phrases=200)) + phrase_list = generate_phrase_list( + subject="symbolism for the human condition and the struggle between good and evil", + num_phrases=200, + max_words=15, + ) + print(phrase_list) diff --git a/tests/test_enhancers/test_prompt_expansion.py b/tests/test_enhancers/test_prompt_expansion.py index 0510181a..3a7d3b30 100644 --- a/tests/test_enhancers/test_prompt_expansion.py +++ b/tests/test_enhancers/test_prompt_expansion.py @@ -42,3 +42,8 @@ def test_prompt_expander_from_wordlist(): def test_get_phraselist_names(): print(", ".join(category_list())) + + +def test_complex_prompt(): + prompt = "{_painting-style_} of {_art-scene_}. painting" + assert len(list(expand_prompts(prompt, n=100))) == 100 diff --git a/tox.ini b/tox.ini index 90befd69..20dfcf66 100644 --- a/tox.ini +++ b/tox.ini @@ -12,7 +12,7 @@ plugins = pydantic.mypy exclude = ^(\./|)(downloads|dist|build|other|testing_support|imaginairy/vendored|imaginairy/modules/sgm) ignore_missing_imports = True warn_unused_configs = True -warn_unused_ignores = True +warn_unused_ignores = False [mypy-imaginairy.vendored.*] follow_imports = skip