Skip to content

Diffusion model using TensorFlow | Progressive Distillation

Notifications You must be signed in to change notification settings

ru5h16h/diffusion

Repository files navigation

Diffusion Model using TensorFlow

This repository provides TensorFlow 2 implementation for unconditional image generation, utilizing the following methods:

Set-up

Create a virtual environment:

python3 -m venv venv

Activate the virtual environment:

source venv/bin/activate

Install the requirements:

pip install -r requirements.txt

File Structure

  • dataset/: Contains modules for loading and preprocessing training data.
  • difussion/: Modules for the diffusion process (e.g., forward process, reverse step using DDIM/DDPM, variance scheduling).
  • model/: Helper modules for the UNet architecture.
  • train.py: Script for training the reverse process.
  • train_distill.py: Script for distilling the original model using progressive distillation.
  • infer.py: Script for running the inference using pre-trained model.
  • utils.py: Module containing utility functions.
  • configs.py: Helper module for loading configurations.

Training and Inference

The configurations for the diffusion process, dataset, UNet model, and training/inference related settings are present in configs.yaml. Once updated, train the original/teacher model using:

python train.py

For distilling the student model, run the following after changing the train_cfg.teacher_checkpoint in configs.yaml:

python train_distill.py

For inference, update train_cfg.checkpoint in configs.yaml with the original or distilled model's checkpoint configs, then run:

python infer.py

About

Diffusion model using TensorFlow | Progressive Distillation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages