Skip to content

Commit

Permalink
[TensorRT EP] Weightless API integration (#20412)
Browse files Browse the repository at this point in the history
This PR includes the weight-stripped engine feature (thanks @moraxu for
the #20214) which is the major feature for TRT 10 integration.

Two TRT EP options are added:

- `trt_weight_stripped_engine_enable`: Enable weight-stripped engine
build and refit.
- `trt_onnx_model_folder_path`: In the quick load case using embedded
engine model / EPContext mode, the original onnx filename is in the
node's attribute, and this option specifies the directory of that onnx
file if needed.

Normal weight-stripped engine workflow:

![image](https://github.com/microsoft/onnxruntime/assets/54722500/9f314865-cbda-4979-a7ac-b31c7a553b56)
Weight-stripped engine and quick load workflow:

![image](https://github.com/microsoft/onnxruntime/assets/54722500/9f31db51-a7a8-495b-ba25-54c7f904cbad)

see the doc [here
](https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#tensorrt-ep-caches)for
more information about EPContext model.

---------

Co-authored-by: yf711 <yifanl@microsoft.com>
Co-authored-by: Ye Wang <52801275+wangyems@users.noreply.github.com>
Co-authored-by: Michal Guzek <moraxu@users.noreply.github.com>
Co-authored-by: pengwa <pengwa@microsoft.com>
Co-authored-by: wejoncy <wejoncy@163.com>
Co-authored-by: Yi Zhang <zhanyi@microsoft.com>
Co-authored-by: Yi Zhang <your@email.com>
Co-authored-by: Pranav Sharma <prs@microsoft.com>
Co-authored-by: Adam Pocock <adam.pocock@oracle.com>
Co-authored-by: cao lei <jslhcl@gmail.com>
Co-authored-by: Adrian Lizarraga <adlizarraga@microsoft.com>
Co-authored-by: inisis <46103969+inisis@users.noreply.github.com>
Co-authored-by: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com>
Co-authored-by: mo-ja <60505697+mo-ja@users.noreply.github.com>
Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
Co-authored-by: Sumit Agarwal <sumitagarwal330@gmail.com>
Co-authored-by: Atanas Dimitrov <70822030+neNasko1@users.noreply.github.com>
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Co-authored-by: Yufeng Li <liyufeng1987@gmail.com>
Co-authored-by: Dhruv Matani <dhruvbird@gmail.com>
Co-authored-by: Dhruv Matani <dhruv.matani@grammarly.com>
Co-authored-by: wangshuai09 <391746016@qq.com>
Co-authored-by: Xiaoyu <85524621+xiaoyu-work@users.noreply.github.com>
Co-authored-by: Xu Xing <xing.xu@intel.com>
Co-authored-by: Dmitri Smirnov <yuslepukhin@users.noreply.github.com>
Co-authored-by: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com>
Co-authored-by: Sai Kishan Pampana <sai.kishan.pampana@intel.com>
Co-authored-by: rachguo <rachguo@rachguos-Mini.attlocal.net>
Co-authored-by: Jian Chen <cjian@microsoft.com>
Co-authored-by: Shubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com>
Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
Co-authored-by: Andrew Fantino <15876180+afantino951@users.noreply.github.com>
Co-authored-by: Thomas Boby <thomas@boby.uk>
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
Co-authored-by: Michal Guzek <mguzek@nvidia.com>
Co-authored-by: George Wu <jywu@microsoft.com>
  • Loading branch information
Show file tree
Hide file tree
Showing 13 changed files with 331 additions and 23 deletions.
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.</dd>
<dt><tt>notes</tt> : string</dt>
<dd>(Optional) Some notes for the model</dd>
<dt><tt>onnx_model_filename</tt> : string</dt>
<dd>(Optional) Filename of the original ONNX model.</dd>
<dt><tt>partition_name</tt> : string</dt>
<dd>(Optional) partitioned graph name.</dd>
<dt><tt>source</tt> : string</dt>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,20 @@ struct OrtTensorRTProviderOptionsV2 {
* - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir"
* - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir"
*
* 3. In the case of building weight-stripped engines, the same security reasons as listed in 1) apply to the
* "onnx_model_filename" node attribute of EP context node, which contains a filename of the ONNX model with the
* weights needed for the refit process. User can specify a folder path relative to the current working
* directory by means of the "trt_onnx_model_folder_path" option.
*
*/
int trt_dump_ep_context_model{0}; // Dump EP context node model
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
int trt_dump_ep_context_model{0}; // Dump EP context node model
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
int trt_weight_stripped_engine_enable{0}; // Enable weight-stripped engine build. Default 0 = false,
// nonzero = true
const char* trt_onnx_model_folder_path{nullptr}; // Folder path relative to the current working directory for
// the ONNX model containing the weights (applicable only when
// the "trt_weight_stripped_engine_enable" option is enabled)

const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
};
5 changes: 5 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3299,6 +3299,11 @@ void RegisterContribSchemas() {
"(Optional) SDK version used to convert the model.",
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr(
"onnx_model_filename",
"(Optional) Filename of the original ONNX model.",
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr(
"hardware_architecture",
"(Optional) Hardware architecture.",
Expand Down
69 changes: 65 additions & 4 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
#include "onnx_ctx_model_helper.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/framework/execution_provider.h"
#include "tensorrt_execution_provider.h"

namespace onnxruntime {
extern TensorrtLogger& GetTensorrtLogger(bool verbose_log);

/*
* Check whether the graph has the EP context contrib op.
Expand Down Expand Up @@ -67,7 +69,8 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
char* engine_data,
size_t size,
const int64_t embed_mode,
std::string compute_capability,
const std::string compute_capability,
const std::string onnx_model_path,
const logging::Logger* logger) {
auto model_build = graph_viewer.CreateModel(*logger);
auto& graph_build = model_build->MainGraph();
Expand All @@ -88,6 +91,7 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
auto attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); // embed_mode
auto attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); // ep_cache_context
auto attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); // hardware_architecture
auto attr_3 = ONNX_NAMESPACE::AttributeProto::Create(); // onnx_model_filename
std::string engine_data_str = "";
attr_0->set_name(EMBED_MODE);
attr_0->set_type(onnx::AttributeProto_AttributeType_INT);
Expand All @@ -106,13 +110,17 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
attr_2->set_name(COMPUTE_CAPABILITY);
attr_2->set_type(onnx::AttributeProto_AttributeType_STRING);
attr_2->set_s(compute_capability);
attr_3->set_name(ONNX_MODEL_FILENAME);
attr_3->set_type(onnx::AttributeProto_AttributeType_STRING);
attr_3->set_s(std::filesystem::path(onnx_model_path).filename().string());

auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create();
int num_attributes = 3;
constexpr int num_attributes = 4;
node_attributes->reserve(num_attributes);
node_attributes->emplace(EMBED_MODE, *attr_0);
node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1);
node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2);
node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_3);

// Create EP context node
graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
Expand Down Expand Up @@ -205,7 +213,7 @@ void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto,
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path;
}

bool IsAbsolutePath(std::string& path_string) {
bool IsAbsolutePath(const std::string& path_string) {
#ifdef _WIN32
onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
auto path = std::filesystem::path(ort_path_string.c_str());
Expand All @@ -219,7 +227,7 @@ bool IsAbsolutePath(std::string& path_string) {
}

// Like "../file_path"
bool IsRelativePathToParentPath(std::string& path_string) {
bool IsRelativePathToParentPath(const std::string& path_string) {
#ifdef _WIN32
onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
auto path = std::filesystem::path(ort_path_string.c_str());
Expand All @@ -236,6 +244,28 @@ bool IsRelativePathToParentPath(std::string& path_string) {
#endif
}

/*
* Get the weight-refitted engine cache path from a weight-stripped engine cache path
*
* Weight-stipped engine:
* An engine with weights stripped and its size is smaller than a regualr engine.
* The cache name of weight-stripped engine is TensorrtExecutionProvider_TRTKernel_XXXXX.stripped.engine
*
* Weight-refitted engine:
* An engine that its weights have been refitted and it's simply a regular engine.
* The cache name of weight-refitted engine is TensorrtExecutionProvider_TRTKernel_XXXXX.engine
*/
std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) {
std::filesystem::path stripped_engine_cache_path(stripped_engine_cache);
std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine";
return refitted_engine_cache_path;
}

bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) {
// The weight-stripped engine cache has the naming of xxx.stripped.engine
return engine_cache_path.stem().extension().string() == ".stripped";
}

Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
if (!ValidateEPCtxNode(graph_viewer)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node");
Expand Down Expand Up @@ -271,6 +301,22 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
// The engine cache and context model (current model) should be in the same directory
std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_));
auto engine_cache_path = ctx_model_dir.append(cache_path);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string();

// If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled
if (!weight_stripped_engine_refit_) {
weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path);
}

// If the serialized refitted engine is present, use it directly without refitting the engine again
if (weight_stripped_engine_refit_) {
const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string());
if (std::filesystem::exists(refitted_engine_cache_path)) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists.";
engine_cache_path = refitted_engine_cache_path.string();
weight_stripped_engine_refit_ = false;
}
}

if (!std::filesystem::exists(engine_cache_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
Expand All @@ -290,6 +336,21 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string());
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string();

if (weight_stripped_engine_refit_) {
const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s();
std::string weight_stripped_engine_cache = engine_cache_path.string();
auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename,
onnx_model_folder_path_,
weight_stripped_engine_cache,
true /* path check for security */,
(*trt_engine_).get(),
true /* serialize refitted engine to disk */,
detailed_build_log_);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
}
}
}
return Status::OK();
}
Expand Down
24 changes: 20 additions & 4 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <string>
#include <filesystem>
#include <memory>

#include "core/providers/tensorrt/nv_includes.h"
#include "core/providers/shared_library/provider_api.h"
Expand All @@ -15,6 +16,7 @@ static const std::string EPCONTEXT_OP = "EPContext";
static const std::string EMBED_MODE = "embed_mode";
static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
static const std::string COMPUTE_CAPABILITY = "hardware_architecture";
static const std::string ONNX_MODEL_FILENAME = "onnx_model_filename";
static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft";
static const std::string EPCONTEXT_WARNING =
"It's suggested to set the ORT graph optimization level to 0 and \
Expand All @@ -29,12 +31,13 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
char* engine_data,
size_t size,
const int64_t embed_mode,
std::string compute_capability,
const std::string compute_capability,
const std::string onnx_model_path,
const logging::Logger* logger);
std::string GetCtxModelPath(const std::string& ep_context_file_path,
const std::string& original_model_path);
bool IsAbsolutePath(std::string& path_string);
bool IsRelativePathToParentPath(std::string& path_string);
bool IsAbsolutePath(const std::string& path_string);
bool IsRelativePathToParentPath(const std::string& path_string);
void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string& ctx_model_path);
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
Expand All @@ -46,7 +49,17 @@ class TensorRTCacheModelHandler {
TensorRTCacheModelHandler(std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine,
nvinfer1::IRuntime* trt_runtime,
std::string ep_context_model_path,
std::string compute_capability) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), ep_context_model_path_(ep_context_model_path), compute_capability_(compute_capability) {
std::string compute_capability,
bool weight_stripped_engine_refit,
std::string onnx_model_folder_path,
bool detailed_build_log)
: trt_engine_(trt_engine),
trt_runtime_(trt_runtime),
ep_context_model_path_(ep_context_model_path),
compute_capability_(compute_capability),
weight_stripped_engine_refit_(weight_stripped_engine_refit),
onnx_model_folder_path_(onnx_model_folder_path),
detailed_build_log_(detailed_build_log) {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);

Expand All @@ -59,5 +72,8 @@ class TensorRTCacheModelHandler {
nvinfer1::IRuntime* trt_runtime_;
std::string ep_context_model_path_; // If using context model, it implies context model and engine cache is in the same directory
std::string compute_capability_;
bool weight_stripped_engine_refit_;
std::string onnx_model_folder_path_;
bool detailed_build_log_;
}; // TRTCacheModelHandler
} // namespace onnxruntime
Loading

0 comments on commit 454fcdd

Please sign in to comment.