Skip to content

Commit

Permalink
Propagate iteration info to optimizer (#36684)
Browse files Browse the repository at this point in the history
This supersedes #36169. Rather than re-implementing the iteration
analysis as done there, this uses the new stmtinfo infrastrcture
to propagate all the analysis done during inference all the way
to inlining. As a result, it applies not only to splats of
singletons, but also to splats of any other short iterable
that inference can analyze. E.g.:

```
f(x) = (x...,)
@code_typed f(1=>2)
@benchmark f(1=>2)
```

Before:
```
julia> @code_typed f(1=>2)
CodeInfo(
1 ─ %1 = Core._apply_iterate(Base.iterate, Core.tuple, x)::Tuple{Int64,Int64}
└──      return %1
) => Tuple{Int64,Int64}

julia> @benchmark f(1=>2)
BenchmarkTools.Trial:
  memory estimate:  96 bytes
  allocs estimate:  3
  --------------
  minimum time:     242.659 ns (0.00% GC)
  median time:      246.904 ns (0.00% GC)
  mean time:        255.390 ns (1.08% GC)
  maximum time:     4.415 μs (93.94% GC)
  --------------
  samples:          10000
  evals/sample:     405
```

After:
```
julia> @code_typed f(1=>2)
CodeInfo(
1 ─ %1 = Base.getfield(x, 1)::Int64
│   %2 = Base.getfield(x, 2)::Int64
│   %3 = Core.tuple(%1, %2)::Tuple{Int64,Int64}
└──      return %3
) => Tuple{Int64,Int64}

julia> @benchmark f(1=>2)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.701 ns (0.00% GC)
  median time:      1.925 ns (0.00% GC)
  mean time:        1.904 ns (0.00% GC)
  maximum time:     6.941 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000
```

I also implemented the TODO, I had left in #36169 to inline
the iterate calls themselves, which gives another 3x
improvement over the solution in that PR:

```
julia> @code_typed f(1)
CodeInfo(
1 ─ %1 = Core.tuple(x)::Tuple{Int64}
└──      return %1
) => Tuple{Int64}

julia> @benchmark f(1)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.696 ns (0.00% GC)
  median time:      1.699 ns (0.00% GC)
  mean time:        1.702 ns (0.00% GC)
  maximum time:     5.389 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000
```

Fixes #36087
Fixes #29114
  • Loading branch information
Keno authored Jul 18, 2020
1 parent 6a4793a commit 435bf88
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 179 deletions.
78 changes: 49 additions & 29 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
push!(fullmatch, thisfullmatch)
end
end
info = UnionSplitInfo(splitsigs, infos)
info = UnionSplitInfo(infos)
else
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
if mt === nothing
Expand Down Expand Up @@ -505,13 +505,13 @@ end
# returns an array of types
function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(typ), vtypes::VarTable, sv::InferenceState)
if isa(typ, PartialStruct) && typ.typ.name === Tuple.name
return typ.fields
return typ.fields, nothing
end

if isa(typ, Const)
val = typ.val
if isa(val, SimpleVector) || isa(val, Tuple)
return Any[ Const(val[i]) for i in 1:length(val) ] # avoid making a tuple Generator here!
return Any[ Const(val[i]) for i in 1:length(val) ], nothing # avoid making a tuple Generator here!
end
end

Expand All @@ -529,27 +529,27 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
if isa(tti, Union)
utis = uniontypes(tti)
if _any(t -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
return Any[Vararg{Any}]
return Any[Vararg{Any}], nothing
end
result = Any[rewrap_unionall(p, tti0) for p in utis[1].parameters]
for t in utis[2:end]
if length(t.parameters) != length(result)
return Any[Vararg{Any}]
return Any[Vararg{Any}], nothing
end
for j in 1:length(t.parameters)
result[j] = tmerge(result[j], rewrap_unionall(t.parameters[j], tti0))
end
end
return result
return result, nothing
elseif tti0 <: Tuple
if isa(tti0, DataType)
if isvatuple(tti0) && length(tti0.parameters) == 1
return Any[Vararg{unwrapva(tti0.parameters[1])}]
return Any[Vararg{unwrapva(tti0.parameters[1])}], nothing
else
return Any[ p for p in tti0.parameters ]
return Any[ p for p in tti0.parameters ], nothing
end
elseif !isa(tti, DataType)
return Any[Vararg{Any}]
return Any[Vararg{Any}], nothing
else
len = length(tti.parameters)
last = tti.parameters[len]
Expand All @@ -558,12 +558,12 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
if va
elts[len] = Vararg{elts[len]}
end
return elts
return elts, nothing
end
elseif tti0 === SimpleVector || tti0 === Any
return Any[Vararg{Any}]
return Any[Vararg{Any}], nothing
elseif tti0 <: Array
return Any[Vararg{eltype(tti0)}]
return Any[Vararg{eltype(tti0)}], nothing
else
return abstract_iteration(interp, itft, typ, vtypes, sv)
end
Expand All @@ -572,30 +572,34 @@ end
# simulate iteration protocol on container type up to fixpoint
function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(itertype), vtypes::VarTable, sv::InferenceState)
if !isdefined(Main, :Base) || !isdefined(Main.Base, :iterate) || !isconst(Main.Base, :iterate)
return Any[Vararg{Any}]
return Any[Vararg{Any}], nothing
end
if itft === nothing
iteratef = getfield(Main.Base, :iterate)
itft = Const(iteratef)
elseif isa(itft, Const)
iteratef = itft.val
else
return Any[Vararg{Any}]
return Any[Vararg{Any}], nothing
end
@assert !isvarargtype(itertype)
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], vtypes, sv).rt
call = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], vtypes, sv)
stateordonet = call.rt
info = call.info
# Return Bottom if this is not an iterator.
# WARNING: Changes to the iteration protocol must be reflected here,
# this is not just an optimization.
stateordonet === Bottom && return Any[Bottom]
stateordonet === Bottom && return Any[Bottom], AbstractIterationInfo(CallMeta[CallMeta(Bottom, info)])
valtype = statetype = Bottom
ret = Any[]
calls = CallMeta[call]

# Try to unroll the iteration up to MAX_TUPLE_SPLAT, which covers any finite
# length iterators, or interesting prefix
while true
stateordonet_widened = widenconst(stateordonet)
if stateordonet_widened === Nothing
return ret
return ret, AbstractIterationInfo(calls)
end
if Nothing <: stateordonet_widened || length(ret) >= InferenceParams(interp).MAX_TUPLE_SPLAT
break
Expand All @@ -607,12 +611,14 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
# If there's no new information in this statetype, don't bother continuing,
# the iterator won't be finite.
if nstatetype statetype
return Any[Bottom]
return Any[Bottom], nothing
end
valtype = getfield_tfunc(stateordonet, Const(1))
push!(ret, valtype)
statetype = nstatetype
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv).rt
call = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
stateordonet = call.rt
push!(calls, call)
end
# From here on, we start asking for results on the widened types, rather than
# the precise (potentially const) state type
Expand All @@ -629,15 +635,15 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
if nounion.parameters[1] <: valtype && nounion.parameters[2] <: statetype
if typeintersect(stateordonet, Nothing) === Union{}
# Reached a fixpoint, but Nothing is not possible => iterator is infinite or failing
return Any[Bottom]
return Any[Bottom], nothing
end
break
end
valtype = tmerge(valtype, nounion.parameters[1])
statetype = tmerge(statetype, nounion.parameters[2])
end
push!(ret, Vararg{valtype})
return ret
return ret, nothing
end

# do apply(af, fargs...), where af is a function value
Expand All @@ -656,13 +662,15 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
nargs = length(aargtypes)
splitunions = 1 < countunionsplit(aargtypes) <= InferenceParams(interp).MAX_APPLY_UNION_ENUM
ctypes = Any[Any[aft]]
infos = [Union{Nothing, AbstractIterationInfo}[]]
for i = 1:nargs
ctypes´ = []
infos′ = []
for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]])
if !isvarargtype(ti)
cti = precise_container_type(interp, itft, ti, vtypes, sv)
cti, info = precise_container_type(interp, itft, ti, vtypes, sv)
else
cti = precise_container_type(interp, itft, unwrapva(ti), vtypes, sv)
cti, info = precise_container_type(interp, itft, unwrapva(ti), vtypes, sv)
# We can't represent a repeating sequence of the same types,
# so tmerge everything together to get one type that represents
# everything.
Expand All @@ -678,19 +686,29 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
if _any(t -> t === Bottom, cti)
continue
end
for ct in ctypes
for j = 1:length(ctypes)
ct = ctypes[j]
if isvarargtype(ct[end])
# This is vararg, we're not gonna be able to do any inling,
# drop the info
info = nothing

tail = tuple_tail_elem(unwrapva(ct[end]), cti)
push!(ctypes´, push!(ct[1:(end - 1)], tail))
else
push!(ctypes´, append!(ct[:], cti))
end
push!(infos′, push!(copy(infos[j]), info))
end
end
ctypes = ctypes´
infos = infos′
end
local info = nothing
for ct in ctypes
retinfos = ApplyCallInfo[]
retinfo = UnionSplitApplyCallInfo(retinfos)
for i = 1:length(ctypes)
ct = ctypes[i]
arginfo = infos[i]
lct = length(ct)
# truncate argument list at the first Vararg
for i = 1:lct-1
Expand All @@ -701,15 +719,17 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
end
end
call = abstract_call(interp, nothing, ct, vtypes, sv, max_methods)
info = call.info
push!(retinfos, ApplyCallInfo(call.info, arginfo))
res = tmerge(res, call.rt)
if res === Any
# No point carrying forward the info, we're not gonna inline it anyway
retinfo = nothing
break
end
end
# TODO: Add a special info type to capture all the iteration info.
# For now, only propagate info if we don't also union-split the iteration
return CallMeta(res, length(ctypes) == 1 ? info : false)
return CallMeta(res, retinfo)
end

function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector)
Expand Down Expand Up @@ -779,7 +799,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
end
rt = builtin_tfunction(interp, f, argtypes[2:end], sv)
if f === getfield && isa(fargs, Vector{Any}) && la == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple
cti = precise_container_type(interp, nothing, argtypes[2], vtypes, sv)
cti, _ = precise_container_type(interp, nothing, argtypes[2], vtypes, sv)
idx = argtypes[3].val
if 1 <= idx <= length(cti)
rt = unwrapva(cti[idx])
Expand Down
Loading

0 comments on commit 435bf88

Please sign in to comment.