From d9a986e5f4df278dd87cf08cf6b77ab725a70455 Mon Sep 17 00:00:00 2001 From: oetyng Date: Tue, 2 Mar 2021 02:31:00 +0100 Subject: [PATCH] fix: issues pointed out in review comments --- examples/stress.rs | 4 +- src/messages/mod.rs | 18 +++---- src/routing/approved.rs | 106 +++++++++++++++++++-------------------- src/routing/command.rs | 9 ++-- src/routing/mod.rs | 6 +-- src/routing/stage.rs | 8 +-- src/routing/tests/mod.rs | 2 +- tests/messages.rs | 10 ++-- 8 files changed, 84 insertions(+), 79 deletions(-) diff --git a/examples/stress.rs b/examples/stress.rs index edf116c0cc..17d8ecf344 100644 --- a/examples/stress.rs +++ b/examples/stress.rs @@ -457,13 +457,13 @@ impl Network { }, }; let bytes = bincode::serialize(&message)?.into(); - let itry = Itinerary { + let itinerary = Itinerary { src: SrcLocation::Node(src), dst: DstLocation::Section(dst), aggregation: Aggregation::None, }; - match node.send_message(itry, bytes).await { + match node.send_message(itinerary, bytes).await { Ok(()) => Ok(true), Err(RoutingError::InvalidSrcLocation) => Ok(false), // node name changed Err(error) => Err(error.into()), diff --git a/src/messages/mod.rs b/src/messages/mod.rs index 6a1a1978b4..8080c27f2e 100644 --- a/src/messages/mod.rs +++ b/src/messages/mod.rs @@ -40,7 +40,7 @@ pub(crate) struct Message { src: SrcAuthority, /// Destination location. dst: DstLocation, - /// + /// The aggregation scheme to be used. aggregation: Aggregation, /// The body of the message. variant: Variant, @@ -311,15 +311,13 @@ impl Message { } /// Elders will aggregate a group sig before - /// they all all send one copy of it each to dst. - pub fn aggregate_at_src(&self) -> bool { - matches!(self.src, SrcAuthority::Section { .. }) - } - - /// Elders will send their signed message, which - /// recipients aggregate. - pub fn aggregate_at_dst(&self) -> bool { - matches!(self.src, SrcAuthority::BlsShare { .. }) + /// they all send one copy of it each to dst. + pub fn aggregation(&self) -> Aggregation { + match self.src { + SrcAuthority::Section { .. } => Aggregation::AtSource, + SrcAuthority::BlsShare { .. } => Aggregation::AtDestination, + SrcAuthority::Node { .. } => Aggregation::None, + } } /// Returns the attached proof chain, if any. diff --git a/src/routing/approved.rs b/src/routing/approved.rs index 11688eb770..8302342146 100644 --- a/src/routing/approved.rs +++ b/src/routing/approved.rs @@ -49,7 +49,7 @@ use sn_messaging::{ section_info::{ Error as TargetSectionError, GetSectionResponse, Message as SectionInfoMsg, SectionInfo, }, - DstLocation, EndUser, Itinerary, MessageType, SrcLocation, + Aggregation, DstLocation, EndUser, Itinerary, MessageType, SrcLocation, }; use std::{cmp, net::SocketAddr, slice}; use tokio::sync::mpsc; @@ -820,7 +820,7 @@ impl Approved { }; let bounce_dst_key = *self.section_key_by_name(&src_name); - let bounce_dst = if msg.aggregate_at_src() { + let bounce_dst = if matches!(msg.aggregation(), Aggregation::AtSource) { DstLocation::Section(src_name) } else { DstLocation::Node(src_name) @@ -1002,46 +1002,39 @@ impl Approved { message: MessageType::ClientMessage(ClientMessage::from(content)?), }]); } - if msg.aggregate_at_dst() { - if !matches!(dst, DstLocation::Node(_)) { - return Err(Error::InvalidDstLocation); - } - if let SrcAuthority::BlsShare { - proof_share, - src_section, - .. - } = &src + if let SrcAuthority::BlsShare { + proof_share, + src_section, + .. + } = &src + { + let signed_bytes = bincode::serialize(&msg.signable_view())?; + match self + .message_accumulator + .add(&signed_bytes, proof_share.clone()) { - let signed_bytes = bincode::serialize(&msg.signable_view())?; - match self - .message_accumulator - .add(&signed_bytes, proof_share.clone()) - { - Ok(proof) => { - trace!("Successfully aggregated signatures for message: {:?}", msg); - let key = msg.proof_chain_last_key()?; - if key.verify(&proof.signature, signed_bytes) { - self.send_event(Event::MessageReceived { - content, - src: SrcLocation::Section(*src_section), - dst, - }); - } else { - trace!( - "Aggregated signature is invalid. Handling message {:?} skipped", - msg - ); - } - } - Err(AggregatorError::NotEnoughShares) => {} - Err(err) => { - trace!("Error accumulating message at destination: {:?}", err); + Ok(proof) => { + trace!("Successfully aggregated signatures for message: {:?}", msg); + let key = msg.proof_chain_last_key()?; + if key.verify(&proof.signature, signed_bytes) { + self.send_event(Event::MessageReceived { + content, + src: SrcLocation::Section(*src_section), + dst, + }); + } else { + trace!( + "Aggregated signature is invalid. Handling message {:?} skipped", + msg + ); } } - return Ok(vec![]); - } else { - return Err(Error::InvalidSrcLocation); + Err(AggregatorError::NotEnoughShares) => {} + Err(err) => { + trace!("Error accumulating message at destination: {:?}", err); + } } + return Ok(vec![]); } self.send_event(Event::MessageReceived { @@ -2036,53 +2029,60 @@ impl Approved { Ok(commands) } - pub fn send_user_message(&mut self, itry: Itinerary, content: Bytes) -> Result> { - let are_we_src = - matches!(itry.src, SrcLocation::Node(_)) && itry.src.name() == self.node.name(); + pub fn send_user_message( + &mut self, + itinerary: Itinerary, + content: Bytes, + ) -> Result> { + let are_we_src = itinerary.src.equals(&self.node.name()) + || itinerary.src.equals(&self.section().prefix().name()); if !are_we_src { error!( "Not sending user message {:?} -> {:?}: we are not the source location", - itry.src, itry.dst + itinerary.src, itinerary.dst ); return Err(Error::InvalidSrcLocation); } - if matches!(itry.src, SrcLocation::EndUser(_)) { + if matches!(itinerary.src, SrcLocation::EndUser(_)) { return Err(Error::InvalidSrcLocation); } - if matches!(itry.dst, DstLocation::Direct) { + if matches!(itinerary.dst, DstLocation::Direct) { error!( "Not sending user message {:?} -> {:?}: direct dst not supported", - itry.src, itry.dst + itinerary.src, itinerary.dst ); return Err(Error::InvalidDstLocation); } - // If the source is a single node, we don't even need to vote, so let's cut this short. - let msg = if itry.aggregate_at_dst() { + // If the msg is to be aggregated at dst, we don't vote among our peers, wemsimply send the msg as our vote to the dst. + let msg = if itinerary.aggregate_at_dst() { Message::for_dst_accumulation( &self.node, self.section_keys_provider.key_share()?, - itry.dst, + itinerary.dst, content, self.section().create_proof_chain_for_our_info(None), None, self.section().prefix().name(), )? - } else if itry.aggregate_at_src() { + } else if itinerary.aggregate_at_src() { let variant = Variant::UserMessage(content); - let vote = self.create_send_message_vote(itry.dst, variant, None)?; + let vote = self.create_send_message_vote(itinerary.dst, variant, None)?; let recipients = delivery_group::signature_targets( - &itry.dst, + &itinerary.dst, self.section.elders_info().peers().copied(), ); return self.send_vote(&recipients, vote); } else { let variant = Variant::UserMessage(content); - Message::single_src(&self.node, itry.dst, variant, None, None)? + Message::single_src(&self.node, itinerary.dst, variant, None, None)? }; let mut commands = vec![]; - if itry.dst.contains(&self.node.name(), self.section.prefix()) { + if itinerary + .dst + .contains(&self.node.name(), self.section.prefix()) + { commands.push(Command::HandleMessage { sender: Some(self.node.addr), message: msg.clone(), diff --git a/src/routing/command.rs b/src/routing/command.rs index 4d9338eedb..8c81baf8ad 100644 --- a/src/routing/command.rs +++ b/src/routing/command.rs @@ -70,7 +70,10 @@ pub(crate) enum Command { message: MessageType, }, /// Send `UserMessage` with the given source and destination. - SendUserMessage { itry: Itinerary, content: Bytes }, + SendUserMessage { + itinerary: Itinerary, + content: Bytes, + }, /// Schedule a timeout after the given duration. When the timeout expires, a `HandleTimeout` /// command is raised. The token is used to identify the timeout. ScheduleTimeout { duration: Duration, token: u64 }, @@ -162,9 +165,9 @@ impl Debug for Command { .field("delivery_group_size", delivery_group_size) .field("message", message) .finish(), - Self::SendUserMessage { itry, content } => f + Self::SendUserMessage { itinerary, content } => f .debug_struct("SendUserMessage") - .field("itry", itry) + .field("itinerary", itinerary) .field("content", &format_args!("{:10}", HexFmt(content))) .finish(), Self::ScheduleTimeout { duration, token } => f diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 6cb3448712..0fb0d9601e 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -312,11 +312,11 @@ impl Routing { /// Send a message. /// Messages sent here, either section to section or node to node are signed /// and validated upon receipt by routing itself. - pub async fn send_message(&self, itry: Itinerary, content: Bytes) -> Result<()> { + pub async fn send_message(&self, itinerary: Itinerary, content: Bytes) -> Result<()> { if let DstLocation::EndUser(EndUser::Client { socket_id, public_key, - }) = itry.dst + }) = itinerary.dst { let socket_addr = self .stage @@ -338,7 +338,7 @@ impl Routing { debug!("Sending user message instead.. (Command::SendUserMessage)"); } } - let command = Command::SendUserMessage { itry, content }; + let command = Command::SendUserMessage { itinerary, content }; self.stage.clone().handle_commands(command).await } diff --git a/src/routing/stage.rs b/src/routing/stage.rs index efbcb3dfaa..4e3ba2aab1 100644 --- a/src/routing/stage.rs +++ b/src/routing/stage.rs @@ -147,9 +147,11 @@ impl Stage { self.send_message(&recipients, delivery_group_size, message) .await } - Command::SendUserMessage { itry, content } => { - self.state.lock().await.send_user_message(itry, content) - } + Command::SendUserMessage { itinerary, content } => self + .state + .lock() + .await + .send_user_message(itinerary, content), Command::ScheduleTimeout { duration, token } => Ok(self .handle_schedule_timeout(duration, token) .await diff --git a/src/routing/tests/mod.rs b/src/routing/tests/mod.rs index 6d7f6594a1..aadcea4c46 100644 --- a/src/routing/tests/mod.rs +++ b/src/routing/tests/mod.rs @@ -1459,7 +1459,7 @@ async fn message_to_self(dst: MessageDst) -> Result<()> { let commands = stage .handle_command(Command::SendUserMessage { - itry: Itinerary { + itinerary: Itinerary { src, dst, aggregation: Aggregation::None, diff --git a/tests/messages.rs b/tests/messages.rs index 8611d53644..ae9ce7fb5b 100644 --- a/tests/messages.rs +++ b/tests/messages.rs @@ -153,13 +153,15 @@ async fn test_messages_between_nodes() -> Result<()> { println!("sending msg.."); - let itry = Itinerary { + let itinerary = Itinerary { src: SrcLocation::Node(node2_name), dst: DstLocation::Node(node1_name), aggregation: Aggregation::None, }; - node2.send_message(itry, Bytes::from_static(msg)).await?; + node2 + .send_message(itinerary, Bytes::from_static(msg)) + .await?; println!("msg sent"); @@ -168,7 +170,7 @@ async fn test_messages_between_nodes() -> Result<()> { println!("Got dst: {:?} (expecting: {}", dst.name(), node2_name); println!("sending response from {:?}..", node1_name); - let itry = Itinerary { + let itinerary = Itinerary { src: SrcLocation::Node(node1_name), dst, aggregation: Aggregation::None, @@ -176,7 +178,7 @@ async fn test_messages_between_nodes() -> Result<()> { // send response from node1 to node2 node1 - .send_message(itry, Bytes::from_static(response)) + .send_message(itinerary, Bytes::from_static(response)) .await?; println!("checking response received..");