Skip to content

Commit

Permalink
Fix a memory leak when using Python callbacks.
Browse files Browse the repository at this point in the history
Fixes #527.

PiperOrigin-RevId: 481173711
Change-Id: Ie176a8bda09727b6be8b6641d3be9a75c675d566
  • Loading branch information
saran-t authored and copybara-github committed Oct 14, 2022
1 parent af74044 commit ec6ea6a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
2 changes: 2 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ Python bindings
- The length of ``MjData.contact`` is now ``ncon`` rather than ``nconmax``, allowing it to be straightforwardly used as
an iterator without needing to check ``ncon``.

- Fix a memory leak when a Python callable is installed as callback
(`#527 <https://github.com/deepmind/mujoco/issues/527>`_).

Version 2.2.2 (September 7, 2022)
---------------------------------
Expand Down
17 changes: 17 additions & 0 deletions python/mujoco/bindings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,23 @@ def __call__(self, m, d, stage):
self.assertEqual(sensor_callback.count, 1)
self.assertEqual(data_with_sensor.sensordata[0], 17)

def test_mjcb_control_not_leak_memory(self):
model_instances = []
data_instances = []
for _ in range(10):
mujoco.set_mjcb_control(None)
model_instances.append(mujoco.MjModel.from_xml_string('<mujoco/>'))
data_instances.append(mujoco.MjData(model_instances[-1]))
mujoco.set_mjcb_control(lambda m, d: None)
mujoco.mj_step(model_instances[-1], data_instances[-1])
mujoco.set_mjcb_control(None)
while data_instances:
d = data_instances.pop()
self.assertEqual(sys.getrefcount(d), 2)
while model_instances:
m = model_instances.pop()
self.assertEqual(sys.getrefcount(m), 2)

def test_can_initialize_mjv_structs(self):
self.assertIsInstance(mujoco.MjvScene(), mujoco.MjvScene)
self.assertIsInstance(mujoco.MjvCamera(), mujoco.MjvCamera)
Expand Down
18 changes: 9 additions & 9 deletions python/mujoco/callbacks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ using enable_if_not_const_t =
// table that associates raw MuJoCo struct pointers back to the pointers to
// their corresponding wrappers.
template <typename Raw>
static enable_if_not_const_t<Raw, py::handle> MjWrapperLookup(Raw* ptr) {
static enable_if_not_const_t<Raw, py::object> MjWrapperLookup(Raw* ptr) {
using LookupFnType = MjWrapper<Raw>* (Raw*);
static LookupFnType* const lookup = []() -> LookupFnType* {
py::gil_scoped_acquire gil;
Expand Down Expand Up @@ -93,8 +93,8 @@ static enable_if_not_const_t<Raw, py::handle> MjWrapperLookup(Raw* ptr) {
const auto [src, type] =
py::detail::type_caster_base<MjWrapper<Raw>>::src_and_type(wrapper);
if (type) {
py::handle instance =
py::detail::find_registered_python_instance(wrapper, type);
py::object instance = py::reinterpret_steal<py::object>(
py::detail::find_registered_python_instance(wrapper, type));
if (!instance) {
if (!PyErr_Occurred()) {
PyErr_SetString(
Expand All @@ -117,7 +117,7 @@ static enable_if_not_const_t<Raw, py::handle> MjWrapperLookup(Raw* ptr) {
}

template <typename Raw>
static const py::handle MjWrapperLookup(const Raw* ptr) {
static const py::object MjWrapperLookup(const Raw* ptr) {
return MjWrapperLookup(const_cast<Raw*>(ptr));
}

Expand Down Expand Up @@ -172,28 +172,28 @@ static PyObject* py_mju_user_free = nullptr;
static PyObject* py_mjcb_passive = nullptr;
static void PyMjcbPassive(const raw::MjModel* m, raw::MjData* d) {
CallPyCallback<void>("mjcb_passive", py_mjcb_passive,
MjWrapperLookup(m), MjWrapperLookup(d));
MjWrapperLookup(m), MjWrapperLookup(d));
}

static PyObject* py_mjcb_control = nullptr;
static void PyMjcbControl(const raw::MjModel* m, raw::MjData* d) {
CallPyCallback<void>("mjcb_control", py_mjcb_control,
MjWrapperLookup(m), MjWrapperLookup(d));
MjWrapperLookup(m), MjWrapperLookup(d));
}

static PyObject* py_mjcb_contactfilter = nullptr;
static int PyMjcbContactfilter(
const raw::MjModel* m, raw::MjData* d, int geom1, int geom2) {
return CallPyCallback<int>("mjcb_contactfilter", py_mjcb_contactfilter,
MjWrapperLookup(m), MjWrapperLookup(d),
geom1, geom2);
MjWrapperLookup(m), MjWrapperLookup(d),
geom1, geom2);
}

static PyObject* py_mjcb_sensor = nullptr;
static void
PyMjcbSensor(const raw::MjModel* m, raw::MjData* d, int stage) {
CallPyCallback<void>("mjcb_sensor", py_mjcb_sensor,
MjWrapperLookup(m), MjWrapperLookup(d), stage);
MjWrapperLookup(m), MjWrapperLookup(d), stage);
}

static PyObject* py_mjcb_time = nullptr;
Expand Down

0 comments on commit ec6ea6a

Please sign in to comment.