Skip to content

Commit

Permalink
perf: improve memory usage
Browse files Browse the repository at this point in the history
add warning for corrupt weights files
  • Loading branch information
brycedrennan committed Dec 29, 2023
1 parent 26d1ff9 commit ed93abc
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 10 deletions.
2 changes: 2 additions & 0 deletions imaginairy/api/generate_refiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ def latent_logger(latents):
condition_scale=prompt.prompt_strength,
**text_conditioning_kwargs,
)
# trying to clear memory. not sure if this helps
sd.unet.set_context(context="self_attention_map", value={})
sd.unet._reset_context()
clear_gpu_cache()

logger.debug("Decoding image")
Expand Down
2 changes: 1 addition & 1 deletion imaginairy/enhancers/upscale_realesrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@memory_managed_model("realesrgan_upsampler", memory_usage_mb=70)
def realesrgan_upsampler(tile=1024, tile_pad=50, ultrasharp=False):
def realesrgan_upsampler(tile=512, tile_pad=50, ultrasharp=False):
model = RRDBNet(
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
)
Expand Down
12 changes: 7 additions & 5 deletions imaginairy/utils/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16)
vae_weights_path = download_diffusers_weights(
base_url=base_url, sub="vae", prefer_fp16=False
)
print(vae_weights_path)
logger.debug(f"vae: {vae_weights_path}")
vae_weights = translator.load_and_translate_weights(
source_path=vae_weights_path,
device="cpu",
Expand All @@ -684,8 +684,10 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16)
del vae_weights

translator = translators.diffusers_unet_sdxl_to_refiners_translator()
unet_weights_path = download_diffusers_weights(base_url=base_url, sub="unet")
print(unet_weights_path)
unet_weights_path = download_diffusers_weights(
base_url=base_url, sub="unet", prefer_fp16=True
)
logger.debug(f"unet: {unet_weights_path}")
unet_weights = translator.load_and_translate_weights(
source_path=unet_weights_path,
device="cpu",
Expand All @@ -700,8 +702,8 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16)
text_encoder_2_path = download_diffusers_weights(
base_url=base_url, sub="text_encoder_2"
)
print(text_encoder_1_path)
print(text_encoder_2_path)
logger.debug(f"text encoder 1: {text_encoder_1_path}")
logger.debug(f"text encoder 2: {text_encoder_2_path}")
text_encoder_weights = (
translators.DoubleTextEncoderTranslator().load_and_translate_weights(
text_encoder_l_weights_path=text_encoder_1_path,
Expand Down
4 changes: 3 additions & 1 deletion imaginairy/vendored/realesrgan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import math
import os
import queue
Expand All @@ -12,6 +13,7 @@

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

logger = logging.getLogger(__name__)

class RealESRGANer:
"""A helper class for upsampling images with RealESRGAN.
Expand Down Expand Up @@ -146,7 +148,7 @@ def tile_process(self):
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)

logger.debug(f"Tiling with {tiles_x}x{tiles_y} ({tiles_x*tiles_y}) tiles")
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
Expand Down
36 changes: 33 additions & 3 deletions imaginairy/weight_management/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Dict

import torch
from refiners.fluxion import load_from_safetensors
from safetensors import safe_open
from torch import device as Device

logger = logging.getLogger(__name__)
Expand All @@ -29,7 +29,8 @@ def load_and_translate_weights(
source_weights = torch.load(source_path, map_location="cpu")

elif extension in ["safetensors"]:
source_weights = load_from_safetensors(source_path, device=device)
with safe_open(source_path, framework="pt", device=device) as f: # type: ignore
source_weights = {k: f.get_tensor(k) for k in f.keys()} # noqa
else:
msg = f"Unsupported extension {extension}"
raise ValueError(msg)
Expand Down Expand Up @@ -79,10 +80,30 @@ def load(cls, path):
return cls(**d)


def check_nan_path(path: str, device):
from safetensors import safe_open

with safe_open(path, framework="pt", device=device) as f: # type: ignore
for k in f.keys(): # noqa
if torch.any(torch.isnan(f.get_tensor(k))):
print(f"Found nan values in {k} of {path}")


def translate_weights(
source_weights: TensorDict, weight_map: WeightTranslationMap
) -> TensorDict:
new_state_dict: TensorDict = {}
# check source weights for nan
for k, v in source_weights.items():
nan_count = torch.sum(torch.isnan(v)).item()
if nan_count:
msg = (
f"Found {nan_count} nan values in {k} of source state dict."
" This could indicate the source weights are corrupted and "
"need to be re-downloaded. "
)
logger.warning(msg)

# print(f"Translating {len(source_weights)} weights")
# print(f"Using {len(weight_map.name_map)} name mappings")
# print(source_weights.keys())
Expand Down Expand Up @@ -142,7 +163,7 @@ def translate_weights(

if source_weights:
msg = f"Unmapped keys: {list(source_weights.keys())}"
print(msg)
logger.info(msg)
for k in source_weights:
if isinstance(source_weights[k], torch.Tensor):
print(f" {k}: {source_weights[k].shape}")
Expand All @@ -154,6 +175,15 @@ def translate_weights(
if key in new_state_dict:
new_state_dict[key] = new_state_dict[key].reshape(new_shape)

# check for nan values
for k in list(new_state_dict.keys()):
v = new_state_dict[k]
nan_count = torch.sum(torch.isnan(v)).item()
if nan_count:
logger.warning(
f"Found {nan_count} nan values in {k} of converted state dict."
)

return new_state_dict


Expand Down

0 comments on commit ed93abc

Please sign in to comment.