Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Implement relax.Function.bind_params #15626

Merged
merged 4 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
*
* \return The Pass.
*/
TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray> params);
TVM_DLL Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params);

/*!
* \brief Bind symbolic vars to constant shape values.
Expand Down
22 changes: 22 additions & 0 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
#ifndef TVM_RELAX_UTILS_H_
#define TVM_RELAX_UTILS_H_

#include <tvm/arith/analyzer.h>
#include <tvm/ir/module.h>
#include <tvm/relax/expr.h>
#include <tvm/runtime/logging.h>

namespace tvm {
Expand All @@ -48,6 +50,26 @@ namespace relax {
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map = {});

/*!
* \brief Infer a binding map for symbolic variables
*
* If a set of relax variables are replaced within an expression, this
* may result in removal of the definition site of a symbolic
* variable. This utility function determines the symbolic variable
* replacements that can be inferred based on the replaced relax
* variables, and can be used alongside the `Bind` utility function to
* replace both the relax variables and the implied symbolic
* variables.
*
* \param binds A map of relax variables to relax expressions
*
* \param analyzer The analyzer to use for simplifications
*
* \return A map of TIR variables to TIR expressions
*/
TVM_DLL tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
const tvm::Map<relax::Var, relax::Expr>& binds, arith::Analyzer* analyzer);

/*!
* \brief Check if the given StructInfo is for a boolean scalar (tensor of rank 0 with a boolean
* dtype).
Expand Down
50 changes: 50 additions & 0 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,56 @@ def bind_symbolic_vars(

return _ffi_api.FunctionBindSymbolicVars(self, binding_map) # type: ignore

def bind_params(
self,
binding_map: Mapping[
Union[str, Var],
Union[int, float, PrimExpr, tvm.runtime.NDArray, _np.ndarray, Expr],
],
) -> "Function":
"""Return a new function with updated symbolic variable

Parameters
----------
binding_map: Mapping[
Union[str, Var],
Union[int, float, PrimExpr, tvm.runtime.NDArray, _np.ndarray, Expr],
]

The mapping of values to be replaced.

Keys may be either a `relax.Var` or a string name of the
Relax variable. If the variables are referred to by name,
the name must uniquely identify a parameter in the
function.

Values must be a relax expression, or a value that is
convertible into a relax expression. The value must be
compatible with the variable being replaced.

Returns
-------
func: Function

The updated function
"""

def _normalize_value(value):
# Conversions that must occur prior to the FFI
# conversions.
if isinstance(value, int):
# Relax uses int64 for symbolic variables, but the FFI
# converts python integers into int32.
return tvm.tir.const(value, "int64")
elif isinstance(value, (_np.ndarray, tvm.nd.NDArray)):
return tvm.relax.const(value)
else:
return value

binding_map = {key: _normalize_value(value) for key, value in binding_map.items()}

return _ffi_api.FunctionBindParams(self, binding_map) # type: ignore


@tvm._ffi.register_object("relax.expr.ExternFunc")
class ExternFunc(BaseFunc):
Expand Down
11 changes: 8 additions & 3 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass:

def BindParams(
func_name: str,
params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]],
params: Dict[Union[str, Var], Union[tvm.runtime.NDArray, np.ndarray]],
) -> tvm.ir.transform.Pass:
"""Bind params of function of the module to constant tensors.

Expand All @@ -397,8 +397,13 @@ def BindParams(
func_name: str
The function name to be bound

params : Dict[str, Union[tvm.runtime.NDArray, np.ndarray]]
The map from param name to constant tensors.
params : Dict[
Union[str,relax.Var],
Union[tvm.runtime.NDArray, np.ndarray],
]

The map from parameter or parameter name name to constant
tensors.

Returns
-------
Expand Down
116 changes: 81 additions & 35 deletions src/relax/transform/bind_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/relax/type.h>
#include <tvm/tir/op.h>

#include <tuple>
#include <utility>

namespace tvm {
Expand Down Expand Up @@ -81,45 +82,88 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant,
}
}

std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> NormalizeBindings(
const Function& func, const Map<ObjectRef, ObjectRef>& untyped_params) {
ICHECK(func.defined());
ICHECK(untyped_params.defined());

// Map from string to the variable(s) with that name.
std::unordered_map<std::string, Array<relax::Var>> string_lookup;
std::unordered_set<const relax::VarNode*> var_set;
for (const auto& param : func->params) {
string_lookup[param->name_hint()].push_back(param);
var_set.insert(param.get());
}

Map<relax::Var, relax::Expr> relax_var_remap;

auto normalize_key = [&](ObjectRef obj) -> relax::Var {
if (auto opt_str = obj.as<String>()) {
std::string str = opt_str.value();
auto it = string_lookup.find(str);
CHECK(it != string_lookup.end())
<< "Function does not have parameter with name \"" << str << "\". "
<< "Function parameters are named "
<< func->params.Map([](const auto& param) { return param->name_hint(); });
CHECK_EQ(it->second.size(), 1)
<< "Function contains multiple parameters with name \"" << str << "\". "
<< "The Relax variables " << it->second << " are all named \"" << str << "\"";
auto var = it->second[0];
CHECK(!relax_var_remap.count(var))
<< "Remap of variable " << var << " was defined multiple times";

return var;
} else if (auto opt_var = obj.as<relax::Var>()) {
auto var = opt_var.value();
CHECK(!relax_var_remap.count(var))
<< "Remap of variable " << var << " was defined multiple times";
CHECK(var_set.count(var.get()))
<< "Function does not use Relax variable " << var << " as a parameter. "
<< "Function parameters are " << func->params;
return var;
} else {
LOG(FATAL)
<< "Expected bound parameter to be a relax::Var, "
<< " or a string that uniquely identifies a relax::Var param within the function. "
<< "However, received object " << obj << " of type " << obj->GetTypeKey();
}
};
auto normalize_value = [&](ObjectRef obj) -> relax::Expr {
if (auto opt = obj.as<relax::Expr>()) {
return opt.value();
} else if (auto opt = obj.as<runtime::NDArray>()) {
return Constant(opt.value());
} else {
LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey()
<< " into relax expression";
}
};

for (const auto& [key, value] : untyped_params) {
relax_var_remap.Set(normalize_key(key), normalize_value(value));
}

arith::Analyzer analyzer;
Map<tir::Var, PrimExpr> symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer);

// for (const auto& [bind_param, bind_expr] : relax_var_remap) {
// MatchSymbolicVar(bind_param, bind_expr, &symbolic_var_map, &analyzer);
// }

return {relax_var_remap, symbolic_var_map};
}

/*!
* \brief Bind params to function by using name
* \param func Relax function
* \param params params dict
* \return Function
*/
inline Function BindParamsByName(Function func, const Map<String, runtime::NDArray>& params) {
std::unordered_map<std::string, Var> name_dict;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
for (auto arg : func->params) {
const auto& name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(name_dict[name]);
} else {
name_dict[name] = arg;
}
}
Function FunctionBindParams(Function func, const Map<ObjectRef, ObjectRef>& untyped_params) {
auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params);

arith::Analyzer analyzer;
Map<Var, Expr> bind_dict;
Map<tir::Var, PrimExpr> symbolic_var_map;

for (auto& kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
const Var& arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "ValueError: Multiple args in the function have name " << kv.first;
}
Expr const_expr = Constant(kv.second);
bind_dict.Set(arg, const_expr);
MatchSymbolicVar(arg, const_expr, &symbolic_var_map, &analyzer);
}
Expr bound_expr = Bind(func, bind_dict, symbolic_var_map);
Function ret = Downcast<Function>(bound_expr);
ICHECK(ret.defined()) << "The returning type is expected to be a Relax Function."
<< "\n";
return ret;
return Downcast<Function>(bound_expr);
}

/*!
Expand All @@ -129,7 +173,7 @@ inline Function BindParamsByName(Function func, const Map<String, runtime::NDArr
* \param param The param dict
* \return The module after binding params.
*/
IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray> param) {
IRModule BindParam(IRModule m, String func_name, Map<ObjectRef, ObjectRef> bind_params) {
IRModuleNode* new_module = m.CopyOnWrite();
Map<GlobalVar, BaseFunc> functions = m->functions;
for (const auto& func_pr : functions) {
Expand All @@ -138,13 +182,13 @@ IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray> p
// Use global_symbol if it's external linkage
Optional<String> gsymbol = relax_f->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (gsymbol.defined() && gsymbol.value() == func_name) {
Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), param);
Function f_after_bind = FunctionBindParams(GetRef<Function>(relax_f), bind_params);
new_module->Update(func_pr.first, f_after_bind);
}
} else {
// Use global var's name_hint if it's internal linkage
if (func_pr.first->name_hint == func_name) {
Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), param);
Function f_after_bind = FunctionBindParams(GetRef<Function>(relax_f), bind_params);
new_module->Update(func_pr.first, f_after_bind);
}
}
Expand All @@ -153,9 +197,11 @@ IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray> p
return GetRef<IRModule>(new_module);
}

TVM_REGISTER_GLOBAL("relax.FunctionBindParams").set_body_typed(FunctionBindParams);

namespace transform {

Pass BindParams(String func_name, Map<String, runtime::NDArray> params) {
Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); };
return CreateModulePass(pass_func, 0, "BindParams", {});
Expand Down
56 changes: 56 additions & 0 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,62 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
return ExprBinder(binds, symbolic_var_map).VisitExpr(expr);
}

tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
const tvm::Map<relax::Var, relax::Expr>& relax_var_remap, arith::Analyzer* analyzer) {
tvm::Map<tir::Var, PrimExpr> tir_var_remap;

auto bind_from_prim_expr = [&tir_var_remap](const PrimExpr& var_shape,
const PrimExpr& expr_shape) {
if (auto var = var_shape.as<tir::Var>()) {
tir_var_remap.Set(var.value(), expr_shape);
}
};

auto bind_from_shape = [&bind_from_prim_expr](const StructInfo& var, const StructInfo& expr) {
auto var_shape = var.as<ShapeStructInfoNode>();
if (!var_shape) return;
if (!var_shape->values.defined()) return;

auto expr_shape = expr.as<ShapeStructInfoNode>();
CHECK(expr_shape) << "Cannot bind expression with struct type " << expr
<< " to variable with struct type " << var;
if (!expr_shape->values.defined()) return;

auto var_shape_arr = var_shape->values.value();
auto expr_shape_arr = expr_shape->values.value();
CHECK_EQ(var_shape_arr.size(), expr_shape_arr.size())
<< "Cannot bind shape " << expr_shape_arr << " of dimension " << expr_shape_arr.size()
<< " to variable with shape " << var_shape_arr << " of dimension " << var_shape_arr.size();
for (size_t i = 0; i < var_shape_arr.size(); i++) {
bind_from_prim_expr(var_shape_arr[i], expr_shape_arr[i]);
}
};

auto bind_from_tensor = [&bind_from_shape](const StructInfo& var, const StructInfo& expr) {
auto var_tensor = var.as<TensorStructInfoNode>();
if (!var_tensor) return;
if (!var_tensor->shape.defined()) return;

auto expr_tensor = expr.as<TensorStructInfoNode>();
CHECK(expr_tensor) << "Cannot bind expression with struct type " << expr
<< " to variable with struct type " << var;
if (!expr_tensor->shape.defined()) return;

bind_from_shape(GetStructInfo(var_tensor->shape.value()),
GetStructInfo(expr_tensor->shape.value()));
};

for (const auto& [relax_var, relax_expr] : relax_var_remap) {
auto var_sinfo = GetStructInfo(relax_var);
auto expr_sinfo = GetStructInfo(relax_expr);

bind_from_tensor(var_sinfo, expr_sinfo);
bind_from_shape(var_sinfo, expr_sinfo);
}

return tir_var_remap;
}

bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank,
bool permit_unknown_dtype) {
const TensorStructInfoNode* tt = sinfo.as<TensorStructInfoNode>();
Expand Down
Loading
Loading