diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index e6a3c043e..9f1f82c09 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -30,6 +30,7 @@ use xmtp_mls::groups::intents::PermissionUpdateType; use xmtp_mls::groups::GroupMetadataOptions; use xmtp_mls::storage::consent_record::ConsentState; use xmtp_mls::storage::consent_record::ConsentType; +use xmtp_mls::storage::consent_record::StoredConsentRecord; use xmtp_mls::{ api::ApiClientWrapper, builder::ClientBuilder, @@ -331,16 +332,12 @@ impl FfiXmtpClient { Ok(state.into()) } - pub async fn set_consent_state( - &self, - state: FfiConsentState, - entity_type: FfiConsentEntityType, - entity: String, - ) -> Result<(), GenericError> { + pub async fn set_consent_states(&self, records: Vec) -> Result<(), GenericError> { let inner = self.inner_client.as_ref(); - inner - .set_consent_state(state.into(), entity_type.into(), entity) - .await?; + let stored_records: Vec = + records.into_iter().map(StoredConsentRecord::from).collect(); + + inner.set_consent_states(stored_records).await?; Ok(()) } @@ -1527,6 +1524,23 @@ impl From for FfiMessage { } } +#[derive(uniffi::Record)] +pub struct FfiConsent { + pub entity_type: FfiConsentEntityType, + pub state: FfiConsentState, + pub entity: String, +} + +impl From for StoredConsentRecord { + fn from(consent: FfiConsent) -> Self { + Self { + entity_type: consent.entity_type.into(), + state: consent.state.into(), + entity: consent.entity, + } + } +} + #[derive(uniffi::Object, Clone, Debug)] pub struct FfiStreamCloser { #[allow(clippy::type_complexity)] @@ -1666,7 +1680,7 @@ impl FfiGroupPermissions { mod tests { use super::{create_client, signature_verifier, FfiMessage, FfiMessageCallback, FfiXmtpClient}; use crate::{ - get_inbox_id_for_address, inbox_owner::SigningError, logger::FfiLogger, + get_inbox_id_for_address, inbox_owner::SigningError, logger::FfiLogger, FfiConsent, FfiConsentEntityType, FfiConsentState, FfiConversationCallback, FfiCreateGroupOptions, FfiGroup, FfiGroupMessageKind, FfiGroupPermissionsOptions, FfiInboxOwner, FfiListConversationsOptions, FfiListMessagesOptions, FfiMetadataField, FfiPermissionPolicy, @@ -3696,12 +3710,11 @@ mod tests { .unwrap(); let alix_updated_consent = alix_group.consent_state().unwrap(); assert_eq!(alix_updated_consent, FfiConsentState::Denied); - - bo.set_consent_state( - FfiConsentState::Allowed, - FfiConsentEntityType::GroupId, - hex::encode(bo_group.id()), - ) + bo.set_consent_states(vec![FfiConsent { + state: FfiConsentState::Allowed, + entity_type: FfiConsentEntityType::GroupId, + entity: hex::encode(bo_group.id()), + }]) .await .unwrap(); let bo_updated_consent = bo_group.consent_state().unwrap(); @@ -3721,11 +3734,11 @@ mod tests { ) .await .unwrap(); - alix.set_consent_state( - FfiConsentState::Allowed, - FfiConsentEntityType::Address, - bo.account_address.clone(), - ) + alix.set_consent_states(vec![FfiConsent { + state: FfiConsentState::Allowed, + entity_type: FfiConsentEntityType::Address, + entity: bo.account_address.clone(), + }]) .await .unwrap(); let bo_consent = alix diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index b3a5dede3..a71a3dba2 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -313,17 +313,33 @@ where &self, address: String, ) -> Result, ClientError> { - if let Some(sanitized_address) = sanitize_evm_addresses(vec![address])?.pop() { - let mut results = self - .api_client - .get_inbox_ids(vec![sanitized_address.clone()]) - .await?; - Ok(results.remove(&sanitized_address)) + let results = self + .find_inbox_ids_from_addresses(vec![address.clone()]) + .await?; + if let Some(first_result) = results.into_iter().next() { + Ok(first_result) } else { Ok(None) } } + pub async fn find_inbox_ids_from_addresses( + &self, + addresses: Vec, + ) -> Result>, ClientError> { + let sanitized_addresses = sanitize_evm_addresses(addresses.clone())?; + let mut results = self + .api_client + .get_inbox_ids(sanitized_addresses.clone()) + .await?; + let inbox_ids: Vec> = sanitized_addresses + .into_iter() + .map(|address| results.remove(&address)) + .collect(); + + Ok(inbox_ids) + } + /// Get sequence id, may not be consistent with the backend pub fn inbox_sequence_id(&self, conn: &DbConnection) -> Result { self.context.inbox_sequence_id(conn) @@ -344,28 +360,40 @@ where // set the consent record in the database // if the consent record is an address also set the inboxId - pub async fn set_consent_state( + pub async fn set_consent_states( &self, - state: ConsentState, - entity_type: ConsentType, - entity: String, + mut records: Vec, ) -> Result<(), ClientError> { let conn = self.store().conn()?; - conn.insert_or_replace_consent_record(StoredConsentRecord::new( - entity_type, - state, - entity.clone(), - ))?; - if entity_type == ConsentType::Address { - if let Some(inbox_id) = self.find_inbox_id_from_address(entity.clone()).await? { - conn.insert_or_replace_consent_record(StoredConsentRecord::new( + let mut new_records = Vec::new(); + let mut addresses_to_lookup = Vec::new(); + let mut record_indices = Vec::new(); + + for (index, record) in records.iter().enumerate() { + if record.entity_type == ConsentType::Address { + addresses_to_lookup.push(record.entity.clone()); + record_indices.push(index); + } + } + + let inbox_ids = self + .find_inbox_ids_from_addresses(addresses_to_lookup) + .await?; + + for (i, inbox_id_opt) in inbox_ids.into_iter().enumerate() { + if let Some(inbox_id) = inbox_id_opt { + let record = &records[record_indices[i]]; + new_records.push(StoredConsentRecord::new( ConsentType::InboxId, - state, + record.state, inbox_id, - ))?; + )); } - }; + } + + records.extend(new_records); + conn.insert_or_replace_consent_records(records)?; Ok(()) } @@ -819,7 +847,7 @@ mod tests { builder::ClientBuilder, groups::GroupMetadataOptions, hpke::{decrypt_welcome, encrypt_welcome}, - storage::consent_record::{ConsentState, ConsentType}, + storage::consent_record::{ConsentState, ConsentType, StoredConsentRecord}, }; #[tokio::test] @@ -1103,14 +1131,12 @@ mod tests { let bo_wallet = generate_local_wallet(); let alix = ClientBuilder::new_test_client(&generate_local_wallet()).await; let bo = ClientBuilder::new_test_client(&bo_wallet).await; - - alix.set_consent_state( - ConsentState::Denied, + let record = StoredConsentRecord::new( ConsentType::Address, + ConsentState::Denied, bo_wallet.get_address(), - ) - .await - .unwrap(); + ); + alix.set_consent_states(vec![record]).await.unwrap(); let inbox_consent = alix .get_consent_state(ConsentType::InboxId, bo.inbox_id()) .await diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 44dde5ecb..e2db02144 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -961,11 +961,11 @@ impl MlsGroup { pub fn update_consent_state(&self, state: ConsentState) -> Result<(), GroupError> { let conn = self.context.store.conn()?; - conn.insert_or_replace_consent_record(StoredConsentRecord::new( + conn.insert_or_replace_consent_records(vec![StoredConsentRecord::new( ConsentType::GroupId, state, hex::encode(self.group_id.clone()), - ))?; + )])?; Ok(()) } diff --git a/xmtp_mls/src/storage/encrypted_store/consent_record.rs b/xmtp_mls/src/storage/encrypted_store/consent_record.rs index dceb8f8e8..1a6cd5ef4 100644 --- a/xmtp_mls/src/storage/encrypted_store/consent_record.rs +++ b/xmtp_mls/src/storage/encrypted_store/consent_record.rs @@ -57,19 +57,23 @@ impl DbConnection { })?) } - /// Insert consent_record, and replace existing entries - pub fn insert_or_replace_consent_record( + /// Insert consent_records, and replace existing entries + pub fn insert_or_replace_consent_records( &self, - record: StoredConsentRecord, + records: Vec, ) -> Result<(), StorageError> { self.raw_query(|conn| { - diesel::insert_into(dsl::consent_records) - .values(&record) - .on_conflict((dsl::entity_type, dsl::entity)) - .do_update() - .set(dsl::state.eq(excluded(dsl::state))) - .execute(conn)?; - Ok(()) + conn.transaction::<_, diesel::result::Error, _>(|conn| { + for record in records.iter() { + diesel::insert_into(dsl::consent_records) + .values(record) + .on_conflict((dsl::entity_type, dsl::entity)) + .do_update() + .set(dsl::state.eq(excluded(dsl::state))) + .execute(conn)?; + } + Ok(()) + }) })?; Ok(()) @@ -179,7 +183,7 @@ mod tests { ); let consent_record_entity = consent_record.entity.clone(); - conn.insert_or_replace_consent_record(consent_record) + conn.insert_or_replace_consent_records(vec![consent_record]) .expect("should store without error"); let consent_record = conn