Skip to content

Commit

Permalink
proto: Make Connection internally use SideState
Browse files Browse the repository at this point in the history
Moves server/client-specific fields of proto::Connection to a new
SideState enum.
  • Loading branch information
gretchenfrage committed Dec 15, 2024
1 parent 73545b6 commit ed22361
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 27 deletions.
89 changes: 63 additions & 26 deletions quinn-proto/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ use timer::{Timer, TimerTable};
/// events or timeouts with different instants must not be interleaved.
pub struct Connection {
endpoint_config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig>>,
config: Arc<TransportConfig>,
rng: StdRng,
crypto: Box<dyn crypto::Session>,
Expand All @@ -145,7 +144,7 @@ pub struct Connection {
allow_mtud: bool,
prev_path: Option<(ConnectionId, PathData)>,
state: State,
side: Side,
side: ConnectionSide,
/// Whether or not 0-RTT was enabled during the handshake. Does not imply acceptance.
zero_rtt_enabled: bool,
/// Set if 0-RTT is supported, then cleared when no longer needed.
Expand Down Expand Up @@ -191,9 +190,6 @@ pub struct Connection {
authentication_failures: u64,
/// Why the connection was lost, if it has been
error: Option<ConnectionError>,
/// Sent in every outgoing Initial packet. Always empty for servers and after Initial keys are
/// discarded.
retry_token: Bytes,
/// Identifies Data-space packet numbers to skip. Not used in earlier spaces.
packet_number_filter: PacketNumberFilter,

Expand Down Expand Up @@ -258,11 +254,16 @@ impl Connection {
rng_seed: [u8; 32],
path_validated: bool,
) -> Self {
let side = if server_config.is_some() {
Side::Server
let connection_side = if let Some(server_config) = server_config.clone() {
ConnectionSide::Server { server_config }
} else {
Side::Client
assert!(pref_addr_cid.is_none());
assert!(path_validated);
ConnectionSide::Client {
token: Bytes::new(),
}
};
let side = connection_side.side();
let initial_space = PacketSpace {
crypto: Some(crypto.initial_keys(&init_cid, side)),
..PacketSpace::new(now)
Expand All @@ -275,7 +276,6 @@ impl Connection {
let mut rng = StdRng::from_seed(rng_seed);
let mut this = Self {
endpoint_config,
server_config,
crypto,
handshake_cid: loc_cid,
rem_handshake_cid: rem_cid,
Expand All @@ -289,8 +289,8 @@ impl Connection {
allow_mtud,
local_ip,
prev_path: None,
side,
state,
side: connection_side,
zero_rtt_enabled: false,
zero_rtt_crypto: None,
key_phase: false,
Expand Down Expand Up @@ -323,7 +323,6 @@ impl Connection {
timers: TimerTable::default(),
authentication_failures: 0,
error: None,
retry_token: Bytes::new(),
#[cfg(test)]
packet_number_filter: match config.deterministic_packet_numbers {
false => PacketNumberFilter::new(&mut rng),
Expand Down Expand Up @@ -420,7 +419,7 @@ impl Connection {
/// Provide control over streams
#[must_use]
pub fn recv_stream(&mut self, id: StreamId) -> RecvStream<'_> {
assert!(id.dir() == Dir::Bi || id.initiator() != self.side);
assert!(id.dir() == Dir::Bi || id.initiator() != self.side.side());
RecvStream {
id,
state: &mut self.streams,
Expand All @@ -431,7 +430,7 @@ impl Connection {
/// Provide control over streams
#[must_use]
pub fn send_stream(&mut self, id: StreamId) -> SendStream<'_> {
assert!(id.dir() == Dir::Bi || id.initiator() == self.side);
assert!(id.dir() == Dir::Bi || id.initiator() == self.side.side());
SendStream {
id,
state: &mut self.streams,
Expand Down Expand Up @@ -1075,9 +1074,7 @@ impl Connection {
// If this packet could initiate a migration and we're a client or a server that
// forbids migration, drop the datagram. This could be relaxed to heuristically
// permit NAT-rebinding-like migration.
if remote != self.path.remote
&& self.server_config.as_ref().map_or(true, |x| !x.migration)
{
if remote != self.path.remote && !self.side.remote_may_migrate() {
trace!("discarding packet from unrecognized peer {}", remote);
return;
}
Expand Down Expand Up @@ -1297,7 +1294,7 @@ impl Connection {

/// Look up whether we're the client or server of this Connection
pub fn side(&self) -> Side {
self.side
self.side.side()
}

/// The latest socket address for this connection's peer
Expand Down Expand Up @@ -2101,7 +2098,9 @@ impl Connection {
trace!("discarding {:?} keys", space_id);
if space_id == SpaceId::Initial {
// No longer needed
self.retry_token = Bytes::new();
if let ConnectionSide::Client { token, .. } = &mut self.side {
*token = Bytes::new();
}
}
let space = &mut self.spaces[space_id];
space.crypto = None;
Expand Down Expand Up @@ -2398,7 +2397,7 @@ impl Connection {

self.discard_space(now, SpaceId::Initial); // Make sure we clean up after any retransmitted Initials
self.spaces[SpaceId::Initial] = PacketSpace {
crypto: Some(self.crypto.initial_keys(&rem_cid, self.side)),
crypto: Some(self.crypto.initial_keys(&rem_cid, self.side.side())),
next_packet_number: self.spaces[SpaceId::Initial].next_packet_number,
crypto_offset: client_hello.len() as u64,
..PacketSpace::new(now)
Expand All @@ -2420,7 +2419,10 @@ impl Connection {
self.streams.retransmit_all_for_0rtt();

let token_len = packet.payload.len() - 16;
self.retry_token = packet.payload.freeze().split_to(token_len);
let ConnectionSide::Client { ref mut token, .. } = self.side else {
unreachable!("we already short-circuited if we're server");
};
*token = packet.payload.freeze().split_to(token_len);
self.state = State::Handshake(state::Handshake {
expected_token: Bytes::new(),
rem_cid_set: false,
Expand Down Expand Up @@ -2745,7 +2747,7 @@ impl Connection {
debug!(offset, "peer claims to be blocked at connection level");
}
Frame::StreamDataBlocked { id, offset } => {
if id.initiator() == self.side && id.dir() == Dir::Uni {
if id.initiator() == self.side.side() && id.dir() == Dir::Uni {
debug!("got STREAM_DATA_BLOCKED on send-only {}", id);
return Err(TransportError::STREAM_STATE_ERROR(
"STREAM_DATA_BLOCKED on send-only stream",
Expand All @@ -2768,7 +2770,7 @@ impl Connection {
);
}
Frame::StopSending(frame::StopSending { id, error_code }) => {
if id.initiator() != self.side {
if id.initiator() != self.side.side() {
if id.dir() == Dir::Uni {
debug!("got STOP_SENDING on recv-only {}", id);
return Err(TransportError::STREAM_STATE_ERROR(
Expand Down Expand Up @@ -2938,11 +2940,11 @@ impl Connection {
&& !is_probing_packet
&& number == self.spaces[SpaceId::Data].rx_packet
{
let ConnectionSide::Server { ref server_config } = self.side else {
panic!("packets from unknown remote should be dropped by clients");
};
debug_assert!(
self.server_config
.as_ref()
.expect("packets from unknown remote should be dropped by clients")
.migration,
server_config.migration,
"migration-initiating packets should have been dropped immediately"
);
self.migrate(now, remote);
Expand Down Expand Up @@ -3618,6 +3620,41 @@ impl fmt::Debug for Connection {
}
}

/// Fields of `Connection` specific to it being client-side or server-side
enum ConnectionSide {
Client {
/// Sent in every outgoing Initial packet. Always empty after Initial keys are discarded
token: Bytes,
},
Server {
server_config: Arc<ServerConfig>,
},
}

impl ConnectionSide {
fn side(&self) -> Side {
match *self {
Self::Client { .. } => Side::Client,
Self::Server { .. } => Side::Server,
}
}

fn is_client(&self) -> bool {
self.side().is_client()
}

fn is_server(&self) -> bool {
self.side().is_server()
}

fn remote_may_migrate(&self) -> bool {
match self {
Self::Server { server_config } => server_config.migration,
Self::Client { .. } => false,
}
}
}

/// Reasons why a connection might be lost
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ConnectionError {
Expand Down
6 changes: 5 additions & 1 deletion quinn-proto/src/connection/packet_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use tracing::{trace, trace_span};

use super::{spaces::SentPacket, Connection, SentFrames};
use crate::{
connection::ConnectionSide,
frame::{self, Close},
packet::{Header, InitialHeader, LongType, PacketNumber, PartialEncode, SpaceId, FIXED_BIT},
ConnectionId, Instant, TransportError, TransportErrorCode,
Expand Down Expand Up @@ -113,7 +114,10 @@ impl PacketBuilder {
SpaceId::Initial => Header::Initial(InitialHeader {
src_cid: conn.handshake_cid,
dst_cid,
token: conn.retry_token.clone(),
token: match &conn.side {
ConnectionSide::Client { token, .. } => token.clone(),
ConnectionSide::Server { .. } => Bytes::new(),
},
number,
version,
}),
Expand Down

0 comments on commit ed22361

Please sign in to comment.