Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix invdepth within existential subtyping. #49049

Merged
merged 2 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 30 additions & 32 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ static int var_outside(jl_stenv_t *e, jl_tvar_t *x, jl_tvar_t *y)
return 0;
}

static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int R, int d);
static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth);

static int reachable_var(jl_value_t *x, jl_tvar_t *y, jl_stenv_t *e);

Expand All @@ -706,7 +706,7 @@ static int var_lt(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int param)
// for this to work we need to compute issub(left,right) before issub(right,left),
// since otherwise the issub(a, bb.ub) check in var_gt becomes vacuous.
if (e->intersection) {
jl_value_t *ub = intersect_aside(bb->ub, a, e, 0, bb->depth0);
jl_value_t *ub = intersect_aside(a, bb->ub, e, bb->depth0);
JL_GC_PUSH1(&ub);
if (ub != (jl_value_t*)b && (!jl_is_typevar(ub) || !reachable_var(ub, b, e)))
bb->ub = ub;
Expand Down Expand Up @@ -2054,11 +2054,6 @@ static int subtype_in_env(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
return subtype_in_env_(x, y, e, e->invdepth, e->Rinvdepth);
}

static int subtype_bounds_in_env(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int R, int d)
{
return subtype_in_env_(x, y, e, R ? e->invdepth : d, R ? d : e->Rinvdepth);
}

JL_DLLEXPORT int jl_subtype(jl_value_t *x, jl_value_t *y)
{
return jl_subtype_env(x, y, NULL, 0);
Expand Down Expand Up @@ -2265,27 +2260,24 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e);

// intersect in nested union environment, similar to subtype_ccheck
static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int R, int d)
static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth)
{
// band-aid for #30335
if (x == (jl_value_t*)jl_any_type && !jl_is_typevar(y))
return y;
if (y == (jl_value_t*)jl_any_type && !jl_is_typevar(x))
return x;
// band-aid for #46736
if (jl_egal(x, y))
if (obviously_egal(x, y))
return x;

jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
int savedepth = e->invdepth, Rsavedepth = e->Rinvdepth;
// TODO: this doesn't quite make sense
e->invdepth = e->Rinvdepth = d;

e->invdepth = e->Rinvdepth = depth;
jl_value_t *res = intersect_all(x, y, e);

pop_unionstate(&e->Runions, &oldRunions);
e->invdepth = savedepth;
e->Rinvdepth = Rsavedepth;
pop_unionstate(&e->Runions, &oldRunions);
return res;
}

Expand Down Expand Up @@ -2386,14 +2378,16 @@ static int try_subtype_by_bounds(jl_value_t *a, jl_value_t *b, jl_stenv_t *e)
return 0;
}

static int try_subtype_in_env(jl_value_t *a, jl_value_t *b, jl_stenv_t *e, int R, int d)
static int try_subtype_in_env(jl_value_t *a, jl_value_t *b, jl_stenv_t *e, int flip)
{
if (a == jl_bottom_type || b == (jl_value_t *)jl_any_type || try_subtype_by_bounds(a, b, e))
return 1;
jl_value_t *root=NULL; jl_savedenv_t se;
JL_GC_PUSH1(&root);
save_env(e, &root, &se);
int ret = subtype_bounds_in_env(a, b, e, R, d);
int invdepth = flip ? e->Rinvdepth : e->invdepth;
int Rinvdepth = flip ? e->invdepth : e->Rinvdepth;
int ret = subtype_in_env_(a, b, e, invdepth, Rinvdepth);
restore_env(e, root, &se);
free_env(&se);
JL_GC_POP();
Expand All @@ -2415,7 +2409,7 @@ static void set_bound(jl_value_t **bound, jl_value_t *val, jl_tvar_t *v, jl_sten
}

// subtype, treating all vars as existential
static int subtype_in_env_existential(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int R, int d)
static int subtype_in_env_existential(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int flip)
{
jl_varbinding_t *v = e->vars;
int len = 0;
Expand All @@ -2434,7 +2428,9 @@ static int subtype_in_env_existential(jl_value_t *x, jl_value_t *y, jl_stenv_t *
v->right = 1;
v = v->prev;
}
int issub = subtype_bounds_in_env(x, y, e, R, d);
int invdepth = flip ? e->Rinvdepth : e->invdepth;
int Rinvdepth = flip ? e->invdepth : e->Rinvdepth;
int issub = subtype_in_env_(x, y, e, invdepth, Rinvdepth);
n = 0; v = e->vars;
while (n < len) {
assert(v != NULL);
Expand Down Expand Up @@ -2512,25 +2508,23 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
{
jl_varbinding_t *bb = lookup(e, b);
if (bb == NULL)
return R ? intersect_aside(a, b->ub, e, 1, 0) : intersect_aside(b->ub, a, e, 0, 0);
return R ? intersect_aside(a, b->ub, e, 0) : intersect_aside(b->ub, a, e, 0);
if (reachable_var(bb->lb, b, e) || reachable_var(bb->ub, b, e))
return a;
if (bb->lb == bb->ub && jl_is_typevar(bb->lb)) {
return intersect(a, bb->lb, e, param);
}
if (bb->lb == bb->ub && jl_is_typevar(bb->lb))
return R ? intersect(a, bb->lb, e, param) : intersect(bb->lb, a, e, param);
if (!jl_is_type(a) && !jl_is_typevar(a))
return set_var_to_const(bb, a, NULL);
int d = bb->depth0;
jl_value_t *root=NULL; jl_savedenv_t se;
if (param == 2) {
jl_value_t *ub = NULL;
JL_GC_PUSH2(&ub, &root);
if (!jl_has_free_typevars(a)) {
save_env(e, &root, &se);
int issub = subtype_in_env_existential(bb->lb, a, e, 0, d);
int issub = subtype_in_env_existential(bb->lb, a, e, R);
restore_env(e, root, &se);
if (issub) {
issub = subtype_in_env_existential(a, bb->ub, e, 1, d);
issub = subtype_in_env_existential(a, bb->ub, e, !R);
restore_env(e, root, &se);
}
free_env(&se);
Expand All @@ -2542,10 +2536,10 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
}
else {
e->triangular++;
ub = R ? intersect_aside(a, bb->ub, e, 1, d) : intersect_aside(bb->ub, a, e, 0, d);
ub = R ? intersect_aside(a, bb->ub, e, bb->depth0) : intersect_aside(bb->ub, a, e, bb->depth0);
e->triangular--;
save_env(e, &root, &se);
int issub = subtype_in_env_existential(bb->lb, ub, e, 0, d);
int issub = subtype_in_env_existential(bb->lb, ub, e, R);
restore_env(e, root, &se);
free_env(&se);
if (!issub) {
Expand Down Expand Up @@ -2576,7 +2570,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
JL_GC_POP();
return ub;
}
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, 1, d) : intersect_aside(bb->ub, a, e, 0, d);
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, bb->depth0) : intersect_aside(bb->ub, a, e, bb->depth0);
if (ub == jl_bottom_type)
return jl_bottom_type;
if (bb->constraintkind == 1 || e->triangular) {
Expand All @@ -2587,7 +2581,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
}
else if (bb->constraintkind == 0) {
JL_GC_PUSH1(&ub);
if (!jl_is_typevar(a) && try_subtype_in_env(bb->ub, a, e, 0, d)) {
if (!jl_is_typevar(a) && try_subtype_in_env(bb->ub, a, e, R)) {
JL_GC_POP();
return (jl_value_t*)b;
}
Expand Down Expand Up @@ -3107,6 +3101,9 @@ static void flip_vars(jl_stenv_t *e)
btemp->right = !btemp->right;
btemp = btemp->prev;
}
int temp = e->invdepth;
e->invdepth = e->Rinvdepth;
e->Rinvdepth = temp;
}

// intersection where xd nominally inherits from yd
Expand Down Expand Up @@ -3154,11 +3151,11 @@ static jl_value_t *intersect_invariant(jl_value_t *x, jl_value_t *y, jl_stenv_t
jl_savedenv_t se;
JL_GC_PUSH2(&ii, &root);
save_env(e, &root, &se);
if (!subtype_in_env_existential(x, y, e, 0, e->invdepth))
if (!subtype_in_env_existential(x, y, e, 0))
ii = NULL;
else {
restore_env(e, root, &se);
if (!subtype_in_env_existential(y, x, e, 0, e->invdepth))
if (!subtype_in_env_existential(y, x, e, 1))
ii = NULL;
}
restore_env(e, root, &se);
Expand Down Expand Up @@ -3320,7 +3317,8 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
}
jl_value_t *ub=NULL, *lb=NULL;
JL_GC_PUSH2(&lb, &ub);
ub = intersect_aside(xub, yub, e, 0, xx ? xx->depth0 : 0);
int d = xx ? xx->depth0 : yy ? yy->depth0 : 0;
ub = R ? intersect_aside(yub, xub, e, d) : intersect_aside(xub, yub, e, d);
if (reachable_var(xlb, (jl_tvar_t*)y, e))
lb = ylb;
else
Expand Down
16 changes: 10 additions & 6 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1804,8 +1804,14 @@ end
#end

# issue #32386
@test typeintersect(Type{S} where S<:(Vector{Pair{_A,N} where N} where _A),
Type{Vector{T}} where T) == Type{Vector{Pair{_A,N} where N}} where _A
@testintersect(Type{S} where S<:(Vector{Pair{_A,N} where N} where _A),
Type{Vector{T}} where T,
Type{Vector{Pair{_A,N} where N}} where _A)

# pr #49049
@testintersect(Tuple{Type{Pair{T, A} where {T, A<:Array{T}}}, Int, Any},
Tuple{Type{F}, Any, Int} where {F<:(Pair{T, A} where {T, A<:Array{T}})},
Tuple{Type{Pair{T, A} where {T, A<:(Array{T})}}, Int, Int})

# issue #32488
struct S32488{S <: Tuple, T, N, L}
Expand Down Expand Up @@ -2431,11 +2437,9 @@ abstract type MyAbstract47877{C}; end
struct MyType47877{A,B} <: MyAbstract47877{A} end
let A = Tuple{Type{T}, T} where T,
B = Tuple{Type{MyType47877{W, V} where V<:Union{Base.BitInteger, MyAbstract47877{W}}}, MyAbstract47877{<:Base.BitInteger}} where W
C = Tuple{Type{MyType47877{W, V} where V<:Union{MyAbstract47877{W1}, Base.BitInteger}}, MyType47877{W, V} where V<:Union{MyAbstract47877{W1}, Base.BitInteger}} where {W<:Base.BitInteger, W1<:Base.BitInteger}
# ensure that merge_env for innervars does not blow up (the large Unions ensure this will take excessive memory if it does)
@test typeintersect(A, B) == C # suboptimal, but acceptable
C = Tuple{Type{MyType47877{W, V} where V<:Union{MyAbstract47877{W}, Base.BitInteger}}, MyType47877{W, V} where V<:Union{MyAbstract47877{W}, Base.BitInteger}} where W<:Base.BitInteger
@test typeintersect(B, A) == C
# ensure that merge_env for innervars does not blow up (the large Unions ensure this will take excessive memory if it does)
@testintersect(A, B, C)
end

let a = (isodd(i) ? Pair{Char, String} : Pair{String, String} for i in 1:2000)
Expand Down