Skip to content

Commit

Permalink
fix to work sample generation in fp8 ref kohya-ss#1057
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss authored and Disty0 committed Jan 28, 2024
1 parent 88e20eb commit f0a8988
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion library/sdxl_lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,11 @@ def __call__(
if up1 is not None:
uncond_pool = up1

dtype = self.unet.dtype
unet_dtype = self.unet.dtype
dtype = unet_dtype
if dtype.itemsize == 1: # fp8
dtype = torch.float16
self.unet.to(dtype)

# 4. Preprocess image and mask
if isinstance(image, PIL.Image.Image):
Expand Down Expand Up @@ -1028,6 +1032,7 @@ def __call__(
if is_cancelled_callback is not None and is_cancelled_callback():
return None

self.unet.to(unet_dtype)
return latents

def latents_to_image(self, latents):
Expand Down

0 comments on commit f0a8988

Please sign in to comment.