Skip to content

Commit

Permalink
Handle external functions in other images
Browse files Browse the repository at this point in the history
When we find an edge to an external function (already cached in an
loaded pkgimage), we emit a global variable which we will patch during
loading with the address of the function to call.
  • Loading branch information
vchuravy committed Oct 17, 2022
1 parent 381f5bd commit 2ed4acd
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 32 deletions.
59 changes: 56 additions & 3 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ typedef struct {
std::vector<GlobalValue*> jl_sysimg_gvars;
std::map<jl_code_instance_t*, std::tuple<uint32_t, uint32_t>> jl_fvar_map;
std::vector<void*> jl_value_to_llvm;
std::vector<jl_code_instance_t*> jl_external_to_llvm;
} jl_native_code_desc_t;

extern "C" JL_DLLEXPORT
Expand All @@ -118,6 +119,15 @@ void jl_get_llvm_gvs_impl(void *native_code, arraylist_t *gvs)
memcpy(gvs->items, data->jl_value_to_llvm.data(), gvs->len * sizeof(void*));
}

extern "C" JL_DLLEXPORT
void jl_get_llvm_external_fns_impl(void *native_code, arraylist_t *external_fns)
{
jl_native_code_desc_t *data = (jl_native_code_desc_t*)native_code;
arraylist_grow(external_fns, data->jl_external_to_llvm.size());
memcpy(external_fns->items, data->jl_external_to_llvm.data(),
external_fns->len * sizeof(jl_code_instance_t*));
}

extern "C" JL_DLLEXPORT
LLVMOrcThreadSafeModuleRef jl_get_llvm_module_impl(void *native_code)
{
Expand Down Expand Up @@ -251,10 +261,12 @@ static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance
// takes the running content that has collected in the shadow module and dump it to disk
// this builds the object file portion of the sysimage files for fast startup, and can
// also be used be extern consumers like GPUCompiler.jl to obtain a module containing
// all reachable & inferrrable functions. The `policy` flag switches between the default
// mode `0`, the extern mode `1`.
// all reachable & inferrrable functions.
// The `policy` flag switches between the default mode `0` and the extern mode `1` used by GPUCompiler.
// `_imaging_mode` controls if raw pointers can be embedded (e.g. the code will be loaded into the same session).
// `_external_linkage` create linkages between pkgimages.
extern "C" JL_DLLEXPORT
void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int _policy, int _imaging_mode)
void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int _policy, int _imaging_mode, int _external_linkage)
{
++CreateNativeCalls;
CreateNativeMax.updateMax(jl_array_len(methods));
Expand Down Expand Up @@ -284,6 +296,7 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
compiler_start_time = jl_hrtime();

params.imaging = imaging;
params.external_linkage = _external_linkage;

// compile all methods for the current world and type-inference world
size_t compile_for[] = { jl_typeinf_world, jl_atomic_load_acquire(&jl_world_counter) };
Expand Down Expand Up @@ -342,6 +355,46 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
}
CreateNativeMethods += emitted.size();

size_t offset = gvars.size();
data->jl_external_to_llvm.resize(params.external_fns.size());

for (auto &extern_fn : params.external_fns) {
jl_code_instance_t *this_code = std::get<0>(extern_fn.first);
bool specsig = std::get<1>(extern_fn.first);
assert(specsig && "Error external_fns doesn't handle non-specsig yet");
Function *F = extern_fn.second;
Module *M = F->getParent();

Type *T_funcp = F->getFunctionType()->getPointerTo();
// Can't create a GC with type FunctionType. Alias also doesn't work
GlobalVariable *GV = new GlobalVariable(*M, T_funcp, false,
GlobalVariable::ExternalLinkage,
Constant::getNullValue(T_funcp),
F->getName());


// Need to insert load instruction... can't RAUW
for (Value *Use: F->users()) {
if (auto CI = dyn_cast<CallInst>(Use)) {
auto Callee = new LoadInst(T_funcp, GV, "", false, Align(1), CI); // TODO correct Align?
CI->setCalledOperand(Callee);
continue;
} else {
llvm::outs() << *Use << "\n";
assert(false);
}
}

assert(F->getNumUses() == 0); // declaration counts as use
GV->takeName(F);
F->eraseFromParent();

size_t idx = gvars.size() - offset;
assert(idx >= 0);
data->jl_external_to_llvm.at(idx) = this_code;
gvars.push_back(std::string(GV->getName()));
}

// clones the contents of the module `m` to the shadow_output collector
// while examining and recording what kind of function pointer we have
for (auto &def : emitted) {
Expand Down
3 changes: 2 additions & 1 deletion src/codegen-stubs.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ JL_DLLEXPORT void jl_dump_native_fallback(void *native_code,
const char *bc_fname, const char *unopt_bc_fname, const char *obj_fname, const char *asm_fname,
const char *sysimg_data, size_t sysimg_len) UNAVAILABLE
JL_DLLEXPORT void jl_get_llvm_gvs_fallback(void *native_code, arraylist_t *gvs) UNAVAILABLE
JL_DLLEXPORT void jl_get_llvm_external_fns_fallback(void *native_code, arraylist_t *gvs) UNAVAILABLE

JL_DLLEXPORT void jl_extern_c_fallback(jl_function_t *f, jl_value_t *rt, jl_value_t *argt, char *name) UNAVAILABLE
JL_DLLEXPORT jl_value_t *jl_dump_method_asm_fallback(jl_method_instance_t *linfo, size_t world,
Expand Down Expand Up @@ -66,7 +67,7 @@ JL_DLLEXPORT size_t jl_jit_total_bytes_fallback(void)
return 0;
}

JL_DLLEXPORT void *jl_create_native_fallback(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int _policy, int _imaging_mode) UNAVAILABLE
JL_DLLEXPORT void *jl_create_native_fallback(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int _policy, int _imaging_mode, int _external_linkage) UNAVAILABLE

JL_DLLEXPORT void jl_dump_compiles_fallback(void *s)
{
Expand Down
50 changes: 40 additions & 10 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,7 @@ class jl_codectx_t {
jl_codegen_params_t &emission_context;
llvm::MapVector<jl_code_instance_t*, jl_codegen_call_target_t> call_targets;
std::map<void*, GlobalVariable*> &global_targets;
std::map<std::tuple<jl_code_instance_t*, bool>, Function*> &external_calls;
Function *f = NULL;
// local var info. globals are not in here.
std::vector<jl_varinfo_t> slots;
Expand Down Expand Up @@ -1454,6 +1455,7 @@ class jl_codectx_t {

bool debug_enabled = false;
bool use_cache = false;
bool external_linkage = false;
const jl_cgparams_t *params = NULL;

std::vector<orc::ThreadSafeModule> llvmcall_modules;
Expand All @@ -1463,8 +1465,10 @@ class jl_codectx_t {
emission_context(params),
call_targets(),
global_targets(params.globals),
external_calls(params.external_fns),
world(params.world),
use_cache(params.cache),
external_linkage(params.external_linkage),
params(params.params) { }

jl_typecache_t &types() {
Expand Down Expand Up @@ -4017,9 +4021,18 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, const
std::string name;
StringRef protoname;
bool need_to_emit = true;
// TODO: We should check if the code is available externally
// and then emit a trampoline.
if (ctx.use_cache) {
bool cache_valid = ctx.use_cache;
bool external = false;
if (ctx.external_linkage) {
if (jl_object_in_image((jl_value_t*)codeinst)) {
// Target is present in another pkgimage
jl_printf(JL_STDERR, "\n (emit_invoke:) Want to resolve method!\n");
cache_valid = true;
external = true;
}
}

if (cache_valid) {
// optimization: emit the correct name immediately, if we know it
// TODO: use `emitted` map here too to try to consolidate names?
auto invoke = jl_atomic_load_relaxed(&codeinst->invoke);
Expand All @@ -4046,6 +4059,13 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, const
result = emit_call_specfun_other(ctx, mi, codeinst->rettype, protoname, argv, nargs, &cc, &return_roots, rt);
else
result = emit_call_specfun_boxed(ctx, codeinst->rettype, protoname, argv, nargs, rt);
if (external) {
assert(!need_to_emit);
auto calledF = jl_Module->getFunction(protoname);
assert(calledF);
// TODO: Check if already present?
ctx.external_calls[std::make_tuple(codeinst, specsig)] = calledF;
}
handled = true;
if (need_to_emit) {
Function *trampoline_decl = cast<Function>(jl_Module->getNamedValue(protoname));
Expand Down Expand Up @@ -5365,7 +5385,17 @@ static Function *emit_tojlinvoke(jl_code_instance_t *codeinst, Module *M, jl_cod
Function *theFunc;
Value *theFarg;
auto invoke = jl_atomic_load_relaxed(&codeinst->invoke);
if (params.cache && invoke != NULL) {

bool cache_valid = params.cache;
if (params.external_linkage) {
if (jl_object_in_image((jl_value_t*)codeinst)) {
// Target is present in another pkgimage
jl_printf(JL_STDERR, "\n (emit_jlinvoke) Want to resolve method\n");
cache_valid = true;
}
}

if (cache_valid && invoke != NULL) {
StringRef theFptrName = jl_ExecutionEngine->getFunctionAtAddress((uintptr_t)invoke, codeinst);
theFunc = cast<Function>(
M->getOrInsertFunction(theFptrName, jlinvoke_func->_type(ctx.builder.getContext())).getCallee());
Expand Down Expand Up @@ -8262,11 +8292,11 @@ void jl_compile_workqueue(
StringRef preal_decl = "";
bool preal_specsig = false;
auto invoke = jl_atomic_load_relaxed(&codeinst->invoke);
// TODO: available_extern
// We need to emit a trampoline that loads the target address in an extern_module from a GV
// Right now we will unecessarily emit a function we have already compiled in a native module
// again in a calling module.
if (params.cache && invoke != NULL) {
bool cache_valid = params.cache;
if (params.external_linkage) {
cache_valid = jl_object_in_image((jl_value_t*)codeinst);
}
if (cache_valid && invoke != NULL) {
auto fptr = jl_atomic_load_relaxed(&codeinst->specptr.fptr);
if (invoke == jl_fptr_args_addr) {
preal_decl = jl_ExecutionEngine->getFunctionAtAddress((uintptr_t)fptr, codeinst);
Expand All @@ -8275,7 +8305,7 @@ void jl_compile_workqueue(
preal_decl = jl_ExecutionEngine->getFunctionAtAddress((uintptr_t)fptr, codeinst);
preal_specsig = true;
}
}
}
else {
auto &result = emitted[codeinst];
jl_llvm_functions_t *decls = NULL;
Expand Down
2 changes: 2 additions & 0 deletions src/jitlayers.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ typedef struct _jl_codegen_params_t {
// outputs
std::vector<std::pair<jl_code_instance_t*, jl_codegen_call_target_t>> workqueue;
std::map<void*, GlobalVariable*> globals;
std::map<std::tuple<jl_code_instance_t*,bool>, Function*> external_fns;
std::map<jl_datatype_t*, DIType*> ditypes;
std::map<jl_datatype_t*, Type*> llvmtypes;
DenseMap<Constant*, GlobalVariable*> mergedConstants;
Expand Down Expand Up @@ -200,6 +201,7 @@ typedef struct _jl_codegen_params_t {
size_t world = 0;
const jl_cgparams_t *params = &jl_default_cgparams;
bool cache = false;
bool external_linkage = false;
bool imaging;
_jl_codegen_params_t(orc::ThreadSafeContext ctx) : tsctx(std::move(ctx)), tsctx_lock(tsctx.getLock()), imaging(imaging_default()) {}
} jl_codegen_params_t;
Expand Down
2 changes: 2 additions & 0 deletions src/jl_exported_data.inc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@
XX(jl_voidpointer_type) \
XX(jl_void_type) \
XX(jl_weakref_type) \
XX(jl_build_ids) \
XX(jl_linkage_blobs) \

// Data symbols that are defined inside the public libjulia
#define JL_EXPORTED_DATA_SYMBOLS(XX) \
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@
YY(jl_get_LLVM_VERSION) \
YY(jl_dump_native) \
YY(jl_get_llvm_gvs) \
YY(jl_get_llvm_external_fns) \
YY(jl_dump_function_asm) \
YY(jl_LLVMCreateDisasm) \
YY(jl_LLVMDisasmInstruction) \
Expand Down
3 changes: 2 additions & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -970,11 +970,12 @@ JL_DLLEXPORT jl_value_t *jl_dump_fptr_asm(uint64_t fptr, char raw_mc, const char
JL_DLLEXPORT jl_value_t *jl_dump_function_ir(jl_llvmf_dump_t *dump, char strip_ir_metadata, char dump_module, const char *debuginfo);
JL_DLLEXPORT jl_value_t *jl_dump_function_asm(jl_llvmf_dump_t *dump, char raw_mc, const char* asm_variant, const char *debuginfo, char binary);

void *jl_create_native(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int policy, int imaging_mode);
void *jl_create_native(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvmmod, const jl_cgparams_t *cgparams, int policy, int imaging_mode, int cache);
void jl_dump_native(void *native_code,
const char *bc_fname, const char *unopt_bc_fname, const char *obj_fname, const char *asm_fname,
const char *sysimg_data, size_t sysimg_len);
void jl_get_llvm_gvs(void *native_code, arraylist_t *gvs);
void jl_get_llvm_external_fns(void *native_code, arraylist_t *gvs);
JL_DLLEXPORT void jl_get_function_id(void *native_code, jl_code_instance_t *ncode,
int32_t *func_idx, int32_t *specfunc_idx);

Expand Down
8 changes: 4 additions & 4 deletions src/precompile.c
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ static int precompile_enq_all_specializations_(jl_methtable_t *mt, void *env)
return jl_typemap_visitor(jl_atomic_load_relaxed(&mt->defs), precompile_enq_all_specializations__, env);
}

static void *jl_precompile_(jl_array_t *m)
static void *jl_precompile_(jl_array_t *m, int external_linkage)
{
jl_array_t *m2 = NULL;
jl_method_instance_t *mi = NULL;
Expand All @@ -366,7 +366,7 @@ static void *jl_precompile_(jl_array_t *m)
jl_array_ptr_1d_push(m2, item);
}
}
void *native_code = jl_create_native(m2, NULL, NULL, 0, 1);
void *native_code = jl_create_native(m2, NULL, NULL, 0, 1, external_linkage);
JL_GC_POP();
return native_code;
}
Expand All @@ -379,7 +379,7 @@ static void *jl_precompile(int all)
if (all)
jl_compile_all_defs(m);
jl_foreach_reachable_mtable(precompile_enq_all_specializations_, m);
void *native_code = jl_precompile_(m);
void *native_code = jl_precompile_(m, 0);
JL_GC_POP();
return native_code;
}
Expand All @@ -398,7 +398,7 @@ static void *jl_precompile_worklist(jl_array_t *worklist)
assert(jl_is_module(mod));
foreach_mtable_in_module(mod, precompile_enq_all_specializations_, m);
}
void *native_code = jl_precompile_(m);
void *native_code = jl_precompile_(m, 1);
JL_GC_POP();
return native_code;
}
Expand Down
Loading

0 comments on commit 2ed4acd

Please sign in to comment.