Skip to content

Commit

Permalink
add support for distributed data parallel training (#116)
Browse files Browse the repository at this point in the history
* make code changes in `train_cifar10.py` to allow DDP (distributed data parallel)

* add instructions to README on how to run cifar10 image generation code on multiple GPUs

* fix: when running cifar10 image generation on multiple gpus, use `rank` for device setting

* fix: load checkpoint on right device

* fix runner ci requirements (#125)

* change pytorch lightning version

* fix pip version

* fix pip in code cov

* change variable name `world_size` to `total_num_gpus`

* change: do not overwrite batch size flag

* add, refactor: calculate number of epochs based on total number of steps, rewrite training loop to use epochs instead of steps

* fix: add `sampler.set_epoch(epoch)` to training loop to shuffle data in distributed mode

* rename file, update README

* add original CIFAR10 training file

---------

Co-authored-by: Alexander Tong <alexandertongdev@gmail.com>
  • Loading branch information
ImahnShekhzadeh and atong01 authored Aug 21, 2024
1 parent b4525b5 commit c25e191
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 5 deletions.
6 changes: 2 additions & 4 deletions examples/images/cifar10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@ python3 train_cifar10.py --model "icfm" --lr 2e-4 --ema_decay 0.9999 --batch_siz
python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

Note that you can train all our methods in parallel using multiple GPUs and DataParallel. You can do this by setting the parallel flag to True in the command line. As an example:
Note that you can train all our methods in parallel using multiple GPUs and DistributedDataParallel. You can do this by providing the number of GPUs, setting the parallel flag to True and providing the master address and port in the command line. As an example:

```bash
python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True
torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10_ddp.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT"
```

*Note from the authors*: We have observed that training with parallel leads to slightly poorer performance than what you can get with one GPU. The reason is probably that DataParallel computes statistics over each device. We are thinking of using DistributedDataParallel to solve this problem in the future. In the meantime, we strongly encourage users to train on a single GPU (the provided scripts require about 8G of GPU memory).

To compute the FID from the OT-CFM model at end of training, run:

```bash
Expand Down
2 changes: 1 addition & 1 deletion examples/images/cifar10/compute_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
# Load the model
PATH = f"{FLAGS.input_dir}/{FLAGS.model}/{FLAGS.model}_cifar10_weights_step_{FLAGS.step}.pt"
print("path: ", PATH)
checkpoint = torch.load(PATH)
checkpoint = torch.load(PATH, map_location=device)
state_dict = checkpoint["ema_model"]
try:
new_net.load_state_dict(state_dict)
Expand Down
214 changes: 214 additions & 0 deletions examples/images/cifar10/train_cifar10_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.

# Authors: Kilian Fatras
# Alexander Tong
# Imahn Shekhzadeh

import copy
import math
import os

import torch
from absl import app, flags
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from tqdm import trange
from utils_cifar import ema, generate_samples, infiniteloop, setup

from torchcfm.conditional_flow_matching import (
ConditionalFlowMatcher,
ExactOptimalTransportConditionalFlowMatcher,
TargetConditionalFlowMatcher,
VariancePreservingConditionalFlowMatcher,
)
from torchcfm.models.unet.unet import UNetModelWrapper

FLAGS = flags.FLAGS

flags.DEFINE_string("model", "otcfm", help="flow matching model type")
flags.DEFINE_string("output_dir", "./results/", help="output_directory")
# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_float("lr", 2e-4, help="target learning rate") # TRY 2e-4
flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping")
flags.DEFINE_integer(
"total_steps", 400001, help="total training steps"
) # Lipman et al uses 400k but double batch size
flags.DEFINE_integer("warmup", 5000, help="learning rate warmup")
flags.DEFINE_integer("batch_size", 128, help="batch size") # Lipman et al uses 128
flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader")
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
flags.DEFINE_bool("parallel", False, help="multi gpu training")
flags.DEFINE_string(
"master_addr", "localhost", help="master address for Distributed Data Parallel"
)
flags.DEFINE_string("master_port", "12355", help="master port for Distributed Data Parallel")

# Evaluation
flags.DEFINE_integer(
"save_step",
20000,
help="frequency of saving checkpoints, 0 to disable during training",
)


def warmup_lr(step):
return min(step, FLAGS.warmup) / FLAGS.warmup


def train(rank, total_num_gpus, argv):
print(
"lr, total_steps, ema decay, save_step:",
FLAGS.lr,
FLAGS.total_steps,
FLAGS.ema_decay,
FLAGS.save_step,
)

if FLAGS.parallel and total_num_gpus > 1:
# When using `DistributedDataParallel`, we need to divide the batch
# size ourselves based on the total number of GPUs of the current node.
batch_size_per_gpu = FLAGS.batch_size // total_num_gpus
setup(rank, total_num_gpus, FLAGS.master_addr, FLAGS.master_port)
else:
batch_size_per_gpu = FLAGS.batch_size

# DATASETS/DATALOADER
dataset = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
),
)
sampler = DistributedSampler(dataset) if FLAGS.parallel else None
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size_per_gpu,
sampler=sampler,
shuffle=False if FLAGS.parallel else True,
num_workers=FLAGS.num_workers,
drop_last=True,
)

datalooper = infiniteloop(dataloader)

# Calculate number of epochs
steps_per_epoch = math.ceil(len(dataset) / FLAGS.batch_size)
num_epochs = math.ceil(FLAGS.total_steps / steps_per_epoch)

# MODELS
net_model = UNetModelWrapper(
dim=(3, 32, 32),
num_res_blocks=2,
num_channels=FLAGS.num_channel,
channel_mult=[1, 2, 2, 2],
num_heads=4,
num_head_channels=64,
attention_resolutions="16",
dropout=0.1,
).to(
rank
) # new dropout + bs of 128

ema_model = copy.deepcopy(net_model)
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
if FLAGS.parallel:
net_model = DistributedDataParallel(net_model, device_ids=[rank])
ema_model = DistributedDataParallel(ema_model, device_ids=[rank])

# show model size
model_size = 0
for param in net_model.parameters():
model_size += param.data.nelement()
print("Model params: %.2f M" % (model_size / 1024 / 1024))

#################################
# OT-CFM
#################################

sigma = 0.0
if FLAGS.model == "otcfm":
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
elif FLAGS.model == "icfm":
FM = ConditionalFlowMatcher(sigma=sigma)
elif FLAGS.model == "fm":
FM = TargetConditionalFlowMatcher(sigma=sigma)
elif FLAGS.model == "si":
FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
else:
raise NotImplementedError(
f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
)

savedir = FLAGS.output_dir + FLAGS.model + "/"
os.makedirs(savedir, exist_ok=True)

global_step = 0 # to keep track of the global step in training loop

with trange(num_epochs, dynamic_ncols=True) as epoch_pbar:
for epoch in epoch_pbar:
epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
if sampler is not None:
sampler.set_epoch(epoch)

with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar:
for step in step_pbar:
global_step += step

optim.zero_grad()
x1 = next(datalooper).to(rank)
x0 = torch.randn_like(x1)
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
vt = net_model(t, xt)
loss = torch.mean((vt - ut) ** 2)
loss.backward()
torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new
optim.step()
sched.step()
ema(net_model, ema_model, FLAGS.ema_decay) # new

# sample and Saving the weights
if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0:
generate_samples(
net_model, FLAGS.parallel, savedir, global_step, net_="normal"
)
generate_samples(
ema_model, FLAGS.parallel, savedir, global_step, net_="ema"
)
torch.save(
{
"net_model": net_model.state_dict(),
"ema_model": ema_model.state_dict(),
"sched": sched.state_dict(),
"optim": optim.state_dict(),
"step": global_step,
},
savedir + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt",
)


def main(argv):
# get world size (number of GPUs)
total_num_gpus = int(os.getenv("WORLD_SIZE", 1))

if FLAGS.parallel and total_num_gpus > 1:
train(rank=int(os.getenv("RANK", 0)), total_num_gpus=total_num_gpus, argv=argv)
else:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train(rank=device, total_num_gpus=total_num_gpus, argv=argv)


if __name__ == "__main__":
app.run(main)
30 changes: 30 additions & 0 deletions examples/images/cifar10/utils_cifar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import copy
import os

import torch
from torch import distributed as dist
from torchdyn.core import NeuralODE

# from torchvision.transforms import ToPILImage
Expand All @@ -10,6 +12,34 @@
device = torch.device("cuda" if use_cuda else "cpu")


def setup(
rank: int,
total_num_gpus: int,
master_addr: str = "localhost",
master_port: str = "12355",
backend: str = "nccl",
):
"""Initialize the distributed environment.
Args:
rank: Rank of the current process.
total_num_gpus: Number of GPUs used in the job.
master_addr: IP address of the master node.
master_port: Port number of the master node.
backend: Backend to use.
"""

os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port

# initialize the process group
dist.init_process_group(
backend=backend,
rank=rank,
world_size=total_num_gpus,
)


def generate_samples(model, parallel, savedir, step, net_="normal"):
"""Save 64 generated images (8 x 8) for sanity check along training.
Expand Down

0 comments on commit c25e191

Please sign in to comment.