diff --git a/doc/changelog.rst b/doc/changelog.rst index b78d76fd7d..2323f8baa2 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -102,6 +102,8 @@ Python bindings - The length of ``MjData.contact`` is now ``ncon`` rather than ``nconmax``, allowing it to be straightforwardly used as an iterator without needing to check ``ncon``. +- Fix a memory leak when a Python callable is installed as callback + (`#527 `_). Version 2.2.2 (September 7, 2022) --------------------------------- diff --git a/python/mujoco/bindings_test.py b/python/mujoco/bindings_test.py index 66e66cb092..79ab3ad842 100644 --- a/python/mujoco/bindings_test.py +++ b/python/mujoco/bindings_test.py @@ -962,6 +962,23 @@ def __call__(self, m, d, stage): self.assertEqual(sensor_callback.count, 1) self.assertEqual(data_with_sensor.sensordata[0], 17) + def test_mjcb_control_not_leak_memory(self): + model_instances = [] + data_instances = [] + for _ in range(10): + mujoco.set_mjcb_control(None) + model_instances.append(mujoco.MjModel.from_xml_string('')) + data_instances.append(mujoco.MjData(model_instances[-1])) + mujoco.set_mjcb_control(lambda m, d: None) + mujoco.mj_step(model_instances[-1], data_instances[-1]) + mujoco.set_mjcb_control(None) + while data_instances: + d = data_instances.pop() + self.assertEqual(sys.getrefcount(d), 2) + while model_instances: + m = model_instances.pop() + self.assertEqual(sys.getrefcount(m), 2) + def test_can_initialize_mjv_structs(self): self.assertIsInstance(mujoco.MjvScene(), mujoco.MjvScene) self.assertIsInstance(mujoco.MjvCamera(), mujoco.MjvCamera) diff --git a/python/mujoco/callbacks.cc b/python/mujoco/callbacks.cc index e32fe753d9..929e345861 100644 --- a/python/mujoco/callbacks.cc +++ b/python/mujoco/callbacks.cc @@ -46,7 +46,7 @@ using enable_if_not_const_t = // table that associates raw MuJoCo struct pointers back to the pointers to // their corresponding wrappers. template -static enable_if_not_const_t MjWrapperLookup(Raw* ptr) { +static enable_if_not_const_t MjWrapperLookup(Raw* ptr) { using LookupFnType = MjWrapper* (Raw*); static LookupFnType* const lookup = []() -> LookupFnType* { py::gil_scoped_acquire gil; @@ -93,8 +93,8 @@ static enable_if_not_const_t MjWrapperLookup(Raw* ptr) { const auto [src, type] = py::detail::type_caster_base>::src_and_type(wrapper); if (type) { - py::handle instance = - py::detail::find_registered_python_instance(wrapper, type); + py::object instance = py::reinterpret_steal( + py::detail::find_registered_python_instance(wrapper, type)); if (!instance) { if (!PyErr_Occurred()) { PyErr_SetString( @@ -117,7 +117,7 @@ static enable_if_not_const_t MjWrapperLookup(Raw* ptr) { } template -static const py::handle MjWrapperLookup(const Raw* ptr) { +static const py::object MjWrapperLookup(const Raw* ptr) { return MjWrapperLookup(const_cast(ptr)); } @@ -172,28 +172,28 @@ static PyObject* py_mju_user_free = nullptr; static PyObject* py_mjcb_passive = nullptr; static void PyMjcbPassive(const raw::MjModel* m, raw::MjData* d) { CallPyCallback("mjcb_passive", py_mjcb_passive, - MjWrapperLookup(m), MjWrapperLookup(d)); + MjWrapperLookup(m), MjWrapperLookup(d)); } static PyObject* py_mjcb_control = nullptr; static void PyMjcbControl(const raw::MjModel* m, raw::MjData* d) { CallPyCallback("mjcb_control", py_mjcb_control, - MjWrapperLookup(m), MjWrapperLookup(d)); + MjWrapperLookup(m), MjWrapperLookup(d)); } static PyObject* py_mjcb_contactfilter = nullptr; static int PyMjcbContactfilter( const raw::MjModel* m, raw::MjData* d, int geom1, int geom2) { return CallPyCallback("mjcb_contactfilter", py_mjcb_contactfilter, - MjWrapperLookup(m), MjWrapperLookup(d), - geom1, geom2); + MjWrapperLookup(m), MjWrapperLookup(d), + geom1, geom2); } static PyObject* py_mjcb_sensor = nullptr; static void PyMjcbSensor(const raw::MjModel* m, raw::MjData* d, int stage) { CallPyCallback("mjcb_sensor", py_mjcb_sensor, - MjWrapperLookup(m), MjWrapperLookup(d), stage); + MjWrapperLookup(m), MjWrapperLookup(d), stage); } static PyObject* py_mjcb_time = nullptr;