Skip to content

Commit

Permalink
Apply typeshed bases change recursively (#639)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored May 28, 2023
1 parent dd21852 commit e328f3e
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
and similar functions, which previously emitted
`inference_failure` (#636)
- Take into account additional base classes declared in stub
files (fixing some false positives around `typing.IO`) (#635)
files (fixing some false positives around `typing.IO`) (#635, #639)
- Fix crash on stubs that contain dict or set literals (#634)
- Remove more old special cases and improve robustness of
annotation parsing (#630)
Expand Down
18 changes: 17 additions & 1 deletion pyanalyze/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _build_type_object(self, typ: Union[type, super, str]) -> TypeObject:
return TypeObject(typ, self.get_additional_bases(typ))
else:
plugin_bases = self.get_additional_bases(typ)
typeshed_bases = self._get_typeshed_bases(typ)
typeshed_bases = self._get_recursive_typeshed_bases(typ)
additional_bases = plugin_bases | typeshed_bases
# Is it marked as a protocol in stubs? If so, use the stub definition.
if self.ts_finder.is_protocol(typ):
Expand All @@ -190,6 +190,22 @@ def _build_type_object(self, typ: Union[type, super, str]) -> TypeObject:

return TypeObject(typ, additional_bases)

def _get_recursive_typeshed_bases(
self, typ: Union[type, str]
) -> Set[Union[type, str]]:
seen = set()
to_do = {typ}
result = set()
while to_do:
typ = to_do.pop()
if typ in seen:
continue
bases = self._get_typeshed_bases(typ)
result |= bases
to_do |= bases
seen.add(typ)
return result

def _get_typeshed_bases(self, typ: Union[type, str]) -> Set[Union[type, str]]:
base_values = self.ts_finder.get_bases_recursively(typ)
return {base.typ for base in base_values if isinstance(base, TypedValue)}
Expand Down
5 changes: 5 additions & 0 deletions pyanalyze/test_type_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,8 @@ def capybara():
with open("x", "rb") as f:
assert_type(f, io.BufferedReader)
want_io(f)

def pacarana():
with open("x", "w+b") as f:
assert_type(f, io.BufferedRandom)
want_io(f)
6 changes: 6 additions & 0 deletions pyanalyze/typeshed.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ def get_bases_recursively(self, typ: Union[type, str]) -> List[Value]:
return bases

def get_bases_for_fq_name(self, fq_name: str) -> Optional[List[Value]]:
if fq_name in (
"typing.Generic",
"typing.Protocol",
"typing_extensions.Protocol",
):
return []
info = self._get_info_for_name(fq_name)
mod, _ = fq_name.rsplit(".", maxsplit=1)
return self._get_bases_from_info(info, mod, fq_name)
Expand Down

0 comments on commit e328f3e

Please sign in to comment.