Skip to content

Commit

Permalink
- accidentally introduced 'transforms' namespace
Browse files Browse the repository at this point in the history
- can't use default Target("tensorrt") arg
  • Loading branch information
mbs-octoml committed Jul 1, 2022
1 parent f173fbc commit 8539ef4
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 31 deletions.
9 changes: 8 additions & 1 deletion python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def get_tensorrt_use_fp16() -> bool:
def partition_for_tensorrt(
mod: tvm.IRModule,
params: Optional[Dict[str, tvm.nd.NDArray]] = None,
target: tvm.target.Target = tvm.target.Target("tensorrt"),
# CAUTION: Can't use default Target("tensorrt") here since the target kind is only available
# if is_tensorrt_compiler_enabled() == True.
target: Optional[tvm.target.Target] = None,
) -> tvm.IRModule:
"""Partition all functions in mod to greedily offload supported operators to TensorRT.
Expand All @@ -130,8 +132,13 @@ def partition_for_tensorrt(
The partitioned module.
"""
assert is_tensorrt_compiler_enabled(), "Can only partition for TensorRT if it is enabled"
if params:
mod["main"] = bind_params_by_name(mod["main"], params)
if target is None:
# Use a default target. The get_tensorrt_target() function will similarly create an
# equivalent default target when compilation continues after partitioning.
target = tvm.target.Target("tensorrt")

seq = tvm.transform.Sequential(
[
Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ class CodegenCModule {
};

/*! \brief The actual translation pass. */
transform::Pass CCompilerImpl() {
auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) {
tvm::transform::Pass CCompilerImpl() {
auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) {
VLOG(1) << "CCompilerImpl input:" << std::endl << PrettyPrint(mod);
Target target = GetCCompilerTarget();

Expand All @@ -388,10 +388,10 @@ transform::Pass CCompilerImpl() {
return tvm::transform::CreateModulePass(pass_func, 0, "CCompilerImpl", {});
}

transform::Pass CCompilerPass() {
tvm::transform::Pass CCompilerPass() {
return transform::Sequential(
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(),
transforms::MarkCompilerFunctionsAsExtern("ccompiler")});
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(),
transform::MarkCompilerFunctionsAsExtern("ccompiler")});
}

} // namespace contrib
Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -902,8 +902,8 @@ class CutlassModuleCodegen {
* \brief A small shim to redirect to the 'relay.ext.cutlass.compile_for_cutlass' Python
* function which does the main CUTLASS training, c-code generation and compilation steps.
*/
transform::Pass CompileForCutlassImpl() {
auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) {
tvm::transform::Pass CompileForCutlassImpl() {
auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) {
VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod);
const auto* pf = runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass");
ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function";
Expand All @@ -926,10 +926,10 @@ runtime::Module CreateCSourceModule(const IRModule& mod) {

TVM_REGISTER_GLOBAL("relay.ext.cutlass.create_c_source_module").set_body_typed(CreateCSourceModule);

transform::Pass CompileForCutlass() {
tvm::transform::Pass CompileForCutlass() {
return transform::Sequential(
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"),
CompileForCutlassImpl(), transforms::MarkCompilerFunctionsAsExtern("cutlass")});
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"),
CompileForCutlassImpl(), transform::MarkCompilerFunctionsAsExtern("cutlass")});
}

} // namespace cutlass
Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/contrib/tensorrt/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) {
* function will require a linear scan of imported runtime modules to find the matching
* TensorRTRuntimeModule implementing it.
*/
transform::Pass CompileForTensorRTImpl() {
auto pass_func = [](IRModule mod, const transform::PassContext& pass_ctx) {
tvm::transform::Pass CompileForTensorRTImpl() {
auto pass_func = [](IRModule mod, const tvm::transform::PassContext& pass_ctx) {
VLOG(1) << "CompileForTensorRT input:" << std::endl << PrettyPrint(mod);
Target target = GetTensorRTTarget();

Expand Down Expand Up @@ -400,10 +400,10 @@ transform::Pass CompileForTensorRTImpl() {
return tvm::transform::CreateModulePass(pass_func, 0, "CompileForTensorRT", {});
}

transform::Pass CompileForTensorRT() {
tvm::transform::Pass CompileForTensorRT() {
return transform::Sequential(
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("tensorrt"),
CompileForTensorRTImpl(), transforms::MarkCompilerFunctionsAsExtern("tensorrt")});
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("tensorrt"),
CompileForTensorRTImpl(), transform::MarkCompilerFunctionsAsExtern("tensorrt")});
}

} // namespace tensorrt
Expand Down
16 changes: 8 additions & 8 deletions src/relay/transforms/compiler_function_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@

#include "./compiler_function_utils.h"

#include "../op/call/call.h"
#include "tvm/relay/analysis.h"
#include "tvm/relay/expr_functor.h"
#include "tvm/relay/transform.h"

namespace tvm {
namespace relay {
namespace transforms {
namespace transform {
namespace {

/*!
Expand Down Expand Up @@ -211,8 +210,8 @@ GlobalVar ExistingGlobalSymbolCache::GetGlobalSymbol(const Function& function) {
return global_var;
}

transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache,
std::string compiler_filter) {
tvm::transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache,
std::string compiler_filter) {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[cache = std::move(cache), compiler_filter = std::move(compiler_filter)](
IRModule mod, transform::PassContext ctx) {
Expand All @@ -235,12 +234,13 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cach
}

// Any Java programmers in the house?
transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter) {
tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(
std::string compiler_filter) {
return OutlineCompilerFunctions(std::make_shared<ExistingGlobalSymbolCache>(),
std::move(compiler_filter));
}

transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) {
VLOG(1) << "MarkCompilerFunctionsAsExtern input:" << std::endl << PrettyPrint(mod);
Expand All @@ -262,7 +262,7 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {});
}

transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars) {
tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars) {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[global_vars = std::move(global_vars)](IRModule mod, transform::PassContext ctx) {
VLOG(1) << "InlineCompilerFunctionsBoundTo with global_vars: " << PrettyPrint(global_vars);
Expand Down Expand Up @@ -295,6 +295,6 @@ TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern")
TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctionsBoundTo")
.set_body_typed(InlineCompilerFunctionsBoundTo);

} // namespace transforms
} // namespace transform
} // namespace relay
} // namespace tvm
15 changes: 8 additions & 7 deletions src/relay/transforms/compiler_function_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

namespace tvm {
namespace relay {
namespace transforms {
namespace transform {

/*!
* \brief Abstract class representing a cache of unique global vars keyed by functions. This can
Expand Down Expand Up @@ -105,8 +105,8 @@ class ExistingGlobalSymbolCache : public GlobalSymbolCache {
* If \p compiler_filter is non-empty only functions with that as their attribute value are
* outlined.
*/
transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache,
std::string compiler_filter = "");
tvm::transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache,
std::string compiler_filter = "");

/*!
* \brief A pass to outline all let-bound and literal functions in direct call positions which have
Expand All @@ -119,7 +119,8 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cach
* This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism
* to prepare the IRModule before custom lowering.
*/
transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter = "");
tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(
std::string compiler_filter = "");

/*!
* \brief A pass to mark all global functions which have a "Compiler" attribute matching
Expand All @@ -132,7 +133,7 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co
* This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism to
* cleanup the IRModule after custom lowering.
*/
transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = "");
tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = "");

/*!
* \brief A pass to inline all global "Compiler" functions which are bound to a global var
Expand All @@ -142,9 +143,9 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = "");
* This pass may be useful for external codegen which needs to undo partitioning based on
* properties of the entire partition.
*/
transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars);
tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars);

} // namespace transforms
} // namespace transform
} // namespace relay
} // namespace tvm

Expand Down

0 comments on commit 8539ef4

Please sign in to comment.