diff --git a/packages/pyright-internal/src/analyzer/patternMatching.ts b/packages/pyright-internal/src/analyzer/patternMatching.ts index 45de6579fa..c64852dae4 100644 --- a/packages/pyright-internal/src/analyzer/patternMatching.ts +++ b/packages/pyright-internal/src/analyzer/patternMatching.ts @@ -83,6 +83,7 @@ import { mapSubtypes, partiallySpecializeType, preserveUnknown, + shouldUseVarianceForSpecialization, specializeTupleClass, specializeWithUnknownTypeArgs, transformPossibleRecursiveTypeAlias, @@ -761,9 +762,7 @@ function narrowTypeBasedOnClassPattern( exprType = specializeWithUnknownTypeArgs( exprType, evaluator.getTupleClassType(), - // for backwards compatibility with bacly typed code, we don't specialize using variance if the type we're - // narrowing is Any/Unknown - isAnyOrUnknown(type) || isPartlyUnknown(type) ? undefined : evaluator.getObjectType() + shouldUseVarianceForSpecialization(type) ? evaluator.getObjectType() : undefined ); } diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index c1d0c964d5..c830735fa3 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -82,7 +82,6 @@ import { isMetaclassInstance, isNoneInstance, isNoneTypeClass, - isPartlyUnknown, isProperty, isTupleClass, isTupleGradualForm, @@ -92,6 +91,7 @@ import { makeTypeVarsFree, mapSubtypes, MemberAccessFlags, + shouldUseVarianceForSpecialization, specializeTupleClass, specializeWithUnknownTypeArgs, stripTypeForm, @@ -1135,11 +1135,7 @@ export function getIsInstanceClassTypes( ): (ClassType | TypeVarType | FunctionType)[] | undefined { let foundNonClassType = false; const classTypeList: (ClassType | TypeVarType | FunctionType)[] = []; - /** - * if the type we're narrowing is Any or Unknown, we don't want to specialize using the - * variance/bound for compatibility with less strictly typed code (cringe) - */ - const useVarianceForSpecialization = !isAnyOrUnknown(typeToNarrow) && !isPartlyUnknown(typeToNarrow); + const useVarianceForSpecialization = shouldUseVarianceForSpecialization(typeToNarrow); // Create a helper function that returns a list of class types or // undefined if any of the types are not valid. const addClassTypesToList = (types: Type[]) => { diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index 3116faf599..4e394a8d85 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -1053,6 +1053,17 @@ export function getTypeVarScopeIds(type: Type): TypeVarScopeId[] { return scopeIds; } +/** + * if the type we're narrowing is Any or Unknown, we don't want to specialize using the + * variance/bound for compatibility with less strictly typed code (cringe) + */ +export const shouldUseVarianceForSpecialization = (type: Type) => + !isAnyOrUnknown(type) && + !isPartlyUnknown(type) && + // TODO: this logic should probably be moved into `isAny`/`isUnknown` or something, + // to fix issues like https://github.com/DetachHead/basedpyright/issues/746 + (type.category !== TypeCategory.TypeVar || !type.shared.isSynthesized); + /** * Specializes the class with "Unknown" type args (or the equivalent for ParamSpecs or TypeVarTuples), or its * widest possible type if its variance is known and {@link objectTypeForVarianceCheck} is provided (`object` if diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBounds.py b/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBounds.py index c7e2f68a29..1b00ee6f78 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBounds.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBounds.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from typing import Any, Never, assert_type, Iterable @@ -44,15 +45,25 @@ def foo(value: object): class AnyOrUnknown: """for backwards compatibility with badly typed code we keep the old functionality when narrowing `Any`/Unknown""" - def foo(self, value: Any): + def __init__(self, value): + """arguments in `__init__` get turned into fake type vars if they're untyped, so we need to handle this case. + see https://github.com/DetachHead/basedpyright/issues/746""" if isinstance(value, Iterable): assert_type(value, Iterable[Any]) - def bar(self, value: Any): + def any(self, value: Any): + if isinstance(value, Iterable): + assert_type(value, Iterable[Any]) + + def match_case(self, value: Any): match value: case Iterable(): assert_type(value, Iterable[Any]) + def unknown(self, value): + if isinstance(value, Iterable): + assert_type(value, Iterable[Any]) + def partially_unknown(self, value=None): if isinstance(value, Iterable): assert_type(value, Iterable[Any])