diff --git a/include/tvm/runtime/name_transforms.h b/include/tvm/runtime/name_transforms.h new file mode 100644 index 000000000000..267dda4158c8 --- /dev/null +++ b/include/tvm/runtime/name_transforms.h @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/name_transforms.h + * \brief Transformations which are applied on names to generate appropriately named. + * These functions are used in both Runtime and Backend. + */ +#ifndef TVM_RUNTIME_NAME_TRANSFORMS_H_ +#define TVM_RUNTIME_NAME_TRANSFORMS_H_ + +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Sanitize name for output into compiler artifacts + * \param name Original name + * \return Sanitized name + */ +std::string SanitizeName(const std::string& name); + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_NAME_TRANSFORMS_H_ diff --git a/python/tvm/relay/backend/name_transforms.py b/python/tvm/relay/backend/name_transforms.py index 19208725a8b9..bbf51a8e24b1 100644 --- a/python/tvm/relay/backend/name_transforms.py +++ b/python/tvm/relay/backend/name_transforms.py @@ -97,14 +97,3 @@ def prefix_generated_name(names: Union[List[str], str]): """ return _backend.PrefixGeneratedName(_preprocess_names(names)) - - -def sanitize_name(original_name: str): - """Sanitize name for output into compiler artifacts - - Parameters - ---------- - original_name : str - Original name to sanitize - """ - return _backend.SanitizeName(original_name) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index a7e10ad72e55..1915eb9322ff 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -23,6 +23,7 @@ import tvm from tvm import relay from tvm.ir import IRModule +from tvm.runtime.name_transforms import sanitize_name from ... import nd as _nd from .. import analysis @@ -30,7 +31,6 @@ from .. import function as _function from .. import op as _op from .. import qnn as _qnn -from ..backend.name_transforms import sanitize_name from .common import ExprTable from .common import infer_shape as _infer_shape from .common import lstm_cell, to_int_list, shape_of, try_infer_value diff --git a/python/tvm/runtime/name_transforms.py b/python/tvm/runtime/name_transforms.py new file mode 100644 index 000000000000..402a47f1a114 --- /dev/null +++ b/python/tvm/runtime/name_transforms.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Name transformation functions shared in Backend and Runtime +""" + +from . import _ffi_api + + +def sanitize_name(original_name: str): + """Sanitize name for output into compiler artifacts + + Parameters + ---------- + original_name : str + Original name to sanitize + """ + return _ffi_api.SanitizeName(original_name) diff --git a/src/relay/backend/aot/aot_lower_main.cc b/src/relay/backend/aot/aot_lower_main.cc index ce72595dc10b..51dd4b219313 100644 --- a/src/relay/backend/aot/aot_lower_main.cc +++ b/src/relay/backend/aot/aot_lower_main.cc @@ -23,6 +23,7 @@ */ #include "./aot_lower_main.h" +#include #include #include @@ -227,7 +228,7 @@ class AOTMainLowerer : public MixedModeVisitor { for (auto input : lowered_main_func->params) { input_vars_.push_back(input); - std::string input_name = SanitizeName(input->name_hint()); + std::string input_name = tvm::runtime::SanitizeName(input->name_hint()); // We don't want the compiler changing input names in the // event of a sanitization collision. Therefore, enforcing // the var created to use the input_name strictly. @@ -518,7 +519,7 @@ class AOTMainLowerer : public MixedModeVisitor { return; } if (target_attr_map[target_kind.value()]) { - std::string context_name = SanitizeName(device_context_name); + std::string context_name = tvm::runtime::SanitizeName(device_context_name); tir::Var device_context_var("device_context_" + context_name, DataType::Handle()); auto pair = target_contexts.find(target_kind.value()); diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 6a9cadb6f770..786b3f81a5ae 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -534,7 +535,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { return; } if (target_attr_map[target_kind.value()]) { - std::string context_name = SanitizeName(device_context_name); + std::string context_name = tvm::runtime::SanitizeName(device_context_name); tir::Var device_context_var("device_context_" + context_name, DataType::Handle()); auto pair = target_contexts.find(target_kind.value()); diff --git a/src/relay/backend/name_transforms.cc b/src/relay/backend/name_transforms.cc index a2f24216ec24..4f364b811bcc 100644 --- a/src/relay/backend/name_transforms.cc +++ b/src/relay/backend/name_transforms.cc @@ -19,6 +19,7 @@ #include "name_transforms.h" +#include #include #include @@ -84,22 +85,11 @@ std::string CombineNames(const Array& names) { return combined_name; } -std::string SanitizeName(const std::string& name) { - ICHECK(!name.empty()) << "Name is empty"; - - auto isNotAlnum = [](char c) { return !std::isalnum(c); }; - std::string sanitized_input = name; - std::replace_if(sanitized_input.begin(), sanitized_input.end(), isNotAlnum, '_'); - - return sanitized_input; -} - TVM_REGISTER_GLOBAL("relay.backend.ToCFunctionStyle").set_body_typed(ToCFunctionStyle); TVM_REGISTER_GLOBAL("relay.backend.ToCVariableStyle").set_body_typed(ToCVariableStyle); TVM_REGISTER_GLOBAL("relay.backend.ToCConstantStyle").set_body_typed(ToCConstantStyle); TVM_REGISTER_GLOBAL("relay.backend.PrefixName").set_body_typed(PrefixName); TVM_REGISTER_GLOBAL("relay.backend.PrefixGeneratedName").set_body_typed(PrefixGeneratedName); -TVM_REGISTER_GLOBAL("relay.backend.SanitizeName").set_body_typed(SanitizeName); } // namespace backend } // namespace relay diff --git a/src/relay/backend/name_transforms.h b/src/relay/backend/name_transforms.h index a30ba6b10825..f59280af2222 100644 --- a/src/relay/backend/name_transforms.h +++ b/src/relay/backend/name_transforms.h @@ -102,13 +102,6 @@ inline std::string PrefixGeneratedName(const Array& names) { return "TVMGen_" + CombineNames(names); } -/*! - * \brief Sanitize name for output into compiler artifacts - * \param name Original name - * \return Sanitized name - */ -std::string SanitizeName(const std::string& name); - } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index e2df2e4272ad..f6cdf6d1ca18 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -36,6 +36,7 @@ #include #include #include +#include #include #include @@ -507,7 +508,7 @@ class NameMangleExtFuncs : public MixedModeMutator { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); if (func->GetAttr(attr::kCompiler).defined()) { - auto fn_name_mangled = relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint)); + auto fn_name_mangled = tvm::runtime::SanitizeName(mangle_fn_(pair.first->name_hint)); GlobalVar gvar = GlobalVar(fn_name_mangled); mangled_gvars_[pair.first->name_hint] = gvar; } @@ -526,7 +527,7 @@ class NameMangleExtFuncs : public MixedModeMutator { if (func->GetAttr(attr::kCompiler).defined()) { auto new_dict = func->attrs->dict; new_dict.Set(tvm::attr::kGlobalSymbol, - String(relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint)))); + String(tvm::runtime::SanitizeName(mangle_fn_(pair.first->name_hint)))); func = WithFields(func, func->params, VisitExpr(func->body), func->ret_type, func->type_params, DictAttrs(new_dict)); diff --git a/src/runtime/aot_executor/aot_executor.cc b/src/runtime/aot_executor/aot_executor.cc index 985c857ed55f..7f7daabf3fc2 100644 --- a/src/runtime/aot_executor/aot_executor.cc +++ b/src/runtime/aot_executor/aot_executor.cc @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -98,7 +99,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name, if (name == "set_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { if (String::CanConvertFrom(args[0])) { - int in_idx = this->GetInputIndex(args[0].operator String()); + int in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String())); if (in_idx >= 0) this->SetInput(in_idx, args[1]); } else { this->SetInput(args[0], args[1]); @@ -107,7 +108,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name, } else if (name == "set_input_zero_copy") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { if (String::CanConvertFrom(args[0])) { - int in_idx = this->GetInputIndex(args[0].operator String()); + int in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String())); if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]); } else { this->SetInputZeroCopy(args[0], args[1]); @@ -116,7 +117,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name, } else if (name == "set_output_zero_copy") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { if (String::CanConvertFrom(args[0])) { - int out_idx = this->GetOutputIndex(args[0].operator String()); + int out_idx = this->GetOutputIndex(tvm::runtime::SanitizeName(args[0].operator String())); if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]); } else { this->SetOutputZeroCopy(args[0], args[1]); @@ -134,7 +135,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name, return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { int in_idx = 0; if (String::CanConvertFrom(args[0])) { - in_idx = this->GetInputIndex(args[0].operator String()); + in_idx = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String())); } else { in_idx = args[0]; } @@ -153,7 +154,7 @@ PackedFunc AotExecutor::GetFunction(const std::string& name, } else if (name == "get_input_index") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string"; - *rv = this->GetInputIndex(args[0].operator String()); + *rv = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String())); }); } else { return PackedFunc(); diff --git a/src/runtime/name_transforms.cc b/src/runtime/name_transforms.cc new file mode 100644 index 000000000000..608b88ac430e --- /dev/null +++ b/src/runtime/name_transforms.cc @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace runtime { + +std::string SanitizeName(const std::string& name) { + ICHECK(!name.empty()) << "Name is empty"; + + auto isNotAlnum = [](char c) { return !std::isalnum(c); }; + std::string sanitized_input = name; + std::replace_if(sanitized_input.begin(), sanitized_input.end(), isNotAlnum, '_'); + + return sanitized_input; +} + +TVM_REGISTER_GLOBAL("runtime.SanitizeName").set_body_typed(SanitizeName); + +} // namespace runtime +} // namespace tvm diff --git a/src/target/source/interface_c.cc b/src/target/source/interface_c.cc index fa38d9b9f4d1..ed7058f1f198 100644 --- a/src/target/source/interface_c.cc +++ b/src/target/source/interface_c.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 88a7a99b4c25..ce5f5d5b5357 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -507,7 +508,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } else { codegen_c_base_.PrintType(input_var.dtype(), call_args_ss); } - call_args_ss << " " << relay::backend::SanitizeName(input_var->name_hint) << ","; + call_args_ss << " " << tvm::runtime::SanitizeName(input_var->name_hint) << ","; } for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) { call_args_ss << "void* output" << i << ","; @@ -565,10 +566,10 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { std::stringstream call_args_ss; if (metadata_->io_pool_allocations.empty()) { for (const auto& input : metadata_->inputs) { - call_args_ss << "inputs->" << relay::backend::SanitizeName(input->name_hint) << ","; + call_args_ss << "inputs->" << tvm::runtime::SanitizeName(input->name_hint) << ","; } for (const auto& output : metadata_->outputs) { - call_args_ss << "outputs->" << relay::backend::SanitizeName(output); + call_args_ss << "outputs->" << tvm::runtime::SanitizeName(output); call_args_ss << ","; } } @@ -578,7 +579,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { if (IsInternalWorkspaceBuffer(pool_var)) { call_args_ss << "&" << pool_name << ","; } else { - call_args_ss << "workspace_pools->" << relay::backend::SanitizeName(pool_name) << ","; + call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name) << ","; } } for (const String& device : metadata_->devices) { diff --git a/tests/cpp/name_transforms_test.cc b/tests/cpp/name_transforms_test.cc index 09a5bbfb583a..12a2ce1d0761 100644 --- a/tests/cpp/name_transforms_test.cc +++ b/tests/cpp/name_transforms_test.cc @@ -21,6 +21,7 @@ #include #include +#include using namespace tvm::relay::backend; using namespace tvm::runtime; diff --git a/tests/python/relay/aot/test_cpp_aot.py b/tests/python/relay/aot/test_cpp_aot.py index b67bc90d34fd..89c34eaac8b6 100644 --- a/tests/python/relay/aot/test_cpp_aot.py +++ b/tests/python/relay/aot/test_cpp_aot.py @@ -118,7 +118,7 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), @pytest.mark.parametrize("enable_usmp", [True, False]) @pytest.mark.parametrize("target_kind", ["c", "llvm"]) -def test_mobilenet(enable_usmp, target_kind): +def test_mobilenet(enable_usmp: bool, target_kind: str): """Full network test with Mobilenet""" ir_mod, params = testing.mobilenet.get_workload(batch_size=1) data_shape = [int(x) for x in ir_mod["main"].checked_type.arg_types[0].shape] @@ -203,5 +203,44 @@ def test_pass_wrong_device_arg(): # TODO write asserts for # and type of device. +@pytest.mark.parametrize("target_kind", ["c", "llvm"]) +@pytest.mark.parametrize("input_name", ["input:0", "input@0", "input_0"]) +def test_aot_input_name_with_special_character(target_kind: str, input_name: str): + """Test name transforms in AOT for input names with special characters.""" + dtype = "float32" + input_1 = relay.var(input_name, shape=(10, 5), dtype=dtype) + weight = relay.var("weight", shape=(1, 5), dtype=dtype) + output = relay.add(input_1, weight) + func = relay.Function([input_1, weight], output) + + input_data = np.random.rand(10, 5).astype(dtype) + weight_data = np.random.rand(1, 5).astype(dtype) + expected_output = input_data + weight_data + params = {"weight": weight_data} + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = tvm.relay.build( + tvm.IRModule.from_expr(func), + target=target_kind, + params=params, + executor=tvm.relay.backend.Executor("aot", {"interface-api": "packed"}), + ) + temp_dir = tvm.contrib.utils.TempDirectory() + test_so_path = temp_dir / "test.so" + mod.export_library(test_so_path, cc="c++", options=["-std=gnu++17", "-g3", "-O0"]) + # test both original name and transformed name + for name in ["input_0", input_name]: + loaded_mod = tvm.runtime.load_module(test_so_path) + runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0))) + inputs = {name: input_data} + runner.set_input(**inputs) + + input_ind = runner.get_input_index(name) + assert (runner.get_input(input_ind).asnumpy() == input_data).all() + + runner.run() + assert (runner.get_output(0).asnumpy() == expected_output).all() + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relay/test_name_transforms.py b/tests/python/relay/test_name_transforms.py index 1c3435a6cc85..72976dc19c21 100644 --- a/tests/python/relay/test_name_transforms.py +++ b/tests/python/relay/test_name_transforms.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest from tvm import TVMError from tvm.relay.backend.name_transforms import ( @@ -22,9 +23,8 @@ to_c_constant_style, prefix_name, prefix_generated_name, - sanitize_name, ) -import pytest +from tvm.runtime.name_transforms import sanitize_name def test_to_c_function_style():