From 2aaeda4b84a863004a6694a7d562462fbe531ece Mon Sep 17 00:00:00 2001 From: EXPLOSION Date: Wed, 9 Aug 2023 15:17:13 +0900 Subject: [PATCH] Reconsider constraints involving parameter specifications (#15272) - Fixes https://github.com/python/mypy/issues/15037 - Fixes https://github.com/python/mypy/issues/15065 - Fixes https://github.com/python/mypy/issues/15073 - Fixes https://github.com/python/mypy/issues/15388 - Fixes https://github.com/python/mypy/issues/15086 Yet another part of https://github.com/python/mypy/pull/14903 that's finally been extracted! --- mypy/constraints.py | 129 ++++++++++++++---- mypy/test/testconstraints.py | 62 +++++++++ mypy/test/typefixture.py | 42 ++++++ .../unit/check-parameter-specification.test | 32 ++++- 4 files changed, 241 insertions(+), 24 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 299c6292a259..9c55b56dd70e 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -82,15 +82,19 @@ def __repr__(self) -> str: op_str = "<:" if self.op == SUPERTYPE_OF: op_str = ":>" - return f"{self.type_var} {op_str} {self.target}" + return f"{self.origin_type_var} {op_str} {self.target}" def __hash__(self) -> int: - return hash((self.type_var, self.op, self.target)) + return hash((self.origin_type_var, self.op, self.target)) def __eq__(self, other: object) -> bool: if not isinstance(other, Constraint): return False - return (self.type_var, self.op, self.target) == (other.type_var, other.op, other.target) + return (self.origin_type_var, self.op, self.target) == ( + other.origin_type_var, + other.op, + other.target, + ) def infer_constraints_for_callable( @@ -698,25 +702,54 @@ def visit_instance(self, template: Instance) -> list[Constraint]: ) elif isinstance(tvar, ParamSpecType) and isinstance(mapped_arg, ParamSpecType): suffix = get_proper_type(instance_arg) + prefix = mapped_arg.prefix + length = len(prefix.arg_types) if isinstance(suffix, CallableType): - prefix = mapped_arg.prefix from_concat = bool(prefix.arg_types) or suffix.from_concatenate suffix = suffix.copy_modified(from_concatenate=from_concat) if isinstance(suffix, (Parameters, CallableType)): # no such thing as variance for ParamSpecs # TODO: is there a case I am missing? - # TODO: constraints between prefixes - prefix = mapped_arg.prefix - suffix = suffix.copy_modified( - suffix.arg_types[len(prefix.arg_types) :], - suffix.arg_kinds[len(prefix.arg_kinds) :], - suffix.arg_names[len(prefix.arg_names) :], + length = min(length, len(suffix.arg_types)) + + constrained_to = suffix.copy_modified( + suffix.arg_types[length:], + suffix.arg_kinds[length:], + suffix.arg_names[length:], + ) + constrained_from = mapped_arg.copy_modified( + prefix=prefix.copy_modified( + prefix.arg_types[length:], + prefix.arg_kinds[length:], + prefix.arg_names[length:], + ) ) - res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix)) + + res.append(Constraint(constrained_from, SUPERTYPE_OF, constrained_to)) + res.append(Constraint(constrained_from, SUBTYPE_OF, constrained_to)) elif isinstance(suffix, ParamSpecType): - res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix)) + suffix_prefix = suffix.prefix + length = min(length, len(suffix_prefix.arg_types)) + + constrained = suffix.copy_modified( + prefix=suffix_prefix.copy_modified( + suffix_prefix.arg_types[length:], + suffix_prefix.arg_kinds[length:], + suffix_prefix.arg_names[length:], + ) + ) + constrained_from = mapped_arg.copy_modified( + prefix=prefix.copy_modified( + prefix.arg_types[length:], + prefix.arg_kinds[length:], + prefix.arg_names[length:], + ) + ) + + res.append(Constraint(constrained_from, SUPERTYPE_OF, constrained)) + res.append(Constraint(constrained_from, SUBTYPE_OF, constrained)) else: # This case should have been handled above. assert not isinstance(tvar, TypeVarTupleType) @@ -768,26 +801,56 @@ def visit_instance(self, template: Instance) -> list[Constraint]: template_arg, ParamSpecType ): suffix = get_proper_type(mapped_arg) + prefix = template_arg.prefix + length = len(prefix.arg_types) if isinstance(suffix, CallableType): prefix = template_arg.prefix from_concat = bool(prefix.arg_types) or suffix.from_concatenate suffix = suffix.copy_modified(from_concatenate=from_concat) + # TODO: this is almost a copy-paste of code above: make this into a function if isinstance(suffix, (Parameters, CallableType)): # no such thing as variance for ParamSpecs # TODO: is there a case I am missing? - # TODO: constraints between prefixes - prefix = template_arg.prefix + length = min(length, len(suffix.arg_types)) - suffix = suffix.copy_modified( - suffix.arg_types[len(prefix.arg_types) :], - suffix.arg_kinds[len(prefix.arg_kinds) :], - suffix.arg_names[len(prefix.arg_names) :], + constrained_to = suffix.copy_modified( + suffix.arg_types[length:], + suffix.arg_kinds[length:], + suffix.arg_names[length:], ) - res.append(Constraint(template_arg, SUPERTYPE_OF, suffix)) + constrained_from = template_arg.copy_modified( + prefix=prefix.copy_modified( + prefix.arg_types[length:], + prefix.arg_kinds[length:], + prefix.arg_names[length:], + ) + ) + + res.append(Constraint(constrained_from, SUPERTYPE_OF, constrained_to)) + res.append(Constraint(constrained_from, SUBTYPE_OF, constrained_to)) elif isinstance(suffix, ParamSpecType): - res.append(Constraint(template_arg, SUPERTYPE_OF, suffix)) + suffix_prefix = suffix.prefix + length = min(length, len(suffix_prefix.arg_types)) + + constrained = suffix.copy_modified( + prefix=suffix_prefix.copy_modified( + suffix_prefix.arg_types[length:], + suffix_prefix.arg_kinds[length:], + suffix_prefix.arg_names[length:], + ) + ) + constrained_from = template_arg.copy_modified( + prefix=prefix.copy_modified( + prefix.arg_types[length:], + prefix.arg_kinds[length:], + prefix.arg_names[length:], + ) + ) + + res.append(Constraint(constrained_from, SUPERTYPE_OF, constrained)) + res.append(Constraint(constrained_from, SUBTYPE_OF, constrained)) else: # This case should have been handled above. assert not isinstance(tvar, TypeVarTupleType) @@ -954,9 +1017,19 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: prefix_len = len(prefix.arg_types) cactual_ps = cactual.param_spec() + cactual_prefix: Parameters | CallableType + if cactual_ps: + cactual_prefix = cactual_ps.prefix + else: + cactual_prefix = cactual + + max_prefix_len = len( + [k for k in cactual_prefix.arg_kinds if k in (ARG_POS, ARG_OPT)] + ) + prefix_len = min(prefix_len, max_prefix_len) + + # we could check the prefixes match here, but that should be caught elsewhere. if not cactual_ps: - max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)]) - prefix_len = min(prefix_len, max_prefix_len) res.append( Constraint( param_spec, @@ -970,7 +1043,17 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: ) ) else: - res.append(Constraint(param_spec, SUBTYPE_OF, cactual_ps)) + # earlier, cactual_prefix = cactual_ps.prefix. thus, this is guaranteed + assert isinstance(cactual_prefix, Parameters) + + constrained_by = cactual_ps.copy_modified( + prefix=cactual_prefix.copy_modified( + cactual_prefix.arg_types[prefix_len:], + cactual_prefix.arg_kinds[prefix_len:], + cactual_prefix.arg_names[prefix_len:], + ) + ) + res.append(Constraint(param_spec, SUBTYPE_OF, constrained_by)) # compare prefixes cactual_prefix = cactual.copy_modified( diff --git a/mypy/test/testconstraints.py b/mypy/test/testconstraints.py index f40996145cba..be1d435f9cca 100644 --- a/mypy/test/testconstraints.py +++ b/mypy/test/testconstraints.py @@ -156,3 +156,65 @@ def test_var_length_tuple_with_fixed_length_tuple(self) -> None: Instance(fx.std_tuplei, [fx.a]), SUPERTYPE_OF, ) + + def test_paramspec_constrained_with_concatenate(self) -> None: + # for legibility (and my own understanding), `Tester.normal()` is `Tester[P]` + # and `Tester.concatenate()` is `Tester[Concatenate[A, P]]` + # ... and 2nd arg to infer_constraints ends up on LHS of equality + fx = self.fx + + # I don't think we can parametrize... + for direction in (SUPERTYPE_OF, SUBTYPE_OF): + print(f"direction is {direction}") + # equiv to: x: Tester[Q] = Tester.normal() + assert set( + infer_constraints(Instance(fx.gpsi, [fx.p]), Instance(fx.gpsi, [fx.q]), direction) + ) == { + Constraint(type_var=fx.p, op=SUPERTYPE_OF, target=fx.q), + Constraint(type_var=fx.p, op=SUBTYPE_OF, target=fx.q), + } + + # equiv to: x: Tester[Q] = Tester.concatenate() + assert set( + infer_constraints( + Instance(fx.gpsi, [fx.p_concatenate]), Instance(fx.gpsi, [fx.q]), direction + ) + ) == { + Constraint(type_var=fx.p_concatenate, op=SUPERTYPE_OF, target=fx.q), + Constraint(type_var=fx.p_concatenate, op=SUBTYPE_OF, target=fx.q), + } + + # equiv to: x: Tester[Concatenate[B, Q]] = Tester.normal() + assert set( + infer_constraints( + Instance(fx.gpsi, [fx.p]), Instance(fx.gpsi, [fx.q_concatenate]), direction + ) + ) == { + Constraint(type_var=fx.p, op=SUPERTYPE_OF, target=fx.q_concatenate), + Constraint(type_var=fx.p, op=SUBTYPE_OF, target=fx.q_concatenate), + } + + # equiv to: x: Tester[Concatenate[B, Q]] = Tester.concatenate() + assert set( + infer_constraints( + Instance(fx.gpsi, [fx.p_concatenate]), + Instance(fx.gpsi, [fx.q_concatenate]), + direction, + ) + ) == { + # this is correct as we assume other parts of mypy will warn that [B] != [A] + Constraint(type_var=fx.p, op=SUPERTYPE_OF, target=fx.q), + Constraint(type_var=fx.p, op=SUBTYPE_OF, target=fx.q), + } + + # equiv to: x: Tester[Concatenate[A, Q]] = Tester.concatenate() + assert set( + infer_constraints( + Instance(fx.gpsi, [fx.p_concatenate]), + Instance(fx.gpsi, [fx.q_concatenate]), + direction, + ) + ) == { + Constraint(type_var=fx.p, op=SUPERTYPE_OF, target=fx.q), + Constraint(type_var=fx.p, op=SUBTYPE_OF, target=fx.q), + } diff --git a/mypy/test/typefixture.py b/mypy/test/typefixture.py index bf1500a3cdec..df78eeb62956 100644 --- a/mypy/test/typefixture.py +++ b/mypy/test/typefixture.py @@ -5,6 +5,8 @@ from __future__ import annotations +from typing import Sequence + from mypy.nodes import ( ARG_OPT, ARG_POS, @@ -26,6 +28,9 @@ Instance, LiteralType, NoneType, + Parameters, + ParamSpecFlavor, + ParamSpecType, Type, TypeAliasType, TypeOfAny, @@ -238,6 +243,31 @@ def make_type_var_tuple(name: str, id: int, upper_bound: Type) -> TypeVarTupleTy "GV2", mro=[self.oi], typevars=["T", "Ts", "S"], typevar_tuple_index=1 ) + def make_parameter_specification( + name: str, id: int, concatenate: Sequence[Type] + ) -> ParamSpecType: + return ParamSpecType( + name, + name, + id, + ParamSpecFlavor.BARE, + self.o, + AnyType(TypeOfAny.from_omitted_generics), + prefix=Parameters( + concatenate, [ARG_POS for _ in concatenate], [None for _ in concatenate] + ), + ) + + self.p = make_parameter_specification("P", 1, []) + self.p_concatenate = make_parameter_specification("P", 1, [self.a]) + self.q = make_parameter_specification("Q", 2, []) + self.q_concatenate = make_parameter_specification("Q", 2, [self.b]) + self.q_concatenate_a = make_parameter_specification("Q", 2, [self.a]) + + self.gpsi = self.make_type_info( + "GPS", mro=[self.oi], typevars=["P"], paramspec_indexes={0} + ) + def _add_bool_dunder(self, type_info: TypeInfo) -> None: signature = CallableType([], [], [], Instance(self.bool_type_info, []), self.function) bool_func = FuncDef("__bool__", [], Block([])) @@ -299,6 +329,7 @@ def make_type_info( bases: list[Instance] | None = None, typevars: list[str] | None = None, typevar_tuple_index: int | None = None, + paramspec_indexes: set[int] | None = None, variances: list[int] | None = None, ) -> TypeInfo: """Make a TypeInfo suitable for use in unit tests.""" @@ -326,6 +357,17 @@ def make_type_info( AnyType(TypeOfAny.from_omitted_generics), ) ) + elif paramspec_indexes is not None and id - 1 in paramspec_indexes: + v.append( + ParamSpecType( + n, + n, + id, + ParamSpecFlavor.BARE, + self.o, + AnyType(TypeOfAny.from_omitted_generics), + ) + ) else: if variances: variance = variances[id - 1] diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 114fe1f8438a..f11b9aa599ed 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -776,7 +776,7 @@ _P = ParamSpec("_P") class Job(Generic[_P]): def __init__(self, target: Callable[_P, None]) -> None: - self.target = target + ... def func( action: Union[Job[int], Callable[[int], None]], @@ -1535,6 +1535,36 @@ def identity(func: Callable[P, None]) -> Callable[P, None]: ... def f(f: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ... [builtins fixtures/paramspec.pyi] +[case testComplicatedParamSpecReturnType] +# regression test for https://github.com/python/mypy/issues/15073 +from typing import TypeVar, Callable +from typing_extensions import ParamSpec, Concatenate + +R = TypeVar("R") +P = ParamSpec("P") + +def f( +) -> Callable[[Callable[Concatenate[Callable[P, R], P], R]], Callable[P, R]]: + def r(fn: Callable[Concatenate[Callable[P, R], P], R]) -> Callable[P, R]: ... + return r +[builtins fixtures/paramspec.pyi] + +[case testParamSpecToParamSpecAssignment] +# minimized from https://github.com/python/mypy/issues/15037 +# ~ the same as https://github.com/python/mypy/issues/15065 +from typing import Callable +from typing_extensions import Concatenate, ParamSpec + +P = ParamSpec("P") + +def f(f: Callable[Concatenate[int, P], None]) -> Callable[P, None]: ... + +x: Callable[ + [Callable[Concatenate[int, P], None]], + Callable[P, None], +] = f +[builtins fixtures/paramspec.pyi] + [case testParamSpecDecoratorAppliedToGeneric] # flags: --new-type-inference from typing import Callable, List, TypeVar