Skip to content

Commit

Permalink
[Runtime] Add 'static_library' runtime::Module (#11442)
Browse files Browse the repository at this point in the history
(See https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6 for
context, which in turn is part of Collage (https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md).

This adds a new 'DSO exportable' runtime module representing the contents of a .o file. It
allows external codegen toolchains to yield a result which:
 - Like CSource modules, can be conveyed directly to the final export_library compilation
   step for linking into the final .so and saved to a know location without risk the
   underlying code artifact will be lost.
 - Like DSOLibrary modules, are self contained so that no additional compile-time arguments
   need be conveyed from the CSource module to the final export_library command line

Since this is the third flavor of 'DSO exportable' module, add a Module::IsDSOExportable.

Since adding the above, can't resist also adding a Module::ImplementsFunction virtual and
calling it from TEComplier to check if an external codegen function actually provided the
implementation it promised.

Note:
 - I've left the existing implementation of runtime.load_module alone which
   relinks .o files to .so files.
 - Though also contained in the .o metadata, I require static libraries to always
   carry their list of exported function names.

This is all pretty stop gap pending a good rework of TVM to supoprt the notion of artifacts
and, perhaps, build rules.
  • Loading branch information
mbs-octoml authored May 26, 2022
1 parent a9ece3d commit db5f4fe
Show file tree
Hide file tree
Showing 26 changed files with 356 additions and 61 deletions.
28 changes: 28 additions & 0 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <dmlc/io.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>

Expand Down Expand Up @@ -190,6 +191,33 @@ class TVM_DLL ModuleNode : public Object {
/*! \return The module it imports from */
const std::vector<Module>& imports() const { return imports_; }

/*!
* \brief Returns true if this module is 'DSO exportable'.
*
* A DSO exportable module (eg a CSourceModuleNode of type_key 'c') can be incorporated into the
* final runtime artifact (ie shared library) by compilation and/or linking using the external
* compiler (llvm, nvcc, etc). DSO exportable modules must implement SaveToFile.
*
* By contrast, non-DSO exportable modules (eg CUDAModuleNode of type_key 'cuda') typically must
* be incorporated into the final runtime artifact by being serialized as data into the
* artifact, then deserialized at runtime. Non-DSO exportable modules must implement SaveToBinary,
* and have a matching deserializer registered as 'runtime.module.loadbinary_<type_key>'.
*
* The default implementation returns false.
*/
virtual bool IsDSOExportable() const;

/*!
* \brief Returns true if this module has a definition for a function of \p name. If
* \p query_imports is true, also search in any imported modules.
*
* Note that even if this function returns true the corresponding \p GetFunction result may be
* nullptr if the function is not yet callable without further compilation.
*
* The default implementation just checkis if \p GetFunction is non-null.
*/
virtual bool ImplementsFunction(const String& name, bool query_imports = false);

// integration with the existing components.
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
static constexpr const char* _type_key = "runtime.Module";
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys
import os
import subprocess
import logging

from .._ffi.base import py_str

Expand Down Expand Up @@ -238,6 +239,7 @@ def _linux_compile(output, objects, options, compile_cmd, compile_shared=False):
cmd += objects
if options:
cmd += options
logging.info("invoking '%s'", cmd)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
Expand All @@ -264,6 +266,7 @@ def _windows_compile(output, objects, options):
cmd += options

try:
logging.info("invoking '%s'", cmd)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
except FileNotFoundError:
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import subprocess
import os
import warnings
import logging

import tvm._ffi
from tvm.target import Target
Expand Down Expand Up @@ -102,6 +103,7 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
# if cxx_compiler_path != "":
# cmd += ["-ccbin", cxx_compiler_path]

logging.info("invoking '%s'", cmd)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

(out, _) = proc.communicate()
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .object_generic import convert_to_object, convert, const
from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, ext_dev
from .module import load_module, enabled, system_lib
from .module import load_module, enabled, system_lib, load_static_library
from .container import String, ShapeTuple
from .params import save_param_dict, load_param_dict

Expand Down
52 changes: 46 additions & 6 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,28 @@ def entry_func(self):
self._entry = self.get_function(self.entry_name)
return self._entry

def implements_function(self, name, query_imports=False):
"""Returns True if the module has a definition for the global function with name. Note
that has_function(name) does not imply get_function(name) is non-null since the module
may be, eg, a CSourceModule which cannot supply a packed-func implementation of the function
without further compilation. However, get_function(name) non null should always imply
has_function(name).
Parameters
----------
name : str
The name of the function
query_imports : bool
Whether to also query modules imported by this module.
Returns
-------
b : Bool
True if module (or one of its imports) has a definition for name.
"""
return _ffi_api.ModuleImplementsFunction(self, name, query_imports)

def get_function(self, name, query_imports=False):
"""Get function from the module.
Expand Down Expand Up @@ -217,6 +239,18 @@ def imported_modules(self):
nmod = _ffi_api.ModuleImportsSize(self)
return [_ffi_api.ModuleGetImport(self, i) for i in range(nmod)]

@property
def is_dso_exportable(self):
"""Returns true if module is 'DSO exportable', ie can be included in result of
export_library by the external compiler directly.
Returns
-------
b : Bool
True if the module is DSO exportable.
"""
return _ffi_api.ModuleIsDSOExportable(self)

def save(self, file_name, fmt=""):
"""Save the module to file.
Expand Down Expand Up @@ -332,8 +366,7 @@ def _collect_from_import_tree(self, filter_func):
return dso_modules

def _collect_dso_modules(self):
is_dso_exportable = lambda m: (m.type_key == "llvm" or m.type_key == "c")
return self._collect_from_import_tree(is_dso_exportable)
return self._collect_from_import_tree(lambda m: m.is_dso_exportable)

def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=None, **kwargs):
"""
Expand Down Expand Up @@ -418,10 +451,7 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
else:
object_format = fcompile.object_format
else:
if module.type_key == "llvm":
object_format = "o"
else:
assert module.type_key == "c"
if module.type_key == "c":
if len(module.format) > 0:
assert module.format in [
"c",
Expand All @@ -436,6 +466,9 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
if kwargs["cc"] == "nvcc":
object_format = "cu"
has_c_module = True
else:
assert module.type_key == "llvm" or module.type_key == "static_library"
object_format = "o"
path_obj = os.path.join(workspace_dir, f"lib{index}.{object_format}")
module.save(path_obj)
files.append(path_obj)
Expand Down Expand Up @@ -552,6 +585,13 @@ def load_module(path, fmt=""):
return _ffi_api.ModuleLoadFromFile(path, fmt)


def load_static_library(path, func_names):
"""Load the .o library at path which implements functions with func_names.
Unlike the generic load_module the result will remain as a static_library
and will not be relinked on-the-fly into a .so library."""
return _ffi_api.ModuleLoadStaticLibrary(path, func_names)


def enabled(target):
"""Whether module runtime is enabled for target
Expand Down
2 changes: 1 addition & 1 deletion src/printer/model_library_format_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode {
bool show_warning)
: text_printer_{show_meta_data, annotate, show_warning} {}

const char* type_key() const override { return "model_library_format_printer"; }
const char* type_key() const final { return "model_library_format_printer"; }

std::string Print(const ObjectRef& node) {
Doc doc;
Expand Down
11 changes: 10 additions & 1 deletion src/relay/backend/contrib/ethosu/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,22 @@ class EthosUModuleNode : public ModuleNode {
return PackedFunc();
}

const char* type_key() const override { return "c"; }
const char* type_key() const final { return "c"; }

static Module Create(Array<CompilationArtifact> compilation_artifacts) {
auto n = make_object<EthosUModuleNode>(compilation_artifacts);
return Module(n);
}

bool IsDSOExportable() const final { return true; }

bool ImplementsFunction(const String& name, bool query_imports) final {
return std::find_if(compilation_artifacts_.begin(), compilation_artifacts_.end(),
[&name](const CompilationArtifact& artifact) {
return artifact->function_name == name;
}) != compilation_artifacts_.end();
}

private:
std::string c_source;
Array<CompilationArtifact> compilation_artifacts_;
Expand Down
29 changes: 16 additions & 13 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class TECompilerImpl : public TECompilerNode {
}
}
for (const auto& global_var : to_be_deleted) {
VLOG(1) << "Removing definition for external codegened '" << global_var->name_hint << "'";
module->Remove(global_var);
}
// HOWEVER we still need a Relay definition to go with those now external functions, so
Expand Down Expand Up @@ -203,27 +204,29 @@ class TECompilerImpl : public TECompilerNode {

std::string ext_name = "relay.ext." + opt_compiler.value();
auto pf = tvm::runtime::Registry::Get(ext_name);
ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
ICHECK(pf) << "Failed to find the external codegen tool for " << ext_name;
// No need to keep compiler attribute at this point, functions have been
// extracted for specific codegen.
src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
VLOG_CONTEXT << ext_name;
VLOG_CONTEXT << opt_compiler.value();
With<Target> with_target(it.first->target);
runtime::Module ext_mod = (*pf)(src_func);
if (ext_mod.defined()) {
if (ext_mod->GetFunction(opt_symbol_name.value(), /*query_imports=*/true) == nullptr) {
// It's possible the codegen yielded C or C++ tracked separately and thus the
// returned runtime module can be empty.
VLOG(1) << "Unable to find definition for the external function '"
<< opt_symbol_name.value()
<< "' in the runtime module generated by external codegen '"
<< opt_compiler.value() << "'";
// TODO(mbs): Can this be an ICHECKs?
if (!ext_mod->ImplementsFunction(opt_symbol_name.value())) {
VLOG(1) << "Note that the external codegen for '" << opt_compiler.value()
<< "' returned a runtime module which does not appear to implement '"
<< opt_symbol_name.value() << "'";
}
ret.push_back(ext_mod);
} else {
// A warning only so that we can write unit tests which can return an empty runtime
// module.
LOG(WARNING) << "No external runtime module was generated by external codegen '"
<< opt_compiler.value() << "'";
// It is valid for the external codegen function to return null:
// - Unit tests can use it.
// - The true compilation may have already been handled by a RelayToTIR custom hook pass
// on the Target's kind. The original Relay functions will be left in place so
// that we can capture that their function names are now externally defined.
VLOG(1) << "Note that no external runtime module was generated by external codegen '"
<< opt_compiler.value() << "'";
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class VMCompiler : public runtime::ModuleNode {

virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);

const char* type_key() const { return "VMCompiler"; }
const char* type_key() const final { return "VMCompiler"; }

/*!
* \brief Set the parameters
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/aot_executor/aot_executor_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class TVM_DLL AotExecutorFactory : public runtime::ModuleNode {
/*!
* \return The type key of the executor.
*/
const char* type_key() const override { return "AotExecutorFactory"; }
const char* type_key() const final { return "AotExecutorFactory"; }

/*!
* \brief Save the module to binary stream.
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/const_loader_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class ConstLoaderModuleNode : public ModuleNode {
return PackedFunc(nullptr);
}

const char* type_key() const { return "const_loader"; }
const char* type_key() const final { return "const_loader"; }

/*!
* \brief Get the list of constants that is required by the given module.
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/contrib/json/json_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class JSONRuntimeBase : public ModuleNode {
LoadGraph(graph_json_);
}

const char* type_key() const override { return "json"; }
const char* type_key() const override { return "json"; } // May be overridden

/*! \brief Initialize a specific json runtime. */
virtual void Init(const Array<NDArray>& consts) = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/contrib/tensorrt/tensorrt_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
*
* \return module type key.
*/
const char* type_key() const override { return "tensorrt"; }
const char* type_key() const final { return "tensorrt"; }

/*!
* \brief Initialize runtime. Create TensorRT layer from JSON
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/graph_executor/graph_executor_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class TVM_DLL GraphExecutorFactory : public runtime::ModuleNode {
/*!
* \return The type key of the executor.
*/
const char* type_key() const override { return "GraphExecutorFactory"; }
const char* type_key() const final { return "GraphExecutorFactory"; }

/*!
* \brief Save the module to binary stream.
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class MetadataModuleNode : public ::tvm::runtime::ModuleNode {
explicit MetadataModuleNode(runtime::metadata::Metadata metadata)
: metadata_{::std::move(metadata)} {}

const char* type_key() const { return "metadata_module"; }
const char* type_key() const final { return "metadata_module"; }

static Module LoadFromBinary() {
return Module(make_object<MetadataModuleNode>(runtime::metadata::Metadata()));
Expand Down
18 changes: 16 additions & 2 deletions src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Module Module::LoadFromFile(const std::string& file_name, const std::string& for
fmt = "so";
}
std::string load_f_name = "runtime.module.loadfile_" + fmt;
VLOG(1) << "Loading module from '" << file_name << "' of format '" << fmt << "'";
const PackedFunc* f = Registry::Get(load_f_name);
ICHECK(f != nullptr) << "Loader for `." << format << "` files is not registered,"
<< " resolved to (" << load_f_name << ") in the global registry."
Expand Down Expand Up @@ -132,6 +133,12 @@ std::string ModuleNode::GetFormat() {
return "";
}

bool ModuleNode::IsDSOExportable() const { return false; }

bool ModuleNode::ImplementsFunction(const String& name, bool query_imports) {
return GetFunction(name, query_imports) != nullptr;
}

bool RuntimeEnabled(const std::string& target) {
std::string f_name;
if (target == "cpu") {
Expand Down Expand Up @@ -191,8 +198,15 @@ TVM_REGISTER_GLOBAL("runtime.ModuleGetFormat").set_body_typed([](Module mod) {
TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile);

TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
.set_body_typed([](Module mod, tvm::String name, tvm::String fmt) {
mod->SaveToFile(name, fmt);
.set_body_typed([](Module mod, String name, tvm::String fmt) { mod->SaveToFile(name, fmt); });

TVM_REGISTER_GLOBAL("runtime.ModuleIsDSOExportable").set_body_typed([](Module mod) {
return mod->IsDSOExportable();
});

TVM_REGISTER_GLOBAL("runtime.ModuleImplementsFunction")
.set_body_typed([](Module mod, String name, bool query_imports) {
return mod->ImplementsFunction(std::move(name), query_imports);
});

TVM_REGISTER_OBJECT_TYPE(ModuleNode);
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/stackvm/stackvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace runtime {

class StackVMModuleNode : public runtime::ModuleNode {
public:
const char* type_key() const { return "stackvm"; }
const char* type_key() const final { return "stackvm"; }

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == runtime::symbol::tvm_module_main) {
Expand Down
Loading

0 comments on commit db5f4fe

Please sign in to comment.