Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric with ddp spawn causes script to hang after Trainer.fit() #331

Closed
awaelchli opened this issue Jun 29, 2021 · 8 comments
Closed

Metric with ddp spawn causes script to hang after Trainer.fit() #331

awaelchli opened this issue Jun 29, 2021 · 8 comments
Labels
bug / fix Something isn't working distributed DDP, etc. help wanted Extra attention is needed Lightning Priority Critical task/issue
Milestone

Comments

@awaelchli
Copy link
Contributor

awaelchli commented Jun 29, 2021

🐛 Bug

Adding a torchmetric as an attribute to the model causes processes to hang when launching with ddp spawn.

To Reproduce

  1. Install torchmetrics 0.4, PyTorch Lightning 0.3.7
  2. Run the script below

Code sample

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.metrics import Accuracy


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.metric = Accuracy() # add this to break it all

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        max_epochs=1,
        weights_summary=None,
        accelerator="ddp_spawn",
        gpus=2,
    )
    trainer.fit(model, train_dataloader=train_data)


if __name__ == '__main__':
    run()

A few important notes:

  • All you have to do is self.metric = Accuracy(), then it will break. You don't have to use the metric, just assign it.
  • Installing TM < 0.4 solves the problem

Expected behavior

No hang.

Environment

  • PyTorch Version (e.g., 1.0): 1.8.0
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.8
  • CUDA/cuDNN version: whatever comes with 1.8.0 torch
  • GPU models and configuration: our grid cluster
  • Any other relevant information: no

Additional context

trying to prepare PL patch release: Lightning-AI/pytorch-lightning#8198

@awaelchli awaelchli added bug / fix Something isn't working help wanted Extra attention is needed labels Jun 29, 2021
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@kaushikb11 kaushikb11 added the Priority Critical task/issue label Jun 29, 2021
@awaelchli
Copy link
Contributor Author

I will try to git bisect against torchmetrics master to find the commit.

@awaelchli
Copy link
Contributor Author

Bisecting between 0.3.2 and 0.4.0 I found that commit fc3333b is the problematic one. cc @tchaton

@hudeven
Copy link

hudeven commented Jun 30, 2021

thanks @awaelchli for bring it up! I'm using Lightning + torchmetric and encountered the same issue for ddp. cc: @tchaton , @maximsch2

@SkafteNicki
Copy link
Member

@awaelchli could you try the fix in PR #338 and see if that fixes the problem?

@awaelchli
Copy link
Contributor Author

awaelchli commented Jul 1, 2021

It was fixed on Lightning side Lightning-AI/pytorch-lightning#8218 to account for the changes in torchmetrics state_dict saving and loading.
Your branch also seems to fix it when I test against PL 1.3.7.

@awaelchli
Copy link
Contributor Author

Btw #339 is in parallel reworking the state_dict logic, which also resolves the problem here.

@Borda Borda added this to the v0.4 milestone Jul 2, 2021
@SkafteNicki
Copy link
Member

Closing as #339 solved this.

@Borda Borda added the distributed DDP, etc. label Aug 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working distributed DDP, etc. help wanted Extra attention is needed Lightning Priority Critical task/issue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants