Skip to content

Commit

Permalink
fix: interrupted generations don't prevent more generations
Browse files Browse the repository at this point in the history
fixes #424

- pref: improve memory usage when loading SD15.
- feature: clean up CLI output more
- feature: cuda memory tracking context manager
- feature: use safetensors fp16 for sd15
  • Loading branch information
brycedrennan committed Jan 2, 2024
1 parent 1b85a38 commit ebec3b0
Show file tree
Hide file tree
Showing 9 changed files with 291 additions and 31 deletions.
1 change: 1 addition & 0 deletions imaginairy/api/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def imagine(
progress_img_interval_min_s=progress_img_interval_min_s,
add_caption=add_caption,
dtype=torch.float16 if half_mode else torch.float32,
output_perf=True,
)
if not result.safety_score.is_filtered:
break
Expand Down
6 changes: 4 additions & 2 deletions imaginairy/api/generate_refiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,10 @@ def latent_logger(latents):
if result.performance_stats:
log = logger.info if output_perf else logger.debug
log(f" Timings: {result.timings_str()}")
log(f" Peak VRAM: {result.gpu_str('memory_peak')}")
log(f" Ending VRAM: {result.gpu_str('memory_end')}")
if torch.cuda.is_available():
log(f" Peak VRAM: {result.gpu_str('memory_peak')}")
log(f" Peak VRAM Delta: {result.gpu_str('memory_peak_delta')}")
log(f" Ending VRAM: {result.gpu_str('memory_end')}")
for controlnet, _ in controlnets:
controlnet.eject()
clear_gpu_cache()
Expand Down
4 changes: 2 additions & 2 deletions imaginairy/cli/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def _imagine_cmd(
total_image_count = len(prompt_texts) * max(len(init_images), 1) * repeats
img_msg = ""
if len(init_images) > 0:
img_msg = f" and {len(init_images)} image(s)"
img_msg = f" x {len(init_images)} image(s)"
logger.info(
f"Received {len(prompt_texts)} prompt(s){img_msg}. Will repeat these {repeats} times to create {total_image_count} images.\n"
f"Generating {total_image_count} images. ({len(prompt_texts)} prompt(s){img_msg} x {repeats} repetitions)\n"
)

from imaginairy.api import imagine_image_files
Expand Down
2 changes: 1 addition & 1 deletion imaginairy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __post_init__(self):
aliases=MODEL_ARCHITECTURE_LOOKUP["sd15"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT},
weights_location="https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/889b629140e71758e1e0006e355c331a5744b4bf/",
weights_location="https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9/",
),
ModelWeightsConfig(
name="Stable Diffusion 1.5 - Inpainting",
Expand Down
4 changes: 2 additions & 2 deletions imaginairy/modules/refiners_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
lda: SD1Autoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
device: Device | str | None = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SD1UNet(in_channels=4)
Expand All @@ -124,7 +124,7 @@ def __init__(
if dtype is not None:
to_kwargs["dtype"] = dtype

self.device = device
self.device = device # type: ignore
self.dtype = dtype

if to_kwargs:
Expand Down
38 changes: 24 additions & 14 deletions imaginairy/utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import warnings
from typing import Callable

import torch.cuda

from imaginairy.utils.memory_tracker import TorchRAMTracker

_CURRENT_LOGGING_CONTEXT = None

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -72,29 +76,34 @@ def __init__(
self.end_time = None
self.duration = 0

self.memory_start = 0
self.memory_context = None
self.memory_start = None
self.memory_end = 0
self.memory_peak = 0
self.memory_peak_delta = 0

def start(self):
# supports repeated calls to start/stop
if self._device == "cuda":
import torch

torch.cuda.reset_peak_memory_stats()
self.memory_start = torch.cuda.memory_allocated()
self.memory_context = TorchRAMTracker(self.description)
self.memory_context.start()
if self.memory_start is None:
self.memory_start = self.memory_context.start_memory
self.end_time = None
self.start_time = time.time()

def stop(self):
# supports repeated calls to start/stop
self.end_time = time.time()
self.duration += self.end_time - self.start_time

if self._device == "cuda":
import torch
self.memory_context.stop()

self.memory_end = torch.cuda.memory_allocated()
self.memory_peak = max(
torch.cuda.max_memory_allocated() - self.memory_start, self.memory_peak
self.memory_end = self.memory_context.end_memory
self.memory_peak = max(self.memory_context.peak_memory, self.memory_peak)
self.memory_peak_delta = max(
self.memory_context.peak_memory_delta, self.memory_peak_delta
)

if self.callback_fn is not None:
Expand Down Expand Up @@ -170,10 +179,11 @@ def get_performance_stats(self) -> dict[str, dict[str, float]]:
self.summary_context.stop()
self.timing_contexts["total"] = self.summary_context

self.summary_context.memory_peak = max(
max(context.memory_peak, context.memory_start, context.memory_end)
for context in self.timing_contexts.values()
)
if torch.cuda.is_available():
self.summary_context.memory_peak = max(
max(context.memory_peak, context.memory_start, context.memory_end)
for context in self.timing_contexts.values()
)

performance_stats = {}
for context in self.timing_contexts.values():
Expand All @@ -182,7 +192,7 @@ def get_performance_stats(self) -> dict[str, dict[str, float]]:
"memory_start": context.memory_start,
"memory_end": context.memory_end,
"memory_peak": context.memory_peak,
"memory_delta": context.memory_end - context.memory_start,
"memory_peak_delta": context.memory_peak_delta,
}
return performance_stats

Expand Down
43 changes: 43 additions & 0 deletions imaginairy/utils/memory_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import contextlib
from typing import Callable, List

import torch


class TorchRAMTracker(contextlib.ContextDecorator):
"""Tracks peak CUDA memory usage for a block of code."""

_memory_stack: List[int] = []
mem_interface = torch.cuda

def __init__(
self, name="", callback_fn: "Callable[[TorchRAMTracker], None] | None" = None
):
self.name = name
self.peak_memory = 0
self.start_memory = 0
self.end_memory = 0
self.callback_fn = callback_fn
self._stack_depth = None

def start(self):
current_peak = self.mem_interface.max_memory_allocated()
TorchRAMTracker._memory_stack.append(current_peak)
self._stack_depth = len(TorchRAMTracker._memory_stack)
self.mem_interface.reset_peak_memory_stats()
self.start_memory = self.mem_interface.memory_allocated()

def stop(self):
end_peak = self.mem_interface.max_memory_allocated()
peaks = TorchRAMTracker._memory_stack[self._stack_depth :] + [end_peak]
self.peak_memory = max(peaks)
del TorchRAMTracker._memory_stack[self._stack_depth :]
self.end_memory = self.mem_interface.memory_allocated()
self.peak_memory_delta = self.peak_memory - self.start_memory

def __enter__(self):
self.start()
return self

def __exit__(self, *exc):
self.stop()
101 changes: 91 additions & 10 deletions imaginairy/utils/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
try_to_load_from_cache,
)
from omegaconf import OmegaConf
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion import DoubleTextEncoder, SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from safetensors.torch import load_file
Expand Down Expand Up @@ -226,7 +227,7 @@ def get_diffusion_model_refiners(
dtype=dtype,
)
# ensures a "fresh" copy that doesn't have additional injected parts
# sd = sd.structural_copy()
sd = sd.structural_copy()

sd.set_self_attention_guidance(enable=True)

Expand Down Expand Up @@ -290,9 +291,19 @@ def _get_diffusion_model_refiners(
raise ValueError(msg)

MOST_RECENTLY_LOADED_MODEL = sd

msg = (
f"sd dtype:{sd.dtype} device:{sd.device}\n"
f"sd.unet dtype:{sd.unet.dtype} device:{sd.unet.device}\n"
f"sd.lda dtype:{sd.lda.dtype} device:{sd.lda.device}\n"
f"sd.clip_text_encoder dtype:{sd.clip_text_encoder.dtype} device:{sd.clip_text_encoder.device}\n"
)
logger.debug(msg)

return sd


# new
def _get_sd15_diffusion_model_refiners(
weights_location: str,
for_inpainting: bool = False,
Expand All @@ -316,20 +327,19 @@ def _get_sd15_diffusion_model_refiners(
vae_weights,
unet_weights,
text_encoder_weights,
) = load_sd15_diffusers_weights(weights_location)
) = load_sd15_diffusers_weights(weights_location, device="cpu")
else:
(
vae_weights,
unet_weights,
text_encoder_weights,
) = load_stable_diffusion_compvis_weights(weights_location)

StableDiffusionCls: type[LatentDiffusionModel]
if for_inpainting:
unet = SD1UNet(in_channels=9, device=device, dtype=dtype)
unet = SD1UNet(in_channels=9, device="cpu", dtype=dtype)
StableDiffusionCls = StableDiffusion_1_Inpainting
else:
unet = SD1UNet(in_channels=4, device=device, dtype=dtype)
unet = SD1UNet(in_channels=4, device="cpu", dtype=dtype)
StableDiffusionCls = StableDiffusion_1
logger.debug(f"Using class {StableDiffusionCls.__name__}")

Expand All @@ -347,7 +357,70 @@ def _get_sd15_diffusion_model_refiners(

logger.debug(f"'{weights_location}' Loaded")
sd.to(device=device, dtype=dtype)
return sd


def _get_sd15_diffusion_model_refiners_new(
weights_location: str,
for_inpainting: bool = False,
device=None,
dtype=torch.float16,
) -> LatentDiffusionModel:
"""
Load a diffusion model.
Weights location may also be shortcut name, e.g. "SD-1.5"
"""
from imaginairy.modules.refiners_sd import (
SD1AutoencoderSliced,
StableDiffusion_1,
StableDiffusion_1_Inpainting,
)

device = device or get_device()
if is_diffusers_repo_url(weights_location):
(
vae_weights,
unet_weights,
text_encoder_weights,
) = load_sd15_diffusers_weights(weights_location, device="cpu")
else:
(
vae_weights,
unet_weights,
text_encoder_weights,
) = load_stable_diffusion_compvis_weights(weights_location)
StableDiffusionCls: type[LatentDiffusionModel]
if for_inpainting:
unet = SD1UNet(in_channels=9, device="cpu", dtype=dtype)
StableDiffusionCls = StableDiffusion_1_Inpainting
else:
unet = SD1UNet(in_channels=4, device="cpu", dtype=dtype)
StableDiffusionCls = StableDiffusion_1

logger.debug("Loading UNet")
unet.load_state_dict(unet_weights, strict=False, assign=True)
del unet_weights
unet.to(device=device, dtype=dtype)

logger.debug("Loading VAE")
lda = SD1AutoencoderSliced(device=device, dtype=dtype)
lda.load_state_dict(vae_weights, assign=True)
del vae_weights
lda.to(device=device, dtype=dtype)

logger.debug("Loading text encoder")
clip_text_encoder = CLIPTextEncoderL()
clip_text_encoder.load_state_dict(text_encoder_weights, assign=True)
del text_encoder_weights
clip_text_encoder.to(device=device, dtype=dtype)

logger.debug(f"Using class {StableDiffusionCls.__name__}")

sd = StableDiffusionCls(device=None, dtype=dtype, lda=lda, unet=unet) # type: ignore
sd.to(device=device, dtype=dtype)

logger.debug(f"'{weights_location}' Loaded")
return sd


Expand Down Expand Up @@ -378,7 +451,7 @@ def load_controlnet(control_weights_location, half_mode):
"model"
]["params"]["control_stage_config"]
controlnet = instantiate_from_config(control_stage_config)
controlnet.load_state_dict(controlnet_state_dict)
controlnet.load_state_dict(controlnet_state_dict, assign=True)
controlnet.to(get_device())
return controlnet

Expand Down Expand Up @@ -663,7 +736,15 @@ def load_sd15_diffusers_weights(base_url: str, device=None):
source_format=FORMAT_NAMES.DIFFUSERS,
dest_format=FORMAT_NAMES.REFINERS,
)

first_vae = next(iter(vae_weights.values()))
first_unet = next(iter(unet_weights.values()))
first_encoder = next(iter(text_encoder_weights.values()))
msg = (
f"vae weights. dtype: {first_vae.dtype} device: {first_vae.device}\n"
f"unet weights. dtype: {first_unet.dtype} device: {first_unet.device}\n"
f"text_encoder weights. dtype: {first_encoder.dtype} device: {first_encoder.device}\n"
)
logger.debug(msg)
return vae_weights, unet_weights, text_encoder_weights


Expand All @@ -684,7 +765,7 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16)
device="cpu",
)
lda = SDXLAutoencoderSliced(device="cpu", dtype=dtype)
lda.load_state_dict(vae_weights)
lda.load_state_dict(vae_weights, assign=True)
del vae_weights

translator = translators.diffusers_unet_sdxl_to_refiners_translator()
Expand All @@ -697,7 +778,7 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16)
device="cpu",
)
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
unet.load_state_dict(unet_weights)
unet.load_state_dict(unet_weights, assign=True)
del unet_weights

text_encoder_1_path = download_diffusers_weights(
Expand All @@ -716,7 +797,7 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16)
)
)
text_encoder = DoubleTextEncoder(device="cpu", dtype=torch.float32)
text_encoder.load_state_dict(text_encoder_weights)
text_encoder.load_state_dict(text_encoder_weights, assign=True)
del text_encoder_weights
lda = lda.to(device=device, dtype=torch.float32)
unet = unet.to(device=device)
Expand Down
Loading

0 comments on commit ebec3b0

Please sign in to comment.