diff --git a/mypy/applytype.py b/mypy/applytype.py index e4a364ae383f..2bc2fa92f7dc 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -4,11 +4,55 @@ import mypy.sametypes from mypy.expandtype import expand_type from mypy.types import ( - Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, get_proper_types + Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, get_proper_types, + TypeVarDef, TypeVarLikeDef, ProperType ) from mypy.nodes import Context +def get_target_type( + tvar: TypeVarLikeDef, + type: ProperType, + callable: CallableType, + report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None], + context: Context, + skip_unsatisfied: bool +) -> Optional[Type]: + # TODO(shantanu): fix for ParamSpecDef + assert isinstance(tvar, TypeVarDef) + values = get_proper_types(tvar.values) + if values: + if isinstance(type, AnyType): + return type + if isinstance(type, TypeVarType) and type.values: + # Allow substituting T1 for T if every allowed value of T1 + # is also a legal value of T. + if all(any(mypy.sametypes.is_same_type(v, v1) for v in values) + for v1 in type.values): + return type + matching = [] + for value in values: + if mypy.subtypes.is_subtype(type, value): + matching.append(value) + if matching: + best = matching[0] + # If there are more than one matching value, we select the narrowest + for match in matching[1:]: + if mypy.subtypes.is_subtype(match, best): + best = match + return best + if skip_unsatisfied: + return None + report_incompatible_typevar_value(callable, type, tvar.name, context) + else: + upper_bound = tvar.upper_bound + if not mypy.subtypes.is_subtype(type, upper_bound): + if skip_unsatisfied: + return None + report_incompatible_typevar_value(callable, type, tvar.name, context) + return type + + def apply_generic_arguments( callable: CallableType, orig_types: Sequence[Optional[Type]], report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None], @@ -29,52 +73,20 @@ def apply_generic_arguments( # Check that inferred type variable values are compatible with allowed # values and bounds. Also, promote subtype values to allowed values. types = get_proper_types(orig_types) - for i, type in enumerate(types): + + # Create a map from type variable id to target type. + id_to_type = {} # type: Dict[TypeVarId, Type] + + for tvar, type in zip(tvars, types): assert not isinstance(type, PartialType), "Internal error: must never apply partial type" - values = get_proper_types(callable.variables[i].values) if type is None: continue - if values: - if isinstance(type, AnyType): - continue - if isinstance(type, TypeVarType) and type.values: - # Allow substituting T1 for T if every allowed value of T1 - # is also a legal value of T. - if all(any(mypy.sametypes.is_same_type(v, v1) for v in values) - for v1 in type.values): - continue - matching = [] - for value in values: - if mypy.subtypes.is_subtype(type, value): - matching.append(value) - if matching: - best = matching[0] - # If there are more than one matching value, we select the narrowest - for match in matching[1:]: - if mypy.subtypes.is_subtype(match, best): - best = match - types[i] = best - else: - if skip_unsatisfied: - types[i] = None - else: - report_incompatible_typevar_value(callable, type, callable.variables[i].name, - context) - else: - upper_bound = callable.variables[i].upper_bound - if not mypy.subtypes.is_subtype(type, upper_bound): - if skip_unsatisfied: - types[i] = None - else: - report_incompatible_typevar_value(callable, type, callable.variables[i].name, - context) - # Create a map from type variable id to target type. - id_to_type = {} # type: Dict[TypeVarId, Type] - for i, tv in enumerate(tvars): - typ = types[i] - if typ: - id_to_type[tv.id] = typ + target_type = get_target_type( + tvar, type, callable, report_incompatible_typevar_value, context, skip_unsatisfied + ) + if target_type is not None: + id_to_type[tvar.id] = target_type # Apply arguments to argument types. arg_types = [expand_type(at, id_to_type) for at in callable.arg_types] diff --git a/mypy/checker.py b/mypy/checker.py index d52441c8332e..7085b03b5b38 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1389,15 +1389,14 @@ def expand_typevars(self, defn: FuncItem, typ: CallableType) -> List[Tuple[FuncItem, CallableType]]: # TODO use generator subst = [] # type: List[List[Tuple[TypeVarId, Type]]] - tvars = typ.variables or [] - tvars = tvars[:] + tvars = list(typ.variables) or [] if defn.info: # Class type variables tvars += defn.info.defn.type_vars or [] + # TODO(shantanu): audit for paramspec for tvar in tvars: - if tvar.values: - subst.append([(tvar.id, value) - for value in tvar.values]) + if isinstance(tvar, TypeVarDef) and tvar.values: + subst.append([(tvar.id, value) for value in tvar.values]) # Make a copy of the function to check for each combination of # value restricted type variables. (Except when running mypyc, # where we need one canonical version of the function.) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index e6cde8cf370b..a44bc3f7ed20 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4311,6 +4311,8 @@ def merge_typevars_in_callables_by_name( for tvdef in target.variables: name = tvdef.fullname if name not in unique_typevars: + # TODO(shantanu): fix for ParamSpecDef + assert isinstance(tvdef, TypeVarDef) unique_typevars[name] = TypeVarType(tvdef) variables.append(tvdef) rename[tvdef.id] = unique_typevars[name] diff --git a/mypy/checkmember.py b/mypy/checkmember.py index c9a5a2c86d97..b9221958dacb 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -1,11 +1,11 @@ """Type checking of attribute access""" -from typing import cast, Callable, Optional, Union, List +from typing import cast, Callable, Optional, Union, Sequence from typing_extensions import TYPE_CHECKING from mypy.types import ( - Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarDef, - Overloaded, TypeVarType, UnionType, PartialType, TypeOfAny, LiteralType, + Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, + TypeVarLikeDef, Overloaded, TypeVarType, UnionType, PartialType, TypeOfAny, LiteralType, DeletedType, NoneType, TypeType, has_type_vars, get_proper_type, ProperType ) from mypy.nodes import ( @@ -676,7 +676,7 @@ def analyze_class_attribute_access(itype: Instance, name: str, mx: MemberContext, override_info: Optional[TypeInfo] = None, - original_vars: Optional[List[TypeVarDef]] = None + original_vars: Optional[Sequence[TypeVarLikeDef]] = None ) -> Optional[Type]: """Analyze access to an attribute on a class object. @@ -839,7 +839,7 @@ def analyze_enum_class_attribute_access(itype: Instance, def add_class_tvars(t: ProperType, isuper: Optional[Instance], is_classmethod: bool, original_type: Type, - original_vars: Optional[List[TypeVarDef]] = None) -> Type: + original_vars: Optional[Sequence[TypeVarLikeDef]] = None) -> Type: """Instantiate type variables during analyze_class_attribute_access, e.g T and Q in the following: @@ -883,7 +883,7 @@ class B(A[str]): pass assert isuper is not None t = cast(CallableType, expand_type_by_instance(t, isuper)) freeze_type_vars(t) - return t.copy_modified(variables=tvars + t.variables) + return t.copy_modified(variables=list(tvars) + list(t.variables)) elif isinstance(t, Overloaded): return Overloaded([cast(CallableType, add_class_tvars(item, isuper, is_classmethod, original_type, diff --git a/mypy/expandtype.py b/mypy/expandtype.py index b805f3c0be83..2e3db6b109a4 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -40,6 +40,8 @@ def freshen_function_type_vars(callee: F) -> F: tvdefs = [] tvmap = {} # type: Dict[TypeVarId, Type] for v in callee.variables: + # TODO(shantanu): fix for ParamSpecDef + assert isinstance(v, TypeVarDef) tvdef = TypeVarDef.new_unification_variable(v) tvdefs.append(tvdef) tvmap[v.id] = TypeVarType(tvdef) diff --git a/mypy/fixup.py b/mypy/fixup.py index 023df1e31331..30e1a0dae2b9 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -11,7 +11,8 @@ from mypy.types import ( CallableType, Instance, Overloaded, TupleType, TypedDictType, TypeVarType, UnboundType, UnionType, TypeVisitor, LiteralType, - TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny) + TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny, TypeVarDef +) from mypy.visitor import NodeVisitor from mypy.lookup import lookup_fully_qualified @@ -183,10 +184,11 @@ def visit_callable_type(self, ct: CallableType) -> None: if ct.ret_type is not None: ct.ret_type.accept(self) for v in ct.variables: - if v.values: - for val in v.values: - val.accept(self) - v.upper_bound.accept(self) + if isinstance(v, TypeVarDef): + if v.values: + for val in v.values: + val.accept(self) + v.upper_bound.accept(self) for arg in ct.bound_args: if arg: arg.accept(self) diff --git a/mypy/messages.py b/mypy/messages.py index 3b24cf6001aa..6a4f2a50457c 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -21,7 +21,7 @@ from mypy.errors import Errors from mypy.types import ( Type, CallableType, Instance, TypeVarType, TupleType, TypedDictType, LiteralType, - UnionType, NoneType, AnyType, Overloaded, FunctionLike, DeletedType, TypeType, + UnionType, NoneType, AnyType, Overloaded, FunctionLike, DeletedType, TypeType, TypeVarDef, UninhabitedType, TypeOfAny, UnboundType, PartialType, get_proper_type, ProperType, get_proper_types ) @@ -1862,16 +1862,20 @@ def [T <: int] f(self, x: int, y: T) -> None if tp.variables: tvars = [] for tvar in tp.variables: - upper_bound = get_proper_type(tvar.upper_bound) - if (isinstance(upper_bound, Instance) and - upper_bound.type.fullname != 'builtins.object'): - tvars.append('{} <: {}'.format(tvar.name, format_type_bare(upper_bound))) - elif tvar.values: - tvars.append('{} in ({})' - .format(tvar.name, ', '.join([format_type_bare(tp) - for tp in tvar.values]))) + if isinstance(tvar, TypeVarDef): + upper_bound = get_proper_type(tvar.upper_bound) + if (isinstance(upper_bound, Instance) and + upper_bound.type.fullname != 'builtins.object'): + tvars.append('{} <: {}'.format(tvar.name, format_type_bare(upper_bound))) + elif tvar.values: + tvars.append('{} in ({})' + .format(tvar.name, ', '.join([format_type_bare(tp) + for tp in tvar.values]))) + else: + tvars.append(tvar.name) else: - tvars.append(tvar.name) + # For other TypeVarLikeDefs, just use the repr + tvars.append(repr(tvar)) s = '[{}] {}'.format(', '.join(tvars), s) return 'def {}'.format(s) diff --git a/mypy/nodes.py b/mypy/nodes.py index 9af5dc5f75cc..dff1dd6c1072 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2079,7 +2079,7 @@ class TypeVarExpr(TypeVarLikeExpr): This is also used to represent type variables in symbol tables. - A type variable is not valid as a type unless bound in a TypeVarScope. + A type variable is not valid as a type unless bound in a TypeVarLikeScope. That happens within: 1. a generic class that uses the type variable as a type argument or diff --git a/mypy/plugin.py b/mypy/plugin.py index ed2d80cfaf29..eb31878b62a7 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -126,7 +126,7 @@ class C: pass from mypy.nodes import ( Expression, Context, ClassDef, SymbolTableNode, MypyFile, CallExpr ) -from mypy.tvar_scope import TypeVarScope +from mypy.tvar_scope import TypeVarLikeScope from mypy.types import Type, Instance, CallableType, TypeList, UnboundType, ProperType from mypy.messages import MessageBuilder from mypy.options import Options @@ -265,7 +265,7 @@ def fail(self, msg: str, ctx: Context, serious: bool = False, *, @abstractmethod def anal_type(self, t: Type, *, - tvar_scope: Optional[TypeVarScope] = None, + tvar_scope: Optional[TypeVarLikeScope] = None, allow_tuple_literal: bool = False, allow_unbound_tvars: bool = False, report_invalid_types: bool = True, diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 55a9a469e97b..dc17450664c8 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -10,7 +10,7 @@ from mypy.plugins.common import try_getting_str_literals from mypy.types import ( Type, Instance, AnyType, TypeOfAny, CallableType, NoneType, TypedDictType, - TypeVarType, TPDICT_FB_NAMES, get_proper_type, LiteralType + TypeVarDef, TypeVarType, TPDICT_FB_NAMES, get_proper_type, LiteralType ) from mypy.subtypes import is_subtype from mypy.typeops import make_simplified_union @@ -204,6 +204,7 @@ def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType: # Tweak the signature to include the value type as context. It's # only needed for type inference since there's a union with a type # variable that accepts everything. + assert isinstance(signature.variables[0], TypeVarDef) tv = TypeVarType(signature.variables[0]) return signature.copy_modified( arg_types=[signature.arg_types[0], @@ -269,6 +270,7 @@ def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType: # Tweak the signature to include the value type as context. It's # only needed for type inference since there's a union with a type # variable that accepts everything. + assert isinstance(signature.variables[0], TypeVarDef) tv = TypeVarType(signature.variables[0]) typ = make_simplified_union([value_type, tv]) return signature.copy_modified( diff --git a/mypy/semanal.py b/mypy/semanal.py index 312888ade82e..9fd49f4a6912 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -78,7 +78,7 @@ EnumCallExpr, RUNTIME_PROTOCOL_DECOS, FakeExpression, Statement, AssignmentExpr, ParamSpecExpr ) -from mypy.tvar_scope import TypeVarScope +from mypy.tvar_scope import TypeVarLikeScope from mypy.typevars import fill_typevars from mypy.visitor import NodeVisitor from mypy.errors import Errors, report_internal_error @@ -97,7 +97,7 @@ from mypy.nodes import implicit_module_attrs from mypy.typeanal import ( TypeAnalyser, analyze_type_alias, no_subscript_builtin_alias, - TypeVariableQuery, TypeVarList, remove_dups, has_any_from_unimported_type, + TypeVarLikeQuery, TypeVarLikeList, remove_dups, has_any_from_unimported_type, check_for_explicit_any, type_constructors, fix_instance_types ) from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError @@ -161,7 +161,7 @@ class SemanticAnalyzer(NodeVisitor[None], # Stack of outer classes (the second tuple item contains tvars). type_stack = None # type: List[Optional[TypeInfo]] # Type variables bound by the current scope, be it class or function - tvar_scope = None # type: TypeVarScope + tvar_scope = None # type: TypeVarLikeScope # Per-module options options = None # type: Options @@ -235,7 +235,7 @@ def __init__(self, self.imports = set() self.type = None self.type_stack = [] - self.tvar_scope = TypeVarScope() + self.tvar_scope = TypeVarLikeScope() self.function_stack = [] self.block_depth = [0] self.loop_depth = 0 @@ -478,7 +478,7 @@ def file_context(self, self.is_stub_file = file_node.path.lower().endswith('.pyi') self._is_typeshed_stub_file = is_typeshed_file(file_node.path) self.globals = file_node.names - self.tvar_scope = TypeVarScope() + self.tvar_scope = TypeVarLikeScope() self.named_tuple_analyzer = NamedTupleAnalyzer(options, self) self.typed_dict_analyzer = TypedDictAnalyzer(options, self, self.msg) @@ -1212,7 +1212,7 @@ class Foo(Bar, Generic[T]): ... Returns (remaining base expressions, inferred type variables, is protocol). """ removed = [] # type: List[int] - declared_tvars = [] # type: TypeVarList + declared_tvars = [] # type: TypeVarLikeList is_protocol = False for i, base_expr in enumerate(base_type_exprs): self.analyze_type_expr(base_expr) @@ -1260,10 +1260,16 @@ class Foo(Bar, Generic[T]): ... tvar_defs = [] # type: List[TypeVarDef] for name, tvar_expr in declared_tvars: tvar_def = self.tvar_scope.bind_new(name, tvar_expr) + assert isinstance(tvar_def, TypeVarDef), ( + "mypy does not currently support ParamSpec use in generic classes" + ) tvar_defs.append(tvar_def) return base_type_exprs, tvar_defs, is_protocol - def analyze_class_typevar_declaration(self, base: Type) -> Optional[Tuple[TypeVarList, bool]]: + def analyze_class_typevar_declaration( + self, + base: Type + ) -> Optional[Tuple[TypeVarLikeList, bool]]: """Analyze type variables declared using Generic[...] or Protocol[...]. Args: @@ -1282,7 +1288,7 @@ def analyze_class_typevar_declaration(self, base: Type) -> Optional[Tuple[TypeVa sym.node.fullname == 'typing.Protocol' and base.args or sym.node.fullname == 'typing_extensions.Protocol' and base.args): is_proto = sym.node.fullname != 'typing.Generic' - tvars = [] # type: TypeVarList + tvars = [] # type: TypeVarLikeList for arg in unbound.args: tag = self.track_incomplete_refs() tvar = self.analyze_unbound_tvar(arg) @@ -1312,9 +1318,9 @@ def analyze_unbound_tvar(self, t: Type) -> Optional[Tuple[str, TypeVarExpr]]: def get_all_bases_tvars(self, base_type_exprs: List[Expression], - removed: List[int]) -> TypeVarList: + removed: List[int]) -> TypeVarLikeList: """Return all type variable references in bases.""" - tvars = [] # type: TypeVarList + tvars = [] # type: TypeVarLikeList for i, base_expr in enumerate(base_type_exprs): if i not in removed: try: @@ -1322,7 +1328,7 @@ def get_all_bases_tvars(self, except TypeTranslationError: # This error will be caught later. continue - base_tvars = base.accept(TypeVariableQuery(self.lookup_qualified, self.tvar_scope)) + base_tvars = base.accept(TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope)) tvars.extend(base_tvars) return remove_dups(tvars) @@ -2416,7 +2422,7 @@ def analyze_alias(self, rvalue: Expression, typ = None # type: Optional[Type] if res: typ, depends_on = res - found_type_vars = typ.accept(TypeVariableQuery(self.lookup_qualified, self.tvar_scope)) + found_type_vars = typ.accept(TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope)) alias_tvars = [name for (name, node) in found_type_vars] qualified_tvars = [node.fullname for (name, node) in found_type_vars] else: @@ -4455,7 +4461,7 @@ def add_unknown_imported_symbol(self, # @contextmanager - def tvar_scope_frame(self, frame: TypeVarScope) -> Iterator[None]: + def tvar_scope_frame(self, frame: TypeVarLikeScope) -> Iterator[None]: old_scope = self.tvar_scope self.tvar_scope = frame yield @@ -4788,11 +4794,11 @@ def analyze_type_expr(self, expr: Expression) -> None: # them semantically analyzed, however, if they need to treat it as an expression # and not a type. (Which is to say, mypyc needs to do this.) Do the analysis # in a fresh tvar scope in order to suppress any errors about using type variables. - with self.tvar_scope_frame(TypeVarScope()): + with self.tvar_scope_frame(TypeVarLikeScope()): expr.accept(self) def type_analyzer(self, *, - tvar_scope: Optional[TypeVarScope] = None, + tvar_scope: Optional[TypeVarLikeScope] = None, allow_tuple_literal: bool = False, allow_unbound_tvars: bool = False, allow_placeholder: bool = False, @@ -4815,7 +4821,7 @@ def type_analyzer(self, *, def anal_type(self, typ: Type, *, - tvar_scope: Optional[TypeVarScope] = None, + tvar_scope: Optional[TypeVarLikeScope] = None, allow_tuple_literal: bool = False, allow_unbound_tvars: bool = False, allow_placeholder: bool = False, diff --git a/mypy/semanal_shared.py b/mypy/semanal_shared.py index ba0972e8c302..3fd2498c696f 100644 --- a/mypy/semanal_shared.py +++ b/mypy/semanal_shared.py @@ -14,7 +14,7 @@ from mypy.types import ( Type, FunctionLike, Instance, TupleType, TPDICT_FB_NAMES, ProperType, get_proper_type ) -from mypy.tvar_scope import TypeVarScope +from mypy.tvar_scope import TypeVarLikeScope from mypy.errorcodes import ErrorCode from mypy import join @@ -105,7 +105,7 @@ def accept(self, node: Node) -> None: @abstractmethod def anal_type(self, t: Type, *, - tvar_scope: Optional[TypeVarScope] = None, + tvar_scope: Optional[TypeVarLikeScope] = None, allow_tuple_literal: bool = False, allow_unbound_tvars: bool = False, report_invalid_types: bool = True) -> Optional[Type]: diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py index 587df57e8a08..1c411886ac7d 100644 --- a/mypy/server/astmerge.py +++ b/mypy/server/astmerge.py @@ -370,9 +370,10 @@ def visit_callable_type(self, typ: CallableType) -> None: if typ.fallback is not None: typ.fallback.accept(self) for tv in typ.variables: - tv.upper_bound.accept(self) - for value in tv.values: - value.accept(self) + if isinstance(tv, TypeVarDef): + tv.upper_bound.accept(self) + for value in tv.values: + value.accept(self) def visit_overloaded(self, t: Overloaded) -> None: for item in t.items(): diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 39a35c728028..1c4fc77d62d2 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -89,6 +89,7 @@ 'check-reports.test', 'check-errorcodes.test', 'check-annotated.test', + 'check-parameter-specification.test', ] # Tests that use Python 3.8-only AST features (like expression-scoped ignores): diff --git a/mypy/tvar_scope.py b/mypy/tvar_scope.py index 7d5dce18fc66..4c7a165036a2 100644 --- a/mypy/tvar_scope.py +++ b/mypy/tvar_scope.py @@ -1,16 +1,19 @@ from typing import Optional, Dict, Union -from mypy.types import TypeVarDef -from mypy.nodes import TypeVarExpr, SymbolTableNode +from mypy.types import TypeVarLikeDef, TypeVarDef, ParamSpecDef +from mypy.nodes import ParamSpecExpr, TypeVarExpr, TypeVarLikeExpr, SymbolTableNode -class TypeVarScope: - """Scope that holds bindings for type variables. Node fullname -> TypeVarDef.""" +class TypeVarLikeScope: + """Scope that holds bindings for type variables and parameter specifications. + + Node fullname -> TypeVarLikeDef. + """ def __init__(self, - parent: 'Optional[TypeVarScope]' = None, + parent: 'Optional[TypeVarLikeScope]' = None, is_class_scope: bool = False, - prohibited: 'Optional[TypeVarScope]' = None) -> None: - """Initializer for TypeVarScope + prohibited: 'Optional[TypeVarLikeScope]' = None) -> None: + """Initializer for TypeVarLikeScope Parameters: parent: the outer scope for this scope @@ -18,7 +21,7 @@ def __init__(self, prohibited: Type variables that aren't strictly in scope exactly, but can't be bound because they're part of an outer class's scope. """ - self.scope = {} # type: Dict[str, TypeVarDef] + self.scope = {} # type: Dict[str, TypeVarLikeDef] self.parent = parent self.func_id = 0 self.class_id = 0 @@ -28,9 +31,9 @@ def __init__(self, self.func_id = parent.func_id self.class_id = parent.class_id - def get_function_scope(self) -> 'Optional[TypeVarScope]': + def get_function_scope(self) -> 'Optional[TypeVarLikeScope]': """Get the nearest parent that's a function scope, not a class scope""" - it = self # type: Optional[TypeVarScope] + it = self # type: Optional[TypeVarLikeScope] while it is not None and it.is_class_scope: it = it.parent return it @@ -44,36 +47,49 @@ def allow_binding(self, fullname: str) -> bool: return False return True - def method_frame(self) -> 'TypeVarScope': + def method_frame(self) -> 'TypeVarLikeScope': """A new scope frame for binding a method""" - return TypeVarScope(self, False, None) + return TypeVarLikeScope(self, False, None) - def class_frame(self) -> 'TypeVarScope': + def class_frame(self) -> 'TypeVarLikeScope': """A new scope frame for binding a class. Prohibits *this* class's tvars""" - return TypeVarScope(self.get_function_scope(), True, self) + return TypeVarLikeScope(self.get_function_scope(), True, self) - def bind_new(self, name: str, tvar_expr: TypeVarExpr) -> TypeVarDef: + def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeDef: if self.is_class_scope: self.class_id += 1 i = self.class_id else: self.func_id -= 1 i = self.func_id - tvar_def = TypeVarDef(name, - tvar_expr.fullname, - i, - values=tvar_expr.values, - upper_bound=tvar_expr.upper_bound, - variance=tvar_expr.variance, - line=tvar_expr.line, - column=tvar_expr.column) + if isinstance(tvar_expr, TypeVarExpr): + tvar_def = TypeVarDef( + name, + tvar_expr.fullname, + i, + values=tvar_expr.values, + upper_bound=tvar_expr.upper_bound, + variance=tvar_expr.variance, + line=tvar_expr.line, + column=tvar_expr.column + ) # type: TypeVarLikeDef + elif isinstance(tvar_expr, ParamSpecExpr): + tvar_def = ParamSpecDef( + name, + tvar_expr.fullname, + i, + line=tvar_expr.line, + column=tvar_expr.column + ) + else: + assert False self.scope[tvar_expr.fullname] = tvar_def return tvar_def - def bind_existing(self, tvar_def: TypeVarDef) -> None: + def bind_existing(self, tvar_def: TypeVarLikeDef) -> None: self.scope[tvar_def.fullname] = tvar_def - def get_binding(self, item: Union[str, SymbolTableNode]) -> Optional[TypeVarDef]: + def get_binding(self, item: Union[str, SymbolTableNode]) -> Optional[TypeVarLikeDef]: fullname = item.fullname if isinstance(item, SymbolTableNode) else item assert fullname is not None if fullname in self.scope: diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 934a4421362f..06b44bd497f0 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -13,7 +13,7 @@ from abc import abstractmethod from mypy.ordered_dict import OrderedDict -from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable, Optional, Set +from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable, Optional, Set, Sequence from mypy_extensions import trait T = TypeVar('T') @@ -21,7 +21,7 @@ from mypy.types import ( Type, AnyType, CallableType, Overloaded, TupleType, TypedDictType, LiteralType, RawExpressionType, Instance, NoneType, TypeType, - UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef, + UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarLikeDef, UnboundType, ErasedType, StarType, EllipsisType, TypeList, CallableArgument, PlaceholderType, TypeAliasType, get_proper_type ) @@ -218,7 +218,7 @@ def translate_types(self, types: Iterable[Type]) -> List[Type]: return [t.accept(self) for t in types] def translate_variables(self, - variables: List[TypeVarDef]) -> List[TypeVarDef]: + variables: Sequence[TypeVarLikeDef]) -> Sequence[TypeVarLikeDef]: return variables def visit_overloaded(self, t: Overloaded) -> Type: diff --git a/mypy/typeanal.py b/mypy/typeanal.py index f1a96eacd23e..74041175cb33 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -16,17 +16,17 @@ CallableType, NoneType, ErasedType, DeletedType, TypeList, TypeVarDef, SyntheticTypeVisitor, StarType, PartialType, EllipsisType, UninhabitedType, TypeType, CallableArgument, TypeQuery, union_items, TypeOfAny, LiteralType, RawExpressionType, - PlaceholderType, Overloaded, get_proper_type, TypeAliasType + PlaceholderType, Overloaded, get_proper_type, TypeAliasType, TypeVarLikeDef, ParamSpecDef ) from mypy.nodes import ( TypeInfo, Context, SymbolTableNode, Var, Expression, nongen_builtins, check_arg_names, check_arg_kinds, ARG_POS, ARG_NAMED, - ARG_OPT, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, TypeVarExpr, + ARG_OPT, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, TypeVarExpr, TypeVarLikeExpr, ParamSpecExpr, TypeAlias, PlaceholderNode, SYMBOL_FUNCBASE_TYPES, Decorator, MypyFile ) from mypy.typetraverser import TypeTraverserVisitor -from mypy.tvar_scope import TypeVarScope +from mypy.tvar_scope import TypeVarLikeScope from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.plugin import Plugin, TypeAnalyzerPluginInterface, AnalyzeTypeContext from mypy.semanal_shared import SemanticAnalyzerCoreInterface @@ -64,7 +64,7 @@ def analyze_type_alias(node: Expression, api: SemanticAnalyzerCoreInterface, - tvar_scope: TypeVarScope, + tvar_scope: TypeVarLikeScope, plugin: Plugin, options: Options, is_typeshed_stub: bool, @@ -117,7 +117,7 @@ class TypeAnalyser(SyntheticTypeVisitor[Type], TypeAnalyzerPluginInterface): def __init__(self, api: SemanticAnalyzerCoreInterface, - tvar_scope: TypeVarScope, + tvar_scope: TypeVarLikeScope, plugin: Plugin, options: Options, is_typeshed_stub: bool, *, @@ -200,11 +200,23 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) self.fail(no_subscript_builtin_alias(fullname, propose_alt=not self.defining_alias), t) tvar_def = self.tvar_scope.get_binding(sym) + if isinstance(sym.node, ParamSpecExpr): + if tvar_def is None: + self.fail('ParamSpec "{}" is unbound'.format(t.name), t) + return AnyType(TypeOfAny.from_error) + self.fail('Invalid location for ParamSpec "{}"'.format(t.name), t) + self.note( + 'You can use ParamSpec as the first argument to Callable, e.g., ' + "'Callable[{}, int]'".format(t.name), + t + ) + return AnyType(TypeOfAny.from_error) if isinstance(sym.node, TypeVarExpr) and tvar_def is not None and self.defining_alias: self.fail('Can\'t use bound type variable "{}"' ' to define generic alias'.format(t.name), t) return AnyType(TypeOfAny.from_error) if isinstance(sym.node, TypeVarExpr) and tvar_def is not None: + assert isinstance(tvar_def, TypeVarDef) if len(t.args) > 0: self.fail('Type variable "{}" used with arguments'.format(t.name), t) return TypeVarType(tvar_def, t.line) @@ -607,6 +619,32 @@ def visit_placeholder_type(self, t: PlaceholderType) -> Type: assert isinstance(n.node, TypeInfo) return self.analyze_type_with_type_info(n.node, t.args, t) + def analyze_callable_args_for_paramspec( + self, + callable_args: Type, + ret_type: Type, + fallback: Instance, + ) -> Optional[CallableType]: + """Construct a 'Callable[P, RET]', where P is ParamSpec, return None if we cannot.""" + if not isinstance(callable_args, UnboundType): + return None + sym = self.lookup_qualified(callable_args.name, callable_args) + if sym is None: + return None + tvar_def = self.tvar_scope.get_binding(sym) + if not isinstance(tvar_def, ParamSpecDef): + return None + + # TODO(shantanu): construct correct type for paramspec + return CallableType( + [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)], + [nodes.ARG_STAR, nodes.ARG_STAR2], + [None, None], + ret_type=ret_type, + fallback=fallback, + is_ellipsis_args=True + ) + def analyze_callable_type(self, t: UnboundType) -> Type: fallback = self.named_type('builtins.function') if len(t.args) == 0: @@ -619,10 +657,11 @@ def analyze_callable_type(self, t: UnboundType) -> Type: fallback=fallback, is_ellipsis_args=True) elif len(t.args) == 2: + callable_args = t.args[0] ret_type = t.args[1] - if isinstance(t.args[0], TypeList): + if isinstance(callable_args, TypeList): # Callable[[ARG, ...], RET] (ordinary callable type) - analyzed_args = self.analyze_callable_args(t.args[0]) + analyzed_args = self.analyze_callable_args(callable_args) if analyzed_args is None: return AnyType(TypeOfAny.from_error) args, kinds, names = analyzed_args @@ -631,7 +670,7 @@ def analyze_callable_type(self, t: UnboundType) -> Type: names, ret_type=ret_type, fallback=fallback) - elif isinstance(t.args[0], EllipsisType): + elif isinstance(callable_args, EllipsisType): # Callable[..., RET] (with literal ellipsis; accept arbitrary arguments) ret = CallableType([AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)], @@ -641,8 +680,19 @@ def analyze_callable_type(self, t: UnboundType) -> Type: fallback=fallback, is_ellipsis_args=True) else: - self.fail('The first argument to Callable must be a list of types or "..."', t) - return AnyType(TypeOfAny.from_error) + # Callable[P, RET] (where P is ParamSpec) + maybe_ret = self.analyze_callable_args_for_paramspec( + callable_args, + ret_type, + fallback + ) + if maybe_ret is None: + # Callable[?, RET] (where ? is something invalid) + # TODO(shantanu): change error to mention paramspec, once we actually have some + # support for it + self.fail('The first argument to Callable must be a list of types or "..."', t) + return AnyType(TypeOfAny.from_error) + ret = maybe_ret else: self.fail('Please use "Callable[[], ]" or "Callable"', t) return AnyType(TypeOfAny.from_error) @@ -791,13 +841,14 @@ def tvar_scope_frame(self) -> Iterator[None]: self.tvar_scope = old_scope def infer_type_variables(self, - type: CallableType) -> List[Tuple[str, TypeVarExpr]]: + type: CallableType) -> List[Tuple[str, TypeVarLikeExpr]]: """Return list of unique type variables referred to in a callable.""" names = [] # type: List[str] - tvars = [] # type: List[TypeVarExpr] + tvars = [] # type: List[TypeVarLikeExpr] for arg in type.arg_types: - for name, tvar_expr in arg.accept(TypeVariableQuery(self.lookup_qualified, - self.tvar_scope)): + for name, tvar_expr in arg.accept( + TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope) + ): if name not in names: names.append(name) tvars.append(tvar_expr) @@ -806,29 +857,30 @@ def infer_type_variables(self, # functions in the return type belong to those functions, not the # function we're currently analyzing. for name, tvar_expr in type.ret_type.accept( - TypeVariableQuery(self.lookup_qualified, self.tvar_scope, - include_callables=False)): + TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope, include_callables=False) + ): if name not in names: names.append(name) tvars.append(tvar_expr) return list(zip(names, tvars)) - def bind_function_type_variables(self, - fun_type: CallableType, defn: Context) -> List[TypeVarDef]: + def bind_function_type_variables( + self, fun_type: CallableType, defn: Context + ) -> Sequence[TypeVarLikeDef]: """Find the type variables of the function type and bind them in our tvar_scope""" if fun_type.variables: for var in fun_type.variables: var_node = self.lookup_qualified(var.name, defn) assert var_node, "Binding for function type variable not found within function" var_expr = var_node.node - assert isinstance(var_expr, TypeVarExpr) + assert isinstance(var_expr, TypeVarLikeExpr) self.tvar_scope.bind_new(var.name, var_expr) return fun_type.variables typevars = self.infer_type_variables(fun_type) # Do not define a new type variable if already defined in scope. typevars = [(name, tvar) for name, tvar in typevars if not self.is_defined_type_var(name, defn)] - defs = [] # type: List[TypeVarDef] + defs = [] # type: List[TypeVarLikeDef] for name, tvar in typevars: if not self.tvar_scope.allow_binding(tvar.fullname): self.fail("Type variable '{}' is bound by an outer class".format(name), defn) @@ -860,17 +912,22 @@ def anal_type(self, t: Type, nested: bool = True) -> Type: if nested: self.nesting_level -= 1 - def anal_var_defs(self, var_defs: List[TypeVarDef]) -> List[TypeVarDef]: - a = [] # type: List[TypeVarDef] - for vd in var_defs: - a.append(TypeVarDef(vd.name, - vd.fullname, - vd.id.raw_id, - self.anal_array(vd.values), - vd.upper_bound.accept(self), - vd.variance, - vd.line)) - return a + def anal_var_def(self, var_def: TypeVarLikeDef) -> TypeVarLikeDef: + if isinstance(var_def, TypeVarDef): + return TypeVarDef( + var_def.name, + var_def.fullname, + var_def.id.raw_id, + self.anal_array(var_def.values), + var_def.upper_bound.accept(self), + var_def.variance, + var_def.line + ) + else: + return var_def + + def anal_var_defs(self, var_defs: Sequence[TypeVarLikeDef]) -> List[TypeVarLikeDef]: + return [self.anal_var_def(vd) for vd in var_defs] def named_type_with_normalized_str(self, fully_qualified_name: str) -> Instance: """Does almost the same thing as `named_type`, except that we immediately @@ -898,7 +955,7 @@ def tuple_type(self, items: List[Type]) -> TupleType: return TupleType(items, fallback=self.named_type('builtins.tuple', [any_type])) -TypeVarList = List[Tuple[str, TypeVarExpr]] +TypeVarLikeList = List[Tuple[str, TypeVarLikeExpr]] # Mypyc doesn't support callback protocols yet. MsgCallback = Callable[[str, Context, DefaultNamedArg(Optional[ErrorCode], 'code')], None] @@ -1059,11 +1116,11 @@ def flatten_tvars(ll: Iterable[List[T]]) -> List[T]: return remove_dups(chain.from_iterable(ll)) -class TypeVariableQuery(TypeQuery[TypeVarList]): +class TypeVarLikeQuery(TypeQuery[TypeVarLikeList]): def __init__(self, lookup: Callable[[str, Context], Optional[SymbolTableNode]], - scope: 'TypeVarScope', + scope: 'TypeVarLikeScope', *, include_callables: bool = True, include_bound_tvars: bool = False) -> None: @@ -1080,12 +1137,12 @@ def _seems_like_callable(self, type: UnboundType) -> bool: return True return False - def visit_unbound_type(self, t: UnboundType) -> TypeVarList: + def visit_unbound_type(self, t: UnboundType) -> TypeVarLikeList: name = t.name node = self.lookup(name, t) - if node and isinstance(node.node, TypeVarExpr) and ( + if node and isinstance(node.node, TypeVarLikeExpr) and ( self.include_bound_tvars or self.scope.get_binding(node) is None): - assert isinstance(node.node, TypeVarExpr) + assert isinstance(node.node, TypeVarLikeExpr) return [(name, node.node)] elif not self.include_callables and self._seems_like_callable(t): return [] @@ -1094,7 +1151,7 @@ def visit_unbound_type(self, t: UnboundType) -> TypeVarList: else: return super().visit_unbound_type(t) - def visit_callable_type(self, t: CallableType) -> TypeVarList: + def visit_callable_type(self, t: CallableType) -> TypeVarLikeList: if self.include_callables: return super().visit_callable_type(t) else: diff --git a/mypy/typeops.py b/mypy/typeops.py index 09f418b129ed..732c19f72113 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -10,7 +10,7 @@ import sys from mypy.types import ( - TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded, + TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, TypeVarLikeDef, Overloaded, TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, TypedDictType, AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types, copy_type, TypeAliasType, TypeQuery @@ -113,7 +113,7 @@ def class_callable(init_type: CallableType, info: TypeInfo, type_type: Instance, special_sig: Optional[str], is_new: bool, orig_self_type: Optional[Type] = None) -> CallableType: """Create a type object type based on the signature of __init__.""" - variables = [] # type: List[TypeVarDef] + variables = [] # type: List[TypeVarLikeDef] variables.extend(info.defn.type_vars) variables.extend(init_type.variables) @@ -227,6 +227,8 @@ class B(A): pass # TODO: infer bounds on the type of *args? return cast(F, func) self_param_type = get_proper_type(func.arg_types[0]) + + variables = [] # type: Sequence[TypeVarLikeDef] if func.variables and supported_self_type(self_param_type): if original_type is None: # TODO: type check method override (see #7861). @@ -472,7 +474,9 @@ def true_or_false(t: Type) -> ProperType: return new_t -def erase_def_to_union_or_bound(tdef: TypeVarDef) -> Type: +def erase_def_to_union_or_bound(tdef: TypeVarLikeDef) -> Type: + # TODO(shantanu): fix for ParamSpecDef + assert isinstance(tdef, TypeVarDef) if tdef.values: return make_simplified_union(tdef.values) else: diff --git a/mypy/types.py b/mypy/types.py index 98943e374e48..a2651a01b37a 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -328,12 +328,34 @@ def is_meta_var(self) -> bool: return self.meta_level > 0 -class TypeVarDef(mypy.nodes.Context): - """Definition of a single type variable.""" - +class TypeVarLikeDef(mypy.nodes.Context): name = '' # Name (may be qualified) fullname = '' # Fully qualified name id = None # type: TypeVarId + + def __init__( + self, name: str, fullname: str, id: Union[TypeVarId, int], line: int = -1, column: int = -1 + ) -> None: + super().__init__(line, column) + self.name = name + self.fullname = fullname + if isinstance(id, int): + id = TypeVarId(id) + self.id = id + + def __repr__(self) -> str: + return self.name + + def serialize(self) -> JsonDict: + raise NotImplementedError + + @classmethod + def deserialize(cls, data: JsonDict) -> 'TypeVarLikeDef': + raise NotImplementedError + + +class TypeVarDef(TypeVarLikeDef): + """Definition of a single type variable.""" values = None # type: List[Type] # Value restriction, empty list if no restriction upper_bound = None # type: Type variance = INVARIANT # type: int @@ -341,13 +363,8 @@ class TypeVarDef(mypy.nodes.Context): def __init__(self, name: str, fullname: str, id: Union[TypeVarId, int], values: List[Type], upper_bound: Type, variance: int = INVARIANT, line: int = -1, column: int = -1) -> None: - super().__init__(line, column) + super().__init__(name, fullname, id, line, column) assert values is not None, "No restrictions must be represented by empty list" - self.name = name - self.fullname = fullname - if isinstance(id, int): - id = TypeVarId(id) - self.id = id self.values = values self.upper_bound = upper_bound self.variance = variance @@ -389,6 +406,28 @@ def deserialize(cls, data: JsonDict) -> 'TypeVarDef': ) +class ParamSpecDef(TypeVarLikeDef): + """Definition of a single ParamSpec variable.""" + + def serialize(self) -> JsonDict: + assert not self.id.is_meta_var() + return { + '.class': 'ParamSpecDef', + 'name': self.name, + 'fullname': self.fullname, + 'id': self.id.raw_id, + } + + @classmethod + def deserialize(cls, data: JsonDict) -> 'ParamSpecDef': + assert data['.class'] == 'ParamSpecDef' + return ParamSpecDef( + data['name'], + data['fullname'], + data['id'], + ) + + class UnboundType(ProperType): """Instance type that has not been bound during semantic analysis.""" @@ -976,7 +1015,7 @@ def __init__(self, fallback: Instance, name: Optional[str] = None, definition: Optional[SymbolNode] = None, - variables: Optional[List[TypeVarDef]] = None, + variables: Optional[Sequence[TypeVarLikeDef]] = None, line: int = -1, column: int = -1, is_ellipsis_args: bool = False, @@ -1028,7 +1067,7 @@ def copy_modified(self, fallback: Bogus[Instance] = _dummy, name: Bogus[Optional[str]] = _dummy, definition: Bogus[SymbolNode] = _dummy, - variables: Bogus[List[TypeVarDef]] = _dummy, + variables: Bogus[Sequence[TypeVarLikeDef]] = _dummy, line: Bogus[int] = _dummy, column: Bogus[int] = _dummy, is_ellipsis_args: Bogus[bool] = _dummy, @@ -2057,15 +2096,19 @@ def visit_callable_type(self, t: CallableType) -> str: if t.variables: vs = [] - # We reimplement TypeVarDef.__repr__ here in order to support id_mapper. for var in t.variables: - if var.values: - vals = '({})'.format(', '.join(val.accept(self) for val in var.values)) - vs.append('{} in {}'.format(var.name, vals)) - elif not is_named_instance(var.upper_bound, 'builtins.object'): - vs.append('{} <: {}'.format(var.name, var.upper_bound.accept(self))) + if isinstance(var, TypeVarDef): + # We reimplement TypeVarDef.__repr__ here in order to support id_mapper. + if var.values: + vals = '({})'.format(', '.join(val.accept(self) for val in var.values)) + vs.append('{} in {}'.format(var.name, vals)) + elif not is_named_instance(var.upper_bound, 'builtins.object'): + vs.append('{} <: {}'.format(var.name, var.upper_bound.accept(self))) + else: + vs.append(var.name) else: - vs.append(var.name) + # For other TypeVarLikeDefs, just use the repr + vs.append(repr(var)) s = '{} {}'.format('[{}]'.format(', '.join(vs)), s) return 'def {}'.format(s) diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test new file mode 100644 index 000000000000..14c426cde1bf --- /dev/null +++ b/test-data/unit/check-parameter-specification.test @@ -0,0 +1,29 @@ +[case testBasicParamSpec] +from typing_extensions import ParamSpec +P = ParamSpec('P') +[builtins fixtures/tuple.pyi] + +[case testParamSpecLocations] +from typing import Callable, List +from typing_extensions import ParamSpec, Concatenate +P = ParamSpec('P') + +x: P # E: ParamSpec "P" is unbound + +def foo1(x: Callable[P, int]) -> Callable[P, str]: ... + +def foo2(x: P) -> P: ... # E: Invalid location for ParamSpec "P" \ + # N: You can use ParamSpec as the first argument to Callable, e.g., 'Callable[P, int]' + +# TODO(shantanu): uncomment once we have support for Concatenate +# def foo3(x: Concatenate[int, P]) -> int: ... $ E: Invalid location for Concatenate + +def foo4(x: List[P]) -> None: ... # E: Invalid location for ParamSpec "P" \ + # N: You can use ParamSpec as the first argument to Callable, e.g., 'Callable[P, int]' + +def foo5(x: Callable[[int, str], P]) -> None: ... # E: Invalid location for ParamSpec "P" \ + # N: You can use ParamSpec as the first argument to Callable, e.g., 'Callable[P, int]' + +def foo6(x: Callable[[P], int]) -> None: ... # E: Invalid location for ParamSpec "P" \ + # N: You can use ParamSpec as the first argument to Callable, e.g., 'Callable[P, int]' +[builtins fixtures/tuple.pyi]