Skip to content

Commit

Permalink
add pytorch cuDNN acceleration (#1592)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lanking authored Apr 21, 2022
1 parent bbdcbf7 commit 109c9b4
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 0 deletions.
11 changes: 11 additions & 0 deletions docs/development/inference_performance_optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ You can enable it by

You might see the exception if certain data type or operator is not supported with the oneDNN device.

#### CuDNN acceleration
PyTorch has a special flags that used for CNN or related network speed up. If your input size won't change frequently,
you may benefit from enabling this configuration in your model:

```
-Dai.djl.pytorch.cudnn_benchmark=true
```

If your input shape changed frequently, this change may stall your performance. For more information, check this
[article](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#enable-cudnn-auto-tuner).

#### Thread configuration
There are two configurations you can set to optimize the inference performance.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ static Engine newInstance() {
if (Integer.getInteger("ai.djl.pytorch.num_threads") != null) {
JniUtils.setNumThreads(Integer.getInteger("ai.djl.pytorch.num_threads"));
}
// for ConvNN related model speed up
if (Boolean.getBoolean("ai.djl.pytorch.cudnn_benchmark")) {
JniUtils.setBenchmarkCuDNN(true);
}
logger.info("Number of inter-op threads is " + JniUtils.getNumInteropThreads());
logger.info("Number of intra-op threads is " + JniUtils.getNumThreads());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ public static void setNumThreads(int threads) {
PyTorchLibrary.LIB.torchSetNumThreads(threads);
}

public static void setBenchmarkCuDNN(boolean enable) {
PyTorchLibrary.LIB.torchSetBenchmarkCuDNN(enable);
}

public static synchronized Set<String> getFeatures() {
if (configs != null) {
return configs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ private PyTorchLibrary() {}

native void torchSetNumThreads(int threads);

native void torchSetBenchmarkCuDNN(boolean enable);

native void torchManualSeed(long seed);

native void torchShowConfig(Set<String> set);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_setGraphExecutorOp
API_END()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSetBenchmarkCuDNN(
JNIEnv* env, jobject jthis, jboolean jenabled) {
API_BEGIN()
torch::globalContext().setBenchmarkCuDNN(jenabled);
API_END()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleEval(
JNIEnv* env, jobject jthis, jlong module_handle) {
API_BEGIN()
Expand Down

0 comments on commit 109c9b4

Please sign in to comment.