Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOT] Sanitize input/output name in runtime #13046

Merged
merged 7 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions include/tvm/runtime/name_mangling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <string>

namespace tvm {
namespace runtime {

/*!
* \brief Sanitize name for output into compiler artifacts
* \param name Original name
* \return Sanitized name
*/
std::string SanitizeName(const std::string& name);

} // namespace runtime
} // namespace tvm
5 changes: 3 additions & 2 deletions src/relay/backend/aot/aot_lower_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include "./aot_lower_main.h"

#include <tvm/runtime/name_mangling.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

Expand Down Expand Up @@ -227,7 +228,7 @@ class AOTMainLowerer : public MixedModeVisitor {

for (auto input : lowered_main_func->params) {
input_vars_.push_back(input);
std::string input_name = SanitizeName(input->name_hint());
std::string input_name = tvm::runtime::SanitizeName(input->name_hint());
// We don't want the compiler changing input names in the
// event of a sanitization collision. Therefore, enforcing
// the var created to use the input_name strictly.
Expand Down Expand Up @@ -518,7 +519,7 @@ class AOTMainLowerer : public MixedModeVisitor {
return;
}
if (target_attr_map[target_kind.value()]) {
std::string context_name = SanitizeName(device_context_name);
std::string context_name = tvm::runtime::SanitizeName(device_context_name);
tir::Var device_context_var("device_context_" + context_name, DataType::Handle());

auto pair = target_contexts.find(target_kind.value());
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/runtime.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/name_mangling.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
Expand Down Expand Up @@ -534,7 +535,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
return;
}
if (target_attr_map[target_kind.value()]) {
std::string context_name = SanitizeName(device_context_name);
std::string context_name = tvm::runtime::SanitizeName(device_context_name);
tir::Var device_context_var("device_context_" + context_name, DataType::Handle());

auto pair = target_contexts.find(target_kind.value());
Expand Down
12 changes: 1 addition & 11 deletions src/relay/backend/name_transforms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "name_transforms.h"

#include <tvm/runtime/name_mangling.h>
#include <tvm/runtime/registry.h>

#include <cctype>
Expand Down Expand Up @@ -84,22 +85,11 @@ std::string CombineNames(const Array<String>& names) {
return combined_name;
}

std::string SanitizeName(const std::string& name) {
ICHECK(!name.empty()) << "Name is empty";

auto isNotAlnum = [](char c) { return !std::isalnum(c); };
std::string sanitized_input = name;
std::replace_if(sanitized_input.begin(), sanitized_input.end(), isNotAlnum, '_');

return sanitized_input;
}

TVM_REGISTER_GLOBAL("relay.backend.ToCFunctionStyle").set_body_typed(ToCFunctionStyle);
TVM_REGISTER_GLOBAL("relay.backend.ToCVariableStyle").set_body_typed(ToCVariableStyle);
TVM_REGISTER_GLOBAL("relay.backend.ToCConstantStyle").set_body_typed(ToCConstantStyle);
TVM_REGISTER_GLOBAL("relay.backend.PrefixName").set_body_typed(PrefixName);
TVM_REGISTER_GLOBAL("relay.backend.PrefixGeneratedName").set_body_typed(PrefixGeneratedName);
TVM_REGISTER_GLOBAL("relay.backend.SanitizeName").set_body_typed(SanitizeName);

} // namespace backend
} // namespace relay
Expand Down
7 changes: 0 additions & 7 deletions src/relay/backend/name_transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,6 @@ inline std::string PrefixGeneratedName(const Array<String>& names) {
return "TVMGen_" + CombineNames(names);
}

/*!
* \brief Sanitize name for output into compiler artifacts
* \param name Original name
* \return Sanitized name
*/
std::string SanitizeName(const std::string& name);

} // namespace backend
} // namespace relay
} // namespace tvm
Expand Down
5 changes: 3 additions & 2 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/name_mangling.h>

#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -507,7 +508,7 @@ class NameMangleExtFuncs : public MixedModeMutator {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
if (func->GetAttr<String>(attr::kCompiler).defined()) {
auto fn_name_mangled = relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint));
auto fn_name_mangled = tvm::runtime::SanitizeName(mangle_fn_(pair.first->name_hint));
GlobalVar gvar = GlobalVar(fn_name_mangled);
mangled_gvars_[pair.first->name_hint] = gvar;
}
Expand All @@ -526,7 +527,7 @@ class NameMangleExtFuncs : public MixedModeMutator {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
auto new_dict = func->attrs->dict;
new_dict.Set(tvm::attr::kGlobalSymbol,
String(relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint))));
String(tvm::runtime::SanitizeName(mangle_fn_(pair.first->name_hint))));
func = WithFields(func, func->params, VisitExpr(func->body), func->ret_type,
func->type_params, DictAttrs(new_dict));

Expand Down
11 changes: 6 additions & 5 deletions src/runtime/aot_executor/aot_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/name_mangling.h>

#include <limits>
#include <memory>
Expand Down Expand Up @@ -98,7 +99,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
int in_idx = this->GetInputIndex(args[0].operator String());
int in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
if (in_idx >= 0) this->SetInput(in_idx, args[1]);
} else {
this->SetInput(args[0], args[1]);
Expand All @@ -107,7 +108,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
} else if (name == "set_input_zero_copy") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
int in_idx = this->GetInputIndex(args[0].operator String());
int in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it also make sense to add a python test that passes in names like input:0 to test that this functionality doesn't regress for the various AoT commands?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it makes sense. I will add a test.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test which exams both mangled name and the original name

if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]);
} else {
this->SetInputZeroCopy(args[0], args[1]);
Expand All @@ -116,7 +117,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
} else if (name == "set_output_zero_copy") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
int out_idx = this->GetOutputIndex(args[0].operator String());
int out_idx = this->GetOutputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]);
} else {
this->SetOutputZeroCopy(args[0], args[1]);
Expand All @@ -134,7 +135,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int in_idx = 0;
if (String::CanConvertFrom(args[0])) {
in_idx = this->GetInputIndex(args[0].operator String());
in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
} else {
in_idx = args[0];
}
Expand All @@ -153,7 +154,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
} else if (name == "get_input_index") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
*rv = this->GetInputIndex(args[0].operator String());
*rv = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
});
} else {
return PackedFunc();
Expand Down
41 changes: 41 additions & 0 deletions src/runtime/name_mangling.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/runtime/logging.h>
#include <tvm/runtime/name_mangling.h>

#include <algorithm>
#include <cctype>
#include <string>

namespace tvm {
namespace runtime {

std::string SanitizeName(const std::string& name) {
ICHECK(!name.empty()) << "Name is empty";

auto isNotAlnum = [](char c) { return !std::isalnum(c); };
std::string sanitized_input = name;
std::replace_if(sanitized_input.begin(), sanitized_input.end(), isNotAlnum, '_');

return sanitized_input;
}

} // namespace runtime
} // namespace tvm
1 change: 1 addition & 0 deletions src/target/source/interface_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/name_mangling.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/usmp/utils.h>
Expand Down
9 changes: 5 additions & 4 deletions src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <dmlc/memory_io.h>
#include <tvm/runtime/metadata.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/name_mangling.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
Expand Down Expand Up @@ -507,7 +508,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
} else {
codegen_c_base_.PrintType(input_var.dtype(), call_args_ss);
}
call_args_ss << " " << relay::backend::SanitizeName(input_var->name_hint) << ",";
call_args_ss << " " << tvm::runtime::SanitizeName(input_var->name_hint) << ",";
}
for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) {
call_args_ss << "void* output" << i << ",";
Expand Down Expand Up @@ -565,10 +566,10 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
std::stringstream call_args_ss;
if (metadata_->io_pool_allocations.empty()) {
for (const auto& input : metadata_->inputs) {
call_args_ss << "inputs->" << relay::backend::SanitizeName(input->name_hint) << ",";
call_args_ss << "inputs->" << tvm::runtime::SanitizeName(input->name_hint) << ",";
}
for (const auto& output : metadata_->outputs) {
call_args_ss << "outputs->" << relay::backend::SanitizeName(output);
call_args_ss << "outputs->" << tvm::runtime::SanitizeName(output);
call_args_ss << ",";
}
}
Expand All @@ -578,7 +579,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
if (IsInternalWorkspaceBuffer(pool_var)) {
call_args_ss << "&" << pool_name << ",";
} else {
call_args_ss << "workspace_pools->" << relay::backend::SanitizeName(pool_name) << ",";
call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name) << ",";
}
}
for (const String& device : metadata_->devices) {
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/name_transforms_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <gtest/gtest.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/name_mangling.h>

using namespace tvm::relay::backend;
using namespace tvm::runtime;
Expand Down