-
Notifications
You must be signed in to change notification settings - Fork 115
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add support for distributed data parallel training (#116)
* make code changes in `train_cifar10.py` to allow DDP (distributed data parallel) * add instructions to README on how to run cifar10 image generation code on multiple GPUs * fix: when running cifar10 image generation on multiple gpus, use `rank` for device setting * fix: load checkpoint on right device * fix runner ci requirements (#125) * change pytorch lightning version * fix pip version * fix pip in code cov * change variable name `world_size` to `total_num_gpus` * change: do not overwrite batch size flag * add, refactor: calculate number of epochs based on total number of steps, rewrite training loop to use epochs instead of steps * fix: add `sampler.set_epoch(epoch)` to training loop to shuffle data in distributed mode * rename file, update README * add original CIFAR10 training file --------- Co-authored-by: Alexander Tong <alexandertongdev@gmail.com>
- Loading branch information
1 parent
b4525b5
commit c25e191
Showing
4 changed files
with
247 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master. | ||
|
||
# Authors: Kilian Fatras | ||
# Alexander Tong | ||
# Imahn Shekhzadeh | ||
|
||
import copy | ||
import math | ||
import os | ||
|
||
import torch | ||
from absl import app, flags | ||
from torch.nn.parallel import DistributedDataParallel | ||
from torch.utils.data import DistributedSampler | ||
from torchdyn.core import NeuralODE | ||
from torchvision import datasets, transforms | ||
from tqdm import trange | ||
from utils_cifar import ema, generate_samples, infiniteloop, setup | ||
|
||
from torchcfm.conditional_flow_matching import ( | ||
ConditionalFlowMatcher, | ||
ExactOptimalTransportConditionalFlowMatcher, | ||
TargetConditionalFlowMatcher, | ||
VariancePreservingConditionalFlowMatcher, | ||
) | ||
from torchcfm.models.unet.unet import UNetModelWrapper | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_string("model", "otcfm", help="flow matching model type") | ||
flags.DEFINE_string("output_dir", "./results/", help="output_directory") | ||
# UNet | ||
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet") | ||
|
||
# Training | ||
flags.DEFINE_float("lr", 2e-4, help="target learning rate") # TRY 2e-4 | ||
flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping") | ||
flags.DEFINE_integer( | ||
"total_steps", 400001, help="total training steps" | ||
) # Lipman et al uses 400k but double batch size | ||
flags.DEFINE_integer("warmup", 5000, help="learning rate warmup") | ||
flags.DEFINE_integer("batch_size", 128, help="batch size") # Lipman et al uses 128 | ||
flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader") | ||
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate") | ||
flags.DEFINE_bool("parallel", False, help="multi gpu training") | ||
flags.DEFINE_string( | ||
"master_addr", "localhost", help="master address for Distributed Data Parallel" | ||
) | ||
flags.DEFINE_string("master_port", "12355", help="master port for Distributed Data Parallel") | ||
|
||
# Evaluation | ||
flags.DEFINE_integer( | ||
"save_step", | ||
20000, | ||
help="frequency of saving checkpoints, 0 to disable during training", | ||
) | ||
|
||
|
||
def warmup_lr(step): | ||
return min(step, FLAGS.warmup) / FLAGS.warmup | ||
|
||
|
||
def train(rank, total_num_gpus, argv): | ||
print( | ||
"lr, total_steps, ema decay, save_step:", | ||
FLAGS.lr, | ||
FLAGS.total_steps, | ||
FLAGS.ema_decay, | ||
FLAGS.save_step, | ||
) | ||
|
||
if FLAGS.parallel and total_num_gpus > 1: | ||
# When using `DistributedDataParallel`, we need to divide the batch | ||
# size ourselves based on the total number of GPUs of the current node. | ||
batch_size_per_gpu = FLAGS.batch_size // total_num_gpus | ||
setup(rank, total_num_gpus, FLAGS.master_addr, FLAGS.master_port) | ||
else: | ||
batch_size_per_gpu = FLAGS.batch_size | ||
|
||
# DATASETS/DATALOADER | ||
dataset = datasets.CIFAR10( | ||
root="./data", | ||
train=True, | ||
download=True, | ||
transform=transforms.Compose( | ||
[ | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
] | ||
), | ||
) | ||
sampler = DistributedSampler(dataset) if FLAGS.parallel else None | ||
dataloader = torch.utils.data.DataLoader( | ||
dataset, | ||
batch_size=batch_size_per_gpu, | ||
sampler=sampler, | ||
shuffle=False if FLAGS.parallel else True, | ||
num_workers=FLAGS.num_workers, | ||
drop_last=True, | ||
) | ||
|
||
datalooper = infiniteloop(dataloader) | ||
|
||
# Calculate number of epochs | ||
steps_per_epoch = math.ceil(len(dataset) / FLAGS.batch_size) | ||
num_epochs = math.ceil(FLAGS.total_steps / steps_per_epoch) | ||
|
||
# MODELS | ||
net_model = UNetModelWrapper( | ||
dim=(3, 32, 32), | ||
num_res_blocks=2, | ||
num_channels=FLAGS.num_channel, | ||
channel_mult=[1, 2, 2, 2], | ||
num_heads=4, | ||
num_head_channels=64, | ||
attention_resolutions="16", | ||
dropout=0.1, | ||
).to( | ||
rank | ||
) # new dropout + bs of 128 | ||
|
||
ema_model = copy.deepcopy(net_model) | ||
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) | ||
sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr) | ||
if FLAGS.parallel: | ||
net_model = DistributedDataParallel(net_model, device_ids=[rank]) | ||
ema_model = DistributedDataParallel(ema_model, device_ids=[rank]) | ||
|
||
# show model size | ||
model_size = 0 | ||
for param in net_model.parameters(): | ||
model_size += param.data.nelement() | ||
print("Model params: %.2f M" % (model_size / 1024 / 1024)) | ||
|
||
################################# | ||
# OT-CFM | ||
################################# | ||
|
||
sigma = 0.0 | ||
if FLAGS.model == "otcfm": | ||
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma) | ||
elif FLAGS.model == "icfm": | ||
FM = ConditionalFlowMatcher(sigma=sigma) | ||
elif FLAGS.model == "fm": | ||
FM = TargetConditionalFlowMatcher(sigma=sigma) | ||
elif FLAGS.model == "si": | ||
FM = VariancePreservingConditionalFlowMatcher(sigma=sigma) | ||
else: | ||
raise NotImplementedError( | ||
f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']" | ||
) | ||
|
||
savedir = FLAGS.output_dir + FLAGS.model + "/" | ||
os.makedirs(savedir, exist_ok=True) | ||
|
||
global_step = 0 # to keep track of the global step in training loop | ||
|
||
with trange(num_epochs, dynamic_ncols=True) as epoch_pbar: | ||
for epoch in epoch_pbar: | ||
epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}") | ||
if sampler is not None: | ||
sampler.set_epoch(epoch) | ||
|
||
with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar: | ||
for step in step_pbar: | ||
global_step += step | ||
|
||
optim.zero_grad() | ||
x1 = next(datalooper).to(rank) | ||
x0 = torch.randn_like(x1) | ||
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) | ||
vt = net_model(t, xt) | ||
loss = torch.mean((vt - ut) ** 2) | ||
loss.backward() | ||
torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new | ||
optim.step() | ||
sched.step() | ||
ema(net_model, ema_model, FLAGS.ema_decay) # new | ||
|
||
# sample and Saving the weights | ||
if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0: | ||
generate_samples( | ||
net_model, FLAGS.parallel, savedir, global_step, net_="normal" | ||
) | ||
generate_samples( | ||
ema_model, FLAGS.parallel, savedir, global_step, net_="ema" | ||
) | ||
torch.save( | ||
{ | ||
"net_model": net_model.state_dict(), | ||
"ema_model": ema_model.state_dict(), | ||
"sched": sched.state_dict(), | ||
"optim": optim.state_dict(), | ||
"step": global_step, | ||
}, | ||
savedir + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt", | ||
) | ||
|
||
|
||
def main(argv): | ||
# get world size (number of GPUs) | ||
total_num_gpus = int(os.getenv("WORLD_SIZE", 1)) | ||
|
||
if FLAGS.parallel and total_num_gpus > 1: | ||
train(rank=int(os.getenv("RANK", 0)), total_num_gpus=total_num_gpus, argv=argv) | ||
else: | ||
use_cuda = torch.cuda.is_available() | ||
device = torch.device("cuda" if use_cuda else "cpu") | ||
train(rank=device, total_num_gpus=total_num_gpus, argv=argv) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters