diff --git a/dc/s2n-quic-dc/src/path/secret/map.rs b/dc/s2n-quic-dc/src/path/secret/map.rs index faf4ea3d1..b25a8e4b1 100644 --- a/dc/s2n-quic-dc/src/path/secret/map.rs +++ b/dc/s2n-quic-dc/src/path/secret/map.rs @@ -488,19 +488,21 @@ impl Map { Some(state.clone()) } - pub(super) fn insert(&self, entry: Arc) { + pub(super) fn on_new_path_secrets(&self, entry: Arc) { // On insert clear our interest in a handshake. self.state.requested_handshakes.pin().remove(&entry.peer); - let id = *entry.secret.id(); - let peer = entry.peer; - if self.state.ids.insert(id, entry.clone()).is_some() { + if self.state.ids.insert(*entry.secret.id(), entry).is_some() { // FIXME: Make insertion fallible and fail handshakes instead? panic!("inserting a path secret ID twice"); } + } + + pub(super) fn on_handshake_complete(&self, entry: Arc) { + let id = *entry.secret.id(); - if let Some(prev) = self.state.peers.insert(peer, entry) { - // This shouldn't happen due to the panic above, but just in case something went wrong - // with the secret map we double check here. + if let Some(prev) = self.state.peers.insert(entry.peer, entry) { + // This shouldn't happen due to the panic in on_new_path_secrets, but just + // in case something went wrong with the secret map we double check here. // FIXME: Make insertion fallible and fail handshakes instead? assert_ne!(*prev.secret.id(), id, "duplicate path secret id"); @@ -546,7 +548,8 @@ impl Map { dc::testing::TEST_REHANDSHAKE_PERIOD, ); let entry = Arc::new(entry); - provider.insert(entry); + provider.on_new_path_secrets(entry.clone()); + provider.on_handshake_complete(entry); } (provider, ids) @@ -573,7 +576,9 @@ impl Map { dc::testing::TEST_APPLICATION_PARAMS, dc::testing::TEST_REHANDSHAKE_PERIOD, ); - self.insert(Arc::new(entry)); + let entry = Arc::new(entry); + self.on_new_path_secrets(entry.clone()); + self.on_handshake_complete(entry); } fn send_control(&self, entry: &Entry, credentials: &Credentials, error: receiver::Error) { @@ -1057,7 +1062,15 @@ impl dc::Path for HandshakingPath { ); let entry = Arc::new(entry); self.entry = Some(entry.clone()); - self.map.insert(entry); + self.map.on_new_path_secrets(entry); + } + + fn on_dc_handshake_complete(&mut self) { + let entry = self.entry.clone().expect( + "the dc handshake cannot be complete without \ + on_peer_stateless_reset_tokens creating a map entry", + ); + self.map.on_handshake_complete(entry); } fn on_mtu_updated(&mut self, mtu: u16) { diff --git a/dc/s2n-quic-dc/src/path/secret/map/test.rs b/dc/s2n-quic-dc/src/path/secret/map/test.rs index b0a8fa41b..db9fac18e 100644 --- a/dc/s2n-quic-dc/src/path/secret/map/test.rs +++ b/dc/s2n-quic-dc/src/path/secret/map/test.rs @@ -41,8 +41,10 @@ fn cleans_after_delay() { let first = fake_entry(1); let second = fake_entry(1); let third = fake_entry(1); - map.insert(first.clone()); - map.insert(second.clone()); + map.on_new_path_secrets(first.clone()); + map.on_handshake_complete(first.clone()); + map.on_new_path_secrets(second.clone()); + map.on_handshake_complete(second.clone()); assert!(map.state.ids.contains_key(first.secret.id())); assert!(map.state.ids.contains_key(second.secret.id())); @@ -50,7 +52,8 @@ fn cleans_after_delay() { map.state.cleaner.clean(&map.state, 1); map.state.cleaner.clean(&map.state, 1); - map.insert(third.clone()); + map.on_new_path_secrets(third.clone()); + map.on_handshake_complete(third.clone()); assert!(!map.state.ids.contains_key(first.secret.id())); assert!(map.state.ids.contains_key(second.secret.id())); @@ -86,9 +89,10 @@ struct Model { #[derive(bolero::TypeGenerator, Debug, Copy, Clone)] enum Operation { - Insert { ip: u8, path_secret_id: TestId }, + NewPathSecret { ip: u8, path_secret_id: TestId }, AdvanceTime, ReceiveUnknown { path_secret_id: TestId }, + HandshakeComplete { path_secret_id: TestId }, } #[derive(bolero::TypeGenerator, PartialEq, Eq, Hash, Copy, Clone)] @@ -130,13 +134,13 @@ enum Invariant { impl Model { fn perform(&mut self, operation: Operation, state: &Map) { match operation { - Operation::Insert { ip, path_secret_id } => { + Operation::NewPathSecret { ip, path_secret_id } => { let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([0, 0, 0, ip]), 0)); let secret = path_secret_id.secret(); let id = *secret.id(); let stateless_reset = state.state.signer.sign(&id); - state.insert(Arc::new(Entry::new( + state.on_new_path_secrets(Arc::new(Entry::new( ip, secret, sender::State::new(stateless_reset), @@ -145,9 +149,16 @@ impl Model { dc::testing::TEST_REHANDSHAKE_PERIOD, ))); - self.invariants.insert(Invariant::ContainsIp(ip)); self.invariants.insert(Invariant::ContainsId(id)); } + Operation::HandshakeComplete { path_secret_id } => { + if let Some(entry) = state.state.ids.get_by_key(&path_secret_id.id()) { + if !state.state.peers.contains_key(&entry.peer) { + state.on_handshake_complete(entry.clone()); + } + self.invariants.insert(Invariant::ContainsIp(entry.peer)); + } + } Operation::AdvanceTime => { let mut invalidated = Vec::new(); self.invariants.retain(|invariant| { @@ -232,7 +243,7 @@ fn has_duplicate_pids(ops: &[Operation]) -> bool { let mut ids = HashSet::new(); for op in ops.iter() { match op { - Operation::Insert { + Operation::NewPathSecret { ip: _, path_secret_id, } => { @@ -244,6 +255,10 @@ fn has_duplicate_pids(ops: &[Operation]) -> bool { Operation::ReceiveUnknown { path_secret_id: _ } => { // no-op, we're fine receiving unknown pids. } + Operation::HandshakeComplete { .. } => { + // no-op, a handshake complete for the same pid as a + // new path secret is expected + } } } @@ -320,7 +335,9 @@ fn no_memory_growth() { map.state.cleaner.stop(); for idx in 0..500_000 { // FIXME: this ends up 2**16 peers in the `peers` map - map.insert(fake_entry(idx as u16)); + let entry = fake_entry(idx as u16); + map.on_new_path_secrets(entry.clone()); + map.on_handshake_complete(entry) } } diff --git a/quic/s2n-quic-core/src/dc/disabled.rs b/quic/s2n-quic-core/src/dc/disabled.rs index e3650b128..6efe7278c 100644 --- a/quic/s2n-quic-core/src/dc/disabled.rs +++ b/quic/s2n-quic-core/src/dc/disabled.rs @@ -45,6 +45,10 @@ impl Path for () { unimplemented!() } + fn on_dc_handshake_complete(&mut self) { + unimplemented!() + } + fn on_mtu_updated(&mut self, _mtu: u16) { unimplemented!() } diff --git a/quic/s2n-quic-core/src/dc/testing.rs b/quic/s2n-quic-core/src/dc/testing.rs index 6fd588062..3eb4883d6 100644 --- a/quic/s2n-quic-core/src/dc/testing.rs +++ b/quic/s2n-quic-core/src/dc/testing.rs @@ -34,6 +34,7 @@ impl MockDcEndpoint { pub struct MockDcPath { pub on_path_secrets_ready_count: u8, pub on_peer_stateless_reset_tokens_count: u8, + pub on_dc_handshake_complete: u8, pub stateless_reset_tokens: Vec, pub peer_stateless_reset_tokens: Vec, pub mtu: u16, @@ -69,6 +70,7 @@ impl dc::Path for MockDcPath { &mut self, _session: &impl TlsSession, ) -> Result, transport::Error> { + debug_assert_eq!(0, self.on_path_secrets_ready_count); self.on_path_secrets_ready_count += 1; Ok(self.stateless_reset_tokens.clone()) } @@ -77,11 +79,17 @@ impl dc::Path for MockDcPath { &mut self, stateless_reset_tokens: impl Iterator, ) { + debug_assert_eq!(0, self.on_peer_stateless_reset_tokens_count); self.on_peer_stateless_reset_tokens_count += 1; self.peer_stateless_reset_tokens .extend(stateless_reset_tokens); } + fn on_dc_handshake_complete(&mut self) { + debug_assert_eq!(0, self.on_dc_handshake_complete); + self.on_dc_handshake_complete += 1; + } + fn on_mtu_updated(&mut self, mtu: u16) { self.mtu = mtu } diff --git a/quic/s2n-quic-core/src/dc/traits.rs b/quic/s2n-quic-core/src/dc/traits.rs index 198c72301..012e3752f 100644 --- a/quic/s2n-quic-core/src/dc/traits.rs +++ b/quic/s2n-quic-core/src/dc/traits.rs @@ -46,6 +46,11 @@ pub trait Path: 'static + Send { stateless_reset_tokens: impl Iterator, ); + /// Called when the peer has confirmed receipt of `DC_STATELESS_RESET_TOKENS`, either + /// by the server sending back its own `DC_STATELESS_RESET_TOKENS` or by the client + /// acknowledging the `DC_STATELESS_RESET_TOKENS` frame was received. + fn on_dc_handshake_complete(&mut self); + /// Called when the MTU has been updated for the path fn on_mtu_updated(&mut self, mtu: u16); } @@ -73,6 +78,13 @@ impl Path for Option

{ } } + #[inline] + fn on_dc_handshake_complete(&mut self) { + if let Some(path) = self { + path.on_dc_handshake_complete() + } + } + #[inline] fn on_mtu_updated(&mut self, max_datagram_size: u16) { if let Some(path) = self { diff --git a/quic/s2n-quic-transport/src/dc/manager.rs b/quic/s2n-quic-transport/src/dc/manager.rs index 2384cc879..bae270e0d 100644 --- a/quic/s2n-quic-transport/src/dc/manager.rs +++ b/quic/s2n-quic-transport/src/dc/manager.rs @@ -160,6 +160,7 @@ impl Manager { if Config::ENDPOINT_TYPE.is_server() { self.stateless_reset_token_sync.send(); } else { + self.path.on_dc_handshake_complete(); publisher.on_dc_state_changed(DcStateChanged { state: DcState::Complete, }); @@ -176,6 +177,7 @@ impl Manager { ensure!(self.state.on_stateless_reset_tokens_acked().is_ok()); debug_assert!(Config::ENDPOINT_TYPE.is_server()); + self.path.on_dc_handshake_complete(); publisher.on_dc_state_changed(DcStateChanged { state: DcState::Complete, }); diff --git a/quic/s2n-quic-transport/src/dc/manager/tests.rs b/quic/s2n-quic-transport/src/dc/manager/tests.rs index 90a6544f4..b5841852c 100644 --- a/quic/s2n-quic-transport/src/dc/manager/tests.rs +++ b/quic/s2n-quic-transport/src/dc/manager/tests.rs @@ -149,7 +149,9 @@ fn on_peer_dc_stateless_reset_tokens( if Config::ENDPOINT_TYPE.is_server() { assert!(manager.state.is_server_tokens_sent()); + assert_eq!(0, manager.path().on_dc_handshake_complete); } else { + assert_eq!(1, manager.path().on_dc_handshake_complete); assert!(manager.state.is_complete()); } @@ -169,6 +171,7 @@ fn on_packet_ack_client() { // Client completes when it has received stateless reset tokens from the peer assert!(!manager.state.is_complete()); + assert_eq!(0, manager.path().on_dc_handshake_complete); } #[test] @@ -182,6 +185,7 @@ fn on_packet_ack_server() { // Server completes when its stateless reset tokens are acked assert!(manager.state.is_complete()); + assert_eq!(1, manager.path().on_dc_handshake_complete); } fn on_packet_ack(