Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REPLCompletions: replace get_type by the proper inference #49206

Merged
merged 1 commit into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 174 additions & 117 deletions stdlib/REPL/src/REPLCompletions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ module REPLCompletions

export completions, shell_completions, bslash_completions, completion_text

using Core: CodeInfo, MethodInstance, CodeInstance, Const
const CC = Core.Compiler
using Base.Meta
using Base: propertynames, something

Expand Down Expand Up @@ -151,21 +153,21 @@ function complete_symbol(sym::String, @nospecialize(ffunc), context_module::Modu

ex = Meta.parse(lookup_name, raise=false, depwarn=false)

b, found = get_value(ex, context_module)
if found
val = b
if isa(b, Module)
mod = b
res = repl_eval_ex(ex, context_module)
res === nothing && return Completion[]
if res isa Const
val = res.val
if isa(val, Module)
mod = val
lookup_module = true
else
lookup_module = false
t = typeof(b)
t = typeof(val)
end
else # If the value is not found using get_value, the expression contain an advanced expression
else
lookup_module = false
t, found = get_type(ex, context_module)
t = CC.widenconst(res)
end
found || return Completion[]
end

suggestions = Completion[]
Expand Down Expand Up @@ -404,133 +406,182 @@ function find_start_brace(s::AbstractString; c_start='(', c_end=')')
return (startind:lastindex(s), method_name_end)
end

# Returns the value in a expression if sym is defined in current namespace fn.
# This method is used to iterate to the value of a expression like:
# :(REPL.REPLCompletions.whitespace_chars) a `dump` of this expression
# will show it consist of Expr, QuoteNode's and Symbol's which all needs to
# be handled differently to iterate down to get the value of whitespace_chars.
function get_value(sym::Expr, fn)
if sym.head === :quote || sym.head === :inert
return sym.args[1], true
end
sym.head !== :. && return (nothing, false)
for ex in sym.args
ex, found = get_value(ex, fn)::Tuple{Any, Bool}
!found && return (nothing, false)
fn, found = get_value(ex, fn)::Tuple{Any, Bool}
!found && return (nothing, false)
end
return (fn, true)
struct REPLInterpreterCache
dict::IdDict{MethodInstance,CodeInstance}
end
get_value(sym::Symbol, fn) = isdefined(fn, sym) ? (getfield(fn, sym), true) : (nothing, false)
get_value(sym::QuoteNode, fn) = (sym.value, true)
get_value(sym::GlobalRef, fn) = get_value(sym.name, sym.mod)
get_value(sym, fn) = (sym, true)

# Return the type of a getfield call expression
function get_type_getfield(ex::Expr, fn::Module)
length(ex.args) == 3 || return Any, false # should never happen, but just for safety
fld, found = get_value(ex.args[3], fn)
fld isa Symbol || return Any, false
obj = ex.args[2]
objt, found = get_type(obj, fn)
found || return Any, false
objt isa DataType || return Any, false
hasfield(objt, fld) || return Any, false
return fieldtype(objt, fld), true
REPLInterpreterCache() = REPLInterpreterCache(IdDict{MethodInstance,CodeInstance}())
const REPL_INTERPRETER_CACHE = REPLInterpreterCache()

function get_code_cache()
# XXX Avoid storing analysis results into the cache that persists across precompilation,
# as [sys|pkg]image currently doesn't support serializing externally created `CodeInstance`.
# Otherwise, `CodeInstance`s created by `REPLInterpreter``, that are much less optimized
# that those produced by `NativeInterpreter`, will leak into the native code cache,
# potentially causing runtime slowdown.
# (see https://github.com/JuliaLang/julia/issues/48453).
if (@ccall jl_generating_output()::Cint) == 1
return REPLInterpreterCache()
else
return REPL_INTERPRETER_CACHE
end
end

# Determines the return type with the Compiler of a function call using the type information of the arguments.
function get_type_call(expr::Expr, fn::Module)
f_name = expr.args[1]
f, found = get_type(f_name, fn)
found || return (Any, false) # If the function f is not found return Any.
args = Any[]
for i in 2:length(expr.args) # Find the type of the function arguments
typ, found = get_type(expr.args[i], fn)
found ? push!(args, typ) : push!(args, Any)
struct REPLInterpreter <: CC.AbstractInterpreter
repl_frame::CC.InferenceResult
world::UInt
inf_params::CC.InferenceParams
opt_params::CC.OptimizationParams
inf_cache::Vector{CC.InferenceResult}
code_cache::REPLInterpreterCache
function REPLInterpreter(repl_frame::CC.InferenceResult;
world::UInt = Base.get_world_counter(),
inf_params::CC.InferenceParams = CC.InferenceParams(),
opt_params::CC.OptimizationParams = CC.OptimizationParams(),
inf_cache::Vector{CC.InferenceResult} = CC.InferenceResult[],
code_cache::REPLInterpreterCache = get_code_cache())
return new(repl_frame, world, inf_params, opt_params, inf_cache, code_cache)
end
world = Base.get_world_counter()
return_type = Core.Compiler.return_type(Tuple{f, args...}, world)
return (return_type, true)
end

# Returns the return type. example: get_type(:(Base.strip("", ' ')), Main) returns (SubString{String}, true)
function try_get_type(sym::Expr, fn::Module)
val, found = get_value(sym, fn)
found && return Core.Typeof(val), found
if sym.head === :call
# getfield call is special cased as the evaluation of getfield provides good type information,
# is inexpensive and it is also performed in the complete_symbol function.
a1 = sym.args[1]
if a1 === :getfield || a1 === GlobalRef(Core, :getfield)
return get_type_getfield(sym, fn)
CC.InferenceParams(interp::REPLInterpreter) = interp.inf_params
CC.OptimizationParams(interp::REPLInterpreter) = interp.opt_params
CC.get_world_counter(interp::REPLInterpreter) = interp.world
CC.get_inference_cache(interp::REPLInterpreter) = interp.inf_cache
CC.code_cache(interp::REPLInterpreter) = CC.WorldView(interp.code_cache, CC.WorldRange(interp.world))
CC.get(wvc::CC.WorldView{REPLInterpreterCache}, mi::MethodInstance, default) = get(wvc.cache.dict, mi, default)
CC.getindex(wvc::CC.WorldView{REPLInterpreterCache}, mi::MethodInstance) = getindex(wvc.cache.dict, mi)
CC.haskey(wvc::CC.WorldView{REPLInterpreterCache}, mi::MethodInstance) = haskey(wvc.cache.dict, mi)
CC.setindex!(wvc::CC.WorldView{REPLInterpreterCache}, ci::CodeInstance, mi::MethodInstance) = setindex!(wvc.cache.dict, ci, mi)

# REPLInterpreter is only used for type analysis, so it should disable optimization entirely
CC.may_optimize(::REPLInterpreter) = false

# REPLInterpreter analyzes a top-level frame, so better to not bail out from it
CC.bail_out_toplevel_call(::REPLInterpreter, ::CC.InferenceLoopState, ::CC.InferenceState) = false

# `REPLInterpreter` aggressively resolves global bindings to enable reasonable completions
# for lines like `Mod.a.|` (where `|` is the cursor position).
# Aggressive binding resolution poses challenges for the inference cache validation
# (until https://github.com/JuliaLang/julia/issues/40399 is implemented).
# To avoid the cache validation issues, `REPLInterpreter` only allows aggressive binding
# resolution for top-level frame representing REPL input code (`repl_frame`) and for child
# `getproperty` frames that are constant propagated from the `repl_frame`. This works, since
# a.) these frames are never cached, and
# b.) their results are only observed by the non-cached `repl_frame`.
#
# `REPLInterpreter` also aggressively concrete evaluate `:inconsistent` calls within
# `repl_frame` to provide reasonable completions for lines like `Ref(Some(42))[].|`.
# Aggressive concrete evaluation allows us to get accurate type information about complex
# expressions that otherwise can not be constant folded, in a safe way, i.e. it still
# doesn't evaluate effectful expressions like `pop!(xs)`.
# Similarly to the aggressive binding resolution, aggressive concrete evaluation doesn't
# present any cache validation issues because `repl_frame` is never cached.

is_repl_frame(interp::REPLInterpreter, sv::CC.InferenceState) = interp.repl_frame === sv.result

# aggressive global binding resolution within `repl_frame`
function CC.abstract_eval_globalref(interp::REPLInterpreter, g::GlobalRef,
sv::CC.InferenceState)
if is_repl_frame(interp, sv)
if CC.isdefined_globalref(g)
return Const(ccall(:jl_get_globalref_value, Any, (Any,), g))
end
return get_type_call(sym, fn)
elseif sym.head === :thunk
thk = sym.args[1]
rt = ccall(:jl_infer_thunk, Any, (Any, Any), thk::Core.CodeInfo, fn)
rt !== Any && return (rt, true)
elseif sym.head === :ref
# some simple cases of `expand`
return try_get_type(Expr(:call, GlobalRef(Base, :getindex), sym.args...), fn)
elseif sym.head === :. && sym.args[2] isa QuoteNode # second check catches broadcasting
return try_get_type(Expr(:call, GlobalRef(Core, :getfield), sym.args...), fn)
elseif sym.head === :toplevel || sym.head === :block
isempty(sym.args) && return (nothing, true)
return try_get_type(sym.args[end], fn)
elseif sym.head === :escape || sym.head === :var"hygienic-scope"
return try_get_type(sym.args[1], fn)
return Union{}
end
return (Any, false)
return @invoke CC.abstract_eval_globalref(interp::CC.AbstractInterpreter, g::GlobalRef,
sv::CC.InferenceState)
end

try_get_type(other, fn::Module) = get_type(other, fn)
function is_repl_frame_getproperty(interp::REPLInterpreter, sv::CC.InferenceState)
def = sv.linfo.def
def isa Method || return false
def.name === :getproperty || return false
sv.cached && return false
return is_repl_frame(interp, sv.parent)
end

function get_type(sym::Expr, fn::Module)
# try to analyze nests of calls. if this fails, try using the expanded form.
val, found = try_get_type(sym, fn)
found && return val, found
# https://github.com/JuliaLang/julia/issues/27184
if isexpr(sym, :macrocall)
_, found = get_type(first(sym.args), fn)
found || return Any, false
end
newsym = try
macroexpand(fn, sym; recursive=false)
catch e
# user code failed in macroexpand (ignore it)
return Any, false
end
val, found = try_get_type(newsym, fn)
if !found
newsym = try
Meta.lower(fn, sym)
catch e
# user code failed in lowering (ignore it)
return Any, false
# aggressive global binding resolution for `getproperty(::Module, ::Symbol)` calls within `repl_frame`
function CC.builtin_tfunction(interp::REPLInterpreter, @nospecialize(f),
argtypes::Vector{Any}, sv::CC.InferenceState)
if f === Core.getglobal && is_repl_frame_getproperty(interp, sv)
if length(argtypes) == 2
a1, a2 = argtypes
if isa(a1, Const) && isa(a2, Const)
a1val, a2val = a1.val, a2.val
if isa(a1val, Module) && isa(a2val, Symbol)
g = GlobalRef(a1val, a2val)
if CC.isdefined_globalref(g)
return Const(ccall(:jl_get_globalref_value, Any, (Any,), g))
end
return Union{}
end
end
end
val, found = try_get_type(newsym, fn)
end
return val, found
return @invoke CC.builtin_tfunction(interp::CC.AbstractInterpreter, f::Any,
argtypes::Vector{Any}, sv::CC.InferenceState)
end

function get_type(sym, fn::Module)
val, found = get_value(sym, fn)
return found ? Core.Typeof(val) : Any, found
# aggressive concrete evaluation for `:inconsistent` frames within `repl_frame`
function CC.concrete_eval_eligible(interp::REPLInterpreter, @nospecialize(f),
result::CC.MethodCallResult, arginfo::CC.ArgInfo,
sv::CC.InferenceState)
if is_repl_frame(interp, sv)
neweffects = CC.Effects(result.effects; consistent=CC.ALWAYS_TRUE)
result = CC.MethodCallResult(result.rt, result.edgecycle, result.edgelimited,
result.edge, neweffects)
end
return @invoke CC.concrete_eval_eligible(interp::CC.AbstractInterpreter, f::Any,
result::CC.MethodCallResult, arginfo::CC.ArgInfo,
sv::CC.InferenceState)
end

function resolve_toplevel_symbols!(mod::Module, src::Core.CodeInfo)
newsrc = copy(src)
@ccall jl_resolve_globals_in_ir(
#=jl_array_t *stmts=# newsrc.code::Any,
#=jl_module_t *m=# mod::Any,
#=jl_svec_t *sparam_vals=# Core.svec()::Any,
#=int binding_effects=# 0::Int)::Cvoid
return newsrc
end

function get_type(T, found::Bool, default_any::Bool)
return found ? T :
default_any ? Any : throw(ArgumentError("argument not found"))
# lower `ex` and run type inference on the resulting top-level expression
function repl_eval_ex(@nospecialize(ex), context_module::Module)
lwr = try
Meta.lower(context_module, ex)
catch # macro expansion failed, etc.
return nothing
end
if lwr isa Symbol
return isdefined(context_module, lwr) ? Const(getfield(context_module, lwr)) : nothing
end
lwr isa Expr || return Const(lwr) # `ex` is literal
isexpr(lwr, :thunk) || return nothing # lowered to `Expr(:error, ...)` or similar
src = lwr.args[1]::Core.CodeInfo

# construct top-level `MethodInstance`
mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ());
mi.specTypes = Tuple{}

mi.def = context_module
src = resolve_toplevel_symbols!(context_module, src)
@atomic mi.uninferred = src

result = CC.InferenceResult(mi)
interp = REPLInterpreter(result)
frame = CC.InferenceState(result, src, #=cache=#:no, interp)::CC.InferenceState

CC.typeinf(interp, frame)

return frame.result.result
end

# Method completion on function call expression that look like :(max(1))
MAX_METHOD_COMPLETIONS::Int = 40
function _complete_methods(ex_org::Expr, context_module::Module, shift::Bool)
funct, found = get_type(ex_org.args[1], context_module)::Tuple{Any,Bool}
!found && return 2, funct, [], Set{Symbol}()

funct = repl_eval_ex(ex_org.args[1], context_module)
funct === nothing && return 2, nothing, [], Set{Symbol}()
funct = CC.widenconst(funct)
args_ex, kwargs_ex, kwargs_flag = complete_methods_args(ex_org, context_module, true, true)
return kwargs_flag, funct, args_ex, kwargs_ex
end
Expand Down Expand Up @@ -635,7 +686,14 @@ function detect_args_kwargs(funargs::Vector{Any}, context_module::Module, defaul
# argument types
push!(args_ex, Any)
else
push!(args_ex, get_type(get_type(ex, context_module)..., default_any))
argt = repl_eval_ex(ex, context_module)
if argt !== nothing
push!(args_ex, CC.widenconst(argt))
elseif default_any
push!(args_ex, Any)
else
throw(ArgumentError("argument not found"))
end
end
end
end
Expand Down Expand Up @@ -709,7 +767,6 @@ function close_path_completion(str, startpos, r, paths, pos)
return lastindex(str) <= pos || str[nextind(str, pos)] != '"'
end


function bslash_completions(string::String, pos::Int)
slashpos = something(findprev(isequal('\\'), string, pos), 0)
if (something(findprev(in(bslash_separators), string, pos), 0) < slashpos &&
Expand Down
Loading