Skip to content

Commit

Permalink
fix noise scheduler error in stable diffusion (#2171)
Browse files Browse the repository at this point in the history
  • Loading branch information
divyashreepathihalli authored Nov 20, 2023
1 parent 40ae4ae commit 862dcb9
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion keras_cv/models/stable_diffusion/noise_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,12 @@ def add_noise(
sqrt_one_minus_alpha_prod = ops.expand_dims(
sqrt_one_minus_alpha_prod, axis=-1
)

sqrt_alpha_prod = ops.cast(
sqrt_alpha_prod, dtype=original_samples.dtype
)
sqrt_one_minus_alpha_prod = ops.cast(
sqrt_one_minus_alpha_prod, dtype=noise.dtype
)
noisy_samples = (
sqrt_alpha_prod * original_samples
+ sqrt_one_minus_alpha_prod * noise
Expand Down

0 comments on commit 862dcb9

Please sign in to comment.