Skip to content

Commit

Permalink
Add internal dlvsym support for ccall's
Browse files Browse the repository at this point in the history
This change lays the necessary groundwork to support performing versioned
symbol lookups using `dlvsym`.

Since there's no way to ask for a versioned symbol from `ccall` this
code is currently unused, but it is a pre-requisite to add symbol
versioning to Julia's internal libraries.
  • Loading branch information
topolarity committed Mar 15, 2023
1 parent 7ba7e32 commit f32da50
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 40 deletions.
60 changes: 36 additions & 24 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ static Value *runtime_sym_lookup(
IRBuilder<> &irbuilder,
jl_codectx_t *ctx,
PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr,
const char *f_name, Function *f,
const char *f_name, const char *f_version, Function *f,
GlobalVariable *libptrgv,
GlobalVariable *llvmgv, bool runtime_lib)
{
Expand All @@ -129,7 +129,7 @@ static Value *runtime_sym_lookup(
// global HMODULE *libptrgv
// global void **llvmgv
// if (*llvmgv == NULL) {
// *llvmgv = jl_load_and_lookup(f_lib, f_name, libptrgv);
// *llvmgv = jl_load_and_lookup(f_lib, f_name, f_version, libptrgv);
// }
// return (*llvmgv)
auto T_pvoidfunc = JuliaType::get_pvoidfunc_ty(irbuilder.getContext());
Expand Down Expand Up @@ -158,8 +158,9 @@ static Value *runtime_sym_lookup(
Value *nameval = stringConstPtr(emission_context, irbuilder, f_name);
if (lib_expr) {
jl_cgval_t libval = emit_expr(*ctx, lib_expr);
llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jllazydlsym_func),
{ boxed(*ctx, libval), nameval });
Value *versionval = nullableStringConstPtr(emission_context, irbuilder, f_version);
llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jllazydlvsym_func),
{ boxed(*ctx, libval), nameval, versionval });
}
else {
Value *libname;
Expand All @@ -170,8 +171,9 @@ static Value *runtime_sym_lookup(
// f_lib is actually one of the special sentinel values
libname = ConstantExpr::getIntToPtr(ConstantInt::get(getSizeTy(irbuilder.getContext()), (uintptr_t)f_lib), getInt8PtrTy(irbuilder.getContext()));
}
llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jldlsym_func),
{ libname, nameval, libptrgv });
Value *versionval = nullableStringConstPtr(emission_context, irbuilder, f_version);
llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jldlvsym_func),
{ libname, nameval, versionval, libptrgv });
}
StoreInst *store = irbuilder.CreateAlignedStore(llvmf, llvmgv, Align(sizeof(void*)));
store->setAtomic(AtomicOrdering::Release);
Expand All @@ -188,18 +190,18 @@ static Value *runtime_sym_lookup(
static Value *runtime_sym_lookup(
jl_codectx_t &ctx,
PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr,
const char *f_name, Function *f,
const char *f_name, const char *f_version, Function *f,
GlobalVariable *libptrgv,
GlobalVariable *llvmgv, bool runtime_lib)
{
return runtime_sym_lookup(ctx.emission_context, ctx.builder, &ctx, funcptype, f_lib, lib_expr,
f_name, f, libptrgv, llvmgv, runtime_lib);
f_name, f_version, f, libptrgv, llvmgv, runtime_lib);
}

static Value *runtime_sym_lookup(
jl_codectx_t &ctx,
PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr,
const char *f_name, Function *f)
const char *f_name, const char *f_version, Function *f)
{
auto T_pvoidfunc = JuliaType::get_pvoidfunc_ty(ctx.builder.getContext());
GlobalVariable *libptrgv;
Expand All @@ -223,15 +225,16 @@ static Value *runtime_sym_lookup(
libptrgv = prepare_global_in(jl_Module, libptrgv);
}
llvmgv = prepare_global_in(jl_Module, llvmgv);
return runtime_sym_lookup(ctx, funcptype, f_lib, lib_expr, f_name, f, libptrgv, llvmgv, runtime_lib);
return runtime_sym_lookup(ctx, funcptype, f_lib, lib_expr, f_name, f_version, f, libptrgv, llvmgv, runtime_lib);
}

// Emit a "PLT" entry that will be lazily initialized
// when being called the first time.
static GlobalVariable *emit_plt_thunk(
jl_codectx_t &ctx,
FunctionType *functype, const AttributeList &attrs,
CallingConv::ID cc, const char *f_lib, const char *f_name,
CallingConv::ID cc,
const char *f_lib, const char *f_name, const char *f_version,
GlobalVariable *libptrgv, GlobalVariable *llvmgv,
bool runtime_lib)
{
Expand All @@ -256,8 +259,8 @@ static GlobalVariable *emit_plt_thunk(
fname);
BasicBlock *b0 = BasicBlock::Create(M->getContext(), "top", plt);
IRBuilder<> irbuilder(b0);
Value *ptr = runtime_sym_lookup(ctx.emission_context, irbuilder, NULL, funcptype, f_lib, NULL, f_name, plt, libptrgv,
llvmgv, runtime_lib);
Value *ptr = runtime_sym_lookup(ctx.emission_context, irbuilder, NULL, funcptype,
f_lib, NULL, f_name, f_version, plt, libptrgv, llvmgv, runtime_lib);
StoreInst *store = irbuilder.CreateAlignedStore(irbuilder.CreateBitCast(ptr, T_pvoidfunc), got, Align(sizeof(void*)));
store->setAtomic(AtomicOrdering::Release);
SmallVector<Value*, 16> args;
Expand Down Expand Up @@ -303,7 +306,8 @@ static Value *emit_plt(
jl_codectx_t &ctx,
FunctionType *functype,
const AttributeList &attrs,
CallingConv::ID cc, const char *f_lib, const char *f_name)
CallingConv::ID cc,
const char *f_lib, const char *f_name, const char *f_version)
{
++PLT;
assert(ctx.emission_context.imaging);
Expand All @@ -320,7 +324,7 @@ static Value *emit_plt(
GlobalVariable *&sharedgot = pltMap[key];
if (!sharedgot) {
sharedgot = emit_plt_thunk(ctx,
functype, attrs, cc, f_lib, f_name, libptrgv, llvmgv, runtime_lib);
functype, attrs, cc, f_lib, f_name, f_version, libptrgv, llvmgv, runtime_lib);
}
GlobalVariable *got = prepare_global_in(jl_Module, sharedgot);
LoadInst *got_val = ctx.builder.CreateAlignedLoad(got->getValueType(), got, Align(sizeof(void*)));
Expand Down Expand Up @@ -565,6 +569,7 @@ typedef struct {
void (*fptr)(void); // if the argument is a constant pointer
const char *f_name; // if the symbol name is known
const char *f_lib; // if a library name is specified
const char *f_version;
jl_value_t *lib_expr; // expression to compute library path lazily
jl_value_t *gcroot;
} native_sym_arg_t;
Expand All @@ -576,6 +581,8 @@ static void interpret_symbol_arg(jl_codectx_t &ctx, native_sym_arg_t &out, jl_va
void (*&fptr)(void) = out.fptr;
const char *&f_name = out.f_name;
const char *&f_lib = out.f_lib;
const char *&f_version = out.f_version;
f_version = NULL;

jl_value_t *ptr = static_eval(ctx, arg);
if (ptr == NULL) {
Expand Down Expand Up @@ -709,20 +716,24 @@ static jl_cgval_t emit_cglobal(jl_codectx_t &ctx, jl_value_t **args, size_t narg
}
else {
if (sym.lib_expr) {
res = runtime_sym_lookup(ctx, cast<PointerType>(getInt8PtrTy(ctx.builder.getContext())), NULL, sym.lib_expr, sym.f_name, ctx.f);
res = runtime_sym_lookup(ctx, cast<PointerType>(getInt8PtrTy(ctx.builder.getContext())), NULL,
sym.lib_expr, sym.f_name, sym.f_version, ctx.f);
}
else if (ctx.emission_context.imaging) {
res = runtime_sym_lookup(ctx, cast<PointerType>(getInt8PtrTy(ctx.builder.getContext())), sym.f_lib, NULL, sym.f_name, ctx.f);
res = runtime_sym_lookup(ctx, cast<PointerType>(getInt8PtrTy(ctx.builder.getContext())),
sym.f_lib, NULL, sym.f_name, sym.f_version, ctx.f);
res = ctx.builder.CreatePtrToInt(res, lrt);
}
else {
void *symaddr;

void* libsym = jl_get_library_(sym.f_lib, 0);
if (!libsym || !jl_dlsym(libsym, sym.f_name, &symaddr, 0)) {
int symbol_found = jl_dlvsym(libsym, sym.f_name, sym.f_version, &symaddr, 0);
if (!libsym || !symbol_found) {
// Error mode, either the library or the symbol couldn't be find during compiletime.
// Fallback to a runtime symbol lookup.
res = runtime_sym_lookup(ctx, cast<PointerType>(getInt8PtrTy(ctx.builder.getContext())), sym.f_lib, NULL, sym.f_name, ctx.f);
res = runtime_sym_lookup(ctx, cast<PointerType>(getInt8PtrTy(ctx.builder.getContext())),
sym.f_lib, NULL, sym.f_name, sym.f_version, ctx.f);
res = ctx.builder.CreatePtrToInt(res, lrt);
} else {
// since we aren't saving this code, there's no sense in
Expand Down Expand Up @@ -2043,25 +2054,26 @@ jl_cgval_t function_sig_t::emit_a_ccall(
PointerType *funcptype = PointerType::get(functype, 0);
if (symarg.lib_expr) {
++DeferredCCallLookups;
llvmf = runtime_sym_lookup(ctx, funcptype, NULL, symarg.lib_expr, symarg.f_name, ctx.f);
llvmf = runtime_sym_lookup(ctx, funcptype, NULL, symarg.lib_expr, symarg.f_name, symarg.f_version, ctx.f);
}
else if (ctx.emission_context.imaging) {
++DeferredCCallLookups;
// vararg requires musttail,
// but musttail is incompatible with noreturn.
if (functype->isVarArg())
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, NULL, symarg.f_name, ctx.f);
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, NULL, symarg.f_name, symarg.f_version, ctx.f);
else
llvmf = emit_plt(ctx, functype, attributes, cc, symarg.f_lib, symarg.f_name);
llvmf = emit_plt(ctx, functype, attributes, cc, symarg.f_lib, symarg.f_name, symarg.f_version);
}
else {
void *symaddr;
void *libsym = jl_get_library_(symarg.f_lib, 0);
if (!libsym || !jl_dlsym(libsym, symarg.f_name, &symaddr, 0)) {
int symbol_found = jl_dlvsym(libsym, symarg.f_name, symarg.f_version, &symaddr, 0);
if (!libsym || !symbol_found) {
++DeferredCCallLookups;
// either the library or the symbol could not be found, place a runtime
// lookup here instead.
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, NULL, symarg.f_name, ctx.f);
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, NULL, symarg.f_name, symarg.f_version, ctx.f);
} else {
++LiteralCCalls;
// since we aren't saving this code, there's no sense in
Expand Down
11 changes: 11 additions & 0 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ static Value *stringConstPtr(
Args);
}

static Value *nullableStringConstPtr(
jl_codegen_params_t &emission_context,
IRBuilder<> &irbuilder,
const char *txt)
{
if (txt == NULL)
return ConstantPointerNull::get(PointerType::get(Type::getInt8Ty(irbuilder.getContext()), 0));

return stringConstPtr(emission_context, irbuilder, txt);
}


// --- MDNode ---
Metadata *to_md_tree(jl_value_t *val, LLVMContext &ctxt) {
Expand Down
10 changes: 5 additions & 5 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,16 +1024,16 @@ static const auto memcmp_func = new JuliaFunction{
None); },
// TODO: inferLibFuncAttributes(*memcmp_func, TLI);
};
static const auto jldlsym_func = new JuliaFunction{
static const auto jldlvsym_func = new JuliaFunction{
XSTR(jl_load_and_lookup),
[](LLVMContext &C) { return FunctionType::get(JuliaType::get_pvoidfunc_ty(C),
{getInt8PtrTy(C), getInt8PtrTy(C), PointerType::get(getInt8PtrTy(C), 0)}, false); },
{getInt8PtrTy(C), getInt8PtrTy(C), getInt8PtrTy(C), PointerType::get(getInt8PtrTy(C), 0)}, false); },
nullptr,
};
static const auto jllazydlsym_func = new JuliaFunction{
static const auto jllazydlvsym_func = new JuliaFunction{
XSTR(jl_lazy_load_and_lookup),
[](LLVMContext &C) { return FunctionType::get(JuliaType::get_pvoidfunc_ty(C),
{JuliaType::get_prjlvalue_ty(C), getInt8PtrTy(C)}, false); },
{JuliaType::get_prjlvalue_ty(C), getInt8PtrTy(C), getInt8PtrTy(C)}, false); },
nullptr,
};
static const auto jltypeassert_func = new JuliaFunction{
Expand Down Expand Up @@ -8742,7 +8742,7 @@ static void init_jit_functions(void)
add_named_global(jl_typeof_func, (void*)NULL);
add_named_global(jl_write_barrier_func, (void*)NULL);
add_named_global(jl_write_barrier_binding_func, (void*)NULL);
add_named_global(jldlsym_func, &jl_load_and_lookup);
add_named_global(jldlvsym_func, &jl_load_and_lookup);
add_named_global("jl_adopt_thread", &jl_adopt_thread);
add_named_global(jlgetcfunctiontrampoline_func, &jl_get_cfunction_trampoline);
add_named_global(jlgetnthfieldchecked_func, &jl_get_nth_field_checked);
Expand Down
39 changes: 34 additions & 5 deletions src/dlload.c
Original file line number Diff line number Diff line change
Expand Up @@ -381,16 +381,28 @@ JL_DLLEXPORT void *jl_load_dynamic_library(const char *modname, unsigned flags,
return handle;
}

JL_DLLEXPORT int jl_dlsym(void *handle, const char *symbol, void ** value, int throw_err) JL_NOTSAFEPOINT
JL_DLLEXPORT int jl_dlvsym(void *handle, const char *symbol, const char *version, void ** value, int throw_err) JL_NOTSAFEPOINT
{

int symbol_found = 0;

/* First, get the symbol value */
if (version != NULL) {
#if defined(_OS_WINDOWS_) || !defined(__USE_GNU)
if (throw_err) {
jl_errorf("could not load symbol \"%s\" (version \"%s\"):\ndlvsym is not available on this platform", symbol, version);
}
return 0;
#else
*value = dlvsym(handle, symbol, version);
#endif
} else {
#ifdef _OS_WINDOWS_
*value = GetProcAddress((HMODULE) handle, symbol);
*value = GetProcAddress((HMODULE) handle, symbol);
#else
*value = dlsym(handle, symbol);
*value = dlsym(handle, symbol);
#endif
}

/* Next, check for errors. On Windows, a NULL pointer means the symbol was
* not found. On everything else, we can have NULL symbols, so we check for
Expand All @@ -405,7 +417,15 @@ JL_DLLEXPORT int jl_dlsym(void *handle, const char *symbol, void ** value, int t
const char *err = "";
if (!symbol_found) {
dlerror(); /* Reset error status. */
*value = dlsym(handle, symbol);
if (version != NULL) {
#if !defined(__USE_GNU)
// unreachable
#else
*value = dlvsym(handle, symbol, version);
#endif
} else {
*value = dlsym(handle, symbol);
}
err = dlerror();
symbol_found = *value != NULL || err == NULL;
}
Expand All @@ -416,11 +436,20 @@ JL_DLLEXPORT int jl_dlsym(void *handle, const char *symbol, void ** value, int t
char err[256];
win32_formatmessage(GetLastError(), err, sizeof(err));
#endif
jl_errorf("could not load symbol \"%s\":\n%s", symbol, err);
if (version != NULL) {
jl_errorf("could not load symbol \"%s\" (version \"%s\"):\n%s", symbol, version, err);
} else {
jl_errorf("could not load symbol \"%s\":\n%s", symbol, err);
}
}
return symbol_found;
}

JL_DLLEXPORT int jl_dlsym(void *handle, const char *symbol, void ** value, int throw_err) JL_NOTSAFEPOINT
{
return jl_dlvsym(handle, symbol, NULL, value, throw_err);
}

#ifdef _OS_WINDOWS_
//Look for symbols in win32 libraries
JL_DLLEXPORT const char *jl_dlfind_win32(const char *f_name)
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
XX(jl_dlclose) \
XX(jl_dlopen) \
XX(jl_dlsym) \
XX(jl_dlvsym) \
XX(jl_dump_host_cpu) \
XX(jl_check_pkgimage_clones) \
XX(jl_egal) \
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,7 @@ JL_DLLEXPORT jl_libhandle jl_load_dynamic_library(const char *fname, unsigned fl
JL_DLLEXPORT jl_libhandle jl_dlopen(const char *filename, unsigned flags) JL_NOTSAFEPOINT;
JL_DLLEXPORT int jl_dlclose(jl_libhandle handle) JL_NOTSAFEPOINT;
JL_DLLEXPORT int jl_dlsym(jl_libhandle handle, const char *symbol, void ** value, int throw_err) JL_NOTSAFEPOINT;
JL_DLLEXPORT int jl_dlvsym(jl_libhandle handle, const char *symbol, const char *version, void ** value, int throw_err) JL_NOTSAFEPOINT;

// evaluation
JL_DLLEXPORT jl_value_t *jl_toplevel_eval(jl_module_t *m, jl_value_t *v);
Expand Down
4 changes: 2 additions & 2 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1285,8 +1285,8 @@ void win32_formatmessage(DWORD code, char *reason, int len) JL_NOTSAFEPOINT;

JL_DLLEXPORT void *jl_get_library_(const char *f_lib, int throw_err);
#define jl_get_library(f_lib) jl_get_library_(f_lib, 1)
JL_DLLEXPORT void *jl_load_and_lookup(const char *f_lib, const char *f_name, _Atomic(void*) *hnd);
JL_DLLEXPORT void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name);
JL_DLLEXPORT void *jl_load_and_lookup(const char *f_lib, const char *f_name, const char *f_version, _Atomic(void*) *hnd);
JL_DLLEXPORT void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name, const char *f_version);
JL_DLLEXPORT jl_value_t *jl_get_cfunction_trampoline(
jl_value_t *fobj, jl_datatype_t *result, htable_t *cache, jl_svec_t *fill,
void *(*init_trampoline)(void *tramp, void **nval),
Expand Down
8 changes: 4 additions & 4 deletions src/runtime_ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,19 @@ void *jl_get_library_(const char *f_lib, int throw_err)
}

extern "C" JL_DLLEXPORT
void *jl_load_and_lookup(const char *f_lib, const char *f_name, _Atomic(void*) *hnd)
void *jl_load_and_lookup(const char *f_lib, const char *f_name, const char *f_version, _Atomic(void*) *hnd)
{
void *handle = jl_atomic_load_acquire(hnd);
if (!handle)
jl_atomic_store_release(hnd, (handle = jl_get_library(f_lib)));
void * ptr;
jl_dlsym(handle, f_name, &ptr, 1);
jl_dlvsym(handle, f_name, f_version, &ptr, 1);
return ptr;
}

// jl_load_and_lookup, but with library computed at run time on first call
extern "C" JL_DLLEXPORT
void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name)
void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name, const char *f_version)
{
char *f_lib;

Expand All @@ -77,7 +77,7 @@ void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name)
else
jl_type_error("ccall", (jl_value_t*)jl_symbol_type, lib_val);
void *ptr;
jl_dlsym(jl_get_library(f_lib), f_name, &ptr, 1);
jl_dlvsym(jl_get_library(f_lib), f_name, f_version, &ptr, 1);
return ptr;
}

Expand Down

0 comments on commit f32da50

Please sign in to comment.