Skip to content

Commit

Permalink
Add magic to debug
Browse files Browse the repository at this point in the history
  • Loading branch information
StAlKeR7779 committed Aug 14, 2023
1 parent 409e5d0 commit 511da59
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions invokeai/backend/stable_diffusion/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class AddsMaskGuidance:
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
scheduler: SchedulerMixin
noise: torch.Tensor
noise: Optional[torch.Tensor]

def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
output_class = step_output.__class__ # We'll create a new one with masked data.
Expand Down Expand Up @@ -124,7 +124,10 @@ def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
t = einops.repeat(t, "-> batch", batch=batch_size)
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
# get very confused about what is happening from step to step when we do that.
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
if self.noise is not None:
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
else:
mask_latents = self.mask_latents.clone()
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
Expand Down Expand Up @@ -368,19 +371,21 @@ def latents_from_embeddings(
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(self._unet_forward, mask, orig_latents)
else:
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
).to(device=orig_latents.device, dtype=orig_latents.dtype)

latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
# TODO: debug better with or without Oo
if False:
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
).to(device=orig_latents.device, dtype=orig_latents.dtype)

latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)

additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))

Expand Down

0 comments on commit 511da59

Please sign in to comment.