Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subtype: Improve simple_meet resolution for Union inputs #49376

Merged
merged 3 commits into from
Apr 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 121 additions & 14 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -485,19 +485,19 @@ static int union_sort_cmp(jl_value_t *a, jl_value_t *b) JL_NOTSAFEPOINT
}
}

static int count_union_components(jl_value_t **types, size_t n)
static int count_union_components(jl_value_t **types, size_t n, int widen)
{
size_t i, c = 0;
for (i = 0; i < n; i++) {
jl_value_t *e = types[i];
while (jl_is_uniontype(e)) {
jl_uniontype_t *u = (jl_uniontype_t*)e;
c += count_union_components(&u->a, 1);
c += count_union_components(&u->a, 1, widen);
e = u->b;
}
if (jl_is_unionall(e) && jl_is_uniontype(jl_unwrap_unionall(e))) {
if (widen && jl_is_unionall(e) && jl_is_uniontype(jl_unwrap_unionall(e))) {
jl_uniontype_t *u = (jl_uniontype_t*)jl_unwrap_unionall(e);
c += count_union_components(&u->a, 2);
c += count_union_components(&u->a, 2, widen);
}
else {
c++;
Expand All @@ -506,21 +506,21 @@ static int count_union_components(jl_value_t **types, size_t n)
return c;
}

static void flatten_type_union(jl_value_t **types, size_t n, jl_value_t **out, size_t *idx)
static void flatten_type_union(jl_value_t **types, size_t n, jl_value_t **out, size_t *idx, int widen)
{
size_t i;
for (i = 0; i < n; i++) {
jl_value_t *e = types[i];
while (jl_is_uniontype(e)) {
jl_uniontype_t *u = (jl_uniontype_t*)e;
flatten_type_union(&u->a, 1, out, idx);
flatten_type_union(&u->a, 1, out, idx, widen);
e = u->b;
}
if (jl_is_unionall(e) && jl_is_uniontype(jl_unwrap_unionall(e))) {
if (widen && jl_is_unionall(e) && jl_is_uniontype(jl_unwrap_unionall(e))) {
// flatten this UnionAll into place by switching the union and unionall
jl_uniontype_t *u = (jl_uniontype_t*)jl_unwrap_unionall(e);
size_t old_idx = 0;
flatten_type_union(&u->a, 2, out, idx);
flatten_type_union(&u->a, 2, out, idx, widen);
for (; old_idx < *idx; old_idx++)
out[old_idx] = jl_rewrap_unionall(out[old_idx], e);
}
Expand Down Expand Up @@ -560,11 +560,11 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
if (n == 1)
return ts[0];

size_t nt = count_union_components(ts, n);
size_t nt = count_union_components(ts, n, 1);
jl_value_t **temp;
JL_GC_PUSHARGS(temp, nt+1);
size_t count = 0;
flatten_type_union(ts, n, temp, &count);
flatten_type_union(ts, n, temp, &count, 1);
assert(count == nt);
size_t j;
for (i = 0; i < nt; i++) {
Expand Down Expand Up @@ -641,14 +641,14 @@ static int simple_subtype2(jl_value_t *a, jl_value_t *b, int hasfree)

jl_value_t *simple_union(jl_value_t *a, jl_value_t *b)
{
size_t nta = count_union_components(&a, 1);
size_t ntb = count_union_components(&b, 1);
size_t nta = count_union_components(&a, 1, 1);
size_t ntb = count_union_components(&b, 1, 1);
size_t nt = nta + ntb;
jl_value_t **temp;
JL_GC_PUSHARGS(temp, nt+1);
size_t count = 0;
flatten_type_union(&a, 1, temp, &count);
flatten_type_union(&b, 1, temp, &count);
flatten_type_union(&a, 1, temp, &count, 1);
flatten_type_union(&b, 1, temp, &count, 1);
assert(count == nt);
size_t i, j;
size_t ra = nta, rb = ntb;
Expand Down Expand Up @@ -717,6 +717,113 @@ jl_value_t *simple_union(jl_value_t *a, jl_value_t *b)
return tu;
}

int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity);

static int simple_disjoint(jl_value_t *a, jl_value_t *b, int hasfree)
{
if (jl_is_uniontype(b)) {
jl_value_t *b1 = ((jl_uniontype_t *)b)->a, *b2 = ((jl_uniontype_t *)b)->b;
JL_GC_PUSH2(&b1, &b2);
int res = simple_disjoint(a, b1, hasfree) && simple_disjoint(a, b2, hasfree);
JL_GC_POP();
return res;
}
if (!hasfree && !jl_has_free_typevars(b))
return jl_has_empty_intersection(a, b);
return obviously_disjoint(a, b, 0);
}

jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi)
{
// Unlike `Union`, we don't unwrap `UnionAll` here to avoid possible widening.
size_t nta = count_union_components(&a, 1, 0);
size_t ntb = count_union_components(&b, 1, 0);
size_t nt = nta + ntb;
jl_value_t **temp;
JL_GC_PUSHARGS(temp, nt+1);
size_t count = 0;
flatten_type_union(&a, 1, temp, &count, 0);
flatten_type_union(&b, 1, temp, &count, 0);
assert(count == nt);
size_t i, j;
// first remove disjoint elements.
for (i = 0; i < nt; i++) {
if (simple_disjoint(temp[i], (i < nta ? b : a), jl_has_free_typevars(temp[i])))
temp[i] = NULL;
}
// then check subtyping.
// stemp[k] == -1 : ∃i temp[k] >:ₛ temp[i]
// stemp[k] == 1 : ∃i temp[k] == temp[i]
// stemp[k] == 2 : ∃i temp[k] <:ₛ temp[i]
int8_t *stemp = (int8_t *)alloca(count);
memset(stemp, 0, count);
for (i = 0; i < nta; i++) {
if (temp[i] == NULL) continue;
int hasfree = jl_has_free_typevars(temp[i]);
for (j = nta; j < nt; j++) {
if (temp[j] == NULL) continue;
int subs = simple_subtype2(temp[i], temp[j], hasfree || jl_has_free_typevars(temp[j]));
int subab = subs & 1, subba = subs >> 1;
if (subba && !subab) {
stemp[i] = -1;
if (stemp[j] >= 0) stemp[j] = 2;
}
else if (subab && !subba) {
stemp[j] = -1;
if (stemp[i] >= 0) stemp[i] = 2;
}
else if (subs) {
if (stemp[i] == 0) stemp[i] = 1;
if (stemp[j] == 0) stemp[j] = 1;
}
}
}
int subs[2] = {1, 1}, rs[2] = {1, 1};
for (i = 0; i < nt; i++) {
subs[i >= nta] &= (temp[i] == NULL || stemp[i] > 0);
rs[i >= nta] &= (temp[i] != NULL && stemp[i] > 0);
}
// return a(b) if a(b) <: b(a)
if (rs[0]) {
JL_GC_POP();
return a;
}
if (rs[1]) {
JL_GC_POP();
return b;
}
// return `Union{}` for `merge_env` if we can't prove `<:` or `>:`
if (!overesi && !subs[0] && !subs[1]) {
JL_GC_POP();
return jl_bottom_type;
}
nt = subs[0] ? nta : subs[1] ? nt : nt;
i = subs[0] ? 0 : subs[1] ? nta : 0;
count = nt - i;
if (!subs[0] && !subs[1]) {
// prepare for over estimation
// only preserve `a` with strict <:, but preserve `b` without strict >:
for (j = 0; j < nt; j++) {
if (stemp[j] < (j < nta ? 2 : 0))
temp[j] = NULL;
}
}
isort_union(&temp[i], count);
temp[nt] = jl_bottom_type;
size_t k;
for (k = nt; k-- > i; ) {
if (temp[k] != NULL) {
if (temp[nt] == jl_bottom_type)
temp[nt] = temp[k];
else
temp[nt] = jl_new_struct(jl_uniontype_type, temp[k], temp[nt]);
}
}
assert(temp[nt] != NULL);
jl_value_t *tu = temp[nt];
JL_GC_POP();
return tu;
}

// unionall types -------------------------------------------------------------

Expand Down
15 changes: 3 additions & 12 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ static int obviously_in_union(jl_value_t *u, jl_value_t *x)
return obviously_egal(u, x);
}

static int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity)
int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity)
{
if (a == b || a == (jl_value_t*)jl_any_type || b == (jl_value_t*)jl_any_type)
return 0;
Expand Down Expand Up @@ -559,6 +559,7 @@ static jl_value_t *simple_join(jl_value_t *a, jl_value_t *b)
return simple_union(a, b);
}

jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi);
// Compute a greatest lower bound of `a` and `b`
// For the subtype path, we need to over-estimate this by returning `b` in many cases.
// But for `merge_env`, we'd better under-estimate and return a `Union{}`
Expand All @@ -570,10 +571,6 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi)
return a;
if (!(jl_is_type(a) || jl_is_typevar(a)) || !(jl_is_type(b) || jl_is_typevar(b)))
return jl_bottom_type;
if (jl_is_uniontype(a) && obviously_in_union(a, b))
return b;
if (jl_is_uniontype(b) && obviously_in_union(b, a))
return a;
if (jl_is_kind(a) && jl_is_type_type(b) && jl_typeof(jl_tparam0(b)) == a)
return b;
if (jl_is_kind(b) && jl_is_type_type(a) && jl_typeof(jl_tparam0(a)) == b)
Expand All @@ -582,13 +579,7 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi)
return a;
if (jl_is_typevar(b) && obviously_egal(a, ((jl_tvar_t*)b)->ub))
return b;
if (obviously_disjoint(a, b, 0))
return jl_bottom_type;
if (!jl_has_free_typevars(a) && !jl_has_free_typevars(b)) {
if (jl_subtype(a, b)) return a;
if (jl_subtype(b, a)) return b;
}
return overesi ? b : jl_bottom_type;
return simple_intersect(a, b, overesi);
}

// main subtyping algorithm
Expand Down
8 changes: 8 additions & 0 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2491,6 +2491,14 @@ end
@test !<:(Tuple{Type{Int}, Int}, Tuple{Type{Union{Int, T}}, T} where T<:Union{Int8,Int16})
@test <:(Tuple{Type{Int}, Int}, Tuple{Type{Union{Int, T}}, T} where T<:Union{Int8,Int})

#issue #49354 (requires assertions enabled)
@test !<:(Tuple{Type{Union{Int, Val{1}}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Val)
@test !<:(Tuple{Type{Union{Int, Val{1}}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Union{Val,Pair})
@test <:(Tuple{Type{Union{Int, Val{1}}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Union{Integer,Val})
@test <:(Tuple{Type{Union{Int, Int8}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Integer)
@test !<:(Tuple{Type{Union{Pair{Int, Any}, Pair{Int, Int}}}, Pair{Int, Any}},
Tuple{Type{Union{Pair{Int, Any}, T1}}, T1} where T1<:(Pair{T,T} where {T}))

let A = Tuple{Type{T}, T, Val{T}} where T,
B = Tuple{Type{S}, Val{S}, Val{S}} where S
@test_broken typeintersect(A, B) != Union{}
Expand Down