Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Returning None from training_step with multi GPU DDP training #5243

Open
iamkucuk opened this issue Dec 23, 2020 · 26 comments
Open

Returning None from training_step with multi GPU DDP training #5243

iamkucuk opened this issue Dec 23, 2020 · 26 comments
Assignees
Labels
distributed Generic distributed-related topic feature Is an improvement or enhancement help wanted Open to be worked on priority: 1 Medium priority task
Milestone

Comments

@iamkucuk
Copy link

iamkucuk commented Dec 23, 2020

🐛 Bug

Returning None from training_step with multi GPU DDP training freezes the training without exception

To Reproduce

Starting multi-gpu training with a None-returning training_step function.

Example training_step function:

    def training_step(self, batch, batch_idx):
        data, target = batch
        model_outputs = self.forward(images)
        loss = calc_loss(model_outputs, target)

        if torch.isnan(loss) or random.random() < .05:
            return None

        return loss

Example trainer:

 trainer = Trainer(
    gpus=2,
    distributed_backend="ddp"
)

Expected behavior

To continue training with skipping the current batch as pointed out at here.

Environment

No specific environment is needed to reproduce this bug.

Additional context

This issue was mentioned here: #4956 but not with specifics.

Note: By the time this issue being investigated, a help for a workaround would be great!

@iamkucuk iamkucuk added bug Something isn't working help wanted Open to be worked on labels Dec 23, 2020
@awaelchli awaelchli added the distributed Generic distributed-related topic label Dec 23, 2020
@awaelchli
Copy link
Contributor

Hi @iamkucuk
This is somewhat expected and it is not so obvious how to solve it. You may know, in DDP each subprocess sees a different split of the data, and conditionally skipping the training step would cause the processes to get out of sync. I'm not exactly sure why training freezes but it must have todo with the processes being out of sync or waiting to sync gradients that were not computed in some processes.

@SeanNaren do you have an idea what we could do when the user wants to skip a step in DDP?

@iamkucuk In the meantime I would investigate why the loss in your output is nan in the first place, I am sure it can be avoided entirely. Regarding the second part of your condition (random), what purpose does it serve? Can you not just reduce the data by 5% at the beginning?

@iamkucuk
Copy link
Author

iamkucuk commented Dec 23, 2020

Hi @awaelchli
I am not sure of it yet, but it may be an exploding gradient issue with a single batch generates powerful gradients. It happens rarely but the model is learning perfectly when it doesn't. I tried clipping gradients which seriously impacts the speed of learning process in my case, but seemingly solves the problem. Another approach I tried is accumulating gradients, which I think reduces the effect of a batch causing low quality gradients problem I mentioned before, and it did reduce the nan loss error significantly. However, the problem still persists.

Another approach I tried was equalizing the loss to torch.tensor(0) and I thought it could help me not to update my model for that batch. However, it is causing a loss of the computation graph.

The random.random() < .05 condition just serves to reproduce the nan loss error more often, as it happens very rarely. It has nothing to do with my training procedure.

@tchaton
Copy link
Contributor

tchaton commented Jan 4, 2021

Hey @awaelchli,

I wonder if we could do the following:

class LightningModule
      
      # introduce this parameters so we don't force synchronisation for all users on every step.
      might_training_step_return_none = False

call training_step
process_i -> return None
process_j -> return loss


if might_training_step_return_none:
       should_skip = self.trainer.accelerator_backend.all_reduce(output is None) > 0
       if should_skip:
              return
 

@awaelchli
Copy link
Contributor

awaelchli commented Jan 4, 2021

yes, this is totally feasible.
with this solution we will be on the "safe side" and always skip in all processes when at least one says skip.
Actually, we alreay have this logic for early stopping, see the accelerator, it has reduce for early stopping.

the other option is to skip only if all processes say skip. I can't come up with a use case but surely someone will. with that in mind, I would aim for a parameter like:

# sum or prod
# sum > 0 means skip
# prod > 0 means skip
# int: process with this index decides (broadcast)
skip_decision: Union[str, int] = "sum"

@tchaton tchaton self-assigned this Jan 5, 2021
@edenlightning edenlightning added priority: 1 Medium priority task feature Is an improvement or enhancement and removed bug Something isn't working labels Jan 19, 2021
@stale
Copy link

stale bot commented Feb 19, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Feb 19, 2021
@awaelchli awaelchli removed the won't fix This will not be worked on label Feb 19, 2021
@stale
Copy link

stale bot commented Mar 21, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Mar 21, 2021
@awaelchli awaelchli removed the won't fix This will not be worked on label Mar 21, 2021
@aced125
Copy link

aced125 commented Apr 15, 2021

Would love to know when this feature is on the master branch! Deepspeed training seems to give infs for some reason (even on the minGPT example), so it would be cool to skip these steps.

@edenlightning edenlightning added this to the v1.4 milestone Apr 27, 2021
@edenlightning edenlightning modified the milestones: v1.4, v1.5 Jul 6, 2021
@AAnoosheh
Copy link

Would love to know when this feature is on the master branch! Deepspeed training seems to give infs for some reason (even on the minGPT example), so it would be cool to skip these steps.

I think this is already resolved in #5359 and merged into 1.1.x. If so, it can be closed.
(Though it might still be unresolved when using DDP-Sharded or if you accidentally log the NaN loss, causing a freeze)

@awaelchli
Copy link
Contributor

@AAnoosheh no, #5359 just closed but couldn't be merged yet. We still need to work on it and figure out a solution.

@awaelchli awaelchli modified the milestones: v1.5, v1.6 Oct 22, 2021
@carmocca carmocca modified the milestones: 1.6, future Feb 17, 2022
@kazimpal87
Copy link

Hi, I'm in a similar situation. My batches are formed from the output of an object detector, so sometimes the batch will essentially be of size zero (I can't think of a good way to explain this but just trust it makes sense). In this case, I would like to return None from train_step, or least return some kind of zero loss-tensor with zero gradients. If its not easy to return None, is there some way to artificially construct a zero-tensor with all the appropriate gradients present so that the DDP sync will work?

@AsaphLightricks
Copy link

Hi, is there a solution to this issue?

@carmocca
Copy link
Contributor

carmocca commented May 2, 2022

No @AsaphLightricks. Returning None in traning_step with DDP is currently unsupported.

@jhauret
Copy link
Contributor

jhauret commented Jun 21, 2022

I would need to return None times to times to implement this

@yashpatel5400
Copy link

Also looking for a patch for this! (My use case is pretty much exactly what @kazimpal87 described above)

@magehrig
Copy link

I also need this feature. Is there a known workaround?

@Borda Borda self-assigned this Nov 7, 2022
@EricWiener
Copy link
Contributor

Hi, any updates on this?

@popcornell
Copy link

popcornell commented May 13, 2023

Second this, this is basically a must have feature.
Even totally skipping any update across all workers when just one returns None will be okay for me.
Right now the training freezes and thus lightning + DDP cannot be used for a lot of tasks e.g. automatic speech recognition (ASR) where this issue is common due to the loss used + mixed precision.

Can't it be done here ? basically skipping reduce altogether and returning None ? (not very familiar with Lightning sorry)
https://github.com/Lightning-AI/lightning/blob/7268670d1aec7d40241962c5bc81b0b871de0a57/src/lightning/pytorch/strategies/ddp.py#LL314C5-L331C1

@gianscarpe
Copy link
Contributor

Hi! I'll give this a look :)

@popcornell
Copy link

Great thanks, for now it can be solved with manual optimization BTW if someone has this problem.

On my side the training freezes not on the iteration when None is returned but the one right after, when calling in pytorch_lightning/loops/training_epoch_loop.py:

batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)

the iteration that returns None goes forward actually without a problem apparently and even updates the progress bar.

@popcornell
Copy link

popcornell commented May 18, 2023

I have a fix but it is inside the training loop ( it is similar to @tchaton suggestion above):

    def _reduce_loss(self, utt_id, batch_id, loss, reduction="mean"):

        assert loss.shape[0] == len(utt_id), "loss must be reduced to batch dimension !"

        mask_nan_inf = torch.logical_or(torch.isnan(loss), ~torch.isfinite(loss))
        if torch.any(mask_nan_inf):
            where_invalid = torch.where(mask_nan_inf)[0]
            for indx in range(where_invalid.shape[0]):
                inv_indx = where_invalid[indx].item()
                log.info(
                    f"NaN loss in batch {batch_id} of epoch {self.current_epoch}, for utt_id {utt_id[inv_indx]}"
                )
            # if any is invalid then we must flag this to all processes
            flag_skip = torch.ones((), device=loss.device, dtype=torch.bool)
        else:
            flag_skip = torch.zeros((), device=loss.device, dtype=torch.bool)

        # sub-optimal but will do,
        # till they fix it in https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1552650013
        world_size = torch_dist.get_world_size()
        torch_dist.barrier()
        # now gather
        result = [torch.zeros_like(flag_skip) for _ in range(world_size)]
        torch_dist.all_gather(result, flag_skip)
        any_invalid = torch.sum(torch.stack(result)).bool().item()

        if any_invalid:
            if self.nan_countdown >= 100:
                raise RuntimeError(
                    "Too many NaNs loss iterations encountered, stopping !"
                )
            self.nan_countdown += 1
            return None
        else:
            self.nan_countdown = 1
            return loss.mean() if reduction == "mean" else loss.sum()

Basically I gather a flag across all DDP workers, if any of the workers set the flag all workers must return None.
If all return None there is not anymore freezing.
But it would be neat if this stuff is handled inside lightning. I feel here I just add unnecessary synchronization.

I am sure for someone more familiar with lightning background magic it must be easy to do add something similar in the right place.

@amorehead
Copy link
Contributor

amorehead commented Oct 9, 2023

@popcornell, thanks for sharing! This logic seems to work for OOM errors occurring in the forward pass. However, when I tried this logic similarly in Lightning's backward hook, I experienced the dreaded DDP freezing issue, suggesting that my DDP ranks have still fallen out of sync. My backward code is as follows. Does anyone here have any ideas about what might be causing the ranks to become out of sync with each other? Notably, my code never logs the warning "Ran out of memory in the backward pass. Skipping batches for all ranks.", which suggests that at least one of the ranks is never hitting the barrier (i.e., torch_dist.barrier()) once all the others have.

def backward(self, loss: torch.Tensor, *args: Any, **kwargs: Any):
        """Overrides Lightning's `backward` step to add an out-of-memory (OOM) check."""
        # by default, do not skip the current batch
        skip_flag = torch.zeros(
            (), device=self.device, dtype=torch.bool, requires_grad=False
        )  # NOTE: for skipping batches in a multi-device setting

        try:
            loss.backward(*args, **kwargs, retain_graph=False)
        except RuntimeError as e:
            skip_flag = torch.ones((), device=self.device, dtype=torch.bool, requires_grad=False)
            if "out of memory" in str(e):
                log.warning(
                    f"Ran out of memory in the backward pass, where `torch_dist.is_initialized` is {torch_dist.is_initialized()}. Skipping batch due to: {e}"
                )
                if not torch_dist.is_initialized():
                    # NOTE: for skipping batches in a single-device setting
                    del loss  # delete the computation graph
                    for p in self.net.parameters():
                        if p.grad is not None:
                            del p.grad

        # NOTE: for skipping batches in a multi-device setting
        # credit: https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1553404417
        if torch_dist.is_initialized():
            # if any rank skips a batch, then all other ranks need to skip
            # their batches as well so DDP can properly keep all ranks synced
            world_size = torch_dist.get_world_size()
            torch_dist.barrier()
            result = [torch.zeros_like(skip_flag) for _ in range(world_size)]
            torch_dist.all_gather(result, skip_flag)
            any_skipped = torch.sum(torch.stack(result)).bool().item()
            if any_skipped:
                del loss  # delete the computation graph
                for p in self.net.parameters():
                    if p.grad is not None:
                        del p.grad
                log.warning(
                    "Ran out of memory in the backward pass. Skipping batches for all ranks."
                )

@amorehead
Copy link
Contributor

Following up on my previous comment, it seems like pytorch/pytorch#18853 (comment) discusses a related issue. In this case, it may be the case that the OOM errors I am seeing are causing loss.backward to be successfully called on some variables of the total autograd graph (but not all of them, since some do not complete their loss.backward call). Said another way, is there a way with DDP to manually tell it to effectively mark certain autograd variables as "having completed their backward pass" even if they actually haven't? That way, DDP ranks would never become out of sync with each other.

@awaelchli
Copy link
Contributor

Your code probably hangs because one process gets the exception and the other does not, and so the one that did not raise will hang at the barrier while watching the other die (sad). When you catch it and run some logic, you will need to synchronize this decision. Make it so that either all of them don't raise, or all of them raise, but a mix is not allowed.

@amorehead
Copy link
Contributor

@awaelchli, thanks for your suggestion. I agree: if a DDP rank raises an exception (e.g., e in my code above), this kind of issue will definitely occur. However, what confuses me is my warning log log.warning(f"Ran out of memory in the backward pass. Skipping batch due to: {e}") gets printed to my respective log file right before my entire training job freezes. This implies that that specific DDP rank that OOM'd indeed went down the if branch of my exception handling code, which should not raise the underlying exception e. To check this though, I am now running a test of this same scenario where I remove the else: raise e branch of the exception handling logic to see if indeed some DDP ranks were raising the actual exception when they shouldn't have been. I'll report back with my findings.

@amorehead
Copy link
Contributor

amorehead commented Oct 10, 2023

@awaelchli, sadly, removing the else: raise e branch in my exception handling logic above does not resolve the DDP rank freezing issue I am facing. What else might be causing one of the ranks to die before it reaches the torch_dist.barrier() call? I can confirm that I can see exactly one warning log being issued before my training job completely freezes across all ranks:

Ran out of memory in the backward pass, where `torch_dist.is_initialized` is True. Skipping batch due to: CUDA out of memory. Tried to allocate 5.18 GiB (GPU 0; 79.15 GiB total capacity; 72.98 GiB already allocated; 139.25 MiB free; 78.39 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.

Since torch_dist.is_initialized is confirmed to be set to True here, any DDP rank that logs such a message should not return None, thereby preventing such a rank from "dying early".

@amorehead
Copy link
Contributor

amorehead commented Oct 11, 2023

Sadly, it looks like I may have found a culprit for this issue: wandb (using the latest version 0.15.12 - reference: pytorch/xla#1562 (comment)).

After disabling my WandbLogger temporarily and rerunning my model training script from the latest checkpoint (shortly after which my script would normally OOM during the loss.backward() call), I am now not seeing the original OOM error which would normally freeze all my DDP ranks from there on. I don't believe I am having wandb watch my model's computational graph in any way, unless similar behavior is enabled by default with Lightning's WandbLogger.

Has anyone else experienced this issue and found a way around it (besides perhaps setting offline=true for WandbLogger and manually syncing local logs to remote later on)? @sydholl has the wandb team seen any issues like this occurring before (wandb/wandb#2091)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
distributed Generic distributed-related topic feature Is an improvement or enhancement help wanted Open to be worked on priority: 1 Medium priority task
Projects
None yet
Development

No branches or pull requests