Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: some multiplex followup #5553

Merged
merged 2 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions crates/net/eth-wire/src/capability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,17 +376,24 @@ impl SharedCapabilities {

/// Returns the matching shared capability for the given capability offset.
///
/// `offset` is the multiplexed message id offset of the capability relative to
/// [`MAX_RESERVED_MESSAGE_ID`].
/// `offset` is the multiplexed message id offset of the capability relative to the reserved
/// message id space. In other words, counting starts at [`MAX_RESERVED_MESSAGE_ID`] + 1, which
/// corresponds to the first non-reserved message id.
///
/// For example: `offset == 0` corresponds to the first shared message across the shared
/// capabilities and will return the first shared capability that supports messages.
#[inline]
pub fn find_by_relative_offset(&self, offset: u8) -> Option<&SharedCapability> {
self.find_by_offset(offset.saturating_add(MAX_RESERVED_MESSAGE_ID))
self.find_by_offset(offset.saturating_add(MAX_RESERVED_MESSAGE_ID + 1))
}

/// Returns the matching shared capability for the given capability offset.
///
/// `offset` is the multiplexed message id offset of the capability that includes the reserved
/// message id space.
///
/// This will always return None if `offset` is less than or equal to
/// [`MAX_RESERVED_MESSAGE_ID`] because the reserved message id space is not shared.
#[inline]
pub fn find_by_offset(&self, offset: u8) -> Option<&SharedCapability> {
let mut iter = self.0.iter();
Expand Down Expand Up @@ -637,12 +644,14 @@ mod tests {

let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();

assert!(shared.find_by_relative_offset(0).is_none());
let shared_eth = shared.find_by_relative_offset(1).unwrap();
let shared_eth = shared.find_by_relative_offset(0).unwrap();
assert_eq!(shared_eth.name(), "eth");

let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
assert_eq!(shared_eth.name(), "eth");

// reserved message id space
assert!(shared.find_by_offset(MAX_RESERVED_MESSAGE_ID).is_none());
}

#[test]
Expand All @@ -654,15 +663,14 @@ mod tests {

let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();

assert!(shared.find_by_relative_offset(0).is_none());
let shared_eth = shared.find_by_relative_offset(1).unwrap();
let shared_eth = shared.find_by_relative_offset(0).unwrap();
assert_eq!(shared_eth.name(), proto.cap.name);

let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
assert_eq!(shared_eth.name(), proto.cap.name);

// the 5th shared message is the last message of the aaa capability
let shared_eth = shared.find_by_relative_offset(5).unwrap();
// the 5th shared message (0,1,2,3,4) is the last message of the aaa capability
let shared_eth = shared.find_by_relative_offset(4).unwrap();
assert_eq!(shared_eth.name(), proto.cap.name);
let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 5).unwrap();
assert_eq!(shared_eth.name(), proto.cap.name);
Expand Down
94 changes: 52 additions & 42 deletions crates/net/eth-wire/src/multiplex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,16 @@ impl<St> RlpxProtocolMultiplexer<St> {
mut self,
cap: &Capability,
handshake: F,
) -> Result<RlpxSatelliteStream<St, Primary>, Self>
) -> Result<RlpxSatelliteStream<St, Primary>, Err>
where
F: FnOnce(ProtocolProxy) -> Fut,
Fut: Future<Output = Result<Primary, Err>>,
St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
P2PStreamError: Into<Err>,
{
let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
else {
return Err(self)
return Err(P2PStreamError::CapabilityNotShared.into())
};

let (to_primary, from_wire) = mpsc::unbounded_channel();
Expand All @@ -87,20 +88,36 @@ impl<St> RlpxProtocolMultiplexer<St> {
let f = handshake(proxy);
pin_mut!(f);

// handle messages until the handshake is complete
// this polls the connection and the primary stream concurrently until the handshake is
// complete
loop {
// TODO error handling
tokio::select! {
Some(Ok(msg)) = self.conn.next() => {
// TODO handle multiplex
let _ = to_primary.send(msg);
// Ensure the message belongs to the primary protocol
let offset = msg[0];
if let Some(cap) = self.conn.shared_capabilities().find_by_relative_offset(offset) {
if cap == &shared_cap {
// delegate to primary
let _ = to_primary.send(msg);
} else {
// delegate to satellite
for proto in &self.protocols {
if proto.cap == *cap {
// TODO: need some form of backpressure here so buffering can't be abused
proto.send_raw(msg);
break
}
}
}
} else {
return Err(P2PStreamError::UnknownReservedMessageId(offset).into())
}
}
Some(msg) = from_primary.recv() => {
// TODO error handling
self.conn.send(msg).await.unwrap();
self.conn.send(msg).await.map_err(Into::into)?;
}
res = &mut f => {
let Ok(primary) = res else { return Err(self) };
let primary = res?;
return Ok(RlpxSatelliteStream {
conn: self.conn,
to_primary,
Expand All @@ -126,15 +143,13 @@ pub struct ProtocolProxy {

impl ProtocolProxy {
fn mask_msg_id(&self, msg: Bytes) -> Bytes {
mattsse marked this conversation as resolved.
Show resolved Hide resolved
// TODO handle empty messages
let mut masked_bytes = BytesMut::zeroed(msg.len());
masked_bytes[0] = msg[0] + self.cap.relative_message_id_offset();
masked_bytes[1..].copy_from_slice(&msg[1..]);
masked_bytes.freeze()
}

fn unmask_id(&self, mut msg: BytesMut) -> BytesMut {
// TODO handle empty messages
msg[0] -= self.cap.relative_message_id_offset();
msg
}
Expand All @@ -157,6 +172,10 @@ impl Sink<Bytes> for ProtocolProxy {
}

fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
if item.is_empty() {
// message must not be empty
return Err(io::ErrorKind::InvalidInput.into())
}
let msg = self.mask_msg_id(item);
self.to_wire.send(msg).map_err(|_| io::ErrorKind::BrokenPipe.into())
}
Expand Down Expand Up @@ -287,34 +306,28 @@ where
Poll::Ready(Some(Ok(msg))) => {
delegated = true;
let offset = msg[0];
// find the protocol that matches the offset
// TODO optimize this by keeping a better index
let mut lowest_satellite = None;
// find the protocol with the lowest offset that is greater than the message
// offset
for (i, proto) in this.satellites.iter().enumerate() {
let proto_offset = proto.cap.relative_message_id_offset();
if proto_offset >= offset {
if let Some((_, lowest_offset)) = lowest_satellite {
if proto_offset < lowest_offset {
lowest_satellite = Some((i, proto_offset));
// delegate the multiplexed message to the correct protocol
if let Some(cap) =
this.conn.shared_capabilities().find_by_relative_offset(offset)
{
if cap == &this.primary_capability {
// delegate to primary
let _ = this.to_primary.send(msg);
} else {
// delegate to satellite
for proto in &this.satellites {
if proto.cap == *cap {
proto.send_raw(msg);
break
}
} else {
lowest_satellite = Some((i, proto_offset));
}
}
} else {
return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(
offset,
)
.into())))
}

if let Some((idx, lowest_offset)) = lowest_satellite {
if lowest_offset < this.primary_capability.relative_message_id_offset()
{
// delegate to satellite
this.satellites[idx].send_raw(msg);
continue
}
}
// delegate to primary
let _ = this.to_primary.send(msg);
}
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
Poll::Ready(None) => {
Expand Down Expand Up @@ -373,14 +386,13 @@ struct ProtocolStream {
}

impl ProtocolStream {
#[inline]
fn mask_msg_id(&self, mut msg: BytesMut) -> Bytes {
// TODO handle empty messages
msg[0] += self.cap.relative_message_id_offset();
msg.freeze()
}

fn unmask_id(&self, mut msg: BytesMut) -> BytesMut {
// TODO handle empty messages
msg[0] -= self.cap.relative_message_id_offset();
msg
}
Expand Down Expand Up @@ -408,15 +420,13 @@ impl fmt::Debug for ProtocolStream {

#[cfg(test)]
mod tests {
use tokio::net::TcpListener;
use tokio_util::codec::Decoder;

use super::*;
use crate::{
test_utils::{connect_passthrough, eth_handshake, eth_hello},
UnauthedEthStream, UnauthedP2PStream,
};

use super::*;
use tokio::net::TcpListener;
use tokio_util::codec::Decoder;

#[tokio::test]
async fn eth_satellite() {
Expand Down
14 changes: 10 additions & 4 deletions crates/net/eth-wire/src/p2pstream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,10 @@ where
///
/// See also <https://github.com/ethereum/devp2p/blob/master/rlpx.md#message-id-based-multiplexing>
///
/// This stream emits Bytes that start with the normalized message id, so that the first byte of
/// each message starts from 0. If this stream only supports a single capability, for example `eth`
/// then the first byte of each message will match [EthMessageID](crate::types::EthMessageID).
/// This stream emits _non-empty_ Bytes that start with the normalized message id, so that the first
/// byte of each message starts from 0. If this stream only supports a single capability, for
/// example `eth` then the first byte of each message will match
/// [EthMessageID](crate::types::EthMessageID).
#[pin_project]
#[derive(Debug)]
pub struct P2PStream<S> {
Expand Down Expand Up @@ -405,6 +406,11 @@ where
None => return Poll::Ready(None),
};

if bytes.is_empty() {
// empty messages are not allowed
return Poll::Ready(Some(Err(P2PStreamError::EmptyProtocolMessage)))
}

// first check that the compressed message length does not exceed the max
// payload size
let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
Expand All @@ -430,7 +436,7 @@ where
err
})?;

let id = *bytes.first().ok_or(P2PStreamError::EmptyProtocolMessage)?;
let id = bytes[0];
match id {
_ if id == P2PMessageID::Ping as u8 => {
trace!("Received Ping, Sending Pong");
Expand Down
Loading