Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: some multiplex followup #5553

Merged
merged 2 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions crates/net/eth-wire/src/capability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,17 +376,24 @@ impl SharedCapabilities {

/// Returns the matching shared capability for the given capability offset.
///
/// `offset` is the multiplexed message id offset of the capability relative to
/// [`MAX_RESERVED_MESSAGE_ID`].
/// `offset` is the multiplexed message id offset of the capability relative to the reserved
/// message id space. In other words, counting starts at [`MAX_RESERVED_MESSAGE_ID`] + 1, which
/// corresponds to the first non-reserved message id.
///
/// For example: `offset == 0` corresponds to the first shared message across the shared
/// capabilities and will return the first shared capability that supports messages.
#[inline]
pub fn find_by_relative_offset(&self, offset: u8) -> Option<&SharedCapability> {
self.find_by_offset(offset.saturating_add(MAX_RESERVED_MESSAGE_ID))
self.find_by_offset(offset.saturating_add(MAX_RESERVED_MESSAGE_ID + 1))
}

/// Returns the matching shared capability for the given capability offset.
///
/// `offset` is the multiplexed message id offset of the capability that includes the reserved
/// message id space.
///
/// This will always return None if `offset` is less than or equal to
/// [`MAX_RESERVED_MESSAGE_ID`] because the reserved message id space is not shared.
#[inline]
pub fn find_by_offset(&self, offset: u8) -> Option<&SharedCapability> {
let mut iter = self.0.iter();
Expand Down Expand Up @@ -637,12 +644,14 @@ mod tests {

let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();

assert!(shared.find_by_relative_offset(0).is_none());
let shared_eth = shared.find_by_relative_offset(1).unwrap();
let shared_eth = shared.find_by_relative_offset(0).unwrap();
assert_eq!(shared_eth.name(), "eth");

let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
assert_eq!(shared_eth.name(), "eth");

// reserved message id space
assert!(shared.find_by_offset(MAX_RESERVED_MESSAGE_ID).is_none());
}

#[test]
Expand All @@ -654,15 +663,14 @@ mod tests {

let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();

assert!(shared.find_by_relative_offset(0).is_none());
let shared_eth = shared.find_by_relative_offset(1).unwrap();
let shared_eth = shared.find_by_relative_offset(0).unwrap();
assert_eq!(shared_eth.name(), proto.cap.name);

let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
assert_eq!(shared_eth.name(), proto.cap.name);

// the 5th shared message is the last message of the aaa capability
let shared_eth = shared.find_by_relative_offset(5).unwrap();
// the 5th shared message (0,1,2,3,4) is the last message of the aaa capability
let shared_eth = shared.find_by_relative_offset(4).unwrap();
assert_eq!(shared_eth.name(), proto.cap.name);
let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 5).unwrap();
assert_eq!(shared_eth.name(), proto.cap.name);
Expand Down
134 changes: 88 additions & 46 deletions crates/net/eth-wire/src/multiplex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,16 @@ impl<St> RlpxProtocolMultiplexer<St> {
mut self,
cap: &Capability,
handshake: F,
) -> Result<RlpxSatelliteStream<St, Primary>, Self>
) -> Result<RlpxSatelliteStream<St, Primary>, Err>
where
F: FnOnce(ProtocolProxy) -> Fut,
Fut: Future<Output = Result<Primary, Err>>,
St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
P2PStreamError: Into<Err>,
{
let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
else {
return Err(self)
return Err(P2PStreamError::CapabilityNotShared.into())
};

let (to_primary, from_wire) = mpsc::unbounded_channel();
Expand All @@ -87,20 +88,36 @@ impl<St> RlpxProtocolMultiplexer<St> {
let f = handshake(proxy);
pin_mut!(f);

// handle messages until the handshake is complete
// this polls the connection and the primary stream concurrently until the handshake is
// complete
loop {
// TODO error handling
tokio::select! {
Some(Ok(msg)) = self.conn.next() => {
// TODO handle multiplex
let _ = to_primary.send(msg);
// Ensure the message belongs to the primary protocol
let offset = msg[0];
if let Some(cap) = self.conn.shared_capabilities().find_by_relative_offset(offset) {
if cap == &shared_cap {
// delegate to primary
let _ = to_primary.send(msg);
} else {
// delegate to satellite
for proto in &self.protocols {
if proto.cap == *cap {
// TODO: need some form of backpressure here so buffering can't be abused
proto.send_raw(msg);
break
}
}
}
} else {
return Err(P2PStreamError::UnknownReservedMessageId(offset).into())
}
}
Some(msg) = from_primary.recv() => {
// TODO error handling
self.conn.send(msg).await.unwrap();
self.conn.send(msg).await.map_err(Into::into)?;
}
res = &mut f => {
let Ok(primary) = res else { return Err(self) };
let primary = res?;
return Ok(RlpxSatelliteStream {
conn: self.conn,
to_primary,
Expand All @@ -117,24 +134,47 @@ impl<St> RlpxProtocolMultiplexer<St> {
}

/// A Stream and Sink type that acts as a wrapper around a primary RLPx subprotocol (e.g. "eth")
///
/// Only emits and sends _non-empty_ messages
#[derive(Debug)]
pub struct ProtocolProxy {
cap: SharedCapability,
/// Receives _non-empty_ messages from the wire
from_wire: UnboundedReceiverStream<BytesMut>,
/// Sends _non-empty_ messages from the wire
to_wire: UnboundedSender<Bytes>,
}

impl ProtocolProxy {
/// Sends a _non-empty_ message on the wire.
fn try_send(&self, msg: Bytes) -> Result<(), io::Error> {
if msg.is_empty() {
// message must not be empty
return Err(io::ErrorKind::InvalidInput.into())
}
self.to_wire.send(self.mask_msg_id(msg)).map_err(|_| io::ErrorKind::BrokenPipe.into())
}

/// Masks the message ID of a message to be sent on the wire.
///
/// # Panics
///
/// If the message is empty.
#[inline]
fn mask_msg_id(&self, msg: Bytes) -> Bytes {
mattsse marked this conversation as resolved.
Show resolved Hide resolved
// TODO handle empty messages
let mut masked_bytes = BytesMut::zeroed(msg.len());
masked_bytes[0] = msg[0] + self.cap.relative_message_id_offset();
masked_bytes[1..].copy_from_slice(&msg[1..]);
masked_bytes.freeze()
}

/// Unmasks the message ID of a message received from the wire.
///
/// # Panics
///
/// If the message is empty.
#[inline]
fn unmask_id(&self, mut msg: BytesMut) -> BytesMut {
// TODO handle empty messages
msg[0] -= self.cap.relative_message_id_offset();
msg
}
Expand All @@ -157,8 +197,7 @@ impl Sink<Bytes> for ProtocolProxy {
}

fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
let msg = self.mask_msg_id(item);
self.to_wire.send(msg).map_err(|_| io::ErrorKind::BrokenPipe.into())
self.get_mut().try_send(item)
}

fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Expand All @@ -181,7 +220,7 @@ impl CanDisconnect<Bytes> for ProtocolProxy {
}
}

/// A connection channel to receive messages for the negotiated protocol.
/// A connection channel to receive _non_empty_ messages for the negotiated protocol.
///
/// This is a [Stream] that returns raw bytes of the received messages for this protocol.
#[derive(Debug)]
Expand Down Expand Up @@ -287,34 +326,28 @@ where
Poll::Ready(Some(Ok(msg))) => {
delegated = true;
let offset = msg[0];
// find the protocol that matches the offset
// TODO optimize this by keeping a better index
let mut lowest_satellite = None;
// find the protocol with the lowest offset that is greater than the message
// offset
for (i, proto) in this.satellites.iter().enumerate() {
let proto_offset = proto.cap.relative_message_id_offset();
if proto_offset >= offset {
if let Some((_, lowest_offset)) = lowest_satellite {
if proto_offset < lowest_offset {
lowest_satellite = Some((i, proto_offset));
// delegate the multiplexed message to the correct protocol
if let Some(cap) =
this.conn.shared_capabilities().find_by_relative_offset(offset)
{
if cap == &this.primary_capability {
// delegate to primary
let _ = this.to_primary.send(msg);
} else {
// delegate to satellite
for proto in &this.satellites {
if proto.cap == *cap {
proto.send_raw(msg);
break
}
} else {
lowest_satellite = Some((i, proto_offset));
}
}
} else {
return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(
offset,
)
.into())))
}

if let Some((idx, lowest_offset)) = lowest_satellite {
if lowest_offset < this.primary_capability.relative_message_id_offset()
{
// delegate to satellite
this.satellites[idx].send_raw(msg);
continue
}
}
// delegate to primary
let _ = this.to_primary.send(msg);
}
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
Poll::Ready(None) => {
Expand Down Expand Up @@ -373,18 +406,29 @@ struct ProtocolStream {
}

impl ProtocolStream {
/// Masks the message ID of a message to be sent on the wire.
///
/// # Panics
///
/// If the message is empty.
#[inline]
fn mask_msg_id(&self, mut msg: BytesMut) -> Bytes {
// TODO handle empty messages
msg[0] += self.cap.relative_message_id_offset();
msg.freeze()
}

/// Unmasks the message ID of a message received from the wire.
///
/// # Panics
///
/// If the message is empty.
#[inline]
fn unmask_id(&self, mut msg: BytesMut) -> BytesMut {
// TODO handle empty messages
msg[0] -= self.cap.relative_message_id_offset();
msg
}

/// Sends the message to the satellite stream.
fn send_raw(&self, msg: BytesMut) {
let _ = self.to_satellite.send(self.unmask_id(msg));
}
Expand All @@ -396,7 +440,7 @@ impl Stream for ProtocolStream {
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let msg = ready!(this.satellite_st.as_mut().poll_next(cx));
Poll::Ready(msg.map(|msg| this.mask_msg_id(msg)))
Poll::Ready(msg.filter(|msg| !msg.is_empty()).map(|msg| this.mask_msg_id(msg)))
}
}

Expand All @@ -408,15 +452,13 @@ impl fmt::Debug for ProtocolStream {

#[cfg(test)]
mod tests {
use tokio::net::TcpListener;
use tokio_util::codec::Decoder;

use super::*;
use crate::{
test_utils::{connect_passthrough, eth_handshake, eth_hello},
UnauthedEthStream, UnauthedP2PStream,
};

use super::*;
use tokio::net::TcpListener;
use tokio_util::codec::Decoder;

#[tokio::test]
async fn eth_satellite() {
Expand Down
14 changes: 10 additions & 4 deletions crates/net/eth-wire/src/p2pstream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,10 @@ where
///
/// See also <https://github.com/ethereum/devp2p/blob/master/rlpx.md#message-id-based-multiplexing>
///
/// This stream emits Bytes that start with the normalized message id, so that the first byte of
/// each message starts from 0. If this stream only supports a single capability, for example `eth`
/// then the first byte of each message will match [EthMessageID](crate::types::EthMessageID).
/// This stream emits _non-empty_ Bytes that start with the normalized message id, so that the first
/// byte of each message starts from 0. If this stream only supports a single capability, for
/// example `eth` then the first byte of each message will match
/// [EthMessageID](crate::types::EthMessageID).
#[pin_project]
#[derive(Debug)]
pub struct P2PStream<S> {
Expand Down Expand Up @@ -405,6 +406,11 @@ where
None => return Poll::Ready(None),
};

if bytes.is_empty() {
// empty messages are not allowed
return Poll::Ready(Some(Err(P2PStreamError::EmptyProtocolMessage)))
}

// first check that the compressed message length does not exceed the max
// payload size
let decompressed_len = snap::raw::decompress_len(&bytes[1..])?;
Expand All @@ -430,7 +436,7 @@ where
err
})?;

let id = *bytes.first().ok_or(P2PStreamError::EmptyProtocolMessage)?;
let id = bytes[0];
match id {
_ if id == P2PMessageID::Ping as u8 => {
trace!("Received Ping, Sending Pong");
Expand Down
Loading