diff --git a/mypy/checker.py b/mypy/checker.py index b786155079e5..3bd9c494a890 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -7216,22 +7216,32 @@ def is_unsafe_overlapping_overload_signatures( # # This discrepancy is unfortunately difficult to get rid of, so we repeat the # checks twice in both directions for now. + # + # Note that we ignore possible overlap between type variables and None. This + # is technically unsafe, but unsafety is tiny and this prevents some common + # use cases like: + # @overload + # def foo(x: None) -> None: .. + # @overload + # def foo(x: T) -> Foo[T]: ... return is_callable_compatible( signature, other, - is_compat=is_overlapping_types_no_promote_no_uninhabited, + is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none, is_compat_return=lambda l, r: not is_subtype_no_promote(l, r), ignore_return=False, check_args_covariantly=True, allow_partial_overlap=True, + no_unify_none=True, ) or is_callable_compatible( other, signature, - is_compat=is_overlapping_types_no_promote_no_uninhabited, + is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none, is_compat_return=lambda l, r: not is_subtype_no_promote(r, l), ignore_return=False, check_args_covariantly=False, allow_partial_overlap=True, + no_unify_none=True, ) @@ -7717,12 +7727,18 @@ def is_subtype_no_promote(left: Type, right: Type) -> bool: return is_subtype(left, right, ignore_promotions=True) -def is_overlapping_types_no_promote_no_uninhabited(left: Type, right: Type) -> bool: +def is_overlapping_types_no_promote_no_uninhabited_no_none(left: Type, right: Type) -> bool: # For the purpose of unsafe overload checks we consider list[] and list[int] # non-overlapping. This is consistent with how we treat list[int] and list[str] as # non-overlapping, despite [] belongs to both. Also this will prevent false positives # for failed type inference during unification. - return is_overlapping_types(left, right, ignore_promotions=True, ignore_uninhabited=True) + return is_overlapping_types( + left, + right, + ignore_promotions=True, + ignore_uninhabited=True, + prohibit_none_typevar_overlap=True, + ) def is_private(node_name: str) -> bool: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9e46d9ee39cb..3282ee338b87 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2400,6 +2400,11 @@ def check_overload_call( # typevar. See https://github.com/python/mypy/issues/4063 for related discussion. erased_targets: list[CallableType] | None = None unioned_result: tuple[Type, Type] | None = None + + # Determine whether we need to encourage union math. This should be generally safe, + # as union math infers better results in the vast majority of cases, but it is very + # computationally intensive. + none_type_var_overlap = self.possible_none_type_var_overlap(arg_types, plausible_targets) union_interrupted = False # did we try all union combinations? if any(self.real_union(arg) for arg in arg_types): try: @@ -2412,6 +2417,7 @@ def check_overload_call( arg_names, callable_name, object_type, + none_type_var_overlap, context, ) except TooManyUnions: @@ -2444,8 +2450,10 @@ def check_overload_call( # If any of checks succeed, stop early. if inferred_result is not None and unioned_result is not None: # Both unioned and direct checks succeeded, choose the more precise type. - if is_subtype(inferred_result[0], unioned_result[0]) and not isinstance( - get_proper_type(inferred_result[0]), AnyType + if ( + is_subtype(inferred_result[0], unioned_result[0]) + and not isinstance(get_proper_type(inferred_result[0]), AnyType) + and not none_type_var_overlap ): return inferred_result return unioned_result @@ -2495,7 +2503,8 @@ def check_overload_call( callable_name=callable_name, object_type=object_type, ) - if union_interrupted: + # Do not show the extra error if the union math was forced. + if union_interrupted and not none_type_var_overlap: self.chk.fail(message_registry.TOO_MANY_UNION_COMBINATIONS, context) return result @@ -2650,6 +2659,44 @@ def overload_erased_call_targets( matches.append(typ) return matches + def possible_none_type_var_overlap( + self, arg_types: list[Type], plausible_targets: list[CallableType] + ) -> bool: + """Heuristic to determine whether we need to try forcing union math. + + This is needed to avoid greedy type variable match in situations like this: + @overload + def foo(x: None) -> None: ... + @overload + def foo(x: T) -> list[T]: ... + + x: int | None + foo(x) + we want this call to infer list[int] | None, not list[int | None]. + """ + if not plausible_targets or not arg_types: + return False + has_optional_arg = False + for arg_type in get_proper_types(arg_types): + if not isinstance(arg_type, UnionType): + continue + for item in get_proper_types(arg_type.items): + if isinstance(item, NoneType): + has_optional_arg = True + break + if not has_optional_arg: + return False + + min_prefix = min(len(c.arg_types) for c in plausible_targets) + for i in range(min_prefix): + if any( + isinstance(get_proper_type(c.arg_types[i]), NoneType) for c in plausible_targets + ) and any( + isinstance(get_proper_type(c.arg_types[i]), TypeVarType) for c in plausible_targets + ): + return True + return False + def union_overload_result( self, plausible_targets: list[CallableType], @@ -2659,6 +2706,7 @@ def union_overload_result( arg_names: Sequence[str | None] | None, callable_name: str | None, object_type: Type | None, + none_type_var_overlap: bool, context: Context, level: int = 0, ) -> list[tuple[Type, Type]] | None: @@ -2698,20 +2746,23 @@ def union_overload_result( # Step 3: Try a direct match before splitting to avoid unnecessary union splits # and save performance. - with self.type_overrides_set(args, arg_types): - direct = self.infer_overload_return_type( - plausible_targets, - args, - arg_types, - arg_kinds, - arg_names, - callable_name, - object_type, - context, - ) - if direct is not None and not isinstance(get_proper_type(direct[0]), (UnionType, AnyType)): - # We only return non-unions soon, to avoid greedy match. - return [direct] + if not none_type_var_overlap: + with self.type_overrides_set(args, arg_types): + direct = self.infer_overload_return_type( + plausible_targets, + args, + arg_types, + arg_kinds, + arg_names, + callable_name, + object_type, + context, + ) + if direct is not None and not isinstance( + get_proper_type(direct[0]), (UnionType, AnyType) + ): + # We only return non-unions soon, to avoid greedy match. + return [direct] # Step 4: Split the first remaining union type in arguments into items and # try to match each item individually (recursive). @@ -2729,6 +2780,7 @@ def union_overload_result( arg_names, callable_name, object_type, + none_type_var_overlap, context, level + 1, ) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 5712d7375e50..da92f7398d4e 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1299,6 +1299,7 @@ def is_callable_compatible( check_args_covariantly: bool = False, allow_partial_overlap: bool = False, strict_concatenate: bool = False, + no_unify_none: bool = False, ) -> bool: """Is the left compatible with the right, using the provided compatibility check? @@ -1415,7 +1416,9 @@ def g(x: int) -> int: ... # (below) treats type variables on the two sides as independent. if left.variables: # Apply generic type variables away in left via type inference. - unified = unify_generic_callable(left, right, ignore_return=ignore_return) + unified = unify_generic_callable( + left, right, ignore_return=ignore_return, no_unify_none=no_unify_none + ) if unified is None: return False left = unified @@ -1427,7 +1430,9 @@ def g(x: int) -> int: ... # So, we repeat the above checks in the opposite direction. This also # lets us preserve the 'symmetry' property of allow_partial_overlap. if allow_partial_overlap and right.variables: - unified = unify_generic_callable(right, left, ignore_return=ignore_return) + unified = unify_generic_callable( + right, left, ignore_return=ignore_return, no_unify_none=no_unify_none + ) if unified is not None: right = unified @@ -1687,6 +1692,8 @@ def unify_generic_callable( target: NormalizedCallableType, ignore_return: bool, return_constraint_direction: int | None = None, + *, + no_unify_none: bool = False, ) -> NormalizedCallableType | None: """Try to unify a generic callable type with another callable type. @@ -1708,6 +1715,10 @@ def unify_generic_callable( type.ret_type, target.ret_type, return_constraint_direction ) constraints.extend(c) + if no_unify_none: + constraints = [ + c for c in constraints if not isinstance(get_proper_type(c.target), NoneType) + ] inferred_vars, _ = mypy.solve.solve_constraints(type.variables, constraints) if None in inferred_vars: return None diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 50acd7d77c8c..4910dfe05d31 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -2185,36 +2185,63 @@ def bar2(*x: int) -> int: ... [builtins fixtures/tuple.pyi] [case testOverloadDetectsPossibleMatchesWithGenerics] -from typing import overload, TypeVar, Generic +# flags: --strict-optional +from typing import overload, TypeVar, Generic, Optional, List T = TypeVar('T') +# The examples below are unsafe, but it is a quite common pattern +# so we ignore the possibility of type variables taking value `None` +# for the purpose of overload overlap checks. @overload -def foo(x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +def foo(x: None, y: None) -> str: ... @overload def foo(x: T, y: T) -> int: ... def foo(x): ... +oi: Optional[int] +reveal_type(foo(None, None)) # N: Revealed type is "builtins.str" +reveal_type(foo(None, 42)) # N: Revealed type is "builtins.int" +reveal_type(foo(42, 42)) # N: Revealed type is "builtins.int" +reveal_type(foo(oi, None)) # N: Revealed type is "Union[builtins.int, builtins.str]" +reveal_type(foo(oi, 42)) # N: Revealed type is "builtins.int" +reveal_type(foo(oi, oi)) # N: Revealed type is "Union[builtins.int, builtins.str]" + +@overload +def foo_list(x: None) -> None: ... +@overload +def foo_list(x: T) -> List[T]: ... +def foo_list(x): ... + +reveal_type(foo_list(oi)) # N: Revealed type is "Union[builtins.list[builtins.int], None]" + # What if 'T' is 'object'? @overload -def bar(x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +def bar(x: None, y: int) -> str: ... @overload def bar(x: T, y: T) -> int: ... def bar(x, y): ... class Wrapper(Generic[T]): @overload - def foo(self, x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types + def foo(self, x: None, y: None) -> str: ... @overload def foo(self, x: T, y: None) -> int: ... def foo(self, x): ... @overload - def bar(self, x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types + def bar(self, x: None, y: int) -> str: ... @overload def bar(self, x: T, y: T) -> int: ... def bar(self, x, y): ... +@overload +def baz(x: str, y: str) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def baz(x: T, y: T) -> int: ... +def baz(x): ... +[builtins fixtures/tuple.pyi] + [case testOverloadFlagsPossibleMatches] from wrapper import * [file wrapper.pyi] @@ -3996,7 +4023,7 @@ T = TypeVar('T') class FakeAttribute(Generic[T]): @overload - def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types + def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ... @overload def dummy(self, instance: T, owner: Type[T]) -> int: ... def dummy(self, instance: Optional[T], owner: Type[T]) -> Union['FakeAttribute[T]', int]: ...