Skip to content

Commit

Permalink
[AOT] Sanitize input/output name in runtime (#13046)
Browse files Browse the repository at this point in the history
This PR adds name sanitization to input/output at AOT runtime module. This means when we use set_input/set_output with AOT, even if we use the original name of those input/outputs which were sanitized and changed in the codegen, it will map them to the correct input/output.

For example if your model has an input with name `input:0`, the AOT codegen would change this to `input_0` but at runtime if we try to set this input it does not exist. Now, we use the same sanitization at runtime
  • Loading branch information
mehrdadh authored Oct 17, 2022
1 parent e8ba1dc commit 3f0d3f2
Show file tree
Hide file tree
Showing 16 changed files with 184 additions and 47 deletions.
43 changes: 43 additions & 0 deletions include/tvm/runtime/name_transforms.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/runtime/name_transforms.h
* \brief Transformations which are applied on names to generate appropriately named.
* These functions are used in both Runtime and Backend.
*/
#ifndef TVM_RUNTIME_NAME_TRANSFORMS_H_
#define TVM_RUNTIME_NAME_TRANSFORMS_H_

#include <string>

namespace tvm {
namespace runtime {

/*!
* \brief Sanitize name for output into compiler artifacts
* \param name Original name
* \return Sanitized name
*/
std::string SanitizeName(const std::string& name);

} // namespace runtime
} // namespace tvm

#endif // TVM_RUNTIME_NAME_TRANSFORMS_H_
11 changes: 0 additions & 11 deletions python/tvm/relay/backend/name_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,3 @@ def prefix_generated_name(names: Union[List[str], str]):
"""

return _backend.PrefixGeneratedName(_preprocess_names(names))


def sanitize_name(original_name: str):
"""Sanitize name for output into compiler artifacts
Parameters
----------
original_name : str
Original name to sanitize
"""
return _backend.SanitizeName(original_name)
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
import tvm
from tvm import relay
from tvm.ir import IRModule
from tvm.runtime.name_transforms import sanitize_name

from ... import nd as _nd
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import qnn as _qnn
from ..backend.name_transforms import sanitize_name
from .common import ExprTable
from .common import infer_shape as _infer_shape
from .common import lstm_cell, to_int_list, shape_of, try_infer_value
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/runtime/name_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Name transformation functions shared in Backend and Runtime
"""

from . import _ffi_api


def sanitize_name(original_name: str):
"""Sanitize name for output into compiler artifacts
Parameters
----------
original_name : str
Original name to sanitize
"""
return _ffi_api.SanitizeName(original_name)
5 changes: 3 additions & 2 deletions src/relay/backend/aot/aot_lower_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include "./aot_lower_main.h"

#include <tvm/runtime/name_transforms.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

Expand Down Expand Up @@ -227,7 +228,7 @@ class AOTMainLowerer : public MixedModeVisitor {

for (auto input : lowered_main_func->params) {
input_vars_.push_back(input);
std::string input_name = SanitizeName(input->name_hint());
std::string input_name = tvm::runtime::SanitizeName(input->name_hint());
// We don't want the compiler changing input names in the
// event of a sanitization collision. Therefore, enforcing
// the var created to use the input_name strictly.
Expand Down Expand Up @@ -518,7 +519,7 @@ class AOTMainLowerer : public MixedModeVisitor {
return;
}
if (target_attr_map[target_kind.value()]) {
std::string context_name = SanitizeName(device_context_name);
std::string context_name = tvm::runtime::SanitizeName(device_context_name);
tir::Var device_context_var("device_context_" + context_name, DataType::Handle());

auto pair = target_contexts.find(target_kind.value());
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/runtime.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/name_transforms.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
Expand Down Expand Up @@ -534,7 +535,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
return;
}
if (target_attr_map[target_kind.value()]) {
std::string context_name = SanitizeName(device_context_name);
std::string context_name = tvm::runtime::SanitizeName(device_context_name);
tir::Var device_context_var("device_context_" + context_name, DataType::Handle());

auto pair = target_contexts.find(target_kind.value());
Expand Down
12 changes: 1 addition & 11 deletions src/relay/backend/name_transforms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "name_transforms.h"

#include <tvm/runtime/name_transforms.h>
#include <tvm/runtime/registry.h>

#include <cctype>
Expand Down Expand Up @@ -84,22 +85,11 @@ std::string CombineNames(const Array<String>& names) {
return combined_name;
}

std::string SanitizeName(const std::string& name) {
ICHECK(!name.empty()) << "Name is empty";

auto isNotAlnum = [](char c) { return !std::isalnum(c); };
std::string sanitized_input = name;
std::replace_if(sanitized_input.begin(), sanitized_input.end(), isNotAlnum, '_');

return sanitized_input;
}

TVM_REGISTER_GLOBAL("relay.backend.ToCFunctionStyle").set_body_typed(ToCFunctionStyle);
TVM_REGISTER_GLOBAL("relay.backend.ToCVariableStyle").set_body_typed(ToCVariableStyle);
TVM_REGISTER_GLOBAL("relay.backend.ToCConstantStyle").set_body_typed(ToCConstantStyle);
TVM_REGISTER_GLOBAL("relay.backend.PrefixName").set_body_typed(PrefixName);
TVM_REGISTER_GLOBAL("relay.backend.PrefixGeneratedName").set_body_typed(PrefixGeneratedName);
TVM_REGISTER_GLOBAL("relay.backend.SanitizeName").set_body_typed(SanitizeName);

} // namespace backend
} // namespace relay
Expand Down
7 changes: 0 additions & 7 deletions src/relay/backend/name_transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,6 @@ inline std::string PrefixGeneratedName(const Array<String>& names) {
return "TVMGen_" + CombineNames(names);
}

/*!
* \brief Sanitize name for output into compiler artifacts
* \param name Original name
* \return Sanitized name
*/
std::string SanitizeName(const std::string& name);

} // namespace backend
} // namespace relay
} // namespace tvm
Expand Down
5 changes: 3 additions & 2 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/name_transforms.h>

#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -507,7 +508,7 @@ class NameMangleExtFuncs : public MixedModeMutator {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
if (func->GetAttr<String>(attr::kCompiler).defined()) {
auto fn_name_mangled = relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint));
auto fn_name_mangled = tvm::runtime::SanitizeName(mangle_fn_(pair.first->name_hint));
GlobalVar gvar = GlobalVar(fn_name_mangled);
mangled_gvars_[pair.first->name_hint] = gvar;
}
Expand All @@ -526,7 +527,7 @@ class NameMangleExtFuncs : public MixedModeMutator {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
auto new_dict = func->attrs->dict;
new_dict.Set(tvm::attr::kGlobalSymbol,
String(relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint))));
String(tvm::runtime::SanitizeName(mangle_fn_(pair.first->name_hint))));
func = WithFields(func, func->params, VisitExpr(func->body), func->ret_type,
func->type_params, DictAttrs(new_dict));

Expand Down
11 changes: 6 additions & 5 deletions src/runtime/aot_executor/aot_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/name_transforms.h>

#include <limits>
#include <memory>
Expand Down Expand Up @@ -98,7 +99,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
int in_idx = this->GetInputIndex(args[0].operator String());
int in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
if (in_idx >= 0) this->SetInput(in_idx, args[1]);
} else {
this->SetInput(args[0], args[1]);
Expand All @@ -107,7 +108,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
} else if (name == "set_input_zero_copy") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
int in_idx = this->GetInputIndex(args[0].operator String());
int in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]);
} else {
this->SetInputZeroCopy(args[0], args[1]);
Expand All @@ -116,7 +117,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
} else if (name == "set_output_zero_copy") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
int out_idx = this->GetOutputIndex(args[0].operator String());
int out_idx = this->GetOutputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]);
} else {
this->SetOutputZeroCopy(args[0], args[1]);
Expand All @@ -134,7 +135,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int in_idx = 0;
if (String::CanConvertFrom(args[0])) {
in_idx = this->GetInputIndex(args[0].operator String());
in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
} else {
in_idx = args[0];
}
Expand All @@ -153,7 +154,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
} else if (name == "get_input_index") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
*rv = this->GetInputIndex(args[0].operator String());
*rv = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
});
} else {
return PackedFunc();
Expand Down
44 changes: 44 additions & 0 deletions src/runtime/name_transforms.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/runtime/logging.h>
#include <tvm/runtime/name_transforms.h>
#include <tvm/runtime/registry.h>

#include <algorithm>
#include <cctype>
#include <string>

namespace tvm {
namespace runtime {

std::string SanitizeName(const std::string& name) {
ICHECK(!name.empty()) << "Name is empty";

auto isNotAlnum = [](char c) { return !std::isalnum(c); };
std::string sanitized_input = name;
std::replace_if(sanitized_input.begin(), sanitized_input.end(), isNotAlnum, '_');

return sanitized_input;
}

TVM_REGISTER_GLOBAL("runtime.SanitizeName").set_body_typed(SanitizeName);

} // namespace runtime
} // namespace tvm
1 change: 1 addition & 0 deletions src/target/source/interface_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/name_transforms.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/usmp/utils.h>
Expand Down
9 changes: 5 additions & 4 deletions src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <dmlc/memory_io.h>
#include <tvm/runtime/metadata.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/name_transforms.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
Expand Down Expand Up @@ -507,7 +508,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
} else {
codegen_c_base_.PrintType(input_var.dtype(), call_args_ss);
}
call_args_ss << " " << relay::backend::SanitizeName(input_var->name_hint) << ",";
call_args_ss << " " << tvm::runtime::SanitizeName(input_var->name_hint) << ",";
}
for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) {
call_args_ss << "void* output" << i << ",";
Expand Down Expand Up @@ -565,10 +566,10 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
std::stringstream call_args_ss;
if (metadata_->io_pool_allocations.empty()) {
for (const auto& input : metadata_->inputs) {
call_args_ss << "inputs->" << relay::backend::SanitizeName(input->name_hint) << ",";
call_args_ss << "inputs->" << tvm::runtime::SanitizeName(input->name_hint) << ",";
}
for (const auto& output : metadata_->outputs) {
call_args_ss << "outputs->" << relay::backend::SanitizeName(output);
call_args_ss << "outputs->" << tvm::runtime::SanitizeName(output);
call_args_ss << ",";
}
}
Expand All @@ -578,7 +579,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
if (IsInternalWorkspaceBuffer(pool_var)) {
call_args_ss << "&" << pool_name << ",";
} else {
call_args_ss << "workspace_pools->" << relay::backend::SanitizeName(pool_name) << ",";
call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name) << ",";
}
}
for (const String& device : metadata_->devices) {
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/name_transforms_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <gtest/gtest.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/name_transforms.h>

using namespace tvm::relay::backend;
using namespace tvm::runtime;
Expand Down
Loading

0 comments on commit 3f0d3f2

Please sign in to comment.