From 15d70c1a806f308e361264e2224762bae664ab9c Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Fri, 26 May 2023 19:35:40 -0700 Subject: [PATCH] Add typeshed bases to explicit type object bases --- docs/changelog.md | 2 ++ pyanalyze/checker.py | 7 ++++--- pyanalyze/test_type_object.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 77e4fd7e..2d367f01 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,8 @@ ## Unreleased +- Take into account additional base classes declared in stub + files (fixing some false positives around `typing.IO`) (#635) - Fix crash on stubs that contain dict or set literals (#634) - Remove more old special cases and improve robustness of annotation parsing (#630) diff --git a/pyanalyze/checker.py b/pyanalyze/checker.py index 8c4c3109..1fb095ee 100644 --- a/pyanalyze/checker.py +++ b/pyanalyze/checker.py @@ -163,15 +163,16 @@ def _build_type_object(self, typ: Union[type, super, str]) -> TypeObject: elif isinstance(typ, super): return TypeObject(typ, self.get_additional_bases(typ)) else: - additional_bases = self.get_additional_bases(typ) + plugin_bases = self.get_additional_bases(typ) + typeshed_bases = self._get_typeshed_bases(typ) + additional_bases = plugin_bases | typeshed_bases # Is it marked as a protocol in stubs? If so, use the stub definition. if self.ts_finder.is_protocol(typ): - bases = self._get_typeshed_bases(typ) return TypeObject( typ, additional_bases, is_protocol=True, - protocol_members=self._get_protocol_members(bases), + protocol_members=self._get_protocol_members(typeshed_bases), ) # Is it a protocol at runtime? if is_instance_of_typing_name(typ, "_ProtocolMeta") and safe_getattr( diff --git a/pyanalyze/test_type_object.py b/pyanalyze/test_type_object.py index 9f535407..70a5d624 100644 --- a/pyanalyze/test_type_object.py +++ b/pyanalyze/test_type_object.py @@ -374,3 +374,33 @@ def capybara(t1: Type[int], t2: type): want_hash([]) # E: incompatible_argument want_myhash([]) # E: incompatible_argument + + +class TestIO(TestNameCheckVisitorBase): + @assert_passes() + def test_text(self): + from typing import TextIO + from typing_extensions import assert_type + import io + + def want_io(x: TextIO): + x.write("hello") + + def capybara(): + with open("x") as f: + assert_type(f, io.TextIOWrapper) + want_io(f) + + @assert_passes() + def test_binary(self): + from typing import BinaryIO + from typing_extensions import assert_type + import io + + def want_io(x: BinaryIO): + x.write(b"hello") + + def capybara(): + with open("x", "rb") as f: + assert_type(f, io.BufferedReader) + want_io(f)