diff --git a/sherpa-onnx/csrc/provider.cc b/sherpa-onnx/csrc/provider.cc index 95bc18c5f..19d585976 100644 --- a/sherpa-onnx/csrc/provider.cc +++ b/sherpa-onnx/csrc/provider.cc @@ -24,6 +24,8 @@ Provider StringToProvider(std::string s) { return Provider::kXnnpack; } else if (s == "nnapi") { return Provider::kNNAPI; + } else if (s == "trt") { + return Provider::kTRT; } else { SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); return Provider::kCPU; diff --git a/sherpa-onnx/csrc/provider.h b/sherpa-onnx/csrc/provider.h index 467e5dab5..c104d401a 100644 --- a/sherpa-onnx/csrc/provider.h +++ b/sherpa-onnx/csrc/provider.h @@ -18,6 +18,7 @@ enum class Provider { kCoreML = 2, // CoreMLExecutionProvider kXnnpack = 3, // XnnpackExecutionProvider kNNAPI = 4, // NnapiExecutionProvider + kTRT = 5, // TensorRTExecutionProvider }; /** diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index d0a697404..431a6a761 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -21,6 +21,16 @@ namespace sherpa_onnx { + +static void OrtStatusFailure(OrtStatus *status, const char *s) { + const auto &api = Ort::GetApi(); + const char *msg = api.GetErrorMessage(status); + SHERPA_ONNX_LOGE( + "Failed to enable TensorRT : %s." + "Available providers: %s. Fallback to cuda", msg, s); + api.ReleaseStatus(status); +} + static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, std::string provider_str) { Provider p = StringToProvider(std::move(provider_str)); @@ -53,6 +63,57 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, } break; } + case Provider::kTRT: { + struct TrtPairs { + const char* op_keys; + const char* op_values; + }; + + std::vector trt_options = { + {"device_id", "0"}, + {"trt_max_workspace_size", "2147483648"}, + {"trt_max_partition_iterations", "10"}, + {"trt_min_subgraph_size", "5"}, + {"trt_fp16_enable", "0"}, + {"trt_detailed_build_log", "0"}, + {"trt_engine_cache_enable", "1"}, + {"trt_engine_cache_path", "."}, + {"trt_timing_cache_enable", "1"}, + {"trt_timing_cache_path", "."} + }; + // ToDo : Trt configs + // "trt_int8_enable" + // "trt_int8_use_native_calibration_table" + // "trt_dump_subgraphs" + + std::vector option_keys, option_values; + for (const TrtPairs& pair : trt_options) { + option_keys.emplace_back(pair.op_keys); + option_values.emplace_back(pair.op_values); + } + + std::vector available_providers = + Ort::GetAvailableProviders(); + if (std::find(available_providers.begin(), available_providers.end(), + "TensorrtExecutionProvider") != available_providers.end()) { + const auto& api = Ort::GetApi(); + + OrtTensorRTProviderOptionsV2* tensorrt_options; + OrtStatus *statusC = api.CreateTensorRTProviderOptions( + &tensorrt_options); + OrtStatus *statusU = api.UpdateTensorRTProviderOptions( + tensorrt_options, option_keys.data(), option_values.data(), + option_keys.size()); + sess_opts.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); + + if (statusC) { OrtStatusFailure(statusC, os.str().c_str()); } + if (statusU) { OrtStatusFailure(statusU, os.str().c_str()); } + + api.ReleaseTensorRTProviderOptions(tensorrt_options); + } + // break; is omitted here intentionally so that + // if TRT not available, CUDA will be used + } case Provider::kCUDA: { if (std::find(available_providers.begin(), available_providers.end(), "CUDAExecutionProvider") != available_providers.end()) { @@ -116,7 +177,6 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, break; } } - return sess_opts; }