diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 36c89b3ceda4..e35ab9044609 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -36,7 +36,7 @@ from mypy import join from mypy.meet import narrow_declared_type from mypy.maptype import map_instance_to_supertype -from mypy.subtypes import is_subtype, is_equivalent, find_member +from mypy.subtypes import is_subtype, is_equivalent, find_member, non_method_protocol_members from mypy import applytype from mypy import erasetype from mypy.checkmember import analyze_member_access, type_object_type, bind_self @@ -264,15 +264,11 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type: callee_type = self.apply_method_signature_hook( e, callee_type, object_type, signature_hook) ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname, object_type) - if (isinstance(e.callee, RefExpr) and len(e.args) == 2 and - e.callee.fullname in ('builtins.isinstance', 'builtins.issubclass')): - for expr in mypy.checker.flatten(e.args[1]): - tp = self.chk.type_map[expr] - if (isinstance(tp, CallableType) and tp.is_type_obj() and - tp.type_object().is_protocol and - not tp.type_object().runtime_protocol): - self.chk.fail('Only @runtime protocols can be used with' - ' instance and class checks', e) + if isinstance(e.callee, RefExpr) and len(e.args) == 2: + if e.callee.fullname in ('builtins.isinstance', 'builtins.issubclass'): + self.check_runtime_protocol_test(e) + if e.callee.fullname == 'builtins.issubclass': + self.check_protocol_issubclass(e) if isinstance(ret_type, UninhabitedType): self.chk.binder.unreachable() if not allow_none_return and isinstance(ret_type, NoneTyp): @@ -280,6 +276,25 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type: return AnyType(TypeOfAny.from_error) return ret_type + def check_runtime_protocol_test(self, e: CallExpr) -> None: + for expr in mypy.checker.flatten(e.args[1]): + tp = self.chk.type_map[expr] + if (isinstance(tp, CallableType) and tp.is_type_obj() and + tp.type_object().is_protocol and + not tp.type_object().runtime_protocol): + self.chk.fail('Only @runtime protocols can be used with' + ' instance and class checks', e) + + def check_protocol_issubclass(self, e: CallExpr) -> None: + for expr in mypy.checker.flatten(e.args[1]): + tp = self.chk.type_map[expr] + if (isinstance(tp, CallableType) and tp.is_type_obj() and + tp.type_object().is_protocol): + attr_members = non_method_protocol_members(tp.type_object()) + if attr_members: + self.chk.msg.report_non_method_protocol(tp.type_object(), + attr_members, e) + def check_typeddict_call(self, callee: TypedDictType, arg_kinds: List[int], arg_names: Sequence[Optional[str]], diff --git a/mypy/meet.py b/mypy/meet.py index 33d7f6e3df4a..3e883b53fd4d 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -44,6 +44,8 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: return narrowed elif isinstance(declared, (Instance, TupleType)): return meet_types(declared, narrowed) + elif isinstance(declared, TypeType) and isinstance(narrowed, TypeType): + return TypeType.make_normalized(narrow_declared_type(declared.item, narrowed.item)) return narrowed diff --git a/mypy/messages.py b/mypy/messages.py index 5fe0f33ce591..bdde2e1bbc0b 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1056,6 +1056,15 @@ def concrete_only_call(self, typ: Type, context: Context) -> None: self.fail("Only concrete class can be given where {} is expected" .format(self.format(typ)), context) + def report_non_method_protocol(self, tp: TypeInfo, members: List[str], + context: Context) -> None: + self.fail("Only protocols that don't have non-method members can be" + " used with issubclass()", context) + if len(members) < 3: + attrs = ', '.join(members) + self.note('Protocol "{}" has non-method member(s): {}' + .format(tp.name(), attrs), context) + def note_call(self, subtype: Type, call: Type, context: Context) -> None: self.note('"{}.__call__" has type {}'.format(self.format_bare(subtype), self.format(call, verbosity=1)), context) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 049da8a07228..e5034cd50bc1 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -123,7 +123,9 @@ def visit_any(self, left: AnyType) -> bool: def visit_none_type(self, left: NoneTyp) -> bool: if experiments.STRICT_OPTIONAL: return (isinstance(self.right, NoneTyp) or - is_named_instance(self.right, 'builtins.object')) + is_named_instance(self.right, 'builtins.object') or + isinstance(self.right, Instance) and self.right.type.is_protocol and + not self.right.type.protocol_members) else: return True @@ -386,7 +388,7 @@ def f(self) -> A: ... is_compat = is_proper_subtype(subtype, supertype) if not is_compat: return False - if isinstance(subtype, NoneTyp) and member.startswith('__') and member.endswith('__'): + if isinstance(subtype, NoneTyp) and isinstance(supertype, CallableType): # We want __hash__ = None idiom to work even without --strict-optional return False subflags = get_member_flags(member, left.type) @@ -516,6 +518,21 @@ def find_node_type(node: Union[Var, FuncBase], itype: Instance, subtype: Type) - return typ +def non_method_protocol_members(tp: TypeInfo) -> List[str]: + """Find all non-callable members of a protocol.""" + + assert tp.is_protocol + result = [] # type: List[str] + anytype = AnyType(TypeOfAny.special_form) + instance = Instance(tp, [anytype] * len(tp.defn.type_vars)) + + for member in tp.protocol_members: + typ = find_member(member, instance, instance) + if not isinstance(typ, CallableType): + result.append(member) + return result + + def is_callable_subtype(left: CallableType, right: CallableType, ignore_return: bool = False, ignore_pos_arg_names: bool = False, diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index fd43bbf3e0f3..1da2c1f63a52 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -2118,3 +2118,64 @@ main:10: note: def other(self, *args: Any, hint: Optional[str] = ..., ** main:10: note: Got: main:10: note: def other(self) -> int +[case testObjectAllowedInProtocolBases] +from typing import Protocol +class P(Protocol, object): + pass +[out] + +[case testNoneSubtypeOfEmptyProtocol] +from typing import Protocol +class P(Protocol): + pass + +x: P = None +[out] + +[case testNoneSubtypeOfAllProtocolsWithoutStrictOptional] +from typing import Protocol +class P(Protocol): + attr: int + def meth(self, arg: str) -> str: + pass + +x: P = None +[out] + +[case testNoneSubtypeOfEmptyProtocolStrict] +# flags: --strict-optional +from typing import Protocol +class P(Protocol): + pass +x: P = None + +class PBad(Protocol): + x: int +y: PBad = None # E: Incompatible types in assignment (expression has type "None", variable has type "PBad") +[out] + +[case testOnlyMethodProtocolUsableWithIsSubclass] +from typing import Protocol, runtime, Union, Type +@runtime +class P(Protocol): + def meth(self) -> int: + pass +@runtime +class PBad(Protocol): + x: str + +class C: + x: str + def meth(self) -> int: + pass +class E: pass + +cls: Type[Union[C, E]] +issubclass(cls, PBad) # E: Only protocols that don't have non-method members can be used with issubclass() \ + # N: Protocol "PBad" has non-method member(s): x +if issubclass(cls, P): + reveal_type(cls) # E: Revealed type is 'Type[__main__.C]' +else: + reveal_type(cls) # E: Revealed type is 'Type[__main__.E]' +[builtins fixtures/isinstance.pyi] +[out]