Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[lang]: infer expected types #3765

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6b9fff2
rename validate_expected_type to infer_type and have it return a type
charles-cooper Feb 9, 2024
c1b8979
Merge branch 'master' into feat/infer_expected_types
charles-cooper Feb 13, 2024
b3e2fd9
update a comment
charles-cooper Feb 13, 2024
3fd9fb8
improve type inference for revert reason strings
charles-cooper Feb 13, 2024
1350838
format some comments
charles-cooper Feb 13, 2024
2013d81
use the result of infer_type
charles-cooper Feb 13, 2024
71aab61
remove length modification functions on bytestring
charles-cooper Feb 13, 2024
d7993ec
feat[lang]: allow downcasting of bytestrings
charles-cooper Mar 5, 2024
93e53c1
fix direction of some comparisons
charles-cooper Mar 5, 2024
b2e62a2
fix existing tests, add tests for new functionality, add compile-time…
charles-cooper Mar 6, 2024
5348802
Merge branch 'master' into feat/bytestring-cast
charles-cooper Mar 8, 2024
f8689ab
Merge branch 'master' into feat/infer_expected_types
charles-cooper Mar 8, 2024
37880cb
Merge branch 'feat/bytestring-cast' into feat/infer_expected_types
charles-cooper Mar 8, 2024
35ec413
allow bytestrings with ellipsis length
charles-cooper Mar 8, 2024
d8169e5
wip - fix external call codegen
charles-cooper Mar 9, 2024
45960ef
Merge branch 'master' into feat/infer_expected_types
charles-cooper Mar 13, 2024
a7556d8
add a hint
charles-cooper Mar 14, 2024
0eedb38
handle more cases in generic
charles-cooper Mar 19, 2024
06217ef
Merge branch 'master' into feat/infer_expected_types
charles-cooper Mar 19, 2024
6ec7730
handle TYPE_T
charles-cooper Mar 19, 2024
9988eb3
Merge branch 'master' into feat/infer_expected_types
charles-cooper Mar 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/functional/builtins/folding/test_bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from tests.utils import parse_and_fold
from vyper.exceptions import OverflowException, TypeMismatch
from vyper.semantics.analysis.utils import validate_expected_type
from vyper.semantics.analysis.utils import infer_type
from vyper.semantics.types.shortcuts import INT256_T, UINT256_T
from vyper.utils import unsigned_to_signed

Expand Down Expand Up @@ -55,7 +55,7 @@ def foo(a: uint256, b: uint256) -> uint256:
# force bounds check, no-op because validate_numeric_bounds
# already does this, but leave in for hygiene (in case
# more types are added).
validate_expected_type(new_node, UINT256_T)
_ = infer_type(new_node, UINT256_T)
# compile time behavior does not match runtime behavior.
# compile-time will throw on OOB, runtime will wrap.
except OverflowException: # here: check the wrapped value matches runtime
Expand All @@ -82,7 +82,7 @@ def foo(a: int256, b: uint256) -> int256:
vyper_ast = parse_and_fold(f"{a} {op} {b}")
old_node = vyper_ast.body[0].value
new_node = old_node.get_folded_value()
validate_expected_type(new_node, INT256_T) # force bounds check
_ = infer_type(new_node, INT256_T) # force bounds check
# compile time behavior does not match runtime behavior.
# compile-time will throw on OOB, runtime will wrap.
except (TypeMismatch, OverflowException):
Expand Down
2 changes: 1 addition & 1 deletion vyper/abi_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def is_complex_type(self):

class ABI_Bytes(ABIType):
def __init__(self, bytes_bound):
if not bytes_bound >= 0:
if bytes_bound is not None and not bytes_bound >= 0:
raise InvalidABIType("Negative bytes_bound provided to ABI_Bytes")

self.bytes_bound = bytes_bound
Expand Down
3 changes: 2 additions & 1 deletion vyper/builtins/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ def _cast_bytestring(expr, arg, out_typ):
_FAIL(arg.typ, out_typ, expr)

ret = ["seq"]
if out_typ.maxlen < arg.typ.maxlen:
assert out_typ.maxlen is not None
if arg.typ.maxlen is None or out_typ.maxlen < arg.typ.maxlen:
ret.append(["assert", ["le", get_bytearray_length(arg), out_typ.maxlen]])
ret.append(arg)
# NOTE: this is a pointer cast
Expand Down
8 changes: 2 additions & 6 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
from vyper.codegen.ir_node import IRnode
from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode
from vyper.semantics.analysis.base import Modifiability
from vyper.semantics.analysis.utils import (
check_modifiability,
get_exact_type_from_node,
validate_expected_type,
)
from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node, infer_type

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.analysis.utils
begins an import cycle.
from vyper.semantics.types import TYPE_T, KwargSettings, VyperType
from vyper.semantics.types.utils import type_from_annotation

Expand Down Expand Up @@ -101,7 +97,7 @@
# for its side effects (will throw if is not a type)
type_from_annotation(arg)
else:
validate_expected_type(arg, expected_type)
infer_type(arg, expected_type)

def _validate_arg_types(self, node: vy_ast.Call) -> None:
num_args = len(self._inputs) # the number of args the signature indicates
Expand Down
58 changes: 22 additions & 36 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
get_common_types,
get_exact_type_from_node,
get_possible_types_from_node,
validate_expected_type,
infer_type,
)
from vyper.semantics.types import (
TYPE_T,
Expand Down Expand Up @@ -202,12 +202,14 @@ def infer_arg_types(self, node, expected_return_typ=None):
target_type = type_from_annotation(node.args[1])
value_types = get_possible_types_from_node(node.args[0])

# For `convert` of integer literals, we need to match type inference rules in
# convert.py codegen routines.
# For `convert` of integer literals, we need to match type inference
# rules in convert.py codegen routines.
# TODO: This can probably be removed once constant folding for `convert` is implemented
if len(value_types) > 1 and all(isinstance(v, IntegerT) for v in value_types):
# Get the smallest (and unsigned if available) type for non-integer target types
# (note this is different from the ordering returned by `get_possible_types_from_node`)
# Get the smallest (and unsigned if available) type for
# non-integer target types
# (note this is different from the ordering returned by
# `get_possible_types_from_node`)
if not isinstance(target_type, IntegerT):
value_types = sorted(value_types, key=lambda v: (v.is_signed, v.bits), reverse=True)
else:
Expand All @@ -218,7 +220,10 @@ def infer_arg_types(self, node, expected_return_typ=None):

# block conversions between same type
if target_type.compare_type(value_type):
raise InvalidType(f"Value and target type are both '{target_type}'", node)
raise InvalidType(
f"Value and target type are both `{target_type}`",
hint="try removing the call to `convert()`",
)

return [value_type, TYPE_T(target_type)]

Expand Down Expand Up @@ -301,11 +306,6 @@ class Slice(BuiltinFunctionT):
def fetch_call_return(self, node):
arg_type, _, _ = self.infer_arg_types(node)

if isinstance(arg_type, StringT):
return_type = StringT()
else:
return_type = BytesT()

# validate start and length are in bounds

arg = node.args[0]
Expand Down Expand Up @@ -334,20 +334,14 @@ def fetch_call_return(self, node):
if length_literal is not None and start_literal + length_literal > arg_type.length:
raise ArgumentException(f"slice out of bounds for {arg_type}", node)

# we know the length statically
return_cls = arg_type.__class__
if length_literal is not None:
return_type.set_length(length_literal)
return_type = return_cls(length_literal)
else:
return_type.set_min_length(arg_type.length)
return_type = return_cls(arg_type.length)

return return_type

def infer_arg_types(self, node, expected_return_typ=None):
self._validate_arg_types(node)
# return a concrete type for `b`
b_type = get_possible_types_from_node(node.args[0]).pop()
return [b_type, self._inputs[1][1], self._inputs[2][1]]

@process_inputs
def build_IR(self, expr, args, kwargs, context):
src, start, length = args
Expand Down Expand Up @@ -500,12 +494,8 @@ def fetch_call_return(self, node):
for arg_t in arg_types:
length += arg_t.length

if isinstance(arg_types[0], (StringT)):
return_type = StringT()
else:
return_type = BytesT()
return_type.set_length(length)
return return_type
return_type_cls = arg_types[0].__class__
return return_type_cls(length)

def infer_arg_types(self, node, expected_return_typ=None):
if len(node.args) < 2:
Expand All @@ -517,8 +507,7 @@ def infer_arg_types(self, node, expected_return_typ=None):
ret = []
prev_typeclass = None
for arg in node.args:
validate_expected_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any()))
arg_t = get_possible_types_from_node(arg).pop()
arg_t = infer_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any()))
current_typeclass = "String" if isinstance(arg_t, StringT) else "Bytes"
if prev_typeclass and current_typeclass != prev_typeclass:
raise TypeMismatch(
Expand Down Expand Up @@ -874,7 +863,7 @@ def infer_kwarg_types(self, node):
"Output type must be one of integer, bytes32 or address", node.keywords[0].value
)
output_typedef = TYPE_T(output_type)
node.keywords[0].value._metadata["type"] = output_typedef
# node.keywords[0].value._metadata["type"] = output_typedef
else:
output_typedef = TYPE_T(BYTES32_T)

Expand Down Expand Up @@ -1089,8 +1078,7 @@ def fetch_call_return(self, node):
raise

if outsize.value:
return_type = BytesT()
return_type.set_min_length(outsize.value)
return_type = BytesT(outsize.value)

if revert_on_failure:
return return_type
Expand Down Expand Up @@ -2382,8 +2370,8 @@ def infer_kwarg_types(self, node):
ret = {}
for kwarg in node.keywords:
kwarg_name = kwarg.arg
validate_expected_type(kwarg.value, self._kwargs[kwarg_name].typ)
ret[kwarg_name] = get_exact_type_from_node(kwarg.value)
typ = infer_type(kwarg.value, self._kwargs[kwarg_name].typ)
ret[kwarg_name] = typ
return ret

def fetch_call_return(self, node):
Expand Down Expand Up @@ -2413,9 +2401,7 @@ def fetch_call_return(self, node):
# the output includes 4 bytes for the method_id.
maxlen += 4

ret = BytesT()
ret.set_length(maxlen)
return ret
return BytesT(maxlen)

@staticmethod
def _parse_method_id(method_id_literal):
Expand Down
6 changes: 5 additions & 1 deletion vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,11 @@ def dummy_node_for_type(typ):


def _check_assign_bytes(left, right):
if right.typ.maxlen > left.typ.maxlen: # pragma: nocover
if (
left.typ.maxlen is not None
and right.typ.maxlen is not None
and right.typ.maxlen > left.typ.maxlen
): # pragma: nocover
raise TypeMismatch(f"Cannot cast from {right.typ} to {left.typ}")

# stricter check for zeroing a byte array.
Expand Down
15 changes: 8 additions & 7 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,26 +699,27 @@ def parse_Call(self):
assert func_t.is_internal or func_t.is_constructor
return self_call.ir_for_self_call(self.expr, self.context)

@classmethod
def handle_external_call(cls, expr, context):
def handle_external_call(self):
# TODO fix cyclic import
from vyper.builtins._signatures import BuiltinFunctionT

call_node = expr.value
call_node = self.expr.value
assert isinstance(call_node, vy_ast.Call)

func_t = call_node.func._metadata["type"]

if isinstance(func_t, BuiltinFunctionT):
return func_t.build_IR(call_node, context)
return func_t.build_IR(call_node, self.context)

return external_call.ir_for_external_call(call_node, context)
return external_call.ir_for_external_call(
call_node, self.context, discard_output=self.is_stmt
)

def parse_ExtCall(self):
return self.handle_external_call(self.expr, self.context)
return self.handle_external_call()

def parse_StaticCall(self):
return self.handle_external_call(self.expr, self.context)
return self.handle_external_call()

def parse_List(self):
typ = self.expr._metadata["type"]
Expand Down
Loading
Loading