diff --git a/cmake/modules/contrib/CODEGENC.cmake b/cmake/modules/contrib/CODEGENC.cmake index 275c32514eba..412fa3e8ffc5 100644 --- a/cmake/modules/contrib/CODEGENC.cmake +++ b/cmake/modules/contrib/CODEGENC.cmake @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -tvm_file_glob(GLOB CSOURCE_RELAY_CONTRIB_SRC src/relay/backend/contrib/codegen_c/codegen.cc) +tvm_file_glob(GLOB CSOURCE_RELAY_CONTRIB_SRC src/relay/backend/contrib/codegen_c/*.cc) list(APPEND COMPILER_SRCS ${CSOURCE_RELAY_CONTRIB_SRC}) diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index b78f16a84f02..f73f2230df4d 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -479,8 +479,10 @@ TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true, namespace attr { +// Following are attributes for IRModule only. + /*! - * \brief Executor targetted by the module + * \brief Executor targeted by the module * * Type: Executor * @@ -516,9 +518,31 @@ constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools"; constexpr const char* kConstantMemoryPools = "constant_memory_pools"; /* - * \brief Module attribute for tir constants + * \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The + * node will record the index into this array. See also kConstNameToConstant below, which is + * the analog for Realy Functions. + * + * Type: Array + */ +constexpr const char* kConstants = "constants"; + +/*! + * \brief All the runtime::Modules accumulated during compilation by external codegen. These + * modules must be either directly linked or captured in the final compilation artifact. + * + * Type: Array + */ +constexpr const char* kExternalMods = "external_mods"; + +/*! + * \brief All the named runtime::NDArrays accumulated during compilation by external codegen. + * Generally the associated runtime::Module will indicate it requires bindings for these names, + * and during module initialization these bindings will be recovered from a ConstLoaderModule. + * See also kConstantsArray above, which is the analog for PrimFuncs. + * + * Type: Map */ -constexpr const char* kConstantsArray = "Constants"; +constexpr const char* kConstNameToConstant = "const_name_to_constant"; } // namespace attr } // namespace tvm diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index ac35c0b41e0e..ddc97549fc70 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -599,9 +599,9 @@ class AllocateConstNode : public StmtNode { /*! \brief The optional data associated to the constant. */ Optional data; - /*! \brief If the PrimFunc containing the Stmt is added to IRModule, - this is an optional index to indicate the index within - "Constants" attribute, that is a Array of IRModule. + /*! + * \brief If the PrimFunc containing the Stmt is added to IRModule, this is an optional index + * to indicate the index within "constants" attribute, that is a Array of IRModule. */ Optional irmod_storage_idx; /*! \brief The type of the buffer. */ diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 819e5eda41f5..020736beb5c4 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -195,7 +195,7 @@ class Interpreter(Executor): The runtime device to run the code on. target : tvm.Target - The target option to build the function. + The target option to build the function. Only homogeneous execution is supported. CAUTION: Despite the API the module is prepared upon each call to evaluate rather than once in create_executor. diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index d4a82cd8d427..bc11d43cb0ca 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -198,8 +198,9 @@ class VMExecutor(Executor): device : :py:class:`~tvm.runtime.Device` The runtime device to run the code on. - target : :py:class:`Target` - The target option to build the function. + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. """ def __init__(self, mod, device, target): diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 1353d8c5f595..32ad6c70794c 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -570,8 +570,9 @@ class GraphExecutor(_interpreter.Executor): device : :py:class:`Device` The runtime device to run the code on. - target : :py:class:`Target` - The target option to build the function. + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. """ def __init__(self, mod, device, target): @@ -630,8 +631,9 @@ class AotExecutor(_interpreter.Executor): device : :py:class:`Device` The runtime device to run the code on. - target : :py:class:`Target` - The target option to build the function. + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. """ def __init__(self, mod, device, target): @@ -639,7 +641,6 @@ def __init__(self, mod, device, target): self.mod = mod self.device = device self.target = target - assert target.attrs.get("executor", "graph") == "aot" def _make_executor(self, expr=None): if expr: @@ -719,8 +720,11 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N device : :py:class:`Device` The device to execute the code. - target : :py:class:`tvm.Target` - The corresponding context + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. + CAUTION: Though this API allows multiple targets, it does not allow multiple devices, so + heterogenous compilation is not yet supported. params : dict of str to NDArray Input parameters to the graph that do not change @@ -730,24 +734,31 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N ------- executor : :py:class:`~tvm.relay.backend.interpreter.Executor` """ + raw_targets = Target.canon_multi_target(target) if mod is None: mod = IRModule() if device is not None: - assert device.device_type == _nd.device(str(target), 0).device_type + assert device.device_type == raw_targets[0].kind.device_type else: - device = _nd.device(str(target), 0) + # Derive the default device from the first target. + device = _nd.device(raw_targets[0].kind.device_type, 0) if params is not None: mod = IRModule.from_expr(bind_params_by_name(mod["main"], params)) - if isinstance(target, str): - target = Target(target) + assert "executor" not in raw_targets[0].attrs or raw_targets[0].attrs["executor"] == kind + if kind == "debug": - return _interpreter.Interpreter(mod, device, target) + assert len(raw_targets) == 1, "The interpreter currently only supports a single target" + return _interpreter.Interpreter(mod, device, raw_targets[0]) if kind == "graph": - return GraphExecutor(mod, device, target) + return GraphExecutor(mod, device, raw_targets) if kind == "vm": - return VMExecutor(mod, device, target) + return VMExecutor(mod, device, raw_targets) if kind == "aot": - return AotExecutor(mod, device, target) + # The AOT requires the executor as a target attribute. + # (The compilation paths for the other executors currently do not always provide this + # attribute, hence the above generic assert is more forgiving). + assert "executor" in raw_targets[0].attrs + return AotExecutor(mod, device, raw_targets) raise RuntimeError("unknown execution strategy: {0}".format(kind)) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index c931289d40c6..d7979a757171 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1386,7 +1386,7 @@ def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""): Parameters ---------- compiler_filter : String - If non-empty, the 'compiler' attribute to filter on. + If non-empty, the "Compiler" attribute to filter on. Returns ------- @@ -1412,7 +1412,7 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""): Parameters ---------- compiler_filter : String - If non-empty, the 'compiler' attribute to filter on. + If non-empty, the "Compiler" attribute to filter on. Returns ------- diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 301bfa73c818..063439e068a4 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -358,7 +358,7 @@ class AllocateConst(Stmt): data_or_idx : Union[NDArray, int] If an NDArray, this is the const data associated with the constant. If an integer, this is the index into the - "Constants" attribute of the `IRModule` that contains the + "constants" attribute of the `IRModule` that contains the `AllocateConst`. body : Stmt diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 5020e79714b2..ae60970b78af 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1167,11 +1167,19 @@ class AOTExecutorCodegen : public MixedModeVisitor { // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need // to run the LegalizePackedCalls pass. LoweredOutput ret; - ret.params = std::unordered_map>(); - for (auto param : params_) { - ret.params.emplace(std::make_pair( - param.first, - std::make_pair(static_cast(param_storage_ids_[param.first]), param.second))); + + // Collect any constants extracted by external codegen. + ret.params = std::unordered_map(); + Map const_name_to_constant = + lowered_mod->GetAttr>(tvm::attr::kConstNameToConstant) + .value_or({}); + for (const auto& kv : const_name_to_constant) { + ICHECK(ret.params.emplace(kv.first, kv.second).second); + } + + // Collect any constants extracted during lowering. + for (const auto& kv : params_) { + ICHECK(ret.params.emplace(kv.first, kv.second).second); } // AoT Executor codegen works completely on TIR beyond this point, hence removing relay main @@ -1212,9 +1220,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { lowered_mod = pack_calls(lowered_mod); } - Optional> external_modules = - lowered_mod->GetAttr>("external_mods"); - ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point."; + // Collect any runtime modules generated by external codegen. + ret.external_mods = + lowered_mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); // This is the point where we separate the functions in the module by target VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod); @@ -1227,8 +1235,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { << PrettyPrint(kv.second); } - ret.external_mods = external_modules.value(); - // Extract USMP metadata to pass onto metadata sources Map pool_var_info; std::vector pool_vars; @@ -1316,11 +1322,6 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { String key = args[0]; *rv = get_param_by_name(key); }); - } else if (name == "get_param_id") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - String key = args[0]; - *rv = get_param_id(key); - }); } else if (name == "get_irmodule") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); }); @@ -1362,17 +1363,11 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { runtime::NDArray get_param_by_name(String key) { auto it = this->output_.params.find(key); CHECK(it != this->output_.params.end()) << "no such parameter " << key; - return (*it).second.second; + return (*it).second; } Array get_external_modules() { return output_.external_mods; } - int get_param_id(String key) { - auto it = this->output_.params.find(key); - CHECK(it != this->output_.params.end()) << "no such parameter " << key; - return (*it).second.first; - } - Map get_irmodule() { return this->output_.lowered_funcs; } std::shared_ptr codegen_; diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 628dee0844ec..9a68b567305d 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -86,17 +86,6 @@ struct ExecutorCodegen { return ret; } - std::unordered_map GetParamIds() { - std::unordered_map ret; - auto names = CallFunc>("list_params_name", nullptr); - for (const auto& expr : names) { - // Implicit cast from runtime::String to std::string - std::string key = expr; - ret[key] = CallFunc("get_param_id", key); - } - return ret; - } - Array GetExternalModules() { return CallFunc>("get_external_modules", nullptr); } @@ -478,6 +467,7 @@ class RelayBuildModule : public runtime::ModuleNode { for (size_t i = 0; i < variables.size(); i++) { auto it = ret_.params.find(variables[i].operator std::string()); if (it != ret_.params.end()) { + VLOG(1) << "constant '" << variables[i] << "' has been captured in external module"; ret_.params.erase(it); } } diff --git a/src/relay/backend/contrib/arm_compute_lib/codegen.cc b/src/relay/backend/contrib/arm_compute_lib/codegen.cc index 842ede3bf20b..81a5b5bbd9d8 100644 --- a/src/relay/backend/contrib/arm_compute_lib/codegen.cc +++ b/src/relay/backend/contrib/arm_compute_lib/codegen.cc @@ -392,10 +392,15 @@ runtime::Module ACLCompiler(const ObjectRef& ref) { ACLJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); - auto param_names = serializer.GetParams(); + + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. + const auto* pf = runtime::Registry::Get("runtime.arm_compute_lib_runtime_create"); ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; - runtime::Module lib = (*pf)(func_name, graph_json, param_names); + runtime::Module lib = (*pf)(func_name, graph_json, serializer.const_names()); return lib; } diff --git a/src/relay/backend/contrib/bnns/codegen.cc b/src/relay/backend/contrib/bnns/codegen.cc index 72c32fb5b19e..3791773ad67d 100644 --- a/src/relay/backend/contrib/bnns/codegen.cc +++ b/src/relay/backend/contrib/bnns/codegen.cc @@ -136,11 +136,15 @@ runtime::Module BNNSCompiler(const ObjectRef& ref) { BNNSJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); - auto params = serializer.GetParams(); + + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. const auto* pf = runtime::Registry::Get("runtime.BNNSJSONRuntimeCreate"); ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; - auto mod = (*pf)(func_name, graph_json, params); + auto mod = (*pf)(func_name, graph_json, serializer.const_names()); return mod; } diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index fd1c39bb9283..ee8724fe92fe 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -16,17 +16,17 @@ * specific language governing permissions and limitations * under the License. */ -#include + #include #include #include #include #include -#include #include #include +#include "../../../transforms/compiler_function_utils.h" #include "../../utils.h" #include "codegen_c.h" @@ -34,30 +34,62 @@ namespace tvm { namespace relay { namespace contrib { -using namespace backend; +/*! \brief Return the "ccompiler" Target instance to use to guide compilation. */ +Target GetCCompilerTarget() { + Target target = Target::Current(/*allow_not_defined=*/true); + if (!target.defined() || target->kind->name != "ccompiler") { + // Use the default compilation options if no specific "ccompiler" target was given + // in the overall targets list. In that case target_hooks.cc will invoke the custom pass + // without pushing any target instance onto the implicit target stack. + target = Target("ccompiler"); + } + return target; +} /*! - * \brief An example codegen that is only used for quick prototyping and testing - * purpose. Only several binary options are covered. Users - * may need to extend them to cover more operators. + * \brief Emits C/C++ code for a single function. + * + * For testing and demonstration only, only a few binary operators are supported. */ -class CodegenC : public MemoizedExprTranslator>, public CodegenCBase { +class CodegenC : public backend::MemoizedExprTranslator>, public CodegenCBase { public: - explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; } + CodegenC(std::unordered_map* const_name_to_constant, + Array* const_names, bool* needs_extra_headers, std::string ext_func_id) + : const_name_to_constant_(const_name_to_constant), + const_names_(const_names), + needs_extra_headers_(needs_extra_headers), + ext_func_id_(std::move(ext_func_id)) {} - std::vector VisitExprDefault_(const Object* op) final { + /*! + * \brief Emit the source code that invokes C compiler compatible wrappers. + * + * \return The emitted code. + */ + std::string JIT(const std::vector& out) override { + if (!ext_func_args_.empty()) { + *needs_extra_headers_ = true; + } + // Write function macros + for (auto decl : func_decl_) { + code_stream_ << decl << "\n"; + } + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out); + } + + private: + std::vector VisitExprDefault_(const Object* op) override { LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey(); return {}; } - std::vector VisitExpr_(const VarNode* node) final { + std::vector VisitExpr_(const VarNode* node) override { ext_func_args_.push_back(GetRef(node)); Output output; output.name = node->name_hint(); return {output}; } - std::vector VisitExpr_(const TupleNode* node) final { + std::vector VisitExpr_(const TupleNode* node) override { std::vector outs; for (auto field : node->fields) { auto res = VisitExpr(field); @@ -67,7 +99,7 @@ class CodegenC : public MemoizedExprTranslator>, public Code return outs; } - std::vector VisitExpr_(const TupleGetItemNode* op) final { + std::vector VisitExpr_(const TupleGetItemNode* op) override { auto res = VisitExpr(op->tuple); ICHECK_GT(res.size(), static_cast(op->index)); @@ -76,19 +108,21 @@ class CodegenC : public MemoizedExprTranslator>, public Code return {res[op->index]}; } - std::vector VisitExpr_(const ConstantNode* cn) final { + std::vector VisitExpr_(const ConstantNode* cn) override { std::ostringstream decl_stream; std::ostringstream buf_stream; Output output; // Get const: static_cast(gcc_0_consts[0]->data) - output.name = CreateDataReference(ext_func_id_, const_idx_); + size_t const_id = const_name_to_constant_->size(); + output.name = CreateDataReference(ext_func_id_, const_id); const auto* type_node = cn->checked_type().as(); ICHECK(type_node); const auto& dtype = GetDtypeString(type_node); // Generate the global variable for needed ndarrays if (const_array_name_.empty()) { + *needs_extra_headers_ = true; const_array_name_ = CreateNDArrayPool(ext_func_id_); std::string checker = CreateInitChecker(ext_func_id_); ext_func_body_.insert(ext_func_body_.begin(), checker); @@ -97,14 +131,14 @@ class CodegenC : public MemoizedExprTranslator>, public Code ICHECK(dtype == "float" || dtype == "int") << "Only float and int are supported for now."; output.dtype = dtype; - std::string const_var_name = CreateConstVar(ext_func_id_, const_idx_); - const_vars_.push_back(const_var_name); - const_idx_++; + std::string const_var_name = CreateConstVar(ext_func_id_, const_id); + const_name_to_constant_->emplace(const_var_name, cn->data); + const_names_->push_back(const_var_name); return {output}; } - std::vector VisitExpr_(const CallNode* call) final { + std::vector VisitExpr_(const CallNode* call) override { std::ostringstream macro_stream; std::ostringstream decl_stream; std::ostringstream buf_stream; @@ -114,17 +148,17 @@ class CodegenC : public MemoizedExprTranslator>, public Code // Make function declaration macro_stream << "CSOURCE_BINARY_OP_" << call->args.size() << "D(" << func_name << ", "; - if (IsOp(call, "add")) { + if (backend::IsOp(call, "add")) { macro_stream << "+"; - } else if (IsOp(call, "subtract")) { + } else if (backend::IsOp(call, "subtract")) { macro_stream << "-"; - } else if (IsOp(call, "multiply")) { + } else if (backend::IsOp(call, "multiply")) { macro_stream << "*"; } else { LOG(FATAL) << "Unrecognized op"; } - auto in_shape = GetShape(call->args[0]->checked_type()); + auto in_shape = backend::GetShape(call->args[0]->checked_type()); for (size_t i = 0; i < in_shape.size(); ++i) { macro_stream << ", " << in_shape[i]; } @@ -152,7 +186,7 @@ class CodegenC : public MemoizedExprTranslator>, public Code } std::string out = "buf_" + std::to_string(buf_idx_++); - auto out_shape = GetShape(call->checked_type()); + auto out_shape = backend::GetShape(call->checked_type()); int out_size = 1; for (size_t i = 0; i < out_shape.size(); ++i) { out_size *= out_shape[i]; @@ -175,27 +209,21 @@ class CodegenC : public MemoizedExprTranslator>, public Code } /*! - * \brief Emit the source code that invokes C compiler compatible wrappers. - * - * \return The emitted code. + * \brief The accumulated constant name to constant mapping. Shared between all generated + * functions. */ - std::string JIT(const std::vector& out) { - // Write function macros - for (auto decl : func_decl_) { - code_stream_ << decl << "\n"; - } - return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out); - } - - private: - /*! \brief The function id that represents a C source function. */ - std::string ext_func_id_ = ""; - /*! \brief The index of a wrapped C function. */ + std::unordered_map* const_name_to_constant_; + /*! \brief The accumulated constant names, in the order they were generated. */ + Array* const_names_; + /*! \brief Set to true if the ndarray and packed function headers are required. */ + bool* needs_extra_headers_; + /*! \brief Name of the global function currently being compiled. */ + std::string ext_func_id_; + + /*! \brief The index of the next available wrapped C function. */ int func_idx = 0; - /*! \brief The index of allocated buffers. */ + /*! \brief The index of the next available allocated buffers. */ int buf_idx_ = 0; - /*! \brief The index of global constants. */ - int const_idx_ = 0; /*! \brief The arguments of a C compiler compatible function. */ Array ext_func_args_; /*! \brief The statements of a C compiler compatible function. */ @@ -206,53 +234,55 @@ class CodegenC : public MemoizedExprTranslator>, public Code std::vector func_decl_; /*! \brief The declaration statements of buffers. */ std::vector buf_decl_; - /*! \brief The variable name to constant mapping. */ - Array const_vars_; - - friend class CSourceCodegen; }; -class CSourceCodegen : public CSourceModuleCodegenBase { +/*! \brief Emits C/C++ code for a module. */ +class CodegenCModule { public: - std::tuple, String, String> GenCFunc(const Function& func) { - ICHECK(func.defined()) << "Input error: expect a Relay function."; - CodegenC builder(GetExtSymbol(func)); - auto out = builder.VisitExpr(func->body); - return std::make_tuple(builder.const_vars_, builder.ext_func_id_, builder.JIT(out)); - } + CodegenCModule(Target target, IRModule mod) : target_(std::move(target)), mod_(std::move(mod)) {} - runtime::Module CreateCSourceModule(const ObjectRef& ref) override { - ICHECK(ref->IsInstance()); - auto res = GenCFunc(Downcast(ref)); - Array variables = std::get<0>(res); - String func_name = std::get<1>(res); - - Optional opt_target = Target::Current(); - if (opt_target.defined() && opt_target.value()->kind->name == "ccompiler") { - Optional header = opt_target.value()->GetAttr("header"); - if (header.defined() && !header.value().empty()) { - code_stream_ << header.value().c_str() << "\n"; + runtime::Module CreateCSourceModule() { + for (const auto& kv : mod_->functions) { + if (const auto* function_node = GetCCompilerFunctionNode(kv.second)) { + GenCFunc(GetRef(function_node)); } } + return Finalize(); + } + + /*! \brief Returns the accumulated constant name to constant mapping. */ + const std::unordered_map& const_name_to_constant() const { + return const_name_to_constant_; + } + + private: + /*! \brief Emits the standard C/C++ header into \p os. */ + void EmitPreamble(std::ostringstream& os) { + // Custom header, if any. + Optional header = target_->GetAttr("header"); + if (header.defined() && !header.value().empty()) { + os << header.value().c_str() << "\n"; + } + + // Standard includes. + os << "#include \n"; + os << "#include \n"; + os << "#include \n"; + os << "#include \n"; + os << "#include \n"; - // Create headers - code_stream_ << "#include \n"; - code_stream_ << "#include \n"; - code_stream_ << "#include \n"; - code_stream_ << "#include \n"; - code_stream_ << "#include \n"; - if (!variables.empty()) { + if (needs_extra_headers_) { // This segment would be generated in C++ because of the usage // of tvm::runtime::Array. This is not ideal, but this to demonstrate // constant copying process used packed imports in other external // codegen. Moreover, in microTVM we dont expect this part to be generated. - code_stream_ << "#ifdef __cplusplus\n"; - code_stream_ << "#include \n"; - code_stream_ << "#include \n"; - code_stream_ << "#endif\n"; + os << "#ifdef __cplusplus\n"; + os << "#include \n"; + os << "#include \n"; + os << "#endif\n"; } - // Append some common macro for operator definition. + // Define some macros to help operator implementations. const char* operator_macro = R"op_macro( #define CSOURCE_BINARY_OP_1D(p_ID_, p_OP_, p_DIM1_, p_DTYPE) \ void p_ID_(p_DTYPE* a, p_DTYPE* b, p_DTYPE* out) { \ @@ -272,38 +302,97 @@ class CSourceCodegen : public CSourceModuleCodegenBase { } )op_macro"; - code_stream_ << operator_macro << "\n\n"; - code_stream_ << std::get<2>(res); - std::string code = code_stream_.str(); + os << operator_macro << "\n\n"; + } + + void GenCFunc(const Function& function) { + ICHECK(function.defined()) << "Input error: expect a Relay function."; + std::string ext_func_id = backend::GetExtSymbol(function); + CodegenC builder(&const_name_to_constant_, &const_names_, &needs_extra_headers_, ext_func_id); + std::vector out = builder.VisitExpr(function->body); + code_stream_ << builder.JIT(out); + func_names_.push_back(ext_func_id); + } + + /*! \brief Returns function if it is tagged with "Compiler=ccompiler". */ + static const FunctionNode* GetCCompilerFunctionNode(const Expr& expr) { + if (const auto* function_node = expr.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && opt_compiler.value() == "ccompiler") { + return function_node; + } + } + return nullptr; + } + + runtime::Module Finalize() { + std::ostringstream os; + EmitPreamble(os); + os << code_stream_.str(); + std::string code = os.str(); + + VLOG(1) << "CodegenCModule generated:" << std::endl << code; // Create a CSource module const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); ICHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; - return (*pf)(code, "c", Array{func_name}, variables); + return (*pf)(code, "c", func_names_, const_names_); } - private: + /*! \brief "ccompiler" Target with compilation options to use. */ + Target target_; + /*! \brief Module we are compiling. */ + IRModule mod_; + + /*! \brief True if we need to include the ndarray and packed function headers. */ + bool needs_extra_headers_ = false; + /*! \brief The accumulated constant name to constant mapping. */ + std::unordered_map const_name_to_constant_; + /*! \brief The accumulated constant names, in the order they were generated. */ + Array const_names_; + /*! \brief The accumulated function names. */ + Array func_names_; + /*! + * \brief The accumulated code stream containing all function definitions. + * (Does not include the preamble.) + */ std::ostringstream code_stream_; }; -/*! - * \brief The external compiler/codegen tool. It takes a Relay expression/module and - * compile it into a runtime module. - * - * The external codegen tool should have been registered similiarly to LLVM, - * CUDA, etc, under TVM, so the generated code could be packed in a runtime - * module. This module simplifies code serialization and invocation. - */ -runtime::Module CCompiler(const ObjectRef& ref) { - CSourceCodegen csource; - return csource.CreateCSourceModule(ref); -} +/*! \brief The actual translation pass. */ +transform::Pass CCompilerImpl() { + auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) { + VLOG(1) << "CCompilerImpl input:" << std::endl << PrettyPrint(mod); + Target target = GetCCompilerTarget(); + + // Emit the C/C++ code and package it as a CSourceModule. + CodegenCModule codegen(target, mod); + runtime::Module runtime_mod = codegen.CreateCSourceModule(); + + // Capture the new runtime module. + Array external_mods = + mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); + external_mods.push_back(runtime_mod); + + // Capture the new constants. + Map const_name_to_constant = + mod->GetAttr>(tvm::attr::kConstNameToConstant).value_or({}); + for (const auto& kv : codegen.const_name_to_constant()) { + ICHECK_EQ(const_name_to_constant.count(kv.first), 0); + const_name_to_constant.Set(kv.first, kv.second); + } -TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler); + return WithAttrs(mod, {{tvm::attr::kExternalMods, external_mods}, + {tvm::attr::kConstNameToConstant, const_name_to_constant}}); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "CCompilerImpl", {}); +} -TVM_REGISTER_TARGET_KIND("ccompiler", kDLCPU) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) - .add_attr_option("header", String("")); // value is prepended to every output CModule +transform::Pass CCompilerPass() { + return transform::Sequential( + {transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(), + transforms::MarkCompilerFunctionsAsExtern("ccompiler")}); +} } // namespace contrib } // namespace relay diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 49a5bca068d1..1ee72c149f1a 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -409,7 +409,7 @@ class CodegenCBase { * * \return The created reference */ - std::string CreateDataReference(const std::string& symbol, int const_id) const { + std::string CreateDataReference(const std::string& symbol, size_t const_id) const { return "(float*)(" + symbol + "_consts[" + std::to_string(const_id) + "]->data)"; } @@ -421,8 +421,8 @@ class CodegenCBase { * * \return The created variable name */ - std::string CreateConstVar(const std::string& symbol, int const_id) const { - return symbol + "_const_" + std::to_string(const_id++); + std::string CreateConstVar(const std::string& symbol, size_t const_id) const { + return symbol + "_const_" + std::to_string(const_id); } /*! \brief The external function source code stream. */ @@ -433,7 +433,14 @@ class CodegenCBase { int indent_{0}; }; +/*! + * \brief A pass to translate all "Primitive" Relay functions with "Compiler=ccompiler" to + * a \p CSourceModule. + */ +transform::Pass CCompilerPass(); + } // namespace contrib } // namespace relay } // namespace tvm + #endif // TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_ diff --git a/src/relay/backend/contrib/codegen_c/target.cc b/src/relay/backend/contrib/codegen_c/target.cc new file mode 100644 index 000000000000..623057ac1762 --- /dev/null +++ b/src/relay/backend/contrib/codegen_c/target.cc @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include "./codegen_c.h" + +namespace tvm { +namespace relay { +namespace contrib { + +/*! + * \brief This demonstration external codegen target emits C/C++ for compilation by the native c + * compiler on CPU. + * - Patterns: None, functions must be explicitly marked as "Primitive" and "Compiler=ccompiler". + * - Custom compiler: relay/backend/contrib/codegen_c/codegen.cc + */ +TVM_REGISTER_TARGET_KIND("ccompiler", kDLCPU) + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .set_attr(tvm::attr::kRelayToTIR, CCompilerPass()) + // Value is prepended to every output CModule. + .add_attr_option("header", String("")); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/codegen_json/codegen_json.h b/src/relay/backend/contrib/codegen_json/codegen_json.h index 4966f3f01c7d..de6d0f74061b 100644 --- a/src/relay/backend/contrib/codegen_json/codegen_json.h +++ b/src/relay/backend/contrib/codegen_json/codegen_json.h @@ -33,6 +33,8 @@ #include #include #include +#include +#include #include #include "../../../../runtime/contrib/json/json_node.h" @@ -150,7 +152,8 @@ class JSONSerializer : public MemoizedExprTranslator(func_); @@ -162,8 +165,18 @@ class JSONSerializer : public MemoizedExprTranslatorbody); } - /*!\brief Return the required params. */ - Array GetParams() const { return params_; } + /*! + * \brief Returns the accumulated map from constant names to the NDArray they must be bound to + * at runtime. Also referred to a 'params' elsewhere in the code. + */ + const std::unordered_map& const_name_to_constant() const { + return const_name_to_constant_; + } + + /*! + * \brief Return the constant names in order they were encountered during translation. + */ + const Array& const_names() const { return const_names_; } /*!\brief Return the generated json. */ std::string GetJSON() { @@ -245,11 +258,15 @@ class JSONSerializer : public MemoizedExprTranslator(vn)]; } - std::vector VisitExpr_(const ConstantNode* cn) { - std::string name = symbol_ + "_const_" + std::to_string(params_.size()); - params_.push_back(name); - auto node = std::make_shared(name, "const" /* op_type_ */); - return AddNode(node, GetRef(cn)); + std::vector VisitExpr_(const ConstantNode* constant_node) { + std::string name = symbol_ + "_const_" + std::to_string(const_names_.size()); + VLOG(1) << "Will require parameter '" << name + << "' to be supplied by the ConstLoaderModule at runtime"; + ICHECK_EQ(const_name_to_constant_.count(name), 0); + const_name_to_constant_.emplace(name, constant_node->data); + const_names_.push_back(name); + auto node = std::make_shared(name, /*op_type=*/"const"); + return AddNode(node, GetRef(constant_node)); } std::vector VisitExpr_(const TupleNode* tn) { @@ -340,8 +357,17 @@ class JSONSerializer : public MemoizedExprTranslator nodes_; /*! \brief Output of the JSON graph. */ std::vector heads_; - /*! \brief The list of required constants. */ - Array params_; + /*! + * \brief A map from constant names to NDArrays for each Constant encountered during + * translation to JSON. The JSON will record only the constant name. The actual NDArray must + * be made available at runtime from a ConstLoaderModule. + */ + std::unordered_map const_name_to_constant_; + /*! + * \brief The domain of the above map, but in order the constants were encountered during + * translation. + */ + Array const_names_; }; } // namespace contrib diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index 772007792ae6..de2934173b5f 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -43,6 +43,18 @@ namespace cutlass { namespace { +/*! \brief Return the "cutlass" Target instance to use to guide compilation. */ +Target GetCutlassTarget() { + Target target = Target::Current(/*allow_not_defined=*/true); + if (!target.defined() || target->kind->name != "cutlass") { + // Use the default CUTLASS compilation options if no specific "cutlass" target was given + // in the overall targets list. In that case target_hooks.cc will invoke the custom pass + // without pushing any target instance onto the implicit target stack. + target = Target("cutlass"); + } + return target; +} + using Str2StrMap = std::unordered_map; static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, @@ -563,7 +575,7 @@ class CodegenCutlass : public backend::MemoizedExprTranslatorExitScope(); code_stream_ << "}\n"; - this->GenerateBackendCFunc(ext_func_id_, ext_func_args_, const_array_name_, out, true); + this->GenerateBackendCFunc(ext_func_id_, ext_func_args_, /*const_arr_name=*/"", out, true); return code_stream_.str(); } @@ -769,7 +781,7 @@ class CodegenCutlass : public backend::MemoizedExprTranslator attrs_; /*! @@ -781,8 +793,6 @@ class CodegenCutlass : public backend::MemoizedExprTranslator ext_func_args_; /*! \brief Statement of the function that will be compiled using CUTLASS kernels. */ std::vector ext_func_body_; - /*! \brief The array declared to store the constant values. */ - std::string const_array_name_; /*! \brief The declaration of intermediate buffers. */ std::vector buf_decl_; }; // class CodegenCutlass @@ -863,14 +873,14 @@ class CutlassModuleCodegen { const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); ICHECK(pf != nullptr) << "Cannot find CSource module to create the external runtime module"; VLOG(1) << "Generated CUTLASS code:" << std::endl << code_stream_.str(); - return (*pf)(code_stream_.str(), "cu", func_names_, const_vars_); + return (*pf)(code_stream_.str(), "cu", func_names_, /*const_vars=*/Array()); } /*! * \brief Returns \p expr as function if it is a \p Function with "Compiler" attribute * value "cutlass". */ - const FunctionNode* GetCutlassFunctionNode(const Expr& expr) { + static const FunctionNode* GetCutlassFunctionNode(const Expr& expr) { if (const auto* function_node = expr.as()) { Optional opt_compiler = function_node->GetAttr(attr::kCompiler); if (opt_compiler.defined() && opt_compiler.value() == "cutlass") { @@ -886,8 +896,6 @@ class CutlassModuleCodegen { std::ostringstream code_stream_; /*! \brief The accumulated function names. */ Array func_names_; - /*! \brief The accumulated constant names. */ - Array const_vars_; }; // CutlassModuleCodegen /*! @@ -899,14 +907,12 @@ transform::Pass CompileForCutlassImpl() { VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod); const auto* pf = runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass"); ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function"; - Optional opt_cutlass_target = Target::Current(); - ICHECK(opt_cutlass_target.defined()) << "Expecting Target::Current to be available"; - ICHECK_EQ(opt_cutlass_target.value()->kind->name, "cutlass"); - runtime::Module runtime_mod = (*pf)(mod, opt_cutlass_target.value()); + Target target = GetCutlassTarget(); + runtime::Module runtime_mod = (*pf)(mod, target); Array external_mods = - mod->GetAttr>("external_mods", Array()).value(); + mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); external_mods.push_back(runtime_mod); - return WithAttr(mod, "external_mods", external_mods); + return WithAttr(mod, tvm::attr::kExternalMods, external_mods); }; return tvm::transform::CreateModulePass(pass_func, 0, "CompileForCutlass", {}); } diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index f17cdafa76a5..2f47c23a7cf9 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -585,11 +585,15 @@ runtime::Module DNNLCompiler(const ObjectRef& ref) { DNNLJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); - auto params = serializer.GetParams(); + + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. const auto* pf = runtime::Registry::Get("runtime.DNNLJSONRuntimeCreate"); ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; - auto mod = (*pf)(func_name, graph_json, params); + auto mod = (*pf)(func_name, graph_json, serializer.const_names()); return mod; #else DNNLModuleCodegen dnnl; diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc index 19bfa8c68298..b01c23ed806a 100644 --- a/src/relay/backend/contrib/example_target_hooks/target.cc +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index 149cc485c752..e08cd240d4d1 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -318,11 +318,16 @@ runtime::Module TensorRTCompiler(const ObjectRef& ref) { serializer.serialize(); std::string graph_json = serializer.GetJSON(); VLOG(1) << "TensorRT JSON:" << std::endl << graph_json; - auto param_names = serializer.GetParams(); + + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. + const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create"); ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function."; VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'"; - runtime::Module lib = (*pf)(func_name, graph_json, param_names); + runtime::Module lib = (*pf)(func_name, graph_json, serializer.const_names()); return lib; } diff --git a/src/relay/backend/contrib/verilator/codegen.cc b/src/relay/backend/contrib/verilator/codegen.cc index 2c29896d1b0e..2e6fb1326314 100644 --- a/src/relay/backend/contrib/verilator/codegen.cc +++ b/src/relay/backend/contrib/verilator/codegen.cc @@ -111,10 +111,15 @@ runtime::Module VerilatorBackend(const ObjectRef& ref) { VerilatorJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); - auto params = serializer.GetParams(); + + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. // Create runtime object - auto n = make_object(func_name, graph_json, params); + auto n = make_object(func_name, graph_json, + serializer.const_names()); // Get Verilator compiler options auto ctx = transform::PassContext::Current(); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index af426e5c71cb..faf9d2899fc3 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -259,21 +259,31 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator>(); - for (auto param : params_) { - ret.params.emplace(std::make_pair( - param.first, - std::make_pair(static_cast(param_storage_ids_[param.first]), param.second))); + + // Collect any runtime modules generated by external codegen. + ret.external_mods = + lowered_mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); + + // Collect any constants extracted by external codegen. + ret.params = std::unordered_map(); + Map const_name_to_constant = + lowered_mod->GetAttr>(tvm::attr::kConstNameToConstant) + .value_or({}); + for (const auto& kv : const_name_to_constant) { + VLOG(1) << "constant '" << kv.first << "' contributed by external codegen"; + ICHECK(ret.params.emplace(kv.first, kv.second).second); } - ret.function_metadata = std::move(function_metadata_); - Optional> external_modules = - lowered_mod->GetAttr>("external_mods"); - ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point."; + // Collect any constants extracted during lowering. + for (const auto& kv : params_) { + VLOG(1) << "constant '" << kv.first << "' contributed by TECompiler"; + ICHECK(ret.params.emplace(kv.first, kv.second).second); + } + + ret.function_metadata = std::move(function_metadata_); // This is the point where we separate the functions in the module by target ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod); - ret.external_mods = external_modules.value(); ret.metadata = ExecutorCodegenMetadata({} /* inputs */, {} /* input_tensor_types */, {} /* outputs */, {} /* output_tensor_types */, {} /* pools */, {} /* devices */, @@ -650,14 +660,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { String key = args[0]; auto it = this->output_.params.find(key); CHECK(it != this->output_.params.end()) << "no such parameter " << key; - *rv = (*it).second.second; - }); - } else if (name == "get_param_id") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - String key = args[0]; - auto it = this->output_.params.find(key); - CHECK(it != this->output_.params.end()) << "no such parameter " << key; - *rv = (*it).second.first; + *rv = (*it).second; }); } else if (name == "get_irmodule") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 08fa18b61e16..210f77330afd 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -1224,7 +1224,7 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr // annotate the module with the resulting runtime modules. // TODO(mbs): runtime modules should be first class rather than attributes. Array external_mods = - module->GetAttr>("external_mods", Array()).value(); + module->GetAttr>(tvm::attr::kExternalMods).value_or({}); Array new_external_mods = compiler->LowerExternalFunctions(); VLOG(1) << "capturing " << external_mods.size() << " existing and " << new_external_mods.size() << " new external modules"; @@ -1246,7 +1246,7 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr device_contexts.Set(kv.first, kv.second); // copy-on-write. } - updated_module = WithAttrs(updated_module, {{"external_mods", std::move(external_mods)}, + updated_module = WithAttrs(updated_module, {{tvm::attr::kExternalMods, std::move(external_mods)}, {"device_contexts", std::move(device_contexts)}}); if (backend::IsAutoSchedulerEnabled()) { diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 67924a7835fb..d6fae8c72b5e 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -223,7 +223,11 @@ struct LoweredOutput { Map lowered_funcs; Array external_mods; Map function_metadata; - std::unordered_map> params; + /*! + * \brief Map from constant names (allocated by the codegen as constants are encountered) + * to the constant's value. + */ + std::unordered_map params; ExecutorCodegenMetadata metadata; }; @@ -249,7 +253,7 @@ struct ConstantUpdater : public ExprVisitor { void VisitExpr_(const ConstantNode* cn) final { std::string name = symbol_ + "_const_" + std::to_string(const_idx_++); - VLOG(1) << "Binding " << name << " to constant of type " << PrettyPrint(cn->checked_type()); + VLOG(1) << "binding '" << name << "' to constant of type " << PrettyPrint(cn->checked_type()); (*params_)[name] = cn->data; } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 7371fd1f8083..a8bd3df32a90 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1166,11 +1166,27 @@ void VMCompiler::Codegen() { for (const auto& kv : per_tvm_target_modules) { ICHECK(kv.first->kind->device_type != kDLExtDev); } - Array ext_mods = - context_.module->GetAttr>("external_mods", Array()) - .value(); - VLOG(0) << "have " << per_tvm_target_modules.size() << " targets to build and " << ext_mods.size() - << " external runtime modules"; + + // Retrieve all external runtime modules accumulated by external codegen (both function-at-a-time + // and IRModule-at-a-time). + Array external_mods = + context_.module->GetAttr>(tvm::attr::kExternalMods).value_or({}); + + // Retrieve any constant bindings accumulated by external codegen (by IRModule-at-a-time passes). + Map const_name_to_constant = + context_.module->GetAttr>(tvm::attr::kConstNameToConstant) + .value_or({}); + + VLOG(0) << "have " << per_tvm_target_modules.size() << " targets to build, " + << external_mods.size() << " external runtime modules, " << const_name_to_constant.size() + << " external constants, and " << params_.size() << " local constants"; + + // Any constant bindings must be merged into the overall 'params' map we've directly accumulated + // via the TECompiler callback. + for (const auto& kv : const_name_to_constant) { + ICHECK_EQ(params_.count(kv.first), 0); + params_.emplace(kv.first, kv.second); + } runtime::Module lib; if (per_tvm_target_modules.empty()) { @@ -1183,7 +1199,7 @@ void VMCompiler::Codegen() { } lib = - codegen::CreateMetadataModule(params_, lib, ext_mods, config_->host_target, + codegen::CreateMetadataModule(params_, lib, external_mods, config_->host_target, Runtime::Create("cpp"), Executor::Create("graph"), // DNS HACK relay::backend::ExecutorCodegenMetadata()); exec_->SetLib(lib); diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc index 1b0f002f1def..0df9f5ee294c 100644 --- a/src/relay/transforms/compiler_function_utils.cc +++ b/src/relay/transforms/compiler_function_utils.cc @@ -50,7 +50,7 @@ const FunctionNode* AsFunctionNode(const Expr& expr, const std::string& compiler } /*! - * \brief Rewrite calls to inlined "Compiler" functions to global functions. The given + * \brief Rewrite calls to inlined and let-bound "Compiler" functions to global functions. The given * module will be extended with the newly outlined functions. */ class Outliner : public MixedModeMutator { @@ -58,6 +58,38 @@ class Outliner : public MixedModeMutator { Outliner(GlobalSymbolCache* cache, std::string compiler_filter, IRModule mod) : cache_(cache), compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {} + Expr VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + Expr var = this->VisitExpr(op->var); + Expr value = this->VisitExpr(op->value); + + if (AsFunctionNode(value, compiler_filter_)) { + // Inline on-the-fly if the let-bound value is a function of interest. + this->memo_[var] = value; + } + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->VisitExpr(op->value); + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + + if (AsFunctionNode(value, compiler_filter_)) { + // The let binding is no longer needed since inlined on-the-fly above. + this->memo_[expr] = this->VisitExpr(op->body); + } else { + Var var = Downcast(this->VisitExpr(op->var)); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + this->memo_[expr] = expr; + } else { + this->memo_[expr] = Let(var, value, body); + } + } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; + } + Expr Rewrite_(const CallNode* pre, const Expr& post) final { Call new_call = Downcast(post); if (const auto* function_node = AsFunctionNode(new_call->op, compiler_filter_)) { diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h index 6664594fc0a0..aa98430318a6 100644 --- a/src/relay/transforms/compiler_function_utils.h +++ b/src/relay/transforms/compiler_function_utils.h @@ -95,9 +95,10 @@ class ExistingGlobalSymbolCache : public GlobalSymbolCache { }; /*! - * \brief A pass to outline all literal functions in direct call positions which have a "Compiler" - * attribute. The given \p GlobalSymbolCache is used to determine a unique global symbol for each - * function, which is also assigned to the "global_symbol" attribute of the new global function. + * \brief A pass to outline all let-bound and literal functions in direct call positions which have + * a "Compiler" attribute. The given \p GlobalSymbolCache is used to determine a unique global + * symbol for each function, which is also assigned to the "global_symbol" attribute of the new + * global function. * * At most one function with the same global symbol is outlined. * @@ -108,9 +109,9 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr cach std::string compiler_filter = ""); /*! - * \brief A pass to outline all literal functions in direct call positions which have a "Compiler" - * attribute. The functions are bound to unique global vars according to their existing - * "global_symbol" attribute. At most one function with the same global symbol is outlined. + * \brief A pass to outline all let-bound and literal functions in direct call positions which have + * a "Compiler" attribute. The functions are bound to unique global vars according to their + * existing "global_symbol" attribute. At most one function with the same global symbol is outlined. * * If \p compiler_filter is non-empty only functions with that as their attribute value are * outlined. diff --git a/src/relay/transforms/target_hooks.cc b/src/relay/transforms/target_hooks.cc index 00953a1907e1..f52e95b2adbf 100644 --- a/src/relay/transforms/target_hooks.cc +++ b/src/relay/transforms/target_hooks.cc @@ -148,7 +148,7 @@ class TargetHookVisitor : public MixedModeVisitor { Pass RelayToTIRTargetHook(CompilationConfig config) { auto pass_func = [config = std::move(config)](IRModule mod, const PassContext& pass_ctx) { - VLOG(1) << "Before:" << std::endl << PrettyPrint(mod); + VLOG(1) << "RelayToTIRTargetHook before:" << std::endl << PrettyPrint(mod); TargetHookVisitor target_hook_visitor(mod, config); std::vector custom_passes = target_hook_visitor.Visit(); for (const auto& custom_pass : custom_passes) { @@ -161,11 +161,14 @@ Pass RelayToTIRTargetHook(CompilationConfig config) { mod = custom_pass.pass(mod); } else { // Invoke the pass. + // Note that there may be a non-external codegen target in scope. Each custom pass + // must be prepared to handle this, eg by creating a default target instance if the + // current target is either null or of a generic kind such as 'cuda' or 'llvm'. VLOG(0) << "Invoking custom pass for target kind '" << custom_pass.target_kind_name << "'"; mod = custom_pass.pass(mod); } } - VLOG(1) << "After:" << std::endl << PrettyPrint(mod); + VLOG(1) << "RelayToTIRTargetHook after:" << std::endl << PrettyPrint(mod); return mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIRTargetHook", {}); diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index e5ca82d5c099..ec301d10812f 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -215,6 +215,8 @@ runtime::Module CreateMetadataModule( String symbol = pf_sym(); Array variables = pf_var(); for (size_t i = 0; i < variables.size(); i++) { + VLOG(1) << "From module of type '" << mod->type_key() << "' found const var '" + << variables[i] << "' for symbol '" << symbol << "'"; symbol_const_vars.push_back(variables[i].operator std::string()); } ICHECK_EQ(const_vars_by_symbol.count(symbol), 0U) << "Found duplicated symbol: " << symbol; diff --git a/src/tir/transforms/extract_constants.cc b/src/tir/transforms/extract_constants.cc index 237f923516da..f9e620ba3322 100644 --- a/src/tir/transforms/extract_constants.cc +++ b/src/tir/transforms/extract_constants.cc @@ -80,14 +80,14 @@ tvm::transform::Pass ExtractPrimFuncConstants() { } auto* attrs = m->attrs.CopyOnWrite(); ConstArrayType constant_array_ = - (attrs->dict.count(tvm::attr::kConstantsArray)) - ? Downcast(attrs->dict[tvm::attr::kConstantsArray]) + (attrs->dict.count(tvm::attr::kConstants)) + ? Downcast(attrs->dict[tvm::attr::kConstants]) : ConstArrayType(); Applicator a = Applicator(); func->body = a.Apply(func->body, constant_array_); const ConstArrayType constant_list = a.constant_array_; if (constant_list.size()) { - attrs->dict.Set(tvm::attr::kConstantsArray, constant_list); + attrs->dict.Set(tvm::attr::kConstants, constant_list); } return GetRef(func); }; diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 4f451a125184..873475ac1ce7 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -235,37 +235,29 @@ def make_mod(): @pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") -def test_extern_gcc_consts(): - @tvm._ffi.register_func("relay.ext.ccompiler.constant_updater") - def constant_updater(expr, symbol): - """A dummy constant updater just to test that a custom one works.""" - return {"ccompiler_0_p0": tvm.nd.array(y0_data)} - - x = relay.var("x", shape=(8, 8)) - y0_data = np.random.uniform(0, 1, (8, 8)).astype("float32") +@pytest.mark.parametrize("check_result", [check_graph_executor_result, check_vm_result]) +def test_extern_gcc_consts(check_result): + shape = (8, 8) + dtype = "float32" + x = relay.var("x", shape=shape) + y0_data = np.random.uniform(0, 1, shape).astype(dtype) - x0 = relay.var("x0", shape=(8, 8)) - y0_const = relay.const(y0_data, "float32") + x0 = relay.var("x0", shape=shape) + y0_const = relay.const(y0_data, dtype) z = x0 + y0_const f = relay.Function([x0], z) f = set_external_func_attr(f, "ccompiler", "ccompiler_0") call = relay.Call(f, [x]) mod = tvm.IRModule.from_expr(call) - with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): - compiler = relay.backend.vm.VMCompiler() - compiler.lower(mod, "llvm") - compiler.codegen() - params = compiler.get_params() - assert len(params) == 1 - assert "ccompiler_0_p0" in params.keys() - - with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): - _, _, params = relay.build(mod, target="llvm") - assert len(params) == 1 - assert "ccompiler_0_p0" in params.keys() - - tvm._ffi.registry.remove_global_func("relay.ext.ccompiler.constant_updater") + # Note that while the VMCompiler get_params() will return all 'parameters' from both + # TVM and external codegen compiled code, the GraphExecutor.get_params() will return only + # those from non-external modules. So in the following we'll test by execution rather than + # test by inspection. + x_data = np.random.rand(*shape).astype(dtype) + inputs = {"x": x_data} + expected_result = x_data + y0_data + check_result(mod, inputs, shape, expected_result, target="llvm") @pytest.mark.skipif( diff --git a/tests/python/relay/transform/test_compiler_function_utils.py b/tests/python/relay/transform/test_compiler_function_utils.py index 66abeff8ab29..b1056f60b82b 100644 --- a/tests/python/relay/transform/test_compiler_function_utils.py +++ b/tests/python/relay/transform/test_compiler_function_utils.py @@ -75,6 +75,39 @@ def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float1 ) +def original_mod_let_bound(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + let %f = fn(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i1: Tensor[(2304, 768), float16], %y_0_i2: Tensor[(2304), float16], + Inline=1, Compiler="cutlass", global_symbol="tvmgen_default_cutlass_main_0", Primitive=1) -> Tensor[(1600, 2304), float16] { + %4 = fn (%FunctionVar_0_0: Tensor[(1600, 768), float16], %FunctionVar_0_1: Tensor[(2304, 768), float16], %FunctionVar_0_2: Tensor[(2304), float16], + PartitionedFromPattern="nn.dense_add_", Composite="cutlass.dense_bias") -> Tensor[(1600, 2304), float16] { + %5 = nn.dense(%FunctionVar_0_0, %FunctionVar_0_1, units=2304); + add(%5, %FunctionVar_0_2) + }; + %4(%y_0_i0, %y_0_i1, %y_0_i2) + }; + %1 = %f(%x0, meta[relay.Constant][0], meta[relay.Constant][1]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + """, + "from_string", + None, + metatable, + ) + + def expected_outlined_mod(): return tvm.parser.parse( """ @@ -175,6 +208,13 @@ def test_outline_compiler_functions_with_existing_global_symbols(): tvm.ir.assert_structural_equal(actual_outlined_mod, expected_outlined_mod(), map_free_vars=True) +def test_outline_let_bound_compiler_functions_with_existing_global_symbols(): + actual_outlined_mod = tvm.relay.transform.OutlineCompilerFunctionsWithExistingGlobalSymbols( + "cutlass" + )(original_mod_let_bound()) + tvm.ir.assert_structural_equal(actual_outlined_mod, expected_outlined_mod(), map_free_vars=True) + + def test_mark_compiler_functions_as_extern(): actual_extern_mod = tvm.relay.transform.MarkCompilerFunctionsAsExtern("cutlass")( expected_outlined_mod() diff --git a/tests/python/unittest/test_custom_datatypes.py b/tests/python/unittest/test_custom_datatypes.py index b135973718bc..e3cff18c51f8 100644 --- a/tests/python/unittest/test_custom_datatypes.py +++ b/tests/python/unittest/test_custom_datatypes.py @@ -21,6 +21,7 @@ import pytest import tvm import tvm.topi.testing +import tvm.testing from tvm import relay from tvm.relay.testing.layers import batch_norm_infer from tvm.target.datatype import ( @@ -560,4 +561,4 @@ def test_posites2(): if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_extract_constants.py b/tests/python/unittest/test_tir_transform_extract_constants.py index cb49e7286fbb..82f4f6515c09 100644 --- a/tests/python/unittest/test_tir_transform_extract_constants.py +++ b/tests/python/unittest/test_tir_transform_extract_constants.py @@ -18,6 +18,7 @@ import tvm from tvm import tir from tvm.script import tir as T +import tvm.testing @tvm.script.ir_module @@ -49,7 +50,7 @@ def constant3(a: T.handle) -> None: def test_const_extraction(): mod = tvm.tir.transform.ExtractPrimFuncConstants()(Module4) - constants = mod.attrs["Constants"] + constants = mod.attrs["constants"] assert len(constants) == 2 def _visit(stmt): @@ -63,4 +64,4 @@ def _visit(stmt): if __name__ == "__main__": - test_const_extraction() + tvm.testing.main()