Skip to content

Commit

Permalink
refactor: ExprVisitor type validation (#3739)
Browse files Browse the repository at this point in the history
this commit simplifies the `ExprVisitor` implementation by moving calls
to `validate_expected_type` into the generic `visit()` function, instead
of having ad-hoc calls to validate_expected_type in the specialized
visitor functions.

in doing so, some inconsistencies in the generic implementation were
found and fixed:
- fix validate_expected_type for tuples
- introduce a void type for dealing with function calls/statements which
  don't return anything.

---------

Co-authored-by: Charles Cooper <cooper.charles.m@gmail.com>
  • Loading branch information
tserg and charles-cooper authored Feb 5, 2024
1 parent 01ec9a1 commit e20885e
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 71 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from vyper import compile_code
from vyper.exceptions import TypeMismatch
from vyper.exceptions import InvalidType

pytestmark = pytest.mark.usefixtures("memory_mocker")

Expand Down Expand Up @@ -159,5 +159,5 @@ def test_tuple_return_typecheck(tx_failed, get_contract_with_gas_estimation):
def getTimeAndBalance() -> (bool, address):
return block.timestamp, self.balance
"""
with pytest.raises(TypeMismatch):
with pytest.raises(InvalidType):
compile_code(code)
82 changes: 28 additions & 54 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types import (
TYPE_T,
VOID_TYPE,
AddressT,
BoolT,
DArrayT,
Expand All @@ -45,6 +46,7 @@
VyperType,
_BytestringT,
is_type_t,
map_void,
)
from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability
from vyper.semantics.types.utils import type_from_annotation
Expand Down Expand Up @@ -235,12 +237,13 @@ def visit_AnnAssign(self, node):
)

typ = type_from_annotation(node.annotation, DataLocation.MEMORY)
validate_expected_type(node.value, typ)

# validate the value before adding it to the namespace
self.expr_visitor.visit(node.value, typ)

self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY)

self.expr_visitor.visit(node.target, typ)
self.expr_visitor.visit(node.value, typ)

def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None:
if isinstance(msg_node, vy_ast.Str):
Expand All @@ -259,10 +262,6 @@ def visit_Assert(self, node):
if node.msg:
self._validate_revert_reason(node.msg)

try:
validate_expected_type(node.test, BoolT())
except InvalidType:
raise InvalidType("Assertion test value must be a boolean", node.test)
self.expr_visitor.visit(node.test, BoolT())

# repeated code for Assign and AugAssign
Expand All @@ -276,7 +275,6 @@ def _assign_helper(self, node):
"Left-hand side of assignment cannot be a HashMap without a key", node
)

validate_expected_type(node.value, target.typ)
target.validate_modification(node, self.func.mutability)

self.expr_visitor.visit(node.value, target.typ)
Expand Down Expand Up @@ -341,16 +339,16 @@ def visit_Expr(self, node):
expr_info.validate_modification(node, self.func.mutability)

# NOTE: fetch_call_return validates call args.
return_value = fn_type.fetch_call_return(node.value)
return_value = map_void(fn_type.fetch_call_return(node.value))
if (
return_value
return_value is not VOID_TYPE
and not isinstance(fn_type, MemberFunctionT)
and not isinstance(fn_type, ContractFunctionT)
):
raise StructureException(
f"Function '{fn_type._id}' cannot be called without assigning the result", node
)
self.expr_visitor.visit(node.value, fn_type)
self.expr_visitor.visit(node.value, return_value)

def visit_For(self, node):
if not isinstance(node.target.target, vy_ast.Name):
Expand Down Expand Up @@ -443,7 +441,6 @@ def visit_For(self, node):
self.expr_visitor.visit(node.iter, iter_type)

def visit_If(self, node):
validate_expected_type(node.test, BoolT())
self.expr_visitor.visit(node.test, BoolT())
with self.namespace.enter_scope():
for n in node.body:
Expand All @@ -462,9 +459,11 @@ def visit_Log(self, node):
raise StructureException(
f"Cannot emit logs from {self.func.mutability.value.lower()} functions", node
)
f.fetch_call_return(node.value)
t = map_void(f.fetch_call_return(node.value))
# CMC 2024-02-05 annotate the event type for codegen usage
# TODO: refactor this
node._metadata["type"] = f.typedef
self.expr_visitor.visit(node.value, f.typedef)
self.expr_visitor.visit(node.value, t)

def visit_Raise(self, node):
if node.exc:
Expand All @@ -489,10 +488,7 @@ def visit_Return(self, node):
f"expected {self.func.return_type.length}, got {len(values)}",
node,
)
for given, expected in zip(values, self.func.return_type.member_types):
validate_expected_type(given, expected)
else:
validate_expected_type(values, self.func.return_type)

self.expr_visitor.visit(node.value, self.func.return_type)


Expand All @@ -503,14 +499,11 @@ def __init__(self, fn_node: Optional[ContractFunctionT] = None):
self.func = fn_node

def visit(self, node, typ):
if typ is not VOID_TYPE and not isinstance(typ, TYPE_T):
validate_expected_type(node, typ)

# recurse and typecheck in case we are being fed the wrong type for
# some reason. note that `validate_expected_type` is unnecessary
# for nodes that already call `get_exact_type_from_node` and
# `get_possible_types_from_node` because `validate_expected_type`
# would be calling the same function again.
# CMC 2023-06-27 would be cleanest to call validate_expected_type()
# before recursing but maybe needs some refactoring before that
# can happen.
# some reason.
super().visit(node, typ)

# annotate
Expand Down Expand Up @@ -541,28 +534,21 @@ def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None:
self.visit(node.value, value_type)

def visit_BinOp(self, node: vy_ast.BinOp, typ: VyperType) -> None:
validate_expected_type(node.left, typ)
self.visit(node.left, typ)

rtyp = typ
if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)):
rtyp = get_possible_types_from_node(node.right).pop()

validate_expected_type(node.right, rtyp)

self.visit(node.right, rtyp)

def visit_BoolOp(self, node: vy_ast.BoolOp, typ: VyperType) -> None:
assert typ == BoolT() # sanity check
for value in node.values:
validate_expected_type(value, BoolT())
self.visit(value, BoolT())

def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None:
call_type = get_exact_type_from_node(node.func)
# except for builtin functions, `get_exact_type_from_node`
# already calls `validate_expected_type` on the call args
# and kwargs via `call_type.fetch_call_return`
self.visit(node.func, call_type)

if isinstance(call_type, ContractFunctionT):
Expand Down Expand Up @@ -594,7 +580,6 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None:
else:
# builtin functions
arg_types = call_type.infer_arg_types(node, expected_return_typ=typ)
# `infer_arg_types` already calls `validate_expected_type`
for arg, arg_type in zip(node.args, arg_types):
self.visit(arg, arg_type)
kwarg_types = call_type.infer_kwarg_types(node)
Expand All @@ -610,7 +595,6 @@ def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None:

rlen = len(node.right.elements)
rtyp = SArrayT(ltyp, rlen)
validate_expected_type(node.right, rtyp)
else:
rtyp = get_exact_type_from_node(node.right)
if isinstance(rtyp, FlagT):
Expand All @@ -621,8 +605,6 @@ def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None:
assert isinstance(rtyp, (SArrayT, DArrayT))
ltyp = rtyp.value_type

validate_expected_type(node.left, ltyp)

self.visit(node.left, ltyp)
self.visit(node.right, rtyp)

Expand All @@ -638,28 +620,27 @@ def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None:
rtyp = get_exact_type_from_node(node.right)
else:
ltyp = rtyp = cmp_typ
validate_expected_type(node.left, ltyp)
validate_expected_type(node.right, rtyp)

self.visit(node.left, ltyp)
self.visit(node.right, rtyp)

def visit_Constant(self, node: vy_ast.Constant, typ: VyperType) -> None:
validate_expected_type(node, typ)
pass

def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None:
self.visit(node.test, BoolT())
self.visit(node.body, typ)
self.visit(node.orelse, typ)

def visit_List(self, node: vy_ast.List, typ: VyperType) -> None:
assert isinstance(typ, (SArrayT, DArrayT))
for element in node.elements:
validate_expected_type(element, typ.value_type)
self.visit(element, typ.value_type)

def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None:
if self.func and self.func.mutability == StateMutability.PURE:
_validate_self_reference(node)

if not isinstance(typ, TYPE_T):
validate_expected_type(node, typ)

def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None:
if isinstance(typ, TYPE_T):
# don't recurse; can't annotate AST children of type definition
Expand Down Expand Up @@ -694,23 +675,16 @@ def visit_Tuple(self, node: vy_ast.Tuple, typ: VyperType) -> None:
# don't recurse; can't annotate AST children of type definition
return

# these guarantees should be provided by validate_expected_type
assert isinstance(typ, TupleT)
for element, subtype in zip(node.elements, typ.member_types):
validate_expected_type(element, subtype)
self.visit(element, subtype)
assert len(node.elements) == len(typ.member_types)

for item_ast, item_type in zip(node.elements, typ.member_types):
self.visit(item_ast, item_type)

def visit_UnaryOp(self, node: vy_ast.UnaryOp, typ: VyperType) -> None:
validate_expected_type(node.operand, typ)
self.visit(node.operand, typ)

def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None:
validate_expected_type(node.test, BoolT())
self.visit(node.test, BoolT())
validate_expected_type(node.body, typ)
self.visit(node.body, typ)
validate_expected_type(node.orelse, typ)
self.visit(node.orelse, typ)


def _validate_range_call(node: vy_ast.Call):
"""
Expand Down
41 changes: 28 additions & 13 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,17 @@ def types_from_Constant(self, node):
)
raise InvalidLiteral(f"Could not determine type for literal value '{node.value}'", node)

def types_from_IfExp(self, node):
validate_expected_type(node.test, BoolT())
types_list = get_common_types(node.body, node.orelse)

if not types_list:
a = get_possible_types_from_node(node.body)[0]
b = get_possible_types_from_node(node.orelse)[0]
raise TypeMismatch(f"Dislike types: {a} and {b}", node)

return types_list

def types_from_List(self, node):
# literal array
if _is_empty_list(node):
Expand Down Expand Up @@ -399,17 +410,6 @@ def types_from_UnaryOp(self, node):
types_list = self.get_possible_types_from_node(node.operand)
return _validate_op(node, types_list, "validate_numeric_op")

def types_from_IfExp(self, node):
validate_expected_type(node.test, BoolT())
types_list = get_common_types(node.body, node.orelse)

if not types_list:
a = get_possible_types_from_node(node.body)[0]
b = get_possible_types_from_node(node.orelse)[0]
raise TypeMismatch(f"Dislike types: {a} and {b}", node)

return types_list


def _is_empty_list(node):
# Checks if a node is a `List` node with an empty list for `elements`,
Expand Down Expand Up @@ -550,11 +550,26 @@ def validate_expected_type(node, expected_type):
-------
None
"""
given_types = _ExprAnalyser().get_possible_types_from_node(node)

if not isinstance(expected_type, tuple):
expected_type = (expected_type,)

if isinstance(node, vy_ast.Tuple):
possible_tuple_types = [t for t in expected_type if isinstance(t, TupleT)]
for t in possible_tuple_types:
if len(t.member_types) != len(node.elements):
continue
for item_ast, item_type in zip(node.elements, t.member_types):
try:
validate_expected_type(item_ast, item_type)
return
except VyperException:
pass
else:
# fail block
pass

given_types = _ExprAnalyser().get_possible_types_from_node(node)

if isinstance(node, vy_ast.List):
# special case - for literal arrays we individually validate each item
for expected in expected_type:
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import primitives, subscriptable, user
from .base import TYPE_T, KwargSettings, VyperType, is_type_t
from .base import TYPE_T, VOID_TYPE, KwargSettings, VyperType, is_type_t, map_void
from .bytestrings import BytesT, StringT, _BytestringT
from .function import MemberFunctionT
from .module import InterfaceT
Expand Down
16 changes: 15 additions & 1 deletion vyper/semantics/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, type_):
self.type_ = type_

def compare_type(self, other):
return isinstance(other, self.type_)
return isinstance(other, self.type_) or self == other


class VyperType:
Expand Down Expand Up @@ -324,6 +324,20 @@ def __init__(self, typ, default, require_literal=False):
self.require_literal = require_literal


class _VoidType(VyperType):
_id = "(void)"


# sentinel for function calls which return nothing
VOID_TYPE = _VoidType()


def map_void(typ: Optional[VyperType]) -> VyperType:
if typ is None:
return VOID_TYPE
return typ


# A type type. Used internally for types which can live in expression
# position, ex. constructors (events, interfaces and structs), and also
# certain builtins which take types as parameters
Expand Down

0 comments on commit e20885e

Please sign in to comment.