From 95f437609edabdcd4dca1e18e39e3a60709b87ae Mon Sep 17 00:00:00 2001
From: Cody Tapscott <topolarity@tapscott.me>
Date: Tue, 14 Mar 2023 12:30:03 -0400
Subject: [PATCH] Lookup `ccall` symbols in local internal libraries first

This is similar to what we do on Windows, where handles to the exact
library are required anyway. This is enough to make sure that our
`dlsym` lookups are directed to the correct libjulia, even when loading
a Julia run-time within Julia.

The second change needed to get things working (not included in this
commit) is to add symbol versioning, so that the runtime linker does not
mix up symbols between the two libraries.
---
 src/ccall.cpp            | 30 +++++++++++--------------
 src/codegen.cpp          |  2 --
 src/dlload.c             | 48 ++++++++++++++++++++--------------------
 src/init.c               |  8 ++-----
 src/julia_internal.h     |  9 ++++----
 src/runtime_ccall.cpp    |  2 --
 src/runtime_intrinsics.c |  4 +---
 7 files changed, 45 insertions(+), 58 deletions(-)

diff --git a/src/ccall.cpp b/src/ccall.cpp
index e490f4146cad2..1087525e1b341 100644
--- a/src/ccall.cpp
+++ b/src/ccall.cpp
@@ -61,7 +61,6 @@ static bool runtime_sym_gvs(jl_codectx_t &ctx, const char *f_lib, const char *f_
     bool runtime_lib = false;
     GlobalVariable *libptrgv;
     jl_codegen_params_t::SymMapGV *symMap;
-#ifdef _OS_WINDOWS_
     if ((intptr_t)f_lib == (intptr_t)JL_EXE_LIBNAME) {
         libptrgv = prepare_global_in(M, jlexe_var);
         symMap = &ctx.emission_context.symMapExe;
@@ -74,9 +73,7 @@ static bool runtime_sym_gvs(jl_codectx_t &ctx, const char *f_lib, const char *f_
         libptrgv = prepare_global_in(M, jldll_var);
         symMap = &ctx.emission_context.symMapDll;
     }
-    else
-#endif
-    if (f_lib == NULL) {
+    else if (f_lib == NULL) {
         libptrgv = jl_emit_RTLD_DEFAULT_var(M);
         symMap = &ctx.emission_context.symMapDefault;
     }
@@ -631,16 +628,12 @@ static void interpret_symbol_arg(jl_codectx_t &ctx, native_sym_arg_t &out, jl_va
                 std::string iname("i");
                 iname += f_name;
                 if (jl_dlsym(jl_libjulia_internal_handle, iname.c_str(), &symaddr, 0)) {
-#ifdef _OS_WINDOWS_
                     f_lib = JL_LIBJULIA_INTERNAL_DL_LIBNAME;
-#endif
                     f_name = jl_symbol_name(jl_symbol(iname.c_str()));
                 }
-#ifdef _OS_WINDOWS_
                 else {
-                    f_lib = jl_dlfind_win32(f_name);
+                    f_lib = jl_dlfind(f_name);
                 }
-#endif
             }
         }
         else if (jl_is_cpointer_type(jl_typeof(ptr))) {
@@ -726,7 +719,8 @@ static jl_cgval_t emit_cglobal(jl_codectx_t &ctx, jl_value_t **args, size_t narg
             void *symaddr;
 
             void* libsym = jl_get_library_(sym.f_lib, 0);
-            if (!libsym || !jl_dlsym(libsym, sym.f_name, &symaddr, 0)) {
+            int symbol_found = jl_dlsym(libsym, sym.f_name, &symaddr, 0);
+            if (!libsym || !symbol_found) {
                 // Error mode, either the library or the symbol couldn't be find during compiletime.
                 // Fallback to a runtime symbol lookup.
                 res = runtime_sym_lookup(ctx, cast<PointerType>(getInt8PtrTy(ctx.builder.getContext())), sym.f_lib, NULL, sym.f_name, ctx.f);
@@ -1381,18 +1375,19 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
         if ((uintptr_t)fptr == ptr)
             return true;
         if (f_lib) {
-#ifdef _OS_WINDOWS_
             if ((f_lib == JL_EXE_LIBNAME) || // preventing invalid pointer access
                 (f_lib == JL_LIBJULIA_INTERNAL_DL_LIBNAME) ||
-                (f_lib == JL_LIBJULIA_DL_LIBNAME) ||
-                (!strcmp(f_lib, jl_crtdll_basename))) {
+                (f_lib == JL_LIBJULIA_DL_LIBNAME)) {
+                // libjulia-like
+            }
+            else
+#ifdef _OS_WINDOWS_
+            if (strcmp(f_lib, jl_crtdll_basename) == 0) {
                 // libjulia-like
             }
             else
-                return false;
-#else
-            return false;
 #endif
+            return false;
         }
         return f_name && f_name == name;
     };
@@ -2082,7 +2077,8 @@ jl_cgval_t function_sig_t::emit_a_ccall(
         else {
             void *symaddr;
             void *libsym = jl_get_library_(symarg.f_lib, 0);
-            if (!libsym || !jl_dlsym(libsym, symarg.f_name, &symaddr, 0)) {
+            int symbol_found = jl_dlsym(libsym, symarg.f_name, &symaddr, 0);
+            if (!libsym || !symbol_found) {
                 ++DeferredCCallLookups;
                 // either the library or the symbol could not be found, place a runtime
                 // lookup here instead.
diff --git a/src/codegen.cpp b/src/codegen.cpp
index b6b86ba4442e1..0b62c481b9e41 100644
--- a/src/codegen.cpp
+++ b/src/codegen.cpp
@@ -8686,11 +8686,9 @@ static void init_jit_functions(void)
 {
     add_named_global(jlstack_chk_guard_var, &__stack_chk_guard);
     add_named_global(jlRTLD_DEFAULT_var, &jl_RTLD_DEFAULT_handle);
-#ifdef _OS_WINDOWS_
     add_named_global(jlexe_var, &jl_exe_handle);
     add_named_global(jldll_var, &jl_libjulia_handle);
     add_named_global(jldlli_var, &jl_libjulia_internal_handle);
-#endif
     auto size2pjlvalue = [](Type *T_size) -> Type * {
         return get_pjlvalue(T_size->getContext());
     };
diff --git a/src/dlload.c b/src/dlload.c
index 9f4e8be29952d..64365848ad6f3 100644
--- a/src/dlload.c
+++ b/src/dlload.c
@@ -189,6 +189,7 @@ JL_DLLEXPORT JL_NO_SANITIZE void *jl_dlopen(const char *filename, unsigned flags
         if (!dlopen)
             return NULL;
         void *libdl_handle = dlopen("libdl.so", RTLD_NOW | RTLD_NOLOAD);
+        assert(libdl_handle);
         dlopen = (dlopen_prototype*)dlsym(libdl_handle, "dlopen");
         dlclose(libdl_handle);
         assert(dlopen);
@@ -239,6 +240,25 @@ JL_DLLEXPORT int jl_dlclose(void *handle) JL_NOTSAFEPOINT
 #endif
 }
 
+void *jl_find_dynamic_library_by_addr(void *symbol) {
+    void *handle;
+#ifdef _OS_WINDOWS_
+    if (!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
+                            (LPCWSTR)symbol,
+                            (HMODULE*)&handle)) {
+        jl_error("could not load base module");
+    }
+#else
+    Dl_info info;
+    if (!dladdr(symbol, &info) || !info.dli_fname) {
+        jl_error("could not load base module");
+    }
+    handle = dlopen(info.dli_fname, RTLD_NOW | RTLD_NOLOAD);
+    dlclose(handle); // Undo ref count increment from `dlopen`
+#endif
+    return handle;
+}
+
 JL_DLLEXPORT void *jl_load_dynamic_library(const char *modname, unsigned flags, int throw_err)
 {
     char path[PATHBUF], relocated[PATHBUF];
@@ -255,26 +275,6 @@ JL_DLLEXPORT void *jl_load_dynamic_library(const char *modname, unsigned flags,
     int n_extensions = endswith_extension(modname) ? 1 : N_EXTENSIONS;
     int ret;
 
-    /*
-      this branch returns handle of libjulia-internal
-    */
-    if (modname == NULL) {
-#ifdef _OS_WINDOWS_
-        if (!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
-                                (LPCWSTR)(uintptr_t)(&jl_load_dynamic_library),
-                                (HMODULE*)&handle)) {
-            jl_error("could not load base module");
-        }
-#else
-        Dl_info info;
-        if (!dladdr((void*)(uintptr_t)&jl_load_dynamic_library, &info) || !info.dli_fname) {
-            jl_error("could not load base module");
-        }
-        handle = dlopen(info.dli_fname, RTLD_NOW);
-#endif
-        goto done;
-    }
-
     abspath = jl_isabspath(modname);
     is_atpath = 0;
 
@@ -421,9 +421,8 @@ JL_DLLEXPORT int jl_dlsym(void *handle, const char *symbol, void ** value, int t
     return symbol_found;
 }
 
-#ifdef _OS_WINDOWS_
-//Look for symbols in win32 libraries
-JL_DLLEXPORT const char *jl_dlfind_win32(const char *f_name)
+// Look for symbols in internal libraries
+JL_DLLEXPORT const char *jl_dlfind(const char *f_name)
 {
     void * dummy;
     if (jl_dlsym(jl_exe_handle, f_name, &dummy, 0))
@@ -432,6 +431,7 @@ JL_DLLEXPORT const char *jl_dlfind_win32(const char *f_name)
         return JL_LIBJULIA_INTERNAL_DL_LIBNAME;
     if (jl_dlsym(jl_libjulia_handle, f_name, &dummy, 0))
         return JL_LIBJULIA_DL_LIBNAME;
+#ifdef _OS_WINDOWS_
     if (jl_dlsym(jl_kernel32_handle, f_name, &dummy, 0))
         return "kernel32";
     if (jl_dlsym(jl_crtdll_handle, f_name, &dummy, 0)) // Prefer crtdll over ntdll
@@ -440,6 +440,7 @@ JL_DLLEXPORT const char *jl_dlfind_win32(const char *f_name)
         return "ntdll";
     if (jl_dlsym(jl_winsock_handle, f_name, &dummy, 0))
         return "ws2_32";
+#endif
     // additional common libraries (libc?) could be added here, but in general,
     // it is better to specify the library explicitly in the code. This exists
     // mainly to ease compatibility with linux, and for libraries that don't
@@ -451,7 +452,6 @@ JL_DLLEXPORT const char *jl_dlfind_win32(const char *f_name)
     // which defaults to jl_libjulia_internal_handle, where we won't find it, and
     // will throw the appropriate error.
 }
-#endif
 
 #ifdef __cplusplus
 }
diff --git a/src/init.c b/src/init.c
index 5990bd24aaabd..95a5a32704f2f 100644
--- a/src/init.c
+++ b/src/init.c
@@ -763,15 +763,11 @@ JL_DLLEXPORT void julia_init(JL_IMAGE_SEARCH rel)
     void *stack_lo, *stack_hi;
     jl_init_stack_limits(1, &stack_lo, &stack_hi);
 
-    jl_libjulia_internal_handle = jl_load_dynamic_library(NULL, JL_RTLD_DEFAULT, 1);
+    jl_libjulia_internal_handle = jl_find_dynamic_library_by_addr(&jl_load_dynamic_library);
+    jl_libjulia_handle = jl_find_dynamic_library_by_addr(&jl_any_type);
 #ifdef _OS_WINDOWS_
     jl_exe_handle = GetModuleHandleA(NULL);
     jl_RTLD_DEFAULT_handle = jl_libjulia_internal_handle;
-    if (!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
-                            (LPCWSTR)&jl_any_type,
-                            (HMODULE*)&jl_libjulia_handle)) {
-        jl_error("could not load base module");
-    }
     jl_ntdll_handle = jl_dlopen("ntdll.dll", JL_RTLD_NOLOAD); // bypass julia's pathchecking for system dlls
     jl_kernel32_handle = jl_dlopen("kernel32.dll", JL_RTLD_NOLOAD);
     jl_crtdll_handle = jl_dlopen(jl_crtdll_name, JL_RTLD_NOLOAD);
diff --git a/src/julia_internal.h b/src/julia_internal.h
index 4f1a0b4513d8d..5d3f26e1eb1ba 100644
--- a/src/julia_internal.h
+++ b/src/julia_internal.h
@@ -1254,11 +1254,11 @@ JL_DLLEXPORT uint64_t jl_rand(void) JL_NOTSAFEPOINT;
 JL_DLLEXPORT void jl_srand(uint64_t) JL_NOTSAFEPOINT;
 JL_DLLEXPORT void jl_init_rand(void);
 
+JL_DLLEXPORT extern void *jl_exe_handle;
+JL_DLLEXPORT extern void *jl_libjulia_handle;
 JL_DLLEXPORT extern void *jl_libjulia_internal_handle;
 JL_DLLEXPORT extern void *jl_RTLD_DEFAULT_handle;
 #if defined(_OS_WINDOWS_)
-JL_DLLEXPORT extern void *jl_exe_handle;
-JL_DLLEXPORT extern void *jl_libjulia_handle;
 JL_DLLEXPORT extern const char *jl_crtdll_basename;
 extern void *jl_ntdll_handle;
 extern void *jl_kernel32_handle;
@@ -1268,6 +1268,7 @@ void win32_formatmessage(DWORD code, char *reason, int len) JL_NOTSAFEPOINT;
 #endif
 
 JL_DLLEXPORT void *jl_get_library_(const char *f_lib, int throw_err);
+void *jl_find_dynamic_library_by_addr(void *symbol);
 #define jl_get_library(f_lib) jl_get_library_(f_lib, 1)
 JL_DLLEXPORT void *jl_load_and_lookup(const char *f_lib, const char *f_name, _Atomic(void*) *hnd);
 JL_DLLEXPORT void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name);
@@ -1277,11 +1278,11 @@ JL_DLLEXPORT jl_value_t *jl_get_cfunction_trampoline(
     jl_unionall_t *env, jl_value_t **vals);
 
 
-// Windows only
+// Special filenames used to refer to internal julia libraries
 #define JL_EXE_LIBNAME                  ((const char*)1)
 #define JL_LIBJULIA_DL_LIBNAME          ((const char*)2)
 #define JL_LIBJULIA_INTERNAL_DL_LIBNAME ((const char*)3)
-JL_DLLEXPORT const char *jl_dlfind_win32(const char *name);
+JL_DLLEXPORT const char *jl_dlfind(const char *name);
 
 // libuv wrappers:
 JL_DLLEXPORT int jl_fs_rename(const char *src_path, const char *dst_path);
diff --git a/src/runtime_ccall.cpp b/src/runtime_ccall.cpp
index e3543c9f62656..fa2184f555f28 100644
--- a/src/runtime_ccall.cpp
+++ b/src/runtime_ccall.cpp
@@ -31,14 +31,12 @@ void *jl_get_library_(const char *f_lib, int throw_err)
 {
     if (f_lib == NULL)
         return jl_RTLD_DEFAULT_handle;
-#ifdef _OS_WINDOWS_
     if (f_lib == JL_EXE_LIBNAME)
         return jl_exe_handle;
     if (f_lib == JL_LIBJULIA_INTERNAL_DL_LIBNAME)
         return jl_libjulia_internal_handle;
     if (f_lib == JL_LIBJULIA_DL_LIBNAME)
         return jl_libjulia_handle;
-#endif
     JL_LOCK(&libmap_lock);
     // This is the only operation we do on the map, which doesn't invalidate
     // any references or iterators.
diff --git a/src/runtime_intrinsics.c b/src/runtime_intrinsics.c
index 0ac5b277b0657..9babdf89f098b 100644
--- a/src/runtime_intrinsics.c
+++ b/src/runtime_intrinsics.c
@@ -508,10 +508,8 @@ JL_DLLEXPORT jl_value_t *jl_cglobal(jl_value_t *v, jl_value_t *ty)
     else
         JL_TYPECHK(cglobal, symbol, v)
 
-#ifdef _OS_WINDOWS_
     if (!f_lib)
-        f_lib = (char*)jl_dlfind_win32(f_name);
-#endif
+        f_lib = (char*)jl_dlfind(f_name);
 
     void *ptr;
     jl_dlsym(jl_get_library(f_lib), f_name, &ptr, 1);