From 793d0fd2b59cfc046626e0fbd5a3721d39c93276 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Thu, 28 Apr 2022 14:40:06 -0700 Subject: [PATCH] add tensorRT option --- engines/onnxruntime/onnxruntime-engine/README.md | 8 ++++++++ .../java/ai/djl/onnxruntime/engine/OrtModel.java | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/engines/onnxruntime/onnxruntime-engine/README.md b/engines/onnxruntime/onnxruntime-engine/README.md index 87e690fb674..cc689131949 100644 --- a/engines/onnxruntime/onnxruntime-engine/README.md +++ b/engines/onnxruntime/onnxruntime-engine/README.md @@ -85,3 +85,11 @@ Gradle: } implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.11.0" ``` + +#### Enable TensorRT execution + +ONNXRuntime offers TensorRT execution as the backend. In DJL, user can specify the followings in the Criteria to enable: + +``` +optOption("OrtDevice", "TensorRT") +``` diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java index dfd9caa1fbc..0e9ae348660 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java @@ -78,7 +78,20 @@ public void load(Path modelPath, String prefix, Map options) try { SessionOptions ortOptions = getSessionOptions(options); Device device = manager.getDevice(); - if (device.isGpu()) { + if (options.containsKey("OrtDevice")) { + String ortDevice = (String) options.get("OrtDevice"); + switch (ortDevice) { + case "TensorRT": + ortOptions.addTensorrt(manager.getDevice().getDeviceId()); + case "ROCM": + ortOptions.addROCM(); + case "CoreML": + ortOptions.addCoreML(); + default: + throw new UnsupportedOperationException( + ortDevice + " not supported by DJL"); + } + } else if (device.isGpu()) { ortOptions.addCUDA(manager.getDevice().getDeviceId()); } OrtSession session = env.createSession(modelFile.toString(), ortOptions);