Skip to content

Commit

Permalink
Merge pull request #84 from argentlabs/feature/small_refactoring
Browse files Browse the repository at this point in the history
Small refactoring
  • Loading branch information
juniset authored Dec 21, 2022
2 parents 34649b9 + f8d6ef2 commit b7c4af7
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 205 deletions.
27 changes: 11 additions & 16 deletions contracts/account/ArgentAccount.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ from starkware.starknet.common.syscalls import (

from contracts.utils.calls import (
CallArray,
execute_call_array,
execute_multicall,
)

from contracts.account.library import (
Expand All @@ -34,7 +34,7 @@ from contracts.account.library import (
/////////////////////

const NAME = 'ArgentAccount';
const VERSION = '0.2.3';
const VERSION = '0.2.4';

/////////////////////
// EVENTS
Expand Down Expand Up @@ -134,7 +134,7 @@ func __execute__{
assert_non_reentrant();

// execute calls
let (retdata_len, retdata) = execute_call_array(call_array_len, call_array, calldata_len, calldata);
let (retdata_len, retdata) = execute_multicall(call_array_len, call_array, calldata);

// emit event
transaction_executed.emit(
Expand Down Expand Up @@ -231,18 +231,13 @@ func upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
// upgrades the implementation
ArgentModel.upgrade(implementation);
// library call to implementation.execute_after_upgrade
if (calldata_len == 0) {
let (retdata: felt*) = alloc();
return (retdata_len=0, retdata=retdata);
} else {
let (retdata_size: felt, retdata: felt*) = library_call(
class_hash=implementation,
function_selector=ArgentModel.EXECUTE_AFTER_UPGRADE_SELECTOR,
calldata_size=calldata_len,
calldata=calldata,
);
return (retdata_len=retdata_size, retdata=retdata);
}
let (retdata_size: felt, retdata: felt*) = library_call(
class_hash=implementation,
function_selector=ArgentModel.EXECUTE_AFTER_UPGRADE_SELECTOR,
calldata_size=calldata_len,
calldata=calldata,
);
return (retdata_len=retdata_size, retdata=retdata);
}

// @dev Logic or multicall to execute after an upgrade.
Expand All @@ -261,7 +256,7 @@ func execute_after_upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range
let (self) = get_contract_address();
assert_no_self_call(self, call_array_len, call_array);
// execute calls
let (retdata_len, retdata) = execute_call_array(call_array_len, call_array, calldata_len, calldata);
let (retdata_len, retdata) = execute_multicall(call_array_len, call_array, calldata);
return (retdata_len=retdata_len, retdata=retdata);
}

Expand Down
47 changes: 18 additions & 29 deletions contracts/account/ArgentPluginAccount.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@ from starkware.starknet.common.syscalls import (
)
from contracts.plugins.IPlugin import IPlugin
from contracts.utils.calls import (
Call,
CallArray,
execute_call_array,
execute_calls,
from_call_array_to_call
execute_multicall,
)
from contracts.account.library import (
ArgentModel,
Expand All @@ -36,7 +33,7 @@ from contracts.account.library import (
///////////////////////

const NAME = 'ArgentPluginAccount';
const VERSION = '0.0.2';
const VERSION = '0.0.3';

// get_selector_from_name('use_plugin')
const USE_PLUGIN_SELECTOR = 1121675007639292412441492001821602921366030142137563176027248191276862353634;
Expand Down Expand Up @@ -153,7 +150,7 @@ func __execute__{
// no reentrant call to prevent signature reutilization
assert_non_reentrant();

let (retdata_len, retdata) = execute_call_array_plugin(call_array_len, call_array, calldata_len, calldata);
let (retdata_len, retdata) = execute_multicall_plugin(call_array_len, call_array, calldata);

// emit event
transaction_executed.emit(
Expand Down Expand Up @@ -341,20 +338,16 @@ func upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
implementation: felt, calldata_len: felt, calldata: felt*
) -> (retdata_len: felt, retdata: felt*) {
alloc_locals;
// upgrades the implementation
ArgentModel.upgrade(implementation);

if (calldata_len == 0) {
let (retdata: felt*) = alloc();
return (retdata_len=0, retdata=retdata);
} else {
let (retdata_size: felt, retdata: felt*) = library_call(
class_hash=implementation,
function_selector=ArgentModel.EXECUTE_AFTER_UPGRADE_SELECTOR,
calldata_size=calldata_len,
calldata=calldata,
);
return (retdata_len=retdata_size, retdata=retdata);
}
// library call to implementation.execute_after_upgrade
let (retdata_size: felt, retdata: felt*) = library_call(
class_hash=implementation,
function_selector=ArgentModel.EXECUTE_AFTER_UPGRADE_SELECTOR,
calldata_size=calldata_len,
calldata=calldata,
);
return (retdata_len=retdata_size, retdata=retdata);
}

// @dev Logic or multicall to execute after an upgrade.
Expand All @@ -373,7 +366,7 @@ func execute_after_upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range
let (self) = get_contract_address();
assert_no_self_call(self, call_array_len, call_array);
// execute calls
let (retdata_len, retdata) = execute_call_array(call_array_len, call_array, calldata_len, calldata);
let (retdata_len, retdata) = execute_multicall(call_array_len, call_array, calldata);
return (retdata_len=retdata_len, retdata=retdata);
}

Expand Down Expand Up @@ -526,20 +519,16 @@ func is_valid_signature{
return (is_valid=is_valid);
}

func execute_call_array_plugin{syscall_ptr: felt*}(
call_array_len: felt, call_array: CallArray*, calldata_len: felt, calldata: felt*
func execute_multicall_plugin{syscall_ptr: felt*}(
call_array_len: felt, call_array: CallArray*, calldata: felt*
) -> (retdata_len: felt, retdata: felt*) {
alloc_locals;

let (calls: Call*) = alloc();
from_call_array_to_call(call_array_len, call_array, calldata, calls);

let (response: felt*) = alloc();
if (calls[0].selector == USE_PLUGIN_SELECTOR) {
let (response_len) = execute_calls(call_array_len - 1, calls + Call.SIZE, response, 0);
if (call_array[0].selector == USE_PLUGIN_SELECTOR) {
let (response_len, response) = execute_multicall(call_array_len - 1, call_array + CallArray.SIZE, calldata);
return (retdata_len=response_len, retdata=response);
} else {
let (response_len) = execute_calls(call_array_len, calls, response, 0);
let (response_len, response) = execute_multicall(call_array_len, call_array, calldata);
return (retdata_len=response_len, retdata=response);
}
}
7 changes: 4 additions & 3 deletions contracts/account/library.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ from contracts.upgrade.Upgradable import _set_implementation
from contracts.utils.calls import CallArray

const SUPPORTS_INTERFACE_SELECTOR = 1184015894760294494673613438913361435336722154500302038630992932234692784845;
const ERC165_ACCOUNT_INTERFACE_ID = 0x3943f10f;
const ERC165_ACCOUNT_INTERFACE_ID_OLD = 0xf10dbd44; // this is needed to upgrade to this version
const ERC165_ACCOUNT_INTERFACE_ID = 0xa66bd575;
const ERC165_ACCOUNT_INTERFACE_ID_OLD_1 = 0x3943f10f; // this is needed to upgrade to this version
const ERC165_ACCOUNT_INTERFACE_ID_OLD_2 = 0xf10dbd44; // this is needed to upgrade to this version

const TRANSACTION_VERSION = 1;
const QUERY_VERSION = 2**128 + TRANSACTION_VERSION;
Expand Down Expand Up @@ -408,7 +409,7 @@ namespace ArgentModel {
return (TRUE,);
}
// Old IAccount
if (interface_id == ERC165_ACCOUNT_INTERFACE_ID_OLD) {
if ((interface_id - ERC165_ACCOUNT_INTERFACE_ID_OLD_1) * (interface_id - ERC165_ACCOUNT_INTERFACE_ID_OLD_2) == 0) {
return (TRUE,);
}
return (FALSE,);
Expand Down
70 changes: 3 additions & 67 deletions contracts/lib/Multicall.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ from starkware.starknet.common.syscalls import call_contract, get_block_number
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.memcpy import memcpy
from contracts.utils.calls import (
Call,
CallArray,
from_call_array_to_call,
execute_multicall,
)


Expand All @@ -15,7 +14,7 @@ from contracts.utils.calls import (
// contracts and return the aggregate response as an array.
// Input: same as the IAccount.__execute__
// @return (block_number, retdata_size, retdata)
// Where retdata is [len(call_1_data), *call_1_data, len(call_1_data), *call_2_data, ..., len(call_N_data), *call_N_data]
// Where retdata is [len(call_1_data), *call_1_data, len(call_2_data), *call_2_data, ..., len(call_N_data), *call_N_data]
// ///////////////////////////////////////////////////////////////////////
@view
func aggregate{syscall_ptr: felt*, range_check_ptr}(
Expand All @@ -27,70 +26,7 @@ func aggregate{syscall_ptr: felt*, range_check_ptr}(
block_number: felt, retdata_len: felt, retdata: felt*
) {
alloc_locals;
let (retdata_len, retdata) = do_call_array(call_array_len, call_array, calldata_len, calldata);
let (retdata_len, retdata) = execute_multicall(call_array_len, call_array, calldata);
let (block_number) = get_block_number();
return (block_number=block_number, retdata_len=retdata_len, retdata=retdata);
}


// @notice Convenience method to convert an execute a call array
// @return response_len: The size of the returned data
// @return response: Data return
// in the form [len(call_1_data), *call_1_data, len(call_1_data), *call_2_data, ..., len(call_N_data), *call_N_data]
func do_call_array{syscall_ptr: felt*}(
call_array_len: felt, call_array: CallArray*, calldata_len: felt, calldata: felt*
) -> (retdata_len: felt, retdata: felt*) {
alloc_locals;
// convert calls
let (calls: Call*) = alloc();
from_call_array_to_call(call_array_len, call_array, calldata, calls);

// execute them
let (response: felt*) = alloc();
let (response_len) = do_calls(call_array_len, calls, response, 0);
return (retdata_len=response_len, retdata=response);
}


// @notice Executes a list of contract calls recursively.
// @param calls_len The number of calls to execute
// @param calls A pointer to the first call to execute
// @param response The array of felt to populate with the returned data
// in the form [len(call_1_data), *call_1_data, len(call_1_data), *call_2_data, ..., len(call_N_data), *call_N_data]
// @return response_len The size of the returned data
func do_calls{syscall_ptr: felt*}(calls_len: felt, calls: Call*, response: felt*, index: felt) -> (
response_len: felt
) {
alloc_locals;

// if no more calls
if (calls_len == 0) {
return (0,);
}

// do the current call
let this_call: Call = [calls];
with_attr error_message("multicall {index} failed") {
let res = call_contract(
contract_address=this_call.to,
function_selector=this_call.selector,
calldata_size=this_call.calldata_len,
calldata=this_call.calldata,
);
}
// copy the result in response
assert [response] = res.retdata_size;
memcpy(
dst=response + 1,
src=res.retdata,
len=res.retdata_size
);
// do the next calls recursively
let (response_len) = do_calls(
calls_len=calls_len - 1,
calls=calls + Call.SIZE,
response=response + res.retdata_size + 1,
index=index + 1
);
return (response_len + res.retdata_size + 1,);
}
99 changes: 25 additions & 74 deletions contracts/utils/calls.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -23,88 +23,39 @@ struct CallArray {
data_len: felt,
}

func from_call_array_to_call{syscall_ptr: felt*}(
call_array_len: felt, call_array: CallArray*, calldata: felt*, calls: Call*
) {
// if no more calls
if (call_array_len == 0) {
return ();
}

// parse the current call
assert [calls] = Call(
to=[call_array].to,
selector=[call_array].selector,
calldata_len=[call_array].data_len,
calldata=calldata + [call_array].data_offset
);

// parse the remaining calls recursively
from_call_array_to_call(
call_array_len - 1, call_array + CallArray.SIZE, calldata, calls + Call.SIZE
);
return ();
}


// @notice Convenience method to convert an execute a call array
// @notice Executes a list of call array recursively
// @return response_len: The size of the returned data
// @return response: Data return
// in the form [*call_1_data, *call_2_data, ..., *call_N_data]
func execute_call_array{syscall_ptr: felt*}(
call_array_len: felt, call_array: CallArray*, calldata_len: felt, calldata: felt*
) -> (retdata_len: felt, retdata: felt*) {
// @return response: An array of felt populated with the returned data
// in the form [len(call_1_data), *call_1_data, len(call_2_data), *call_2_data, ..., len(call_N_data), *call_N_data]
func execute_multicall{syscall_ptr: felt*}(
call_array_len: felt, call_array: CallArray*, calldata: felt*
) -> (response_len: felt, response: felt*) {
alloc_locals;
// convert calls
let (calls: Call*) = alloc();
from_call_array_to_call(call_array_len, call_array, calldata, calls);

// execute them
let (response: felt*) = alloc();
let (response_len) = execute_calls(call_array_len, calls, response, 0);
return (retdata_len=response_len, retdata=response);
}
if (call_array_len == 0) {
let (response) = alloc();
return (0, response);
}

// call recursively all previous calls
let (response_len, response: felt*) = execute_multicall(call_array_len - 1, call_array, calldata);

// @notice Executes a list of contract calls recursively.
// @param calls_len The number of calls to execute
// @param calls A pointer to the first call to execute
// @param response The array of felt to populate with the returned data
// in the form [*call_1_data, *call_2_data, ..., *call_N_data]
// @return response_len The size of the returned data
func execute_calls{syscall_ptr: felt*}(calls_len: felt, calls: Call*, response: felt*, index: felt) -> (
response_len: felt
) {
alloc_locals;
// handle the last call
let last_call = call_array[call_array_len - 1];

// if no more calls
if (calls_len == 0) {
return (0,);
}

// do the current call
let this_call: Call = [calls];
with_attr error_message("multicall {index} failed") {
// call the last call
with_attr error_message("multicall {call_array_len} failed") {
let res = call_contract(
contract_address=this_call.to,
function_selector=this_call.selector,
calldata_size=this_call.calldata_len,
calldata=this_call.calldata,
contract_address=last_call.to,
function_selector=last_call.selector,
calldata_size=last_call.data_len,
calldata=calldata + last_call.data_offset,
);
}
// copy the result in response
memcpy(
dst=response,
src=res.retdata,
len=res.retdata_size
);
// do the next calls recursively
let (response_len) = execute_calls(
calls_len=calls_len - 1,
calls=calls + Call.SIZE,
response=response + res.retdata_size,
index=index + 1
);
return (response_len + res.retdata_size,);

// store response data
assert [response + response_len] = res.retdata_size;
memcpy(response + response_len + 1, res.retdata, res.retdata_size);
return (response_len + res.retdata_size + 1, response);
}

Loading

0 comments on commit b7c4af7

Please sign in to comment.