Skip to content

Commit

Permalink
protocol cache (#450)
Browse files Browse the repository at this point in the history
From #430
  • Loading branch information
JelleZijlstra authored Jan 29, 2022
1 parent b201328 commit 4a6224e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Cache decisions about whether classes implement protocols (#450)
- Fix application of multiple suggested changes per file
when an earlier change has added or removed lines (#449)
- Treat `NoReturn` like `Any` in `**kwargs` calls (#446)
Expand Down
4 changes: 2 additions & 2 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,10 +2050,10 @@ def _maybe_record_usages_from_import(self, node: ast.ImportFrom) -> None:
module_name = parent_module_name + "." + node.module
else:
module_name = parent_module_name
if module_name is None:
return
module = sys.modules.get(module_name)
if module is None:
if module_name is None:
return
try:
module = __import__(module_name)
except Exception:
Expand Down
18 changes: 16 additions & 2 deletions pyanalyze/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ class ParameterKind(enum.Enum):
}


@dataclass
@dataclass(frozen=True)
class SigParameter:
"""Represents a single parameter to a callable."""

Expand Down Expand Up @@ -411,7 +411,7 @@ class SigParameter:
def __post_init__(self) -> None:
# backward compatibility
if self.default is EMPTY:
self.default = None
object.__setattr__(self, "default", None)

def substitute_typevars(self, typevars: TypeVarMap) -> "SigParameter":
return SigParameter(
Expand Down Expand Up @@ -519,6 +519,20 @@ def __post_init__(self) -> None:
)
self.validate()

def __hash__(self) -> int:
return hash(
(
tuple(self.parameters.items()),
self.return_value,
self.impl,
self.callable,
self.is_asynq,
self.has_return_annotation,
self.allow_call,
self.evaluator,
)
)

def validate(self) -> None:
seen_kinds = set()
seen_with_default = set()
Expand Down
13 changes: 11 additions & 2 deletions pyanalyze/type_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
from dataclasses import dataclass, field
import inspect
from typing import Container, Set, Sequence, Union
from typing import Container, Dict, Set, Sequence, Union
from unittest import mock

from .safe import safe_isinstance, safe_issubclass, safe_in
Expand All @@ -15,6 +15,7 @@
CanAssignContext,
CanAssignError,
KnownValue,
TypeVarMap,
TypedValue,
Value,
stringify_object,
Expand Down Expand Up @@ -42,6 +43,9 @@ class TypeObject:
protocol_members: Set[str] = field(default_factory=set)
is_thrift_enum: bool = field(init=False)
is_universally_assignable: bool = field(init=False)
_protocol_positive_cache: Dict[Value, TypeVarMap] = field(
default_factory=dict, repr=False
)

def __post_init__(self) -> None:
if isinstance(self.typ, str):
Expand Down Expand Up @@ -121,6 +125,9 @@ def can_assign(
return CanAssignError(
f"Cannot assign super object {other_val} to protocol {self}"
)
tv_map = self._protocol_positive_cache.get(other_val)
if tv_map is not None:
return tv_map
# This is a guard against infinite recursion if the Protocol is recursive
if ctx.can_assume_compatibility(self, other):
return {}
Expand All @@ -141,7 +148,9 @@ def can_assign(
f"Value of protocol member {member!r} conflicts", [tv_map]
)
tv_maps.append(tv_map)
return unify_typevar_maps(tv_maps)
result = unify_typevar_maps(tv_maps)
self._protocol_positive_cache[other_val] = result
return result

def is_instance(self, obj: object) -> bool:
"""Whether obj is an instance of this type."""
Expand Down

0 comments on commit 4a6224e

Please sign in to comment.