diff --git a/include/tvm/script/printer/traced_object.h b/include/tvm/script/printer/traced_object.h index 6f04b66cec97..4c09b0a41b79 100644 --- a/include/tvm/script/printer/traced_object.h +++ b/include/tvm/script/printer/traced_object.h @@ -86,6 +86,8 @@ class TracedObject { using ObjectType = typename RefT::ContainerType; public: + using ObjectRefType = RefT; + // Don't use this direcly. For convenience, call MakeTraced() instead. explicit TracedObject(const RefT& object_ref, ObjectPath path) : ref_(object_ref), path_(std::move(path)) {} diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h new file mode 100644 index 000000000000..05fbbf79f2ee --- /dev/null +++ b/include/tvm/script/printer/traced_object_functor.h @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_ +#define TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +namespace { + +namespace detail { +/*! + * \brief Helper template class to extract the type of first argument of a function + * \tparam FType The function type. + */ +template +struct FirstArgTypeGetter; + +template +struct FirstArgTypeGetter { + using T = ArgOne; +}; + +/*! + * \brief Template alias for the type of first argument of a function + * \tparam FType The function type. + * + * The name of public functions are in snake case to be consistent with + * tvm/node/functor.h + */ +template +using FirstArgType = typename detail::FirstArgTypeGetter< + typename tvm::runtime::detail::function_signature::FType>::T; +} // namespace detail + +} // namespace + +/* + * This type alias and the following free functions are created to reduce the binary bloat + * from template and also hide implementation details from this header + */ +using DispatchTable = std::unordered_map>; + +/*! + * \brief Get function from dispatch table. + * \param dispatch_table The dispatch table. + * \param token The dispatch token. + * \param type_index The type index of the Object type to be dispatched. + * + * \return The dispatch function. + */ +const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table, + const String& token, uint32_t type_index); + +/*! + * \brief Set function in dispatch table. + * \param dispatch_table The dispatch table. + * \param token The dispatch token. + * \param type_index The type index of the Object type to be dispatched. + * \param f The dispatch function. + */ +void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index, + runtime::PackedFunc f); + +constexpr const char* kDefaultDispatchToken = ""; + +/*! + * \brief Dynamic dispatch functor based on TracedObject. + * + * This functor dispatches based on the type of object ref inside the input TracedObject, + * and the input dispatch token. + */ +template +class TracedObjectFunctor { + private: + using TSelf = TracedObjectFunctor; + + template + using IsDispatchFunction = + typename std::is_convertible, Args...)>>; + + public: + /*! + * \brief Call the dispatch function. + * \param token The dispatch token. + * \param traced_object The traced object. + * \param args Other args. + * + * \return The return value of the dispatch function + * + * If the TObjectRef isn't registered with the token, it will try to find + * dispatch function for TObjectRef with kDefaultDispatchToken. + */ + template + R operator()(const String& token, TracedObject traced_object, Args... args) const { + const runtime::PackedFunc& dispatch_function = + GetDispatchFunction(dispatch_table_, token, traced_object.Get()->type_index()); + return dispatch_function(traced_object.Get(), traced_object.GetPath(), args...); + } + + /*! + * \brief Set the dispatch function + * \param token The dispatch token. + * \param type_index The TVM object type index for this dispatch function. + * \param f The dispatch function. + * + * This takes a type-erased packed function as input. It should be used + * through FFI boundary, for example, registering dispatch function from Python. + */ + TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) { + SetDispatchFunction(&dispatch_table_, token, type_index, std::move(f)); + return *this; + } + + /*! + * \brief Set the dispatch function + * \param token The dispatch token. + * \param f The dispatch function. + * + * The diaptch function should have signature `R(TracedObject, Args...)`. + */ + template ::ObjectRefType, + typename = std::enable_if_t::value>> + TSelf& set_dispatch(String token, TCallable f) { + return set_dispatch( + token, // + TObjectRef::ContainerType::RuntimeTypeIndex(), // + runtime::TypedPackedFunc( + [f = std::move(f)](TObjectRef object, ObjectPath path, Args... args) -> R { + return f(MakeTraced(object, path), args...); + })); + } + /*! + * \brief Set the default dispatch function + * \param f The dispatch function. + * + * Default dispatch function will be used if there is no function registered + * with the requested dispatch token. + * + * Default dispatch function has an empty string as dispatch token. + */ + template + TSelf& set_dispatch(TCallable&& f) { + return set_dispatch(kDefaultDispatchToken, std::forward(f)); + } + + private: + DispatchTable dispatch_table_; +}; + +} // namespace printer +} // namespace script +} // namespace tvm +#endif // TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_ diff --git a/src/script/printer/traced_object_functor.cc b/src/script/printer/traced_object_functor.cc new file mode 100644 index 000000000000..a018099a1de0 --- /dev/null +++ b/src/script/printer/traced_object_functor.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +namespace tvm { +namespace script { +namespace printer { + +const runtime::PackedFunc* GetDispatchFunctionForToken(const DispatchTable& table, + const String& token, uint32_t type_index) { + auto it = table.find(token); + if (it == table.end()) { + return nullptr; + } + const std::vector& tab = it->second; + if (type_index >= tab.size()) { + return nullptr; + } + const PackedFunc* f = &tab[type_index]; + if (f->defined()) { + return f; + } else { + return nullptr; + } +} + +const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table, + const String& token, uint32_t type_index) { + if (const runtime::PackedFunc* pf = + GetDispatchFunctionForToken(dispatch_table, token, type_index)) { + return *pf; + } else if (const runtime::PackedFunc* pf = + GetDispatchFunctionForToken(dispatch_table, kDefaultDispatchToken, type_index)) { + // Fallback to function with the default dispatch token + return *pf; + } else { + ICHECK(false) << "ObjectFunctor calls un-registered function on type: " + << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")"; + throw; + } +} + +void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index, + runtime::PackedFunc f) { + std::vector* table = &(*dispatch_table)[token]; + if (table->size() <= type_index) { + table->resize(type_index + 1, nullptr); + } + runtime::PackedFunc& slot = (*table)[type_index]; + if (slot != nullptr) { + ICHECK(false) << "Dispatch for type is already registered: " + << runtime::Object::TypeIndex2Key(type_index); + } + slot = f; +} +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc new file mode 100644 index 000000000000..3fd52d44aa8c --- /dev/null +++ b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +using namespace tvm; +using namespace tvm::script::printer; + +namespace { + +class FooObjectNode : public Object { + public: + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const char* _type_key = "test.FooObject"; + TVM_DECLARE_FINAL_OBJECT_INFO(FooObjectNode, Object); +}; + +class FooObject : public ObjectRef { + public: + FooObject() { this->data_ = make_object(); } + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FooObject, ObjectRef, FooObjectNode); +}; + +TVM_REGISTER_NODE_TYPE(FooObjectNode); + +class BarObjectNode : public Object { + public: + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const char* _type_key = "test.BarObject"; + TVM_DECLARE_FINAL_OBJECT_INFO(BarObjectNode, Object); +}; + +class BarObject : public ObjectRef { + public: + BarObject() { this->data_ = make_object(); } + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BarObject, ObjectRef, BarObjectNode); +}; + +TVM_REGISTER_NODE_TYPE(BarObjectNode); + +String ComputeFoo(TracedObject foo) { return "Foo"; } + +} // anonymous namespace + +TEST(TracedObjectFunctorTest, NormalRegistration) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Bar"; }); + + ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); + ICHECK_EQ(functor("", MakeTraced(BarObject(), path)), "Bar"); +} + +TEST(TracedObjectFunctorTest, RegistrationWithFunction) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + functor.set_dispatch([](TracedObject o) -> String { return "FooLambda"; }); + functor.set_dispatch("tir", ComputeFoo); + + ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "FooLambda"); + ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo"); +} + +TEST(TracedObjectFunctorTest, RegistrationWithDispatchToken) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch("tir", [](TracedObject o) -> String { return "Foo tir"; }); + functor.set_dispatch("relax", [](TracedObject o) -> String { return "Foo relax"; }); + + ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); + ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir"); + ICHECK_EQ(functor("relax", MakeTraced(FooObject(), path)), "Foo relax"); + ICHECK_EQ(functor("xyz", MakeTraced(FooObject(), path)), "Foo"); +} + +TEST(TracedObjectFunctorTest, RegistrationWithPackedFunc) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + auto f_default = [](runtime::TVMArgs, runtime::TVMRetValue* ret) { *ret = String("default"); }; + auto f_tir = [](runtime::TVMArgs, runtime::TVMRetValue* ret) { *ret = String("tir"); }; + + functor.set_dispatch("", FooObjectNode::RuntimeTypeIndex(), runtime::PackedFunc(f_default)); + functor.set_dispatch("tir", FooObjectNode::RuntimeTypeIndex(), runtime::PackedFunc(f_tir)); + + ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "default"); + ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "tir"); +} + +TEST(TracedObjectFunctorTest, ExtraArg) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x + 1; }); + + ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2); + ICHECK_EQ(functor("", MakeTraced(BarObject(), path), 2), 3); + ICHECK_EQ(functor("tir", MakeTraced(BarObject(), path), 2), 3); +} + +TEST(TracedObjectFunctorTest, CallWithUnregisteredType) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + bool failed = false; + try { + ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2); + } catch (...) { + failed = true; + } + ASSERT_EQ(failed, true); +} + +TEST(TracedObjectFunctorTest, DuplicateRegistration_WithoutToken) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + functor.set_dispatch([](TracedObject o, int x) { return x; }); + + bool failed = false; + try { + functor.set_dispatch([](TracedObject o, int x) { return x; }); + } catch (...) { + failed = true; + } + ASSERT_EQ(failed, true); +} + +TEST(TracedObjectFunctorTest, DuplicateRegistration_WithToken) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); + + bool failed = false; + try { + functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); + } catch (...) { + failed = true; + } + ASSERT_EQ(failed, true); +}