diff --git a/docs/changelog.md b/docs/changelog.md index 02c99b44..772d173e 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,8 @@ ## Unreleased +- Fix type narrowing on the `else` case of `issubclass()` + (#401) - Fix indexing a list with an index typed as a `TypeVar` (#400) - Fix "This function should have an @asynq() decorator" diff --git a/pyanalyze/implementation.py b/pyanalyze/implementation.py index 428aec55..1f57fbff 100644 --- a/pyanalyze/implementation.py +++ b/pyanalyze/implementation.py @@ -2,6 +2,7 @@ from .error_code import ErrorCode from .extensions import reveal_type from .format_strings import parse_format_string +from .predicates import IsAssignablePredicate from .safe import safe_hasattr, safe_isinstance, safe_issubclass from .stacked_scopes import ( NULL_CONSTRAINT, @@ -12,6 +13,7 @@ PredicateProvider, OrConstraint, VarnameWithOrigin, + annotate_with_constraint, ) from .signature import ( ANY_SIGNATURE, @@ -96,20 +98,21 @@ def flatten_unions( def _issubclass_impl(ctx: CallContext) -> Value: class_or_tuple = ctx.vars["class_or_tuple"] - extension = None - if isinstance(class_or_tuple, KnownValue): - if isinstance(class_or_tuple.val, type): - extension = ParameterTypeGuardExtension( - "cls", SubclassValue(TypedValue(class_or_tuple.val)) - ) - elif isinstance(class_or_tuple.val, tuple) and all( - isinstance(elt, type) for elt in class_or_tuple.val - ): - vals = [SubclassValue(TypedValue(elt)) for elt in class_or_tuple.val] - extension = ParameterTypeGuardExtension("cls", MultiValuedValue(vals)) - if extension is not None: - return AnnotatedValue(TypedValue(bool), [extension]) - return TypedValue(bool) + varname = ctx.varname_for_arg("cls") + if varname is None or not isinstance(class_or_tuple, KnownValue): + return TypedValue(bool) + if isinstance(class_or_tuple.val, type): + narrowed_type = SubclassValue(TypedValue(class_or_tuple.val)) + elif isinstance(class_or_tuple.val, tuple) and all( + isinstance(elt, type) for elt in class_or_tuple.val + ): + vals = [SubclassValue(TypedValue(elt)) for elt in class_or_tuple.val] + narrowed_type = unite_values(*vals) + else: + return TypedValue(bool) + predicate = IsAssignablePredicate(narrowed_type, ctx.visitor, positive_only=False) + constraint = Constraint(varname, ConstraintType.predicate, True, predicate) + return annotate_with_constraint(TypedValue(bool), constraint) def _isinstance_impl(ctx: CallContext) -> ImplReturn: diff --git a/pyanalyze/test_implementation.py b/pyanalyze/test_implementation.py index 8160f558..619ce329 100644 --- a/pyanalyze/test_implementation.py +++ b/pyanalyze/test_implementation.py @@ -1147,6 +1147,19 @@ def capybara(x: type, y): ), ) + @assert_passes() + def test_negative_narrowing(self) -> None: + from typing import Type, Union + + def capybara(x: Union[Type[str], Type[int]]) -> None: + assert_is_value( + x, SubclassValue(TypedValue(str)) | SubclassValue(TypedValue(int)) + ) + if issubclass(x, str): + assert_is_value(x, SubclassValue(TypedValue(str))) + else: + assert_is_value(x, SubclassValue(TypedValue(int))) + class TestInferenceHelpers(TestNameCheckVisitorBase): @assert_passes()