diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index d012e4a1cf..f748eac8de 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -50,6 +50,7 @@ def process_inputs(wrapped_fn): @functools.wraps(wrapped_fn) def decorator_fn(self, node, context): subs = [] + for arg in node.args: arg_ir = process_arg(arg, arg._metadata["type"], context) # TODO annotate arg_ir with argname from self._inputs? @@ -137,10 +138,17 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None: def check_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifiability) -> bool: return self._modifiability <= modifiability - def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: + def get_return_type( + self, node: vy_ast.Call, expected_type: Optional[VyperType] = None + ) -> Optional[VyperType]: self._validate_arg_types(node) - return self._return_type + ret = self._return_type + + if expected_type is not None and not expected_type.compare_type(ret): # type: ignore + raise TypeMismatch(f"{self._id}() returns {ret}, but expected {expected_type}", node) + + return ret def infer_arg_types(self, node: vy_ast.Call, expected_return_typ=None) -> list[VyperType]: self._validate_arg_types(node) @@ -152,6 +160,7 @@ def infer_arg_types(self, node: vy_ast.Call, expected_return_typ=None) -> list[V if len(varargs) > 0: assert self._has_varargs ret.extend(get_exact_type_from_node(arg) for arg in varargs) + return ret def infer_kwarg_types(self, node: vy_ast.Call) -> dict[str, VyperType]: diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 674efda7ce..11c72d8674 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -119,8 +119,8 @@ class TypenameFoldedFunctionT(FoldedFunctionT): # (2) should always be folded. _inputs = [("typename", TYPE_T.any())] - def fetch_call_return(self, node): - type_ = self.infer_arg_types(node)[0].typedef + def get_return_type(self, node, expected_type=None): + type_ = self.infer_arg_types(node, expected_type)[0].typedef return type_ def infer_arg_types(self, node, expected_return_typ=None): @@ -194,8 +194,8 @@ def build_IR(self, expr, args, kwargs, context): class Convert(BuiltinFunctionT): _id = "convert" - def fetch_call_return(self, node): - _, target_typedef = self.infer_arg_types(node) + def get_return_type(self, node, expected_typ=None): + _, target_typedef = self.infer_arg_types(node, expected_return_typ=expected_typ) # note: more type conversion validation happens in convert.py return target_typedef.typedef @@ -293,14 +293,9 @@ class Slice(BuiltinFunctionT): ("length", UINT256_T), ] - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): arg_type, _, _ = self.infer_arg_types(node) - if isinstance(arg_type, StringT): - return_type = StringT() - else: - return_type = BytesT() - # validate start and length are in bounds arg = node.args[0] @@ -330,10 +325,11 @@ def fetch_call_return(self, node): raise ArgumentException(f"slice out of bounds for {arg_type}", node) # we know the length statically - if length_literal is not None: - return_type.set_length(length_literal) + length = length_literal if length_literal is not None else 0 + if isinstance(arg_type, StringT): + return_type = StringT(length) else: - return_type.set_min_length(arg_type.length) + return_type = BytesT(length) return return_type @@ -491,7 +487,7 @@ def build_IR(self, node, context): class Concat(BuiltinFunctionT): _id = "concat" - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): arg_types = self.infer_arg_types(node) length = 0 @@ -499,10 +495,9 @@ def fetch_call_return(self, node): length += arg_t.length if isinstance(arg_types[0], (StringT)): - return_type = StringT() + return_type = StringT(length) else: - return_type = BytesT() - return_type.set_length(length) + return_type = BytesT(length) return return_type def infer_arg_types(self, node, expected_return_typ=None): @@ -732,7 +727,7 @@ def _try_fold(self, node): else: return vy_ast.Bytes.from_node(node, value=value) - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): validate_call_args(node, 1, ["output_type"]) type_ = self.infer_kwarg_types(node)["output_type"].typedef @@ -837,7 +832,7 @@ class Extract32(BuiltinFunctionT): _inputs = [("b", BytesT.any()), ("start", IntegerT.unsigneds())] _kwargs = {"output_type": KwargSettings(TYPE_T.any(), BYTES32_T)} - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): self._validate_arg_types(node) return_type = self.infer_kwarg_types(node)["output_type"].typedef return return_type @@ -952,7 +947,7 @@ def _try_fold(self, node): return vy_ast.Int.from_node(node, value=int(value * denom)) - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): self.infer_arg_types(node) return self._return_type @@ -1015,7 +1010,7 @@ class RawCall(BuiltinFunctionT): "revert_on_failure": KwargSettings(BoolT(), True, require_literal=True), } - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): self._validate_arg_types(node) kwargz = {i.arg: i.value for i in node.keywords} @@ -1039,8 +1034,7 @@ def fetch_call_return(self, node): raise if outsize.value: - return_type = BytesT() - return_type.set_min_length(outsize.value) + return_type = BytesT(outsize.value) if revert_on_failure: return return_type @@ -1231,7 +1225,7 @@ class RawRevert(BuiltinFunctionT): _return_type = None _is_terminus = True - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): return None def infer_arg_types(self, node, expected_return_typ=None): @@ -1251,7 +1245,7 @@ class RawLog(BuiltinFunctionT): _id = "raw_log" _inputs = [("topics", DArrayT(BYTES32_T, 4)), ("data", (BYTES32_T, BytesT.any()))] - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): self.infer_arg_types(node) def infer_arg_types(self, node, expected_return_typ=None): @@ -1422,7 +1416,7 @@ def _try_fold(self, node): value = (value << shift) % (2**256) return vy_ast.Int.from_node(node, value=value) - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): # return type is the type of the first argument return self.infer_arg_types(node)[0] @@ -1902,7 +1896,7 @@ class _UnsafeMath(BuiltinFunctionT): def __repr__(self): return f"builtin function unsafe_{self.op}" - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): return_type = self.infer_arg_types(node).pop() return return_type @@ -1987,7 +1981,11 @@ def _try_fold(self, node): value = self._eval_fn(left.value, right.value) return type(left).from_node(node, value=value) - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): + return_type = self.infer_arg_types(node).pop() + return return_type + + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) types_list = get_common_types( @@ -1998,12 +1996,6 @@ def fetch_call_return(self, node): return types_list - def infer_arg_types(self, node, expected_return_typ=None): - types_list = self.fetch_call_return(node) - # type mismatch should have been caught in `fetch_call_return` - assert expected_return_typ in types_list - return [expected_return_typ, expected_return_typ] - @process_inputs def build_IR(self, expr, args, kwargs, context): op = self._opcode @@ -2040,7 +2032,7 @@ class Uint2Str(BuiltinFunctionT): _id = "uint2str" _inputs = [("x", IntegerT.unsigneds())] - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): arg_t = self.infer_arg_types(node)[0] bits = arg_t.bits len_needed = math.ceil(bits * math.log(2) / math.log(10)) @@ -2065,7 +2057,7 @@ def infer_arg_types(self, node, expected_return_typ=None): @process_inputs def build_IR(self, expr, args, kwargs, context): - return_t = self.fetch_call_return(expr) + return_t = self.get_return_type(expr) n_digits = return_t.maxlen with args[0].cache_when_complex("val") as (b1, val): @@ -2227,7 +2219,7 @@ def build_IR(self, expr, args, kwargs, context): class Empty(TypenameFoldedFunctionT): _id = "empty" - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): type_ = self.infer_arg_types(node)[0].typedef if isinstance(type_, HashMapT): raise TypeMismatch("Cannot use empty on HashMap", node) @@ -2245,7 +2237,7 @@ class Breakpoint(BuiltinFunctionT): _warned = False - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): if not self._warned: vyper_warn("`breakpoint` should only be used for debugging!", node) self._warned = True @@ -2265,7 +2257,7 @@ class Print(BuiltinFunctionT): _warned = False - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): if not self._warned: vyper_warn("`print` should only be used for debugging!", node) self._warned = True @@ -2366,7 +2358,7 @@ def infer_kwarg_types(self, node): ret[kwarg_name] = get_exact_type_from_node(kwarg.value) return ret - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): self._validate_arg_types(node) ensure_tuple = next( (arg.value.value for arg in node.keywords if arg.arg == "ensure_tuple"), True @@ -2393,9 +2385,7 @@ def fetch_call_return(self, node): # the output includes 4 bytes for the method_id. maxlen += 4 - ret = BytesT() - ret.set_length(maxlen) - return ret + return BytesT(maxlen) @staticmethod def _parse_method_id(method_id_literal): @@ -2412,6 +2402,7 @@ def _parse_method_id(method_id_literal): def build_IR(self, expr, args, kwargs, context): ensure_tuple = kwargs["ensure_tuple"] method_id = self._parse_method_id(kwargs["method_id"]) + expr_type = expr._metadata["type"] if len(args) < 1: raise StructureException("abi_encode expects at least one argument", expr) @@ -2429,7 +2420,7 @@ def build_IR(self, expr, args, kwargs, context): maxlen += 4 buf_t = BytesT(maxlen) - assert self.fetch_call_return(expr).length == maxlen + assert self.get_return_type(expr, expr_type).length == maxlen buf = context.new_internal_variable(buf_t) ret = ["seq"] @@ -2465,8 +2456,8 @@ class ABIDecode(BuiltinFunctionT): _inputs = [("data", BytesT.any()), ("output_type", TYPE_T.any())] _kwargs = {"unwrap_tuple": KwargSettings(BoolT(), True, require_literal=True)} - def fetch_call_return(self, node): - _, output_type = self.infer_arg_types(node) + def get_return_type(self, node, expected_type=None): + _, output_type = self.infer_arg_types(node, expected_return_typ=expected_type) return output_type.typedef def infer_arg_types(self, node, expected_return_typ=None): diff --git a/vyper/codegen/external_call.py b/vyper/codegen/external_call.py index 331b991bfe..aa10d3d927 100644 --- a/vyper/codegen/external_call.py +++ b/vyper/codegen/external_call.py @@ -20,7 +20,7 @@ from vyper.codegen.ir_node import Encoding, IRnode from vyper.evm.address_space import MEMORY from vyper.exceptions import TypeCheckFailure -from vyper.semantics.types import InterfaceT, TupleT +from vyper.semantics.types import VOID_TYPE, InterfaceT, TupleT from vyper.semantics.types.function import StateMutability @@ -32,7 +32,7 @@ class _CallKwargs: default_return_value: IRnode -def _pack_arguments(fn_type, args, context): +def _pack_arguments(fn_type, return_t, args, context): # abi encoding just treats all args as a big tuple args_tuple_t = TupleT([x.typ for x in args]) args_as_tuple = IRnode.from_list(["multi"] + [x for x in args], typ=args_tuple_t) @@ -42,8 +42,8 @@ def _pack_arguments(fn_type, args, context): dst_tuple_t = TupleT(fn_type.argument_types[: len(args)]) check_assign(dummy_node_for_type(dst_tuple_t), args_as_tuple) - if fn_type.return_type is not None: - return_abi_t = calculate_type_for_external_return(fn_type.return_type).abi_type + if return_t is not VOID_TYPE: + return_abi_t = calculate_type_for_external_return(return_t).abi_type # we use the same buffer for args and returndata, # so allocate enough space here for the returndata too. @@ -78,10 +78,8 @@ def _pack_arguments(fn_type, args, context): return buf, pack_args, args_ofst, args_len -def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context, expr): - return_t = fn_type.return_type - - if return_t is None: +def _unpack_returndata(buf, fn_type, return_t, call_kwargs, contract_address, context, expr): + if return_t is VOID_TYPE: return ["pass"], 0, 0 wrapped_return_t = calculate_type_for_external_return(return_t) @@ -190,6 +188,9 @@ def _extcodesize_check(address): def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, context): fn_type = call_expr.func._metadata["type"] + # the return type may differ from the function's return type if the function was + # imported via ABI e.g. widening of bytestrings + return_t = call_expr._metadata["type"] # sanity check assert fn_type.n_positional_args <= len(args_ir) <= fn_type.n_total_args @@ -201,15 +202,15 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con # a duplicate label exception will get thrown during assembly. ret.append(eval_once_check(_freshname(call_expr.node_source_code))) - buf, arg_packer, args_ofst, args_len = _pack_arguments(fn_type, args_ir, context) + buf, arg_packer, args_ofst, args_len = _pack_arguments(fn_type, return_t, args_ir, context) ret_unpacker, ret_ofst, ret_len = _unpack_returndata( - buf, fn_type, call_kwargs, contract_address, context, call_expr + buf, fn_type, return_t, call_kwargs, contract_address, context, call_expr ) ret += arg_packer - if fn_type.return_type is None and not call_kwargs.skip_contract_check: + if return_t is None and not call_kwargs.skip_contract_check: # if we do not expect return data, check that a contract exists at the # target address. we must perform this check BEFORE the call because # the contract might selfdestruct. on the other hand we can omit this @@ -232,7 +233,6 @@ def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, con ret.append(check_external_call(call_op)) - return_t = fn_type.return_type if return_t is not None: ret.append(ret_unpacker) diff --git a/vyper/semantics/README.md b/vyper/semantics/README.md index 7a8a384c6d..44326a2618 100644 --- a/vyper/semantics/README.md +++ b/vyper/semantics/README.md @@ -203,7 +203,7 @@ bar: bytes32 = sha256(b"hash me!") 1. We look up `sha256` in `namespace` and retrieve the definition for the builtin function. -2. We call `fetch_call_return` on the function definition object, with the AST +2. We call `get_return_type` on the function definition object, with the AST node representing the call. This method validates the input arguments, and returns a `BytesM_T` with `m=32`. 3. We validation of the declaration of `bar` in the same manner as the first diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index b5292b1dad..3d804f689c 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -508,7 +508,7 @@ def visit_Expr(self, node): raise StructureException("Struct creation without assignment is disallowed", node) # NOTE: fetch_call_return validates call args. - return_value = map_void(fn_type.fetch_call_return(call_node)) + return_value = map_void(fn_type.get_return_type(call_node)) if ( return_value is not VOID_TYPE and not isinstance(fn_type, MemberFunctionT) @@ -603,7 +603,7 @@ def visit_Log(self, node): raise StructureException( f"Cannot emit logs from {self.func.mutability} functions", node ) - t = map_void(f.fetch_call_return(node.value)) + t = map_void(f.get_return_type(node.value)) # CMC 2024-02-05 annotate the event type for codegen usage # TODO: refactor this node._metadata["type"] = f.typedef @@ -614,18 +614,19 @@ def visit_Raise(self, node): self._validate_revert_reason(node.exc) def visit_Return(self, node): - values = node.value - if values is None: - if self.func.return_type: + return_value = node.value + if return_value is None: + if self.func.return_type is not None: raise FunctionDeclarationException("Return statement is missing a value", node) return elif self.func.return_type is None: raise FunctionDeclarationException("Function should not return any values", node) - if isinstance(values, vy_ast.Tuple): - values = values.elements + if isinstance(return_value, vy_ast.Tuple): + values = return_value.elements if not isinstance(self.func.return_type, TupleT): raise FunctionDeclarationException("Function only returns a single value", node) + if self.func.return_type.length != len(values): raise FunctionDeclarationException( f"Incorrect number of return values: " diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index d30eee79e0..055a461721 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -266,7 +266,7 @@ def types_from_Compare(self, node): raise InvalidOperation( "Right operand must be Array for membership comparison", node.right ) - types_list = [i for i in left if _is_type_in_list(i, [i.value_type for i in right])] + types_list = [i for i in left if _any_compare_type(i, [i.value_type for i in right])] if not types_list: raise TypeMismatch( "Cannot perform membership comparison between dislike types", node @@ -287,7 +287,7 @@ def types_from_StaticCall(self, node): def types_from_Call(self, node): # function calls, e.g. `foo()` or `MyStruct()` var = self.get_exact_type_from_node(node.func, include_type_exprs=True) - return_value = var.fetch_call_return(node) + return_value = var.get_return_type(node) if return_value: if isinstance(return_value, list): return return_value @@ -432,9 +432,9 @@ def _is_empty_list(node): return all(_is_empty_list(t) for t in node.elements) -def _is_type_in_list(obj, types_list): - # check if a type object is in a list of types - return any(i.compare_type(obj) for i in types_list) +def _any_compare_type(obj, types_list): + # check if an expression of a list of types can be assigned to a type object + return any(obj.compare_type(i) for i in types_list) # NOTE: dead fn @@ -541,7 +541,7 @@ def _validate_literal_array(node, expected): for item in node.elements: try: validate_expected_type(item, expected.value_type) - except (InvalidType, TypeMismatch): + except TypeMismatch: return False return True diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 128ede0d5b..4f7e462d5f 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -307,7 +307,10 @@ def compare_type(self, other: "VyperType") -> bool: """ return isinstance(other, type(self)) - def fetch_call_return(self, node: vy_ast.Call) -> Optional["VyperType"]: + def get_return_type( + self, node: vy_ast.Call, expected_type: Optional["VyperType"] = None + ) -> Optional["VyperType"]: + # TODO will be cleaner to separate into validate_call and get_return_type """ Validate a call to this type and return the result. @@ -416,7 +419,10 @@ def check_modifiability_for_call(self, node, modifiability): raise StructureException("Value is not callable", node) # dispatch into ctor if it's called - def fetch_call_return(self, node): + def get_return_type(self, node, expected_type=None): + if expected_type is not None: + if not self.typedef.compare_type(expected_type): + raise CompilerPanic(f"bad type passed to {self.typedef} ctor", node) if hasattr(self.typedef, "_ctor_call_return"): return self.typedef._ctor_call_return(node) raise StructureException("Value is not callable", node) diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index cd330681cf..8d9c7f2699 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -6,41 +6,40 @@ from vyper.utils import ceil32 +class _UnknownLength: + pass + + +UNKNOWN_LENGTH = _UnknownLength() + + class _BytestringT(VyperType): """ Private base class for single-value types which occupy multiple memory slots and where a maximum length must be given via a subscript (string, bytes). - Types for literals have an inferred minimum length. For example, `b"hello"` - has a length of 5 of more and so can be used in an operation with `bytes[5]` - or `bytes[10]`, but not `bytes[4]`. Upon comparison to a fixed length type, - the minimum length is discarded and the type assumes the fixed length it was - compared against. + Types for literals are initialized to the literal's length. Attributes ---------- _length : int - The maximum allowable length of the data within the type. - _min_length: int - The minimum length of the data within the type. Used when the type - is applied to a literal definition. + The length of the data within the type. """ # this is a carveout because currently we allow dynamic arrays of # bytestrings, but not static arrays of bytestrings _as_darray = True _as_hashmap_key = True - _equality_attrs = ("_length", "_min_length") + _equality_attrs = ("_length",) _is_bytestring: bool = True - def __init__(self, length: int = 0) -> None: + def __init__(self, length: int | _UnknownLength = UNKNOWN_LENGTH) -> None: super().__init__() self._length = length - self._min_length = length def __repr__(self): - return f"{self._id}[{self.length}]" + return f"{self._id}[{self._length}]" def _addl_dict_fields(self): return {"length": self.length} @@ -50,9 +49,9 @@ def length(self): """ Property method used to check the length of a type. """ - if self._length: - return self._length - return self._min_length + if self._length is UNKNOWN_LENGTH: + return 0 + return self._length @property def maxlen(self): @@ -78,51 +77,21 @@ def size_in_bytes(self): return 32 + ceil32(self.length) - def set_length(self, length): - """ - Sets the exact length of the type. - - May only be called once, and only on a type that does not yet have - a fixed length. - """ - if self._length: - raise CompilerPanic("Type already has a fixed length") - self._length = length - self._min_length = length - - def set_min_length(self, min_length): - """ - Sets the minimum length of the type. - - May only be used to increase the minimum length. May not be called if - an exact length has been set. - """ - if self._length: - raise CompilerPanic("Type already has a fixed length") - if self._min_length > min_length: - raise CompilerPanic("Cannot reduce the min_length of ArrayValueType") - self._min_length = min_length - + # note: definition of compare_type is: + # expr of type `other` can be assigned to expr of type `self` def compare_type(self, other): if not super().compare_type(other): return False - # CMC 2022-03-18 TODO this method should be refactored so it does not have side effects - - # when comparing two literals, both now have an equal min-length - if not self._length and not other._length: - min_length = max(self._min_length, other._min_length) - self.set_min_length(min_length) - other.set_min_length(min_length) + # relax typechecking if length has not been set for either type + # (e.g. JSON ABI import, `address.code`) so that it can be updated in + # annotation phase + # note that if both lengths are unknown, there is an exception + # but it will be handled elsewhere. + if self._length is UNKNOWN_LENGTH or other._length is UNKNOWN_LENGTH: return True - # comparing a defined length to a literal causes the literal to have a fixed length - if self._length: - if not other._length: - other.set_length(max(self._length, other._min_length)) - return self._length >= other._length - - return other.compare_type(self) + return self._length >= other._length @classmethod def from_annotation(cls, node: vy_ast.VyperNode) -> "_BytestringT": @@ -150,9 +119,7 @@ def from_annotation(cls, node: vy_ast.VyperNode) -> "_BytestringT": def from_literal(cls, node: vy_ast.Constant) -> "_BytestringT": if not isinstance(node, cls._valid_literal): raise UnexpectedNodeType(f"Not a {cls._id}: {node}") - t = cls() - t.set_min_length(len(node.value)) - return t + return cls(len(node.value)) class BytesT(_BytestringT): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 7a56b01281..0eee4a7e01 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -62,7 +62,7 @@ class ContractFunctionT(VyperType): Contract function type. Functions compare false against all types and so cannot be assigned without - being called. Calls are validated by `fetch_call_return`, check the call + being called. Calls are validated by `get_return_type`, check the call arguments against `positional_args` and `keyword_arg`, and return `return_type`. Attributes @@ -97,6 +97,7 @@ def __init__( state_mutability: StateMutability, from_interface: bool = False, nonreentrant: bool = False, + is_from_abi: Optional[bool] = False, ast_def: Optional[vy_ast.VyperNode] = None, ) -> None: super().__init__() @@ -108,6 +109,7 @@ def __init__( self.visibility = function_visibility self.mutability = state_mutability self.nonreentrant = nonreentrant + self.is_from_abi = is_from_abi self.from_interface = from_interface # sanity check, nonreentrant used to be Optional[str] @@ -250,6 +252,7 @@ def from_abi(cls, abi: dict) -> "ContractFunctionT": from_interface=True, function_visibility=FunctionVisibility.EXTERNAL, state_mutability=StateMutability.from_abi(abi), + is_from_abi=True, ) @classmethod @@ -616,7 +619,9 @@ def _enhance_call_exception(self, e, ast_node=None): e.hint = self._pp_signature return e - def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: + def get_return_type( + self, node: vy_ast.Call, expected_type: VyperType | None = None + ) -> Optional[VyperType]: # mypy hint - right now, the only way a ContractFunctionT can be # called is via `Attribute`, e.x. self.foo() or library.bar() assert isinstance(node.func, vy_ast.Attribute) @@ -876,7 +881,9 @@ def _id(self): def __repr__(self): return f"{self.underlying_type._id} member function '{self.name}'" - def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: + def get_return_type( + self, node: vy_ast.Call, expected_type: VyperType | None = None + ) -> VyperType | None: validate_call_args(node, len(self.arg_types)) assert len(node.args) == len(self.arg_types) # validate_call_args postcondition diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index ca8e99bc92..d7a896b73c 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -138,7 +138,9 @@ def from_FlagDef(cls, base_node: vy_ast.FlagDef) -> "FlagT": return cls(base_node.name, members) - def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: + def get_return_type( + self, node: vy_ast.Call, expected_type: VyperType | None = None + ) -> VyperType | None: # TODO return None