Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[ux]: remove side effects in compare_type for bytestrings #3379

Open
wants to merge 68 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
30c54db
try what happens
tserg Apr 29, 2023
1212c50
minor fix
tserg Apr 30, 2023
135629e
try fix
tserg Apr 30, 2023
c87630c
fix lint
tserg Apr 30, 2023
809278a
update comment
tserg May 1, 2023
2cc75ef
handle nested bytestrings
tserg May 1, 2023
59c8f78
fix lint
tserg May 1, 2023
40077ea
try aggressive overriding of contract function return type
tserg May 1, 2023
b7e52c5
undo changes
tserg May 1, 2023
9f9981e
undo changes
tserg May 1, 2023
b07efdd
add widening test
tserg May 5, 2023
77cdd99
Merge branch 'master' of https://github.com/vyperlang/vyper into fix/…
tserg May 5, 2023
ace4367
check for propagated type
tserg May 5, 2023
44e0023
widen bytestring by deriving larger type
tserg May 5, 2023
f4060d7
remove _min_length; add _is_literal
tserg May 6, 2023
828929a
update comment
tserg May 6, 2023
cc1d76a
fix formatting
tserg May 6, 2023
7fcbacd
improve comment
tserg May 6, 2023
36028c2
use length instead of _length
tserg May 6, 2023
1b5a390
add comment
tserg May 6, 2023
00c26cf
fix lint
tserg May 6, 2023
2cc968e
Merge branch 'master' of https://github.com/vyperlang/vyper into fix/…
tserg Nov 8, 2023
014e989
fix literal cmp
tserg Nov 8, 2023
c1a9bbc
uncomment
tserg Nov 8, 2023
38e94d2
fix wip
tserg Nov 9, 2023
013e594
Merge branch 'master' of https://github.com/vyperlang/vyper into fix/…
tserg Nov 9, 2023
9ce8c4a
fix test
tserg Nov 9, 2023
1cb5066
fix any syntax
tserg Nov 9, 2023
5d6ad38
fix comment
tserg Nov 9, 2023
114973d
rename helper; change semantics
tserg Nov 9, 2023
eb8ffc5
fix lint
tserg Nov 9, 2023
ba527d3
add is_from_abi attr
tserg Nov 11, 2023
50e5000
remove set_length fn
tserg Nov 11, 2023
a325b99
fix slice
tserg Nov 12, 2023
d4d598a
fix lint
tserg Nov 12, 2023
8e8f204
remove is_literal
tserg Nov 12, 2023
6d29d1b
fix compare_type
tserg Nov 12, 2023
b175311
revert overwriting of ContractFunctionT return typ
tserg Nov 12, 2023
5a8490e
add comment
tserg Nov 12, 2023
ac9b222
fix lint
tserg Nov 12, 2023
e788797
wip - review
charles-cooper Nov 20, 2023
b33c878
fix lint
charles-cooper Nov 20, 2023
4a2744f
fix some return types
charles-cooper Nov 20, 2023
8b537c4
fix return type for call stmts
tserg Nov 20, 2023
27490e3
Merge branch 'master' into fix/bytestring_compare_type
charles-cooper Dec 7, 2023
a44779b
rename to _any_compare_type
charles-cooper Dec 7, 2023
1208ae1
fix lint
charles-cooper Dec 7, 2023
832f17b
rename fetch_call_return to get_return_type
charles-cooper Dec 7, 2023
23ef9d1
fix some signatures
charles-cooper Dec 8, 2023
8c82d1d
Merge branch 'master' of https://github.com/vyperlang/vyper into fix/…
tserg Dec 18, 2023
87096e2
Merge branch 'fix/bytestring_compare_type' of https://github.com/tser…
tserg Dec 18, 2023
f3553f5
fix lint
tserg Dec 18, 2023
ce5b120
fix more lint
tserg Dec 18, 2023
b124308
use Optional
tserg Dec 18, 2023
b94ae77
revert return analysis
tserg Dec 18, 2023
b9a474d
fix some tests
tserg Dec 18, 2023
6b58aeb
fix more tests
tserg Dec 18, 2023
5c6135c
revert circular call in infer_arg_types
tserg Dec 18, 2023
73193d5
fix lint
tserg Dec 18, 2023
9f667ca
fix more tests
tserg Dec 18, 2023
b3eb9f2
fix lint
tserg Dec 18, 2023
34ab4bc
fix import
tserg Dec 18, 2023
1b25514
Merge branch 'master' of https://github.com/vyperlang/vyper into fix/…
tserg Sep 21, 2024
729df1e
fix lint
tserg Sep 21, 2024
a1fbb11
fix convert
tserg Sep 21, 2024
24a306a
fix lint
tserg Sep 21, 2024
7c60586
fix typo
tserg Sep 21, 2024
b91e442
fix another typo
tserg Sep 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def process_inputs(wrapped_fn):
@functools.wraps(wrapped_fn)
def decorator_fn(self, node, context):
subs = []

for arg in node.args:
arg_ir = process_arg(arg, arg._metadata["type"], context)
# TODO annotate arg_ir with argname from self._inputs?
Expand Down Expand Up @@ -137,10 +138,15 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None:
def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool:
return self._modifiability <= modifiability

def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]:
def get_return_type(self, node: vy_ast.Call, expected_type: VyperType | None = None) -> Optional[VyperType]:
self._validate_arg_types(node)

return self._return_type
ret = self._return_type

if expected_type is not None and not expected_type.compare_type(ret):
raise TypeMismatch("{self._id}() returns {ret}, but expected {expected_type}", node)

return ret

def infer_arg_types(self, node: vy_ast.Call, expected_return_typ=None) -> list[VyperType]:
self._validate_arg_types(node)
Expand All @@ -152,6 +158,7 @@ def infer_arg_types(self, node: vy_ast.Call, expected_return_typ=None) -> list[V
if len(varargs) > 0:
assert self._has_varargs
ret.extend(get_exact_type_from_node(arg) for arg in varargs)

return ret

def infer_kwarg_types(self, node: vy_ast.Call) -> dict[str, VyperType]:
Expand Down
77 changes: 37 additions & 40 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@
# (2) should always be folded.
_inputs = [("typename", TYPE_T.any())]

def fetch_call_return(self, node):
type_ = self.infer_arg_types(node)[0].typedef
def get_return_type(self, node, expected_type=None):
type_ = self.infer_arg_types(node, expected_type)[0].typedef
return type_

def infer_arg_types(self, node, expected_return_typ=None):
Expand Down Expand Up @@ -194,8 +194,8 @@
class Convert(BuiltinFunctionT):
_id = "convert"

def fetch_call_return(self, node):
_, target_typedef = self.infer_arg_types(node)
def get_return_type(self, node, expected_return_typ=None):
_, target_typedef = self.infer_arg_types(node, expected_return_typ=expected_return_typ)

# note: more type conversion validation happens in convert.py
return target_typedef.typedef
Expand Down Expand Up @@ -293,14 +293,9 @@
("length", UINT256_T),
]

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
arg_type, _, _ = self.infer_arg_types(node)

if isinstance(arg_type, StringT):
return_type = StringT()
else:
return_type = BytesT()

# validate start and length are in bounds

arg = node.args[0]
Expand Down Expand Up @@ -330,10 +325,11 @@
raise ArgumentException(f"slice out of bounds for {arg_type}", node)

# we know the length statically
if length_literal is not None:
return_type.set_length(length_literal)
length = length_literal if length_literal is not None else 0
if isinstance(arg_type, StringT):
return_type = StringT(length)
else:
return_type.set_min_length(arg_type.length)
return_type = BytesT(length)

return return_type

Expand Down Expand Up @@ -491,18 +487,17 @@
class Concat(BuiltinFunctionT):
_id = "concat"

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
arg_types = self.infer_arg_types(node)

length = 0
for arg_t in arg_types:
length += arg_t.length

if isinstance(arg_types[0], (StringT)):
return_type = StringT()
return_type = StringT(length)
else:
return_type = BytesT()
return_type.set_length(length)
return_type = BytesT(length)
return return_type

def infer_arg_types(self, node, expected_return_typ=None):
Expand Down Expand Up @@ -732,7 +727,7 @@
else:
return vy_ast.Bytes.from_node(node, value=value)

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
validate_call_args(node, 1, ["output_type"])

type_ = self.infer_kwarg_types(node)["output_type"].typedef
Expand Down Expand Up @@ -837,7 +832,7 @@
_inputs = [("b", BytesT.any()), ("start", IntegerT.unsigneds())]
_kwargs = {"output_type": KwargSettings(TYPE_T.any(), BYTES32_T)}

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
self._validate_arg_types(node)
return_type = self.infer_kwarg_types(node)["output_type"].typedef
return return_type
Expand Down Expand Up @@ -952,7 +947,7 @@

return vy_ast.Int.from_node(node, value=int(value * denom))

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
self.infer_arg_types(node)
return self._return_type

Expand Down Expand Up @@ -1015,7 +1010,7 @@
"revert_on_failure": KwargSettings(BoolT(), True, require_literal=True),
}

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
self._validate_arg_types(node)

kwargz = {i.arg: i.value for i in node.keywords}
Expand All @@ -1039,8 +1034,7 @@
raise

if outsize.value:
return_type = BytesT()
return_type.set_min_length(outsize.value)
return_type = BytesT(outsize.value)

if revert_on_failure:
return return_type
Expand Down Expand Up @@ -1231,7 +1225,7 @@
_return_type = None
_is_terminus = True

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
return None

def infer_arg_types(self, node, expected_return_typ=None):
Expand All @@ -1251,7 +1245,7 @@
_id = "raw_log"
_inputs = [("topics", DArrayT(BYTES32_T, 4)), ("data", (BYTES32_T, BytesT.any()))]

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
self.infer_arg_types(node)

def infer_arg_types(self, node, expected_return_typ=None):
Expand Down Expand Up @@ -1422,7 +1416,7 @@
value = (value << shift) % (2**256)
return vy_ast.Int.from_node(node, value=value)

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
# return type is the type of the first argument
return self.infer_arg_types(node)[0]

Expand Down Expand Up @@ -1902,7 +1896,7 @@
def __repr__(self):
return f"builtin function unsafe_{self.op}"

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
return_type = self.infer_arg_types(node).pop()
return return_type

Expand Down Expand Up @@ -1987,7 +1981,11 @@
value = self._eval_fn(left.value, right.value)
return type(left).from_node(node, value=value)

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
return_type = self.infer_arg_types(node).pop()
return return_type

def infer_arg_types(self, node, expected_return_type=None):
Fixed Show fixed Hide fixed
self._validate_arg_types(node)

types_list = get_common_types(
Expand Down Expand Up @@ -2040,7 +2038,7 @@
_id = "uint2str"
_inputs = [("x", IntegerT.unsigneds())]

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
arg_t = self.infer_arg_types(node)[0]
bits = arg_t.bits
len_needed = math.ceil(bits * math.log(2) / math.log(10))
Expand All @@ -2065,7 +2063,7 @@

@process_inputs
def build_IR(self, expr, args, kwargs, context):
return_t = self.fetch_call_return(expr)
return_t = self.get_return_type(expr)
n_digits = return_t.maxlen

with args[0].cache_when_complex("val") as (b1, val):
Expand Down Expand Up @@ -2227,7 +2225,7 @@
class Empty(TypenameFoldedFunctionT):
_id = "empty"

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
type_ = self.infer_arg_types(node)[0].typedef
if isinstance(type_, HashMapT):
raise TypeMismatch("Cannot use empty on HashMap", node)
Expand All @@ -2245,7 +2243,7 @@

_warned = False

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
if not self._warned:
vyper_warn("`breakpoint` should only be used for debugging!", node)
self._warned = True
Expand All @@ -2265,7 +2263,7 @@

_warned = False

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
if not self._warned:
vyper_warn("`print` should only be used for debugging!", node)
self._warned = True
Expand Down Expand Up @@ -2366,7 +2364,7 @@
ret[kwarg_name] = get_exact_type_from_node(kwarg.value)
return ret

def fetch_call_return(self, node):
def get_return_type(self, node, expected_type=None):
self._validate_arg_types(node)
ensure_tuple = next(
(arg.value.value for arg in node.keywords if arg.arg == "ensure_tuple"), True
Expand All @@ -2393,9 +2391,7 @@
# the output includes 4 bytes for the method_id.
maxlen += 4

ret = BytesT()
ret.set_length(maxlen)
return ret
return BytesT(maxlen)

@staticmethod
def _parse_method_id(method_id_literal):
Expand All @@ -2412,6 +2408,7 @@
def build_IR(self, expr, args, kwargs, context):
ensure_tuple = kwargs["ensure_tuple"]
method_id = self._parse_method_id(kwargs["method_id"])
expr_type = expr._metadata["type"]

if len(args) < 1:
raise StructureException("abi_encode expects at least one argument", expr)
Expand All @@ -2429,7 +2426,7 @@
maxlen += 4

buf_t = BytesT(maxlen)
assert self.fetch_call_return(expr).length == maxlen
assert self.get_return_type(expr, expr_type).length == maxlen
buf = context.new_internal_variable(buf_t)

ret = ["seq"]
Expand Down Expand Up @@ -2465,8 +2462,8 @@
_inputs = [("data", BytesT.any()), ("output_type", TYPE_T.any())]
_kwargs = {"unwrap_tuple": KwargSettings(BoolT(), True, require_literal=True)}

def fetch_call_return(self, node):
_, output_type = self.infer_arg_types(node)
def get_return_type(self, node, expected_type=None):
_, output_type = self.infer_arg_types(node, return_type=expected_type)
Fixed Show fixed Hide fixed
return output_type.typedef

def infer_arg_types(self, node, expected_return_typ=None):
Expand Down
24 changes: 12 additions & 12 deletions vyper/codegen/external_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.evm.address_space import MEMORY
from vyper.exceptions import TypeCheckFailure
from vyper.semantics.types import InterfaceT, TupleT
from vyper.semantics.types import InterfaceT, TupleT, VOID_TYPE
from vyper.semantics.types.function import StateMutability


Expand All @@ -32,7 +32,7 @@ class _CallKwargs:
default_return_value: IRnode


def _pack_arguments(fn_type, args, context):
def _pack_arguments(fn_type, return_t, args, context):
# abi encoding just treats all args as a big tuple
args_tuple_t = TupleT([x.typ for x in args])
args_as_tuple = IRnode.from_list(["multi"] + [x for x in args], typ=args_tuple_t)
Expand All @@ -42,8 +42,8 @@ def _pack_arguments(fn_type, args, context):
dst_tuple_t = TupleT(fn_type.argument_types[: len(args)])
check_assign(dummy_node_for_type(dst_tuple_t), args_as_tuple)

if fn_type.return_type is not None:
return_abi_t = calculate_type_for_external_return(fn_type.return_type).abi_type
if return_t is not VOID_TYPE:
return_abi_t = calculate_type_for_external_return(return_t).abi_type

# we use the same buffer for args and returndata,
# so allocate enough space here for the returndata too.
Expand Down Expand Up @@ -78,10 +78,8 @@ def _pack_arguments(fn_type, args, context):
return buf, pack_args, args_ofst, args_len


def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context, expr):
return_t = fn_type.return_type

if return_t is None:
def _unpack_returndata(buf, fn_type, return_t, call_kwargs, contract_address, context, expr):
if return_t is VOID_TYPE:
return ["pass"], 0, 0

wrapped_return_t = calculate_type_for_external_return(return_t)
Expand Down Expand Up @@ -190,6 +188,9 @@ def _extcodesize_check(address):

def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, context):
fn_type = call_expr.func._metadata["type"]
# the return type may differ from the function's return type if the function was
# imported via ABI e.g. widening of bytestrings
return_t = call_expr._metadata["type"]

# sanity check
assert fn_type.n_positional_args <= len(args_ir) <= fn_type.n_total_args
Expand All @@ -201,15 +202,15 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con
# a duplicate label exception will get thrown during assembly.
ret.append(eval_once_check(_freshname(call_expr.node_source_code)))

buf, arg_packer, args_ofst, args_len = _pack_arguments(fn_type, args_ir, context)
buf, arg_packer, args_ofst, args_len = _pack_arguments(fn_type, return_t, args_ir, context)

ret_unpacker, ret_ofst, ret_len = _unpack_returndata(
buf, fn_type, call_kwargs, contract_address, context, call_expr
buf, fn_type, return_t, call_kwargs, contract_address, context, call_expr
)

ret += arg_packer

if fn_type.return_type is None and not call_kwargs.skip_contract_check:
if return_t is None and not call_kwargs.skip_contract_check:
# if we do not expect return data, check that a contract exists at the
# target address. we must perform this check BEFORE the call because
# the contract might selfdestruct. on the other hand we can omit this
Expand All @@ -232,7 +233,6 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con

ret.append(check_external_call(call_op))

return_t = fn_type.return_type
if return_t is not None:
ret.append(ret_unpacker)

Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ bar: bytes32 = sha256(b"hash me!")

1. We look up `sha256` in `namespace` and retrieve the definition for the builtin
function.
2. We call `fetch_call_return` on the function definition object, with the AST
2. We call `get_return_type` on the function definition object, with the AST
node representing the call. This method validates the input arguments, and returns
a `BytesM_T` with `m=32`.
3. We validation of the declaration of `bar` in the same manner as the first
Expand Down
Loading
Loading