Skip to content

Commit

Permalink
add flag of keepGraphOptimize.
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Nov 17, 2022
1 parent 57f0327 commit cee50cf
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ protected NDList forwardInternal(
for (NDArray array : inputs) {
inputDescriptions.add(array.getName(), array.getShape());
}
NDList outputs = IValueUtils.forward(this, inputs, training);
NDList outputs = IValueUtils.forward(this, inputs, training, true);
for (NDArray array : outputs) {
outputDescriptions.add(array.getName(), array.getShape());
}
Expand All @@ -139,7 +139,7 @@ protected NDList forwardInternal(
}
}
}
return IValueUtils.forward(this, inputs, training);
return IValueUtils.forward(this, inputs, training, true);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,17 @@ private IValueUtils() {}
* @param block the block that contains PyTorch module
* @param inputs the input {@link NDList}
* @param isTrain if running on training mode
* @param keepGraphOptimize whether keep graphOptimization during inference, always false on
* Android, inactive if istrain is true
* @return the result {@link NDList}
*/
public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain) {
public static NDList forward(
PtSymbolBlock block, NDList inputs, boolean isTrain, boolean keepGraphOptimize) {
IValue[] iValues = getInputs(inputs);
long[] iValueHandles = Arrays.stream(iValues).mapToLong(IValue::getHandle).toArray();
long result = PyTorchLibrary.LIB.moduleForward(block.getHandle(), iValueHandles, isTrain);
long result =
PyTorchLibrary.LIB.moduleForward(
block.getHandle(), iValueHandles, isTrain, isTrain || keepGraphOptimize);
PtNDManager manager = (PtNDManager) inputs.get(0).getManager();
Arrays.stream(iValues).forEach(IValue::close);
try (IValue iValue = new IValue(result)) {
Expand All @@ -61,7 +66,8 @@ public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain
*/
public static IValue forward(PtSymbolBlock block, IValue... inputs) {
long[] handles = Arrays.stream(inputs).mapToLong(IValue::getHandle).toArray();
return new IValue(PyTorchLibrary.LIB.moduleForward(block.getHandle(), handles, false));
return new IValue(
PyTorchLibrary.LIB.moduleForward(block.getHandle(), handles, false, true));
}

private static int addToMap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,8 @@ public static PtSymbolBlock loadModule(
mapLocation,
extraFileKeys,
extraFileValues,
trainParam);
trainParam,
true);
return new PtSymbolBlock(manager, handle);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,8 @@ native long moduleLoad(
boolean mapLocation,
String[] extraFileNames,
String[] extraFileValues,
boolean trainParam);
boolean trainParam,
boolean keepGraphOptimize);

native long moduleLoad(
InputStream is, int[] device, boolean mapLocation, byte[] buffer, long size);
Expand All @@ -505,7 +506,8 @@ native long moduleLoad(

native void moduleTrain(long handle);

native long moduleForward(long moduleHandle, long[] iValueHandles, boolean isTrain);
native long moduleForward(
long moduleHandle, long[] iValueHandles, boolean isTrain, boolean keepGraphOptimize);

native void setGraphExecutorOptimize(boolean enabled);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

// The file is the implementation for PyTorch inference operations

struct JITCallGuard {
struct JITCallGuardA {
#ifdef V1_10_X
torch::autograd::AutoGradMode no_autograd_guard{false};
torch::NoGradGuard no_grad;
Expand All @@ -31,40 +31,18 @@ struct JITCallGuard {
#endif
};

JNIEXPORT jlong JNICALL
Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad__Ljava_lang_String_2_3IZ_3Ljava_lang_String_2_3Ljava_lang_String_2Z(
JNIEnv* env, jobject jthis, jstring jpath, jintArray jarray, jboolean jmap_location, jobjectArray jefnames,
jobjectArray jefvalues, jboolean jtrainParam) {
API_BEGIN()
const std::string path = djl::utils::jni::GetStringFromJString(env, jpath);
const torch::Device device = utils::GetDeviceFromJDevice(env, jarray);
std::unordered_map<std::string, std::string> map;
size_t len = static_cast<size_t>(env->GetArrayLength(jefnames));
for (size_t i = 0; i < len; ++i) {
auto jname = (jstring) env->GetObjectArrayElement(jefnames, i);
auto name = djl::utils::jni::GetStringFromJString(env, jname);
map[name] = "";
}
struct JITCallGuardB {
#ifdef V1_10_X
torch::autograd::AutoGradMode no_autograd_guard{false};
torch::NoGradGuard no_grad;
#else
c10::InferenceMode guard;
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
#endif
};

if (!jtrainParam) {
JITCallGuard guard;
torch::jit::Module module;
if (jmap_location) {
module = torch::jit::load(path, device, map);
module.eval();
} else {
module = torch::jit::load(path, torch::nullopt, map);
module.eval();
module.to(device);
}
const auto* module_ptr = new torch::jit::Module(module);
for (size_t i = 0; i < len; ++i) {
auto jname = (jstring) env->GetObjectArrayElement(jefnames, i);
auto name = djl::utils::jni::GetStringFromJString(env, jname);
env->SetObjectArrayElement(jefvalues, i, env->NewStringUTF(map[name].c_str()));
}
return reinterpret_cast<uintptr_t>(module_ptr);
}
inline jlong modelLoadHelpFunc(JNIEnv* env, std::string path, torch::Device device, std::unordered_map<std::string, std::string> map,
jboolean jmap_location, jobjectArray jefnames, jobjectArray jefvalues) {
torch::jit::Module module;
if (jmap_location) {
module = torch::jit::load(path, device, map);
Expand All @@ -75,12 +53,40 @@ Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad__Ljava_lang_String_2_3IZ_3Ljav
module.to(device);
}
const auto* module_ptr = new torch::jit::Module(module);
size_t len = static_cast<size_t>(env->GetArrayLength(jefnames));
for (size_t i = 0; i < len; ++i) {
auto jname = (jstring) env->GetObjectArrayElement(jefnames, i);
auto name = djl::utils::jni::GetStringFromJString(env, jname);
env->SetObjectArrayElement(jefvalues, i, env->NewStringUTF(map[name].c_str()));
}
return reinterpret_cast<uintptr_t>(module_ptr);
};

JNIEXPORT jlong JNICALL
Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad__Ljava_lang_String_2_3IZ_3Ljava_lang_String_2_3Ljava_lang_String_2ZZ(
JNIEnv* env, jobject jthis, jstring jpath, jintArray jarray, jboolean jmap_location, jobjectArray jefnames,
jobjectArray jefvalues, jboolean jtrainParam, jboolean jkeepGraphOptimize) {
API_BEGIN()
const std::string path = djl::utils::jni::GetStringFromJString(env, jpath);
const torch::Device device = utils::GetDeviceFromJDevice(env, jarray);
std::unordered_map<std::string, std::string> map;
size_t len = static_cast<size_t>(env->GetArrayLength(jefnames));
for (size_t i = 0; i < len; ++i) {
auto jname = (jstring) env->GetObjectArrayElement(jefnames, i);
auto name = djl::utils::jni::GetStringFromJString(env, jname);
map[name] = "";
}

if (!jtrainParam) {
if (jkeepGraphOptimize) {
JITCallGuardA guard;
return modelLoadHelpFunc(env, path, device, map, jmap_location, jefnames, jefvalues);
} else {
JITCallGuardB guard;
return modelLoadHelpFunc(env, path, device, map, jmap_location, jefnames, jefvalues);
}
}
return modelLoadHelpFunc(env, path, device, map, jmap_location, jefnames, jefvalues);
API_END_RETURN()
}

Expand Down Expand Up @@ -213,8 +219,8 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleTrain(
API_END()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleForward(
JNIEnv* env, jobject jthis, jlong module_handle, jlongArray jivalue_ptrs, jboolean jis_train) {
JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleForward(JNIEnv* env, jobject jthis,
jlong module_handle, jlongArray jivalue_ptrs, jboolean jis_train, jboolean jkeepGraphOptimize) {
API_BEGIN()
auto* module_ptr = reinterpret_cast<torch::jit::script::Module*>(module_handle);
size_t len = env->GetArrayLength(jivalue_ptrs);
Expand All @@ -227,10 +233,15 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleForward(
torch::IValue output = [&]() {
if (jis_train) {
return module_ptr->forward(inputs);
} else {
if (jkeepGraphOptimize) {
JITCallGuardA guard;
return module_ptr->forward(inputs);
} else {
JITCallGuardB guard;
return module_ptr->forward(inputs);
}
}
// disable autograd
JITCallGuard guard;
return module_ptr->forward(inputs);
}();
env->ReleaseLongArrayElements(jivalue_ptrs, jptrs, djl::utils::jni::RELEASE_MODE);
const auto* result_ptr = new torch::IValue(output);
Expand Down

0 comments on commit cee50cf

Please sign in to comment.