diff --git a/src/main/c/common-tools.cpp b/src/main/c/common-tools.cpp index 76ce00b..146b6f0 100644 --- a/src/main/c/common-tools.cpp +++ b/src/main/c/common-tools.cpp @@ -8,6 +8,7 @@ extern "C" { static jclass ClassAssertionError; static jclass ClassClass; static jclass ClassMethod; +static jclass ClassThread; static jclass ClassIllegalArgumentException; static jclass ClassNoSuchElementException; static jclass ClassBoolean; @@ -34,6 +35,21 @@ static jvmtiEnv * GetJvmti(JavaVM *jvm) { static std::list classLoadHooks; // NOLINT(cert-err58-cpp) +static jthread GetCurrentThread(JNIEnv *env) { + jmethodID mid = env->GetStaticMethodID(ClassThread, "currentThread", "()Ljava/lang/Thread;"); + return env->CallStaticObjectMethod(ClassThread, mid); +} + +static bool EnsureStackDoesNotContain(jvmtiEnv *jvmti, JNIEnv *env, jmethodID mid) { + jvmtiFrameInfo frames[25]; + jint count; + jvmti->GetStackTrace(GetCurrentThread(env), 0, 25, frames, &count); + for (int i = 0; i < count; ++i) { + if (frames[i].method == mid) return false; + } + return true; +} + static void classLoadHook(jvmtiEnv *jvmti_env, JNIEnv* jni_env, jclass class_being_redefined, @@ -44,14 +60,18 @@ static void classLoadHook(jvmtiEnv *jvmti_env, const unsigned char* class_data, jint* new_class_data_len, unsigned char** new_class_data) { + if (name == nullptr) return; jmethodID mid = jni_env->GetMethodID(ClassClassLoadHook, "transform", "(Ljava/lang/ClassLoader;Ljava/lang/String;Ljava/lang/Class;Ljava/security/ProtectionDomain;[B)[B"); jbyteArray arr = jni_env->NewByteArray(class_data_len); jni_env->SetByteArrayRegion(arr, 0, class_data_len, (jbyte*) class_data); jstring j_name = jni_env->NewStringUTF(name); for (const jobject &item : classLoadHooks) { - jobject obj = jni_env->CallObjectMethod(item, mid, loader, j_name, class_being_redefined, protection_domain, arr); - if (obj != nullptr) { - arr = reinterpret_cast(obj); + jmethodID mid2 = jni_env->GetMethodID(jni_env->GetObjectClass(item), "transform", "(Ljava/lang/ClassLoader;Ljava/lang/String;Ljava/lang/Class;Ljava/security/ProtectionDomain;[B)[B"); + if (EnsureStackDoesNotContain(jvmti_env, jni_env, mid2)) { + jobject obj = jni_env->CallObjectMethod(item, mid, loader, j_name, class_being_redefined, protection_domain, arr); + if (obj != nullptr) { + arr = reinterpret_cast(obj); + } } } *new_class_data_len = jni_env->GetArrayLength(arr); @@ -64,6 +84,7 @@ static void InitTools(JNIEnv *env) { ClassAssertionError = reinterpret_cast(env->NewGlobalRef(env->FindClass("java/lang/AssertionError"))); ClassClass = reinterpret_cast(env->NewGlobalRef(env->FindClass("java/lang/Class"))); ClassMethod = reinterpret_cast(env->NewGlobalRef(env->FindClass("java/lang/reflect/Method"))); + ClassThread = reinterpret_cast(env->NewGlobalRef(env->FindClass("java/lang/Thread"))); ClassIllegalArgumentException = reinterpret_cast(env->NewGlobalRef(env->FindClass("java/lang/IllegalArgumentException"))); ClassNoSuchElementException = reinterpret_cast(env->NewGlobalRef(env->FindClass("java/util/NoSuchElementException"))); ClassBoolean = reinterpret_cast(env->NewGlobalRef(env->FindClass("java/lang/Boolean"))); diff --git a/src/main/resources/libnativeutil.dll b/src/main/resources/libnativeutil.dll index be3668f..1147d6d 100644 Binary files a/src/main/resources/libnativeutil.dll and b/src/main/resources/libnativeutil.dll differ diff --git a/src/main/resources/libnativeutil.so b/src/main/resources/libnativeutil.so index 6d4f81d..bddd57a 100644 Binary files a/src/main/resources/libnativeutil.so and b/src/main/resources/libnativeutil.so differ diff --git a/src/test/java/net/blueberrymc/native_util/NativeUtilTest.java b/src/test/java/net/blueberrymc/native_util/NativeUtilTest.java index 6b3f132..4fbe428 100644 --- a/src/test/java/net/blueberrymc/native_util/NativeUtilTest.java +++ b/src/test/java/net/blueberrymc/native_util/NativeUtilTest.java @@ -83,25 +83,6 @@ public void testGetObjectSize() { assert size == 16L : size; } - @Test - public void testClassLoadHook() { - AtomicBoolean hasLoaded = new AtomicBoolean(); - NativeUtil.registerClassLoadHook((loader, className, classBeingRedefined, protectionDomain, buf) -> { - System.out.println("Loading class " + className); - if (className.equals("net/blueberrymc/native_util/TestClass")) { - hasLoaded.set(true); - } - return null; - }); - new TestClass(); - try { - Thread.sleep(100); - } catch (InterruptedException e) { - e.printStackTrace(); - } - assert hasLoaded.get() : "net/blueberrymc/native_util/TestClass was not loaded"; - } - private static class A { @SuppressWarnings("unused") public int getSomething() { diff --git a/src/test/java/net/blueberrymc/native_util/TestClass.java b/src/test/java/net/blueberrymc/native_util/TestClass.java deleted file mode 100644 index 1bf8977..0000000 --- a/src/test/java/net/blueberrymc/native_util/TestClass.java +++ /dev/null @@ -1,4 +0,0 @@ -package net.blueberrymc.native_util; - -public class TestClass { -}