diff --git a/engines/onnxruntime/onnxruntime-engine/README.md b/engines/onnxruntime/onnxruntime-engine/README.md index 87e690fb674..cc689131949 100644 --- a/engines/onnxruntime/onnxruntime-engine/README.md +++ b/engines/onnxruntime/onnxruntime-engine/README.md @@ -85,3 +85,11 @@ Gradle: } implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.11.0" ``` + +#### Enable TensorRT execution + +ONNXRuntime offers TensorRT execution as the backend. In DJL, user can specify the followings in the Criteria to enable: + +``` +optOption("OrtDevice", "TensorRT") +``` diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java index dfd9caa1fbc..0e9ae348660 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java @@ -78,7 +78,20 @@ public void load(Path modelPath, String prefix, Map options) try { SessionOptions ortOptions = getSessionOptions(options); Device device = manager.getDevice(); - if (device.isGpu()) { + if (options.containsKey("OrtDevice")) { + String ortDevice = (String) options.get("OrtDevice"); + switch (ortDevice) { + case "TensorRT": + ortOptions.addTensorrt(manager.getDevice().getDeviceId()); + case "ROCM": + ortOptions.addROCM(); + case "CoreML": + ortOptions.addCoreML(); + default: + throw new UnsupportedOperationException( + ortDevice + " not supported by DJL"); + } + } else if (device.isGpu()) { ortOptions.addCUDA(manager.getDevice().getDeviceId()); } OrtSession session = env.createSession(modelFile.toString(), ortOptions);