Skip to content

Commit

Permalink
narrow constraints to a union of all possible conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
DetachHead committed Nov 21, 2024
1 parent d1ba0a6 commit 1a71f1d
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 37 deletions.
67 changes: 38 additions & 29 deletions packages/pyright-internal/src/analyzer/typeGuards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1139,42 +1139,51 @@ export function getIsInstanceClassTypes(
// Create a helper function that returns a list of class types or
// undefined if any of the types are not valid.
const addClassTypesToList = (types: Type[]) => {
types.forEach((subtype) => {
if (isClass(subtype)) {
evaluator.inferVarianceForClass(subtype);
subtype = specializeWithUnknownTypeArgs(
subtype,
types.forEach((type) => {
const subtypes: Type[] = [];
if (isClass(type)) {
evaluator.inferVarianceForClass(type);
type = specializeWithUnknownTypeArgs(
type,
evaluator.getTupleClassType(),
useVarianceForSpecialization ? evaluator.getObjectType() : undefined
);

if (isInstantiableClass(subtype) && ClassType.isBuiltIn(subtype, 'Callable')) {
subtype = convertToInstantiable(getUnknownTypeForCallable());
}
doForEachSubtype(type, (subtype) => {
if (isInstantiableClass(subtype) && ClassType.isBuiltIn(subtype, 'Callable')) {
subtypes.push(convertToInstantiable(getUnknownTypeForCallable()));
} else {
subtypes.push(subtype);
}
});
} else {
subtypes.push(type);
}

if (isInstantiableClass(subtype)) {
// If this is a reference to a class that has type promotions (e.g.
// float or complex), remove the promotions for purposes of the
// isinstance check).
if (!subtype.priv.includeSubclasses && subtype.priv.includePromotions) {
subtype = ClassType.cloneRemoveTypePromotions(subtype);
for (let subtype of subtypes) {
if (isInstantiableClass(subtype)) {
// If this is a reference to a class that has type promotions (e.g.
// float or complex), remove the promotions for purposes of the
// isinstance check).
if (!subtype.priv.includeSubclasses && subtype.priv.includePromotions) {
subtype = ClassType.cloneRemoveTypePromotions(subtype);
}
classTypeList.push(subtype);
} else if (isTypeVar(subtype) && TypeBase.isInstantiable(subtype)) {
classTypeList.push(subtype);
} else if (isNoneTypeClass(subtype)) {
assert(isInstantiableClass(subtype));
classTypeList.push(subtype);
} else if (
isFunction(subtype) &&
subtype.shared.parameters.length === 2 &&
subtype.shared.parameters[0].category === ParamCategory.ArgsList &&
subtype.shared.parameters[1].category === ParamCategory.KwargsDict
) {
classTypeList.push(subtype);
} else {
foundNonClassType = true;
}
classTypeList.push(subtype);
} else if (isTypeVar(subtype) && TypeBase.isInstantiable(subtype)) {
classTypeList.push(subtype);
} else if (isNoneTypeClass(subtype)) {
assert(isInstantiableClass(subtype));
classTypeList.push(subtype);
} else if (
isFunction(subtype) &&
subtype.shared.parameters.length === 2 &&
subtype.shared.parameters[0].category === ParamCategory.ArgsList &&
subtype.shared.parameters[1].category === ParamCategory.KwargsDict
) {
classTypeList.push(subtype);
} else {
foundNonClassType = true;
}
});
};
Expand Down
43 changes: 35 additions & 8 deletions packages/pyright-internal/src/analyzer/typeUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* Functions that operate on Type objects.
*/

import { appendArray } from '../common/collectionUtils';
import { allCombinations, appendArray } from '../common/collectionUtils';
import { assert } from '../common/debug';
import { ParamCategory } from '../parser/parseNodes';
import { ConstraintSolution, ConstraintSolutionSet } from './constraintSolution';
Expand Down Expand Up @@ -1105,7 +1105,7 @@ export function specializeWithUnknownTypeArgs(
type: ClassType,
tupleClassType?: ClassType,
objectTypeForVarianceCheck?: Type
): ClassType {
): ClassType | UnionType {
if (type.shared.typeParams.length === 0) {
return type;
}
Expand All @@ -1121,12 +1121,39 @@ export function specializeWithUnknownTypeArgs(
);
}

return ClassType.specialize(
type,
type.shared.typeParams.map((param) => getUnknownForTypeVar(param, tupleClassType, objectTypeForVarianceCheck)),
/* isTypeArgExplicit */ false,
/* includeSubclasses */ type.priv.includeSubclasses
);
const result = UnionType.create();
const constraintCombinations = new Array<Type[]>();

// since constraints can't be specialized, we create a union of every possible combination of constraints
// instead of specializing them or leaving them as Unknown (cringe)
for (const typeParam of type.shared.typeParams) {
const currentConstraints = new Array<Type>();
constraintCombinations.push(currentConstraints);
if (typeParam.shared.constraints.length) {
for (const constraint of typeParam.shared.constraints) {
currentConstraints.push(constraint);
}
} else {
currentConstraints.push(getUnknownForTypeVar(typeParam, tupleClassType, objectTypeForVarianceCheck));
}
}
for (const typeVarsToSpecialize of allCombinations(constraintCombinations)) {
UnionType.addType(
result,
ClassType.specialize(
type,
typeVarsToSpecialize,
/* isTypeArgExplicit */ false,
/* includeSubclasses */ type.priv.includeSubclasses
)
);
}
if (result.priv.subtypes.length === 1) {
// convert it back to a ClassType if there's only one type, because there's places where the result is checked
// that don't account for unions and we want to minimize the risk of breaking things
return result.priv.subtypes[0] as ClassType;
}
return result;
}

/**
Expand Down
14 changes: 14 additions & 0 deletions packages/pyright-internal/src/common/collectionUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,17 @@ export function arrayEquals<T>(c1: T[], c2: T[], predicate: (e1: T, e2: T) => bo

return c1.every((v, i) => predicate(v, c2[i]));
}

export const allCombinations = <T>(items: T[][]): T[][] => {
const [head, ...tail] = items;
if (!tail.length) {
return head.map((value) => [value]);
}
const result: T[][] = [];
for (const item1 of head) {
for (const item2 of allCombinations(tail)) {
result.push([item1, ...item2]);
}
}
return result;
};
56 changes: 56 additions & 0 deletions packages/pyright-internal/src/tests/collectionUtils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,59 @@ class D extends B {
this.name = name;
}
}

describe('allCombinations', () => {
test('2', () => {
assert.deepEqual(
utils.allCombinations([
['int', 'str'],
['dict', 'list'],
]),
[
['int', 'dict'],
['int', 'list'],
['str', 'dict'],
['str', 'list'],
]
);
});
test('3', () => {
assert.deepEqual(
utils.allCombinations([
['int', 'str'],
['dict', 'list'],
['asdf', 'fdsa'],
]),
[
['int', 'dict', 'asdf'],
['int', 'dict', 'fdsa'],
['int', 'list', 'asdf'],
['int', 'list', 'fdsa'],
['str', 'dict', 'asdf'],
['str', 'dict', 'fdsa'],
['str', 'list', 'asdf'],
['str', 'list', 'fdsa'],
]
);
});
describe('varying lengths', () => {
test('2 and 1', () => {
assert.deepEqual(utils.allCombinations([['int', 'str'], ['dict', 'list'], ['asdf']]), [
['int', 'dict', 'asdf'],
['int', 'list', 'asdf'],
['str', 'dict', 'asdf'],
['str', 'list', 'asdf'],
]);
});
test('3 and 2 and 1', () => {
assert.deepEqual(utils.allCombinations([['int', 'str', 'fdsa'], ['dict', 'list'], ['asdf']]), [
['int', 'dict', 'asdf'],
['int', 'list', 'asdf'],
['str', 'dict', 'asdf'],
['str', 'list', 'asdf'],
['fdsa', 'dict', 'asdf'],
['fdsa', 'list', 'asdf'],
]);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,10 @@ def partially_unknown(self, value=None):
def goo[KT, VT](self: MutableMapping[KT, VT]) -> Iterator[KT]:
assert isinstance(self, Reversible)
return reversed(self)

class Constraints[T: (int, str), U: (int, str), V: int]:
...

def _(value: object):
if isinstance(value, Constraints):
assert_type(value, Constraints[int, int, int] | Constraints[int, str, int] | Constraints[str, int, int] | Constraints[str, str, int])

0 comments on commit 1a71f1d

Please sign in to comment.