-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Returning None from training_step with multi GPU DDP training #5243
Comments
Hi @iamkucuk @SeanNaren do you have an idea what we could do when the user wants to skip a step in DDP? @iamkucuk In the meantime I would investigate why the loss in your output is nan in the first place, I am sure it can be avoided entirely. Regarding the second part of your condition (random), what purpose does it serve? Can you not just reduce the data by 5% at the beginning? |
Hi @awaelchli Another approach I tried was equalizing the loss to The |
Hey @awaelchli, I wonder if we could do the following:
call training_step
|
yes, this is totally feasible. the other option is to skip only if all processes say skip. I can't come up with a use case but surely someone will. with that in mind, I would aim for a parameter like: # sum or prod
# sum > 0 means skip
# prod > 0 means skip
# int: process with this index decides (broadcast)
skip_decision: Union[str, int] = "sum" |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
Would love to know when this feature is on the master branch! Deepspeed training seems to give infs for some reason (even on the minGPT example), so it would be cool to skip these steps. |
I think this is already resolved in #5359 and merged into 1.1.x. If so, it can be closed. |
@AAnoosheh no, #5359 just closed but couldn't be merged yet. We still need to work on it and figure out a solution. |
Hi, I'm in a similar situation. My batches are formed from the output of an object detector, so sometimes the batch will essentially be of size zero (I can't think of a good way to explain this but just trust it makes sense). In this case, I would like to return None from train_step, or least return some kind of zero loss-tensor with zero gradients. If its not easy to return None, is there some way to artificially construct a zero-tensor with all the appropriate gradients present so that the DDP sync will work? |
Hi, is there a solution to this issue? |
No @AsaphLightricks. Returning |
I would need to return None times to times to implement this |
Also looking for a patch for this! (My use case is pretty much exactly what @kazimpal87 described above) |
I also need this feature. Is there a known workaround? |
Hi, any updates on this? |
Second this, this is basically a must have feature. Can't it be done here ? basically skipping reduce altogether and returning None ? (not very familiar with Lightning sorry) |
Hi! I'll give this a look :) |
Great thanks, for now it can be solved with manual optimization BTW if someone has this problem. On my side the training freezes not on the iteration when batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs) the iteration that returns None goes forward actually without a problem apparently and even updates the progress bar. |
I have a fix but it is inside the training loop ( it is similar to @tchaton suggestion above): def _reduce_loss(self, utt_id, batch_id, loss, reduction="mean"):
assert loss.shape[0] == len(utt_id), "loss must be reduced to batch dimension !"
mask_nan_inf = torch.logical_or(torch.isnan(loss), ~torch.isfinite(loss))
if torch.any(mask_nan_inf):
where_invalid = torch.where(mask_nan_inf)[0]
for indx in range(where_invalid.shape[0]):
inv_indx = where_invalid[indx].item()
log.info(
f"NaN loss in batch {batch_id} of epoch {self.current_epoch}, for utt_id {utt_id[inv_indx]}"
)
# if any is invalid then we must flag this to all processes
flag_skip = torch.ones((), device=loss.device, dtype=torch.bool)
else:
flag_skip = torch.zeros((), device=loss.device, dtype=torch.bool)
# sub-optimal but will do,
# till they fix it in https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1552650013
world_size = torch_dist.get_world_size()
torch_dist.barrier()
# now gather
result = [torch.zeros_like(flag_skip) for _ in range(world_size)]
torch_dist.all_gather(result, flag_skip)
any_invalid = torch.sum(torch.stack(result)).bool().item()
if any_invalid:
if self.nan_countdown >= 100:
raise RuntimeError(
"Too many NaNs loss iterations encountered, stopping !"
)
self.nan_countdown += 1
return None
else:
self.nan_countdown = 1
return loss.mean() if reduction == "mean" else loss.sum() Basically I gather a flag across all DDP workers, if any of the workers set the flag all workers must return None. I am sure for someone more familiar with lightning background magic it must be easy to do add something similar in the right place. |
@popcornell, thanks for sharing! This logic seems to work for OOM errors occurring in the forward pass. However, when I tried this logic similarly in def backward(self, loss: torch.Tensor, *args: Any, **kwargs: Any):
"""Overrides Lightning's `backward` step to add an out-of-memory (OOM) check."""
# by default, do not skip the current batch
skip_flag = torch.zeros(
(), device=self.device, dtype=torch.bool, requires_grad=False
) # NOTE: for skipping batches in a multi-device setting
try:
loss.backward(*args, **kwargs, retain_graph=False)
except RuntimeError as e:
skip_flag = torch.ones((), device=self.device, dtype=torch.bool, requires_grad=False)
if "out of memory" in str(e):
log.warning(
f"Ran out of memory in the backward pass, where `torch_dist.is_initialized` is {torch_dist.is_initialized()}. Skipping batch due to: {e}"
)
if not torch_dist.is_initialized():
# NOTE: for skipping batches in a single-device setting
del loss # delete the computation graph
for p in self.net.parameters():
if p.grad is not None:
del p.grad
# NOTE: for skipping batches in a multi-device setting
# credit: https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1553404417
if torch_dist.is_initialized():
# if any rank skips a batch, then all other ranks need to skip
# their batches as well so DDP can properly keep all ranks synced
world_size = torch_dist.get_world_size()
torch_dist.barrier()
result = [torch.zeros_like(skip_flag) for _ in range(world_size)]
torch_dist.all_gather(result, skip_flag)
any_skipped = torch.sum(torch.stack(result)).bool().item()
if any_skipped:
del loss # delete the computation graph
for p in self.net.parameters():
if p.grad is not None:
del p.grad
log.warning(
"Ran out of memory in the backward pass. Skipping batches for all ranks."
) |
Following up on my previous comment, it seems like pytorch/pytorch#18853 (comment) discusses a related issue. In this case, it may be the case that the OOM errors I am seeing are causing |
Your code probably hangs because one process gets the exception and the other does not, and so the one that did not raise will hang at the barrier while watching the other die (sad). When you catch it and run some logic, you will need to synchronize this decision. Make it so that either all of them don't raise, or all of them raise, but a mix is not allowed. |
@awaelchli, thanks for your suggestion. I agree: if a DDP rank raises an exception (e.g., |
@awaelchli, sadly, removing the
Since |
Sadly, it looks like I may have found a culprit for this issue: After disabling my Has anyone else experienced this issue and found a way around it (besides perhaps setting |
🐛 Bug
Returning None from training_step with multi GPU DDP training freezes the training without exception
To Reproduce
Starting multi-gpu training with a None-returning training_step function.
Example training_step function:
Example trainer:
Expected behavior
To continue training with skipping the current batch as pointed out at here.
Environment
No specific environment is needed to reproduce this bug.
Additional context
This issue was mentioned here: #4956 but not with specifics.
Note: By the time this issue being investigated, a help for a workaround would be great!
The text was updated successfully, but these errors were encountered: