From ed93abc2612d308d177038582b43bf559cced15f Mon Sep 17 00:00:00 2001 From: Bryce Date: Fri, 29 Dec 2023 09:01:14 -0800 Subject: [PATCH] perf: improve memory usage add warning for corrupt weights files --- imaginairy/api/generate_refiners.py | 2 ++ imaginairy/enhancers/upscale_realesrgan.py | 2 +- imaginairy/utils/model_manager.py | 12 ++++--- imaginairy/vendored/realesrgan.py | 4 ++- imaginairy/weight_management/translation.py | 36 +++++++++++++++++++-- 5 files changed, 46 insertions(+), 10 deletions(-) diff --git a/imaginairy/api/generate_refiners.py b/imaginairy/api/generate_refiners.py index 4b57b3e5..3baabfc5 100644 --- a/imaginairy/api/generate_refiners.py +++ b/imaginairy/api/generate_refiners.py @@ -310,7 +310,9 @@ def latent_logger(latents): condition_scale=prompt.prompt_strength, **text_conditioning_kwargs, ) + # trying to clear memory. not sure if this helps sd.unet.set_context(context="self_attention_map", value={}) + sd.unet._reset_context() clear_gpu_cache() logger.debug("Decoding image") diff --git a/imaginairy/enhancers/upscale_realesrgan.py b/imaginairy/enhancers/upscale_realesrgan.py index 664a4dbc..2343d792 100644 --- a/imaginairy/enhancers/upscale_realesrgan.py +++ b/imaginairy/enhancers/upscale_realesrgan.py @@ -12,7 +12,7 @@ @memory_managed_model("realesrgan_upsampler", memory_usage_mb=70) -def realesrgan_upsampler(tile=1024, tile_pad=50, ultrasharp=False): +def realesrgan_upsampler(tile=512, tile_pad=50, ultrasharp=False): model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4 ) diff --git a/imaginairy/utils/model_manager.py b/imaginairy/utils/model_manager.py index 93c51bf9..e9cff14e 100644 --- a/imaginairy/utils/model_manager.py +++ b/imaginairy/utils/model_manager.py @@ -674,7 +674,7 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16) vae_weights_path = download_diffusers_weights( base_url=base_url, sub="vae", prefer_fp16=False ) - print(vae_weights_path) + logger.debug(f"vae: {vae_weights_path}") vae_weights = translator.load_and_translate_weights( source_path=vae_weights_path, device="cpu", @@ -684,8 +684,10 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16) del vae_weights translator = translators.diffusers_unet_sdxl_to_refiners_translator() - unet_weights_path = download_diffusers_weights(base_url=base_url, sub="unet") - print(unet_weights_path) + unet_weights_path = download_diffusers_weights( + base_url=base_url, sub="unet", prefer_fp16=True + ) + logger.debug(f"unet: {unet_weights_path}") unet_weights = translator.load_and_translate_weights( source_path=unet_weights_path, device="cpu", @@ -700,8 +702,8 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16) text_encoder_2_path = download_diffusers_weights( base_url=base_url, sub="text_encoder_2" ) - print(text_encoder_1_path) - print(text_encoder_2_path) + logger.debug(f"text encoder 1: {text_encoder_1_path}") + logger.debug(f"text encoder 2: {text_encoder_2_path}") text_encoder_weights = ( translators.DoubleTextEncoderTranslator().load_and_translate_weights( text_encoder_l_weights_path=text_encoder_1_path, diff --git a/imaginairy/vendored/realesrgan.py b/imaginairy/vendored/realesrgan.py index 373d0fae..94b119ec 100644 --- a/imaginairy/vendored/realesrgan.py +++ b/imaginairy/vendored/realesrgan.py @@ -1,3 +1,4 @@ +import logging import math import os import queue @@ -12,6 +13,7 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +logger = logging.getLogger(__name__) class RealESRGANer: """A helper class for upsampling images with RealESRGAN. @@ -146,7 +148,7 @@ def tile_process(self): self.output = self.img.new_zeros(output_shape) tiles_x = math.ceil(width / self.tile_size) tiles_y = math.ceil(height / self.tile_size) - + logger.debug(f"Tiling with {tiles_x}x{tiles_y} ({tiles_x*tiles_y}) tiles") # loop over all tiles for y in range(tiles_y): for x in range(tiles_x): diff --git a/imaginairy/weight_management/translation.py b/imaginairy/weight_management/translation.py index 873e2e9b..928d7875 100644 --- a/imaginairy/weight_management/translation.py +++ b/imaginairy/weight_management/translation.py @@ -4,7 +4,7 @@ from typing import Dict import torch -from refiners.fluxion import load_from_safetensors +from safetensors import safe_open from torch import device as Device logger = logging.getLogger(__name__) @@ -29,7 +29,8 @@ def load_and_translate_weights( source_weights = torch.load(source_path, map_location="cpu") elif extension in ["safetensors"]: - source_weights = load_from_safetensors(source_path, device=device) + with safe_open(source_path, framework="pt", device=device) as f: # type: ignore + source_weights = {k: f.get_tensor(k) for k in f.keys()} # noqa else: msg = f"Unsupported extension {extension}" raise ValueError(msg) @@ -79,10 +80,30 @@ def load(cls, path): return cls(**d) +def check_nan_path(path: str, device): + from safetensors import safe_open + + with safe_open(path, framework="pt", device=device) as f: # type: ignore + for k in f.keys(): # noqa + if torch.any(torch.isnan(f.get_tensor(k))): + print(f"Found nan values in {k} of {path}") + + def translate_weights( source_weights: TensorDict, weight_map: WeightTranslationMap ) -> TensorDict: new_state_dict: TensorDict = {} + # check source weights for nan + for k, v in source_weights.items(): + nan_count = torch.sum(torch.isnan(v)).item() + if nan_count: + msg = ( + f"Found {nan_count} nan values in {k} of source state dict." + " This could indicate the source weights are corrupted and " + "need to be re-downloaded. " + ) + logger.warning(msg) + # print(f"Translating {len(source_weights)} weights") # print(f"Using {len(weight_map.name_map)} name mappings") # print(source_weights.keys()) @@ -142,7 +163,7 @@ def translate_weights( if source_weights: msg = f"Unmapped keys: {list(source_weights.keys())}" - print(msg) + logger.info(msg) for k in source_weights: if isinstance(source_weights[k], torch.Tensor): print(f" {k}: {source_weights[k].shape}") @@ -154,6 +175,15 @@ def translate_weights( if key in new_state_dict: new_state_dict[key] = new_state_dict[key].reshape(new_shape) + # check for nan values + for k in list(new_state_dict.keys()): + v = new_state_dict[k] + nan_count = torch.sum(torch.isnan(v)).item() + if nan_count: + logger.warning( + f"Found {nan_count} nan values in {k} of converted state dict." + ) + return new_state_dict