-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support for distributed data parallel training #116
Conversation
…e on multiple GPUs
…k` for device setting
Hi, thank you for your contribution! I had an internal implementation with fabric form litghtning but I like to rely only on PyTorch for this example. I need some time to review it (a few days/weeks). I will come back to it soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this nice contribution! My main concern is about the data shuffling. I think we should keep shuffling them. Maybe I am missing something and happy to learn about it.
@@ -81,7 +89,8 @@ def train(argv): | |||
dataloader = torch.utils.data.DataLoader( | |||
dataset, | |||
batch_size=FLAGS.batch_size, | |||
shuffle=True, | |||
sampler=DistributedSampler(dataset) if FLAGS.parallel else None, | |||
shuffle=False if FLAGS.parallel else True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hum. I am rather unsure about this. where do you shuffle the data then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair point. I found this warning in the PyTorch docs:
So in my current implementation, train_sampler.set_epoch(epoch)
is missing, which I will add now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect. Once you have finished your change, I will run the code myself. Once I get it working, I will merge the PR.
Final question, can you try to load and run the existing checkpoints? I just want to be sure that people can reproduce our results. Thx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I refactored the training loop to use num_epochs
instead of FLAGS.total_steps
, since sampler.set_epoch(epoch)
uses an epoch count. However, I think we need to change more than this. The PyTorch warning I pasted above mentions that we need to use sampler.set_epoch(epoch)
"before creating the DataLoader
iterator", but right now, the data loader iterator is created once before the training loop:
from utils_cifar import infiniteloop
datalooper = infiniteloop(dataloader)
The way I would change this is by having a training loop like this:
# datalooper = infiniteloop(dataloader)
with trange(num_epochs, dynamic_ncols=True) as epoch_pbar:
for epoch in epoch_pbar:
epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
if sampler is not None:
sampler.set_epoch(epoch)
for batch_idx, data in enumerate(dataloader):
# step += 1 # tricky
optim.zero_grad()
x1 = data.to(rank) # old: `x1 = next(datalooper)`
[...]
Is this fine by you? IMO, what is a bit tricky is to handle the step
counter correctly (based on which checkpoints are saved and some samples during training are generated). In a distributed setup, we'll have several processes running in parallel, and thus, we would probably save checkpoints and images multiple times (once per process/GPU). However, since the filenames do not reflect the process ID, one process would also overwrite the files of the other. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About your question: When you say "existing checkpoints", which ones do you mean? I had once run the training and generation of samples on one GPU and gotten an FID of 3.8
(which is only slightly worse than the 3.5
you report).
* change pytorch lightning version * fix pip version * fix pip in code cov
…eps, rewrite training loop to use epochs instead of steps
…in distributed mode
I like the new changes. @atong01 do you mind having a look? I also think it would be great to keep the original train_cifar10.py. While I like this code, it is slightly more complicated than the previous one. So I would keep both. The idea of this package is that any master student can easily understand it in 1hour. @ImahnShekhzadeh can you rename this file train_cifar10_ddp.py please? and re-add the previous file? Thanks |
Done |
LGTM. Thanks for the contribution @ImahnShekhzadeh |
This PR adds support for distributed data parallel (DDP) and replaces
DataParallel
withDistributedDataParallel
intrain_cifar.py
, which can be used via the flagparallel
. To achieve this, the code is refactored, and the flagsmaster_addr
andmaster_port
are added.I tested the changes, on a single GPU, I get an FID of 3.74 (with the OT-CFM method), on two GPUs with DDP, I get an FID of 3.81.
Before submitting
pytest
command?pre-commit run -a
command?