Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure proper handling of sparams for widened compile signatures #47667

Merged
merged 4 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInsta
return ci
end
if may_discard_trees(interp)
cache_the_tree = ci.inferred && (is_inlineable(ci) || isa_compileable_sig(linfo.specTypes, def))
cache_the_tree = ci.inferred && (is_inlineable(ci) || isa_compileable_sig(linfo.specTypes, linfo.sparam_vals, def))
else
cache_the_tree = true
end
Expand Down
11 changes: 8 additions & 3 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ function get_compileable_sig(method::Method, @nospecialize(atype), sparams::Simp
mt, atype, sparams, method)
end

isa_compileable_sig(@nospecialize(atype), method::Method) =
!iszero(ccall(:jl_isa_compileable_sig, Int32, (Any, Any), atype, method))
isa_compileable_sig(@nospecialize(atype), sparams::SimpleVector, method::Method) =
!iszero(ccall(:jl_isa_compileable_sig, Int32, (Any, Any, Any), atype, sparams, method))

# eliminate UnionAll vars that might be degenerate due to having identical bounds,
# or a concrete upper bound and appearing covariantly.
Expand Down Expand Up @@ -200,7 +200,12 @@ function specialize_method(method::Method, @nospecialize(atype), sparams::Simple
if compilesig
new_atype = get_compileable_sig(method, atype, sparams)
new_atype === nothing && return nothing
atype = new_atype
if atype !== new_atype
sp_ = ccall(:jl_type_intersection_with_env, Any, (Any, Any), new_atype, method.sig)::SimpleVector
if sparams === sp_[2]::SimpleVector
atype = new_atype
end
end
end
if preexisting
# check cached specializations
Expand Down
338 changes: 230 additions & 108 deletions src/gf.c

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ static jl_callptr_t _jl_compile_codeinst(
// hack to export this pointer value to jl_dump_method_disasm
jl_atomic_store_release(&this_code->specptr.fptr, (void*)getAddressForFunction(decls.specFunctionObject));
}
if (this_code== codeinst)
if (this_code == codeinst)
fptr = addr;
}

Expand Down
26 changes: 25 additions & 1 deletion src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -1093,12 +1093,36 @@ jl_value_t *jl_unwrap_unionall(jl_value_t *v)
}

// wrap `t` in the same unionalls that surround `u`
// where `t` is derived from `u`, so the error checks in jl_type_unionall are unnecessary
jl_value_t *jl_rewrap_unionall(jl_value_t *t, jl_value_t *u)
{
if (!jl_is_unionall(u))
return t;
JL_GC_PUSH1(&t);
t = jl_rewrap_unionall(t, ((jl_unionall_t*)u)->body);
jl_tvar_t *v = ((jl_unionall_t*)u)->var;
// normalize `T where T<:S` => S
if (t == (jl_value_t*)v)
return v->ub;
// where var doesn't occur in body just return body
if (!jl_has_typevar(t, v))
return t;
JL_GC_PUSH1(&t);
//if (v->lb == v->ub) // TODO maybe
// t = jl_substitute_var(body, v, v->ub);
//else
t = jl_new_struct(jl_unionall_type, v, t);
JL_GC_POP();
return t;
}

// wrap `t` in the same unionalls that surround `u`
// where `t` is extended from `u`, so the checks in jl_rewrap_unionall are unnecessary
jl_value_t *jl_rewrap_unionall_(jl_value_t *t, jl_value_t *u)
{
if (!jl_is_unionall(u))
return t;
t = jl_rewrap_unionall_(t, ((jl_unionall_t*)u)->body);
JL_GC_PUSH1(&t);
t = jl_new_struct(jl_unionall_type, ((jl_unionall_t*)u)->var, t);
JL_GC_POP();
return t;
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,7 @@ STATIC_INLINE int jl_is_concrete_type(jl_value_t *v) JL_NOTSAFEPOINT
return jl_is_datatype(v) && ((jl_datatype_t*)v)->isconcretetype;
}

JL_DLLEXPORT int jl_isa_compileable_sig(jl_tupletype_t *type, jl_method_t *definition);
JL_DLLEXPORT int jl_isa_compileable_sig(jl_tupletype_t *type, jl_svec_t *sparams, jl_method_t *definition);

// type constructors
JL_DLLEXPORT jl_typename_t *jl_new_typename_in(jl_sym_t *name, jl_module_t *inmodule, int abstract, int mutabl);
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ JL_DLLEXPORT jl_value_t *jl_instantiate_type_in_env(jl_value_t *ty, jl_unionall_
jl_value_t *jl_substitute_var(jl_value_t *t, jl_tvar_t *var, jl_value_t *val);
JL_DLLEXPORT jl_value_t *jl_unwrap_unionall(jl_value_t *v JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_value_t *jl_rewrap_unionall(jl_value_t *t, jl_value_t *u);
JL_DLLEXPORT jl_value_t *jl_rewrap_unionall_(jl_value_t *t, jl_value_t *u);
int jl_count_union_components(jl_value_t *v);
JL_DLLEXPORT jl_value_t *jl_nth_union_component(jl_value_t *v JL_PROPAGATES_ROOT, int i) JL_NOTSAFEPOINT;
int jl_find_union_component(jl_value_t *haystack, jl_value_t *needle, unsigned *nth) JL_NOTSAFEPOINT;
Expand Down
4 changes: 2 additions & 2 deletions src/precompile.c
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ static void jl_compile_all_defs(jl_array_t *mis)
size_t i, l = jl_array_len(allmeths);
for (i = 0; i < l; i++) {
jl_method_t *m = (jl_method_t*)jl_array_ptr_ref(allmeths, i);
if (jl_isa_compileable_sig((jl_tupletype_t*)m->sig, m)) {
if (jl_is_datatype(m->sig) && jl_isa_compileable_sig((jl_tupletype_t*)m->sig, jl_emptysvec, m)) {
// method has a single compilable specialization, e.g. its definition
// signature is concrete. in this case we can just hint it.
jl_compile_hint((jl_tupletype_t*)m->sig);
Expand Down Expand Up @@ -354,7 +354,7 @@ static void *jl_precompile_(jl_array_t *m)
mi = (jl_method_instance_t*)item;
size_t min_world = 0;
size_t max_world = ~(size_t)0;
if (mi != jl_atomic_load_relaxed(&mi->def.method->unspecialized) && !jl_isa_compileable_sig((jl_tupletype_t*)mi->specTypes, mi->def.method))
if (mi != jl_atomic_load_relaxed(&mi->def.method->unspecialized) && !jl_isa_compileable_sig((jl_tupletype_t*)mi->specTypes, mi->sparam_vals, mi->def.method))
mi = jl_get_specialization1((jl_tupletype_t*)mi->specTypes, jl_atomic_load_acquire(&jl_world_counter), &min_world, &max_world, 0);
if (mi)
jl_array_ptr_1d_push(m2, (jl_value_t*)mi);
Expand Down
6 changes: 3 additions & 3 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -2890,8 +2890,8 @@ static jl_value_t *intersect_sub_datatype(jl_datatype_t *xd, jl_datatype_t *yd,
jl_value_t *super_pattern=NULL;
JL_GC_PUSH2(&isuper, &super_pattern);
jl_value_t *wrapper = xd->name->wrapper;
super_pattern = jl_rewrap_unionall((jl_value_t*)((jl_datatype_t*)jl_unwrap_unionall(wrapper))->super,
wrapper);
super_pattern = jl_rewrap_unionall_((jl_value_t*)((jl_datatype_t*)jl_unwrap_unionall(wrapper))->super,
wrapper);
int envsz = jl_subtype_env_size(super_pattern);
jl_value_t *ii = jl_bottom_type;
{
Expand Down Expand Up @@ -3528,7 +3528,7 @@ jl_value_t *jl_type_intersection_env_s(jl_value_t *a, jl_value_t *b, jl_svec_t *
if (jl_is_uniontype(ans_unwrapped)) {
ans_unwrapped = switch_union_tuple(((jl_uniontype_t*)ans_unwrapped)->a, ((jl_uniontype_t*)ans_unwrapped)->b);
if (ans_unwrapped != NULL) {
*ans = jl_rewrap_unionall(ans_unwrapped, *ans);
*ans = jl_rewrap_unionall_(ans_unwrapped, *ans);
}
}
JL_GC_POP();
Expand Down
2 changes: 1 addition & 1 deletion stdlib/Random/src/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ rand(rng::AbstractRNG, ::UniformT{T}) where {T} = rand(rng, T)
rand(rng::AbstractRNG, X) = rand(rng, Sampler(rng, X, Val(1)))
# this is needed to disambiguate
rand(rng::AbstractRNG, X::Dims) = rand(rng, Sampler(rng, X, Val(1)))
rand(rng::AbstractRNG=default_rng(), ::Type{X}=Float64) where {X} = rand(rng, Sampler(rng, X, Val(1)))::X
rand(rng::AbstractRNG=default_rng(), ::Type{X}=Float64) where {X} = rand(rng, Sampler(rng, X, Val(1)))::X

rand(X) = rand(default_rng(), X)
rand(::Type{X}) where {X} = rand(default_rng(), X)
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ f11366(x::Type{Ref{T}}) where {T} = Ref{x}


let f(T) = Type{T}
@test Base.return_types(f, Tuple{Type{Int}}) == [Type{Type{Int}}]
@test Base.return_types(f, Tuple{Type{Int}}) == Any[Type{Type{Int}}]
end

# issue #9222
Expand Down
25 changes: 25 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7902,3 +7902,28 @@ struct ModTparamTestStruct{M}; end
end
@test ModTparamTestStruct{@__MODULE__}() == 2
@test ModTparamTestStruct{ModTparamTest}() == 1

# issue #47476
f47476(::Union{Int, NTuple{N,Int}}...) where {N} = N
# force it to populate the MethodInstance specializations cache
# with the correct sparams
code_typed(f47476, (Vararg{Union{Int, NTuple{2,Int}}},));
code_typed(f47476, (Int, Vararg{Union{Int, NTuple{2,Int}}},));
code_typed(f47476, (Int, Int, Vararg{Union{Int, NTuple{2,Int}}},))
code_typed(f47476, (Int, Int, Int, Vararg{Union{Int, NTuple{2,Int}}},))
code_typed(f47476, (Int, Int, Int, Int, Vararg{Union{Int, NTuple{2,Int}}},))
@test f47476(1, 2, 3, 4, 5, 6, (7, 8)) === 2
@test_throws UndefVarError(:N) f47476(1, 2, 3, 4, 5, 6, 7)

vect47476(::Type{T}) where {T} = T
@test vect47476(Type{Type{Type{Int32}}}) === Type{Type{Type{Int32}}}
@test vect47476(Type{Type{Type{Int64}}}) === Type{Type{Type{Int64}}}

g47476(::Union{Nothing,Int,Val{T}}...) where {T} = T
@test_throws UndefVarError(:T) g47476(nothing, 1, nothing, 2, nothing, 3, nothing, 4, nothing, 5)
@test g47476(nothing, 1, nothing, 2, nothing, 3, nothing, 4, nothing, 5, Val(6)) === 6
let spec = only(methods(g47476)).specializations
@test !isempty(spec)
@test any(mi -> mi !== nothing && Base.isvatuple(mi.specTypes), spec)
@test all(mi -> mi === nothing || !Base.has_free_typevars(mi.specTypes), spec)
end
4 changes: 2 additions & 2 deletions test/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1493,8 +1493,8 @@ end
f(x, y) = x + y
f(x::Int, y) = 2x + y
end
precompile(M.f, (Int, Any))
precompile(M.f, (AbstractFloat, Any))
@test precompile(M.f, (Int, Any))
@test precompile(M.f, (AbstractFloat, Any))
mis = map(methods(M.f)) do m
m.specializations[1]
end
Expand Down