Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement "engine" for managing inference/codegen #54816

Merged
merged 3 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
350 changes: 191 additions & 159 deletions base/compiler/typeinfer.jl

Large diffs are not rendered by default.

37 changes: 8 additions & 29 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ mutable struct InferenceResult
effects::Effects # if optimization is finished
analysis_results::AnalysisResults # AnalysisResults with e.g. result::ArgEscapeCache if optimized, otherwise NULL_ANALYSIS_RESULTS
is_src_volatile::Bool # `src` has been cached globally as the compressed format already, allowing `src` to be used destructively
ci::CodeInstance # CodeInstance if this result has been added to the cache
ci::CodeInstance # CodeInstance if this result may be added to the cache
function InferenceResult(mi::MethodInstance, argtypes::Vector{Any}, overridden_by_const::Union{Nothing,BitVector})
return new(mi, argtypes, overridden_by_const, nothing, nothing, nothing,
WorldRange(), Effects(), Effects(), NULL_ANALYSIS_RESULTS, false)
Expand Down Expand Up @@ -402,36 +402,15 @@ get_inference_world(interp::NativeInterpreter) = interp.world
get_inference_cache(interp::NativeInterpreter) = interp.inf_cache
cache_owner(interp::NativeInterpreter) = nothing

"""
already_inferred_quick_test(::AbstractInterpreter, ::MethodInstance)
engine_reserve(interp::AbstractInterpreter, mi::MethodInstance) = engine_reserve(mi, cache_owner(interp))
engine_reserve(mi::MethodInstance, @nospecialize owner) = ccall(:jl_engine_reserve, Any, (Any, Any), mi, owner)::CodeInstance
# engine_fulfill(::AbstractInterpreter, ci::CodeInstance, src::CodeInfo) = ccall(:jl_engine_fulfill, Cvoid, (Any, Any), ci, src) # currently the same as engine_reject, so just use that one
engine_reject(::AbstractInterpreter, ci::CodeInstance) = ccall(:jl_engine_fulfill, Cvoid, (Any, Ptr{Cvoid}), ci, C_NULL)

For the `NativeInterpreter`, we don't need to do an actual cache query to know if something
was already inferred. If we reach this point, but the inference flag has been turned off,
then it's in the cache. This is purely for a performance optimization.
"""
already_inferred_quick_test(interp::NativeInterpreter, mi::MethodInstance) = !mi.inInference
already_inferred_quick_test(interp::AbstractInterpreter, mi::MethodInstance) = false

"""
lock_mi_inference(::AbstractInterpreter, mi::MethodInstance)

Hint that `mi` is in inference to help accelerate bootstrapping.
This is particularly used by `NativeInterpreter` and helps us limit the amount of wasted
work we might do when inference is working on initially inferring itself by letting us
detect when inference is already in progress and not running a second copy on it.
This creates a data-race, but the entry point into this code from C (`jl_type_infer`)
already includes detection and restriction on recursion, so it is hopefully mostly a
benign problem, since it should really only happen during the first phase of bootstrapping
that we encounter this flag.
"""
lock_mi_inference(::NativeInterpreter, mi::MethodInstance) = (mi.inInference = true; nothing)
lock_mi_inference(::AbstractInterpreter, ::MethodInstance) = return

"""
See `lock_mi_inference`.
"""
unlock_mi_inference(::NativeInterpreter, mi::MethodInstance) = (mi.inInference = false; nothing)
unlock_mi_inference(::AbstractInterpreter, ::MethodInstance) = return
function already_inferred_quick_test end
function lock_mi_inference end
function unlock_mi_inference end

"""
add_remark!(::AbstractInterpreter, sv::InferenceState, remark)
Expand Down
32 changes: 2 additions & 30 deletions doc/src/devdocs/locks.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,8 @@ The following is a level 5 lock
The following are a level 6 lock, which can only recurse to acquire locks at lower levels:

> * codegen
> * jl_modules_mutex
The following is an almost root lock (level end-1), meaning only the root look may be held when
trying to acquire it:

> * typeinf
>
> > this one is perhaps one of the most tricky ones, since type-inference can be invoked from many
> > points
> >
> > currently the lock is merged with the codegen lock, since they call each other recursively
The following lock synchronizes IO operation. Be aware that doing any I/O (for example,
printing warning messages or debug information) while holding any other lock listed above
may result in pernicious and hard-to-find deadlocks. BE VERY CAREFUL!
Expand Down Expand Up @@ -149,34 +138,17 @@ Module serializer : toplevel lock

JIT & type-inference : codegen lock

MethodInstance/CodeInstance updates : Method->writelock, codegen lock
MethodInstance/CodeInstance updates : Method->writelock

> * These are set at construction and immutable:
> * specTypes
> * sparam_vals
> * def
> * owner
> * These are set by `jl_type_infer` (while holding codegen lock):
> * cache
> * rettype
> * inferred
* valid ages

> * `inInference` flag:
> * optimization to quickly avoid recurring into `jl_type_infer` while it is already running
> * actual state (of setting `inferred`, then `fptr`) is protected by codegen lock
> * Function pointers:
> * these transition once, from `NULL` to a value, while the codegen lock is held
>
> * Code-generator cache (the contents of `functionObjectsDecls`):
> * these can transition multiple times, but only while the codegen lock is held
> * it is valid to use old version of this, or block for new versions of this, so races are benign,
> as long as the code is careful not to reference other data in the method instance (such as `rettype`)
> and assume it is coordinated, unless also holding the codegen lock
> * these transition once, from `NULL` to a value, which is coordinated internal to the JIT
>
LLVMContext : codegen lock
Method : Method->writelock

Expand Down
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ SRCS := \
simplevector runtime_intrinsics precompile jloptions mtarraylist \
threading scheduler stackwalk gc gc-debug gc-pages gc-stacks gc-alloc-profiler gc-page-profiler method \
jlapi signal-handling safepoint timing subtype rtutils gc-heap-snapshot \
crc32c APInt-C processor ircode opaque_closure codegen-stubs coverage runtime_ccall
crc32c APInt-C processor ircode opaque_closure codegen-stubs coverage runtime_ccall engine

RT_LLVMLINK :=
CG_LLVMLINK :=
Expand Down
2 changes: 1 addition & 1 deletion src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ jl_code_instance_t *jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_
jl_error("Refusing to automatically run type inference with custom cache lookup.");
}
else {
codeinst = jl_type_infer(mi, world, 0, SOURCE_MODE_ABI);
codeinst = jl_type_infer(mi, world, SOURCE_MODE_ABI);
/* Even if this codeinst is ordinarily not cacheable, we need to force
* it into the cache here, since it was explicitly requested and is
* otherwise not reachable from anywhere in the system image.
Expand Down
2 changes: 1 addition & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9895,7 +9895,7 @@ void jl_compile_workqueue(
if (policy != CompilationPolicy::Default &&
jl_atomic_load_relaxed(&codeinst->inferred) == jl_nothing) {
// Codegen lock is held, so SOURCE_MODE_FORCE_SOURCE_UNCACHED is not required
codeinst = jl_type_infer(codeinst->def, jl_atomic_load_relaxed(&codeinst->max_world), 0, SOURCE_MODE_FORCE_SOURCE);
codeinst = jl_type_infer(codeinst->def, jl_atomic_load_relaxed(&codeinst->max_world), SOURCE_MODE_FORCE_SOURCE);
}
if (codeinst) {
orc::ThreadSafeModule result_m =
Expand Down
158 changes: 158 additions & 0 deletions src/engine.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// This file is a part of Julia. License is MIT: https://julialang.org/license

#include <mutex>
#include <condition_variable>
#include <llvm/ADT/DenseMap.h>
#include <llvm/ADT/DenseSet.h>
#include <llvm/ADT/SmallVector.h>
#include "julia.h"
#include "julia_internal.h"
#include "julia_assert.h"

using namespace llvm;

struct ReservationInfo {
int16_t tid = 0;
jl_code_instance_t *ci = nullptr;
};

struct InferKey {
jl_method_instance_t *mi = nullptr;
jl_value_t *owner = nullptr;
};

template<> struct llvm::DenseMapInfo<InferKey> {
using FirstInfo = DenseMapInfo<jl_method_instance_t*>;
using SecondInfo = DenseMapInfo<jl_value_t*>;

static inline InferKey getEmptyKey() {
return InferKey{FirstInfo::getEmptyKey(),
SecondInfo::getEmptyKey()};
}

static inline InferKey getTombstoneKey() {
return InferKey{FirstInfo::getTombstoneKey(),
SecondInfo::getTombstoneKey()};
}

static unsigned getHashValue(const InferKey& PairVal) {
return detail::combineHashValue(FirstInfo::getHashValue(PairVal.mi),
SecondInfo::getHashValue(PairVal.owner));
}

static bool isEqual(const InferKey &LHS, const InferKey &RHS) {
return LHS.mi == RHS.mi && LHS.owner == RHS.owner;
}
};

static std::mutex engine_lock;
static std::condition_variable engine_wait;
// map from MethodInstance to threadid that owns it currently for inference
static DenseMap<InferKey, ReservationInfo> Reservations;
// vector of which threads are blocked and which lease they need
static SmallVector<InferKey, 0> Awaiting; // (this could be merged into ptls also)


#ifdef __cplusplus
extern "C" {
#endif

jl_code_instance_t *jl_engine_reserve(jl_method_instance_t *m, jl_value_t *owner)
{
jl_task_t *ct = jl_current_task;
ct->ptls->engine_nqueued++; // disables finalizers until inference is finished on this method graph
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
jl_code_instance_t *ci = jl_new_codeinst_uninit(m, owner); // allocate a placeholder
JL_GC_PUSH1(&ci);
int8_t gc_state = jl_gc_safe_enter(ct->ptls);
InferKey key = {m, owner};
std::unique_lock<std::mutex> lock(engine_lock);
auto tid = jl_atomic_load_relaxed(&ct->tid);
if ((signed)Awaiting.size() < tid + 1)
Awaiting.resize(tid + 1);
while (1) {
auto record = Reservations.find(key);
if (record == Reservations.end()) {
Reservations[key] = ReservationInfo{tid, ci};
lock.unlock();
jl_gc_safe_leave(ct->ptls, gc_state); // contains jl_gc_safepoint
JL_GC_POP();
return ci;
}
// before waiting, need to run deadlock/cycle detection
// there is a cycle if the thread holding our lease is blocked
// and waiting for (transitively) any lease that is held by this thread
auto wait_tid = record->second.tid;
while (1) {
if (wait_tid == tid) {
lock.unlock();
jl_gc_safe_leave(ct->ptls, gc_state); // contains jl_gc_safepoint
JL_GC_POP();
ct->ptls->engine_nqueued--;
return ci; // break the cycle
}
if ((signed)Awaiting.size() <= wait_tid)
break; // no cycle, since it is running (and this should be unreachable)
auto key2 = Awaiting[wait_tid];
if (key2.mi == nullptr)
break; // no cycle, since it is running
auto record2 = Reservations.find(key2);
if (record2 == Reservations.end())
break; // no cycle, since it is about to resume
assert(wait_tid != record2->second.tid);
wait_tid = record2->second.tid;
}
Awaiting[tid] = key;
engine_wait.wait(lock);
gbaraldi marked this conversation as resolved.
Show resolved Hide resolved
Awaiting[tid] = InferKey{};
}
}

int jl_engine_hasreserved(jl_method_instance_t *m, jl_value_t *owner)
{
jl_task_t *ct = jl_current_task;
InferKey key = {m, owner};
std::unique_lock<std::mutex> lock(engine_lock);
auto record = Reservations.find(key);
return record != Reservations.end() && record->second.tid == jl_atomic_load_relaxed(&ct->tid);
}

STATIC_INLINE int gc_marked(uintptr_t bits) JL_NOTSAFEPOINT
{
return (bits & GC_MARKED) != 0;
}

void jl_engine_sweep(jl_ptls_t *gc_all_tls_states)
{
std::unique_lock<std::mutex> lock(engine_lock);
bool any = false;
for (auto I = Reservations.begin(); I != Reservations.end(); ++I) {
jl_code_instance_t *ci = I->second.ci;
if (!gc_marked(jl_astaggedvalue(ci)->bits.gc)) {
auto tid = I->second.tid;
Reservations.erase(I);
jl_ptls_t ptls2 = gc_all_tls_states[tid];
ptls2->engine_nqueued--;
any = true;
}
}
if (any)
engine_wait.notify_all();
}

void jl_engine_fulfill(jl_code_instance_t *ci, jl_code_info_t *src)
{
jl_task_t *ct = jl_current_task;
std::unique_lock<std::mutex> lock(engine_lock);
auto record = Reservations.find(InferKey{ci->def, ci->owner});
if (record == Reservations.end() || record->second.ci != ci)
return;
assert(jl_atomic_load_relaxed(&ct->tid) == record->second.tid);
ct->ptls->engine_nqueued--; // re-enables finalizers, but doesn't immediately try to run them
Reservations.erase(record);
engine_wait.notify_all();
}


#ifdef __cplusplus
}
#endif
10 changes: 5 additions & 5 deletions src/gc-stacks.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ static void *malloc_stack(size_t bufsz) JL_NOTSAFEPOINT
}


static void free_stack(void *stkbuf, size_t bufsz)
static void free_stack(void *stkbuf, size_t bufsz) JL_NOTSAFEPOINT
{
#ifdef JL_USE_GUARD_PAGE
size_t guard_size = LLT_ALIGN(jl_guard_size, jl_page_size);
Expand Down Expand Up @@ -111,7 +111,7 @@ static void *malloc_stack(size_t bufsz) JL_NOTSAFEPOINT
}
# endif

static void free_stack(void *stkbuf, size_t bufsz)
static void free_stack(void *stkbuf, size_t bufsz) JL_NOTSAFEPOINT
{
#ifdef JL_USE_GUARD_PAGE
size_t guard_size = LLT_ALIGN(jl_guard_size, jl_page_size);
Expand All @@ -124,7 +124,7 @@ static void free_stack(void *stkbuf, size_t bufsz)
}
#endif

JL_DLLEXPORT uint32_t jl_get_num_stack_mappings(void)
JL_DLLEXPORT uint32_t jl_get_num_stack_mappings(void) JL_NOTSAFEPOINT
{
return jl_atomic_load_relaxed(&num_stack_mappings);
}
Expand Down Expand Up @@ -159,7 +159,7 @@ static unsigned select_pool(size_t nb) JL_NOTSAFEPOINT
}


static void _jl_free_stack(jl_ptls_t ptls, void *stkbuf, size_t bufsz)
static void _jl_free_stack(jl_ptls_t ptls, void *stkbuf, size_t bufsz) JL_NOTSAFEPOINT
{
#ifdef _COMPILER_ASAN_ENABLED_
__asan_unpoison_stack_memory((uintptr_t)stkbuf, bufsz);
Expand Down Expand Up @@ -238,7 +238,7 @@ JL_DLLEXPORT void *jl_malloc_stack(size_t *bufsz, jl_task_t *owner) JL_NOTSAFEPO
return stk;
}

void sweep_stack_pools(void)
void sweep_stack_pools(void) JL_NOTSAFEPOINT
{
// Stack sweeping algorithm:
// // deallocate stacks if we have too many sitting around unused
Expand Down
13 changes: 7 additions & 6 deletions src/gc.c
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ JL_DLLEXPORT void jl_gc_run_pending_finalizers(jl_task_t *ct)
if (ct == NULL)
ct = jl_current_task;
jl_ptls_t ptls = ct->ptls;
if (!ptls->in_finalizer && ptls->locks.len == 0 && ptls->finalizers_inhibited == 0) {
if (!ptls->in_finalizer && ptls->locks.len == 0 && ptls->finalizers_inhibited == 0 && ptls->engine_nqueued == 0) {
run_finalizers(ct, 0);
}
}
Expand Down Expand Up @@ -620,7 +620,7 @@ JL_DLLEXPORT void jl_finalize_th(jl_task_t *ct, jl_value_t *o)
}

// explicitly scheduled objects for the sweepfunc callback
static void gc_sweep_foreign_objs_in_list(arraylist_t *objs)
static void gc_sweep_foreign_objs_in_list(arraylist_t *objs) JL_NOTSAFEPOINT
{
size_t p = 0;
for (size_t i = 0; i < objs->len; i++) {
Expand All @@ -638,7 +638,7 @@ static void gc_sweep_foreign_objs_in_list(arraylist_t *objs)
objs->len = p;
}

static void gc_sweep_foreign_objs(void)
static void gc_sweep_foreign_objs(void) JL_NOTSAFEPOINT
{
assert(gc_n_threads);
for (int i = 0; i < gc_n_threads; i++) {
Expand Down Expand Up @@ -1584,8 +1584,11 @@ STATIC_INLINE void gc_sweep_pool_page(gc_page_profiler_serializer_t *s, jl_gc_pa
// sweep over all memory that is being used and not in a pool
static void gc_sweep_other(jl_ptls_t ptls, int sweep_full) JL_NOTSAFEPOINT
{
sweep_stack_pools();
gc_sweep_foreign_objs();
sweep_malloced_memory();
sweep_big(ptls, sweep_full);
jl_engine_sweep(gc_all_tls_states);
}

static void gc_pool_sync_nfree(jl_gc_pagemeta_t *pg, jl_taggedvalue_t *last) JL_NOTSAFEPOINT
Expand Down Expand Up @@ -3662,8 +3665,6 @@ static int _jl_gc_collect(jl_ptls_t ptls, jl_gc_collection_t collection)
#endif
current_sweep_full = sweep_full;
sweep_weak_refs();
sweep_stack_pools();
gc_sweep_foreign_objs();
gc_sweep_other(ptls, sweep_full);
gc_scrub();
gc_verify_tags();
Expand Down Expand Up @@ -3953,7 +3954,7 @@ JL_DLLEXPORT void jl_gc_collect(jl_gc_collection_t collection)
// Only disable finalizers on current thread
// Doing this on all threads is racy (it's impossible to check
// or wait for finalizers on other threads without dead lock).
if (!ptls->finalizers_inhibited && ptls->locks.len == 0) {
if (!ptls->finalizers_inhibited && ptls->locks.len == 0 && ptls->engine_nqueued == 0) {
JL_TIMING(GC, GC_Finalizers);
run_finalizers(ct, 0);
}
Expand Down
Loading