Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FFI][BUGFIX] Grab GIL when check env signals #17419

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions python/tvm/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ cdef inline void* c_handle(object handle):
# python env API
cdef extern from "Python.h":
int PyErr_CheckSignals()
void* PyGILState_Ensure()
void PyGILState_Release(void*)
void Py_IncRef(void*)
void Py_DecRef(void*)

cdef extern from "tvm/runtime/c_backend_api.h":
int TVMBackendRegisterEnvCAPI(const char* name, void* ptr)
Expand All @@ -210,11 +214,13 @@ cdef _init_env_api():
# so backend can call tvm::runtime::EnvCheckSignals to check
# signal when executing a long running function.
#
# This feature is only enabled in cython for now due to problems of calling
# these functions in ctypes.
#
# When the functions are not registered, the signals will be handled
# only when the FFI function returns.
# Also registers the gil state release and ensure as PyErr_CheckSignals
# function is called with gil released and we need to regrab the gil
CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyErr_CheckSignals"), <void*>PyErr_CheckSignals))
CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Ensure"), <void*>PyGILState_Ensure))
CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Release"), <void*>PyGILState_Release))
CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Release"), <void*>PyGILState_Release))
CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("Py_IncRef"), <void*>Py_IncRef))
CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("Py_DecRef"), <void*>Py_DecRef))

_init_env_api()
16 changes: 0 additions & 16 deletions python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -376,19 +376,3 @@ def _set_class_object_generic(object_generic_class, func_convert_to_object):
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object

# Py_INCREF and Py_DECREF are C macros, not function objects.
# Therefore, providing a wrapper function.
cdef void _py_incref_wrapper(void* py_object):
Py_INCREF(<object>py_object)
cdef void _py_decref_wrapper(void* py_object):
Py_DECREF(<object>py_object)

def _init_pythonapi_inc_def_ref():
register_func = TVMBackendRegisterEnvCAPI
register_func(c_str("Py_IncRef"), <void*>_py_incref_wrapper)
register_func(c_str("Py_DecRef"), <void*>_py_decref_wrapper)
register_func(c_str("PyGILState_Ensure"), <void*>PyGILState_Ensure)
register_func(c_str("PyGILState_Release"), <void*>PyGILState_Release)

_init_pythonapi_inc_def_ref()
12 changes: 8 additions & 4 deletions src/runtime/registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,14 @@ class EnvCAPIRegistry {
// implementation of tvm::runtime::EnvCheckSignals
void CheckSignals() {
// check python signal to see if there are exception raised
if (pyerr_check_signals != nullptr && (*pyerr_check_signals)() != 0) {
// The error will let FFI know that the frontend environment
// already set an error.
throw EnvErrorAlreadySet("");
if (pyerr_check_signals != nullptr) {
// The C++ env comes without gil, so we need to grab gil here
WithGIL context(this);
if ((*pyerr_check_signals)() != 0) {
// The error will let FFI know that the frontend environment
// already set an error.
throw EnvErrorAlreadySet("");
}
}
}

Expand Down
8 changes: 8 additions & 0 deletions src/support/ffi_testing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ TVM_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) {
std::this_thread::sleep_for(duration);
});

TVM_REGISTER_GLOBAL("testing.check_signals").set_body_typed([](double sleep_period) {
while (true) {
std::chrono::duration<int64_t, std::nano> duration(static_cast<int64_t>(sleep_period * 1e9));
std::this_thread::sleep_for(duration);
runtime::EnvCheckSignals();
}
});

TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Variant<String, IntImm> {
if (x % 2 == 0) {
return IntImm(DataType::Int(64), x / 2);
Expand Down
Loading