diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java index 8521e74466f..307082cd1c7 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java @@ -130,7 +130,7 @@ protected NDList forwardInternal( for (NDArray array : inputs) { inputDescriptions.add(array.getName(), array.getShape()); } - NDList outputs = IValueUtils.forward(this, inputs, training); + NDList outputs = IValueUtils.forward(this, inputs, training, true); for (NDArray array : outputs) { outputDescriptions.add(array.getName(), array.getShape()); } @@ -139,7 +139,7 @@ protected NDList forwardInternal( } } } - return IValueUtils.forward(this, inputs, training); + return IValueUtils.forward(this, inputs, training, true); } /** {@inheritDoc} */ diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java index 0a1da4b888b..d11ee30f3c8 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java @@ -39,12 +39,17 @@ private IValueUtils() {} * @param block the block that contains PyTorch module * @param inputs the input {@link NDList} * @param isTrain if running on training mode + * @param keepGraphOptimize whether keep graphOptimization during inference, always false on + * Android, inactive if istrain is true * @return the result {@link NDList} */ - public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain) { + public static NDList forward( + PtSymbolBlock block, NDList inputs, boolean isTrain, boolean keepGraphOptimize) { IValue[] iValues = getInputs(inputs); long[] iValueHandles = Arrays.stream(iValues).mapToLong(IValue::getHandle).toArray(); - long result = PyTorchLibrary.LIB.moduleForward(block.getHandle(), iValueHandles, isTrain); + long result = + PyTorchLibrary.LIB.moduleForward( + block.getHandle(), iValueHandles, isTrain, isTrain || keepGraphOptimize); PtNDManager manager = (PtNDManager) inputs.get(0).getManager(); Arrays.stream(iValues).forEach(IValue::close); try (IValue iValue = new IValue(result)) { @@ -61,7 +66,8 @@ public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain */ public static IValue forward(PtSymbolBlock block, IValue... inputs) { long[] handles = Arrays.stream(inputs).mapToLong(IValue::getHandle).toArray(); - return new IValue(PyTorchLibrary.LIB.moduleForward(block.getHandle(), handles, false)); + return new IValue( + PyTorchLibrary.LIB.moduleForward(block.getHandle(), handles, false, true)); } private static int addToMap( diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 862833d445f..45899220b98 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -1580,7 +1580,8 @@ public static PtSymbolBlock loadModule( mapLocation, extraFileKeys, extraFileValues, - trainParam); + trainParam, + true); return new PtSymbolBlock(manager, handle); } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 3817123906e..a6bb1441386 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -496,7 +496,8 @@ native long moduleLoad( boolean mapLocation, String[] extraFileNames, String[] extraFileValues, - boolean trainParam); + boolean trainParam, + boolean keepGraphOptimize); native long moduleLoad( InputStream is, int[] device, boolean mapLocation, byte[] buffer, long size); @@ -505,7 +506,8 @@ native long moduleLoad( native void moduleTrain(long handle); - native long moduleForward(long moduleHandle, long[] iValueHandles, boolean isTrain); + native long moduleForward( + long moduleHandle, long[] iValueHandles, boolean isTrain, boolean keepGraphOptimize); native void setGraphExecutorOptimize(boolean enabled); diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc index b7a00c8757f..a4192df7832 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc @@ -19,7 +19,7 @@ // The file is the implementation for PyTorch inference operations -struct JITCallGuard { +struct JITCallGuardA { #ifdef V1_10_X torch::autograd::AutoGradMode no_autograd_guard{false}; torch::NoGradGuard no_grad; @@ -31,40 +31,18 @@ struct JITCallGuard { #endif }; -JNIEXPORT jlong JNICALL -Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad__Ljava_lang_String_2_3IZ_3Ljava_lang_String_2_3Ljava_lang_String_2Z( - JNIEnv* env, jobject jthis, jstring jpath, jintArray jarray, jboolean jmap_location, jobjectArray jefnames, - jobjectArray jefvalues, jboolean jtrainParam) { - API_BEGIN() - const std::string path = djl::utils::jni::GetStringFromJString(env, jpath); - const torch::Device device = utils::GetDeviceFromJDevice(env, jarray); - std::unordered_map map; - size_t len = static_cast(env->GetArrayLength(jefnames)); - for (size_t i = 0; i < len; ++i) { - auto jname = (jstring) env->GetObjectArrayElement(jefnames, i); - auto name = djl::utils::jni::GetStringFromJString(env, jname); - map[name] = ""; - } +struct JITCallGuardB { +#ifdef V1_10_X + torch::autograd::AutoGradMode no_autograd_guard{false}; + torch::NoGradGuard no_grad; +#else + c10::InferenceMode guard; + torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false}; +#endif +}; - if (!jtrainParam) { - JITCallGuard guard; - torch::jit::Module module; - if (jmap_location) { - module = torch::jit::load(path, device, map); - module.eval(); - } else { - module = torch::jit::load(path, torch::nullopt, map); - module.eval(); - module.to(device); - } - const auto* module_ptr = new torch::jit::Module(module); - for (size_t i = 0; i < len; ++i) { - auto jname = (jstring) env->GetObjectArrayElement(jefnames, i); - auto name = djl::utils::jni::GetStringFromJString(env, jname); - env->SetObjectArrayElement(jefvalues, i, env->NewStringUTF(map[name].c_str())); - } - return reinterpret_cast(module_ptr); - } +inline jlong modelLoadHelpFunc(JNIEnv* env, std::string path, torch::Device device, std::unordered_map map, + jboolean jmap_location, jobjectArray jefnames, jobjectArray jefvalues) { torch::jit::Module module; if (jmap_location) { module = torch::jit::load(path, device, map); @@ -75,12 +53,40 @@ Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad__Ljava_lang_String_2_3IZ_3Ljav module.to(device); } const auto* module_ptr = new torch::jit::Module(module); + size_t len = static_cast(env->GetArrayLength(jefnames)); for (size_t i = 0; i < len; ++i) { auto jname = (jstring) env->GetObjectArrayElement(jefnames, i); auto name = djl::utils::jni::GetStringFromJString(env, jname); env->SetObjectArrayElement(jefvalues, i, env->NewStringUTF(map[name].c_str())); } return reinterpret_cast(module_ptr); +}; + +JNIEXPORT jlong JNICALL +Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad__Ljava_lang_String_2_3IZ_3Ljava_lang_String_2_3Ljava_lang_String_2ZZ( + JNIEnv* env, jobject jthis, jstring jpath, jintArray jarray, jboolean jmap_location, jobjectArray jefnames, + jobjectArray jefvalues, jboolean jtrainParam, jboolean jkeepGraphOptimize) { + API_BEGIN() + const std::string path = djl::utils::jni::GetStringFromJString(env, jpath); + const torch::Device device = utils::GetDeviceFromJDevice(env, jarray); + std::unordered_map map; + size_t len = static_cast(env->GetArrayLength(jefnames)); + for (size_t i = 0; i < len; ++i) { + auto jname = (jstring) env->GetObjectArrayElement(jefnames, i); + auto name = djl::utils::jni::GetStringFromJString(env, jname); + map[name] = ""; + } + + if (!jtrainParam) { + if (jkeepGraphOptimize) { + JITCallGuardA guard; + return modelLoadHelpFunc(env, path, device, map, jmap_location, jefnames, jefvalues); + } else { + JITCallGuardB guard; + return modelLoadHelpFunc(env, path, device, map, jmap_location, jefnames, jefvalues); + } + } + return modelLoadHelpFunc(env, path, device, map, jmap_location, jefnames, jefvalues); API_END_RETURN() } @@ -213,8 +219,8 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleTrain( API_END() } -JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleForward( - JNIEnv* env, jobject jthis, jlong module_handle, jlongArray jivalue_ptrs, jboolean jis_train) { +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleForward(JNIEnv* env, jobject jthis, + jlong module_handle, jlongArray jivalue_ptrs, jboolean jis_train, jboolean jkeepGraphOptimize) { API_BEGIN() auto* module_ptr = reinterpret_cast(module_handle); size_t len = env->GetArrayLength(jivalue_ptrs); @@ -227,10 +233,15 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleForward( torch::IValue output = [&]() { if (jis_train) { return module_ptr->forward(inputs); + } else { + if (jkeepGraphOptimize) { + JITCallGuardA guard; + return module_ptr->forward(inputs); + } else { + JITCallGuardB guard; + return module_ptr->forward(inputs); + } } - // disable autograd - JITCallGuard guard; - return module_ptr->forward(inputs); }(); env->ReleaseLongArrayElements(jivalue_ptrs, jptrs, djl::utils::jni::RELEASE_MODE); const auto* result_ptr = new torch::IValue(output);