diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f824b5535c470..bf38dd56247d9 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1597,6 +1597,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.
notes : string
(Optional) Some notes for the model
+
onnx_model_filename : string
+
(Optional) Filename of the original ONNX model.
partition_name : string
(Optional) partitioned graph name.
source : string
diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 32a9f06464ace..f71b60a9f5e13 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -64,10 +64,20 @@ struct OrtTensorRTProviderOptionsV2 { * - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir" * - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir" * + * 3. In the case of building weight-stripped engines, the same security reasons as listed in 1) apply to the + * "onnx_model_filename" node attribute of EP context node, which contains a filename of the ONNX model with the + * weights needed for the refit process. User can specify a folder path relative to the current working + * directory by means of the "trt_onnx_model_folder_path" option. + * */ - int trt_dump_ep_context_model{0}; // Dump EP context node model - const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path. - int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data + int trt_dump_ep_context_model{0}; // Dump EP context node model + const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path. + int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data + int trt_weight_stripped_engine_enable{0}; // Enable weight-stripped engine build. Default 0 = false, + // nonzero = true + const char* trt_onnx_model_folder_path{nullptr}; // Folder path relative to the current working directory for + // the ONNX model containing the weights (applicable only when + // the "trt_weight_stripped_engine_enable" option is enabled) const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix }; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index c3805d119a18d..dea8775c89a30 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3299,6 +3299,11 @@ void RegisterContribSchemas() { "(Optional) SDK version used to convert the model.", AttributeProto::STRING, OPTIONAL_VALUE) + .Attr( + "onnx_model_filename", + "(Optional) Filename of the original ONNX model.", + AttributeProto::STRING, + OPTIONAL_VALUE) .Attr( "hardware_architecture", "(Optional) Hardware architecture.", diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index 1994d1f5ab0b8..959da4944ff59 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -8,8 +8,10 @@ #include "onnx_ctx_model_helper.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/framework/execution_provider.h" +#include "tensorrt_execution_provider.h" namespace onnxruntime { +extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); /* * Check whether the graph has the EP context contrib op. @@ -67,7 +69,8 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, char* engine_data, size_t size, const int64_t embed_mode, - std::string compute_capability, + const std::string compute_capability, + const std::string onnx_model_path, const logging::Logger* logger) { auto model_build = graph_viewer.CreateModel(*logger); auto& graph_build = model_build->MainGraph(); @@ -88,6 +91,7 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, auto attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); // embed_mode auto attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); // ep_cache_context auto attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); // hardware_architecture + auto attr_3 = ONNX_NAMESPACE::AttributeProto::Create(); // onnx_model_filename std::string engine_data_str = ""; attr_0->set_name(EMBED_MODE); attr_0->set_type(onnx::AttributeProto_AttributeType_INT); @@ -106,13 +110,17 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, attr_2->set_name(COMPUTE_CAPABILITY); attr_2->set_type(onnx::AttributeProto_AttributeType_STRING); attr_2->set_s(compute_capability); + attr_3->set_name(ONNX_MODEL_FILENAME); + attr_3->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_3->set_s(std::filesystem::path(onnx_model_path).filename().string()); auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create(); - int num_attributes = 3; + constexpr int num_attributes = 4; node_attributes->reserve(num_attributes); node_attributes->emplace(EMBED_MODE, *attr_0); node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1); node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2); + node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_3); // Create EP context node graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN); @@ -205,7 +213,7 @@ void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto, LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path; } -bool IsAbsolutePath(std::string& path_string) { +bool IsAbsolutePath(const std::string& path_string) { #ifdef _WIN32 onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); auto path = std::filesystem::path(ort_path_string.c_str()); @@ -219,7 +227,7 @@ bool IsAbsolutePath(std::string& path_string) { } // Like "../file_path" -bool IsRelativePathToParentPath(std::string& path_string) { +bool IsRelativePathToParentPath(const std::string& path_string) { #ifdef _WIN32 onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); auto path = std::filesystem::path(ort_path_string.c_str()); @@ -236,6 +244,28 @@ bool IsRelativePathToParentPath(std::string& path_string) { #endif } +/* + * Get the weight-refitted engine cache path from a weight-stripped engine cache path + * + * Weight-stipped engine: + * An engine with weights stripped and its size is smaller than a regualr engine. + * The cache name of weight-stripped engine is TensorrtExecutionProvider_TRTKernel_XXXXX.stripped.engine + * + * Weight-refitted engine: + * An engine that its weights have been refitted and it's simply a regular engine. + * The cache name of weight-refitted engine is TensorrtExecutionProvider_TRTKernel_XXXXX.engine + */ +std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) { + std::filesystem::path stripped_engine_cache_path(stripped_engine_cache); + std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine"; + return refitted_engine_cache_path; +} + +bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) { + // The weight-stripped engine cache has the naming of xxx.stripped.engine + return engine_cache_path.stem().extension().string() == ".stripped"; +} + Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) { if (!ValidateEPCtxNode(graph_viewer)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node"); @@ -271,6 +301,22 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph // The engine cache and context model (current model) should be in the same directory std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_)); auto engine_cache_path = ctx_model_dir.append(cache_path); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); + + // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled + if (!weight_stripped_engine_refit_) { + weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path); + } + + // If the serialized refitted engine is present, use it directly without refitting the engine again + if (weight_stripped_engine_refit_) { + const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string()); + if (std::filesystem::exists(refitted_engine_cache_path)) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists."; + engine_cache_path = refitted_engine_cache_path.string(); + weight_stripped_engine_refit_ = false; + } + } if (!std::filesystem::exists(engine_cache_path)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, @@ -290,6 +336,21 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph "TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string()); } LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string(); + + if (weight_stripped_engine_refit_) { + const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s(); + std::string weight_stripped_engine_cache = engine_cache_path.string(); + auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, + onnx_model_folder_path_, + weight_stripped_engine_cache, + true /* path check for security */, + (*trt_engine_).get(), + true /* serialize refitted engine to disk */, + detailed_build_log_); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } } return Status::OK(); } diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h index 9f1e5178428e7..f8fefc12c3453 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h @@ -5,6 +5,7 @@ #include #include +#include #include "core/providers/tensorrt/nv_includes.h" #include "core/providers/shared_library/provider_api.h" @@ -15,6 +16,7 @@ static const std::string EPCONTEXT_OP = "EPContext"; static const std::string EMBED_MODE = "embed_mode"; static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; static const std::string COMPUTE_CAPABILITY = "hardware_architecture"; +static const std::string ONNX_MODEL_FILENAME = "onnx_model_filename"; static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft"; static const std::string EPCONTEXT_WARNING = "It's suggested to set the ORT graph optimization level to 0 and \ @@ -29,12 +31,13 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, char* engine_data, size_t size, const int64_t embed_mode, - std::string compute_capability, + const std::string compute_capability, + const std::string onnx_model_path, const logging::Logger* logger); std::string GetCtxModelPath(const std::string& ep_context_file_path, const std::string& original_model_path); -bool IsAbsolutePath(std::string& path_string); -bool IsRelativePathToParentPath(std::string& path_string); +bool IsAbsolutePath(const std::string& path_string); +bool IsRelativePathToParentPath(const std::string& path_string); void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto, const std::string& ctx_model_path); void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto, @@ -46,7 +49,17 @@ class TensorRTCacheModelHandler { TensorRTCacheModelHandler(std::unique_ptr* trt_engine, nvinfer1::IRuntime* trt_runtime, std::string ep_context_model_path, - std::string compute_capability) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), ep_context_model_path_(ep_context_model_path), compute_capability_(compute_capability) { + std::string compute_capability, + bool weight_stripped_engine_refit, + std::string onnx_model_folder_path, + bool detailed_build_log) + : trt_engine_(trt_engine), + trt_runtime_(trt_runtime), + ep_context_model_path_(ep_context_model_path), + compute_capability_(compute_capability), + weight_stripped_engine_refit_(weight_stripped_engine_refit), + onnx_model_folder_path_(onnx_model_folder_path), + detailed_build_log_(detailed_build_log) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler); @@ -59,5 +72,8 @@ class TensorRTCacheModelHandler { nvinfer1::IRuntime* trt_runtime_; std::string ep_context_model_path_; // If using context model, it implies context model and engine cache is in the same directory std::string compute_capability_; + bool weight_stripped_engine_refit_; + std::string onnx_model_folder_path_; + bool detailed_build_log_; }; // TRTCacheModelHandler } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 7c9248bca8e53..0c3886ccf6b38 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1238,6 +1238,13 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; + // incase the EP context is dumped the engine cache has to be enabled + auto enable_engine_cache_for_ep_context_model = [this]() { + if (dump_ep_context_model_ && ep_context_embed_mode_ == 0) { + engine_cache_enable_ = true; + } + }; + // Get environment variables if (info.has_trt_options) { max_partition_iterations_ = info.max_partition_iterations; @@ -1255,12 +1262,15 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv } dump_subgraphs_ = info.dump_subgraphs; engine_cache_enable_ = info.engine_cache_enable; + weight_stripped_engine_enable_ = info.weight_stripped_engine_enable; + onnx_model_folder_path_ = info.onnx_model_folder_path; timing_cache_enable_ = info.timing_cache_enable; force_timing_cache_match_ = info.force_timing_cache; detailed_build_log_ = info.detailed_build_log; dump_ep_context_model_ = info.dump_ep_context_model; ep_context_file_path_ = info.ep_context_file_path; ep_context_embed_mode_ = info.ep_context_embed_mode; + enable_engine_cache_for_ep_context_model(); if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { cache_path_ = info.engine_cache_path; cache_prefix_ = info.engine_cache_prefix; @@ -1354,6 +1364,16 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv engine_cache_enable_ = (std::stoi(engine_cache_enable_env) == 0 ? false : true); } + const std::string weight_stripped_engine_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kWeightStrippedEngineEnable); + if (!weight_stripped_engine_enable_env.empty()) { + weight_stripped_engine_enable_ = std::stoi(weight_stripped_engine_enable_env) != 0; + } + + const std::string onnx_model_folder_path_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kOnnxModelFolderPath); + if (!onnx_model_folder_path_env.empty()) { + onnx_model_folder_path_ = onnx_model_folder_path_env; + } + const std::string timing_cache_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTimingCacheEnable); if (!timing_cache_enable_env.empty()) { timing_cache_enable_ = (std::stoi(timing_cache_enable_env) == 0 ? false : true); @@ -1384,6 +1404,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env); } + enable_engine_cache_for_ep_context_model(); + if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { const std::string engine_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePath); cache_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCachePath); @@ -1623,6 +1645,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_dla_core: " << dla_core_ << ", trt_dump_subgraphs: " << dump_subgraphs_ << ", trt_engine_cache_enable: " << engine_cache_enable_ + << ", trt_weight_stripped_engine_enable: " << weight_stripped_engine_enable_ + << ", trt_onnx_model_folder_path: " << onnx_model_folder_path_ << ", trt_cache_path: " << cache_path_ << ", trt_global_cache_path: " << global_cache_path_ << ", trt_engine_decryption_enable: " << engine_decryption_enable_ @@ -2275,7 +2299,6 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { // Construct subgraph capability from node list std::vector> result; - // Get ModelPath const auto& path_string = graph.ModelPath().ToPathString(); #ifdef _WIN32 @@ -2466,6 +2489,67 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, return result; } +/** + * Refit the weight-stripped engine + */ +common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_filename, + std::string& onnx_model_folder_path, + std::string& weight_stripped_engine_cath_path, + bool path_check, + nvinfer1::ICudaEngine* trt_engine, + bool serialize_refitted_engine, + bool detailed_build_log) { +#if NV_TENSORRT_MAJOR >= 10 + std::filesystem::path onnx_model_path{onnx_model_folder_path}; + onnx_model_path.append(onnx_model_filename); + if (path_check && IsAbsolutePath(onnx_model_path.string())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "For security purpose, the ONNX model path should be set with " + "a relative path, but it is an absolute path: " + + onnx_model_path.string()); + } + if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The ONNX model path has '..'. For security purpose, it's not " + "allowed to point outside the directory."); + } + + if (!std::filesystem::exists(onnx_model_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The ONNX model " + onnx_model_path.string() + + " does not exist."); + } + + // weight-stripped engine refit logic + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log); + auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); + auto parser_refitter = std::unique_ptr( + nvonnxparser::createParserRefitter(*refitter, trt_logger)); + if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + } + if (refitter->refitCudaEngine()) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully refitted the weight-stripped engine."; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + } + + // serialize the refitted engine to disk + if (serialize_refitted_engine) { + std::string refitted_engine_cache = GetWeightRefittedEnginePath(weight_stripped_engine_cath_path); + nvinfer1::IHostMemory* serialized_engine = trt_engine->serialize(); + std::ofstream engine_file(refitted_engine_cache, std::ios::binary | std::ios::out); + engine_file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialize the refitted engine to " << refitted_engine_cache; + } + return Status::OK(); +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP's IParserRefitter can only be used on TRT 10.0 onwards."); +#endif +} + common::Status TensorrtExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { for (auto& fused_node_graph : fused_nodes_and_graphs) { @@ -2489,7 +2573,11 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector= 10 + trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; + trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; +#else + LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; +#endif + } + // limit used tactic sources if (!tactic_sources_.empty()) { nvinfer1::TacticSources tactics = trt_config->getTacticSources(); @@ -2820,10 +2919,18 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_; - const std::string engine_cache_path = cache_path_prefix + ".engine"; + std::string engine_cache_path = cache_path_prefix + ".engine"; const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; const std::string profile_cache_path = cache_path_prefix + ".profile"; + // If weight-stripped engine is enabled and refitted engine cache is not present, + // TRT EP will use the engine cache with ".stripped.engine" appended to the end. + const std::filesystem::path engine_cache_fs_path = engine_cache_path; + if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { + engine_cache_path = cache_path_prefix + ".stripped.engine"; + weight_stripped_engine_refit_ = true; + } + // Generate file name for dumping ep context model if (dump_ep_context_model_ && ctx_model_path_.empty()) { ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); @@ -2863,6 +2970,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not deserialize engine from cache: " + engine_cache_path); } + } else if (engine_decryption_enable_ && engine_cache_enable_ && std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) { // Decrypt engine size_t engine_size = 0; @@ -2977,12 +3085,26 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView serialized_engine->size(), ep_context_embed_mode_, compute_capability_, + model_path_, GetLogger())}; DumpCtxModel(model_proto.get(), ctx_model_path_); } } } + if (weight_stripped_engine_refit_) { + auto status = RefitEngine(model_path_, + onnx_model_folder_path_, + engine_cache_path, + false /* path check for security */, + trt_engine.get(), + true /* serialize refitted engine to disk */, + detailed_build_log_); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } + // Build context // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading @@ -3049,6 +3171,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView 0, ep_context_embed_mode_, compute_capability_, + model_path_, GetLogger())); if (ep_context_embed_mode_ == 0) { DumpCtxModel(model_proto_.get(), ctx_model_path_); @@ -3069,11 +3192,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, - dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, - runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, - dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, - global_cache_path_, force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_, - builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics, cuda_graph_enable_, cache_prefix_, cache_suffix}; + dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, + engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], + context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, + engine_decryption_, engine_encryption_, timing_cache_enable_, global_cache_path_, force_timing_cache_match_, + detailed_build_log_, build_heuristics_enable_, sparsity_enable_, builder_optimization_level_, + auxiliary_streams_, !tactic_sources_.empty(), tactics, cuda_graph_enable_, cache_prefix_, cache_suffix}; *state = p.release(); return 0; }; @@ -3137,7 +3261,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); } const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_; - const std::string engine_cache_path = cache_path_prefix + ".engine"; + std::string engine_cache_path = cache_path_prefix + ".engine"; const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; const std::string profile_cache_path = cache_path_prefix + ".profile"; std::string timing_cache_path = ""; @@ -3145,6 +3269,14 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); } + // If weight-stripped engine is enabled and refitted engine cache is not present, + // TRT EP will use the engine cache with ".stripped.engine" appended to the end. + const std::filesystem::path engine_cache_fs_path = engine_cache_path; + if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { + engine_cache_path = cache_path_prefix + ".stripped.engine"; + weight_stripped_engine_refit_ = true; + } + // Load serialized engine if (trt_state->engine_cache_enable && trt_engine == nullptr) { std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); @@ -3173,6 +3305,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; trt_engine = trt_state->engine->get(); context_update = true; + } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) { shape_ranges = DeserializeProfileV2(profile_file); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; @@ -3285,6 +3418,16 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; } #endif + if (weight_stripped_engine_enable_) { +#if NV_TENSORRT_MAJOR >= 10 + trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled"; + trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled"; +#else + LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; +#endif + } // limit used tactic sources if (trt_state->filter_tactic_sources) { nvinfer1::TacticSources tactics = trt_config->getTacticSources(); @@ -3379,6 +3522,19 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView DumpCtxModel(model_proto_.get(), ctx_model_path_); } context_update = true; + + if (weight_stripped_engine_refit_) { + auto status = RefitEngine(model_path_, + onnx_model_folder_path_, + engine_cache_path, + false /* path check for security */, + trt_engine, + true /* serialize refitted engine to disk */, + detailed_build_log_); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } } if (context_update) { @@ -3579,7 +3735,13 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con std::unordered_map output_types; // TRT engine output name -> ORT output tensor type // Get engine binary data and deserialize it - auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, runtime_.get(), model_path_, compute_capability_); + auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, + runtime_.get(), + model_path_, + compute_capability_, + weight_stripped_engine_enable_, + onnx_model_folder_path_, + detailed_build_log_); auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index eabbbdea1c4ac..389cd471db2ae 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -27,7 +27,9 @@ static const std::string kDLACore = "ORT_TENSORRT_DLA_CORE"; static const std::string kDumpSubgraphs = "ORT_TENSORRT_DUMP_SUBGRAPHS"; static const std::string kEngineCacheEnable = "ORT_TENSORRT_ENGINE_CACHE_ENABLE"; static const std::string kCachePath = "ORT_TENSORRT_CACHE_PATH"; -// As a timing cache can be used across multiple ONNX files it makes sense to have a seperate cache path +static const std::string kWeightStrippedEngineEnable = "ORT_TENSORRT_WEIGHT_STRIPPED_ENGINE_ENABLE"; +static const std::string kOnnxModelFolderPath = "ORT_TENSORRT_ONNX_MODEL_FOLDER_PATH"; +// As a timing cache can be used across multiple ONNX files it makes sense to have a separate cache path static const std::string kTimingCachePath = "ORT_TENSORRT_GLOBAL_CACHE_PATH"; static const std::string kDecryptionEnable = "ORT_TENSORRT_ENGINE_DECRYPTION_ENABLE"; static const std::string kDecryptionLibPath = "ORT_TENSORRT_ENGINE_DECRYPTION_LIB_PATH"; @@ -217,6 +219,7 @@ struct SubGraphContext { using SubGraphContextMap = std::unordered_map>; using DDSOutputAllocatorMap = std::unordered_map>; +std::string GetWeightRefittedEnginePath(std::string engine_cache_path); // Logical device representation. class TensorrtExecutionProvider : public IExecutionProvider { @@ -263,6 +266,17 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool IsGraphCaptured(int graph_annotation_id) const override; Status ReplayGraph(int graph_annotation_id) override; + /** + * Refit the weight-stripped engine + */ + static common::Status RefitEngine(std::string onnx_model_filename, + std::string& onnx_model_folder_path, + std::string& weight_stripped_engine_cath_path, + bool path_check, + nvinfer1::ICudaEngine* trt_engine, + bool serialize_refitted_engine, + bool detailed_build_log); + private: mutable TensorrtExecutionProviderInfo info_; bool external_stream_ = false; @@ -280,6 +294,9 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool int8_use_native_tensorrt_calibration_table_ = false; bool dump_subgraphs_ = false; bool engine_cache_enable_ = false; + bool weight_stripped_engine_enable_ = false; + bool weight_stripped_engine_refit_ = false; + std::string onnx_model_folder_path_; bool build_heuristics_enable_ = false; bool sparsity_enable_ = false; int builder_optimization_level_ = 3; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index cd2087c9d7472..7f7587543d175 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -27,6 +27,8 @@ constexpr const char* kDLACore = "trt_dla_core"; constexpr const char* kDumpSubgraphs = "trt_dump_subgraphs"; constexpr const char* kEngineCacheEnable = "trt_engine_cache_enable"; constexpr const char* kEngineCachePath = "trt_engine_cache_path"; +constexpr const char* kWeightStrippedEngineEnable = "trt_weight_stripped_engine_enable"; +constexpr const char* kOnnxModelFolderPath = "trt_onnx_model_folder_path"; constexpr const char* kEngineCachePrefix = "trt_engine_cache_prefix"; constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable"; constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path"; @@ -92,6 +94,8 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kDumpSubgraphs, info.dump_subgraphs) .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCacheEnable, info.engine_cache_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePath, info.engine_cache_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kWeightStrippedEngineEnable, info.weight_stripped_engine_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kOnnxModelFolderPath, info.onnx_model_folder_path) .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePrefix, info.engine_cache_prefix) .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path) @@ -139,6 +143,8 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.dump_subgraphs)}, {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.engine_cache_enable)}, {tensorrt::provider_option_names::kEngineCachePath, MakeStringWithClassicLocale(info.engine_cache_path)}, + {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.weight_stripped_engine_enable)}, + {tensorrt::provider_option_names::kOnnxModelFolderPath, MakeStringWithClassicLocale(info.onnx_model_folder_path)}, {tensorrt::provider_option_names::kEngineCachePrefix, MakeStringWithClassicLocale(info.engine_cache_prefix)}, {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)}, {tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)}, @@ -180,6 +186,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor const std::string kProfilesMaxShapes_ = empty_if_null(info.trt_profile_max_shapes); const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes); const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path); + const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path); const ProviderOptions options{ {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, @@ -198,6 +205,8 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.trt_engine_cache_enable)}, {tensorrt::provider_option_names::kEngineCachePath, kEngineCachePath_}, {tensorrt::provider_option_names::kEngineCachePrefix, kEngineCachePrefix_}, + {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.trt_weight_stripped_engine_enable)}, + {tensorrt::provider_option_names::kOnnxModelFolderPath, kOnnxModelFolderPath_}, {tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.trt_engine_decryption_enable)}, {tensorrt::provider_option_names::kDecryptionLibPath, kDecryptionLibPath_}, {tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.trt_force_sequential_engine_build)}, @@ -289,6 +298,8 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_dla_core = internal_options.dla_core; trt_provider_options_v2.trt_dump_subgraphs = internal_options.dump_subgraphs; trt_provider_options_v2.trt_engine_cache_enable = internal_options.engine_cache_enable; + trt_provider_options_v2.trt_weight_stripped_engine_enable = internal_options.weight_stripped_engine_enable; + trt_provider_options_v2.trt_onnx_model_folder_path = copy_string_if_needed(internal_options.onnx_model_folder_path); trt_provider_options_v2.trt_engine_cache_path = copy_string_if_needed(internal_options.engine_cache_path); trt_provider_options_v2.trt_engine_cache_prefix = copy_string_if_needed(internal_options.engine_cache_prefix); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index 80424b8d6d196..df9d8456573fe 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -32,6 +32,8 @@ struct TensorrtExecutionProviderInfo { bool dump_subgraphs{false}; bool engine_cache_enable{false}; std::string engine_cache_path{""}; + bool weight_stripped_engine_enable{false}; + std::string onnx_model_folder_path{""}; bool engine_decryption_enable{false}; std::string engine_decryption_lib_path{""}; bool force_sequential_engine_build{false}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 568da57a50956..7beeba336d1a4 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -90,6 +90,8 @@ struct Tensorrt_Provider : Provider { info.dump_subgraphs = options.trt_dump_subgraphs != 0; info.engine_cache_enable = options.trt_engine_cache_enable != 0; info.engine_cache_path = options.trt_engine_cache_path == nullptr ? "" : options.trt_engine_cache_path; + info.weight_stripped_engine_enable = options.trt_weight_stripped_engine_enable != 0; + info.onnx_model_folder_path = options.trt_onnx_model_folder_path == nullptr ? "" : options.trt_onnx_model_folder_path; info.engine_decryption_enable = options.trt_engine_decryption_enable != 0; info.engine_decryption_lib_path = options.trt_engine_decryption_lib_path == nullptr ? "" : options.trt_engine_decryption_lib_path; info.force_sequential_engine_build = options.trt_force_sequential_engine_build != 0; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index ef56a7960c5f5..fdd2eadf7f97e 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2334,6 +2334,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor delete[] ptr->trt_profile_max_shapes; delete[] ptr->trt_profile_opt_shapes; delete[] ptr->trt_ep_context_file_path; + delete[] ptr->trt_onnx_model_folder_path; } std::unique_ptr p(ptr); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index bab32853ec493..f964cfbdae2d3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -501,7 +501,9 @@ std::unique_ptr CreateExecutionProviderInstance( // So we need these std::string variables defined here as they will be kept alive for the lifetime of TRT EP and we can still access them from OrtTensorRTProviderOptionsV2 instance. // (The reason is string copy is involved, for example params.trt_engine_cache_path = cache_path.c_str() and those std::string variable is referenced by OrtTensorRTProviderOptionsV2 instance // and TRT EP instance, so it won't be released.) - std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path; + std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources, + trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path, + onnx_model_folder_path; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { OrtTensorRTProviderOptionsV2 params; @@ -614,6 +616,21 @@ std::unique_ptr CreateExecutionProviderInstance( } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_prefix' should be a string to customize engine cache prefix i.e. 'FRCNN' or 'yolov4'.\n"); } + } else if (option.first == "trt_weight_stripped_engine_enable") { + if (option.second == "True" || option.second == "true") { + params.trt_weight_stripped_engine_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.trt_weight_stripped_engine_enable = false; + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_weight_stripped_engine_enable' should be 'True' or 'False'. Default value is 'False'.\n"); + } + } else if (option.first == "trt_onnx_model_folder_path") { + if (!option.second.empty()) { + onnx_model_folder_path = option.second; + params.trt_onnx_model_folder_path = onnx_model_folder_path.c_str(); + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_onnx_model_folder_path' should be a path string i.e. 'engine_cache'.\n"); + } } else if (option.first == "trt_engine_decryption_enable") { if (option.second == "True" || option.second == "true") { params.trt_engine_decryption_enable = true; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 62291762f61b8..e4b7e087c0839 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -112,6 +112,8 @@ namespace perftest { "\t [TensorRT only] [trt_engine_cache_enable]: Enable engine caching.\n" "\t [TensorRT only] [trt_engine_cache_path]: Specify engine cache path.\n" "\t [TensorRT only] [trt_engine_cache_prefix]: Customize engine cache prefix when trt_engine_cache_enable is true.\n" + "\t [TensorRT only] [trt_weight_stripped_engine_enable]: Enable weight-stripped engine build.\n" + "\t [TensorRT only] [trt_onnx_model_folder_path]: Folder path for the ONNX model with weights.\n" "\t [TensorRT only] [trt_force_sequential_engine_build]: Force TensorRT engines to be built sequentially.\n" "\t [TensorRT only] [trt_context_memory_sharing_enable]: Enable TensorRT context memory sharing between subgraphs.\n" "\t [TensorRT only] [trt_layer_norm_fp32_fallback]: Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow.\n"