diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 7c3fdc14d..d3ca6f2ad 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -16,7 +16,7 @@ use openmls::{ framing::{MlsMessageBodyIn, MlsMessageIn}, group::GroupEpoch, messages::Welcome, - prelude::tls_codec::{Deserialize, Error as TlsCodecError, Serialize}, + prelude::tls_codec::{Deserialize, Error as TlsCodecError}, }; use openmls_traits::OpenMlsProvider; use prost::EncodeError; @@ -586,11 +586,13 @@ where /// Upload a new key package to the network replacing an existing key package /// This is expected to be run any time the client receives new Welcome messages pub async fn rotate_key_package(&self) -> Result<(), ClientError> { - let provider: XmtpOpenMlsProvider = self.store().conn()?.into(); - - let kp = self.identity().new_key_package(&provider)?; - let kp_bytes = kp.tls_serialize_detached()?; - self.api_client.upload_key_package(kp_bytes, true).await?; + self.store() + .transaction_async(|provider| async move { + self.identity() + .rotate_key_package(&provider, &self.api_client) + .await + }) + .await?; Ok(()) } @@ -668,6 +670,7 @@ where /// Returns any new groups created in the operation pub async fn sync_welcomes(&self) -> Result, ClientError> { let envelopes = self.query_welcome_messages(&self.store().conn()?).await?; + let num_envelopes = envelopes.len(); let id = self.installation_public_key(); let groups: Vec = stream::iter(envelopes.into_iter()) @@ -717,6 +720,11 @@ where .collect() .await; + // If any welcomes were found, rotate your key package + if num_envelopes > 0 { + self.rotate_key_package().await?; + } + Ok(groups) } @@ -848,12 +856,16 @@ mod tests { builder::ClientBuilder, groups::GroupMetadataOptions, hpke::{decrypt_welcome, encrypt_welcome}, + identity::serialize_key_package_hash_ref, storage::{ consent_record::{ConsentState, ConsentType, StoredConsentRecord}, schema::identity_updates, }, + XmtpApi, }; + use super::Client; + #[tokio::test] async fn test_group_member_recovery() { let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await; @@ -1179,4 +1191,82 @@ mod tests { assert_eq!(inbox_consent, ConsentState::Denied); assert_eq!(address_consent, ConsentState::Denied); } + + async fn get_key_package_init_key( + client: &Client, + installation_id: &[u8], + ) -> Vec { + let kps = client + .get_key_packages_for_installation_ids(vec![installation_id.to_vec()]) + .await + .unwrap(); + let kp = kps.first().unwrap(); + + serialize_key_package_hash_ref(&kp.inner, &client.mls_provider().unwrap()).unwrap() + } + + #[tokio::test] + async fn test_key_package_rotation() { + let alix_wallet = generate_local_wallet(); + let bo_wallet = generate_local_wallet(); + let alix = ClientBuilder::new_test_client(&alix_wallet).await; + let bo = ClientBuilder::new_test_client(&bo_wallet).await; + let bo_store = bo.store(); + + let alix_original_init_key = + get_key_package_init_key(&alix, &alix.installation_public_key()).await; + let bo_original_init_key = + get_key_package_init_key(&bo, &bo.installation_public_key()).await; + + // Bo's original key should be deleted + let bo_original_from_db = bo_store + .conn() + .unwrap() + .find_key_package_history_entry_by_hash_ref(bo_original_init_key.clone()); + assert!(bo_original_from_db.is_ok()); + + alix.create_group_with_members( + vec![bo_wallet.get_address()], + None, + GroupMetadataOptions::default(), + ) + .await + .unwrap(); + + bo.sync_welcomes().await.unwrap(); + + let bo_new_key = get_key_package_init_key(&bo, &bo.installation_public_key()).await; + // Bo's key should have changed + assert_ne!(bo_original_init_key, bo_new_key); + + bo.sync_welcomes().await.unwrap(); + let bo_new_key_2 = get_key_package_init_key(&bo, &bo.installation_public_key()).await; + // Bo's key should not have changed syncing the second time. + assert_eq!(bo_new_key, bo_new_key_2); + + alix.sync_welcomes().await.unwrap(); + let alix_key_2 = get_key_package_init_key(&alix, &alix.installation_public_key()).await; + // Alix's key should not have changed at all + assert_eq!(alix_original_init_key, alix_key_2); + + alix.create_group_with_members( + vec![bo_wallet.get_address()], + None, + GroupMetadataOptions::default(), + ) + .await + .unwrap(); + bo.sync_welcomes().await.unwrap(); + + // Bo should have two groups now + let bo_groups = bo.find_groups(None, None, None, None).unwrap(); + assert_eq!(bo_groups.len(), 2); + + // Bo's original key should be deleted + let bo_original_after_delete = bo_store + .conn() + .unwrap() + .find_key_package_history_entry_by_hash_ref(bo_original_init_key); + assert!(bo_original_after_delete.is_err()); + } } diff --git a/xmtp_mls/src/identity.rs b/xmtp_mls/src/identity.rs index 0ca912e9a..bee387072 100644 --- a/xmtp_mls/src/identity.rs +++ b/xmtp_mls/src/identity.rs @@ -18,6 +18,7 @@ use crate::{retryable, Fetch, Store}; use ed25519_dalek::SigningKey; use log::debug; use log::info; +use openmls::prelude::hash_ref::HashReference; use openmls::prelude::tls_codec::Serialize; use openmls::{ credentials::{errors::BasicCredentialError, BasicCredential, CredentialWithKey}, @@ -29,6 +30,7 @@ use openmls::{ prelude_test::KeyPackage, }; use openmls_basic_credential::SignatureKeyPair; +use openmls_traits::storage::StorageProvider; use openmls_traits::types::CryptoError; use openmls_traits::OpenMlsProvider; use prost::Message; @@ -162,6 +164,8 @@ pub enum IdentityError { RequiredIdentityNotFound, #[error("error creating new identity: {0}")] NewIdentity(String), + #[error(transparent)] + DieselResult(#[from] diesel::result::Error), } impl RetryableError for IdentityError { @@ -171,6 +175,7 @@ impl RetryableError for IdentityError { Self::WrappedApi(err) => retryable!(err), Self::StorageError(err) => retryable!(err), Self::OpenMlsStorageError(err) => retryable!(err), + Self::DieselResult(err) => retryable!(err), _ => false, } } @@ -424,16 +429,7 @@ impl Identity { // This is needed to get to the private key when decrypting welcome messages. let public_init_key = kp.key_package().hpke_init_key().tls_serialize_detached()?; - let key_package_hash_ref = match kp.key_package().hash_ref(provider.crypto()) { - Ok(key_package_hash_ref) => key_package_hash_ref, - Err(_) => return Err(IdentityError::UninitializedIdentity), - }; - - // Serialize the hash reference (with bincode) - let hash_ref = match bincode::serialize(&key_package_hash_ref) { - Ok(hash_ref) => hash_ref, - Err(_) => return Err(IdentityError::UninitializedIdentity), - }; + let hash_ref = serialize_key_package_hash_ref(kp.key_package(), &provider)?; // Store the hash reference, keyed with the public init key provider .storage() @@ -455,15 +451,70 @@ impl Identity { info!("Identity already registered. skipping key package publishing"); return Ok(()); } + + self.rotate_key_package(provider, api_client).await?; + self.is_ready.store(true, Ordering::SeqCst); + + Ok(StoredIdentity::try_from(self)?.store(provider.conn_ref())?) + } + + pub(crate) async fn rotate_key_package( + &self, + provider: &XmtpOpenMlsProvider, + api_client: &ApiClientWrapper, + ) -> Result<(), IdentityError> { let kp = self.new_key_package(provider)?; let kp_bytes = kp.tls_serialize_detached()?; + let conn = provider.conn_ref(); + let hash_ref = serialize_key_package_hash_ref(&kp, provider)?; + let history_id = conn.store_key_package_history_entry(hash_ref)?.id; + let old_id = history_id - 1; + + // Find all key packages that are not the current or previous KPs + // We can delete before uploading because this is either run inside a transaction or is being applied to a brand + // new identity + let old_key_packages = conn.find_key_package_history_entries_before_id(old_id)?; + for kp in old_key_packages { + self.delete_key_package(provider, kp.key_package_hash_ref)?; + } + conn.delete_key_package_history_entries_before_id(old_id)?; + api_client.upload_key_package(kp_bytes, true).await?; - self.is_ready.store(true, Ordering::SeqCst); + Ok(()) + } - Ok(StoredIdentity::try_from(self)?.store(provider.conn_ref())?) + pub(crate) fn delete_key_package( + &self, + provider: &XmtpOpenMlsProvider, + hash_ref: Vec, + ) -> Result<(), IdentityError> { + let openmls_hash_ref = deserialize_key_package_hash_ref(&hash_ref)?; + provider.storage().delete_key_package(&openmls_hash_ref)?; + + Ok(()) } } +pub(crate) fn serialize_key_package_hash_ref( + kp: &KeyPackage, + provider: &impl OpenMlsProvider, +) -> Result, IdentityError> { + let key_package_hash_ref = kp + .hash_ref(provider.crypto()) + .map_err(|_| IdentityError::UninitializedIdentity)?; + let serialized = bincode::serialize(&key_package_hash_ref) + .map_err(|_| IdentityError::UninitializedIdentity)?; + + Ok(serialized) +} + +fn deserialize_key_package_hash_ref(hash_ref: &[u8]) -> Result { + let key_package_hash_ref: HashReference = + bincode::deserialize(hash_ref).map_err(|_| IdentityError::UninitializedIdentity)?; + + Ok(key_package_hash_ref) +} + async fn sign_with_installation_key( signature_text: String, installation_private_key: &[u8; 32], diff --git a/xmtp_mls/src/storage/encrypted_store/key_package_history.rs b/xmtp_mls/src/storage/encrypted_store/key_package_history.rs index 6fa4ebecb..1f759a8bc 100644 --- a/xmtp_mls/src/storage/encrypted_store/key_package_history.rs +++ b/xmtp_mls/src/storage/encrypted_store/key_package_history.rs @@ -59,6 +59,21 @@ impl DbConnection { Ok(result) } + + pub fn delete_key_package_history_entries_before_id( + &self, + id: i32, + ) -> Result<(), StorageError> { + self.raw_query(|conn| { + diesel::delete( + key_package_history::dsl::key_package_history + .filter(key_package_history::dsl::id.lt(id)), + ) + .execute(conn) + })?; + + Ok(()) + } } #[cfg(test)]