Skip to content

Commit

Permalink
add assert_type() (#434)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Jan 23, 2022
1 parent 9879790 commit b11f15a
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Add support for `assert_type()` (#433)
- `reveal_type()` and `dump_value()` now return their argument,
the anticipated behavior for `typing.reveal_type()` in Python
3.11 (#432)
Expand Down
22 changes: 22 additions & 0 deletions pyanalyze/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,28 @@ def f(x: int) -> None:
return value


def assert_type(val: _T, typ: Any) -> _T:
"""Assert the inferred static type of an expression.
When a static type checker encounters a call to this function,
it checks that the inferred type of `val` matches the `typ`
argument, and if it dooes not, it emits an error.
Example::
def f(x: int) -> None:
assert_type(x, int) # ok
assert_type(x, str) # error
This is useful for checking that the type checker interprets
a complicated set of type annotations in the way the user intended.
At runtime this returns the first argument unchanged.
"""
return val


_overloads: Dict[str, List[Callable[..., Any]]] = defaultdict(list)
_type_evaluations: Dict[str, List[Callable[..., Any]]] = defaultdict(list)

Expand Down
50 changes: 46 additions & 4 deletions pyanalyze/implementation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import typing
import typing_extensions
from .annotations import type_from_value
from .error_code import ErrorCode
from .extensions import reveal_type
from .extensions import assert_type, reveal_type
from .format_strings import parse_format_string
from .predicates import IsAssignablePredicate
from .safe import safe_hasattr, safe_isinstance, safe_issubclass
Expand Down Expand Up @@ -52,6 +50,7 @@
concrete_values_from_iterable,
kv_pairs_from_mapping,
make_weak,
unannotate,
unite_values,
flatten_values,
replace_known_sequence_value,
Expand All @@ -66,6 +65,8 @@
import inspect
import warnings
from types import FunctionType
import typing
import typing_extensions
from typing import (
Sequence,
TypeVar,
Expand Down Expand Up @@ -1042,6 +1043,20 @@ def _cast_impl(ctx: CallContext) -> Value:
return type_from_value(typ, visitor=ctx.visitor, node=ctx.node)


def _assert_type_impl(ctx: CallContext) -> Value:
# TODO maybe we should walk over the whole value and remove Annotated.
val = unannotate(ctx.vars["val"])
typ = ctx.vars["typ"]
expected_type = type_from_value(typ, visitor=ctx.visitor, node=ctx.node)
if val != expected_type:
ctx.show_error(
f"Type is {val} (expected {expected_type})",
error_code=ErrorCode.inference_failure,
arg="obj",
)
return val


def _subclasses_impl(ctx: CallContext) -> Value:
"""Overridden because typeshed types make it (T) => List[T] instead."""
self_obj = ctx.vars["self"]
Expand Down Expand Up @@ -1423,7 +1438,18 @@ def get_default_argspecs() -> Dict[object, Signature]:
callable=str.format,
),
Signature.make(
[SigParameter("typ"), SigParameter("val")], callable=cast, impl=_cast_impl
[SigParameter("typ", _POS_ONLY), SigParameter("val", _POS_ONLY)],
callable=cast,
impl=_cast_impl,
),
Signature.make(
[
SigParameter("val", _POS_ONLY, annotation=TypeVarValue(T)),
SigParameter("typ", _POS_ONLY),
],
TypeVarValue(T),
callable=assert_type,
impl=_assert_type_impl,
),
# workaround for https://github.com/python/typeshed/pull/3501
Signature.make(
Expand Down Expand Up @@ -1566,4 +1592,20 @@ def get_default_argspecs() -> Dict[object, Signature]:
callable=reveal_type_func,
)
signatures.append(sig)
# Anticipating that this will be added to the stdlib
try:
assert_type_func = getattr(mod, "assert_type")
except AttributeError:
pass
else:
sig = Signature.make(
[
SigParameter("val", _POS_ONLY, annotation=TypeVarValue(T)),
SigParameter("typ", _POS_ONLY),
],
TypeVarValue(T),
callable=assert_type_func,
impl=_assert_type_impl,
)
signatures.append(sig)
return {sig.callable: sig for sig in signatures}
11 changes: 11 additions & 0 deletions pyanalyze/test_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,17 @@ def capybara():
assert_is_value(x, KnownValue(1))
assert_is_value(y, KnownValue(1))

@assert_passes()
def test_assert_type(self) -> None:
from pyanalyze.extensions import assert_type
from typing import Any

def capybara(x: int) -> None:
assert_type(x, int)
assert_type(x, "int")
assert_type(x, Any) # E: inference_failure
assert_type(x, str) # E: inference_failure


class TestCallableGuards(TestNameCheckVisitorBase):
@assert_passes()
Expand Down

0 comments on commit b11f15a

Please sign in to comment.