Skip to content

Commit

Permalink
[release/2.5] ModuleTracker: Add explicit garbage collection
Browse files Browse the repository at this point in the history
When running an FSDP model with FlopCounterMode, we are experiencing a memory
leak. It is coming from ModuleTracker class. Even though
ModuleTracker class is keeping weakrefrences of the operators, the
tensors/operators are not being freed after the backward pass. To force
free these tensors/operators after backwardpass, I explicitly added
garbage collection in the post forward hook.

(cherry picked from commit 63dc40d)
  • Loading branch information
pragupta committed Oct 30, 2024
1 parent 8ec0173 commit 328b01c
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torch/utils/module_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
register_module_forward_pre_hook,
)
from torch.utils._pytree import tree_flatten

import gc

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -136,6 +136,7 @@ def _fw_post_hook(self, mod, input, output):
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
if tensors:
register_multi_grad_hook(tensors, self._get_append_fn(name, True))
gc.collect()

def __enter__(self):
self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
Expand Down

0 comments on commit 328b01c

Please sign in to comment.