Skip to content

Commit

Permalink
WIP: also build compiled calls for ccall (#216)
Browse files Browse the repository at this point in the history
* in some cases, also build compiled calls for ccall

* fixes

* put back precompile statement

* fix working with sparams

* fix compiling ccalls when they are RHS of assignments

* fix for TypeofBottom

* fix calls to get pointer in ccall
  • Loading branch information
KristofferC authored Mar 28, 2019
1 parent 8758ba9 commit c1b99a1
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 30 deletions.
4 changes: 3 additions & 1 deletion src/commands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ function maybe_step_through_wrapper!(@nospecialize(recurse), frame::Frame)
end
end
ret = evaluate_call!(dummy_breakpoint, frame, last)
@assert isa(ret, BreakpointRef)
if !isa(ret, BreakpointRef) # Happens if next call is Compiled
return frame
end
frame.framedata.ssavalues[frame.pc] = Wrapper()
return maybe_step_through_wrapper!(recurse, callee(frame))
end
Expand Down
8 changes: 6 additions & 2 deletions src/interpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,18 @@ function evaluate_foreigncall(frame::Frame, call_expr::Expr)
return Core.eval(moduleof(frame), Expr(head, args...))
end

# We have to intercept llvmcall before we try it as a builtin
# We have to intercept ccalls / llvmcalls before we try it as a builtin
function bypass_builtins(frame, call_expr, pc)
if isassigned(frame.framecode.methodtables, pc)
tme = frame.framecode.methodtables[pc]
if isa(tme, Compiled)
fargs = collect_args(frame, call_expr)
f = to_function(fargs[1])
return Some{Any}(f(fargs[2:end]...))
if parentmodule(f) === JuliaInterpreter.CompiledCalls
return Some{Any}(Base.invokelatest(f, fargs[2:end]...))
else
return Some{Any}(f(fargs[2:end]...))
end
end
end
return nothing
Expand Down
142 changes: 116 additions & 26 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,14 @@ Currently it looks up `GlobalRef`s (for which it needs `mod` to know the scope i
which this will run) and ensures that no statement includes nested `:call` expressions
(splitting them out into multiple SSA-form statements if needed).
"""
function optimize!(code::CodeInfo, mod::Module)
function optimize!(code::CodeInfo, scope)
mod = moduleof(scope)
sparams = scope isa Method ? Symbol[sparam_syms(scope)...] : Symbol[]
code.inferred && error("optimization of inferred code not implemented")
# TODO: because of builtins.jl, for CodeInfos like
# %1 = Core.apply_type
# %2 = (%1)(args...)
# it would be best to *not* resolve the GlobalRef at %1

## Replace GlobalRefs with QuoteNodes
for (i, stmt) in enumerate(code.code)
if isa(stmt, GlobalRef)
Expand Down Expand Up @@ -150,42 +151,131 @@ function optimize!(code::CodeInfo, mod::Module)
# Replace :llvmcall and :foreigncall with compiled variants. See
# https://github.com/JuliaDebug/JuliaInterpreter.jl/issues/13#issuecomment-464880123
methodtables = Vector{Union{Compiled,TypeMapEntry}}(undef, length(code.code))
# @show code
for (idx, stmt) in enumerate(code.code)
# Foregincalls can be rhs of assignments
if isexpr(stmt, :(=))
stmt = stmt.args[2]
end
if isexpr(stmt, :call)
# Check for :llvmcall
arg1 = stmt.args[1]
if arg1 == :llvmcall || lookup_stmt(code.code, arg1) == Base.llvmcall
if (arg1 == :llvmcall || lookup_stmt(code.code, arg1) == Base.llvmcall) && isempty(sparams) && scope isa Method
uuid = uuid4()
ustr = replace(string(uuid), '-'=>'_')
methname = Symbol("llvmcall_", ustr)
nargs = length(stmt.args)-4
argnames = [Symbol("arg", string(i)) for i = 1:nargs]
# Run a mini-interpreter to extract the types
framecode = FrameCode(CompiledCalls, code; optimize=false)
frame = Frame(framecode, prepare_framedata(framecode, []))
idxstart = idx
for i = 2:4
idxstart = smallest_ref(code.code, stmt.args[i], idxstart)
end
frame.pc = idxstart
while true
pc = step_expr!(Compiled(), frame)
pc == idx && break
pc === nothing && error("this should never happen")
end
str, RetType, ArgType = @lookup(frame, stmt.args[2]), @lookup(frame, stmt.args[3]), @lookup(frame, stmt.args[4])
def = quote
function $methname($(argnames...))
return Base.llvmcall($str, $RetType, $ArgType, $(argnames...))
end
end
f = Core.eval(CompiledCalls, def)
stmt.args[1] = QuoteNode(f)
deleteat!(stmt.args, 2:4)
build_compiled_call!(stmt, methname, Base.llvmcall, stmt.args[2:4], code, idx, nargs, sparams)
methodtables[idx] = Compiled()
end
elseif isexpr(stmt, :foreigncall) && scope isa Method
f = lookup_stmt(code.code, stmt.args[1])
if isa(f, Ptr)
f = string(uuid4())
elseif isexpr(f, :call)
length(f.args) == 3 || continue
f.args[1] === tuple || continue
lib = f.args[3] isa String ? f.args[3] : f.args[3].value
prefix = f.args[2] isa String ? f.args[2] : f.args[2].value
f = Symbol(prefix, '_', lib)
end
# Punt on non literal ccall arguments for now
if !(isa(f, String) || isa(f, Symbol) || isa(f, Ptr))
continue
end
# TODO: Only compile one ccall per call and argument types
uuid = uuid4()
ustr = replace(string(uuid), '-'=>'_')
methname = Symbol("ccall", '_', f, '_', ustr)
nargs = stmt.args[5]
build_compiled_call!(stmt, methname, :ccall, stmt.args[1:3], code, idx, nargs, sparams)
methodtables[idx] = Compiled()
end
end

return code, methodtables
end

function parametric_type_to_expr(t::Type)
t isa Core.TypeofBottom && return t
return t.hasfreetypevars ? Expr(:curly, t.name.name, ((tv-> tv isa TypeVar ? tv.name : tv).(t.parameters))...) : t
end

# Handle :llvmcall & :foreigncall (issue #28)
function build_compiled_call!(stmt, methname, fcall, typargs, code, idx, nargs, sparams)
argnames = Any[Symbol("arg", string(i)) for i = 1:nargs]
if fcall == :ccall
cfunc, RetType, ArgType = lookup_stmt(code.code, stmt.args[1]), stmt.args[2], stmt.args[3]
# The result of this is useful to have next to you when reading this code:
# f(x, y) = ccall(:jl_value_ptr, Ptr{Cvoid}, (Float32,Any), x, y)
# @code_lowered f(2, 3)
args = []
for (atype, arg) in zip(ArgType, stmt.args[6:6+nargs-1])
if atype === Any
push!(args, arg)
else
@assert arg isa SSAValue
unsafe_convert_expr = code.code[arg.id]
cconvert_expr = code.code[unsafe_convert_expr.args[3].id]
push!(args, cconvert_expr.args[3])
end
end
else
# Run a mini-interpreter to extract the types
framecode = FrameCode(CompiledCalls, code; optimize=false)
frame = Frame(framecode, prepare_framedata(framecode, []))
idxstart = idx
for i = 2:4
idxstart = smallest_ref(code.code, stmt.args[i], idxstart)
end
frame.pc = idxstart
if idxstart < idx
while true
pc = step_expr!(Compiled(), frame)
pc == idx && break
pc === nothing && error("this should never happen")
end
end
cfunc, RetType, ArgType = @lookup(frame, stmt.args[2]), @lookup(frame, stmt.args[3]), @lookup(frame, stmt.args[4])
args = stmt.args[5:end]
end
if isa(cfunc, Expr)
cfunc = eval(cfunc)
end
if isa(cfunc, Symbol)
cfunc = QuoteNode(cfunc)
end
if fcall == :ccall
ArgType = Expr(:tuple, [parametric_type_to_expr(t) for t in ArgType]...)
end
if isa(RetType, SimpleVector)
@assert length(RetType) == 1
RetType = RetType[1]
end
RetType = parametric_type_to_expr(RetType)
wrapargs = copy(argnames)
for sparam in sparams
push!(wrapargs, :(::Type{$sparam}))
end
if stmt.args[4] == :(:llvmcall)
def = :(
function $methname($(wrapargs...)) where {$(sparams...)}
return $fcall($cfunc, llvmcall, $RetType, $ArgType, $(argnames...))
end)
else
def = :(
function $methname($(wrapargs...)) where {$(sparams...)}
return $fcall($cfunc, $RetType, $ArgType, $(argnames...))
end)
end
f = Core.eval(CompiledCalls, def)
stmt.args[1] = QuoteNode(f)
stmt.head = :call
deleteat!(stmt.args, 2:length(stmt.args))
append!(stmt.args, args)
for i in 1:length(sparams)
push!(stmt.args, :($(Expr(:static_parameter, 1))))
end
return
end

1 change: 1 addition & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ function _precompile_()
@assert precompile(Tuple{typeof(enter_call_expr), Expr})
@assert precompile(Tuple{typeof(copy_codeinfo), Core.CodeInfo})
@assert precompile(Tuple{typeof(optimize!), Core.CodeInfo, Module})
@assert precompile(Tuple{typeof(optimize!), Core.CodeInfo, Method})
@assert precompile(Tuple{typeof(set_structtype_const), Module, Symbol})
@assert precompile(Tuple{typeof(namedtuple), Vector{Any}})
@assert precompile(Tuple{typeof(resolvefc), Frame, Any})
Expand Down
2 changes: 1 addition & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ end
const BREAKPOINT_EXPR = :($(QuoteNode(getproperty))($JuliaInterpreter, :__BREAKPOINT_MARKER__))
function FrameCode(scope, src::CodeInfo; generator=false, optimize=true)
if optimize
src, methodtables = optimize!(copy_codeinfo(src), moduleof(scope))
src, methodtables = optimize!(copy_codeinfo(src), scope)
else
src = copy_codeinfo(src)
methodtables = Vector{Union{Compiled,TypeMapEntry}}(undef, length(src.code))
Expand Down
10 changes: 10 additions & 0 deletions test/interpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,13 @@ function hash220(x::Tuple{Ptr{UInt8},Int}, h::UInt)
ccall(Base.memhash, UInt, (Ptr{UInt8}, Csize_t, UInt32), x[1], x[2], h % UInt32) + h
end
@test @interpret(hash220((Ptr{UInt8}(0),0), UInt(1))) == hash220((Ptr{UInt8}(0),0), UInt(1))

# ccall with type parameters
@test (@interpret Base.unsafe_convert(Ptr{Int}, [1,2])) isa Ptr{Int}

# ccall with call to get the pointer
cf = [@cfunction(fcfun, Int, (Int, Int))]
function call_cf()
ccall(cf[1], Int, (Int, Int), 1, 2)
end
@test (@interpret call_cf()) == call_cf()

0 comments on commit c1b99a1

Please sign in to comment.