diff --git a/contracts/account/ArgentAccount.cairo b/contracts/account/ArgentAccount.cairo index 473fbc3d..e1418097 100644 --- a/contracts/account/ArgentAccount.cairo +++ b/contracts/account/ArgentAccount.cairo @@ -11,7 +11,7 @@ from starkware.starknet.common.syscalls import ( from contracts.utils.calls import ( CallArray, - execute_call_array, + execute_multicall, ) from contracts.account.library import ( @@ -34,7 +34,7 @@ from contracts.account.library import ( ///////////////////// const NAME = 'ArgentAccount'; -const VERSION = '0.2.3'; +const VERSION = '0.2.4'; ///////////////////// // EVENTS @@ -134,7 +134,7 @@ func __execute__{ assert_non_reentrant(); // execute calls - let (retdata_len, retdata) = execute_call_array(call_array_len, call_array, calldata_len, calldata); + let (retdata_len, retdata) = execute_multicall(call_array_len, call_array, calldata); // emit event transaction_executed.emit( @@ -231,18 +231,13 @@ func upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( // upgrades the implementation ArgentModel.upgrade(implementation); // library call to implementation.execute_after_upgrade - if (calldata_len == 0) { - let (retdata: felt*) = alloc(); - return (retdata_len=0, retdata=retdata); - } else { - let (retdata_size: felt, retdata: felt*) = library_call( - class_hash=implementation, - function_selector=ArgentModel.EXECUTE_AFTER_UPGRADE_SELECTOR, - calldata_size=calldata_len, - calldata=calldata, - ); - return (retdata_len=retdata_size, retdata=retdata); - } + let (retdata_size: felt, retdata: felt*) = library_call( + class_hash=implementation, + function_selector=ArgentModel.EXECUTE_AFTER_UPGRADE_SELECTOR, + calldata_size=calldata_len, + calldata=calldata, + ); + return (retdata_len=retdata_size, retdata=retdata); } // @dev Logic or multicall to execute after an upgrade. @@ -261,7 +256,7 @@ func execute_after_upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range let (self) = get_contract_address(); assert_no_self_call(self, call_array_len, call_array); // execute calls - let (retdata_len, retdata) = execute_call_array(call_array_len, call_array, calldata_len, calldata); + let (retdata_len, retdata) = execute_multicall(call_array_len, call_array, calldata); return (retdata_len=retdata_len, retdata=retdata); } diff --git a/contracts/account/ArgentPluginAccount.cairo b/contracts/account/ArgentPluginAccount.cairo index e6e3b077..4eed85fc 100644 --- a/contracts/account/ArgentPluginAccount.cairo +++ b/contracts/account/ArgentPluginAccount.cairo @@ -10,11 +10,8 @@ from starkware.starknet.common.syscalls import ( ) from contracts.plugins.IPlugin import IPlugin from contracts.utils.calls import ( - Call, CallArray, - execute_call_array, - execute_calls, - from_call_array_to_call + execute_multicall, ) from contracts.account.library import ( ArgentModel, @@ -36,7 +33,7 @@ from contracts.account.library import ( /////////////////////// const NAME = 'ArgentPluginAccount'; -const VERSION = '0.0.2'; +const VERSION = '0.0.3'; // get_selector_from_name('use_plugin') const USE_PLUGIN_SELECTOR = 1121675007639292412441492001821602921366030142137563176027248191276862353634; @@ -153,7 +150,7 @@ func __execute__{ // no reentrant call to prevent signature reutilization assert_non_reentrant(); - let (retdata_len, retdata) = execute_call_array_plugin(call_array_len, call_array, calldata_len, calldata); + let (retdata_len, retdata) = execute_multicall_plugin(call_array_len, call_array, calldata); // emit event transaction_executed.emit( @@ -341,20 +338,16 @@ func upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( implementation: felt, calldata_len: felt, calldata: felt* ) -> (retdata_len: felt, retdata: felt*) { alloc_locals; + // upgrades the implementation ArgentModel.upgrade(implementation); - - if (calldata_len == 0) { - let (retdata: felt*) = alloc(); - return (retdata_len=0, retdata=retdata); - } else { - let (retdata_size: felt, retdata: felt*) = library_call( - class_hash=implementation, - function_selector=ArgentModel.EXECUTE_AFTER_UPGRADE_SELECTOR, - calldata_size=calldata_len, - calldata=calldata, - ); - return (retdata_len=retdata_size, retdata=retdata); - } + // library call to implementation.execute_after_upgrade + let (retdata_size: felt, retdata: felt*) = library_call( + class_hash=implementation, + function_selector=ArgentModel.EXECUTE_AFTER_UPGRADE_SELECTOR, + calldata_size=calldata_len, + calldata=calldata, + ); + return (retdata_len=retdata_size, retdata=retdata); } // @dev Logic or multicall to execute after an upgrade. @@ -373,7 +366,7 @@ func execute_after_upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range let (self) = get_contract_address(); assert_no_self_call(self, call_array_len, call_array); // execute calls - let (retdata_len, retdata) = execute_call_array(call_array_len, call_array, calldata_len, calldata); + let (retdata_len, retdata) = execute_multicall(call_array_len, call_array, calldata); return (retdata_len=retdata_len, retdata=retdata); } @@ -526,20 +519,16 @@ func is_valid_signature{ return (is_valid=is_valid); } -func execute_call_array_plugin{syscall_ptr: felt*}( - call_array_len: felt, call_array: CallArray*, calldata_len: felt, calldata: felt* +func execute_multicall_plugin{syscall_ptr: felt*}( + call_array_len: felt, call_array: CallArray*, calldata: felt* ) -> (retdata_len: felt, retdata: felt*) { alloc_locals; - let (calls: Call*) = alloc(); - from_call_array_to_call(call_array_len, call_array, calldata, calls); - - let (response: felt*) = alloc(); - if (calls[0].selector == USE_PLUGIN_SELECTOR) { - let (response_len) = execute_calls(call_array_len - 1, calls + Call.SIZE, response, 0); + if (call_array[0].selector == USE_PLUGIN_SELECTOR) { + let (response_len, response) = execute_multicall(call_array_len - 1, call_array + CallArray.SIZE, calldata); return (retdata_len=response_len, retdata=response); } else { - let (response_len) = execute_calls(call_array_len, calls, response, 0); + let (response_len, response) = execute_multicall(call_array_len, call_array, calldata); return (retdata_len=response_len, retdata=response); } } diff --git a/contracts/account/library.cairo b/contracts/account/library.cairo index 1e00eff9..d3995404 100644 --- a/contracts/account/library.cairo +++ b/contracts/account/library.cairo @@ -17,8 +17,9 @@ from contracts.upgrade.Upgradable import _set_implementation from contracts.utils.calls import CallArray const SUPPORTS_INTERFACE_SELECTOR = 1184015894760294494673613438913361435336722154500302038630992932234692784845; -const ERC165_ACCOUNT_INTERFACE_ID = 0x3943f10f; -const ERC165_ACCOUNT_INTERFACE_ID_OLD = 0xf10dbd44; // this is needed to upgrade to this version +const ERC165_ACCOUNT_INTERFACE_ID = 0xa66bd575; +const ERC165_ACCOUNT_INTERFACE_ID_OLD_1 = 0x3943f10f; // this is needed to upgrade to this version +const ERC165_ACCOUNT_INTERFACE_ID_OLD_2 = 0xf10dbd44; // this is needed to upgrade to this version const TRANSACTION_VERSION = 1; const QUERY_VERSION = 2**128 + TRANSACTION_VERSION; @@ -408,7 +409,7 @@ namespace ArgentModel { return (TRUE,); } // Old IAccount - if (interface_id == ERC165_ACCOUNT_INTERFACE_ID_OLD) { + if ((interface_id - ERC165_ACCOUNT_INTERFACE_ID_OLD_1) * (interface_id - ERC165_ACCOUNT_INTERFACE_ID_OLD_2) == 0) { return (TRUE,); } return (FALSE,); diff --git a/contracts/lib/Multicall.cairo b/contracts/lib/Multicall.cairo index 0fa7e08e..c50f9c1f 100644 --- a/contracts/lib/Multicall.cairo +++ b/contracts/lib/Multicall.cairo @@ -4,9 +4,8 @@ from starkware.starknet.common.syscalls import call_contract, get_block_number from starkware.cairo.common.alloc import alloc from starkware.cairo.common.memcpy import memcpy from contracts.utils.calls import ( - Call, CallArray, - from_call_array_to_call, + execute_multicall, ) @@ -15,7 +14,7 @@ from contracts.utils.calls import ( // contracts and return the aggregate response as an array. // Input: same as the IAccount.__execute__ // @return (block_number, retdata_size, retdata) -// Where retdata is [len(call_1_data), *call_1_data, len(call_1_data), *call_2_data, ..., len(call_N_data), *call_N_data] +// Where retdata is [len(call_1_data), *call_1_data, len(call_2_data), *call_2_data, ..., len(call_N_data), *call_N_data] // /////////////////////////////////////////////////////////////////////// @view func aggregate{syscall_ptr: felt*, range_check_ptr}( @@ -27,70 +26,7 @@ func aggregate{syscall_ptr: felt*, range_check_ptr}( block_number: felt, retdata_len: felt, retdata: felt* ) { alloc_locals; - let (retdata_len, retdata) = do_call_array(call_array_len, call_array, calldata_len, calldata); + let (retdata_len, retdata) = execute_multicall(call_array_len, call_array, calldata); let (block_number) = get_block_number(); return (block_number=block_number, retdata_len=retdata_len, retdata=retdata); } - - -// @notice Convenience method to convert an execute a call array -// @return response_len: The size of the returned data -// @return response: Data return -// in the form [len(call_1_data), *call_1_data, len(call_1_data), *call_2_data, ..., len(call_N_data), *call_N_data] -func do_call_array{syscall_ptr: felt*}( - call_array_len: felt, call_array: CallArray*, calldata_len: felt, calldata: felt* -) -> (retdata_len: felt, retdata: felt*) { - alloc_locals; - // convert calls - let (calls: Call*) = alloc(); - from_call_array_to_call(call_array_len, call_array, calldata, calls); - - // execute them - let (response: felt*) = alloc(); - let (response_len) = do_calls(call_array_len, calls, response, 0); - return (retdata_len=response_len, retdata=response); -} - - -// @notice Executes a list of contract calls recursively. -// @param calls_len The number of calls to execute -// @param calls A pointer to the first call to execute -// @param response The array of felt to populate with the returned data -// in the form [len(call_1_data), *call_1_data, len(call_1_data), *call_2_data, ..., len(call_N_data), *call_N_data] -// @return response_len The size of the returned data -func do_calls{syscall_ptr: felt*}(calls_len: felt, calls: Call*, response: felt*, index: felt) -> ( - response_len: felt -) { - alloc_locals; - - // if no more calls - if (calls_len == 0) { - return (0,); - } - - // do the current call - let this_call: Call = [calls]; - with_attr error_message("multicall {index} failed") { - let res = call_contract( - contract_address=this_call.to, - function_selector=this_call.selector, - calldata_size=this_call.calldata_len, - calldata=this_call.calldata, - ); - } - // copy the result in response - assert [response] = res.retdata_size; - memcpy( - dst=response + 1, - src=res.retdata, - len=res.retdata_size - ); - // do the next calls recursively - let (response_len) = do_calls( - calls_len=calls_len - 1, - calls=calls + Call.SIZE, - response=response + res.retdata_size + 1, - index=index + 1 - ); - return (response_len + res.retdata_size + 1,); -} diff --git a/contracts/utils/calls.cairo b/contracts/utils/calls.cairo index 84e470ae..b130d83e 100644 --- a/contracts/utils/calls.cairo +++ b/contracts/utils/calls.cairo @@ -23,88 +23,39 @@ struct CallArray { data_len: felt, } -func from_call_array_to_call{syscall_ptr: felt*}( - call_array_len: felt, call_array: CallArray*, calldata: felt*, calls: Call* -) { - // if no more calls - if (call_array_len == 0) { - return (); - } - - // parse the current call - assert [calls] = Call( - to=[call_array].to, - selector=[call_array].selector, - calldata_len=[call_array].data_len, - calldata=calldata + [call_array].data_offset - ); - - // parse the remaining calls recursively - from_call_array_to_call( - call_array_len - 1, call_array + CallArray.SIZE, calldata, calls + Call.SIZE - ); - return (); -} - - -// @notice Convenience method to convert an execute a call array +// @notice Executes a list of call array recursively // @return response_len: The size of the returned data -// @return response: Data return -// in the form [*call_1_data, *call_2_data, ..., *call_N_data] -func execute_call_array{syscall_ptr: felt*}( - call_array_len: felt, call_array: CallArray*, calldata_len: felt, calldata: felt* -) -> (retdata_len: felt, retdata: felt*) { +// @return response: An array of felt populated with the returned data +// in the form [len(call_1_data), *call_1_data, len(call_2_data), *call_2_data, ..., len(call_N_data), *call_N_data] +func execute_multicall{syscall_ptr: felt*}( + call_array_len: felt, call_array: CallArray*, calldata: felt* +) -> (response_len: felt, response: felt*) { alloc_locals; - // convert calls - let (calls: Call*) = alloc(); - from_call_array_to_call(call_array_len, call_array, calldata, calls); - // execute them - let (response: felt*) = alloc(); - let (response_len) = execute_calls(call_array_len, calls, response, 0); - return (retdata_len=response_len, retdata=response); -} + if (call_array_len == 0) { + let (response) = alloc(); + return (0, response); + } + // call recursively all previous calls + let (response_len, response: felt*) = execute_multicall(call_array_len - 1, call_array, calldata); -// @notice Executes a list of contract calls recursively. -// @param calls_len The number of calls to execute -// @param calls A pointer to the first call to execute -// @param response The array of felt to populate with the returned data -// in the form [*call_1_data, *call_2_data, ..., *call_N_data] -// @return response_len The size of the returned data -func execute_calls{syscall_ptr: felt*}(calls_len: felt, calls: Call*, response: felt*, index: felt) -> ( - response_len: felt -) { - alloc_locals; + // handle the last call + let last_call = call_array[call_array_len - 1]; - // if no more calls - if (calls_len == 0) { - return (0,); - } - - // do the current call - let this_call: Call = [calls]; - with_attr error_message("multicall {index} failed") { + // call the last call + with_attr error_message("multicall {call_array_len} failed") { let res = call_contract( - contract_address=this_call.to, - function_selector=this_call.selector, - calldata_size=this_call.calldata_len, - calldata=this_call.calldata, + contract_address=last_call.to, + function_selector=last_call.selector, + calldata_size=last_call.data_len, + calldata=calldata + last_call.data_offset, ); } - // copy the result in response - memcpy( - dst=response, - src=res.retdata, - len=res.retdata_size - ); - // do the next calls recursively - let (response_len) = execute_calls( - calls_len=calls_len - 1, - calls=calls + Call.SIZE, - response=response + res.retdata_size, - index=index + 1 - ); - return (response_len + res.retdata_size,); + + // store response data + assert [response + response_len] = res.retdata_size; + memcpy(response + response_len + 1, res.retdata, res.retdata_size); + return (response_len + res.retdata_size + 1, response); } diff --git a/test/test_argent_account.py b/test/test_argent_account.py index 7ed5750e..3c155c35 100644 --- a/test/test_argent_account.py +++ b/test/test_argent_account.py @@ -20,10 +20,11 @@ ESCAPE_SECURITY_PERIOD = 24*7*60*60 -VERSION = str_to_felt('0.2.3') +VERSION = str_to_felt('0.2.4') NAME = str_to_felt('ArgentAccount') -IACCOUNT_ID = 0x3943f10f +IACCOUNT_ID = 0xa66bd575 +IACCOUNT_ID_OLD = 0x3943f10f ESCAPE_TYPE_GUARDIAN = 1 ESCAPE_TYPE_SIGNER = 2 @@ -226,11 +227,11 @@ async def test_multicall(contract_factory): # should indicate which called failed await assert_revert( sender.send_transaction([(dapp.contract_address, 'set_number', [47]), (dapp.contract_address, 'throw_error', [1])], [signer, guardian]), - "multicall 1 failed" + "multicall 2 failed" ) await assert_revert( sender.send_transaction([(dapp.contract_address, 'throw_error', [1]), (dapp.contract_address, 'set_number', [47])], [signer, guardian]), - "multicall 0 failed" + "multicall 1 failed" ) # should call the dapp @@ -621,7 +622,7 @@ async def test_support_interface(contract_factory): res = (await account.supportsInterface(IACCOUNT_ID).call()).result assert (res.success == 1) # IAccount old - res = (await account.supportsInterface(0xf10dbd44).call()).result + res = (await account.supportsInterface(IACCOUNT_ID_OLD).call()).result assert (res.success == 1) # unsupported res = (await account.supportsInterface(0xffffffff).call()).result diff --git a/test/test_proxy.py b/test/test_proxy.py index 5e98027f..7f21802a 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -153,8 +153,7 @@ async def test_upgrade(contract_factory): ) ret_execute = get_execute_data(tx_exec_info) - assert len(ret_execute) == 1, "Unexpected return data length" - assert ret_execute[0] == 0, "Expected 0 calls to be executed after upgrade" + assert len(ret_execute) == 0, "Expected 0 calls to be executed after upgrade" assert_event_emitted( tx_exec_info, @@ -181,8 +180,7 @@ async def test_upgrade_exec(contract_factory): ) ret_execute = get_execute_data(tx_exec_info) - assert len(ret_execute) == 2, "Unexpected return data length" - assert ret_execute[0] == 1, "Expected 1 call to be executed after upgrade" + assert ret_execute[0] == 1, "Expected call response to be of length 1 " assert ret_execute[1] == 47, "Expected new_number returned" assert_event_emitted( @@ -212,10 +210,10 @@ async def test_upgrade_many_calls(contract_factory): ) ret_execute = get_execute_data(tx_exec_info) - assert len(ret_execute) == 3, "Unexpected return data length" - assert ret_execute[0] == 2, "Expected 2 calls to be executed after upgrade" + assert ret_execute[0] == 1, "Expected call 1 response to be of length 1 " assert ret_execute[1] == 47, "Expected new_number returned from first call" - assert ret_execute[2] == 48, "Expected new_number returned form second call" + assert ret_execute[2] == 1, "Expected call 2 response to be of length 1 " + assert ret_execute[3] == 48, "Expected new_number returned form second call" assert_event_emitted( tx_exec_info, @@ -278,5 +276,5 @@ async def test_execute_after_upgrade_safety(contract_factory): [build_upgrade_call(account, account_2_class_hash, [change_signer_call])], [signer, guardian] ), - "multicall 0 failed" + "multicall 1 failed" ) diff --git a/test/utils/utilities.py b/test/utils/utilities.py index 8a87ee66..d65202c2 100644 --- a/test/utils/utilities.py +++ b/test/utils/utilities.py @@ -89,9 +89,10 @@ def cached_contract(state: StarknetState, _class: ContractClass, deployed: Stark def get_execute_data(tx_exec_info: TransactionExecutionInfo) -> List[int]: raw_data: List[int] = tx_exec_info.call_info.retdata - ret_execute_size, *ret_execute = raw_data - assert ret_execute_size == len(ret_execute), "Unexpected return size" - return ret_execute + response_len = raw_data[2] + response = raw_data[3:] + assert response_len == len(response), "Unexpected return size" + return response def copy_contract_state(contract: StarknetContract) -> StarknetContract: