diff --git a/src/lib.cairo b/src/lib.cairo index 235e12392..ea35f72ca 100644 --- a/src/lib.cairo +++ b/src/lib.cairo @@ -3,6 +3,7 @@ mod account; mod introspection; mod security; mod token; +mod upgrades; mod utils; #[cfg(test)] diff --git a/src/tests.cairo b/src/tests.cairo index cc78a11ec..7145f5650 100644 --- a/src/tests.cairo +++ b/src/tests.cairo @@ -4,4 +4,5 @@ mod introspection; mod mocks; mod security; mod token; +mod upgrades; mod utils; diff --git a/src/tests/mocks.cairo b/src/tests/mocks.cairo index cb3112db1..11a30de47 100644 --- a/src/tests/mocks.cairo +++ b/src/tests/mocks.cairo @@ -17,3 +17,5 @@ mod snake721_mock; mod snake_accesscontrol_mock; mod snake_account_mock; mod src5_mocks; +mod upgrades_v1; +mod upgrades_v2; diff --git a/src/tests/mocks/upgrades_v1.cairo b/src/tests/mocks/upgrades_v1.cairo new file mode 100644 index 000000000..d851f59d9 --- /dev/null +++ b/src/tests/mocks/upgrades_v1.cairo @@ -0,0 +1,55 @@ +// This contract is a mock used to test the core functionality of the upgrade functions. +// The functions are NOT PROTECTED. +// DO NOT USE IN PRODUCTION. + +use array::ArrayTrait; +use starknet::ClassHash; + +#[starknet::interface] +trait IUpgradesV1 { + fn upgrade(ref self: TState, impl_hash: ClassHash); + fn set_value(ref self: TState, val: felt252); + fn get_value(self: @TState) -> felt252; + fn remove_selector(self: @TState); +} + +trait UpgradesV1Trait { + fn set_value(ref self: TState, val: felt252); + fn get_value(self: @TState) -> felt252; + fn remove_selector(self: @TState); +} + +#[starknet::contract] +mod UpgradesV1 { + use array::ArrayTrait; + use starknet::ClassHash; + use starknet::ContractAddress; + use openzeppelin::upgrades::interface::IUpgradeable; + use openzeppelin::upgrades::upgradeable::Upgradeable; + + #[storage] + struct Storage { + value: felt252 + } + + #[external(v0)] + impl UpgradeableImpl of IUpgradeable { + fn upgrade(ref self: ContractState, impl_hash: ClassHash) { + let mut unsafe_state = Upgradeable::unsafe_new_contract_state(); + Upgradeable::InternalImpl::_upgrade(ref unsafe_state, impl_hash); + } + } + + #[external(v0)] + impl UpgradesV1Impl of super::UpgradesV1Trait { + fn set_value(ref self: ContractState, val: felt252) { + self.value.write(val); + } + + fn get_value(self: @ContractState) -> felt252 { + self.value.read() + } + + fn remove_selector(self: @ContractState) {} + } +} diff --git a/src/tests/mocks/upgrades_v2.cairo b/src/tests/mocks/upgrades_v2.cairo new file mode 100644 index 000000000..f0ad60ee7 --- /dev/null +++ b/src/tests/mocks/upgrades_v2.cairo @@ -0,0 +1,64 @@ +// This contract is a mock used to test the core functionality of the upgrade functions. +// The functions are NOT PROTECTED. +// DO NOT USE IN PRODUCTION. + +use array::ArrayTrait; +use starknet::ClassHash; + +#[starknet::interface] +trait IUpgradesV2 { + fn upgrade(ref self: TState, impl_hash: ClassHash); + fn set_value(ref self: TState, val: felt252); + fn set_value2(ref self: TState, val: felt252); + fn get_value(self: @TState) -> felt252; + fn get_value2(self: @TState) -> felt252; +} + +trait UpgradesV2Trait { + fn set_value(ref self: TState, val: felt252); + fn set_value2(ref self: TState, val: felt252); + fn get_value(self: @TState) -> felt252; + fn get_value2(self: @TState) -> felt252; +} + +#[starknet::contract] +mod UpgradesV2 { + use array::ArrayTrait; + use starknet::ClassHash; + use starknet::ContractAddress; + use openzeppelin::upgrades::interface::IUpgradeable; + use openzeppelin::upgrades::upgradeable::Upgradeable; + + #[storage] + struct Storage { + value: felt252, + value2: felt252, + } + + #[external(v0)] + impl UpgradeableImpl of IUpgradeable { + fn upgrade(ref self: ContractState, impl_hash: ClassHash) { + let mut unsafe_state = Upgradeable::unsafe_new_contract_state(); + Upgradeable::InternalImpl::_upgrade(ref unsafe_state, impl_hash); + } + } + + #[external(v0)] + impl UpgradesV2Impl of super::UpgradesV2Trait { + fn set_value(ref self: ContractState, val: felt252) { + self.value.write(val); + } + + fn set_value2(ref self: ContractState, val: felt252) { + self.value2.write(val); + } + + fn get_value(self: @ContractState) -> felt252 { + self.value.read() + } + + fn get_value2(self: @ContractState) -> felt252 { + self.value2.read() + } + } +} diff --git a/src/tests/upgrades.cairo b/src/tests/upgrades.cairo new file mode 100644 index 000000000..fe66e01e3 --- /dev/null +++ b/src/tests/upgrades.cairo @@ -0,0 +1 @@ +mod test_upgradeable; diff --git a/src/tests/upgrades/test_upgradeable.cairo b/src/tests/upgrades/test_upgradeable.cairo new file mode 100644 index 000000000..66de7a2a7 --- /dev/null +++ b/src/tests/upgrades/test_upgradeable.cairo @@ -0,0 +1,104 @@ +use array::ArrayTrait; +use openzeppelin::tests::mocks::upgrades_v1::IUpgradesV1Dispatcher; +use openzeppelin::tests::mocks::upgrades_v1::IUpgradesV1DispatcherTrait; +use openzeppelin::tests::mocks::upgrades_v2::IUpgradesV2Dispatcher; +use openzeppelin::tests::mocks::upgrades_v2::IUpgradesV2DispatcherTrait; +use openzeppelin::tests::mocks::upgrades_v1::UpgradesV1; +use openzeppelin::tests::mocks::upgrades_v2::UpgradesV2; +use openzeppelin::tests::utils; +use openzeppelin::upgrades::upgradeable::Upgradeable::Upgraded; +use option::OptionTrait; +use starknet::ClassHash; +use starknet::ContractAddress; +use starknet::Felt252TryIntoClassHash; +use starknet::class_hash_const; +use starknet::contract_address_const; +use starknet::testing; +use traits::TryInto; + +const VALUE: felt252 = 123; + +fn V2_CLASS_HASH() -> ClassHash { + UpgradesV2::TEST_CLASS_HASH.try_into().unwrap() +} + +fn CLASS_HASH_ZERO() -> ClassHash { + class_hash_const::<0>() +} + +fn ZERO() -> ContractAddress { + contract_address_const::<0>() +} + +// +// Setup +// + +fn deploy_v1() -> IUpgradesV1Dispatcher { + let calldata = array![]; + let address = utils::deploy(UpgradesV1::TEST_CLASS_HASH, calldata); + IUpgradesV1Dispatcher { contract_address: address } +} + +// +// upgrade +// + +#[test] +#[available_gas(2000000)] +#[should_panic(expected: ('Class hash cannot be zero', 'ENTRYPOINT_FAILED', ))] +fn test_upgrade_with_class_hash_zero() { + let v1 = deploy_v1(); + v1.upgrade(CLASS_HASH_ZERO()); +} + +#[test] +#[available_gas(2000000)] +fn test_upgraded_event() { + let v1 = deploy_v1(); + v1.upgrade(V2_CLASS_HASH()); + + let event = testing::pop_log::(v1.contract_address).unwrap(); + assert(event.class_hash == V2_CLASS_HASH(), 'Invalid class hash'); +} + +#[test] +#[available_gas(2000000)] +fn test_new_selector_after_upgrade() { + let v1 = deploy_v1(); + + v1.upgrade(V2_CLASS_HASH()); + let v2 = IUpgradesV2Dispatcher { contract_address: v1.contract_address }; + + v2.set_value2(VALUE); + assert(v2.get_value2() == VALUE, 'New selector should be callable'); +} + +#[test] +#[available_gas(2000000)] +fn test_state_persists_after_upgrade() { + let v1 = deploy_v1(); + v1.set_value(VALUE); + + v1.upgrade(V2_CLASS_HASH()); + let v2 = IUpgradesV2Dispatcher { contract_address: v1.contract_address }; + + assert(v2.get_value() == VALUE, 'Should keep state after upgrade'); +} + +#[test] +#[available_gas(2000000)] +fn test_remove_selector_passes_in_v1() { + let v1 = deploy_v1(); + v1.remove_selector(); +} + +#[test] +#[available_gas(2000000)] +#[should_panic(expected: ('ENTRYPOINT_NOT_FOUND', ))] +fn test_remove_selector_fails_in_v2() { + let v1 = deploy_v1(); + v1.upgrade(V2_CLASS_HASH()); + // We use the v1 dispatcher because remove_selector is not in v2 interface + v1.remove_selector(); +} diff --git a/src/upgrades.cairo b/src/upgrades.cairo new file mode 100644 index 000000000..18d78fc73 --- /dev/null +++ b/src/upgrades.cairo @@ -0,0 +1,2 @@ +mod upgradeable; +mod interface; diff --git a/src/upgrades/interface.cairo b/src/upgrades/interface.cairo new file mode 100644 index 000000000..1857c003d --- /dev/null +++ b/src/upgrades/interface.cairo @@ -0,0 +1,6 @@ +use starknet::ClassHash; + +#[starknet::interface] +trait IUpgradeable { + fn upgrade(ref self: TState, impl_hash: ClassHash); +} diff --git a/src/upgrades/upgradeable.cairo b/src/upgrades/upgradeable.cairo new file mode 100644 index 000000000..444cf36b7 --- /dev/null +++ b/src/upgrades/upgradeable.cairo @@ -0,0 +1,29 @@ +#[starknet::contract] +mod Upgradeable { + use starknet::ClassHash; + use starknet::SyscallResult; + use zeroable::Zeroable; + + #[storage] + struct Storage {} + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + Upgraded: Upgraded + } + + #[derive(Drop, starknet::Event)] + struct Upgraded { + class_hash: ClassHash + } + + #[generate_trait] + impl InternalImpl of InternalState { + fn _upgrade(ref self: ContractState, new_class_hash: ClassHash) { + assert(!new_class_hash.is_zero(), 'Class hash cannot be zero'); + starknet::replace_class_syscall(new_class_hash).unwrap_syscall(); + self.emit(Upgraded { class_hash: new_class_hash }); + } + } +}