Skip to content

Commit

Permalink
Add traced_object_functor
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <jshao@octoml.ai>
Co-authored-by: Greg Bonik <gbonik@octoml.ai>
  • Loading branch information
3 people committed Aug 5, 2022
1 parent 4231ebb commit a472810
Show file tree
Hide file tree
Showing 4 changed files with 421 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/tvm/script/printer/traced_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class TracedObject {
using ObjectType = typename RefT::ContainerType;

public:
using ObjectRefType = RefT;

// Don't use this direcly. For convenience, call MakeTraced() instead.
explicit TracedObject(const RefT& object_ref, ObjectPath path)
: ref_(object_ref), path_(std::move(path)) {}
Expand Down
168 changes: 168 additions & 0 deletions include/tvm/script/printer/traced_object_functor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
#define TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_

#include <tvm/node/node.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/script/printer/traced_object.h>

#include <string>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>

namespace tvm {
namespace script {
namespace printer {

namespace {

namespace detail {
/*!
* \brief Helper template class to extract the type of first argument of a function
* \tparam FType The function type.
*/
template <typename FType>
struct first_arg_type_helper;

template <typename R, typename ArgOne, typename... OtherArgs>
struct first_arg_type_helper<R(ArgOne, OtherArgs...)> {
using T = ArgOne;
};

/*!
* \brief Template alias for the type of first argument of a function
* \tparam FType The function type.
*
* The name of public functions are in snake case to be consistent with
* tvm/node/functor.h
*/
template <typename FType>
using first_arg_type = typename detail::first_arg_type_helper<
typename tvm::runtime::detail::function_signature<FType>::FType>::T;
} // namespace detail

} // namespace

namespace dispatch_table {
/*
* Functions in dispatch_table namespace is created to reduce the binary bloat
* from template and also hide implementation details from this header
*/

using DispatchTable = std::unordered_map<std::string, std::vector<runtime::PackedFunc>>;

constexpr const char* kDefaultDispatchToken = "";

const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table,
const String& token, uint32_t type_index);
void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index,
runtime::PackedFunc f);
} // namespace dispatch_table

/*!
* \brief Dynamic dispatch functor based on TracedObject.
*
* This functor dispatches based on the type of object ref inside the input TracedObject,
* and the input dispatch token.
*/
template <typename R, typename... Args>
class TracedObjectFunctor {
private:
using TSelf = TracedObjectFunctor<R, Args...>;

template <class TObjectRef, class TCallable>
using IsDispatchFunction =
typename std::is_convertible<TCallable, std::function<R(TracedObject<TObjectRef>, Args...)>>;

public:
/*!
* \brief Call the dispatch function.
* \param token The dispatch token.
* \param traced_object The traced object.
* \param args Other args.
*
* \return The return value of the dispatch function
*
* If the TObjectRef isn't registered with the token, it will try to find
* dispatch function for TObjectRef with kDefaultDispatchToken.
*/
template <class TObjectRef>
R operator()(const String& token, TracedObject<TObjectRef> traced_object, Args... args) const {
const runtime::PackedFunc& dispatch_function = dispatch_table::GetDispatchFunction(
dispatch_table_, token, traced_object.Get()->type_index());
return dispatch_function(traced_object.Get(), traced_object.GetPath(),
std::forward<Args>(args)...);
}

/*!
* \brief Set the dispatch function
* \param f The dispatch function.
*
* This takes a type-erased packed function as input. It should be used
* through FFI boundary, for example, registering dispatch function from Python.
*/
TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) {
dispatch_table::SetDispatchFunction(&dispatch_table_, token, type_index, std::move(f));
return *this;
}

/*!
* \brief Set the dispatch function
* \param token The dispatch token.
* \param f The dispatch function.
*
* The diaptch function should have signature `R(TracedObject<TObjectRef>, Args...)`.
*/
template <typename TCallable,
typename TObjectRef = typename detail::first_arg_type<TCallable>::ObjectRefType,
typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
TSelf& set_dispatch(String token, TCallable f) {
return set_dispatch(token, //
TObjectRef::ContainerType::RuntimeTypeIndex(), //
runtime::TypedPackedFunc<R(TObjectRef, ObjectPath, Args...)>(
[f](TObjectRef object, ObjectPath path, Args... args) -> R {
return f(MakeTraced(object, path), std::forward<Args>(args)...);
}));
}
/*!
* \brief Set the default dispatch function
* \param f The dispatch function.
*
* Default dispatch function will be used if there is no function registered
* with the requested dispatch token.
*
* Default dispatch function has an empty string as dispatch token.
*/
template <typename TCallable>
TSelf& set_dispatch(TCallable f) {
return set_dispatch(dispatch_table::kDefaultDispatchToken, std::forward<TCallable>(f));
}

private:
dispatch_table::DispatchTable dispatch_table_;
};

} // namespace printer
} // namespace script
} // namespace tvm
#endif // TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
77 changes: 77 additions & 0 deletions src/script/printer/traced_object_functor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/script/printer/traced_object_functor.h>

namespace tvm {
namespace script {
namespace printer {
namespace dispatch_table {

const runtime::PackedFunc* GetDispatchFunctionForToken(const DispatchTable& table,
const String& token, uint32_t type_index) {
auto it = table.find(token);
if (it == table.end()) {
return nullptr;
}
const std::vector<runtime::PackedFunc>& tab = it->second;
if (type_index >= tab.size()) {
return nullptr;
}
const PackedFunc* f = &tab[type_index];
if (f->defined()) {
return f;
} else {
return nullptr;
}
}

const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table,
const String& token, uint32_t type_index) {
if (const runtime::PackedFunc* pf =
GetDispatchFunctionForToken(dispatch_table, token, type_index)) {
return *pf;
} else if (const runtime::PackedFunc* pf =
GetDispatchFunctionForToken(dispatch_table, kDefaultDispatchToken, type_index)) {
// Fallback to function with the default dispatch token
return *pf;
} else {
ICHECK(false) << "ObjectFunctor calls un-registered function on type: "
<< runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")";
throw;
}
}

void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index,
runtime::PackedFunc f) {
std::vector<runtime::PackedFunc>* table = &(*dispatch_table)[token];
if (table->size() <= type_index) {
table->resize(type_index + 1, nullptr);
}
runtime::PackedFunc& slot = (*table)[type_index];
if (slot != nullptr) {
ICHECK(false) << "Dispatch for type is already registered: "
<< runtime::Object::TypeIndex2Key(type_index);
}
slot = f;
}
} // namespace dispatch_table
} // namespace printer
} // namespace script
} // namespace tvm
Loading

0 comments on commit a472810

Please sign in to comment.