Skip to content

Commit

Permalink
Stop the source generator thread by sending a Python exception to it …
Browse files Browse the repository at this point in the history
…via the CPython api, this avoids an decref error on shutdown
  • Loading branch information
dagardner-nv committed Aug 29, 2024
1 parent e774f2d commit cdea4ad
Showing 1 changed file with 61 additions and 10 deletions.
71 changes: 61 additions & 10 deletions python/mrc/_pymrc/src/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
#include <pybind11/gil.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pyerrors.h> // for PyExc_SystemExit
#include <rxcpp/rx.hpp>

#include <chrono>
#include <exception>
#include <fstream>
#include <functional>
Expand All @@ -54,6 +56,58 @@
// IWYU thinks we need array for py::print
// IWYU pragma: no_include <array>

namespace {
namespace py = pybind11;
using namespace std::chrono_literals;
constexpr std::chrono::seconds ThreadShutdownTimeoutSec = 10s;

void stop_python_thread(py::object& thread)
{
if (thread.attr("is_alive")().cast<bool>())
{
// The source generator thread needs to stop prior to the thread object going out of scope
auto thread_id = thread.attr("ident").cast<unsigned long>();
try
{
DVLOG(10) << "Attempting to kill python thread " << thread.attr("name").cast<std::string>()
<< " id: " << thread_id;

// Using PyExc_SystemExit since Python won't log it as an uncaught exception
auto num_threads_stopped = PyThreadState_SetAsyncExc(thread_id, PyExc_SystemExit);
if (num_threads_stopped != 1)
{
LOG(WARNING) << "Expected return value of 1 from PyThreadState_SetAsyncExc, received: "
<< num_threads_stopped;
}
else
{
// wait until the thread has actually stopped
bool is_alive = thread.attr("is_alive")().cast<bool>();
const auto deadline = std::chrono::system_clock::now() + ThreadShutdownTimeoutSec;
while (is_alive && std::chrono::system_clock::now() < deadline)
{
{
pybind11::gil_scoped_release no_gil;
boost::this_fiber::yield();
}

is_alive = thread.attr("is_alive")().cast<bool>();
}

if (is_alive)
{
LOG(WARNING) << "Unable to stop thread: " << thread_id << " after "
<< ThreadShutdownTimeoutSec.count() << " seconds";
}
}
} catch (const std::exception& e)
{
LOG(ERROR) << "Encountered error while attempting to stop thread: " << thread_id << ": " << e.what();
}
}
}
} // namespace

namespace mrc::pymrc {

namespace py = pybind11;
Expand Down Expand Up @@ -159,6 +213,7 @@ void iterator_thread(py::iterator itr, py::object queue, py::object exception_qu
py::gil_scoped_acquire gil;
PyIteratorIterator wrapped_iter(std::move(itr));
PyIteratorIterator sentinel;

try
{
while (wrapped_iter != sentinel)
Expand All @@ -172,9 +227,7 @@ void iterator_thread(py::iterator itr, py::object queue, py::object exception_qu
}
} catch (py::error_already_set py_except)
{
std::cerr << "\n********************\nCaught python exception : " << py_except.what() << std::flush;
exception_queue.attr("put")(py::str(py_except.what()));
std::cerr << "\n********************\npushed\n" << std::flush;
}
}

Expand Down Expand Up @@ -248,15 +301,10 @@ class PyIteratorWrapper
{
if (!m_thread.attr("is_alive")().cast<bool>())
{
std::cerr << "**************\n************\nqueue is empty, thread has stopped\n" << std::flush;
// Check to see if we got an exception
if (!m_exception_queue.attr("empty")().cast<bool>())
{
std::cerr << "**************\n************\nExcepion queue has an item\n" << std::flush;

auto py_err_str = m_exception_queue.attr("get")("block"_a = true, "timeout"_a = 0.5);
std::cerr << "**************\n************\nGot an exception, attempting to cast and re-throw\n"
<< std::flush;
throw std::runtime_error(py_err_str.cast<std::string>());
}
throw pybind11::stop_iteration();
Expand All @@ -272,12 +320,14 @@ class PyIteratorWrapper
~PyIteratorWrapper()
{
{
AcquireGIL gil;
pybind11::gil_scoped_acquire gil;

stop_python_thread(m_thread);

m_thread = py::object();
m_exception_queue = py::object();
m_queue = py::object();
m_empty_exception = py::object();
m_thread = py::object();
}
}

Expand All @@ -304,7 +354,8 @@ class PyIteratorWrapper
auto bound_iter_thread_fn = functools_mod.attr(
"partial")(iter_thread_fn, "itr"_a = iter, "queue"_a = m_queue, "exception_queue"_a = m_exception_queue);

m_thread = threading_mod.attr("Thread")("target"_a = bound_iter_thread_fn, "daemon"_a = true);
m_thread = threading_mod.attr(
"Thread")("target"_a = bound_iter_thread_fn, "name"_a = "py_gen_src", "daemon"_a = true);
m_thread.attr("start")();
}

Expand Down

0 comments on commit cdea4ad

Please sign in to comment.