diff --git a/src/Analyser/TypeSpecifier.php b/src/Analyser/TypeSpecifier.php index ef69051f2a..fc43f83499 100644 --- a/src/Analyser/TypeSpecifier.php +++ b/src/Analyser/TypeSpecifier.php @@ -40,7 +40,6 @@ use PHPStan\Type\Constant\ConstantStringType; use PHPStan\Type\ConstantScalarType; use PHPStan\Type\ConstantType; -use PHPStan\Type\Enum\EnumCaseObjectType; use PHPStan\Type\FloatType; use PHPStan\Type\FunctionTypeSpecifyingExtension; use PHPStan\Type\Generic\GenericClassStringType; @@ -232,7 +231,7 @@ public function specifyTypesInCondition( $exprRightType = $scope->getType($expr->right); if ( $exprLeftType instanceof ConstantScalarType - || $exprLeftType instanceof EnumCaseObjectType + || count($exprLeftType->getEnumCases()) === 1 || ($exprLeftType instanceof ConstantType && !$exprRightType->equals($exprLeftType) && $exprRightType->isSuperTypeOf($exprLeftType)->yes()) ) { $types = $this->create( @@ -246,7 +245,7 @@ public function specifyTypesInCondition( } if ( $exprRightType instanceof ConstantScalarType - || $exprRightType instanceof EnumCaseObjectType + || count($exprRightType->getEnumCases()) === 1 || ($exprRightType instanceof ConstantType && !$exprLeftType->equals($exprRightType) && $exprLeftType->isSuperTypeOf($exprRightType)->yes()) ) { $leftType = $this->create( diff --git a/src/Rules/Comparison/MatchExpressionRule.php b/src/Rules/Comparison/MatchExpressionRule.php index e9081d8e7b..14ea973ec3 100644 --- a/src/Rules/Comparison/MatchExpressionRule.php +++ b/src/Rules/Comparison/MatchExpressionRule.php @@ -9,19 +9,13 @@ use PHPStan\Rules\Rule; use PHPStan\Rules\RuleErrorBuilder; use PHPStan\Type\Constant\ConstantBooleanType; -use PHPStan\Type\Enum\EnumCaseObjectType; use PHPStan\Type\NeverType; use PHPStan\Type\ObjectType; -use PHPStan\Type\SubtractableType; use PHPStan\Type\TypeCombinator; -use PHPStan\Type\TypeUtils; -use PHPStan\Type\TypeWithClassName; use PHPStan\Type\UnionType; use PHPStan\Type\VerbosityLevel; use UnhandledMatchError; -use function array_keys; use function array_map; -use function array_values; use function count; use function sprintf; @@ -92,44 +86,13 @@ public function processNode(Node $node, Scope $scope): array if (!$hasDefault && !$nextArmIsDead) { $remainingType = $node->getEndScope()->getType($matchCondition); - if ($remainingType instanceof TypeWithClassName && $remainingType instanceof SubtractableType) { - $subtractedType = $remainingType->getSubtractedType(); - if ($subtractedType !== null && $remainingType->getClassReflection() !== null) { - $classReflection = $remainingType->getClassReflection(); - if ($classReflection->isEnum()) { - $cases = []; - foreach (array_keys($classReflection->getEnumCases()) as $name) { - $cases[$name] = new EnumCaseObjectType($classReflection->getName(), $name); - } - - $subtractedTypes = TypeUtils::flattenTypes($subtractedType); - $set = true; - foreach ($subtractedTypes as $subType) { - if (!$subType instanceof EnumCaseObjectType) { - $set = false; - break; - } - - if ($subType->getClassName() !== $classReflection->getName()) { - $set = false; - break; - } - - unset($cases[$subType->getEnumCaseName()]); - } - - $cases = array_values($cases); - $casesCount = count($cases); - if ($set) { - if ($casesCount > 1) { - $remainingType = new UnionType($cases); - } - if ($casesCount === 1) { - $remainingType = $cases[0]; - } - } - } - } + $cases = $remainingType->getEnumCases(); + $casesCount = count($cases); + if ($casesCount > 1) { + $remainingType = new UnionType($cases); + } + if ($casesCount === 1) { + $remainingType = $cases[0]; } if ( !$remainingType instanceof NeverType diff --git a/src/Type/Accessory/HasMethodType.php b/src/Type/Accessory/HasMethodType.php index e795257f3b..ccd1b782e2 100644 --- a/src/Type/Accessory/HasMethodType.php +++ b/src/Type/Accessory/HasMethodType.php @@ -143,6 +143,11 @@ public function getCallableParametersAcceptors(ClassMemberAccessAnswerer $scope) ]; } + public function getEnumCases(): array + { + return []; + } + public function traverse(callable $cb): Type { return $this; diff --git a/src/Type/Accessory/HasOffsetType.php b/src/Type/Accessory/HasOffsetType.php index e5bd3841fe..2172c0a08b 100644 --- a/src/Type/Accessory/HasOffsetType.php +++ b/src/Type/Accessory/HasOffsetType.php @@ -284,6 +284,11 @@ public function toArrayKey(): Type return new ErrorType(); } + public function getEnumCases(): array + { + return []; + } + public function traverse(callable $cb): Type { return $this; diff --git a/src/Type/Accessory/HasOffsetValueType.php b/src/Type/Accessory/HasOffsetValueType.php index 9cbef14e77..bded4e6dec 100644 --- a/src/Type/Accessory/HasOffsetValueType.php +++ b/src/Type/Accessory/HasOffsetValueType.php @@ -325,6 +325,11 @@ public function toArrayKey(): Type return new ErrorType(); } + public function getEnumCases(): array + { + return []; + } + public function traverse(callable $cb): Type { $newValueType = $cb($this->valueType); diff --git a/src/Type/Accessory/HasPropertyType.php b/src/Type/Accessory/HasPropertyType.php index 8793c2f4ce..19c3935a8d 100644 --- a/src/Type/Accessory/HasPropertyType.php +++ b/src/Type/Accessory/HasPropertyType.php @@ -104,6 +104,11 @@ public function getCallableParametersAcceptors(ClassMemberAccessAnswerer $scope) return [new TrivialParametersAcceptor()]; } + public function getEnumCases(): array + { + return []; + } + public function traverse(callable $cb): Type { return $this; diff --git a/src/Type/CallableType.php b/src/Type/CallableType.php index 2bf97afb0a..48ef36e11c 100644 --- a/src/Type/CallableType.php +++ b/src/Type/CallableType.php @@ -397,6 +397,11 @@ public function isScalar(): TrinaryLogic return TrinaryLogic::createMaybe(); } + public function getEnumCases(): array + { + return []; + } + public function isCommonCallable(): bool { return $this->isCommonCallable; diff --git a/src/Type/ClosureType.php b/src/Type/ClosureType.php index b7fdf5d901..b87ee2709d 100644 --- a/src/Type/ClosureType.php +++ b/src/Type/ClosureType.php @@ -250,6 +250,11 @@ public function isCallable(): TrinaryLogic return TrinaryLogic::createYes(); } + public function getEnumCases(): array + { + return []; + } + /** * @return ParametersAcceptor[] */ diff --git a/src/Type/Enum/EnumCaseObjectType.php b/src/Type/Enum/EnumCaseObjectType.php index cda6d91a18..419f00f142 100644 --- a/src/Type/Enum/EnumCaseObjectType.php +++ b/src/Type/Enum/EnumCaseObjectType.php @@ -145,6 +145,11 @@ public function isSmallerThanOrEqual(Type $otherType): TrinaryLogic return TrinaryLogic::createNo(); } + public function getEnumCases(): array + { + return [$this]; + } + /** * @param mixed[] $properties */ diff --git a/src/Type/IntersectionType.php b/src/Type/IntersectionType.php index ed38236134..f23abd6001 100644 --- a/src/Type/IntersectionType.php +++ b/src/Type/IntersectionType.php @@ -639,6 +639,18 @@ public function shuffleArray(): Type return $this->intersectTypes(static fn (Type $type): Type => $type->shuffleArray()); } + public function getEnumCases(): array + { + $enumCases = []; + foreach ($this->types as $type) { + foreach ($type->getEnumCases() as $enumCase) { + $enumCases[] = $enumCase; + } + } + + return $enumCases; + } + public function isCallable(): TrinaryLogic { return $this->intersectResults(static fn (Type $type): TrinaryLogic => $type->isCallable()); diff --git a/src/Type/IterableType.php b/src/Type/IterableType.php index 966731872d..b66a143e58 100644 --- a/src/Type/IterableType.php +++ b/src/Type/IterableType.php @@ -335,6 +335,11 @@ public function isScalar(): TrinaryLogic return TrinaryLogic::createNo(); } + public function getEnumCases(): array + { + return []; + } + public function inferTemplateTypes(Type $receivedType): TemplateTypeMap { if ($receivedType instanceof UnionType || $receivedType instanceof IntersectionType) { diff --git a/src/Type/MixedType.php b/src/Type/MixedType.php index e547d6b2da..dcef5330ae 100644 --- a/src/Type/MixedType.php +++ b/src/Type/MixedType.php @@ -241,6 +241,11 @@ public function isCallable(): TrinaryLogic return TrinaryLogic::createMaybe(); } + public function getEnumCases(): array + { + return []; + } + /** * @return ParametersAcceptor[] */ diff --git a/src/Type/NeverType.php b/src/Type/NeverType.php index 14be4fa9c9..f04f312626 100644 --- a/src/Type/NeverType.php +++ b/src/Type/NeverType.php @@ -408,6 +408,11 @@ public function isScalar(): TrinaryLogic return TrinaryLogic::createNo(); } + public function getEnumCases(): array + { + return []; + } + /** * @param mixed[] $properties */ diff --git a/src/Type/NonexistentParentClassType.php b/src/Type/NonexistentParentClassType.php index b5b7061c1f..c2a5dd1054 100644 --- a/src/Type/NonexistentParentClassType.php +++ b/src/Type/NonexistentParentClassType.php @@ -139,6 +139,11 @@ public function isScalar(): TrinaryLogic return TrinaryLogic::createNo(); } + public function getEnumCases(): array + { + return []; + } + /** * @param mixed[] $properties */ diff --git a/src/Type/ObjectType.php b/src/Type/ObjectType.php index a78c806587..23d9243204 100644 --- a/src/Type/ObjectType.php +++ b/src/Type/ObjectType.php @@ -1039,6 +1039,35 @@ public function unsetOffset(Type $offsetType): Type return $this; } + public function getEnumCases(): array + { + $classReflection = $this->getClassReflection(); + if ($classReflection === null) { + return []; + } + + if (!$classReflection->isEnum()) { + return []; + } + + $subtracted = []; + if ($this->subtractedType !== null) { + foreach ($this->subtractedType->getEnumCases() as $enumCase) { + $subtracted[$enumCase->getEnumCaseName()] = true; + } + } + + $cases = []; + foreach ($classReflection->getEnumCases() as $enumCase) { + if (array_key_exists($enumCase->getName(), $subtracted)) { + continue; + } + $cases[] = new EnumCaseObjectType($classReflection->getName(), $enumCase->getName(), $classReflection); + } + + return $cases; + } + public function isCallable(): TrinaryLogic { $parametersAcceptors = $this->findCallableParametersAcceptors(); diff --git a/src/Type/ObjectWithoutClassType.php b/src/Type/ObjectWithoutClassType.php index ffac3045bb..a929c6d4f9 100644 --- a/src/Type/ObjectWithoutClassType.php +++ b/src/Type/ObjectWithoutClassType.php @@ -119,6 +119,11 @@ function () use ($level): string { ); } + public function getEnumCases(): array + { + return []; + } + public function subtract(Type $type): Type { if ($type instanceof self) { diff --git a/src/Type/StaticType.php b/src/Type/StaticType.php index b7aeddb2f3..4d7d79a204 100644 --- a/src/Type/StaticType.php +++ b/src/Type/StaticType.php @@ -411,6 +411,11 @@ public function isCallable(): TrinaryLogic return $this->getStaticObjectType()->isCallable(); } + public function getEnumCases(): array + { + return $this->getStaticObjectType()->getEnumCases(); + } + public function isArray(): TrinaryLogic { return $this->getStaticObjectType()->isArray(); diff --git a/src/Type/StrictMixedType.php b/src/Type/StrictMixedType.php index ed2271f833..972057f069 100644 --- a/src/Type/StrictMixedType.php +++ b/src/Type/StrictMixedType.php @@ -312,6 +312,11 @@ public function getReferencedTemplateTypes(TemplateTypeVariance $positionVarianc return []; } + public function getEnumCases(): array + { + return []; + } + public function traverse(callable $cb): Type { return $this; diff --git a/src/Type/Traits/LateResolvableTypeTrait.php b/src/Type/Traits/LateResolvableTypeTrait.php index 9a981fda8a..ff3fbfa2b6 100644 --- a/src/Type/Traits/LateResolvableTypeTrait.php +++ b/src/Type/Traits/LateResolvableTypeTrait.php @@ -260,6 +260,11 @@ public function isCallable(): TrinaryLogic return $this->resolve()->isCallable(); } + public function getEnumCases(): array + { + return $this->resolve()->getEnumCases(); + } + public function getCallableParametersAcceptors(ClassMemberAccessAnswerer $scope): array { return $this->resolve()->getCallableParametersAcceptors($scope); diff --git a/src/Type/Traits/NonObjectTypeTrait.php b/src/Type/Traits/NonObjectTypeTrait.php index 10d929ab44..0cc6cd4f37 100644 --- a/src/Type/Traits/NonObjectTypeTrait.php +++ b/src/Type/Traits/NonObjectTypeTrait.php @@ -79,4 +79,9 @@ public function isCloneable(): TrinaryLogic return TrinaryLogic::createNo(); } + public function getEnumCases(): array + { + return []; + } + } diff --git a/src/Type/Type.php b/src/Type/Type.php index 69886d58df..7d1183a1bc 100644 --- a/src/Type/Type.php +++ b/src/Type/Type.php @@ -12,6 +12,7 @@ use PHPStan\TrinaryLogic; use PHPStan\Type\Constant\ConstantArrayType; use PHPStan\Type\Constant\ConstantStringType; +use PHPStan\Type\Enum\EnumCaseObjectType; use PHPStan\Type\Generic\TemplateTypeMap; use PHPStan\Type\Generic\TemplateTypeReference; use PHPStan\Type\Generic\TemplateTypeVariance; @@ -118,6 +119,11 @@ public function shiftArray(): Type; public function shuffleArray(): Type; + /** + * @return list + */ + public function getEnumCases(): array; + public function isCallable(): TrinaryLogic; /** diff --git a/src/Type/UnionType.php b/src/Type/UnionType.php index 5e3913ee0e..02e28c03f5 100644 --- a/src/Type/UnionType.php +++ b/src/Type/UnionType.php @@ -622,6 +622,11 @@ public function shuffleArray(): Type return $this->unionTypes(static fn (Type $type): Type => $type->shuffleArray()); } + public function getEnumCases(): array + { + return $this->pickTypes(static fn (Type $type) => $type->getEnumCases()); + } + public function isCallable(): TrinaryLogic { return $this->unionResults(static fn (Type $type): TrinaryLogic => $type->isCallable()); diff --git a/tests/PHPStan/Type/ObjectTypeTest.php b/tests/PHPStan/Type/ObjectTypeTest.php index 3cb3212860..e635d4863d 100644 --- a/tests/PHPStan/Type/ObjectTypeTest.php +++ b/tests/PHPStan/Type/ObjectTypeTest.php @@ -16,12 +16,14 @@ use InvalidArgumentException; use Iterator; use LogicException; +use ObjectTypeEnums\FooEnum; use PHPStan\Testing\PHPStanTestCase; use PHPStan\TrinaryLogic; use PHPStan\Type\Accessory\HasMethodType; use PHPStan\Type\Accessory\HasPropertyType; use PHPStan\Type\Constant\ConstantIntegerType; use PHPStan\Type\Constant\ConstantStringType; +use PHPStan\Type\Enum\EnumCaseObjectType; use PHPStan\Type\Generic\GenericObjectType; use PHPStan\Type\Generic\TemplateTypeFactory; use PHPStan\Type\Generic\TemplateTypeScope; @@ -32,6 +34,7 @@ use Throwable; use ThrowPoints\TryCatch\MyInvalidArgumentException; use Traversable; +use function count; use function sprintf; use const PHP_VERSION_ID; @@ -582,4 +585,53 @@ public function testHasOffsetValueType( ); } + public function dataGetEnumCases(): iterable + { + yield [ + new ObjectType(stdClass::class), + [], + ]; + + yield [ + new ObjectType(FooEnum::class), + [ + new EnumCaseObjectType(FooEnum::class, 'FOO'), + new EnumCaseObjectType(FooEnum::class, 'BAR'), + new EnumCaseObjectType(FooEnum::class, 'BAZ'), + ], + ]; + + yield [ + new ObjectType(FooEnum::class, new EnumCaseObjectType(FooEnum::class, 'FOO')), + [ + new EnumCaseObjectType(FooEnum::class, 'BAR'), + new EnumCaseObjectType(FooEnum::class, 'BAZ'), + ], + ]; + + yield [ + new ObjectType(FooEnum::class, new UnionType([new EnumCaseObjectType(FooEnum::class, 'FOO'), new EnumCaseObjectType(FooEnum::class, 'BAR')])), + [ + new EnumCaseObjectType(FooEnum::class, 'BAZ'), + ], + ]; + } + + /** + * @dataProvider dataGetEnumCases + * @param list $expectedEnumCases + */ + public function testGetEnumCases( + ObjectType $type, + array $expectedEnumCases, + ): void + { + $enumCases = $type->getEnumCases(); + $this->assertCount(count($expectedEnumCases), $enumCases); + foreach ($enumCases as $i => $enumCase) { + $expectedEnumCase = $expectedEnumCases[$i]; + $this->assertTrue($expectedEnumCase->equals($enumCase), sprintf('%s->equals(%s)', $expectedEnumCase->describe(VerbosityLevel::precise()), $enumCase->describe(VerbosityLevel::precise()))); + } + } + } diff --git a/tests/PHPStan/Type/data/ObjectTypeEnums.php b/tests/PHPStan/Type/data/ObjectTypeEnums.php new file mode 100644 index 0000000000..d190bf7497 --- /dev/null +++ b/tests/PHPStan/Type/data/ObjectTypeEnums.php @@ -0,0 +1,12 @@ +