Skip to content

Commit

Permalink
[onnxruntime] Adds disablePerSessionThreads option to model loading (d…
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored and patins1 committed Oct 30, 2022
1 parent 7c11b24 commit 1659a6a
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import ai.djl.training.GradientCollector;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtLoggingLevel;
import ai.onnxruntime.OrtSession;

/**
Expand All @@ -42,13 +43,33 @@ public final class OrtEngine extends Engine {

private OrtEngine() {
// init OrtRuntime
this.env = OrtEnvironment.getEnvironment();
OrtEnvironment.ThreadingOptions options = new OrtEnvironment.ThreadingOptions();
try {
Integer interOpThreads = Integer.getInteger("ai.djl.onnxruntime.num_interop_threads");
Integer intraOpsThreads = Integer.getInteger("ai.djl.onnxruntime.num_threads");
if (interOpThreads != null) {
options.setGlobalInterOpNumThreads(interOpThreads);
}
if (intraOpsThreads != null) {
options.setGlobalIntraOpNumThreads(intraOpsThreads);
}
OrtLoggingLevel logging = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING;
String name = OrtEnvironment.DEFAULT_NAME;
this.env = OrtEnvironment.getEnvironment(logging, name, options);
} catch (OrtException e) {
options.close();
throw new AssertionError("Failed to config OrtEnvironment", e);
}
}

static Engine newInstance() {
return new OrtEngine();
}

OrtEnvironment getEnv() {
return env;
}

/** {@inheritDoc} */
@Override
public Engine getAlternativeEngine() {
Expand Down Expand Up @@ -87,8 +108,7 @@ public boolean hasCapability(String capability) {
if (StandardCapabilities.MKL.equals(capability)) {
return true;
} else if (StandardCapabilities.CUDA.equals(capability)) {
try {
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
try (OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions()) {
sessionOptions.addCUDA();
return true;
} catch (OrtException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ private SessionOptions getSessionOptions(Map<String, ?> options) throws OrtExcep
ortSession.setCPUArenaAllocator(true);
}

String disablePerSessionThreads = (String) options.get("disablePerSessionThreads");
if (Boolean.parseBoolean(disablePerSessionThreads)) {
ortSession.disablePerSessionThreads();
}

String customOpLibrary = (String) options.get("customOpLibrary");
if (customOpLibrary != null) {
ortSession.registerCustomOpLibrary(customOpLibrary);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void close() {
private static final class SystemManager extends OrtNDManager implements SystemNDManager {

SystemManager() {
super(null, null, OrtEnvironment.getEnvironment());
super(null, null, ((OrtEngine) Engine.getEngine(OrtEngine.ENGINE_NAME)).getEnv());
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import org.testng.Assert;
import org.testng.SkipException;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.io.IOException;
Expand All @@ -39,6 +40,12 @@

public class OrtTest {

@BeforeClass
public void setUp() {
System.setProperty("ai.djl.onnxruntime.num_threads", "1");
System.setProperty("ai.djl.onnxruntime.num_interop_threads", "1");
}

@Test
public void testOrt() throws TranslateException, ModelException, IOException {
try {
Expand All @@ -52,6 +59,7 @@ public void testOrt() throws TranslateException, ModelException, IOException {
.optOption("optLevel", "NO_OPT")
.optOption("memoryPatternOptimization", "true")
.optOption("cpuArenaAllocator", "true")
.optOption("disablePerSessionThreads", "true")
.build();

IrisFlower virginica = new IrisFlower(1.0f, 2.0f, 3.0f, 4.0f);
Expand Down

0 comments on commit 1659a6a

Please sign in to comment.