diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index f96db2bfb01..bb49ac51c05 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -22,6 +22,7 @@ import ai.djl.training.GradientCollector; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtLoggingLevel; import ai.onnxruntime.OrtSession; /** @@ -42,13 +43,33 @@ public final class OrtEngine extends Engine { private OrtEngine() { // init OrtRuntime - this.env = OrtEnvironment.getEnvironment(); + OrtEnvironment.ThreadingOptions options = new OrtEnvironment.ThreadingOptions(); + try { + Integer interOpThreads = Integer.getInteger("ai.djl.onnxruntime.num_interop_threads"); + Integer intraOpsThreads = Integer.getInteger("ai.djl.onnxruntime.num_threads"); + if (interOpThreads != null) { + options.setGlobalInterOpNumThreads(interOpThreads); + } + if (intraOpsThreads != null) { + options.setGlobalIntraOpNumThreads(intraOpsThreads); + } + OrtLoggingLevel logging = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING; + String name = OrtEnvironment.DEFAULT_NAME; + this.env = OrtEnvironment.getEnvironment(logging, name, options); + } catch (OrtException e) { + options.close(); + throw new AssertionError("Failed to config OrtEnvironment", e); + } } static Engine newInstance() { return new OrtEngine(); } + OrtEnvironment getEnv() { + return env; + } + /** {@inheritDoc} */ @Override public Engine getAlternativeEngine() { @@ -87,8 +108,7 @@ public boolean hasCapability(String capability) { if (StandardCapabilities.MKL.equals(capability)) { return true; } else if (StandardCapabilities.CUDA.equals(capability)) { - try { - OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions(); + try (OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions()) { sessionOptions.addCUDA(); return true; } catch (OrtException e) { diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java index 6f2cf648b93..81e07032a1c 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java @@ -179,6 +179,11 @@ private SessionOptions getSessionOptions(Map options) throws OrtExcep ortSession.setCPUArenaAllocator(true); } + String disablePerSessionThreads = (String) options.get("disablePerSessionThreads"); + if (Boolean.parseBoolean(disablePerSessionThreads)) { + ortSession.disablePerSessionThreads(); + } + String customOpLibrary = (String) options.get("customOpLibrary"); if (customOpLibrary != null) { ortSession.registerCustomOpLibrary(customOpLibrary); diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java index 3cfbddce660..f18ad30a0fd 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java @@ -127,7 +127,7 @@ public void close() { private static final class SystemManager extends OrtNDManager implements SystemNDManager { SystemManager() { - super(null, null, OrtEnvironment.getEnvironment()); + super(null, null, ((OrtEngine) Engine.getEngine(OrtEngine.ENGINE_NAME)).getEnv()); } /** {@inheritDoc} */ diff --git a/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java b/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java index 6287bd526a3..5bf1f1a4d3f 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java +++ b/engines/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java @@ -29,6 +29,7 @@ import org.testng.Assert; import org.testng.SkipException; +import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.io.IOException; @@ -39,6 +40,12 @@ public class OrtTest { + @BeforeClass + public void setUp() { + System.setProperty("ai.djl.onnxruntime.num_threads", "1"); + System.setProperty("ai.djl.onnxruntime.num_interop_threads", "1"); + } + @Test public void testOrt() throws TranslateException, ModelException, IOException { try { @@ -52,6 +59,7 @@ public void testOrt() throws TranslateException, ModelException, IOException { .optOption("optLevel", "NO_OPT") .optOption("memoryPatternOptimization", "true") .optOption("cpuArenaAllocator", "true") + .optOption("disablePerSessionThreads", "true") .build(); IrisFlower virginica = new IrisFlower(1.0f, 2.0f, 3.0f, 4.0f);