Skip to content

Commit

Permalink
move SanitizeName
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed Oct 11, 2022
1 parent d023ef4 commit 4bfa4af
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 32 deletions.
15 changes: 15 additions & 0 deletions include/tvm/runtime/name_mangling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

#include <string>

namespace tvm {
namespace runtime {

/*!
* \brief Sanitize name for output into compiler artifacts
* \param name Original name
* \return Sanitized name
*/
std::string SanitizeName(const std::string& name);

} // namespace support
} // namespace tvm
5 changes: 3 additions & 2 deletions src/relay/backend/aot/aot_lower_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include "./aot_lower_main.h"

#include <tvm/runtime/name_mangling.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

Expand Down Expand Up @@ -227,7 +228,7 @@ class AOTMainLowerer : public MixedModeVisitor {

for (auto input : lowered_main_func->params) {
input_vars_.push_back(input);
std::string input_name = SanitizeName(input->name_hint());
std::string input_name = tvm::runtime::SanitizeName(input->name_hint());
// We don't want the compiler changing input names in the
// event of a sanitization collision. Therefore, enforcing
// the var created to use the input_name strictly.
Expand Down Expand Up @@ -518,7 +519,7 @@ class AOTMainLowerer : public MixedModeVisitor {
return;
}
if (target_attr_map[target_kind.value()]) {
std::string context_name = SanitizeName(device_context_name);
std::string context_name = tvm::runtime::SanitizeName(device_context_name);
tir::Var device_context_var("device_context_" + context_name, DataType::Handle());

auto pair = target_contexts.find(target_kind.value());
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/runtime.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/name_mangling.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
Expand Down Expand Up @@ -534,7 +535,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
return;
}
if (target_attr_map[target_kind.value()]) {
std::string context_name = SanitizeName(device_context_name);
std::string context_name = tvm::runtime::SanitizeName(device_context_name);
tir::Var device_context_var("device_context_" + context_name, DataType::Handle());

auto pair = target_contexts.find(target_kind.value());
Expand Down
12 changes: 1 addition & 11 deletions src/relay/backend/name_transforms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "name_transforms.h"

#include <tvm/runtime/name_mangling.h>
#include <tvm/runtime/registry.h>

#include <cctype>
Expand Down Expand Up @@ -84,22 +85,11 @@ std::string CombineNames(const Array<String>& names) {
return combined_name;
}

std::string SanitizeName(const std::string& name) {
ICHECK(!name.empty()) << "Name is empty";

auto isNotAlnum = [](char c) { return !std::isalnum(c); };
std::string sanitized_input = name;
std::replace_if(sanitized_input.begin(), sanitized_input.end(), isNotAlnum, '_');

return sanitized_input;
}

TVM_REGISTER_GLOBAL("relay.backend.ToCFunctionStyle").set_body_typed(ToCFunctionStyle);
TVM_REGISTER_GLOBAL("relay.backend.ToCVariableStyle").set_body_typed(ToCVariableStyle);
TVM_REGISTER_GLOBAL("relay.backend.ToCConstantStyle").set_body_typed(ToCConstantStyle);
TVM_REGISTER_GLOBAL("relay.backend.PrefixName").set_body_typed(PrefixName);
TVM_REGISTER_GLOBAL("relay.backend.PrefixGeneratedName").set_body_typed(PrefixGeneratedName);
TVM_REGISTER_GLOBAL("relay.backend.SanitizeName").set_body_typed(SanitizeName);

} // namespace backend
} // namespace relay
Expand Down
7 changes: 0 additions & 7 deletions src/relay/backend/name_transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,6 @@ inline std::string PrefixGeneratedName(const Array<String>& names) {
return "TVMGen_" + CombineNames(names);
}

/*!
* \brief Sanitize name for output into compiler artifacts
* \param name Original name
* \return Sanitized name
*/
std::string SanitizeName(const std::string& name);

} // namespace backend
} // namespace relay
} // namespace tvm
Expand Down
5 changes: 3 additions & 2 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/name_mangling.h>

#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -507,7 +508,7 @@ class NameMangleExtFuncs : public MixedModeMutator {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
if (func->GetAttr<String>(attr::kCompiler).defined()) {
auto fn_name_mangled = relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint));
auto fn_name_mangled = tvm::runtime::SanitizeName(mangle_fn_(pair.first->name_hint));
GlobalVar gvar = GlobalVar(fn_name_mangled);
mangled_gvars_[pair.first->name_hint] = gvar;
}
Expand All @@ -526,7 +527,7 @@ class NameMangleExtFuncs : public MixedModeMutator {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
auto new_dict = func->attrs->dict;
new_dict.Set(tvm::attr::kGlobalSymbol,
String(relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint))));
String(tvm::runtime::SanitizeName(mangle_fn_(pair.first->name_hint))));
func = WithFields(func, func->params, VisitExpr(func->body), func->ret_type,
func->type_params, DictAttrs(new_dict));

Expand Down
11 changes: 6 additions & 5 deletions src/runtime/aot_executor/aot_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/name_mangling.h>

#include <limits>
#include <memory>
Expand Down Expand Up @@ -98,7 +99,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
int in_idx = this->GetInputIndex(args[0].operator String());
int in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
if (in_idx >= 0) this->SetInput(in_idx, args[1]);
} else {
this->SetInput(args[0], args[1]);
Expand All @@ -107,7 +108,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
} else if (name == "set_input_zero_copy") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
int in_idx = this->GetInputIndex(args[0].operator String());
int in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]);
} else {
this->SetInputZeroCopy(args[0], args[1]);
Expand All @@ -116,7 +117,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
} else if (name == "set_output_zero_copy") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
int out_idx = this->GetOutputIndex(args[0].operator String());
int out_idx = this->GetOutputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]);
} else {
this->SetOutputZeroCopy(args[0], args[1]);
Expand All @@ -134,7 +135,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int in_idx = 0;
if (String::CanConvertFrom(args[0])) {
in_idx = this->GetInputIndex(args[0].operator String());
in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
} else {
in_idx = args[0];
}
Expand All @@ -153,7 +154,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
} else if (name == "get_input_index") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
*rv = this->GetInputIndex(args[0].operator String());
*rv = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
});
} else {
return PackedFunc();
Expand Down
41 changes: 41 additions & 0 deletions src/runtime/name_mangling.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/runtime/logging.h>
#include <tvm/runtime/name_mangling.h>

#include <algorithm>
#include <cctype>
#include <string>

namespace tvm {
namespace runtime {

std::string SanitizeName(const std::string& name) {
ICHECK(!name.empty()) << "Name is empty";

auto isNotAlnum = [](char c) { return !std::isalnum(c); };
std::string sanitized_input = name;
std::replace_if(sanitized_input.begin(), sanitized_input.end(), isNotAlnum, '_');

return sanitized_input;
}

} // namespace runtime
} // namespace tvm
1 change: 1 addition & 0 deletions src/target/source/interface_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/name_mangling.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/usmp/utils.h>
Expand Down
9 changes: 5 additions & 4 deletions src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <dmlc/memory_io.h>
#include <tvm/runtime/metadata.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/name_mangling.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
Expand Down Expand Up @@ -507,7 +508,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
} else {
codegen_c_base_.PrintType(input_var.dtype(), call_args_ss);
}
call_args_ss << " " << relay::backend::SanitizeName(input_var->name_hint) << ",";
call_args_ss << " " << tvm::runtime::SanitizeName(input_var->name_hint) << ",";
}
for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) {
call_args_ss << "void* output" << i << ",";
Expand Down Expand Up @@ -565,10 +566,10 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
std::stringstream call_args_ss;
if (metadata_->io_pool_allocations.empty()) {
for (const auto& input : metadata_->inputs) {
call_args_ss << "inputs->" << relay::backend::SanitizeName(input->name_hint) << ",";
call_args_ss << "inputs->" << tvm::runtime::SanitizeName(input->name_hint) << ",";
}
for (const auto& output : metadata_->outputs) {
call_args_ss << "outputs->" << relay::backend::SanitizeName(output);
call_args_ss << "outputs->" << tvm::runtime::SanitizeName(output);
call_args_ss << ",";
}
}
Expand All @@ -578,7 +579,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
if (IsInternalWorkspaceBuffer(pool_var)) {
call_args_ss << "&" << pool_name << ",";
} else {
call_args_ss << "workspace_pools->" << relay::backend::SanitizeName(pool_name) << ",";
call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name) << ",";
}
}
for (const String& device : metadata_->devices) {
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/name_transforms_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <gtest/gtest.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/name_mangling.h>

using namespace tvm::relay::backend;
using namespace tvm::runtime;
Expand Down

0 comments on commit 4bfa4af

Please sign in to comment.