diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index afbaca9cfe..2937a8c1c2 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -1139,42 +1139,51 @@ export function getIsInstanceClassTypes( // Create a helper function that returns a list of class types or // undefined if any of the types are not valid. const addClassTypesToList = (types: Type[]) => { - types.forEach((subtype) => { - if (isClass(subtype)) { - evaluator.inferVarianceForClass(subtype); - subtype = specializeWithUnknownTypeArgs( - subtype, + types.forEach((type) => { + const subtypes: Type[] = []; + if (isClass(type)) { + evaluator.inferVarianceForClass(type); + type = specializeWithUnknownTypeArgs( + type, evaluator.getTupleClassType(), useVarianceForSpecialization ? evaluator.getObjectType() : undefined ); - if (isInstantiableClass(subtype) && ClassType.isBuiltIn(subtype, 'Callable')) { - subtype = convertToInstantiable(getUnknownTypeForCallable()); - } + doForEachSubtype(type, (subtype) => { + if (isInstantiableClass(subtype) && ClassType.isBuiltIn(subtype, 'Callable')) { + subtypes.push(convertToInstantiable(getUnknownTypeForCallable())); + } else { + subtypes.push(subtype); + } + }); + } else { + subtypes.push(type); } - if (isInstantiableClass(subtype)) { - // If this is a reference to a class that has type promotions (e.g. - // float or complex), remove the promotions for purposes of the - // isinstance check). - if (!subtype.priv.includeSubclasses && subtype.priv.includePromotions) { - subtype = ClassType.cloneRemoveTypePromotions(subtype); + for (let subtype of subtypes) { + if (isInstantiableClass(subtype)) { + // If this is a reference to a class that has type promotions (e.g. + // float or complex), remove the promotions for purposes of the + // isinstance check). + if (!subtype.priv.includeSubclasses && subtype.priv.includePromotions) { + subtype = ClassType.cloneRemoveTypePromotions(subtype); + } + classTypeList.push(subtype); + } else if (isTypeVar(subtype) && TypeBase.isInstantiable(subtype)) { + classTypeList.push(subtype); + } else if (isNoneTypeClass(subtype)) { + assert(isInstantiableClass(subtype)); + classTypeList.push(subtype); + } else if ( + isFunction(subtype) && + subtype.shared.parameters.length === 2 && + subtype.shared.parameters[0].category === ParamCategory.ArgsList && + subtype.shared.parameters[1].category === ParamCategory.KwargsDict + ) { + classTypeList.push(subtype); + } else { + foundNonClassType = true; } - classTypeList.push(subtype); - } else if (isTypeVar(subtype) && TypeBase.isInstantiable(subtype)) { - classTypeList.push(subtype); - } else if (isNoneTypeClass(subtype)) { - assert(isInstantiableClass(subtype)); - classTypeList.push(subtype); - } else if ( - isFunction(subtype) && - subtype.shared.parameters.length === 2 && - subtype.shared.parameters[0].category === ParamCategory.ArgsList && - subtype.shared.parameters[1].category === ParamCategory.KwargsDict - ) { - classTypeList.push(subtype); - } else { - foundNonClassType = true; } }); }; diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index 54b5a6bc67..608464e42f 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -7,7 +7,7 @@ * Functions that operate on Type objects. */ -import { appendArray } from '../common/collectionUtils'; +import { allCombinations, appendArray } from '../common/collectionUtils'; import { assert } from '../common/debug'; import { ParamCategory } from '../parser/parseNodes'; import { ConstraintSolution, ConstraintSolutionSet } from './constraintSolution'; @@ -1105,7 +1105,7 @@ export function specializeWithUnknownTypeArgs( type: ClassType, tupleClassType?: ClassType, objectTypeForVarianceCheck?: Type -): ClassType { +): ClassType | UnionType { if (type.shared.typeParams.length === 0) { return type; } @@ -1121,12 +1121,39 @@ export function specializeWithUnknownTypeArgs( ); } - return ClassType.specialize( - type, - type.shared.typeParams.map((param) => getUnknownForTypeVar(param, tupleClassType, objectTypeForVarianceCheck)), - /* isTypeArgExplicit */ false, - /* includeSubclasses */ type.priv.includeSubclasses - ); + const result = UnionType.create(); + const constraintCombinations = new Array(); + + // since constraints can't be specialized, we create a union of every possible combination of constraints + // instead of specializing them or leaving them as Unknown (cringe) + for (const typeParam of type.shared.typeParams) { + const currentConstraints = new Array(); + constraintCombinations.push(currentConstraints); + if (typeParam.shared.constraints.length) { + for (const constraint of typeParam.shared.constraints) { + currentConstraints.push(constraint); + } + } else { + currentConstraints.push(getUnknownForTypeVar(typeParam, tupleClassType, objectTypeForVarianceCheck)); + } + } + for (const typeVarsToSpecialize of allCombinations(constraintCombinations)) { + UnionType.addType( + result, + ClassType.specialize( + type, + typeVarsToSpecialize, + /* isTypeArgExplicit */ false, + /* includeSubclasses */ type.priv.includeSubclasses + ) + ); + } + if (result.priv.subtypes.length === 1) { + // convert it back to a ClassType if there's only one type, because there's places where the result is checked + // that don't account for unions and we want to minimize the risk of breaking things + return result.priv.subtypes[0] as ClassType; + } + return result; } /** diff --git a/packages/pyright-internal/src/common/collectionUtils.ts b/packages/pyright-internal/src/common/collectionUtils.ts index 38d1272bf3..c9cfe33555 100644 --- a/packages/pyright-internal/src/common/collectionUtils.ts +++ b/packages/pyright-internal/src/common/collectionUtils.ts @@ -428,3 +428,17 @@ export function arrayEquals(c1: T[], c2: T[], predicate: (e1: T, e2: T) => bo return c1.every((v, i) => predicate(v, c2[i])); } + +export const allCombinations = (items: T[][]): T[][] => { + const [head, ...tail] = items; + if (!tail.length) { + return head.map((value) => [value]); + } + const result: T[][] = []; + for (const item1 of head) { + for (const item2 of allCombinations(tail)) { + result.push([item1, ...item2]); + } + } + return result; +}; diff --git a/packages/pyright-internal/src/tests/collectionUtils.test.ts b/packages/pyright-internal/src/tests/collectionUtils.test.ts index 0308aba099..67985c4401 100644 --- a/packages/pyright-internal/src/tests/collectionUtils.test.ts +++ b/packages/pyright-internal/src/tests/collectionUtils.test.ts @@ -176,3 +176,59 @@ class D extends B { this.name = name; } } + +describe('allCombinations', () => { + test('2', () => { + assert.deepEqual( + utils.allCombinations([ + ['int', 'str'], + ['dict', 'list'], + ]), + [ + ['int', 'dict'], + ['int', 'list'], + ['str', 'dict'], + ['str', 'list'], + ] + ); + }); + test('3', () => { + assert.deepEqual( + utils.allCombinations([ + ['int', 'str'], + ['dict', 'list'], + ['asdf', 'fdsa'], + ]), + [ + ['int', 'dict', 'asdf'], + ['int', 'dict', 'fdsa'], + ['int', 'list', 'asdf'], + ['int', 'list', 'fdsa'], + ['str', 'dict', 'asdf'], + ['str', 'dict', 'fdsa'], + ['str', 'list', 'asdf'], + ['str', 'list', 'fdsa'], + ] + ); + }); + describe('varying lengths', () => { + test('2 and 1', () => { + assert.deepEqual(utils.allCombinations([['int', 'str'], ['dict', 'list'], ['asdf']]), [ + ['int', 'dict', 'asdf'], + ['int', 'list', 'asdf'], + ['str', 'dict', 'asdf'], + ['str', 'list', 'asdf'], + ]); + }); + test('3 and 2 and 1', () => { + assert.deepEqual(utils.allCombinations([['int', 'str', 'fdsa'], ['dict', 'list'], ['asdf']]), [ + ['int', 'dict', 'asdf'], + ['int', 'list', 'asdf'], + ['str', 'dict', 'asdf'], + ['str', 'list', 'asdf'], + ['fdsa', 'dict', 'asdf'], + ['fdsa', 'list', 'asdf'], + ]); + }); + }); +}); diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBounds.py b/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBounds.py index cf15a0f39a..e754f88bc2 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBounds.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingUsingBounds.py @@ -99,3 +99,10 @@ def partially_unknown(self, value=None): def goo[KT, VT](self: MutableMapping[KT, VT]) -> Iterator[KT]: assert isinstance(self, Reversible) return reversed(self) + +class Constraints[T: (int, str), U: (int, str), V: int]: + ... + +def _(value: object): + if isinstance(value, Constraints): + assert_type(value, Constraints[int, int, int] | Constraints[int, str, int] | Constraints[str, int, int] | Constraints[str, str, int])