Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

k-diffusion-euler #1019

Merged
merged 28 commits into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
41e3bf2
k-diffusion-euler
hlky Oct 27, 2022
9032c0b
make style make quality
hlky Oct 27, 2022
5496649
make fix-copies
hlky Oct 27, 2022
6753644
fix tests for euler a
patil-suraj Oct 28, 2022
85ae890
Update src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
hlky Oct 28, 2022
bdc8334
Update src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
hlky Oct 28, 2022
003a61b
Update src/diffusers/schedulers/scheduling_euler_discrete.py
hlky Oct 28, 2022
8a3b4a3
Update src/diffusers/schedulers/scheduling_euler_discrete.py
hlky Oct 28, 2022
07a99c7
Merge branch 'k-diffusion-euler' of https://github.com/Sygil-Dev/diff…
patil-suraj Oct 28, 2022
44c0509
remove unused arg and method
patil-suraj Oct 28, 2022
0126944
update doc
patil-suraj Oct 28, 2022
deedc4e
quality
patil-suraj Oct 28, 2022
dde3f8d
make flake happy
patil-suraj Oct 28, 2022
2a33db8
use logger instead of warn
patil-suraj Oct 31, 2022
e1d2c88
raise error instead of deprication
patil-suraj Oct 31, 2022
51a855e
don't require scipy
patil-suraj Oct 31, 2022
18c9d9a
pass generator in step
patil-suraj Oct 31, 2022
66ee52e
fix tests
patil-suraj Oct 31, 2022
3198b77
Apply suggestions from code review
patil-suraj Oct 31, 2022
30db08a
Update tests/test_scheduler.py
patil-suraj Oct 31, 2022
9cf1cf0
remove unused generator
patil-suraj Oct 31, 2022
b1324ca
pass generator as extra_step_kwargs
patil-suraj Oct 31, 2022
c7fe0a0
update tests
patil-suraj Oct 31, 2022
c5e6aa5
pass generator as kwarg
patil-suraj Oct 31, 2022
d6daae7
pass generator as kwarg
patil-suraj Oct 31, 2022
5993631
quality
patil-suraj Oct 31, 2022
6d484c3
fix test for lms
patil-suraj Oct 31, 2022
207a5d2
fix tests
patil-suraj Oct 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from .utils.dummy_pt_objects import * # noqa F403

if is_torch_available() and is_scipy_available():
from .schedulers import LMSDiscreteScheduler
from .schedulers import EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
else:
from .utils.dummy_torch_and_scipy_objects import * # noqa F403

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import (
DDIMScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -52,7 +58,9 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: Union[
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import (
DDIMScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -63,7 +69,9 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: Union[
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@


if is_scipy_available() and is_torch_available():
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from .scheduling_euler_discrete import EulerDiscreteScheduler
from .scheduling_lms_discrete import LMSDiscreteScheduler
else:
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
256 changes: 256 additions & 0 deletions src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.

Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None


class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72

[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.

Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.

"""

@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)

# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()

# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.is_scale_input_called = False

def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
"""
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.

Args:
sample (`torch.FloatTensor`): input sample
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain

Returns:
`torch.FloatTensor`: scaled input sample
"""
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True
return sample

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.

Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps

timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)

def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).

Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`float`): current timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class

Returns:
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
a `tuple`. When returning a tuple, the first element is the sample tensor.

"""
if not self.is_scale_input_called:
warnings.warn(
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)

if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
deprecate(
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
"timestep as an index",
"0.8.0",
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerAncestralDiscreteScheduler.step()` will not be supported in future versions. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep.",
standard_warn=False,
)
step_index = timestep
else:
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output
sigma_from = self.sigmas[step_index]
sigma_to = self.sigmas[step_index + 1]
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5

# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma

dt = sigma_down - sigma

prev_sample = sample + derivative * dt

prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved

if not return_dict:
return (prev_sample,)

return EulerAncestralDiscreteSchedulerOutput(
prev_sample=prev_sample, pred_original_sample=pred_original_sample
)

def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
self.timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

schedule_timesteps = self.timesteps

if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
deprecate(
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
"timesteps as indices",
"0.8.0",
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerAncestralDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
" pass values from `scheduler.timesteps` as timesteps.",
standard_warn=False,
)
step_indices = timesteps
else:
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

sigma = self.sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

noisy_samples = original_samples + noise * sigma
return noisy_samples

def __len__(self):
return self.config.num_train_timesteps
Loading