diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2c531cf8cee0..49c3e82b8e7b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -41,6 +41,8 @@ from .schedulers import ( DDIMScheduler, DDPMScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, IPNDMScheduler, KarrasVeScheduler, PNDMScheduler, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 098e8e6a85c9..e64f50b38993 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -9,7 +9,13 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import ( + DDIMScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -52,7 +58,9 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[ + DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler + ], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, ): @@ -334,6 +342,11 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index db8134834564..799b1c19f6c1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -10,7 +10,13 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import ( + DDIMScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -63,7 +69,9 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[ + DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler + ], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, ): @@ -335,6 +343,11 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index d49f55aa921d..77bd55f2744d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -379,6 +379,11 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 71a2be62128c..55472f34cf68 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -352,6 +352,11 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index f179acde09e2..c3999d2cac61 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -19,6 +19,8 @@ if is_torch_available(): from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler + from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler + from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_pndm import PNDMScheduler diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py new file mode 100644 index 000000000000..b153dfbe2f02 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -0,0 +1,261 @@ +# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, deprecate, logging +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete +class EulerAncestralDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.is_scale_input_called = False + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + + Returns: + `torch.FloatTensor`: scaled input sample + """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + generator (`torch.Generator`, optional): Random number generator. + return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise + a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep.", + ) + + if not self.is_scale_input_called: + logger.warn( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma * model_output + sigma_from = self.sigmas[step_index] + sigma_to = self.sigmas[step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + prev_sample = sample + derivative * dt + + device = model_output.device if torch.is_tensor(model_output) else "cpu" + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + prev_sample = prev_sample + noise * sigma_up + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + self.timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + schedule_timesteps = self.timesteps + + if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor): + deprecate( + "timesteps as indices", + "0.8.0", + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerAncestralDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to" + " pass values from `scheduler.timesteps` as timesteps.", + standard_warn=False, + ) + step_indices = timesteps + else: + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = self.sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py new file mode 100644 index 000000000000..a71fbda13df6 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -0,0 +1,270 @@ +# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, deprecate, logging +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete +class EulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original + k-diffusion implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.is_scale_input_called = False + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + + Returns: + `torch.FloatTensor`: scaled input sample + """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + s_churn (`float`) + s_tmin (`float`) + s_tmax (`float`) + s_noise (`float`) + generator (`torch.Generator`, optional): Random number generator. + return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep.", + ) + + if not self.is_scale_input_called: + logger.warn( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + device = model_output.device if torch.is_tensor(model_output) else "cpu" + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma_hat * model_output + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + + dt = self.sigmas[step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + + if not return_dict: + return (prev_sample,) + + return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + self.timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + schedule_timesteps = self.timesteps + + if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor): + deprecate( + "timesteps as indices", + "0.8.0", + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to" + " pass values from `scheduler.timesteps` as timesteps.", + standard_warn=False, + ) + step_indices = timesteps + else: + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = self.sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 93834c76550b..5dd583279708 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -272,6 +272,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class EulerAncestralDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class EulerDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class IPNDMScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 6b23aca33bc6..c21d40f4c5da 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -24,6 +24,8 @@ from diffusers import ( AutoencoderKL, DDIMScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, StableDiffusionPipeline, @@ -361,6 +363,96 @@ def test_stable_diffusion_k_lms(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_k_euler_ancestral(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 128, 128, 3) + expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_k_euler(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = EulerDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 128, 128, 3) + expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_attention_chunk(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 194fb66f663f..9285eed20ff0 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import tempfile import unittest from typing import Dict, List, Tuple @@ -22,6 +23,8 @@ from diffusers import ( DDIMScheduler, DDPMScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, IPNDMScheduler, LMSDiscreteScheduler, PNDMScheduler, @@ -77,7 +80,11 @@ def check_over_configs(self, time_step=0, **config): num_inference_steps = kwargs.pop("num_inference_steps", None) + # TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default for scheduler_class in self.scheduler_classes: + if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): + time_step = float(time_step) + sample = self.dummy_sample residual = 0.1 * sample @@ -94,7 +101,13 @@ def check_over_configs(self, time_step=0, **config): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps + # Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" @@ -106,6 +119,9 @@ def check_over_forward(self, time_step=0, **forward_kwargs): num_inference_steps = kwargs.pop("num_inference_steps", None) for scheduler_class in self.scheduler_classes: + if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): + time_step = float(time_step) + sample = self.dummy_sample residual = 0.1 * sample @@ -122,9 +138,12 @@ def check_over_forward(self, time_step=0, **forward_kwargs): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - torch.manual_seed(0) + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample - torch.manual_seed(0) + + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" @@ -141,6 +160,10 @@ def test_from_pretrained_save_pretrained(self): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + timestep = 1 + if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): + timestep = float(timestep) + with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) @@ -151,10 +174,13 @@ def test_from_pretrained_save_pretrained(self): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - torch.manual_seed(0) - output = scheduler.step(residual, 1, sample, **kwargs).prev_sample - torch.manual_seed(0) - new_output = new_scheduler.step(residual, 1, sample, **kwargs).prev_sample + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) + output = scheduler.step(residual, timestep, sample, **kwargs).prev_sample + + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) + new_output = new_scheduler.step(residual, timestep, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" @@ -163,7 +189,14 @@ def test_step_shape(self): num_inference_steps = kwargs.pop("num_inference_steps", None) + timestep_0 = 0 + timestep_1 = 1 + for scheduler_class in self.scheduler_classes: + if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): + timestep_0 = float(timestep_0) + timestep_1 = float(timestep_1) + scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) @@ -175,8 +208,8 @@ def test_step_shape(self): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - output_0 = scheduler.step(residual, 0, sample, **kwargs).prev_sample - output_1 = scheduler.step(residual, 1, sample, **kwargs).prev_sample + output_0 = scheduler.step(residual, timestep_0, sample, **kwargs).prev_sample + output_1 = scheduler.step(residual, timestep_1, sample, **kwargs).prev_sample self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) @@ -216,6 +249,9 @@ def recursive_check(tuple_object, dict_object): timestep = 1 for scheduler_class in self.scheduler_classes: + if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): + timestep = float(timestep) + scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) @@ -227,6 +263,9 @@ def recursive_check(tuple_object, dict_object): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps + # Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) outputs_dict = scheduler.step(residual, timestep, sample, **kwargs) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): @@ -234,6 +273,9 @@ def recursive_check(tuple_object, dict_object): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps + # Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler + if "generator" in set(inspect.signature(scheduler.step).parameters.keys()): + kwargs["generator"] = torch.Generator().manual_seed(0) outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs) recursive_check(outputs_tuple, outputs_dict) @@ -933,6 +975,117 @@ def test_full_loop_no_noise(self): assert abs(result_mean.item() - 1.31) < 1e-3 +class EulerDiscreteSchedulerTest(SchedulerCommonTest): + scheduler_classes = (EulerDiscreteScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1100, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "trained_betas": None, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "scaled_linear"]: + self.check_over_configs(beta_schedule=schedule) + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + generator = torch.Generator().manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + print(result_sum, result_mean) + + assert abs(result_sum.item() - 10.0807) < 1e-2 + assert abs(result_mean.item() - 0.0131) < 1e-3 + + +class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): + scheduler_classes = (EulerAncestralDiscreteScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1100, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "trained_betas": None, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "scaled_linear"]: + self.check_over_configs(beta_schedule=schedule) + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + generator = torch.Generator().manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + print(result_sum, result_mean) + assert abs(result_sum.item() - 152.3192) < 1e-2 + assert abs(result_mean.item() - 0.1983) < 1e-3 + + class IPNDMSchedulerTest(SchedulerCommonTest): scheduler_classes = (IPNDMScheduler,) forward_default_kwargs = (("num_inference_steps", 50),)