Skip to content

Commit

Permalink
feat: increase test coverage + minor fix (#41)
Browse files Browse the repository at this point in the history
* feat: increase test coverage + minor fix

* fix: format code

* fix: 0 address check for mailbox client component

* fix: minor fixes and assertions

* feat: minor code changes and structure update

* fix: format
  • Loading branch information
JordyRo1 authored Jun 19, 2024
1 parent ac4a2e6 commit 233c584
Show file tree
Hide file tree
Showing 17 changed files with 568 additions and 348 deletions.
8 changes: 7 additions & 1 deletion contracts/src/contracts/client/mailboxclient_component.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub mod MailboxclientComponent {
};
use openzeppelin::access::ownable::{OwnableComponent, OwnableComponent::InternalImpl};
use openzeppelin::upgrades::{interface::IUpgradeable, upgradeable::UpgradeableComponent};
use starknet::ContractAddress;
use starknet::{ContractAddress, contract_address_const};

#[storage]
struct Storage {
Expand All @@ -16,6 +16,10 @@ pub mod MailboxclientComponent {
interchain_security_module: ContractAddress,
}

pub mod Errors {
pub const ADDRESS_CANNOT_BE_ZERO: felt252 = 'Address cannot be zero';
}


#[embeddable_as(MailboxClientImpl)]
impl MailboxClient<
Expand All @@ -26,6 +30,7 @@ pub mod MailboxclientComponent {
fn set_hook(ref self: ComponentState<TContractState>, _hook: ContractAddress) {
let ownable_comp = get_dep_component!(@self, Owner);
ownable_comp.assert_only_owner();
assert(_hook != contract_address_const::<0>(), Errors::ADDRESS_CANNOT_BE_ZERO);
self.hook.write(_hook);
}

Expand All @@ -34,6 +39,7 @@ pub mod MailboxclientComponent {
) {
let ownable_comp = get_dep_component!(@self, Owner);
ownable_comp.assert_only_owner();
assert(_module != contract_address_const::<0>(), Errors::ADDRESS_CANNOT_BE_ZERO);
self.interchain_security_module.write(_module);
}

Expand Down
13 changes: 11 additions & 2 deletions contracts/src/contracts/hooks/merkle_tree_hook.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub mod merkle_tree_hook {
pub const INVALID_METADATA_VARIANT: felt252 = 'Invalid metadata variant';
pub const MERKLE_TREE_FULL: felt252 = 'Merkle tree full';
pub const CANNOT_EXCEED_TREE_DEPTH: felt252 = 'Cannot exceed tree depth';
pub const TREE_DEPTH_NOT_REACHED: felt252 = 'Tree depth not reached';
}

#[event]
Expand Down Expand Up @@ -111,7 +112,7 @@ pub mod merkle_tree_hook {
}

fn latest_checkpoint(self: @ContractState) -> (u256, u32) {
(self._root(), self.count())
(self._root(), self.count() - 1)
}
}

Expand Down Expand Up @@ -139,6 +140,7 @@ pub mod merkle_tree_hook {
pub impl InternalImpl of InternalTrait {
fn _post_dispatch(ref self: ContractState, _metadata: Bytes, _message: Message) {
let (id, _) = MessageTrait::format_message(_message);
// ensure messages which were not dispatched are not inserted into the tree
assert(self.mailboxclient._is_latest_dispatched(id), Errors::MESSAGE_NOT_DISPATCHING);
let index = self.count();
self._insert(ByteData { value: id, size: HASH_SIZE });
Expand Down Expand Up @@ -173,13 +175,19 @@ pub mod merkle_tree_hook {
fn _root_with_ctx(self: @ContractState, _zeroes: Array<u256>) -> u256 {
let mut cur_idx = 0;
let index = self.count.read();

// Not present in the solidity implementation
let mut current = ByteData { value: *_zeroes[0], size: HASH_SIZE };
loop {
if (cur_idx == TREE_DEPTH) {
break ();
}
let ith_bit = _get_ith_bit(index.into(), cur_idx);
let next = self.tree.read(cur_idx);
let next = self
.tree
.read(
cur_idx
); // Will return 0 if no values is stored, in accordance with solidity impl
if (ith_bit == 1) {
current =
ByteData {
Expand All @@ -206,6 +214,7 @@ pub mod merkle_tree_hook {
current.value
}
fn _branch_root(_item: ByteData, _branch: Span<ByteData>, _index: u256) -> u256 {
assert(_branch.len() >= TREE_DEPTH, Errors::TREE_DEPTH_NOT_REACHED);
let mut cur_idx = 0;
let mut current = _item;
loop {
Expand Down
90 changes: 64 additions & 26 deletions contracts/src/contracts/isms/aggregation/aggregation.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[starknet::contract]
pub mod aggregation {
use alexandria_bytes::Bytes;
use alexandria_bytes::{Bytes, BytesTrait};
use hyperlane_starknet::contracts::libs::aggregation_ism_metadata::aggregation_ism_metadata::AggregationIsmMetadata;
use hyperlane_starknet::contracts::libs::message::{Message, MessageTrait};
use hyperlane_starknet::interfaces::{
Expand Down Expand Up @@ -44,6 +44,7 @@ pub mod aggregation {
pub const THRESHOLD_NOT_REACHED: felt252 = 'Threshold not reached';
pub const MODULE_ADDRESS_CANNOT_BE_NULL: felt252 = 'Module address cannot be null';
pub const THRESHOLD_NOT_SET: felt252 = 'Threshold not set';
pub const MODULES_ALREADY_STORED: felt252 = 'Modules already stored';
}

#[constructor]
Expand All @@ -63,15 +64,19 @@ pub mod aggregation {
) -> (Span<ContractAddress>, u8) {
// THE USER CAN DEFINE HERE CONDITIONS FOR THE MODULE AND THRESHOLD SELECTION
let threshold = self.threshold.read();
(build_modules_span(self), threshold)
(self.build_modules_span(), threshold)
}

fn verify(self: @ContractState, _metadata: Bytes, _message: Message,) -> bool {
let (isms, mut threshold) = self.modules_and_threshold(_message.clone());

assert(threshold != 0, Errors::THRESHOLD_NOT_SET);
let modules = build_modules_span(self);
let modules = self.build_modules_span();
let mut cur_idx: u8 = 0;
loop {
if (threshold == 0) {
break ();
}
if (cur_idx.into() == isms.len()) {
break ();
}
Expand All @@ -82,6 +87,7 @@ pub mod aggregation {
let ism = IInterchainSecurityModuleDispatcher {
contract_address: *modules.at(cur_idx.into())
};

let metadata = AggregationIsmMetadata::metadata_at(_metadata.clone(), cur_idx);
assert(ism.verify(metadata, _message.clone()), Errors::VERIFICATION_FAILED);
threshold -= 1;
Expand All @@ -92,7 +98,7 @@ pub mod aggregation {
}

fn get_modules(self: @ContractState) -> Span<ContractAddress> {
build_modules_span(self)
self.build_modules_span()
}

fn get_threshold(self: @ContractState) -> u8 {
Expand All @@ -101,7 +107,8 @@ pub mod aggregation {

fn set_modules(ref self: ContractState, _modules: Span<ContractAddress>) {
self.ownable.assert_only_owner();
let mut last_module = find_last_module(@self);
assert(!self.are_modules_stored(_modules), Errors::MODULES_ALREADY_STORED);
let mut last_module = self.find_last_module();
let mut cur_idx = 0;
loop {
if (cur_idx == _modules.len()) {
Expand All @@ -122,29 +129,60 @@ pub mod aggregation {
self.threshold.write(_threshold);
}
}

fn find_last_module(self: @ContractState) -> ContractAddress {
let mut current_module = self.modules.read(contract_address_const::<0>());
loop {
let next_module = self.modules.read(current_module);
if next_module == contract_address_const::<0>() {
break current_module;
#[generate_trait]
impl InternalImpl of InternalTrait {
fn find_last_module(self: @ContractState) -> ContractAddress {
let mut current_module = self.modules.read(contract_address_const::<0>());
loop {
let next_module = self.modules.read(current_module);
if next_module == contract_address_const::<0>() {
break current_module;
}
current_module = next_module;
}
current_module = next_module;
}
}

fn build_modules_span(self: @ContractState) -> Span<ContractAddress> {
let mut cur_address = contract_address_const::<0>();
let mut modules = array![];
loop {
let next_address = self.modules.read(cur_address);
if (next_address == contract_address_const::<0>()) {
break ();
fn find_module_index(
self: @ContractState, _module: ContractAddress
) -> Option<ContractAddress> {
let mut current_module: ContractAddress = 0.try_into().unwrap();
loop {
let next_module = self.modules.read(current_module);
if next_module == _module {
break Option::Some(current_module);
} else if next_module == 0.try_into().unwrap() {
break Option::None(());
}
current_module = next_module;
}
modules.append(cur_address);
cur_address = next_address
};
modules.span()
}

fn are_modules_stored(self: @ContractState, _modules: Span<ContractAddress>) -> bool {
let mut cur_idx = 0;
let mut result = false;
while cur_idx < _modules.len()
&& result == false {
let module = *_modules.at(cur_idx);
match self.find_module_index(module) {
Option::Some => result = true,
Option::None => {}
};
cur_idx += 1;
};
result
}

fn build_modules_span(self: @ContractState) -> Span<ContractAddress> {
let mut cur_address = contract_address_const::<0>();
let mut modules = array![];
loop {
let next_address = self.modules.read(cur_address);
if (next_address == contract_address_const::<0>()) {
break ();
}
modules.append(next_address);
cur_address = next_address
};
modules.span()
}
}
}
Loading

0 comments on commit 233c584

Please sign in to comment.