From aeeefbb5e4695ed19b945b05a6df5ebc6f415c3d Mon Sep 17 00:00:00 2001 From: Stan Bondi Date: Mon, 3 Jan 2022 15:12:08 +0200 Subject: [PATCH] perf(comms)!: optimise connection establishment (#3658) Description --- Optimise connection establishment: - perform identity protocol before yamux upgrade - identity protocol uses socket rather than yamux - trim down bytes sent for identity protocol Additional: remove tokio-tungstonite dependency from base node (part of warp default-features for websockets) Motivation and Context --- By performing identity protocol before yamux upgrade, we can send the identity message without incurring the header cost of yamux and also avoids having to close a substream after identities are exchanged. It is important that the identity protocol is as trim as possible, so a manual framing implementation is used rather than the tokio length-delimited framing codec. This is a network breaking change, nodes upgraded to this will not be able to communicate with older nodes and vice-versa. How Has This Been Tested? --- Existing tests pass, manually between two upgraded nodes --- Cargo.lock | 61 -------- applications/tari_base_node/Cargo.toml | 2 +- base_layer/p2p/src/lib.rs | 4 +- comms/dht/tests/dht.rs | 14 +- comms/src/builder/mod.rs | 2 +- comms/src/connection_manager/common.rs | 25 ++-- comms/src/connection_manager/dialer.rs | 30 ++-- comms/src/connection_manager/listener.rs | 26 ++-- comms/src/connection_manager/tests/manager.rs | 6 +- comms/src/connectivity/manager.rs | 6 +- comms/src/multiplexing/yamux.rs | 27 +--- comms/src/peer_manager/peer.rs | 7 +- comms/src/proto/identity.proto | 4 - comms/src/protocol/identity.rs | 140 ++++++++++-------- comms/src/protocol/mod.rs | 2 +- comms/src/protocol/network_info.rs | 4 +- comms/src/protocol/protocols.rs | 15 +- comms/src/test_utils/transport.rs | 8 +- 18 files changed, 160 insertions(+), 223 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 56684a9d13..5131ade8b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3548,24 +3548,6 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" -[[package]] -name = "multipart" -version = "0.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00dec633863867f29cb39df64a397cdf4a6354708ddd7759f70c7fb51c5f9182" -dependencies = [ - "buf_redux", - "httparse", - "log", - "mime", - "mime_guess", - "quick-error", - "rand 0.8.4", - "safemem", - "tempfile", - "twoway", -] - [[package]] name = "native-tls" version = "0.2.8" @@ -7353,19 +7335,6 @@ dependencies = [ "tokio-stream", ] -[[package]] -name = "tokio-tungstenite" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "511de3f85caf1c98983545490c3d09685fa8eb634e57eec22bb4db271f46cbd8" -dependencies = [ - "futures-util", - "log", - "pin-project 1.0.8", - "tokio 1.14.0", - "tungstenite", -] - [[package]] name = "tokio-util" version = "0.6.9" @@ -7731,25 +7700,6 @@ dependencies = [ "unicode-width", ] -[[package]] -name = "tungstenite" -version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0b2d8558abd2e276b0a8df5c05a2ec762609344191e5fd23e292c910e9165b5" -dependencies = [ - "base64 0.13.0", - "byteorder", - "bytes 1.1.0", - "http", - "httparse", - "log", - "rand 0.8.4", - "sha-1 0.9.8", - "thiserror", - "url 2.2.2", - "utf-8", -] - [[package]] name = "twofish" version = "0.5.0" @@ -7761,15 +7711,6 @@ dependencies = [ "opaque-debug 0.3.0", ] -[[package]] -name = "twoway" -version = "0.1.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59b11b2b5241ba34be09c3cc85a36e56e48f9888862e19cedf23336d35316ed1" -dependencies = [ - "memchr", -] - [[package]] name = "typemap" version = "0.3.3" @@ -8012,7 +7953,6 @@ dependencies = [ "log", "mime", "mime_guess", - "multipart", "percent-encoding 2.1.0", "pin-project 1.0.8", "scoped-tls", @@ -8021,7 +7961,6 @@ dependencies = [ "serde_urlencoded 0.7.0", "tokio 1.14.0", "tokio-stream", - "tokio-tungstenite", "tokio-util", "tower-service", "tracing", diff --git a/applications/tari_base_node/Cargo.toml b/applications/tari_base_node/Cargo.toml index e530446a8f..a32b47988d 100644 --- a/applications/tari_base_node/Cargo.toml +++ b/applications/tari_base_node/Cargo.toml @@ -49,7 +49,7 @@ tracing-subscriber = "0.2.20" # Metrics tari_metrics = { path = "../../infrastructure/metrics", optional = true } -warp = { version = "0.3.1", optional = true } +warp = { version = "0.3.1", optional = true, default-features = false } reqwest = { version = "0.11.4", default-features = false, optional = true } [features] diff --git a/base_layer/p2p/src/lib.rs b/base_layer/p2p/src/lib.rs index 9f8b992bbf..58e65e6eb4 100644 --- a/base_layer/p2p/src/lib.rs +++ b/base_layer/p2p/src/lib.rs @@ -53,7 +53,7 @@ pub use tari_common::configuration::Network; pub const DEFAULT_DNS_NAME_SERVER: &str = "1.1.1.1:853/cloudflare-dns.com"; /// Major network version. Peers will refuse connections if this value differs -pub const MAJOR_NETWORK_VERSION: u32 = 0; +pub const MAJOR_NETWORK_VERSION: u8 = 0; /// Minor network version. This should change with each time the network protocol has changed in a backward-compatible /// way. -pub const MINOR_NETWORK_VERSION: u32 = 0; +pub const MINOR_NETWORK_VERSION: u8 = 0; diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index 8fdf80992a..97f1de4af0 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -550,8 +550,11 @@ async fn dht_propagate_dedup() { let mut node_A_messaging = node_A.messaging_events.subscribe(); let mut node_B_messaging = node_B.messaging_events.subscribe(); + let mut node_B_messaging2 = node_B.messaging_events.subscribe(); let mut node_C_messaging = node_C.messaging_events.subscribe(); + let mut node_C_messaging2 = node_C.messaging_events.subscribe(); let mut node_D_messaging = node_D.messaging_events.subscribe(); + let mut node_D_messaging2 = node_D.messaging_events.subscribe(); #[derive(Clone, PartialEq, ::prost::Message)] struct Person { @@ -596,6 +599,11 @@ async fn dht_propagate_dedup() { let node_C_id = node_C.node_identity().node_id().clone(); let node_D_id = node_D.node_identity().node_id().clone(); + // Ensure that the message has propagated before disconnecting everyone + let _ = node_B_messaging2.recv().await.unwrap(); + let _ = node_C_messaging2.recv().await.unwrap(); + let _ = node_D_messaging2.recv().await.unwrap(); + node_A.shutdown().await; node_B.shutdown().await; node_C.shutdown().await; @@ -611,7 +619,11 @@ async fn dht_propagate_dedup() { let received = filter_received(collect_try_recv!(node_B_messaging, timeout = Duration::from_secs(20))); let recv_count = count_messages_received(&received, &[&node_A_id, &node_C_id]); // Expected race condition: If A->B->C before A->C then C->B does not happen - assert!((1..=2).contains(&recv_count)); + assert!( + (1..=2).contains(&recv_count), + "expected recv_count to be in [1-2] but was {}", + recv_count + ); let received = filter_received(collect_try_recv!(node_C_messaging, timeout = Duration::from_secs(20))); let recv_count = count_messages_received(&received, &[&node_A_id, &node_B_id]); diff --git a/comms/src/builder/mod.rs b/comms/src/builder/mod.rs index 2885ab17fb..75d62f1b8e 100644 --- a/comms/src/builder/mod.rs +++ b/comms/src/builder/mod.rs @@ -124,7 +124,7 @@ impl CommsBuilder { } /// Set a network major and minor version as per [RFC-173 Versioning](https://rfc.tari.com/RFC-0173_Versioning.html) - pub fn with_node_version(mut self, major_version: u32, minor_version: u32) -> Self { + pub fn with_node_version(mut self, major_version: u8, minor_version: u8) -> Self { self.connection_manager_config.network_info.major_version = major_version; self.connection_manager_config.network_info.minor_version = minor_version; self diff --git a/comms/src/connection_manager/common.rs b/comms/src/connection_manager/common.rs index c2ec62f50d..f835f507c3 100644 --- a/comms/src/connection_manager/common.rs +++ b/comms/src/connection_manager/common.rs @@ -22,14 +22,13 @@ use std::{convert::TryFrom, net::Ipv6Addr}; -use futures::StreamExt; use log::*; +use tokio::io::{AsyncRead, AsyncWrite}; use super::types::ConnectionDirection; use crate::{ connection_manager::error::ConnectionManagerError, multiaddr::{Multiaddr, Protocol}, - multiplexing::Yamux, peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags}, proto::identity::PeerIdentityMsg, protocol, @@ -43,30 +42,24 @@ const LOG_TARGET: &str = "comms::connection_manager::common"; /// The maximum size of the peer's user agent string. If the peer sends a longer string it is truncated. const MAX_USER_AGENT_LEN: usize = 100; -pub async fn perform_identity_exchange<'p, P: IntoIterator>( - muxer: &mut Yamux, +pub async fn perform_identity_exchange< + 'p, + P: IntoIterator, + TSocket: AsyncRead + AsyncWrite + Unpin, +>( + socket: &mut TSocket, node_identity: &NodeIdentity, direction: ConnectionDirection, our_supported_protocols: P, network_info: NodeNetworkInfo, ) -> Result { - let mut control = muxer.get_yamux_control(); - let stream = match direction { - ConnectionDirection::Inbound => muxer - .incoming_mut() - .next() - .await - .ok_or(ConnectionManagerError::IncomingListenerStreamClosed)?, - ConnectionDirection::Outbound => control.open_stream().await?, - }; - debug!( target: LOG_TARGET, - "{} substream opened to peer. Performing identity exchange.", direction + "{} socket opened to peer. Performing identity exchange.", direction ); let peer_identity = - protocol::identity_exchange(node_identity, direction, our_supported_protocols, network_info, stream).await?; + protocol::identity_exchange(node_identity, direction, our_supported_protocols, network_info, socket).await?; Ok(peer_identity) } diff --git a/comms/src/connection_manager/dialer.rs b/comms/src/connection_manager/dialer.rs index aa2735fed9..dc21ab5310 100644 --- a/comms/src/connection_manager/dialer.rs +++ b/comms/src/connection_manager/dialer.rs @@ -352,7 +352,7 @@ where async fn perform_socket_upgrade_procedure( peer_manager: Arc, node_identity: Arc, - socket: NoiseSocket, + mut socket: NoiseSocket, dialed_addr: Multiaddr, authenticated_public_key: CommsPublicKey, conn_man_notifier: mpsc::Sender, @@ -361,33 +361,29 @@ where cancel_signal: ShutdownSignal, ) -> Result { static CONNECTION_DIRECTION: ConnectionDirection = ConnectionDirection::Outbound; - let mut muxer = Yamux::upgrade_connection(socket, CONNECTION_DIRECTION) - .await - .map_err(|err| ConnectionManagerError::YamuxUpgradeFailure(err.to_string()))?; - debug!( target: LOG_TARGET, "Starting peer identity exchange for peer with public key '{}'", authenticated_public_key ); - if cancel_signal.is_terminated() { - return Err(ConnectionManagerError::DialCancelled); - } + + // Check if we know the peer and if it is banned + let known_peer = common::find_unbanned_peer(&peer_manager, &authenticated_public_key).await?; let peer_identity = common::perform_identity_exchange( - &mut muxer, + &mut socket, &node_identity, CONNECTION_DIRECTION, &our_supported_protocols, config.network_info.clone(), ) .await?; + if cancel_signal.is_terminated() { - muxer.get_yamux_control().close().await?; return Err(ConnectionManagerError::DialCancelled); } let features = PeerFeatures::from_bits_truncate(peer_identity.features); - trace!( + debug!( target: LOG_TARGET, "Peer identity exchange succeeded on Outbound connection for peer '{}' (Features = {:?})", authenticated_public_key, @@ -395,9 +391,6 @@ where ); trace!(target: LOG_TARGET, "{:?}", peer_identity); - // Check if we know the peer and if it is banned - let known_peer = common::find_unbanned_peer(&peer_manager, &authenticated_public_key).await?; - let (peer_node_id, their_supported_protocols) = common::validate_and_add_peer_from_peer_identity( &peer_manager, known_peer, @@ -409,7 +402,6 @@ where .await?; if cancel_signal.is_terminated() { - muxer.get_yamux_control().close().await?; return Err(ConnectionManagerError::DialCancelled); } @@ -420,6 +412,14 @@ where peer_node_id.short_str() ); + let muxer = Yamux::upgrade_connection(socket, CONNECTION_DIRECTION) + .map_err(|err| ConnectionManagerError::YamuxUpgradeFailure(err.to_string()))?; + + if cancel_signal.is_terminated() { + muxer.get_yamux_control().close().await?; + return Err(ConnectionManagerError::DialCancelled); + } + peer_connection::create( muxer, dialed_addr, diff --git a/comms/src/connection_manager/listener.rs b/comms/src/connection_manager/listener.rs index f261fbd7c1..3f50376ba0 100644 --- a/comms/src/connection_manager/listener.rs +++ b/comms/src/connection_manager/listener.rs @@ -28,7 +28,7 @@ use std::{ atomic::{AtomicUsize, Ordering}, Arc, }, - time::Duration, + time::{Duration, Instant}, }; use futures::{future, FutureExt}; @@ -351,7 +351,8 @@ where "Starting noise protocol upgrade for peer at address '{}'", peer_addr ); - let noise_socket = time::timeout( + let timer = Instant::now(); + let mut noise_socket = time::timeout( Duration::from_secs(30), noise_config.upgrade_socket(socket, CONNECTION_DIRECTION), ) @@ -362,21 +363,23 @@ where .get_remote_public_key() .ok_or(ConnectionManagerError::InvalidStaticPublicKey)?; + debug!( + target: LOG_TARGET, + "Noise socket upgrade completed in {:.2?} with public key '{}'", + timer.elapsed(), + authenticated_public_key + ); + // Check if we know the peer and if it is banned let known_peer = common::find_unbanned_peer(&peer_manager, &authenticated_public_key).await?; - let mut muxer = Yamux::upgrade_connection(noise_socket, CONNECTION_DIRECTION) - .await - .map_err(|err| ConnectionManagerError::YamuxUpgradeFailure(err.to_string()))?; - - trace!( + debug!( target: LOG_TARGET, - "Starting peer identity exchange for peer with public key '{}'", - authenticated_public_key + "Starting peer identity exchange for peer with public key '{}'", authenticated_public_key ); let peer_identity = common::perform_identity_exchange( - &mut muxer, + &mut noise_socket, &node_identity, CONNECTION_DIRECTION, &our_supported_protocols, @@ -410,6 +413,9 @@ where peer_node_id.short_str() ); + let muxer = Yamux::upgrade_connection(noise_socket, CONNECTION_DIRECTION) + .map_err(|err| ConnectionManagerError::YamuxUpgradeFailure(err.to_string()))?; + peer_connection::create( muxer, peer_addr, diff --git a/comms/src/connection_manager/tests/manager.rs b/comms/src/connection_manager/tests/manager.rs index cacda2bd0d..c6ebf133b6 100644 --- a/comms/src/connection_manager/tests/manager.rs +++ b/comms/src/connection_manager/tests/manager.rs @@ -42,7 +42,7 @@ use crate::{ }, noise::NoiseConfig, peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags, PeerManagerError}, - protocol::{ProtocolEvent, ProtocolId, Protocols, IDENTITY_PROTOCOL}, + protocol::{ProtocolEvent, ProtocolId, Protocols}, runtime, runtime::task, test_utils::{ @@ -156,7 +156,7 @@ async fn dial_success() { let mut conn_out = conn_man1.dial_peer(node_identity2.node_id().clone()).await.unwrap(); assert_eq!(conn_out.peer_node_id(), node_identity2.node_id()); let peer2 = peer_manager1.find_by_node_id(conn_out.peer_node_id()).await.unwrap(); - assert_eq!(peer2.supported_protocols, [&IDENTITY_PROTOCOL, &TEST_PROTO]); + assert_eq!(peer2.supported_protocols, [&TEST_PROTO]); assert_eq!(peer2.user_agent, "node2"); let event = subscription2.recv().await.unwrap(); @@ -164,7 +164,7 @@ async fn dial_success() { assert_eq!(conn_in.peer_node_id(), node_identity1.node_id()); let peer1 = peer_manager2.find_by_node_id(node_identity1.node_id()).await.unwrap(); - assert_eq!(peer1.supported_protocols(), [&IDENTITY_PROTOCOL, &TEST_PROTO]); + assert_eq!(peer1.supported_protocols(), [&TEST_PROTO]); assert_eq!(peer1.user_agent, "node1"); let err = conn_out diff --git a/comms/src/connectivity/manager.rs b/comms/src/connectivity/manager.rs index 31e9eff330..b902d781f9 100644 --- a/comms/src/connectivity/manager.rs +++ b/comms/src/connectivity/manager.rs @@ -490,9 +490,9 @@ impl ConnectivityManagerActor { #[allow(clippy::single_match)] match event { PeerConnected(new_conn) => { - self.connection_manager - .cancel_dial(new_conn.peer_node_id().clone()) - .await?; + // self.connection_manager + // .cancel_dial(new_conn.peer_node_id().clone()) + // .await?; match self.pool.get_connection(new_conn.peer_node_id()) { Some(existing_conn) if !existing_conn.is_connected() => { diff --git a/comms/src/multiplexing/yamux.rs b/comms/src/multiplexing/yamux.rs index 9a5fdd9a01..f17b48a9e4 100644 --- a/comms/src/multiplexing/yamux.rs +++ b/comms/src/multiplexing/yamux.rs @@ -55,7 +55,7 @@ const RECEIVE_WINDOW: u32 = 5 * 1024 * 1024; // 5MiB impl Yamux { /// Upgrade the underlying socket to use yamux - pub async fn upgrade_connection(socket: TSocket, direction: ConnectionDirection) -> io::Result + pub fn upgrade_connection(socket: TSocket, direction: ConnectionDirection) -> io::Result where TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static { let mode = match direction { ConnectionDirection::Inbound => Mode::Server, @@ -360,9 +360,7 @@ mod test { let (dialer, listener) = MemorySocket::new_pair(); let msg = b"The Way of Kings"; - let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound) - .await - .unwrap(); + let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound)?; let mut dialer_control = dialer.get_yamux_control(); task::spawn(async move { @@ -373,9 +371,7 @@ mod test { substream.shutdown().await.unwrap(); }); - let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound) - .await? - .into_incoming(); + let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound)?.into_incoming(); let mut substream = listener .next() .await @@ -396,9 +392,7 @@ mod test { const NUM_SUBSTREAMS: usize = 10; let (dialer, listener) = MemorySocket::new_pair(); - let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound) - .await - .unwrap(); + let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound).unwrap(); let mut dialer_control = dialer.get_yamux_control(); let substreams_out = task::spawn(async move { @@ -410,7 +404,6 @@ mod test { }); let mut listener = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound) - .await .unwrap() .into_incoming(); let substreams_in = collect_stream!(&mut listener, take = NUM_SUBSTREAMS, timeout = Duration::from_secs(10)); @@ -430,7 +423,7 @@ mod test { let (dialer, listener) = MemorySocket::new_pair(); let msg = b"Words of Radiance"; - let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound).await?; + let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound)?; let mut dialer_control = dialer.get_yamux_control(); task::spawn(async move { @@ -444,9 +437,7 @@ mod test { assert_eq!(buf, b""); }); - let mut incoming = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound) - .await? - .into_incoming(); + let mut incoming = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound)?.into_incoming(); let mut substream = incoming.next().await.unwrap(); let mut buf = vec![0; msg.len()]; @@ -473,7 +464,7 @@ mod test { let (dialer, listener) = MemorySocket::new_pair(); - let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound).await?; + let dialer = Yamux::upgrade_connection(dialer, ConnectionDirection::Outbound)?; let mut dialer_control = dialer.get_yamux_control(); task::spawn(async move { @@ -492,9 +483,7 @@ mod test { assert_eq!(buf, vec![0xAAu8; MSG_LEN]); }); - let mut incoming = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound) - .await? - .into_incoming(); + let mut incoming = Yamux::upgrade_connection(listener, ConnectionDirection::Inbound)?.into_incoming(); assert_eq!(incoming.substream_count(), 0); let mut substream = incoming.next().await.unwrap(); assert_eq!(incoming.substream_count(), 1); diff --git a/comms/src/peer_manager/peer.rs b/comms/src/peer_manager/peer.rs index 785edf5af0..7e6482cedf 100644 --- a/comms/src/peer_manager/peer.rs +++ b/comms/src/peer_manager/peer.rs @@ -350,6 +350,7 @@ impl Hash for Peer { #[cfg(test)] mod test { + use bytes::Bytes; use serde_json::Value; use tari_crypto::{ keys::PublicKey, @@ -361,7 +362,6 @@ mod test { use crate::{ net_address::MultiaddressesWithStats, peer_manager::NodeId, - protocol, test_utils::node_identity::build_node_identity, types::CommsPublicKey, }; @@ -424,6 +424,7 @@ mod test { let net_address2 = "/ip4/125.0.0.125/tcp/8000".parse::().unwrap(); let net_address3 = "/ip4/126.0.0.126/tcp/9000".parse::().unwrap(); + static DUMMY_PROTOCOL: Bytes = Bytes::from_static(b"dummy"); peer.update( Some(vec![net_address2.clone(), net_address3.clone()]), None, @@ -431,7 +432,7 @@ mod test { Some("".to_string()), None, Some(PeerFeatures::MESSAGE_PROPAGATION), - Some(vec![protocol::IDENTITY_PROTOCOL.clone()]), + Some(vec![DUMMY_PROTOCOL.clone()]), ); assert_eq!(peer.public_key, public_key1); @@ -453,7 +454,7 @@ mod test { .any(|net_address_with_stats| net_address_with_stats.address == net_address3)); assert!(peer.is_banned()); assert!(peer.has_features(PeerFeatures::MESSAGE_PROPAGATION)); - assert_eq!(peer.supported_protocols, vec![protocol::IDENTITY_PROTOCOL.clone()]); + assert_eq!(peer.supported_protocols, vec![DUMMY_PROTOCOL.clone()]); } #[test] diff --git a/comms/src/proto/identity.proto b/comms/src/proto/identity.proto index bf4d9f5a78..9147b29f94 100644 --- a/comms/src/proto/identity.proto +++ b/comms/src/proto/identity.proto @@ -7,8 +7,4 @@ message PeerIdentityMsg { uint64 features = 2; repeated bytes supported_protocols = 3; string user_agent = 4; - // Major node version. This must match the current node's version in order for the connection to be established. - uint32 major = 5; - // Minor node version. This indicates minor non-breaking changes. - uint32 minor = 6; } diff --git a/comms/src/protocol/identity.rs b/comms/src/protocol/identity.rs index 6a430adfc4..a3f2c44e26 100644 --- a/comms/src/protocol/identity.rs +++ b/comms/src/protocol/identity.rs @@ -19,17 +19,16 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{io, time::Duration}; +use std::{convert::TryFrom, io, time::Duration}; -use futures::{SinkExt, StreamExt}; +use bytes::Bytes; use log::*; use prost::Message; use thiserror::Error; use tokio::{ - io::{AsyncRead, AsyncWrite}, + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, time, }; -use tokio_util::codec::{Framed, LengthDelimitedCodec}; use tracing; use crate::{ @@ -37,55 +36,25 @@ use crate::{ message::MessageExt, peer_manager::NodeIdentity, proto::identity::PeerIdentityMsg, - protocol::{NodeNetworkInfo, ProtocolError, ProtocolId, ProtocolNegotiation}, + protocol::{NodeNetworkInfo, ProtocolError, ProtocolId}, }; -pub static IDENTITY_PROTOCOL: ProtocolId = ProtocolId::from_static(b"t/identity/1.0"); const LOG_TARGET: &str = "comms::protocol::identity"; +const MAX_IDENTITY_PROTOCOL_MSG_SIZE: u16 = 1024; + #[tracing::instrument(skip(socket, our_supported_protocols))] pub async fn identity_exchange<'p, TSocket, P>( node_identity: &NodeIdentity, direction: ConnectionDirection, our_supported_protocols: P, network_info: NodeNetworkInfo, - mut socket: TSocket, + socket: &mut TSocket, ) -> Result where TSocket: AsyncRead + AsyncWrite + Unpin, P: IntoIterator, { - // Negotiate the identity protocol - let mut negotiation = ProtocolNegotiation::new(&mut socket); - let proto = match direction { - ConnectionDirection::Outbound => { - debug!( - target: LOG_TARGET, - "[ThisNode={}] Starting Outbound identity exchange with peer.", - node_identity.node_id().short_str() - ); - negotiation - .negotiate_protocol_outbound_optimistic(&IDENTITY_PROTOCOL.clone()) - .await? - }, - ConnectionDirection::Inbound => { - debug!( - target: LOG_TARGET, - "[ThisNode={}] Starting Inbound identity exchange with peer.", - node_identity.node_id().short_str() - ); - negotiation - .negotiate_protocol_inbound(&[IDENTITY_PROTOCOL.clone()]) - .await? - }, - }; - - debug_assert_eq!(proto, IDENTITY_PROTOCOL); - - // Create length-delimited frame codec - let framed = Framed::new(socket, LengthDelimitedCodec::new()); - let (mut sink, mut stream) = framed.split(); - let supported_protocols = our_supported_protocols.into_iter().map(|p| p.to_vec()).collect(); // Send this node's identity @@ -93,28 +62,23 @@ where addresses: vec![node_identity.public_address().to_vec()], features: node_identity.features().bits(), supported_protocols, - major: network_info.major_version, - minor: network_info.minor_version, user_agent: network_info.user_agent, } .to_encoded_bytes(); - sink.send(msg_bytes.into()).await?; - sink.close().await?; + write_protocol_frame(socket, network_info.major_version as u8, &msg_bytes).await?; + socket.flush().await?; - // Receive the connecting nodes identity - let msg_bytes = time::timeout(Duration::from_secs(10), stream.next()) - .await? - .ok_or(IdentityProtocolError::PeerUnexpectedCloseConnection)??; - let identity_msg = PeerIdentityMsg::decode(msg_bytes)?; + // Receive the connecting node's identity + let (version, msg_bytes) = time::timeout(Duration::from_secs(10), read_protocol_frame(socket)).await??; + let identity_msg = PeerIdentityMsg::decode(Bytes::from(msg_bytes))?; - if identity_msg.major != network_info.major_version { + if version > network_info.major_version { warn!( target: LOG_TARGET, - "Peer sent mismatching major protocol version '{}'. This node has version '{}.{}'", - identity_msg.major, - network_info.major_version, - network_info.minor_version + "Peer sent mismatching major protocol version '{}'. This node has version '{}'", + version, + network_info.major_version ); return Err(IdentityProtocolError::ProtocolVersionMismatch); } @@ -122,6 +86,56 @@ where Ok(identity_msg) } +async fn read_protocol_frame(socket: &mut S) -> Result<(u8, Vec), IdentityProtocolError> { + let mut buf = [0u8; 3]; + socket.read_exact(&mut buf).await?; + let version = buf[0]; + let buf = [buf[1], buf[2]]; + let len = u16::from_le_bytes(buf); + if len > MAX_IDENTITY_PROTOCOL_MSG_SIZE { + return Err(IdentityProtocolError::MaxMsgSizeExceeded { + expected: MAX_IDENTITY_PROTOCOL_MSG_SIZE, + got: len, + }); + } + let len = len as usize; + let mut msg = vec![0u8; len]; + socket.read_exact(&mut msg).await?; + Ok((version, msg)) +} + +async fn write_protocol_frame( + socket: &mut S, + version: u8, + msg_bytes: &[u8], +) -> Result<(), IdentityProtocolError> { + debug_assert!( + msg_bytes.len() <= MAX_IDENTITY_PROTOCOL_MSG_SIZE as usize, + "Sending identity protocol message of size {}, greater than {} bytes. This is a protocol violation", + msg_bytes.len(), + MAX_IDENTITY_PROTOCOL_MSG_SIZE + ); + + let len = u16::try_from(msg_bytes.len()).map_err(|_| { + IdentityProtocolError::ProtocolError(format!( + "Identity protocol attempted to send a message larger than u16::MAX bytes. len = {}", + msg_bytes.len() + )) + })?; + let version_bytes = [version]; + let len_bytes = len.to_le_bytes(); + + trace!( + target: LOG_TARGET, + "Writing {} bytes", + len_bytes.len() + msg_bytes.len() + 1 + ); + socket.write_all(&version_bytes[..]).await?; + socket.write_all(&len_bytes[..]).await?; + socket.write_all(msg_bytes).await?; + Ok(()) +} + #[derive(Debug, Error, Clone)] pub enum IdentityProtocolError { #[error("IoError: {0}")] @@ -138,6 +152,8 @@ pub enum IdentityProtocolError { Timeout, #[error("Protocol version mismatch")] ProtocolVersionMismatch, + #[error("Max identity protocol message size exceeded. Expected <= {expected} got {got}")] + MaxMsgSizeExceeded { expected: u16, got: u16 }, } impl From for IdentityProtocolError { @@ -185,8 +201,8 @@ mod test { let (out_sock, in_sock) = future::join(transport.dial(addr), listener.next()).await; - let out_sock = out_sock.unwrap(); - let (in_sock, _) = in_sock.unwrap().unwrap(); + let mut out_sock = out_sock.unwrap(); + let (mut in_sock, _) = in_sock.unwrap().unwrap(); let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_CLIENT); @@ -200,7 +216,7 @@ mod test { minor_version: 1, ..Default::default() }, - in_sock, + &mut in_sock, ), super::identity_exchange( &node_identity2, @@ -210,7 +226,7 @@ mod test { minor_version: 2, ..Default::default() }, - out_sock, + &mut out_sock, ), ) .await; @@ -234,8 +250,8 @@ mod test { let (out_sock, in_sock) = future::join(transport.dial(addr), listener.next()).await; - let out_sock = out_sock.unwrap(); - let (in_sock, _) = in_sock.unwrap().unwrap(); + let mut out_sock = out_sock.unwrap(); + let (mut in_sock, _) = in_sock.unwrap().unwrap(); let node_identity1 = build_node_identity(PeerFeatures::COMMUNICATION_NODE); let node_identity2 = build_node_identity(PeerFeatures::COMMUNICATION_CLIENT); @@ -249,7 +265,7 @@ mod test { major_version: 0, ..Default::default() }, - in_sock, + &mut in_sock, ), super::identity_exchange( &node_identity2, @@ -259,7 +275,7 @@ mod test { major_version: 1, ..Default::default() }, - out_sock, + &mut out_sock, ), ) .await; @@ -267,7 +283,7 @@ mod test { let err = result1.unwrap_err(); assert!(matches!(err, IdentityProtocolError::ProtocolVersionMismatch)); - let err = result2.unwrap_err(); - assert!(matches!(err, IdentityProtocolError::ProtocolVersionMismatch)); + // Passes because older versions are supported + result2.unwrap(); } } diff --git a/comms/src/protocol/mod.rs b/comms/src/protocol/mod.rs index 27628a183f..8fd41836ae 100644 --- a/comms/src/protocol/mod.rs +++ b/comms/src/protocol/mod.rs @@ -27,7 +27,7 @@ mod extensions; pub use extensions::{ProtocolExtension, ProtocolExtensionContext, ProtocolExtensionError, ProtocolExtensions}; mod identity; -pub use identity::{identity_exchange, IdentityProtocolError, IDENTITY_PROTOCOL}; +pub use identity::{identity_exchange, IdentityProtocolError}; mod negotiation; pub use negotiation::ProtocolNegotiation; diff --git a/comms/src/protocol/network_info.rs b/comms/src/protocol/network_info.rs index 7a122233a1..90968f8e65 100644 --- a/comms/src/protocol/network_info.rs +++ b/comms/src/protocol/network_info.rs @@ -25,10 +25,10 @@ pub struct NodeNetworkInfo { /// Major protocol version. This indicates the protocol version that is supported by this node. A peer MAY reject /// the connection if a remote peer advertises a different major version number. - pub major_version: u32, + pub major_version: u8, /// Minor protocol version. A version number that represents backward-compatible protocol changes. A peer SHOULD /// NOT reject the connection if a remote peer advertises a different minor version number. - pub minor_version: u32, + pub minor_version: u8, /// The byte that MUST be sent (outbound connections) or MUST be received (inbound connections) for a connection to /// be established. This byte cannot be 0x46 (E) because that is reserved for liveness. /// Default: 0x00 diff --git a/comms/src/protocol/protocols.rs b/comms/src/protocol/protocols.rs index 0f394a61aa..fd24c0f9ff 100644 --- a/comms/src/protocol/protocols.rs +++ b/comms/src/protocol/protocols.rs @@ -26,14 +26,7 @@ use tokio::sync::mpsc; use crate::{ peer_manager::NodeId, - protocol::{ - ProtocolError, - ProtocolExtension, - ProtocolExtensionContext, - ProtocolExtensionError, - ProtocolId, - IDENTITY_PROTOCOL, - }, + protocol::{ProtocolError, ProtocolExtension, ProtocolExtensionContext, ProtocolExtensionError, ProtocolId}, Substream, }; @@ -102,10 +95,7 @@ impl Protocols { } pub fn get_supported_protocols(&self) -> Vec { - let mut p = Vec::with_capacity(self.protocols.len() + 1); - p.push(IDENTITY_PROTOCOL.clone()); - p.extend(self.protocols.keys().cloned()); - p + self.protocols.keys().cloned().collect() } pub async fn notify( @@ -152,7 +142,6 @@ mod test { fn add() { let (tx, _) = mpsc::channel(1); let protos = [ - IDENTITY_PROTOCOL.clone(), ProtocolId::from_static(b"/tari/test/1"), ProtocolId::from_static(b"/tari/test/2"), ]; diff --git a/comms/src/test_utils/transport.rs b/comms/src/test_utils/transport.rs index 7d128d0726..7a770440fa 100644 --- a/comms/src/test_utils/transport.rs +++ b/comms/src/test_utils/transport.rs @@ -40,13 +40,9 @@ pub async fn build_connected_sockets() -> (Multiaddr, MemorySocket, MemorySocket pub async fn build_multiplexed_connections() -> (Multiaddr, Yamux, Yamux) { let (addr, socket_out, socket_in) = build_connected_sockets().await; - let muxer_out = Yamux::upgrade_connection(socket_out, ConnectionDirection::Outbound) - .await - .unwrap(); + let muxer_out = Yamux::upgrade_connection(socket_out, ConnectionDirection::Outbound).unwrap(); - let muxer_in = Yamux::upgrade_connection(socket_in, ConnectionDirection::Inbound) - .await - .unwrap(); + let muxer_in = Yamux::upgrade_connection(socket_in, ConnectionDirection::Inbound).unwrap(); (addr, muxer_out, muxer_in) }