diff --git a/contract/src/lib.rs b/contract/src/lib.rs index db72b9801..6f0cbdae1 100644 --- a/contract/src/lib.rs +++ b/contract/src/lib.rs @@ -3,10 +3,9 @@ pub mod primitives; use near_sdk::borsh::{self, BorshDeserialize, BorshSerialize}; use near_sdk::collections::LookupMap; use near_sdk::serde::{Deserialize, Serialize}; -use near_sdk::{env, near_bindgen, AccountId, PanicOnDefault, Promise, PromiseOrValue, PublicKey}; +use near_sdk::{env, near_bindgen, AccountId, Promise, PromiseOrValue, PublicKey}; use near_sdk::{log, Gas}; -use primitives::ParticipantInfo; -use primitives::{CandidateInfo, Candidates, Participants, PkVotes, Votes}; +use primitives::{CandidateInfo, Candidates, ParticipantInfo, Participants, PkVotes, Votes}; use std::collections::{BTreeMap, HashSet}; const GAS_FOR_SIGN_CALL: Gas = Gas::from_tgas(250); @@ -49,23 +48,55 @@ pub enum ProtocolContractState { } #[near_bindgen] -#[derive(BorshDeserialize, BorshSerialize, PanicOnDefault)] +#[derive(BorshDeserialize, BorshSerialize, Debug)] +pub enum VersionedMpcContract { + V0(MpcContract), +} + +impl Default for VersionedMpcContract { + fn default() -> Self { + env::panic_str("Calling default not allowed."); + } +} + +#[derive(BorshDeserialize, BorshSerialize, Debug)] pub struct MpcContract { protocol_state: ProtocolContractState, pending_requests: LookupMap<[u8; 32], Option<(String, String)>>, request_counter: u32, } -#[near_bindgen] impl MpcContract { - #[init] + fn add_request(&mut self, payload: &[u8; 32], signature: &Option<(String, String)>) { + if self.request_counter > 8 { + env::panic_str("Too many pending requests. Please, try again later."); + } + if !self.pending_requests.contains_key(payload) { + self.request_counter += 1; + } + self.pending_requests.insert(payload, signature); + } + + fn remove_request(&mut self, payload: &[u8; 32]) { + self.pending_requests.remove(payload); + self.request_counter -= 1; + } + + fn add_signature(&mut self, payload: &[u8; 32], signature: (String, String)) { + if self.pending_requests.contains_key(payload) { + self.pending_requests.insert(payload, &Some(signature)); + } + } + + fn clean_payloads(&mut self, payloads: Vec<[u8; 32]>, counter: u32) { + log!("clean_payloads"); + for payload in payloads.iter() { + self.pending_requests.remove(payload); + } + self.request_counter = counter; + } + pub fn init(threshold: usize, candidates: BTreeMap) -> Self { - log!( - "init: signer={}, treshhold={}, candidates={}", - env::signer_account_id(), - threshold, - serde_json::to_string(&candidates).unwrap() - ); MpcContract { protocol_state: ProtocolContractState::Initializing(InitializingContractState { candidates: Candidates { candidates }, @@ -76,6 +107,20 @@ impl MpcContract { request_counter: 0, } } +} + +#[near_bindgen] +impl VersionedMpcContract { + #[init] + pub fn init(threshold: usize, candidates: BTreeMap) -> Self { + log!( + "init: signer={}, treshhold={}, candidates={}", + env::signer_account_id(), + threshold, + serde_json::to_string(&candidates).unwrap() + ); + Self::V0(MpcContract::init(threshold, candidates)) + } // This function can be used to transfer the MPC network to a new contract. #[init] @@ -93,7 +138,7 @@ impl MpcContract { threshold, public_key ); - MpcContract { + Self::V0(MpcContract { protocol_state: ProtocolContractState::Running(RunningContractState { epoch, participants: Participants { participants }, @@ -105,11 +150,27 @@ impl MpcContract { }), pending_requests: LookupMap::new(b"m"), request_counter: 0, - } + }) + } + + /// Key versions refer new versions of the root key that we may choose to generate on cohort changes + /// Older key versions will always work but newer key versions were never held by older signers + /// Newer key versions may also add new security features, like only existing within a secure enclave + /// Currently only 0 is a valid key version + pub const fn latest_key_version(&self) -> u32 { + 0 } pub fn state(self) -> ProtocolContractState { - self.protocol_state + match self { + Self::V0(mpc_contract) => mpc_contract.protocol_state, + } + } + + fn mutable_state(&mut self) -> &mut ProtocolContractState { + match self { + Self::V0(ref mut mpc_contract) => &mut mpc_contract.protocol_state, + } } pub fn join( @@ -125,10 +186,11 @@ impl MpcContract { cipher_pk, sign_pk ); - match &mut self.protocol_state { + let protocol_state = self.mutable_state(); + match protocol_state { ProtocolContractState::Running(RunningContractState { participants, - candidates, + ref mut candidates, .. }) => { let signer_account_id = env::signer_account_id(); @@ -155,7 +217,8 @@ impl MpcContract { env::signer_account_id(), candidate_account_id ); - match &mut self.protocol_state { + let protocol_state = self.mutable_state(); + match protocol_state { ProtocolContractState::Running(RunningContractState { epoch, participants, @@ -178,15 +241,14 @@ impl MpcContract { let mut new_participants = participants.clone(); new_participants .insert(candidate_account_id.clone(), candidate_info.clone().into()); - self.protocol_state = - ProtocolContractState::Resharing(ResharingContractState { - old_epoch: *epoch, - old_participants: participants.clone(), - new_participants, - threshold: *threshold, - public_key: public_key.clone(), - finished_votes: HashSet::new(), - }); + *protocol_state = ProtocolContractState::Resharing(ResharingContractState { + old_epoch: *epoch, + old_participants: participants.clone(), + new_participants, + threshold: *threshold, + public_key: public_key.clone(), + finished_votes: HashSet::new(), + }); true } else { false @@ -202,7 +264,8 @@ impl MpcContract { env::signer_account_id(), kick ); - match &mut self.protocol_state { + let protocol_state = self.mutable_state(); + match protocol_state { ProtocolContractState::Running(RunningContractState { epoch, participants, @@ -226,15 +289,14 @@ impl MpcContract { if voted.len() >= *threshold { let mut new_participants = participants.clone(); new_participants.remove(&kick); - self.protocol_state = - ProtocolContractState::Resharing(ResharingContractState { - old_epoch: *epoch, - old_participants: participants.clone(), - new_participants, - threshold: *threshold, - public_key: public_key.clone(), - finished_votes: HashSet::new(), - }); + *protocol_state = ProtocolContractState::Resharing(ResharingContractState { + old_epoch: *epoch, + old_participants: participants.clone(), + new_participants, + threshold: *threshold, + public_key: public_key.clone(), + finished_votes: HashSet::new(), + }); true } else { false @@ -250,7 +312,8 @@ impl MpcContract { env::signer_account_id(), public_key ); - match &mut self.protocol_state { + let protocol_state = self.mutable_state(); + match protocol_state { ProtocolContractState::Initializing(InitializingContractState { candidates, threshold, @@ -263,7 +326,7 @@ impl MpcContract { let voted = pk_votes.entry(public_key.clone()); voted.insert(signer_account_id); if voted.len() >= *threshold { - self.protocol_state = ProtocolContractState::Running(RunningContractState { + *protocol_state = ProtocolContractState::Running(RunningContractState { epoch: 0, participants: candidates.clone().into(), threshold: *threshold, @@ -289,7 +352,8 @@ impl MpcContract { env::signer_account_id(), epoch ); - match &mut self.protocol_state { + let protocol_state = self.mutable_state(); + match protocol_state { ProtocolContractState::Resharing(ResharingContractState { old_epoch, old_participants, @@ -307,7 +371,7 @@ impl MpcContract { } finished_votes.insert(signer_account_id); if finished_votes.len() >= *threshold { - self.protocol_state = ProtocolContractState::Running(RunningContractState { + *protocol_state = ProtocolContractState::Running(RunningContractState { epoch: *old_epoch + 1, participants: new_participants.clone(), threshold: *threshold, @@ -332,37 +396,33 @@ impl MpcContract { } } - #[allow(unused_variables)] - /// `key_version` must be less than or equal to the value at `latest_key_version` - pub fn sign(&mut self, payload: [u8; 32], path: String, key_version: u32) -> Promise { - let latest_key_version: u32 = self.latest_key_version(); - assert!( - key_version <= latest_key_version, - "This version of the signer contract doesn't support versions greater than {}", - latest_key_version, - ); - // Make sure sign call will not run out of gas doing recursive calls because the payload will never be removed - assert!( - env::prepaid_gas() >= GAS_FOR_SIGN_CALL, - "Insufficient gas provided. Provided: {} Required: {}", - env::prepaid_gas(), - GAS_FOR_SIGN_CALL - ); - log!( - "sign: signer={}, predecessor={}, payload={:?}, path={:?}, key_version={}", - env::signer_account_id(), - env::predecessor_account_id(), - payload, - path, - key_version - ); - match self.pending_requests.get(&payload) { - None => { - self.add_request(&payload, None); - log!(&serde_json::to_string(&near_sdk::env::random_seed_array()).unwrap()); - Self::ext(env::current_account_id()).sign_helper(payload, 0) + fn signature_per_payload(&self, payload: [u8; 32]) -> Option> { + match self { + Self::V0(mpc_contract) => mpc_contract.pending_requests.get(&payload), + } + } + + fn remove_request(&mut self, payload: &[u8; 32]) { + match self { + Self::V0(mpc_contract) => { + mpc_contract.remove_request(payload); + } + } + } + + fn add_request(&mut self, payload: &[u8; 32], signature: &Option<(String, String)>) { + match self { + Self::V0(mpc_contract) => { + mpc_contract.add_request(payload, signature); + } + } + } + + fn add_signature(&mut self, payload: &[u8; 32], signature: (String, String)) { + match self { + Self::V0(mpc_contract) => { + mpc_contract.add_signature(payload, signature); } - Some(_) => env::panic_str("Signature for this payload already requested"), } } @@ -372,7 +432,7 @@ impl MpcContract { payload: [u8; 32], depth: usize, ) -> PromiseOrValue<(String, String)> { - if let Some(signature) = self.pending_requests.get(&payload) { + if let Some(signature) = self.signature_per_payload(payload) { match signature { Some(signature) => { log!( @@ -417,8 +477,46 @@ impl MpcContract { env::panic_str(&message); } + #[allow(unused_variables)] + /// `key_version` must be less than or equal to the value at `latest_key_version` + pub fn sign(&mut self, payload: [u8; 32], path: String, key_version: u32) -> Promise { + let latest_key_version: u32 = self.latest_key_version(); + assert!( + key_version <= latest_key_version, + "This version of the signer contract doesn't support versions greater than {}", + latest_key_version, + ); + // Make sure sign call will not run out of gas doing recursive calls because the payload will never be removed + assert!( + env::prepaid_gas() >= GAS_FOR_SIGN_CALL, + "Insufficient gas provided. Provided: {} Required: {}", + env::prepaid_gas(), + GAS_FOR_SIGN_CALL + ); + log!( + "sign: signer={}, payload={:?}, path={:?}, key_version={}", + env::signer_account_id(), + payload, + path, + key_version + ); + match self.signature_per_payload(payload) { + None => { + self.add_request(&payload, &None); + log!(&serde_json::to_string(&near_sdk::env::random_seed_array()).unwrap()); + Self::ext(env::current_account_id()).sign_helper(payload, 0) + } + Some(_) => env::panic_str("Signature for this payload already requested"), + } + } + + pub fn version(&self) -> String { + env!("CARGO_PKG_VERSION").to_string() + } + pub fn respond(&mut self, payload: [u8; 32], big_r: String, s: String) { - if let ProtocolContractState::Running(state) = &self.protocol_state { + let protocol_state = self.mutable_state(); + if let ProtocolContractState::Running(state) = protocol_state { let signer = env::signer_account_id(); if state.participants.contains_key(&signer) { log!( @@ -437,49 +535,6 @@ impl MpcContract { } } - /// This is the root public key combined from all the public keys of the participants. - pub fn public_key(&self) -> PublicKey { - match &self.protocol_state { - ProtocolContractState::Running(state) => state.public_key.clone(), - ProtocolContractState::Resharing(state) => state.public_key.clone(), - _ => env::panic_str("public key not available (protocol is not running or resharing)"), - } - } - - pub fn version(&self) -> String { - env!("CARGO_PKG_VERSION").to_string() - } - - /// Key versions refer new versions of the root key that we may choose to generate on cohort changes - /// Older key versions will always work but newer key versions were never held by older signers - /// Newer key versions may also add new security features, like only existing within a secure enclave - /// Currently only 0 is a valid key version - pub const fn latest_key_version(&self) -> u32 { - 0 - } - - fn add_signature(&mut self, payload: &[u8; 32], signature: (String, String)) { - if self.pending_requests.contains_key(payload) { - self.pending_requests.insert(payload, &Some(signature)); - } - } - - fn add_request(&mut self, payload: &[u8; 32], signature: Option<(String, String)>) { - if self.request_counter > 8 { - env::panic_str("Too many pending requests. Please, try again later."); - } - if !self.pending_requests.contains_key(payload) { - self.request_counter += 1; - } - self.pending_requests.insert(payload, &signature); - } - - fn remove_request(&mut self, payload: &[u8; 32]) { - self.pending_requests.remove(payload); - self.request_counter -= 1; - } - - // Helper functions #[private] #[init(ignore_state)] pub fn clean(keys: Vec) -> Self { @@ -487,19 +542,39 @@ impl MpcContract { for key in keys.iter() { env::storage_remove(&key.0); } - Self { + Self::V0(MpcContract { protocol_state: ProtocolContractState::NotInitialized, pending_requests: LookupMap::new(b"m"), request_counter: 0, - } + }) } #[private] pub fn clean_payloads(&mut self, payloads: Vec<[u8; 32]>, counter: u32) { - log!("clean_payloads"); - for payload in payloads.iter() { - self.pending_requests.remove(payload); + match self { + Self::V0(mpc_contract) => { + mpc_contract.clean_payloads(payloads, counter); + } } - self.request_counter = counter; + } + + /// This is the root public key combined from all the public keys of the participants. + pub fn public_key(self) -> PublicKey { + match self.state() { + ProtocolContractState::Running(state) => state.public_key.clone(), + ProtocolContractState::Resharing(state) => state.public_key.clone(), + _ => env::panic_str("public key not available (protocol is not running or resharing)"), + } + } + + #[private] + #[init(ignore_state)] + pub fn migrate_state_old_to_v0() -> Self { + let old_contract: MpcContract = env::state_read().expect("Old state doesn't exist"); + Self::V0(MpcContract { + protocol_state: old_contract.protocol_state, + pending_requests: old_contract.pending_requests, + request_counter: old_contract.request_counter, + }) } } diff --git a/contract/tests/tests.rs b/contract/tests/tests.rs index c26c8b6ba..0f254d05f 100644 --- a/contract/tests/tests.rs +++ b/contract/tests/tests.rs @@ -1,6 +1,7 @@ -use mpc_contract::primitives::CandidateInfo; +use mpc_contract::{primitives::CandidateInfo, MpcContract, VersionedMpcContract}; +use near_sdk::env; use near_workspaces::AccountId; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; const CONTRACT_FILE_PATH: &str = "./../target/seperate_wasm/wasm32-unknown-unknown/release/mpc_contract.wasm"; @@ -37,3 +38,19 @@ async fn test_contract_can_not_be_reinitialized() -> anyhow::Result<()> { Ok(()) } + +#[test] +fn test_old_state_can_be_migrated_to_v0() -> anyhow::Result<()> { + let old_contract = MpcContract::init(3, BTreeMap::new()); + env::state_write(&old_contract); + + let v0_contract = VersionedMpcContract::migrate_state_old_to_v0(); + let expected_contract = VersionedMpcContract::V0(old_contract); + + assert_eq!( + format!("{v0_contract:#?}"), + format!("{expected_contract:#?}") + ); + + Ok(()) +}