Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'static_library' runtime::Module #11442

Merged
merged 1 commit into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <dmlc/io.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>

Expand Down Expand Up @@ -190,6 +191,33 @@ class TVM_DLL ModuleNode : public Object {
/*! \return The module it imports from */
const std::vector<Module>& imports() const { return imports_; }

/*!
* \brief Returns true if this module is 'DSO exportable'.
*
* A DSO exportable module (eg a CSourceModuleNode of type_key 'c') can be incorporated into the
* final runtime artifact (ie shared library) by compilation and/or linking using the external
* compiler (llvm, nvcc, etc). DSO exportable modules must implement SaveToFile.
*
* By contrast, non-DSO exportable modules (eg CUDAModuleNode of type_key 'cuda') typically must
* be incorporated into the final runtime artifact by being serialized as data into the
* artifact, then deserialized at runtime. Non-DSO exportable modules must implement SaveToBinary,
* and have a matching deserializer registered as 'runtime.module.loadbinary_<type_key>'.
*
* The default implementation returns false.
*/
virtual bool IsDSOExportable() const;

/*!
* \brief Returns true if this module has a definition for a function of \p name. If
* \p query_imports is true, also search in any imported modules.
*
* Note that even if this function returns true the corresponding \p GetFunction result may be
* nullptr if the function is not yet callable without further compilation.
*
* The default implementation just checkis if \p GetFunction is non-null.
*/
virtual bool ImplementsFunction(const String& name, bool query_imports = false);

// integration with the existing components.
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
static constexpr const char* _type_key = "runtime.Module";
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys
import os
import subprocess
import logging

from .._ffi.base import py_str

Expand Down Expand Up @@ -238,6 +239,7 @@ def _linux_compile(output, objects, options, compile_cmd, compile_shared=False):
cmd += objects
if options:
cmd += options
logging.info("invoking '%s'", cmd)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
Expand All @@ -264,6 +266,7 @@ def _windows_compile(output, objects, options):
cmd += options

try:
logging.info("invoking '%s'", cmd)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
except FileNotFoundError:
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import subprocess
import os
import warnings
import logging

import tvm._ffi
from tvm.target import Target
Expand Down Expand Up @@ -102,6 +103,7 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
# if cxx_compiler_path != "":
# cmd += ["-ccbin", cxx_compiler_path]

logging.info("invoking '%s'", cmd)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

(out, _) = proc.communicate()
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .object_generic import convert_to_object, convert, const
from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, ext_dev
from .module import load_module, enabled, system_lib
from .module import load_module, enabled, system_lib, load_static_library
from .container import String, ShapeTuple
from .params import save_param_dict, load_param_dict

Expand Down
52 changes: 46 additions & 6 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,28 @@ def entry_func(self):
self._entry = self.get_function(self.entry_name)
return self._entry

def implements_function(self, name, query_imports=False):
"""Returns True if the module has a definition for the global function with name. Note
that has_function(name) does not imply get_function(name) is non-null since the module
may be, eg, a CSourceModule which cannot supply a packed-func implementation of the function
without further compilation. However, get_function(name) non null should always imply
has_function(name).

Parameters
----------
name : str
The name of the function

query_imports : bool
Whether to also query modules imported by this module.

Returns
-------
b : Bool
True if module (or one of its imports) has a definition for name.
"""
return _ffi_api.ModuleImplementsFunction(self, name, query_imports)

def get_function(self, name, query_imports=False):
"""Get function from the module.

Expand Down Expand Up @@ -217,6 +239,18 @@ def imported_modules(self):
nmod = _ffi_api.ModuleImportsSize(self)
return [_ffi_api.ModuleGetImport(self, i) for i in range(nmod)]

@property
def is_dso_exportable(self):
"""Returns true if module is 'DSO exportable', ie can be included in result of
export_library by the external compiler directly.

Returns
-------
b : Bool
True if the module is DSO exportable.
"""
return _ffi_api.ModuleIsDSOExportable(self)

def save(self, file_name, fmt=""):
"""Save the module to file.

Expand Down Expand Up @@ -332,8 +366,7 @@ def _collect_from_import_tree(self, filter_func):
return dso_modules

def _collect_dso_modules(self):
is_dso_exportable = lambda m: (m.type_key == "llvm" or m.type_key == "c")
return self._collect_from_import_tree(is_dso_exportable)
return self._collect_from_import_tree(lambda m: m.is_dso_exportable)

def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=None, **kwargs):
"""
Expand Down Expand Up @@ -418,10 +451,7 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
else:
object_format = fcompile.object_format
else:
if module.type_key == "llvm":
object_format = "o"
else:
assert module.type_key == "c"
if module.type_key == "c":
if len(module.format) > 0:
assert module.format in [
"c",
Expand All @@ -436,6 +466,9 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
if kwargs["cc"] == "nvcc":
object_format = "cu"
has_c_module = True
else:
assert module.type_key == "llvm" or module.type_key == "static_library"
object_format = "o"
path_obj = os.path.join(workspace_dir, f"lib{index}.{object_format}")
module.save(path_obj)
files.append(path_obj)
Expand Down Expand Up @@ -552,6 +585,13 @@ def load_module(path, fmt=""):
return _ffi_api.ModuleLoadFromFile(path, fmt)


def load_static_library(path, func_names):
"""Load the .o library at path which implements functions with func_names.
Unlike the generic load_module the result will remain as a static_library
and will not be relinked on-the-fly into a .so library."""
return _ffi_api.ModuleLoadStaticLibrary(path, func_names)


def enabled(target):
"""Whether module runtime is enabled for target

Expand Down
2 changes: 1 addition & 1 deletion src/printer/model_library_format_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode {
bool show_warning)
: text_printer_{show_meta_data, annotate, show_warning} {}

const char* type_key() const override { return "model_library_format_printer"; }
const char* type_key() const final { return "model_library_format_printer"; }

std::string Print(const ObjectRef& node) {
Doc doc;
Expand Down
11 changes: 10 additions & 1 deletion src/relay/backend/contrib/ethosu/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,22 @@ class EthosUModuleNode : public ModuleNode {
return PackedFunc();
}

const char* type_key() const override { return "c"; }
const char* type_key() const final { return "c"; }

static Module Create(Array<CompilationArtifact> compilation_artifacts) {
auto n = make_object<EthosUModuleNode>(compilation_artifacts);
return Module(n);
}

bool IsDSOExportable() const final { return true; }

bool ImplementsFunction(const String& name, bool query_imports) final {
return std::find_if(compilation_artifacts_.begin(), compilation_artifacts_.end(),
[&name](const CompilationArtifact& artifact) {
return artifact->function_name == name;
}) != compilation_artifacts_.end();
}

private:
std::string c_source;
Array<CompilationArtifact> compilation_artifacts_;
Expand Down
29 changes: 16 additions & 13 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class TECompilerImpl : public TECompilerNode {
}
}
for (const auto& global_var : to_be_deleted) {
VLOG(1) << "Removing definition for external codegened '" << global_var->name_hint << "'";
module->Remove(global_var);
}
// HOWEVER we still need a Relay definition to go with those now external functions, so
Expand Down Expand Up @@ -203,27 +204,29 @@ class TECompilerImpl : public TECompilerNode {

std::string ext_name = "relay.ext." + opt_compiler.value();
auto pf = tvm::runtime::Registry::Get(ext_name);
ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
ICHECK(pf) << "Failed to find the external codegen tool for " << ext_name;
// No need to keep compiler attribute at this point, functions have been
// extracted for specific codegen.
src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
VLOG_CONTEXT << ext_name;
VLOG_CONTEXT << opt_compiler.value();
With<Target> with_target(it.first->target);
runtime::Module ext_mod = (*pf)(src_func);
if (ext_mod.defined()) {
if (ext_mod->GetFunction(opt_symbol_name.value(), /*query_imports=*/true) == nullptr) {
// It's possible the codegen yielded C or C++ tracked separately and thus the
// returned runtime module can be empty.
VLOG(1) << "Unable to find definition for the external function '"
<< opt_symbol_name.value()
<< "' in the runtime module generated by external codegen '"
<< opt_compiler.value() << "'";
// TODO(mbs): Can this be an ICHECKs?
if (!ext_mod->ImplementsFunction(opt_symbol_name.value())) {
VLOG(1) << "Note that the external codegen for '" << opt_compiler.value()
<< "' returned a runtime module which does not appear to implement '"
<< opt_symbol_name.value() << "'";
}
ret.push_back(ext_mod);
} else {
// A warning only so that we can write unit tests which can return an empty runtime
// module.
LOG(WARNING) << "No external runtime module was generated by external codegen '"
<< opt_compiler.value() << "'";
// It is valid for the external codegen function to return null:
// - Unit tests can use it.
// - The true compilation may have already been handled by a RelayToTIR custom hook pass
// on the Target's kind. The original Relay functions will be left in place so
// that we can capture that their function names are now externally defined.
VLOG(1) << "Note that no external runtime module was generated by external codegen '"
<< opt_compiler.value() << "'";
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class VMCompiler : public runtime::ModuleNode {

virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);

const char* type_key() const { return "VMCompiler"; }
const char* type_key() const final { return "VMCompiler"; }

/*!
* \brief Set the parameters
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/aot_executor/aot_executor_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class TVM_DLL AotExecutorFactory : public runtime::ModuleNode {
/*!
* \return The type key of the executor.
*/
const char* type_key() const override { return "AotExecutorFactory"; }
const char* type_key() const final { return "AotExecutorFactory"; }

/*!
* \brief Save the module to binary stream.
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/const_loader_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class ConstLoaderModuleNode : public ModuleNode {
return PackedFunc(nullptr);
}

const char* type_key() const { return "const_loader"; }
const char* type_key() const final { return "const_loader"; }

/*!
* \brief Get the list of constants that is required by the given module.
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/contrib/json/json_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class JSONRuntimeBase : public ModuleNode {
LoadGraph(graph_json_);
}

const char* type_key() const override { return "json"; }
const char* type_key() const override { return "json"; } // May be overridden

/*! \brief Initialize a specific json runtime. */
virtual void Init(const Array<NDArray>& consts) = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/contrib/tensorrt/tensorrt_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
*
* \return module type key.
*/
const char* type_key() const override { return "tensorrt"; }
const char* type_key() const final { return "tensorrt"; }

/*!
* \brief Initialize runtime. Create TensorRT layer from JSON
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/graph_executor/graph_executor_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class TVM_DLL GraphExecutorFactory : public runtime::ModuleNode {
/*!
* \return The type key of the executor.
*/
const char* type_key() const override { return "GraphExecutorFactory"; }
const char* type_key() const final { return "GraphExecutorFactory"; }

/*!
* \brief Save the module to binary stream.
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class MetadataModuleNode : public ::tvm::runtime::ModuleNode {
explicit MetadataModuleNode(runtime::metadata::Metadata metadata)
: metadata_{::std::move(metadata)} {}

const char* type_key() const { return "metadata_module"; }
const char* type_key() const final { return "metadata_module"; }

static Module LoadFromBinary() {
return Module(make_object<MetadataModuleNode>(runtime::metadata::Metadata()));
Expand Down
18 changes: 16 additions & 2 deletions src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Module Module::LoadFromFile(const std::string& file_name, const std::string& for
fmt = "so";
}
std::string load_f_name = "runtime.module.loadfile_" + fmt;
VLOG(1) << "Loading module from '" << file_name << "' of format '" << fmt << "'";
const PackedFunc* f = Registry::Get(load_f_name);
ICHECK(f != nullptr) << "Loader for `." << format << "` files is not registered,"
<< " resolved to (" << load_f_name << ") in the global registry."
Expand Down Expand Up @@ -132,6 +133,12 @@ std::string ModuleNode::GetFormat() {
return "";
}

bool ModuleNode::IsDSOExportable() const { return false; }

bool ModuleNode::ImplementsFunction(const String& name, bool query_imports) {
return GetFunction(name, query_imports) != nullptr;
}

bool RuntimeEnabled(const std::string& target) {
std::string f_name;
if (target == "cpu") {
Expand Down Expand Up @@ -191,8 +198,15 @@ TVM_REGISTER_GLOBAL("runtime.ModuleGetFormat").set_body_typed([](Module mod) {
TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile);

TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
.set_body_typed([](Module mod, tvm::String name, tvm::String fmt) {
mod->SaveToFile(name, fmt);
.set_body_typed([](Module mod, String name, tvm::String fmt) { mod->SaveToFile(name, fmt); });

TVM_REGISTER_GLOBAL("runtime.ModuleIsDSOExportable").set_body_typed([](Module mod) {
return mod->IsDSOExportable();
});

TVM_REGISTER_GLOBAL("runtime.ModuleImplementsFunction")
.set_body_typed([](Module mod, String name, bool query_imports) {
return mod->ImplementsFunction(std::move(name), query_imports);
});

TVM_REGISTER_OBJECT_TYPE(ModuleNode);
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/stackvm/stackvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace runtime {

class StackVMModuleNode : public runtime::ModuleNode {
public:
const char* type_key() const { return "stackvm"; }
const char* type_key() const final { return "stackvm"; }

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == runtime::symbol::tvm_module_main) {
Expand Down
Loading