Skip to content

Commit

Permalink
Support for DDIMScheduler in Diffusion Policy (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
kashyapakshay authored May 8, 2024
1 parent f5de57b commit 460df2c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
8 changes: 8 additions & 0 deletions lerobot/common/policies/diffusion/configuration_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class DiffusionConfig:
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
Bias modulation is used be default, while this parameter indicates whether to also use scale
modulation.
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
beta_start: Beta value for the first forward-diffusion step.
Expand Down Expand Up @@ -110,6 +111,7 @@ class DiffusionConfig:
diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True
# Noise scheduler.
noise_scheduler_type: str = "DDPM"
num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001
Expand Down Expand Up @@ -144,3 +146,9 @@ def __post_init__(self):
raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
)
supported_noise_schedulers = ["DDPM", "DDIM"]
if self.noise_scheduler_type not in supported_noise_schedulers:
raise ValueError(
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
f"Got {self.noise_scheduler_type}."
)
18 changes: 16 additions & 2 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax
Expand Down Expand Up @@ -126,6 +127,19 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
return {"loss": loss}


def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
"""
Factory for noise scheduler instances of the requested type. All kwargs are passed
to the scheduler.
"""
if name == "DDPM":
return DDPMScheduler(**kwargs)
elif name == "DDIM":
return DDIMScheduler(**kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {name}")


class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
Expand All @@ -138,12 +152,12 @@ def __init__(self, config: DiffusionConfig):
* config.n_obs_steps,
)

self.noise_scheduler = DDPMScheduler(
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start,
beta_end=config.beta_end,
beta_schedule=config.beta_schedule,
variance_type="fixed_small",
clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type,
Expand Down
1 change: 1 addition & 0 deletions lerobot/configs/policy/diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ policy:
diffusion_step_embed_dim: 128
use_film_scale_modulation: True
# Noise scheduler.
noise_scheduler_type: DDPM
num_train_timesteps: 100
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001
Expand Down

0 comments on commit 460df2c

Please sign in to comment.