Skip to content

Commit

Permalink
Move module_tracker to logging for confused hierarchy (#134467) (#134501
Browse files Browse the repository at this point in the history
)

* Move module_tracker to logging for confused hierarchy (#134467)

Fixes #134242

Make sure to never raise an error when confused. Logs for confusion can be enabled with `TORCH_LOGS="torch.utils.module_tracker"` or the usual python systems.

Pull Request resolved: #134467
Approved by: https://github.com/malfet

* Fix bad merge conflict resolution
  • Loading branch information
albanD committed Aug 27, 2024
1 parent 6a79d4a commit b84e8c6
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 15 deletions.
42 changes: 36 additions & 6 deletions test/test_module_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from copy import copy

import torch
from torch import nn
from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo
from torch.utils.checkpoint import checkpoint
from torch.utils.module_tracker import ModuleTracker


Expand All @@ -14,7 +16,7 @@ def test_module_hierarchy(self):
seen_fw = []
seen_bw = []

class Foo(torch.nn.Module):
class Foo(nn.Module):
def forward(self, x):
x = x["a"].relu_()
seen_fw.append((copy(tracker.parents), tracker.is_bw))
Expand All @@ -23,12 +25,12 @@ def forward(self, x):
)
return {"a": torch.mm(x, x)}

class Mod(torch.nn.Module):
def __init__(self):
class Mod(nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = Foo()
self.b = torch.nn.ModuleDict({"nest": Foo()})
self.c = torch.nn.ModuleList([Foo()])
self.b = nn.ModuleDict({"nest": Foo()})
self.c = nn.ModuleList([Foo()])

def forward(self, x):
x = self.c[0](x)
Expand Down Expand Up @@ -68,8 +70,36 @@ def forward(self, x):
],
)

def test_confused_hierarchy(self):
class MyMod(nn.Module):
def __init__(self):
super().__init__()
self.inner = nn.Linear(2, 2)
self.ran = False

def forward(self, inp):
if not self.ran:
self.ran = True
return self(inp)
else:
self.ran = False
return self.inner(inp)

mod = MyMod()
inp = torch.rand(1, 2, requires_grad=True)

# Should not fail
with ModuleTracker() as tracker:
res = mod(inp)
res.sum().backward()

# Should not fail
with ModuleTracker() as tracker:
res = checkpoint(lambda inp: mod(inp), inp)
res.sum().backward()

def test_bw_detection(self):
mod = torch.nn.Linear(2, 2)
mod = nn.Linear(2, 2)

with ModuleTracker() as tracker:
mod(torch.rand(2, requires_grad=True)).sum().backward()
Expand Down
3 changes: 2 additions & 1 deletion torch/autograd/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def __subclasshook__(cls, C):

def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
return t.view_as(t).grad_fn.next_functions[0][0]
with torch.enable_grad():
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn

Expand Down
22 changes: 14 additions & 8 deletions torch/utils/module_tracker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
import logging
import weakref

from typing import Set
Expand All @@ -11,6 +12,10 @@
)
from torch.utils._pytree import tree_flatten


logger = logging.getLogger(__name__)


__all__ = ["ModuleTracker"]


Expand Down Expand Up @@ -93,9 +98,10 @@ def fn(*args):
if is_bw:
self._maybe_set_engine_callback()
if name in self.parents:
print(
"The module hierarchy tracking seems to be messed up."
"Please file a bug to PyTorch."
logger.info(
"The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s",
name,
"backward" if is_bw else "forward",
)
self.parents.add(name)

Expand All @@ -105,11 +111,11 @@ def _get_pop_fn(self, name, is_bw):
def fn(*args):
if name in self.parents:
self.parents.remove(name)
elif not is_bw:
# Due to some input/output not requiring gradients, we cannot enforce
# proper nesting in backward
raise RuntimeError(
"The Module hierarchy tracking is wrong. Report a bug to PyTorch"
else:
logger.info(
"The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s",
name,
"backward" if is_bw else "forward",
)

return fn
Expand Down

0 comments on commit b84e8c6

Please sign in to comment.