Skip to content

Commit

Permalink
[TensorRT EP] support user_compute_stream in python API (#20168)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

* Implement `user_compute_stream` python api for TensorRT EP
* Using this option will implicitly set `has_user_compute_stream` as
`true`
* Extend existing TRTEP unit test to verify `user_compute_stream` option
* This has been verified in local pytorch env, with
`torch.cuda.Stream()` passing into `user_compute_stream`:
```python
...
# Before inference
if torch.cuda.is_available():
    s = torch.cuda.Stream()
    option = {"user_compute_stream": str(s.cuda_stream)}
    sess.set_providers(["TensorrtExecutionProvider"], [option])
    options = sess.get_provider_options()

    assert "TensorrtExecutionProvider" in options
    assert options["TensorrtExecutionProvider"].get("user_compute_stream", "") == str(s.cuda_stream)
    assert options["TensorrtExecutionProvider"].get("has_user_compute_stream", "") == "1"
...
```
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Align with existing `user_compute_stream` python implementations for
[CUDA EP](https://github.com/microsoft/onnxruntime/pull/19229)/[ROCm
EP](#19619)
  • Loading branch information
yf711 committed Apr 16, 2024
1 parent e02aef1 commit 54f91ea
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace tensorrt {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kHasUserComputeStream = "has_user_compute_stream";
constexpr const char* kUserComputeStream = "user_compute_stream";
constexpr const char* kMaxPartitionIterations = "trt_max_partition_iterations";
constexpr const char* kMinSubgraphSize = "trt_min_subgraph_size";
constexpr const char* kMaxWorkspaceSize = "trt_max_workspace_size";
Expand Down Expand Up @@ -55,6 +56,7 @@ constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model";

TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) {
TensorrtExecutionProviderInfo info{};
void* user_compute_stream = nullptr;
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
.AddValueParser(
Expand All @@ -71,6 +73,14 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
})
.AddAssignmentToReference(tensorrt::provider_option_names::kMaxPartitionIterations, info.max_partition_iterations)
.AddAssignmentToReference(tensorrt::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream)
.AddValueParser(
tensorrt::provider_option_names::kUserComputeStream,
[&user_compute_stream](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
user_compute_stream = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddAssignmentToReference(tensorrt::provider_option_names::kMinSubgraphSize, info.min_subgraph_size)
.AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size)
.AddAssignmentToReference(tensorrt::provider_option_names::kFp16Enable, info.fp16_enable)
Expand Down Expand Up @@ -107,6 +117,8 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode)
.Parse(options)); // add new provider option here.

info.user_compute_stream = user_compute_stream;
info.has_user_compute_stream = (user_compute_stream != nullptr);
return info;
}

Expand All @@ -115,6 +127,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.max_partition_iterations)},
{tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
{tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
{tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.min_subgraph_size)},
{tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.max_workspace_size)},
{tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
Expand Down Expand Up @@ -171,6 +184,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
const ProviderOptions options{
{tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{tensorrt::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
{tensorrt::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
{tensorrt::provider_option_names::kMaxPartitionIterations, MakeStringWithClassicLocale(info.trt_max_partition_iterations)},
{tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.trt_min_subgraph_size)},
{tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.trt_max_workspace_size)},
Expand Down Expand Up @@ -253,10 +267,14 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options
trt_provider_options_v2.device_id = internal_options.device_id;

// The 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance can be set by C API UpdateTensorRTProviderOptionsWithValue() as well
// We only set the 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance if it is provided in options
// We only set the 'has_user_compute_stream' of the OrtTensorRTProviderOptionsV2 instance if it is provided in options or user_compute_stream is provided
if (options.find("has_user_compute_stream") != options.end()) {
trt_provider_options_v2.has_user_compute_stream = internal_options.has_user_compute_stream;
}
if (options.find("user_compute_stream") != options.end() && internal_options.user_compute_stream != nullptr) {
trt_provider_options_v2.user_compute_stream = internal_options.user_compute_stream;
trt_provider_options_v2.has_user_compute_stream = true;
}

trt_provider_options_v2.trt_max_partition_iterations = internal_options.max_partition_iterations;
trt_provider_options_v2.trt_min_subgraph_size = internal_options.min_subgraph_size;
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,15 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'device_id' should be a number i.e. '0'.\n");
}
} else if (option.first == "user_compute_stream") {
if (!option.second.empty()) {
auto stream = std::stoull(option.second, nullptr, 0);
params.user_compute_stream = reinterpret_cast<void*>(stream);
params.has_user_compute_stream = true;
} else {
params.has_user_compute_stream = false;
ORT_THROW("[ERROR] [TensorRT] The value for the key 'user_compute_stream' should be a string to define the compute stream for the inference to run on.\n");
}
} else if (option.first == "trt_max_partition_iterations") {
if (!option.second.empty()) {
params.trt_max_partition_iterations = std::stoi(option.second);
Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def test_set_providers_with_options(self):
option["trt_engine_cache_path"] = engine_cache_path
force_sequential_engine_build = "true"
option["trt_force_sequential_engine_build"] = force_sequential_engine_build
option["user_compute_stream"] = "1"
sess.set_providers(["TensorrtExecutionProvider"], [option])

options = sess.get_provider_options()
Expand All @@ -326,6 +327,8 @@ def test_set_providers_with_options(self):
self.assertEqual(option["trt_engine_cache_enable"], "1")
self.assertEqual(option["trt_engine_cache_path"], str(engine_cache_path))
self.assertEqual(option["trt_force_sequential_engine_build"], "1")
self.assertEqual(option["user_compute_stream"], "1")
self.assertEqual(option["has_user_compute_stream"], "1")

from onnxruntime.capi import _pybind_state as C

Expand Down Expand Up @@ -354,6 +357,19 @@ def test_set_providers_with_options(self):
sess.set_providers(['TensorrtExecutionProvider'], [option])
"""

try:
import torch

if torch.cuda.is_available():
s = torch.cuda.Stream()
option["user_compute_stream"] = str(s.cuda_stream)
sess.set_providers(["TensorrtExecutionProvider"], [option])
options = sess.get_provider_options()
self.assertEqual(options["TensorrtExecutionProvider"]["user_compute_stream"], str(s.cuda_stream))
self.assertEqual(options["TensorrtExecutionProvider"]["has_user_compute_stream"], "1")
except ImportError:
print("torch is not installed, skip testing setting user_compute_stream from torch cuda stream")

if "CUDAExecutionProvider" in onnxrt.get_available_providers():
cuda_success = 0

Expand Down

0 comments on commit 54f91ea

Please sign in to comment.