Skip to content

Commit

Permalink
introduce @noinfer macro to tell the compiler to avoid excess infer…
Browse files Browse the repository at this point in the history
…ence

This commit introduces new compiler annotation named `@noinfer`, which
requests the compiler to avoid excess inference.

## Understand `@nospecialize`

In order to discuss `@noinfer`, it would help a lot to understand the
behavior of `@nospecialize`.

Its docstring says simply:
> This is only a hint for the compiler to avoid excess code generation.

More specifically, it works by _suppressing dispatches_ with complex
runtime types of the annotated arguments. This could be understood with
the example below:
```julia
julia> function invokef(f, itr)
           local r = 0
           r += f(itr[1])
           r += f(itr[2])
           r += f(itr[3])
           r
       end;

julia> _isa = isa; # just for the sake of explanation, global variable 
to prevent inling
julia> f(a) = _isa(a, Function);
julia> g(@nospecialize a) = _isa(a, Function);
julia> dispatchonly = Any[sin, muladd, nothing]; # untyped container can 
cause excessive runtime dispatch

julia> @code_typed invokef(f, dispatchonly)
CodeInfo(
1 ─ %1  = π (0, Int64)
│   %2  = Base.arrayref(true, itr, 1)::Any
│   %3  = (f)(%2)::Any
│   %4  = (%1 + %3)::Any
│   %5  = Base.arrayref(true, itr, 2)::Any
│   %6  = (f)(%5)::Any
│   %7  = (%4 + %6)::Any
│   %8  = Base.arrayref(true, itr, 3)::Any
│   %9  = (f)(%8)::Any
│   %10 = (%7 + %9)::Any
└──       return %10
) => Any

julia> @code_typed invokef(g, dispatchonly)
CodeInfo(
1 ─ %1  = π (0, Int64)
│   %2  = Base.arrayref(true, itr, 1)::Any
│   %3  = invoke f(%2::Any)::Any
│   %4  = (%1 + %3)::Any
│   %5  = Base.arrayref(true, itr, 2)::Any
│   %6  = invoke f(%5::Any)::Any
│   %7  = (%4 + %6)::Any
│   %8  = Base.arrayref(true, itr, 3)::Any
│   %9  = invoke f(%8::Any)::Any
│   %10 = (%7 + %9)::Any
└──       return %10
) => Any
```

The calls of `f` remain to be `:call` expression (thus dispatched and
compiled at runtime) while the calls of `g` are resolved as `:invoke`
expressions. This is because `@nospecialize` requests the compiler to
give up compiling `g` with concrete argument types but with precisely
declared argument types, and in this way `invokef(g, dispatchonly)` will
avoid runtime dispatches and accompanying JIT compilations (i.e. "excess 
code generation").

The problem here is, it influences dispatch only, does not intervene 
into
inference in anyway. So there is still a possibility of "excess 
inference"
when the compiler sees a considerable complexity of argument types 
during
inference:
```julia
julia> withinfernce = tuple(sin, muladd, "foo"); # typed container can 
cause excessive inference

julia> @time @code_typed invokef(f, withinfernce);
  0.000812 seconds (3.77 k allocations: 217.938 KiB, 94.34% compilation 
time)

julia> @time @code_typed invokef(g, withinfernce);
  0.000753 seconds (3.77 k allocations: 218.047 KiB, 92.42% compilation 
time)
```

The purpose of this PR is basically to provide a more drastic way to
avoid excess compilation.

## Design

Here are some ideas to implement the functionality:
1. make `@nospecialize` avoid inference also
2. add noinfer effect when `@nospecialize`d method is annotated as 
`@noinline` also
3. implement as `@pure`-like boolean annotation to request noinfer 
effect on top of `@nospecialize`
4. implement as annotation that is orthogonal to `@nospecialize`

After trying 1 ~ 3., I decided to submit 3. for now, because I think the
interface is ready to be experimented.

### 1. make `@nospecialize` avoid inference also

This is almost same as what Jameson has done at 
<vtjnash@8ab7b6b>.
It turned out that this approach performs very badly because some of
`@nospecialize`'d arguments still need inference to
perform reasonably. For example, it's obvious that the following
definition of `getindex(@nospecialize(t::Tuple), i::Int)` would perform
very badly if `@nospecialize` blocks inference, because of a lack of
useful type information for succeeding optimizations:
<https://github.com/JuliaLang/julia/blob/12d364e8249a07097a233ce7ea2886002459cc50/base/tuple.jl#L29-L30>

### 2. add noinfer effect when `@nospecialize`d method is annotated as 
`@noinline` also

The important observation is that we often use `@nospecialize` even when
we expect inference to forward type and constant information.
Adversely, we may be able to exploit the fact that we usually don't
expect inference to forward information to a callee when we annotate it
as `@noinline`.
So the idea is to enable the inference suppression when 
`@nospecialize`'d
method is annotated as `@noinline` also.

It's a reasonable choice, and could be implemented efficiently after
<#41922>.
But it sounds a bit weird to me to associate no infer effect with
`@noinline`, and I also think there may be some cases we want to inline
a method while _partially_ avoiding inference, e.g.:
```julia
# the compiler will always infer with `f::Any`
@noinline function twof(@nospecialize(f), n) # we really want not to 
inline this method body ?
    if occursin('+', string(typeof(f).name.name::Symbol))
        2 + n
    elseif occursin('*', string(typeof(f).name.name::Symbol))
        2n
    else
        zero(n)
    end
end
```

### 3. implement as `@pure`-like boolean annotation to request noinfer 
effect on top of `@nospecialize`

So this is what this commit implements. It basically replaces the 
previous
`@noinline` flag with newly-introduced annotation named `@noinfer`. It's
still associated with `@nospecialize` and it only has effect when used
together with `@nospecialize`, but now it's not associated to 
`@noinline`
at least, and it would help us reason about the behavior of `@noinfer`
and experiment its effect more reliably:
```julia
# the compiler will always infer with `f::Any`
Base.@noinfer function twof(@nospecialize(f), n) # the compiler may or 
not inline this method
    if occursin('+', string(typeof(f).name.name::Symbol))
        2 + n
    elseif occursin('*', string(typeof(f).name.name::Symbol))
        2n
    else
        zero(n)
    end
end
```

### 4. implement as annotation that is orthogonal to `@nospecialize`

Actually, we can have `@nospecialize` and `@noinfer` separately, and it
would allow us to configure compilation strategies in a more 
fine-grained
way.
```julia
function noinfspec(Base.@noinfer(f), @nospecialize(g))
    ...
end
```

I'm fine with this approach, if initial experiments show `@noinfer` is
useful.
  • Loading branch information
aviatesk committed Aug 19, 2021
1 parent 29c9ea0 commit 0e77d92
Show file tree
Hide file tree
Showing 15 changed files with 204 additions and 29 deletions.
9 changes: 8 additions & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
add_remark!(interp, sv, "Refusing to infer into `depwarn`")
return MethodCallResult(Any, false, false, nothing)
end
if is_noinfer(method)
sig = get_nospecialize_sig(method, sig, sparams)
end
topmost = nothing
# Limit argument type tuple growth of functions:
# look through the parents list to see if there's a call to the same method
Expand Down Expand Up @@ -584,7 +587,11 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
end
end
force |= allconst
mi = specialize_method(match; preexisting=!force)
if is_noinfer(method)
mi = specialize_method_noinfer(match; preexisting=!force)
else
mi = specialize_method(match; preexisting=!force)
end
if mi === nothing
add_remark!(interp, sv, "[constprop] Failed to specialize")
return nothing
Expand Down
11 changes: 7 additions & 4 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,7 @@ end

function analyze_method!(match::MethodMatch, atypes::Vector{Any},
state::InliningState, @nospecialize(stmttyp))
method = match.method
methsig = method.sig
(; method, sparams) = match

# Check that we habe the correct number of arguments
na = Int(method.nargs)
Expand All @@ -800,7 +799,7 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
end

# Bail out if any static parameters are left as TypeVar
validate_sparams(match.sparams) || return nothing
validate_sparams(sparams) || return nothing

et = state.et

Expand All @@ -809,7 +808,11 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
end

# See if there exists a specialization for this method signature
mi = specialize_method(match; preexisting=true) # Union{Nothing, MethodInstance}
if is_noinfer(method)
mi = specialize_method_noinfer(match; preexisting=true)
else
mi = specialize_method(match; preexisting=true)
end
if !isa(mi, MethodInstance)
return compileable_specialization(et, match)
end
Expand Down
26 changes: 25 additions & 1 deletion base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ function is_inlineable_constant(@nospecialize(x))
return count_const_size(x) <= MAX_INLINE_CONST_SIZE
end

is_nospecialized(method::Method) = method.nospecialize 0

is_noinfer(method::Method) = method.noinfer && is_nospecialized(method)
# is_noinfer(method::Method) = is_nospecialized(method) && is_declared_noinline(method)

###########################
# MethodInstance/CodeInfo #
###########################
Expand Down Expand Up @@ -144,6 +149,20 @@ function get_compileable_sig(method::Method, @nospecialize(atypes), sparams::Sim
isa(atypes, DataType) || return nothing
mt = ccall(:jl_method_table_for, Any, (Any,), atypes)
mt === nothing && return nothing
atypes′ = ccall(:jl_normalize_to_compilable_sig, Any, (Any, Any, Any, Any),
mt, atypes, sparams, method)
is_compileable = isdispatchtuple(atypes) ||
ccall(:jl_isa_compileable_sig, Int32, (Any, Any), atypes′, method) 0
return is_compileable ? atypes′ : nothing
end

function get_nospecialize_sig(method::Method, @nospecialize(atypes), sparams::SimpleVector)
if isa(atypes, UnionAll)
atypes, sparams = normalize_typevars(method, atypes, sparams)
end
isa(atypes, DataType) || return method.sig
mt = ccall(:jl_method_table_for, Any, (Any,), atypes)
mt === nothing && return method.sig
return ccall(:jl_normalize_to_compilable_sig, Any, (Any, Any, Any, Any),
mt, atypes, sparams, method)
end
Expand Down Expand Up @@ -188,7 +207,7 @@ function specialize_method(method::Method, @nospecialize(atypes), sparams::Simpl
if preexisting
# check cached specializations
# for an existing result stored there
return ccall(:jl_specializations_lookup, Any, (Any, Any), method, atypes)
return ccall(:jl_specializations_lookup, Ref{MethodInstance}, (Any, Any), method, atypes)
end
return ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), method, atypes, sparams)
end
Expand All @@ -197,6 +216,11 @@ function specialize_method(match::MethodMatch; kwargs...)
return specialize_method(match.method, match.spec_types, match.sparams; kwargs...)
end

function specialize_method_noinfer((; method, spec_types, sparams)::MethodMatch; kwargs...)
atypes = get_nospecialize_sig(method, spec_types, sparams)
return specialize_method(method, atypes, sparams; kwargs...)
end

# This function is used for computing alternate limit heuristics
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector)
if isdefined(method, :generator) && method.generator.expand_early && may_invoke_generator(method, sig, sparams)
Expand Down
9 changes: 8 additions & 1 deletion base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ end
Applied to a function argument name, hints to the compiler that the method
should not be specialized for different types of that argument,
but instead to use precisely the declared type for each argument.
This is only a hint for avoiding excess code generation.
Can be applied to an argument within a formal argument list,
or in the function body.
When applied to an argument, the macro must wrap the entire argument expression.
Expand Down Expand Up @@ -87,6 +86,14 @@ end
f(y) = [x for x in y]
@specialize
```
!!! note
This is only a hint for the compiler to avoid excess code generation by suppressing
dispatches with complex runtime types of the annotated arguments.
Note that `@nospecialize` doesn't intervene into inference, and thus it doens't
eliminate any latency due to inference that may happen when the compiler sees complex
types that can be known statically. Use [`Base.@noinfer`](@ref) together with
`@nospecialize` also in order to suppress excess inference for such a case.
"""
macro nospecialize(vars...)
if nfields(vars) === 1
Expand Down
33 changes: 27 additions & 6 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,12 @@ end
macro noinline() Expr(:meta, :noinline) end

"""
@pure ex
@pure(ex)
Base.@pure function f(args...)
...
end
Base.@pure f(args...) = ...
`@pure` gives the compiler a hint for the definition of a pure function,
`Base.@pure` gives the compiler a hint for the definition of a pure function,
helping for type inference.
This macro is intended for internal compiler use and may be subject to changes.
Expand All @@ -267,10 +269,12 @@ macro pure(ex)
end

"""
@aggressive_constprop ex
@aggressive_constprop(ex)
Base.@aggressive_constprop function f(args...)
...
end
Base.@aggressive_constprop f(args...) = ...
`@aggressive_constprop` requests more aggressive interprocedural constant
`Base.@aggressive_constprop` requests more aggressive interprocedural constant
propagation for the annotated function. For a method where the return type
depends on the value of the arguments, this can yield improved inference results
at the cost of additional compile time.
Expand All @@ -279,6 +283,23 @@ macro aggressive_constprop(ex)
esc(isa(ex, Expr) ? pushmeta!(ex, :aggressive_constprop) : ex)
end

"""
Base.@noinfer function f(args...)
@nospecialize ...
...
end
Base.@noinfer f(@nospecialize args...) = ...
Tells the compiler to infer `f` only with the precisely the declared types of arguments.
It can eliminate a latency problem due to excessive inference that can happen when the
compiler sees a considerable complexity of argument types during inference.
Note that this macro only has effect when used together with [`@nospecialize`](@ref),
and the effect is only applied to `@nospecialize`d arguments.
"""
macro noinfer(ex)
esc(isa(ex, Expr) ? pushmeta!(ex, :noinfer) : ex)
end

"""
@propagate_inbounds
Expand Down
2 changes: 2 additions & 0 deletions doc/src/base/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ Base.@inline
Base.@noinline
Base.@nospecialize
Base.@specialize
Base.@noinfer
Base.@aggressive_constprop
Base.gensym
Base.@gensym
var"name"
Expand Down
3 changes: 2 additions & 1 deletion src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jl_sym_t *static_parameter_sym; jl_sym_t *inline_sym;
jl_sym_t *noinline_sym; jl_sym_t *generated_sym;
jl_sym_t *generated_only_sym; jl_sym_t *isdefined_sym;
jl_sym_t *propagate_inbounds_sym; jl_sym_t *specialize_sym;
jl_sym_t *aggressive_constprop_sym;
jl_sym_t *aggressive_constprop_sym; jl_sym_t *noinfer_sym;
jl_sym_t *nospecialize_sym; jl_sym_t *macrocall_sym;
jl_sym_t *colon_sym; jl_sym_t *hygienicscope_sym;
jl_sym_t *throw_undef_if_not_sym; jl_sym_t *getfield_undefref_sym;
Expand Down Expand Up @@ -397,6 +397,7 @@ void jl_init_common_symbols(void)
polly_sym = jl_symbol("polly");
propagate_inbounds_sym = jl_symbol("propagate_inbounds");
aggressive_constprop_sym = jl_symbol("aggressive_constprop");
noinfer_sym = jl_symbol("noinfer");
isdefined_sym = jl_symbol("isdefined");
nospecialize_sym = jl_symbol("nospecialize");
specialize_sym = jl_symbol("specialize");
Expand Down
2 changes: 2 additions & 0 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
write_int8(s->s, m->pure);
write_int8(s->s, m->is_for_opaque_closure);
write_int8(s->s, m->aggressive_constprop);
write_int32(s->s, m->noinfer);
jl_serialize_value(s, (jl_value_t*)m->slot_syms);
jl_serialize_value(s, (jl_value_t*)m->roots);
jl_serialize_value(s, (jl_value_t*)m->ccallable);
Expand Down Expand Up @@ -1525,6 +1526,7 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
m->pure = read_int8(s->s);
m->is_for_opaque_closure = read_int8(s->s);
m->aggressive_constprop = read_int8(s->s);
m->noinfer = read_int32(s->s);
m->slot_syms = jl_deserialize_value(s, (jl_value_t**)&m->slot_syms);
jl_gc_wb(m, m->slot_syms);
m->roots = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&m->roots);
Expand Down
8 changes: 2 additions & 6 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -2051,10 +2051,8 @@ JL_DLLEXPORT jl_value_t *jl_normalize_to_compilable_sig(jl_methtable_t *mt, jl_t
intptr_t nspec = (mt == jl_type_type_mt || mt == jl_nonfunction_mt ? m->nargs + 1 : mt->max_args + 2);
jl_compilation_sig(ti, env, m, nspec, &newparams);
tt = (newparams ? jl_apply_tuple_type(newparams) : ti);
int is_compileable = ((jl_datatype_t*)ti)->isdispatchtuple ||
jl_isa_compileable_sig(tt, m);
JL_GC_POP();
return is_compileable ? (jl_value_t*)tt : jl_nothing;
return (jl_value_t*)tt;
}

// compile-time method lookup
Expand Down Expand Up @@ -2098,9 +2096,7 @@ jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types JL_PROPAGATES
}
else {
tt = jl_normalize_to_compilable_sig(mt, ti, env, m);
if (tt != jl_nothing) {
nf = jl_specializations_get_linfo(m, (jl_value_t*)tt, env);
}
nf = jl_specializations_get_linfo(m, (jl_value_t*)tt, env);
}
}
}
Expand Down
18 changes: 11 additions & 7 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2331,7 +2331,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_code_info_type =
jl_new_datatype(jl_symbol("CodeInfo"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(19,
jl_perm_symsvec(20,
"code",
"codelocs",
"ssavaluetypes",
Expand All @@ -2350,8 +2350,9 @@ void jl_init_types(void) JL_GC_DISABLED
"inlineable",
"propagate_inbounds",
"pure",
"aggressive_constprop"),
jl_svec(19,
"aggressive_constprop",
"noinfer"),
jl_svec(20,
jl_array_any_type,
jl_array_int32_type,
jl_any_type,
Expand All @@ -2370,14 +2371,15 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type),
jl_emptysvec,
0, 1, 19);
0, 1, 20);

jl_method_type =
jl_new_datatype(jl_symbol("Method"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(26,
jl_perm_symsvec(27,
"name",
"module",
"file",
Expand All @@ -2403,8 +2405,9 @@ void jl_init_types(void) JL_GC_DISABLED
"isva",
"pure",
"is_for_opaque_closure",
"aggressive_constprop"),
jl_svec(26,
"aggressive_constprop",
"noinfer"),
jl_svec(27,
jl_symbol_type,
jl_module_type,
jl_symbol_type,
Expand All @@ -2430,6 +2433,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type),
jl_emptysvec,
0, 1, 10);
Expand Down
2 changes: 2 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ typedef struct _jl_code_info_t {
uint8_t propagate_inbounds;
uint8_t pure;
uint8_t aggressive_constprop;
uint8_t noinfer;
} jl_code_info_t;

// This type describes a single method definition, and stores data
Expand Down Expand Up @@ -326,6 +327,7 @@ typedef struct _jl_method_t {
uint8_t pure;
uint8_t is_for_opaque_closure;
uint8_t aggressive_constprop;
uint8_t noinfer;

// hidden fields:
// lock for modifications to the method
Expand Down
2 changes: 1 addition & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ extern jl_sym_t *static_parameter_sym; extern jl_sym_t *inline_sym;
extern jl_sym_t *noinline_sym; extern jl_sym_t *generated_sym;
extern jl_sym_t *generated_only_sym; extern jl_sym_t *isdefined_sym;
extern jl_sym_t *propagate_inbounds_sym; extern jl_sym_t *specialize_sym;
extern jl_sym_t *aggressive_constprop_sym;
extern jl_sym_t *aggressive_constprop_sym; extern jl_sym_t *noinfer_sym;
extern jl_sym_t *nospecialize_sym; extern jl_sym_t *macrocall_sym;
extern jl_sym_t *colon_sym; extern jl_sym_t *hygienicscope_sym;
extern jl_sym_t *throw_undef_if_not_sym; extern jl_sym_t *getfield_undefref_sym;
Expand Down
5 changes: 5 additions & 0 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir)
li->propagate_inbounds = 1;
else if (ma == (jl_value_t*)aggressive_constprop_sym)
li->aggressive_constprop = 1;
else if (ma == (jl_value_t*)noinfer_sym)
li->noinfer = 1;
else
jl_array_ptr_set(meta, ins++, ma);
}
Expand Down Expand Up @@ -380,6 +382,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
src->pure = 0;
src->edges = jl_nothing;
src->aggressive_constprop = 0;
src->noinfer = 0;
return src;
}

Expand Down Expand Up @@ -567,6 +570,7 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src)
m->called = called;
m->pure = src->pure;
m->aggressive_constprop = src->aggressive_constprop;
m->noinfer = src->noinfer;
jl_add_function_name_to_lineinfo(src, (jl_value_t*)m->name);

jl_array_t *copy = NULL;
Expand Down Expand Up @@ -683,6 +687,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
m->deleted_world = ~(size_t)0;
m->is_for_opaque_closure = 0;
m->aggressive_constprop = 0;
m->noinfer = 0;
JL_MUTEX_INIT(&m->writelock);
return m;
}
Expand Down
Loading

0 comments on commit 0e77d92

Please sign in to comment.