Skip to content

Commit

Permalink
Merge pull request #48 from jpy-consortium/pr-25-release-gil
Browse files Browse the repository at this point in the history
Release the GIL while in Java, with a special method to deliberately acquire it
  • Loading branch information
devinrsmith authored May 17, 2022
2 parents 2d084e7 + ea2171d commit da3dba9
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 44 deletions.
15 changes: 15 additions & 0 deletions src/main/c/jni/org_jpy_PyLib.c
Original file line number Diff line number Diff line change
Expand Up @@ -2145,6 +2145,21 @@ JNIEXPORT jboolean JNICALL Java_org_jpy_PyLib_hasGil
return result;
}

/*
* Class: org_jpy_PyLib
* Method: ensureGil
* Signature: (Ljava/util/function/Supplier;)Ljava/lang/Object;
*/
JNIEXPORT jobject JNICALL Java_org_jpy_PyLib_ensureGil
(JNIEnv* jenv, jclass jLibClass, jobject supplier)
{
jobject result;
JPy_BEGIN_GIL_STATE
result = (*jenv)->CallObjectMethod(jenv, supplier, JPy_Supplier_get_MID);
JPy_END_GIL_STATE
return result;
}


/*
* Class: org_jpy_python_PyLib
Expand Down
8 changes: 8 additions & 0 deletions src/main/c/jni/org_jpy_PyLib.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

104 changes: 84 additions & 20 deletions src/main/c/jpy_jmethod.c
Original file line number Diff line number Diff line change
Expand Up @@ -268,48 +268,80 @@ PyObject* JMethod_InvokeMethod(JNIEnv* jenv, JPy_JMethod* method, PyObject* pyAr
JPy_DIAG_PRINT(JPy_DIAG_F_EXEC, "JMethod_InvokeMethod: calling static Java method %s#%s\n", declaringClass->javaName, JPy_AS_UTF8(method->name));

if (returnType == JPy_JVoid) {
Py_BEGIN_ALLOW_THREADS;
(*jenv)->CallStaticVoidMethodA(jenv, classRef, method->mid, jArgs);
Py_END_ALLOW_THREADS
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JVOID();
} else if (returnType == JPy_JBoolean) {
jboolean v = (*jenv)->CallStaticBooleanMethodA(jenv, classRef, method->mid, jArgs);
jboolean v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallStaticBooleanMethodA(jenv, classRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JBOOLEAN(v);
} else if (returnType == JPy_JChar) {
jchar v = (*jenv)->CallStaticCharMethodA(jenv, classRef, method->mid, jArgs);
jchar v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallStaticCharMethodA(jenv, classRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JCHAR(v);
} else if (returnType == JPy_JByte) {
jbyte v = (*jenv)->CallStaticByteMethodA(jenv, classRef, method->mid, jArgs);
jbyte v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallStaticByteMethodA(jenv, classRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JBYTE(v);
} else if (returnType == JPy_JShort) {
jshort v = (*jenv)->CallStaticShortMethodA(jenv, classRef, method->mid, jArgs);
jshort v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallStaticShortMethodA(jenv, classRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JSHORT(v);
} else if (returnType == JPy_JInt) {
jint v = (*jenv)->CallStaticIntMethodA(jenv, classRef, method->mid, jArgs);
jint v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallStaticIntMethodA(jenv, classRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JINT(v);
} else if (returnType == JPy_JLong) {
jlong v = (*jenv)->CallStaticLongMethodA(jenv, classRef, method->mid, jArgs);
jlong v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallStaticLongMethodA(jenv, classRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JLONG(v);
} else if (returnType == JPy_JFloat) {
jfloat v = (*jenv)->CallStaticFloatMethodA(jenv, classRef, method->mid, jArgs);
jfloat v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallStaticFloatMethodA(jenv, classRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JFLOAT(v);
} else if (returnType == JPy_JDouble) {
jdouble v = (*jenv)->CallStaticDoubleMethodA(jenv, classRef, method->mid, jArgs);
jdouble v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallStaticDoubleMethodA(jenv, classRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JDOUBLE(v);
} else if (returnType == JPy_JString) {
jstring v = (*jenv)->CallStaticObjectMethodA(jenv, classRef, method->mid, jArgs);
jstring v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallStaticObjectMethodA(jenv, classRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FromJString(jenv, v);
JPy_DELETE_LOCAL_REF(v);
} else {
jobject v = (*jenv)->CallStaticObjectMethodA(jenv, classRef, method->mid, jArgs);
jobject v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallStaticObjectMethodA(jenv, classRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JMethod_FromJObject(jenv, method, pyArgs, jArgs, 0, returnType, v);
JPy_DELETE_LOCAL_REF(v);
Expand All @@ -326,48 +358,80 @@ PyObject* JMethod_InvokeMethod(JNIEnv* jenv, JPy_JMethod* method, PyObject* pyAr
objectRef = ((JPy_JObj*) self)->objectRef;

if (returnType == JPy_JVoid) {
Py_BEGIN_ALLOW_THREADS;
(*jenv)->CallVoidMethodA(jenv, objectRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JVOID();
} else if (returnType == JPy_JBoolean) {
jboolean v = (*jenv)->CallBooleanMethodA(jenv, objectRef, method->mid, jArgs);
jboolean v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallBooleanMethodA(jenv, objectRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JBOOLEAN(v);
} else if (returnType == JPy_JChar) {
jchar v = (*jenv)->CallCharMethodA(jenv, objectRef, method->mid, jArgs);
jchar v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallCharMethodA(jenv, objectRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JCHAR(v);
} else if (returnType == JPy_JByte) {
jbyte v = (*jenv)->CallByteMethodA(jenv, objectRef, method->mid, jArgs);
jbyte v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallByteMethodA(jenv, objectRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JBYTE(v);
} else if (returnType == JPy_JShort) {
jshort v = (*jenv)->CallShortMethodA(jenv, objectRef, method->mid, jArgs);
jshort v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallShortMethodA(jenv, objectRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JSHORT(v);
} else if (returnType == JPy_JInt) {
jint v = (*jenv)->CallIntMethodA(jenv, objectRef, method->mid, jArgs);
jint v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallIntMethodA(jenv, objectRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JINT(v);
} else if (returnType == JPy_JLong) {
jlong v = (*jenv)->CallLongMethodA(jenv, objectRef, method->mid, jArgs);
jlong v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallLongMethodA(jenv, objectRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JLONG(v);
} else if (returnType == JPy_JFloat) {
jfloat v = (*jenv)->CallFloatMethodA(jenv, objectRef, method->mid, jArgs);
jfloat v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallFloatMethodA(jenv, objectRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JFLOAT(v);
} else if (returnType == JPy_JDouble) {
jdouble v = (*jenv)->CallDoubleMethodA(jenv, objectRef, method->mid, jArgs);
jdouble v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallDoubleMethodA(jenv, objectRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FROM_JDOUBLE(v);
} else if (returnType == JPy_JString) {
jstring v = (*jenv)->CallObjectMethodA(jenv, objectRef, method->mid, jArgs);
jstring v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallObjectMethodA(jenv, objectRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JPy_FromJString(jenv, v);
JPy_DELETE_LOCAL_REF(v);
} else {
jobject v = (*jenv)->CallObjectMethodA(jenv, objectRef, method->mid, jArgs);
jobject v;
Py_BEGIN_ALLOW_THREADS;
v = (*jenv)->CallObjectMethodA(jenv, objectRef, method->mid, jArgs);
Py_END_ALLOW_THREADS;
JPy_ON_JAVA_EXCEPTION_GOTO(error);
returnValue = JMethod_FromJObject(jenv, method, pyArgs, jArgs, 1, returnType, v);
JPy_DELETE_LOCAL_REF(v);
Expand Down
4 changes: 4 additions & 0 deletions src/main/c/jpy_jtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,9 @@ int JType_PythonToJavaConversionError(JPy_JType* type, PyObject* pyArg)

int JType_CreateJavaObject(JNIEnv* jenv, JPy_JType* type, PyObject* pyArg, jclass classRef, jmethodID initMID, jvalue value, jobject* objectRef)
{
Py_BEGIN_ALLOW_THREADS;
*objectRef = (*jenv)->NewObjectA(jenv, classRef, initMID, &value);
Py_END_ALLOW_THREADS;
if (*objectRef == NULL) {
PyErr_NoMemory();
return -1;
Expand All @@ -388,7 +390,9 @@ int JType_CreateJavaObject(JNIEnv* jenv, JPy_JType* type, PyObject* pyArg, jclas

int JType_CreateJavaObject_2(JNIEnv* jenv, JPy_JType* type, PyObject* pyArg, jclass classRef, jmethodID initMID, jvalue value1, jvalue value2, jobject* objectRef)
{
Py_BEGIN_ALLOW_THREADS;
*objectRef = (*jenv)->NewObject(jenv, classRef, initMID, value1, value2);
Py_END_ALLOW_THREADS;
if (*objectRef == NULL) {
PyErr_NoMemory();
return -1;
Expand Down
7 changes: 7 additions & 0 deletions src/main/c/jpy_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ jmethodID JPy_Throwable_getCause_MID = NULL;
// stack trace element
jclass JPy_StackTraceElement_JClass = NULL;

// java.util.function.Supplier
jclass JPy_Supplier_JClass = NULL;
jmethodID JPy_Supplier_get_MID = NULL;

// }}}


Expand Down Expand Up @@ -954,6 +958,9 @@ int JPy_InitGlobalVars(JNIEnv* jenv)
DEFINE_METHOD(JPy_Throwable_getCause_MID, JPy_Throwable_JClass, "getCause", "()Ljava/lang/Throwable;");
DEFINE_METHOD(JPy_Throwable_getStackTrace_MID, JPy_Throwable_JClass, "getStackTrace", "()[Ljava/lang/StackTraceElement;");

DEFINE_CLASS(JPy_Supplier_JClass, "java/util/function/Supplier");
DEFINE_METHOD(JPy_Supplier_get_MID, JPy_Supplier_JClass, "get", "()Ljava/lang/Object;")

// JType_AddClassAttribute is actually called from within JType_GetType(), but not for
// JPy_JObject and JPy_JClass for an obvious reason. So we do it now:
JType_AddClassAttribute(jenv, JPy_JObject);
Expand Down
3 changes: 3 additions & 0 deletions src/main/c/jpy_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ extern jmethodID JPy_PyObject_Init_MID;
extern jclass JPy_PyDictWrapper_JClass;
extern jmethodID JPy_PyDictWrapper_GetPointer_MID;

extern jclass JPy_Supplier_JClass;
extern jmethodID JPy_Supplier_get_MID;

#ifdef __cplusplus
} /* extern "C" */
#endif
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/jpy/PyLib.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.function.Supplier;

import static org.jpy.PyLibConfig.*;

Expand Down Expand Up @@ -431,6 +432,8 @@ public static void stopPython() {

public static native boolean hasGil();

public static native <T> T ensureGil(Supplier<T> runnable);

/**
* Calls a Python callable and returns the resulting Python object.
* <p>
Expand Down
45 changes: 21 additions & 24 deletions src/main/java/org/jpy/PyObjectReferences.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,33 +81,30 @@ public int cleanupOnlyUseFromGIL() {
}

private int cleanupOnlyUseFromGIL(long[] buffer) {
if (!PyLib.hasGil()) {
throw new IllegalStateException(
"We should only be calling PyObjectReferences.cleanupOnlyUseFromGIL if we have the GIL!");
}

int index = 0;
while (index < buffer.length) {
final Reference<? extends PyObject> reference = referenceQueue.poll();
if (reference == null) {
break;
return PyLib.ensureGil(() -> {
int index = 0;
while (index < buffer.length) {
final Reference<? extends PyObject> reference = referenceQueue.poll();
if (reference == null) {
break;
}
index = appendIfNotClosed(buffer, index, reference);
}
if (index == 0) {
return 0;
}
index = appendIfNotClosed(buffer, index, reference);
}
if (index == 0) {
return 0;
}

// We really really really want to make sure we *already* have the GIL lock at this point in
// time. Otherwise, we block here until the GIL is available for us, and stall all cleanup
// related to our PyObjects.
// We really really really want to make sure we *already* have the GIL lock at this point in
// time. Otherwise, we block here until the GIL is available for us, and stall all cleanup
// related to our PyObjects.

if (index == 1) {
PyLib.decRef(buffer[0]);
return 1;
}
PyLib.decRefs(buffer, index);
return index;
if (index == 1) {
PyLib.decRef(buffer[0]);
return 1;
}
PyLib.decRefs(buffer, index);
return index;
});
}

private int appendIfNotClosed(long[] buffer, int index, Reference<? extends PyObject> reference) {
Expand Down
22 changes: 22 additions & 0 deletions src/test/java/org/jpy/PyLibTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -283,4 +283,26 @@ public void decRefs() {

PyLib.decRefs(new long[] { pyObject1, pyObject2, 0, 0 }, 2);
}

@Test
public void testEnsureGIL() {
assertFalse(PyLib.hasGil());
boolean[] lambdaSuccessfullyRan = {false};
Integer intResult = PyLib.ensureGil(() -> {
assertTrue(PyLib.hasGil());
lambdaSuccessfullyRan[0] = true;
return 123;
});
assertEquals((Integer) 123, intResult);
assertTrue(lambdaSuccessfullyRan[0]);

try {
Object result = PyLib.ensureGil(() -> {
throw new IllegalStateException("Error from inside GIL block");
});
fail("Exception expected");
} catch (IllegalStateException expectedException) {
assertEquals("Error from inside GIL block", expectedException.getMessage());
}//let anything else rethrow as a failure
}
}

0 comments on commit da3dba9

Please sign in to comment.