Skip to content

Commit

Permalink
[dynamo] handle 3.13.0 __dict__ watcher bug (pytorch#138284)
Browse files Browse the repository at this point in the history
python/cpython#116115 introduced a bug (python/cpython#125608) where changing the attributes of an object may not fire the dict watchers registered to the object's `__dict__`. It has been fixed by python/cpython#125611 but will only be in 3.13.1+.

This PR disables the dict watcher guard shortcut for `__dict__`s on 3.13.0 and warns the user to try using 3.13.1+ instead. We also added a simple test to check for this functionality in the future.

Pull Request resolved: pytorch#138284
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#138030
  • Loading branch information
williamwen42 authored and rahulsingh-intel committed Oct 29, 2024
1 parent b6d0cdb commit 32ddf67
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
20 changes: 20 additions & 0 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,26 @@ def test_self_mutating1(self):
else:
self.assertExpectedInline(cnt.frame_count, """1""")

def test_nn_module_setattr(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.var = 0

@torch.compile(backend="eager", dynamic=False)
def f(x, m):
return x + m.var

inp = torch.ones(3)
m = Mod()

self.assertEqual(f(inp, m), inp)
# In 3.13.0, setattr will not fire a __dict__'s watchers,
# so guards may not be invalidated.
m.var = 1
# should trigger a recompile
self.assertEqual(f(inp, m), inp + 1)

@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
def test_generation_tag(self):
cnt = torch._dynamo.testing.CompileCounter()
Expand Down
35 changes: 27 additions & 8 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import textwrap
import time
import types
import warnings
import weakref
from contextlib import contextmanager
from copy import deepcopy
Expand Down Expand Up @@ -651,6 +652,20 @@ def guard_on_dict_keys_and_order(self, value, guard):
key, get_verbose_code_parts(f"{key_source} == {key!r}", guard)
)

@staticmethod
def _get_generic_dict_manager_example_value(example_value):
# due to a bug in 3.13.0 (introduced by https://github.com/python/cpython/pull/116115,
# reported in https://github.com/python/cpython/issues/125608,
# fixed by https://github.com/python/cpython/pull/125611), we cannot take
# advantage of __dict__ versions to speed up guard checks.
if sys.version_info >= (3, 13) and sys.version_info < (3, 13, 1):
warnings.warn(
"Guards may run slower on Python 3.13.0. Consider upgrading to Python 3.13.1+.",
RuntimeWarning,
)
return None
return example_value

def getattr_on_nn_module(
self,
source,
Expand Down Expand Up @@ -776,7 +791,7 @@ def getitem_on_dict_mgr(
# Guard Manager
mod_generic_dict_manager = base_guard_manager.get_generic_dict_manager(
source=mod_dict_source,
example_value=mod_dict,
example_value=self._get_generic_dict_manager_example_value(mod_dict),
guard_manager_enum=GuardManagerType.GUARD_MANAGER,
)

Expand Down Expand Up @@ -1271,7 +1286,7 @@ def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None:
mod_dict_source = f"{guard.name}.__dict__"
mod_generic_dict_manager = base_manager.get_generic_dict_manager(
source=mod_dict_source,
example_value=val.__dict__,
example_value=self._get_generic_dict_manager_example_value(val.__dict__),
guard_manager_enum=GuardManagerType.GUARD_MANAGER,
)

Expand Down Expand Up @@ -2261,12 +2276,16 @@ def add_code_part(code_part, guard, log_only=False):
structured_guard_fns.append(
lambda: {
"code": code_part,
"stack": structured.from_traceback(guard.stack.summary())
if guard.stack
else None,
"user_stack": structured.from_traceback(guard.user_stack)
if guard.user_stack
else None,
"stack": (
structured.from_traceback(guard.stack.summary())
if guard.stack
else None
),
"user_stack": (
structured.from_traceback(guard.user_stack)
if guard.user_stack
else None
),
}
)

Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/dynamo/guards.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <torch/extension.h>

#include <torch/csrc/dynamo/debug_macros.h>

#ifdef USE_CUDA
#include <ATen/cuda/EmptyTensor.h>
#endif
Expand Down Expand Up @@ -655,7 +657,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {

static std::unordered_map<PyObject*, uint64_t> dict_version_map;
static int dict_version_watcher_id;
static uint64_t global_dict_version_id = 0;
static uint64_t global_dict_version_id = 1;
static int dict_version_watch_callback(
PyDict_WatchEvent event,
PyObject* dict,
Expand Down

0 comments on commit 32ddf67

Please sign in to comment.