Skip to content

Commit

Permalink
fix(s2n-quic-dc): wait to insert in peer map until handshake completes (
Browse files Browse the repository at this point in the history
  • Loading branch information
WesleyRosenblum authored Oct 30, 2024
1 parent 6c7057f commit 01cbb44
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 19 deletions.
33 changes: 23 additions & 10 deletions dc/s2n-quic-dc/src/path/secret/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,19 +488,21 @@ impl Map {
Some(state.clone())
}

pub(super) fn insert(&self, entry: Arc<Entry>) {
pub(super) fn on_new_path_secrets(&self, entry: Arc<Entry>) {
// On insert clear our interest in a handshake.
self.state.requested_handshakes.pin().remove(&entry.peer);
let id = *entry.secret.id();
let peer = entry.peer;
if self.state.ids.insert(id, entry.clone()).is_some() {
if self.state.ids.insert(*entry.secret.id(), entry).is_some() {
// FIXME: Make insertion fallible and fail handshakes instead?
panic!("inserting a path secret ID twice");
}
}

pub(super) fn on_handshake_complete(&self, entry: Arc<Entry>) {
let id = *entry.secret.id();

if let Some(prev) = self.state.peers.insert(peer, entry) {
// This shouldn't happen due to the panic above, but just in case something went wrong
// with the secret map we double check here.
if let Some(prev) = self.state.peers.insert(entry.peer, entry) {
// This shouldn't happen due to the panic in on_new_path_secrets, but just
// in case something went wrong with the secret map we double check here.
// FIXME: Make insertion fallible and fail handshakes instead?
assert_ne!(*prev.secret.id(), id, "duplicate path secret id");

Expand Down Expand Up @@ -546,7 +548,8 @@ impl Map {
dc::testing::TEST_REHANDSHAKE_PERIOD,
);
let entry = Arc::new(entry);
provider.insert(entry);
provider.on_new_path_secrets(entry.clone());
provider.on_handshake_complete(entry);
}

(provider, ids)
Expand All @@ -573,7 +576,9 @@ impl Map {
dc::testing::TEST_APPLICATION_PARAMS,
dc::testing::TEST_REHANDSHAKE_PERIOD,
);
self.insert(Arc::new(entry));
let entry = Arc::new(entry);
self.on_new_path_secrets(entry.clone());
self.on_handshake_complete(entry);
}

fn send_control(&self, entry: &Entry, credentials: &Credentials, error: receiver::Error) {
Expand Down Expand Up @@ -1057,7 +1062,15 @@ impl dc::Path for HandshakingPath {
);
let entry = Arc::new(entry);
self.entry = Some(entry.clone());
self.map.insert(entry);
self.map.on_new_path_secrets(entry);
}

fn on_dc_handshake_complete(&mut self) {
let entry = self.entry.clone().expect(
"the dc handshake cannot be complete without \
on_peer_stateless_reset_tokens creating a map entry",
);
self.map.on_handshake_complete(entry);
}

fn on_mtu_updated(&mut self, mtu: u16) {
Expand Down
35 changes: 26 additions & 9 deletions dc/s2n-quic-dc/src/path/secret/map/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,19 @@ fn cleans_after_delay() {
let first = fake_entry(1);
let second = fake_entry(1);
let third = fake_entry(1);
map.insert(first.clone());
map.insert(second.clone());
map.on_new_path_secrets(first.clone());
map.on_handshake_complete(first.clone());
map.on_new_path_secrets(second.clone());
map.on_handshake_complete(second.clone());

assert!(map.state.ids.contains_key(first.secret.id()));
assert!(map.state.ids.contains_key(second.secret.id()));

map.state.cleaner.clean(&map.state, 1);
map.state.cleaner.clean(&map.state, 1);

map.insert(third.clone());
map.on_new_path_secrets(third.clone());
map.on_handshake_complete(third.clone());

assert!(!map.state.ids.contains_key(first.secret.id()));
assert!(map.state.ids.contains_key(second.secret.id()));
Expand Down Expand Up @@ -86,9 +89,10 @@ struct Model {

#[derive(bolero::TypeGenerator, Debug, Copy, Clone)]
enum Operation {
Insert { ip: u8, path_secret_id: TestId },
NewPathSecret { ip: u8, path_secret_id: TestId },
AdvanceTime,
ReceiveUnknown { path_secret_id: TestId },
HandshakeComplete { path_secret_id: TestId },
}

#[derive(bolero::TypeGenerator, PartialEq, Eq, Hash, Copy, Clone)]
Expand Down Expand Up @@ -130,13 +134,13 @@ enum Invariant {
impl Model {
fn perform(&mut self, operation: Operation, state: &Map) {
match operation {
Operation::Insert { ip, path_secret_id } => {
Operation::NewPathSecret { ip, path_secret_id } => {
let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([0, 0, 0, ip]), 0));
let secret = path_secret_id.secret();
let id = *secret.id();

let stateless_reset = state.state.signer.sign(&id);
state.insert(Arc::new(Entry::new(
state.on_new_path_secrets(Arc::new(Entry::new(
ip,
secret,
sender::State::new(stateless_reset),
Expand All @@ -145,9 +149,16 @@ impl Model {
dc::testing::TEST_REHANDSHAKE_PERIOD,
)));

self.invariants.insert(Invariant::ContainsIp(ip));
self.invariants.insert(Invariant::ContainsId(id));
}
Operation::HandshakeComplete { path_secret_id } => {
if let Some(entry) = state.state.ids.get_by_key(&path_secret_id.id()) {
if !state.state.peers.contains_key(&entry.peer) {
state.on_handshake_complete(entry.clone());
}
self.invariants.insert(Invariant::ContainsIp(entry.peer));
}
}
Operation::AdvanceTime => {
let mut invalidated = Vec::new();
self.invariants.retain(|invariant| {
Expand Down Expand Up @@ -232,7 +243,7 @@ fn has_duplicate_pids(ops: &[Operation]) -> bool {
let mut ids = HashSet::new();
for op in ops.iter() {
match op {
Operation::Insert {
Operation::NewPathSecret {
ip: _,
path_secret_id,
} => {
Expand All @@ -244,6 +255,10 @@ fn has_duplicate_pids(ops: &[Operation]) -> bool {
Operation::ReceiveUnknown { path_secret_id: _ } => {
// no-op, we're fine receiving unknown pids.
}
Operation::HandshakeComplete { .. } => {
// no-op, a handshake complete for the same pid as a
// new path secret is expected
}
}
}

Expand Down Expand Up @@ -320,7 +335,9 @@ fn no_memory_growth() {
map.state.cleaner.stop();
for idx in 0..500_000 {
// FIXME: this ends up 2**16 peers in the `peers` map
map.insert(fake_entry(idx as u16));
let entry = fake_entry(idx as u16);
map.on_new_path_secrets(entry.clone());
map.on_handshake_complete(entry)
}
}

Expand Down
4 changes: 4 additions & 0 deletions quic/s2n-quic-core/src/dc/disabled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ impl Path for () {
unimplemented!()
}

fn on_dc_handshake_complete(&mut self) {
unimplemented!()
}

fn on_mtu_updated(&mut self, _mtu: u16) {
unimplemented!()
}
Expand Down
8 changes: 8 additions & 0 deletions quic/s2n-quic-core/src/dc/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ impl MockDcEndpoint {
pub struct MockDcPath {
pub on_path_secrets_ready_count: u8,
pub on_peer_stateless_reset_tokens_count: u8,
pub on_dc_handshake_complete: u8,
pub stateless_reset_tokens: Vec<stateless_reset::Token>,
pub peer_stateless_reset_tokens: Vec<stateless_reset::Token>,
pub mtu: u16,
Expand Down Expand Up @@ -69,6 +70,7 @@ impl dc::Path for MockDcPath {
&mut self,
_session: &impl TlsSession,
) -> Result<Vec<stateless_reset::Token>, transport::Error> {
debug_assert_eq!(0, self.on_path_secrets_ready_count);
self.on_path_secrets_ready_count += 1;
Ok(self.stateless_reset_tokens.clone())
}
Expand All @@ -77,11 +79,17 @@ impl dc::Path for MockDcPath {
&mut self,
stateless_reset_tokens: impl Iterator<Item = &'a stateless_reset::Token>,
) {
debug_assert_eq!(0, self.on_peer_stateless_reset_tokens_count);
self.on_peer_stateless_reset_tokens_count += 1;
self.peer_stateless_reset_tokens
.extend(stateless_reset_tokens);
}

fn on_dc_handshake_complete(&mut self) {
debug_assert_eq!(0, self.on_dc_handshake_complete);
self.on_dc_handshake_complete += 1;
}

fn on_mtu_updated(&mut self, mtu: u16) {
self.mtu = mtu
}
Expand Down
12 changes: 12 additions & 0 deletions quic/s2n-quic-core/src/dc/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ pub trait Path: 'static + Send {
stateless_reset_tokens: impl Iterator<Item = &'a stateless_reset::Token>,
);

/// Called when the peer has confirmed receipt of `DC_STATELESS_RESET_TOKENS`, either
/// by the server sending back its own `DC_STATELESS_RESET_TOKENS` or by the client
/// acknowledging the `DC_STATELESS_RESET_TOKENS` frame was received.
fn on_dc_handshake_complete(&mut self);

/// Called when the MTU has been updated for the path
fn on_mtu_updated(&mut self, mtu: u16);
}
Expand Down Expand Up @@ -73,6 +78,13 @@ impl<P: Path> Path for Option<P> {
}
}

#[inline]
fn on_dc_handshake_complete(&mut self) {
if let Some(path) = self {
path.on_dc_handshake_complete()
}
}

#[inline]
fn on_mtu_updated(&mut self, max_datagram_size: u16) {
if let Some(path) = self {
Expand Down
2 changes: 2 additions & 0 deletions quic/s2n-quic-transport/src/dc/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ impl<Config: endpoint::Config> Manager<Config> {
if Config::ENDPOINT_TYPE.is_server() {
self.stateless_reset_token_sync.send();
} else {
self.path.on_dc_handshake_complete();
publisher.on_dc_state_changed(DcStateChanged {
state: DcState::Complete,
});
Expand All @@ -176,6 +177,7 @@ impl<Config: endpoint::Config> Manager<Config> {
ensure!(self.state.on_stateless_reset_tokens_acked().is_ok());

debug_assert!(Config::ENDPOINT_TYPE.is_server());
self.path.on_dc_handshake_complete();
publisher.on_dc_state_changed(DcStateChanged {
state: DcState::Complete,
});
Expand Down
4 changes: 4 additions & 0 deletions quic/s2n-quic-transport/src/dc/manager/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ fn on_peer_dc_stateless_reset_tokens<Config, Endpoint>(

if Config::ENDPOINT_TYPE.is_server() {
assert!(manager.state.is_server_tokens_sent());
assert_eq!(0, manager.path().on_dc_handshake_complete);
} else {
assert_eq!(1, manager.path().on_dc_handshake_complete);
assert!(manager.state.is_complete());
}

Expand All @@ -169,6 +171,7 @@ fn on_packet_ack_client() {

// Client completes when it has received stateless reset tokens from the peer
assert!(!manager.state.is_complete());
assert_eq!(0, manager.path().on_dc_handshake_complete);
}

#[test]
Expand All @@ -182,6 +185,7 @@ fn on_packet_ack_server() {

// Server completes when its stateless reset tokens are acked
assert!(manager.state.is_complete());
assert_eq!(1, manager.path().on_dc_handshake_complete);
}

fn on_packet_ack<Config, Endpoint>(
Expand Down

0 comments on commit 01cbb44

Please sign in to comment.