Skip to content

Commit

Permalink
make method table definitions unsorted
Browse files Browse the repository at this point in the history
Rather than stopping at the first subtyping match, this scans the entire
table of methods for intersections for every lookup. It then filters the
results to remove ambiguities and sort for specificity order.

Also exports the has_ambiguity computation from ml-matches, as it's
cheap here to provide (already computed), and required by Core.Compiler
(and also more correct than what the optimizer had been doing, though
only on extreme hypothetical edge cases).
  • Loading branch information
vtjnash committed Jul 1, 2020
1 parent ad94873 commit 4c75c70
Show file tree
Hide file tree
Showing 22 changed files with 530 additions and 638 deletions.
11 changes: 6 additions & 5 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,19 @@ const _REF_NAME = Ref.body.name
call_result_unused(frame::InferenceState, pc::LineNum=frame.currpc) =
isexpr(frame.src.code[frame.currpc], :call) && isempty(frame.ssavalue_uses[pc])

function matching_methods(@nospecialize(atype), cache::IdDict{Any, Tuple{Any, UInt, UInt}}, max_methods::Int, world::UInt)
function matching_methods(@nospecialize(atype), cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}, max_methods::Int, world::UInt)
box = Core.Box(atype)
return get!(cache, atype) do
_min_val = UInt[typemin(UInt)]
_max_val = UInt[typemax(UInt)]
ms = _methods_by_ftype(box.contents, max_methods, world, _min_val, _max_val)
return ms, _min_val[1], _max_val[1]
_ambig = Int32[0]
ms = _methods_by_ftype(box.contents, max_methods, world, false, _min_val, _max_val, _ambig)
return ms, _min_val[1], _max_val[1], _ambig[1] != 0
end
end

function matching_methods(@nospecialize(atype), cache::IdDict{Any, Tuple{Any, UInt, UInt}}, max_methods::Int, world::UInt, min_valid::Vector{UInt}, max_valid::Vector{UInt})
ms, minvalid, maxvalid = matching_methods(atype, cache, max_methods, world)
function matching_methods(@nospecialize(atype), cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}, max_methods::Int, world::UInt, min_valid::Vector{UInt}, max_valid::Vector{UInt})
ms, minvalid, maxvalid, ambig = matching_methods(atype, cache, max_methods, world)
min_valid[1] = max(min_valid[1], minvalid)
max_valid[1] = min(max_valid[1], maxvalid)
return ms
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ mutable struct InferenceState

# cached results of calling `_methods_by_ftype`, including `min_valid` and
# `max_valid`, to be used in inlining
matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt}}
matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}

# The interpreter that created this inference state. Not looked at by
# NativeInterpreter. But other interpreters may use this to detect cycles
Expand Down Expand Up @@ -111,7 +111,7 @@ mutable struct InferenceState
Vector{InferenceState}(), # callers_in_cycle
#=parent=#nothing,
cached, false, false, false,
IdDict{Any, Tuple{Any, UInt, UInt}}(),
IdDict{Any, Tuple{Any, UInt, UInt, Bool}}(),
interp)
result.result = frame
cached && push!(get_inference_cache(interp), result)
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ mutable struct OptimizationState
const_api::Bool
# cached results of calling `_methods_by_ftype` from inference, including
# `min_valid` and `max_valid`
matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt}}
matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}
# TODO: This will be eliminated once optimization no longer needs to do method lookups
interp::AbstractInterpreter
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
Expand Down Expand Up @@ -64,7 +64,7 @@ mutable struct OptimizationState
src, inmodule, nargs,
get_world_counter(), UInt(1), get_world_counter(),
sptypes_from_meth_instance(linfo), slottypes, false,
IdDict{Any, Tuple{Any, UInt, UInt}}(), interp)
IdDict{Any, Tuple{Any, UInt, UInt, Bool}}(), interp)
end
end

Expand Down
27 changes: 8 additions & 19 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
@nospecialize

struct InvokeData
entry::Core.TypeMapEntry
entry::Method
types0
min_valid::UInt
max_valid::UInt
Expand Down Expand Up @@ -683,17 +683,6 @@ function analyze_method!(idx::Int, sig::Signature, @nospecialize(metharg), meths
return nothing
end

# Check if we intersect any of this method's ambiguities
# TODO: We could split out the ambiguous case as another "union split" case.
# For now, we just reject the method
if method.ambig !== nothing && invoke_data === nothing
for entry::Core.TypeMapEntry in method.ambig
if typeintersect(sig.atype, entry.sig) !== Bottom
return nothing
end
end
end

# Bail out if any static parameters are left as TypeVar
ok = true
for i = 1:length(methsp)
Expand Down Expand Up @@ -943,12 +932,11 @@ is_builtin(s::Signature) =
function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::InvokeData, sv::OptimizationState, todo::Vector{Any})
stmt = ir.stmts[idx][:inst]
calltype = ir.stmts[idx][:type]
method = invoke_data.entry.func
method = invoke_data.entry
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
sig.atype, method.sig)::SimpleVector
sig.atype, method.sig)::SimpleVector
methsp = methsp::SimpleVector
result = analyze_method!(idx, sig, metharg, methsp, method, stmt, sv, true, invoke_data,
calltype)
result = analyze_method!(idx, sig, metharg, methsp, method, stmt, sv, true, invoke_data, calltype)
handle_single_case!(ir, stmt, idx, result, true, todo)
update_valid_age!(invoke_data.min_valid, invoke_data.max_valid, sv)
return nothing
Expand Down Expand Up @@ -1046,10 +1034,11 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
# in the case that the cache is nonempty, so it should be unchanged
# The max number of methods should be the same as in inference most
# of the time, and should not affect correctness otherwise.
(meth, min_valid, max_valid) =
(meth, min_valid, max_valid, ambig) =
matching_methods(atype, sv.matching_methods_cache, sv.params.MAX_METHODS, sv.world)
if meth === false
if meth === false || ambig
# Too many applicable methods
# Or there is a (partial?) ambiguity
too_many = true
break
elseif length(meth) == 0
Expand Down Expand Up @@ -1159,7 +1148,7 @@ function compute_invoke_data(@nospecialize(atypes), world::UInt)
invoke_entry = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt),
invoke_types, world) # XXX: min_valid, max_valid
invoke_entry === nothing && return nothing
invoke_data = InvokeData(invoke_entry::Core.TypeMapEntry, invoke_types, min_valid[1], max_valid[1])
invoke_data = InvokeData(invoke_entry::Method, invoke_types, min_valid[1], max_valid[1])
atype0 = atypes[2]
atypes = atypes[4:end]
pushfirst!(atypes, atype0)
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

@inline isexpr(@nospecialize(stmt), head::Symbol) = isa(stmt, Expr) && stmt.head === head
@eval Core.UpsilonNode() = $(Expr(:new, Core.UpsilonNode))
Core.PhiNode() = Core.PhiNode(Any[], Any[])

"""
Expand Down Expand Up @@ -313,6 +312,8 @@ end

struct UndefToken
end
const undef_token = UndefToken()


function getindex(x::UseRef)
stmt = x.stmt
Expand Down
4 changes: 0 additions & 4 deletions base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ function make_ssa!(ci::CodeInfo, code::Vector{Any}, idx, slot, @nospecialize(typ
idx
end

struct UndefToken
end
const undef_token = UndefToken()

function new_to_regular(@nospecialize(stmt), new_offset::Int)
if isa(stmt, NewSSAValue)
return SSAValue(stmt.id + new_offset)
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1160,14 +1160,14 @@ function invoke_tfunc(interp::AbstractInterpreter, @nospecialize(ft), @nospecial
isdispatchelem(ft) || return Any # check that we might not have a subtype of `ft` at runtime, before doing supertype lookup below
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)
argtype = Tuple{ft, argtype.parameters...}
entry = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), types, get_world_counter(interp))
if entry === nothing
meth = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), types, get_world_counter(interp))
if meth === nothing
return Any
end
# XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
meth = entry.func
meth = meth::Method
(ti, env) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), argtype, meth.sig)::SimpleVector
rt, edge = typeinf_edge(interp, meth::Method, ti, env, sv)
rt, edge = typeinf_edge(interp, meth, ti, env, sv)
edge !== nothing && add_backedge!(edge::MethodInstance, sv)
return rt
end
Expand Down
33 changes: 21 additions & 12 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -835,10 +835,10 @@ function _methods(@nospecialize(f), @nospecialize(t), lim::Int, world::UInt)
end

function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt)
return _methods_by_ftype(t, lim, world, UInt[typemin(UInt)], UInt[typemax(UInt)])
return _methods_by_ftype(t, lim, world, false, UInt[typemin(UInt)], UInt[typemax(UInt)], Cint[0])
end
function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt, min::Array{UInt,1}, max::Array{UInt,1})
return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}), t, lim, 0, world, min, max)
function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt, ambig::Bool, min::Array{UInt,1}, max::Array{UInt,1}, has_ambig::Array{Int32,1})
return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool}
end

# high-level, more convenient method lookup functions
Expand Down Expand Up @@ -895,7 +895,9 @@ function methods_including_ambiguous(@nospecialize(f), @nospecialize(t))
world = typemax(UInt)
min = UInt[typemin(UInt)]
max = UInt[typemax(UInt)]
ms = ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}), tt, -1, 1, world, min, max)::Array{Any,1}
has_ambig = Int32[0]
ms = _methods_by_ftype(tt, -1, world, true, min, max, has_ambig)
ms === false && return false
return MethodList(Method[m[3] for m in ms], typeof(f).name.mt)
end

Expand Down Expand Up @@ -1177,7 +1179,7 @@ function which(@nospecialize(tt::Type))
if m === nothing
error("no unique matching method found for the specified argument types")
end
return m.func::Method
return m::Method
end

"""
Expand Down Expand Up @@ -1294,11 +1296,10 @@ end
"""
Base.isambiguous(m1, m2; ambiguous_bottom=false) -> Bool
Determine whether two methods `m1` and `m2` (typically of the same
function) are ambiguous. This test is performed in the context of
other methods of the same function; in isolation, `m1` and `m2` might
be ambiguous, but if a third method resolving the ambiguity has been
defined, this returns `false`.
Determine whether two methods `m1` and `m2` may be ambiguous for some call
signature. This test is performed in the context of other methods of the same
function; in isolation, `m1` and `m2` might be ambiguous, but if a third method
resolving the ambiguity has been defined, this returns `false`.
For parametric types, the `ambiguous_bottom` keyword argument controls whether
`Union{}` counts as an ambiguous intersection of type parameters – when `true`,
Expand All @@ -1325,15 +1326,23 @@ false
```
"""
function isambiguous(m1::Method, m2::Method; ambiguous_bottom::Bool=false)
# TODO: eagerly returning `morespecific` is wrong, and fails to consider
# the possibility of an ambiguity caused by a third method:
# see the precise algorithm in ml_matches for a more correct computation
if m1 === m2 || morespecific(m1.sig, m2.sig) || morespecific(m2.sig, m1.sig)
return false
end
ti = typeintersect(m1.sig, m2.sig)
(ti <: m1.sig && ti <: m2.sig) || return false # XXX: completely wrong, obviously
ti === Bottom && return false
if !ambiguous_bottom
has_bottom_parameter(ti) && return false
end
ml = _methods_by_ftype(ti, -1, typemax(UInt))
isempty(ml) && return true
for m in ml
if ti <: m[3].sig
m === m1 && continue
m === m2 && continue
if ti <: m[3].sig && morespecific(m[3].sig, m1.sig) && morespecific(m[3].sig, m2.sig)
return false
end
end
Expand Down
2 changes: 1 addition & 1 deletion doc/src/manual/methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ julia> g(2, 3.0)
julia> g(2.0, 3.0)
ERROR: MethodError: g(::Float64, ::Float64) is ambiguous. Candidates:
g(x::Float64, y) in Main at none:1
g(x, y::Float64) in Main at none:1
g(x::Float64, y) in Main at none:1
Possible fix, define
g(::Float64, ::Float64)
```
Expand Down
23 changes: 5 additions & 18 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,6 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
write_uint8(s->s, TAG_METHOD);
jl_method_t *m = (jl_method_t*)v;
int internal = 1;
int external_mt = 0;
internal = module_in_worklist(m->module);
if (!internal) {
// flag this in the backref table as special
Expand All @@ -581,22 +580,11 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
write_uint8(s->s, internal);
if (!internal)
return;
jl_methtable_t *mt = jl_method_table_for((jl_value_t*)m->sig);
assert((jl_value_t*)mt != jl_nothing);
external_mt = !module_in_worklist(mt->module);
jl_serialize_value(s, m->specializations);
jl_serialize_value(s, m->speckeyset);
jl_serialize_value(s, (jl_value_t*)m->name);
jl_serialize_value(s, (jl_value_t*)m->file);
write_int32(s->s, m->line);
if (external_mt) {
jl_serialize_value(s, jl_nothing);
jl_serialize_value(s, jl_nothing);
}
else {
jl_serialize_value(s, (jl_value_t*)m->ambig);
jl_serialize_value(s, (jl_value_t*)m->resorted);
}
write_int32(s->s, m->called);
write_int32(s->s, m->nargs);
write_int32(s->s, m->nospecialize);
Expand Down Expand Up @@ -995,7 +983,8 @@ static void jl_collect_backedges(jl_array_t *s, jl_array_t *t)
}
size_t min_valid = 0;
size_t max_valid = ~(size_t)0;
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, -1, 0, jl_world_counter, &min_valid, &max_valid);
int ambig = 0;
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig);
if (matches == jl_false) {
valid = 0;
break;
Expand Down Expand Up @@ -1405,10 +1394,6 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
m->line = read_int32(s->s);
m->primary_world = jl_world_counter;
m->deleted_world = ~(size_t)0;
m->ambig = jl_deserialize_value(s, (jl_value_t**)&m->ambig);
jl_gc_wb(m, m->ambig);
m->resorted = jl_deserialize_value(s, (jl_value_t**)&m->resorted);
jl_gc_wb(m, m->resorted);
m->called = read_int32(s->s);
m->nargs = read_int32(s->s);
m->nospecialize = read_int32(s->s);
Expand Down Expand Up @@ -1858,7 +1843,9 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids)
int valid = 1;
size_t min_valid = 0;
size_t max_valid = ~(size_t)0;
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, -1, 0, jl_world_counter, &min_valid, &max_valid);
int ambig = 0;
// TODO: possibly need to included ambiguities too (for the optimizer correctness)?
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig);
if (matches == jl_false || jl_array_len(matches) != jl_array_len(expected)) {
valid = 0;
}
Expand Down
Loading

0 comments on commit 4c75c70

Please sign in to comment.