Skip to content

Commit

Permalink
Subtype: Improve simple_meet resolution for Union inputs (JuliaLa…
Browse files Browse the repository at this point in the history
…ng#49376)

* Improve `simple_meet` resolution.

* Fix for many-to-one cases.

* Test disjoint via `jl_has_empty_intersection`
  • Loading branch information
N5N3 committed Apr 24, 2023
1 parent 24bda32 commit eb4f64b
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 17 deletions.
138 changes: 138 additions & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,144 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
return tu;
}

static int simple_subtype2(jl_value_t *a, jl_value_t *b, int hasfree)
{
int subab = 0, subba = 0;
if (jl_egal(a, b)) {
subab = subba = 1;
}
else if (a == jl_bottom_type || b == (jl_value_t*)jl_any_type) {
subab = 1;
}
else if (b == jl_bottom_type || a == (jl_value_t*)jl_any_type) {
subba = 1;
}
else if (hasfree) {
// subab = simple_subtype(a, b);
// subba = simple_subtype(b, a);
}
else if (jl_is_type_type(a) && jl_is_type_type(b) &&
jl_typeof(jl_tparam0(a)) != jl_typeof(jl_tparam0(b))) {
// issue #24521: don't merge Type{T} where typeof(T) varies
}
else if (jl_typeof(a) == jl_typeof(b) && jl_types_egal(a, b)) {
subab = subba = 1;
}
else {
subab = jl_subtype(a, b);
subba = jl_subtype(b, a);
}
return subab | (subba<<1);
}

int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity);

static int simple_disjoint(jl_value_t *a, jl_value_t *b, int hasfree)
{
if (jl_is_uniontype(b)) {
jl_value_t *b1 = ((jl_uniontype_t *)b)->a, *b2 = ((jl_uniontype_t *)b)->b;
JL_GC_PUSH2(&b1, &b2);
int res = simple_disjoint(a, b1, hasfree) && simple_disjoint(a, b2, hasfree);
JL_GC_POP();
return res;
}
if (!hasfree && !jl_has_free_typevars(b))
return jl_has_empty_intersection(a, b);
return obviously_disjoint(a, b, 0);
}

jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi)
{
// Unlike `Union`, we don't unwrap `UnionAll` here to avoid possible widening.
size_t nta = count_union_components(&a, 1);
size_t ntb = count_union_components(&b, 1);
size_t nt = nta + ntb;
jl_value_t **temp;
JL_GC_PUSHARGS(temp, nt+1);
size_t count = 0;
flatten_type_union(&a, 1, temp, &count);
flatten_type_union(&b, 1, temp, &count);
assert(count == nt);
size_t i, j;
// first remove disjoint elements.
for (i = 0; i < nt; i++) {
if (simple_disjoint(temp[i], (i < nta ? b : a), jl_has_free_typevars(temp[i])))
temp[i] = NULL;
}
// then check subtyping.
// stemp[k] == -1 : ∃i temp[k] >:ₛ temp[i]
// stemp[k] == 1 : ∃i temp[k] == temp[i]
// stemp[k] == 2 : ∃i temp[k] <:ₛ temp[i]
int8_t *stemp = (int8_t *)alloca(count);
memset(stemp, 0, count);
for (i = 0; i < nta; i++) {
if (temp[i] == NULL) continue;
int hasfree = jl_has_free_typevars(temp[i]);
for (j = nta; j < nt; j++) {
if (temp[j] == NULL) continue;
int subs = simple_subtype2(temp[i], temp[j], hasfree || jl_has_free_typevars(temp[j]));
int subab = subs & 1, subba = subs >> 1;
if (subba && !subab) {
stemp[i] = -1;
if (stemp[j] >= 0) stemp[j] = 2;
}
else if (subab && !subba) {
stemp[j] = -1;
if (stemp[i] >= 0) stemp[i] = 2;
}
else if (subs) {
if (stemp[i] == 0) stemp[i] = 1;
if (stemp[j] == 0) stemp[j] = 1;
}
}
}
int subs[2] = {1, 1}, rs[2] = {1, 1};
for (i = 0; i < nt; i++) {
subs[i >= nta] &= (temp[i] == NULL || stemp[i] > 0);
rs[i >= nta] &= (temp[i] != NULL && stemp[i] > 0);
}
// return a(b) if a(b) <: b(a)
if (rs[0]) {
JL_GC_POP();
return a;
}
if (rs[1]) {
JL_GC_POP();
return b;
}
// return `Union{}` for `merge_env` if we can't prove `<:` or `>:`
if (!overesi && !subs[0] && !subs[1]) {
JL_GC_POP();
return jl_bottom_type;
}
nt = subs[0] ? nta : subs[1] ? nt : nt;
i = subs[0] ? 0 : subs[1] ? nta : 0;
count = nt - i;
if (!subs[0] && !subs[1]) {
// prepare for over estimation
// only preserve `a` with strict <:, but preserve `b` without strict >:
for (j = 0; j < nt; j++) {
if (stemp[j] < (j < nta ? 2 : 0))
temp[j] = NULL;
}
}
isort_union(&temp[i], count);
temp[nt] = jl_bottom_type;
size_t k;
for (k = nt; k-- > i; ) {
if (temp[k] != NULL) {
if (temp[nt] == jl_bottom_type)
temp[nt] = temp[k];
else
temp[nt] = jl_new_struct(jl_uniontype_type, temp[k], temp[nt]);
}
}
assert(temp[nt] != NULL);
jl_value_t *tu = temp[nt];
JL_GC_POP();
return tu;
}

// unionall types -------------------------------------------------------------

JL_DLLEXPORT jl_value_t *jl_type_unionall(jl_tvar_t *v, jl_value_t *body)
Expand Down
26 changes: 9 additions & 17 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ static int in_union(jl_value_t *u, jl_value_t *x) JL_NOTSAFEPOINT
return in_union(((jl_uniontype_t*)u)->a, x) || in_union(((jl_uniontype_t*)u)->b, x);
}

static int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity)
int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity)
{
if (a == b || a == (jl_value_t*)jl_any_type || b == (jl_value_t*)jl_any_type)
return 0;
Expand Down Expand Up @@ -479,20 +479,18 @@ static jl_value_t *simple_join(jl_value_t *a, jl_value_t *b)
return jl_new_struct(jl_uniontype_type, a, b);
}

// compute a greatest lower bound of `a` and `b`
// in many cases, we need to over-estimate this by returning `b`.
static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b)
jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi);
// Compute a greatest lower bound of `a` and `b`
// For the subtype path, we need to over-estimate this by returning `b` in many cases.
// But for `merge_env`, we'd better under-estimate and return a `Union{}`
static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi)
{
if (a == (jl_value_t*)jl_any_type || b == jl_bottom_type || obviously_egal(a,b))
return b;
if (b == (jl_value_t*)jl_any_type || a == jl_bottom_type)
return a;
if (!(jl_is_type(a) || jl_is_typevar(a)) || !(jl_is_type(b) || jl_is_typevar(b)))
return jl_bottom_type;
if (jl_is_uniontype(a) && in_union(a, b))
return b;
if (jl_is_uniontype(b) && in_union(b, a))
return a;
if (jl_is_kind(a) && jl_is_type_type(b) && jl_typeof(jl_tparam0(b)) == a)
return b;
if (jl_is_kind(b) && jl_is_type_type(a) && jl_typeof(jl_tparam0(a)) == b)
Expand All @@ -501,13 +499,7 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b)
return a;
if (jl_is_typevar(b) && obviously_egal(a, ((jl_tvar_t*)b)->ub))
return b;
if (obviously_disjoint(a, b, 0))
return jl_bottom_type;
if (!jl_has_free_typevars(a) && !jl_has_free_typevars(b)) {
if (jl_subtype(a, b)) return a;
if (jl_subtype(b, a)) return b;
}
return b;
return simple_intersect(a, b, overesi);
}

static jl_unionall_t *rename_unionall(jl_unionall_t *u)
Expand Down Expand Up @@ -652,7 +644,7 @@ static int var_lt(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int param)
JL_GC_POP();
}
else {
bb->ub = simple_meet(bb->ub, a);
bb->ub = simple_meet(bb->ub, a, 1);
}
assert(bb->ub != (jl_value_t*)b);
if (jl_is_typevar(a)) {
Expand Down Expand Up @@ -3303,7 +3295,7 @@ static int merge_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se, int co
while (v != NULL) {
b1 = jl_svecref(*root, n);
b2 = v->lb;
jl_svecset(*root, n, simple_meet(b1, b2));
jl_svecset(*root, n, simple_meet(b1, b2, 0));
b1 = jl_svecref(*root, n+1);
b2 = v->ub;
jl_svecset(*root, n+1, simple_join(b1, b2));
Expand Down
8 changes: 8 additions & 0 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2331,3 +2331,11 @@ end
# requires assertions enabled (to test union-split in `obviously_disjoint`)
@test !<:(Tuple{Type{Int}, Int}, Tuple{Type{Union{Int, T}}, T} where T<:Union{Int8,Int16})
@test <:(Tuple{Type{Int}, Int}, Tuple{Type{Union{Int, T}}, T} where T<:Union{Int8,Int})

#issue #49354 (requires assertions enabled)
@test !<:(Tuple{Type{Union{Int, Val{1}}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Val)
@test !<:(Tuple{Type{Union{Int, Val{1}}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Union{Val,Pair})
@test <:(Tuple{Type{Union{Int, Val{1}}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Union{Integer,Val})
@test <:(Tuple{Type{Union{Int, Int8}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Integer)
@test !<:(Tuple{Type{Union{Pair{Int, Any}, Pair{Int, Int}}}, Pair{Int, Any}},
Tuple{Type{Union{Pair{Int, Any}, T1}}, T1} where T1<:(Pair{T,T} where {T}))

0 comments on commit eb4f64b

Please sign in to comment.