diff --git a/contract/src/lib.rs b/contract/src/lib.rs index c96bc3e7f..d8ff04f9c 100644 --- a/contract/src/lib.rs +++ b/contract/src/lib.rs @@ -48,7 +48,7 @@ pub enum ProtocolContractState { } #[near_bindgen] -#[derive(BorshDeserialize, BorshSerialize)] +#[derive(BorshDeserialize, BorshSerialize, Debug)] pub enum VersionedMpcContract { V0(MpcContract), } @@ -59,7 +59,7 @@ impl Default for VersionedMpcContract { } } -#[derive(BorshDeserialize, BorshSerialize, PanicOnDefault)] +#[derive(BorshDeserialize, BorshSerialize, PanicOnDefault, Debug)] pub struct MpcContract { protocol_state: ProtocolContractState, pending_requests: LookupMap<[u8; 32], Option<(String, String)>>, @@ -88,13 +88,25 @@ impl MpcContract { } } - pub fn clean_payloads(&mut self, payloads: Vec<[u8; 32]>, counter: u32) { + fn clean_payloads(&mut self, payloads: Vec<[u8; 32]>, counter: u32) { log!("clean_payloads"); for payload in payloads.iter() { self.pending_requests.remove(payload); } self.request_counter = counter; } + + pub fn test_init() -> Self { + MpcContract { + protocol_state: ProtocolContractState::Initializing(InitializingContractState { + candidates: Candidates::default(), + threshold: 2, + pk_votes: PkVotes::new(), + }), + pending_requests: LookupMap::new(b"m"), + request_counter: 2, + } + } } #[near_bindgen] @@ -546,12 +558,12 @@ impl VersionedMpcContract { #[private] #[init(ignore_state)] - pub fn migrate_state() -> Self { + pub fn migrate_state_old_to_v0() -> Self { let old_contract: MpcContract = env::state_read().expect("Old state doesn't exist"); Self::V0(MpcContract { protocol_state: old_contract.protocol_state, - pending_requests: LookupMap::new(b"m"), - request_counter: 0, + pending_requests: old_contract.pending_requests, + request_counter: old_contract.request_counter, }) } } diff --git a/contract/tests/tests.rs b/contract/tests/tests.rs index c26c8b6ba..7a4fac432 100644 --- a/contract/tests/tests.rs +++ b/contract/tests/tests.rs @@ -1,4 +1,5 @@ -use mpc_contract::primitives::CandidateInfo; +use mpc_contract::{primitives::CandidateInfo, MpcContract, VersionedMpcContract}; +use near_sdk::env; use near_workspaces::AccountId; use std::collections::HashMap; @@ -37,3 +38,19 @@ async fn test_contract_can_not_be_reinitialized() -> anyhow::Result<()> { Ok(()) } + +#[test] +fn test_old_state_can_be_migrated_to_v0() -> anyhow::Result<()> { + let old_contract = MpcContract::test_init(); + env::state_write(&old_contract); + + let v0_contract = VersionedMpcContract::migrate_state_old_to_v0(); + let expected_contract = VersionedMpcContract::V0(old_contract); + + assert_eq!( + format!("{v0_contract:#?}"), + format!("{expected_contract:#?}") + ); + + Ok(()) +}