From 4a6224ec5cfff68bf1e47bb2ee605dd5635e386f Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Fri, 28 Jan 2022 20:07:47 -0800 Subject: [PATCH] protocol cache (#450) From #430 --- docs/changelog.md | 1 + pyanalyze/name_check_visitor.py | 4 ++-- pyanalyze/signature.py | 18 ++++++++++++++++-- pyanalyze/type_object.py | 13 +++++++++++-- 4 files changed, 30 insertions(+), 6 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index f10d7e81..42aa2fa5 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,7 @@ ## Unreleased +- Cache decisions about whether classes implement protocols (#450) - Fix application of multiple suggested changes per file when an earlier change has added or removed lines (#449) - Treat `NoReturn` like `Any` in `**kwargs` calls (#446) diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index 814cede5..a2ca344d 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -2050,10 +2050,10 @@ def _maybe_record_usages_from_import(self, node: ast.ImportFrom) -> None: module_name = parent_module_name + "." + node.module else: module_name = parent_module_name + if module_name is None: + return module = sys.modules.get(module_name) if module is None: - if module_name is None: - return try: module = __import__(module_name) except Exception: diff --git a/pyanalyze/signature.py b/pyanalyze/signature.py index 09ab19e1..cc0c6fa5 100644 --- a/pyanalyze/signature.py +++ b/pyanalyze/signature.py @@ -377,7 +377,7 @@ class ParameterKind(enum.Enum): } -@dataclass +@dataclass(frozen=True) class SigParameter: """Represents a single parameter to a callable.""" @@ -411,7 +411,7 @@ class SigParameter: def __post_init__(self) -> None: # backward compatibility if self.default is EMPTY: - self.default = None + object.__setattr__(self, "default", None) def substitute_typevars(self, typevars: TypeVarMap) -> "SigParameter": return SigParameter( @@ -519,6 +519,20 @@ def __post_init__(self) -> None: ) self.validate() + def __hash__(self) -> int: + return hash( + ( + tuple(self.parameters.items()), + self.return_value, + self.impl, + self.callable, + self.is_asynq, + self.has_return_annotation, + self.allow_call, + self.evaluator, + ) + ) + def validate(self) -> None: seen_kinds = set() seen_with_default = set() diff --git a/pyanalyze/type_object.py b/pyanalyze/type_object.py index 1669af43..d5093321 100644 --- a/pyanalyze/type_object.py +++ b/pyanalyze/type_object.py @@ -5,7 +5,7 @@ """ from dataclasses import dataclass, field import inspect -from typing import Container, Set, Sequence, Union +from typing import Container, Dict, Set, Sequence, Union from unittest import mock from .safe import safe_isinstance, safe_issubclass, safe_in @@ -15,6 +15,7 @@ CanAssignContext, CanAssignError, KnownValue, + TypeVarMap, TypedValue, Value, stringify_object, @@ -42,6 +43,9 @@ class TypeObject: protocol_members: Set[str] = field(default_factory=set) is_thrift_enum: bool = field(init=False) is_universally_assignable: bool = field(init=False) + _protocol_positive_cache: Dict[Value, TypeVarMap] = field( + default_factory=dict, repr=False + ) def __post_init__(self) -> None: if isinstance(self.typ, str): @@ -121,6 +125,9 @@ def can_assign( return CanAssignError( f"Cannot assign super object {other_val} to protocol {self}" ) + tv_map = self._protocol_positive_cache.get(other_val) + if tv_map is not None: + return tv_map # This is a guard against infinite recursion if the Protocol is recursive if ctx.can_assume_compatibility(self, other): return {} @@ -141,7 +148,9 @@ def can_assign( f"Value of protocol member {member!r} conflicts", [tv_map] ) tv_maps.append(tv_map) - return unify_typevar_maps(tv_maps) + result = unify_typevar_maps(tv_maps) + self._protocol_positive_cache[other_val] = result + return result def is_instance(self, obj: object) -> bool: """Whether obj is an instance of this type."""