diff --git a/docs/changelog.md b/docs/changelog.md index 0b04d809..b4dcf3e9 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -9,7 +9,7 @@ and similar functions, which previously emitted `inference_failure` (#636) - Take into account additional base classes declared in stub - files (fixing some false positives around `typing.IO`) (#635) + files (fixing some false positives around `typing.IO`) (#635, #639) - Fix crash on stubs that contain dict or set literals (#634) - Remove more old special cases and improve robustness of annotation parsing (#630) diff --git a/pyanalyze/checker.py b/pyanalyze/checker.py index 1fb095ee..61f1f4c7 100644 --- a/pyanalyze/checker.py +++ b/pyanalyze/checker.py @@ -164,7 +164,7 @@ def _build_type_object(self, typ: Union[type, super, str]) -> TypeObject: return TypeObject(typ, self.get_additional_bases(typ)) else: plugin_bases = self.get_additional_bases(typ) - typeshed_bases = self._get_typeshed_bases(typ) + typeshed_bases = self._get_recursive_typeshed_bases(typ) additional_bases = plugin_bases | typeshed_bases # Is it marked as a protocol in stubs? If so, use the stub definition. if self.ts_finder.is_protocol(typ): @@ -190,6 +190,22 @@ def _build_type_object(self, typ: Union[type, super, str]) -> TypeObject: return TypeObject(typ, additional_bases) + def _get_recursive_typeshed_bases( + self, typ: Union[type, str] + ) -> Set[Union[type, str]]: + seen = set() + to_do = {typ} + result = set() + while to_do: + typ = to_do.pop() + if typ in seen: + continue + bases = self._get_typeshed_bases(typ) + result |= bases + to_do |= bases + seen.add(typ) + return result + def _get_typeshed_bases(self, typ: Union[type, str]) -> Set[Union[type, str]]: base_values = self.ts_finder.get_bases_recursively(typ) return {base.typ for base in base_values if isinstance(base, TypedValue)} diff --git a/pyanalyze/test_type_object.py b/pyanalyze/test_type_object.py index 70a5d624..f4af01df 100644 --- a/pyanalyze/test_type_object.py +++ b/pyanalyze/test_type_object.py @@ -404,3 +404,8 @@ def capybara(): with open("x", "rb") as f: assert_type(f, io.BufferedReader) want_io(f) + + def pacarana(): + with open("x", "w+b") as f: + assert_type(f, io.BufferedRandom) + want_io(f) diff --git a/pyanalyze/typeshed.py b/pyanalyze/typeshed.py index f783a1ce..b3a3e8a8 100644 --- a/pyanalyze/typeshed.py +++ b/pyanalyze/typeshed.py @@ -361,6 +361,12 @@ def get_bases_recursively(self, typ: Union[type, str]) -> List[Value]: return bases def get_bases_for_fq_name(self, fq_name: str) -> Optional[List[Value]]: + if fq_name in ( + "typing.Generic", + "typing.Protocol", + "typing_extensions.Protocol", + ): + return [] info = self._get_info_for_name(fq_name) mod, _ = fq_name.rsplit(".", maxsplit=1) return self._get_bases_from_info(info, mod, fq_name)