diff --git a/src/jltypes.c b/src/jltypes.c index 2e3a38d7df3ec..b66a312de1129 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -485,19 +485,19 @@ static int union_sort_cmp(jl_value_t *a, jl_value_t *b) JL_NOTSAFEPOINT } } -static int count_union_components(jl_value_t **types, size_t n) +static int count_union_components(jl_value_t **types, size_t n, int widen) { size_t i, c = 0; for (i = 0; i < n; i++) { jl_value_t *e = types[i]; while (jl_is_uniontype(e)) { jl_uniontype_t *u = (jl_uniontype_t*)e; - c += count_union_components(&u->a, 1); + c += count_union_components(&u->a, 1, widen); e = u->b; } - if (jl_is_unionall(e) && jl_is_uniontype(jl_unwrap_unionall(e))) { + if (widen && jl_is_unionall(e) && jl_is_uniontype(jl_unwrap_unionall(e))) { jl_uniontype_t *u = (jl_uniontype_t*)jl_unwrap_unionall(e); - c += count_union_components(&u->a, 2); + c += count_union_components(&u->a, 2, widen); } else { c++; @@ -506,21 +506,21 @@ static int count_union_components(jl_value_t **types, size_t n) return c; } -static void flatten_type_union(jl_value_t **types, size_t n, jl_value_t **out, size_t *idx) +static void flatten_type_union(jl_value_t **types, size_t n, jl_value_t **out, size_t *idx, int widen) { size_t i; for (i = 0; i < n; i++) { jl_value_t *e = types[i]; while (jl_is_uniontype(e)) { jl_uniontype_t *u = (jl_uniontype_t*)e; - flatten_type_union(&u->a, 1, out, idx); + flatten_type_union(&u->a, 1, out, idx, widen); e = u->b; } - if (jl_is_unionall(e) && jl_is_uniontype(jl_unwrap_unionall(e))) { + if (widen && jl_is_unionall(e) && jl_is_uniontype(jl_unwrap_unionall(e))) { // flatten this UnionAll into place by switching the union and unionall jl_uniontype_t *u = (jl_uniontype_t*)jl_unwrap_unionall(e); size_t old_idx = 0; - flatten_type_union(&u->a, 2, out, idx); + flatten_type_union(&u->a, 2, out, idx, widen); for (; old_idx < *idx; old_idx++) out[old_idx] = jl_rewrap_unionall(out[old_idx], e); } @@ -560,11 +560,11 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n) if (n == 1) return ts[0]; - size_t nt = count_union_components(ts, n); + size_t nt = count_union_components(ts, n, 1); jl_value_t **temp; JL_GC_PUSHARGS(temp, nt+1); size_t count = 0; - flatten_type_union(ts, n, temp, &count); + flatten_type_union(ts, n, temp, &count, 1); assert(count == nt); size_t j; for (i = 0; i < nt; i++) { @@ -641,14 +641,14 @@ static int simple_subtype2(jl_value_t *a, jl_value_t *b, int hasfree) jl_value_t *simple_union(jl_value_t *a, jl_value_t *b) { - size_t nta = count_union_components(&a, 1); - size_t ntb = count_union_components(&b, 1); + size_t nta = count_union_components(&a, 1, 1); + size_t ntb = count_union_components(&b, 1, 1); size_t nt = nta + ntb; jl_value_t **temp; JL_GC_PUSHARGS(temp, nt+1); size_t count = 0; - flatten_type_union(&a, 1, temp, &count); - flatten_type_union(&b, 1, temp, &count); + flatten_type_union(&a, 1, temp, &count, 1); + flatten_type_union(&b, 1, temp, &count, 1); assert(count == nt); size_t i, j; size_t ra = nta, rb = ntb; @@ -717,6 +717,113 @@ jl_value_t *simple_union(jl_value_t *a, jl_value_t *b) return tu; } +int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity); + +static int simple_disjoint(jl_value_t *a, jl_value_t *b, int hasfree) +{ + if (jl_is_uniontype(b)) { + jl_value_t *b1 = ((jl_uniontype_t *)b)->a, *b2 = ((jl_uniontype_t *)b)->b; + JL_GC_PUSH2(&b1, &b2); + int res = simple_disjoint(a, b1, hasfree) && simple_disjoint(a, b2, hasfree); + JL_GC_POP(); + return res; + } + if (!hasfree && !jl_has_free_typevars(b)) + return jl_has_empty_intersection(a, b); + return obviously_disjoint(a, b, 0); +} + +jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi) +{ + // Unlike `Union`, we don't unwrap `UnionAll` here to avoid possible widening. + size_t nta = count_union_components(&a, 1, 0); + size_t ntb = count_union_components(&b, 1, 0); + size_t nt = nta + ntb; + jl_value_t **temp; + JL_GC_PUSHARGS(temp, nt+1); + size_t count = 0; + flatten_type_union(&a, 1, temp, &count, 0); + flatten_type_union(&b, 1, temp, &count, 0); + assert(count == nt); + size_t i, j; + // first remove disjoint elements. + for (i = 0; i < nt; i++) { + if (simple_disjoint(temp[i], (i < nta ? b : a), jl_has_free_typevars(temp[i]))) + temp[i] = NULL; + } + // then check subtyping. + // stemp[k] == -1 : ∃i temp[k] >:ₛ temp[i] + // stemp[k] == 1 : ∃i temp[k] == temp[i] + // stemp[k] == 2 : ∃i temp[k] <:ₛ temp[i] + int8_t *stemp = (int8_t *)alloca(count); + memset(stemp, 0, count); + for (i = 0; i < nta; i++) { + if (temp[i] == NULL) continue; + int hasfree = jl_has_free_typevars(temp[i]); + for (j = nta; j < nt; j++) { + if (temp[j] == NULL) continue; + int subs = simple_subtype2(temp[i], temp[j], hasfree || jl_has_free_typevars(temp[j])); + int subab = subs & 1, subba = subs >> 1; + if (subba && !subab) { + stemp[i] = -1; + if (stemp[j] >= 0) stemp[j] = 2; + } + else if (subab && !subba) { + stemp[j] = -1; + if (stemp[i] >= 0) stemp[i] = 2; + } + else if (subs) { + if (stemp[i] == 0) stemp[i] = 1; + if (stemp[j] == 0) stemp[j] = 1; + } + } + } + int subs[2] = {1, 1}, rs[2] = {1, 1}; + for (i = 0; i < nt; i++) { + subs[i >= nta] &= (temp[i] == NULL || stemp[i] > 0); + rs[i >= nta] &= (temp[i] != NULL && stemp[i] > 0); + } + // return a(b) if a(b) <: b(a) + if (rs[0]) { + JL_GC_POP(); + return a; + } + if (rs[1]) { + JL_GC_POP(); + return b; + } + // return `Union{}` for `merge_env` if we can't prove `<:` or `>:` + if (!overesi && !subs[0] && !subs[1]) { + JL_GC_POP(); + return jl_bottom_type; + } + nt = subs[0] ? nta : subs[1] ? nt : nt; + i = subs[0] ? 0 : subs[1] ? nta : 0; + count = nt - i; + if (!subs[0] && !subs[1]) { + // prepare for over estimation + // only preserve `a` with strict <:, but preserve `b` without strict >: + for (j = 0; j < nt; j++) { + if (stemp[j] < (j < nta ? 2 : 0)) + temp[j] = NULL; + } + } + isort_union(&temp[i], count); + temp[nt] = jl_bottom_type; + size_t k; + for (k = nt; k-- > i; ) { + if (temp[k] != NULL) { + if (temp[nt] == jl_bottom_type) + temp[nt] = temp[k]; + else + temp[nt] = jl_new_struct(jl_uniontype_type, temp[k], temp[nt]); + } + } + assert(temp[nt] != NULL); + jl_value_t *tu = temp[nt]; + JL_GC_POP(); + return tu; +} // unionall types ------------------------------------------------------------- diff --git a/src/subtype.c b/src/subtype.c index 518c566193b70..3115c5984eca5 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -447,7 +447,7 @@ static int obviously_in_union(jl_value_t *u, jl_value_t *x) return obviously_egal(u, x); } -static int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity) +int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity) { if (a == b || a == (jl_value_t*)jl_any_type || b == (jl_value_t*)jl_any_type) return 0; @@ -559,6 +559,7 @@ static jl_value_t *simple_join(jl_value_t *a, jl_value_t *b) return simple_union(a, b); } +jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi); // Compute a greatest lower bound of `a` and `b` // For the subtype path, we need to over-estimate this by returning `b` in many cases. // But for `merge_env`, we'd better under-estimate and return a `Union{}` @@ -570,10 +571,6 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi) return a; if (!(jl_is_type(a) || jl_is_typevar(a)) || !(jl_is_type(b) || jl_is_typevar(b))) return jl_bottom_type; - if (jl_is_uniontype(a) && obviously_in_union(a, b)) - return b; - if (jl_is_uniontype(b) && obviously_in_union(b, a)) - return a; if (jl_is_kind(a) && jl_is_type_type(b) && jl_typeof(jl_tparam0(b)) == a) return b; if (jl_is_kind(b) && jl_is_type_type(a) && jl_typeof(jl_tparam0(a)) == b) @@ -582,13 +579,7 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi) return a; if (jl_is_typevar(b) && obviously_egal(a, ((jl_tvar_t*)b)->ub)) return b; - if (obviously_disjoint(a, b, 0)) - return jl_bottom_type; - if (!jl_has_free_typevars(a) && !jl_has_free_typevars(b)) { - if (jl_subtype(a, b)) return a; - if (jl_subtype(b, a)) return b; - } - return overesi ? b : jl_bottom_type; + return simple_intersect(a, b, overesi); } // main subtyping algorithm diff --git a/test/subtype.jl b/test/subtype.jl index aad1424a2d66c..b38588155ef64 100644 --- a/test/subtype.jl +++ b/test/subtype.jl @@ -2491,6 +2491,14 @@ end @test !<:(Tuple{Type{Int}, Int}, Tuple{Type{Union{Int, T}}, T} where T<:Union{Int8,Int16}) @test <:(Tuple{Type{Int}, Int}, Tuple{Type{Union{Int, T}}, T} where T<:Union{Int8,Int}) +#issue #49354 (requires assertions enabled) +@test !<:(Tuple{Type{Union{Int, Val{1}}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Val) +@test !<:(Tuple{Type{Union{Int, Val{1}}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Union{Val,Pair}) +@test <:(Tuple{Type{Union{Int, Val{1}}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Union{Integer,Val}) +@test <:(Tuple{Type{Union{Int, Int8}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Integer) +@test !<:(Tuple{Type{Union{Pair{Int, Any}, Pair{Int, Int}}}, Pair{Int, Any}}, + Tuple{Type{Union{Pair{Int, Any}, T1}}, T1} where T1<:(Pair{T,T} where {T})) + let A = Tuple{Type{T}, T, Val{T}} where T, B = Tuple{Type{S}, Val{S}, Val{S}} where S @test_broken typeintersect(A, B) != Union{}