Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for assert_type #12584

Merged
merged 3 commits into from
Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
get_proper_types, flatten_nested_unions, LITERAL_TYPE_NAMES,
)
from mypy.nodes import (
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
AssertTypeExpr, NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
MemberExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr,
OpExpr, UnaryExpr, IndexExpr, CastExpr, RevealExpr, TypeApplication, ListExpr,
TupleExpr, DictExpr, LambdaExpr, SuperExpr, SliceExpr, Context, Expression,
Expand Down Expand Up @@ -3144,6 +3144,14 @@ def visit_cast_expr(self, expr: CastExpr) -> Type:
context=expr)
return target_type

def visit_assert_type_expr(self, expr: AssertTypeExpr) -> Type:
source_type = self.accept(expr.expr, type_context=AnyType(TypeOfAny.special_form),
allow_none_return=True, always_allow_any=True)
target_type = expr.type
if not is_same_type(source_type, target_type):
self.msg.assert_type_fail(source_type, target_type, expr)
return source_type

def visit_reveal_expr(self, expr: RevealExpr) -> Type:
"""Type check a reveal_type expression."""
if expr.kind == REVEAL_TYPE:
Expand Down
3 changes: 3 additions & 0 deletions mypy/errorcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def __str__(self) -> str:
REDUNDANT_CAST: Final = ErrorCode(
"redundant-cast", "Check that cast changes type of expression", "General"
)
ASSERT_TYPE: Final = ErrorCode(
"assert-type", "Check that assert_type() call succeeds", "General"
)
COMPARISON_OVERLAP: Final = ErrorCode(
"comparison-overlap", "Check that types in comparisons and 'in' expressions overlap", "General"
)
Expand Down
6 changes: 5 additions & 1 deletion mypy/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
ConditionalExpr, EllipsisExpr, YieldFromExpr, YieldExpr, RevealExpr, SuperExpr,
TypeApplication, LambdaExpr, ListComprehension, SetComprehension, DictionaryComprehension,
GeneratorExpr, BackquoteExpr, TypeVarExpr, TypeAliasExpr, NamedTupleExpr, EnumCallExpr,
TypedDictExpr, NewTypeExpr, PromoteExpr, AwaitExpr, TempNode, AssignmentExpr, ParamSpecExpr
TypedDictExpr, NewTypeExpr, PromoteExpr, AwaitExpr, TempNode, AssignmentExpr, ParamSpecExpr,
AssertTypeExpr,
)
from mypy.visitor import ExpressionVisitor

Expand Down Expand Up @@ -175,6 +176,9 @@ def visit_slice_expr(self, e: SliceExpr) -> None:
def visit_cast_expr(self, e: CastExpr) -> None:
return None

def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
return None

def visit_conditional_expr(self, e: ConditionalExpr) -> None:
return None

Expand Down
5 changes: 5 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,11 @@ def redundant_cast(self, typ: Type, context: Context) -> None:
self.fail('Redundant cast to {}'.format(format_type(typ)), context,
code=codes.REDUNDANT_CAST)

def assert_type_fail(self, source_type: Type, target_type: Type, context: Context) -> None:
self.fail(f"Expression is of type {format_type(source_type)}, "
f"not {format_type(target_type)}", context,
code=codes.ASSERT_TYPE)

def unimported_type_becomes_any(self, prefix: str, typ: Type, ctx: Context) -> None:
self.fail("{} becomes {} due to an unfollowed import".format(prefix, format_type(typ)),
ctx, code=codes.NO_ANY_UNIMPORTED)
Expand Down
6 changes: 5 additions & 1 deletion mypy/mixedtraverser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from mypy.nodes import (
Var, FuncItem, ClassDef, AssignmentStmt, ForStmt, WithStmt,
AssertTypeExpr, Var, FuncItem, ClassDef, AssignmentStmt, ForStmt, WithStmt,
CastExpr, TypeApplication, TypeAliasExpr, TypeVarExpr, TypedDictExpr, NamedTupleExpr,
PromoteExpr, NewTypeExpr
)
Expand Down Expand Up @@ -79,6 +79,10 @@ def visit_cast_expr(self, o: CastExpr) -> None:
super().visit_cast_expr(o)
o.type.accept(self)

def visit_assert_type_expr(self, o: AssertTypeExpr) -> None:
super().visit_assert_type_expr(o)
o.type.accept(self)

def visit_type_application(self, o: TypeApplication) -> None:
super().visit_type_application(o)
for t in o.types:
Expand Down
16 changes: 16 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,6 +1945,22 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_cast_expr(self)


class AssertTypeExpr(Expression):
"""Represents a typing.assert_type(expr, type) call."""
__slots__ = ('expr', 'type')

expr: Expression
type: "mypy.types.Type"

def __init__(self, expr: Expression, typ: 'mypy.types.Type') -> None:
super().__init__()
self.expr = expr
self.type = typ

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_assert_type_expr(self)


class RevealExpr(Expression):
"""Reveal type expression reveal_type(expr) or reveal_locals() expression."""

Expand Down
23 changes: 21 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from typing_extensions import Final, TypeAlias as _TypeAlias

from mypy.nodes import (
MypyFile, TypeInfo, Node, AssignmentStmt, FuncDef, OverloadedFuncDef,
AssertTypeExpr, MypyFile, TypeInfo, Node, AssignmentStmt, FuncDef, OverloadedFuncDef,
ClassDef, Var, GDEF, FuncItem, Import, Expression, Lvalue,
ImportFrom, ImportAll, Block, LDEF, NameExpr, MemberExpr,
IndexExpr, TupleExpr, ListExpr, ExpressionStmt, ReturnStmt,
Expand Down Expand Up @@ -99,7 +99,7 @@
TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType,
get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType, Parameters, ParamSpecType,
PROTOCOL_NAMES, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, FINAL_DECORATOR_NAMES, REVEAL_TYPE_NAMES,
is_named_instance,
ASSERT_TYPE_NAMES, is_named_instance,
)
from mypy.typeops import function_type, get_type_vars
from mypy.type_visitor import TypeQuery
Expand Down Expand Up @@ -3897,6 +3897,19 @@ def visit_call_expr(self, expr: CallExpr) -> None:
expr.analyzed.line = expr.line
expr.analyzed.column = expr.column
expr.analyzed.accept(self)
elif refers_to_fullname(expr.callee, ASSERT_TYPE_NAMES):
if not self.check_fixed_args(expr, 2, 'assert_type'):
return
# Translate second argument to an unanalyzed type.
try:
target = self.expr_to_unanalyzed_type(expr.args[1])
except TypeTranslationError:
self.fail('assert_type() type is not a type', expr)
return
expr.analyzed = AssertTypeExpr(expr.args[0], target)
expr.analyzed.line = expr.line
expr.analyzed.column = expr.column
expr.analyzed.accept(self)
elif refers_to_fullname(expr.callee, REVEAL_TYPE_NAMES):
if not self.check_fixed_args(expr, 1, 'reveal_type'):
return
Expand Down Expand Up @@ -4200,6 +4213,12 @@ def visit_cast_expr(self, expr: CastExpr) -> None:
if analyzed is not None:
expr.type = analyzed

def visit_assert_type_expr(self, expr: AssertTypeExpr) -> None:
expr.expr.accept(self)
analyzed = self.anal_type(expr.type)
if analyzed is not None:
expr.type = analyzed

def visit_reveal_expr(self, expr: RevealExpr) -> None:
if expr.kind == REVEAL_TYPE:
if expr.expr is not None:
Expand Down
6 changes: 5 additions & 1 deletion mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
MypyFile, SymbolTable, Block, AssignmentStmt, NameExpr, MemberExpr, RefExpr, TypeInfo,
FuncDef, ClassDef, NamedTupleExpr, SymbolNode, Var, Statement, SuperExpr, NewTypeExpr,
OverloadedFuncDef, LambdaExpr, TypedDictExpr, EnumCallExpr, FuncBase, TypeAliasExpr, CallExpr,
CastExpr, TypeAlias,
CastExpr, TypeAlias, AssertTypeExpr,
MDEF
)
from mypy.traverser import TraverserVisitor
Expand Down Expand Up @@ -226,6 +226,10 @@ def visit_cast_expr(self, node: CastExpr) -> None:
super().visit_cast_expr(node)
self.fixup_type(node.type)

def visit_assert_type_expr(self, node: AssertTypeExpr) -> None:
super().visit_assert_type_expr(node)
self.fixup_type(node.type)

def visit_super_expr(self, node: SuperExpr) -> None:
super().visit_super_expr(node)
if node.info is not None:
Expand Down
7 changes: 6 additions & 1 deletion mypy/server/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a
ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt,
TupleExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block,
TypeInfo, FuncBase, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr,
LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr
LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr,
AssertTypeExpr,
)
from mypy.operators import (
op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods
Expand Down Expand Up @@ -686,6 +687,10 @@ def visit_cast_expr(self, e: CastExpr) -> None:
super().visit_cast_expr(e)
self.add_type_dependencies(e.type)

def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
super().visit_assert_type_expr(e)
self.add_type_dependencies(e.type)

def visit_type_application(self, e: TypeApplication) -> None:
super().visit_type_application(e)
for typ in e.types:
Expand Down
6 changes: 5 additions & 1 deletion mypy/server/subexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
SliceExpr, CastExpr, RevealExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr,
IndexExpr, GeneratorExpr, ListComprehension, SetComprehension, DictionaryComprehension,
ConditionalExpr, TypeApplication, LambdaExpr, StarExpr, BackquoteExpr, AwaitExpr,
AssignmentExpr,
AssignmentExpr, AssertTypeExpr,
)
from mypy.traverser import TraverserVisitor

Expand Down Expand Up @@ -99,6 +99,10 @@ def visit_cast_expr(self, e: CastExpr) -> None:
self.add(e)
super().visit_cast_expr(e)

def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
self.add(e)
super().visit_assert_type_expr(e)

def visit_reveal_expr(self, e: RevealExpr) -> None:
self.add(e)
super().visit_reveal_expr(e)
Expand Down
3 changes: 3 additions & 0 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,9 @@ def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> str:
def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> str:
return self.dump([o.expr, o.type], o)

def visit_assert_type_expr(self, o: 'mypy.nodes.AssertTypeExpr') -> str:
return self.dump([o.expr, o.type], o)

def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> str:
if o.kind == mypy.nodes.REVEAL_TYPE:
return self.dump([o.expr], o)
Expand Down
5 changes: 4 additions & 1 deletion mypy/traverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from mypy.visitor import NodeVisitor
from mypy.nodes import (
Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef,
AssertTypeExpr, Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef,
ExpressionStmt, AssignmentStmt, OperatorAssignmentStmt, WhileStmt,
ForStmt, ReturnStmt, AssertStmt, DelStmt, IfStmt, RaiseStmt,
TryStmt, WithStmt, MatchStmt, NameExpr, MemberExpr, OpExpr, SliceExpr, CastExpr,
Expand Down Expand Up @@ -205,6 +205,9 @@ def visit_slice_expr(self, o: SliceExpr) -> None:
def visit_cast_expr(self, o: CastExpr) -> None:
o.expr.accept(self)

def visit_assert_type_expr(self, o: AssertTypeExpr) -> None:
o.expr.accept(self)

def visit_reveal_expr(self, o: RevealExpr) -> None:
if o.kind == REVEAL_TYPE:
assert o.expr is not None
Expand Down
5 changes: 4 additions & 1 deletion mypy/treetransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import List, Dict, cast, Optional, Iterable

from mypy.nodes import (
MypyFile, Import, Node, ImportAll, ImportFrom, FuncItem, FuncDef,
AssertTypeExpr, MypyFile, Import, Node, ImportAll, ImportFrom, FuncItem, FuncDef,
OverloadedFuncDef, ClassDef, Decorator, Block, Var,
OperatorAssignmentStmt, ExpressionStmt, AssignmentStmt, ReturnStmt,
RaiseStmt, AssertStmt, DelStmt, BreakStmt, ContinueStmt,
Expand Down Expand Up @@ -407,6 +407,9 @@ def visit_cast_expr(self, node: CastExpr) -> CastExpr:
return CastExpr(self.expr(node.expr),
self.type(node.type))

def visit_assert_type_expr(self, node: AssertTypeExpr) -> AssertTypeExpr:
return AssertTypeExpr(self.expr(node.expr), self.type(node.type))

def visit_reveal_expr(self, node: RevealExpr) -> RevealExpr:
if node.kind == REVEAL_TYPE:
assert node.expr is not None
Expand Down
5 changes: 5 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@
'typing_extensions.reveal_type',
)

ASSERT_TYPE_NAMES: Final = (
'typing.assert_type',
'typing_extensions.assert_type',
)

# Attributes that can optionally be defined in the body of a subclass of
# enum.Enum but are removed from the class __dict__ by EnumMeta.
ENUM_REMOVED_PROPS: Final = (
Expand Down
7 changes: 7 additions & 0 deletions mypy/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T:
def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> T:
pass

@abstractmethod
def visit_assert_type_expr(self, o: 'mypy.nodes.AssertTypeExpr') -> T:
pass

@abstractmethod
def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> T:
pass
Expand Down Expand Up @@ -523,6 +527,9 @@ def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T:
def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> T:
pass

def visit_assert_type_expr(self, o: 'mypy.nodes.AssertTypeExpr') -> T:
pass

def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> T:
pass

Expand Down
5 changes: 4 additions & 1 deletion mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ConditionalExpr, ComparisonExpr, IntExpr, FloatExpr, ComplexExpr, StrExpr,
BytesExpr, EllipsisExpr, ListExpr, TupleExpr, DictExpr, SetExpr, ListComprehension,
SetComprehension, DictionaryComprehension, SliceExpr, GeneratorExpr, CastExpr, StarExpr,
AssignmentExpr,
AssignmentExpr, AssertTypeExpr,
Var, RefExpr, MypyFile, TypeInfo, TypeApplication, LDEF, ARG_POS
)
from mypy.types import TupleType, Instance, TypeType, ProperType, get_proper_type
Expand Down Expand Up @@ -203,6 +203,9 @@ def transform_super_expr(builder: IRBuilder, o: SuperExpr) -> Value:
def transform_call_expr(builder: IRBuilder, expr: CallExpr) -> Value:
if isinstance(expr.analyzed, CastExpr):
return translate_cast_expr(builder, expr.analyzed)
elif isinstance(expr.analyzed, AssertTypeExpr):
# Compile to a no-op.
return builder.accept(expr.analyzed.expr)

callee = expr.callee
if isinstance(callee, IndexExpr) and isinstance(callee.analyzed, TypeApplication):
Expand Down
5 changes: 4 additions & 1 deletion mypyc/irbuild/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import NoReturn

from mypy.nodes import (
MypyFile, FuncDef, ReturnStmt, AssignmentStmt, OpExpr,
AssertTypeExpr, MypyFile, FuncDef, ReturnStmt, AssignmentStmt, OpExpr,
IntExpr, NameExpr, Var, IfStmt, UnaryExpr, ComparisonExpr, WhileStmt, CallExpr,
IndexExpr, Block, ListExpr, ExpressionStmt, MemberExpr, ForStmt,
BreakStmt, ContinueStmt, ConditionalExpr, OperatorAssignmentStmt, TupleExpr, ClassDef,
Expand Down Expand Up @@ -327,6 +327,9 @@ def visit_var(self, o: Var) -> None:
def visit_cast_expr(self, o: CastExpr) -> Value:
assert False, "CastExpr should have been handled in CallExpr"

def visit_assert_type_expr(self, o: AssertTypeExpr) -> Value:
assert False, "AssertTypeExpr should have been handled in CallExpr"

def visit_star_expr(self, o: StarExpr) -> Value:
assert False, "should have been handled in Tuple/List/Set/DictExpr or CallExpr"

Expand Down
11 changes: 11 additions & 0 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,17 @@ L0:
o = r3
return 1

[case testAssertType]
from typing import assert_type
def f(x: int) -> None:
y = assert_type(x, int)
[out]
def f(x):
x, y :: int
L0:
y = x
return 1

[case testDownCast]
from typing import cast, List, Tuple
class A: pass
Expand Down
12 changes: 12 additions & 0 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,18 @@ class B: pass
[out]
main:3: error: "A" not callable

-- assert_type()

[case testAssertType]
from typing import assert_type, Any
from typing_extensions import Literal
a: int = 1
returned = assert_type(a, int)
reveal_type(returned) # N: Revealed type is "builtins.int"
assert_type(a, str) # E: Expression is of type "int", not "str"
assert_type(a, Any) # E: Expression is of type "int", not "Any"
assert_type(a, Literal[1]) # E: Expression is of type "int", not "Literal[1]"
[builtins fixtures/tuple.pyi]

-- None return type
-- ----------------
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/typing-full.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ from abc import abstractmethod, ABCMeta
class GenericMeta(type): pass

def cast(t, o): ...
def assert_type(o, t): ...
overload = 0
Any = 0
Union = 0
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/lib-stub/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# the stubs under fixtures/.

cast = 0
assert_type = 0
overload = 0
Any = 0
Union = 0
Expand Down
Loading