Skip to content

Commit

Permalink
Add support for assert_type
Browse files Browse the repository at this point in the history
See python/cpython#30843.

The implementation mostly follows that of cast(). It relies on
`mypy.sametypes.is_same_type()`.
  • Loading branch information
JelleZijlstra committed Apr 14, 2022
1 parent 0c6b290 commit 22a8e4d
Show file tree
Hide file tree
Showing 24 changed files with 173 additions and 12 deletions.
10 changes: 9 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
get_proper_types, flatten_nested_unions, LITERAL_TYPE_NAMES,
)
from mypy.nodes import (
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
AssertTypeExpr, NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
MemberExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr,
OpExpr, UnaryExpr, IndexExpr, CastExpr, RevealExpr, TypeApplication, ListExpr,
TupleExpr, DictExpr, LambdaExpr, SuperExpr, SliceExpr, Context, Expression,
Expand Down Expand Up @@ -3144,6 +3144,14 @@ def visit_cast_expr(self, expr: CastExpr) -> Type:
context=expr)
return target_type

def visit_assert_type_expr(self, expr: AssertTypeExpr) -> Type:
source_type = self.accept(expr.expr, type_context=AnyType(TypeOfAny.special_form),
allow_none_return=True, always_allow_any=True)
target_type = expr.type
if not is_same_type(source_type, target_type):
self.msg.assert_type_fail(source_type, target_type, expr)
return source_type

def visit_reveal_expr(self, expr: RevealExpr) -> Type:
"""Type check a reveal_type expression."""
if expr.kind == REVEAL_TYPE:
Expand Down
3 changes: 3 additions & 0 deletions mypy/errorcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def __str__(self) -> str:
REDUNDANT_CAST: Final = ErrorCode(
"redundant-cast", "Check that cast changes type of expression", "General"
)
ASSERT_TYPE: Final = ErrorCode(
"assert-type", "Check that assert_type() call succeeds", "General"
)
COMPARISON_OVERLAP: Final = ErrorCode(
"comparison-overlap", "Check that types in comparisons and 'in' expressions overlap", "General"
)
Expand Down
5 changes: 4 additions & 1 deletion mypy/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing_extensions import Final

from mypy.nodes import (
Expression, ComparisonExpr, OpExpr, MemberExpr, UnaryExpr, StarExpr, IndexExpr, LITERAL_YES,
AssertTypeExpr, Expression, ComparisonExpr, OpExpr, MemberExpr, UnaryExpr, StarExpr, IndexExpr, LITERAL_YES,
LITERAL_NO, NameExpr, LITERAL_TYPE, IntExpr, FloatExpr, ComplexExpr, StrExpr, BytesExpr,
UnicodeExpr, ListExpr, TupleExpr, SetExpr, DictExpr, CallExpr, SliceExpr, CastExpr,
ConditionalExpr, EllipsisExpr, YieldFromExpr, YieldExpr, RevealExpr, SuperExpr,
Expand Down Expand Up @@ -175,6 +175,9 @@ def visit_slice_expr(self, e: SliceExpr) -> None:
def visit_cast_expr(self, e: CastExpr) -> None:
return None

def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
return None

def visit_conditional_expr(self, e: ConditionalExpr) -> None:
return None

Expand Down
5 changes: 5 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,11 @@ def redundant_cast(self, typ: Type, context: Context) -> None:
self.fail('Redundant cast to {}'.format(format_type(typ)), context,
code=codes.REDUNDANT_CAST)

def assert_type_fail(self, source_type: Type, target_type: Type, context: Context) -> None:
self.fail(f"Expression is of type {format_type(source_type)}, "
f"not {format_type(target_type)}", context,
code=codes.ASSERT_TYPE)

def unimported_type_becomes_any(self, prefix: str, typ: Type, ctx: Context) -> None:
self.fail("{} becomes {} due to an unfollowed import".format(prefix, format_type(typ)),
ctx, code=codes.NO_ANY_UNIMPORTED)
Expand Down
6 changes: 5 additions & 1 deletion mypy/mixedtraverser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from mypy.nodes import (
Var, FuncItem, ClassDef, AssignmentStmt, ForStmt, WithStmt,
AssertTypeExpr, Var, FuncItem, ClassDef, AssignmentStmt, ForStmt, WithStmt,
CastExpr, TypeApplication, TypeAliasExpr, TypeVarExpr, TypedDictExpr, NamedTupleExpr,
PromoteExpr, NewTypeExpr
)
Expand Down Expand Up @@ -79,6 +79,10 @@ def visit_cast_expr(self, o: CastExpr) -> None:
super().visit_cast_expr(o)
o.type.accept(self)

def visit_assert_type_expr(self, o: AssertTypeExpr) -> None:
super().visit_assert_type_expr(o)
o.type.accept(self)

def visit_type_application(self, o: TypeApplication) -> None:
super().visit_type_application(o)
for t in o.types:
Expand Down
16 changes: 16 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,6 +1945,22 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_cast_expr(self)


class AssertTypeExpr(Expression):
"""Represents a typing.assert_type(expr, type) call."""
__slots__ = ('expr', 'type')

expr: Expression
type: "mypy.types.Type"

def __init__(self, expr: Expression, typ: 'mypy.types.Type') -> None:
super().__init__()
self.expr = expr
self.type = typ

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_assert_type_expr(self)


class RevealExpr(Expression):
"""Reveal type expression reveal_type(expr) or reveal_locals() expression."""

Expand Down
23 changes: 21 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from typing_extensions import Final, TypeAlias as _TypeAlias

from mypy.nodes import (
MypyFile, TypeInfo, Node, AssignmentStmt, FuncDef, OverloadedFuncDef,
AssertTypeExpr, MypyFile, TypeInfo, Node, AssignmentStmt, FuncDef, OverloadedFuncDef,
ClassDef, Var, GDEF, FuncItem, Import, Expression, Lvalue,
ImportFrom, ImportAll, Block, LDEF, NameExpr, MemberExpr,
IndexExpr, TupleExpr, ListExpr, ExpressionStmt, ReturnStmt,
Expand Down Expand Up @@ -94,7 +94,7 @@
from mypy.errorcodes import ErrorCode
from mypy import message_registry, errorcodes as codes
from mypy.types import (
NEVER_NAMES, FunctionLike, UnboundType, TypeVarType, TupleType, UnionType, StarType,
ASSERT_TYPE_NAMES, NEVER_NAMES, FunctionLike, UnboundType, TypeVarType, TupleType, UnionType, StarType,
CallableType, Overloaded, Instance, Type, AnyType, LiteralType, LiteralValue,
TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType,
get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType, Parameters, ParamSpecType,
Expand Down Expand Up @@ -3897,6 +3897,19 @@ def visit_call_expr(self, expr: CallExpr) -> None:
expr.analyzed.line = expr.line
expr.analyzed.column = expr.column
expr.analyzed.accept(self)
elif refers_to_fullname(expr.callee, ASSERT_TYPE_NAMES):
if not self.check_fixed_args(expr, 2, 'assert_type'):
return
# Translate second argument to an unanalyzed type.
try:
target = self.expr_to_unanalyzed_type(expr.args[1])
except TypeTranslationError:
self.fail('assert_type() type is not a type', expr)
return
expr.analyzed = AssertTypeExpr(expr.args[0], target)
expr.analyzed.line = expr.line
expr.analyzed.column = expr.column
expr.analyzed.accept(self)
elif refers_to_fullname(expr.callee, REVEAL_TYPE_NAMES):
if not self.check_fixed_args(expr, 1, 'reveal_type'):
return
Expand Down Expand Up @@ -4200,6 +4213,12 @@ def visit_cast_expr(self, expr: CastExpr) -> None:
if analyzed is not None:
expr.type = analyzed

def visit_assert_type_expr(self, expr: AssertTypeExpr) -> None:
expr.expr.accept(self)
analyzed = self.anal_type(expr.type)
if analyzed is not None:
expr.type = analyzed

def visit_reveal_expr(self, expr: RevealExpr) -> None:
if expr.kind == REVEAL_TYPE:
if expr.expr is not None:
Expand Down
6 changes: 5 additions & 1 deletion mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from typing import Dict, List, cast, TypeVar, Optional

from mypy.nodes import (
MypyFile, SymbolTable, Block, AssignmentStmt, NameExpr, MemberExpr, RefExpr, TypeInfo,
AssertTypeExpr, MypyFile, SymbolTable, Block, AssignmentStmt, NameExpr, MemberExpr, RefExpr, TypeInfo,
FuncDef, ClassDef, NamedTupleExpr, SymbolNode, Var, Statement, SuperExpr, NewTypeExpr,
OverloadedFuncDef, LambdaExpr, TypedDictExpr, EnumCallExpr, FuncBase, TypeAliasExpr, CallExpr,
CastExpr, TypeAlias,
Expand Down Expand Up @@ -226,6 +226,10 @@ def visit_cast_expr(self, node: CastExpr) -> None:
super().visit_cast_expr(node)
self.fixup_type(node.type)

def visit_assert_type_expr(self, node: AssertTypeExpr) -> None:
super().visit_assert_type_expr(node)
self.fixup_type(node.type)

def visit_super_expr(self, node: SuperExpr) -> None:
super().visit_super_expr(node)
if node.info is not None:
Expand Down
6 changes: 5 additions & 1 deletion mypy/server/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a

from mypy.checkmember import bind_self
from mypy.nodes import (
Node, Expression, MypyFile, FuncDef, ClassDef, AssignmentStmt, NameExpr, MemberExpr, Import,
AssertTypeExpr, Node, Expression, MypyFile, FuncDef, ClassDef, AssignmentStmt, NameExpr, MemberExpr, Import,
ImportFrom, CallExpr, CastExpr, TypeVarExpr, TypeApplication, IndexExpr, UnaryExpr, OpExpr,
ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt,
TupleExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block,
Expand Down Expand Up @@ -686,6 +686,10 @@ def visit_cast_expr(self, e: CastExpr) -> None:
super().visit_cast_expr(e)
self.add_type_dependencies(e.type)

def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
super().visit_assert_type_expr(e)
self.add_type_dependencies(e.type)

def visit_type_application(self, e: TypeApplication) -> None:
super().visit_type_application(e)
for typ in e.types:
Expand Down
6 changes: 5 additions & 1 deletion mypy/server/subexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List

from mypy.nodes import (
Expression, Node, MemberExpr, YieldFromExpr, YieldExpr, CallExpr, OpExpr, ComparisonExpr,
AssertTypeExpr, Expression, Node, MemberExpr, YieldFromExpr, YieldExpr, CallExpr, OpExpr, ComparisonExpr,
SliceExpr, CastExpr, RevealExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr,
IndexExpr, GeneratorExpr, ListComprehension, SetComprehension, DictionaryComprehension,
ConditionalExpr, TypeApplication, LambdaExpr, StarExpr, BackquoteExpr, AwaitExpr,
Expand Down Expand Up @@ -99,6 +99,10 @@ def visit_cast_expr(self, e: CastExpr) -> None:
self.add(e)
super().visit_cast_expr(e)

def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
self.add(e)
super().visit_assert_type_expr(e)

def visit_reveal_expr(self, e: RevealExpr) -> None:
self.add(e)
super().visit_reveal_expr(e)
Expand Down
3 changes: 3 additions & 0 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,9 @@ def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> str:
def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> str:
return self.dump([o.expr, o.type], o)

def visit_assert_type_expr(self, o: 'mypy.nodes.AssertTypeExpr') -> str:
return self.dump([o.expr, o.type], o)

def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> str:
if o.kind == mypy.nodes.REVEAL_TYPE:
return self.dump([o.expr], o)
Expand Down
5 changes: 4 additions & 1 deletion mypy/traverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from mypy.visitor import NodeVisitor
from mypy.nodes import (
Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef,
AssertTypeExpr, Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef,
ExpressionStmt, AssignmentStmt, OperatorAssignmentStmt, WhileStmt,
ForStmt, ReturnStmt, AssertStmt, DelStmt, IfStmt, RaiseStmt,
TryStmt, WithStmt, MatchStmt, NameExpr, MemberExpr, OpExpr, SliceExpr, CastExpr,
Expand Down Expand Up @@ -205,6 +205,9 @@ def visit_slice_expr(self, o: SliceExpr) -> None:
def visit_cast_expr(self, o: CastExpr) -> None:
o.expr.accept(self)

def visit_assert_type_expr(self, o: AssertTypeExpr) -> None:
o.expr.accept(self)

def visit_reveal_expr(self, o: RevealExpr) -> None:
if o.kind == REVEAL_TYPE:
assert o.expr is not None
Expand Down
5 changes: 4 additions & 1 deletion mypy/treetransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import List, Dict, cast, Optional, Iterable

from mypy.nodes import (
MypyFile, Import, Node, ImportAll, ImportFrom, FuncItem, FuncDef,
AssertTypeExpr, MypyFile, Import, Node, ImportAll, ImportFrom, FuncItem, FuncDef,
OverloadedFuncDef, ClassDef, Decorator, Block, Var,
OperatorAssignmentStmt, ExpressionStmt, AssignmentStmt, ReturnStmt,
RaiseStmt, AssertStmt, DelStmt, BreakStmt, ContinueStmt,
Expand Down Expand Up @@ -407,6 +407,9 @@ def visit_cast_expr(self, node: CastExpr) -> CastExpr:
return CastExpr(self.expr(node.expr),
self.type(node.type))

def visit_assert_type_expr(self, node: AssertTypeExpr) -> AssertTypeExpr:
return AssertTypeExpr(self.expr(node.expr), self.type(node.type))

def visit_reveal_expr(self, node: RevealExpr) -> RevealExpr:
if node.kind == REVEAL_TYPE:
assert node.expr is not None
Expand Down
5 changes: 5 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@
'typing_extensions.reveal_type',
)

ASSERT_TYPE_NAMES: Final = (
'typing.assert_type',
'typing_extensions.assert_Type',
)

# Attributes that can optionally be defined in the body of a subclass of
# enum.Enum but are removed from the class __dict__ by EnumMeta.
ENUM_REMOVED_PROPS: Final = (
Expand Down
7 changes: 7 additions & 0 deletions mypy/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T:
def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> T:
pass

@abstractmethod
def visit_assert_type_expr(self, o: 'mypy.nodes.AssertTypeExpr') -> T:
pass

@abstractmethod
def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> T:
pass
Expand Down Expand Up @@ -523,6 +527,9 @@ def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T:
def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> T:
pass

def visit_assert_type_expr(self, o: 'mypy.nodes.AssertTypeExpr') -> T:
pass

def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> T:
pass

Expand Down
5 changes: 4 additions & 1 deletion mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import List, Optional, Union, Callable, cast

from mypy.nodes import (
Expression, NameExpr, MemberExpr, SuperExpr, CallExpr, UnaryExpr, OpExpr, IndexExpr,
AssertTypeExpr, Expression, NameExpr, MemberExpr, SuperExpr, CallExpr, UnaryExpr, OpExpr, IndexExpr,
ConditionalExpr, ComparisonExpr, IntExpr, FloatExpr, ComplexExpr, StrExpr,
BytesExpr, EllipsisExpr, ListExpr, TupleExpr, DictExpr, SetExpr, ListComprehension,
SetComprehension, DictionaryComprehension, SliceExpr, GeneratorExpr, CastExpr, StarExpr,
Expand Down Expand Up @@ -203,6 +203,9 @@ def transform_super_expr(builder: IRBuilder, o: SuperExpr) -> Value:
def transform_call_expr(builder: IRBuilder, expr: CallExpr) -> Value:
if isinstance(expr.analyzed, CastExpr):
return translate_cast_expr(builder, expr.analyzed)
elif isinstance(expr.analyzed, AssertTypeExpr):
# Compile to a no-op.
return builder.accept(expr.analyzed.expr)

callee = expr.callee
if isinstance(callee, IndexExpr) and isinstance(callee.analyzed, TypeApplication):
Expand Down
5 changes: 4 additions & 1 deletion mypyc/irbuild/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import NoReturn

from mypy.nodes import (
MypyFile, FuncDef, ReturnStmt, AssignmentStmt, OpExpr,
AssertTypeExpr, MypyFile, FuncDef, ReturnStmt, AssignmentStmt, OpExpr,
IntExpr, NameExpr, Var, IfStmt, UnaryExpr, ComparisonExpr, WhileStmt, CallExpr,
IndexExpr, Block, ListExpr, ExpressionStmt, MemberExpr, ForStmt,
BreakStmt, ContinueStmt, ConditionalExpr, OperatorAssignmentStmt, TupleExpr, ClassDef,
Expand Down Expand Up @@ -327,6 +327,9 @@ def visit_var(self, o: Var) -> None:
def visit_cast_expr(self, o: CastExpr) -> Value:
assert False, "CastExpr should have been handled in CallExpr"

def visit_assert_type_expr(self, o: AssertTypeExpr) -> Value:
assert False, "AssertTypeExpr should have been handled in CallExpr"

def visit_star_expr(self, o: StarExpr) -> Value:
assert False, "should have been handled in Tuple/List/Set/DictExpr or CallExpr"

Expand Down
11 changes: 11 additions & 0 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,17 @@ L0:
o = r3
return 1

[case testAssertType]
from typing import assert_type
def f(x: int) -> None:
y = assert_type(x, int)
[out]
def f(x):
x, y :: int
L0:
y = x
return 1

[case testDownCast]
from typing import cast, List, Tuple
class A: pass
Expand Down
12 changes: 12 additions & 0 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,18 @@ class B: pass
[out]
main:3: error: "A" not callable

-- assert_type()

[case testAssertType]
from typing import assert_type, Any
from typing_extensions import Literal
a: int = 1
returned = assert_type(a, int)
reveal_type(returned) # N: Revealed type is "builtins.int"
assert_type(a, str) # E: Expression is of type "int", not "str"
assert_type(a, Any) # E: Expression is of type "int", not "Any"
assert_type(a, Literal[1]) # E: Expression is of type "int", not "Literal[1]"
[builtins fixtures/tuple.pyi]

-- None return type
-- ----------------
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/typing-full.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ from abc import abstractmethod, ABCMeta
class GenericMeta(type): pass

def cast(t, o): ...
def assert_type(o, t): ...
overload = 0
Any = 0
Union = 0
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/lib-stub/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# the stubs under fixtures/.

cast = 0
assert_type = 0
overload = 0
Any = 0
Union = 0
Expand Down
Loading

0 comments on commit 22a8e4d

Please sign in to comment.