Skip to content

Commit

Permalink
[TVMScript] Printer Registry (apache#12237)
Browse files Browse the repository at this point in the history
This PR:

- Adds the registry of printing function (traced_object_layered_functor.cc)

Compared to the prototype version, this:
- Consolidates the implementation into a single class, since this class is only for the TVMScript printer.
- Deduces the TObjectRef when calling set_dispatch.

Tracking issue: apache#11912

Co-authored-by: Greg Bonik <gbonik@octoml.ai>
  • Loading branch information
2 people authored and Mikael Sevenier committed Aug 12, 2022
1 parent 627d439 commit f1dabac
Show file tree
Hide file tree
Showing 4 changed files with 431 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/tvm/script/printer/traced_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class TracedObject {
using ObjectType = typename RefT::ContainerType;

public:
using ObjectRefType = RefT;

// Don't use this direcly. For convenience, call MakeTraced() instead.
explicit TracedObject(const RefT& object_ref, ObjectPath path)
: ref_(object_ref), path_(std::move(path)) {}
Expand Down
183 changes: 183 additions & 0 deletions include/tvm/script/printer/traced_object_functor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
#define TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_

#include <tvm/node/node.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/script/printer/traced_object.h>

#include <string>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>

namespace tvm {
namespace script {
namespace printer {

namespace {

namespace detail {
/*!
* \brief Helper template class to extract the type of first argument of a function
* \tparam FType The function type.
*/
template <typename FType>
struct FirstArgTypeGetter;

template <typename R, typename ArgOne, typename... OtherArgs>
struct FirstArgTypeGetter<R(ArgOne, OtherArgs...)> {
using T = ArgOne;
};

/*!
* \brief Template alias for the type of first argument of a function
* \tparam FType The function type.
*
* The name of public functions are in snake case to be consistent with
* tvm/node/functor.h
*/
template <typename FType>
using FirstArgType = typename detail::FirstArgTypeGetter<
typename tvm::runtime::detail::function_signature<FType>::FType>::T;
} // namespace detail

} // namespace

/*
* This type alias and the following free functions are created to reduce the binary bloat
* from template and also hide implementation details from this header
*/
using DispatchTable = std::unordered_map<std::string, std::vector<runtime::PackedFunc>>;

/*!
* \brief Get function from dispatch table.
* \param dispatch_table The dispatch table.
* \param token The dispatch token.
* \param type_index The type index of the Object type to be dispatched.
*
* \return The dispatch function.
*/
const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table,
const String& token, uint32_t type_index);

/*!
* \brief Set function in dispatch table.
* \param dispatch_table The dispatch table.
* \param token The dispatch token.
* \param type_index The type index of the Object type to be dispatched.
* \param f The dispatch function.
*/
void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index,
runtime::PackedFunc f);

constexpr const char* kDefaultDispatchToken = "";

/*!
* \brief Dynamic dispatch functor based on TracedObject.
*
* This functor dispatches based on the type of object ref inside the input TracedObject,
* and the input dispatch token.
*/
template <typename R, typename... Args>
class TracedObjectFunctor {
private:
using TSelf = TracedObjectFunctor<R, Args...>;

template <class TObjectRef, class TCallable>
using IsDispatchFunction =
typename std::is_convertible<TCallable, std::function<R(TracedObject<TObjectRef>, Args...)>>;

public:
/*!
* \brief Call the dispatch function.
* \param token The dispatch token.
* \param traced_object The traced object.
* \param args Other args.
*
* \return The return value of the dispatch function
*
* If the TObjectRef isn't registered with the token, it will try to find
* dispatch function for TObjectRef with kDefaultDispatchToken.
*/
template <class TObjectRef>
R operator()(const String& token, TracedObject<TObjectRef> traced_object, Args... args) const {
const runtime::PackedFunc& dispatch_function =
GetDispatchFunction(dispatch_table_, token, traced_object.Get()->type_index());
return dispatch_function(traced_object.Get(), traced_object.GetPath(), args...);
}

/*!
* \brief Set the dispatch function
* \param token The dispatch token.
* \param type_index The TVM object type index for this dispatch function.
* \param f The dispatch function.
*
* This takes a type-erased packed function as input. It should be used
* through FFI boundary, for example, registering dispatch function from Python.
*/
TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) {
SetDispatchFunction(&dispatch_table_, token, type_index, std::move(f));
return *this;
}

/*!
* \brief Set the dispatch function
* \param token The dispatch token.
* \param f The dispatch function.
*
* The diaptch function should have signature `R(TracedObject<TObjectRef>, Args...)`.
*/
template <typename TCallable,
typename TObjectRef = typename detail::FirstArgType<TCallable>::ObjectRefType,
typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
TSelf& set_dispatch(String token, TCallable f) {
return set_dispatch(
token, //
TObjectRef::ContainerType::RuntimeTypeIndex(), //
runtime::TypedPackedFunc<R(TObjectRef, ObjectPath, Args...)>(
[f = std::move(f)](TObjectRef object, ObjectPath path, Args... args) -> R {
return f(MakeTraced(object, path), args...);
}));
}
/*!
* \brief Set the default dispatch function
* \param f The dispatch function.
*
* Default dispatch function will be used if there is no function registered
* with the requested dispatch token.
*
* Default dispatch function has an empty string as dispatch token.
*/
template <typename TCallable>
TSelf& set_dispatch(TCallable&& f) {
return set_dispatch(kDefaultDispatchToken, std::forward<TCallable>(f));
}

private:
DispatchTable dispatch_table_;
};

} // namespace printer
} // namespace script
} // namespace tvm
#endif // TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
75 changes: 75 additions & 0 deletions src/script/printer/traced_object_functor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/script/printer/traced_object_functor.h>

namespace tvm {
namespace script {
namespace printer {

const runtime::PackedFunc* GetDispatchFunctionForToken(const DispatchTable& table,
const String& token, uint32_t type_index) {
auto it = table.find(token);
if (it == table.end()) {
return nullptr;
}
const std::vector<runtime::PackedFunc>& tab = it->second;
if (type_index >= tab.size()) {
return nullptr;
}
const PackedFunc* f = &tab[type_index];
if (f->defined()) {
return f;
} else {
return nullptr;
}
}

const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table,
const String& token, uint32_t type_index) {
if (const runtime::PackedFunc* pf =
GetDispatchFunctionForToken(dispatch_table, token, type_index)) {
return *pf;
} else if (const runtime::PackedFunc* pf =
GetDispatchFunctionForToken(dispatch_table, kDefaultDispatchToken, type_index)) {
// Fallback to function with the default dispatch token
return *pf;
} else {
ICHECK(false) << "ObjectFunctor calls un-registered function on type: "
<< runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")";
throw;
}
}

void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index,
runtime::PackedFunc f) {
std::vector<runtime::PackedFunc>* table = &(*dispatch_table)[token];
if (table->size() <= type_index) {
table->resize(type_index + 1, nullptr);
}
runtime::PackedFunc& slot = (*table)[type_index];
if (slot != nullptr) {
ICHECK(false) << "Dispatch for type is already registered: "
<< runtime::Object::TypeIndex2Key(type_index);
}
slot = f;
}
} // namespace printer
} // namespace script
} // namespace tvm
Loading

0 comments on commit f1dabac

Please sign in to comment.