Skip to content

Commit

Permalink
Fix segfault at the end of execution
Browse files Browse the repository at this point in the history
  • Loading branch information
yelite committed Aug 12, 2022
1 parent 9bff2ca commit e61d43b
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 1 deletion.
21 changes: 21 additions & 0 deletions include/tvm/script/printer/traced_object_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_tab
void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index,
runtime::PackedFunc f);

/*!
* \brief Remove function from dispatch table.
* \param dispatch_table The dispatch table.
* \param token The dispatch token.
* \param type_index The TVM object type index for the dispatch function to be removed.
*/
void RemoveDispatchFunction(DispatchTable* dispatch_table, const String& token,
uint32_t type_index);

constexpr const char* kDefaultDispatchToken = "";

/*!
Expand Down Expand Up @@ -173,6 +182,18 @@ class TracedObjectFunctor {
return set_dispatch(kDefaultDispatchToken, std::forward<TCallable>(f));
}

/*!
* \brief Remove dispatch function
* \param token The dispatch token.
* \param type_index The TVM object type index for the dispatch function to be removed.
*
* This is useful when dispatch function comes from other language's runtime, and
* those function should be removed before that language runtime shuts down.
*/
void remove_dispatch(String token, uint32_t type_index) {
RemoveDispatchFunction(&dispatch_table_, token, type_index);
}

private:
DispatchTable dispatch_table_;
};
Expand Down
35 changes: 34 additions & 1 deletion python/tvm/script/printer/ir_docsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
# specific language governing permissions and limitations
# under the License.

import atexit
from contextlib import ExitStack, contextmanager
from typing import Callable, Dict, Mapping, Optional, Sequence, Type, TypeVar, Generator
from typing import Callable, Dict, Generator, Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar

from tvm._ffi import get_object_type_index, register_object
from tvm.runtime import Object, ObjectPath
Expand All @@ -26,6 +27,33 @@
from .frame import Frame
from .var_table import VarTable

_REGISTERED_TYPES: Set[Tuple[str, int]] = set() # {(dispatch_token, type_index)}


def _cleanup_dispatch_function():
for dispatch_token, type_index in _REGISTERED_TYPES:
_ffi_api.IRDocsifierRemoveDispatch(dispatch_token, type_index)


_CLEANUP_REGISTERED = False


def _ensure_cleanup_function_registered():
"""
Add a cleanup function to be called on interpreter termination,
to remove all dispatch functions registered on the Python side.
Without cleaning up those dispatch functions, program will segfault
on termination. It's because dispatch functions are referenced from the
static memory of libtvm, thus they will be cleaned up at the very end,
making calls to Py_DecRef after Python interpreter terminates.
"""
global _CLEANUP_REGISTERED # pylint: disable=global-statement

if not _CLEANUP_REGISTERED:
atexit.register(_cleanup_dispatch_function)
_CLEANUP_REGISTERED = True


@register_object("script.printer.IRDocsifier")
class IRDocsifier(Object):
Expand Down Expand Up @@ -78,7 +106,12 @@ def set_dispatch(
there is no dispatch function registered with the current dispatch token.
"""
type_index = get_object_type_index(node_type)
if type_index is None:
raise TypeError(f"{type(node_type)} is not a registered TVM object type.")

_ensure_cleanup_function_registered()
_ffi_api.IRDocsifierSetDispatch(dispatch_token, type_index, dispatch_function) # type: ignore # pylint: disable=no-member
_REGISTERED_TYPES.add((dispatch_token, type_index))

def as_doc(self, obj: Object, object_path: ObjectPath) -> Doc:
"""
Expand Down
4 changes: 4 additions & 0 deletions src/script/printer/ir_docsifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ TVM_REGISTER_GLOBAL("script.printer.IRDocsifierSetDispatch")
.set_body_typed([](String token, uint64_t type_index, runtime::PackedFunc f) {
IRDocsifier::vtable().set_dispatch(token, type_index, std::move(f));
});
TVM_REGISTER_GLOBAL("script.printer.IRDocsifierRemoveDispatch")
.set_body_typed([](String token, uint64_t type_index) {
IRDocsifier::vtable().remove_dispatch(token, type_index);
});
} // namespace printer
} // namespace script
} // namespace tvm
10 changes: 10 additions & 0 deletions src/script/printer/traced_object_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uin
}
slot = f;
}

void RemoveDispatchFunction(DispatchTable* dispatch_table, const String& token,
uint32_t type_index) {
std::vector<runtime::PackedFunc>* table = &(*dispatch_table)[token];
if (table->size() <= type_index) {
return;
}
(*table)[type_index] = nullptr;
}

} // namespace printer
} // namespace script
} // namespace tvm
14 changes: 14 additions & 0 deletions tests/cpp/tvmscript_printer_traced_object_functor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,20 @@ TEST(TracedObjectFunctorTest, ExtraArg) {
ICHECK_EQ(functor("tir", MakeTraced(BarObject(), path), 2), 3);
}

TEST(TracedObjectFunctorTest, RemoveDispatchFunction) {
TracedObjectFunctor<String> functor;
ObjectPath path = ObjectPath::Root();

functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "Foo"; });
functor.set_dispatch("tir", [](TracedObject<FooObject> o) -> String { return "Foo tir"; });

ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir");

functor.remove_dispatch("tir", FooObjectNode::RuntimeTypeIndex());
ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo");
}

TEST(TracedObjectFunctorTest, CallWithUnregisteredType) {
TracedObjectFunctor<int, int> functor;
ObjectPath path = ObjectPath::Root();
Expand Down

0 comments on commit e61d43b

Please sign in to comment.