diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index f824b5535c470..bf38dd56247d9 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -1597,6 +1597,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.
notes : string
(Optional) Some notes for the model
+onnx_model_filename : string
+(Optional) Filename of the original ONNX model.
partition_name : string
(Optional) partitioned graph name.
source : string
diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
index 32a9f06464ace..f71b60a9f5e13 100644
--- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
+++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
@@ -64,10 +64,20 @@ struct OrtTensorRTProviderOptionsV2 {
* - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir"
* - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir"
*
+ * 3. In the case of building weight-stripped engines, the same security reasons as listed in 1) apply to the
+ * "onnx_model_filename" node attribute of EP context node, which contains a filename of the ONNX model with the
+ * weights needed for the refit process. User can specify a folder path relative to the current working
+ * directory by means of the "trt_onnx_model_folder_path" option.
+ *
*/
- int trt_dump_ep_context_model{0}; // Dump EP context node model
- const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.
- int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
+ int trt_dump_ep_context_model{0}; // Dump EP context node model
+ const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.
+ int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
+ int trt_weight_stripped_engine_enable{0}; // Enable weight-stripped engine build. Default 0 = false,
+ // nonzero = true
+ const char* trt_onnx_model_folder_path{nullptr}; // Folder path relative to the current working directory for
+ // the ONNX model containing the weights (applicable only when
+ // the "trt_weight_stripped_engine_enable" option is enabled)
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
};
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index c3805d119a18d..dea8775c89a30 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3299,6 +3299,11 @@ void RegisterContribSchemas() {
"(Optional) SDK version used to convert the model.",
AttributeProto::STRING,
OPTIONAL_VALUE)
+ .Attr(
+ "onnx_model_filename",
+ "(Optional) Filename of the original ONNX model.",
+ AttributeProto::STRING,
+ OPTIONAL_VALUE)
.Attr(
"hardware_architecture",
"(Optional) Hardware architecture.",
diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
index 1994d1f5ab0b8..959da4944ff59 100644
--- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
+++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
@@ -8,8 +8,10 @@
#include "onnx_ctx_model_helper.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/framework/execution_provider.h"
+#include "tensorrt_execution_provider.h"
namespace onnxruntime {
+extern TensorrtLogger& GetTensorrtLogger(bool verbose_log);
/*
* Check whether the graph has the EP context contrib op.
@@ -67,7 +69,8 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
char* engine_data,
size_t size,
const int64_t embed_mode,
- std::string compute_capability,
+ const std::string compute_capability,
+ const std::string onnx_model_path,
const logging::Logger* logger) {
auto model_build = graph_viewer.CreateModel(*logger);
auto& graph_build = model_build->MainGraph();
@@ -88,6 +91,7 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
auto attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); // embed_mode
auto attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); // ep_cache_context
auto attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); // hardware_architecture
+ auto attr_3 = ONNX_NAMESPACE::AttributeProto::Create(); // onnx_model_filename
std::string engine_data_str = "";
attr_0->set_name(EMBED_MODE);
attr_0->set_type(onnx::AttributeProto_AttributeType_INT);
@@ -106,13 +110,17 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
attr_2->set_name(COMPUTE_CAPABILITY);
attr_2->set_type(onnx::AttributeProto_AttributeType_STRING);
attr_2->set_s(compute_capability);
+ attr_3->set_name(ONNX_MODEL_FILENAME);
+ attr_3->set_type(onnx::AttributeProto_AttributeType_STRING);
+ attr_3->set_s(std::filesystem::path(onnx_model_path).filename().string());
auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create();
- int num_attributes = 3;
+ constexpr int num_attributes = 4;
node_attributes->reserve(num_attributes);
node_attributes->emplace(EMBED_MODE, *attr_0);
node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1);
node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2);
+ node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_3);
// Create EP context node
graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
@@ -205,7 +213,7 @@ void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto,
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path;
}
-bool IsAbsolutePath(std::string& path_string) {
+bool IsAbsolutePath(const std::string& path_string) {
#ifdef _WIN32
onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
auto path = std::filesystem::path(ort_path_string.c_str());
@@ -219,7 +227,7 @@ bool IsAbsolutePath(std::string& path_string) {
}
// Like "../file_path"
-bool IsRelativePathToParentPath(std::string& path_string) {
+bool IsRelativePathToParentPath(const std::string& path_string) {
#ifdef _WIN32
onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
auto path = std::filesystem::path(ort_path_string.c_str());
@@ -236,6 +244,28 @@ bool IsRelativePathToParentPath(std::string& path_string) {
#endif
}
+/*
+ * Get the weight-refitted engine cache path from a weight-stripped engine cache path
+ *
+ * Weight-stipped engine:
+ * An engine with weights stripped and its size is smaller than a regualr engine.
+ * The cache name of weight-stripped engine is TensorrtExecutionProvider_TRTKernel_XXXXX.stripped.engine
+ *
+ * Weight-refitted engine:
+ * An engine that its weights have been refitted and it's simply a regular engine.
+ * The cache name of weight-refitted engine is TensorrtExecutionProvider_TRTKernel_XXXXX.engine
+ */
+std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) {
+ std::filesystem::path stripped_engine_cache_path(stripped_engine_cache);
+ std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine";
+ return refitted_engine_cache_path;
+}
+
+bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) {
+ // The weight-stripped engine cache has the naming of xxx.stripped.engine
+ return engine_cache_path.stem().extension().string() == ".stripped";
+}
+
Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
if (!ValidateEPCtxNode(graph_viewer)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node");
@@ -271,6 +301,22 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
// The engine cache and context model (current model) should be in the same directory
std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_));
auto engine_cache_path = ctx_model_dir.append(cache_path);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string();
+
+ // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled
+ if (!weight_stripped_engine_refit_) {
+ weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path);
+ }
+
+ // If the serialized refitted engine is present, use it directly without refitting the engine again
+ if (weight_stripped_engine_refit_) {
+ const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string());
+ if (std::filesystem::exists(refitted_engine_cache_path)) {
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists.";
+ engine_cache_path = refitted_engine_cache_path.string();
+ weight_stripped_engine_refit_ = false;
+ }
+ }
if (!std::filesystem::exists(engine_cache_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
@@ -290,6 +336,21 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string());
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string();
+
+ if (weight_stripped_engine_refit_) {
+ const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s();
+ std::string weight_stripped_engine_cache = engine_cache_path.string();
+ auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename,
+ onnx_model_folder_path_,
+ weight_stripped_engine_cache,
+ true /* path check for security */,
+ (*trt_engine_).get(),
+ true /* serialize refitted engine to disk */,
+ detailed_build_log_);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ }
+ }
}
return Status::OK();
}
diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
index 9f1e5178428e7..f8fefc12c3453 100644
--- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
+++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
@@ -5,6 +5,7 @@
#include
#include
+#include
#include "core/providers/tensorrt/nv_includes.h"
#include "core/providers/shared_library/provider_api.h"
@@ -15,6 +16,7 @@ static const std::string EPCONTEXT_OP = "EPContext";
static const std::string EMBED_MODE = "embed_mode";
static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
static const std::string COMPUTE_CAPABILITY = "hardware_architecture";
+static const std::string ONNX_MODEL_FILENAME = "onnx_model_filename";
static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft";
static const std::string EPCONTEXT_WARNING =
"It's suggested to set the ORT graph optimization level to 0 and \
@@ -29,12 +31,13 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
char* engine_data,
size_t size,
const int64_t embed_mode,
- std::string compute_capability,
+ const std::string compute_capability,
+ const std::string onnx_model_path,
const logging::Logger* logger);
std::string GetCtxModelPath(const std::string& ep_context_file_path,
const std::string& original_model_path);
-bool IsAbsolutePath(std::string& path_string);
-bool IsRelativePathToParentPath(std::string& path_string);
+bool IsAbsolutePath(const std::string& path_string);
+bool IsRelativePathToParentPath(const std::string& path_string);
void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string& ctx_model_path);
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
@@ -46,7 +49,17 @@ class TensorRTCacheModelHandler {
TensorRTCacheModelHandler(std::unique_ptr* trt_engine,
nvinfer1::IRuntime* trt_runtime,
std::string ep_context_model_path,
- std::string compute_capability) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), ep_context_model_path_(ep_context_model_path), compute_capability_(compute_capability) {
+ std::string compute_capability,
+ bool weight_stripped_engine_refit,
+ std::string onnx_model_folder_path,
+ bool detailed_build_log)
+ : trt_engine_(trt_engine),
+ trt_runtime_(trt_runtime),
+ ep_context_model_path_(ep_context_model_path),
+ compute_capability_(compute_capability),
+ weight_stripped_engine_refit_(weight_stripped_engine_refit),
+ onnx_model_folder_path_(onnx_model_folder_path),
+ detailed_build_log_(detailed_build_log) {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);
@@ -59,5 +72,8 @@ class TensorRTCacheModelHandler {
nvinfer1::IRuntime* trt_runtime_;
std::string ep_context_model_path_; // If using context model, it implies context model and engine cache is in the same directory
std::string compute_capability_;
+ bool weight_stripped_engine_refit_;
+ std::string onnx_model_folder_path_;
+ bool detailed_build_log_;
}; // TRTCacheModelHandler
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
index 7c9248bca8e53..0c3886ccf6b38 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
@@ -1238,6 +1238,13 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes;
+ // incase the EP context is dumped the engine cache has to be enabled
+ auto enable_engine_cache_for_ep_context_model = [this]() {
+ if (dump_ep_context_model_ && ep_context_embed_mode_ == 0) {
+ engine_cache_enable_ = true;
+ }
+ };
+
// Get environment variables
if (info.has_trt_options) {
max_partition_iterations_ = info.max_partition_iterations;
@@ -1255,12 +1262,15 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
}
dump_subgraphs_ = info.dump_subgraphs;
engine_cache_enable_ = info.engine_cache_enable;
+ weight_stripped_engine_enable_ = info.weight_stripped_engine_enable;
+ onnx_model_folder_path_ = info.onnx_model_folder_path;
timing_cache_enable_ = info.timing_cache_enable;
force_timing_cache_match_ = info.force_timing_cache;
detailed_build_log_ = info.detailed_build_log;
dump_ep_context_model_ = info.dump_ep_context_model;
ep_context_file_path_ = info.ep_context_file_path;
ep_context_embed_mode_ = info.ep_context_embed_mode;
+ enable_engine_cache_for_ep_context_model();
if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) {
cache_path_ = info.engine_cache_path;
cache_prefix_ = info.engine_cache_prefix;
@@ -1354,6 +1364,16 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
engine_cache_enable_ = (std::stoi(engine_cache_enable_env) == 0 ? false : true);
}
+ const std::string weight_stripped_engine_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kWeightStrippedEngineEnable);
+ if (!weight_stripped_engine_enable_env.empty()) {
+ weight_stripped_engine_enable_ = std::stoi(weight_stripped_engine_enable_env) != 0;
+ }
+
+ const std::string onnx_model_folder_path_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kOnnxModelFolderPath);
+ if (!onnx_model_folder_path_env.empty()) {
+ onnx_model_folder_path_ = onnx_model_folder_path_env;
+ }
+
const std::string timing_cache_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTimingCacheEnable);
if (!timing_cache_enable_env.empty()) {
timing_cache_enable_ = (std::stoi(timing_cache_enable_env) == 0 ? false : true);
@@ -1384,6 +1404,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env);
}
+ enable_engine_cache_for_ep_context_model();
+
if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) {
const std::string engine_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePath);
cache_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCachePath);
@@ -1623,6 +1645,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
<< ", trt_dla_core: " << dla_core_
<< ", trt_dump_subgraphs: " << dump_subgraphs_
<< ", trt_engine_cache_enable: " << engine_cache_enable_
+ << ", trt_weight_stripped_engine_enable: " << weight_stripped_engine_enable_
+ << ", trt_onnx_model_folder_path: " << onnx_model_folder_path_
<< ", trt_cache_path: " << cache_path_
<< ", trt_global_cache_path: " << global_cache_path_
<< ", trt_engine_decryption_enable: " << engine_decryption_enable_
@@ -2275,7 +2299,6 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
const IKernelLookup& /*kernel_lookup*/) const {
// Construct subgraph capability from node list
std::vector> result;
-
// Get ModelPath
const auto& path_string = graph.ModelPath().ToPathString();
#ifdef _WIN32
@@ -2466,6 +2489,67 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
return result;
}
+/**
+ * Refit the weight-stripped engine
+ */
+common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_filename,
+ std::string& onnx_model_folder_path,
+ std::string& weight_stripped_engine_cath_path,
+ bool path_check,
+ nvinfer1::ICudaEngine* trt_engine,
+ bool serialize_refitted_engine,
+ bool detailed_build_log) {
+#if NV_TENSORRT_MAJOR >= 10
+ std::filesystem::path onnx_model_path{onnx_model_folder_path};
+ onnx_model_path.append(onnx_model_filename);
+ if (path_check && IsAbsolutePath(onnx_model_path.string())) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "For security purpose, the ONNX model path should be set with "
+ "a relative path, but it is an absolute path: " +
+ onnx_model_path.string());
+ }
+ if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "The ONNX model path has '..'. For security purpose, it's not "
+ "allowed to point outside the directory.");
+ }
+
+ if (!std::filesystem::exists(onnx_model_path)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "The ONNX model " + onnx_model_path.string() +
+ " does not exist.");
+ }
+
+ // weight-stripped engine refit logic
+ TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log);
+ auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger));
+ auto parser_refitter = std::unique_ptr(
+ nvonnxparser::createParserRefitter(*refitter, trt_logger));
+ if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string());
+ }
+ if (refitter->refitCudaEngine()) {
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully refitted the weight-stripped engine.";
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string());
+ }
+
+ // serialize the refitted engine to disk
+ if (serialize_refitted_engine) {
+ std::string refitted_engine_cache = GetWeightRefittedEnginePath(weight_stripped_engine_cath_path);
+ nvinfer1::IHostMemory* serialized_engine = trt_engine->serialize();
+ std::ofstream engine_file(refitted_engine_cache, std::ios::binary | std::ios::out);
+ engine_file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size());
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialize the refitted engine to " << refitted_engine_cache;
+ }
+ return Status::OK();
+#else
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP's IParserRefitter can only be used on TRT 10.0 onwards.");
+#endif
+}
+
common::Status TensorrtExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs,
std::vector& node_compute_funcs) {
for (auto& fused_node_graph : fused_nodes_and_graphs) {
@@ -2489,7 +2573,11 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector= 10
+ trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled";
+ trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled";
+#else
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!";
+#endif
+ }
+
// limit used tactic sources
if (!tactic_sources_.empty()) {
nvinfer1::TacticSources tactics = trt_config->getTacticSources();
@@ -2820,10 +2919,18 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
// Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
// Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity
const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_;
- const std::string engine_cache_path = cache_path_prefix + ".engine";
+ std::string engine_cache_path = cache_path_prefix + ".engine";
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
const std::string profile_cache_path = cache_path_prefix + ".profile";
+ // If weight-stripped engine is enabled and refitted engine cache is not present,
+ // TRT EP will use the engine cache with ".stripped.engine" appended to the end.
+ const std::filesystem::path engine_cache_fs_path = engine_cache_path;
+ if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) {
+ engine_cache_path = cache_path_prefix + ".stripped.engine";
+ weight_stripped_engine_refit_ = true;
+ }
+
// Generate file name for dumping ep context model
if (dump_ep_context_model_ && ctx_model_path_.empty()) {
ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_);
@@ -2863,6 +2970,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path);
}
+
} else if (engine_decryption_enable_ && engine_cache_enable_ && std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) {
// Decrypt engine
size_t engine_size = 0;
@@ -2977,12 +3085,26 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
serialized_engine->size(),
ep_context_embed_mode_,
compute_capability_,
+ model_path_,
GetLogger())};
DumpCtxModel(model_proto.get(), ctx_model_path_);
}
}
}
+ if (weight_stripped_engine_refit_) {
+ auto status = RefitEngine(model_path_,
+ onnx_model_folder_path_,
+ engine_cache_path,
+ false /* path check for security */,
+ trt_engine.get(),
+ true /* serialize refitted engine to disk */,
+ detailed_build_log_);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ }
+ }
+
// Build context
// Note: Creating an execution context from an engine is thread safe per TRT doc
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
@@ -3049,6 +3171,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
0,
ep_context_embed_mode_,
compute_capability_,
+ model_path_,
GetLogger()));
if (ep_context_embed_mode_ == 0) {
DumpCtxModel(model_proto_.get(), ctx_model_path_);
@@ -3069,11 +3192,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
&parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name],
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
- dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
- runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_,
- dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_,
- global_cache_path_, force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_,
- builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics, cuda_graph_enable_, cache_prefix_, cache_suffix};
+ dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision,
+ engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name],
+ context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_,
+ engine_decryption_, engine_encryption_, timing_cache_enable_, global_cache_path_, force_timing_cache_match_,
+ detailed_build_log_, build_heuristics_enable_, sparsity_enable_, builder_optimization_level_,
+ auxiliary_streams_, !tactic_sources_.empty(), tactics, cuda_graph_enable_, cache_prefix_, cache_suffix};
*state = p.release();
return 0;
};
@@ -3137,7 +3261,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
}
const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_;
- const std::string engine_cache_path = cache_path_prefix + ".engine";
+ std::string engine_cache_path = cache_path_prefix + ".engine";
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
const std::string profile_cache_path = cache_path_prefix + ".profile";
std::string timing_cache_path = "";
@@ -3145,6 +3269,14 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_);
}
+ // If weight-stripped engine is enabled and refitted engine cache is not present,
+ // TRT EP will use the engine cache with ".stripped.engine" appended to the end.
+ const std::filesystem::path engine_cache_fs_path = engine_cache_path;
+ if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) {
+ engine_cache_path = cache_path_prefix + ".stripped.engine";
+ weight_stripped_engine_refit_ = true;
+ }
+
// Load serialized engine
if (trt_state->engine_cache_enable && trt_engine == nullptr) {
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
@@ -3173,6 +3305,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
trt_engine = trt_state->engine->get();
context_update = true;
+
} else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) {
shape_ranges = DeserializeProfileV2(profile_file);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
@@ -3285,6 +3418,16 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!";
}
#endif
+ if (weight_stripped_engine_enable_) {
+#if NV_TENSORRT_MAJOR >= 10
+ trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] STRIP_PLAN is enabled";
+ trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] REFIT_IDENTICAL is enabled";
+#else
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!";
+#endif
+ }
// limit used tactic sources
if (trt_state->filter_tactic_sources) {
nvinfer1::TacticSources tactics = trt_config->getTacticSources();
@@ -3379,6 +3522,19 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
DumpCtxModel(model_proto_.get(), ctx_model_path_);
}
context_update = true;
+
+ if (weight_stripped_engine_refit_) {
+ auto status = RefitEngine(model_path_,
+ onnx_model_folder_path_,
+ engine_cache_path,
+ false /* path check for security */,
+ trt_engine,
+ true /* serialize refitted engine to disk */,
+ detailed_build_log_);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ }
+ }
}
if (context_update) {
@@ -3579,7 +3735,13 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
std::unordered_map output_types; // TRT engine output name -> ORT output tensor type
// Get engine binary data and deserialize it
- auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, runtime_.get(), model_path_, compute_capability_);
+ auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine,
+ runtime_.get(),
+ model_path_,
+ compute_capability_,
+ weight_stripped_engine_enable_,
+ onnx_model_folder_path_,
+ detailed_build_log_);
auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
index eabbbdea1c4ac..389cd471db2ae 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
@@ -27,7 +27,9 @@ static const std::string kDLACore = "ORT_TENSORRT_DLA_CORE";
static const std::string kDumpSubgraphs = "ORT_TENSORRT_DUMP_SUBGRAPHS";
static const std::string kEngineCacheEnable = "ORT_TENSORRT_ENGINE_CACHE_ENABLE";
static const std::string kCachePath = "ORT_TENSORRT_CACHE_PATH";
-// As a timing cache can be used across multiple ONNX files it makes sense to have a seperate cache path
+static const std::string kWeightStrippedEngineEnable = "ORT_TENSORRT_WEIGHT_STRIPPED_ENGINE_ENABLE";
+static const std::string kOnnxModelFolderPath = "ORT_TENSORRT_ONNX_MODEL_FOLDER_PATH";
+// As a timing cache can be used across multiple ONNX files it makes sense to have a separate cache path
static const std::string kTimingCachePath = "ORT_TENSORRT_GLOBAL_CACHE_PATH";
static const std::string kDecryptionEnable = "ORT_TENSORRT_ENGINE_DECRYPTION_ENABLE";
static const std::string kDecryptionLibPath = "ORT_TENSORRT_ENGINE_DECRYPTION_LIB_PATH";
@@ -217,6 +219,7 @@ struct SubGraphContext {
using SubGraphContextMap = std::unordered_map>;
using DDSOutputAllocatorMap = std::unordered_map>;
+std::string GetWeightRefittedEnginePath(std::string engine_cache_path);
// Logical device representation.
class TensorrtExecutionProvider : public IExecutionProvider {
@@ -263,6 +266,17 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool IsGraphCaptured(int graph_annotation_id) const override;
Status ReplayGraph(int graph_annotation_id) override;
+ /**
+ * Refit the weight-stripped engine
+ */
+ static common::Status RefitEngine(std::string onnx_model_filename,
+ std::string& onnx_model_folder_path,
+ std::string& weight_stripped_engine_cath_path,
+ bool path_check,
+ nvinfer1::ICudaEngine* trt_engine,
+ bool serialize_refitted_engine,
+ bool detailed_build_log);
+
private:
mutable TensorrtExecutionProviderInfo info_;
bool external_stream_ = false;
@@ -280,6 +294,9 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool int8_use_native_tensorrt_calibration_table_ = false;
bool dump_subgraphs_ = false;
bool engine_cache_enable_ = false;
+ bool weight_stripped_engine_enable_ = false;
+ bool weight_stripped_engine_refit_ = false;
+ std::string onnx_model_folder_path_;
bool build_heuristics_enable_ = false;
bool sparsity_enable_ = false;
int builder_optimization_level_ = 3;
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc
index cd2087c9d7472..7f7587543d175 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc
@@ -27,6 +27,8 @@ constexpr const char* kDLACore = "trt_dla_core";
constexpr const char* kDumpSubgraphs = "trt_dump_subgraphs";
constexpr const char* kEngineCacheEnable = "trt_engine_cache_enable";
constexpr const char* kEngineCachePath = "trt_engine_cache_path";
+constexpr const char* kWeightStrippedEngineEnable = "trt_weight_stripped_engine_enable";
+constexpr const char* kOnnxModelFolderPath = "trt_onnx_model_folder_path";
constexpr const char* kEngineCachePrefix = "trt_engine_cache_prefix";
constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable";
constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path";
@@ -92,6 +94,8 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(tensorrt::provider_option_names::kDumpSubgraphs, info.dump_subgraphs)
.AddAssignmentToReference(tensorrt::provider_option_names::kEngineCacheEnable, info.engine_cache_enable)
.AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePath, info.engine_cache_path)
+ .AddAssignmentToReference(tensorrt::provider_option_names::kWeightStrippedEngineEnable, info.weight_stripped_engine_enable)
+ .AddAssignmentToReference(tensorrt::provider_option_names::kOnnxModelFolderPath, info.onnx_model_folder_path)
.AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePrefix, info.engine_cache_prefix)
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable)
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path)
@@ -139,6 +143,8 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kDumpSubgraphs, MakeStringWithClassicLocale(info.dump_subgraphs)},
{tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.engine_cache_enable)},
{tensorrt::provider_option_names::kEngineCachePath, MakeStringWithClassicLocale(info.engine_cache_path)},
+ {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.weight_stripped_engine_enable)},
+ {tensorrt::provider_option_names::kOnnxModelFolderPath, MakeStringWithClassicLocale(info.onnx_model_folder_path)},
{tensorrt::provider_option_names::kEngineCachePrefix, MakeStringWithClassicLocale(info.engine_cache_prefix)},
{tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)},
{tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)},
@@ -180,6 +186,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
const std::string kProfilesMaxShapes_ = empty_if_null(info.trt_profile_max_shapes);
const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes);
const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path);
+ const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path);
const ProviderOptions options{
{tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
@@ -198,6 +205,8 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
{tensorrt::provider_option_names::kEngineCacheEnable, MakeStringWithClassicLocale(info.trt_engine_cache_enable)},
{tensorrt::provider_option_names::kEngineCachePath, kEngineCachePath_},
{tensorrt::provider_option_names::kEngineCachePrefix, kEngineCachePrefix_},
+ {tensorrt::provider_option_names::kWeightStrippedEngineEnable, MakeStringWithClassicLocale(info.trt_weight_stripped_engine_enable)},
+ {tensorrt::provider_option_names::kOnnxModelFolderPath, kOnnxModelFolderPath_},
{tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.trt_engine_decryption_enable)},
{tensorrt::provider_option_names::kDecryptionLibPath, kDecryptionLibPath_},
{tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.trt_force_sequential_engine_build)},
@@ -289,6 +298,8 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options
trt_provider_options_v2.trt_dla_core = internal_options.dla_core;
trt_provider_options_v2.trt_dump_subgraphs = internal_options.dump_subgraphs;
trt_provider_options_v2.trt_engine_cache_enable = internal_options.engine_cache_enable;
+ trt_provider_options_v2.trt_weight_stripped_engine_enable = internal_options.weight_stripped_engine_enable;
+ trt_provider_options_v2.trt_onnx_model_folder_path = copy_string_if_needed(internal_options.onnx_model_folder_path);
trt_provider_options_v2.trt_engine_cache_path = copy_string_if_needed(internal_options.engine_cache_path);
trt_provider_options_v2.trt_engine_cache_prefix = copy_string_if_needed(internal_options.engine_cache_prefix);
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h
index 80424b8d6d196..df9d8456573fe 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h
@@ -32,6 +32,8 @@ struct TensorrtExecutionProviderInfo {
bool dump_subgraphs{false};
bool engine_cache_enable{false};
std::string engine_cache_path{""};
+ bool weight_stripped_engine_enable{false};
+ std::string onnx_model_folder_path{""};
bool engine_decryption_enable{false};
std::string engine_decryption_lib_path{""};
bool force_sequential_engine_build{false};
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc
index 568da57a50956..7beeba336d1a4 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc
@@ -90,6 +90,8 @@ struct Tensorrt_Provider : Provider {
info.dump_subgraphs = options.trt_dump_subgraphs != 0;
info.engine_cache_enable = options.trt_engine_cache_enable != 0;
info.engine_cache_path = options.trt_engine_cache_path == nullptr ? "" : options.trt_engine_cache_path;
+ info.weight_stripped_engine_enable = options.trt_weight_stripped_engine_enable != 0;
+ info.onnx_model_folder_path = options.trt_onnx_model_folder_path == nullptr ? "" : options.trt_onnx_model_folder_path;
info.engine_decryption_enable = options.trt_engine_decryption_enable != 0;
info.engine_decryption_lib_path = options.trt_engine_decryption_lib_path == nullptr ? "" : options.trt_engine_decryption_lib_path;
info.force_sequential_engine_build = options.trt_force_sequential_engine_build != 0;
diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc
index ef56a7960c5f5..fdd2eadf7f97e 100644
--- a/onnxruntime/core/session/provider_bridge_ort.cc
+++ b/onnxruntime/core/session/provider_bridge_ort.cc
@@ -2334,6 +2334,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor
delete[] ptr->trt_profile_max_shapes;
delete[] ptr->trt_profile_opt_shapes;
delete[] ptr->trt_ep_context_file_path;
+ delete[] ptr->trt_onnx_model_folder_path;
}
std::unique_ptr p(ptr);
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index bab32853ec493..f964cfbdae2d3 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -501,7 +501,9 @@ std::unique_ptr CreateExecutionProviderInstance(
// So we need these std::string variables defined here as they will be kept alive for the lifetime of TRT EP and we can still access them from OrtTensorRTProviderOptionsV2 instance.
// (The reason is string copy is involved, for example params.trt_engine_cache_path = cache_path.c_str() and those std::string variable is referenced by OrtTensorRTProviderOptionsV2 instance
// and TRT EP instance, so it won't be released.)
- std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path;
+ std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources,
+ trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path,
+ onnx_model_folder_path;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtTensorRTProviderOptionsV2 params;
@@ -614,6 +616,21 @@ std::unique_ptr CreateExecutionProviderInstance(
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_cache_prefix' should be a string to customize engine cache prefix i.e. 'FRCNN' or 'yolov4'.\n");
}
+ } else if (option.first == "trt_weight_stripped_engine_enable") {
+ if (option.second == "True" || option.second == "true") {
+ params.trt_weight_stripped_engine_enable = true;
+ } else if (option.second == "False" || option.second == "false") {
+ params.trt_weight_stripped_engine_enable = false;
+ } else {
+ ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_weight_stripped_engine_enable' should be 'True' or 'False'. Default value is 'False'.\n");
+ }
+ } else if (option.first == "trt_onnx_model_folder_path") {
+ if (!option.second.empty()) {
+ onnx_model_folder_path = option.second;
+ params.trt_onnx_model_folder_path = onnx_model_folder_path.c_str();
+ } else {
+ ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_onnx_model_folder_path' should be a path string i.e. 'engine_cache'.\n");
+ }
} else if (option.first == "trt_engine_decryption_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_engine_decryption_enable = true;
diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc
index 62291762f61b8..e4b7e087c0839 100644
--- a/onnxruntime/test/perftest/command_args_parser.cc
+++ b/onnxruntime/test/perftest/command_args_parser.cc
@@ -112,6 +112,8 @@ namespace perftest {
"\t [TensorRT only] [trt_engine_cache_enable]: Enable engine caching.\n"
"\t [TensorRT only] [trt_engine_cache_path]: Specify engine cache path.\n"
"\t [TensorRT only] [trt_engine_cache_prefix]: Customize engine cache prefix when trt_engine_cache_enable is true.\n"
+ "\t [TensorRT only] [trt_weight_stripped_engine_enable]: Enable weight-stripped engine build.\n"
+ "\t [TensorRT only] [trt_onnx_model_folder_path]: Folder path for the ONNX model with weights.\n"
"\t [TensorRT only] [trt_force_sequential_engine_build]: Force TensorRT engines to be built sequentially.\n"
"\t [TensorRT only] [trt_context_memory_sharing_enable]: Enable TensorRT context memory sharing between subgraphs.\n"
"\t [TensorRT only] [trt_layer_norm_fp32_fallback]: Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow.\n"