From 8539ef4cd344c0357731b8135bb74fa3d25c6ac5 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Fri, 1 Jul 2022 10:58:09 -0700 Subject: [PATCH] - accidentally introduced 'transforms' namespace - can't use default Target("tensorrt") arg --- python/tvm/relay/op/contrib/tensorrt.py | 9 ++++++++- src/relay/backend/contrib/codegen_c/codegen.cc | 10 +++++----- src/relay/backend/contrib/cutlass/codegen.cc | 10 +++++----- src/relay/backend/contrib/tensorrt/codegen.cc | 10 +++++----- src/relay/transforms/compiler_function_utils.cc | 16 ++++++++-------- src/relay/transforms/compiler_function_utils.h | 15 ++++++++------- 6 files changed, 39 insertions(+), 31 deletions(-) diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index a499414fa5ac..d659f514d9a3 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -111,7 +111,9 @@ def get_tensorrt_use_fp16() -> bool: def partition_for_tensorrt( mod: tvm.IRModule, params: Optional[Dict[str, tvm.nd.NDArray]] = None, - target: tvm.target.Target = tvm.target.Target("tensorrt"), + # CAUTION: Can't use default Target("tensorrt") here since the target kind is only available + # if is_tensorrt_compiler_enabled() == True. + target: Optional[tvm.target.Target] = None, ) -> tvm.IRModule: """Partition all functions in mod to greedily offload supported operators to TensorRT. @@ -130,8 +132,13 @@ def partition_for_tensorrt( The partitioned module. """ + assert is_tensorrt_compiler_enabled(), "Can only partition for TensorRT if it is enabled" if params: mod["main"] = bind_params_by_name(mod["main"], params) + if target is None: + # Use a default target. The get_tensorrt_target() function will similarly create an + # equivalent default target when compilation continues after partitioning. + target = tvm.target.Target("tensorrt") seq = tvm.transform.Sequential( [ diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index ee8724fe92fe..41f0a0a06408 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -360,8 +360,8 @@ class CodegenCModule { }; /*! \brief The actual translation pass. */ -transform::Pass CCompilerImpl() { - auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) { +tvm::transform::Pass CCompilerImpl() { + auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) { VLOG(1) << "CCompilerImpl input:" << std::endl << PrettyPrint(mod); Target target = GetCCompilerTarget(); @@ -388,10 +388,10 @@ transform::Pass CCompilerImpl() { return tvm::transform::CreateModulePass(pass_func, 0, "CCompilerImpl", {}); } -transform::Pass CCompilerPass() { +tvm::transform::Pass CCompilerPass() { return transform::Sequential( - {transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(), - transforms::MarkCompilerFunctionsAsExtern("ccompiler")}); + {transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(), + transform::MarkCompilerFunctionsAsExtern("ccompiler")}); } } // namespace contrib diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index de2934173b5f..2e76ab1cbbf6 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -902,8 +902,8 @@ class CutlassModuleCodegen { * \brief A small shim to redirect to the 'relay.ext.cutlass.compile_for_cutlass' Python * function which does the main CUTLASS training, c-code generation and compilation steps. */ -transform::Pass CompileForCutlassImpl() { - auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) { +tvm::transform::Pass CompileForCutlassImpl() { + auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) { VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod); const auto* pf = runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass"); ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function"; @@ -926,10 +926,10 @@ runtime::Module CreateCSourceModule(const IRModule& mod) { TVM_REGISTER_GLOBAL("relay.ext.cutlass.create_c_source_module").set_body_typed(CreateCSourceModule); -transform::Pass CompileForCutlass() { +tvm::transform::Pass CompileForCutlass() { return transform::Sequential( - {transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"), - CompileForCutlassImpl(), transforms::MarkCompilerFunctionsAsExtern("cutlass")}); + {transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"), + CompileForCutlassImpl(), transform::MarkCompilerFunctionsAsExtern("cutlass")}); } } // namespace cutlass diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index 526f6bf7588a..dda5736b1be6 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -348,8 +348,8 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { * function will require a linear scan of imported runtime modules to find the matching * TensorRTRuntimeModule implementing it. */ -transform::Pass CompileForTensorRTImpl() { - auto pass_func = [](IRModule mod, const transform::PassContext& pass_ctx) { +tvm::transform::Pass CompileForTensorRTImpl() { + auto pass_func = [](IRModule mod, const tvm::transform::PassContext& pass_ctx) { VLOG(1) << "CompileForTensorRT input:" << std::endl << PrettyPrint(mod); Target target = GetTensorRTTarget(); @@ -400,10 +400,10 @@ transform::Pass CompileForTensorRTImpl() { return tvm::transform::CreateModulePass(pass_func, 0, "CompileForTensorRT", {}); } -transform::Pass CompileForTensorRT() { +tvm::transform::Pass CompileForTensorRT() { return transform::Sequential( - {transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("tensorrt"), - CompileForTensorRTImpl(), transforms::MarkCompilerFunctionsAsExtern("tensorrt")}); + {transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("tensorrt"), + CompileForTensorRTImpl(), transform::MarkCompilerFunctionsAsExtern("tensorrt")}); } } // namespace tensorrt diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc index 0df9f5ee294c..1dafcd10a361 100644 --- a/src/relay/transforms/compiler_function_utils.cc +++ b/src/relay/transforms/compiler_function_utils.cc @@ -24,14 +24,13 @@ #include "./compiler_function_utils.h" -#include "../op/call/call.h" #include "tvm/relay/analysis.h" #include "tvm/relay/expr_functor.h" #include "tvm/relay/transform.h" namespace tvm { namespace relay { -namespace transforms { +namespace transform { namespace { /*! @@ -211,8 +210,8 @@ GlobalVar ExistingGlobalSymbolCache::GetGlobalSymbol(const Function& function) { return global_var; } -transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, - std::string compiler_filter) { +tvm::transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, + std::string compiler_filter) { runtime::TypedPackedFunc pass_func = [cache = std::move(cache), compiler_filter = std::move(compiler_filter)]( IRModule mod, transform::PassContext ctx) { @@ -235,12 +234,13 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr cach } // Any Java programmers in the house? -transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter) { +tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols( + std::string compiler_filter) { return OutlineCompilerFunctions(std::make_shared(), std::move(compiler_filter)); } -transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { +tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { runtime::TypedPackedFunc pass_func = [compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) { VLOG(1) << "MarkCompilerFunctionsAsExtern input:" << std::endl << PrettyPrint(mod); @@ -262,7 +262,7 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {}); } -transform::Pass InlineCompilerFunctionsBoundTo(Array global_vars) { +tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array global_vars) { runtime::TypedPackedFunc pass_func = [global_vars = std::move(global_vars)](IRModule mod, transform::PassContext ctx) { VLOG(1) << "InlineCompilerFunctionsBoundTo with global_vars: " << PrettyPrint(global_vars); @@ -295,6 +295,6 @@ TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern") TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctionsBoundTo") .set_body_typed(InlineCompilerFunctionsBoundTo); -} // namespace transforms +} // namespace transform } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h index aa98430318a6..f3499faec262 100644 --- a/src/relay/transforms/compiler_function_utils.h +++ b/src/relay/transforms/compiler_function_utils.h @@ -66,7 +66,7 @@ namespace tvm { namespace relay { -namespace transforms { +namespace transform { /*! * \brief Abstract class representing a cache of unique global vars keyed by functions. This can @@ -105,8 +105,8 @@ class ExistingGlobalSymbolCache : public GlobalSymbolCache { * If \p compiler_filter is non-empty only functions with that as their attribute value are * outlined. */ -transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, - std::string compiler_filter = ""); +tvm::transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, + std::string compiler_filter = ""); /*! * \brief A pass to outline all let-bound and literal functions in direct call positions which have @@ -119,7 +119,8 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr cach * This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism * to prepare the IRModule before custom lowering. */ -transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter = ""); +tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols( + std::string compiler_filter = ""); /*! * \brief A pass to mark all global functions which have a "Compiler" attribute matching @@ -132,7 +133,7 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co * This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism to * cleanup the IRModule after custom lowering. */ -transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = ""); +tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = ""); /*! * \brief A pass to inline all global "Compiler" functions which are bound to a global var @@ -142,9 +143,9 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = ""); * This pass may be useful for external codegen which needs to undo partitioning based on * properties of the entire partition. */ -transform::Pass InlineCompilerFunctionsBoundTo(Array global_vars); +tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array global_vars); -} // namespace transforms +} // namespace transform } // namespace relay } // namespace tvm