Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for distributed data parallel training #116

Merged
merged 11 commits into from
Aug 21, 2024

Conversation

ImahnShekhzadeh
Copy link
Contributor

@ImahnShekhzadeh ImahnShekhzadeh commented May 21, 2024

This PR adds support for distributed data parallel (DDP) and replaces DataParallel with DistributedDataParallel in train_cifar.py, which can be used via the flag parallel. To achieve this, the code is refactored, and the flags master_addr and master_port are added.

I tested the changes, on a single GPU, I get an FID of 3.74 (with the OT-CFM method), on two GPUs with DDP, I get an FID of 3.81.

Before submitting

  • Did you make sure title is self-explanatory and the description concisely explains the PR?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you test your PR locally with pytest command?
  • Did you run pre-commit hooks with pre-commit run -a command?

@kilianFatras
Copy link
Collaborator

Hi, thank you for your contribution!

I had an internal implementation with fabric form litghtning but I like to rely only on PyTorch for this example. I need some time to review it (a few days/weeks). I will come back to it soon.

Copy link
Collaborator

@kilianFatras kilianFatras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this nice contribution! My main concern is about the data shuffling. I think we should keep shuffling them. Maybe I am missing something and happy to learn about it.

examples/images/cifar10/compute_fid.py Show resolved Hide resolved
examples/images/cifar10/train_cifar10.py Outdated Show resolved Hide resolved
examples/images/cifar10/train_cifar10.py Outdated Show resolved Hide resolved
@@ -81,7 +89,8 @@ def train(argv):
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=FLAGS.batch_size,
shuffle=True,
sampler=DistributedSampler(dataset) if FLAGS.parallel else None,
shuffle=False if FLAGS.parallel else True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hum. I am rather unsure about this. where do you shuffle the data then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. I found this warning in the PyTorch docs:
image

So in my current implementation, train_sampler.set_epoch(epoch) is missing, which I will add now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect. Once you have finished your change, I will run the code myself. Once I get it working, I will merge the PR.

Final question, can you try to load and run the existing checkpoints? I just want to be sure that people can reproduce our results. Thx.

Copy link
Contributor Author

@ImahnShekhzadeh ImahnShekhzadeh Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I refactored the training loop to use num_epochs instead of FLAGS.total_steps, since sampler.set_epoch(epoch) uses an epoch count. However, I think we need to change more than this. The PyTorch warning I pasted above mentions that we need to use sampler.set_epoch(epoch) "before creating the DataLoader iterator", but right now, the data loader iterator is created once before the training loop:

from utils_cifar import infiniteloop

datalooper = infiniteloop(dataloader)

The way I would change this is by having a training loop like this:

# datalooper = infiniteloop(dataloader)

with trange(num_epochs, dynamic_ncols=True) as epoch_pbar:
    for epoch in epoch_pbar:
        epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
        if sampler is not None:
            sampler.set_epoch(epoch)
        for batch_idx, data in enumerate(dataloader): 
                # step += 1  # tricky
                optim.zero_grad()
                x1 = data.to(rank)  # old: `x1 = next(datalooper)`
                [...]

Is this fine by you? IMO, what is a bit tricky is to handle the step counter correctly (based on which checkpoints are saved and some samples during training are generated). In a distributed setup, we'll have several processes running in parallel, and thus, we would probably save checkpoints and images multiple times (once per process/GPU). However, since the filenames do not reflect the process ID, one process would also overwrite the files of the other. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About your question: When you say "existing checkpoints", which ones do you mean? I had once run the training and generation of samples on one GPU and gotten an FID of 3.8 (which is only slightly worse than the 3.5 you report).

atong01 and others added 2 commits July 29, 2024 19:07
* change pytorch lightning version

* fix pip version

* fix pip in code cov
@kilianFatras
Copy link
Collaborator

I like the new changes. @atong01 do you mind having a look? I also think it would be great to keep the original train_cifar10.py.

While I like this code, it is slightly more complicated than the previous one. So I would keep both. The idea of this package is that any master student can easily understand it in 1hour. @ImahnShekhzadeh can you rename this file train_cifar10_ddp.py please? and re-add the previous file? Thanks

@ImahnShekhzadeh
Copy link
Contributor Author

@ImahnShekhzadeh can you rename this file train_cifar10_ddp.py please? and re-add the previous file? Thanks

Done

@atong01
Copy link
Owner

atong01 commented Aug 21, 2024

LGTM. Thanks for the contribution @ImahnShekhzadeh

@atong01 atong01 merged commit c25e191 into atong01:main Aug 21, 2024
31 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants