Skip to content

Commit

Permalink
fix #8505, static parameters in staged functions
Browse files Browse the repository at this point in the history
  • Loading branch information
JeffBezanson committed Oct 23, 2014
1 parent 45a1463 commit 81136af
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 20 deletions.
14 changes: 7 additions & 7 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -611,11 +611,11 @@ const limit_tuple_type_n = function (t::Tuple, lim::Int)
return t
end

function func_for_method(m::Method, tt)
function func_for_method(m::Method, tt, env)
if !m.isstaged
return m.func.code
end
(ccall(:jl_instantiate_staged,Any,(Any,Any),m,tt)).code
(ccall(:jl_instantiate_staged,Any,(Any,Any,Any),m,tt,env)).code
end

function abstract_call_gf(f, fargs, argtypes, e)
Expand Down Expand Up @@ -691,7 +691,7 @@ function abstract_call_gf(f, fargs, argtypes, e)
for (m::Tuple) in x
local linfo
try
linfo = func_for_method(m[3],argtypes)
linfo = func_for_method(m[3],argtypes,m[2])
catch
rettype = Any
break
Expand Down Expand Up @@ -742,7 +742,7 @@ function invoke_tfunc(f, types, argtypes)
for (m::Tuple) in applicable
local linfo
try
linfo = func_for_method(m[3],types)
linfo = func_for_method(m[3],types,m[2])
catch
return Any
end
Expand Down Expand Up @@ -2091,7 +2091,7 @@ function inlineable(f, e::Expr, atypes, sv, enclosing_ast)

local linfo
try
linfo = func_for_method(meth[3],atypes)
linfo = func_for_method(meth[3],atypes,meth[2])
catch
return NF
end
Expand Down Expand Up @@ -3061,7 +3061,7 @@ code_typed(f, types) = code_typed(call, tuple(isa(f,Type)?Type{f}:typeof(f), typ
function code_typed(f::Function, types::(Type...))
asts = []
for x in _methods(f,types,-1)
linfo = func_for_method(x[3],types)
linfo = func_for_method(x[3],types,x[2])
(tree, ty) = typeinf(linfo, x[1], x[2])
if !isa(tree,Expr)
push!(asts, ccall(:jl_uncompress_ast, Any, (Any,Any), linfo, tree))
Expand All @@ -3076,7 +3076,7 @@ return_types(f, types) = return_types(call, tuple(isa(f,Type)?Type{f}:typeof(f),
function return_types(f::Function, types)
rt = []
for x in _methods(f,types,-1)
linfo = func_for_method(x[3],types)
linfo = func_for_method(x[3],types,x[2])
(tree, ty) = typeinf(linfo, x[1], x[2])
push!(rt, ty)
end
Expand Down
41 changes: 28 additions & 13 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,23 @@ jl_function_t *jl_instantiate_method(jl_function_t *f, jl_tuple_t *sp)
return nf;
}

// append values of static parameters to closure environment
static jl_function_t *with_appended_env(jl_function_t *meth, jl_tuple_t *sparams)
{
if (sparams == jl_null)
return meth;
jl_value_t *temp = (jl_value_t*)jl_alloc_tuple(jl_tuple_len(sparams)/2);
JL_GC_PUSH1(&temp);
size_t i;
for(i=0; i < jl_tuple_len(temp); i++) {
jl_tupleset(temp, i, jl_tupleref(sparams,i*2+1));
}
temp = (jl_value_t*)jl_tuple_append((jl_tuple_t*)meth->env, (jl_tuple_t*)temp);
meth = jl_new_closure(meth->fptr, temp, meth->linfo);
JL_GC_POP();
return meth;
}

// make a new method that calls the generated code from the given linfo
jl_function_t *jl_reinstantiate_method(jl_function_t *f, jl_lambda_info_t *li)
{
Expand Down Expand Up @@ -814,12 +831,7 @@ static jl_function_t *cache_method(jl_methtable_t *mt, jl_tuple_t *type,
newmeth = jl_new_closure(unspec->fptr, method->env, unspec->linfo);

if (sparams != jl_null) {
temp = (jl_value_t*)jl_alloc_tuple(jl_tuple_len(sparams)/2);
for(i=0; i < jl_tuple_len(temp); i++) {
jl_tupleset(temp, i, jl_tupleref(sparams,i*2+1));
}
temp = (jl_value_t*)jl_tuple_append((jl_tuple_t*)newmeth->env, (jl_tuple_t*)temp);
newmeth = jl_new_closure(newmeth->fptr, temp, newmeth->linfo);
newmeth = with_appended_env(newmeth, sparams);
}

(void)jl_method_cache_insert(mt, type, newmeth);
Expand Down Expand Up @@ -932,7 +944,7 @@ static jl_value_t *lookup_match(jl_value_t *a, jl_value_t *b, jl_tuple_t **penv,
return ti;
}

DLLEXPORT jl_function_t *jl_instantiate_staged(jl_methlist_t *m, jl_tuple_t *tt)
DLLEXPORT jl_function_t *jl_instantiate_staged(jl_methlist_t *m, jl_tuple_t *tt, jl_tuple_t *env)
{
jl_expr_t *ex = NULL;
jl_expr_t *oldast = NULL;
Expand Down Expand Up @@ -962,7 +974,8 @@ DLLEXPORT jl_function_t *jl_instantiate_staged(jl_methlist_t *m, jl_tuple_t *tt)
jl_cellset(argnames->args,i,arg);
}
}
jl_cellset(ex->args, 1, jl_apply(m->func, tt->data, jl_tuple_len(tt)));
func = with_appended_env(m->func, env);
jl_cellset(ex->args, 1, jl_apply(func, tt->data, jl_tuple_len(tt)));
func = (jl_function_t*)jl_toplevel_eval_in(m->func->linfo->module, (jl_value_t*)ex);
JL_GC_POP();
return func;
Expand Down Expand Up @@ -1015,7 +1028,7 @@ static jl_function_t *jl_mt_assoc_by_type(jl_methtable_t *mt, jl_tuple_t *tt, in
if (m != JL_NULL) {
func = m->func;
if (m->isstaged)
func = jl_instantiate_staged(m,tt);
func = jl_instantiate_staged(m,tt,env);
JL_GC_POP();
if (!cache)
return func;
Expand All @@ -1029,7 +1042,7 @@ static jl_function_t *jl_mt_assoc_by_type(jl_methtable_t *mt, jl_tuple_t *tt, in
func = m->func;

if (m->isstaged)
func = jl_instantiate_staged(m,tt);
func = jl_instantiate_staged(m,tt,env);

// don't bother computing this if no arguments are tuples
for(i=0; i < jl_tuple_len(tt); i++) {
Expand Down Expand Up @@ -1462,6 +1475,7 @@ static void all_p2c(jl_value_t *ast, jl_tuple_t *tvars)
jl_lambda_info_t *li = (jl_lambda_info_t*)ast;
li->ast = jl_prepare_ast(li, jl_null);
parameters_to_closureenv(li->ast, tvars);
all_p2c(li->ast, tvars);
}
else if (jl_is_expr(ast)) {
jl_expr_t *e = (jl_expr_t*)ast;
Expand All @@ -1477,9 +1491,7 @@ static void precompile_unspecialized(jl_function_t *func, jl_tuple_t *sig, jl_tu
// add static parameter names to end of closure env; compile
// assuming they are there. method cache will fill them in when
// it constructs closures for new "specializations".
func->linfo->ast = jl_prepare_ast(func->linfo, jl_null);
parameters_to_closureenv(func->linfo->ast, tvars);
all_p2c(func->linfo->ast, tvars);
all_p2c((jl_value_t*)func->linfo, tvars);
}
jl_trampoline_compile_function(func, 1, sig ? sig : jl_tuple_type);
}
Expand Down Expand Up @@ -1762,6 +1774,9 @@ void jl_add_method(jl_function_t *gf, jl_tuple_t *types, jl_function_t *meth,
assert(jl_is_mtable(jl_gf_mtable(gf)));
if (meth->linfo != NULL)
meth->linfo->name = jl_gf_name(gf);
if (isstaged && tvars != jl_null) {
all_p2c((jl_value_t*)meth->linfo, tvars);
}
(void)jl_method_table_insert(jl_gf_mtable(gf), types, meth, tvars, isstaged);
}

Expand Down
7 changes: 7 additions & 0 deletions test/staged.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,10 @@ stagedfunction h(x)
end
end
@test MyTest8497.h(3) == 4

# static parameters (issue #8505)
stagedfunction f8505{T}(x::Vector{T})
T
end
@test f8505([1.0]) === Float64
@test f8505([1]) === Int

0 comments on commit 81136af

Please sign in to comment.