Skip to content

Commit

Permalink
AbstractInterpreter: make it easier to overload pure/concrete-eval (#…
Browse files Browse the repository at this point in the history
…44224)

fix #44174.

(cherry picked from commit daa0849)
  • Loading branch information
aviatesk authored and KristofferC committed Feb 19, 2022
1 parent a9d1ad2 commit 5f0d551
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 67 deletions.
144 changes: 80 additions & 64 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
const_results = Union{InferenceResult,Nothing,ConstResult}[]
multiple_matches = napplicable > 1

if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch)
val = pure_eval_call(f, argtypes)
if val !== nothing
# TODO: add some sort of edge(s)
return CallMeta(val, MethodResultPure(info))
end
end
val = pure_eval_call(interp, f, applicable, arginfo, sv)
val !== nothing && return CallMeta(val, MethodResultPure(info)) # TODO: add some sort of edge(s)

fargs = arginfo.fargs
for i in 1:napplicable
Expand Down Expand Up @@ -619,27 +614,85 @@ struct MethodCallResult
end
end

function pure_eval_eligible(interp::AbstractInterpreter,
@nospecialize(f), applicable::Vector{Any}, arginfo::ArgInfo, sv::InferenceState)
return !isoverlayed(method_table(interp, sv)) &&
f !== nothing &&
length(applicable) == 1 &&
is_method_pure(applicable[1]::MethodMatch) &&
is_all_const_arg(arginfo)
end

function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector)
if isdefined(method, :generator)
method.generator.expand_early || return false
mi = specialize_method(method, sig, sparams)
isa(mi, MethodInstance) || return false
staged = get_staged(mi)
(staged isa CodeInfo && (staged::CodeInfo).pure) || return false
return true
end
return method.pure
end
is_method_pure(match::MethodMatch) = is_method_pure(match.method, match.spec_types, match.sparams)

function pure_eval_call(interp::AbstractInterpreter,
@nospecialize(f), applicable::Vector{Any}, arginfo::ArgInfo, sv::InferenceState)
pure_eval_eligible(interp, f, applicable, arginfo, sv) || return nothing
return _pure_eval_call(f, arginfo)
end
function _pure_eval_call(@nospecialize(f), arginfo::ArgInfo)
args = collect_const_args(arginfo)
try
value = Core._apply_pure(f, args)
return Const(value)
catch
return nothing
end
end

function concrete_eval_eligible(interp::AbstractInterpreter,
@nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState)
return !isoverlayed(method_table(interp, sv)) &&
f !== nothing &&
result.edge !== nothing &&
is_total_or_error(result.edge_effects) &&
is_all_const_arg(arginfo)
end

function is_all_const_arg((; argtypes)::ArgInfo)
for a in argtypes
if !isa(a, Const) && !isconstType(a) && !issingletontype(a)
return false
end
for i = 2:length(argtypes)
a = widenconditional(argtypes[i])
isa(a, Const) || isconstType(a) || issingletontype(a) || return false
end
return true
end

function concrete_eval_const_proven_total_or_error(interp::AbstractInterpreter,
@nospecialize(f), (; argtypes)::ArgInfo, _::InferenceState)
args = Any[ (a = widenconditional(argtypes[i]);
isa(a, Const) ? a.val :
isconstType(a) ? (a::DataType).parameters[1] :
(a::DataType).instance) for i in 2:length(argtypes) ]
function collect_const_args((; argtypes)::ArgInfo)
return Any[ let a = widenconditional(argtypes[i])
isa(a, Const) ? a.val :
isconstType(a) ? (a::DataType).parameters[1] :
(a::DataType).instance
end for i in 2:length(argtypes) ]
end

function concrete_eval_call(interp::AbstractInterpreter,
@nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState)
concrete_eval_eligible(interp, f, result, arginfo, sv) || return nothing
args = collect_const_args(arginfo)
try
value = Core._call_in_world_total(get_world_counter(interp), f, args...)
return Const(value)
catch e
return nothing
if is_inlineable_constant(value) || call_result_unused(sv)
# If the constant is not inlineable, still do the const-prop, since the
# code that led to the creation of the Const may be inlineable in the same
# circumstance and may be optimizable.
return ConstCallResults(Const(value), ConstResult(result.edge, value), EFFECTS_TOTAL)
end
catch
# The evaulation threw. By :consistent-cy, we're guaranteed this would have happened at runtime
return ConstCallResults(Union{}, ConstResult(result.edge), result.edge_effects)
end
return nothing
end

function const_prop_enabled(interp::AbstractInterpreter, sv::InferenceState, match::MethodMatch)
Expand Down Expand Up @@ -671,19 +724,10 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
if !const_prop_enabled(interp, sv, match)
return nothing
end
if f !== nothing && result.edge !== nothing && is_total_or_error(result.edge_effects) && is_all_const_arg(arginfo)
rt = concrete_eval_const_proven_total_or_error(interp, f, arginfo, sv)
val = concrete_eval_call(interp, f, result, arginfo, sv)
if val !== nothing
add_backedge!(result.edge, sv)
if rt === nothing
# The evaulation threw. By :consistent-cy, we're guaranteed this would have happened at runtime
return ConstCallResults(Union{}, ConstResult(result.edge), result.edge_effects)
end
if is_inlineable_constant(rt.val) || call_result_unused(sv)
# If the constant is not inlineable, still do the const-prop, since the
# code that led to the creation of the Const may be inlineable in the same
# circumstance and may be optimizable.
return ConstCallResults(rt, ConstResult(result.edge, rt.val), EFFECTS_TOTAL)
end
return val
end
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
mi === nothing && return nothing
Expand Down Expand Up @@ -1218,36 +1262,6 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::
return CallMeta(res, retinfo)
end

function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector)
if isdefined(method, :generator)
method.generator.expand_early || return false
mi = specialize_method(method, sig, sparams)
isa(mi, MethodInstance) || return false
staged = get_staged(mi)
(staged isa CodeInfo && (staged::CodeInfo).pure) || return false
return true
end
return method.pure
end
is_method_pure(match::MethodMatch) = is_method_pure(match.method, match.spec_types, match.sparams)

function pure_eval_call(@nospecialize(f), argtypes::Vector{Any})
for i = 2:length(argtypes)
a = widenconditional(argtypes[i])
if !(isa(a, Const) || isconstType(a))
return nothing
end
end
args = Any[ (a = widenconditional(argtypes[i]);
isa(a, Const) ? a.val : (a::DataType).parameters[1]) for i in 2:length(argtypes) ]
try
value = Core._apply_pure(f, args)
return Const(value)
catch
return nothing
end
end

function argtype_by_index(argtypes::Vector{Any}, i::Int)
n = length(argtypes)
na = argtypes[n]
Expand Down Expand Up @@ -1586,8 +1600,10 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
elseif max_methods > 1 && istopfunction(f, :copyto!)
max_methods = 1
elseif la == 3 && istopfunction(f, :typejoin)
val = pure_eval_call(f, argtypes)
return CallMeta(val === nothing ? Type : val, MethodResultPure())
if is_all_const_arg(arginfo)
val = _pure_eval_call(f, arginfo)
return CallMeta(val === nothing ? Type : val, MethodResultPure())
end
end
atype = argtypes_to_type(argtypes)
return abstract_call_gf_by_type(interp, f, arginfo, atype, sv, max_methods)
Expand Down
11 changes: 8 additions & 3 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
_min_val[] = typemin(UInt)
_max_val[] = typemax(UInt)
ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig)
end
if ms === false
return missing
if ms === false
return missing
end
end
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
end
Expand Down Expand Up @@ -123,3 +123,8 @@ end

# This query is not cached
findsup(@nospecialize(sig::Type), table::CachedMethodTable) = findsup(sig, table.table)

isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
isoverlayed(::InternalMethodTable) = false
isoverlayed(::OverlayMethodTable) = true
isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table)

0 comments on commit 5f0d551

Please sign in to comment.