Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _get_prev_sample function in PNDMScheduler to be better supported by fx #6878

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def scale_model_input(
if self.step_index is None:
self._init_step_index(timestep)

sigma = self.sigmas[self.step_index]
sigma = self.sigmas.index_select(0, self.step_index)
sample = sample / ((sigma**2 + 1) ** 0.5)

self.is_scale_input_called = True
Expand Down Expand Up @@ -425,7 +425,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0

return indices[pos].item()
return torch.tensor(indices[pos].item())

def _init_step_index(self, timestep):
if self.begin_index is None:
Expand Down Expand Up @@ -500,9 +500,15 @@ def step(
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)

sigma = self.sigmas[self.step_index]
#sigma = self.sigmas[self.step_index]
sigma = self.sigmas.index_select(0, self.step_index)

gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
#gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
condition = s_tmin <= sigma
condition1 = sigma <= s_tmax
gamma = torch.where(condition & condition1,
torch.minimum(torch.tensor(s_churn / (len(self.sigmas) - 1)), torch.tensor(2**0.5 - 1)),
torch.tensor(0.0))

noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
Expand All @@ -511,8 +517,9 @@ def step(
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)

if gamma > 0:
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
#if gamma > 0:
# sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
sample = torch.where(gamma > 0, sample + eps * (sigma_hat**2 - sigma**2) ** 0.5, sample)

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
Expand All @@ -532,7 +539,8 @@ def step(
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat

dt = self.sigmas[self.step_index + 1] - sigma_hat
#dt = self.sigmas[self.step_index + 1] - sigma_hat
dt = self.sigmas.index_select(0, self.step_index + 1) - sigma_hat

prev_sample = sample + derivative * dt

Expand Down
13 changes: 11 additions & 2 deletions src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,17 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
if not isinstance(timestep, torch.Tensor):
timestep = torch.tensor(timestep)
if not isinstance(prev_timestep, torch.Tensor):
prev_timestep = torch.tensor(prev_timestep)
alpha_prod_t = self.alphas_cumprod.index_select(0, timestep)
updated_prev_timestep = torch.where(
prev_timestep >= 0, prev_timestep, self.alphas_cumprod.size(dim=0) + prev_timestep
)
alpha_prod_t_prev = torch.where(
prev_timestep >= 0, self.alphas_cumprod.index_select(0, updated_prev_timestep), self.final_alpha_cumprod
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

Expand Down